├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── collator.cpython-36.pyc │ ├── dataset.cpython-36.pyc │ └── classification_dataset.cpython-36.pyc ├── data_preprocessor.py ├── classification_dataset.py ├── collator.py └── dataset.py ├── dict_msf ├── doc_char.dict ├── doc_token.dict ├── doc_topic.dict ├── doc_keyword.dict ├── doc_token_ngram.dict └── doc_label.dict ├── requirements.txt ├── __pycache__ ├── util.cpython-36.pyc └── config.cpython-36.pyc ├── model ├── __pycache__ │ ├── rnn.cpython-36.pyc │ ├── layers.cpython-36.pyc │ ├── loss.cpython-36.pyc │ ├── attention.cpython-36.pyc │ ├── embedding.cpython-36.pyc │ ├── model_util.cpython-36.pyc │ ├── optimizer.cpython-36.pyc │ └── transformer_encoder.cpython-36.pyc ├── classification │ ├── __pycache__ │ │ ├── dpcnn.cpython-36.pyc │ │ ├── drnn.cpython-36.pyc │ │ ├── fasttext.cpython-36.pyc │ │ ├── textcnn.cpython-36.pyc │ │ ├── textrcnn.cpython-36.pyc │ │ ├── textrnn.cpython-36.pyc │ │ ├── classifier.cpython-36.pyc │ │ ├── textvdcnn.cpython-36.pyc │ │ ├── transformer.cpython-36.pyc │ │ ├── region_embedding.cpython-36.pyc │ │ └── attentive_convolution.cpython-36.pyc │ ├── bert.py │ ├── region_embedding.py │ ├── dpcnn.py │ ├── textrcnn.py │ ├── textrnn.py │ ├── textcnn.py │ ├── drnn.py │ ├── transformer.py │ ├── textvdcnn.py │ ├── classifier.py │ ├── fasttext.py │ └── attentive_convolution.py ├── attention.py ├── transformer_encoder.py ├── rnn.py ├── model_util.py ├── loss.py ├── layers.py ├── optimizer.py └── embedding.py ├── evaluate ├── __pycache__ │ └── classification_evaluate.cpython-36.pyc └── classification_evaluate.py ├── config.py ├── README.md ├── util.py ├── conf ├── train_cnn.json ├── train_rcnn.json ├── train_rnn.json └── train_bert.json ├── eval.py └── readme └── Configuration.md /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dict_msf/doc_char.dict: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dict_msf/doc_token.dict: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dict_msf/doc_topic.dict: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dict_msf/doc_keyword.dict: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dict_msf/doc_token_ngram.dict: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch_nightly>=1.0.0.dev20190325 2 | numpy>=1.16.2 3 | torch>=1.0.1.post2 4 | -------------------------------------------------------------------------------- /__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/rnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/rnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/collator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/dataset/__pycache__/collator.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/dataset/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/embedding.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/model_util.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transformer_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/__pycache__/transformer_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/dpcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/dpcnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/drnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/drnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/fasttext.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/fasttext.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/textcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/textcnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/textrcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/textrcnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/textrnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/textrnn.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/classification_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/dataset/__pycache__/classification_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /evaluate/__pycache__/classification_evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/evaluate/__pycache__/classification_evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/classifier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/classifier.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/textvdcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/textvdcnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/region_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/region_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /model/classification/__pycache__/attentive_convolution.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmshi-trio/MSL/HEAD/model/classification/__pycache__/attentive_convolution.cpython-36.pyc -------------------------------------------------------------------------------- /dict_msf/doc_label.dict: -------------------------------------------------------------------------------- 1 | 头痛 829 2 | 腹痛 802 3 | 胸痛 317 4 | 疼痛 282 5 | 咽喉痛 264 6 | 腰痛 247 7 | 发热 198 8 | 关节疼痛 171 9 | 肢体疼痛 170 10 | 头晕 161 11 | 背痛 155 12 | 痛风 148 13 | 腹胀 142 14 | 牙痛 136 15 | 咳嗽 121 16 | 感冒 99 17 | 呕吐 99 18 | 恶心 97 19 | 无力 94 20 | 下腹痛 89 21 | 全身酸痛 86 22 | 腹泻 82 23 | 尿痛 78 24 | 胸闷 65 25 | 失眠 51 26 | 眼痛 47 27 | 便秘 44 28 | 尿频 41 29 | 腰酸 40 30 | -------------------------------------------------------------------------------- /model/classification/bert.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | import torch 4 | 5 | from dataset.classification_dataset import ClassificationDataset as cDataset 6 | from model.classification.classifier import Classifier 7 | from transformers import * 8 | 9 | class BERT(Classifier): 10 | def __init__(self, dataset, config): 11 | super(BERT, self).__init__(dataset, config) 12 | self.bert_model = BertModel.from_pretrained(config.data.pretrained_bert_embedding) 13 | self.linear = torch.nn.Linear(config.embedding.dimension, len(dataset.label_map)) 14 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 15 | 16 | def forward(self, batch): 17 | embedding = self.bert_model(torch.LongTensor(batch['doc_token']).to(self.config.device))[1] 18 | return self.dropout(self.linear(embedding)) 19 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import json 16 | 17 | 18 | class Config(object): 19 | """Config load from json file 20 | """ 21 | 22 | def __init__(self, config=None, config_file=None): 23 | if config_file: 24 | with open(config_file, 'r') as fin: 25 | config = json.load(fin) 26 | 27 | self.dict = config 28 | if config: 29 | self._update(config) 30 | 31 | def __getitem__(self, key): 32 | return self.dict[key] 33 | 34 | def __contains__(self, item): 35 | return item in self.dict 36 | 37 | def items(self): 38 | return self.dict.items() 39 | 40 | def add(self, key, value): 41 | """Add key value pair 42 | """ 43 | self.__dict__[key] = value 44 | 45 | def _update(self, config): 46 | if not isinstance(config, dict): 47 | return 48 | 49 | for key in config: 50 | if isinstance(config[key], dict): 51 | config[key] = Config(config[key]) 52 | 53 | if isinstance(config[key], list): 54 | config[key] = [Config(x) if isinstance(x, dict) else x for x in 55 | config[key]] 56 | 57 | self.__dict__.update(config) 58 | -------------------------------------------------------------------------------- /model/classification/region_embedding.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from dataset.classification_dataset import ClassificationDataset as cDataset 18 | from model.classification.classifier import Classifier 19 | from model.model_util import InitType 20 | from model.model_util import init_tensor 21 | 22 | 23 | class RegionEmbedding(Classifier): 24 | """Implement region embedding classification method 25 | Reference: "A New Method of Region Embedding for Text Classification" 26 | """ 27 | 28 | def __init__(self, dataset, config): 29 | super(RegionEmbedding, self).__init__(dataset, config) 30 | self.region_size = config.embedding.region_size 31 | self.radius = int(self.region_size / 2) 32 | self.linear = torch.nn.Linear(config.embedding.dimension, 33 | len(dataset.label_map)) 34 | init_tensor(self.linear.weight, init_type=InitType.XAVIER_UNIFORM) 35 | init_tensor(self.linear.bias, init_type=InitType.UNIFORM, low=0, high=0) 36 | 37 | def get_parameter_optimizer_dict(self): 38 | params = super(RegionEmbedding, self).get_parameter_optimizer_dict() 39 | return params 40 | 41 | def forward(self, batch): 42 | embedding, _, mask = self.get_embedding( 43 | batch, [self.radius, self.radius], cDataset.VOCAB_PADDING) 44 | # mask should have same dim with padded embedding 45 | mask = torch.nn.functional.pad(mask, (self.radius, self.radius, 0, 0), "constant", 0) 46 | mask = mask.unsqueeze(2) 47 | embedding = embedding * mask 48 | doc_embedding = torch.sum(embedding, 1) 49 | doc_embedding = self.dropout(doc_embedding) 50 | return self.linear(doc_embedding) 51 | -------------------------------------------------------------------------------- /dataset/data_preprocessor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | 16 | # Same as "A New Method of Region Embedding for Text Classification" 17 | # https://github.com/text-representation/local-context-unit/blob/master/bin/prepare.py 18 | 19 | import csv 20 | import json 21 | import re 22 | import sys 23 | 24 | 25 | def clean_str(string): 26 | """ 27 | Tokenization/string cleaning for all datasets except for SST. 28 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 29 | """ 30 | string = string.strip().strip('"') 31 | string = re.sub(r"[^A-Za-z0-9(),!?\.\'\`]", " ", string) 32 | string = re.sub(r"\'s", " \'s", string) 33 | string = re.sub(r"\'ve", " \'ve", string) 34 | string = re.sub(r"n\'t", " n\'t", string) 35 | string = re.sub(r"\'re", " \'re", string) 36 | string = re.sub(r"\'d", " \'d", string) 37 | string = re.sub(r"\'ll", " \'ll", string) 38 | string = re.sub(r",", " , ", string) 39 | string = re.sub(r"\.", " \. ", string) 40 | string = re.sub(r"\"", " , ", string) 41 | string = re.sub(r"!", " ! ", string) 42 | string = re.sub(r"\(", " \( ", string) 43 | string = re.sub(r"\)", " \) ", string) 44 | string = re.sub(r"\?", " \? ", string) 45 | string = re.sub(r"\s{2,}", " ", string) 46 | return string.strip().lower() 47 | 48 | 49 | def convert_multi_slots_to_single_slots(slots): 50 | """ 51 | covert the data which text_data are saved as multi-slots, e.g() 52 | """ 53 | if len(slots) == 1: 54 | return slots[0] 55 | else: 56 | return ' '.join(slots) 57 | 58 | 59 | def preprocess(csv_file, json_file): 60 | with open(json_file, "w") as fout: 61 | with open(csv_file, 'rb') as fin: 62 | lines = csv.reader(fin) 63 | for items in lines: 64 | text_data = convert_multi_slots_to_single_slots(items[1:]) 65 | text_data = clean_str(text_data) 66 | sample = dict() 67 | sample['doc_label'] = [items[0]] 68 | sample['doc_token'] = text_data.split(" ") 69 | sample['doc_keyword'] = [] 70 | sample['doc_topic'] = [] 71 | json_str = json.dumps(sample, ensure_ascii=False) 72 | fout.write(json_str) 73 | 74 | 75 | if __name__ == '__main__': 76 | train_csv = sys.argv[1] 77 | train_json = sys.argv[2] 78 | test_csv = sys.argv[3] 79 | test_json = sys.argv[4] 80 | preprocess(train_csv, train_json) 81 | preprocess(test_csv, test_json) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralBERTClassifier for Medical Slot Filling 2 | 3 | ## Introduction 4 | 5 | NeuralBERTClassifier is designed for quick implementation of neural models for multi-label classification problem: Medical Slot Filling (MSF). A salient feature is that NeuralBERTClassifier currently provides a variety of text encoders, such as FastText, TextCNN, TextRNN, RCNN, VDCNN, DPCNN, DRNN, AttentiveConvNet, Transformer encoder, and BERT etc. It also supports other text classification scenarios, including binary-class and multi-class classification. It is built on [PyTorch](https://pytorch.org/). Corresponding paper **Understanding Medical Conversations with Scattered Keyword Attention and Weak Supervision from Responses** was accepted by [AAAI 2020](https://aaai.org/ojs/index.php/AAAI/article/view/6412). 6 | 7 | 8 | ## Notice 9 | **According to Tencent's regulations, the dataset can only be used for research purposes.** 10 | 11 | 12 | ## Support tasks 13 | 14 | * Binary-class text classifcation 15 | * Multi-class text classification 16 | * Multi-label text classification 17 | * Hiearchical (multi-label) text classification (HMC) 18 | 19 | ## Support text encoders 20 | 21 | * TextCNN ([Kim, 2014](https://arxiv.org/pdf/1408.5882.pdf)) 22 | * RCNN ([Lai et al., 2015](https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/download/9745/9552)) 23 | * TextRNN ([Liu et al., 2016](https://arxiv.org/pdf/1605.05101.pdf)) 24 | * FastText ([Joulin et al., 2016](https://arxiv.org/pdf/1607.01759.pdf)) 25 | * VDCNN ([Conneau et al., 2016](https://arxiv.org/pdf/1606.01781.pdf)) 26 | * DPCNN ([Johnson and Zhang, 2017](https://www.aclweb.org/anthology/P17-1052)) 27 | * AttentiveConvNet ([Yin and Schutze, 2017](https://arxiv.org/pdf/1710.00519.pdf)) 28 | * DRNN ([Wang, 2018](https://www.aclweb.org/anthology/P18-1215)) 29 | * Region embedding ([Qiao et al., 2018](http://research.baidu.com/Public/uploads/5acc1e230d179.pdf)) 30 | * Transformer encoder ([Vaswani et al., 2017](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)) 31 | * Star-Transformer encoder ([Guo et al., 2019](https://arxiv.org/pdf/1902.09113.pdf)) 32 | 33 | ## Requirement 34 | 35 | * Python 3 36 | * PyTorch 0.4+ 37 | * Numpy 1.14.3+ 38 | 39 | ## Usage 40 | 41 | ### Training 42 | 43 | python train.py conf/train.json 44 | 45 | ***Detail configurations and explanations see [Configuration](readme/Configuration.md).*** 46 | 47 | The training info will be outputted in standard output and log.logger\_file. 48 | 49 | ### Evaluation 50 | python eval.py conf/train.json 51 | 52 | * if eval.is\_flat = false, hierarchical evaluation will be outputted. 53 | * eval.model\_dir is the model to evaluate. 54 | * data.test\_json\_files is the input text file to evaluate. 55 | 56 | The evaluation info will be outputed in eval.dir. 57 | 58 | ## Input Data Format 59 | 60 | JSON example: 61 | 62 | { 63 | "doc_label": ["Computer--MachineLearning--DeepLearning", "Neuro--ComputationalNeuro"], 64 | "doc_token": ["I", "love", "deep", "learning"], 65 | "doc_keyword": ["deep learning"], 66 | "doc_topic": ["AI", "Machine learning"] 67 | } 68 | 69 | "doc_keyword" and "doc_topic" are optional. 70 | 71 | 72 | ## Update 73 | 74 | * 2020-10-27 75 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | 16 | import logging 17 | import sys 18 | 19 | EPS = 1e-7 20 | 21 | 22 | class Type(object): 23 | @classmethod 24 | def str(cls): 25 | raise NotImplementedError 26 | 27 | 28 | class ModeType(Type): 29 | """Standard names for model modes. 30 | The following standard keys are defined: 31 | * `TRAIN`: training mode. 32 | * `EVAL`: evaluation mode. 33 | * `PREDICT`: inference mode. 34 | """ 35 | TRAIN = 'train' 36 | EVAL = 'eval' 37 | PREDICT = 'infer' 38 | 39 | @classmethod 40 | def str(cls): 41 | return ",".join([cls.TRAIN, cls.EVAL, cls.PREDICT]) 42 | 43 | 44 | class Logger(object): 45 | _instance = None 46 | 47 | def __new__(cls, *args, **kw): 48 | if not cls._instance: 49 | cls._instance = super(Logger, cls).__new__(cls) 50 | return cls._instance 51 | 52 | def __init__(self, config): 53 | if config.log.log_level == "debug": 54 | logging_level = logging.DEBUG 55 | elif config.log.log_level == "info": 56 | logging_level = logging.INFO 57 | elif config.log.log_level == "warn": 58 | logging_level = logging.WARN 59 | elif config.log.log_level == "error": 60 | logging_level = logging.ERROR 61 | else: 62 | raise TypeError( 63 | "No logging type named %s, candidate is: info, debug, error") 64 | logging.basicConfig(filename=config.log.logger_file, 65 | level=logging_level, 66 | format='%(asctime)s : %(levelname)s %(message)s', 67 | filemode="a", datefmt='%Y-%m-%d %H:%M:%S') 68 | 69 | @staticmethod 70 | def debug(msg): 71 | """Log debug message 72 | msg: Message to log 73 | """ 74 | logging.debug(msg) 75 | sys.stdout.write(msg + "\n") 76 | 77 | @staticmethod 78 | def info(msg): 79 | """"Log info message 80 | msg: Message to log 81 | """ 82 | logging.info(msg) 83 | sys.stdout.write(msg + "\n") 84 | 85 | @staticmethod 86 | def warn(msg): 87 | """Log warn message 88 | msg: Message to log 89 | """ 90 | logging.warning(msg) 91 | sys.stdout.write(msg + "\n") 92 | 93 | @staticmethod 94 | def error(msg): 95 | """Log error message 96 | msg: Message to log 97 | """ 98 | logging.error(msg) 99 | sys.stderr.write(msg + "\n") 100 | -------------------------------------------------------------------------------- /model/classification/dpcnn.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | from dataset.classification_dataset import ClassificationDataset as cDataset 19 | from model.classification.classifier import Classifier 20 | 21 | 22 | class DPCNN(Classifier): 23 | """ 24 | Reference: 25 | Deep Pyramid Convolutional Neural Networks for Text Categorization 26 | """ 27 | 28 | def __init__(self, dataset, config): 29 | super(DPCNN, self).__init__(dataset, config) 30 | self.num_kernels = config.DPCNN.num_kernels 31 | self.pooling_stride = config.DPCNN.pooling_stride 32 | self.kernel_size = config.DPCNN.kernel_size 33 | self.radius = int(self.kernel_size / 2) 34 | assert self.kernel_size % 2 == 1, "DPCNN kernel should be odd!" 35 | self.convert_conv = torch.nn.Sequential( 36 | torch.nn.Conv1d( 37 | config.embedding.dimension, self.num_kernels, 38 | self.kernel_size, padding=self.radius) 39 | ) 40 | 41 | self.convs = torch.nn.ModuleList([torch.nn.Sequential( 42 | torch.nn.ReLU(), 43 | torch.nn.Conv1d( 44 | self.num_kernels, self.num_kernels, 45 | self.kernel_size, padding=self.radius), 46 | torch.nn.ReLU(), 47 | torch.nn.Conv1d( 48 | self.num_kernels, self.num_kernels, 49 | self.kernel_size, padding=self.radius) 50 | ) for _ in range(config.DPCNN.blocks + 1)]) 51 | 52 | self.linear = torch.nn.Linear(self.num_kernels, len(dataset.label_map)) 53 | 54 | def get_parameter_optimizer_dict(self): 55 | params = super(DPCNN, self).get_parameter_optimizer_dict() 56 | params.append({'params': self.convert_conv.parameters()}) 57 | params.append({'params': self.convs.parameters()}) 58 | params.append({'params': self.linear.parameters()}) 59 | return params 60 | 61 | def forward(self, batch): 62 | if self.config.feature.feature_names[0] == "token": 63 | embedding = self.token_embedding( 64 | batch[cDataset.DOC_TOKEN].to(self.config.device)) 65 | else: 66 | embedding = self.char_embedding( 67 | batch[cDataset.DOC_CHAR]).to(self.config.device) 68 | embedding = embedding.permute(0, 2, 1) 69 | conv_embedding = self.convert_conv(embedding) 70 | conv_features = self.convs[0](conv_embedding) 71 | conv_features = conv_embedding + conv_features 72 | for i in range(1, len(self.convs)): 73 | block_features = F.max_pool1d( 74 | conv_features, self.kernel_size, self.pooling_stride) 75 | conv_features = self.convs[i](block_features) 76 | conv_features = conv_features + block_features 77 | doc_embedding = F.max_pool1d( 78 | conv_features, conv_features.size(2)).squeeze() 79 | return self.dropout(self.linear(doc_embedding)) 80 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from model.model_util import init_tensor 21 | 22 | 23 | class ScaledDotProductAttention(nn.Module): 24 | ''' Scaled Dot-Product Attention ''' 25 | 26 | def __init__(self, temperature, attn_dropout=0.1): 27 | super(ScaledDotProductAttention, self).__init__() 28 | self.temperature = temperature 29 | self.dropout = nn.Dropout(attn_dropout) 30 | self.softmax = nn.Softmax(dim=2) 31 | 32 | def forward(self, q, k, v, mask=None): 33 | 34 | attn = torch.bmm(q, k.transpose(1, 2)) 35 | attn = attn / self.temperature 36 | 37 | if mask is not None: 38 | attn = attn.masked_fill(mask, -np.inf) 39 | 40 | attn = self.softmax(attn) 41 | attn = self.dropout(attn) 42 | output = torch.bmm(attn, v) 43 | 44 | return output, attn 45 | 46 | 47 | class MultiHeadAttention(nn.Module): 48 | ''' Multi-Head Attention module ''' 49 | 50 | def __init__(self, n_head, d_model, d_k, d_v, use_star=False, dropout=0.1): 51 | super(MultiHeadAttention, self).__init__() 52 | 53 | self.n_head = n_head 54 | self.d_k = d_k 55 | self.d_v = d_v 56 | self.use_star = use_star 57 | 58 | self.w_qs = nn.Linear(d_model, n_head * d_k) 59 | self.w_ks = nn.Linear(d_model, n_head * d_k) 60 | self.w_vs = nn.Linear(d_model, n_head * d_v) 61 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 62 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 63 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 64 | 65 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 66 | self.layer_norm = nn.LayerNorm(d_model) 67 | 68 | self.fc = nn.Linear(n_head * d_v, d_model) 69 | nn.init.xavier_normal_(self.fc.weight) 70 | 71 | self.dropout = nn.Dropout(dropout) 72 | 73 | 74 | def forward(self, q, k, v, mask=None): 75 | 76 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 77 | 78 | sz_b, len_q, _ = q.size() 79 | sz_b, len_k, _ = k.size() 80 | sz_b, len_v, _ = v.size() 81 | 82 | residual = q 83 | 84 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 85 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 86 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 87 | 88 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 89 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 90 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 91 | 92 | if mask is not None: 93 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 94 | output, attn = self.attention(q, k, v, mask=mask) 95 | 96 | output = output.view(n_head, sz_b, len_q, d_v) 97 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 98 | 99 | if self.use_star: 100 | output = self.dropout(F.relu(self.fc(output))) 101 | output = self.layer_norm(output) 102 | else: 103 | output = self.dropout(self.fc(output)) 104 | output = self.layer_norm(output + residual) 105 | 106 | return output, attn 107 | -------------------------------------------------------------------------------- /model/classification/textrcnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | from dataset.classification_dataset import ClassificationDataset as cDataset 19 | from model.classification.classifier import Classifier 20 | from model.rnn import RNN 21 | 22 | 23 | class TextRCNN(Classifier): 24 | """TextRNN + TextCNN 25 | """ 26 | def __init__(self, dataset, config): 27 | super(TextRCNN, self).__init__(dataset, config) 28 | self.rnn = RNN( 29 | config.embedding.dimension, config.TextRCNN.hidden_dimension, 30 | num_layers=config.TextRCNN.num_layers, 31 | batch_first=True, bidirectional=config.TextRCNN.bidirectional, 32 | rnn_type=config.TextRCNN.rnn_type) 33 | 34 | hidden_dimension = config.TextRCNN.hidden_dimension 35 | if config.TextRCNN.bidirectional: 36 | hidden_dimension *= 2 37 | self.kernel_sizes = config.TextRCNN.kernel_sizes 38 | self.convs = torch.nn.ModuleList() 39 | for kernel_size in self.kernel_sizes: 40 | self.convs.append(torch.nn.Conv1d( 41 | hidden_dimension, config.TextRCNN.num_kernels, 42 | kernel_size, padding=kernel_size - 1)) 43 | 44 | self.top_k = self.config.TextRCNN.top_k_max_pooling 45 | hidden_size = len(config.TextRCNN.kernel_sizes) * \ 46 | config.TextRCNN.num_kernels * self.top_k 47 | 48 | self.linear = torch.nn.Linear(hidden_size, len(dataset.label_map)) 49 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 50 | 51 | def get_parameter_optimizer_dict(self): 52 | params = list() 53 | params.append({'params': self.token_embedding.parameters()}) 54 | params.append({'params': self.char_embedding.parameters()}) 55 | params.append({'params': self.rnn.parameters()}) 56 | params.append({'params': self.convs.parameters()}) 57 | params.append({'params': self.linear.parameters()}) 58 | return params 59 | 60 | def update_lr(self, optimizer, epoch): 61 | """ 62 | """ 63 | if epoch > self.config.train.num_epochs_static_embedding: 64 | for param_group in optimizer.param_groups[:2]: 65 | param_group["lr"] = self.config.optimizer.learning_rate 66 | else: 67 | for param_group in optimizer.param_groups[:2]: 68 | param_group["lr"] = 0 69 | 70 | def forward(self, batch): 71 | #if self.config.feature.feature_names[0] == "token": 72 | # embedding = self.token_embedding( 73 | # batch[cDataset.DOC_TOKEN].to(self.config.device)) 74 | # seq_length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device) 75 | #else: 76 | # embedding = self.char_embedding( 77 | # batch[cDataset.DOC_CHAR].to(self.config.device)) 78 | # seq_length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device) 79 | embedding = batch['doc_token'].to(self.config.device) 80 | seq_length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device) 81 | output, _ = self.rnn(embedding, seq_length) 82 | 83 | doc_embedding = output.transpose(1, 2) 84 | pooled_outputs = [] 85 | for _, conv in enumerate(self.convs): 86 | convolution = F.relu(conv(doc_embedding)) 87 | pooled = torch.topk(convolution, self.top_k)[0].view( 88 | convolution.size(0), -1) 89 | pooled_outputs.append(pooled) 90 | 91 | doc_embedding = torch.cat(pooled_outputs, 1) 92 | 93 | return self.dropout(self.linear(doc_embedding)) 94 | -------------------------------------------------------------------------------- /model/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | """ 16 | Transformer Encoder: 17 | Heavily borrowed from https://github.com/jadore801120/attention-is-all-you-need-pytorch/ 18 | Star-Transformer Encode: 19 | https://arxiv.org/pdf/1902.09113v2.pdf 20 | """ 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | from model.attention import MultiHeadAttention 27 | 28 | 29 | class PositionwiseFeedForward(nn.Module): 30 | ''' A two-feed-forward-layer module ''' 31 | 32 | def __init__(self, d_in, d_hid, dropout=0.1): 33 | super(PositionwiseFeedForward, self).__init__() 34 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 35 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 36 | self.layer_norm = nn.LayerNorm(d_in) 37 | self.dropout = nn.Dropout(dropout) 38 | 39 | def forward(self, x): 40 | residual = x 41 | output = x.transpose(1, 2) 42 | output = self.w_2(F.relu(self.w_1(output))) 43 | output = output.transpose(1, 2) 44 | output = self.dropout(output) 45 | output = self.layer_norm(output + residual) 46 | return output 47 | 48 | 49 | class EncoderLayer(nn.Module): 50 | ''' Compose with two layers ''' 51 | 52 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 53 | super(EncoderLayer, self).__init__() 54 | self.slf_attn = MultiHeadAttention( 55 | n_head, d_model, d_k, d_v, dropout=dropout) 56 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 57 | 58 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 59 | enc_output, enc_slf_attn = self.slf_attn( 60 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 61 | enc_output *= non_pad_mask 62 | 63 | enc_output = self.pos_ffn(enc_output) 64 | enc_output *= non_pad_mask 65 | 66 | return enc_output, enc_slf_attn 67 | 68 | 69 | class StarEncoderLayer(nn.Module): 70 | ''' Star-Transformer: https://arxiv.org/pdf/1902.09113v2.pdf ''' 71 | 72 | def __init__(self, d_model, n_head, d_k, d_v, dropout=0.1): 73 | super(StarEncoderLayer, self).__init__() 74 | self.slf_attn_satellite = MultiHeadAttention( 75 | n_head, d_model, d_k, d_v, use_star=True, dropout=dropout) 76 | self.slf_attn_relay = MultiHeadAttention( 77 | n_head, d_model, d_k, d_v, use_star=True, dropout=dropout) 78 | 79 | def forward(self, h, e, s, non_pad_mask=None, slf_attn_mask=None): 80 | # satellite node 81 | batch_size, seq_len, d_model = h.size() 82 | h_extand = torch.zeros(batch_size, seq_len+2, d_model, dtype=torch.float, device=h.device) 83 | h_extand[:, 1:seq_len+1, :] = h # head and tail padding(not cycle) 84 | s = s.reshape([batch_size, 1, d_model]) 85 | s_expand = s.expand([batch_size, seq_len, d_model]) 86 | context = torch.cat((h_extand[:, 0:seq_len, :], 87 | h_extand[:, 1:seq_len+1, :], 88 | h_extand[:, 2:seq_len+2, :], 89 | e, 90 | s_expand), 91 | 2) 92 | context = context.reshape([batch_size*seq_len, 5, d_model]) 93 | h = h.reshape([batch_size*seq_len, 1, d_model]) 94 | 95 | h, _ = self.slf_attn_satellite( 96 | h, context, context, mask=slf_attn_mask) 97 | h = torch.squeeze(h, 1).reshape([batch_size, seq_len, d_model]) 98 | if non_pad_mask is not None: 99 | h *= non_pad_mask 100 | 101 | # virtual relay node 102 | s_h = torch.cat((s, h), 1) 103 | s, _ = self.slf_attn_relay( 104 | s, s_h, s_h, mask=slf_attn_mask) 105 | s = torch.squeeze(s, 1) 106 | 107 | return h, s 108 | -------------------------------------------------------------------------------- /model/rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from util import Type 18 | 19 | 20 | class RNNType(Type): 21 | RNN = 'RNN' 22 | LSTM = 'LSTM' 23 | GRU = 'GRU' 24 | 25 | @classmethod 26 | def str(cls): 27 | return ",".join([cls.RNN, cls.LSTM, cls.GRU]) 28 | 29 | 30 | class RNN(torch.nn.Module): 31 | """ 32 | One layer rnn. 33 | """ 34 | 35 | def __init__(self, input_size, hidden_size, num_layers=1, 36 | nonlinearity="tanh", bias=True, batch_first=False, dropout=0, 37 | bidirectional=False, rnn_type=RNNType.GRU): 38 | super(RNN, self).__init__() 39 | self.rnn_type = rnn_type 40 | self.num_layers = num_layers 41 | self.batch_first = batch_first 42 | self.bidirectional = bidirectional 43 | if rnn_type == RNNType.LSTM: 44 | self.rnn = torch.nn.LSTM( 45 | input_size, hidden_size, num_layers=num_layers, bias=bias, 46 | batch_first=batch_first, dropout=dropout, 47 | bidirectional=bidirectional) 48 | elif rnn_type == RNNType.GRU: 49 | self.rnn = torch.nn.GRU( 50 | input_size, hidden_size, num_layers=num_layers, bias=bias, 51 | batch_first=batch_first, dropout=dropout, 52 | bidirectional=bidirectional) 53 | elif rnn_type == RNNType.RNN: 54 | self.rnn = torch.nn.RNN( 55 | input_size, hidden_size, vnonlinearity=nonlinearity, bias=bias, 56 | batch_first=batch_first, dropout=dropout, 57 | bidirectional=bidirectional) 58 | else: 59 | raise TypeError( 60 | "Unsupported rnn init type: %s. Supported rnn type is: %s" % ( 61 | rnn_type, RNNType.str())) 62 | 63 | def forward(self, inputs, seq_lengths=None, init_state=None, 64 | ori_state=False): 65 | """ 66 | Args: 67 | inputs: 68 | seq_lengths: 69 | init_state: 70 | ori_state: If true, will return ori state generate by rnn. Else will 71 | will return formatted state 72 | :return: 73 | """ 74 | if seq_lengths is not None: 75 | seq_lengths = seq_lengths.int() 76 | sorted_seq_lengths, indices = torch.sort(seq_lengths, 77 | descending=True) 78 | if self.batch_first: 79 | sorted_inputs = inputs[indices] 80 | else: 81 | sorted_inputs = inputs[:, indices] 82 | packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( 83 | sorted_inputs, sorted_seq_lengths, batch_first=self.batch_first) 84 | outputs, state = self.rnn(packed_inputs, init_state) 85 | else: 86 | outputs, state = self.rnn(inputs, init_state) 87 | 88 | if ori_state: 89 | return outputs, state 90 | if self.rnn_type == RNNType.LSTM: 91 | state = state[0] 92 | if self.bidirectional: 93 | last_layers_hn = state[2 * (self.num_layers - 1):] 94 | last_layers_hn = torch.cat( 95 | (last_layers_hn[0], last_layers_hn[1]), 1) 96 | else: 97 | last_layers_hn = state[self.num_layers - 1:] 98 | last_layers_hn = last_layers_hn[0] 99 | 100 | _, revert_indices = torch.sort(indices, descending=False) 101 | last_layers_hn = last_layers_hn[revert_indices] 102 | pad_output, _ = torch.nn.utils.rnn.pad_packed_sequence( 103 | outputs, batch_first=self.batch_first) 104 | if self.batch_first: 105 | pad_output = pad_output[revert_indices] 106 | else: 107 | pad_output = pad_output[:, revert_indices] 108 | return pad_output, last_layers_hn 109 | -------------------------------------------------------------------------------- /conf/train_cnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "task_info":{ 3 | "label_type": "multi_label", 4 | "hierarchical": false, 5 | "hierar_taxonomy": "data/rcv1.taxonomy", 6 | "hierar_penalty": 0.000001, 7 | "weak_data_augmentation": true, 8 | "weak_pretrain": true, 9 | "top_n_teacher": 3, 10 | "Augmentation_Method": "Self Learning", 11 | "add_noise": false 12 | }, 13 | "device": "cuda", 14 | "model_name": "TextCNN", 15 | "checkpoint_dir": "checkpoint_dir_msf", 16 | "model_dir": "trained_model_msf", 17 | "data": { 18 | "train_json_files": [ 19 | "data/train_target.json" 20 | ], 21 | "validate_json_files": [ 22 | "data/dev_target.json" 23 | ], 24 | "test_json_files": [ 25 | "data/test_target.json" 26 | ], 27 | "unlabeled_train_json_files":[ 28 | "data/train_relative.json" 29 | ], 30 | "unlabeled_dev_json_files":[ 31 | "data/dev_relative.json" 32 | ], 33 | "unlabeled_test_json_files":[ 34 | "data/test_relative.json" 35 | ], 36 | "weak_labeled_json_files":[ 37 | "data/train_relative_labeled_x.json" 38 | ], 39 | "generate_dict_using_json_files": false, 40 | "generate_dict_using_all_json_files": true, 41 | "generate_dict_using_pretrained_embedding": false, 42 | "dict_dir": "dict_msf", 43 | "num_worker": 4, 44 | "pretrained_bert_embedding": "/dockerdata/xiaomingshi/chinese_L-12_H-768_A-12" 45 | }, 46 | "feature": { 47 | "feature_names": [ 48 | "token" 49 | ], 50 | "min_token_count": 2, 51 | "min_char_count": 2, 52 | "token_ngram": 0, 53 | "min_token_ngram_count": 0, 54 | "min_keyword_count": 0, 55 | "min_topic_count": 2, 56 | "max_token_dict_size": 1000000, 57 | "max_char_dict_size": 150000, 58 | "max_token_ngram_dict_size": 10000000, 59 | "max_keyword_dict_size": 100, 60 | "max_topic_dict_size": 100, 61 | "max_token_len": 256, 62 | "max_char_len": 1024, 63 | "max_char_len_per_token": 4, 64 | "token_pretrained_file": "", 65 | "keyword_pretrained_file": "" 66 | }, 67 | "train": { 68 | "K": 1, 69 | "batch_size": [64, 128, 256], 70 | "start_epoch": 1, 71 | "pretrain_num_epochs": [10, 15, 20], 72 | "num_epochs": 100, 73 | "self_num_epochs": 100, 74 | "num_epochs_static_embedding": 0, 75 | "decay_steps": 1000, 76 | "decay_rate": 1.0, 77 | "clip_gradients": 100.0, 78 | "l2_lambda": 0.0, 79 | "loss_type": "BCEWithLogitsLoss", 80 | "sampler": "fixed", 81 | "num_sampled": 5, 82 | "visible_device_list": "0", 83 | "hidden_layer_dropout": 0.5 84 | }, 85 | "embedding": { 86 | "type": "embedding", 87 | "dimension": 768, 88 | "region_embedding_type": "context_word", 89 | "region_size": 5, 90 | "initializer": "uniform", 91 | "fan_mode": "FAN_IN", 92 | "uniform_bound": 0.25, 93 | "random_stddev": 0.01, 94 | "dropout": 0.0 95 | }, 96 | "optimizer": { 97 | "optimizer_type": "Adam", 98 | "learning_rate": 0.008, 99 | "adadelta_decay_rate": 0.95, 100 | "adadelta_epsilon": 1e-08 101 | }, 102 | "TextCNN": { 103 | "kernel_sizes": [ 104 | 1, 105 | 2, 106 | 3 107 | ], 108 | "num_kernels": 100, 109 | "top_k_max_pooling": 1 110 | }, 111 | "TextRNN": { 112 | "hidden_dimension": 64, 113 | "rnn_type": "LSTM", 114 | "num_layers": 1, 115 | "doc_embedding_type": "Attention", 116 | "attention_dimension": 16, 117 | "bidirectional": true 118 | }, 119 | "DRNN": { 120 | "hidden_dimension": 5, 121 | "window_size": 3, 122 | "rnn_type": "GRU", 123 | "bidirectional": true, 124 | "cell_hidden_dropout": 0.1 125 | }, 126 | "eval": { 127 | "text_file": "data/test_target.json", 128 | "threshold": 0.5, 129 | "dir": "eval_dir", 130 | "batch_size": 1024, 131 | "is_flat": true, 132 | "top_k": 5, 133 | "model_dir": "checkpoint_dir_msf" 134 | }, 135 | "TextVDCNN": { 136 | "vdcnn_depth": 9, 137 | "top_k_max_pooling": 8 138 | }, 139 | "DPCNN": { 140 | "kernel_size": 3, 141 | "pooling_stride": 2, 142 | "num_kernels": 16, 143 | "blocks": 2 144 | }, 145 | "TextRCNN": { 146 | "kernel_sizes": [ 147 | 1, 148 | 2, 149 | 3 150 | ], 151 | "num_kernels": 100, 152 | "top_k_max_pooling": 1, 153 | "hidden_dimension":64, 154 | "rnn_type": "GRU", 155 | "num_layers": 1, 156 | "bidirectional": true 157 | }, 158 | "Transformer": { 159 | "d_inner": 128, 160 | "d_k": 32, 161 | "d_v": 32, 162 | "n_head": 4, 163 | "n_layers": 1, 164 | "dropout": 0.1, 165 | "use_star": true 166 | }, 167 | "AttentiveConvNet": { 168 | "attention_type": "bilinear", 169 | "margin_size": 3, 170 | "type": "advanced", 171 | "hidden_size": 64 172 | }, 173 | "log": { 174 | "logger_file": "./log/TextCNN_SelfLearning", 175 | "log_level": "info" 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /conf/train_rcnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "task_info":{ 3 | "label_type": "multi_label", 4 | "hierarchical": false, 5 | "hierar_taxonomy": "data/rcv1.taxonomy", 6 | "hierar_penalty": 0.000001, 7 | "weak_data_augmentation": true, 8 | "weak_pretrain": true, 9 | "top_n_teacher": 3, 10 | "Augmentation_Method": "Self Learning", 11 | "add_noise": false 12 | }, 13 | "device": "cuda", 14 | "model_name": "TextRCNN", 15 | "checkpoint_dir": "checkpoint_dir_msf", 16 | "model_dir": "trained_model_msf", 17 | "data": { 18 | "train_json_files": [ 19 | "data/train_target.json" 20 | ], 21 | "validate_json_files": [ 22 | "data/dev_target.json" 23 | ], 24 | "test_json_files": [ 25 | "data/test_target.json" 26 | ], 27 | "unlabeled_train_json_files":[ 28 | "data/train_relative.json" 29 | ], 30 | "unlabeled_dev_json_files":[ 31 | "data/dev_relative.json" 32 | ], 33 | "unlabeled_test_json_files":[ 34 | "data/test_relative.json" 35 | ], 36 | "weak_labeled_json_files":[ 37 | "data/train_relative_labeled_x.json" 38 | ], 39 | "generate_dict_using_json_files": false, 40 | "generate_dict_using_all_json_files": true, 41 | "generate_dict_using_pretrained_embedding": false, 42 | "dict_dir": "dict_msf", 43 | "num_worker": 4, 44 | "pretrained_bert_embedding": "/dockerdata/xiaomingshi/chinese_L-12_H-768_A-12" 45 | }, 46 | "feature": { 47 | "feature_names": [ 48 | "token" 49 | ], 50 | "min_token_count": 2, 51 | "min_char_count": 2, 52 | "token_ngram": 0, 53 | "min_token_ngram_count": 0, 54 | "min_keyword_count": 0, 55 | "min_topic_count": 2, 56 | "max_token_dict_size": 1000000, 57 | "max_char_dict_size": 150000, 58 | "max_token_ngram_dict_size": 10000000, 59 | "max_keyword_dict_size": 100, 60 | "max_topic_dict_size": 100, 61 | "max_token_len": 256, 62 | "max_char_len": 1024, 63 | "max_char_len_per_token": 4, 64 | "token_pretrained_file": "", 65 | "keyword_pretrained_file": "" 66 | }, 67 | "train": { 68 | "K": 1, 69 | "batch_size": [64, 128, 256], 70 | "start_epoch": 1, 71 | "pretrain_num_epochs": [10, 15, 20], 72 | "num_epochs": 100, 73 | "self_num_epochs": 100, 74 | "num_epochs_static_embedding": 0, 75 | "decay_steps": 1000, 76 | "decay_rate": 1.0, 77 | "clip_gradients": 100.0, 78 | "l2_lambda": 0.0, 79 | "loss_type": "BCEWithLogitsLoss", 80 | "sampler": "fixed", 81 | "num_sampled": 5, 82 | "visible_device_list": "2", 83 | "hidden_layer_dropout": 0.5 84 | }, 85 | "embedding": { 86 | "type": "embedding", 87 | "dimension": 768, 88 | "region_embedding_type": "context_word", 89 | "region_size": 5, 90 | "initializer": "uniform", 91 | "fan_mode": "FAN_IN", 92 | "uniform_bound": 0.25, 93 | "random_stddev": 0.01, 94 | "dropout": 0.0 95 | }, 96 | "optimizer": { 97 | "optimizer_type": "Adam", 98 | "learning_rate": 0.008, 99 | "adadelta_decay_rate": 0.95, 100 | "adadelta_epsilon": 1e-08 101 | }, 102 | "TextCNN": { 103 | "kernel_sizes": [ 104 | 1, 105 | 2, 106 | 3 107 | ], 108 | "num_kernels": 100, 109 | "top_k_max_pooling": 1 110 | }, 111 | "TextRNN": { 112 | "hidden_dimension": 64, 113 | "rnn_type": "LSTM", 114 | "num_layers": 1, 115 | "doc_embedding_type": "Attention", 116 | "attention_dimension": 16, 117 | "bidirectional": true 118 | }, 119 | "DRNN": { 120 | "hidden_dimension": 5, 121 | "window_size": 3, 122 | "rnn_type": "GRU", 123 | "bidirectional": true, 124 | "cell_hidden_dropout": 0.1 125 | }, 126 | "eval": { 127 | "text_file": "data/test_target.json", 128 | "threshold": 0.5, 129 | "dir": "eval_dir", 130 | "batch_size": 1024, 131 | "is_flat": true, 132 | "top_k": 5, 133 | "model_dir": "checkpoint_dir_msf" 134 | }, 135 | "TextVDCNN": { 136 | "vdcnn_depth": 9, 137 | "top_k_max_pooling": 8 138 | }, 139 | "DPCNN": { 140 | "kernel_size": 3, 141 | "pooling_stride": 2, 142 | "num_kernels": 16, 143 | "blocks": 2 144 | }, 145 | "TextRCNN": { 146 | "kernel_sizes": [ 147 | 1, 148 | 2, 149 | 3 150 | ], 151 | "num_kernels": 100, 152 | "top_k_max_pooling": 1, 153 | "hidden_dimension":64, 154 | "rnn_type": "GRU", 155 | "num_layers": 1, 156 | "bidirectional": true 157 | }, 158 | "Transformer": { 159 | "d_inner": 128, 160 | "d_k": 32, 161 | "d_v": 32, 162 | "n_head": 4, 163 | "n_layers": 1, 164 | "dropout": 0.1, 165 | "use_star": true 166 | }, 167 | "AttentiveConvNet": { 168 | "attention_type": "bilinear", 169 | "margin_size": 3, 170 | "type": "advanced", 171 | "hidden_size": 64 172 | }, 173 | "log": { 174 | "logger_file": "./log/TextRCNN_SelfLearning", 175 | "log_level": "info" 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /conf/train_rnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "task_info":{ 3 | "label_type": "multi_label", 4 | "hierarchical": false, 5 | "hierar_taxonomy": "data/rcv1.taxonomy", 6 | "hierar_penalty": 0.000001, 7 | "weak_data_augmentation": true, 8 | "weak_pretrain": true, 9 | "top_n_teacher": 3, 10 | "Augmentation_Method": "Self Learning", 11 | "add_noise": false 12 | }, 13 | "device": "cuda", 14 | "model_name": "TextRNN", 15 | "checkpoint_dir": "checkpoint_dir_msf", 16 | "model_dir": "trained_model_msf", 17 | "data": { 18 | "train_json_files": [ 19 | "data/train_target.json" 20 | ], 21 | "validate_json_files": [ 22 | "data/dev_target.json" 23 | ], 24 | "test_json_files": [ 25 | "data/test_target.json" 26 | ], 27 | "unlabeled_train_json_files":[ 28 | "data/train_relative.json" 29 | ], 30 | "unlabeled_dev_json_files":[ 31 | "data/dev_relative.json" 32 | ], 33 | "unlabeled_test_json_files":[ 34 | "data/test_relative.json" 35 | ], 36 | "weak_labeled_json_files":[ 37 | "data/train_relative_labeled_x.json" 38 | ], 39 | "generate_dict_using_json_files": false, 40 | "generate_dict_using_all_json_files": true, 41 | "generate_dict_using_pretrained_embedding": false, 42 | "dict_dir": "dict_msf", 43 | "num_worker": 4, 44 | "pretrained_bert_embedding": "/dockerdata/xiaomingshi/chinese_L-12_H-768_A-12" 45 | }, 46 | "feature": { 47 | "feature_names": [ 48 | "token" 49 | ], 50 | "min_token_count": 2, 51 | "min_char_count": 2, 52 | "token_ngram": 0, 53 | "min_token_ngram_count": 0, 54 | "min_keyword_count": 0, 55 | "min_topic_count": 2, 56 | "max_token_dict_size": 1000000, 57 | "max_char_dict_size": 150000, 58 | "max_token_ngram_dict_size": 10000000, 59 | "max_keyword_dict_size": 100, 60 | "max_topic_dict_size": 100, 61 | "max_token_len": 256, 62 | "max_char_len": 1024, 63 | "max_char_len_per_token": 4, 64 | "token_pretrained_file": "", 65 | "keyword_pretrained_file": "" 66 | }, 67 | "train": { 68 | "K": 1, 69 | "batch_size": [64, 128, 256], 70 | "start_epoch": 1, 71 | "pretrain_num_epochs": [10, 15, 20], 72 | "num_epochs": 100, 73 | "self_num_epochs": 100, 74 | "num_epochs_static_embedding": 0, 75 | "decay_steps": 1000, 76 | "decay_rate": 1.0, 77 | "clip_gradients": 100.0, 78 | "l2_lambda": 0.0, 79 | "loss_type": "BCEWithLogitsLoss", 80 | "sampler": "fixed", 81 | "num_sampled": 5, 82 | "visible_device_list": "1", 83 | "hidden_layer_dropout": 0.5 84 | }, 85 | "embedding": { 86 | "type": "embedding", 87 | "dimension": 768, 88 | "region_embedding_type": "context_word", 89 | "region_size": 5, 90 | "initializer": "uniform", 91 | "fan_mode": "FAN_IN", 92 | "uniform_bound": 0.25, 93 | "random_stddev": 0.01, 94 | "dropout": 0.0 95 | }, 96 | "optimizer": { 97 | "optimizer_type": "Adam", 98 | "learning_rate": 0.008, 99 | "adadelta_decay_rate": 0.95, 100 | "adadelta_epsilon": 1e-08 101 | }, 102 | "TextCNN": { 103 | "kernel_sizes": [ 104 | 1, 105 | 2, 106 | 3 107 | ], 108 | "num_kernels": 100, 109 | "top_k_max_pooling": 1 110 | }, 111 | "TextRNN": { 112 | "hidden_dimension": 64, 113 | "rnn_type": "LSTM", 114 | "num_layers": 1, 115 | "doc_embedding_type": "Attention", 116 | "attention_dimension": 16, 117 | "bidirectional": true 118 | }, 119 | "DRNN": { 120 | "hidden_dimension": 5, 121 | "window_size": 3, 122 | "rnn_type": "GRU", 123 | "bidirectional": true, 124 | "cell_hidden_dropout": 0.1 125 | }, 126 | "eval": { 127 | "text_file": "data/test_target.json", 128 | "threshold": 0.5, 129 | "dir": "eval_dir", 130 | "batch_size": 1024, 131 | "is_flat": true, 132 | "top_k": 5, 133 | "model_dir": "checkpoint_dir_msf" 134 | }, 135 | "TextVDCNN": { 136 | "vdcnn_depth": 9, 137 | "top_k_max_pooling": 8 138 | }, 139 | "DPCNN": { 140 | "kernel_size": 3, 141 | "pooling_stride": 2, 142 | "num_kernels": 16, 143 | "blocks": 2 144 | }, 145 | "TextRCNN": { 146 | "kernel_sizes": [ 147 | 1, 148 | 2, 149 | 3 150 | ], 151 | "num_kernels": 100, 152 | "top_k_max_pooling": 1, 153 | "hidden_dimension":64, 154 | "rnn_type": "GRU", 155 | "num_layers": 1, 156 | "bidirectional": true 157 | }, 158 | "Transformer": { 159 | "d_inner": 128, 160 | "d_k": 32, 161 | "d_v": 32, 162 | "n_head": 4, 163 | "n_layers": 1, 164 | "dropout": 0.1, 165 | "use_star": true 166 | }, 167 | "AttentiveConvNet": { 168 | "attention_type": "bilinear", 169 | "margin_size": 3, 170 | "type": "advanced", 171 | "hidden_size": 64 172 | }, 173 | "log": { 174 | "logger_file": "./log/TextRNN_SelfLearning", 175 | "log_level": "info" 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /conf/train_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "task_info":{ 3 | "label_type": "multi_label", 4 | "hierarchical": false, 5 | "hierar_taxonomy": "data/rcv1.taxonomy", 6 | "hierar_penalty": 0.000001, 7 | "weak_data_augmentation": false, 8 | "weak_pretrain": true, 9 | "top_n_teacher": false, 10 | "Augmentation_Method": "Union", 11 | "add_noise": false 12 | }, 13 | "device": "cuda", 14 | "model_name": "BERT", 15 | "checkpoint_dir": "/dockerdata/xiaomingshi/model", 16 | "model_dir": "/dockerdata/xiaomingshi/model", 17 | "data": { 18 | "train_json_files": [ 19 | "data/train_target.json" 20 | ], 21 | "validate_json_files": [ 22 | "data/dev_target.json" 23 | ], 24 | "test_json_files": [ 25 | "data/test_target.json" 26 | ], 27 | "unlabeled_train_json_files":[ 28 | "data/train_relative.json" 29 | ], 30 | "unlabeled_dev_json_files":[ 31 | "data/dev_relative.json" 32 | ], 33 | "unlabeled_test_json_files":[ 34 | "data/test_relative.json" 35 | ], 36 | "weak_labeled_json_files":[ 37 | "data/train_relative_labeled_x.json" 38 | ], 39 | "generate_dict_using_json_files": false, 40 | "generate_dict_using_all_json_files": true, 41 | "generate_dict_using_pretrained_embedding": false, 42 | "dict_dir": "dict_msf", 43 | "num_worker": 4, 44 | "pretrained_bert_embedding": "/dockerdata/xiaomingshi/chinese_L-12_H-768_A-12" 45 | }, 46 | "feature": { 47 | "feature_names": [ 48 | "token" 49 | ], 50 | "min_token_count": 2, 51 | "min_char_count": 2, 52 | "token_ngram": 0, 53 | "min_token_ngram_count": 0, 54 | "min_keyword_count": 0, 55 | "min_topic_count": 2, 56 | "max_token_dict_size": 1000000, 57 | "max_char_dict_size": 150000, 58 | "max_token_ngram_dict_size": 10000000, 59 | "max_keyword_dict_size": 100, 60 | "max_topic_dict_size": 100, 61 | "max_token_len": 256, 62 | "max_char_len": 1024, 63 | "max_char_len_per_token": 4, 64 | "token_pretrained_file": "", 65 | "keyword_pretrained_file": "" 66 | }, 67 | "train": { 68 | "K": 1, 69 | "batch_size": [32, 64, 128], 70 | "start_epoch": 1, 71 | "pretrain_num_epochs": [10, 15, 20], 72 | "num_epochs": 50, 73 | "self_num_epochs": 100, 74 | "num_epochs_static_embedding": 0, 75 | "decay_steps": 1000, 76 | "decay_rate": 1.0, 77 | "clip_gradients": 100.0, 78 | "l2_lambda": 0.0, 79 | "loss_type": "BCEWithLogitsLoss", 80 | "sampler": "fixed", 81 | "num_sampled": 5, 82 | "visible_device_list": "3", 83 | "hidden_layer_dropout": 0.5 84 | }, 85 | "embedding": { 86 | "type": "embedding", 87 | "dimension": 768, 88 | "region_embedding_type": "context_word", 89 | "region_size": 5, 90 | "initializer": "uniform", 91 | "fan_mode": "FAN_IN", 92 | "uniform_bound": 0.25, 93 | "random_stddev": 0.01, 94 | "dropout": 0.0 95 | }, 96 | "optimizer": { 97 | "optimizer_type": "Adam", 98 | "learning_rate": 0.008, 99 | "adadelta_decay_rate": 0.95, 100 | "adadelta_epsilon": 1e-08 101 | }, 102 | "TextCNN": { 103 | "kernel_sizes": [ 104 | 1, 105 | 2, 106 | 3 107 | ], 108 | "num_kernels": 100, 109 | "top_k_max_pooling": 1 110 | }, 111 | "TextRNN": { 112 | "hidden_dimension": 64, 113 | "rnn_type": "LSTM", 114 | "num_layers": 1, 115 | "doc_embedding_type": "Attention", 116 | "attention_dimension": 16, 117 | "bidirectional": true 118 | }, 119 | "DRNN": { 120 | "hidden_dimension": 5, 121 | "window_size": 3, 122 | "rnn_type": "GRU", 123 | "bidirectional": true, 124 | "cell_hidden_dropout": 0.1 125 | }, 126 | "eval": { 127 | "text_file": "data/test_target.json", 128 | "threshold": 0.5, 129 | "dir": "eval_dir", 130 | "batch_size": 128, 131 | "is_flat": true, 132 | "top_k": 5, 133 | "model_dir": "/dockerdata/xiaomingshi/model" 134 | }, 135 | "TextVDCNN": { 136 | "vdcnn_depth": 9, 137 | "top_k_max_pooling": 8 138 | }, 139 | "DPCNN": { 140 | "kernel_size": 3, 141 | "pooling_stride": 2, 142 | "num_kernels": 16, 143 | "blocks": 2 144 | }, 145 | "TextRCNN": { 146 | "kernel_sizes": [ 147 | 1, 148 | 2, 149 | 3 150 | ], 151 | "num_kernels": 100, 152 | "top_k_max_pooling": 1, 153 | "hidden_dimension":64, 154 | "rnn_type": "GRU", 155 | "num_layers": 1, 156 | "bidirectional": true 157 | }, 158 | "Transformer": { 159 | "d_inner": 128, 160 | "d_k": 32, 161 | "d_v": 32, 162 | "n_head": 4, 163 | "n_layers": 1, 164 | "dropout": 0.1, 165 | "use_star": true 166 | }, 167 | "AttentiveConvNet": { 168 | "attention_type": "bilinear", 169 | "margin_size": 3, 170 | "type": "advanced", 171 | "hidden_size": 64 172 | }, 173 | "log": { 174 | "logger_file": "./log/BERT_WeakPretrain", 175 | "log_level": "info" 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /model/classification/textrnn.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from dataset.classification_dataset import ClassificationDataset as cDataset 18 | from model.classification.classifier import Classifier 19 | from model.layers import SumAttention 20 | from model.rnn import RNN 21 | from util import Type 22 | from transformers import * 23 | 24 | 25 | class DocEmbeddingType(Type): 26 | """Standard names for doc embedding type. 27 | """ 28 | AVG = 'AVG' 29 | ATTENTION = 'Attention' 30 | LAST_HIDDEN = 'LastHidden' 31 | 32 | @classmethod 33 | def str(cls): 34 | return ",".join( 35 | [cls.AVG, cls.ATTENTION, cls.LAST_HIDDEN]) 36 | 37 | 38 | class TextRNN(Classifier): 39 | """Implement TextRNN, contains LSTM,BiLSTM,GRU,BiGRU 40 | Reference: "Effective LSTMs for Target-Dependent Sentiment Classification" 41 | "Bidirectional LSTM-CRF Models for Sequence Tagging" 42 | "Generative and discriminative text classification 43 | with recurrent neural networks" 44 | """ 45 | 46 | def __init__(self, dataset, config): 47 | super(TextRNN, self).__init__(dataset, config) 48 | self.rnn = RNN( 49 | config.embedding.dimension, config.TextRNN.hidden_dimension, 50 | num_layers=config.TextRNN.num_layers, batch_first=True, 51 | bidirectional=config.TextRNN.bidirectional, 52 | rnn_type=config.TextRNN.rnn_type) 53 | hidden_dimension = config.TextRNN.hidden_dimension 54 | if config.TextRNN.bidirectional: 55 | hidden_dimension *= 2 56 | self.sum_attention = SumAttention(hidden_dimension, 57 | config.TextRNN.attention_dimension, 58 | config.device) 59 | self.linear = torch.nn.Linear(hidden_dimension, len(dataset.label_map)) 60 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 61 | 62 | def get_parameter_optimizer_dict(self): 63 | params = super(TextRNN, self).get_parameter_optimizer_dict() 64 | params.append({'params': self.rnn.parameters()}) 65 | params.append({'params': self.linear.parameters()}) 66 | return params 67 | 68 | def update_lr(self, optimizer, epoch): 69 | if epoch > self.config.train.num_epochs_static_embedding: 70 | for param_group in optimizer.param_groups[:2]: 71 | param_group["lr"] = self.config.optimizer.learning_rate 72 | else: 73 | for param_group in optimizer.param_groups[:2]: 74 | param_group["lr"] = 0.0 75 | 76 | def forward(self, batch): 77 | # delete embedding process 78 | #if self.config.feature.feature_names[0] == "token": 79 | # embedding = self.token_embedding( 80 | # batch[cDataset.DOC_TOKEN].to(self.config.device)) 81 | # length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device) 82 | #else: 83 | # embedding = self.char_embedding( 84 | # batch[cDataset.DOC_CHAR].to(self.config.device)) 85 | # length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device) 86 | #bert_model = BertModel.from_pretrained('/users3/xmshi/2018_1/contextualLUDSTC4_Bert/bert/model').to(self.config.device) 87 | #with torch.no_grad(): 88 | # embedding = bert_model(batch[cDataset.DOC_TOKEN].to(self.config.device))[0] 89 | embedding = batch['doc_token'].to(self.config.device) 90 | length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device) 91 | output, last_hidden = self.rnn(embedding, length) 92 | 93 | doc_embedding_type = self.config.TextRNN.doc_embedding_type 94 | if doc_embedding_type == DocEmbeddingType.AVG: 95 | doc_embedding = torch.sum(output, 1) / length.unsqueeze(1) 96 | elif doc_embedding_type == DocEmbeddingType.ATTENTION: 97 | doc_embedding = self.sum_attention(output) 98 | elif doc_embedding_type == DocEmbeddingType.LAST_HIDDEN: 99 | doc_embedding = last_hidden 100 | else: 101 | raise TypeError( 102 | "Unsupported rnn init type: %s. Supported rnn type is: %s" % ( 103 | doc_embedding_type, DocEmbeddingType.str())) 104 | 105 | return self.dropout(self.linear(doc_embedding)) 106 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import sys 16 | import time 17 | 18 | import torch 19 | from torch.utils.data import DataLoader 20 | 21 | import util 22 | from config import Config 23 | from dataset.classification_dataset import ClassificationDataset 24 | from dataset.collator import ClassificationCollator 25 | from dataset.collator import ClassificationType 26 | from dataset.collator import FastTextCollator 27 | from evaluate.classification_evaluate import \ 28 | ClassificationEvaluator as cEvaluator 29 | from model.classification.drnn import DRNN 30 | from model.classification.fasttext import FastText 31 | from model.classification.textcnn import TextCNN 32 | from model.classification.textvdcnn import TextVDCNN 33 | from model.classification.textrnn import TextRNN 34 | from model.classification.textrcnn import TextRCNN 35 | from model.classification.transformer import Transformer 36 | from model.classification.dpcnn import DPCNN 37 | from model.classification.attentive_convolution import AttentiveConvNet 38 | from model.classification.region_embedding import RegionEmbedding 39 | from model.model_util import get_optimizer, get_hierar_relations 40 | from util import ModeType 41 | 42 | ClassificationDataset, ClassificationCollator, FastTextCollator, cEvaluator, 43 | FastText, TextCNN, TextRNN, TextRCNN, DRNN, TextVDCNN, Transformer, DPCNN, 44 | AttentiveConvNet, RegionEmbedding 45 | 46 | 47 | def get_classification_model(model_name, dataset, conf): 48 | model = globals()[model_name](dataset, conf) 49 | model = model.cuda(conf.device) if conf.device.startswith("cuda") else model 50 | return model 51 | 52 | 53 | def load_checkpoint(file_name, conf, model, optimizer): 54 | checkpoint = torch.load(file_name) 55 | conf.train.start_epoch = checkpoint["epoch"] 56 | model.load_state_dict(checkpoint["state_dict"]) 57 | optimizer.load_state_dict(checkpoint["optimizer"]) 58 | 59 | 60 | def eval(conf): 61 | logger = util.Logger(conf) 62 | model_name = conf.model_name 63 | dataset_name = "ClassificationDataset" 64 | collate_name = "FastTextCollator" if model_name == "FastText" \ 65 | else "ClassificationCollator" 66 | 67 | test_dataset = globals()[dataset_name](conf, conf.data.test_json_files) 68 | collate_fn = globals()[collate_name](conf, len(test_dataset.label_map)) 69 | test_data_loader = DataLoader( 70 | test_dataset, batch_size=conf.eval.batch_size, shuffle=False, 71 | num_workers=conf.data.num_worker, collate_fn=collate_fn, 72 | pin_memory=True) 73 | 74 | empty_dataset = globals()[dataset_name](conf, []) 75 | model = get_classification_model(model_name, empty_dataset, conf) 76 | optimizer = get_optimizer(conf, model) 77 | load_checkpoint(conf.eval.model_dir, conf, model, optimizer) 78 | model.eval() 79 | is_multi = False 80 | if conf.task_info.label_type == ClassificationType.MULTI_LABEL: 81 | is_multi = True 82 | predict_probs = [] 83 | standard_labels = [] 84 | evaluator = cEvaluator(conf.eval.dir) 85 | for batch in test_data_loader: 86 | logits = model(batch) 87 | if not is_multi: 88 | result = torch.nn.functional.softmax(logits, dim=1).cpu().tolist() 89 | else: 90 | result = torch.sigmoid(logits).cpu().tolist() 91 | predict_probs.extend(result) 92 | standard_labels.extend(batch[ClassificationDataset.DOC_LABEL_LIST]) 93 | (_, precision_list, recall_list, fscore_list, right_list, 94 | predict_list, standard_list) = \ 95 | evaluator.evaluate( 96 | predict_probs, standard_label_ids=standard_labels, label_map=empty_dataset.label_map, 97 | threshold=conf.eval.threshold, top_k=conf.eval.top_k, 98 | is_flat=conf.eval.is_flat, is_multi=is_multi) 99 | logger.warn( 100 | "Performance is precision: %f, " 101 | "recall: %f, fscore: %f, right: %d, predict: %d, standard: %d." % ( 102 | precision_list[0][cEvaluator.MICRO_AVERAGE], 103 | recall_list[0][cEvaluator.MICRO_AVERAGE], 104 | fscore_list[0][cEvaluator.MICRO_AVERAGE], 105 | right_list[0][cEvaluator.MICRO_AVERAGE], 106 | predict_list[0][cEvaluator.MICRO_AVERAGE], 107 | standard_list[0][cEvaluator.MICRO_AVERAGE])) 108 | evaluator.save() 109 | 110 | 111 | if __name__ == '__main__': 112 | config = Config(config_file=sys.argv[1]) 113 | eval(config) 114 | -------------------------------------------------------------------------------- /model/classification/textcnn.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from dataset.classification_dataset import ClassificationDataset as cDataset 18 | from model.classification.classifier import Classifier 19 | 20 | 21 | class TextCNN(Classifier): 22 | def __init__(self, dataset, config): 23 | super(TextCNN, self).__init__(dataset, config) 24 | 25 | self.kernel_sizes = config.TextCNN.kernel_sizes 26 | self.convs = torch.nn.ModuleList() 27 | for kernel_size in self.kernel_sizes: 28 | self.convs.append(torch.nn.Conv1d( 29 | config.embedding.dimension, config.TextCNN.num_kernels, 30 | kernel_size, padding=kernel_size - 1)) 31 | 32 | self.top_k = self.config.TextCNN.top_k_max_pooling 33 | hidden_size = len(config.TextCNN.kernel_sizes) * \ 34 | config.TextCNN.num_kernels * self.top_k 35 | self.linear = torch.nn.Linear(hidden_size, len(dataset.label_map)) 36 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 37 | 38 | def get_parameter_optimizer_dict(self): 39 | params = list() 40 | params.append({'params': self.token_embedding.parameters()}) 41 | params.append({'params': self.char_embedding.parameters()}) 42 | params.append({'params': self.convs.parameters()}) 43 | params.append({'params': self.linear.parameters()}) 44 | return params 45 | 46 | def update_lr(self, optimizer, epoch): 47 | """Update lr 48 | """ 49 | if epoch > self.config.train.num_epochs_static_embedding: 50 | for param_group in optimizer.param_groups[:2]: 51 | param_group["lr"] = self.config.optimizer.learning_rate 52 | else: 53 | for param_group in optimizer.param_groups[:2]: 54 | param_group["lr"] = 0 55 | 56 | def forward(self, batch): 57 | #if self.config.feature.feature_names[0] == "token": 58 | # embedding = self.token_embedding( 59 | # batch[cDataset.DOC_TOKEN].to(self.config.device)) 60 | #else: 61 | # embedding = self.char_embedding( 62 | # batch[cDataset.DOC_CHAR].to(self.config.device)) 63 | #embedding = embedding.transpose(1, 2) 64 | embedding = batch['doc_token'].to(self.config.device) 65 | embedding = embedding.transpose(1, 2) 66 | pooled_outputs = [] 67 | for i, conv in enumerate(self.convs): 68 | #convolution = torch.nn.ReLU(conv(embedding)) 69 | convolution = torch.nn.functional.relu(conv(embedding)) 70 | pooled = torch.topk(convolution, self.top_k)[0].view( 71 | convolution.size(0), -1) 72 | pooled_outputs.append(pooled) 73 | 74 | doc_embedding = torch.cat(pooled_outputs, 1) 75 | return self.dropout(self.linear(doc_embedding)) 76 | 77 | def token_similarity_attention(self, output): 78 | # output: (batch, sentence length, embedding dim) 79 | symptom_id_list = [6, 134, 15, 78, 2616, 257, 402, 281, 14848, 71, 82, 96, 352, 60, 227, 204, 178, 175, 233, 192, 416, 91, 232, 317, 17513, 628, 1047] 80 | symptom_embedding = self.token_embedding(torch.LongTensor(symptom_id_list).cuda()) 81 | # symptom_embedding: torch.tensor(symptom_num, embedding dim) 82 | batch_symptom_embedding = torch.cat([symptom_embedding.view(1, symptom_embedding.shape[0], -1)] * output.shape[0], dim=0) 83 | similarity = torch.sigmoid(torch.bmm(torch.nn.functional.normalize(output, dim=2), torch.nn.functional.normalize(batch_symptom_embedding.permute(0, 2, 1), dim=2))) 84 | #similarity = torch.bmm(torch.nn.functional.normalize(output, dim=2), torch.nn.functional.normalize(batch_symptom_embedding.permute(0, 2, 1), dim=2)) 85 | #similarity = torch.sigmoid(torch.max(similarity, dim=2)[0]) 86 | similarity = torch.max(similarity, dim=2)[0] 87 | #similarity = torch.sigmoid(torch.sum(similarity, dim=2)) 88 | # similarity: torch.tensor(batch, sentence_len) 89 | similarity = torch.cat([similarity.view(similarity.shape[0], -1, 1)] * output.shape[2], dim=2) 90 | print(similarity) 91 | # similarity: torch.tensor(batch, batch, sentence_len, embedding dim) 92 | #sentence_embedding = torch.sum(torch.mul(similarity, output), dim=1) 93 | # sentence_embedding: (batch, embedding) 94 | sentence_embedding = torch.mul(similarity, output) 95 | # sentence_embedding: (batch, sentence len, embedding) 96 | return sentence_embedding 97 | -------------------------------------------------------------------------------- /model/classification/drnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | 16 | import torch 17 | 18 | from dataset.classification_dataset import ClassificationDataset as cDataset 19 | from model.classification.classifier import Classifier 20 | from model.rnn import RNN 21 | from model.rnn import RNNType 22 | 23 | 24 | class DRNN(Classifier): 25 | def __init__(self, dataset, config): 26 | super(DRNN, self).__init__(dataset, config) 27 | self.rnn_type = config.DRNN.rnn_type 28 | self.forward_rnn = RNN( 29 | config.embedding.dimension, config.DRNN.hidden_dimension, 30 | batch_first=True, rnn_type=config.DRNN.rnn_type) 31 | if config.DRNN.bidirectional: 32 | self.backward_rnn = RNN( 33 | config.embedding.dimension, config.DRNN.hidden_dimension, 34 | batch_first=True, rnn_type=config.DRNN.rnn_type) 35 | self.window_size = config.DRNN.window_size 36 | self.dropout = torch.nn.Dropout(p=config.DRNN.cell_hidden_dropout) 37 | self.hidden_dimension = config.DRNN.hidden_dimension 38 | if config.DRNN.bidirectional: 39 | self.hidden_dimension *= 2 40 | self.batch_norm = torch.nn.BatchNorm1d(self.hidden_dimension) 41 | 42 | self.mlp = torch.nn.Linear(self.hidden_dimension, self.hidden_dimension) 43 | self.linear = torch.nn.Linear(self.hidden_dimension, 44 | len(dataset.label_map)) 45 | 46 | def get_parameter_optimizer_dict(self): 47 | params = super(DRNN, self).get_parameter_optimizer_dict() 48 | params.append({'params': self.forward_rnn.parameters()}) 49 | if self.config.DRNN.bidirectional: 50 | params.append({'params': self.backward_rnn.parameters()}) 51 | params.append({'params': self.batch_norm.parameters()}) 52 | params.append({'params': self.mlp.parameters()}) 53 | params.append({'params': self.linear.parameters()}) 54 | return params 55 | 56 | def forward(self, batch): 57 | front_pad_embedding, _, mask = self.get_embedding( 58 | batch, [self.window_size - 1, 0], cDataset.VOCAB_PADDING_LEARNABLE) 59 | if self.config.DRNN.bidirectional: 60 | tail_pad_embedding, _, _ = self.get_embedding( 61 | batch, [0, self.window_size - 1], 62 | cDataset.VOCAB_PADDING_LEARNABLE) 63 | batch_size = front_pad_embedding.size(0) 64 | mask = mask.unsqueeze(2) 65 | 66 | front_slice_embedding_list = \ 67 | [front_pad_embedding[:, i:i + self.window_size, :] for i in 68 | range(front_pad_embedding.size(1) - self.window_size + 1)] 69 | 70 | front_slice_embedding = torch.cat(front_slice_embedding_list, dim=0) 71 | 72 | state = None 73 | for i in range(front_slice_embedding.size(1)): 74 | _, state = self.forward_rnn(front_slice_embedding[:, i:i + 1, :], 75 | init_state=state, ori_state=True) 76 | if self.rnn_type == RNNType.LSTM: 77 | state[0] = self.dropout(state[0]) 78 | else: 79 | state = self.dropout(state) 80 | front_state = state[0] if self.rnn_type == RNNType.LSTM else state 81 | front_state = front_state.transpose(0, 1) 82 | front_hidden = torch.cat(front_state.split(batch_size, dim=0), dim=1) 83 | front_hidden = front_hidden * mask 84 | 85 | hidden = front_hidden 86 | if self.config.DRNN.bidirectional: 87 | tail_slice_embedding_list = list() 88 | for i in range(tail_pad_embedding.size(1) - self.window_size + 1): 89 | slice_embedding = \ 90 | tail_pad_embedding[:, i:i + self.window_size, :] 91 | tail_slice_embedding_list.append(slice_embedding) 92 | tail_slice_embedding = torch.cat(tail_slice_embedding_list, dim=0) 93 | 94 | state = None 95 | for i in range(tail_slice_embedding.size(1), 0, -1): 96 | _, state = self.backward_rnn( 97 | tail_slice_embedding[:, i - 1:i, :], 98 | init_state=state, ori_state=True) 99 | if i != tail_slice_embedding.size(1) - 1: 100 | if self.rnn_type == RNNType.LSTM: 101 | state[0] = self.dropout(state[0]) 102 | else: 103 | state = self.dropout(state) 104 | tail_state = state[0] if self.rnn_type == RNNType.LSTM else state 105 | tail_state = tail_state.transpose(0, 1) 106 | tail_hidden = torch.cat(tail_state.split(batch_size, dim=0), dim=1) 107 | tail_hidden = tail_hidden * mask 108 | hidden = torch.cat([hidden, tail_hidden], dim=2) 109 | 110 | hidden = hidden.transpose(1, 2).contiguous() 111 | 112 | batch_normed = self.batch_norm(hidden).transpose(1, 2) 113 | batch_normed = batch_normed * mask 114 | mlp_hidden = self.mlp(batch_normed) 115 | mlp_hidden = mlp_hidden * mask 116 | neg_mask = (mask - 1) * 65500.0 117 | mlp_hidden = mlp_hidden + neg_mask 118 | max_pooling = torch.nn.functional.max_pool1d( 119 | mlp_hidden.transpose(1, 2), mlp_hidden.size(1)).squeeze() 120 | return self.linear(self.dropout(max_pooling)) 121 | -------------------------------------------------------------------------------- /model/classification/transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from dataset.classification_dataset import ClassificationDataset as cDataset 19 | from model.classification.classifier import Classifier 20 | from model.transformer_encoder import EncoderLayer, StarEncoderLayer 21 | from model.embedding import PositionEmbedding 22 | 23 | 24 | class Transformer(Classifier): 25 | def __init__(self, dataset, config): 26 | super(Transformer, self).__init__(dataset, config) 27 | 28 | self.pad = dataset.token_map[dataset.VOCAB_PADDING] 29 | 30 | if config.feature.feature_names[0] == "token": 31 | seq_max_len = config.feature.max_token_len 32 | else: 33 | seq_max_len = config.feature.max_char_len 34 | self.position_enc = PositionEmbedding(seq_max_len, 35 | config.embedding.dimension, 36 | self.pad) 37 | 38 | if config.Transformer.use_star: 39 | self.layer_stack = nn.ModuleList([ 40 | StarEncoderLayer(config.embedding.dimension, 41 | config.Transformer.n_head, 42 | config.Transformer.d_k, 43 | config.Transformer.d_v, 44 | dropout=config.Transformer.dropout) 45 | for _ in range(config.Transformer.n_layers)]) 46 | else: 47 | self.layer_stack = nn.ModuleList([ 48 | EncoderLayer(config.embedding.dimension, 49 | config.Transformer.d_inner, 50 | config.Transformer.n_head, 51 | config.Transformer.d_k, 52 | config.Transformer.d_v, 53 | dropout=config.Transformer.dropout) 54 | for _ in range(config.Transformer.n_layers)]) 55 | 56 | hidden_size = config.embedding.dimension 57 | self.linear = torch.nn.Linear(hidden_size, len(dataset.label_map)) 58 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 59 | 60 | def get_parameter_optimizer_dict(self): 61 | params = list() 62 | params.append({'params': self.token_embedding.parameters()}) 63 | params.append({'params': self.char_embedding.parameters()}) 64 | for i in range(0, len(self.layer_stack)): 65 | params.append({'params': self.layer_stack[i].parameters()}) 66 | params.append({'params': self.linear.parameters()}) 67 | return params 68 | 69 | def update_lr(self, optimizer, epoch): 70 | if epoch > self.config.train.num_epochs_static_embedding: 71 | for param_group in optimizer.param_groups[:2]: 72 | param_group["lr"] = self.config.optimizer.learning_rate 73 | else: 74 | for param_group in optimizer.param_groups[:2]: 75 | param_group["lr"] = 0 76 | 77 | def forward(self, batch): 78 | def _get_non_pad_mask(seq, pad): 79 | assert seq.dim() == 2 80 | return seq.ne(pad).type(torch.float).unsqueeze(-1) 81 | 82 | def _get_attn_key_pad_mask(seq_k, seq_q, pad): 83 | ''' For masking out the padding part of key sequence. ''' 84 | 85 | # Expand to fit the shape of key query attention matrix. 86 | len_q = seq_q.size(1) 87 | padding_mask = seq_k.eq(pad) 88 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 89 | 90 | return padding_mask 91 | 92 | if self.config.feature.feature_names[0] == "token": 93 | src_seq = batch[cDataset.DOC_TOKEN].to(self.config.device) 94 | embedding = self.token_embedding(src_seq) 95 | else: 96 | src_seq = batch[cDataset.DOC_CHAR].to(self.config.device) 97 | embedding = self.char_embedding(src_seq) 98 | 99 | # Prepare masks 100 | slf_attn_mask = _get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq, pad=self.pad) 101 | non_pad_mask = _get_non_pad_mask(src_seq, self.pad) 102 | 103 | batch_lens = (src_seq != self.pad).sum(dim=-1) 104 | src_pos = torch.zeros_like(src_seq, dtype=torch.long) 105 | for row, length in enumerate(batch_lens): 106 | src_pos[row][:length] = torch.arange(1, length + 1) 107 | 108 | enc_output = embedding + self.position_enc(src_pos) 109 | 110 | if self.config.Transformer.use_star: 111 | s = torch.mean(embedding, 1) # virtual relay node 112 | h = enc_output 113 | for enc_layer in self.layer_stack: 114 | h, s = enc_layer(h, embedding, s, 115 | non_pad_mask=non_pad_mask, 116 | slf_attn_mask=None) 117 | h_max, _ = torch.max(h, 1) 118 | enc_output = h_max + s 119 | else: 120 | for enc_layer in self.layer_stack: 121 | enc_output, _ = enc_layer(enc_output, 122 | non_pad_mask=non_pad_mask, 123 | slf_attn_mask=slf_attn_mask) 124 | enc_output = torch.mean(enc_output, 1) 125 | 126 | return self.dropout(self.linear(enc_output)) 127 | -------------------------------------------------------------------------------- /readme/Configuration.md: -------------------------------------------------------------------------------- 1 | Configuration of NeuralClassifier uses JSON. 2 | 3 | ## Common 4 | 5 | * **task\_info** 6 | * **label_type**: Candidates: "single-label", "multi-label". 7 | * **hierarchical**: Boolean. Indicates whether it is a hierarchical classification. 8 | * **hierar_taxonomy**: A text file describes taxonomy. 9 | * **hierar_penalty**: Float. 10 | * **device**: Candidates: "cuda", "cpu". 11 | * **model\_name**: Candidates: "FastText", "TextCNN", "TextRNN", "TextRCNN", "DRNN", "VDCNN", "DPCNN", "AttentiveConvNet", "Transformer". 12 | * **checkpoint\_dir**: checkpoint directory 13 | * **model\_dir**: model directory 14 | * **data** 15 | * **train\_json\_files**: train input data. 16 | * **validate\_json\_files**: validation input data. 17 | * **test\_json\_files**: test input data. 18 | * **generate\_dict\_using\_json\_files**: generate dict using train data. 19 | * **generate\_dict\_using\_all\_json\_files**: generate dict using train, validate, test data. 20 | * **generate\_dict\_using\_pretrained\_embedding**: generate dict from pre-trained embedding. 21 | * **dict\_dir**: dict directory. 22 | * **num\_worker**: number of porcess to load data. 23 | 24 | 25 | ## Feature 26 | 27 | * **feature\_names**: Candidates: "token", "char". 28 | * **min\_token\_count** 29 | * **min\_char\_count** 30 | * **token\_ngram**: N-Gram, for example, 2 means bigram. 31 | * **min\_token\_ngram\_count** 32 | * **min\_keyword\_count** 33 | * **min\_topic\_count** 34 | * **max\_token\_dict\_size** 35 | * **max\_char\_dict\_size** 36 | * **max\_token\_ngram\_dict\_size** 37 | * **max\_keyword\_dict\_size** 38 | * **max\_topic\_dict\_size** 39 | * **max\_token\_len** 40 | * **max\_char\_len** 41 | * **max\_char\_len\_per\_token** 42 | * **token\_pretrained\_file**: token pre-trained embedding. 43 | * **keyword\_pretrained\_file**: keyword pre-trained embedding. 44 | 45 | 46 | ## Train 47 | 48 | * **batch\_size** 49 | * **eval\_train\_data**: whether evaluate training data when training. 50 | * **start\_epoch**: start number of epochs. 51 | * **num\_epochs**: number of epochs. 52 | * **num\_epochs\_static\_embedding**: number of epochs that input embedding does not update. 53 | * **decay\_steps**: decay learning rate every decay\_steps. 54 | * **decay\_rate**: Rate of decay for learning rate. 55 | * **clip\_gradients**: Clip absolute value gradient bigger than threshold. 56 | * **l2\_lambda**: l2 regularization lambda value. 57 | * **loss\_type**: Candidates: "SoftmaxCrossEntropy", "SoftmaxFocalCrossEntropy", "SigmodFocalCrossEntropy", "BCEWithLogitsLoss". 58 | * **sampler**: If loss type is NCE, sampler is needed. Candidate: "fixed", "log", "learned", "uniform". 59 | * **num\_sampled**: If loss type is NCE, need to sample negative labels. 60 | * **hidden\_layer\_dropout**: dropout of hidden layer. 61 | * **visible\_device\_list**: GPU list to use. 62 | 63 | 64 | ## Embedding 65 | 66 | * **type**: Candidates: "embedding", "region_embedding". 67 | * **dimension**: dimension of embedding. 68 | * **region\_embedding\_type**: config for Region embedding. Candidates: "word\_context", "context\_word". 69 | * **region_size** region size, must be odd number. Config for Region embedding. 70 | * **initializer**: Candidates: "uniform", "normal", "xavier\_uniform", "xavier\_normal", "kaiming\_uniform", "kaiming\_normal", "orthogonal". 71 | * **fan\_mode**: Candidates: "FAN\_IN", "FAN\_OUT". 72 | * **uniform\_bound**: If embedding_initializer is uniform, this param will be used as bound. e.g. [-embedding\_uniform\_bound,embedding\_uniform\_bound]. 73 | * **random\_stddev**: If embedding_initializer is random, this param will be used as stddev. 74 | * **dropout**: dropout of embedding layer. 75 | 76 | 77 | ## Optimizer 78 | 79 | * **optimizer\_type**: Candidates: "Adam", "Adadelta" 80 | * **learning\_rate**: learning rate. 81 | * **adadelta\_decay\_rate**: useful when optimizer\_type is Adadelta. 82 | * **adadelta\_epsilon**: useful when optimizer\_type is Adadelta. 83 | 84 | 85 | ## Eval 86 | 87 | * **text\_file** 88 | * **threshold**: float trunc threshold for predict probabilities. 89 | * **dir**: output dir of evaluation. 90 | * **batch\_size**: batch size of evaluation. 91 | * **is\_flat**: Boolean, flat evaluation or hierarchical evaluation. 92 | 93 | 94 | ## Log 95 | 96 | * **logger\_file**: log file path. 97 | * **log\_level**: Candidates: "debug", "info", "warn", "error". 98 | 99 | 100 | ## Encoder 101 | 102 | ### TextCNN 103 | 104 | * **kernel\_sizes**: kernel size. 105 | * **num\_kernels**: number of kernels. 106 | * **top\_k\_max\_pooling**: max top-k pooling. 107 | 108 | ### TextRNN 109 | 110 | * **hidden\_dimension**: dimension of hidden layer. 111 | * **rnn\_type**: Candidates: "RNN", "LSTM", "GRU". 112 | * **num\_layers**: number of layers. 113 | * **doc\_embedding\_type**: Candidates: "AVG", "Attention", "LastHidden". 114 | * **attention\_dimension**: dimension of self-attention. 115 | * **bidirectional**: Boolean, use Bi-RNNs. 116 | 117 | ### RCNN 118 | 119 | see TextCNN and TextRNN 120 | 121 | ### DRNN 122 | 123 | * **hidden\_dimension**: dimension of hidden layer. 124 | * **window\_size**: window size. 125 | * **rnn\_type**: Candidates: "RNN", "LSTM", "GRU". 126 | * **bidirectional**: Boolean. 127 | * **cell\_hidden\_dropout** 128 | 129 | ### VDCNN 130 | 131 | * **vdcnn\_depth**: depth of VDCNN. 132 | * **top\_k\_max\_pooling**: max top-k pooling. 133 | 134 | ### DPCNN 135 | 136 | * **kernel\_size**: kernel size. 137 | * **pooling\_stride**: stride of pooling. 138 | * **num\_kernels**: number of kernels. 139 | * **blocks**: number of blocks for DPCNN. 140 | 141 | ### AttentiveConvNet 142 | 143 | * **attention\_type**: Candidates: "dot", "bilinear", "additive_projection". 144 | * **margin\_size**: attentive width, must be odd. 145 | * **type**: Candidates: "light", "advanced". 146 | * **hidden\_size**: size of hidder layer. 147 | 148 | ### Transformer 149 | 150 | * **d\_inner**: dimension of inner nodes. 151 | * **d\_k**: dimension of key. 152 | * **d\_v**: dimension fo value. 153 | * **n\_head**: number of heads. 154 | * **n\_layers**: number of layers. 155 | * **dropout** 156 | * **use\_star**: whether use Star-Transformer, see [Star-Transformer](https://arxiv.org/pdf/1902.09113v2.pdf "Star-Transformer") 157 | -------------------------------------------------------------------------------- /model/classification/textvdcnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | # Implement model of "Very deep convolutional networks for text classification" 16 | # which can be seen at "http://www.aclweb.org/anthology/E17-1104" 17 | 18 | import torch 19 | 20 | import numpy as np 21 | 22 | from dataset.classification_dataset import ClassificationDataset as cDataset 23 | from model.classification.classifier import Classifier 24 | 25 | 26 | class TextVDCNN(Classifier): 27 | def __init__(self, dataset, config): 28 | """all convolutional blocks 29 | 4 kinds of conv blocks, which #feature_map are 64,128,256,512 30 | Depth: 9 17 29 49 31 | ------------------------------ 32 | conv block 512: 2 4 4 6 33 | conv block 256: 2 4 4 10 34 | conv block 128: 2 4 10 16 35 | conv block 64: 2 4 10 16 36 | First conv. layer: 1 1 1 1 37 | """ 38 | super(TextVDCNN, self).__init__(dataset, config) 39 | 40 | self.vdcnn_num_convs = {} 41 | self.vdcnn_num_convs[9] = [2, 2, 2, 2] 42 | self.vdcnn_num_convs[17] = [4, 4, 4, 4] 43 | self.vdcnn_num_convs[29] = [10, 10, 4, 4] 44 | self.vdcnn_num_convs[49] = [16, 16, 10, 6] 45 | self.num_kernels = [64, 128, 256, 512] 46 | 47 | self.vdcnn_depth = config.TextVDCNN.vdcnn_depth 48 | self.first_conv = torch.nn.Conv1d(config.embedding.dimension, 64, 3, 49 | padding=2) 50 | last_num_kernel = 64 51 | self.convs = torch.nn.ModuleList() 52 | self.batch_norms = torch.nn.ModuleList() 53 | for i, num_kernel in enumerate(self.num_kernels): 54 | tmp_convs = torch.nn.ModuleList() 55 | tmp_batch_norms = torch.nn.ModuleList() 56 | for _ in range(0, self.vdcnn_num_convs[self.vdcnn_depth][i]): 57 | tmp_convs.append( 58 | torch.nn.Conv1d(last_num_kernel, num_kernel, 3, padding=2)) 59 | tmp_batch_norms.append(torch.nn.BatchNorm1d(num_kernel)) 60 | last_num_kernel = num_kernel 61 | self.convs.append(tmp_convs) 62 | self.batch_norms.append(tmp_batch_norms) 63 | 64 | self.top_k = self.config.TextVDCNN.top_k_max_pooling 65 | hidden_size = self.num_kernels[-1] * self.top_k 66 | self.linear1 = torch.nn.Linear(hidden_size, 2048) 67 | self.linear2 = torch.nn.Linear(2048, 2048) 68 | self.linear = torch.nn.Linear(2048, len(dataset.label_map)) 69 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 70 | 71 | def get_parameter_optimizer_dict(self): 72 | params = list() 73 | params.append({'params': self.token_embedding.parameters()}) 74 | params.append({'params': self.char_embedding.parameters()}) 75 | params.append({'params': self.first_conv.parameters()}) 76 | for i in range(0, len(self.num_kernels)): 77 | params.append({'params': self.convs[i].parameters()}) 78 | params.append({'params': self.batch_norms[i].parameters()}) 79 | params.append({'params': self.linear1.parameters()}) 80 | params.append({'params': self.linear2.parameters()}) 81 | params.append({'params': self.linear3.parameters()}) 82 | return params 83 | 84 | def update_lr(self, optimizer, epoch): 85 | """Update lr 86 | """ 87 | if epoch > self.config.train.num_epochs_static_embedding: 88 | for param_group in optimizer.param_groups[:2]: 89 | param_group["lr"] = self.config.optimizer.learning_rate 90 | else: 91 | for param_group in optimizer.param_groups[:2]: 92 | param_group["lr"] = 0 93 | 94 | def forward(self, batch): 95 | def convolutional_block(inputs, num_layers, convs, batch_norms): 96 | """Convolutional Block of VDCNN 97 | Convolutional block contains 2 conv layers, and can be repeated 98 | Temp Conv-->Batch Norm-->ReLU-->Temp Conv-->Batch Norm-->ReLU 99 | """ 100 | hidden_layer = inputs 101 | for i in range(0, num_layers): 102 | batch_norm = batch_norms[i](convs[i](inputs)) 103 | hidden_layer = torch.nn.functional.relu(batch_norm) 104 | return hidden_layer 105 | 106 | if self.config.feature.feature_names[0] == "token": 107 | embedding = self.token_embedding( 108 | batch[cDataset.DOC_TOKEN].to(self.config.device)) 109 | else: 110 | embedding = self.char_embedding( 111 | batch[cDataset.DOC_CHAR].to(self.config.device)) 112 | embedding = embedding.transpose(1, 2) 113 | 114 | # first conv layer (kernel_size=3, #feature_map=64) 115 | first_conv = self.first_conv(embedding) 116 | first_conv = torch.nn.functional.relu(first_conv) 117 | 118 | # all convolutional blocks 119 | conv_block = first_conv 120 | for i in range(0, len(self.num_kernels)): 121 | conv_block = convolutional_block( 122 | conv_block, 123 | num_layers=self.vdcnn_num_convs[self.vdcnn_depth][i], 124 | convs=self.convs[i], 125 | batch_norms=self.batch_norms[i]) 126 | if i < len(self.num_kernels) - 1: 127 | # max-pooling with stride=2 128 | pool = torch.nn.functional.max_pool1d(conv_block, 129 | kernel_size=3, stride=2) 130 | else: 131 | # k-max-pooling 132 | pool = torch.topk(conv_block, self.top_k)[0].view( 133 | conv_block.size(0), -1) 134 | 135 | pool_shape = int(np.prod(pool.size()[1:])) 136 | doc_embedding = torch.reshape(pool, (-1, pool_shape)) 137 | fc1 = self.linear1(doc_embedding) 138 | fc2 = self.linear2(fc1) 139 | return self.dropout(self.linear(fc2)) 140 | -------------------------------------------------------------------------------- /model/model_util.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import codecs as cs 16 | import torch 17 | 18 | from util import Type 19 | from model.optimizer import BertAdam 20 | 21 | class ActivationType(Type): 22 | """Standard names for activation 23 | """ 24 | SIGMOID = 'sigmoid' 25 | TANH = "tanh" 26 | RELU = 'relu' 27 | LEAKY_RELU = 'leaky_relu' 28 | NONE = 'linear' 29 | 30 | @classmethod 31 | def str(cls): 32 | return ",".join( 33 | [cls.SIGMOID, cls.TANH, cls.RELU, cls.LEAKY_RELU, cls.NONE]) 34 | 35 | 36 | class InitType(Type): 37 | """Standard names for init 38 | """ 39 | UNIFORM = 'uniform' 40 | NORMAL = "normal" 41 | XAVIER_UNIFORM = 'xavier_uniform' 42 | XAVIER_NORMAL = 'xavier_normal' 43 | KAIMING_UNIFORM = 'kaiming_uniform' 44 | KAIMING_NORMAL = 'kaiming_normal' 45 | ORTHOGONAL = 'orthogonal' 46 | 47 | def str(self): 48 | return ",".join( 49 | [self.UNIFORM, self.NORMAL, self.XAVIER_UNIFORM, self.XAVIER_NORMAL, 50 | self.KAIMING_UNIFORM, self.KAIMING_NORMAL, self.ORTHOGONAL]) 51 | 52 | 53 | class FAN_MODE(Type): 54 | """Standard names for fan mode 55 | """ 56 | FAN_IN = 'FAN_IN' 57 | FAN_OUT = "FAN_OUT" 58 | 59 | def str(self): 60 | return ",".join([self.FAN_IN, self.FAN_OUT]) 61 | 62 | 63 | def init_tensor(tensor, init_type=InitType.XAVIER_UNIFORM, low=0, high=1, 64 | mean=0, std=1, activation_type=ActivationType.NONE, 65 | fan_mode=FAN_MODE.FAN_IN, negative_slope=0): 66 | """Init torch.Tensor 67 | Args: 68 | tensor: Tensor to be initialized. 69 | init_type: Init type, candidate can be found in InitType. 70 | low: The lower bound of the uniform distribution, 71 | useful when init_type is uniform. 72 | high: The upper bound of the uniform distribution, 73 | useful when init_type is uniform. 74 | mean: The mean of the normal distribution, 75 | useful when init_type is normal. 76 | std: The standard deviation of the normal distribution, 77 | useful when init_type is normal. 78 | activation_type: For xavier and kaiming init, 79 | coefficient is calculate according the activation_type. 80 | fan_mode: For kaiming init, fan mode is needed 81 | negative_slope: For kaiming init, 82 | coefficient is calculate according the negative_slope. 83 | Returns: 84 | """ 85 | if init_type == InitType.UNIFORM: 86 | return torch.nn.init.uniform_(tensor, a=low, b=high) 87 | elif init_type == InitType.NORMAL: 88 | return torch.nn.init.normal_(tensor, mean=mean, std=std) 89 | elif init_type == InitType.XAVIER_UNIFORM: 90 | return torch.nn.init.xavier_uniform_( 91 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) 92 | elif init_type == InitType.XAVIER_NORMAL: 93 | return torch.nn.init.xavier_normal_( 94 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) 95 | elif init_type == InitType.KAIMING_UNIFORM: 96 | return torch.nn.init.kaiming_uniform_( 97 | tensor, a=negative_slope, mode=fan_mode, 98 | nonlinearity=activation_type) 99 | elif init_type == InitType.KAIMING_NORMAL: 100 | return torch.nn.init.kaiming_normal_( 101 | tensor, a=negative_slope, mode=fan_mode, 102 | nonlinearity=activation_type) 103 | elif init_type == InitType.ORTHOGONAL: 104 | return torch.nn.init.orthogonal_( 105 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) 106 | else: 107 | raise TypeError( 108 | "Unsupported tensor init type: %s. Supported init type is: %s" % ( 109 | init_type, InitType.str())) 110 | 111 | 112 | class OptimizerType(Type): 113 | """Standard names for optimizer 114 | """ 115 | ADAM = "Adam" 116 | ADADELTA = "Adadelta" 117 | BERT_ADAM = "BERTAdam" 118 | 119 | def str(self): 120 | return ",".join([self.ADAM, self.ADADELTA]) 121 | 122 | 123 | def get_optimizer(config, params): 124 | params = params.get_parameter_optimizer_dict() 125 | if config.optimizer.optimizer_type == OptimizerType.ADAM: 126 | return torch.optim.Adam(lr=config.optimizer.learning_rate, 127 | params=params) 128 | elif config.optimizer.optimizer_type == OptimizerType.ADADELTA: 129 | return torch.optim.Adadelta( 130 | lr=config.optimizer.learning_rate, 131 | rho=config.optimizer.adadelta_decay_rate, 132 | eps=config.optimizer.adadelta_epsilon, 133 | params=params) 134 | elif config.optimizer.optimizer_type == OptimizerType.BERT_ADAM: 135 | return BertAdam(params, 136 | lr=config.optimizer.learning_rate, 137 | weight_decay=0, max_grad_norm=-1) 138 | else: 139 | raise TypeError( 140 | "Unsupported tensor optimizer type: %s.Supported optimizer " 141 | "type is: %s" % (config.optimizer_type, OptimizerType.str())) 142 | 143 | 144 | def get_hierar_relations(hierar_taxonomy, label_map): 145 | """ get parent-children relationships from given hierar_taxonomy 146 | hierar_taxonomy: parent_label \t child_label_0 \t child_label_1 \n 147 | """ 148 | hierar_relations = {} 149 | with cs.open(hierar_taxonomy, "r", "utf8") as f: 150 | for line in f: 151 | line_split = line.strip("\n").split("\t") 152 | parent_label, children_label = line_split[0], line_split[1:] 153 | if parent_label not in label_map: 154 | continue 155 | parent_label_id = label_map[parent_label] 156 | children_label_ids = [label_map[child_label] \ 157 | for child_label in children_label if child_label in label_map] 158 | hierar_relations[parent_label_id] = children_label_ids 159 | return hierar_relations 160 | -------------------------------------------------------------------------------- /model/classification/classifier.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from dataset.classification_dataset import ClassificationDataset as cDataset 18 | from model.embedding import Embedding 19 | from model.embedding import EmbeddingProcessType 20 | from model.embedding import EmbeddingType 21 | from model.embedding import RegionEmbeddingLayer 22 | from model.model_util import ActivationType 23 | 24 | 25 | class Classifier(torch.nn.Module): 26 | def __init__(self, dataset, config): 27 | super(Classifier, self).__init__() 28 | self.config = config 29 | assert len(self.config.feature.feature_names) == 1 30 | assert self.config.feature.feature_names[0] == "token" or \ 31 | self.config.feature.feature_names[0] == "char" 32 | if config.embedding.type == EmbeddingType.EMBEDDING: 33 | self.token_embedding = \ 34 | Embedding(dataset.token_map, config.embedding.dimension, 35 | cDataset.DOC_TOKEN, config, dataset.VOCAB_PADDING, 36 | pretrained_embedding_file= 37 | config.feature.token_pretrained_file, 38 | mode=EmbeddingProcessType.FLAT, 39 | dropout=self.config.embedding.dropout, 40 | init_type=self.config.embedding.initializer, 41 | low=-self.config.embedding.uniform_bound, 42 | high=self.config.embedding.uniform_bound, 43 | std=self.config.embedding.random_stddev, 44 | fan_mode=self.config.embedding.fan_mode, 45 | activation_type=ActivationType.NONE) 46 | self.char_embedding = \ 47 | Embedding(dataset.char_map, config.embedding.dimension, 48 | cDataset.DOC_CHAR, config, dataset.VOCAB_PADDING, 49 | mode=EmbeddingProcessType.FLAT, 50 | dropout=self.config.embedding.dropout, 51 | init_type=self.config.embedding.initializer, 52 | low=-self.config.embedding.uniform_bound, 53 | high=self.config.embedding.uniform_bound, 54 | std=self.config.embedding.random_stddev, 55 | fan_mode=self.config.embedding.fan_mode, 56 | activation_type=ActivationType.NONE) 57 | elif config.embedding.type == EmbeddingType.REGION_EMBEDDING: 58 | self.token_embedding = RegionEmbeddingLayer( 59 | dataset.token_map, config.embedding.dimension, 60 | config.embedding.region_size, cDataset.DOC_TOKEN, config, 61 | padding=dataset.VOCAB_PADDING, 62 | pretrained_embedding_file= 63 | config.feature.token_pretrained_file, 64 | dropout=self.config.embedding.dropout, 65 | init_type=self.config.embedding.initializer, 66 | low=-self.config.embedding.uniform_bound, 67 | high=self.config.embedding.uniform_bound, 68 | std=self.config.embedding.random_stddev, 69 | fan_mode=self.config.embedding.fan_mode, 70 | region_embedding_type=config.embedding.region_embedding_type) 71 | 72 | self.char_embedding = RegionEmbeddingLayer( 73 | dataset.char_map, config.embedding.dimension, 74 | config.embedding.region_size, cDataset.DOC_CHAR, config, 75 | padding=dataset.VOCAB_PADDING, 76 | dropout=self.config.embedding.dropout, 77 | init_type=self.config.embedding.initializer, 78 | low=-self.config.embedding.uniform_bound, 79 | high=self.config.embedding.uniform_bound, 80 | std=self.config.embedding.random_stddev, 81 | fan_mode=self.config.embedding.fan_mode, 82 | region_embedding_type=config.embedding.region_embedding_type) 83 | else: 84 | raise TypeError( 85 | "Unsupported embedding type: %s. " % config.embedding.type) 86 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 87 | 88 | def get_embedding(self, batch, pad_shape=None, pad_value=0): 89 | if self.config.feature.feature_names[0] == "token": 90 | token_id = batch[cDataset.DOC_TOKEN].to(self.config.device) 91 | if pad_shape is not None: 92 | token_id = torch.nn.functional.pad( 93 | token_id, pad_shape, mode='constant', value=pad_value) 94 | #embedding = self.token_embedding(token_id) 95 | embedding = token_id 96 | length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device) 97 | mask = batch[cDataset.DOC_TOKEN_MASK].to(self.config.device) 98 | else: 99 | char_id = batch[cDataset.DOC_TOKEN].to(self.config.device) 100 | if pad_shape is not None: 101 | char_id = torch.nn.functional.pad( 102 | char_id, pad_shape, mode='constant', value=pad_value) 103 | embedding = self.token_embedding(char_id) 104 | length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device) 105 | mask = batch[cDataset.DOC_CHAR_MASK].to(self.config.device) 106 | return embedding, length, mask 107 | 108 | def get_parameter_optimizer_dict(self): 109 | params = list() 110 | params.append( 111 | {'params': self.token_embedding.parameters(), 'is_embedding': True}) 112 | params.append( 113 | {'params': self.char_embedding.parameters(), 'is_embedding': True}) 114 | return params 115 | 116 | def update_lr(self, optimizer, epoch): 117 | """Update lr 118 | """ 119 | if epoch > self.config.train.num_epochs_static_embedding: 120 | for param_group in optimizer.param_groups[:2]: 121 | param_group["lr"] = self.config.optimizer.learning_rate 122 | else: 123 | for param_group in optimizer.param_groups[:2]: 124 | param_group["lr"] = 0 125 | 126 | def forward(self, batch): 127 | raise NotImplementedError 128 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from util import Type 19 | 20 | 21 | class LossType(Type): 22 | """Standard names for loss type 23 | """ 24 | SOFTMAX_CROSS_ENTROPY = "SoftmaxCrossEntropy" 25 | SOFTMAX_FOCAL_CROSS_ENTROPY = "SoftmaxFocalCrossEntropy" 26 | SIGMOID_FOCAL_CROSS_ENTROPY = "SigmoidFocalCrossEntropy" 27 | BCE_WITH_LOGITS = "BCEWithLogitsLoss" 28 | 29 | @classmethod 30 | def str(cls): 31 | return ",".join([cls.SOFTMAX_CROSS_ENTROPY, 32 | cls.SOFTMAX_FOCAL_CROSS_ENTROPY, 33 | cls.SIGMOID_FOCAL_CROSS_ENTROPY, 34 | cls.BCE_WITH_LOGITS]) 35 | 36 | 37 | class ActivationType(Type): 38 | """Standard names for activation type 39 | """ 40 | SOFTMAX = "Softmax" 41 | SIGMOID = "Sigmoid" 42 | 43 | @classmethod 44 | def str(cls): 45 | return ",".join([cls.SOFTMAX, 46 | cls.SIGMOID]) 47 | 48 | 49 | class FocalLoss(nn.Module): 50 | """Softmax focal loss 51 | references: Focal Loss for Dense Object Detection 52 | https://github.com/Hsuxu/FocalLoss-PyTorch 53 | """ 54 | 55 | def __init__(self, label_size, activation_type=ActivationType.SOFTMAX, 56 | gamma=2.0, alpha=0.25, epsilon=1.e-9): 57 | super(FocalLoss, self).__init__() 58 | self.num_cls = label_size 59 | self.activation_type = activation_type 60 | self.gamma = gamma 61 | self.alpha = alpha 62 | self.epsilon = epsilon 63 | 64 | def forward(self, logits, target): 65 | """ 66 | Args: 67 | logits: model's output, shape of [batch_size, num_cls] 68 | target: ground truth labels, shape of [batch_size] 69 | Returns: 70 | shape of [batch_size] 71 | """ 72 | if self.activation_type == ActivationType.SOFTMAX: 73 | idx = target.view(-1, 1).long() 74 | one_hot_key = torch.zeros(idx.size(0), self.num_cls, 75 | dtype=torch.float, 76 | device=idx.device) 77 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 78 | logits = torch.softmax(logits, dim=-1) 79 | loss = -self.alpha * one_hot_key * \ 80 | torch.pow((1 - logits), self.gamma) * \ 81 | (logits + self.epsilon).log() 82 | loss = loss.sum(1) 83 | elif self.activation_type == ActivationType.SIGMOID: 84 | multi_hot_key = target 85 | logits = torch.sigmoid(logits) 86 | zero_hot_key = 1 - multi_hot_key 87 | loss = -self.alpha * multi_hot_key * \ 88 | torch.pow((1 - logits), self.gamma) * \ 89 | (logits + self.epsilon).log() 90 | loss += -(1 - self.alpha) * zero_hot_key * \ 91 | torch.pow(logits, self.gamma) * \ 92 | (1 - logits + self.epsilon).log() 93 | else: 94 | raise TypeError("Unknown activation type: " + self.activation_type 95 | + "Supported activation types: " + 96 | ActivationType.str()) 97 | return loss.mean() 98 | 99 | 100 | class ClassificationLoss(torch.nn.Module): 101 | def __init__(self, label_size, class_weight=None, 102 | loss_type=LossType.SOFTMAX_CROSS_ENTROPY): 103 | super(ClassificationLoss, self).__init__() 104 | self.label_size = label_size 105 | self.loss_type = loss_type 106 | if loss_type == LossType.SOFTMAX_CROSS_ENTROPY: 107 | self.criterion = torch.nn.CrossEntropyLoss(class_weight) 108 | elif loss_type == LossType.SOFTMAX_FOCAL_CROSS_ENTROPY: 109 | self.criterion = FocalLoss(label_size, ActivationType.SOFTMAX) 110 | elif loss_type == LossType.SIGMOID_FOCAL_CROSS_ENTROPY: 111 | self.criterion = FocalLoss(label_size, ActivationType.SIGMOID) 112 | elif loss_type == LossType.BCE_WITH_LOGITS: 113 | self.criterion = torch.nn.BCEWithLogitsLoss() 114 | else: 115 | raise TypeError( 116 | "Unsupported loss type: %s. Supported loss type is: %s" % ( 117 | loss_type, LossType.str())) 118 | 119 | def forward(self, logits, target, 120 | use_hierar=False, 121 | is_multi=False, 122 | *argvs): 123 | if use_hierar: 124 | assert self.loss_type in [LossType.BCE_WITH_LOGITS, 125 | LossType.SIGMOID_FOCAL_CROSS_ENTROPY] 126 | device = logits.device 127 | if not is_multi: 128 | target = torch.eye(self.label_size)[target].to(device) 129 | hierar_penalty, hierar_paras, hierar_relations = argvs[0:3] 130 | return self.criterion(logits, target) + \ 131 | hierar_penalty * self.cal_recursive_regularize(hierar_paras, 132 | hierar_relations, 133 | device) 134 | else: 135 | return self.criterion(logits, target) 136 | 137 | def cal_recursive_regularize(self, paras, hierar_relations, device="cpu"): 138 | """ Only support hierarchical text classification with BCELoss 139 | references: http://www.cse.ust.hk/~yqsong/papers/2018-WWW-Text-GraphCNN.pdf 140 | http://www.cs.cmu.edu/~sgopal1/papers/KDD13.pdf 141 | """ 142 | recursive_loss = 0.0 143 | for i in range(len(paras)): 144 | if i not in hierar_relations: 145 | continue 146 | children_ids = hierar_relations[i] 147 | if not children_ids: 148 | continue 149 | children_ids_list = torch.tensor(children_ids, dtype=torch.long).to( 150 | device) 151 | children_paras = torch.index_select(paras, 0, children_ids_list) 152 | parent_para = torch.index_select(paras, 0, 153 | torch.tensor(i).to(device)) 154 | parent_para = parent_para.repeat(children_ids_list.size()[0], 1) 155 | diff_paras = parent_para - children_paras 156 | diff_paras = diff_paras.view(diff_paras.size()[0], -1) 157 | recursive_loss += 1.0 / 2 * torch.norm(diff_paras, p=2) ** 2 158 | return recursive_loss 159 | -------------------------------------------------------------------------------- /model/classification/fasttext.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from dataset.classification_dataset import ClassificationDataset as cDataset 18 | from model.embedding import Embedding 19 | from model.embedding import EmbeddingProcessType 20 | from model.model_util import ActivationType 21 | 22 | 23 | class FastText(torch.nn.Module): 24 | """Implement fasttext classification method 25 | Reference: "Bag of Tricks for Efficient Text Classification" 26 | """ 27 | 28 | def __init__(self, dataset, config): 29 | super(FastText, self).__init__() 30 | self.config = config 31 | assert "token" in self.config.feature.feature_names 32 | self.token_embedding = \ 33 | Embedding(dataset.token_map, 34 | config.embedding.dimension, 35 | cDataset.DOC_TOKEN, config, 36 | padding_idx=dataset.VOCAB_PADDING, 37 | pretrained_embedding_file= 38 | config.feature.token_pretrained_file, 39 | mode=EmbeddingProcessType.SUM, dropout=0, 40 | init_type=config.embedding.initializer, 41 | low=-config.embedding.uniform_bound, 42 | high=config.embedding.uniform_bound, 43 | std=config.embedding.random_stddev, 44 | activation_type=ActivationType.NONE) 45 | if self.config.feature.token_ngram > 1: 46 | self.token_ngram_embedding = \ 47 | Embedding(dataset.token_ngram_map, 48 | config.embedding.dimension, 49 | cDataset.DOC_TOKEN_NGRAM, config, 50 | padding_idx=dataset.VOCAB_PADDING, 51 | mode=EmbeddingProcessType.SUM, dropout=0, 52 | init_type=config.embedding.initializer, 53 | low=-config.embedding.uniform_bound, 54 | high=config.embedding.uniform_bound, 55 | std=config.embedding.random_stddev, 56 | activation_type=ActivationType.NONE) 57 | if "keyword" in self.config.feature.feature_names: 58 | self.keyword_embedding = \ 59 | Embedding(dataset.keyword_map, 60 | config.embedding.dimension, 61 | cDataset.DOC_KEYWORD, config, 62 | padding_idx=dataset.VOCAB_PADDING, 63 | pretrained_embedding_file= 64 | config.feature.keyword_pretrained_file, 65 | mode=EmbeddingProcessType.SUM, dropout=0, 66 | init_type=config.embedding.initializer, 67 | low=-config.embedding.uniform_bound, 68 | high=config.embedding.uniform_bound, 69 | std=config.embedding.random_stddev, 70 | activation_type=ActivationType.NONE) 71 | if "topic" in self.config.feature.feature_names: 72 | self.topic_embedding = \ 73 | Embedding(dataset.topic_map, 74 | config.embedding.dimension, 75 | cDataset.DOC_TOPIC, config, 76 | padding_idx=dataset.VOCAB_PADDING, 77 | mode=EmbeddingProcessType.SUM, dropout=0, 78 | init_type=config.embedding.initializer, 79 | low=-config.embedding.uniform_bound, 80 | high=config.embedding.uniform_bound, 81 | std=config.embedding.random_stddev, 82 | activation_type=ActivationType.NONE) 83 | self.linear = torch.nn.Linear( 84 | config.embedding.dimension, len(dataset.label_map)) 85 | self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout) 86 | 87 | def get_parameter_optimizer_dict(self): 88 | params = list() 89 | params.append({'params': self.token_embedding.parameters()}) 90 | if self.config.feature.token_ngram > 1: 91 | params.append({'params': self.token_ngram_embedding.parameters()}) 92 | if "keyword" in self.config.feature.feature_names: 93 | params.append({'params': self.keyword_embedding.parameters()}) 94 | if "topic" in self.config.feature.feature_names: 95 | params.append({'params': self.topic_embedding.parameters()}) 96 | params.append({'params': self.linear.parameters()}) 97 | return params 98 | 99 | def update_lr(self, optimizer, epoch): 100 | """Update lr 101 | """ 102 | if epoch > self.config.train.num_epochs_static_embedding: 103 | for param_group in optimizer.param_groups: 104 | param_group["lr"] = self.config.optimizer.learning_rate 105 | else: 106 | for param_group in optimizer.param_groups: 107 | param_group["lr"] = 0 108 | 109 | def forward(self, batch): 110 | doc_embedding = self.token_embedding( 111 | batch[cDataset.DOC_TOKEN].to(self.config.device), 112 | batch[cDataset.DOC_TOKEN_OFFSET].to(self.config.device)) 113 | length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device) 114 | if self.config.feature.token_ngram > 1: 115 | doc_embedding += self.token_ngram_embedding( 116 | batch[cDataset.DOC_TOKEN_NGRAM].to(self.config.device), 117 | batch[cDataset.DOC_TOKEN_NGRAM_OFFSET].to(self.config.device)) 118 | length += batch[cDataset.DOC_TOKEN_NGRAM_LEN].to(self.config.device) 119 | if "keyword" in self.config.feature.feature_names: 120 | doc_embedding += self.keyword_embedding( 121 | batch[cDataset.DOC_KEYWORD].to(self.config.device), 122 | batch[cDataset.DOC_KEYWORD_OFFSET].to(self.config.device)) 123 | length += batch[cDataset.DOC_KEYWORD_LEN].to(self.config.device) 124 | if "topic" in self.config.feature.feature_names: 125 | doc_embedding += self.topic_embedding( 126 | batch[cDataset.DOC_TOPIC].to(self.config.device), 127 | batch[cDataset.DOC_TOPIC_OFFSET].to(self.config.device)) 128 | length += batch[cDataset.DOC_TOPIC_LEN].to(self.config.device) 129 | 130 | doc_embedding /= length.resize_(doc_embedding.size()[0], 1) 131 | doc_embedding = self.dropout(doc_embedding) 132 | return self.linear(doc_embedding) 133 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import math 16 | 17 | import torch 18 | 19 | from model.model_util import init_tensor 20 | 21 | 22 | class SumAttention(torch.nn.Module): 23 | """ 24 | Reference: Hierarchical Attention Networks for Document Classification 25 | """ 26 | 27 | def __init__(self, input_dimension, attention_dimension, device, dropout=0): 28 | super(SumAttention, self).__init__() 29 | self.attention_matrix = \ 30 | init_tensor(torch.empty(input_dimension, attention_dimension)).to(device) 31 | self.bias = torch.zeros(attention_dimension).to(device) 32 | self.attention_vector = init_tensor(torch.empty(attention_dimension, 1)).to(device) 33 | self.dropout = torch.nn.Dropout(p=dropout) 34 | 35 | def forward(self, inputs): 36 | if inputs.size(1) == 1: 37 | return self.dropout(inputs.squeeze()) 38 | u = torch.tanh(torch.matmul(inputs, self.attention_matrix) + self.bias) 39 | v = torch.matmul(u, self.attention_vector) 40 | alpha = torch.nn.functional.softmax(v, 1).squeeze().unsqueeze(1) 41 | return self.dropout(torch.matmul(alpha, inputs).squeeze()) 42 | 43 | 44 | class AdditiveAttention(torch.nn.Module): 45 | """Also known as Soft Attention or Bahdanau Attention 46 | Reference: 47 | Neural machine translation by jointly learning to align and translate 48 | """ 49 | 50 | def __init__(self, dim, dropout=0): 51 | super(AdditiveAttention, self).__init__() 52 | self.w_attention_matrix = init_tensor(torch.empty(dim, dim)) 53 | self.u_attention_matrix = init_tensor(torch.empty(dim, dim)) 54 | self.v_attention_vector = init_tensor(torch.empty(dim, 1)) 55 | 56 | self.dropout = torch.nn.Dropout(p=dropout) 57 | 58 | def forward(self, s, h): 59 | raise NotImplementedError 60 | 61 | 62 | class AdditiveAttention1D(AdditiveAttention): 63 | """ 64 | Input shape is: [batch, dim] and [batch, seq_len, dim] 65 | Output is same with the first input 66 | """ 67 | 68 | def forward(self, s, h): 69 | s_attention = s.matmul(self.w_attention_matrix).unsqueeze(1) 70 | h_attention = h.matmul(self.u_attention_matrix) 71 | attention = torch.tanh(s_attention + h_attention) 72 | attention = attention.matmul(self.v_attention_vector).squeeze() 73 | attention_weight = torch.nn.functional.softmax(attention, -1) 74 | return self.dropout(attention_weight.unsqueeze(1).matmul(h).squeeze()) 75 | 76 | 77 | class AdditiveAttention2D(AdditiveAttention): 78 | """ 79 | Input shape is: [batch, seq_len, dim] and [batch, seq_len, dim] 80 | Output is same with the first input 81 | """ 82 | 83 | def forward(self, s, h): 84 | s_attention = s.matmul(self.w_attention_matrix).unsqueeze(2) 85 | h_attention = h.matmul(self.u_attention_matrix).unsqueeze(1) 86 | seq_len = h.size(1) 87 | h_attention = h_attention.expand(-1, seq_len, -1, -1) 88 | attention = torch.nn.functional.tanh(s_attention + h_attention) 89 | attention = attention.matmul(self.v_attention_vector).squeeze() 90 | attention_weight = torch.nn.functional.softmax(attention, -1) 91 | return self.dropout(attention_weight.unsqueeze(2).matmul(h).squeeze()) 92 | 93 | 94 | class DotProductAttention(torch.nn.Module): 95 | """ 96 | Reference: Attention is all you need 97 | Input shape is: [batch, seq_len, dim_k] and [batch, seq_len, dim_k] 98 | [batch, seq_len, dim_v] 99 | Output is same with the third input 100 | """ 101 | 102 | def __init__(self, scaling_factor=None, dropout=0): 103 | super(DotProductAttention, self).__init__() 104 | self.scaling_factor = scaling_factor 105 | self.dropout = torch.nn.Dropout(p=dropout) 106 | 107 | def forward(self, q, k, v): 108 | if self.scaling_factor is None: 109 | self.scaling_factor = 1 / math.sqrt(q.size(2)) 110 | e = q.matmul(k.permute(0, 2, 1)) / self.scaling_factor 111 | attention_weight = torch.nn.functional.softmax(e, -1) 112 | return self.dropout(attention_weight.matmul(v)) 113 | 114 | 115 | class MultiHeadAttention(torch.nn.Module): 116 | """ 117 | Reference: Attention is all you need 118 | """ 119 | 120 | def __init__(self, dimension, dk, dv, head_number, 121 | scaling_factor, dropout=0): 122 | super(MultiHeadAttention, self).__init__() 123 | self.dk = dk 124 | self.dv = dv 125 | self.head_number = head_number 126 | self.q_linear = torch.nn.Linear(dimension, head_number * dk) 127 | self.k_linear = torch.nn.Linear(dimension, head_number * dk) 128 | self.v_linear = torch.nn.Linear(dimension, head_number * dv) 129 | self.scaling_factor = scaling_factor 130 | self.dropout = torch.nn.Dropout(p=dropout) 131 | 132 | def forward(self, q, k, v): 133 | def _reshape_permute(x, d, head_number): 134 | x = x.view(x.size(0), x.size(1), head_number, d) 135 | return x.permute(0, 2, 1, 3) 136 | 137 | q_trans = _reshape_permute(self.q_linear(q), self.dk, self.head_number) 138 | k_trans = _reshape_permute(self.k_linear(k), self.dk, self.head_number) 139 | v_trans = _reshape_permute(self.v_linear(v), self.dv, self.head_number) 140 | 141 | e = q_trans.matmul(k_trans.permute(0, 1, 3, 2)) / self.scaling_factor 142 | attention_weight = torch.nn.functional.softmax(e, -1) 143 | output = attention_weight.matmul(v_trans).permute(0, 2, 1, 3) 144 | output = output.view(output.size(0), output.size(1), 145 | output.size(2) * output.size(3)) 146 | return self.dropout(output) 147 | 148 | 149 | class Highway(torch.nn.Module): 150 | """ 151 | Reference: Highway Networks. 152 | For now we don't limit the type of the gate and forward. 153 | Caller should init Highway with transformer and carry and guarantee the dim 154 | to be matching. 155 | """ 156 | 157 | def __init__(self, transformer_gate, transformer_forward): 158 | super(Highway, self).__init__() 159 | self.transformer_forward = transformer_forward 160 | self.transformer_gate = transformer_gate 161 | 162 | def forward(self, x, gate_input=None, forward_input=None): 163 | if gate_input is None: 164 | gate_input = x 165 | if forward_input is None: 166 | forward_input = x 167 | gate = self.transformer_gate(gate_input) 168 | forward = self.transformer_forward(forward_input) 169 | return gate * forward + (1 - gate) * x 170 | -------------------------------------------------------------------------------- /dataset/classification_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | from dataset.dataset import DatasetBase 16 | from dataset.dataset import InsertVocabMode 17 | from util import ModeType 18 | import torch 19 | from transformers import * 20 | 21 | class ClassificationDataset(DatasetBase): 22 | CLASSIFICATION_LABEL_SEPARATOR = "--" 23 | DOC_LABEL = "doc_label" 24 | DOC_LABEL_LIST = "doc_label_list" 25 | 26 | DOC_TOKEN = "doc_token" 27 | DOC_CHAR = "doc_char" 28 | DOC_CHAR_IN_TOKEN = "doc_char_in_token" 29 | DOC_TOKEN_NGRAM = "doc_token_ngram" 30 | DOC_KEYWORD = "doc_keyword" 31 | DOC_TOPIC = "doc_topic" 32 | 33 | DOC_TOKEN_OFFSET = "doc_token_offset" 34 | DOC_TOKEN_NGRAM_OFFSET = "doc_token_ngram_offset" 35 | DOC_KEYWORD_OFFSET = "doc_keyword_offset" 36 | DOC_TOPIC_OFFSET = "doc_topic_offset" 37 | 38 | DOC_TOKEN_LEN = "doc_token_len" 39 | DOC_CHAR_LEN = "doc_char_len" 40 | DOC_CHAR_IN_TOKEN_LEN = "doc_char_in_token_len" 41 | DOC_TOKEN_NGRAM_LEN = "doc_token_ngram_len" 42 | DOC_KEYWORD_LEN = "doc_keyword_len" 43 | DOC_TOPIC_LEN = "doc_topic_len" 44 | 45 | DOC_TOKEN_MASK = "doc_token_mask" 46 | DOC_CHAR_MASK = "doc_char_mask" 47 | DOC_CHAR_IN_TOKEN_MASK = "doc_char_in_token_mask" 48 | 49 | DOC_TOKEN_MAX_LEN = "doc_token_max_len" 50 | DOC_CHAR_MAX_LEN = "doc_char_max_len" 51 | DOC_CHAR_IN_TOKEN_MAX_LEN = "doc_char_in_token_max_len" 52 | 53 | def __init__(self, config, json_files, generate_dict=False, 54 | mode=ModeType.EVAL): 55 | super(ClassificationDataset, self).__init__( 56 | config, json_files, generate_dict=generate_dict, mode=mode) 57 | self.tokenizer = BertTokenizer.from_pretrained('/dockerdata/xiaomingshi/chinese_L-12_H-768_A-12') 58 | 59 | def _init_dict(self): 60 | self.dict_names = \ 61 | [self.DOC_LABEL, self.DOC_TOKEN, self.DOC_CHAR, 62 | self.DOC_TOKEN_NGRAM, self.DOC_KEYWORD, self.DOC_TOPIC] 63 | 64 | self.dict_files = [] 65 | for dict_name in self.dict_names: 66 | self.dict_files.append( 67 | self.config.data.dict_dir + "/" + dict_name + ".dict") 68 | self.label_dict_file = self.dict_files[0] 69 | 70 | # By default keep all labels 71 | self.min_count = [0, 72 | self.config.feature.min_token_count, 73 | self.config.feature.min_char_count, 74 | self.config.feature.min_token_ngram_count, 75 | self.config.feature.min_keyword_count, 76 | self.config.feature.min_topic_count] 77 | 78 | # By default keep all labels 79 | self.max_dict_size = [self.BIG_VALUE, 80 | self.config.feature.max_token_dict_size, 81 | self.config.feature.max_char_dict_size, 82 | self.config.feature.max_token_ngram_dict_size, 83 | self.config.feature.max_keyword_dict_size, 84 | self.config.feature.max_topic_dict_size] 85 | 86 | self.max_sequence_length = [ 87 | self.config.feature.max_token_len, 88 | self.config.feature.max_char_len] 89 | 90 | # When generating dict, the following map store vocab count. 91 | # Then clear dict and load vocab of word index 92 | self.label_map = dict() 93 | self.token_map = dict() 94 | self.char_map = dict() 95 | self.token_ngram_map = dict() 96 | self.keyword_map = dict() 97 | self.topic_map = dict() 98 | self.dicts = [self.label_map, self.token_map, self.char_map, 99 | self.token_ngram_map, self.keyword_map, self.topic_map] 100 | 101 | # Save sorted dict according to the count 102 | self.label_count_list = [] 103 | self.token_count_list = [] 104 | self.char_count_list = [] 105 | self.token_ngram_count_list = [] 106 | self.keyword_count_list = [] 107 | self.topic_count_list = [] 108 | self.count_list = [self.label_count_list, self.token_count_list, 109 | self.char_count_list, self.token_ngram_count_list, 110 | self.keyword_count_list, self.topic_count_list] 111 | 112 | self.id_to_label_map = dict() 113 | self.id_to_token_map = dict() 114 | self.id_to_char_map = dict() 115 | self.id_to_token_gram_map = dict() 116 | self.id_to_keyword_map = dict() 117 | self.id_to_topic_map = dict() 118 | self.id_to_vocab_dict_list = [ 119 | self.id_to_label_map, self.id_to_token_map, self.id_to_char_map, 120 | self.id_to_token_gram_map, self.id_to_keyword_map, 121 | self.id_to_topic_map] 122 | 123 | self.pretrained_dict_names = [self.DOC_TOKEN, self.DOC_KEYWORD] 124 | self.pretrained_dict_files = \ 125 | [self.config.feature.token_pretrained_file, 126 | self.config.feature.keyword_pretrained_file] 127 | self.pretrained_min_count = \ 128 | [self.config.feature.min_token_count, 129 | self.config.feature.min_keyword_count] 130 | 131 | def _insert_vocab(self, json_obj, mode=InsertVocabMode.ALL): 132 | """Insert vocab to dict 133 | """ 134 | if mode == InsertVocabMode.ALL or mode == InsertVocabMode.LABEL: 135 | doc_labels = json_obj[self.DOC_LABEL] 136 | self._insert_sequence_vocab(doc_labels, self.label_map) 137 | if mode == InsertVocabMode.ALL or mode == InsertVocabMode.OTHER: 138 | doc_tokens = \ 139 | json_obj[self.DOC_TOKEN][0:self.config.feature.max_token_len] 140 | doc_keywords = json_obj[self.DOC_KEYWORD] 141 | doc_topics = json_obj[self.DOC_TOPIC] 142 | 143 | self._insert_sequence_tokens( 144 | doc_tokens, self.token_map, self.token_ngram_map, self.char_map, 145 | self.config.feature.token_ngram) 146 | self._insert_sequence_vocab(doc_keywords, self.keyword_map) 147 | self._insert_sequence_vocab(doc_topics, self.topic_map) 148 | 149 | def get_sentence_bert_id(self, sentence): 150 | ''' 151 | sentence: str sequence 152 | ''' 153 | sentence = '[CLS] ' + ' '.join(sentence) + ' [SEP]' 154 | token_ids = self.tokenizer.encode(sentence, add_special_tokens=False, max_length=50, truncation=True) 155 | return token_ids 156 | 157 | def _get_vocab_id_list(self, json_obj): 158 | """Use dict to convert all vocabs to ids 159 | """ 160 | doc_labels = json_obj[self.DOC_LABEL] 161 | doc_tokens = \ 162 | json_obj[self.DOC_TOKEN][0:self.config.feature.max_token_len] 163 | doc_keywords = json_obj[self.DOC_KEYWORD] 164 | doc_topics = json_obj[self.DOC_TOPIC] 165 | 166 | token_ids, char_ids, char_in_token_ids, token_ngram_ids = \ 167 | self._token_to_id(doc_tokens, self.token_map, self.char_map, 168 | self.config.feature.token_ngram, 169 | self.token_ngram_map, 170 | self.config.feature.max_char_len, 171 | self.config.feature.max_char_len_per_token) 172 | return {self.DOC_LABEL: self._label_to_id(doc_labels, self.label_map), 173 | #self.DOC_TOKEN: token_ids, 174 | # 修改_get_vocab_id_list函数,使之只对label进行处理,对doc_token获取其对应的embedding 175 | #self.DOC_TOKEN: doc_tokens, 176 | #self.DOC_TOKEN: self.get_sentence_bert_embedding(doc_tokens), 177 | self.DOC_TOKEN: self.get_sentence_bert_id(doc_tokens), 178 | self.DOC_CHAR: char_ids, 179 | self.DOC_CHAR_IN_TOKEN: char_in_token_ids, 180 | self.DOC_TOKEN_NGRAM: token_ngram_ids, 181 | self.DOC_KEYWORD: 182 | self._vocab_to_id(doc_keywords, self.keyword_map), 183 | self.DOC_TOPIC: self._vocab_to_id(doc_topics, self.topic_map)} 184 | -------------------------------------------------------------------------------- /model/classification/attentive_convolution.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import torch 16 | 17 | from model.classification.classifier import Classifier 18 | from model.layers import AdditiveAttention2D 19 | from model.layers import DotProductAttention 20 | from model.layers import Highway 21 | from model.model_util import init_tensor 22 | from util import Type 23 | 24 | 25 | class AttentiveConvNetType(Type): 26 | LIGHT = "light" 27 | ADVANCED = "advanced" 28 | 29 | @classmethod 30 | def str(cls): 31 | return ",".join(cls.LIGHT, cls.ADVANCED) 32 | 33 | 34 | class AttentionType(Type): 35 | DOT = "dot" 36 | BILINEAR = "bilinear" 37 | ADDITIVE_PROJECTION = "additive_projection" 38 | 39 | @classmethod 40 | def str(cls): 41 | return ",".join(cls.DOT, cls.BILINEAR, cls.ADDITIVE_PROJECTION) 42 | 43 | 44 | class AttentiveConvNet(Classifier): 45 | """Attentive Convolution: 46 | Equipping CNNs with RNN-style Attention Mechanisms 47 | """ 48 | 49 | def __init__(self, dataset, config): 50 | super(AttentiveConvNet, self).__init__(dataset, config) 51 | self.attentive_conv_net_type = config.AttentiveConvNet.type 52 | self.attention_type = config.AttentiveConvNet.attention_type 53 | self.dim = config.embedding.dimension 54 | self.attention_dim = self.dim 55 | self.margin_size = config.AttentiveConvNet.margin_size 56 | assert self.margin_size % 2 == 1, \ 57 | "AttentiveConvNet margin size should be odd!" 58 | 59 | self.radius = int(self.margin_size / 2) 60 | if self.attentive_conv_net_type == AttentiveConvNetType.ADVANCED: 61 | self.attention_dim *= 2 62 | self.x_context_highway = self.get_highway(self.dim, 63 | self.margin_size) 64 | 65 | self.x_self_highway = self.get_highway(self.dim, 1) 66 | 67 | self.a_context_highway = self.get_highway(self.dim, 68 | self.margin_size) 69 | self.a_self_highway = self.get_highway(self.dim, 1) 70 | self.beneficiary_highway = self.get_highway(self.dim, 1) 71 | 72 | if self.attention_type == AttentionType.DOT: 73 | self.dot_product_attention = DotProductAttention(1.0) 74 | elif self.attention_type == AttentionType.BILINEAR: 75 | self.bilinear_matrix = init_tensor( 76 | torch.empty(self.attention_dim, self.attention_dim)).to( 77 | config.device) 78 | self.dot_product_attention = DotProductAttention(1.0) 79 | elif self.attention_type == AttentionType.ADDITIVE_PROJECTION: 80 | self.additive_projection = AdditiveAttention2D(self.attention_dim) 81 | else: 82 | raise TypeError( 83 | "Unsupported AttentionType: %s." % self.attention_type) 84 | 85 | self.attentive_conv = init_tensor( 86 | torch.empty(self.attention_dim, self.dim)).to(config.device) 87 | self.x_conv = torch.nn.Sequential( 88 | torch.nn.Conv1d(self.dim, self.dim, self.margin_size, 89 | padding=self.radius), 90 | torch.nn.Tanh()) 91 | self.bias = torch.zeros([self.dim]).to(config.device) 92 | self.hidden_size = config.AttentiveConvNet.hidden_size 93 | self.hidden1_matrix = init_tensor( 94 | torch.empty(self.dim, self.hidden_size)).to(config.device) 95 | self.hidden2_matrix = init_tensor( 96 | torch.empty(self.hidden_size, self.hidden_size)).to(config.device) 97 | self.linear = torch.nn.Linear(self.dim + 2 * self.hidden_size, 98 | len(dataset.label_map)) 99 | 100 | @staticmethod 101 | def get_highway(dimension, margin_size): 102 | radius = int(margin_size / 2) 103 | transformer_gate = torch.nn.Sequential( 104 | torch.nn.Conv1d(dimension, dimension, margin_size, padding=radius), 105 | torch.nn.Sigmoid()) 106 | transformer_forward = torch.nn.Sequential( 107 | torch.nn.Conv1d(dimension, dimension, margin_size, padding=radius), 108 | torch.nn.Tanh()) 109 | return Highway(transformer_gate, transformer_forward) 110 | 111 | def get_parameter_optimizer_dict(self): 112 | params = super(AttentiveConvNet, 113 | self).get_parameter_optimizer_dict() 114 | if self.attentive_conv_net_type == AttentiveConvNetType.ADVANCED: 115 | params.append({'params': self.x_context_highway.parameters()}) 116 | params.append({'params': self.x_self_highway.parameters()}) 117 | params.append({'params': self.a_context_highway.parameters()}) 118 | params.append({'params': self.a_self_highway.parameters()}) 119 | params.append({'params': self.beneficiary_highway.parameters()}) 120 | if self.attention_type == AttentionType.DOT: 121 | params.append({'params': self.dot_product_attention.parameters()}) 122 | elif self.attention_type == AttentionType.BILINEAR: 123 | params.append({'params': self.bilinear_matrix}) 124 | params.append({'params': self.dot_product_attention.parameters()}) 125 | elif self.attention_type == AttentionType.ADDITIVE_PROJECTION: 126 | params.append({'params': self.additive_projection.parameters()}) 127 | 128 | params.append({'params': self.attentive_conv}) 129 | params.append({'params': self.x_conv.parameters()}) 130 | params.append({'params': self.hidden1_matrix}) 131 | params.append({'params': self.hidden2_matrix}) 132 | params.append({'params': self.linear.parameters()}) 133 | 134 | return params 135 | 136 | def forward(self, batch): 137 | 138 | embedding, _, _ = self.get_embedding(batch) 139 | if self.attentive_conv_net_type == AttentiveConvNetType.LIGHT: 140 | x_multi_granularity, a_multi_granularity, x_beneficiary = \ 141 | embedding, embedding, embedding 142 | elif self.attentive_conv_net_type == AttentiveConvNetType.ADVANCED: 143 | embedding = embedding.permute(0, 2, 1) 144 | source_context = self.x_context_highway(embedding) 145 | source_self = self.x_self_highway(embedding) 146 | x_multi_granularity = \ 147 | torch.cat([source_context, source_self], 1).permute(0, 2, 1) 148 | 149 | focus_context = self.a_context_highway(embedding) 150 | focus_self = self.a_self_highway(embedding) 151 | a_multi_granularity = \ 152 | torch.cat([focus_context, focus_self], 1).permute(0, 2, 1) 153 | 154 | x_beneficiary = self.beneficiary_highway( 155 | embedding).permute(0, 2, 1) 156 | else: 157 | raise TypeError( 158 | "Unsupported AttentiveConvNetType: %s." % 159 | self.attentive_conv_net_type) 160 | 161 | if self.attention_type == AttentionType.DOT: 162 | attentive_context = self.dot_product_attention( 163 | x_multi_granularity, a_multi_granularity, a_multi_granularity) 164 | elif self.attention_type == AttentionType.BILINEAR: 165 | x_trans = x_multi_granularity.matmul(self.bilinear_matrix) 166 | attentive_context = self.dot_product_attention( 167 | x_trans, a_multi_granularity, a_multi_granularity) 168 | elif self.attention_type == AttentionType.ADDITIVE_PROJECTION: 169 | attentive_context = self.additive_projection( 170 | a_multi_granularity, x_multi_granularity) 171 | 172 | attentive_conv = attentive_context.matmul(self.attentive_conv) 173 | x_conv = self.x_conv(x_beneficiary.permute(0, 2, 1)).permute(0, 2, 1) 174 | attentive_convolution = \ 175 | torch.tanh(attentive_conv + x_conv + self.bias).permute(0, 2, 1) 176 | hidden = torch.nn.functional.max_pool1d( 177 | attentive_convolution, 178 | kernel_size=attentive_convolution.size()[-1]).squeeze() 179 | hidden1 = hidden.matmul(self.hidden1_matrix) 180 | hidden2 = hidden1.matmul(self.hidden2_matrix) 181 | hidden_layer = torch.cat([hidden, hidden1, hidden2], 1) 182 | 183 | return self.dropout(self.linear(hidden_layer)) 184 | -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | #copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """PyTorch optimization for BERT model.""" 29 | 30 | import math 31 | 32 | import torch 33 | from torch.nn.utils import clip_grad_norm_ 34 | from torch.optim import Optimizer 35 | from torch.optim.optimizer import required 36 | 37 | 38 | def warmup_cosine(x, warmup=0.002): 39 | if x < warmup: 40 | return x / warmup 41 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 42 | 43 | 44 | def warmup_constant(x, warmup=0.002): 45 | if x < warmup: 46 | return x / warmup 47 | return 1.0 48 | 49 | 50 | def warmup_linear(x, warmup=0.002): 51 | if x < warmup: 52 | return x / warmup 53 | return 1.0 - x 54 | 55 | 56 | SCHEDULES = { 57 | 'warmup_cosine': warmup_cosine, 58 | 'warmup_constant': warmup_constant, 59 | 'warmup_linear': warmup_linear, 60 | } 61 | 62 | 63 | class BertAdam(Optimizer): 64 | """Implements BERT version of Adam algorithm with weight decay fix. 65 | Params: 66 | lr: learning rate 67 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 68 | t_total: total number of training steps for the learning 69 | rate schedule, -1 means constant learning rate. Default: -1 70 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 71 | b1: Adams b1. Default: 0.9 72 | b2: Adams b2. Default: 0.999 73 | e: Adams epsilon. Default: 1e-6 74 | weight_decay: Weight decay. Default: 0.01 75 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 76 | """ 77 | 78 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, 79 | schedule='warmup_linear', 80 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 81 | max_grad_norm=1.0): 82 | if lr is not required and lr < 0.0: 83 | raise ValueError( 84 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 85 | if schedule not in SCHEDULES: 86 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 87 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 88 | raise ValueError( 89 | "Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format( 90 | warmup)) 91 | if not 0.0 <= b1 < 1.0: 92 | raise ValueError( 93 | "Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 94 | if not 0.0 <= b2 < 1.0: 95 | raise ValueError( 96 | "Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 97 | if not e >= 0.0: 98 | raise ValueError( 99 | "Invalid epsilon value: {} - should be >= 0.0".format(e)) 100 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, 101 | t_total=t_total, 102 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 103 | max_grad_norm=max_grad_norm) 104 | super(BertAdam, self).__init__(params, defaults) 105 | 106 | def get_lr(self): 107 | lr = [] 108 | for group in self.param_groups: 109 | for p in group['params']: 110 | state = self.state[p] 111 | if len(state) == 0: 112 | return [0] 113 | if group['t_total'] != -1: 114 | schedule_fct = SCHEDULES[group['schedule']] 115 | lr_scheduled = group['lr'] * schedule_fct( 116 | state['step'] / group['t_total'], group['warmup']) 117 | else: 118 | lr_scheduled = group['lr'] 119 | lr.append(lr_scheduled) 120 | return lr 121 | 122 | def step(self, closure=None): 123 | """Performs a single optimization step. 124 | Arguments: 125 | closure (callable, optional): A closure that reevaluates the model 126 | and returns the loss. 127 | """ 128 | loss = None 129 | if closure is not None: 130 | loss = closure() 131 | 132 | for group in self.param_groups: 133 | for p in group['params']: 134 | if p.grad is None: 135 | continue 136 | grad = p.grad.data 137 | if grad.is_sparse: 138 | raise RuntimeError( 139 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 140 | state = self.state[p] 141 | device = p.device 142 | # State initialization 143 | if len(state) == 0: 144 | state['step'] = 0 145 | # Exponential moving average of gradient values 146 | state['next_m'] = torch.zeros_like(p.data) 147 | # Exponential moving average of squared gradient values 148 | state['next_v'] = torch.zeros_like(p.data) 149 | if 'is_embedding' in group and group['is_embedding']: 150 | vocab_size = p.data.size(0) 151 | state['b1_correction'] = torch.ones([vocab_size], 152 | device=device) 153 | state['b1_correction'][:] = group['b1'] 154 | state['b2_correction'] = torch.ones([vocab_size], 155 | device=device) 156 | state['b2_correction'][:] = group['b2'] 157 | state['ones'] = torch.ones([vocab_size], device=device) 158 | state['zeros'] = torch.zeros([vocab_size], 159 | device=device) 160 | 161 | state['b1'] = torch.ones([vocab_size], device=device) 162 | state['b1'][:] = group['b1'] 163 | state['b2'] = torch.ones([vocab_size], device=device) 164 | state['b2'][:] = group['b2'] 165 | 166 | next_m, next_v = state['next_m'], state['next_v'] 167 | beta1, beta2 = group['b1'], group['b2'] 168 | 169 | # Add grad clipping 170 | if group['max_grad_norm'] > 0: 171 | clip_grad_norm_(p, group['max_grad_norm']) 172 | 173 | # Decay the first and second moment running average coefficient 174 | # In-place operations to update the averages at the same time 175 | next_m.mul_(beta1).add_(1 - beta1, grad) 176 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 177 | update = next_m / (next_v.sqrt() + group['e']) 178 | 179 | # Just adding the square of the weights to the loss function is *not* 180 | # the correct way of using L2 regularization/weight decay with Adam, 181 | # since that will interact with the m and v parameters in strange ways. 182 | # 183 | # Instead we want to decay the weights in a manner that doesn't interact 184 | # with the m/v parameters. This is equivalent to adding the square 185 | # of the weights to the loss with plain (non-momentum) SGD. 186 | if group['weight_decay'] > 0.0: 187 | update += group['weight_decay'] * p.data 188 | 189 | if group['t_total'] != -1: 190 | schedule_fct = SCHEDULES[group['schedule']] 191 | lr_scheduled = group['lr'] * schedule_fct( 192 | state['step'] / group['t_total'], group['warmup']) 193 | else: 194 | lr_scheduled = group['lr'] 195 | 196 | if 'is_embedding' in group and group['is_embedding']: 197 | bias_correction1 = 1 - state['b1_correction'] 198 | bias_correction2 = 1 - state['b2_correction'] 199 | step_size = lr_scheduled * bias_correction2.sqrt() / bias_correction1 200 | step_size = step_size.unsqueeze(1) 201 | lr_scheduled = lr_scheduled * step_size 202 | grad_condition = torch.ge(torch.abs(grad).sum(1), 1e-6) 203 | 204 | update_embedding = torch.where(grad_condition, 205 | state['ones'], 206 | state['zeros']) 207 | lr_scheduled = lr_scheduled * update_embedding.unsqueeze(-1) 208 | beta1_tensor = torch.where(grad_condition, state['ones'], 209 | state['b1']) 210 | state['b1_correction'].mul_(beta1_tensor) 211 | beta2_tensor = torch.where(grad_condition, state['ones'], 212 | state['b2']) 213 | state['b2_correction'].mul_(beta2_tensor) 214 | 215 | update_with_lr = lr_scheduled * update 216 | p.data.add_(-update_with_lr) 217 | 218 | state['step'] += 1 219 | return loss 220 | 221 | -------------------------------------------------------------------------------- /dataset/collator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | """Collator for NeuralClassifier""" 16 | 17 | import torch 18 | 19 | from dataset.classification_dataset import ClassificationDataset as cDataset 20 | from util import Type 21 | 22 | 23 | class Collator(object): 24 | def __init__(self, device): 25 | self.device = device 26 | 27 | def __call__(self, batch): 28 | raise NotImplementedError 29 | 30 | 31 | class ClassificationType(Type): 32 | SINGLE_LABEL = "single_label" 33 | MULTI_LABEL = "multi_label" 34 | 35 | @classmethod 36 | def str(cls): 37 | return ",".join([cls.SINGLE_LABEL, cls.MULTI_LABEL]) 38 | 39 | 40 | class ClassificationCollator(Collator): 41 | def __init__(self, conf, label_size): 42 | super(ClassificationCollator, self).__init__(conf.device) 43 | self.classification_type = conf.task_info.label_type 44 | min_seq = 1 45 | if conf.model_name == "TextCNN": 46 | min_seq = conf.TextCNN.top_k_max_pooling 47 | elif conf.model_name == "DPCNN": 48 | min_seq = conf.DPCNN.kernel_size * 2 ** conf.DPCNN.blocks 49 | elif conf.model_name == "RegionEmbedding": 50 | min_seq = conf.feature.max_token_len 51 | self.min_token_max_len = min_seq 52 | self.min_char_max_len = min_seq 53 | self.label_size = label_size 54 | 55 | def _get_multi_hot_label(self, doc_labels): 56 | """For multi-label classification 57 | Generate multi-hot for input labels 58 | e.g. input: [[0,1], [2]] 59 | output: [[1,1,0], [0,0,1]] 60 | """ 61 | batch_size = len(doc_labels) 62 | max_label_num = max([len(x) for x in doc_labels]) 63 | doc_labels_extend = \ 64 | [[doc_labels[i][0] for x in range(max_label_num)] for i in range(batch_size)] 65 | for i in range(0, batch_size): 66 | doc_labels_extend[i][0 : len(doc_labels[i])] = doc_labels[i] 67 | y = torch.Tensor(doc_labels_extend).long() 68 | y_onehot = torch.zeros(batch_size, self.label_size).scatter_(1, y, 1) 69 | return y_onehot 70 | 71 | def _append_label(self, doc_labels, sample): 72 | if self.classification_type == ClassificationType.SINGLE_LABEL: 73 | assert len(sample[cDataset.DOC_LABEL]) == 1 74 | doc_labels.extend(sample[cDataset.DOC_LABEL]) 75 | elif self.classification_type == ClassificationType.MULTI_LABEL: 76 | doc_labels.append(sample[cDataset.DOC_LABEL]) 77 | else: 78 | raise TypeError( 79 | "Unsupported classification type: %s. Supported " 80 | "classification type is: %s" % 81 | (self.classification_type, ClassificationType.str())) 82 | 83 | def __call__(self, batch): 84 | def _append_vocab(ori_vocabs, vocabs, max_len): 85 | padding = [cDataset.VOCAB_PADDING] * (max_len - len(ori_vocabs)) 86 | vocabs.append(ori_vocabs + padding) 87 | 88 | doc_labels = [] 89 | 90 | doc_token = [] 91 | doc_char = [] 92 | doc_char_in_token = [] 93 | 94 | doc_token_len = [] 95 | doc_char_len = [] 96 | doc_char_in_token_len = [] 97 | 98 | doc_token_max_len = self.min_token_max_len 99 | doc_char_max_len = self.min_char_max_len 100 | doc_char_in_token_max_len = 0 101 | 102 | for _, value in enumerate(batch): 103 | doc_token_max_len = max(doc_token_max_len, 104 | len(value[cDataset.DOC_TOKEN])) 105 | doc_char_max_len = max(doc_char_max_len, 106 | len(value[cDataset.DOC_CHAR])) 107 | for char_in_token in value[cDataset.DOC_CHAR_IN_TOKEN]: 108 | doc_char_in_token_max_len = max(doc_char_in_token_max_len, 109 | len(char_in_token)) 110 | 111 | for _, value in enumerate(batch): 112 | self._append_label(doc_labels, value) 113 | _append_vocab(value[cDataset.DOC_TOKEN], doc_token, 114 | doc_token_max_len) 115 | doc_token_len.append(len(value[cDataset.DOC_TOKEN])) 116 | _append_vocab(value[cDataset.DOC_CHAR], doc_char, doc_char_max_len) 117 | doc_char_len.append(len(value[cDataset.DOC_CHAR])) 118 | 119 | doc_char_in_token_len_tmp = [] 120 | for char_in_token in value[cDataset.DOC_CHAR_IN_TOKEN]: 121 | _append_vocab(char_in_token, doc_char_in_token, 122 | doc_char_in_token_max_len) 123 | doc_char_in_token_len_tmp.append(len(char_in_token)) 124 | 125 | padding = [cDataset.VOCAB_PADDING] * doc_char_in_token_max_len 126 | for _ in range( 127 | len(value[cDataset.DOC_CHAR_IN_TOKEN]), doc_token_max_len): 128 | doc_char_in_token.append(padding) 129 | doc_char_in_token_len_tmp.append(0) 130 | doc_char_in_token_len.append(doc_char_in_token_len_tmp) 131 | 132 | if self.classification_type == ClassificationType.SINGLE_LABEL: 133 | tensor_doc_labels = torch.tensor(doc_labels) 134 | doc_label_list = [[x] for x in doc_labels] 135 | elif self.classification_type == ClassificationType.MULTI_LABEL: 136 | tensor_doc_labels = self._get_multi_hot_label(doc_labels) 137 | doc_label_list = doc_labels 138 | 139 | batch_map = { 140 | cDataset.DOC_LABEL: tensor_doc_labels, 141 | cDataset.DOC_LABEL_LIST: doc_label_list, 142 | 143 | cDataset.DOC_TOKEN: torch.tensor(doc_token), 144 | cDataset.DOC_CHAR: torch.tensor(doc_char), 145 | cDataset.DOC_CHAR_IN_TOKEN: torch.tensor(doc_char_in_token), 146 | 147 | cDataset.DOC_TOKEN_MASK: torch.tensor(doc_token).gt(0).float(), 148 | cDataset.DOC_CHAR_MASK: torch.tensor(doc_char).gt(0).float(), 149 | cDataset.DOC_CHAR_IN_TOKEN_MASK: 150 | torch.tensor(doc_char_in_token).gt(0).float(), 151 | 152 | cDataset.DOC_TOKEN_LEN: torch.tensor( 153 | doc_token_len, dtype=torch.float32), 154 | cDataset.DOC_CHAR_LEN: torch.tensor( 155 | doc_char_len, dtype=torch.float32), 156 | cDataset.DOC_CHAR_IN_TOKEN_LEN: torch.tensor( 157 | doc_char_in_token_len, dtype=torch.float32), 158 | 159 | cDataset.DOC_TOKEN_MAX_LEN: 160 | torch.tensor([doc_token_max_len], dtype=torch.float32), 161 | cDataset.DOC_CHAR_MAX_LEN: 162 | torch.tensor([doc_char_max_len], dtype=torch.float32), 163 | cDataset.DOC_CHAR_IN_TOKEN_MAX_LEN: 164 | torch.tensor([doc_char_in_token_max_len], dtype=torch.float32) 165 | } 166 | return batch_map 167 | 168 | 169 | class FastTextCollator(ClassificationCollator): 170 | """FastText Collator 171 | Extra support features: token, token-ngrams, keywords, topics. 172 | """ 173 | def __call__(self, batch): 174 | def _append_vocab(sample, vocabs, offsets, lens, name): 175 | filtered_vocab = [x for x in sample[name] if 176 | x is not cDataset.VOCAB_UNKNOWN] 177 | vocabs.extend(filtered_vocab) 178 | offsets.append(offsets[-1] + len(filtered_vocab)) 179 | lens.append(len(filtered_vocab)) 180 | 181 | doc_labels = [] 182 | 183 | doc_tokens = [] 184 | doc_token_ngrams = [] 185 | doc_keywords = [] 186 | doc_topics = [] 187 | 188 | doc_tokens_offset = [0] 189 | doc_token_ngrams_offset = [0] 190 | doc_keywords_offset = [0] 191 | doc_topics_offset = [0] 192 | 193 | doc_tokens_len = [] 194 | doc_token_ngrams_len = [] 195 | doc_keywords_len = [] 196 | doc_topics_len = [] 197 | for _, value in enumerate(batch): 198 | self._append_label(doc_labels, value) 199 | _append_vocab(value, doc_tokens, doc_tokens_offset, 200 | doc_tokens_len, 201 | cDataset.DOC_TOKEN) 202 | _append_vocab(value, doc_token_ngrams, doc_token_ngrams_offset, 203 | doc_token_ngrams_len, 204 | cDataset.DOC_TOKEN_NGRAM) 205 | _append_vocab(value, doc_keywords, doc_keywords_offset, 206 | doc_keywords_len, cDataset.DOC_KEYWORD) 207 | _append_vocab(value, doc_topics, doc_topics_offset, 208 | doc_topics_len, cDataset.DOC_TOPIC) 209 | doc_tokens_offset.pop() 210 | doc_token_ngrams_offset.pop() 211 | doc_keywords_offset.pop() 212 | doc_topics_offset.pop() 213 | 214 | if self.classification_type == ClassificationType.SINGLE_LABEL: 215 | tensor_doc_labels = torch.tensor(doc_labels) 216 | doc_label_list = [[x] for x in doc_labels] 217 | elif self.classification_type == ClassificationType.MULTI_LABEL: 218 | tensor_doc_labels = self._get_multi_hot_label(doc_labels) 219 | doc_label_list = doc_labels 220 | 221 | batch_map = { 222 | cDataset.DOC_LABEL: tensor_doc_labels, 223 | cDataset.DOC_LABEL_LIST: doc_label_list, 224 | 225 | cDataset.DOC_TOKEN: torch.tensor(doc_tokens), 226 | cDataset.DOC_TOKEN_NGRAM: torch.tensor(doc_token_ngrams), 227 | cDataset.DOC_KEYWORD: torch.tensor(doc_keywords), 228 | cDataset.DOC_TOPIC: torch.tensor(doc_topics), 229 | 230 | cDataset.DOC_TOKEN_OFFSET: torch.tensor(doc_tokens_offset), 231 | cDataset.DOC_TOKEN_NGRAM_OFFSET: 232 | torch.tensor(doc_token_ngrams_offset), 233 | cDataset.DOC_KEYWORD_OFFSET: torch.tensor(doc_keywords_offset), 234 | cDataset.DOC_TOPIC_OFFSET: torch.tensor(doc_topics_offset), 235 | 236 | cDataset.DOC_TOKEN_LEN: 237 | torch.tensor(doc_tokens_len, dtype=torch.float32), 238 | cDataset.DOC_TOKEN_NGRAM_LEN: 239 | torch.tensor(doc_token_ngrams_len, dtype=torch.float32), 240 | cDataset.DOC_KEYWORD_LEN: 241 | torch.tensor(doc_keywords_len, dtype=torch.float32), 242 | cDataset.DOC_TOPIC_LEN: 243 | torch.tensor(doc_topics_len, dtype=torch.float32)} 244 | return batch_map 245 | -------------------------------------------------------------------------------- /model/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | 19 | from model.model_util import ActivationType 20 | from model.model_util import FAN_MODE 21 | from model.model_util import InitType 22 | from model.model_util import init_tensor 23 | from util import Logger 24 | from util import Type 25 | 26 | 27 | class EmbeddingType(Type): 28 | """Standard names for embedding type 29 | The following keys are defined: 30 | * `EMBEDDING`: Return the embedding after lookup. 31 | * `REGION_EMBEDDING`: Return the region embedding. 32 | Reference: A New Method of Region Embedding for Text Classification 33 | """ 34 | EMBEDDING = 'embedding' 35 | REGION_EMBEDDING = 'region_embedding' 36 | 37 | @classmethod 38 | def str(cls): 39 | return ",".join([cls.EMBEDDING, cls.REGION_EMBEDDING]) 40 | 41 | 42 | class EmbeddingProcessType(Type): 43 | """Standard names for embedding mode 44 | Given the vocab tensor shape[batch_size, sequence_len]. 45 | The following keys are defined: 46 | * `FLAT`: Normal mode, return tensor shape will be 47 | * [batch_size, sequence_len, embedding_size] 48 | * `MEAN`: Mean mode, return tensor shape will be 49 | * [batch_size, embedding_size] 50 | * `SUM`: Sum mode, return tensor shape will be 51 | * [batch_size, embedding_size] 52 | """ 53 | FLAT = 'flat' 54 | MEAN = 'mean' 55 | SUM = 'sum' 56 | 57 | @classmethod 58 | def str(cls): 59 | return ",".join([cls.FLAT, cls.MEAN, cls.SUM]) 60 | 61 | 62 | class Embedding(torch.nn.Module): 63 | def __init__(self, dict_map, embedding_dim, name, config, padding_idx=None, 64 | pretrained_embedding_file=None, mode=EmbeddingProcessType.FLAT, 65 | dropout=0, init_type=InitType.XAVIER_UNIFORM, low=0, high=1, 66 | mean=0, std=1, activation_type=ActivationType.NONE, 67 | fan_mode=FAN_MODE.FAN_IN, negative_slope=0): 68 | super(Embedding, self).__init__() 69 | self.logger = Logger(config) 70 | self.dropout = torch.nn.Dropout(p=dropout) 71 | self.mode = mode 72 | if self.mode == EmbeddingProcessType.FLAT: 73 | self.embedding = torch.nn.Embedding( 74 | len(dict_map), embedding_dim, padding_idx=padding_idx) 75 | else: 76 | self.embedding = torch.nn.EmbeddingBag( 77 | len(dict_map), embedding_dim, mode=mode) 78 | embedding_lookup_table = init_tensor( 79 | tensor=torch.empty(len(dict_map), embedding_dim), 80 | init_type=init_type, low=low, high=high, mean=mean, std=std, 81 | activation_type=activation_type, fan_mode=fan_mode, 82 | negative_slope=negative_slope) 83 | if pretrained_embedding_file is not None and \ 84 | pretrained_embedding_file != "": 85 | self.load_pretrained_embedding( 86 | embedding_lookup_table, dict_map, embedding_dim, name, 87 | pretrained_embedding_file) 88 | if padding_idx is not None: 89 | embedding_lookup_table[padding_idx] = 0.0 90 | self.embedding.weight.data.copy_(embedding_lookup_table) 91 | 92 | def forward(self, vocab_ids, offset=None): 93 | if self.mode == EmbeddingProcessType.FLAT: 94 | embedding = self.embedding(vocab_ids) 95 | else: 96 | embedding = self.embedding(vocab_ids, offset) 97 | return self.dropout(embedding) 98 | 99 | def load_pretrained_embedding( 100 | self, embedding_lookup_table, dict_map, embedding_dim, name, 101 | pretrained_embedding_file): 102 | self.logger.warn( 103 | "Load %s embedding from %s" % (name, pretrained_embedding_file)) 104 | with open(pretrained_embedding_file) as fin: 105 | num_pretrained = 0 106 | for line in fin: 107 | data = line.strip().split(' ') 108 | # Check embedding info 109 | if len(data) == 2: 110 | assert int(data[1]) == embedding_dim, \ 111 | "Pretrained embedding dim not matching: %s, %d" % ( 112 | data[1], embedding_dim) 113 | continue 114 | if data[0] not in dict_map: 115 | continue 116 | embedding = torch.FloatTensor([float(i) for i in data[1:]]) 117 | embedding_lookup_table[dict_map[data[0]]] = embedding 118 | num_pretrained += 1 119 | self.logger.warn( 120 | "Total dict size of %s is %d" % (name, len(dict_map))) 121 | self.logger.warn("Size of pretrained %s embedding is %d" % ( 122 | name, num_pretrained)) 123 | self.logger.warn( 124 | "Size of randomly initialize %s embedding is %d" % ( 125 | name, len(dict_map) - num_pretrained)) 126 | 127 | 128 | class RegionEmbeddingType(Type): 129 | """Standard names for region embedding type 130 | """ 131 | WC = 'word_context' 132 | CW = 'context_word' 133 | 134 | @classmethod 135 | def str(cls): 136 | return ",".join([cls.WC, cls.CW]) 137 | 138 | 139 | class RegionEmbeddingLayer(torch.nn.Module): 140 | """ 141 | Reference: A New Method of Region Embedding for Text Classification 142 | """ 143 | 144 | def __init__(self, dict_map, embedding_dim, region_size, name, config, 145 | padding=None, pretrained_embedding_file=None, dropout=0, 146 | init_type=InitType.XAVIER_UNIFORM, low=0, high=1, mean=0, 147 | std=1, fan_mode=FAN_MODE.FAN_IN, 148 | region_embedding_type=RegionEmbeddingType.WC): 149 | super(RegionEmbeddingLayer, self).__init__() 150 | self.region_embedding_type = region_embedding_type 151 | self.region_size = region_size 152 | assert self.region_size % 2 == 1 153 | self.radius = int(region_size / 2) 154 | self.embedding_dim = embedding_dim 155 | self.embedding = Embedding( 156 | dict_map, embedding_dim, "RegionWord" + name, config=config, 157 | padding_idx=padding, 158 | pretrained_embedding_file=pretrained_embedding_file, 159 | dropout=dropout, init_type=init_type, low=low, high=high, mean=mean, 160 | std=std, fan_mode=fan_mode) 161 | self.context_embedding = Embedding( 162 | dict_map, embedding_dim * region_size, "RegionContext" + name, 163 | config=config, padding_idx=padding, dropout=dropout, 164 | init_type=init_type, low=low, high=high, mean=mean, std=std, 165 | fan_mode=fan_mode) 166 | 167 | def forward(self, vocab_ids): 168 | seq_length = vocab_ids.size(1) 169 | actual_length = vocab_ids.size(1) - self.radius * 2 170 | trim_vocab_id = vocab_ids[:, self.radius:seq_length - self.radius] 171 | slice_vocabs = \ 172 | [vocab_ids[:, i:i + self.region_size] for i in 173 | range(actual_length)] 174 | slice_vocabs = torch.cat(slice_vocabs, 1) 175 | slice_vocabs = \ 176 | slice_vocabs.view(-1, actual_length, self.region_size) 177 | 178 | if self.region_embedding_type == RegionEmbeddingType.WC: 179 | vocab_embedding = self.embedding(slice_vocabs) 180 | context_embedding = self.context_embedding(trim_vocab_id) 181 | context_embedding = context_embedding.view( 182 | -1, actual_length, self.region_size, self.embedding_dim) 183 | region_embedding = vocab_embedding * context_embedding 184 | region_embedding, _ = region_embedding.max(2) 185 | elif self.region_embedding_type == RegionEmbeddingType.CW: 186 | vocab_embedding = self.embedding(trim_vocab_id).unsqueeze(2) 187 | context_embedding = self.context_embedding(slice_vocabs) 188 | size = context_embedding.size() 189 | context_embedding = context_embedding.view( 190 | size[0], size[1], size[2], self.region_size, self.embedding_dim) 191 | mask = torch.ones( 192 | [self.region_size, self.region_size, self.embedding_dim]) 193 | 194 | for i in range(self.region_size): 195 | mask[i][self.region_size - i - 1] = 0. 196 | neg_mask = mask * -65500.0 197 | mask = mask.le(0).float() 198 | mask = mask.unsqueeze(0).unsqueeze(0) 199 | context_embedding = context_embedding * mask 200 | context_embedding = context_embedding + neg_mask 201 | context_embedding, _ = context_embedding.max(3) 202 | region_embedding = vocab_embedding * context_embedding 203 | region_embedding, _ = region_embedding.max(2) 204 | else: 205 | raise TypeError( 206 | "Unsupported region embedding type: %s." % 207 | self.region_embedding_type) 208 | 209 | return region_embedding 210 | 211 | 212 | class PositionEmbedding(torch.nn.Module): 213 | ''' Reference: attention is all you need ''' 214 | 215 | def __init__(self, seq_max_len, embedding_dim, padding_idx): 216 | super(PositionEmbedding, self).__init__() 217 | 218 | self.position_enc = nn.Embedding.from_pretrained( 219 | self.get_sinusoid_encoding_table(seq_max_len + 1, 220 | embedding_dim, 221 | padding_idx=padding_idx), 222 | freeze=True) 223 | 224 | def forward(self, src_pos): 225 | return self.position_enc(src_pos) 226 | 227 | @staticmethod 228 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 229 | def cal_angle(position, hid_idx): 230 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 231 | 232 | def get_posi_angle_vec(position): 233 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 234 | 235 | sinusoid_table = np.array( 236 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 237 | 238 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 239 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 240 | 241 | if padding_idx is not None: 242 | # zero vector for padding dimension 243 | sinusoid_table[padding_idx] = 0. 244 | 245 | return torch.FloatTensor(sinusoid_table) 246 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | import json 16 | import os 17 | 18 | import torch 19 | 20 | from util import Logger 21 | from util import ModeType 22 | from util import Type 23 | from torch.utils.data import TensorDataset 24 | 25 | 26 | class InsertVocabMode(Type): 27 | """Standard names for embedding mode 28 | Given the vocab tensor shape[batch_size, sequence_len]. 29 | The following keys are defined: 30 | * `FLAT`: Normal mode, return tensor shape will be 31 | * [batch_size, sequence_len, embedding_size] 32 | * `MEAN`: Mean mode, return tensor shape will be 33 | * [batch_size, embedding_size] 34 | * `SUM`: Sum mode, return tensor shape will be 35 | * [batch_size, embedding_size] 36 | """ 37 | ALL = 'all' 38 | LABEL = 'label' 39 | OTHER = 'other' 40 | 41 | def str(self): 42 | return ",".join( 43 | [self.ALL, self.LABEL, self.OTHER]) 44 | 45 | 46 | class DatasetBase(torch.utils.data.dataset.Dataset): 47 | #class DatasetBase(TensorDataset): 48 | """Base dataset class 49 | """ 50 | CHARSET = "utf-8" 51 | 52 | VOCAB_PADDING = 0 # Embedding is all zero and not learnable 53 | VOCAB_UNKNOWN = 1 54 | VOCAB_PADDING_LEARNABLE = 2 # Embedding is random initialized and learnable 55 | 56 | BIG_VALUE = 1000 * 1000 * 1000 57 | 58 | def __init__(self, config, json_files, generate_dict=False, 59 | mode=ModeType.EVAL): 60 | """ 61 | Another way to do this is keep the file handler. But when DataLoader's 62 | num_worker bigger than 1, error will occur. 63 | Args: 64 | config: 65 | """ 66 | 67 | self.config = config 68 | self.logger = Logger(config) 69 | self._init_dict() 70 | self.sample_index = [] 71 | self.sample_size = 0 72 | self.mode = mode 73 | 74 | self.files = json_files 75 | for i, json_file in enumerate(json_files): 76 | with open(json_file) as fin: 77 | self.sample_index.append([i, 0]) 78 | while True: 79 | json_str = fin.readline() 80 | if not json_str: 81 | self.sample_index.pop() 82 | break 83 | self.sample_size += 1 84 | self.sample_index.append([i, fin.tell()]) 85 | 86 | def _insert_vocab(files, _mode=InsertVocabMode.ALL): 87 | for _i, _json_file in enumerate(files): 88 | with open(_json_file) as _fin: 89 | for _json_str in _fin: 90 | try: 91 | self._insert_vocab(json.loads(_json_str), mode) 92 | except: 93 | print(_json_str) 94 | 95 | # Dict can be generated using: 96 | # json files or/and pretrained embedding 97 | if generate_dict: 98 | # Use train json files to generate dict 99 | # If generate_dict_using_json_files is true, then all vocab in train 100 | # will be used, else only part vocab will be used. e.g. label 101 | vocab_json_files = config.data.train_json_files 102 | mode = InsertVocabMode.LABEL 103 | if self.config.data.generate_dict_using_json_files: 104 | mode = InsertVocabMode.ALL 105 | self.logger.info("Use dataset to generate dict.") 106 | _insert_vocab(vocab_json_files, mode) 107 | 108 | if self.config.data.generate_dict_using_all_json_files: 109 | vocab_json_files += self.config.data.validate_json_files + \ 110 | self.config.data.test_json_files 111 | _insert_vocab(vocab_json_files, InsertVocabMode.OTHER) 112 | 113 | if self.config.data.generate_dict_using_pretrained_embedding: 114 | self.logger.info("Use pretrained embedding to generate dict.") 115 | self._load_pretrained_dict() 116 | self._print_dict_info() 117 | 118 | self._shrink_dict() 119 | self.logger.info("Shrink dict over.") 120 | self._print_dict_info(True) 121 | #self._save_dict() 122 | self._clear_dict() 123 | self._load_dict() 124 | 125 | def __len__(self): 126 | return self.sample_size 127 | 128 | def __getitem__(self, idx): 129 | if idx >= self.sample_size: 130 | raise IndexError 131 | index = self.sample_index[idx] 132 | with open(self.files[index[0]]) as fin: 133 | fin.seek(index[1]) 134 | json_str = fin.readline() 135 | return self._get_vocab_id_list(json.loads(json_str)) 136 | 137 | def _init_dict(self): 138 | """Init all dict 139 | """ 140 | raise NotImplementedError 141 | 142 | def _save_dict(self, dict_name=None): 143 | """Save vocab to file and generate id_to_vocab_dict_map 144 | Args: 145 | dict_name: Dict name, if None save all dict. Default None. 146 | """ 147 | if dict_name is None: 148 | if not os.path.exists(self.config.data.dict_dir): 149 | os.makedirs(self.config.data.dict_dir) 150 | for name in self.dict_names: 151 | self._save_dict(name) 152 | else: 153 | dict_idx = self.dict_names.index(dict_name) 154 | dict_file = open(self.dict_files[dict_idx], "w") 155 | id_to_vocab_dict_map = self.id_to_vocab_dict_list[dict_idx] 156 | index = 0 157 | for vocab, count in self.count_list[dict_idx]: 158 | id_to_vocab_dict_map[index] = vocab 159 | index += 1 160 | dict_file.write("%s\t%d\n" % (vocab, count)) 161 | dict_file.close() 162 | 163 | def _load_dict(self, dict_name=None): 164 | """Load dict from file. 165 | Args: 166 | dict_name: Dict name, if None load all dict. Default None. 167 | Returns: 168 | dict. 169 | """ 170 | if dict_name is None: 171 | for name in self.dict_names: 172 | self._load_dict(name) 173 | else: 174 | dict_idx = self.dict_names.index(dict_name) 175 | if not os.path.exists(self.dict_files[dict_idx]): 176 | self.logger.warn("Not exists %s for %s" % ( 177 | self.dict_files[dict_idx], dict_name)) 178 | else: 179 | dict_map = self.dicts[dict_idx] 180 | id_to_vocab_dict_map = self.id_to_vocab_dict_list[dict_idx] 181 | if dict_name != self.DOC_LABEL: 182 | dict_map[self.VOCAB_PADDING] = 0 183 | dict_map[self.VOCAB_UNKNOWN] = 1 184 | dict_map[self.VOCAB_PADDING_LEARNABLE] = 2 185 | id_to_vocab_dict_map[0] = self.VOCAB_PADDING 186 | id_to_vocab_dict_map[1] = self.VOCAB_UNKNOWN 187 | id_to_vocab_dict_map[2] = self.VOCAB_PADDING_LEARNABLE 188 | 189 | for line in open(self.dict_files[dict_idx], "rb"): 190 | vocab = line.decode().strip("\n").split("\t") 191 | dict_idx = len(dict_map) 192 | dict_map[vocab[0]] = dict_idx 193 | id_to_vocab_dict_map[dict_idx] = vocab[0] 194 | 195 | def _load_pretrained_dict(self, dict_name=None, 196 | pretrained_file=None, min_count=0): 197 | """Use pretrained embedding to generate dict 198 | """ 199 | if dict_name is None: 200 | for i, _ in enumerate(self.pretrained_dict_names): 201 | self._load_pretrained_dict( 202 | self.pretrained_dict_names[i], 203 | self.pretrained_dict_files[i], 204 | self.pretrained_min_count[i]) 205 | 206 | else: 207 | if pretrained_file is None or pretrained_file == "": 208 | return 209 | index = self.dict_names.index(dict_name) 210 | dict_map = self.dicts[index] 211 | with open(pretrained_file) as fin: 212 | for line in fin: 213 | data = line.strip().split(' ') 214 | if len(data) == 2: 215 | continue 216 | if data[0] not in dict_map: 217 | dict_map[data[0]] = 0 218 | dict_map[data[0]] += min_count + 1 219 | 220 | def _insert_vocab(self, json_obj, mode=InsertVocabMode.ALL): 221 | """Insert vocab to dict 222 | """ 223 | raise NotImplementedError 224 | 225 | def _shrink_dict(self, dict_name=None): 226 | if dict_name is None: 227 | for name in self.dict_names: 228 | self._shrink_dict(name) 229 | else: 230 | dict_idx = self.dict_names.index(dict_name) 231 | self.count_list[dict_idx] = sorted(self.dicts[dict_idx].items(), 232 | key=lambda x: (x[1], x[0]), 233 | reverse=True) 234 | self.count_list[dict_idx] = \ 235 | [(k, v) for k, v in self.count_list[dict_idx] if 236 | v >= self.min_count[dict_idx]][0:self.max_dict_size[dict_idx]] 237 | 238 | def _clear_dict(self): 239 | """Clear all dict 240 | """ 241 | for dict_map in self.dicts: 242 | dict_map.clear() 243 | for id_to_vocab_dict in self.id_to_vocab_dict_list: 244 | id_to_vocab_dict.clear() 245 | 246 | def _print_dict_info(self, count_list=False): 247 | """Print dict info 248 | """ 249 | for i, dict_name in enumerate(self.dict_names): 250 | if count_list: 251 | self.logger.info( 252 | "Size of %s dict is %d" % ( 253 | dict_name, len(self.count_list[i]))) 254 | else: 255 | self.logger.info( 256 | "Size of %s dict is %d" % (dict_name, len(self.dicts[i]))) 257 | 258 | def _insert_sequence_tokens(self, sequence_tokens, token_map, 259 | token_ngram_map, char_map, ngram=0): 260 | for token in sequence_tokens: 261 | for char in token: 262 | self._add_vocab_to_dict(char_map, char) 263 | self._add_vocab_to_dict(token_map, token) 264 | if ngram > 1: 265 | for j in range(2, ngram + 1): 266 | for token_ngram in ["".join(sequence_tokens[k:k + j]) for k in 267 | range(len(sequence_tokens) - j + 1)]: 268 | self._add_vocab_to_dict(token_ngram_map, 269 | token_ngram) 270 | 271 | def _insert_sequence_vocab(self, sequence_vocabs, dict_map): 272 | for vocab in sequence_vocabs: 273 | self._add_vocab_to_dict(dict_map, vocab) 274 | 275 | @staticmethod 276 | def _add_vocab_to_dict(dict_map, vocab): 277 | if vocab not in dict_map: 278 | dict_map[vocab] = 0 279 | dict_map[vocab] += 1 280 | 281 | def _get_vocab_id_list(self, json_obj): 282 | """Use dict to convert all vocabs to ids 283 | """ 284 | return json_obj 285 | 286 | def _label_to_id(self, sequence_labels, dict_map): 287 | """Convert label to id. The reason that label is not in label map may be 288 | label is filtered or label in validate/test does not occur in train set 289 | """ 290 | label_id_list = [] 291 | for label in sequence_labels: 292 | if label not in dict_map: 293 | self.logger.warn("Label not in label map: %s" % label) 294 | else: 295 | label_id_list.append(self.label_map[label]) 296 | assert label_id_list, "Label is empty: %s" % " ".join(sequence_labels) 297 | 298 | return label_id_list 299 | 300 | def _token_to_id(self, sequence_tokens, token_map, char_map, ngram=0, 301 | token_ngram_map=None, max_char_sequence_length=-1, 302 | max_char_length_per_token=-1): 303 | """Convert token to id. Vocab not in dict map will be map to _UNK 304 | """ 305 | token_id_list = [] 306 | char_id_list = [] 307 | char_in_token_id_list = [] 308 | ngram_id_list = [] 309 | for token in sequence_tokens: 310 | char_id = [char_map.get(x, self.VOCAB_UNKNOWN) for x in token] 311 | char_id_list.extend(char_id[0:max_char_sequence_length]) 312 | char_in_token = [char_map.get(x, self.VOCAB_UNKNOWN) 313 | for x in token[0:max_char_length_per_token]] 314 | char_in_token_id_list.append(char_in_token) 315 | 316 | token_id_list.append( 317 | token_map.get(token, token_map[self.VOCAB_UNKNOWN])) 318 | if ngram > 1: 319 | for j in range(2, ngram + 1): 320 | ngram_id_list.extend( 321 | token_ngram_map[x] for x in 322 | ["".join(sequence_tokens[k:k + j]) for k in 323 | range(len(sequence_tokens) - j + 1)] if x in 324 | token_ngram_map) 325 | if not sequence_tokens: 326 | token_id_list.append(self.VOCAB_PADDING) 327 | char_id_list.append(self.VOCAB_PADDING) 328 | char_in_token_id_list.append([self.VOCAB_PADDING]) 329 | if not ngram_id_list: 330 | ngram_id_list.append(token_ngram_map[self.VOCAB_PADDING]) 331 | return token_id_list, char_id_list, char_in_token_id_list, ngram_id_list 332 | 333 | def _vocab_to_id(self, sequence_vocabs, dict_map): 334 | """Convert vocab to id. Vocab not in dict map will be map to _UNK 335 | """ 336 | vocab_id_list = \ 337 | [dict_map.get(x, self.VOCAB_UNKNOWN) for x in sequence_vocabs] 338 | if not vocab_id_list: 339 | vocab_id_list.append(self.VOCAB_PADDING) 340 | return vocab_id_list 341 | -------------------------------------------------------------------------------- /evaluate/classification_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """ 4 | Tencent is pleased to support the open source community by making NeuralClassifier available. 5 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 6 | Licensed under the MIT License (the "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | http://opensource.org/licenses/MIT 9 | Unless required by applicable law or agreed to in writing, software distributed under the License 10 | is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 | or implied. See the License for thespecific language governing permissions and limitations under 12 | the License. 13 | """ 14 | 15 | # Provide function that calculate the precision, recall, F1-score 16 | # and output confusion_matrix. 17 | 18 | 19 | import json 20 | import os 21 | 22 | import numpy as np 23 | 24 | from dataset.classification_dataset import ClassificationDataset as cDataset 25 | 26 | 27 | class ClassificationEvaluator(object): 28 | MACRO_AVERAGE = "macro_average" 29 | MICRO_AVERAGE = "micro_average" 30 | """Not thread safe, will keep the latest eval result 31 | """ 32 | 33 | def __init__(self, eval_dir): 34 | self.confusion_matrix_list = None 35 | self.precision_list = None 36 | self.recall_list = None 37 | self.fscore_list = None 38 | self.right_list = None 39 | self.predict_list = None 40 | self.standard_list = None 41 | 42 | self.eval_dir = eval_dir 43 | if not os.path.exists(self.eval_dir): 44 | os.makedirs(self.eval_dir) 45 | 46 | @staticmethod 47 | def _calculate_prf(right_count, predict_count, standard_count): 48 | """Calculate precision, recall, fscore 49 | Args: 50 | standard_count: Standard count 51 | predict_count: Predict count 52 | right_count: Right count 53 | Returns: 54 | precision, recall, f_score 55 | """ 56 | precision, recall, f_score = 0, 0, 0 57 | if predict_count > 0: 58 | precision = right_count / predict_count 59 | if standard_count > 0: 60 | recall = right_count / standard_count 61 | if precision + recall > 0: 62 | f_score = precision * recall * 2 / (precision + recall) 63 | 64 | return precision, recall, f_score 65 | 66 | @staticmethod 67 | def _judge_label_in(label_name, label_to_id_maps): 68 | cnt = 0 69 | for label in label_name: 70 | for i in range(0, len(label_to_id_maps)): 71 | if label in label_to_id_maps[i]: 72 | cnt += 1 73 | break 74 | return cnt == len(label_name) 75 | def calculate_level_performance( 76 | self, id_to_label_map, right_count_category, predict_count_category, 77 | standard_count_category, other_text='其他', 78 | exclude_method="contain"): 79 | """Calculate the level performance. 80 | Args: 81 | id_to_label_map: Label id to label name. 82 | other_text: Text to judge the other label. 83 | right_count_category: Right count. 84 | predict_count_category: Predict count. 85 | standard_count_category: Standard count. 86 | exclude_method: The method to judge the other label. Can be 87 | contain(label_name contains other_text) or 88 | start(label_name start with other_text). 89 | Returns: 90 | precision_dict, recall_dict, fscore_dict. 91 | """ 92 | other_label = dict() 93 | for _, label_name in id_to_label_map.items(): 94 | if exclude_method == "contain": 95 | if other_text in label_name: 96 | other_label[label_name] = 1 97 | elif exclude_method == "start": 98 | if label_name.startswith(other_text): 99 | other_label[label_name] = 1 100 | else: 101 | raise TypeError( 102 | "Cannot find exclude_method: " + 103 | exclude_method) 104 | 105 | precision_dict = dict() 106 | recall_dict = dict() 107 | fscore_dict = dict() 108 | precision_dict[self.MACRO_AVERAGE] = 0 109 | recall_dict[self.MACRO_AVERAGE] = 0 110 | fscore_dict[self.MACRO_AVERAGE] = 0 111 | right_total = 0 112 | predict_total = 0 113 | standard_total = 0 114 | 115 | for _, label_name in id_to_label_map.items(): 116 | if label_name in other_label: 117 | continue 118 | precision_dict[label_name], recall_dict[label_name], \ 119 | fscore_dict[label_name] = self._calculate_prf( 120 | right_count_category[label_name], 121 | predict_count_category[label_name], 122 | standard_count_category[label_name]) 123 | right_total += right_count_category[label_name] 124 | predict_total += predict_count_category[label_name] 125 | standard_total += standard_count_category[label_name] 126 | precision_dict[self.MACRO_AVERAGE] += precision_dict[label_name] 127 | recall_dict[self.MACRO_AVERAGE] += recall_dict[label_name] 128 | fscore_dict[self.MACRO_AVERAGE] += fscore_dict[label_name] 129 | num_label_eval = len(id_to_label_map) - len(other_label) 130 | 131 | precision_dict[self.MACRO_AVERAGE] = \ 132 | precision_dict[self.MACRO_AVERAGE] / num_label_eval 133 | recall_dict[self.MACRO_AVERAGE] = \ 134 | recall_dict[self.MACRO_AVERAGE] / num_label_eval 135 | fscore_dict[self.MACRO_AVERAGE] = 0 \ 136 | if (recall_dict[self.MACRO_AVERAGE] + 137 | precision_dict[self.MACRO_AVERAGE]) == 0 else \ 138 | 2 * precision_dict[self.MACRO_AVERAGE] * \ 139 | recall_dict[self.MACRO_AVERAGE] / \ 140 | (recall_dict[self.MACRO_AVERAGE] 141 | + precision_dict[self.MACRO_AVERAGE]) 142 | 143 | right_count_category[self.MICRO_AVERAGE] = right_total 144 | predict_count_category[self.MICRO_AVERAGE] = predict_total 145 | standard_count_category[self.MICRO_AVERAGE] = standard_total 146 | 147 | (precision_dict[self.MICRO_AVERAGE], recall_dict[self.MICRO_AVERAGE], 148 | fscore_dict[self.MICRO_AVERAGE]) = \ 149 | self._calculate_prf(right_total, predict_total, standard_total) 150 | return precision_dict, recall_dict, fscore_dict 151 | 152 | def evaluate(self, predicts, standard_label_names=None, 153 | standard_label_ids=None, label_map=None, threshold=0, top_k=3, 154 | is_prob=True, is_flat=False, is_multi=False, other_text='其他'): 155 | """Eval the predict result. 156 | Args: 157 | predicts: Predict probability or 158 | predict text label(is_prob is false) 159 | fmt: 160 | if is_multi: [[p1,p2],[p2],[p3], ...] 161 | else: [[p1], [p2], [p3], ...] 162 | standard_label_names: Standard label names. If standard_label_names 163 | is None, standard_label_ids should be given. 164 | standard_label_ids: Standard label ids. If standard_label_ids 165 | is None, standard_label_names should be given. 166 | label_map: Label dict. If is_prob is false and label_map is None, 167 | label_map will be generated using labels. 168 | threshold: Threshold to filter probs. 169 | top_k: if is_multi true, top_k is used for truncating the predicts. 170 | is_prob: The predict is prob list or label id. 171 | is_flat: If true, only calculate flat result. 172 | Else, calculate hierarchical result. 173 | is_multi: multi-label evaluation. 174 | other_text: Label name contains other_text will not be calculate. 175 | Returns: 176 | confusion_matrix_list contain all result, 177 | filtered_confusion_matrix_list contains result that max predict prob 178 | is greater than threshold and will be used to calculate prf, 179 | precision_list, recall_list, fscore_list, 180 | right_count_list, predict_count_list, standard_count_list, turn_accuracy 181 | """ 182 | 183 | def _init_confusion_matrix(label_map): 184 | """Init confusion matrix. 185 | Args: 186 | label_map: Label map. 187 | Returns: 188 | confusion_matrix. 189 | """ 190 | confusion_matrix = dict() 191 | for label_name in label_map.keys(): 192 | confusion_matrix[label_name] = dict() 193 | for label_name_other in label_map.keys(): 194 | confusion_matrix[label_name][label_name_other] = 0 195 | return confusion_matrix 196 | 197 | def _init_count_dict(label_map): 198 | """Init count dict. 199 | Args: 200 | label_map: Label map. 201 | Returns: 202 | count_dict. 203 | """ 204 | count_dict = dict() 205 | for label_name in label_map.keys(): 206 | count_dict[label_name] = 0 207 | return count_dict 208 | 209 | assert (standard_label_names is not None or 210 | standard_label_ids is not None) 211 | sep = cDataset.CLASSIFICATION_LABEL_SEPARATOR 212 | depth = 0 213 | if not is_prob and label_map is None: 214 | assert standard_label_names is not None 215 | label_map = dict() 216 | # Use standard_label_names to generate label_map 217 | for label_list in standard_label_names: 218 | for label in label_list: 219 | if label not in label_map: 220 | label_map[label] = len(label_map) 221 | if not is_flat: 222 | for label in label_map.keys(): 223 | hierarchical_labels = label.split(sep) 224 | depth = max(len(hierarchical_labels), depth) 225 | label_to_id_maps = [] 226 | id_to_label_maps = [] 227 | for i in range(depth + 1): 228 | label_to_id_maps.append(dict()) 229 | id_to_label_maps.append(dict()) 230 | for label_name, label_id in label_map.items(): 231 | label_to_id_maps[0][label_name] = label_id 232 | id_to_label_maps[0][label_id] = label_name 233 | if not is_flat: 234 | hierarchical_labels = label_name.split(sep) 235 | for i in range(1, len(hierarchical_labels) + 1): 236 | label = sep.join(hierarchical_labels[:i]) 237 | if label not in label_to_id_maps[i]: 238 | index = len(label_to_id_maps[i]) 239 | label_to_id_maps[i][label] = index 240 | id_to_label_maps[i][index] = label 241 | 242 | confusion_matrix_list = [] 243 | right_category_count_list = [] 244 | predict_category_count_list = [] 245 | standard_category_count_list = [] 246 | for i in range(depth + 1): 247 | confusion_matrix_list.append( 248 | _init_confusion_matrix(label_to_id_maps[i])) 249 | right_category_count_list.append( 250 | _init_count_dict(label_to_id_maps[i])) 251 | predict_category_count_list.append( 252 | _init_count_dict(label_to_id_maps[i])) 253 | standard_category_count_list.append( 254 | _init_count_dict(label_to_id_maps[i])) 255 | 256 | line_count = 0 257 | debug_file = open("probs.txt", "w", encoding=cDataset.CHARSET) 258 | accuracy = 0 259 | for predict in predicts: 260 | if is_prob: 261 | prob_np = np.array(predict, dtype=np.float32) 262 | if not is_multi: 263 | predict_label_ids = [prob_np.argmax()] 264 | else: 265 | predict_label_ids = [] 266 | predict_label_idx = np.argsort(-prob_np) 267 | for j in range(0, top_k): 268 | if prob_np[predict_label_idx[j]] > threshold: 269 | predict_label_ids.append(predict_label_idx[j]) 270 | 271 | predict_label_name = [id_to_label_maps[0][predict_label_id] \ 272 | for predict_label_id in predict_label_ids] 273 | debug_file.write(json.dumps(prob_np.tolist())) 274 | debug_file.write("\n") 275 | else: 276 | predict_label_name = predict 277 | 278 | if standard_label_names is not None: 279 | standard_label_name = standard_label_names[line_count] 280 | else: 281 | standard_label_name = [id_to_label_maps[0][standard_label_ids[line_count][i]] \ 282 | for i in range(len(standard_label_ids[line_count]))] 283 | if (not self. _judge_label_in(predict_label_name, label_to_id_maps)) or \ 284 | (not self._judge_label_in(standard_label_name, label_to_id_maps)): 285 | line_count += 1 286 | continue 287 | for std_name in standard_label_name: 288 | for pred_name in predict_label_name: 289 | confusion_matrix_list[0][std_name][pred_name] += 1 290 | for pred_name in predict_label_name: 291 | predict_category_count_list[0][pred_name] += 1 292 | for std_name in standard_label_name: 293 | standard_category_count_list[0][std_name] += 1 294 | for pred_name in predict_label_name: 295 | if std_name == pred_name: 296 | right_category_count_list[0][pred_name] += 1 297 | 298 | if standard_label_name == predict_label_name: 299 | accuracy += 1 300 | 301 | if not is_flat: 302 | standard_hierarchical_labels = \ 303 | [std_name.split(sep) for std_name in standard_label_name] 304 | predict_hierarchical_labels = \ 305 | [pred_name.split(sep) for pred_name in predict_label_name] 306 | 307 | standard_label_map = {} 308 | predict_label_map = {} 309 | for std_label in standard_hierarchical_labels: 310 | for i in range(0, len(std_label)): 311 | if i + 1 not in standard_label_map: 312 | standard_label_map[i + 1] = set() 313 | standard_label_map[i + 1].add(sep.join(std_label[:i+1])) 314 | for pred_label in predict_hierarchical_labels: 315 | for i in range(0, len(pred_label)): 316 | if i + 1 not in predict_label_map: 317 | predict_label_map[i + 1] = set() 318 | predict_label_map[i + 1].add(sep.join(pred_label[:i+1])) 319 | for level, std_label_set in standard_label_map.items(): 320 | for std_label in std_label_set: 321 | standard_category_count_list[level][std_label] += 1 322 | for level, pred_label_set in predict_label_map.items(): 323 | for pred_label in pred_label_set: 324 | predict_category_count_list[level][pred_label] += 1 325 | for level, std_label_set in standard_label_map.items(): 326 | for std_label in std_label_set: 327 | if level in predict_label_map: 328 | for pred_label in predict_label_map[level]: 329 | confusion_matrix_list[level][std_label][pred_label] += 1 330 | if std_label == pred_label: 331 | right_category_count_list[level][pred_label] += 1 332 | line_count += 1 333 | turn_accuracy = float(accuracy)/float(line_count) 334 | debug_file.close() 335 | precision_list = [] 336 | recall_list = [] 337 | fscore_list = [] 338 | precision_dict, recall_dict, fscore_dict = \ 339 | self.calculate_level_performance( 340 | id_to_label_maps[0], right_category_count_list[0], 341 | predict_category_count_list[0], standard_category_count_list[0], 342 | exclude_method="start") 343 | 344 | precision_list.append(precision_dict) 345 | recall_list.append(recall_dict) 346 | fscore_list.append(fscore_dict) 347 | 348 | for i in range(1, depth + 1): 349 | precision_dict, recall_dict, fscore_dict = \ 350 | self.calculate_level_performance( 351 | id_to_label_maps[i], right_category_count_list[i], 352 | predict_category_count_list[i], 353 | standard_category_count_list[i], other_text) 354 | precision_list.append(precision_dict) 355 | recall_list.append(recall_dict) 356 | fscore_list.append(fscore_dict) 357 | 358 | self.confusion_matrix_list, self.precision_list, self.recall_list,\ 359 | self.fscore_list, self.right_list, self.predict_list,\ 360 | self.standard_list = ( 361 | confusion_matrix_list, precision_list, recall_list, fscore_list, 362 | right_category_count_list, predict_category_count_list, 363 | standard_category_count_list) 364 | return (confusion_matrix_list, precision_list, recall_list, fscore_list, 365 | right_category_count_list, predict_category_count_list, 366 | standard_category_count_list, turn_accuracy) 367 | 368 | @staticmethod 369 | def save_confusion_matrix(file_name, confusion_matrix): 370 | """Save confusion matrix 371 | Args: 372 | file_name: File to save to. 373 | confusion_matrix: Confusion Matrix. 374 | Returns: 375 | """ 376 | with open(file_name, "w", encoding=cDataset.CHARSET) as cm_file: 377 | cm_file.write("\t") 378 | for category_fist in sorted(confusion_matrix.keys()): 379 | cm_file.write(category_fist + "\t") 380 | cm_file.write("\n") 381 | for category_fist in sorted(confusion_matrix.keys()): 382 | cm_file.write(category_fist + "\t") 383 | for category_second in sorted(confusion_matrix.keys()): 384 | cm_file.write( 385 | str(confusion_matrix[category_fist][ 386 | category_second]) + "\t") 387 | cm_file.write("\n") 388 | 389 | def save_prf(self, file_name, precision_category, recall_category, 390 | fscore_category, right_category, predict_category, 391 | standard_category): 392 | """Save precision, recall, fscore 393 | Args: 394 | file_name: File to save to. 395 | precision_category: Precision dict. 396 | recall_category: Recall dict. 397 | fscore_category: Fscore dict. 398 | right_category: Right dict. 399 | predict_category: Predict dict. 400 | standard_category: Standard dict. 401 | Returns: 402 | """ 403 | 404 | def _format(category): 405 | """Format evaluation string. 406 | Args: 407 | category: Category evaluation to format. 408 | Returns: 409 | """ 410 | if category == self.MACRO_AVERAGE: 411 | return "%s, precision: %f, recall: %f, fscore: %f, " % ( 412 | category, precision_category[category], 413 | recall_category[category], fscore_category[category]) 414 | return "%s, precision: %f, recall: %f, fscore: %f, " \ 415 | "right_count: %d, predict_count: %d, " \ 416 | "standard_count: %d" % ( 417 | category, precision_category[category], 418 | recall_category[category], fscore_category[category], 419 | right_category[category], predict_category[category], 420 | standard_category[category]) 421 | 422 | with open(file_name, "w", encoding=cDataset.CHARSET) as prf_file: 423 | prf_file.write(_format(self.MACRO_AVERAGE) + "\n") 424 | prf_file.write(_format(self.MICRO_AVERAGE) + "\n") 425 | prf_file.write("\n") 426 | for category in precision_category: 427 | if category != self.MICRO_AVERAGE and \ 428 | category != self.MACRO_AVERAGE: 429 | prf_file.write(_format(category) + "\n") 430 | 431 | def save(self): 432 | """Save the latest evaluation. 433 | """ 434 | for i, confusion_matrix in enumerate(self.confusion_matrix_list): 435 | if i == 0: 436 | eval_name = "all" 437 | else: 438 | eval_name = "level_%s" % i 439 | self.save_confusion_matrix( 440 | self.eval_dir + "/" + eval_name + "_confusion_matrix", 441 | confusion_matrix) 442 | self.save_prf( 443 | self.eval_dir + "/" + eval_name + "_prf", 444 | self.precision_list[i], self.recall_list[i], 445 | self.fscore_list[i], self.right_list[i], 446 | self.predict_list[i], self.standard_list[i]) 447 | --------------------------------------------------------------------------------