├── mltc ├── dicts │ ├── userdict.dict │ └── stopwords.txt ├── outputs │ ├── log │ │ └── __init__.py │ ├── feature │ │ └── __init__.py │ ├── figure │ │ └── __init__.py │ ├── result │ │ └── __init__.py │ ├── checkpoints │ │ └── __init__.py │ └── embedding │ │ └── __init__.py ├── pretrain │ ├── __init__.py │ └── bert │ │ ├── __init__.py │ │ └── base-uncased │ │ └── __init__.py ├── models │ ├── textrcnn.py │ ├── textcnn.py │ ├── model.py │ └── bert_for_multi_label.py ├── run.sh ├── scheme │ └── error.py ├── train │ ├── losses.py │ ├── trainer.py │ └── metrics.py ├── pipeline.yml ├── preprocessors │ ├── processor.py │ ├── tests │ │ └── test_chinese.py │ ├── english.py │ └── chinese.py ├── postprocessors │ ├── processor.py │ ├── nopretrain.py │ └── bert.py ├── tokenizers │ └── tokenizer.py ├── configs │ └── basic_config.py ├── callback │ ├── progressbar.py │ ├── training_monitor.py │ └── model_checkpoint.py ├── predict │ └── predictor.py ├── dataio │ └── task_data.py ├── dataset │ └── labels.txt ├── engineerings │ └── engineering.py ├── utils │ └── utils.py └── main.py ├── pytest.ini ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /mltc/dicts/userdict.dict: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/outputs/log/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/outputs/feature/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/outputs/figure/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/outputs/result/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/pretrain/bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/outputs/checkpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/outputs/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/pretrain/bert/base-uncased/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mltc/models/textrcnn.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class TextRCNN: 4 | pass -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = tests.py test_*.py tests_*.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pnlp 3 | demjson 4 | pytorch-transformers==1.1.0 5 | matplotlib 6 | jieba 7 | pandas 8 | scikit-learn -------------------------------------------------------------------------------- /mltc/run.sh: -------------------------------------------------------------------------------- 1 | python main.py --do_data 2 | python main.py --do_train --save_best --train_batch_size 16 --eval_batch_size 16 --learning_rate 1e-2 3 | python main.py --do_test -------------------------------------------------------------------------------- /mltc/scheme/error.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | class Error(Exception): 4 | pass 5 | 6 | 7 | @dataclass 8 | class PipelineReadError(Error): 9 | code = 10000 10 | desc = "Please check your pipeline." 11 | 12 | 13 | @dataclass 14 | class PipelineFieldNotDefinedError(Error): 15 | code = 10001 16 | desc = "You have used an undefined field in the pipline config file." 17 | 18 | 19 | @dataclass 20 | class ModelNotDefinedError: 21 | code = 20000 22 | desc = "You have used an invalid model." -------------------------------------------------------------------------------- /mltc/train/losses.py: -------------------------------------------------------------------------------- 1 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 2 | 3 | 4 | class CrossEntropy(object): 5 | def __init__(self): 6 | self.loss_f = CrossEntropyLoss() 7 | 8 | def __call__(self, output, target): 9 | loss = self.loss_f(input=output, target=target) 10 | return loss 11 | 12 | class BCEWithLogLoss(object): 13 | def __init__(self): 14 | self.loss_fn = BCEWithLogitsLoss() 15 | 16 | def __call__(self,output,target): 17 | output = output.float() 18 | target = target.float() 19 | loss = self.loss_fn(input = output,target = target) 20 | return loss 21 | -------------------------------------------------------------------------------- /mltc/pipeline.yml: -------------------------------------------------------------------------------- 1 | # pipeline: 2 | # preprocessor: "ChineseChar" 3 | # pretrain: "Bert" 4 | # postprocessor: "Bert" 5 | # classifier: "BertFC" 6 | 7 | # pipeline: 8 | # preprocessor: "ChineseChar" 9 | # pretrain: "Bert" 10 | # postprocessor: "Bert" 11 | # classifier: "BertCNN" 12 | 13 | pipeline: 14 | preprocessor: "ChineseChar" 15 | pretrain: "Bert" 16 | postprocessor: "Bert" 17 | classifier: "BertRCNN" 18 | 19 | # pipeline: 20 | # preprocessor: "ChineseChar" 21 | # pretrain: "Bert" 22 | # postprocessor: "Bert" 23 | # classifier: "BertDPCNN" 24 | 25 | # pipeline: 26 | # preprocessor: "ChineseChar" 27 | # pretrain: "Nopretrain" 28 | # postprocessor: "Nopretrain" 29 | # classifier: "TextCNN" 30 | 31 | # pipeline: 32 | # preprocessor: "ChineseWord" 33 | # pretrain: "Nopretrain" 34 | # postprocessor: "Nopretrain" 35 | # classifier: "TextCNN" 36 | 37 | -------------------------------------------------------------------------------- /mltc/preprocessors/processor.py: -------------------------------------------------------------------------------- 1 | from preprocessors.english import EnglishProcessor 2 | from preprocessors.chinese import ChineseCharProcessor, ChineseWordProcessor 3 | from scheme.error import PipelineFieldNotDefinedError 4 | 5 | 6 | class Preprocessor: 7 | 8 | def __init__(self, choose: str): 9 | self.choose = choose 10 | 11 | def __call__(self, stopwords_path: str, userdict_path: str): 12 | if self.choose == "English": 13 | processor = EnglishProcessor(stopwords_path=stopwords_path) 14 | 15 | elif self.choose == "ChineseChar": 16 | processor = ChineseCharProcessor(stopwords_path=stopwords_path) 17 | 18 | elif self.choose == "ChineseWord": 19 | processor = ChineseWordProcessor( 20 | stopwords_path=stopwords_path, userdict_path=userdict_path) 21 | 22 | else: 23 | raise PipelineFieldNotDefinedError 24 | 25 | return processor 26 | -------------------------------------------------------------------------------- /mltc/postprocessors/processor.py: -------------------------------------------------------------------------------- 1 | from postprocessors.bert import BertProcessor 2 | from postprocessors.nopretrain import NopretrainProcessor 3 | from scheme.error import PipelineFieldNotDefinedError 4 | from configs.basic_config import config 5 | 6 | class Postprocessor: 7 | 8 | def __init__(self, choose_post: str): 9 | self.choose_post = choose_post 10 | 11 | def __call__(self, do_lower_case: bool): 12 | if self.choose_post == "Bert": 13 | vocab_path = config.bert_vocab_path 14 | processor = BertProcessor( 15 | vocab_path=vocab_path, do_lower_case=do_lower_case) 16 | elif self.choose_post == "Nopretrain": 17 | vocab_path = config.nopretrain_vocab_path 18 | processor = NopretrainProcessor( 19 | vocab_path=vocab_path) 20 | 21 | else: 22 | raise PipelineFieldNotDefinedError 23 | 24 | return processor 25 | 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mltc/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | 4 | def load_vocab(vocab_file): 5 | """Loads a vocabulary file into a dictionary.""" 6 | vocab = collections.OrderedDict() 7 | with open(vocab_file, "r", encoding="utf-8") as reader: 8 | tokens = reader.readlines() 9 | for index, token in enumerate(tokens): 10 | token = token.rstrip('\n') 11 | vocab[token] = index 12 | return vocab 13 | 14 | 15 | class Tokenizer: 16 | 17 | def __init__(self, vocab_path: str): 18 | if not os.path.isfile(vocab_path): 19 | raise ValueError( 20 | "Can't find a vocabulary file at path '{}'.".format(vocab_path)) 21 | self.vocab = load_vocab(vocab_path) 22 | self.ids_to_tokens = collections.OrderedDict( 23 | [(ids, tok) for tok, ids in self.vocab.items()]) 24 | self.unk_token = "[UNK]" 25 | self.pad_token = "[PAD]" 26 | self.bos_token = "[BOS]" 27 | self.eos_token = "[EOS]" 28 | 29 | @property 30 | def vocab_size(self): 31 | return len(self.vocab) 32 | 33 | def _convert_token_to_id(self, token): 34 | """ Converts a token (str/unicode) in an id using the vocab. """ 35 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 36 | 37 | def tokenize(self, text: str): 38 | return text.split(" ") 39 | 40 | 41 | def convert_tokens_to_ids(self, token_list: list): 42 | ids = [] 43 | for token in token_list: 44 | ids.append(self._convert_token_to_id(token)) 45 | return ids 46 | 47 | 48 | -------------------------------------------------------------------------------- /mltc/configs/basic_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from utils.utils import AttrDict 3 | 4 | BASE_DIR = Path('.') 5 | 6 | dct = { 7 | 'raw_data_path': BASE_DIR / 'dataset/train.csv', 8 | 'test_path': BASE_DIR / 'dataset/test.csv', 9 | 'nopretrain_vocab_path': BASE_DIR / 'dataset/vocab.txt', 10 | 11 | 'data_dir': BASE_DIR / 'dataset', 12 | 'log_dir': BASE_DIR / 'outputs/log', 13 | 'writer_dir': BASE_DIR / "outputs/TSboard", 14 | 'figure_dir': BASE_DIR / "outputs/figure", 15 | 'checkpoint_dir': BASE_DIR / "outputs/checkpoints", 16 | 'cache_dir': BASE_DIR / 'models/', 17 | 'result': BASE_DIR / "outputs/result", 18 | 19 | 'bert_vocab_path': BASE_DIR / 'pretrain/bert/base-uncased/bert_vocab.txt', 20 | 'bert_config_file': BASE_DIR / 'pretrain/bert/base-uncased/config.json', 21 | 'bert_model_dir': BASE_DIR / 'pretrain/bert/base-uncased', 22 | 23 | 'word2vec_model_dir': BASE_DIR / 'pretrain/word2vec/word2vec.vec', 24 | 'word2vec_vocab_path': BASE_DIR / 'pretrain/word2vec/vocab.txt', 25 | 26 | 'stopwords_path': BASE_DIR / 'dicts/stopwords.txt', 27 | 'userdict_path': BASE_DIR / 'dicts/userdict.dict', 28 | 29 | 'embedding_size': 300, 30 | 'vocab_size': 0, 31 | 'dropout': 0.5, # 0.5 32 | 33 | 'cnn': { 34 | 'num_filters': 256, # 256 35 | 'filter_sizes': (2, 3, 4) 36 | }, 37 | 38 | 'rcnn': { 39 | 'rnn_hidden': 256, 40 | 'num_layers': 2, 41 | 'kernel_size': 256, 42 | 'dropout': 0.5 43 | }, 44 | 45 | 'dpcnn': { 46 | "num_filters": 256, # 256 47 | "kernel_size": 256 48 | } 49 | } 50 | 51 | config = AttrDict(dct) 52 | 53 | 54 | if __name__ == '__main__': 55 | print(BASE_DIR) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | notebook.ipynb 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /mltc/callback/progressbar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class ProgressBar(object): 4 | 5 | def __init__(self, n_total,width=30): 6 | self.width = width 7 | self.n_total = n_total 8 | self.start_time = time.time() 9 | 10 | def batch_step(self, step, info, bar_type='Training'): 11 | now = time.time() 12 | current = step + 1 13 | recv_per = current / self.n_total 14 | bar = f'[{bar_type}] {current}/{self.n_total} [' 15 | if recv_per >= 1: 16 | recv_per = 1 17 | prog_width = int(self.width * recv_per) 18 | if prog_width > 0: 19 | bar += '=' * (prog_width - 1) 20 | if current< self.n_total: 21 | bar += ">" 22 | else: 23 | bar += '=' 24 | bar += '.' * (self.width - prog_width) 25 | bar += ']' 26 | show_bar = f"\r{bar}" 27 | time_per_unit = (now - self.start_time) / current 28 | if current < self.n_total: 29 | eta = time_per_unit * (self.n_total - current) 30 | if eta > 3600: 31 | eta_format = ('%d:%02d:%02d' % 32 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 33 | elif eta > 60: 34 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 35 | else: 36 | eta_format = '%ds' % eta 37 | time_info = f' - ETA: {eta_format}' 38 | else: 39 | if time_per_unit >= 1: 40 | time_info = f' {time_per_unit:.1f}s/step' 41 | elif time_per_unit >= 1e-3: 42 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 43 | else: 44 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 45 | 46 | show_bar += time_info 47 | if len(info) != 0: 48 | show_info = f'{show_bar} ' + \ 49 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 50 | print(show_info, end='') 51 | else: 52 | print(show_bar, end='') 53 | -------------------------------------------------------------------------------- /mltc/predict/predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.utils import model_device 4 | from callback.progressbar import ProgressBar 5 | from sklearn.metrics import f1_score 6 | 7 | 8 | class Predictor(object): 9 | def __init__(self, 10 | model, 11 | logger, 12 | n_gpu 13 | ): 14 | self.model = model 15 | self.logger = logger 16 | self.model, self.device = model_device(n_gpu=n_gpu, model=self.model) 17 | 18 | def predict(self, data, thresh): 19 | pbar = ProgressBar(n_total=len(data)) 20 | all_logits = None 21 | # y_true = torch.LongTensor() 22 | y_true = None 23 | self.model.eval() 24 | with torch.no_grad(): 25 | for step, batch in enumerate(data): 26 | batch = tuple(t.to(self.device) for t in batch) 27 | input_ids, input_mask, segment_ids, label_ids = batch 28 | # y_true = torch.cat((y_true, label_ids), 0) 29 | if y_true is None: 30 | y_true = label_ids.detach().cpu().numpy() 31 | else: 32 | y_true = np.concatenate( 33 | [y_true, label_ids.detach().cpu().numpy()], axis=0) 34 | logits = self.model(input_ids, segment_ids, input_mask) 35 | logits = logits.sigmoid() 36 | if all_logits is None: 37 | all_logits = logits.detach().cpu().numpy() 38 | else: 39 | all_logits = np.concatenate( 40 | [all_logits, logits.detach().cpu().numpy()], axis=0) 41 | pbar.batch_step(step=step, info={}, bar_type='Testing') 42 | y_pred = (all_logits > thresh) * 1 43 | micro = f1_score(y_true, y_pred, average='micro') 44 | macro = f1_score(y_true, y_pred, average='macro') 45 | score = (micro + macro) / 2 46 | self.logger.info( 47 | "\nScore: micro {}, macro {} Average {}".format( 48 | micro, macro, score)) 49 | if 'cuda' in str(self.device): 50 | torch.cuda.empty_cache() 51 | return all_logits, y_pred 52 | -------------------------------------------------------------------------------- /mltc/callback/training_monitor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import matplotlib.pyplot as plt 4 | from pnlp import piop 5 | 6 | 7 | plt.switch_backend('agg') 8 | 9 | 10 | class TrainingMonitor(): 11 | def __init__(self, file_dir, arch, add_test=False): 12 | ''' 13 | :param startAt: 重新开始训练的epoch点 14 | ''' 15 | if isinstance(file_dir, Path): 16 | pass 17 | else: 18 | file_dir = Path(file_dir) 19 | file_dir.mkdir(parents=True, exist_ok=True) 20 | 21 | self.arch = arch 22 | self.file_dir = file_dir 23 | self.H = {} 24 | self.add_test = add_test 25 | self.json_path = file_dir / (arch + "_training_monitor.json") 26 | 27 | def reset(self,start_at): 28 | if start_at > 0: 29 | if self.json_path is not None: 30 | if self.json_path.exists(): 31 | self.H = piop.read_json(self.json_path) 32 | for k in self.H.keys(): 33 | self.H[k] = self.H[k][:start_at] 34 | 35 | def epoch_step(self, logs={}): 36 | for (k, v) in logs.items(): 37 | l = self.H.get(k, []) 38 | # np.float32会报错 39 | if not isinstance(v, np.float): 40 | v = round(float(v), 4) 41 | l.append(v) 42 | self.H[k] = l 43 | 44 | # 写入文件 45 | if self.json_path is not None: 46 | piop.write_json(self.json_path, self.H) 47 | 48 | # 保存train图像 49 | if len(self.H["loss"]) == 1: 50 | self.paths = {key: self.file_dir / (self.arch + f'_{key.upper()}') for key in self.H.keys()} 51 | 52 | if len(self.H["loss"]) > 1: 53 | # 指标变化 54 | # 曲线 55 | # 需要成对出现 56 | keys = [key for key, _ in self.H.items() if '_' not in key] 57 | for key in keys: 58 | N = np.arange(0, len(self.H[key])) 59 | plt.style.use("ggplot") 60 | plt.figure() 61 | plt.plot(N, self.H[key], label=f"train_{key}") 62 | plt.plot(N, self.H[f"valid_{key}"], label=f"valid_{key}") 63 | if self.add_test: 64 | plt.plot(N, self.H[f"test_{key}"], label=f"test_{key}") 65 | plt.legend() 66 | plt.xlabel("Epoch #") 67 | plt.ylabel(key) 68 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]") 69 | plt.savefig(str(self.paths[key])) 70 | plt.close() 71 | -------------------------------------------------------------------------------- /mltc/models/textcnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.utils import get_embeddings_from_file, logger 6 | 7 | 8 | 9 | class TextCNN(nn.Module): 10 | def __init__(self, config): 11 | super(TextCNN, self).__init__() 12 | self.config = config 13 | if self.config.embedding_pretrained: 14 | embeddings = get_embeddings_from_file(self.config.embedding_pretrained) 15 | self.embedding = nn.Embedding.from_pretrained( 16 | self.config.embedding_pretrained, freeze=False) 17 | else: 18 | self.embedding = nn.Embedding( 19 | self.config.vocab_size, 20 | self.config.embedding_size) 21 | self.convs = nn.ModuleList( 22 | [nn.Conv2d(1, self.config.num_filters, (k, self.config.embedding_size)) 23 | for k in self.config.filter_sizes]) 24 | self.dropout = nn.Dropout(self.config.dropout) 25 | print(self.config) 26 | self.fc = nn.Linear(self.config.num_filters * 27 | len(self.config.filter_sizes), self.config.num_labels) 28 | 29 | def conv_and_pool(self, x, conv): 30 | x = F.relu(conv(x)).squeeze(3) 31 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 32 | return x 33 | 34 | def forward(self, *inputs): 35 | input_ids = inputs[0] 36 | out = self.embedding(input_ids.long()) 37 | out = out.unsqueeze(1) 38 | out = torch.cat([self.conv_and_pool(out, conv) 39 | for conv in self.convs], 1) 40 | out = self.dropout(out) 41 | out = self.fc(out) 42 | return out 43 | 44 | def save_pretrained(self, save_directory): 45 | """ Save a model and its configuration file to a directory, so that it 46 | can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. 47 | """ 48 | assert os.path.isdir( 49 | save_directory), "\ 50 | Saving path should be a directory where the model and configuration can be saved" 51 | 52 | # Only save the model itself if we are using distributed training 53 | model_to_save = self.module if hasattr(self, 'module') else self 54 | 55 | # Save configuration file 56 | # model_to_save.config.save_pretrained(save_directory) 57 | 58 | # If we save using the predefined names, 59 | # we can load using `from_pretrained` 60 | output_model_file = os.path.join(save_directory, "pytorch_model.bin") 61 | torch.save(model_to_save.state_dict(), output_model_file) 62 | logger.info("Model weights saved in {}".format(output_model_file)) 63 | -------------------------------------------------------------------------------- /mltc/dataio/task_data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import Counter 3 | import os 4 | import random 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from utils.utils import save_pickle, logger, deserializate 8 | from pnlp import piop 9 | from sklearn.model_selection import train_test_split 10 | 11 | 12 | class TaskData: 13 | 14 | def __init__(self, data_num: int = 0): 15 | self.data_num = data_num 16 | 17 | def train_val_split(self, X: list, y: list, valid_size: float, 18 | data_name=None, data_dir=None, save=True): 19 | logger.info('split train data into train and valid') 20 | Xy = [] 21 | for i in range(len(X)): 22 | Xy.append((X[i], y[i])) 23 | train, valid = train_test_split( 24 | Xy, test_size=valid_size, random_state=42) 25 | if save: 26 | train_path = data_dir / "{}.train.pkl".format(data_name) 27 | valid_path = data_dir / "{}.valid.pkl".format(data_name) 28 | save_pickle(data=train, file_path=train_path) 29 | save_pickle(data=valid, file_path=valid_path) 30 | return train, valid 31 | 32 | def read_data(self, raw_data_path: str, data_dir: str, 33 | preprocessor=None, is_train=True): 34 | targets, sents = [], [] 35 | data = pd.read_csv(raw_data_path) 36 | if self.data_num: 37 | data = data.head(self.data_num) 38 | label_path = data_dir / "labels.txt" 39 | all_cates = [lb for lb in piop.read_lines(label_path) if len(lb) > 0] 40 | 41 | if is_train: 42 | data["category"] = data["meta"].apply( 43 | lambda x: deserializate(x).get("accusation")) 44 | for cate in all_cates: 45 | data[cate] = data["category"].apply(lambda x: int(cate in x)) 46 | 47 | for row in data.values: 48 | if is_train: 49 | target = row[3:] 50 | sent = str(row[0]) 51 | else: 52 | target = [-1] * len(all_cates) 53 | sent = str(row[0]) 54 | 55 | if preprocessor: 56 | sent = preprocessor(sent) 57 | 58 | if sent: 59 | targets.append(target) 60 | sents.append(sent) 61 | return targets, sents 62 | 63 | def build_vocab(self, vocab_path: str, data_list: list, min_count: int): 64 | vocab = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"] 65 | lst = [] 66 | for sent in data_list: 67 | for token in sent.split(): 68 | lst.append(token) 69 | count = Counter(lst) 70 | for key, freq in count.most_common(): 71 | if freq >= min_count: 72 | vocab.append(key) 73 | piop.write_file(vocab_path, vocab) 74 | 75 | -------------------------------------------------------------------------------- /mltc/dataset/labels.txt: -------------------------------------------------------------------------------- 1 | 非法[制造、销售]非法制造的注册商标标识 2 | 非法[制造、买卖、运输、储存]危险物质 3 | 挪用特定款物 4 | 走私[武器、弹药] 5 | 徇私舞弊[不征、少征]税款 6 | 交通肇事 7 | 强迫交易 8 | 非法采矿 9 | 非法[生产、买卖]警用装备 10 | 故意伤害 11 | 介绍贿赂 12 | 传播性病 13 | 非法[生产、销售]间谍专用器材 14 | [盗窃、抢夺][枪支、弹药、爆炸物] 15 | 虚开发票 16 | 动植物检疫徇私舞弊 17 | 侮辱 18 | 妨害作证 19 | 聚众扰乱[公共场所秩序、交通秩序] 20 | [虚开增值税专用发票、用于骗取出口退税、抵扣税款发票] 21 | 危险驾驶 22 | 过失致人重伤 23 | 非法行医 24 | 破坏计算机信息系统 25 | [盗窃、侮辱]尸体 26 | 单位受贿 27 | 提供[侵入、非法控制计算机信息系统][程序、工具] 28 | 非法[猎捕、杀害][珍贵、濒危]野生动物 29 | 重大劳动安全事故 30 | 非法买卖制毒物品 31 | 倒卖文物 32 | 受贿 33 | 破坏易燃易爆设备 34 | 高利转贷 35 | 过失致人死亡 36 | 污染环境 37 | 破坏监管秩序 38 | 妨害公务 39 | 非法吸收公众存款 40 | [伪造、变造、买卖]国家机关[公文、证件、印章] 41 | 非法占用农用地 42 | 非法[持有、私藏][枪支、弹药] 43 | 伪证 44 | 巨额财产来源不明 45 | 假冒注册商标 46 | 非法获取国家秘密 47 | [生产、销售][有毒、有害]食品 48 | [伪造、变造、买卖]武装部队[公文、证件、印章] 49 | 非法经营 50 | 冒充军人招摇撞骗 51 | 对非国家工作人员行贿 52 | 非法[收购、运输、出售][珍贵、濒危野生动物、珍贵、濒危野生动物]制品 53 | 开设赌场 54 | 破坏生产经营 55 | 经济犯 56 | 信用卡诈骗 57 | 对单位行贿 58 | 走私普通[货物、物品] 59 | 寻衅滋事 60 | 走私国家禁止进出口的[货物、物品] 61 | 诬告陷害 62 | 非法[收购、运输][盗伐、滥伐]的林木 63 | 组织卖淫 64 | 侵犯著作权 65 | 徇私枉法 66 | 赌博 67 | 强迫他人吸毒 68 | 玩忽职守 69 | 金融凭证诈骗 70 | 破坏交通设施 71 | 传授犯罪方法 72 | 重大责任事故 73 | 敲诈勒索 74 | 破坏电力设备 75 | 徇私舞弊不移交刑事案件 76 | [伪造、变造]居民身份证 77 | [窝藏、转移、收购、销售]赃物 78 | [窝藏、包庇] 79 | 过失损坏[广播电视设施、公用电信设施] 80 | 绑架 81 | 挪用公款 82 | 非法进行节育手术 83 | 盗窃 84 | 非法获取公民个人信息 85 | 单位行贿 86 | [伪造、倒卖]伪造的有价票证 87 | 贪污 88 | 盗掘[古文化遗址、古墓葬] 89 | 非法捕捞水产品 90 | 非法拘禁 91 | 盗伐林木 92 | [窃取、收买、非法提供]信用卡信息 93 | 诈骗 94 | 聚众哄抢 95 | 非法侵入住宅 96 | [组织、领导、参加]黑社会性质组织 97 | 包庇毒品犯罪分子 98 | 强制[猥亵、侮辱]妇女 99 | 强迫卖淫 100 | [盗窃、抢夺][枪支、弹药、爆炸物、危险物质] 101 | 聚众斗殴 102 | [生产、销售]不符合安全标准的食品 103 | 故意毁坏财物 104 | 保险诈骗 105 | 非法[采伐、毁坏]国家重点保护植物 106 | [走私、贩卖、运输、制造]毒品 107 | 失火 108 | 协助组织卖淫 109 | 销售假冒注册商标的商品 110 | 帮助[毁灭、伪造]证据 111 | 收买被拐卖的[妇女、儿童] 112 | 票据诈骗 113 | [掩饰、隐瞒][犯罪所得、犯罪所得收益] 114 | [引诱、容留、介绍]卖淫 115 | 拐卖[妇女、儿童] 116 | 洗钱 117 | 帮助犯罪分子逃避处罚 118 | 爆炸 119 | 招收[公务员、学生]徇私舞弊 120 | 过失投放危险物质 121 | 非法[转让、倒卖]土地使用权 122 | 虐待 123 | 拐骗儿童 124 | 强奸 125 | 脱逃 126 | 扰乱无线电通讯管理秩序 127 | [生产、销售]伪劣[农药、兽药、化肥、种子] 128 | 妨害信用卡管理 129 | 走私 130 | 非法[制造、买卖、运输、邮寄、储存][枪支、弹药、爆炸物] 131 | 骗取[贷款、票据承兑、金融票证] 132 | 逃税 133 | 非法携带[枪支、弹药、管制刀具、危险物品]危及公共安全 134 | 非法[买卖、运输、携带、持有]毒品原植物[种子、幼苗] 135 | 虐待被监管人 136 | 非法出售发票 137 | 虚报注册资本 138 | 滥用职权 139 | 危险物品肇事 140 | 走私废物 141 | 抢夺 142 | 放火 143 | 非法[制造、出售]非法制造的发票 144 | [出售、购买、运输]假币 145 | [引诱、教唆、欺骗]他人吸毒 146 | 集资诈骗 147 | 违法发放贷款 148 | [持有、使用]假币 149 | 贷款诈骗 150 | [生产、销售]伪劣产品 151 | 以危险方法危害公共安全 152 | 招摇撞骗 153 | 利用影响力受贿 154 | 猥亵儿童 155 | 聚众冲击国家机关 156 | [隐匿、故意销毁][会计凭证、会计帐簿、财务会计报告] 157 | 拒不支付劳动报酬 158 | [编造、故意传播]虚假恐怖信息 159 | 滥伐林木 160 | 持有伪造的发票 161 | 遗弃 162 | 非法组织卖血 163 | 合同诈骗 164 | 非法[收购、运输、加工、出售][国家重点保护植物、国家重点保护植物制品] 165 | 强迫劳动 166 | [制造、贩卖、传播]淫秽物品 167 | 过失以危险方法危害公共安全 168 | 投放危险物质 169 | 非法种植毒品原植物 170 | 非国家工作人员受贿 171 | [生产、销售]假药 172 | 故意杀人 173 | [组织、领导]传销活动 174 | 打击报复证人 175 | 私分国有资产 176 | 串通投标 177 | 挪用资金 178 | 职务侵占 179 | 侵占 180 | 过失损坏[武器装备、军事设施、军事通信] 181 | 伪造[公司、企业、事业单位、人民团体]印章 182 | 重婚 183 | 传播淫秽物品 184 | 诽谤 185 | [组织、强迫、引诱、容留、介绍]卖淫 186 | 倒卖[车票、船票] 187 | [窝藏、转移、隐瞒][毒品、毒赃] 188 | 非法处置[查封、扣押、冻结]的财产 189 | [伪造、变造]金融票证 190 | 劫持[船只、汽车] 191 | 非法狩猎 192 | 行贿 193 | 破坏交通工具 194 | 破坏[广播电视设施、公用电信设施] 195 | 抢劫 196 | [制作、复制、出版、贩卖、传播]淫秽物品牟利 197 | 走私[珍贵动物、珍贵动物制品] 198 | 非法持有毒品 199 | 拒不执行[判决、裁定] 200 | 伪造货币 201 | 聚众扰乱社会秩序 202 | 容留他人吸毒 203 | -------------------------------------------------------------------------------- /mltc/models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models.bert_for_multi_label import (BertFCForMultiLable, 4 | BertCNNForMultiLabel, 5 | BertRCNNForMultiLabel, 6 | BertDPCNNForMultiLabel) 7 | from models.textcnn import TextCNN 8 | from models.textrcnn import TextRCNN 9 | 10 | from scheme.error import ModelNotDefinedError 11 | from configs.basic_config import config 12 | 13 | 14 | class Classifier: 15 | 16 | def __init__(self, choose_model: str, 17 | choose_pretrain: str, 18 | resume_path: str): 19 | self.choose_model = choose_model 20 | self.choose_pretrain = choose_pretrain 21 | self.resume_path = resume_path 22 | 23 | def __call__(self, num_labels: int): 24 | if self.choose_pretrain == "Bert": 25 | if self.resume_path: 26 | model_dir = self.resume_path 27 | else: 28 | model_dir = config.bert_model_dir 29 | 30 | if self.choose_model == "BertFC": 31 | model = BertFCForMultiLable.from_pretrained( 32 | model_dir, num_labels=num_labels) 33 | elif self.choose_model == "BertCNN": 34 | model = BertCNNForMultiLabel.from_pretrained( 35 | model_dir, num_labels=num_labels) 36 | elif self.choose_model == "BertRCNN": 37 | model = BertRCNNForMultiLabel.from_pretrained( 38 | model_dir, num_labels=num_labels) 39 | elif self.choose_model == "BertDPCNN": 40 | model = BertDPCNNForMultiLabel.from_pretrained( 41 | model_dir, num_labels=num_labels) 42 | else: 43 | raise ModelNotDefinedError 44 | 45 | elif self.choose_pretrain in ["Word2vec", "Nopretrain"]: 46 | if self.resume_path: 47 | model_dir = self.resume_path 48 | else: 49 | model_dir = None 50 | 51 | if self.choose_pretrain == "Word2vec": 52 | pretrain_model_dir = config.word2vec_model_dir 53 | else: 54 | pretrain_model_dir = None 55 | 56 | if self.choose_model == "TextCNN": 57 | cnn_config = config.cnn 58 | cnn_config.embedding_pretrained = pretrain_model_dir 59 | cnn_config.embedding_size = config.embedding_size 60 | cnn_config.vocab_size = config.vocab_size 61 | cnn_config.dropout = config.dropout 62 | cnn_config.num_labels = num_labels 63 | if self.resume_path: 64 | model = TextCNN(cnn_config) 65 | state_dict_file = os.path.join( 66 | model_dir, "pytorch_model.bin") 67 | model.load_state_dict(torch.load(state_dict_file)) 68 | else: 69 | model = TextCNN(cnn_config) 70 | 71 | elif self.choose_model == "TextRCNN": 72 | rcnn_config = config.rcnn 73 | rcnn_config.num_labels = num_labels 74 | if self.resume_path: 75 | model = TextRCNN(rcnn_config) 76 | else: 77 | model = TextRCNN(rcnn_config) 78 | model.load_state_dict(torch.load(model_dir)) 79 | else: 80 | raise ModelNotDefinedError 81 | 82 | else: 83 | raise ModelNotDefinedError 84 | 85 | return model 86 | -------------------------------------------------------------------------------- /mltc/callback/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import torch 4 | from utils.utils import load_pickle, logger 5 | 6 | 7 | class ModelCheckpoint(object): 8 | """Save the model after every epoch. 9 | # Arguments 10 | checkpoint_dir: string, path to save the model file. 11 | monitor: quantity to monitor. 12 | verbose: verbosity mode, 0 or 1. 13 | save_best_only: if `save_best_only=True`, 14 | the latest best model according to 15 | the quantity monitored will not be overwritten. 16 | mode: one of {auto, min, max}. 17 | If `save_best_only=True`, the decision 18 | to overwrite the current save file is made 19 | based on either the maximization or the 20 | minimization of the monitored quantity. For `val_acc`, 21 | this should be `max`, for `val_loss` this should 22 | be `min`, etc. In `auto` mode, the direction is 23 | automatically inferred from the name of the monitored quantity. 24 | """ 25 | 26 | def __init__(self, checkpoint_dir, 27 | monitor, 28 | arch, 29 | mode='min', 30 | epoch_freq=1, 31 | best=None, 32 | save_best_only=True): 33 | if isinstance(checkpoint_dir, Path): 34 | checkpoint_dir = checkpoint_dir 35 | else: 36 | checkpoint_dir = Path(checkpoint_dir) 37 | assert checkpoint_dir.is_dir() 38 | checkpoint_dir.mkdir(exist_ok=True) 39 | self.base_path = checkpoint_dir 40 | self.arch = arch 41 | self.monitor = monitor 42 | self.epoch_freq = epoch_freq 43 | self.save_best_only = save_best_only 44 | 45 | # 计算模式 46 | if mode == 'min': 47 | self.monitor_op = np.less 48 | self.best = np.Inf 49 | 50 | elif mode == 'max': 51 | self.monitor_op = np.greater 52 | self.best = -np.Inf 53 | # 这里主要重新加载模型时候 54 | # 对best重新赋值 55 | if best: 56 | self.best = best 57 | 58 | if save_best_only: 59 | self.model_name = f"BEST_{arch}_MODEL.pth" 60 | 61 | def epoch_step(self, state, current): 62 | ''' 63 | :param state: 需要保存的信息 64 | :param current: 当前判断指标 65 | :return: 66 | ''' 67 | if self.save_best_only: 68 | if self.monitor_op(current, self.best): 69 | logger.info(f"\nEpoch {state['epoch']}: \ 70 | {self.monitor} improved from \ 71 | {self.best:.5f} to {current:.5f}") 72 | self.best = current 73 | state['best'] = self.best 74 | best_path = self.base_path / self.model_name 75 | torch.save(state, str(best_path)) 76 | 77 | else: 78 | filename = self.base_path / f"epoch_{state['epoch']}_{state[self.monitor]}_{self.arch}_model.bin" 79 | if state['epoch'] % self.epoch_freq == 0: 80 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.") 81 | torch.save(state, str(filename)) 82 | 83 | def bert_epoch_step(self, state, current): 84 | model_to_save = state['model'] 85 | if self.save_best_only: 86 | if self.monitor_op(current, self.best): 87 | logger.info(f"\nEpoch {state['epoch']}: \ 88 | {self.monitor} improved from \ 89 | {self.best:.5f} to {current:.5f}") 90 | self.best = current 91 | state['best'] = self.best 92 | model_to_save.save_pretrained(str(self.base_path)) 93 | output_config_file = self.base_path / 'config.json' 94 | with open(str(output_config_file), 'w') as f: 95 | f.write(model_to_save.config.to_json_string()) 96 | state.pop("model") 97 | torch.save(state, self.base_path / 'checkpoint_info.bin') 98 | 99 | else: 100 | if state['epoch'] % self.epoch_freq == 0: 101 | save_path = self.base_path / f"checkpoint-epoch-{state['epoch']}" 102 | save_path.mkdir(exist_ok=True) 103 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.") 104 | model_to_save.save_pretrained(save_path) 105 | output_config_file = save_path / 'config.json' 106 | with open(str(output_config_file), 'w') as f: 107 | f.write(model_to_save.config.to_json_string()) 108 | state.pop("model") 109 | torch.save(state, save_path / 'checkpoint_info.bin') 110 | -------------------------------------------------------------------------------- /mltc/postprocessors/nopretrain.py: -------------------------------------------------------------------------------- 1 | from pnlp import piop 2 | import torch 3 | import numpy as np 4 | from tokenizers.tokenizer import Tokenizer 5 | from torch.utils.data import TensorDataset 6 | from pytorch_transformers import BertTokenizer 7 | 8 | from callback.progressbar import ProgressBar 9 | from utils.utils import load_pickle, logger 10 | 11 | 12 | class InputExample: 13 | 14 | def __init__(self, guid, text, labels): 15 | self.guid = guid 16 | self.text = text 17 | self.labels = labels 18 | 19 | 20 | class InputFeature: 21 | 22 | def __init__(self, input_ids, label_ids, input_len): 23 | self.input_ids = input_ids 24 | self.label_ids = label_ids 25 | self.input_len = input_len 26 | 27 | 28 | class NopretrainProcessor: 29 | 30 | def __init__(self, vocab_path: str): 31 | self.tokenizer = Tokenizer(vocab_path) 32 | self.pad_id = self.tokenizer.vocab.get(self.tokenizer.pad_token, 0) 33 | 34 | @property 35 | def vocab_size(self): 36 | return self.tokenizer.vocab_size 37 | 38 | def get_train(self, data_file): 39 | return self.read_data(data_file) 40 | 41 | def get_dev(self, data_file): 42 | return self.read_data(data_file) 43 | 44 | def get_test(self, lines): 45 | return lines 46 | 47 | @classmethod 48 | def read_data(cls, input_file): 49 | return load_pickle(input_file) 50 | 51 | def get_labels(self, data_file): 52 | return [lb for lb in piop.read_lines(data_file) if len(lb) > 0] 53 | 54 | def create_examples(self, lines, example_type, cached_examples_file): 55 | pbar = ProgressBar(n_total=len(lines)) 56 | if cached_examples_file.exists(): 57 | logger.info("Loading examples from cached file %s", 58 | cached_examples_file) 59 | examples = torch.load(cached_examples_file) 60 | else: 61 | examples = [] 62 | for i, line in enumerate(lines): 63 | guid = '%s-%d' % (example_type, i) 64 | text = line[0] 65 | labels = line[1] 66 | if isinstance(labels, str): 67 | labels = [np.float(x) for x in labels.split(",")] 68 | else: 69 | labels = [np.float(x) for x in list(labels)] 70 | example = InputExample( 71 | guid=guid, text=text, labels=labels) 72 | examples.append(example) 73 | pbar.batch_step(step=i, info={}, bar_type='create examples') 74 | logger.info("Saving examples into cached file %s", 75 | cached_examples_file) 76 | torch.save(examples, cached_examples_file) 77 | return examples 78 | 79 | def create_features(self, examples, max_seq_len, cached_features_file): 80 | pbar = ProgressBar(n_total=len(examples)) 81 | if cached_features_file.exists(): 82 | logger.info("Loading features from cached file %s", 83 | cached_features_file) 84 | features = torch.load(cached_features_file) 85 | else: 86 | features = [] 87 | for ex_id, example in enumerate(examples): 88 | tokens = self.tokenizer.tokenize(example.text) 89 | label_ids = example.labels 90 | 91 | if len(tokens) > max_seq_len: 92 | tokens = tokens[:max_seq_len] 93 | 94 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 95 | padding = [self.pad_id] * (max_seq_len - len(input_ids)) 96 | input_len = len(input_ids) 97 | 98 | input_ids += padding 99 | 100 | assert len(input_ids) == max_seq_len 101 | 102 | if ex_id < 2: 103 | logger.info( 104 | "*** Example ***") 105 | logger.info( 106 | f"guid: {example.guid}" % ()) 107 | logger.info( 108 | f"tokens: {' '.join([str(x) for x in tokens])}") 109 | logger.info( 110 | f"input_ids: {' '.join([str(x) for x in input_ids])}") 111 | 112 | feature = InputFeature(input_ids=input_ids, 113 | label_ids=label_ids, 114 | input_len=input_len) 115 | features.append(feature) 116 | pbar.batch_step(step=ex_id, info={}, 117 | bar_type='create features') 118 | logger.info("Saving features into cached file %s", 119 | cached_features_file) 120 | torch.save(features, cached_features_file) 121 | return features 122 | 123 | def create_dataset(self, features, is_sorted=False): 124 | if is_sorted: 125 | logger.info("sorted data by th length of input") 126 | features = sorted( 127 | features, key=lambda x: x.input_len, reverse=True) 128 | all_input_ids = torch.tensor( 129 | [f.input_ids for f in features], dtype=torch.long) 130 | all_label_ids = torch.tensor( 131 | [f.label_ids for f in features], dtype=torch.long) 132 | dataset = TensorDataset(all_input_ids, all_input_ids, all_input_ids, all_label_ids) 133 | return dataset 134 | -------------------------------------------------------------------------------- /mltc/preprocessors/tests/test_chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pytest 4 | 5 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(root) 7 | 8 | from chinese import ChineseProcessor, Chinese2Arabic 9 | 10 | @pytest.fixture 11 | def setup(): 12 | cp = ChineseProcessor() 13 | return cp 14 | 15 | 16 | def test_clean_nums(setup): 17 | text = "1.3的-3说法 2% 发生2000的 1e9 发 1发" 18 | new = setup.clean_nums(text) 19 | assert new == "X的X说法 X 发生X的 X 发 X发" 20 | 21 | 22 | def test_clean_punctuation(setup): 23 | text = ",a.?!哈哈?,。我!你……《》就是~不”“说:嘿嘿" 24 | new = setup.clean_punctuation(text) 25 | assert new.split() == "a 哈哈 我 你 就是 不 说 嘿嘿".split() 26 | 27 | 28 | def test_clean_linkpic(setup): 29 | inp = """https://www.yam.com 或 30 | [网址](https://www.google.cn/)和 31 | ![](http://xx.jpg)""" 32 | new = setup.clean_linkpic(inp) 33 | assert new.replace("\n", "").replace(" ", "") == "或和" 34 | 35 | 36 | def test_cnnum2num(setup): 37 | text = "一万五千六百三十八元" 38 | new = setup.cnnum2num(text, "元") 39 | assert new == "15638元" 40 | 41 | text = "一百二十公斤" 42 | new = setup.cnnum2num(text, "公斤") 43 | assert new == "120公斤" 44 | 45 | text = "300吨" 46 | new = setup.cnnum2num(text, "吨") 47 | assert new == "300吨" 48 | 49 | text = "1万克" 50 | new = setup.cnnum2num(text, "克") 51 | assert new == "1万克" 52 | 53 | 54 | def test_concentration_convert(setup): 55 | inp = 2 56 | new = setup.concentration_convert(inp) 57 | assert new == "T" 58 | 59 | inp = 2000.32342 60 | new = setup.concentration_convert(inp) 61 | assert new == "V" 62 | 63 | inp = "哈哈" 64 | new = setup.concentration_convert(inp) 65 | assert new == "哈哈" 66 | 67 | 68 | def test_quantity_convert(setup): 69 | inp = "2万" 70 | new = setup.quantity_convert(inp) 71 | assert new == "数万" 72 | 73 | inp = "3000" 74 | new = setup.quantity_convert(inp) 75 | assert new == "数千" 76 | 77 | inp = "500.232" 78 | new = setup.quantity_convert(inp) 79 | assert new == "数百" 80 | 81 | inp = "3亿" 82 | new = setup.quantity_convert(inp) 83 | assert new == "数亿" 84 | 85 | inp = "哈哈" 86 | new = setup.quantity_convert(inp) 87 | assert new == "哈哈" 88 | 89 | inp = "3万多" 90 | new = setup.quantity_convert(inp) 91 | assert new == "3万多" 92 | 93 | inp = "30多万" 94 | new = setup.quantity_convert(inp) 95 | assert new == "30多万" 96 | 97 | inp = "3万亿" 98 | new = setup.quantity_convert(inp) 99 | assert new == "数" 100 | 101 | 102 | def test_clean_punctuation(setup): 103 | inp = "我,你。他?哈,后~《爱》" 104 | new = setup.clean_punctuation(inp) 105 | assert new == "我 你 他 哈 后 爱 " 106 | 107 | 108 | def test_clean_date(setup): 109 | inp = "2018年,18年3月,3月2日,2日晚,2018-11-1,1987/3/25,1987.04.22" 110 | new = setup.clean_date(inp) 111 | assert new == "X年,X年X月,X月X日,X日晚,X年X月X日,X年X月X日,X年X月X日" 112 | 113 | 114 | def test_clean_time(setup): 115 | inp = "十点,十点三十分,八时整,UTC+09:00,18:09:01,18:09" 116 | new = setup.clean_time(inp) 117 | assert new == "X时,X时X分,X时整,X点,X点X分,X点X分" 118 | 119 | 120 | def test_clean_money(setup): 121 | inp = "十元,三里,一千万元啊,这是两百元。给你。19.42万元,共8万元。18.32,万,千,百,亿元。" 122 | new = setup.clean_money(inp) 123 | assert new == "数十元,三里,数千万元啊,这是数百元。给你。数十万元,共数万元。18.32,万,千,百,亿元。" 124 | 125 | 126 | def test_clean_weight(setup): 127 | inp = "123千克,三百二十克,两百多吨,一百二十公斤,1万斤,20000吨,好多。" 128 | new = setup.clean_weight(inp) 129 | assert new == "数百千克,数百克,两百多吨,数百千克,数万斤,数万吨,好多。" 130 | 131 | inp = "3043克白粉,20斤白面,3000吨钢材,三千吨钢材。" 132 | new = setup.clean_weight(inp) 133 | assert new == "数千克白粉,数十斤白面,数千吨钢材,数千吨钢材。" 134 | 135 | 136 | def test_clean_concentration(setup): 137 | inp = "浓度达214,浓度分别超国家规定的排放标准8.38" 138 | new = setup.clean_concentration(inp) 139 | assert new == "浓度达H,浓度分别超国家规定的排放标准T" 140 | 141 | 142 | def test_clean_entity(setup): 143 | inp = "永顺县人民检察院指控,张三去爬珠穆朗玛峰了。" 144 | new = setup.clean_entity(inp) 145 | assert new == "LO指控,P去爬L了。" 146 | 147 | 148 | def test_clean_stopwords(setup): 149 | inp = ["我", "喜欢", "你"] 150 | new = setup.clean_stopwords(inp) 151 | assert new == " ".join(inp) 152 | 153 | root = os.path.dirname(os.path.dirname( 154 | os.path.dirname(os.path.abspath(__file__)))) 155 | stopwords_path = os.path.join(root, "dicts", "stopwords.txt") 156 | assert os.path.exists(stopwords_path) == True 157 | setup.reset(stopwords_path) 158 | 159 | inp = ["我", "喜欢", "你"] 160 | new = setup.clean_stopwords(inp) 161 | assert new == "喜欢" 162 | 163 | 164 | def test_clean_english(setup): 165 | inp = "Lenovo 是联想,Alibaba 是阿里巴巴。" 166 | new = setup.clean_english(inp) 167 | assert new == "E 是联想,E 是阿里巴巴。" 168 | 169 | 170 | def test_chinese2arabic(): 171 | ca = Chinese2Arabic() 172 | s = "一亿三千万" 173 | assert ca(s) == 130000000 174 | s = "一万五千六百三十八" 175 | assert ca(s) == 15638 176 | s = "壹仟两百" 177 | assert ca(s) == 1200 178 | s = "十一" 179 | assert ca(s) == 11 180 | s = "三" 181 | assert ca(s) == 3 182 | s = "两百五十" 183 | assert ca(s) == 250 184 | s = "两百零五" 185 | assert ca(s) == 205 186 | s = "二十万五千" 187 | assert ca(s) == 205000 188 | s = "两百三十九万四千八百二十三" 189 | assert ca(s) == 2394823 190 | s = "一千三百万" 191 | assert ca(s) == 13000000 192 | s = "万" 193 | assert ca(s) == "万" 194 | s = "亿" 195 | assert ca(s) == "亿" 196 | s = "千" 197 | assert ca(s) == "千" 198 | s = "百" 199 | assert ca(s) == "百" 200 | s = "零" 201 | assert ca(s) == 0 202 | 203 | 204 | if __name__ == '__main__': 205 | print(root) 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | This is a repo focused on Multi-Label Text Classification, the main structure was forked from [lonePatient/Bert-Multi-Label-Text-Classification](https://github.com/lonePatient/Bert-Multi-Label-Text-Classification). We did several improvements: 4 | 5 | - Add a pipeline to automatically configure the whole thing 6 | - Add a preprocessor for Chinese 7 | - Add an engineering part 8 | - Add a basic tokenizer 9 | 10 | ## Environment 11 | 12 | ```bash 13 | # create a venv and install the dependencies 14 | python3 -m venv env 15 | source env/bin/activate 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Usage 20 | 21 | - Prepare dataset 22 | 23 | - Read Dataset below. 24 | - Add `train.csv` and `test.csv` to `dataset/` 25 | - Each line of the `train.csv` has two fields (fact and meta). Each line of the `test.csv` has only one field: `fact`, the output is under `outputs/result` 26 | 27 | - If you want to evaluate your test score, please modify `main.py` line 181: `is_train=False` to `is_train=True`, make sure your test dataset has two fields like the train dataset. 28 | - Paths and filenames can be defined in `configs/basic_config.py` 29 | 30 | - Prepare pretrained model 31 | 32 | - Add pretrained files to `pretrain/bert/base-uncased/`, here we used a [domain model](https://github.com/thunlp/OpenCLaP) 33 | - Paths and filenames can be defined in `configs/basic_config.py` 34 | - For example bert filenames should be: `config.json`, `pytorch_model.bin` and `bert_vocab.txt` 35 | 36 | - Define a pipeline 37 | 38 | - Edit `pipeline.yml` 39 | - We have already added several pipelines in the file 40 | - When the pipeline has been changed, it's better to clean the cached dataset 41 | 42 | - Run `./run.sh` 43 | 44 | - Also could run by hand 45 | 46 | ```bash 47 | python main.py --do_data 48 | python main.py --do_train --save_best 49 | python main.py --do_test 50 | ``` 51 | 52 | - Or set the train, test data number 53 | 54 | ```bash 55 | python main.py --do_data --train_data_num 100 56 | python main.py --do_train --save_best 57 | python main.py --do_test --test_data_num 10 58 | ``` 59 | 60 | # 中文处理 61 | 62 | ## Dataset 63 | 64 | [thunlp/CAIL: Chinese AI & Law Challenge](https://github.com/thunlp/CAIL) Task 1 65 | 66 | 类别信息是 meta 中的 accusation,共 202 种类别,每个 fact 可能有多个类别。 67 | 68 | 69 | ## Preprocess 70 | 71 | 预处理主要做了以下工作: 72 | 73 | - 剔除文本中的链接、图片等 74 | - 剔除标点符号 75 | - 剔除停用词 76 | - 统一处理日期、时间 77 | - 替换为 X,如 X年X月X日X时X分 78 | - 中文数字数字化 79 | - 金额 80 | - 重量 81 | - 浓度 82 | - 统一处理金额 83 | - 按大小区间分类 84 | - 统一处理重量 85 | - 按大小区间分类 86 | - 统一处理浓度 87 | - 按大小区间分类 88 | - 统一处理地点、人名等实体 89 | - 统一替换为特定的 Token 90 | 91 | ## Features 92 | 93 | 尝试做了一些特征工程,不过并未实际运用在模型当中。主要包括几个方面: 94 | 95 | - 基于长度的特征 96 | - 文本长度 97 | - 基于词的特征 98 | - 总词数 99 | - 标点符号占总词数比例 100 | - 数字占总词数比例 101 | - hapax_legomena1:出现一次的词占总词数的比例 102 | - hapax_legomena2:出现两次的词占总词数的比例 103 | - 一字词比例 104 | - 二字词比例 105 | - 三字词比例 106 | - 四字词比例 107 | - TTR:词 token 数/总词数 108 | - 基于句子的特征 109 | - 短句数 110 | - 整句数 111 | - 基于内容的特征 112 | - 被告数量 113 | - 被告中男性比例 114 | - 被告中法定代表人比例 115 | - 担保数 116 | - 关键词 117 | - 基于 TextRank 和 TF-IDF 118 | 119 | ## Models 120 | 121 | 模型可选择的非常多,不过总体来看可以分为使用预训练模型和不使用预训练模型两种。一般情况下,使用预训练模型的效果要好于不使用。预训练模型可以使用词向量,也可以使用 Bert 和基于 Bert 的不同变形。 122 | 123 | 这里我们首先选择基本的、不使用预训练模型的 TextCNN 作为 Baseline,该模型如下图所示: 124 | 125 | ![](https://pic3.zhimg.com/80/v2-bb10ad5bbdc5294d3041662f887e60a6_hd.png) 126 | 127 | TextCNN 类似于 Ngram 滑动窗口提取特征,MaxPooling 获取重要特征,多个通道获取不同类型的特征。最大的问题是 MaxPooling 丢失了内部结构信息。 128 | 129 | 然后选择 Bert 作为预训练模型,我们选择了[领域 Bert 模型](https://github.com/thunlp/OpenCLaP),分别尝试了直接使用 Bert 的 classification 信息,TextRCNN 和 TextDPCNN。之所以选择这两个模型,是因为它们在之前的测评中显示的[结果](https://github.com/Tencent/NeuralNLP-NeuralClassifier)较好。 130 | 131 | TextRCNN 相当于 RNN + CNN,其基本结构如下图所示: 132 | 133 | ![](https://pic3.zhimg.com/80/v2-263209ce34c0941fece21de00065aa92_hd.png) 134 | 135 | RNN 我们采用双向 LSTM,结果与 embedding 拼接后再接一个 Maxpooling 获取重要特征。 136 | 137 | DPCNN 可以看作是多个叠加的 CNN,结果如下图所示: 138 | 139 | ![](https://ask.qcloudimg.com/http-save/yehe-1178513/hon9vbfkku.jpeg?imageView2/2/w/1620) 140 | 141 | 先做了两次宽度为 3,filter 数量为 250 个的卷积,然后开始做两两相邻的 MaxPooling(丢失很少的信息,提取更抽象的特征)。每个 block 中,池化后的结果与卷积后的结果相加。 142 | 143 | 除了上面介绍的几种模型外,还有其他一些比较常见的模型: 144 | 145 | - TextRNN:就是 TextRCNN 去掉 MaxPooling 后的部分。 146 | - TextRNN + Attention:在 TextRNN 后面加了一个 Attention。这个模型还有个版本(HAN)是分层的,也就是先获取 word-level 的表示,在此基础上再获得 sentence-level 的表示,不同的层级分别对应有 Attention。 147 | - FastText:Word + Bigram + Trigram,拼接后取序列平均。 148 | - Transformer:完全的基于 Attention。 149 | 150 | 此外,一般也会做模型融合: 151 | 152 | - 机器学习 + 深度学习 153 | - 多个模型结果拼接后再用分类器分类 154 | - 使用差异大的模型:模型差异越大,融合效果越好 155 | - 重新划分训练集、验证集和测试集:相当于改变模型输入制造模型差异 156 | 157 | 因为 GPU 太贵了,就没有一一尝试了。一般比较好的结果应该是: 158 | 159 | - 合理的数据集 160 | - 精细的预处理 161 | - 适当的特征工程 162 | - 不同模型融合 163 | 164 | ## Others 165 | 166 | - 数据不平衡问题 167 | - 补充数据,或用相关数据增强数据 168 | 169 | - 对数目小的类别进行过采样 170 | - 调整 loss 中样本权重 171 | 172 | - 标签相似问题 173 | 174 | 175 | ## References 176 | 177 | - [lonePatient/Bert-Multi-Label-Text-Classification: This repo contains a PyTorch implementation of a pretrained BERT model for multi-label text classification.](https://github.com/lonePatient/Bert-Multi-Label-Text-Classification) 178 | - [649453932/Bert-Chinese-Text-Classification-Pytorch: 使用Bert,ERNIE,进行中文文本分类](https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch) 179 | - [Tencent/NeuralNLP-NeuralClassifier: An Open-source Neural Hierarchical Multi-label Text Classification Toolkit](https://github.com/Tencent/NeuralNLP-NeuralClassifier) 180 | - [中文文本分类 pytorch实现 - 知乎](https://zhuanlan.zhihu.com/p/73176084) 181 | - [PyTorch 官方教程中文版 - PyTorch 官方教程中文版](http://pytorch.panchuang.net/) 182 | - [GuidoPaul/CAIL2019: 中国法研杯司法人工智能挑战赛之相似案例匹配第一名解决方案](https://github.com/GuidoPaul/CAIL2019) 183 | - [jingyihiter/mycail: 中国法研杯 - 司法人工智能挑战赛](https://github.com/jingyihiter/mycail) 184 | - [中文文本分类 pytorch实现 - 知乎](https://zhuanlan.zhihu.com/p/73176084) 185 | - [如何到 top5%?NLP 文本分类和情感分析竞赛总结](https://posts.careerengine.us/p/5c383710ed75772cc919313c) 186 | - [用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 - 知乎](https://zhuanlan.zhihu.com/p/25928551) 187 | - [“达观杯”文本分类挑战赛Top10经验分享 - 知乎](https://zhuanlan.zhihu.com/p/45391378) 188 | - [中国法研杯---司法人工智能挑战赛 - 知乎](https://zhuanlan.zhihu.com/p/47024891) 189 | - [达观数据曾彦能:如何用深度学习做好长文本分类与法律文书智能化处理 - 云 + 社区 - 腾讯云](https://cloud.tencent.com/developer/article/1519320) 190 | - [深度学习网络调参技巧 - 知乎](https://zhuanlan.zhihu.com/p/24720954?utm_source=zhihu&utm_medium=social) 191 | 192 | # Changelog 193 | 194 | - 191130 add 中文处理 195 | 196 | - 191127 updated usage details 197 | 198 | - 191126 created -------------------------------------------------------------------------------- /mltc/preprocessors/english.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | replacement = { 4 | "aren't": "are not", 5 | "can't": "cannot", 6 | "couldn't": "could not", 7 | "didn't": "did not", 8 | "doesn't": "does not", 9 | "don't": "do not", 10 | "hadn't": "had not", 11 | "hasn't": "has not", 12 | "haven't": "have not", 13 | "he'd": "he would", 14 | "he'll": "he will", 15 | "he's": "he is", 16 | "i'd": "I would", 17 | "i'll": "I will", 18 | "i'm": "I am", 19 | "isn't": "is not", 20 | "it's": "it is", 21 | "it'll": "it will", 22 | "i've": "I have", 23 | "let's": "let us", 24 | "mightn't": "might not", 25 | "mustn't": "must not", 26 | "shan't": "shall not", 27 | "she'd": "she would", 28 | "she'll": "she will", 29 | "she's": "she is", 30 | "shouldn't": "should not", 31 | "that's": "that is", 32 | "there's": "there is", 33 | "they'd": "they would", 34 | "they'll": "they will", 35 | "they're": "they are", 36 | "they've": "they have", 37 | "we'd": "we would", 38 | "we're": "we are", 39 | "weren't": "were not", 40 | "we've": "we have", 41 | "what'll": "what will", 42 | "what're": "what are", 43 | "what's": "what is", 44 | "what've": "what have", 45 | "where's": "where is", 46 | "who'd": "who would", 47 | "who'll": "who will", 48 | "who're": "who are", 49 | "who's": "who is", 50 | "who've": "who have", 51 | "won't": "will not", 52 | "wouldn't": "would not", 53 | "you'd": "you would", 54 | "you'll": "you will", 55 | "you're": "you are", 56 | "you've": "you have", 57 | "'re": " are", 58 | "wasn't": "was not", 59 | "we'll": " will", 60 | "tryin'": "trying", 61 | } 62 | 63 | 64 | class EnglishProcessor(object): 65 | def __init__(self, min_len=2, stopwords_path=None): 66 | self.min_len = min_len 67 | self.stopwords_path = stopwords_path 68 | self.reset() 69 | 70 | def lower(self, sentence): 71 | return sentence.lower() 72 | 73 | def reset(self): 74 | if self.stopwords_path: 75 | with open(self.stopwords_path, 'r') as fr: 76 | self.stopwords = {} 77 | for line in fr: 78 | word = line.strip(' ').strip('\n') 79 | self.stopwords[word] = 1 80 | 81 | def clean_length(self, sentence): 82 | if len([x for x in sentence]) >= self.min_len: 83 | return sentence 84 | 85 | def replace(self, sentence): 86 | # Replace words like gooood to good 87 | sentence = re.sub(r'(\w)\1{2,}', r'\1\1', sentence) 88 | # Normalize common abbreviations 89 | words = sentence.split(' ') 90 | words = [replacement[word] 91 | if word in replacement else word for word in words] 92 | sentence_repl = " ".join(words) 93 | return sentence_repl 94 | 95 | def remove_website(self, sentence): 96 | sentence_repl = sentence.replace(r"http\S+", "") 97 | sentence_repl = sentence_repl.replace(r"https\S+", "") 98 | sentence_repl = sentence_repl.replace(r"http", "") 99 | sentence_repl = sentence_repl.replace(r"https", "") 100 | return sentence_repl 101 | 102 | def remove_name_tag(self, sentence): 103 | sentence_repl = sentence.replace(r"@\S+", "") 104 | return sentence_repl 105 | 106 | def remove_time(self, sentence): 107 | # Remove time related text 108 | sentence_repl = sentence.replace( 109 | r'\w{3}[+-][0-9]{1,2}\:[0-9]{2}\b', "") # e.g. UTC+09:00 110 | sentence_repl = sentence_repl.replace( 111 | r'\d{1,2}\:\d{2}\:\d{2}', "") # e.g. 18:09:01 112 | sentence_repl = sentence_repl.replace( 113 | r'\d{1,2}\:\d{2}', "") # e.g. 18:09 114 | # Remove date related text 115 | # e.g. 11/12/19, 11-1-19, 1.12.19, 11/12/2019 116 | sentence_repl = sentence_repl.replace( 117 | r'\d{1,2}(?:\/|\-|\.)\d{1,2}(?:\/|\-|\.)\d{2,4}', "") 118 | # e.g. 11 dec, 2019 11 dec 2019 dec 11, 2019 119 | sentence_repl = sentence_repl.replace( 120 | r"([\d]{1,2}\s(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)|(jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)\s[\d]{1,2})(\s|\,|\,\s|\s\,)[\d]{2,4}", 121 | "") 122 | # e.g. 11 december, 2019 11 december 2019 december 11, 2019 123 | sentence_repl = sentence_repl.replace( 124 | r"[\d]{1,2}\s(january|february|march|april|may|june|july|august|september|october|november|december)(\s|\,|\,\s|\s\,)[\d]{2,4}", 125 | "") 126 | return sentence_repl 127 | 128 | def remove_breaks(self, sentence): 129 | # Remove line breaks 130 | sentence_repl = sentence.replace("\r", "") 131 | sentence_repl = sentence_repl.replace("\n", "") 132 | sentence_repl = re.sub(r"\\n\n", ".", sentence_repl) 133 | return sentence_repl 134 | 135 | def remove_ip(self, sentence): 136 | # Remove phone number and IP address 137 | sentence_repl = sentence.replace(r'\d{8,}', "") 138 | sentence_repl = sentence_repl.replace( 139 | r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', "") 140 | return sentence_repl 141 | 142 | def adjust_common(self, sentence): 143 | # Adjust common abbreviation 144 | sentence_repl = sentence.replace(r" you re ", " you are ") 145 | sentence_repl = sentence_repl.replace(r" we re ", " we are ") 146 | sentence_repl = sentence_repl.replace(r" they re ", " they are ") 147 | sentence_repl = sentence_repl.replace(r"@", "at") 148 | return sentence_repl 149 | 150 | def remove_chinese(self, sentence): 151 | # Chinese bad word 152 | sentence_repl = re.sub(r"fucksex", "fuck sex", sentence) 153 | sentence_repl = re.sub(r"f u c k", "fuck", sentence_repl) 154 | sentence_repl = re.sub(r"幹", "fuck", sentence_repl) 155 | sentence_repl = re.sub(r"死", "die", sentence_repl) 156 | sentence_repl = re.sub(r"他妈的", "fuck", sentence_repl) 157 | sentence_repl = re.sub(r"去你妈的", "fuck off", sentence_repl) 158 | sentence_repl = re.sub(r"肏你妈", "fuck your mother", sentence_repl) 159 | sentence_repl = re.sub( 160 | r"肏你祖宗十八代", "your ancestors to the 18th generation", sentence_repl) 161 | return sentence_repl 162 | 163 | def full2half(self, sentence): 164 | ret_str = '' 165 | for i in sentence: 166 | if ord(i) >= 33 + 65248 and ord(i) <= 126 + 65248: 167 | ret_str += chr(ord(i) - 65248) 168 | else: 169 | ret_str += i 170 | return ret_str 171 | 172 | def remove_stopword(self, sentence): 173 | words = sentence.split() 174 | x = [word for word in words if word not in self.stopwords] 175 | return " ".join(x) 176 | 177 | # 主函数 178 | def __call__(self, sentence): 179 | x = sentence 180 | # x = self.lower(x) 181 | x = self.replace(x) 182 | x = self.remove_website(x) 183 | x = self.remove_name_tag(x) 184 | x = self.remove_time(x) 185 | x = self.remove_breaks(x) 186 | x = self.remove_ip(x) 187 | x = self.adjust_common(x) 188 | x = self.remove_chinese(x) 189 | return x 190 | 191 | -------------------------------------------------------------------------------- /mltc/engineerings/engineering.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import pandas as pd 3 | import jieba 4 | import jieba.analyse 5 | 6 | from pnlp import ptxt 7 | 8 | ALLOW_POS = ( 9 | 'a', 'ad', 'ag', 'an', 10 | 'n', 'ng', 'nr', 'nrfg', 'nrt', 'ns', 'nt', 'nz', 11 | 'vn', 'v' 12 | ) 13 | 14 | class Engineer: 15 | 16 | def __init__(self, text_list: list): 17 | self.df = pd.DataFrame(text_list, columns=["text"]) 18 | 19 | @property 20 | def length_related_features(self) -> pd.DataFrame: 21 | data = self.df.copy() 22 | 23 | data["len"] = data["text"].apply(len) 24 | 25 | need_cols = [ 26 | "len" 27 | ] 28 | 29 | return data[need_cols] 30 | 31 | @property 32 | def word_related_features(self) -> pd.DataFrame: 33 | data = self.df.copy() 34 | 35 | pzh = re.compile(r'[\u4e00-\u9fa5]+') 36 | 37 | data["words"] = data["text"].apply( 38 | lambda x: jieba.lcut(x)) 39 | data["words_zh"] = data["words"].apply( 40 | lambda x: [w for w in x if pzh.search(w)]) 41 | 42 | data["word_num"] = data["words"].apply(lambda x: len(x)) 43 | 44 | data["punc_num"] = data["text"].apply(lambda x: ptxt.Text(x).len_pun) 45 | data["punc_num_ratio"] = data["punc_num"] / data["word_num"] 46 | 47 | data["num_num"] = data["text"].apply(lambda x: ptxt.Text(x).len_num) 48 | data["num_num_ratio"] = data["num_num"] / data["word_num"] 49 | 50 | data["appear_once"] = data["words_zh"].apply( 51 | lambda x: len([w for (w, f) in Counter(x).items() if f == 1])) 52 | data["hapax_legomena1"] = data["appear_once"] / data["word_num"] 53 | 54 | data["appear_twice"] = data["words_zh"].apply( 55 | lambda x: len([w for (w, f) in Counter(x).items() if f == 2])) 56 | data["hapax_legomena2"] = data["appear_twice"] / data["word_num"] 57 | 58 | data["one_char_num"] = data["words_zh"].apply( 59 | lambda x: len([w for w in x if len(w) == 1])) 60 | data["one_char_ratio"] = data["one_char_num"] / data["word_num"] 61 | 62 | data["two_char_num"] = data["words_zh"].apply( 63 | lambda x: len([w for w in x if len(w) == 2])) 64 | data["two_char_ratio"] = data["two_char_num"] / data["word_num"] 65 | 66 | data["three_char_num"] = data["words_zh"].apply( 67 | lambda x: len([w for w in x if len(w) == 3])) 68 | data["three_char_ratio"] = data["three_char_num"] / data["word_num"] 69 | 70 | data["four_char_num"] = data["words_zh"].apply( 71 | lambda x: len([w for w in x if len(w) == 4])) 72 | data["four_char_ratio"] = data["four_char_num"] / data["word_num"] 73 | 74 | data["ttr"] = data["words"].apply( 75 | lambda x: len(set(x)) / len(x)) 76 | 77 | need_cols = [ 78 | "word_num", 79 | "punc_num_ratio", 80 | "num_num_ratio", 81 | "hapax_legomena1", 82 | "hapax_legomena2", 83 | "one_char_ratio", 84 | "two_char_ratio", 85 | "three_char_ratio", 86 | "four_char_ratio", 87 | "ttr" 88 | ] 89 | return data[need_cols] 90 | 91 | @property 92 | def sent_related_features(self) -> pd.DataFrame: 93 | data = self.df.copy() 94 | 95 | rule = re.compile(r'[,、。?!”……]+') 96 | data["short_sent_num"] = data["text"].apply( 97 | lambda x: len([ss for ss in rule.split(x) if len(ss) > 1])) 98 | 99 | rule = re.compile(r'[。?!”……]+') 100 | data["sent_num"] = data["text"].apply( 101 | lambda x: len([ss for ss in rule.split(x) if len(ss) > 1])) 102 | 103 | need_cols = [ 104 | "short_sent_num", 105 | "sent_num" 106 | ] 107 | 108 | return data[need_cols] 109 | 110 | @property 111 | def content_related_features(self) -> pd.DataFrame: 112 | data = self.df.copy() 113 | 114 | # 被告数量 115 | rule = re.compile(r'被告') 116 | data["defendant_num"] = data["text"].apply( 117 | lambda x: len(rule.findall(x))) 118 | 119 | # 被告中男性比例 120 | rule = re.compile(r'被告.*男.*[。!?……”]+') 121 | data["defendant_male_num"] = data["text"].apply( 122 | lambda x: len(rule.findall(x))) 123 | data["defendante_male_ratio"] = data[ 124 | "defendant_male_num"] / data["defendant_num"] 125 | 126 | # 被告中法定代表人比例 127 | rule = re.compile(r'被告.*法定代表人.*[。!?……”]+') 128 | data["defendant_company_num"] = data["text"].apply( 129 | lambda x: len(rule.findall(x))) 130 | data["defendante_company_ratio"] = data[ 131 | "defendant_company_num"] / data["defendant_num"] 132 | 133 | # 担保 134 | rule = re.compile(r'担保') 135 | data["guarantee_num"] = data["text"].apply( 136 | lambda x: len(rule.findall(x))) 137 | 138 | need_cols = [ 139 | "defendant_num", 140 | "defendante_male_ratio", 141 | "defendante_company_ratio", 142 | "guarantee_num" 143 | ] 144 | 145 | return data[need_cols] 146 | 147 | @property 148 | def keywords(self): 149 | data = self.df.copy() 150 | 151 | data["keywords"] = data["text"].apply( 152 | lambda x: KeyWords(x)(10)) 153 | 154 | return data[["keywords"]] 155 | 156 | def __call__(self): 157 | data = pd.concat([ 158 | self.length_related_features, 159 | self.word_related_features, 160 | self.sent_related_features, 161 | self.content_related_features 162 | ], axis=1) 163 | # qt = QuantileTransformer(random_state=2019) 164 | # qt_features = qt.fit_transform(data) 165 | return data 166 | 167 | 168 | class KeyWords: 169 | 170 | def __init__(self, text): 171 | self.text = text 172 | 173 | @property 174 | def tfidf(self) -> list: 175 | kw_with_weight = jieba.analyse.extract_tags( 176 | self.text, allowPOS=ALLOW_POS, withWeight=True) 177 | return self.standardize(kw_with_weight) 178 | 179 | @property 180 | def textrank(self) -> list: 181 | kw_with_weight = jieba.analyse.textrank( 182 | self.text, allowPOS=ALLOW_POS, withWeight=True) 183 | return self.standardize(kw_with_weight) 184 | 185 | def standardize(self, kw_with_weight: list) -> list: 186 | words, weights = [], [] 187 | for w, p in kw_with_weight: 188 | words.append(w) 189 | weights.append(p) 190 | arr = np.array(weights) 191 | sumw = np.sum(arr) 192 | new_weights = arr / sumw 193 | kw_standardized = [(words[i], new_weights[i]) 194 | for i in range(len(words))] 195 | return kw_standardized 196 | 197 | def __call__(self, topk=5) -> list: 198 | union_kwd = {} 199 | idf_kwd = dict(self.tfidf) 200 | trk_kwd = dict(self.textrank) 201 | union = set(idf_kwd.keys()) & set(trk_kwd.keys()) 202 | for w in union: 203 | union_kwd[w] = idf_kwd[w] + trk_kwd[w] 204 | sort_kws = sorted(union_kwd.items(), key=lambda x: x[1], reverse=True) 205 | res = [w for (w,f) in std_kws][:topk] 206 | return res 207 | 208 | 209 | -------------------------------------------------------------------------------- /mltc/postprocessors/bert.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | import numpy as np 4 | from utils.utils import load_pickle, logger 5 | from callback.progressbar import ProgressBar 6 | from pnlp import piop 7 | from torch.utils.data import TensorDataset 8 | from pytorch_transformers import BertTokenizer 9 | 10 | 11 | class InputExample(object): 12 | def __init__(self, guid, text_a, text_b=None, label=None): 13 | """Constructs a InputExample. 14 | Args: 15 | guid: Unique id for the example. 16 | text_a: string. The untokenized text of the first sequence. For single 17 | sequence tasks, only this sequence must be specified. 18 | text_b: (Optional) string. The untokenized text of the second sequence. 19 | Only must be specified for sequence pair tasks. 20 | label: (Optional) string. The label of the example. This should be 21 | specified for train and dev examples, but not for test examples. 22 | """ 23 | self.guid = guid 24 | self.text_a = text_a 25 | self.text_b = text_b 26 | self.label = label 27 | 28 | 29 | class InputFeature(object): 30 | ''' 31 | A single set of features of data. 32 | ''' 33 | 34 | def __init__(self, input_ids, input_mask, segment_ids, 35 | label_id, input_len): 36 | self.input_ids = input_ids 37 | self.input_mask = input_mask 38 | self.segment_ids = segment_ids 39 | self.label_id = label_id 40 | self.input_len = input_len 41 | 42 | 43 | class BertProcessor(object): 44 | """Base class for data converters for sequence classification data sets.""" 45 | 46 | def __init__(self, vocab_path, do_lower_case): 47 | self.tokenizer = BertTokenizer(vocab_path, do_lower_case) 48 | 49 | @property 50 | def vocab_size(self): 51 | return self.tokenizer.vocab_size 52 | 53 | def get_train(self, data_file): 54 | """Gets a collection of `InputExample`s for the train set.""" 55 | return self.read_data(data_file) 56 | 57 | def get_dev(self, data_file): 58 | """Gets a collection of `InputExample`s for the dev set.""" 59 | return self.read_data(data_file) 60 | 61 | def get_test(self, lines): 62 | return lines 63 | 64 | def get_labels(self, data_file): 65 | return [lb for lb in piop.read_lines(data_file) if len(lb) > 0] 66 | 67 | @classmethod 68 | def read_data(cls, input_file): 69 | return load_pickle(input_file) 70 | 71 | def truncate_seq_pair(self, tokens_a, tokens_b, max_length): 72 | """ 73 | # This is a simple heuristic which will always truncate the longer sequence 74 | # one token at a time. This makes more sense than truncating an equal percent 75 | # of tokens from each, since if one sequence is very short then each token 76 | # that's truncated likely contains more information than a longer sequence. 77 | """ 78 | while True: 79 | total_length = len(tokens_a) + len(tokens_b) 80 | if total_length <= max_length: 81 | break 82 | if len(tokens_a) > len(tokens_b): 83 | tokens_a.pop() 84 | else: 85 | tokens_b.pop() 86 | 87 | def create_examples(self, lines, example_type, cached_examples_file): 88 | ''' 89 | Creates examples for data 90 | ''' 91 | pbar = ProgressBar(n_total=len(lines)) 92 | if cached_examples_file.exists(): 93 | logger.info("Loading examples from cached file %s", 94 | cached_examples_file) 95 | examples = torch.load(cached_examples_file) 96 | else: 97 | examples = [] 98 | for i, line in enumerate(lines): 99 | guid = '%s-%d' % (example_type, i) 100 | text_a = line[0] 101 | label = line[1] 102 | if isinstance(label, str): 103 | label = [np.float(x) for x in label.split(",")] 104 | else: 105 | label = [np.float(x) for x in list(label)] 106 | text_b = None 107 | example = InputExample( 108 | guid=guid, text_a=text_a, text_b=text_b, label=label) 109 | examples.append(example) 110 | pbar.batch_step(step=i, info={}, bar_type='create examples') 111 | logger.info("Saving examples into cached file %s", 112 | cached_examples_file) 113 | torch.save(examples, cached_examples_file) 114 | return examples 115 | 116 | def create_features(self, examples, max_seq_len, cached_features_file): 117 | ''' 118 | # The convention in BERT is: 119 | # (a) For sequence pairs: 120 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 121 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 122 | # (b) For single sequences: 123 | # tokens: [CLS] the dog is hairy . [SEP] 124 | # type_ids: 0 0 0 0 0 0 0 125 | ''' 126 | pbar = ProgressBar(n_total=len(examples)) 127 | if cached_features_file.exists(): 128 | logger.info("Loading features from cached file %s", 129 | cached_features_file) 130 | features = torch.load(cached_features_file) 131 | else: 132 | features = [] 133 | for ex_id, example in enumerate(examples): 134 | tokens_a = self.tokenizer.tokenize(example.text_a) 135 | tokens_b = None 136 | label_id = example.label 137 | 138 | if example.text_b: 139 | tokens_b = self.tokenizer.tokenize(example.text_b) 140 | # Modifies `tokens_a` and `tokens_b` in place 141 | # so that the total 142 | # length is less than the specified length. 143 | # Account for [CLS], [SEP], [SEP] with "- 3" 144 | self.truncate_seq_pair( 145 | tokens_a, tokens_b, max_length=max_seq_len - 3) 146 | else: 147 | # Account for [CLS] and [SEP] with '-2' 148 | if len(tokens_a) > max_seq_len - 2: 149 | tokens_a = tokens_a[:max_seq_len - 2] 150 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] 151 | segment_ids = [0] * len(tokens) 152 | if tokens_b: 153 | tokens += tokens_b + ['[SEP]'] 154 | segment_ids += [1] * (len(tokens_b) + 1) 155 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 156 | input_mask = [1] * len(input_ids) 157 | padding = [0] * (max_seq_len - len(input_ids)) 158 | input_len = len(input_ids) 159 | 160 | input_ids += padding 161 | input_mask += padding 162 | segment_ids += padding 163 | 164 | assert len(input_ids) == max_seq_len 165 | assert len(input_mask) == max_seq_len 166 | assert len(segment_ids) == max_seq_len 167 | 168 | if ex_id < 2: 169 | logger.info("*** Example ***") 170 | logger.info(f"guid: {example.guid}" % ()) 171 | logger.info(f"tokens: {' '.join([str(x) for x in tokens])}") 172 | logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}") 173 | logger.info(f"input_mask: {' '.join([str(x) for x in input_mask])}") 174 | logger.info(f"segment_ids: {' '.join([str(x) for x in segment_ids])}") 175 | 176 | feature = InputFeature(input_ids=input_ids, 177 | input_mask=input_mask, 178 | segment_ids=segment_ids, 179 | label_id=label_id, 180 | input_len=input_len) 181 | features.append(feature) 182 | pbar.batch_step(step=ex_id, info={}, 183 | bar_type='create features') 184 | logger.info("Saving features into cached file %s", 185 | cached_features_file) 186 | torch.save(features, cached_features_file) 187 | return features 188 | 189 | def create_dataset(self, features, is_sorted=False): 190 | # Convert to Tensors and build dataset 191 | if is_sorted: 192 | logger.info("sorted data by th length of input") 193 | features = sorted( 194 | features, key=lambda x: x.input_len, reverse=True) 195 | all_input_ids = torch.tensor( 196 | [f.input_ids for f in features], dtype=torch.long) 197 | all_input_mask = torch.tensor( 198 | [f.input_mask for f in features], dtype=torch.long) 199 | all_segment_ids = torch.tensor( 200 | [f.segment_ids for f in features], dtype=torch.long) 201 | all_label_ids = torch.tensor( 202 | [f.label_id for f in features], dtype=torch.long) 203 | dataset = TensorDataset( 204 | all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 205 | return dataset 206 | -------------------------------------------------------------------------------- /mltc/train/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from callback.progressbar import ProgressBar 4 | from utils.utils import (restore_checkpoint, model_device, 5 | summary, seed_everything, AverageMeter) 6 | from torch.nn.utils import clip_grad_norm_ 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, n_gpu, 11 | model, 12 | epochs, 13 | logger, 14 | criterion, 15 | optimizer, 16 | lr_scheduler, 17 | early_stopping, 18 | epoch_metrics, 19 | batch_metrics, 20 | gradient_accumulation_steps, 21 | grad_clip=0.0, 22 | verbose=1, 23 | fp16=None, 24 | resume_path=None, 25 | training_monitor=None, 26 | model_checkpoint=None 27 | ): 28 | self.start_epoch = 1 29 | self.global_step = 0 30 | self.n_gpu = n_gpu 31 | self.model = model 32 | self.epochs = epochs 33 | self.logger = logger 34 | self.fp16 = fp16 35 | self.grad_clip = grad_clip 36 | self.verbose = verbose 37 | self.criterion = criterion 38 | self.optimizer = optimizer 39 | self.lr_scheduler = lr_scheduler 40 | self.early_stopping = early_stopping 41 | self.epoch_metrics = epoch_metrics 42 | self.batch_metrics = batch_metrics 43 | self.model_checkpoint = model_checkpoint 44 | self.training_monitor = training_monitor 45 | self.gradient_accumulation_steps = gradient_accumulation_steps 46 | self.model, self.device = model_device( 47 | n_gpu=self.n_gpu, model=self.model) 48 | if self.fp16: 49 | try: 50 | from apex import amp 51 | except ImportError: 52 | raise ImportError( 53 | "Please install apex from \ 54 | https://www.github.com/nvidia/apex to use fp16 training.") 55 | 56 | if resume_path: 57 | self.logger.info(f"\nLoading checkpoint: {resume_path}") 58 | resume_dict = torch.load(resume_path / 'checkpoint_info.bin') 59 | best = resume_dict['epoch'] 60 | self.start_epoch = resume_dict['epoch'] 61 | if self.model_checkpoint: 62 | self.model_checkpoint.best = best 63 | self.logger.info(f"\nCheckpoint '{resume_path}' \ 64 | and epoch {self.start_epoch} loaded") 65 | 66 | def epoch_reset(self): 67 | self.outputs = [] 68 | self.targets = [] 69 | self.result = {} 70 | for metric in self.epoch_metrics: 71 | metric.reset() 72 | 73 | def batch_reset(self): 74 | self.info = {} 75 | for metric in self.batch_metrics: 76 | metric.reset() 77 | 78 | def save_info(self, epoch, best): 79 | model_save = self.model.module if hasattr( 80 | self.model, 'module') else self.model 81 | state = {"model": model_save, 82 | 'epoch': epoch, 83 | 'best': best} 84 | return state 85 | 86 | def valid_epoch(self, data): 87 | pbar = ProgressBar(n_total=len(data)) 88 | self.epoch_reset() 89 | self.model.eval() 90 | with torch.no_grad(): 91 | for step, batch in enumerate(data): 92 | batch = tuple(t.to(self.device) for t in batch) 93 | input_ids, input_mask, segment_ids, label_ids = batch 94 | logits = self.model(input_ids, input_mask, segment_ids) 95 | self.outputs.append(logits.cpu().detach()) 96 | self.targets.append(label_ids.cpu().detach()) 97 | pbar.batch_step(step=step, info={}, bar_type='Evaluating') 98 | self.outputs = torch.cat(self.outputs, dim=0).cpu().detach() 99 | self.targets = torch.cat(self.targets, dim=0).cpu().detach() 100 | loss = self.criterion(target=self.targets, output=self.outputs) 101 | self.result['valid_loss'] = loss.item() 102 | print("------------- valid result --------------") 103 | if self.epoch_metrics: 104 | for metric in self.epoch_metrics: 105 | metric(logits=self.outputs, target=self.targets) 106 | value = metric.value() 107 | if value: 108 | self.result[f'valid_{metric.name()}'] = value 109 | if 'cuda' in str(self.device): 110 | torch.cuda.empty_cache() 111 | return self.result 112 | 113 | def train_epoch(self, data): 114 | pbar = ProgressBar(n_total=len(data)) 115 | tr_loss = AverageMeter() 116 | self.epoch_reset() 117 | for step, batch in enumerate(data): 118 | self.batch_reset() 119 | self.model.train() 120 | batch = tuple(t.to(self.device) for t in batch) 121 | input_ids, input_mask, segment_ids, label_ids = batch 122 | print("input_ids, input_mask, segment_ids, label_ids SIZE: \n") 123 | print(input_ids.size(), input_mask.size(), 124 | segment_ids.size(), label_ids.size()) 125 | logits = self.model(input_ids, input_mask, segment_ids) 126 | print("logits and label ids size: ", 127 | logits.size(), label_ids.size()) 128 | loss = self.criterion(output=logits, target=label_ids) 129 | if len(self.n_gpu) >= 2: 130 | loss = loss.mean() 131 | if self.gradient_accumulation_steps > 1: 132 | loss = loss / self.gradient_accumulation_steps 133 | if self.fp16: 134 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 135 | scaled_loss.backward() 136 | clip_grad_norm_(amp.master_params( 137 | self.optimizer), self.grad_clip) 138 | else: 139 | loss.backward() 140 | clip_grad_norm_(self.model.parameters(), self.grad_clip) 141 | if (step + 1) % self.gradient_accumulation_steps == 0: 142 | self.lr_scheduler.step() 143 | self.optimizer.step() 144 | self.optimizer.zero_grad() 145 | self.global_step += 1 146 | if self.batch_metrics: 147 | for metric in self.batch_metrics: 148 | metric(logits=logits, target=label_ids) 149 | self.info[metric.name()] = metric.value() 150 | self.info['loss'] = loss.item() 151 | tr_loss.update(loss.item(), n=1) 152 | if self.verbose >= 1: 153 | pbar.batch_step(step=step, info=self.info, bar_type='Training') 154 | self.outputs.append(logits.cpu().detach()) 155 | self.targets.append(label_ids.cpu().detach()) 156 | print("\n------------- train result --------------") 157 | # epoch metric 158 | self.outputs = torch.cat(self.outputs, dim=0).cpu().detach() 159 | self.targets = torch.cat(self.targets, dim=0).cpu().detach() 160 | self.result['loss'] = tr_loss.avg 161 | if self.epoch_metrics: 162 | for metric in self.epoch_metrics: 163 | metric(logits=self.outputs, target=self.targets) 164 | value = metric.value() 165 | if value: 166 | self.result[f'{metric.name()}'] = value 167 | if "cuda" in str(self.device): 168 | torch.cuda.empty_cache() 169 | return self.result 170 | 171 | def train(self, train_data, valid_data, seed): 172 | seed_everything(seed) 173 | print("model summary info: ") 174 | for step, (input_ids, input_mask, segment_ids, 175 | label_ids) in enumerate(train_data): 176 | input_ids = input_ids.to(self.device) 177 | input_mask = input_mask.to(self.device) 178 | segment_ids = segment_ids.to(self.device) 179 | summary(self.model, *(input_ids, input_mask, 180 | segment_ids), show_input=True) 181 | break 182 | 183 | # *************************************************************** 184 | for epoch in range(self.start_epoch, self.start_epoch+self.epochs): 185 | self.logger.info(f"Epoch {epoch}/{self.epochs}") 186 | train_log = self.train_epoch(train_data) 187 | valid_log = self.valid_epoch(valid_data) 188 | 189 | logs = dict(train_log, **valid_log) 190 | show_info = f'\nEpoch: {epoch} - ' + "-".join( 191 | [f' {key}: {value:.4f} ' for 192 | key, value in logs.items()]) 193 | self.logger.info(show_info) 194 | 195 | # save 196 | if self.training_monitor: 197 | self.training_monitor.epoch_step(logs) 198 | 199 | # save model 200 | if self.model_checkpoint: 201 | state = self.save_info(epoch, best=logs['valid_loss']) 202 | self.model_checkpoint.bert_epoch_step( 203 | current=logs[self.model_checkpoint.monitor], state=state) 204 | 205 | # early_stopping 206 | if self.early_stopping: 207 | self.early_stopping.epoch_step( 208 | epoch=epoch, current=logs[self.early_stopping.monitor]) 209 | if self.early_stopping.stop_training: 210 | break 211 | -------------------------------------------------------------------------------- /mltc/models/bert_for_multi_label.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel 5 | from configs.basic_config import config as basic_config 6 | 7 | 8 | class BertFCForMultiLable(BertPreTrainedModel): 9 | def __init__(self, config): 10 | 11 | super(BertFCForMultiLable, self).__init__(config) 12 | # bert = BertModel.from_pretrained(bert_model_path) 13 | self.bert = BertModel(config) 14 | for param in self.bert.parameters(): 15 | param.requires_grad = True 16 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 17 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 18 | self.apply(self.init_weights) 19 | 20 | def forward(self, input_ids, 21 | attention_mask=None, token_type_ids=None, head_mask=None): 22 | """ 23 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 24 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 25 | Sequence of hidden-states at the output of the last layer of the model. 26 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 27 | Last layer hidden-state of the first token of the sequence (classification token) 28 | further processed by a Linear layer and a Tanh activation function. The Linear 29 | layer weights are trained from the next sentence prediction (classification) 30 | objective during Bert pretraining. This output is usually *not* a good summary 31 | of the semantic content of the input, you're often better with averaging or pooling 32 | the sequence of hidden-states for the whole input sequence. 33 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 34 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 35 | of shape ``(batch_size, sequence_length, hidden_size)``: 36 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 37 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 38 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 39 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 40 | Examples:: 41 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 42 | model = BertModel.from_pretrained('bert-base-uncased') 43 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 44 | outputs = model(input_ids) 45 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 46 | """ 47 | outputs = self.bert(input_ids, 48 | attention_mask=attention_mask, 49 | token_type_ids=token_type_ids, 50 | head_mask=head_mask) 51 | pooled_output = outputs[1] 52 | pooled_output = self.dropout(pooled_output) 53 | logits = self.classifier(pooled_output) 54 | return logits 55 | 56 | def unfreeze(self, start_layer, end_layer): 57 | def children(m): 58 | return m if isinstance(m, (list, tuple)) else list(m.children()) 59 | 60 | def set_trainable_attr(m, b): 61 | m.trainable = b 62 | for p in m.parameters(): 63 | p.requires_grad = b 64 | 65 | def apply_leaf(m, f): 66 | c = children(m) 67 | if isinstance(m, nn.Module): 68 | f(m) 69 | if len(c) > 0: 70 | for l in c: 71 | apply_leaf(l, f) 72 | 73 | def set_trainable(l, b): 74 | apply_leaf(l, lambda m: set_trainable_attr(m, b)) 75 | 76 | # You can unfreeze the last layer of bert 77 | # by calling set_trainable(model.bert.encoder.layer[23], True) 78 | set_trainable(self.bert, False) 79 | for i in range(start_layer, end_layer+1): 80 | set_trainable(self.bert.encoder.layer[i], True) 81 | 82 | 83 | class BertCNNForMultiLabel(BertPreTrainedModel): 84 | 85 | def __init__(self, config): 86 | super(BertPreTrainedModel, self).__init__(config) 87 | config.num_filters = basic_config.cnn.num_filters 88 | config.filter_sizes = basic_config.cnn.filter_sizes 89 | config.dropout = basic_config.dropout 90 | 91 | self.bert = BertModel(config) 92 | for param in self.bert.parameters(): 93 | param.requires_grad = True 94 | self.convs = nn.ModuleList( 95 | [nn.Conv2d(1, config.num_filters, (k, config.hidden_size)) 96 | for k in config.filter_sizes]) 97 | self.dropout = nn.Dropout(config.dropout) 98 | self.fc_cnn = nn.Linear(config.num_filters * 99 | len(config.filter_sizes), config.num_labels) 100 | 101 | def conv_and_pool(self, x, conv): 102 | x = F.relu(conv(x)).squeeze(3) 103 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 104 | return x 105 | 106 | def forward(self, input_ids, 107 | attention_mask=None, token_type_ids=None, head_mask=None): 108 | outputs = self.bert(input_ids, 109 | attention_mask=attention_mask, 110 | token_type_ids=token_type_ids, 111 | head_mask=head_mask) 112 | encoder_out, text_cls = outputs 113 | out = encoder_out.unsqueeze(1) 114 | out = torch.cat([self.conv_and_pool(out, conv) 115 | for conv in self.convs], 1) 116 | out = self.dropout(out) 117 | out = self.fc_cnn(out) 118 | return out 119 | 120 | 121 | class BertRCNNForMultiLabel(BertPreTrainedModel): 122 | 123 | def __init__(self, config): 124 | super(BertPreTrainedModel, self).__init__(config) 125 | config.rnn_hidden = basic_config.rcnn.rnn_hidden 126 | config.num_layers = basic_config.rcnn.num_layers 127 | config.kernel_size = basic_config.rcnn.kernel_size 128 | config.lstm_dropout = basic_config.rcnn.dropout 129 | 130 | self.bert = BertModel(config) 131 | for param in self.bert.parameters(): 132 | param.requires_grad = True 133 | self.lstm = nn.LSTM(config.hidden_size, 134 | config.rnn_hidden, 135 | config.num_layers, 136 | bidirectional=True, 137 | batch_first=True, 138 | dropout=config.lstm_dropout) 139 | self.maxpool = nn.MaxPool1d(config.kernel_size) 140 | self.fc = nn.Linear(config.rnn_hidden * 2 + 141 | config.hidden_size, config.num_labels) 142 | def forward(self, input_ids, 143 | attention_mask=None, token_type_ids=None, head_mask=None): 144 | outputs = self.bert(input_ids, 145 | attention_mask=attention_mask, 146 | token_type_ids=token_type_ids, 147 | head_mask=head_mask) 148 | encoder_out, text_cls = outputs 149 | out, _ = self.lstm(encoder_out) 150 | out = torch.cat((encoder_out, out), 2) 151 | out = F.relu(out) 152 | out = out.permute(0, 2, 1) 153 | out = self.maxpool(out).squeeze() 154 | out = self.fc(out) 155 | return out 156 | 157 | 158 | class BertDPCNNForMultiLabel(BertPreTrainedModel): 159 | 160 | def __init__(self, config): 161 | super(BertPreTrainedModel, self).__init__(config) 162 | config.kernel_size = basic_config.dpcnn.kernel_size 163 | config.num_filters = basic_config.dpcnn.num_filters 164 | 165 | self.bert = BertModel(config) 166 | for param in self.bert.parameters(): 167 | param.requires_grad = True 168 | self.conv_region = nn.Conv2d( 169 | 1, config.num_filters, (3, config.hidden_size), stride=1) 170 | self.conv = nn.Conv2d(config.num_filters, 171 | config.num_filters, (3, 1), stride=1) 172 | self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) 173 | self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom 174 | self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom 175 | self.relu = nn.ReLU() 176 | self.fc = nn.Linear(config.num_filters, config.num_labels) 177 | 178 | def forward(self, input_ids, 179 | attention_mask=None, token_type_ids=None, head_mask=None): 180 | outputs = self.bert(input_ids, 181 | attention_mask=attention_mask, 182 | token_type_ids=token_type_ids, 183 | head_mask=head_mask) 184 | encoder_out, text_cls = outputs 185 | x = encoder_out.unsqueeze(1) # [batch_size, 1, seq_len, embed] 186 | x = self.conv_region(x) # [batch_size, num_filters, seq_len-3+1, 1] 187 | x = self.padding1(x) # [batch_size, num_filters, seq_len, 1] 188 | x = self.relu(x) 189 | x = self.conv(x) # [batch_size, num_filters, seq_len-3+1, 1] 190 | x = self.padding1(x) # [batch_size, num_filters, seq_len, 1] 191 | x = self.relu(x) 192 | x = self.conv(x) # [batch_size, num_filters, seq_len-3+1, 1] 193 | while x.size()[2] > 2: 194 | x = self._block(x) 195 | x = x.squeeze() # [batch_size, num_filters] 196 | x = self.fc(x) 197 | return x 198 | 199 | def _block(self, x): 200 | x = self.padding2(x) 201 | px = self.max_pool(x) 202 | x = self.padding1(px) 203 | x = F.relu(x) 204 | x = self.conv(x) 205 | x = self.padding1(x) 206 | x = F.relu(x) 207 | x = self.conv(x) 208 | x = x + px # short cut 209 | return x 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /mltc/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from addict import Dict 3 | import copy 4 | import demjson 5 | import logging 6 | import random 7 | import json 8 | import pickle 9 | from pathlib import Path 10 | from collections import OrderedDict 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | logger = logging.getLogger() 17 | 18 | 19 | class AttrDict(Dict): 20 | """Dict that can get attribute by dot""" 21 | 22 | def __init__(self, *args, **kwargs): 23 | super(AttrDict, self).__init__(*args, **kwargs) 24 | # self.__dict__ = self 25 | 26 | def to_dict(self): 27 | """Serializes this instance to a Python dictionary.""" 28 | output = copy.deepcopy(self) 29 | return output 30 | 31 | def to_json_string(self): 32 | """Serializes this instance to a JSON string.""" 33 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 34 | 35 | def get_embeddings_from_file(embedding_file: str): 36 | pass 37 | 38 | 39 | def seed_everything(seed=1029): 40 | random.seed(seed) 41 | os.environ['PYTHONHASHSEED'] = str(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | # some cudnn methods can be random even after fixing the seed 47 | # unless you tell it to be deterministic 48 | torch.backends.cudnn.deterministic = True 49 | 50 | 51 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 52 | if isinstance(log_file, Path): 53 | log_file = str(log_file) 54 | 55 | log_format = logging.Formatter("%(message)s") 56 | logger = logging.getLogger() 57 | logger.setLevel(logging.INFO) 58 | console_handler = logging.StreamHandler() 59 | console_handler.setFormatter(log_format) 60 | logger.handlers = [console_handler] 61 | if log_file: 62 | file_handler = logging.FileHandler(log_file) 63 | file_handler.setLevel(log_file_level) 64 | file_handler.setFormatter(log_format) 65 | logger.addHandler(file_handler) 66 | return logger 67 | 68 | 69 | def save_pickle(data, file_path): 70 | if isinstance(file_path, Path): 71 | file_path = str(file_path) 72 | with open(file_path, "wb") as f: 73 | pickle.dump(data, f) 74 | 75 | 76 | def load_pickle(input_file): 77 | with open(str(input_file), 'rb') as f: 78 | data = pickle.load(f) 79 | return data 80 | 81 | 82 | def deserializate(data): 83 | jr = demjson.decode(data.encode('utf8'), 84 | encoding='utf8', return_errors=True) 85 | return jr.object 86 | 87 | 88 | def prepare_device(use_gpu): 89 | """ 90 | setup GPU device if available, move model into configured device 91 | # 如果n_gpu_use为数字,则使用range生成list 92 | # 如果输入的是一个list,则默认使用list[0]作为controller 93 | Example: 94 | use_gpu = '' : cpu 95 | use_gpu = '0': cuda:0 96 | use_gpu = '0,1' : cuda:0 and cuda:1 97 | """ 98 | n_gpu_use = [int(x) for x in use_gpu.split(",")] 99 | if not use_gpu: 100 | device_type = 'cpu' 101 | else: 102 | device_type = "cuda:{}".format(n_gpu_use[0]) 103 | n_gpu = torch.cuda.device_count() 104 | if len(n_gpu_use) > 0 and n_gpu == 0: 105 | logger.warning("Warning: There\'s no GPU available on this machine, \ 106 | training will be performed on CPU.") 107 | device_type = 'cpu' 108 | if len(n_gpu_use) > n_gpu: 109 | msg = f"Warning: The number of GPU\'s configured to use is {n_gpu}, \ 110 | but only {n_gpu} are available on this machine." 111 | logger.warning(msg) 112 | n_gpu_use = range(n_gpu) 113 | device = torch.device(device_type) 114 | list_ids = n_gpu_use 115 | return device, list_ids 116 | 117 | 118 | def model_device(n_gpu, model): 119 | ''' 120 | :param n_gpu: 121 | :param model: 122 | :return: 123 | ''' 124 | device, device_ids = prepare_device(n_gpu) 125 | if len(device_ids) > 1: 126 | logger.info(f"current {len(device_ids)} GPUs") 127 | model = torch.nn.DataParallel(model, device_ids=device_ids) 128 | if len(device_ids) == 1: 129 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0]) 130 | model = model.to(device) 131 | return model, device 132 | 133 | 134 | def restore_checkpoint(resume_path, model=None): 135 | ''' 136 | 加载模型 137 | :param resume_path: 138 | :param model: 139 | :param optimizer: 140 | :return: 141 | 注意: 如果是加载Bert模型的话,需要调整,不能使用该模式 142 | 可以使用模块自带的Bert_model.from_pretrained(state_dict = your save state_dict) 143 | ''' 144 | if isinstance(resume_path, Path): 145 | resume_path = str(resume_path) 146 | checkpoint = torch.load(resume_path) 147 | best = checkpoint['best'] 148 | start_epoch = checkpoint['epoch'] + 1 149 | states = checkpoint['state_dict'] 150 | if isinstance(model, nn.DataParallel): 151 | model.module.load_state_dict(states) 152 | else: 153 | model.load_state_dict(states) 154 | return [model, best, start_epoch] 155 | 156 | 157 | class AverageMeter(object): 158 | ''' 159 | computes and stores the average and current value 160 | Example: 161 | >>> loss = AverageMeter() 162 | >>> for step,batch in enumerate(train_data): 163 | >>> pred = self.model(batch) 164 | >>> raw_loss = self.metrics(pred,target) 165 | >>> loss.update(raw_loss.item(),n = 1) 166 | >>> cur_loss = loss.avg 167 | ''' 168 | 169 | def __init__(self): 170 | self.reset() 171 | 172 | def reset(self): 173 | self.val = 0 174 | self.avg = 0 175 | self.sum = 0 176 | self.count = 0 177 | 178 | def update(self, val, n=1): 179 | self.val = val 180 | self.sum += val * n 181 | self.count += n 182 | self.avg = self.sum / self.count 183 | 184 | 185 | def summary(model, *inputs, batch_size=-1, show_input=True): 186 | ''' 187 | 打印模型结构信息 188 | :param model: 189 | :param inputs: 190 | :param batch_size: 191 | :param show_input: 192 | :return: 193 | Example: 194 | >>> print("model summary info: ") 195 | >>> for step,batch in enumerate(train_data): 196 | >>> summary(self.model,*batch,show_input=True) 197 | >>> break 198 | ''' 199 | 200 | def register_hook(module): 201 | def hook(module, input, output=None): 202 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 203 | module_idx = len(summary) 204 | 205 | m_key = f"{class_name}-{module_idx + 1}" 206 | summary[m_key] = OrderedDict() 207 | summary[m_key]["input_shape"] = list(input[0].size()) 208 | summary[m_key]["input_shape"][0] = batch_size 209 | 210 | if show_input is False and output is not None: 211 | if isinstance(output, (list, tuple)): 212 | for out in output: 213 | if isinstance(out, torch.Tensor): 214 | summary[m_key]["output_shape"] = [ 215 | [-1] + list(out.size())[1:] 216 | ][0] 217 | else: 218 | summary[m_key]["output_shape"] = [ 219 | [-1] + list(out[0].size())[1:] 220 | ][0] 221 | else: 222 | summary[m_key]["output_shape"] = list(output.size()) 223 | summary[m_key]["output_shape"][0] = batch_size 224 | 225 | params = 0 226 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 227 | params += torch.prod( 228 | torch.LongTensor(list(module.weight.size()))) 229 | summary[m_key]["trainable"] = module.weight.requires_grad 230 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 231 | params += torch.prod( 232 | torch.LongTensor(list(module.bias.size()))) 233 | summary[m_key]["nb_params"] = params 234 | 235 | if (not isinstance(module, nn.Sequential) and 236 | not isinstance(module, nn.ModuleList) and 237 | not (module == model)): 238 | if show_input is True: 239 | hooks.append(module.register_forward_pre_hook(hook)) 240 | else: 241 | hooks.append(module.register_forward_hook(hook)) 242 | 243 | # create properties 244 | summary = OrderedDict() 245 | hooks = [] 246 | 247 | # register hook 248 | model.apply(register_hook) 249 | model(*inputs) 250 | 251 | # remove these hooks 252 | for h in hooks: 253 | h.remove() 254 | 255 | print("-------------------------------------------------------") 256 | if show_input is True: 257 | line_new = f"{'Layer(type)':>25} {'Input Shape':>25} {'Param #':>15}" 258 | else: 259 | line_new = f"{'Layer(type)':>25} {'Output Shape':>25} {'Param #':>15}" 260 | print(line_new) 261 | print("========================================================") 262 | 263 | total_params = 0 264 | total_output = 0 265 | trainable_params = 0 266 | for layer in summary: 267 | # input_shape, output_shape, trainable, nb_params 268 | if show_input is True: 269 | line_new = "{:>25} {:>25} {:>15}".format( 270 | layer, 271 | str(summary[layer]["input_shape"]), 272 | "{0:,}".format(summary[layer]["nb_params"]), 273 | ) 274 | else: 275 | line_new = "{:>25} {:>25} {:>15}".format( 276 | layer, 277 | str(summary[layer]["output_shape"]), 278 | "{0:,}".format(summary[layer]["nb_params"]), 279 | ) 280 | 281 | total_params += summary[layer]["nb_params"] 282 | if show_input is True: 283 | total_output += np.prod(summary[layer]["input_shape"]) 284 | else: 285 | total_output += np.prod(summary[layer]["output_shape"]) 286 | if "trainable" in summary[layer]: 287 | if summary[layer]["trainable"] == True: 288 | trainable_params += summary[layer]["nb_params"] 289 | 290 | print(line_new) 291 | 292 | print("============================================================") 293 | print(f"Total params: {total_params:0,}") 294 | print(f"Trainable params: {trainable_params:0,}") 295 | print(f"Non-trainable params: {(total_params - trainable_params):0,}") 296 | print("------------------------------------------------------------") 297 | -------------------------------------------------------------------------------- /mltc/train/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | from sklearn.metrics import roc_auc_score 5 | from sklearn.metrics import f1_score, classification_report 6 | 7 | __call__ = ['Accuracy', 'AUC', 'F1Score', 'EntityScore', 8 | 'ClassReport', 'MultiLabelReport', 'AccuracyThresh'] 9 | 10 | 11 | class Metric: 12 | def __init__(self): 13 | pass 14 | 15 | def __call__(self, outputs, target): 16 | raise NotImplementedError 17 | 18 | def reset(self): 19 | raise NotImplementedError 20 | 21 | def value(self): 22 | raise NotImplementedError 23 | 24 | def name(self): 25 | raise NotImplementedError 26 | 27 | 28 | class Accuracy(Metric): 29 | ''' 30 | 计算准确度 31 | 可以使用topK参数设定计算K准确度 32 | Examples: 33 | >>> metric = Accuracy(**) 34 | >>> for epoch in range(epochs): 35 | >>> metric.reset() 36 | >>> for batch in batchs: 37 | >>> logits = model() 38 | >>> metric(logits,target) 39 | >>> print(metric.name(),metric.value()) 40 | ''' 41 | 42 | def __init__(self, topK): 43 | super(Accuracy, self).__init__() 44 | self.topK = topK 45 | self.reset() 46 | 47 | def __call__(self, logits, target): 48 | _, pred = logits.topk(self.topK, 1, True, True) 49 | pred = pred.t() 50 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 51 | self.correct_k = correct[:self.topK].view(-1).float().sum(0) 52 | self.total = target.size(0) 53 | 54 | def reset(self): 55 | self.correct_k = 0 56 | self.total = 0 57 | 58 | def value(self): 59 | return float(self.correct_k) / self.total 60 | 61 | def name(self): 62 | return 'accuracy' 63 | 64 | 65 | class AccuracyThresh(Metric): 66 | ''' 67 | 计算准确度 68 | 可以使用topK参数设定计算K准确度 69 | Example: 70 | >>> metric = AccuracyThresh(**) 71 | >>> for epoch in range(epochs): 72 | >>> metric.reset() 73 | >>> for batch in batchs: 74 | >>> logits = model() 75 | >>> metric(logits,target) 76 | >>> print(metric.name(),metric.value()) 77 | ''' 78 | 79 | def __init__(self, thresh=0.5): 80 | super(AccuracyThresh, self).__init__() 81 | self.thresh = thresh 82 | self.reset() 83 | 84 | def __call__(self, logits, target): 85 | self.y_pred = logits.sigmoid() 86 | self.y_true = target 87 | 88 | def reset(self): 89 | self.correct_k = 0 90 | self.total = 0 91 | 92 | def value(self): 93 | data_size = self.y_pred.size(0) 94 | acc = np.mean(((self.y_pred > self.thresh) == 95 | self.y_true.bool()).float().cpu().numpy(), axis=1).sum() 96 | return acc / data_size 97 | 98 | def name(self): 99 | return 'accuracy' 100 | 101 | 102 | class AUC(Metric): 103 | ''' 104 | AUC score 105 | micro: 106 | Calculate metrics globally by considering each element of the label 107 | indicator matrix as a label. 108 | macro: 109 | Calculate metrics for each label, and find their unweighted 110 | mean. This does not take label imbalance into account. 111 | weighted: 112 | Calculate metrics for each label, and find their average, weighted 113 | by support (the number of true instances for each label). 114 | samples: 115 | Calculate metrics for each instance, and find their average. 116 | Example: 117 | >>> metric = AUC(**) 118 | >>> for epoch in range(epochs): 119 | >>> metric.reset() 120 | >>> for batch in batchs: 121 | >>> logits = model() 122 | >>> metric(logits,target) 123 | >>> print(metric.name(),metric.value()) 124 | ''' 125 | 126 | def __init__(self, task_type='binary', average='binary'): 127 | super(AUC, self).__init__() 128 | 129 | assert task_type in ['binary', 'multiclass'] 130 | assert average in ['binary', 'micro', 'macro', 'samples', 'weighted'] 131 | 132 | self.task_type = task_type 133 | self.average = average 134 | 135 | def __call__(self, logits, target): 136 | ''' 137 | 计算整个结果 138 | ''' 139 | if self.task_type == 'binary': 140 | self.y_prob = logits.sigmoid().data.cpu().numpy() 141 | else: 142 | self.y_prob = logits.softmax(-1).data.cpu().detach().numpy() 143 | self.y_true = target.cpu().numpy() 144 | 145 | def reset(self): 146 | self.y_prob = 0 147 | self.y_true = 0 148 | 149 | def value(self): 150 | ''' 151 | 计算指标得分 152 | ''' 153 | auc = roc_auc_score(y_score=self.y_prob, 154 | y_true=self.y_true, average=self.average) 155 | return auc 156 | 157 | def name(self): 158 | return 'auc' 159 | 160 | 161 | class F1Score(Metric): 162 | ''' 163 | F1 Score 164 | binary: 165 | Only report results for the class specified by ``pos_label``. 166 | This is applicable only if targets (``y_{true,pred}``) are binary. 167 | micro: 168 | Calculate metrics globally by considering each element of the label 169 | indicator matrix as a label. 170 | macro: 171 | Calculate metrics for each label, and find their unweighted 172 | mean. This does not take label imbalance into account. 173 | weighted: 174 | Calculate metrics for each label, and find their average, weighted 175 | by support (the number of true instances for each label). 176 | samples: 177 | Calculate metrics for each instance, and find their average. 178 | Example: 179 | >>> metric = F1Score(**) 180 | >>> for epoch in range(epochs): 181 | >>> metric.reset() 182 | >>> for batch in batchs: 183 | >>> logits = model() 184 | >>> metric(logits,target) 185 | >>> print(metric.name(),metric.value()) 186 | ''' 187 | 188 | def __init__(self, thresh=0.5, 189 | normalizate=True, 190 | task_type='binary', 191 | average='binary', 192 | search_thresh=False): 193 | super(F1Score).__init__() 194 | assert task_type in ['binary', 'multiclass'] 195 | assert average in ['binary', 'micro', 'macro', 'samples', 'weighted'] 196 | 197 | self.thresh = thresh 198 | self.task_type = task_type 199 | self.normalizate = normalizate 200 | self.search_thresh = search_thresh 201 | self.average = average 202 | 203 | def thresh_search(self, y_prob): 204 | ''' 205 | 对于f1评分的指标,一般我们需要对阈值进行调整,一般不会使用默认的0.5值,因此 206 | 这里我们队Thresh进行优化 207 | :return: 208 | ''' 209 | best_threshold = 0 210 | best_score = 0 211 | for threshold in tqdm([i * 0.01 for i in range(100)], disable=True): 212 | self.y_pred = y_prob > threshold 213 | score = self.value() 214 | if score > best_score: 215 | best_threshold = threshold 216 | best_score = score 217 | return best_threshold, best_score 218 | 219 | def __call__(self, logits, target): 220 | ''' 221 | 计算整个结果 222 | :return: 223 | ''' 224 | self.y_true = target.cpu().numpy() 225 | if self.normalizate and self.task_type == 'binary': 226 | y_prob = logits.sigmoid().data.cpu().numpy() 227 | elif self.normalizate and self.task_type == 'multiclass': 228 | y_prob = logits.softmax(-1).data.cpu().detach().numpy() 229 | else: 230 | y_prob = logits.cpu().detach().numpy() 231 | 232 | if self.task_type == 'binary': 233 | if self.thresh and self.search_thresh == False: 234 | self.y_pred = (y_prob > self.thresh).astype(int) 235 | self.value() 236 | else: 237 | thresh, f1 = self.thresh_search(y_prob=y_prob) 238 | print(f"Best thresh: {thresh:.4f} - F1 Score: {f1:.4f}") 239 | 240 | if self.task_type == 'multiclass': 241 | self.y_pred = np.argmax(y_prob, 1) 242 | 243 | def reset(self): 244 | self.y_pred = 0 245 | self.y_true = 0 246 | 247 | def value(self): 248 | ''' 249 | 计算指标得分 250 | ''' 251 | f1 = f1_score(y_true=self.y_true, y_pred=self.y_pred, 252 | average=self.average) 253 | return f1 254 | 255 | def name(self): 256 | return 'f1' 257 | 258 | 259 | class ClassReport(Metric): 260 | ''' 261 | class report 262 | ''' 263 | 264 | def __init__(self, target_names=None): 265 | super(ClassReport).__init__() 266 | self.target_names = target_names 267 | 268 | def reset(self): 269 | self.y_pred = 0 270 | self.y_true = 0 271 | 272 | def value(self): 273 | ''' 274 | 计算指标得分 275 | ''' 276 | score = classification_report(y_true=self.y_true, 277 | y_pred=self.y_pred, 278 | target_names=self.target_names) 279 | print(f"\n\n classification report: {score}") 280 | 281 | def __call__(self, logits, target): 282 | _, y_pred = torch.max(logits.data, 1) 283 | self.y_pred = y_pred.cpu().numpy() 284 | self.y_true = target.cpu().numpy() 285 | 286 | def name(self): 287 | return "class_report" 288 | 289 | 290 | class MultiLabelReport(Metric): 291 | ''' 292 | multi label report 293 | ''' 294 | 295 | def __init__(self, id2label=None): 296 | super(MultiLabelReport).__init__() 297 | self.id2label = id2label 298 | 299 | def reset(self): 300 | self.y_prob = 0 301 | self.y_true = 0 302 | 303 | def __call__(self, logits, target): 304 | 305 | self.y_prob = logits.sigmoid().data.cpu().detach().numpy() 306 | self.y_true = target.cpu().numpy() 307 | 308 | def value(self): 309 | ''' 310 | 计算指标得分 311 | ''' 312 | for i, label in self.id2label.items(): 313 | try: 314 | auc = roc_auc_score( 315 | y_score=self.y_prob[:, i], y_true=self.y_true[:, i]) 316 | print(f"label:{label} - auc: {auc:.4f}") 317 | except ValueError as e: 318 | print(f"label:{label} - auc: can not calculate") 319 | 320 | def name(self): 321 | return "multilabel_report" 322 | -------------------------------------------------------------------------------- /mltc/preprocessors/chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from pnlp import ptxt, piop 4 | import jieba 5 | import jieba.posseg as pseg 6 | 7 | 8 | CN_NUM = { 9 | '〇': 0, '一': 1, '二': 2, '三': 3, '四': 4, '五': 5, 10 | '六': 6, '七': 7, '八': 8, '九': 9, '零': 0, 11 | '壹': 1, '贰': 2, '叁': 3, '肆': 4, '伍': 5, 12 | '陆': 6, '柒': 7, '捌': 8, '玖': 9, '貮': 2, '两': 2, 13 | } 14 | 15 | CN_UNIT = { 16 | '十': 10, 17 | '拾': 10, 18 | '百': 100, 19 | '佰': 100, 20 | '千': 1000, 21 | '仟': 1000, 22 | '万': 10000, 23 | '萬': 10000, 24 | '亿': 100000000, 25 | '億': 100000000, 26 | '兆': 1000000000000, 27 | } 28 | 29 | 30 | class ChineseProcessor: 31 | 32 | """ 33 | 34 | Nodes 35 | ------ 36 | replace info: 37 | - S ==> Small 38 | - T ==> Tiny 39 | - B ==> Big 40 | - H ==> Huge 41 | - V ==> Very xx 42 | 43 | - P ==> Person 44 | - L ==> Location 45 | - O ==> Organization 46 | 47 | - E ==> English 48 | """ 49 | 50 | def __init__(self): 51 | self.stopwords_set = set() 52 | self.pun_zh = r",。;、?!:“”‘’()「」『』〔〕【】《》〈〉…——\-—~~·" 53 | self.pun_en = r",.;?!\(\)\[\]\{\}<>_" 54 | self.cn_num = "".join(list(CN_NUM.keys())) 55 | self.cn_unit = "".join(list(CN_UNIT.keys())) 56 | self.year = "〇一二三四五六七八九十零" 57 | self.month = "一二三四五六七八九十" 58 | self.weight = "一二三四五六七八九十百千万亿" 59 | 60 | def reset(self, stopwords_path): 61 | if stopwords_path and os.path.exists(stopwords_path): 62 | self.stopwords_set = set(piop.read_lines(stopwords_path)) 63 | 64 | def cnnum2num(self, text: str, unit: str): 65 | rule = re.compile(rf'[{self.cn_num + self.cn_unit}]+{unit}') 66 | ca = Chinese2Arabic() 67 | text = rule.sub(lambda x: str(ca(x.group()[:-len(unit)])) + unit, text) 68 | return text 69 | 70 | def concentration_convert(self, concern: float): 71 | num = concern 72 | try: 73 | num = float(num) 74 | except Exception as e: 75 | return concern 76 | 77 | if num < 1.0: 78 | return "S" 79 | elif num < 10.0: 80 | return "T" 81 | elif num < 100.0: 82 | return "B" 83 | elif num < 1000.0: 84 | return "H" 85 | else: 86 | return "V" 87 | 88 | def quantity_convert(self, input_quantity): 89 | dct = { 90 | "万": 10000, 91 | "亿": 100000000 92 | } 93 | times = [] 94 | quantity = input_quantity 95 | while quantity and quantity[-1] in ["万", "亿"]: 96 | times.append(quantity[-1]) 97 | quantity = quantity[:-1] 98 | try: 99 | quantity = float(quantity) 100 | except Exception as e: 101 | return input_quantity 102 | for t in times: 103 | quantity *= dct.get(t) 104 | 105 | if quantity < 100: 106 | return "数十" 107 | elif quantity < 1000: 108 | return "数百" 109 | elif quantity < 10000: 110 | return "数千" 111 | elif quantity < 100000: 112 | return "数万" 113 | elif quantity < 1000000: 114 | return "数十万" 115 | elif quantity < 10000000: 116 | return "数百万" 117 | elif quantity < 100000000: 118 | return "数千万" 119 | elif quantity < 1000000000: 120 | return "数亿" 121 | elif quantity < 10000000000: 122 | return "数十亿" 123 | elif quantity < 100000000000: 124 | return "数百亿" 125 | elif quantity < 1000000000000: 126 | return "数千亿" 127 | else: 128 | return "数" 129 | 130 | def clean_punctuation(self, text): 131 | rule = re.compile(rf'[{self.pun_zh + self.pun_en}]+') 132 | text = rule.sub(" ", text) 133 | return text 134 | 135 | def clean_linkpic(self, text): 136 | text = ptxt.Text(text, 'pic').clean 137 | text = ptxt.Text(text, 'lnk').clean 138 | return text 139 | 140 | def clean_date(self, text): 141 | rule = re.compile(rf'[\d{self.year}]+年') 142 | text = rule.sub("X年", text) 143 | 144 | rule = re.compile(rf'[\d{self.month}]+月') 145 | text = rule.sub("X月", text) 146 | 147 | rule = re.compile(rf'[\d{self.month}]+日') 148 | text = rule.sub("X日", text) 149 | 150 | # e.g. 11/12/19, 11-1-19, 1.12.19, 11/12/2019 151 | rule = re.compile( 152 | r'(?:19|20)\d{2}(?:\/|\-|\.)\d{1,2}(?:\/|\-|\.)\d{1,2}') 153 | text = rule.sub("X年X月X日", text) 154 | return text 155 | 156 | def clean_time(self, text): 157 | rule = re.compile(rf'[\d{self.month}]+[时点]') 158 | text = rule.sub("X时", text) 159 | 160 | rule = re.compile(rf'[时点][\d{self.month}]+[分]') 161 | text = rule.sub("时X分", text) 162 | # e.g. UTC+09:00 163 | rule = re.compile(r'\w{3}[+-][0-9]{1,2}\:[0-9]{2}\b') 164 | text = rule.sub("X点", text) 165 | # e.g. 18:09:01 166 | rule = re.compile(r'\d{1,2}\:\d{2}\:\d{2}') 167 | text = rule.sub("X点X分", text) 168 | # e.g. 18:09 169 | rule = re.compile(r'\d{1,2}\:\d{2}') 170 | text = rule.sub("X点X分", text) 171 | return text 172 | 173 | def clean_money(self, text): 174 | text = self.cnnum2num(text, "元") 175 | 176 | rule = re.compile(r'\d+[.]?\d*[万亿]?元') 177 | text = rule.sub(lambda x: self.quantity_convert( 178 | x.group()[:-1]) + "元", text) 179 | return text 180 | 181 | def clean_weight(self, text): 182 | text = self.cnnum2num(text, "千克") 183 | text = self.cnnum2num(text, "公斤") 184 | 185 | rule = re.compile(rf'[\d{self.weight}]+(?:千克|公斤)') 186 | text = rule.sub(lambda x: self.quantity_convert( 187 | x.group()[:-2]) + "千克", text) 188 | 189 | text = self.cnnum2num(text, "斤") 190 | rule = re.compile(rf'[\d{self.weight}]+斤') 191 | text = rule.sub(lambda x: self.quantity_convert( 192 | x.group()[:-1]) + "斤", text) 193 | 194 | text = self.cnnum2num(text, "吨") 195 | rule = re.compile(rf'[\d{self.weight}]+吨') 196 | text = rule.sub(lambda x: self.quantity_convert( 197 | x.group()[:-1]) + "吨", text) 198 | 199 | text = self.cnnum2num(text, "克") 200 | rule = re.compile(rf'[\d{self.weight}]+克') 201 | text = rule.sub(lambda x: self.quantity_convert( 202 | x.group()[:-1]) + "克", text) 203 | return text 204 | 205 | def clean_concentration(self, text): 206 | 207 | def convert_combine(text): 208 | pt = ptxt.Text(text, "num") 209 | pure_text = pt.clean 210 | converted = self.concentration_convert(pt.extract.mats[0]) 211 | return pure_text + converted 212 | 213 | rule = re.compile(r'浓度\w*[.]?\d+[.]?\d*') 214 | text = rule.sub(lambda x: convert_combine(x.group()), text) 215 | return text 216 | 217 | def clean_entity(self, text): 218 | wps = pseg.cut(text) 219 | res = [] 220 | for w, pos in wps: 221 | # 人名 222 | if pos == "nr": 223 | res.append("P") 224 | # 地名 225 | elif pos == "ns": 226 | res.append("L") 227 | # 机构名 228 | elif pos == "nt": 229 | res.append("O") 230 | else: 231 | res.append(w) 232 | return "".join(res) 233 | 234 | def clean_stopwords(self, token_list): 235 | if self.stopwords_set: 236 | res = [t for t in token_list if t not in self.stopwords_set] 237 | else: 238 | res = token_list 239 | return " ".join(res) 240 | 241 | def clean_nums(self, text): 242 | rule = re.compile(r"[.-]?[\d.]+[e%]?[\d]?") 243 | text = rule.sub("X", text) 244 | return text 245 | 246 | def clean_english(self, text): 247 | rule = re.compile(r'[a-zA-Z]+') 248 | text = rule.sub("E", text) 249 | return text 250 | 251 | 252 | class ChineseCharProcessor(ChineseProcessor): 253 | 254 | def __init__(self, stopwords_path="", *args, **kwargs): 255 | super().__init__(*args, **kwargs) 256 | self.reset(stopwords_path) 257 | 258 | def __call__(self, sent): 259 | sent = ptxt.Text(sent, "whi").clean 260 | sent = self.clean_linkpic(sent) 261 | 262 | sent = self.clean_english(sent) 263 | 264 | sent = self.clean_date(sent) 265 | sent = self.clean_time(sent) 266 | 267 | sent = self.clean_money(sent) 268 | sent = self.clean_weight(sent) 269 | sent = self.clean_concentration(sent) 270 | 271 | sent = self.clean_entity(sent) 272 | 273 | sent = self.clean_nums(sent) 274 | 275 | clist = list(sent) 276 | sent = self.clean_stopwords(clist) 277 | sent = self.clean_punctuation(sent) 278 | 279 | return sent 280 | 281 | 282 | class ChineseWordProcessor(ChineseProcessor): 283 | 284 | def __init__(self, stopwords_path="", userdict_path="", *args, **kwargs): 285 | super().__init__(*args, **kwargs) 286 | if userdict_path and os.path.exists(userdict_path): 287 | jieba.load_userdict(str(userdict_path)) 288 | self.reset(stopwords_path) 289 | 290 | def __call__(self, sent): 291 | sent = ptxt.Text(sent, "whi").clean 292 | sent = self.clean_linkpic(sent) 293 | 294 | sent = self.clean_english(sent) 295 | 296 | sent = self.clean_date(sent) 297 | sent = self.clean_time(sent) 298 | 299 | sent = self.clean_money(sent) 300 | sent = self.clean_weight(sent) 301 | sent = self.clean_concentration(sent) 302 | 303 | sent = self.clean_entity(sent) 304 | 305 | sent = self.clean_nums(sent) 306 | 307 | wlist = jieba.lcut(sent) 308 | sent = self.clean_stopwords(wlist) 309 | sent = self.clean_punctuation(sent) 310 | 311 | return sent 312 | 313 | 314 | class Chinese2Arabic: 315 | """ 316 | Chinese_to_arabic 317 | modifed from https://github.com/bamtercelboo/corpus_process_script/blob/master/cn_to_arabic/cn_to_arabic.py 318 | """ 319 | 320 | def __init__(self): 321 | self.CN_NUM = CN_NUM 322 | self.CN_UNIT = CN_UNIT 323 | 324 | def __call__(self, cn: str): 325 | unit = 0 326 | ldig = [] 327 | for cndig in reversed(cn): 328 | if cndig in self.CN_UNIT: 329 | unit = self.CN_UNIT.get(cndig) 330 | if unit == 10000 or unit == 100000000: 331 | ldig.append(unit) 332 | unit = 1 333 | else: 334 | dig = self.CN_NUM.get(cndig) 335 | if unit: 336 | dig *= unit 337 | unit = 0 338 | ldig.append(dig) 339 | if unit == 10: 340 | ldig.append(10) 341 | val, tmp = 0, 0 342 | for x in reversed(ldig): 343 | if x == 10000 or x == 100000000: 344 | val += tmp * x 345 | tmp = 0 346 | else: 347 | tmp += x 348 | val += tmp 349 | if val == 0 and cn != "零": 350 | return cn 351 | else: 352 | return val 353 | 354 | 355 | if __name__ == '__main__': 356 | ccp = ChineseCharProcessor(stopwords_path="../dicts/stopwords.txt") 357 | cwp = ChineseWordProcessor(stopwords_path="../dicts/stopwords.txt") 358 | text = """ 359 | 一元,三里,十元。 360 | 朱镕基总理不错。张三去爬珠穆朗玛峰。 361 | 多福多寿一千万元啊,这是两百元。给你。我与你,也好。19.42万元,共8万元。18.32,万,千,百,亿元。 362 | 123千克,三百二十千克,两百多千克,一百二十公斤,1万千克,20000千克,好多。 363 | 3043克白粉,20斤白面,3000吨钢材,三千吨钢材。 364 | 浓度达214,浓度分别超国家规定的排放标准8.38 365 | """ 366 | 367 | res = ccp(text) 368 | print(res) 369 | 370 | res = cwp(text) 371 | print(res) 372 | 373 | print(text) 374 | 375 | print(ccp.clean_money(text)) 376 | print(ccp.clean_weight(text)) 377 | -------------------------------------------------------------------------------- /mltc/dicts/stopwords.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | " 4 | # 5 | $ 6 | & 7 | ' 8 | ( 9 | ) 10 | * 11 | + 12 | , 13 | - 14 | -- 15 | . 16 | ... 17 | ...... 18 | ................... 19 | ./ 20 | .一 21 | .数 22 | .日 23 | / 24 | // 25 | 0 26 | 1 27 | 2 28 | 3 29 | 4 30 | 5 31 | 6 32 | 7 33 | 8 34 | 9 35 | : 36 | :// 37 | :: 38 | ; 39 | < 40 | = 41 | > 42 | ? 43 | @ 44 | Lex 45 | [ 46 | ] 47 | _ 48 | all 49 | an 50 | and 51 | are 52 | as 53 | at 54 | be 55 | by 56 | can 57 | exp 58 | for 59 | from 60 | has 61 | have 62 | if 63 | in 64 | is 65 | it 66 | not 67 | of 68 | on 69 | one 70 | or 71 | sub 72 | sup 73 | that 74 | the 75 | then 76 | this 77 | to 78 | we 79 | which 80 | with 81 | you 82 | } 83 | ~~~~ 84 | · 85 | × 86 | ××× 87 | Δ 88 | Ψ 89 | γ 90 | μ 91 | φ 92 | φ. 93 | В 94 | — 95 | —— 96 | ——— 97 | ‘ 98 | ’ 99 | ’‘ 100 | “ 101 | ” 102 | ”, 103 | … 104 | …… 105 | …………………………………………………③ 106 | ′∈ 107 | ′| 108 | ℃ 109 | Ⅲ 110 | ↑ 111 | → 112 | ∈[ 113 | ∪φ∈ 114 | ≈ 115 | ① 116 | ② 117 | ②c 118 | ③ 119 | ③] 120 | ④ 121 | ⑤ 122 | ⑥ 123 | ⑦ 124 | ⑧ 125 | ⑨ 126 | ⑩ 127 | ── 128 | ■ 129 | ▲ 130 | 、 131 | 。 132 | 〉 133 | 《 134 | 》 135 | 》), 136 | 「 137 | 」 138 | 『 139 | 』 140 | 【 141 | 】 142 | 〔 143 | 〕 144 | 〕〔 145 | ㈧ 146 | 一 147 | 一. 148 | 一一 149 | 一个 150 | 一些 151 | 一何 152 | 一個 153 | 一切 154 | 一则 155 | 一则通过 156 | 一方面 157 | 一旦 158 | 一来 159 | 一样 160 | 一番 161 | 一直 162 | 一般 163 | 一转眼 164 | 万一 165 | 三天两头 166 | 三番两次 167 | 三番五次 168 | 上 169 | 上下 170 | 上去 171 | 上来 172 | 下 173 | 不 174 | 不下 175 | 不了 176 | 不亦乐乎 177 | 不仅 178 | 不仅仅 179 | 不仅仅是 180 | 不会 181 | 不但 182 | 不光 183 | 不免 184 | 不再 185 | 不力 186 | 不单 187 | 不只 188 | 不可开交 189 | 不可抗拒 190 | 不同 191 | 不外 192 | 不外乎 193 | 不大 194 | 不如 195 | 不妨 196 | 不定 197 | 不对 198 | 不少 199 | 不尽 200 | 不尽然 201 | 不巧 202 | 不已 203 | 不常 204 | 不得 205 | 不得不 206 | 不得了 207 | 不得已 208 | 不必 209 | 不怎么 210 | 不怕 211 | 不惟 212 | 不成 213 | 不拘 214 | 不择手段 215 | 不料 216 | 不日 217 | 不时 218 | 不是 219 | 不曾 220 | 不止 221 | 不止一次 222 | 不比 223 | 不消 224 | 不满 225 | 不然 226 | 不然的话 227 | 不特 228 | 不独 229 | 不由得 230 | 不知不觉 231 | 不管 232 | 不管怎样 233 | 不经意 234 | 不胜 235 | 不能 236 | 不能不 237 | 不至于 238 | 不若 239 | 不要 240 | 不论 241 | 不起 242 | 不过 243 | 不迭 244 | 不问 245 | 不限 246 | 与 247 | 与其 248 | 与其说 249 | 与否 250 | 与此同时 251 | 且 252 | 且不说 253 | 且说 254 | 两者 255 | 个 256 | 个人 257 | 个别 258 | 临 259 | 临到 260 | 为 261 | 为了 262 | 为什么 263 | 为何 264 | 为止 265 | 为此 266 | 为着 267 | 举凡 268 | 乃 269 | 乃至 270 | 乃至于 271 | 么 272 | 之 273 | 之一 274 | 之所以 275 | 之类 276 | 乌乎 277 | 乎 278 | 乘 279 | 乘势 280 | 乘机 281 | 乘虚 282 | 乘隙 283 | 也 284 | 也好 285 | 也就是说 286 | 也罢 287 | 了 288 | 二来 289 | 二话不说 290 | 二话没说 291 | 于 292 | 于是 293 | 于是乎 294 | 云云 295 | 云尔 296 | 互相 297 | 些 298 | 交口 299 | 亦 300 | 亲口 301 | 亲手 302 | 亲眼 303 | 亲自 304 | 亲身 305 | 人 306 | 人人 307 | 人们 308 | 人家 309 | 什么 310 | 什么样 311 | 今 312 | 介于 313 | 仍 314 | 仍旧 315 | 仍然 316 | 从 317 | 从不 318 | 从严 319 | 从中 320 | 从今以后 321 | 从优 322 | 从古到今 323 | 从古至今 324 | 从头 325 | 从宽 326 | 从小 327 | 从新 328 | 从无到有 329 | 从早到晚 330 | 从未 331 | 从来 332 | 从此 333 | 从此以后 334 | 从而 335 | 从轻 336 | 从速 337 | 从重 338 | 他 339 | 他人 340 | 他们 341 | 他們 342 | 他是 343 | 以 344 | 以上 345 | 以为 346 | 以便 347 | 以免 348 | 以及 349 | 以故 350 | 以期 351 | 以来 352 | 以至 353 | 以至于 354 | 以致 355 | 们 356 | 任 357 | 任何 358 | 任凭 359 | 伙同 360 | 会 361 | 传说 362 | 传闻 363 | 似的 364 | 但 365 | 但凡 366 | 但愿 367 | 但是 368 | 何 369 | 何乐而不为 370 | 何以 371 | 何况 372 | 何处 373 | 何妨 374 | 何尝 375 | 何必 376 | 何时 377 | 何止 378 | 何苦 379 | 何须 380 | 余外 381 | 作为 382 | 你 383 | 你们 384 | 你們 385 | 你是 386 | 使 387 | 使得 388 | 例如 389 | 依 390 | 依据 391 | 依照 392 | 便于 393 | 俺 394 | 俺们 395 | 倍加 396 | 倍感 397 | 倒不如 398 | 倒不如说 399 | 倒是 400 | 倘 401 | 倘使 402 | 倘或 403 | 倘然 404 | 倘若 405 | 借 406 | 借以 407 | 借此 408 | 假使 409 | 假如 410 | 假若 411 | 偏偏 412 | 偶尔 413 | 偶而 414 | 傥然 415 | 像 416 | 儿 417 | 元/吨 418 | 充其极 419 | 充其量 420 | 充分 421 | 先不先 422 | 光是 423 | 全体 424 | 全力 425 | 全年 426 | 全然 427 | 全身心 428 | 全部 429 | 全都 430 | 八成 431 | 公然 432 | 兮 433 | 共总 434 | 关于 435 | 其 436 | 其一 437 | 其中 438 | 其二 439 | 其他 440 | 其余 441 | 其后 442 | 其它 443 | 其实 444 | 其次 445 | 具体地说 446 | 具体来说 447 | 具体说来 448 | 兼之 449 | 内 450 | 再 451 | 再其次 452 | 再则 453 | 再有 454 | 再次 455 | 再者 456 | 再者说 457 | 再说 458 | 冒 459 | 冲 460 | 决不 461 | 决非 462 | 况且 463 | 凑巧 464 | 凝神 465 | 几 466 | 几乎 467 | 几度 468 | 几时 469 | 几番 470 | 几经 471 | 凡 472 | 凡是 473 | 凭 474 | 凭借 475 | 出于 476 | 出去 477 | 出来 478 | 分别 479 | 分头 480 | 分期分批 481 | 切不可 482 | 切切 483 | 切勿 484 | 切莫 485 | 则 486 | 则甚 487 | 刚好 488 | 刚巧 489 | 刚才 490 | 别 491 | 别人 492 | 别处 493 | 别是 494 | 别的 495 | 别管 496 | 别说 497 | 到 498 | 到了儿 499 | 到处 500 | 到头 501 | 到头来 502 | 到底 503 | 到目前为止 504 | 前后 505 | 前此 506 | 前者 507 | 加上 508 | 加之 509 | 加以 510 | 动不动 511 | 动辄 512 | 勃然 513 | 匆匆 514 | 千万千万 515 | 单单 516 | 单纯 517 | 即 518 | 即令 519 | 即使 520 | 即便 521 | 即刻 522 | 即如 523 | 即将 524 | 即或 525 | 即是说 526 | 即若 527 | 却 528 | 去 529 | 又 530 | 又及 531 | 及 532 | 及其 533 | 及至 534 | 反之 535 | 反之亦然 536 | 反之则 537 | 反倒 538 | 反倒是 539 | 反手 540 | 反而 541 | 反过来 542 | 反过来说 543 | 取道 544 | 受到 545 | 另 546 | 另一个 547 | 另一方面 548 | 另外 549 | 另悉 550 | 另方面 551 | 另行 552 | 只 553 | 只当 554 | 只怕 555 | 只是 556 | 只有 557 | 只消 558 | 只要 559 | 只限 560 | 叫 561 | 叮咚 562 | 可 563 | 可以 564 | 可好 565 | 可是 566 | 可能 567 | 可见 568 | 各 569 | 各个 570 | 各位 571 | 各式 572 | 各种 573 | 各自 574 | 同 575 | 同时 576 | 后 577 | 后来 578 | 后者 579 | 向 580 | 向使 581 | 向着 582 | 吓 583 | 吗 584 | 否则 585 | 吧 586 | 吧哒 587 | 吱 588 | 呀 589 | 呃 590 | 呆呆地 591 | 呕 592 | 呗 593 | 呜 594 | 呜呼 595 | 呢 596 | 呵 597 | 呵呵 598 | 呸 599 | 呼哧 600 | 呼啦 601 | 咋 602 | 和 603 | 咚 604 | 咦 605 | 咧 606 | 咱 607 | 咱们 608 | 咳 609 | 哇 610 | 哈 611 | 哈哈 612 | 哉 613 | 哎 614 | 哎呀 615 | 哎哟 616 | 哗 617 | 哗啦 618 | 哟 619 | 哦 620 | 哩 621 | 哪 622 | 哪个 623 | 哪些 624 | 哪儿 625 | 哪天 626 | 哪年 627 | 哪怕 628 | 哪样 629 | 哪边 630 | 哪里 631 | 哼 632 | 哼唷 633 | 唉 634 | 唯有 635 | 啊 636 | 啊呀 637 | 啊哈 638 | 啊哟 639 | 啐 640 | 啥 641 | 啦 642 | 啪达 643 | 啷当 644 | 喂 645 | 喏 646 | 喔唷 647 | 喽 648 | 嗡 649 | 嗡嗡 650 | 嗬 651 | 嗯 652 | 嗳 653 | 嘎 654 | 嘎嘎 655 | 嘎登 656 | 嘘 657 | 嘛 658 | 嘻 659 | 嘿 660 | 嘿嘿 661 | 因 662 | 因为 663 | 因了 664 | 因此 665 | 因着 666 | 因而 667 | 固然 668 | 在 669 | 在下 670 | 在于 671 | 地 672 | 基于 673 | 基本 674 | 基本上 675 | 处在 676 | 处处 677 | 多 678 | 多么 679 | 多亏 680 | 多多 681 | 多多少少 682 | 多多益善 683 | 多少 684 | 多年前 685 | 多年来 686 | 多次 687 | 够瞧的 688 | 大 689 | 大不了 690 | 大举 691 | 大体上 692 | 大凡 693 | 大多 694 | 大大 695 | 大家 696 | 大张旗鼓 697 | 大抵 698 | 大概 699 | 大略 700 | 大约 701 | 大致 702 | 大都 703 | 大面儿上 704 | 奋勇 705 | 她 706 | 她们 707 | 她們 708 | 她是 709 | 好 710 | 好在 711 | 如 712 | 如上 713 | 如上所述 714 | 如下 715 | 如今 716 | 如何 717 | 如其 718 | 如前所述 719 | 如同 720 | 如常 721 | 如是 722 | 如期 723 | 如果 724 | 如次 725 | 如此 726 | 如此等等 727 | 如若 728 | 妳們 729 | 始而 730 | 姑且 731 | 存心 732 | 孰料 733 | 孰知 734 | 宁 735 | 宁可 736 | 宁愿 737 | 宁肯 738 | 它 739 | 它们 740 | 它是 741 | 对 742 | 对于 743 | 对待 744 | 对方 745 | 对比 746 | 将 747 | 将才 748 | 将要 749 | 将近 750 | 小 751 | 尔 752 | 尔后 753 | 尔尔 754 | 尔等 755 | 尚且 756 | 就 757 | 就地 758 | 就是 759 | 就是了 760 | 就是说 761 | 就此 762 | 就算 763 | 就要 764 | 尽 765 | 尽可能 766 | 尽如人意 767 | 尽心尽力 768 | 尽心竭力 769 | 尽快 770 | 尽早 771 | 尽然 772 | 尽管 773 | 尽管如此 774 | 尽量 775 | 局外 776 | 居然 777 | 届时 778 | 屡屡 779 | 屡次 780 | 屡次三番 781 | 岂但 782 | 岂止 783 | 岂非 784 | 川流不息 785 | 差一点 786 | 差不多 787 | 己 788 | 已 789 | 已矣 790 | 巴 791 | 巴巴 792 | 常言说 793 | 常言说得好 794 | 常言道 795 | 平素 796 | 年复一年 797 | 并 798 | 并且 799 | 并排 800 | 并无 801 | 并没 802 | 并没有 803 | 并肩 804 | 并非 805 | 庶乎 806 | 庶几 807 | 开外 808 | 开始 809 | 弹指之间 810 | 归 811 | 归根到底 812 | 归根结底 813 | 归齐 814 | 当 815 | 当下 816 | 当中 817 | 当儿 818 | 当即 819 | 当口儿 820 | 当地 821 | 当场 822 | 当头 823 | 当庭 824 | 当然 825 | 当真 826 | 当着 827 | 彻夜 828 | 彼 829 | 彼时 830 | 彼此 831 | 往 832 | 待 833 | 待到 834 | 很 835 | 很多 836 | 很少 837 | 得 838 | 得了 839 | 得天独厚 840 | 得起 841 | 必定 842 | 必将 843 | 必须 844 | 快要 845 | 忽地 846 | 忽然 847 | 怎 848 | 怎么 849 | 怎么办 850 | 怎么样 851 | 怎奈 852 | 怎样 853 | 急匆匆 854 | 怪不得 855 | 总之 856 | 总的来看 857 | 总的来说 858 | 总的说来 859 | 总而言之 860 | 恍然 861 | 恐怕 862 | 恰似 863 | 恰好 864 | 恰如 865 | 恰巧 866 | 恰恰 867 | 恰恰相反 868 | 恰逢 869 | 您 870 | 您们 871 | 您是 872 | 惟其 873 | 惯常 874 | 愤然 875 | 慢说 876 | 成年累月 877 | 成心 878 | 我 879 | 我们 880 | 我們 881 | 我是 882 | 或 883 | 或则 884 | 或多或少 885 | 或是 886 | 或曰 887 | 或者 888 | 或许 889 | 截然 890 | 截至 891 | 所 892 | 所以 893 | 所在 894 | 所幸 895 | 所有 896 | 才 897 | 才能 898 | 扑通 899 | 打 900 | 打从 901 | 打开天窗说亮话 902 | 把 903 | 抑或 904 | 抽冷子 905 | 拦腰 906 | 拿 907 | 按 908 | 按时 909 | 按期 910 | 按照 911 | 按理 912 | 按说 913 | 挨个 914 | 挨家挨户 915 | 挨次 916 | 挨着 917 | 挨门挨户 918 | 挨门逐户 919 | 换句话说 920 | 换言之 921 | 据 922 | 据实 923 | 据悉 924 | 据我所知 925 | 据此 926 | 据称 927 | 据说 928 | 接下来 929 | 接着 930 | 接连不断 931 | 故 932 | 故意 933 | 故此 934 | 故而 935 | 敞开儿 936 | 敢于 937 | 敢情 938 | 数/ 939 | 断然 940 | 方才 941 | 方能 942 | 旁人 943 | 无 944 | 无宁 945 | 无论 946 | 既 947 | 既往 948 | 既是 949 | 既然 950 | 日复一日 951 | 日渐 952 | 日益 953 | 日臻 954 | 日见 955 | 时候 956 | 昂然 957 | 是 958 | 是以 959 | 是否 960 | 是的 961 | 暗中 962 | 暗地里 963 | 暗自 964 | 更为 965 | 更加 966 | 更进一步 967 | 曾 968 | 替 969 | 替代 970 | 最 971 | 最后 972 | 有 973 | 有些 974 | 有关 975 | 有及 976 | 有时 977 | 有的 978 | 有的是 979 | 望 980 | 朝 981 | 朝着 982 | 末##末 983 | 本 984 | 本人 985 | 本地 986 | 本着 987 | 本身 988 | 权时 989 | 来 990 | 来不及 991 | 来得及 992 | 来看 993 | 来着 994 | 来自 995 | 来讲 996 | 来说 997 | 极为 998 | 极了 999 | 极其 1000 | 极力 1001 | 极大 1002 | 极度 1003 | 极端 1004 | 果然 1005 | 果真 1006 | 某 1007 | 某个 1008 | 某些 1009 | 某某 1010 | 根据 1011 | 格外 1012 | 次第 1013 | 欤 1014 | 正值 1015 | 正如 1016 | 正巧 1017 | 正是 1018 | 此 1019 | 此中 1020 | 此后 1021 | 此地 1022 | 此处 1023 | 此外 1024 | 此时 1025 | 此次 1026 | 此间 1027 | 毋宁 1028 | 每 1029 | 每当 1030 | 每时每刻 1031 | 每每 1032 | 每逢 1033 | 比 1034 | 比及 1035 | 比如 1036 | 比如说 1037 | 比方 1038 | 比照 1039 | 比起 1040 | 毕竟 1041 | 毫不 1042 | 毫无 1043 | 毫无例外 1044 | 毫无保留地 1045 | 沒有 1046 | 沙沙 1047 | 没奈何 1048 | 没有 1049 | 沿 1050 | 沿着 1051 | 漫说 1052 | 焉 1053 | 然则 1054 | 然后 1055 | 然而 1056 | 照 1057 | 照着 1058 | 牢牢 1059 | 犹且 1060 | 犹自 1061 | 独自 1062 | 猛然 1063 | 猛然间 1064 | 率尔 1065 | 率然 1066 | 理应 1067 | 理当 1068 | 理该 1069 | 瑟瑟 1070 | 甚且 1071 | 甚么 1072 | 甚或 1073 | 甚而 1074 | 甚至 1075 | 甚至于 1076 | 用 1077 | 用来 1078 | 甭 1079 | 由 1080 | 由于 1081 | 由是 1082 | 由此 1083 | 由此可见 1084 | 略为 1085 | 略加 1086 | 略微 1087 | 的 1088 | 的确 1089 | 的话 1090 | 皆可 1091 | 直到 1092 | 相对而言 1093 | 省得 1094 | 看 1095 | 看上去 1096 | 看来 1097 | 看样子 1098 | 看起来 1099 | 眨眼 1100 | 着 1101 | 着呢 1102 | 矣 1103 | 矣乎 1104 | 矣哉 1105 | 砰 1106 | 碰巧 1107 | 离 1108 | 种 1109 | 究竟 1110 | 穷年累月 1111 | 立刻 1112 | 立地 1113 | 立时 1114 | 立马 1115 | 竟然 1116 | 竟而 1117 | 第 1118 | 第二 1119 | 等 1120 | 等到 1121 | 等等 1122 | 策略地 1123 | 简直 1124 | 简而言之 1125 | 简言之 1126 | 管 1127 | 类如 1128 | 精光 1129 | 紧接着 1130 | 累年 1131 | 累次 1132 | 纯粹 1133 | 纵 1134 | 纵令 1135 | 纵使 1136 | 纵然 1137 | 经 1138 | 经常 1139 | 经过 1140 | 结果 1141 | 给 1142 | 绝不 1143 | 绝对 1144 | 绝非 1145 | 绝顶 1146 | 继之 1147 | 继后 1148 | 继而 1149 | 综上所述 1150 | 缕缕 1151 | 罢了 1152 | 老是 1153 | 老老实实 1154 | 者 1155 | 而 1156 | 而且 1157 | 而况 1158 | 而又 1159 | 而后 1160 | 而外 1161 | 而已 1162 | 而是 1163 | 而言 1164 | 而论 1165 | 联袂 1166 | 背地里 1167 | 背靠背 1168 | 能 1169 | 能否 1170 | 腾 1171 | 自 1172 | 自个儿 1173 | 自从 1174 | 自各儿 1175 | 自后 1176 | 自家 1177 | 自己 1178 | 自打 1179 | 自身 1180 | 至 1181 | 至于 1182 | 至今 1183 | 至若 1184 | 致 1185 | 與 1186 | 般的 1187 | 若 1188 | 若夫 1189 | 若是 1190 | 若果 1191 | 若非 1192 | 莫不 1193 | 莫不然 1194 | 莫如 1195 | 莫若 1196 | 莫非 1197 | 著 1198 | 藉以 1199 | 虽 1200 | 虽则 1201 | 虽然 1202 | 虽说 1203 | 被 1204 | 要 1205 | 要不 1206 | 要不是 1207 | 要不然 1208 | 要么 1209 | 要是 1210 | 譬喻 1211 | 譬如 1212 | 让 1213 | 许多 1214 | 论 1215 | 论说 1216 | 设使 1217 | 设或 1218 | 设若 1219 | 诚如 1220 | 诚然 1221 | 话说 1222 | 该 1223 | 该当 1224 | 说来 1225 | 请勿 1226 | 诸 1227 | 诸位 1228 | 诸如 1229 | 谁 1230 | 谁人 1231 | 谁料 1232 | 谁知 1233 | 豁然 1234 | 贼死 1235 | 赖以 1236 | 赶 1237 | 赶快 1238 | 赶早不赶晚 1239 | 起 1240 | 起先 1241 | 起初 1242 | 起头 1243 | 起来 1244 | 起见 1245 | 起首 1246 | 趁 1247 | 趁便 1248 | 趁势 1249 | 趁早 1250 | 趁机 1251 | 趁热 1252 | 趁着 1253 | 越是 1254 | 距 1255 | 跟 1256 | 路经 1257 | 轰然 1258 | 较 1259 | 较为 1260 | 较之 1261 | 较比 1262 | 边 1263 | 达旦 1264 | 过 1265 | 过于 1266 | 近几年来 1267 | 近年来 1268 | 近来 1269 | 还 1270 | 还是 1271 | 还有 1272 | 还要 1273 | 这 1274 | 这一来 1275 | 这个 1276 | 这么 1277 | 这么些 1278 | 这么样 1279 | 这么点儿 1280 | 这些 1281 | 这会儿 1282 | 这儿 1283 | 这就是说 1284 | 这时 1285 | 这样 1286 | 这次 1287 | 这般 1288 | 这边 1289 | 这里 1290 | 进去 1291 | 进来 1292 | 进而 1293 | 连 1294 | 连同 1295 | 连声 1296 | 连日 1297 | 连日来 1298 | 连袂 1299 | 连连 1300 | 迟早 1301 | 迫于 1302 | 逐步 1303 | 通过 1304 | 遵循 1305 | 遵照 1306 | 那 1307 | 那个 1308 | 那么 1309 | 那么些 1310 | 那么样 1311 | 那些 1312 | 那会儿 1313 | 那儿 1314 | 那时 1315 | 那末 1316 | 那样 1317 | 那般 1318 | 那边 1319 | 那里 1320 | 都 1321 | 鄙人 1322 | 鉴于 1323 | 针对 1324 | 长期以来 1325 | 长此下去 1326 | 长话短说 1327 | 间或 1328 | 阿 1329 | 陡然 1330 | 除 1331 | 除了 1332 | 除却 1333 | 除去 1334 | 除外 1335 | 除开 1336 | 除此 1337 | 除此之外 1338 | 除此以外 1339 | 除此而外 1340 | 除非 1341 | 随 1342 | 随后 1343 | 随时 1344 | 随着 1345 | 隔夜 1346 | 隔日 1347 | 难得 1348 | 难怪 1349 | 难说 1350 | 难道 1351 | 难道说 1352 | 非但 1353 | 非常 1354 | 非徒 1355 | 非得 1356 | 非特 1357 | 非独 1358 | 靠 1359 | 顶多 1360 | 顷刻 1361 | 顷刻之间 1362 | 顷刻间 1363 | 顺 1364 | 顺着 1365 | 顿时 1366 | 风雨无阻 1367 | 首先 1368 | 马上 1369 | 高低 1370 | 默然 1371 | 默默地 1372 | ! 1373 | # 1374 | % 1375 | & 1376 | ' 1377 | ( 1378 | ) 1379 | )÷(1- 1380 | )、 1381 | * 1382 | + 1383 | +ξ 1384 | ++ 1385 | , 1386 | ,也 1387 | - 1388 | -β 1389 | -- 1390 | -[*]- 1391 | . 1392 | / 1393 | 0:2 1394 | 1. 1395 | 12% 1396 | 2.3% 1397 | 5:0 1398 | : 1399 | ; 1400 | < 1401 | <± 1402 | <Δ 1403 | <λ 1404 | <φ 1405 | << 1406 | = 1407 | =″ 1408 | =☆ 1409 | =( 1410 | =- 1411 | =[ 1412 | ={ 1413 | > 1414 | >λ 1415 | ? 1416 | A 1417 | LI 1418 | R.L. 1419 | ZXFITL 1420 | [ 1421 | [①①] 1422 | [①②] 1423 | [①③] 1424 | [①④] 1425 | [①⑤] 1426 | [①⑥] 1427 | [①⑦] 1428 | [①⑧] 1429 | [①⑨] 1430 | [①A] 1431 | [①B] 1432 | [①C] 1433 | [①D] 1434 | [①E] 1435 | [①] 1436 | [①a] 1437 | [①c] 1438 | [①d] 1439 | [①e] 1440 | [①f] 1441 | [①g] 1442 | [①h] 1443 | [①i] 1444 | [①o] 1445 | [② 1446 | [②①] 1447 | [②②] 1448 | [②③] 1449 | [②④ 1450 | [②⑤] 1451 | [②⑥] 1452 | [②⑦] 1453 | [②⑧] 1454 | [②⑩] 1455 | [②B] 1456 | [②G] 1457 | [②] 1458 | [②a] 1459 | [②b] 1460 | [②c] 1461 | [②d] 1462 | [②e] 1463 | [②f] 1464 | [②g] 1465 | [②h] 1466 | [②i] 1467 | [②j] 1468 | [③①] 1469 | [③⑩] 1470 | [③F] 1471 | [③] 1472 | [③a] 1473 | [③b] 1474 | [③c] 1475 | [③d] 1476 | [③e] 1477 | [③g] 1478 | [③h] 1479 | [④] 1480 | [④a] 1481 | [④b] 1482 | [④c] 1483 | [④d] 1484 | [④e] 1485 | [⑤] 1486 | [⑤]] 1487 | [⑤a] 1488 | [⑤b] 1489 | [⑤d] 1490 | [⑤e] 1491 | [⑤f] 1492 | [⑥] 1493 | [⑦] 1494 | [⑧] 1495 | [⑨] 1496 | [⑩] 1497 | [*] 1498 | [- 1499 | [] 1500 | ] 1501 | ]∧′=[ 1502 | ][ 1503 | _ 1504 | a] 1505 | b] 1506 | c] 1507 | e] 1508 | f] 1509 | ng昉 1510 | {- 1511 | } 1512 | }> 1513 | ~ 1514 | ~± 1515 | ~+ 1516 | -------------------------------------------------------------------------------- /mltc/main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from argparse import ArgumentParser 3 | import os 4 | import warnings 5 | import torch 6 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 7 | from pytorch_transformers import AdamW, WarmupLinearSchedule 8 | from pnlp import piop 9 | import pandas as pd 10 | 11 | from configs.basic_config import config 12 | from utils.utils import seed_everything, init_logger, logger, AttrDict 13 | from preprocessors.processor import Preprocessor 14 | from postprocessors.processor import Postprocessor 15 | from scheme.error import PipelineReadError 16 | from callback.training_monitor import TrainingMonitor 17 | from callback.model_checkpoint import ModelCheckpoint 18 | from models.model import Classifier 19 | from train.losses import BCEWithLogLoss 20 | from train.trainer import Trainer 21 | from train.metrics import AUC, AccuracyThresh, MultiLabelReport 22 | 23 | import sys 24 | 25 | root = os.path.dirname(os.path.abspath(__file__)) 26 | sys.path.append(root) 27 | 28 | warnings.simplefilter('ignore') 29 | 30 | 31 | def train(args): 32 | ########### data ########### 33 | processor = Postprocessor(config["postprocessor"])( 34 | do_lower_case=args.do_lower_case) 35 | label_list = processor.get_labels(config['data_dir'] / "labels.txt") 36 | id2label = {i: label for i, label in enumerate(label_list)} 37 | 38 | train_data = processor.get_train( 39 | config['data_dir'] / "{}.train.pkl".format(args.data_name)) 40 | train_examples = processor.create_examples( 41 | lines=train_data, 42 | example_type='train', 43 | cached_examples_file=config[ 44 | "data_dir"] / "cached_train_examples_{}".format(args.pretrain)) 45 | train_features = processor.create_features( 46 | examples=train_examples, 47 | max_seq_len=args.train_max_seq_len, 48 | cached_features_file=config[ 49 | "data_dir"] / "cached_train_features_{}_{}".format( 50 | args.train_max_seq_len, args.pretrain)) 51 | 52 | train_dataset = processor.create_dataset( 53 | train_features, is_sorted=args.sorted) 54 | if args.sorted: 55 | train_sampler = SequentialSampler(train_dataset) 56 | else: 57 | train_sampler = RandomSampler(train_dataset) 58 | train_dataloader = DataLoader( 59 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 60 | 61 | valid_data = processor.get_dev( 62 | config["data_dir"] / "{}.valid.pkl".format(args.data_name)) 63 | valid_examples = processor.create_examples( 64 | lines=valid_data, 65 | example_type='valid', 66 | cached_examples_file=config[ 67 | "data_dir"] / "cached_valid_examples_{}".format(args.pretrain)) 68 | valid_features = processor.create_features( 69 | examples=valid_examples, 70 | max_seq_len=args.eval_max_seq_len, 71 | cached_features_file=config[ 72 | "data_dir"] / "cached_valid_features_{}_{}".format( 73 | args.eval_max_seq_len, args.pretrain)) 74 | valid_dataset = processor.create_dataset(valid_features) 75 | valid_sampler = SequentialSampler(valid_dataset) 76 | valid_dataloader = DataLoader( 77 | valid_dataset, sampler=valid_sampler, batch_size=args.eval_batch_size) 78 | 79 | if config["pretrain"] == "Nopretrain": 80 | config["vocab_size"] = processor.vocab_size 81 | 82 | ########### model ########### 83 | logger.info("========= initializing model =========") 84 | if args.resume_path: 85 | resume_path = Path(args.resume_path) 86 | model = Classifier( 87 | config["classifier"], config["pretrain"], resume_path)( 88 | num_labels=len(label_list)) 89 | else: 90 | model = Classifier( 91 | config["classifier"], config["pretrain"], "")( 92 | num_labels=len(label_list)) 93 | 94 | t_total = int(len(train_dataloader) / 95 | args.gradient_accumulation_steps * args.epochs) 96 | param_optimizer = list(model.named_parameters()) 97 | no_decay = ['bias', 'LayerNorm.weight'] 98 | optimizer_grouped_parameters = [ 99 | { 100 | 'params': [p for n, p in param_optimizer if not any( 101 | nd in n for nd in no_decay)], 102 | 'weight_decay': args.weight_decay 103 | }, 104 | { 105 | 'params': [p for n, p in param_optimizer if any( 106 | nd in n for nd in no_decay)], 107 | 'weight_decay': 0.0 108 | } 109 | ] 110 | warmup_steps = int(t_total * args.warmup_proportion) 111 | optimizer = AdamW(optimizer_grouped_parameters, 112 | lr=args.learning_rate, eps=args.adam_epsilon) 113 | lr_scheduler = WarmupLinearSchedule( 114 | optimizer, warmup_steps=warmup_steps, t_total=t_total) 115 | 116 | if args.fp16: 117 | try: 118 | from apex import amp 119 | except ImportError as e: 120 | raise ImportError( 121 | "Please install apex github.com/nvidia/apex to use fp16.") 122 | model, optimizer = amp.initialize( 123 | model, optimizer, opt_level=args.fp16_opt_level) 124 | 125 | ########### callback ########### 126 | logger.info("========= initializing callbacks =========") 127 | train_monitor = TrainingMonitor( 128 | file_dir=config['figure_dir'], arch=args.pretrain) 129 | model_checkpoint = ModelCheckpoint(checkpoint_dir=config['checkpoint_dir'], 130 | mode=args.mode, 131 | monitor=args.monitor, 132 | arch=args.pretrain, 133 | save_best_only=args.save_best) 134 | 135 | ########### train ########### 136 | logger.info("========= Running training =========") 137 | logger.info(" Num examples = {}".format(len(train_examples))) 138 | logger.info(" Num Epochs = {}".format(args.epochs)) 139 | logger.info(" Total train batch size \ 140 | (w. parallel, distributed & accumulation) = {}".format( 141 | args.train_batch_size * args.gradient_accumulation_steps * ( 142 | torch.distributed.get_world_size() if args.local_rank != -1 else 1) 143 | )) 144 | logger.info(" Gradient Accumulation steps = {}".format( 145 | args.gradient_accumulation_steps)) 146 | logger.info(" Total optimization steps = {}".format(t_total)) 147 | 148 | trainer = Trainer( 149 | n_gpu=args.n_gpu, 150 | model=model, 151 | epochs=args.epochs, 152 | logger=logger, 153 | criterion=BCEWithLogLoss(), 154 | optimizer=optimizer, 155 | lr_scheduler=lr_scheduler, 156 | early_stopping=None, 157 | training_monitor=train_monitor, 158 | fp16=args.fp16, 159 | resume_path=args.resume_path, 160 | grad_clip=args.grad_clip, 161 | model_checkpoint=model_checkpoint, 162 | gradient_accumulation_steps=args.gradient_accumulation_steps, 163 | batch_metrics=[AccuracyThresh(thresh=0.5)], 164 | epoch_metrics=[AUC(average='micro', task_type='binary'), 165 | MultiLabelReport(id2label=id2label)]) 166 | trainer.train(train_data=train_dataloader, 167 | valid_data=valid_dataloader, 168 | seed=args.seed) 169 | 170 | 171 | def test(args): 172 | from dataio.task_data import TaskData 173 | from predict.predictor import Predictor 174 | data = TaskData(args.test_data_num) 175 | labels, sents = data.read_data( 176 | raw_data_path=config["test_path"], 177 | data_dir=config["data_dir"], 178 | preprocessor=Preprocessor(config["preprocessor"])( 179 | stopwords_path=config["stopwords_path"], 180 | userdict_path=config["userdict_path"]), 181 | is_train=False) 182 | lines = list(zip(sents, labels)) 183 | 184 | processor = Postprocessor(config["postprocessor"])( 185 | do_lower_case=args.do_lower_case) 186 | label_list = processor.get_labels(config['data_dir'] / "labels.txt") 187 | id2label = {i: label for i, label in enumerate(label_list)} 188 | 189 | test_data = processor.get_test(lines=lines) 190 | test_examples = processor.create_examples( 191 | lines=test_data, 192 | example_type='test', 193 | cached_examples_file=config[ 194 | 'data_dir'] / "cached_test_examples_{}".format(args.pretrain)) 195 | test_features = processor.create_features( 196 | examples=test_examples, 197 | max_seq_len=args.eval_max_seq_len, 198 | cached_features_file=config[ 199 | 'data_dir'] / "cached_test_features_{}_{}".format( 200 | args.eval_max_seq_len, args.pretrain)) 201 | 202 | test_dataset = processor.create_dataset(test_features) 203 | test_sampler = SequentialSampler(test_dataset) 204 | test_dataloader = DataLoader( 205 | test_dataset, sampler=test_sampler, batch_size=args.train_batch_size) 206 | 207 | if config["pretrain"] == "Nopretrain": 208 | config["vocab_size"] = processor.vocab_size 209 | 210 | model = Classifier(config["classifier"], 211 | config["pretrain"], 212 | config["checkpoint_dir"])( 213 | num_labels=len(label_list)) 214 | 215 | ########### predict ########### 216 | logger.info('model predicting....') 217 | predictor = Predictor(model=model, logger=logger, n_gpu=args.n_gpu) 218 | logits, y_pred = predictor.predict(data=test_dataloader, thresh=0.5) 219 | 220 | pred_labels = [] 221 | for item in y_pred.tolist(): 222 | tmp = [] 223 | for i,v in enumerate(item): 224 | if v == 1: 225 | tmp.append(label_list[i]) 226 | pred_labels.append(",".join(tmp)) 227 | 228 | assert len(pred_labels) == y_pred.shape[0] 229 | df_pred_labels = pd.DataFrame(pred_labels, columns=["predict_labels"]) 230 | 231 | 232 | df_test_raw = pd.read_csv(config["test_path"]) 233 | if args.test_data_num > 0: 234 | df_test_raw = df_test_raw.head(args.test_data_num) 235 | df_labels = pd.DataFrame(logits, columns=label_list) 236 | df = pd.concat([df_test_raw, df_pred_labels, df_labels], axis=1) 237 | 238 | df.to_csv(config["result"] / "output.csv", index=False) 239 | # from sklearn.metrics import f1_score 240 | # from pnlp import piop 241 | # import numpy as np 242 | # y_pred = (result > 0.5) * 1 243 | # ytest = piop.read_json(config['data_dir'] / "ytest.json") 244 | # if args.test_data_num: 245 | # ytest = ytest[:args.test_data_num] 246 | # y_true = np.array(ytest) 247 | # micro = f1_score(y_true, y_pred, average='micro') 248 | # macro = f1_score(y_true, y_pred, average='macro') 249 | # score = (micro + macro) / 2 250 | # print("Score: micro {}, macro {} Average {}".format(micro, macro, score)) 251 | 252 | 253 | def main(): 254 | 255 | parser = ArgumentParser() 256 | parser.add_argument("--pretrain", default="bert", type=str) 257 | parser.add_argument("--do_data", action="store_true") 258 | parser.add_argument("--do_train", action="store_true") 259 | parser.add_argument("--do_test", action="store_true") 260 | parser.add_argument("--save_best", action="store_true") 261 | parser.add_argument("--do_lower_case", action='store_true') 262 | parser.add_argument("--data_name", default="law", type=str) 263 | parser.add_argument("--train_data_num", default=0, type=int) 264 | parser.add_argument("--test_data_num", default=0, type=int) 265 | parser.add_argument("--epochs", default=5, type=int) 266 | parser.add_argument("--resume_path", default="", type=str) 267 | parser.add_argument("--mode", default="min", type=str) 268 | parser.add_argument("--monitor", default="valid_loss", type=str) 269 | parser.add_argument("--valid_size", default=0.2, type=float) 270 | parser.add_argument("--local_rank", type=int, default=-1) 271 | parser.add_argument("--sorted", default=1, type=int, 272 | help="1 : True 0:False") 273 | parser.add_argument("--n_gpu", type=str, default="0", 274 | help='"0,1,.." or "0" or "" ') 275 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 276 | parser.add_argument("--train_batch_size", default=8, type=int) 277 | parser.add_argument("--eval_batch_size", default=8, type=int) 278 | parser.add_argument("--train_max_seq_len", default=256, type=int) 279 | parser.add_argument("--eval_max_seq_len", default=256, type=int) 280 | parser.add_argument("--loss_scale", type=float, default=0) 281 | parser.add_argument("--warmup_proportion", default=0.1, type=int, ) 282 | parser.add_argument("--weight_decay", default=0.01, type=float) 283 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 284 | parser.add_argument("--grad_clip", default=1.0, type=float) 285 | parser.add_argument("--learning_rate", default=2e-5, type=float) 286 | parser.add_argument("--seed", type=int, default=42) 287 | parser.add_argument("--fp16", action="store_true") 288 | parser.add_argument("--fp16_opt_level", type=str, default="O1") 289 | 290 | args = parser.parse_args() 291 | 292 | try: 293 | pipeline = piop.read_yml("pipeline.yml") 294 | pl = AttrDict(pipeline["pipeline"]) 295 | config["preprocessor"] = pl.preprocessor 296 | config["pretrain"] = pl.pretrain 297 | config["postprocessor"] = pl.postprocessor 298 | config["classifier"] = pl.classifier 299 | except Exception as e: 300 | raise PipelineReadError 301 | 302 | config["checkpoint_dir"] = config["checkpoint_dir"] / config["classifier"] 303 | config["checkpoint_dir"].mkdir(exist_ok=True) 304 | 305 | torch.save(args, config["checkpoint_dir"] / "training_args.bin") 306 | seed_everything(args.seed) 307 | init_logger(log_file=config["log_dir"] / 308 | "{}.log".format(config["classifier"])) 309 | 310 | logger.info("Training/evaluation parameters %s", args) 311 | 312 | if args.do_data: 313 | from dataio.task_data import TaskData 314 | data = TaskData(args.train_data_num) 315 | labels, sents = data.read_data( 316 | raw_data_path=config["raw_data_path"], 317 | data_dir=config["data_dir"], 318 | preprocessor=Preprocessor(config["preprocessor"])( 319 | stopwords_path=config["stopwords_path"], 320 | userdict_path=config["userdict_path"]), 321 | is_train=True) 322 | data.train_val_split(X=sents, y=labels, 323 | valid_size=args.valid_size, 324 | data_dir=config["data_dir"], 325 | data_name=args.data_name) 326 | if config["pretrain"] == "Nopretrain": 327 | data.build_vocab( 328 | config["nopretrain_vocab_path"], sents, min_count=5) 329 | 330 | if args.do_train: 331 | train(args) 332 | 333 | if args.do_test: 334 | test(args) 335 | 336 | 337 | if __name__ == "__main__": 338 | main() 339 | --------------------------------------------------------------------------------