├── models ├── __pycache__ │ ├── crf.cpython-36.pyc │ ├── BertMark.cpython-36.pyc │ ├── CnnText.cpython-36.pyc │ ├── BertClass.cpython-36.pyc │ ├── Bert_baseline.cpython-36.pyc │ └── MaskedCELoss.cpython-36.pyc └── BertClass.py ├── data └── scripts │ ├── __pycache__ │ ├── dataset.cpython-36.pyc │ ├── vocab.cpython-36.pyc │ ├── dataloader.cpython-36.pyc │ └── tokenization.cpython-36.pyc │ ├── process.py │ ├── dataset.py │ └── dataloader.py ├── scripts ├── utils.py ├── add_key_entity.py └── train_class.py └── README.md /models/__pycache__/crf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/models/__pycache__/crf.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/BertMark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/models/__pycache__/BertMark.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/CnnText.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/models/__pycache__/CnnText.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/BertClass.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/models/__pycache__/BertClass.cpython-36.pyc -------------------------------------------------------------------------------- /data/scripts/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/data/scripts/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/scripts/__pycache__/vocab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/data/scripts/__pycache__/vocab.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/Bert_baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/models/__pycache__/Bert_baseline.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/MaskedCELoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/models/__pycache__/MaskedCELoss.cpython-36.pyc -------------------------------------------------------------------------------- /data/scripts/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/data/scripts/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /data/scripts/__pycache__/tokenization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rjk-git/CCF_Negative_Financial_Information_and_Subject_Judgment/HEAD/data/scripts/__pycache__/tokenization.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def config_logger(log_path): 5 | # Configuring logger 6 | logger = logging.getLogger() 7 | logger.setLevel(logging.INFO) 8 | fhandler = logging.FileHandler(log_path, mode='w') 9 | shandler = logging.StreamHandler() 10 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 11 | fhandler.setFormatter(formatter) 12 | shandler.setFormatter(formatter) 13 | logger.addHandler(fhandler) 14 | logger.addHandler(shandler) 15 | 16 | return logger 17 | -------------------------------------------------------------------------------- /models/BertClass.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import nn, rnn 2 | import mxnet as mx 3 | from mxnet import nd 4 | import numpy as np 5 | 6 | 7 | class BertClass(nn.Block): 8 | def __init__(self, bert, max_seq_len, ctx=mx.cpu(), **kwargs): 9 | super(BertClass, self).__init__(**kwargs) 10 | self.ctx = ctx 11 | self.max_seq_len = max_seq_len 12 | self.bert = bert 13 | self.output_dense = nn.Dense(2) 14 | 15 | def forward(self, content, token_types, valid_len): 16 | bert_output = self.bert(content, token_types, valid_len) 17 | bert_output = bert_output[:, 0, :] 18 | output = self.output_dense(bert_output) 19 | return output 20 | -------------------------------------------------------------------------------- /data/scripts/process.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from itertools import islice 3 | import jieba.posseg 4 | import multiprocessing 5 | from gluonnlp.data import Counter 6 | import pandas as pd 7 | 8 | 9 | def split_to_train_valid(): 10 | valid_ratio = 0.1 11 | data = pd.read_csv("../Train_Data.csv", doublequote=True) 12 | print("总计:{}条样本".format(len(data))) 13 | train = data.sample(frac=0.9, random_state=0, axis=0) 14 | valid = data[~data.index.isin(train.index)] 15 | train.to_csv("../train.csv", index=False) 16 | valid.to_csv("../valid.csv", index=False) 17 | print("训练集:{}条样本".format(len(train))) 18 | print("验证集:{}条样本".format(len(valid))) 19 | 20 | 21 | if __name__ == "__main__": 22 | split_to_train_valid() 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCF_Negative_Financial_Information_and_Subject_Judgment 2 | #### a simple and effective baseline! 3 | ### 适用于将该题目分为两个子任务(负面分类,主体判定)的同学使用。 4 | ~~1.本代码不包含负面分类的任务代码,可参考各类 伯特,罗伯特,阿尔伯特 的 分类baseline代码。~~ 5 | 1.由于有同学需要,就把自己的放上来了,不过是Mxnet,直接使用还需配置相应环境,详情见(http://zh.gluon.ai/chapter_prerequisite/install.html#%E4%BD%BF%E7%94%A8GPU%E7%89%88%E7%9A%84MXNet),这是gluon的文档,内容丰富,甚至完全能够作为深度学习入门教程。如果不想安装,代码也可以作为参考,和pytorch大概也只是一些函数名字不一样。 6 | 2.只需把第一个任务的结果(id,negative)文件加入代码中即可得到最终的结果(id,negative,key_entity) 7 | 3.代码只用了两个方式过滤实体,“取最大字符串”和“NIKE实体过滤” 8 | a.取最大字符串方法意思是:如果存在一个实体的字符串可以包含其他实体的字符串,那么就取这一个实体就行。如:小资易贷,小资易贷有限公司,就只取:小资易贷有限公司 9 | b.NICK实体过滤的NIKE全称"Not In Key Entity" but in entity。统计每个实体在Entity中出现又在KeyEntity中出现的次数,以及对应没有出现的次数。 10 | 设定一个比例,在Entity中出现又在KeyEntity中出现的次数的比例低于多少的实体为"NIKE"实体,直接过滤。 11 | 12 | #### 主体判定在该方式下可以达到线上0.558左右。如果第一个任务负面分类分数F1_s达到0.96 x 0.4 + 0.558 = 0.9415,如果你的F1_s达到0.98 x 0.4 + 0.558 = 0.9495! 13 | #### 代码中func_on_row()里面你还可以继续添加你觉得可以过滤的方法进一步提升效果。 14 | #### 最后,觉得有用,star!感谢。 15 | -------------------------------------------------------------------------------- /data/scripts/dataset.py: -------------------------------------------------------------------------------- 1 | from mxnet import gluon 2 | import pandas as pd 3 | from itertools import islice 4 | from mxnet.gluon.data import ArrayDataset 5 | 6 | 7 | class ClassDataset(gluon.data.Dataset): 8 | def __init__(self, file_path, **kwargs): 9 | super(ClassDataset, self).__init__(**kwargs) 10 | self.id, self.context, self.label = self._get_data(file_path) 11 | 12 | def _get_data(self, file_path): 13 | example_id = [] 14 | example_content = [] 15 | example_label = [] 16 | data = pd.read_csv(file_path, doublequote=True) 17 | datas = data.values.tolist() 18 | for data in datas: 19 | eid = data[0] 20 | title = data[1] 21 | content = data[2] 22 | label = data[4] 23 | example_id.append(eid) 24 | if title != title: 25 | title = "" 26 | if content != content: 27 | content = "" 28 | example_content.append(title + content) 29 | example_label.append(label) 30 | return example_id, example_content, example_label 31 | 32 | def __getitem__(self, item): 33 | return self.id[item], self.context[item], self.label[item] 34 | 35 | def __len__(self): 36 | return len(self.id) 37 | 38 | 39 | class ClassTestDataset(gluon.data.Dataset): 40 | def __init__(self, file_path, **kwargs): 41 | super(ClassTestDataset, self).__init__(**kwargs) 42 | self.id, self.context, self.example_entitys = self._get_data(file_path) 43 | 44 | def _get_data(self, file_path): 45 | example_id = [] 46 | example_content = [] 47 | example_entitys = [] 48 | data = pd.read_csv(file_path, doublequote=True) 49 | datas = data.values.tolist() 50 | for data in datas: 51 | eid = data[0] 52 | title = data[1] 53 | content = data[2] 54 | entity = data[3] 55 | example_id.append(eid) 56 | if title != title: 57 | title = "" 58 | if content != content: 59 | content = "" 60 | if entity != entity: 61 | entity = "" 62 | example_content.append(title + content) 63 | example_entitys.append(entity) 64 | return example_id, example_content, example_entitys 65 | 66 | def __getitem__(self, item): 67 | return self.id[item], self.context[item], self.example_entitys[item] 68 | 69 | def __len__(self): 70 | return len(self.id) 71 | -------------------------------------------------------------------------------- /data/scripts/dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | import re 5 | import sys 6 | sys.path.append("../data/scripts") 7 | 8 | import gluon 9 | import gluonnlp as nlp 10 | from gluonnlp import Vocab 11 | from gluonnlp.data import BERTSentenceTransform, BERTTokenizer, Counter 12 | from mxnet import nd 13 | from mxnet.gluon import data 14 | 15 | import tokenization 16 | 17 | 18 | class DatasetAssiantTransformer(): 19 | def __init__(self, ch_vocab=None, max_seq_len=None, istrain=True): 20 | self.ch_vocab = ch_vocab 21 | self.max_seq_len = max_seq_len 22 | self.istrain = istrain 23 | self.tokenizer = BERTTokenizer(ch_vocab) # 后面没用bert的tokenizer,感觉效果反而好些。 24 | 25 | def ClassProcess(self, *data): 26 | if self.istrain: 27 | example_id, source, label = data 28 | else: 29 | example_id, source, entity = data 30 | content = re.sub("[ \n\t\\n\u3000]", " ", source) 31 | content = re.sub("[??]+", "?", content) 32 | content = [char for char in content] 33 | content = [self.ch_vocab(self.ch_vocab.cls_token)] + content 34 | content = self.ch_vocab(content) 35 | if self.max_seq_len and len(content) > self.max_seq_len: 36 | content = content[:self.max_seq_len] 37 | valid_len = len(content) 38 | token_type = [0] * valid_len 39 | if self.istrain: 40 | return content, token_type, valid_len, label, example_id 41 | else: 42 | return content, token_type, valid_len, source, entity, example_id 43 | 44 | 45 | class ClassDataLoader(object): 46 | def __init__(self, dataset, batch_size, assiant, shuffle=False, num_workers=3, lazy=True): 47 | trans_func = assiant.ClassProcess 48 | self.istrain = assiant.istrain 49 | self.assiant = assiant 50 | self.dataset = dataset.transform(trans_func, lazy=lazy) 51 | self.batch_size = batch_size 52 | self.pad_val = assiant.ch_vocab[assiant.ch_vocab.padding_token] 53 | self.shuffle = shuffle 54 | self.num_workers = num_workers 55 | self.dataloader = self._build_dataloader() 56 | 57 | def _build_dataloader(self): 58 | if self.istrain: 59 | batchify_fn = nlp.data.batchify.Tuple( 60 | nlp.data.batchify.Pad(pad_val=self.pad_val), 61 | nlp.data.batchify.Pad(pad_val=0), 62 | nlp.data.batchify.Stack(dtype="float32"), 63 | nlp.data.batchify.Stack(dtype="float32"), 64 | nlp.data.batchify.List() 65 | ) 66 | else: 67 | batchify_fn = nlp.data.batchify.Tuple( 68 | nlp.data.batchify.Pad(pad_val=self.pad_val), 69 | nlp.data.batchify.Pad(pad_val=0), 70 | nlp.data.batchify.Stack(dtype="float32"), 71 | nlp.data.batchify.List(), 72 | nlp.data.batchify.List(), 73 | nlp.data.batchify.List() 74 | ) 75 | dataloader = data.DataLoader(dataset=self.dataset, batch_size=self.batch_size, 76 | shuffle=self.shuffle, batchify_fn=batchify_fn, 77 | num_workers=self.num_workers) 78 | return dataloader 79 | 80 | @property 81 | def dataiter(self): 82 | return self.dataloader 83 | 84 | @property 85 | def data_lengths(self): 86 | return len(self.dataset) 87 | -------------------------------------------------------------------------------- /scripts/add_key_entity.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | class PickKeyEntity(object): 5 | # neg_filepath: your result of negative predict, only negative 0-1 6 | # train_filepath: source "Train_Data.csv" 7 | # test_filepath: source "Test_Data.csv" 8 | def __init__(self, neg_filepath, train_filepath, test_filepath): 9 | self.train_filepath = train_filepath 10 | self.test_filepath = test_filepath 11 | self.neg_filepath = neg_filepath 12 | self.ratio = 0.1 # in entity and in key_entitiy ratio that smaller than this ratio is not key entity 13 | self.nike = self._generate_nike_data() 14 | 15 | def run(self): 16 | test_data = pd.read_csv(self.test_filepath) 17 | test_data = test_data.loc[:, ["id", "entity"]] 18 | neg_data = pd.read_csv(self.neg_filepath) 19 | test_data = pd.merge(neg_data, test_data, on="id", how="inner") 20 | test_data["key_entity"] = test_data.apply(self.func_on_row, axis=1) 21 | test_data.to_csv("./result.csv", index=False, 22 | columns=["id", "negative", "key_entity"]) 23 | 24 | # add your thoughts here to filter entities if you want to improve the effect further. 25 | def func_on_row(self, row): 26 | key_entitys = [] 27 | if row["negative"] == 1: 28 | entitys = row["entity"].split(";") 29 | for entity in entitys: 30 | if entity not in self.nike: 31 | key_entitys.append(entity) 32 | key_entitys = self._remove_substring(key_entitys) 33 | return ";".join(key_entitys) 34 | 35 | # 'nike' entity means "not in key entitiy" but in entity from train data 36 | def _generate_nike_data(self): 37 | nike = [] 38 | numsOfEntitiyAsKey = {} 39 | train_data = pd.read_csv(self.train_filepath) 40 | train_data = train_data.loc[:, ["negative", "entity", "key_entity"]] 41 | train_data = train_data[train_data.negative == 1] 42 | for index, row in train_data.iterrows(): 43 | entitys = row["entity"] 44 | key_entitys = row["key_entity"] 45 | entitys = entitys.split(";") 46 | key_entitys = key_entitys.split(";") 47 | for entity in entitys: 48 | if numsOfEntitiyAsKey.get(entity, -1) == -1: 49 | if entity in key_entitys: 50 | numsOfEntitiyAsKey.update({entity: {"in": 1, "out": 0}}) 51 | else: 52 | numsOfEntitiyAsKey.update({entity: {"in": 0, "out": 1}}) 53 | else: 54 | if entity in key_entitys: 55 | numsOfEntitiyAsKey[entity]["in"] += 1 56 | else: 57 | numsOfEntitiyAsKey[entity]["out"] += 1 58 | for entity, nums in numsOfEntitiyAsKey.items(): 59 | num_in = nums["in"] 60 | num_out = nums["out"] 61 | freq_in = num_in / (num_in + num_out) 62 | if freq_in < self.ratio: 63 | nike.append(entity) 64 | return nike 65 | 66 | # remove entities that can be substring of other entities. 67 | # eg: 资易贷,小资易贷,资易贷有限公司 we retain 小资易贷,资易贷有限公司 68 | def _remove_substring(self, entities): 69 | entities = list(set(entities)) 70 | longest_entities = [] 71 | for entity in entities: 72 | flag = 0 73 | for entity_ in entities: 74 | if entity == entity_: 75 | continue 76 | if entity_.find(entity) != -1: 77 | flag = 1 78 | if flag == 0: 79 | longest_entities.append(entity) 80 | return longest_entities 81 | 82 | 83 | if __name__ == "__main__": 84 | pickKeyEntity = PickKeyEntity("../results/negative_result.csv", 85 | "../data/Train_Data.csv", "../data/Test_Data.csv") 86 | pickKeyEntity.run() 87 | -------------------------------------------------------------------------------- /scripts/train_class.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import math 4 | import os 5 | import re 6 | import sys 7 | sys.path.append("..") 8 | 9 | import gluonnlp 10 | import jieba.posseg 11 | import mxnet as mx 12 | import numpy as np 13 | from gluonnlp.data import train_valid_split 14 | from gluonnlp.model import BeamSearchScorer 15 | from mxboard import * 16 | from mxnet import autograd, gluon, init, nd 17 | from mxnet.gluon import loss as gloss 18 | from numpy import random 19 | from tqdm import tqdm 20 | 21 | from data.scripts.dataloader import ClassDataLoader, DatasetAssiantTransformer 22 | from data.scripts.dataset import ClassDataset, ClassTestDataset 23 | from data.scripts.vocab import load_label_vocab 24 | from models.BertClass import BertClass 25 | from utils import config_logger, get_entities 26 | 27 | 28 | np.random.seed(100) 29 | random.seed(100) 30 | mx.random.seed(10000) 31 | 32 | 33 | def train_and_valid(ch_bert, model, ch_vocab, train_dataiter, dev_dataiter, trainer, finetune_trainer, epochs, loss_func, ctx, lr, batch_size, params_save_step, params_save_path_root, eval_step, log_step, check_step, logger, num_train_examples, warmup_ratio): 34 | batches = len(train_dataiter) 35 | 36 | num_train_steps = int(num_train_examples / batch_size * epochs) 37 | num_warmup_steps = int(num_train_steps * warmup_ratio) 38 | global_step = 0 39 | 40 | dev_bleu_score = 0 41 | 42 | for epoch in range(epochs): 43 | for content, token_types, valid_len, label, example_id in train_dataiter: 44 | # learning rate schedule 45 | if global_step < num_warmup_steps: 46 | new_lr = lr * global_step / num_warmup_steps 47 | else: 48 | non_warmup_steps = global_step - num_warmup_steps 49 | offset = non_warmup_steps / (num_train_steps - num_warmup_steps) 50 | new_lr = lr - offset * lr 51 | trainer.set_learning_rate(new_lr) 52 | 53 | content = content.as_in_context(ctx) 54 | token_types = token_types.as_in_context(ctx) 55 | valid_len = valid_len.as_in_context(ctx) 56 | label = label.as_in_context(ctx) 57 | 58 | with autograd.record(): 59 | output = model(content, token_types, valid_len) 60 | loss_mean = loss_func(output, label) 61 | loss_mean = nd.sum(loss_mean) / batch_size 62 | loss_mean.backward() 63 | loss_scalar = loss_mean.asscalar() 64 | 65 | trainer.step(1) 66 | finetune_trainer.step(1) 67 | 68 | if global_step and global_step % log_step == 0: 69 | acc = nd.sum(nd.equal(nd.argmax(nd.softmax( 70 | output, axis=-1), axis=-1), label)) / batch_size 71 | acc = acc.asscalar() 72 | logger.info("epoch:{}, batch:{}/{}, acc:{}, loss:{}, (lr:{}s)".format(epoch, global_step % 73 | batches, batches, acc, loss_scalar, trainer.learning_rate)) 74 | global_step += 1 75 | F1 = dev(ch_bert, model, ch_vocab, dev_dataiter, logger, ctx) 76 | if not os.path.exists(params_save_path_root): 77 | os.makedirs(params_save_path_root) 78 | model_params_file = params_save_path_root + \ 79 | "model_step_{}_{}.params".format(global_step, F1) 80 | model.save_parameters(model_params_file) 81 | logger.info("{} Save Completed.".format(model_params_file)) 82 | 83 | 84 | def dev(ch_bert, model, ch_vocab, dev_dataiter, logger, ctx): 85 | TP_s = 0 86 | FP_s = 0 87 | FN_s = 0 88 | example_ids = [] 89 | for content, token_types, valid_len, label, example_id in tqdm(dev_dataiter): 90 | example_ids.extend(example_id) 91 | content = content.as_in_context(ctx) 92 | token_types = token_types.as_in_context(ctx) 93 | valid_len = valid_len.as_in_context(ctx) 94 | label = label.as_in_context(ctx) 95 | 96 | output = model(content, token_types, valid_len) 97 | predict = nd.argmax(nd.softmax(output, axis=-1), axis=-1) 98 | label = label.as_in_context(ctx) 99 | tp_s = int(nd.sum(nd.equal(predict, label)).asscalar()) 100 | fp_s = int(nd.sum(nd.not_equal(predict, label) * nd.equal(label, 0)).asscalar()) 101 | fn_s = int(nd.sum(nd.not_equal(predict, label) * nd.equal(label, 1)).asscalar()) 102 | TP_s += tp_s 103 | FP_s += fp_s 104 | FN_s += fn_s 105 | 106 | P_s = TP_s / (TP_s + FP_s) 107 | R_s = TP_s / (TP_s + FN_s) 108 | F = (2 * P_s * R_s) / (P_s + R_s) 109 | 110 | logger.info("F:{}".format(F)) 111 | return F 112 | 113 | 114 | def predict(ch_bert, model, ch_vocab, test_dataiter, logger, ctx): 115 | example_ids = [] 116 | sources = [] 117 | pred_labels = [] 118 | entities = [] 119 | 120 | neg_example_ids = [] 121 | neg_sources = [] 122 | neg_pred_labels = [] 123 | neg_entities = [] 124 | 125 | for content, token_type, valid_len, source, entity, example_id in tqdm(test_dataiter): 126 | sources.extend(source) 127 | entities.extend(entity) 128 | 129 | content = content.as_in_context(ctx) 130 | token_type = token_type.as_in_context(ctx) 131 | valid_len = valid_len.as_in_context(ctx) 132 | 133 | outputs = model(content, token_type, valid_len) 134 | predicts = nd.argmax(nd.softmax(outputs, axis=-1), axis=-1).asnumpy().tolist() 135 | predicts = [int(label) for label in predicts] 136 | for label, eid, text, en in zip(predicts, example_id, source, entity): 137 | if label == 0: 138 | example_ids.append(eid) 139 | sources.append(text) 140 | pred_labels.append(label) 141 | entities.append(en) 142 | else: 143 | neg_example_ids.append(eid) 144 | neg_sources.append(text) 145 | neg_pred_labels.append(label) 146 | neg_entities.append(en) 147 | 148 | f = open("../results/result_only_neg.csv", "w", encoding="utf-8") 149 | writer = csv.writer(f) 150 | writer.writerow(["id", "nagative", "key_entity"]) 151 | example_ids.extend(neg_example_ids) 152 | pred_labels.extend(neg_pred_labels) 153 | for ids, label in zip(example_ids, pred_labels): 154 | row = [ids, label, ""] 155 | writer.writerow(row) 156 | 157 | 158 | def main(args): 159 | # init some setting 160 | # config logging 161 | log_path = os.path.join(args.log_root, '{}.log'.format(args.model_name)) 162 | logger = config_logger(log_path) 163 | 164 | gpu_idx = args.gpu 165 | if not gpu_idx: 166 | ctx = mx.cpu() 167 | else: 168 | ctx = mx.gpu(gpu_idx - 1) 169 | logger.info("Using ctx: {}".format(ctx)) 170 | 171 | # Loading vocab and model 172 | ch_bert, ch_vocab = gluonnlp.model.get_model(args.bert_model, 173 | dataset_name=args.ch_bert_dataset, 174 | pretrained=True, 175 | ctx=ctx, 176 | use_pooler=False, 177 | use_decoder=False, 178 | use_classifier=False) 179 | model = BertClass(bert=ch_bert, 180 | max_seq_len=args.max_seq_len, ctx=ctx) 181 | logger.info("Model Creating Completed.") 182 | 183 | # init or load params for model 184 | if args.istrain: 185 | model.output_dense.initialize(init.Xavier(), ctx) 186 | else: 187 | model.load_parameters(args.model_params_path, ctx=ctx) 188 | logger.info("Parameters Initing and Loading Completed") 189 | 190 | model.hybridize() 191 | 192 | if args.istrain: 193 | # Loading dataloader 194 | assiant = DatasetAssiantTransformer( 195 | ch_vocab=ch_vocab, max_seq_len=args.max_seq_len) 196 | dataset = ClassDataset(args.train_file_path) 197 | train_dataset, dev_dataset = train_valid_split(dataset, valid_ratio=0.1) 198 | train_dataiter = ClassDataLoader(train_dataset, batch_size=args.batch_size, 199 | assiant=assiant, shuffle=True).dataiter 200 | dev_dataiter = ClassDataLoader(dev_dataset, batch_size=args.batch_size, 201 | assiant=assiant, shuffle=True).dataiter 202 | logger.info("Data Loading Completed") 203 | else: 204 | assiant = DatasetAssiantTransformer( 205 | ch_vocab=ch_vocab, max_seq_len=args.max_seq_len, istrain=args.istrain) 206 | test_dataset = ClassTestDataset(args.test_file_path) 207 | test_dataiter = ClassDataLoader(test_dataset, batch_size=args.batch_size, 208 | assiant=assiant, shuffle=True).dataiter 209 | 210 | # build trainer 211 | finetune_trainer = gluon.Trainer(ch_bert.collect_params(), 212 | args.optimizer, {"learning_rate": args.finetune_lr}) 213 | trainer = gluon.Trainer(model.collect_params("dense*"), args.optimizer, 214 | {"learning_rate": args.train_lr}) 215 | 216 | loss_func = gloss.SoftmaxCELoss() 217 | 218 | if args.istrain: 219 | logger.info("## Trainning Start ##") 220 | train_and_valid( 221 | ch_bert=ch_bert, model=model, ch_vocab=ch_vocab, 222 | train_dataiter=train_dataiter, dev_dataiter=dev_dataiter, trainer=trainer, finetune_trainer=finetune_trainer, epochs=args.epochs, 223 | loss_func=loss_func, ctx=ctx, lr=args.train_lr, batch_size=args.batch_size, params_save_step=args.params_save_step, 224 | params_save_path_root=args.params_save_path_root, eval_step=args.eval_step, log_step=args.log_step, check_step=args.check_step, 225 | logger=logger, num_train_examples=len(train_dataset), warmup_ratio=args.warmup_ratio 226 | ) 227 | else: 228 | predict(ch_bert=ch_bert, model=model, ch_vocab=ch_vocab, 229 | test_dataiter=test_dataiter, logger=logger, ctx=ctx) 230 | 231 | 232 | if __name__ == "__main__": 233 | parser = argparse.ArgumentParser() 234 | parser.add_argument("--model_name", type=str, default="new_entites_find") 235 | parser.add_argument("--train_file_path", type=str, 236 | default="../data/Train_Data.csv") 237 | parser.add_argument("--test_file_path", type=str, default="../data/Test_Data.csv") 238 | parser.add_argument("--bert_model", type=str, 239 | default="bert_12_768_12") 240 | parser.add_argument("--ch_bert_dataset", type=str, 241 | default="wiki_cn_cased") 242 | parser.add_argument("--model_params_path", type=str, 243 | default="../parameters/xxx.params") 244 | parser.add_argument("--istrain", type=bool, 245 | default=True) 246 | parser.add_argument("--score", type=str, 247 | default="0") 248 | parser.add_argument("--gpu", type=int, 249 | default=1, help='which gpu to use for finetuning. CPU is used if set 0.') 250 | parser.add_argument("--optimizer", type=str, default="adam") 251 | parser.add_argument("--bert_optimizer", type=str, default="bertadam") 252 | parser.add_argument("--train_lr", type=float, default=5e-5) 253 | parser.add_argument("--finetune_lr", type=float, default=2e-5) 254 | parser.add_argument("--batch_size", type=int, 255 | default=64) 256 | parser.add_argument("--epochs", type=int, 257 | default=3) 258 | parser.add_argument("--log_root", type=str, default="../logs/") 259 | parser.add_argument("--log_step", type=int, default=10) 260 | parser.add_argument("--eval_step", type=int, default=1000) 261 | parser.add_argument("--check_step", type=int, default=5) 262 | parser.add_argument("--params_save_step", type=int, default=300) 263 | parser.add_argument("--params_save_path_root", type=str, default="../parameters/") 264 | parser.add_argument('--warmup_ratio', type=float, default=0.1, 265 | help='ratio of warmup steps that linearly increase learning rate from ' 266 | '0 to target learning rate. default is 0.1') 267 | parser.add_argument("--max_seq_len", type=int, 268 | default=256) 269 | # model parameters setting 270 | parser.add_argument("--model_dim", type=int, 271 | default=768) 272 | parser.add_argument("--dropout", type=float, 273 | default=0.1) 274 | 275 | args = parser.parse_args() 276 | 277 | main(args) 278 | --------------------------------------------------------------------------------