├── __init__.py ├── pic ├── eval.png └── test.png ├── __pycache__ ├── utils.cpython-36.pyc └── train_eval.cpython-36.pyc ├── models ├── __pycache__ │ ├── textcnn.cpython-36.pyc │ └── lstm_embedding.cpython-36.pyc ├── lstm_embedding.py └── textcnn.py ├── README.md ├── utils.py ├── run.py └── train_eval.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pic/eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathCrazyy/text_classify/HEAD/pic/eval.png -------------------------------------------------------------------------------- /pic/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathCrazyy/text_classify/HEAD/pic/test.png -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathCrazyy/text_classify/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/train_eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathCrazyy/text_classify/HEAD/__pycache__/train_eval.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/textcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathCrazyy/text_classify/HEAD/models/__pycache__/textcnn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/lstm_embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mathCrazyy/text_classify/HEAD/models/__pycache__/lstm_embedding.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # text_classify 2 | 针对Cnews数据集进行分类,使用了torchtext进行文本预处理,使用textcnn,lstm提取特征做分类。 3 | 代码中使用的need_bertembedding可以通过如下工具自动生成: https://github.com/xmxoxo/BERT-Vector 4 | 5 | ### 数据集如下 6 | 链接:https://pan.baidu.com/s/1t-MGwuntLgjOwlJKHh3oNg 7 | 提取码:j2yr 8 | 9 | ### 代码主体 10 | - models 11 | 网络定义,包含了textcnn和lstm的网络构建。 12 | - train_eval.py 13 | 训练代码,验证代码,测试单条数据代码,可直接微小改动构建flask服务。 14 | - utils.py 15 | 数据处理部分,主要使用torchtext完成了数据的词典映射,pad,shuffle等操作。 16 | - run.py 17 | 包含了训练和验证代码,以及单句测试 18 | - data 19 | data文件夹从百度网盘下载,直接考入即可。 20 | - data_tag 21 | 过程中生成的文件,包括模型和日志。 22 | ### 结果 23 | 结果具有随机性,大致差不多如下: 24 | ![avatar](https://github.com/mathCrazyy/text_classify/tree/master/pic/eval.png) 25 | 26 | ![avatar](https://github.com/mathCrazyy/text_classify/tree/master/pic/test.png) 27 | 28 | ### 代码对应的博客地址: 29 | https://blog.csdn.net/qq_25992377/article/details/105012948 30 | https://blog.csdn.net/qq_25992377/article/details/105013476 31 | https://blog.csdn.net/qq_25992377/article/details/105019786 32 | 33 | ### reference 34 | https://github.com/649453932/Chinese-Text-Classification-Pytorch 35 | http://mlexplained.com/2018/02/08/a-comprehensive-tutorial-to-torchtext/ 36 | en548708 37 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torchtext.data import Field 2 | from torchtext.data import Iterator,BucketIterator 3 | from torchtext.vocab import Vectors 4 | 5 | import torch 6 | from torchtext.data import TabularDataset 7 | 8 | def generate_data(config): 9 | ## 不同字段的操作定义 10 | 11 | tokenizer = lambda x: [one for one in x] 12 | TEXT = Field(sequential=True, tokenize=tokenizer,fix_length=config.sen_max_length)##截断句长直接影响acc!!! 13 | LABEL = Field(sequential=False, use_vocab=False) ## 如果标签是数值型的话 14 | 15 | datafields = [("context", TEXT), ("label_id", LABEL)] ## TEXT field, LABEL field 16 | test_field = [("context", TEXT), ("label_id", LABEL)] 17 | train_file, valid_file = TabularDataset.splits( 18 | path=config.data_ori, 19 | train=config.train_path, 20 | validation=config.valid_path, 21 | format="csv", 22 | skip_header=True, 23 | fields=datafields 24 | ) 25 | test_file = TabularDataset( 26 | path=config.data_ori+config.test_path, 27 | format="csv", 28 | skip_header=True, 29 | fields=test_field 30 | ) 31 | ## 构建词典 32 | vectors=Vectors(name=config.data_ori+config.embedding_path,cache="./") 33 | TEXT.build_vocab(train_file,max_size=config.vocab_maxsize, min_freq=config.vocab_minfreq, vectors=vectors) 34 | TEXT.vocab.set_vectors(vectors.stoi, vectors.vectors, vectors.dim) 35 | 36 | train_iter, val_iter = BucketIterator.splits( 37 | (train_file, valid_file), 38 | batch_sizes=(config.batch_size, config.batch_size), 39 | device=config.device, 40 | sort_key=lambda x: len(x.context), 41 | sort_within_batch=True, 42 | # 当要使用pack_padded_sequence时,需要将sort_within_batch设置为True,同时会将paded sequence 转为PackedSequence对象 43 | repeat=False 44 | ) 45 | 46 | test_iter = Iterator(test_file, batch_size=config.batch_size, device=config.device, sort=False, sort_within_batch=False, repeat=False) 47 | 48 | return train_iter, val_iter, test_iter, TEXT 49 | 50 | if __name__=="__main__": 51 | print("test data") 52 | #train_iter, valid_iter, test_iter=generate_data(file_path) 53 | 54 | #a=list(train_iter) 55 | #print(a[0]) 56 | #print(a[0].context) 57 | -------------------------------------------------------------------------------- /models/lstm_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import codecs 4 | import numpy as np 5 | import re 6 | 7 | class Config(object): 8 | def __init__(self,data_ori, data_tgt): 9 | self.model_name="lstm_embedding" 10 | self.data_ori=data_ori+"/" 11 | self.train_path="train.csv" 12 | self.valid_path="valid.csv" 13 | self.test_path="test.csv" 14 | self.embedding_path="need_bertembedding" 15 | 16 | self.sen_max_length=150 17 | 18 | self.embedding_dim=768 19 | self.hidden_dim=128 20 | self.class_num=10 21 | self.num_lstm_layers=1 22 | self.num_linear=2 23 | self.dropout=0.3 24 | self.batch_size=64 25 | self.learning_rate=1e-3 26 | self.epochs = 10 27 | 28 | ### 构建词典 29 | self.vocab_maxsize = 4000 30 | self.vocab_minfreq = 10 31 | 32 | self.save_path=data_tgt+self.model_name+".ckpt" 33 | self.log_path=data_tgt+"/log/"+self.model_name 34 | 35 | self.device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 36 | self.print_cricle=100 37 | self.require_improvement=1000 38 | 39 | 40 | 41 | 42 | class Model(nn.Module): 43 | #def __init__(self, n_vocab, hidden_dim, emb_dim=100, num_linear=1): 44 | def __init__(self, config): 45 | super().__init__() 46 | lines = codecs.open(config.data_ori+config.embedding_path, encoding="utf-8") 47 | #pattern = re.compile("[\u4e00-\u9fa5a-zA-Z0-9]") 48 | embeddings_vec = [line.replace("\n", "") for line in lines][1:-1] 49 | 50 | embeddings = np.random.rand(len(embeddings_vec), config.embedding_dim) 51 | for index, line in enumerate(embeddings_vec): 52 | line_seg = line.split(" ") 53 | try: 54 | embeddings[index] = [float(one) for one in line_seg[1:]] 55 | except: 56 | # print(embeddings[index]) 57 | pass 58 | 59 | pretrained_weight = np.array(embeddings) 60 | embeds = nn.Embedding(len(embeddings), config.embedding_dim) 61 | embeds.weight.data.copy_(torch.from_numpy(pretrained_weight)) 62 | 63 | # self.embedding=nn.Embedding(n_vocab,emb_dim) 64 | self.embedding = embeds 65 | self.encoder = nn.LSTM(config.embedding_dim, config.hidden_dim, num_layers=config.num_lstm_layers, dropout=config.dropout) 66 | self.dropout = nn.Dropout(config.dropout) 67 | self.linear_layers = [] 68 | for _ in range(config.num_linear): 69 | self.linear_layers.append(nn.Linear(config.hidden_dim, config.hidden_dim)) 70 | self.linear_layers = nn.ModuleList(self.linear_layers) 71 | 72 | self.predictor = nn.Linear(config.hidden_dim, 10) 73 | 74 | def forward(self, seq): 75 | hdn, _ = self.encoder(self.embedding(seq)) 76 | feature = hdn[-1, :, :] 77 | feature = self.dropout(feature) 78 | for layer in self.linear_layers: 79 | feature = layer(feature) 80 | preds = self.predictor(feature) 81 | return preds 82 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | from importlib import import_module 5 | from utils import generate_data 6 | from train_eval import train, test 7 | 8 | 9 | import argparse 10 | 11 | parser=argparse.ArgumentParser(description="文本分类") 12 | parser.add_argument("--model", type=str,required=True, help="choose a model lstm") 13 | parser.add_argument("--embedding",default="pre_trained",type=str,help="random or pre_trained") 14 | parser.add_argument("--data_path",type=str,default="data/",help="all pred_files") 15 | parser.add_argument("--target_path",type=str,default="data_tgt/",help="all files generated") 16 | 17 | 18 | 19 | args=parser.parse_args() 20 | 21 | print(args) 22 | 23 | if __name__=="__main__": 24 | model_name=args.model 25 | data_path=args.data_path 26 | target_path=args.target_path 27 | 28 | which_model=import_module("models."+model_name) 29 | config=which_model.Config(data_path,target_path) 30 | 31 | np.random.seed(1) 32 | torch.manual_seed(1) 33 | torch.cuda.manual_seed_all(1) 34 | torch.backends.cudnn.deterministc=True 35 | 36 | start_time=time.time() 37 | train_iter, valid_iter, test_iter, TEXT=generate_data(config) 38 | end_time=time.time() 39 | print("time usage: ",end_time-start_time) 40 | model=which_model.Model(config).to(config.device) 41 | train(config,model,train_iter,valid_iter, test_iter) 42 | 43 | ## 测试单条句子 44 | sentence="哈哈哈哈你好啊啊" 45 | sentence="""鲍勃库西奖归谁属? NCAA最强控卫是坎巴还是弗神新浪体育讯如今,本赛季的NCAA进入到了末段,各项奖项的评选结果也即将出炉, 46 | 其中评选最佳控卫的鲍勃-库西奖就将在下周最终四强战时公布,鲍勃-库西奖是由奈史密斯篮球名人堂提供,旨在奖励年度最佳大学控 47 | 卫。最终获奖的球员也即将在以下几名热门人选中产生。〈〈〈 NCAA疯狂三月专题主页上线,点击链接查看精彩内容吉梅尔-弗雷戴特 48 | ,杨百翰大学“弗神”吉梅尔-弗雷戴特一直都备受关注,他不仅仅是一名射手,他会用“终结对手脚踝”一样的变向过掉面前的防守>者,并且他可以用任意一支手完成得分,如果他被犯规了,可以提前把这两份划入他的帐下了,因为他是一名命中率高达90%的罚球手>。弗雷戴特具有所有伟大控卫都具备的一点特质,他是一位赢家也是一位领导者。“他整个赛季至始至终的稳定领导着球队前进,这是 49 | 无可比拟的。”杨百翰大学主教练戴夫-罗斯称赞道,“他的得分能力毋庸置疑,但是我认为他带领球队获胜的能力才是他最重要的控>卫职责。我们在主场之外的比赛(客场或中立场)共取胜19场,他都表现的很棒。”弗雷戴特能否在NBA取得成功?当然,但是有很多专>业人士比我们更有资格去做出这样的判断。“我喜爱他。”凯尔特人主教练多克-里弗斯说道,“他很棒,我看过ESPN的片段剪辑,从>剪辑来看,他是个超级巨星,我认为他很成为一名优秀的NBA球员。”诺兰-史密斯,杜克大学当赛季初,球队宣布大一天才控卫凯瑞->厄尔文因脚趾的伤病缺席赛季大部分比赛后,诺兰-史密斯便开始接管球权,他在进攻端上足发条,在ACC联盟(杜克大学所在分区)的得 50 | 分榜上名列前茅,但同时他在分区助攻榜上也占据头名,这在众强林立的ACC联盟前无古人。“我不认为全美有其他的球员能在凯瑞-厄 51 | 尔文受伤后,如此好的接管球队,并且之前毫无准备。”杜克主教练迈克-沙舍夫斯基赞扬道,“他会将比赛带入自己的节奏,得分,>组织,领导球队,无所不能。而且他现在是攻防俱佳,对持球人的防守很有提高。总之他拥有了辉煌的赛季。”坎巴-沃克,康涅狄格>大学坎巴-沃克带领康涅狄格在赛季初的毛伊岛邀请赛一路力克密歇根州大和肯塔基等队夺冠,他场均30分4助攻得到最佳球员。在大东 52 | 赛区锦标赛和全国锦标赛中,他场均27.1分,6.1个篮板,5.1次助攻,依旧如此给力。他以疯狂的表现开始这个赛季,也将以疯狂的表 53 | 现结束这个赛季。“我们在全国锦标赛中前进着,并且之前曾经5天连赢5场,赢得了大东赛区锦标赛的冠军,这些都归功于坎巴-沃克>。”康涅狄格大学主教练吉姆-卡洪称赞道,“他是一名纯正的控卫而且能为我们得分,他有过单场42分,有过单场17助攻,也有过单>场15篮板。这些都是一名6英尺175镑的球员所完成的啊!我们有很多好球员,但他才是最好的领导者,为球队所做的贡献也是最大。” 54 | 乔丹-泰勒,威斯康辛大学全美没有一个持球者能像乔丹-泰勒一样很少失误,他4.26的助攻失误在全美遥遥领先,在大十赛区的比赛中 55 | ,他平均35.8分钟才会有一次失误。他还是名很出色的得分手,全场砍下39分击败印第安纳大学的比赛就是最好的证明,其中下半场他 56 | 曾经连拿18分。“那个夜晚他证明自己值得首轮顺位。”当时的见证者印第安纳大学主教练汤姆-克雷恩说道。“对一名控卫的所有要>求不过是领导球队、使球队变的更好、带领球队成功,乔丹-泰勒全做到了。”威斯康辛教练博-莱恩说道""" 57 | print(sentence) 58 | res= test(config, model,TEXT,sentence) 59 | print(res) 60 | sentence="""景顺长城就参与新股询价问题作出说明⊙本报记者 黄金滔 安仲文 景顺长城基金管理有限公司11日在其网站上就参与新股询价有关问 61 | 题作出说明,表示该公司本着独立、客观、诚信的原则参与新股询价,遵守相关法律法规和公司内部制度,并切实保障了基金持有人利 62 | 益,并表示将严肃对待中国证券业协会对其提出的自律处理。景顺长城表示,作为中国证券业协会认定的IPO询价对象,该公司认真履>行询价义务,2008年度共参与7只新股询价。根据该公司报价区间和实际发行价格的比较,不存在较大价格偏离现象及操纵价格嫌疑。>对于公司参与询价的股票,除其中1只股票由于公司旗下基金参与了同一发行期内另外5只股票网上申购而没有申购外,其余6只股票全>部参与网上申购。据了解,根据《关于基金投资非公开发行股票等流通受限证券有关问题的通知》(证监基金字[2006]141号),为了>切实保护基金持有人利益、防范基金资产的流动性风险,景顺长城公司于2006年7月修改投资管理制度,明确规定“不投资带有锁定期>的证券”,因此,从2006年8月起,该公司没有投资任何带有锁定期的股票,包括IPO网下申购;并且在每次参与新股询价时,公司在递 63 | 交的《新股询价信息表》中的“在报价区间的申购意向”一栏均明确填写“不确定”或“不申购”。不过,景顺长城也强调,由于中国 64 | 证券市场处于快速发展时期,法规制度不断完善,该公司将严肃对待中国证券业协会对其提出的自律处理,一如既往地以审慎诚信的态 65 | 度为投资人服务""" 66 | res=test(config, model , TEXT, sentence) 67 | print(res) 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /models/textcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import codecs 4 | import numpy as np 5 | import re 6 | import torch.nn.functional as F 7 | 8 | class Config(object): 9 | def __init__(self,data_ori, data_tgt): 10 | self.model_name="textcnn" 11 | self.data_ori=data_ori+"/" 12 | self.train_path="train_100.csv" 13 | self.valid_path="valid_100.csv" 14 | self.test_path="test_100.csv" 15 | self.embedding_path="need_bertembedding" 16 | 17 | self.sen_max_length=150 18 | 19 | self.embedding_dim=768 20 | self.hidden_dim=128 21 | self.class_num=10 22 | #self.num_lstm_layers=1 23 | #self.num_linear=2 24 | self.dropout=0.3 25 | self.batch_size=64 26 | self.learning_rate=1e-3 27 | self.epochs = 10 28 | 29 | self.filter_sizes=(2,3,4) 30 | self.num_filters=256## channels数目 31 | 32 | ### 构建词典 33 | self.vocab_maxsize = 4000 34 | self.vocab_minfreq = 10 35 | 36 | self.save_path=data_tgt+self.model_name+".ckpt" 37 | self.log_path=data_tgt+"/log/"+self.model_name 38 | 39 | self.device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 40 | self.print_cricle=100 41 | self.require_improvement=1000 42 | 43 | class Model(nn.Module): 44 | #def __init__(self, n_vocab, hidden_dim, emb_dim=100, num_linear=1): 45 | def __init__(self, config): 46 | super().__init__() 47 | lines = codecs.open(config.data_ori+config.embedding_path, encoding="utf-8") 48 | #pattern = re.compile("[\u4e00-\u9fa5a-zA-Z0-9]") 49 | embeddings_vec = [line.replace("\n", "") for line in lines][1:-1] 50 | 51 | embeddings = np.random.rand(len(embeddings_vec), config.embedding_dim) 52 | for index, line in enumerate(embeddings_vec): 53 | line_seg = line.split(" ") 54 | try: 55 | embeddings[index] = [float(one) for one in line_seg[1:]] 56 | except: 57 | # print(embeddings[index]) 58 | pass 59 | 60 | pretrained_weight = np.array(embeddings) 61 | embeds = nn.Embedding(len(embeddings), config.embedding_dim) 62 | embeds.weight.data.copy_(torch.from_numpy(pretrained_weight)) 63 | 64 | # self.embedding=nn.Embedding(n_vocab,emb_dim) 65 | self.embedding = embeds 66 | self.conv=nn.Conv2d(1,config.num_filters,(3,config.embedding_dim)) 67 | self.convs=nn.ModuleList( 68 | [nn.Conv2d(1,config.num_filters,(k,config.embedding_dim))for k in config.filter_sizes] 69 | ) 70 | self.dropout = nn.Dropout(config.dropout) 71 | self.fc=nn.Linear(config.num_filters*len(config.filter_sizes),config.class_num) # 每个核心的大小,出一个向量 72 | 73 | 74 | def conv_and_pool(self,inputs, conv): 75 | inputs=F.relu(conv(inputs))## 这里卷积后,会把embeddings所在那一层抹掉 76 | inputs=inputs.squeeze(3) 77 | inputs=F.max_pool1d(inputs,inputs.size(2))## 这里最大池化层后,会把 抹掉 78 | inputs=inputs.squeeze(2) 79 | #print(inputs.shape) 80 | return inputs 81 | 82 | def forward(self, seq): 83 | ## 这里很方,觉得大致思路就是,从输入到输出,最后在接fc层的时候,怼成一个[batch_size, n]大小的矩阵就可以了。n是描述特征的 84 | #print(seq.shape) 85 | seq_embedings = self.embedding(seq.t()) 86 | #print("加了embedding然后batch放到第一维度", seq_embedings.shape) 87 | seq_embedings_batch=seq_embedings.unsqueeze(1) 88 | #print("再加一个维度",seq_embedings_batch.shape) 89 | ##把所有核心的结果连在一起 90 | xx=self.conv_and_pool(seq_embedings_batch,self.conv) 91 | #print(xx.shape) 92 | concat_res=torch.cat([self.conv_and_pool(seq_embedings_batch,conv)for conv in self.convs],1)## 注意concat的方向 93 | #print("所有结果链接在一起:", concat_res.shape) 94 | out=self.dropout(concat_res) 95 | out=self.fc(out) 96 | #print("outshape: ",out.shape) 97 | return out 98 | 99 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | 3 | #from model_embedding import Model 4 | #from models.lstm_embedding import Model 5 | 6 | import torch 7 | import tqdm 8 | from torchtext.data import Iterator, BucketIterator 9 | import torch.nn.functional as F 10 | from sklearn.metrics import classification_report 11 | from sklearn import metrics 12 | 13 | import numpy as np 14 | from tensorboardX import SummaryWriter 15 | import time 16 | 17 | def evaluate(config, model, eval_iter, test=False): 18 | model.eval() 19 | val_loss = 0.0 20 | all_acc = 0.0 21 | predict_all = np.array([], dtype=int) 22 | labels_all = np.array([], dtype=int) 23 | loss_total=0.0 24 | with torch.no_grad(): 25 | for batch in eval_iter: 26 | preds = model(batch.context) 27 | loss=F.cross_entropy(preds.cpu(),batch.label_id.cpu()) 28 | loss_total+=loss.item() 29 | predic=torch.max(preds.data,1)[1].cpu().numpy() 30 | predict_all = np.append(predict_all, predic) 31 | labels_all = np.append(labels_all, batch.label_id.cpu()) 32 | 33 | acc=metrics.accuracy_score(labels_all,predict_all) 34 | # class_report=classification_report(predict_all,labels_all) 35 | class_report = classification_report(labels_all, predict_all) 36 | # print(labels_all) 37 | # print(predict_all) 38 | return loss_total/len(eval_iter),acc, class_report 39 | 40 | 41 | def test(config, model, TEXT, sentence): 42 | sentence_seq=[TEXT.vocab.stoi[one] for one in sentence] 43 | need_pad=config.sen_max_length-len(sentence_seq) 44 | for _ in range(need_pad): 45 | sentence_seq.append(1) 46 | 47 | example=torch.Tensor(sentence_seq).long().to(config.device) 48 | example=example.unsqueeze(1) 49 | 50 | preds=model(example) 51 | predic=torch.max(preds.data,1)[1].cpu().numpy() 52 | return predic 53 | 54 | 55 | 56 | def train(config, model, train_iter, valid_iter, test_iter): 57 | model.train() 58 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 59 | 60 | eval_best_loss = float("inf") 61 | epochs = config.epochs 62 | writer=SummaryWriter(log_dir=config.log_path+"/"+time.strftime("%m-%d_%H.%M",time.localtime())) 63 | total_batch=0 64 | last_improve=0 65 | flag=False 66 | 67 | for epoch in range(1, epochs + 1): 68 | # if epoch % 5 ==0: 69 | # for p in optimizer.param_groups: 70 | # p["lr"]*=0.9 71 | running_loss = 0.0 72 | runing_corrects = 0 73 | model.train() 74 | 75 | for batch in train_iter: 76 | total_batch+=1 77 | model.zero_grad() 78 | preds = model(batch.context) 79 | ## 应当是10行,5列的样子 80 | # y_p = batch.label_id.squeeze(1) 81 | y_p = batch.label_id 82 | # https://blog.csdn.net/ccbrid/article/details/90610599 83 | loss = F.cross_entropy(preds, y_p.long()).to(config.device) 84 | 85 | loss.backward() 86 | optimizer.step() 87 | if(total_batch%100==0): 88 | pred_res = torch.max(preds.data, 1)[1].cpu() 89 | train_acc = metrics.accuracy_score(y_p.cpu(), pred_res) 90 | eval_loss, eval_acc, eval_report = evaluate(config, model, valid_iter) 91 | test_loss, test_acc, test_report = evaluate(config, model, test_iter) 92 | print("train_loss: ",loss,"train_acc: ",train_acc,total_batch) 93 | print("eval_loss: ",eval_loss,"eval_acc: ",eval_acc,total_batch) 94 | print("test_loss: ",test_loss,"test_acc: ",test_acc,total_batch) 95 | if eval_loss < eval_best_loss: 96 | eval_best_loss = eval_loss 97 | torch.save(model.state_dict(), config.save_path) 98 | last_improve=total_batch 99 | 100 | writer.add_scalar("loss/train",loss.item(),total_batch) 101 | writer.add_scalar("loss/dev",eval_loss,total_batch) 102 | writer.add_scalar("acc/train",train_acc,total_batch) 103 | writer.add_scalar("acc/dev",eval_acc,total_batch) 104 | model.train() 105 | if total_batch-last_improve>config.require_improvement and last_improve!=0: 106 | print(total_batch-last_improve) 107 | print(config.require_improvement) 108 | print("超过",config.require_improvement,"轮次没有提升并退出") 109 | print("eval_report: ", eval_report) 110 | print("test_report", test_report) 111 | flag=True 112 | break 113 | if flag: 114 | break 115 | 116 | 117 | 118 | 119 | 120 | """ 121 | 122 | config = Config() 123 | 124 | nh = 64 125 | model = SimpleLSTMBaseline(TEXT, nh).to(device) 126 | 127 | train(config, model, train_iter, valid_iter, test_iter) 128 | 129 | 130 | """ 131 | 132 | 133 | --------------------------------------------------------------------------------