├── .DS_Store ├── README.md ├── data ├── .DS_Store └── training_sample.txt ├── img └── CNN_result.png ├── requirements.txt ├── setup.py ├── src ├── .DS_Store ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── preprocess.cpython-37.pyc ├── config.py ├── config │ ├── base.py │ └── text_classifier.py ├── data │ ├── __init__.py │ ├── augmentation.py │ ├── base.py │ ├── cnews_dataset.py │ ├── dataset.py │ └── preprocess.py ├── evaluate.py ├── evaluate │ ├── base.py │ └── text_classifier.py ├── main.py ├── model.py ├── models │ ├── __init__.py │ ├── base.py │ ├── bert.py │ ├── lstm.py │ └── textcnn.py ├── preprocess.py ├── train.py ├── train │ ├── base.py │ └── text_classifier.py └── utils │ ├── __init__.py │ ├── metrics.py │ └── visualization.py └── tests ├── test_data.py └── test_models.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 中文文本分类系统 2 | 3 | 一个基于深度学习的中文文本分类系统,支持多种模型架构和训练策略。 4 | 5 | ## 功能特性 6 | 7 | - 支持多种深度学习模型: 8 | - TextCNN:使用卷积神经网络进行文本分类 9 | - LSTM:使用长短期记忆网络进行文本分类 10 | - BERT:使用预训练的中文BERT模型进行文本分类 11 | 12 | - 数据处理功能: 13 | - 文本预处理(URL移除、邮件移除、标点符号处理等) 14 | - 数据增强(同义词替换、随机删除、随机交换、随机插入) 15 | - 支持自定义词表和标签映射 16 | 17 | - 训练功能: 18 | - 支持多种优化器(Adam、SGD) 19 | - 学习率调度 20 | - 早停机制 21 | - 模型检查点保存 22 | - TensorBoard可视化 23 | 24 | - 评估功能: 25 | - 分类报告生成 26 | - 混淆矩阵可视化 27 | - ROC曲线分析 28 | - 错误样本分析 29 | - 特征重要性分析(适用于可解释模型) 30 | 31 | - 配置管理: 32 | - 支持JSON和YAML格式的配置文件 33 | - 配置验证和默认值 34 | - 灵活的配置更新机制 35 | 36 | ## 项目结构 37 | 38 | ``` 39 | Text_Classification/ 40 | ├── data/ # 数据目录 41 | │ ├── raw/ # 原始数据 42 | │ ├── processed/ # 处理后的数据 43 | │ └── vocab/ # 词表文件 44 | ├── models/ # 模型目录 45 | │ ├── checkpoints/ # 模型检查点 46 | │ └── logs/ # TensorBoard日志 47 | ├── evaluation/ # 评估结果 48 | ├── src/ # 源代码 49 | │ ├── config/ # 配置管理 50 | │ ├── data/ # 数据处理 51 | │ ├── models/ # 模型定义 52 | │ ├── train/ # 训练相关 53 | │ ├── evaluate/ # 评估相关 54 | │ └── utils/ # 工具函数 55 | ├── tests/ # 测试代码 56 | ├── notebooks/ # Jupyter notebooks 57 | ├── requirements.txt # 项目依赖 58 | └── setup.py # 安装脚本 59 | ``` 60 | 61 | ## 安装说明 62 | 63 | 1. 克隆项目: 64 | ```bash 65 | git clone https://github.com/yourusername/Text_Classification.git 66 | cd Text_Classification 67 | ``` 68 | 69 | 2. 创建虚拟环境(推荐): 70 | ```bash 71 | python -m venv venv 72 | source venv/bin/activate # Linux/Mac 73 | # 或 74 | venv\Scripts\activate # Windows 75 | ``` 76 | 77 | 3. 安装依赖: 78 | ```bash 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | 4. 安装项目: 83 | ```bash 84 | pip install -e . 85 | ``` 86 | 87 | ## 使用说明 88 | 89 | 1. 准备数据: 90 | - 将训练数据放在 `data/raw/` 目录下 91 | - 数据格式:每行一个样本,格式为 "标签\t文本" 92 | 93 | 2. 配置模型: 94 | - 复制 `configs/default.yaml` 为 `configs/my_config.yaml` 95 | - 根据需要修改配置参数 96 | 97 | 3. 训练模型: 98 | ```python 99 | from src.config import TextClassifierConfig 100 | from src.train import TextClassifierTrainer 101 | from src.data import CNewsDataset 102 | 103 | # 加载配置 104 | config = TextClassifierConfig('configs/my_config.yaml') 105 | 106 | # 准备数据 107 | dataset = CNewsDataset(config.get_data_config()) 108 | 109 | # 创建训练器 110 | trainer = TextClassifierTrainer( 111 | model_type=config['model']['model_type'], 112 | model_config=config.get_model_config(), 113 | dataset=dataset, 114 | training_config=config.get_training_config() 115 | ) 116 | 117 | # 训练模型 118 | trainer.train() 119 | ``` 120 | 121 | 4. 评估模型: 122 | ```python 123 | from src.evaluate import TextClassifierEvaluator 124 | 125 | # 创建评估器 126 | evaluator = TextClassifierEvaluator( 127 | model=trainer.model, 128 | dataset=dataset, 129 | config=config.get_evaluation_config() 130 | ) 131 | 132 | # 评估模型 133 | metrics = evaluator.evaluate(test_data) 134 | ``` 135 | 136 | 5. 使用模型预测: 137 | ```python 138 | # 预测单个文本 139 | text = "这是一条测试文本" 140 | prediction, probability = evaluator.predict([text]) 141 | print(f"预测类别: {prediction[0]}") 142 | print(f"预测概率: {probability[0]}") 143 | ``` 144 | 145 | ## 配置说明 146 | 147 | 配置文件支持JSON和YAML格式,主要包含以下配置节: 148 | 149 | - `data`: 数据相关配置 150 | - `train_file`: 训练数据文件路径 151 | - `val_file`: 验证数据文件路径 152 | - `test_file`: 测试数据文件路径 153 | - `max_length`: 文本最大长度 154 | - `preprocessing`: 文本预处理选项 155 | 156 | - `model`: 模型相关配置 157 | - `model_type`: 模型类型(textcnn/lstm/bert) 158 | - `vocab_size`: 词表大小 159 | - `embedding_dim`: 词向量维度 160 | - `num_classes`: 类别数量 161 | - 模型特定配置(如TextCNN的filter_sizes等) 162 | 163 | - `training`: 训练相关配置 164 | - `batch_size`: 批处理大小 165 | - `epochs`: 训练轮数 166 | - `learning_rate`: 学习率 167 | - `optimizer`: 优化器类型 168 | - 其他训练参数(早停、学习率调度等) 169 | 170 | - `augmentation`: 数据增强配置 171 | - 各种增强方法的参数 172 | 173 | - `evaluation`: 评估相关配置 174 | - 评估指标和可视化选项 175 | 176 | ## 开发说明 177 | 178 | 1. 代码风格: 179 | - 遵循PEP 8规范 180 | - 使用类型注解 181 | - 编写详细的文档字符串 182 | 183 | 2. 测试: 184 | - 运行单元测试:`pytest tests/` 185 | - 运行代码覆盖率:`pytest --cov=src tests/` 186 | 187 | 3. 贡献: 188 | - Fork项目 189 | - 创建特性分支 190 | - 提交更改 191 | - 发起Pull Request 192 | 193 | ## 结果 194 | ### CNN 195 | 速度相当快,效果也不错,precision与recall都趋近于0.9 196 | ![image](https://github.com/sun830910/Text_Classification/blob/master/img/CNN_result.png) 197 | 198 | 199 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/data/.DS_Store -------------------------------------------------------------------------------- /data/training_sample.txt: -------------------------------------------------------------------------------- 1 | 体育 马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的困扰。7月31日下午6点,国奥队的日常训练再度受到大雨的干扰,无奈之下队员们只慢跑了25分钟就草草收场。31日上午10点,国奥队在奥体中心外场训练的时候,天就是阴沉沉的,气象预报显示当天下午沈阳就有大雨,但幸好队伍上午的训练并没有受到任何干扰。下午6点,当球队抵达训练场时,大雨已经下了几个小时,而且丝毫没有停下来的意思。抱着试一试的态度,球队开始了当天下午的例行训练,25分钟过去了,天气没有任何转好的迹象,为了保护球员们,国奥队决定中止当天的训练,全队立即返回酒店。在雨中训练对足球队来说并不是什么稀罕事,但在奥运会即将开始之前,全队变得“娇贵”了。在沈阳最后一周的训练,国奥队首先要保证现有的球员不再出现意外的伤病情况以免影响正式比赛,因此这一阶段控制训练受伤、控制感冒等疾病的出现被队伍放在了相当重要的位置。而抵达沈阳之后,中后卫冯萧霆就一直没有训练,冯萧霆是7月27日在长春患上了感冒,因此也没有参加29日跟塞尔维亚的热身赛。队伍介绍说,冯萧霆并没有出现发烧症状,但为了安全起见,这两天还是让他静养休息,等感冒彻底好了之后再恢复训练。由于有了冯萧霆这个例子,因此国奥队对雨中训练就显得特别谨慎,主要是担心球员们受凉而引发感冒,造成非战斗减员。而女足队员马晓旭在热身赛中受伤导致无缘奥运的前科,也让在沈阳的国奥队现在格外警惕,“训练中不断嘱咐队员们要注意动作,我们可不能再出这样的事情了。”一位工作人员表示。从长春到沈阳,雨水一路伴随着国奥队,“也邪了,我们走到哪儿雨就下到哪儿,在长春几次训练都被大雨给搅和了,没想到来沈阳又碰到这种事情。”一位国奥球员也对雨水的“青睐”有些不解。 2 | 体育 商瑞华首战复仇心切 中国玫瑰要用美国方式攻克瑞典多曼来了,瑞典来了,商瑞华首战求3分的信心也来了。距离首战72小时当口,中国女足彻底从“恐瑞症”当中获得解脱,因为商瑞华已经找到了瑞典人的软肋。找到软肋,保密4月20日奥运会分组抽签结果出来后,中国姑娘就把瑞典锁定为关乎奥运成败的头号劲敌,因为除了浦玮等个别老将之外,现役女足将士竟然没有人尝过击败瑞典的滋味。在中瑞两队共计15次交锋的历史上,中国队6胜3平6负与瑞典队平分秋色,但从2001年起至今近8年时间,中国在同瑞典连续5次交锋中均未尝胜绩,战绩为2平3负。尽管八年不胜瑞典曾一度成为谢亚龙聘请多曼斯基的理由之一,但这份战绩表也成为压在姑娘们身上的一座大山,在奥运备战过程中越发凸显沉重。或许正源于此,商瑞华才在首战前3天召开了一堂完全针对“恐瑞症”的战术分析课。3日中午在喜来登大酒店中国队租用的会议室里,商瑞华给大家播放了瑞典队的比赛录像剪辑。这是7月6日瑞典在主场同美国队进行的一场奥运热身赛,当时美国队头牌射手瓦姆巴赫还没有受伤,比赛当中双方均尽遣主力,占据主场优势的瑞典队曾一度占据上风,但终究还是在定位球防守上吃了亏,美国队在下半场的一次角球配合中,通过远射打进唯一进球,尽管慢动作显示是打在瑞典队后卫腿上弹射入网,但从过程到结果,均显示了相同内容——“瑞典队的防守并非无懈可击”。商瑞华让科研教练曹晓东等人对这场比赛进行精心剪辑,尤其是瑞典队失球以及美国队形成有威胁射门的片段,更是被放大进行动作分解,每一个中国姑娘都可以一目了然地看清瑞典队哪些地方有机可乘。甚至之前被商瑞华称为“恐怖杀手”的瑞典8号谢琳,也在这次战术分解过程中被发现了不足之处。姑娘们心知肚明并开心享受对手软肋被找到的欢悦,但却必须对记者保密,某主力球员说:“这可不能告诉外界,反正我们心里有数了,知道对付这个速度奇快的谢琳该怎么办!”老帅的“瑞典情结”就像中国队8年不胜瑞典一样,瑞典队也连续遭遇对美国队的溃败:去年阿尔加夫杯瑞典0比1不敌美国,世界杯小组赛上瑞典0比2被美国完胜,今年阿尔加夫杯美国人再次击败瑞典,算上7月6日一役,瑞典遇到美国连平局都没有,竟然是4连败!3日中午的这堂战术分析课后,“用美国人的方式击败瑞典”已经成为中国女足将士的共同心声。姑娘们当然有理由这样去憧憬,因为在7月30日奥运会前最后一场热身赛中,中国女足便曾0比0与强大的美国队握手言和。在3日中午的战术分析课上,这场中美热身也被梳理出片段,与姑娘们一道完成总结。点评过程中,商瑞华对大家在同美国队比赛中表现出来的逼抢意识,给予很高评价。主帅的认可和称赞,更是在大家潜意识里强化了“逼抢的价值”,就连排在18人名单外的候补球员都说:“既然我们能跟最擅长逼抢的美国队玩逼抢,当然有信心跟瑞典队也这么打”,姑娘们都对打好首战的信心越来越强。“我们当然需要低调,但内心深处已经再也没有对瑞典的恐惧,只要把我们训练中的内容打出来,就完全有可能击败瑞典”,这堂战术讨论课后,不止一名球员向记者表达着对首战获胜的渴望。商瑞华心中的“瑞典情结”其实最重,不仅仅因为他是主帅,17年前商瑞华首次担任中国女足主帅时所遭遇的滑铁卢,正是源自1991年世界杯中国0比1被瑞典淘汰所致。抽签出来后面对与瑞典同组,64岁的老帅那份复仇的雄心也潜滋暗长,在5月上旬和7月中旬,商瑞华曾连续两次前往欧洲在现场刺探瑞典军情,真正对瑞典队的特点达到了如指掌的程度。昨天中午的战术讨论课后,商瑞华告诉记者:“瑞典队和所有其他队伍一样有优点也有不足,我们如果能够扬长避短拿对方的不足做文章,就有可能击败对手。从奥运会备战阶段看来,中国队战术上打得比较快的特点基本已经成型,边路也不错,在前场抢截后快速发动进攻的能力也逐步增强。我很清楚瑞典队世界排名第3而我们排第14,但承认差距不等于接受失败,我当然想赢。” 3 | -------------------------------------------------------------------------------- /img/CNN_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/img/CNN_result.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.8.0 2 | numpy>=1.19.2 3 | pandas>=1.2.0 4 | scikit-learn>=0.24.0 5 | matplotlib>=3.3.0 6 | tqdm>=4.50.0 7 | pytest>=6.0.0 8 | black>=21.5b2 9 | flake8>=3.9.0 10 | jupyter>=1.0.0 11 | transformers>=4.5.0 12 | torch>=1.8.0 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | with open("requirements.txt", "r", encoding="utf-8") as fh: 7 | requirements = fh.read().splitlines() 8 | 9 | setup( 10 | name="text-classification", 11 | version="0.1.0", 12 | author="Justin", 13 | author_email="your.email@example.com", 14 | description="中文文本分类项目", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/yourusername/Text_Classification", 18 | packages=find_packages(), 19 | classifiers=[ 20 | "Development Status :: 3 - Alpha", 21 | "Intended Audience :: Science/Research", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | ], 29 | python_requires=">=3.8", 30 | install_requires=requirements, 31 | ) 32 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/.DS_Store -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/preprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/__pycache__/preprocess.cpython-37.pyc -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Created on 2020-07-19 00:20 5 | @Author : Justin Jiang 6 | @Email : jw_jiang@pku.edu.com 7 | 8 | 配置模型、路径、与训练相关参数 9 | """ 10 | 11 | import os 12 | from pathlib import Path 13 | 14 | # 项目根目录 15 | ROOT_DIR = Path(__file__).parent.parent 16 | 17 | # 数据目录 18 | DATA_DIR = ROOT_DIR / "data" 19 | RAW_DATA_DIR = DATA_DIR / "raw" 20 | PROCESSED_DATA_DIR = DATA_DIR / "processed" 21 | MODEL_SAVE_DIR = DATA_DIR / "models" 22 | 23 | # 创建必要的目录 24 | for dir_path in [RAW_DATA_DIR, PROCESSED_DATA_DIR, MODEL_SAVE_DIR]: 25 | dir_path.mkdir(parents=True, exist_ok=True) 26 | 27 | # 数据集配置 28 | DATASET_CONFIG = { 29 | "train_file": RAW_DATA_DIR / "cnews.train.txt", 30 | "val_file": RAW_DATA_DIR / "cnews.val.txt", 31 | "test_file": RAW_DATA_DIR / "cnews.test.txt", 32 | "vocab_file": RAW_DATA_DIR / "cnews.vocab.txt", 33 | "categories": ["体育", "财经", "房产", "家居", "教育", "科技", "时尚", "时政", "游戏", "娱乐"], 34 | "vocab_size": 5000, 35 | "max_sequence_length": 600, 36 | } 37 | 38 | # 模型通用配置 39 | MODEL_CONFIG = { 40 | "embedding_dim": 128, 41 | "num_classes": len(DATASET_CONFIG["categories"]), 42 | "dropout_rate": 0.5, 43 | "learning_rate": 0.001, 44 | "batch_size": 64, 45 | "epochs": 10, 46 | "early_stopping_patience": 3, 47 | } 48 | 49 | # TextCNN 模型配置 50 | TEXTCNN_CONFIG = { 51 | **MODEL_CONFIG, 52 | "num_filters": 128, 53 | "filter_sizes": [2, 3, 4, 5], 54 | "model_name": "textcnn", 55 | } 56 | 57 | # LSTM 模型配置 58 | LSTM_CONFIG = { 59 | **MODEL_CONFIG, 60 | "lstm_units": 128, 61 | "num_layers": 2, 62 | "model_name": "lstm", 63 | } 64 | 65 | # BERT 模型配置 66 | BERT_CONFIG = { 67 | **MODEL_CONFIG, 68 | "bert_model_name": "bert-base-chinese", 69 | "max_sequence_length": 512, 70 | "model_name": "bert", 71 | } 72 | 73 | # 训练配置 74 | TRAIN_CONFIG = { 75 | "log_dir": ROOT_DIR / "logs", 76 | "checkpoint_dir": MODEL_SAVE_DIR / "checkpoints", 77 | "tensorboard_dir": ROOT_DIR / "logs" / "tensorboard", 78 | "save_best_only": True, 79 | "save_weights_only": True, 80 | "monitor": "val_accuracy", 81 | "mode": "max", 82 | } 83 | 84 | # 评估配置 85 | EVAL_CONFIG = { 86 | "metrics": ["accuracy", "precision", "recall", "f1"], 87 | "confusion_matrix": True, 88 | "classification_report": True, 89 | } 90 | 91 | # 日志配置 92 | LOGGING_CONFIG = { 93 | "version": 1, 94 | "disable_existing_loggers": False, 95 | "formatters": { 96 | "standard": { 97 | "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s" 98 | }, 99 | }, 100 | "handlers": { 101 | "console": { 102 | "class": "logging.StreamHandler", 103 | "level": "INFO", 104 | "formatter": "standard", 105 | "stream": "ext://sys.stdout", 106 | }, 107 | "file": { 108 | "class": "logging.FileHandler", 109 | "level": "INFO", 110 | "formatter": "standard", 111 | "filename": ROOT_DIR / "logs" / "app.log", 112 | "mode": "a", 113 | }, 114 | }, 115 | "loggers": { 116 | "": { 117 | "handlers": ["console", "file"], 118 | "level": "INFO", 119 | "propagate": True 120 | } 121 | } 122 | } 123 | 124 | class Config(object): 125 | def __init__(self): 126 | self.config_dict = { 127 | "data_path": { 128 | "vocab_path": "../data/cnews.vocab.txt", 129 | "trainingSet_path": "../data/cnews.train.txt", 130 | "valSet_path": "../data/cnews.val.txt", 131 | "testingSet_path": "../data/cnews.test.txt" 132 | }, 133 | "CNN_training_rule": { 134 | "embedding_dim": 64, 135 | "seq_length": 600, 136 | "num_classes": 10, 137 | 138 | "conv1_num_filters": 128, 139 | "conv1_kernel_size": 1, 140 | 141 | "conv2_num_filters": 64, 142 | "conv2_kernel_size": 1, 143 | 144 | "vocab_size": 5000, 145 | 146 | "hidden_dim": 128, 147 | 148 | "dropout_keep_prob": 0.5, 149 | "learning_rate": 1e-3, 150 | 151 | "batch_size": 64, 152 | "epochs": 5, 153 | 154 | "print_per_batch": 100, 155 | "save_per_batch": 1000 156 | }, 157 | "LSTM": { 158 | "seq_length": 600, 159 | "num_classes": 10, 160 | "vocab_size": 5000, 161 | "batch_size": 64 162 | }, 163 | "result": { 164 | "CNN_model_path": "../result/CNN_model.h5", 165 | "LSTM_model_path": "../result/LSTM_model.h5" 166 | } 167 | } 168 | 169 | def get(self, section, name): 170 | return self.config_dict[section][name] -------------------------------------------------------------------------------- /src/config/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any, Optional 3 | import json 4 | import yaml 5 | import os 6 | from pathlib import Path 7 | from ..utils.logger import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | class BaseConfig(ABC): 12 | """配置管理基类""" 13 | 14 | def __init__(self, config_path: Optional[str] = None): 15 | """初始化配置管理器 16 | 17 | Args: 18 | config_path: 配置文件路径 19 | """ 20 | self.config: Dict[str, Any] = {} 21 | self.config_path = config_path 22 | 23 | if config_path: 24 | self.load_config(config_path) 25 | 26 | @abstractmethod 27 | def validate_config(self) -> bool: 28 | """验证配置是否有效 29 | 30 | Returns: 31 | 配置是否有效 32 | """ 33 | pass 34 | 35 | @abstractmethod 36 | def get_default_config(self) -> Dict[str, Any]: 37 | """获取默认配置 38 | 39 | Returns: 40 | 默认配置字典 41 | """ 42 | pass 43 | 44 | def load_config(self, config_path: str): 45 | """加载配置文件 46 | 47 | Args: 48 | config_path: 配置文件路径 49 | """ 50 | config_path = Path(config_path) 51 | if not config_path.exists(): 52 | logger.warning(f"配置文件不存在: {config_path},将使用默认配置") 53 | self.config = self.get_default_config() 54 | return 55 | 56 | # 根据文件扩展名选择加载方式 57 | if config_path.suffix == '.json': 58 | with open(config_path, 'r', encoding='utf-8') as f: 59 | self.config = json.load(f) 60 | elif config_path.suffix in ['.yml', '.yaml']: 61 | with open(config_path, 'r', encoding='utf-8') as f: 62 | self.config = yaml.safe_load(f) 63 | else: 64 | raise ValueError(f"不支持的配置文件格式: {config_path.suffix}") 65 | 66 | # 验证配置 67 | if not self.validate_config(): 68 | logger.warning("配置验证失败,将使用默认配置") 69 | self.config = self.get_default_config() 70 | else: 71 | logger.info(f"成功加载配置文件: {config_path}") 72 | 73 | def save_config(self, config_path: Optional[str] = None): 74 | """保存配置到文件 75 | 76 | Args: 77 | config_path: 配置文件路径,如果为None则使用初始化时的路径 78 | """ 79 | if config_path is None: 80 | if self.config_path is None: 81 | raise ValueError("未指定配置文件路径") 82 | config_path = self.config_path 83 | 84 | config_path = Path(config_path) 85 | config_path.parent.mkdir(parents=True, exist_ok=True) 86 | 87 | # 根据文件扩展名选择保存方式 88 | if config_path.suffix == '.json': 89 | with open(config_path, 'w', encoding='utf-8') as f: 90 | json.dump(self.config, f, ensure_ascii=False, indent=2) 91 | elif config_path.suffix in ['.yml', '.yaml']: 92 | with open(config_path, 'w', encoding='utf-8') as f: 93 | yaml.safe_dump(self.config, f, allow_unicode=True, default_flow_style=False) 94 | else: 95 | raise ValueError(f"不支持的配置文件格式: {config_path.suffix}") 96 | 97 | logger.info(f"配置已保存到: {config_path}") 98 | 99 | def update_config(self, new_config: Dict[str, Any], validate: bool = True): 100 | """更新配置 101 | 102 | Args: 103 | new_config: 新的配置字典 104 | validate: 是否验证更新后的配置 105 | """ 106 | # 递归更新配置 107 | def deep_update(d: Dict[str, Any], u: Dict[str, Any]) -> Dict[str, Any]: 108 | for k, v in u.items(): 109 | if isinstance(v, dict) and k in d and isinstance(d[k], dict): 110 | d[k] = deep_update(d[k], v) 111 | else: 112 | d[k] = v 113 | return d 114 | 115 | self.config = deep_update(self.config, new_config) 116 | 117 | # 验证更新后的配置 118 | if validate and not self.validate_config(): 119 | raise ValueError("更新后的配置验证失败") 120 | 121 | logger.info("配置已更新") 122 | 123 | def get(self, key: str, default: Any = None) -> Any: 124 | """获取配置项 125 | 126 | Args: 127 | key: 配置项键名 128 | default: 默认值 129 | 130 | Returns: 131 | 配置项值 132 | """ 133 | return self.config.get(key, default) 134 | 135 | def set(self, key: str, value: Any, validate: bool = True): 136 | """设置配置项 137 | 138 | Args: 139 | key: 配置项键名 140 | value: 配置项值 141 | validate: 是否验证更新后的配置 142 | """ 143 | self.config[key] = value 144 | 145 | # 验证更新后的配置 146 | if validate and not self.validate_config(): 147 | raise ValueError("更新后的配置验证失败") 148 | 149 | logger.info(f"配置项已更新: {key}") 150 | 151 | def __getitem__(self, key: str) -> Any: 152 | """通过字典方式访问配置项 153 | 154 | Args: 155 | key: 配置项键名 156 | 157 | Returns: 158 | 配置项值 159 | """ 160 | return self.config[key] 161 | 162 | def __setitem__(self, key: str, value: Any): 163 | """通过字典方式设置配置项 164 | 165 | Args: 166 | key: 配置项键名 167 | value: 配置项值 168 | """ 169 | self.set(key, value) 170 | 171 | def __contains__(self, key: str) -> bool: 172 | """检查配置项是否存在 173 | 174 | Args: 175 | key: 配置项键名 176 | 177 | Returns: 178 | 配置项是否存在 179 | """ 180 | return key in self.config 181 | 182 | def __str__(self) -> str: 183 | """返回配置的字符串表示""" 184 | return json.dumps(self.config, ensure_ascii=False, indent=2) -------------------------------------------------------------------------------- /src/config/text_classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | from .base import BaseConfig 3 | from ..utils.logger import get_logger 4 | 5 | logger = get_logger(__name__) 6 | 7 | class TextClassifierConfig(BaseConfig): 8 | """文本分类配置类""" 9 | 10 | def validate_config(self) -> bool: 11 | """验证配置是否有效 12 | 13 | Returns: 14 | 配置是否有效 15 | """ 16 | required_keys = { 17 | 'data': ['train_file', 'val_file', 'test_file', 'max_length'], 18 | 'model': ['model_type', 'vocab_size', 'embedding_dim', 'num_classes'], 19 | 'training': ['batch_size', 'epochs', 'learning_rate', 'optimizer'] 20 | } 21 | 22 | # 检查必需配置项 23 | for section, keys in required_keys.items(): 24 | if section not in self.config: 25 | logger.error(f"缺少配置节: {section}") 26 | return False 27 | 28 | for key in keys: 29 | if key not in self.config[section]: 30 | logger.error(f"缺少配置项: {section}.{key}") 31 | return False 32 | 33 | # 验证模型类型 34 | valid_model_types = ['textcnn', 'lstm', 'bert'] 35 | if self.config['model']['model_type'] not in valid_model_types: 36 | logger.error(f"无效的模型类型: {self.config['model']['model_type']}") 37 | return False 38 | 39 | # 验证优化器 40 | valid_optimizers = ['adam', 'sgd'] 41 | if self.config['training']['optimizer'] not in valid_optimizers: 42 | logger.error(f"无效的优化器: {self.config['training']['optimizer']}") 43 | return False 44 | 45 | # 验证数值范围 46 | if self.config['data']['max_length'] <= 0: 47 | logger.error("max_length必须大于0") 48 | return False 49 | 50 | if self.config['model']['vocab_size'] <= 0: 51 | logger.error("vocab_size必须大于0") 52 | return False 53 | 54 | if self.config['model']['embedding_dim'] <= 0: 55 | logger.error("embedding_dim必须大于0") 56 | return False 57 | 58 | if self.config['model']['num_classes'] <= 0: 59 | logger.error("num_classes必须大于0") 60 | return False 61 | 62 | if self.config['training']['batch_size'] <= 0: 63 | logger.error("batch_size必须大于0") 64 | return False 65 | 66 | if self.config['training']['epochs'] <= 0: 67 | logger.error("epochs必须大于0") 68 | return False 69 | 70 | if self.config['training']['learning_rate'] <= 0: 71 | logger.error("learning_rate必须大于0") 72 | return False 73 | 74 | return True 75 | 76 | def get_default_config(self) -> Dict[str, Any]: 77 | """获取默认配置 78 | 79 | Returns: 80 | 默认配置字典 81 | """ 82 | return { 83 | 'data': { 84 | 'train_file': 'data/raw/cnews.train.txt', 85 | 'val_file': 'data/raw/cnews.val.txt', 86 | 'test_file': 'data/raw/cnews.test.txt', 87 | 'max_length': 512, 88 | 'shuffle_buffer_size': 10000, 89 | 'preprocessing': { 90 | 'remove_urls': True, 91 | 'remove_emails': True, 92 | 'remove_numbers': False, 93 | 'remove_punctuation': False, 94 | 'remove_whitespace': True, 95 | 'lowercase': True 96 | } 97 | }, 98 | 'model': { 99 | 'model_type': 'textcnn', 100 | 'vocab_size': 50000, 101 | 'embedding_dim': 300, 102 | 'num_classes': 10, 103 | 'textcnn': { 104 | 'num_filters': 128, 105 | 'filter_sizes': [2, 3, 4, 5], 106 | 'dropout_rate': 0.5 107 | }, 108 | 'lstm': { 109 | 'lstm_units': 128, 110 | 'bidirectional': True, 111 | 'dropout_rate': 0.5 112 | }, 113 | 'bert': { 114 | 'model_name': 'bert-base-chinese', 115 | 'dropout_rate': 0.1, 116 | 'fine_tune': True 117 | } 118 | }, 119 | 'training': { 120 | 'batch_size': 32, 121 | 'epochs': 10, 122 | 'learning_rate': 1e-3, 123 | 'optimizer': 'adam', 124 | 'momentum': 0.9, 125 | 'early_stopping_patience': 5, 126 | 'reduce_lr_patience': 2, 127 | 'reduce_lr_factor': 0.5, 128 | 'min_lr': 1e-6, 129 | 'metrics': ['accuracy', 'precision', 'recall', 'f1'], 130 | 'checkpoint_dir': 'models/checkpoints', 131 | 'log_dir': 'models/logs', 132 | 'save_best_only': True 133 | }, 134 | 'augmentation': { 135 | 'enabled': True, 136 | 'methods': { 137 | 'synonym_replacement': { 138 | 'enabled': True, 139 | 'max_words': 3, 140 | 'prob': 0.3 141 | }, 142 | 'random_deletion': { 143 | 'enabled': True, 144 | 'prob': 0.2 145 | }, 146 | 'random_swap': { 147 | 'enabled': True, 148 | 'max_swaps': 3, 149 | 'prob': 0.3 150 | }, 151 | 'random_insertion': { 152 | 'enabled': True, 153 | 'max_insertions': 3, 154 | 'prob': 0.3 155 | } 156 | } 157 | }, 158 | 'evaluation': { 159 | 'batch_size': 32, 160 | 'output_dir': 'evaluation', 161 | 'save_predictions': True, 162 | 'plot_confusion_matrix': True, 163 | 'plot_roc_curves': True, 164 | 'analyze_errors': True, 165 | 'top_k_errors': 5, 166 | 'analyze_feature_importance': False 167 | } 168 | } 169 | 170 | def get_model_config(self) -> Dict[str, Any]: 171 | """获取模型配置 172 | 173 | Returns: 174 | 模型配置字典 175 | """ 176 | model_type = self.config['model']['model_type'] 177 | model_config = { 178 | 'vocab_size': self.config['model']['vocab_size'], 179 | 'embedding_dim': self.config['model']['embedding_dim'], 180 | 'num_classes': self.config['model']['num_classes'], 181 | 'max_length': self.config['data']['max_length'] 182 | } 183 | 184 | # 添加模型特定配置 185 | if model_type == 'textcnn': 186 | model_config.update(self.config['model']['textcnn']) 187 | elif model_type == 'lstm': 188 | model_config.update(self.config['model']['lstm']) 189 | elif model_type == 'bert': 190 | model_config.update(self.config['model']['bert']) 191 | 192 | return model_config 193 | 194 | def get_training_config(self) -> Dict[str, Any]: 195 | """获取训练配置 196 | 197 | Returns: 198 | 训练配置字典 199 | """ 200 | return self.config['training'] 201 | 202 | def get_data_config(self) -> Dict[str, Any]: 203 | """获取数据配置 204 | 205 | Returns: 206 | 数据配置字典 207 | """ 208 | return self.config['data'] 209 | 210 | def get_augmentation_config(self) -> Dict[str, Any]: 211 | """获取数据增强配置 212 | 213 | Returns: 214 | 数据增强配置字典 215 | """ 216 | return self.config['augmentation'] 217 | 218 | def get_evaluation_config(self) -> Dict[str, Any]: 219 | """获取评估配置 220 | 221 | Returns: 222 | 评估配置字典 223 | """ 224 | return self.config['evaluation'] -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import jieba 3 | import synonyms 4 | from typing import List, Tuple 5 | import numpy as np 6 | from googletrans import Translator 7 | from tqdm import tqdm 8 | 9 | class TextAugmenter: 10 | def __init__(self, prob: float = 0.3): 11 | """ 12 | 文本增强器 13 | :param prob: 每个词被替换的概率 14 | """ 15 | self.prob = prob 16 | self.translator = Translator() 17 | 18 | def synonym_replacement(self, text: str, n: int = 1) -> str: 19 | """ 20 | 同义词替换 21 | :param text: 输入文本 22 | :param n: 替换词数量 23 | :return: 增强后的文本 24 | """ 25 | words = list(jieba.cut(text)) 26 | n = min(n, len(words)) 27 | new_words = words.copy() 28 | random_word_list = list(set([word for word in words if len(word) > 1])) 29 | random.shuffle(random_word_list) 30 | num_replaced = 0 31 | 32 | for random_word in random_word_list: 33 | synonyms_list = synonyms.nearby(random_word)[0] 34 | if len(synonyms_list) > 1: 35 | synonym = random.choice(synonyms_list[1:]) 36 | for idx, word in enumerate(new_words): 37 | if word == random_word and random.random() < self.prob: 38 | new_words[idx] = synonym 39 | num_replaced += 1 40 | break 41 | if num_replaced >= n: 42 | break 43 | 44 | return ''.join(new_words) 45 | 46 | def back_translation(self, text: str, target_lang: str = 'en') -> str: 47 | """ 48 | 回译增强 49 | :param text: 输入文本 50 | :param target_lang: 目标语言 51 | :return: 增强后的文本 52 | """ 53 | try: 54 | # 翻译成目标语言 55 | translated = self.translator.translate(text, dest=target_lang) 56 | # 翻译回中文 57 | back_translated = self.translator.translate(translated.text, dest='zh-cn') 58 | return back_translated.text 59 | except Exception as e: 60 | print(f"回译失败: {e}") 61 | return text 62 | 63 | def augment_batch(self, texts: List[str], labels: List[int], 64 | methods: List[str] = ['synonym', 'back_translation'], 65 | n_augment: int = 1) -> Tuple[List[str], List[int]]: 66 | """ 67 | 批量数据增强 68 | :param texts: 文本列表 69 | :param labels: 标签列表 70 | :param methods: 增强方法列表 71 | :param n_augment: 每个样本增强次数 72 | :return: 增强后的文本和标签 73 | """ 74 | augmented_texts = [] 75 | augmented_labels = [] 76 | 77 | for text, label in tqdm(zip(texts, labels), total=len(texts), desc="数据增强"): 78 | augmented_texts.append(text) 79 | augmented_labels.append(label) 80 | 81 | for _ in range(n_augment): 82 | method = random.choice(methods) 83 | if method == 'synonym': 84 | aug_text = self.synonym_replacement(text) 85 | elif method == 'back_translation': 86 | aug_text = self.back_translation(text) 87 | else: 88 | continue 89 | 90 | augmented_texts.append(aug_text) 91 | augmented_labels.append(label) 92 | 93 | return augmented_texts, augmented_labels 94 | 95 | def get_augmented_dataset(texts: List[str], labels: List[int], 96 | augment_ratio: float = 0.3) -> Tuple[List[str], List[int]]: 97 | """ 98 | 获取增强后的数据集 99 | :param texts: 原始文本列表 100 | :param labels: 原始标签列表 101 | :param augment_ratio: 增强比例 102 | :return: 增强后的文本和标签 103 | """ 104 | augmenter = TextAugmenter() 105 | n_augment = int(len(texts) * augment_ratio) 106 | 107 | # 随机选择样本进行增强 108 | indices = np.random.choice(len(texts), n_augment, replace=False) 109 | texts_to_augment = [texts[i] for i in indices] 110 | labels_to_augment = [labels[i] for i in indices] 111 | 112 | # 进行数据增强 113 | aug_texts, aug_labels = augmenter.augment_batch( 114 | texts_to_augment, 115 | labels_to_augment, 116 | methods=['synonym', 'back_translation'], 117 | n_augment=1 118 | ) 119 | 120 | # 合并原始数据和增强数据 121 | all_texts = texts + aug_texts 122 | all_labels = labels + aug_labels 123 | 124 | return all_texts, all_labels -------------------------------------------------------------------------------- /src/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, List, Dict, Any 3 | import numpy as np 4 | 5 | class BaseDataset(ABC): 6 | """数据集处理的基础类""" 7 | 8 | def __init__(self, data_dir: str, max_length: int = 512): 9 | self.data_dir = data_dir 10 | self.max_length = max_length 11 | self.label_to_id: Dict[str, int] = {} 12 | self.id_to_label: Dict[int, str] = {} 13 | self.vocab: Dict[str, int] = {} 14 | self.vocab_size: int = 0 15 | 16 | @abstractmethod 17 | def load_data(self, filename: str) -> Tuple[List[str], List[int]]: 18 | """加载数据文件 19 | 20 | Args: 21 | filename: 数据文件名 22 | 23 | Returns: 24 | texts: 文本列表 25 | labels: 标签列表 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def preprocess_text(self, text: str) -> str: 31 | """文本预处理 32 | 33 | Args: 34 | text: 原始文本 35 | 36 | Returns: 37 | 处理后的文本 38 | """ 39 | pass 40 | 41 | @abstractmethod 42 | def encode_texts(self, texts: List[str]) -> np.ndarray: 43 | """将文本转换为模型输入格式 44 | 45 | Args: 46 | texts: 文本列表 47 | 48 | Returns: 49 | 编码后的文本数组 50 | """ 51 | pass 52 | 53 | def get_label_mapping(self) -> Tuple[Dict[str, int], Dict[int, str]]: 54 | """获取标签映射 55 | 56 | Returns: 57 | label_to_id: 标签到ID的映射 58 | id_to_label: ID到标签的映射 59 | """ 60 | return self.label_to_id, self.id_to_label 61 | 62 | def get_vocab_info(self) -> Tuple[Dict[str, int], int]: 63 | """获取词表信息 64 | 65 | Returns: 66 | vocab: 词表字典 67 | vocab_size: 词表大小 68 | """ 69 | return self.vocab, self.vocab_size 70 | 71 | def save_vocab(self, filepath: str): 72 | """保存词表 73 | 74 | Args: 75 | filepath: 保存路径 76 | """ 77 | import json 78 | with open(filepath, 'w', encoding='utf-8') as f: 79 | json.dump(self.vocab, f, ensure_ascii=False, indent=2) 80 | 81 | def load_vocab(self, filepath: str): 82 | """加载词表 83 | 84 | Args: 85 | filepath: 词表文件路径 86 | """ 87 | import json 88 | with open(filepath, 'r', encoding='utf-8') as f: 89 | self.vocab = json.load(f) 90 | self.vocab_size = len(self.vocab) 91 | 92 | def save_label_mapping(self, filepath: str): 93 | """保存标签映射 94 | 95 | Args: 96 | filepath: 保存路径 97 | """ 98 | import json 99 | mapping = { 100 | 'label_to_id': self.label_to_id, 101 | 'id_to_label': {str(k): v for k, v in self.id_to_label.items()} 102 | } 103 | with open(filepath, 'w', encoding='utf-8') as f: 104 | json.dump(mapping, f, ensure_ascii=False, indent=2) 105 | 106 | def load_label_mapping(self, filepath: str): 107 | """加载标签映射 108 | 109 | Args: 110 | filepath: 映射文件路径 111 | """ 112 | import json 113 | with open(filepath, 'r', encoding='utf-8') as f: 114 | mapping = json.load(f) 115 | self.label_to_id = mapping['label_to_id'] 116 | self.id_to_label = {int(k): v for k, v in mapping['id_to_label'].items()} -------------------------------------------------------------------------------- /src/data/cnews_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import jieba 4 | import numpy as np 5 | from typing import Tuple, List, Dict 6 | from .base import BaseDataset 7 | 8 | class CNewsDataset(BaseDataset): 9 | """中文新闻数据集处理类""" 10 | 11 | def __init__(self, data_dir: str, max_length: int = 512, min_freq: int = 2): 12 | super().__init__(data_dir, max_length) 13 | self.min_freq = min_freq 14 | self._build_vocab() 15 | 16 | def _build_vocab(self): 17 | """构建词表""" 18 | # 加载训练数据 19 | train_texts, _ = self.load_data('cnews.train.txt') 20 | 21 | # 统计词频 22 | word_freq = {} 23 | for text in train_texts: 24 | words = jieba.lcut(text) 25 | for word in words: 26 | word_freq[word] = word_freq.get(word, 0) + 1 27 | 28 | # 过滤低频词 29 | word_freq = {k: v for k, v in word_freq.items() if v >= self.min_freq} 30 | 31 | # 构建词表 32 | self.vocab = { 33 | '': 0, 34 | '': 1, 35 | **{word: idx + 2 for idx, word in enumerate(word_freq.keys())} 36 | } 37 | self.vocab_size = len(self.vocab) 38 | 39 | def load_data(self, filename: str) -> Tuple[List[str], List[int]]: 40 | """加载数据文件 41 | 42 | Args: 43 | filename: 数据文件名 44 | 45 | Returns: 46 | texts: 文本列表 47 | labels: 标签列表 48 | """ 49 | texts, labels = [], [] 50 | filepath = os.path.join(self.data_dir, filename) 51 | 52 | with open(filepath, 'r', encoding='utf-8') as f: 53 | for line in f: 54 | label, text = line.strip().split('\t') 55 | 56 | # 更新标签映射 57 | if label not in self.label_to_id: 58 | label_id = len(self.label_to_id) 59 | self.label_to_id[label] = label_id 60 | self.id_to_label[label_id] = label 61 | 62 | texts.append(text) 63 | labels.append(self.label_to_id[label]) 64 | 65 | return texts, np.array(labels) 66 | 67 | def preprocess_text(self, text: str) -> str: 68 | """文本预处理 69 | 70 | Args: 71 | text: 原始文本 72 | 73 | Returns: 74 | 处理后的文本 75 | """ 76 | # 去除特殊字符和数字 77 | text = re.sub(r'[^\u4e00-\u9fa5]', '', text) 78 | # 分词 79 | words = jieba.lcut(text) 80 | # 截断或填充 81 | if len(words) > self.max_length: 82 | words = words[:self.max_length] 83 | else: 84 | words.extend([''] * (self.max_length - len(words))) 85 | return ' '.join(words) 86 | 87 | def encode_texts(self, texts: List[str]) -> np.ndarray: 88 | """将文本转换为模型输入格式 89 | 90 | Args: 91 | texts: 文本列表 92 | 93 | Returns: 94 | 编码后的文本数组 95 | """ 96 | encoded_texts = [] 97 | for text in texts: 98 | # 预处理文本 99 | processed_text = self.preprocess_text(text) 100 | # 转换为词ID 101 | word_ids = [ 102 | self.vocab.get(word, self.vocab['']) 103 | for word in processed_text.split() 104 | ] 105 | encoded_texts.append(word_ids) 106 | return np.array(encoded_texts) 107 | 108 | def get_class_weights(self) -> Dict[int, float]: 109 | """计算类别权重,用于处理类别不平衡 110 | 111 | Returns: 112 | 类别权重字典 113 | """ 114 | _, labels = self.load_data('cnews.train.txt') 115 | class_counts = np.bincount(labels) 116 | total_samples = len(labels) 117 | class_weights = { 118 | i: total_samples / (len(class_counts) * count) 119 | for i, count in enumerate(class_counts) 120 | } 121 | return class_weights -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.preprocessing.sequence import pad_sequences 3 | from tensorflow.keras.utils import to_categorical 4 | from .preprocess import read_txt, get_vocab, get_category_id, tokenize 5 | from typing import Tuple, List 6 | 7 | class TextDataset: 8 | def __init__(self, txt_path: str, vocab_path: str, max_length: int, categories: List[str] = None): 9 | self.txt_path = txt_path 10 | self.vocab, self.vocab_dict = get_vocab(vocab_path) 11 | self.max_length = max_length 12 | self.categories = categories or ["体育", "财经", "房产", "家居", "教育", "科技", "时尚", "时政", "游戏", "娱乐"] 13 | self.cate_dict = get_category_id(self.categories) 14 | self.num_classes = len(self.categories) 15 | 16 | def encode_samples(self) -> Tuple[np.ndarray, np.ndarray]: 17 | labels, contents = read_txt(self.txt_path) 18 | labels_idx = [self.cate_dict[label] for label in labels] 19 | contents_idx = [] 20 | for content in contents: 21 | # 可选:分词(如数据已分好可跳过) 22 | # tokens = tokenize(content) 23 | tokens = list(content) 24 | idxs = [self.vocab_dict.get(word, 5000) for word in tokens] 25 | contents_idx.append(idxs) 26 | x_pad = pad_sequences(contents_idx, self.max_length) 27 | y_pad = to_categorical(labels_idx, num_classes=self.num_classes) 28 | return x_pad, y_pad 29 | 30 | def encode_single(self, sentence: str) -> np.ndarray: 31 | # tokens = tokenize(sentence) 32 | tokens = list(sentence) 33 | idxs = [self.vocab_dict.get(word, 5000) for word in tokens] 34 | x_pad = pad_sequences([idxs], self.max_length) 35 | return x_pad 36 | -------------------------------------------------------------------------------- /src/data/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 文本预处理工具函数 4 | """ 5 | import jieba 6 | from pathlib import Path 7 | from typing import List, Tuple, Dict 8 | 9 | def read_txt(txt_path: str) -> Tuple[List[str], List[str]]: 10 | """ 11 | 读取文档数据 12 | :param txt_path: 文档路径 13 | :return: 标签数组与文本数组 14 | """ 15 | with open(txt_path, "r", encoding='utf-8') as f: 16 | data = f.readlines() 17 | labels = [] 18 | contents = [] 19 | for line in data: 20 | label, content = line.strip().split('\t') 21 | labels.append(label) 22 | contents.append(content) 23 | return labels, contents 24 | 25 | def get_vocab(vocab_path: str) -> Tuple[List[str], Dict[str, int]]: 26 | """ 27 | 读取词汇表 28 | :param vocab_path: 词汇表路径 29 | :return: 词汇表list和词到索引的dict 30 | """ 31 | with open(vocab_path, "r", encoding="utf-8") as f: 32 | infile = f.readlines() 33 | vocabs = [word.strip() for word in infile] 34 | vocabs_dict = {word: idx for idx, word in enumerate(vocabs)} 35 | return vocabs, vocabs_dict 36 | 37 | def get_category_id(categories: List[str] = None) -> Dict[str, int]: 38 | """ 39 | 返回分类种类的索引 40 | :param categories: 类别列表 41 | :return: 分类到索引的dict 42 | """ 43 | if categories is None: 44 | categories = ["体育", "财经", "房产", "家居", "教育", "科技", "时尚", "时政", "游戏", "娱乐"] 45 | return {cat: idx for idx, cat in enumerate(categories)} 46 | 47 | def tokenize(text: str) -> List[str]: 48 | """ 49 | 使用jieba分词 50 | :param text: 输入文本 51 | :return: 分词结果 52 | """ 53 | return list(jieba.cut(text)) 54 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from config import TEXTCNN_CONFIG, LSTM_CONFIG, DATASET_CONFIG, MODEL_SAVE_DIR 3 | from data.dataset import TextDataset 4 | from tensorflow.keras.models import load_model 5 | from sklearn.metrics import classification_report, accuracy_score 6 | import numpy as np 7 | import os 8 | 9 | MODEL_CONFIG_MAP = { 10 | 'textcnn': TEXTCNN_CONFIG, 11 | 'lstm': LSTM_CONFIG, 12 | } 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser(description="中文文本分类评估脚本") 16 | parser.add_argument('--model', type=str, default='textcnn', choices=['textcnn', 'lstm'], help='选择模型类型') 17 | args = parser.parse_args() 18 | 19 | model_config = MODEL_CONFIG_MAP[args.model] 20 | print(f"评估模型: {args.model}") 21 | 22 | # 加载测试集 23 | test_dataset = TextDataset( 24 | txt_path=str(DATASET_CONFIG['test_file']), 25 | vocab_path=str(DATASET_CONFIG['vocab_file']), 26 | max_length=model_config['max_sequence_length'], 27 | categories=DATASET_CONFIG['categories'] 28 | ) 29 | x_test, y_test = test_dataset.encode_samples() 30 | 31 | # 加载模型 32 | model_path = os.path.join(MODEL_SAVE_DIR, f"{args.model}_model.h5") 33 | if not os.path.exists(model_path): 34 | print(f"模型文件不存在: {model_path}") 35 | return 36 | model = load_model(model_path) 37 | print(f"模型已加载: {model_path}") 38 | 39 | # 预测 40 | y_pred = model.predict(x_test) 41 | y_true = np.argmax(y_test, axis=1) 42 | y_pred_label = np.argmax(y_pred, axis=1) 43 | 44 | # 输出评估指标 45 | acc = accuracy_score(y_true, y_pred_label) 46 | print(f"准确率: {acc:.4f}") 47 | print("分类报告:") 48 | print(classification_report(y_true, y_pred_label, target_names=DATASET_CONFIG['categories'])) 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /src/evaluate/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any, List, Tuple, Optional 3 | import numpy as np 4 | import tensorflow as tf 5 | from sklearn.metrics import classification_report, confusion_matrix 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | from ..models.base import BaseModel 9 | from ..data.base import BaseDataset 10 | from ..utils.logger import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | class BaseEvaluator(ABC): 15 | """评估器基类""" 16 | 17 | def __init__( 18 | self, 19 | model: BaseModel, 20 | dataset: BaseDataset, 21 | config: Dict[str, Any], 22 | output_dir: str = "evaluation" 23 | ): 24 | """初始化评估器 25 | 26 | Args: 27 | model: 模型实例 28 | dataset: 数据集实例 29 | config: 评估配置 30 | output_dir: 评估结果输出目录 31 | """ 32 | self.model = model 33 | self.dataset = dataset 34 | self.config = config 35 | self.output_dir = output_dir 36 | 37 | # 获取标签映射 38 | self.label_to_id, self.id_to_label = dataset.get_label_mapping() 39 | self.num_classes = len(self.label_to_id) 40 | 41 | @abstractmethod 42 | def evaluate( 43 | self, 44 | test_data: Tuple[np.ndarray, np.ndarray], 45 | **kwargs 46 | ) -> Dict[str, float]: 47 | """评估模型性能 48 | 49 | Args: 50 | test_data: 测试数据 (X_test, y_test) 51 | **kwargs: 其他评估参数 52 | 53 | Returns: 54 | 评估指标 55 | """ 56 | pass 57 | 58 | def predict( 59 | self, 60 | texts: List[str], 61 | batch_size: int = 32 62 | ) -> Tuple[np.ndarray, np.ndarray]: 63 | """预测文本类别 64 | 65 | Args: 66 | texts: 文本列表 67 | batch_size: 批处理大小 68 | 69 | Returns: 70 | predictions: 预测的类别ID 71 | probabilities: 预测的概率分布 72 | """ 73 | # 文本预处理和编码 74 | encoded_texts = self.dataset.encode_texts(texts) 75 | 76 | # 批量预测 77 | dataset = tf.data.Dataset.from_tensor_slices(encoded_texts) 78 | dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE) 79 | 80 | # 获取预测结果 81 | probabilities = self.model.predict(dataset) 82 | predictions = np.argmax(probabilities, axis=1) 83 | 84 | return predictions, probabilities 85 | 86 | def get_classification_report( 87 | self, 88 | y_true: np.ndarray, 89 | y_pred: np.ndarray, 90 | output_file: Optional[str] = None 91 | ) -> str: 92 | """生成分类报告 93 | 94 | Args: 95 | y_true: 真实标签 96 | y_pred: 预测标签 97 | output_file: 输出文件路径 98 | 99 | Returns: 100 | 分类报告文本 101 | """ 102 | # 将标签ID转换为标签名称 103 | y_true_names = [self.id_to_label[y] for y in y_true] 104 | y_pred_names = [self.id_to_label[y] for y in y_pred] 105 | 106 | # 生成分类报告 107 | report = classification_report( 108 | y_true_names, 109 | y_pred_names, 110 | target_names=list(self.label_to_id.keys()), 111 | digits=4 112 | ) 113 | 114 | # 保存报告 115 | if output_file: 116 | with open(output_file, 'w', encoding='utf-8') as f: 117 | f.write(report) 118 | logger.info(f"分类报告已保存到: {output_file}") 119 | 120 | return report 121 | 122 | def plot_confusion_matrix( 123 | self, 124 | y_true: np.ndarray, 125 | y_pred: np.ndarray, 126 | output_file: Optional[str] = None, 127 | figsize: Tuple[int, int] = (10, 8) 128 | ): 129 | """绘制混淆矩阵 130 | 131 | Args: 132 | y_true: 真实标签 133 | y_pred: 预测标签 134 | output_file: 输出文件路径 135 | figsize: 图像大小 136 | """ 137 | # 计算混淆矩阵 138 | cm = confusion_matrix(y_true, y_pred) 139 | 140 | # 绘制混淆矩阵 141 | plt.figure(figsize=figsize) 142 | sns.heatmap( 143 | cm, 144 | annot=True, 145 | fmt='d', 146 | cmap='Blues', 147 | xticklabels=list(self.label_to_id.keys()), 148 | yticklabels=list(self.label_to_id.keys()) 149 | ) 150 | plt.title('混淆矩阵') 151 | plt.xlabel('预测标签') 152 | plt.ylabel('真实标签') 153 | plt.xticks(rotation=45, ha='right') 154 | plt.yticks(rotation=0) 155 | plt.tight_layout() 156 | 157 | # 保存图像 158 | if output_file: 159 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 160 | logger.info(f"混淆矩阵已保存到: {output_file}") 161 | 162 | plt.close() 163 | 164 | def plot_learning_curves( 165 | self, 166 | history: Dict[str, List[float]], 167 | output_file: Optional[str] = None, 168 | figsize: Tuple[int, int] = (10, 6) 169 | ): 170 | """绘制学习曲线 171 | 172 | Args: 173 | history: 训练历史记录 174 | output_file: 输出文件路径 175 | figsize: 图像大小 176 | """ 177 | plt.figure(figsize=figsize) 178 | 179 | # 绘制训练和验证损失 180 | plt.subplot(1, 2, 1) 181 | plt.plot(history['loss'], label='训练损失') 182 | plt.plot(history['val_loss'], label='验证损失') 183 | plt.title('模型损失') 184 | plt.xlabel('轮次') 185 | plt.ylabel('损失') 186 | plt.legend() 187 | 188 | # 绘制训练和验证准确率 189 | plt.subplot(1, 2, 2) 190 | plt.plot(history['accuracy'], label='训练准确率') 191 | plt.plot(history['val_accuracy'], label='验证准确率') 192 | plt.title('模型准确率') 193 | plt.xlabel('轮次') 194 | plt.ylabel('准确率') 195 | plt.legend() 196 | 197 | plt.tight_layout() 198 | 199 | # 保存图像 200 | if output_file: 201 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 202 | logger.info(f"学习曲线已保存到: {output_file}") 203 | 204 | plt.close() 205 | 206 | def analyze_errors( 207 | self, 208 | texts: List[str], 209 | y_true: np.ndarray, 210 | y_pred: np.ndarray, 211 | top_k: int = 5, 212 | output_file: Optional[str] = None 213 | ): 214 | """分析预测错误的样本 215 | 216 | Args: 217 | texts: 文本列表 218 | y_true: 真实标签 219 | y_pred: 预测标签 220 | top_k: 每个类别展示的错误样本数 221 | output_file: 输出文件路径 222 | """ 223 | # 找出预测错误的样本 224 | errors = [] 225 | for text, true_label, pred_label in zip(texts, y_true, y_pred): 226 | if true_label != pred_label: 227 | errors.append({ 228 | 'text': text, 229 | 'true_label': self.id_to_label[true_label], 230 | 'pred_label': self.id_to_label[pred_label] 231 | }) 232 | 233 | # 按真实标签分组 234 | error_groups = {} 235 | for error in errors: 236 | true_label = error['true_label'] 237 | if true_label not in error_groups: 238 | error_groups[true_label] = [] 239 | error_groups[true_label].append(error) 240 | 241 | # 生成错误分析报告 242 | report = [] 243 | report.append("预测错误分析报告") 244 | report.append("=" * 50) 245 | 246 | for true_label, errors in error_groups.items(): 247 | report.append(f"\n真实标签: {true_label}") 248 | report.append("-" * 30) 249 | 250 | # 统计预测错误的分布 251 | pred_dist = {} 252 | for error in errors: 253 | pred_label = error['pred_label'] 254 | pred_dist[pred_label] = pred_dist.get(pred_label, 0) + 1 255 | 256 | # 输出预测分布 257 | report.append("预测分布:") 258 | for pred_label, count in sorted(pred_dist.items(), key=lambda x: x[1], reverse=True): 259 | report.append(f" - 预测为 {pred_label}: {count} 个样本") 260 | 261 | # 输出错误样本 262 | report.append("\n错误样本示例:") 263 | for error in errors[:top_k]: 264 | report.append(f" - 文本: {error['text'][:100]}...") 265 | report.append(f" 预测为: {error['pred_label']}") 266 | report.append() 267 | 268 | # 保存报告 269 | if output_file: 270 | with open(output_file, 'w', encoding='utf-8') as f: 271 | f.write('\n'.join(report)) 272 | logger.info(f"错误分析报告已保存到: {output_file}") 273 | 274 | return '\n'.join(report) -------------------------------------------------------------------------------- /src/evaluate/text_classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Tuple, Optional 2 | import numpy as np 3 | import tensorflow as tf 4 | from sklearn.metrics import roc_curve, auc 5 | import matplotlib.pyplot as plt 6 | from .base import BaseEvaluator 7 | from ..utils.logger import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | class TextClassifierEvaluator(BaseEvaluator): 12 | """文本分类评估器""" 13 | 14 | def evaluate( 15 | self, 16 | test_data: Tuple[np.ndarray, np.ndarray], 17 | **kwargs 18 | ) -> Dict[str, float]: 19 | """评估模型性能 20 | 21 | Args: 22 | test_data: 测试数据 (X_test, y_test) 23 | **kwargs: 其他评估参数 24 | 25 | Returns: 26 | 评估指标 27 | """ 28 | X_test, y_test = test_data 29 | 30 | # 准备测试数据 31 | test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)) 32 | test_dataset = test_dataset.batch( 33 | self.config.get('batch_size', 32) 34 | ).prefetch( 35 | tf.data.AUTOTUNE 36 | ) 37 | 38 | # 评估模型 39 | logger.info("开始评估模型...") 40 | metrics = self.model.evaluate(test_dataset, return_dict=True) 41 | 42 | # 获取预测结果 43 | y_pred, y_prob = self.predict(X_test) 44 | 45 | # 生成评估报告 46 | report = self.get_classification_report( 47 | y_test, 48 | y_pred, 49 | output_file=f"{self.output_dir}/classification_report.txt" 50 | ) 51 | logger.info("\n分类报告:\n" + report) 52 | 53 | # 绘制混淆矩阵 54 | self.plot_confusion_matrix( 55 | y_test, 56 | y_pred, 57 | output_file=f"{self.output_dir}/confusion_matrix.png" 58 | ) 59 | 60 | # 绘制ROC曲线 61 | self.plot_roc_curves( 62 | y_test, 63 | y_prob, 64 | output_file=f"{self.output_dir}/roc_curves.png" 65 | ) 66 | 67 | # 分析错误样本 68 | if 'texts' in kwargs: 69 | self.analyze_errors( 70 | kwargs['texts'], 71 | y_test, 72 | y_pred, 73 | output_file=f"{self.output_dir}/error_analysis.txt" 74 | ) 75 | 76 | return metrics 77 | 78 | def plot_roc_curves( 79 | self, 80 | y_true: np.ndarray, 81 | y_prob: np.ndarray, 82 | output_file: Optional[str] = None, 83 | figsize: Tuple[int, int] = (10, 8) 84 | ): 85 | """绘制ROC曲线 86 | 87 | Args: 88 | y_true: 真实标签(one-hot编码) 89 | y_prob: 预测概率 90 | output_file: 输出文件路径 91 | figsize: 图像大小 92 | """ 93 | plt.figure(figsize=figsize) 94 | 95 | # 计算每个类别的ROC曲线 96 | fpr = dict() 97 | tpr = dict() 98 | roc_auc = dict() 99 | 100 | for i in range(self.num_classes): 101 | # 将真实标签转换为二分类形式 102 | y_true_binary = (y_true == i).astype(int) 103 | 104 | # 计算ROC曲线 105 | fpr[i], tpr[i], _ = roc_curve(y_true_binary, y_prob[:, i]) 106 | roc_auc[i] = auc(fpr[i], tpr[i]) 107 | 108 | # 绘制ROC曲线 109 | plt.plot( 110 | fpr[i], 111 | tpr[i], 112 | label=f'{self.id_to_label[i]} (AUC = {roc_auc[i]:.3f})' 113 | ) 114 | 115 | # 绘制对角线 116 | plt.plot([0, 1], [0, 1], 'k--') 117 | 118 | # 设置图表属性 119 | plt.xlim([0.0, 1.0]) 120 | plt.ylim([0.0, 1.05]) 121 | plt.xlabel('假正例率 (False Positive Rate)') 122 | plt.ylabel('真正例率 (True Positive Rate)') 123 | plt.title('各类别的ROC曲线') 124 | plt.legend(loc="lower right") 125 | plt.grid(True) 126 | 127 | # 保存图像 128 | if output_file: 129 | plt.savefig(output_file, dpi=300, bbox_inches='tight') 130 | logger.info(f"ROC曲线已保存到: {output_file}") 131 | 132 | plt.close() 133 | 134 | def analyze_feature_importance( 135 | self, 136 | texts: List[str], 137 | top_k: int = 10, 138 | output_file: Optional[str] = None 139 | ): 140 | """分析特征重要性(仅适用于可解释的模型) 141 | 142 | Args: 143 | texts: 文本列表 144 | top_k: 每个类别展示的重要特征数 145 | output_file: 输出文件路径 146 | """ 147 | # 获取模型的特征重要性(需要模型支持) 148 | if not hasattr(self.model, 'get_feature_importance'): 149 | logger.warning("当前模型不支持特征重要性分析") 150 | return 151 | 152 | # 获取预测结果 153 | predictions, _ = self.predict(texts) 154 | 155 | # 分析每个类别的特征重要性 156 | report = [] 157 | report.append("特征重要性分析报告") 158 | report.append("=" * 50) 159 | 160 | for class_id in range(self.num_classes): 161 | # 获取该类别的样本 162 | class_texts = [text for text, pred in zip(texts, predictions) if pred == class_id] 163 | if not class_texts: 164 | continue 165 | 166 | # 获取特征重要性 167 | importance = self.model.get_feature_importance(class_texts) 168 | 169 | # 生成报告 170 | report.append(f"\n类别: {self.id_to_label[class_id]}") 171 | report.append("-" * 30) 172 | report.append("重要特征:") 173 | 174 | # 输出top-k重要特征 175 | for feature, score in sorted(importance.items(), key=lambda x: x[1], reverse=True)[:top_k]: 176 | report.append(f" - {feature}: {score:.4f}") 177 | 178 | # 保存报告 179 | if output_file: 180 | with open(output_file, 'w', encoding='utf-8') as f: 181 | f.write('\n'.join(report)) 182 | logger.info(f"特征重要性分析报告已保存到: {output_file}") 183 | 184 | return '\n'.join(report) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Created on 2020-07-19 01:32 5 | @Author : Justin Jiang 6 | @Email : jw_jiang@pku.edu.com 7 | """ 8 | 9 | import tensorflow.keras as keras 10 | import numpy as np 11 | from sklearn import metrics 12 | import os 13 | 14 | from preprocess import preprocesser 15 | from config import Config 16 | from model import TextCNN 17 | 18 | 19 | if __name__ == '__main__': 20 | CNN_model = TextCNN() 21 | CNN_model.train(3) 22 | CNN_model.test() -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Created on 2020-07-19 00:12 5 | @Author : Justin Jiang 6 | @Email : jw_jiang@pku.edu.com 7 | """ 8 | 9 | import tensorflow.keras as keras 10 | from config import Config 11 | from preprocess import preprocesser 12 | import os 13 | from sklearn import metrics 14 | import numpy as np 15 | 16 | 17 | class TextCNN(object): 18 | 19 | def __init__(self): 20 | self.config = Config() 21 | self.pre = preprocesser() 22 | 23 | def model(self): 24 | num_classes = self.config.get("CNN_training_rule", "num_classes") 25 | vocab_size = self.config.get("CNN_training_rule", "vocab_size") 26 | seq_length = self.config.get("CNN_training_rule", "seq_length") 27 | 28 | conv1_num_filters = self.config.get("CNN_training_rule", "conv1_num_filters") 29 | conv1_kernel_size = self.config.get("CNN_training_rule", "conv1_kernel_size") 30 | 31 | conv2_num_filters = self.config.get("CNN_training_rule", "conv2_num_filters") 32 | conv2_kernel_size = self.config.get("CNN_training_rule", "conv2_kernel_size") 33 | 34 | hidden_dim = self.config.get("CNN_training_rule", "hidden_dim") 35 | dropout_keep_prob = self.config.get("CNN_training_rule", "dropout_keep_prob") 36 | 37 | model_input = keras.layers.Input((seq_length,), dtype='float64') 38 | embedding_layer = keras.layers.Embedding(vocab_size+1, 256, input_length=seq_length) 39 | embedded = embedding_layer(model_input) 40 | 41 | # conv1形状[batch_size, seq_length, conv1_num_filters] 42 | conv_1 = keras.layers.Conv1D(conv1_num_filters, conv1_kernel_size, padding="SAME")(embedded) 43 | conv_2 = keras.layers.Conv1D(conv2_num_filters, conv2_kernel_size, padding="SAME")(conv_1) 44 | max_poolinged = keras.layers.GlobalMaxPool1D()(conv_2) 45 | 46 | full_connect = keras.layers.Dense(hidden_dim)(max_poolinged) 47 | droped = keras.layers.Dropout(dropout_keep_prob)(full_connect) 48 | relued = keras.layers.ReLU()(droped) 49 | model_output = keras.layers.Dense(num_classes, activation="softmax")(relued) 50 | model = keras.models.Model(inputs=model_input, outputs=model_output) 51 | model.compile(loss="categorical_crossentropy", 52 | optimizer="adam", 53 | metrics=["accuracy"]) 54 | print(model.summary()) 55 | return model 56 | 57 | def train(self, epochs): 58 | trainingSet_path = self.config.get("data_path", "trainingSet_path") 59 | valSet_path = self.config.get("data_path", "valSet_path") 60 | seq_length = self.config.get("CNN_training_rule", "seq_length") 61 | model_save_path = self.config.get("result", "CNN_model_path") 62 | batch_size = self.config.get("CNN_training_rule", "batch_size") 63 | 64 | x_train, y_train = self.pre.word2idx(trainingSet_path, max_length=seq_length) 65 | x_val, y_val = self.pre.word2idx(valSet_path, max_length=seq_length) 66 | 67 | model = self.model() 68 | for _ in range(epochs): 69 | model.fit(x_train, y_train, 70 | batch_size=batch_size, 71 | epochs=1, 72 | validation_data=(x_val, y_val)) 73 | model.save(model_save_path, overwrite=True) 74 | 75 | def test(self): 76 | model_save_path = self.config.get("result", "CNN_model_path") 77 | testingSet_path = self.config.get("data_path", "testingSet_path") 78 | seq_length = self.config.get("CNN_training_rule", "seq_length") 79 | 80 | 81 | if os.path.exists(model_save_path): 82 | model = keras.models.load_model(model_save_path) 83 | print("-----model loaded-----") 84 | model.summary() 85 | 86 | x_test, y_test = self.pre.word2idx(testingSet_path, max_length=seq_length) 87 | # print(x_test.shape) 88 | # print(type(x_test)) 89 | # print(y_test.shape) 90 | # print(type(y_test)) 91 | pre_test = model.predict(x_test) 92 | # print(pre_test.shape) 93 | # metrics.classification_report(np.argmax(pre_test, axis=1), np.argmax(y_test, axis=1), digits=4, output_dict=True) 94 | print(metrics.classification_report(np.argmax(pre_test, axis=1), np.argmax(y_test, axis=1))) 95 | 96 | 97 | class LSTM(object): 98 | 99 | def __init__(self): 100 | self.config = Config() 101 | self.pre = preprocesser() 102 | 103 | def model(self): 104 | seq_length = self.config.get("LSTM", "seq_length") 105 | num_classes = self.config.get("LSTM", "num_classes") 106 | vocab_size = self.config.get("LSTM", "vocab_size") 107 | 108 | 109 | model_input = keras.layers.Input((seq_length)) 110 | embedding = keras.layers.Embedding(vocab_size+1, 256, input_length=seq_length)(model_input) 111 | LSTM = keras.layers.LSTM(256)(embedding) 112 | FC1 = keras.layers.Dense(256, activation="relu")(LSTM) 113 | droped = keras.layers.Dropout(0.5)(FC1) 114 | FC2 = keras.layers.Dense(num_classes, activation="softmax")(droped) 115 | 116 | model = keras.models.Model(inputs=model_input, outputs=FC2) 117 | 118 | model.compile(loss="categorical_crossentropy", 119 | optimizer=keras.optimizers.RMSprop(), 120 | metrics=["accuracy"]) 121 | model.summary() 122 | return model 123 | 124 | def train(self, epochs): 125 | trainingSet_path = self.config.get("data_path", "trainingSet_path") 126 | valSet_path = self.config.get("data_path", "valSet_path") 127 | seq_length = self.config.get("LSTM", "seq_length") 128 | model_save_path = self.config.get("result", "LSTM_model_path") 129 | batch_size = self.config.get("LSTM", "batch_size") 130 | 131 | model = self.model() 132 | 133 | x_train, y_train = self.pre.word2idx(trainingSet_path, max_length=seq_length) 134 | x_val, y_val = self.pre.word2idx(valSet_path, max_length=seq_length) 135 | 136 | for _ in range(epochs): 137 | model.fit(x_train, y_train, 138 | batch_size=batch_size, 139 | validation_data=(x_val, y_val), 140 | epochs=1) 141 | model.save(model_save_path, overwrite=True) 142 | 143 | def test(self): 144 | model_save_path = self.config.get("result", "LSTM_model_path") 145 | testingSet_path = self.config.get("data_path", "testingSet_path") 146 | seq_length = self.config.get("LSTM", "seq_length") 147 | 148 | 149 | if os.path.exists(model_save_path): 150 | model = keras.models.load_model(model_save_path) 151 | print("-----model loaded-----") 152 | model.summary() 153 | 154 | x_test, y_test = self.pre.word2idx(testingSet_path, max_length=seq_length) 155 | pre_test = model.predict(x_test) 156 | 157 | # metrics.classification_report(np.argmax(pre_test, axis=1), np.argmax(y_test, axis=1), digits=4, output_dict=True) 158 | print(metrics.classification_report(np.argmax(pre_test, axis=1), np.argmax(y_test, axis=1))) 159 | 160 | 161 | 162 | if __name__ == '__main__': 163 | test = TextCNN() 164 | # test.train(3) 165 | test.test() 166 | 167 | # LSTMTest = LSTM() 168 | # LSTMTest.train(3) 169 | # LSTMTest.test() -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .textcnn import TextCNN 2 | from .lstm import LSTM 3 | from .bert import BERTClassifier 4 | 5 | __all__ = ['TextCNN', 'LSTM', 'BERTClassifier'] 6 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import tensorflow as tf 3 | 4 | class BaseModel(ABC): 5 | """文本分类模型的基础类""" 6 | 7 | def __init__(self, vocab_size: int, embedding_dim: int, num_classes: int, max_length: int): 8 | self.vocab_size = vocab_size 9 | self.embedding_dim = embedding_dim 10 | self.num_classes = num_classes 11 | self.max_length = max_length 12 | self.model = None 13 | 14 | @abstractmethod 15 | def build(self) -> tf.keras.Model: 16 | """构建模型架构""" 17 | pass 18 | 19 | def compile(self, optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']): 20 | """编译模型""" 21 | if self.model is None: 22 | self.model = self.build() 23 | self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics) 24 | 25 | def fit(self, *args, **kwargs): 26 | """训练模型""" 27 | if self.model is None: 28 | raise ValueError("Model must be compiled before training") 29 | return self.model.fit(*args, **kwargs) 30 | 31 | def evaluate(self, *args, **kwargs): 32 | """评估模型""" 33 | if self.model is None: 34 | raise ValueError("Model must be compiled before evaluation") 35 | return self.model.evaluate(*args, **kwargs) 36 | 37 | def predict(self, *args, **kwargs): 38 | """模型预测""" 39 | if self.model is None: 40 | raise ValueError("Model must be compiled before prediction") 41 | return self.model.predict(*args, **kwargs) 42 | 43 | def save(self, filepath: str): 44 | """保存模型""" 45 | if self.model is None: 46 | raise ValueError("No model to save") 47 | self.model.save(filepath) 48 | 49 | @classmethod 50 | def load(cls, filepath: str): 51 | """加载模型""" 52 | model = tf.keras.models.load_model(filepath) 53 | return model 54 | -------------------------------------------------------------------------------- /src/models/bert.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from transformers import TFBertModel, BertTokenizer 3 | from .base import BaseModel 4 | 5 | class BERTClassifier(BaseModel): 6 | """BERT模型实现""" 7 | 8 | def __init__(self, vocab_size: int, embedding_dim: int, num_classes: int, max_length: int, 9 | model_name: str = 'bert-base-chinese', dropout_rate: float = 0.1): 10 | super().__init__(vocab_size, embedding_dim, num_classes, max_length) 11 | self.model_name = model_name 12 | self.dropout_rate = dropout_rate 13 | self.tokenizer = BertTokenizer.from_pretrained(model_name) 14 | self.bert_model = TFBertModel.from_pretrained(model_name) 15 | 16 | def build(self) -> tf.keras.Model: 17 | """构建BERT模型架构""" 18 | # 输入层 19 | input_ids = tf.keras.layers.Input(shape=(self.max_length,), dtype=tf.int32, name='input_ids') 20 | attention_mask = tf.keras.layers.Input(shape=(self.max_length,), dtype=tf.int32, name='attention_mask') 21 | token_type_ids = tf.keras.layers.Input(shape=(self.max_length,), dtype=tf.int32, name='token_type_ids') 22 | 23 | # BERT层 24 | bert_output = self.bert_model( 25 | input_ids=input_ids, 26 | attention_mask=attention_mask, 27 | token_type_ids=token_type_ids 28 | ) 29 | 30 | # 使用[CLS]标记的输出 31 | pooled_output = bert_output[1] 32 | 33 | # Dropout层 34 | dropout = tf.keras.layers.Dropout(self.dropout_rate)(pooled_output) 35 | 36 | # 全连接层 37 | dense = tf.keras.layers.Dense(256, activation='relu')(dropout) 38 | outputs = tf.keras.layers.Dense(self.num_classes, activation='softmax')(dense) 39 | 40 | return tf.keras.Model( 41 | inputs=[input_ids, attention_mask, token_type_ids], 42 | outputs=outputs 43 | ) 44 | 45 | def preprocess_text(self, texts): 46 | """预处理文本为BERT输入格式""" 47 | return self.tokenizer( 48 | texts, 49 | padding='max_length', 50 | truncation=True, 51 | max_length=self.max_length, 52 | return_tensors='tf' 53 | ) 54 | 55 | def predict(self, texts, *args, **kwargs): 56 | """重写预测方法以处理文本输入""" 57 | inputs = self.preprocess_text(texts) 58 | return super().predict(inputs, *args, **kwargs) 59 | 60 | def fit(self, texts, labels, *args, **kwargs): 61 | """重写训练方法以处理文本输入""" 62 | inputs = self.preprocess_text(texts) 63 | return super().fit(inputs, labels, *args, **kwargs) 64 | 65 | def evaluate(self, texts, labels, *args, **kwargs): 66 | """重写评估方法以处理文本输入""" 67 | inputs = self.preprocess_text(texts) 68 | return super().evaluate(inputs, labels, *args, **kwargs) 69 | 70 | def encode_text(text: str, tokenizer, max_length: int = 512): 71 | """ 72 | 使用BERT tokenizer对文本进行编码 73 | :param text: 输入文本 74 | :param tokenizer: BERT tokenizer 75 | :param max_length: 最大序列长度 76 | :return: 编码后的input_ids, attention_mask, token_type_ids 77 | """ 78 | encoding = tokenizer( 79 | text, 80 | max_length=max_length, 81 | padding='max_length', 82 | truncation=True, 83 | return_tensors='tf' 84 | ) 85 | return ( 86 | encoding['input_ids'], 87 | encoding['attention_mask'], 88 | encoding['token_type_ids'] 89 | ) 90 | -------------------------------------------------------------------------------- /src/models/lstm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .base import BaseModel 3 | 4 | class LSTM(BaseModel): 5 | """LSTM模型实现""" 6 | 7 | def __init__(self, vocab_size: int, embedding_dim: int, num_classes: int, max_length: int, 8 | lstm_units: int = 128, dropout_rate: float = 0.5, bidirectional: bool = True): 9 | super().__init__(vocab_size, embedding_dim, num_classes, max_length) 10 | self.lstm_units = lstm_units 11 | self.dropout_rate = dropout_rate 12 | self.bidirectional = bidirectional 13 | 14 | def build(self) -> tf.keras.Model: 15 | """构建LSTM模型架构""" 16 | inputs = tf.keras.layers.Input(shape=(self.max_length,)) 17 | 18 | # 词嵌入层 19 | embedding = tf.keras.layers.Embedding( 20 | self.vocab_size, 21 | self.embedding_dim, 22 | input_length=self.max_length 23 | )(inputs) 24 | 25 | # LSTM层 26 | if self.bidirectional: 27 | lstm = tf.keras.layers.Bidirectional( 28 | tf.keras.layers.LSTM(self.lstm_units, return_sequences=True) 29 | )(embedding) 30 | lstm = tf.keras.layers.Bidirectional( 31 | tf.keras.layers.LSTM(self.lstm_units) 32 | )(lstm) 33 | else: 34 | lstm = tf.keras.layers.LSTM(self.lstm_units, return_sequences=True)(embedding) 35 | lstm = tf.keras.layers.LSTM(self.lstm_units)(lstm) 36 | 37 | # Dropout层 38 | dropout = tf.keras.layers.Dropout(self.dropout_rate)(lstm) 39 | 40 | # 全连接层 41 | dense = tf.keras.layers.Dense(128, activation='relu')(dropout) 42 | outputs = tf.keras.layers.Dense(self.num_classes, activation='softmax')(dense) 43 | 44 | return tf.keras.Model(inputs=inputs, outputs=outputs) 45 | -------------------------------------------------------------------------------- /src/models/textcnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .base import BaseModel 3 | 4 | class TextCNN(BaseModel): 5 | """TextCNN模型实现""" 6 | 7 | def __init__(self, vocab_size: int, embedding_dim: int, num_classes: int, max_length: int, 8 | num_filters: int = 128, filter_sizes: list = [3, 4, 5], dropout_rate: float = 0.5): 9 | super().__init__(vocab_size, embedding_dim, num_classes, max_length) 10 | self.num_filters = num_filters 11 | self.filter_sizes = filter_sizes 12 | self.dropout_rate = dropout_rate 13 | 14 | def build(self) -> tf.keras.Model: 15 | """构建TextCNN模型架构""" 16 | inputs = tf.keras.layers.Input(shape=(self.max_length,)) 17 | 18 | # 词嵌入层 19 | embedding = tf.keras.layers.Embedding( 20 | self.vocab_size, 21 | self.embedding_dim, 22 | input_length=self.max_length 23 | )(inputs) 24 | 25 | # 卷积层 26 | conv_outputs = [] 27 | for filter_size in self.filter_sizes: 28 | conv = tf.keras.layers.Conv1D( 29 | filters=self.num_filters, 30 | kernel_size=filter_size, 31 | activation='relu', 32 | padding='same' 33 | )(embedding) 34 | pool = tf.keras.layers.GlobalMaxPooling1D()(conv) 35 | conv_outputs.append(pool) 36 | 37 | # 合并所有卷积输出 38 | concat = tf.keras.layers.Concatenate()(conv_outputs) 39 | 40 | # Dropout层 41 | dropout = tf.keras.layers.Dropout(self.dropout_rate)(concat) 42 | 43 | # 全连接层 44 | dense = tf.keras.layers.Dense(128, activation='relu')(dropout) 45 | outputs = tf.keras.layers.Dense(self.num_classes, activation='softmax')(dense) 46 | 47 | return tf.keras.Model(inputs=inputs, outputs=outputs) 48 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Created on 2020-07-18 10:50 5 | @Author : Justin Jiang 6 | @Email : jw_jiang@pku.edu.com 7 | 8 | 数据加载 9 | """ 10 | 11 | import numpy as np 12 | import tensorflow.keras as keras 13 | from config import Config 14 | import jieba 15 | 16 | class preprocesser(object): 17 | 18 | def __init__(self): 19 | self.config = Config() 20 | 21 | def read_txt(self, txt_path): 22 | """ 23 | 读取文档数据 24 | :param txt_path:文档路径 25 | :return: 该文档中的标签数组与文本数组 26 | """ 27 | with open(txt_path, "r", encoding='utf-8') as f: 28 | data = f.readlines() 29 | labels = [] 30 | contents = [] 31 | for line in data: 32 | label, content = line.strip().split('\t') 33 | labels.append(label) 34 | contents.append(content) 35 | return labels, contents 36 | 37 | def get_vocab_id(self): 38 | """ 39 | 读取分词文档 40 | :return:分词数组与各分词索引的字典 41 | """ 42 | vocab_path = self.config.get("data_path", "vocab_path") 43 | with open(vocab_path, "r", encoding="utf-8") as f: 44 | infile = f.readlines() 45 | vocabs = list([word.replace("\n", "") for word in infile]) 46 | vocabs_dict = dict(zip(vocabs, range(len(vocabs)))) 47 | return vocabs, vocabs_dict 48 | 49 | def get_category_id(self): 50 | """ 51 | 返回分类种类的索引 52 | :return: 返回分类种类的字典 53 | """ 54 | categories = ["体育", "财经", "房产", "家居", "教育", "科技", "时尚", "时政", "游戏", "娱乐"] 55 | cates_dict = dict(zip(categories, range(len(categories)))) 56 | return cates_dict 57 | 58 | def word2idx(self, txt_path, max_length): 59 | """ 60 | 将语料中各文本转换成固定max_length后返回各文本的标签与文本tokens 61 | :param txt_path: 语料路径 62 | :param max_length: pad后的长度 63 | :return: 语料pad后表示与标签 64 | """ 65 | # vocabs:分词词汇表 66 | # vocabs_dict:各分词的索引 67 | vocabs, vocabs_dict = self.get_vocab_id() 68 | # cates_dict:各分类的索引 69 | cates_dict = self.get_category_id() 70 | 71 | # 读取语料 72 | labels, contents = self.read_txt(txt_path) 73 | # labels_idx:用来存放语料中的分类 74 | labels_idx = [] 75 | # contents_idx:用来存放语料中各样本的索引 76 | contents_idx = [] 77 | 78 | # 遍历语料 79 | for idx in range(len(contents)): 80 | # tmp:存放当前语句index 81 | tmp = [] 82 | # 将该idx(样本)的标签加入至labels_idx中 83 | labels_idx.append(cates_dict[labels[idx]]) 84 | # contents[idx]:为该语料中的样本遍历项 85 | # 遍历contents中各词并将其转换为索引后加入contents_idx中 86 | for word in contents[idx]: 87 | if word in vocabs: 88 | tmp.append(vocabs_dict[word]) 89 | else: 90 | # 第5000位设置为未知字符 91 | tmp.append(5000) 92 | # 将该样本index后结果存入contents_idx作为结果等待传回 93 | contents_idx.append(tmp) 94 | 95 | 96 | # 将各样本长度pad至max_length 97 | x_pad = keras.preprocessing.sequence.pad_sequences(contents_idx, max_length) 98 | y_pad = keras.utils.to_categorical(labels_idx, num_classes=len(cates_dict)) 99 | 100 | return x_pad, y_pad 101 | 102 | 103 | def word2idx_for_sample(self, sentence, max_length): 104 | # vocabs:分词词汇表 105 | # vocabs_dict:各分词的索引 106 | vocabs, vocabs_dict = self.get_vocab_id() 107 | result = [] 108 | # 遍历语料 109 | for word in sentence: 110 | # tmp:存放当前语句index 111 | if word in vocabs: 112 | result.append(vocabs_dict[word]) 113 | else: 114 | # 第5000位设置为未知字符,实际中为vocabs_dict[5000],使得vocabs_dict长度变成len(vocabs_dict+1) 115 | result.append(5000) 116 | 117 | x_pad = keras.preprocessing.sequence.pad_sequences([result], max_length) 118 | return x_pad 119 | 120 | 121 | if __name__ == '__main__': 122 | test = preprocesser() 123 | # tmp_path = '../data/training_sample.txt' 124 | # 125 | # test.word2idx(tmp_path, 600) 126 | # x_pad, y_pad = test.word2idx(tmp_path, 600) 127 | # 128 | # print(len(x_pad[0])) 129 | # print(x_pad[0]) 130 | # print(len(x_pad[1])) 131 | # print(x_pad[1]) 132 | # 133 | # print(y_pad) 134 | print(test.word2idx_for_sample("马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的困扰。7月31日下午6点,国奥队的日常训练再度受到大雨的干扰,无奈之下队员们只慢跑了25分钟就草草收场。31日上午10点,国奥队在奥体中心外场训练的时候,天就是阴沉沉的,气象预报显示当天下午沈阳就有大雨,但幸好队伍上午的训练并没有受到任何干扰。下午6点,当球队抵达训练场时,大雨已经下了几个小时,而且丝毫没有停下来的意思。抱着试一试的态度,球队开始了当天下午的例行训练,25分钟过去了,天气没有任何转好的迹象,为了保护球员们,国奥队决定中止当天的训练,全队立即返回酒店。在雨中训练对足球队来说并不是什么稀罕事,但在奥运会即将开始之前,全队变得“娇贵”了。在沈阳最后一周的训练,国奥队首先要保证现有的球员不再出现意外的伤病情况以免影响正式比赛,因此这一阶段控制训练受伤、控制感冒等疾病的出现被队伍放在了相当重要的位置。而抵达沈阳之后,中后卫冯萧霆就一直没有训练,冯萧霆是7月27日在长春患上了感冒,因此也没有参加29日跟塞尔维亚的热身赛。队伍介绍说,冯萧霆并没有出现发烧症状,但为了安全起见,这两天还是让他静养休息,等感冒彻底好了之后再恢复训练。由于有了冯萧霆这个例子,因此国奥队对雨中训练就显得特别谨慎,主要是担心球员们受凉而引发感冒,造成非战斗减员。而女足队员马晓旭在热身赛中受伤导致无缘奥运的前科,也让在沈阳的国奥队现在格外警惕,“训练中不断嘱咐队员们要注意动作,我们可不能再出这样的事情了。”一位工作人员表示。从长春到沈阳,雨水一路伴随着国奥队,“也邪了,我们走到哪儿雨就下到哪儿,在长春几次训练都被大雨给搅和了,没想到来沈阳又碰到这种事情。”一位国奥球员也对雨水的“青睐”有些不解。", 600)) -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from config import TEXTCNN_CONFIG, LSTM_CONFIG, DATASET_CONFIG, MODEL_SAVE_DIR 3 | from data.dataset import TextDataset 4 | from models.textcnn import build_textcnn 5 | from models.lstm import build_lstm 6 | import os 7 | 8 | MODEL_MAP = { 9 | 'textcnn': (build_textcnn, TEXTCNN_CONFIG), 10 | 'lstm': (build_lstm, LSTM_CONFIG), 11 | } 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description="中文文本分类训练脚本") 15 | parser.add_argument('--model', type=str, default='textcnn', choices=['textcnn', 'lstm'], help='选择模型类型') 16 | parser.add_argument('--epochs', type=int, default=10, help='训练轮数') 17 | args = parser.parse_args() 18 | 19 | # 选择模型和配置 20 | build_fn, model_config = MODEL_MAP[args.model] 21 | print(f"使用模型: {args.model}") 22 | 23 | # 加载数据 24 | train_dataset = TextDataset( 25 | txt_path=str(DATASET_CONFIG['train_file']), 26 | vocab_path=str(DATASET_CONFIG['vocab_file']), 27 | max_length=model_config['max_sequence_length'], 28 | categories=DATASET_CONFIG['categories'] 29 | ) 30 | val_dataset = TextDataset( 31 | txt_path=str(DATASET_CONFIG['val_file']), 32 | vocab_path=str(DATASET_CONFIG['vocab_file']), 33 | max_length=model_config['max_sequence_length'], 34 | categories=DATASET_CONFIG['categories'] 35 | ) 36 | x_train, y_train = train_dataset.encode_samples() 37 | x_val, y_val = val_dataset.encode_samples() 38 | 39 | # 构建模型 40 | model = build_fn( 41 | vocab_size=DATASET_CONFIG['vocab_size'], 42 | seq_length=model_config['max_sequence_length'], 43 | num_classes=len(DATASET_CONFIG['categories']), 44 | embedding_dim=model_config['embedding_dim'], 45 | **{k: v for k, v in model_config.items() if k not in ['embedding_dim', 'max_sequence_length']} 46 | ) 47 | model.summary() 48 | 49 | # 训练 50 | model.fit( 51 | x_train, y_train, 52 | batch_size=model_config['batch_size'], 53 | epochs=args.epochs, 54 | validation_data=(x_val, y_val), 55 | verbose=1 56 | ) 57 | 58 | # 保存模型 59 | save_path = os.path.join(MODEL_SAVE_DIR, f"{args.model}_model.h5") 60 | model.save(save_path) 61 | print(f"模型已保存到: {save_path}") 62 | 63 | if __name__ == '__main__': 64 | main() -------------------------------------------------------------------------------- /src/train/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any, Optional, Tuple 3 | import tensorflow as tf 4 | import numpy as np 5 | from ..models.base import BaseModel 6 | from ..data.base import BaseDataset 7 | from ..utils.logger import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | class BaseTrainer(ABC): 12 | """训练器基类""" 13 | 14 | def __init__( 15 | self, 16 | model: BaseModel, 17 | dataset: BaseDataset, 18 | config: Dict[str, Any], 19 | model_dir: str = "models" 20 | ): 21 | """初始化训练器 22 | 23 | Args: 24 | model: 模型实例 25 | dataset: 数据集实例 26 | config: 训练配置 27 | model_dir: 模型保存目录 28 | """ 29 | self.model = model 30 | self.dataset = dataset 31 | self.config = config 32 | self.model_dir = model_dir 33 | 34 | # 训练相关属性 35 | self.optimizer = None 36 | self.loss_fn = None 37 | self.metrics = None 38 | self.callbacks = [] 39 | 40 | # 初始化训练组件 41 | self._init_optimizer() 42 | self._init_loss() 43 | self._init_metrics() 44 | self._init_callbacks() 45 | 46 | @abstractmethod 47 | def _init_optimizer(self): 48 | """初始化优化器""" 49 | pass 50 | 51 | @abstractmethod 52 | def _init_loss(self): 53 | """初始化损失函数""" 54 | pass 55 | 56 | @abstractmethod 57 | def _init_metrics(self): 58 | """初始化评估指标""" 59 | pass 60 | 61 | def _init_callbacks(self): 62 | """初始化回调函数""" 63 | # 模型检查点 64 | checkpoint = tf.keras.callbacks.ModelCheckpoint( 65 | filepath=f"{self.model_dir}/checkpoints/model-{{epoch:02d}}-{{val_loss:.4f}}.h5", 66 | monitor='val_loss', 67 | save_best_only=True, 68 | save_weights_only=True, 69 | mode='min', 70 | verbose=1 71 | ) 72 | 73 | # 早停 74 | early_stopping = tf.keras.callbacks.EarlyStopping( 75 | monitor='val_loss', 76 | patience=self.config.get('early_stopping_patience', 5), 77 | restore_best_weights=True, 78 | verbose=1 79 | ) 80 | 81 | # 学习率调度器 82 | lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau( 83 | monitor='val_loss', 84 | factor=0.5, 85 | patience=2, 86 | min_lr=1e-6, 87 | verbose=1 88 | ) 89 | 90 | # TensorBoard 91 | tensorboard = tf.keras.callbacks.TensorBoard( 92 | log_dir=f"{self.model_dir}/logs", 93 | histogram_freq=1, 94 | write_graph=True, 95 | update_freq='epoch' 96 | ) 97 | 98 | self.callbacks.extend([ 99 | checkpoint, 100 | early_stopping, 101 | lr_scheduler, 102 | tensorboard 103 | ]) 104 | 105 | def train( 106 | self, 107 | train_data: Tuple[np.ndarray, np.ndarray], 108 | val_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, 109 | **kwargs 110 | ) -> Dict[str, Any]: 111 | """训练模型 112 | 113 | Args: 114 | train_data: 训练数据 (X_train, y_train) 115 | val_data: 验证数据 (X_val, y_val) 116 | **kwargs: 其他训练参数 117 | 118 | Returns: 119 | 训练历史记录 120 | """ 121 | logger.info("开始训练模型...") 122 | 123 | # 准备训练数据 124 | X_train, y_train = train_data 125 | train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) 126 | train_dataset = train_dataset.shuffle( 127 | buffer_size=self.config.get('shuffle_buffer_size', 10000) 128 | ).batch( 129 | self.config.get('batch_size', 32) 130 | ).prefetch( 131 | tf.data.AUTOTUNE 132 | ) 133 | 134 | # 准备验证数据 135 | if val_data is not None: 136 | X_val, y_val = val_data 137 | val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)) 138 | val_dataset = val_dataset.batch( 139 | self.config.get('batch_size', 32) 140 | ).prefetch( 141 | tf.data.AUTOTUNE 142 | ) 143 | else: 144 | val_dataset = None 145 | 146 | # 训练模型 147 | history = self.model.fit( 148 | train_dataset, 149 | validation_data=val_dataset, 150 | epochs=self.config.get('epochs', 10), 151 | callbacks=self.callbacks, 152 | **kwargs 153 | ) 154 | 155 | logger.info("模型训练完成") 156 | return history.history 157 | 158 | def evaluate( 159 | self, 160 | test_data: Tuple[np.ndarray, np.ndarray] 161 | ) -> Dict[str, float]: 162 | """评估模型 163 | 164 | Args: 165 | test_data: 测试数据 (X_test, y_test) 166 | 167 | Returns: 168 | 评估指标 169 | """ 170 | logger.info("开始评估模型...") 171 | 172 | X_test, y_test = test_data 173 | test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)) 174 | test_dataset = test_dataset.batch( 175 | self.config.get('batch_size', 32) 176 | ).prefetch( 177 | tf.data.AUTOTUNE 178 | ) 179 | 180 | # 评估模型 181 | metrics = self.model.evaluate(test_dataset, return_dict=True) 182 | 183 | # 记录评估结果 184 | for metric_name, metric_value in metrics.items(): 185 | logger.info(f"{metric_name}: {metric_value:.4f}") 186 | 187 | return metrics 188 | 189 | def save_model(self, filepath: str): 190 | """保存模型 191 | 192 | Args: 193 | filepath: 保存路径 194 | """ 195 | logger.info(f"保存模型到 {filepath}") 196 | self.model.save(filepath) 197 | 198 | def load_model(self, filepath: str): 199 | """加载模型 200 | 201 | Args: 202 | filepath: 模型文件路径 203 | """ 204 | logger.info(f"从 {filepath} 加载模型") 205 | self.model = tf.keras.models.load_model(filepath) -------------------------------------------------------------------------------- /src/train/text_classifier.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from typing import Dict, Any, List 3 | from .base import BaseTrainer 4 | from ..utils.logger import get_logger 5 | import numpy as np 6 | 7 | logger = get_logger(__name__) 8 | 9 | class TextClassifierTrainer(BaseTrainer): 10 | """文本分类训练器""" 11 | 12 | def _init_optimizer(self): 13 | """初始化优化器""" 14 | learning_rate = self.config.get('learning_rate', 1e-3) 15 | optimizer_name = self.config.get('optimizer', 'adam') 16 | 17 | if optimizer_name == 'adam': 18 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 19 | elif optimizer_name == 'sgd': 20 | momentum = self.config.get('momentum', 0.9) 21 | self.optimizer = tf.keras.optimizers.SGD( 22 | learning_rate=learning_rate, 23 | momentum=momentum 24 | ) 25 | else: 26 | raise ValueError(f"不支持的优化器: {optimizer_name}") 27 | 28 | logger.info(f"使用优化器: {optimizer_name}, 学习率: {learning_rate}") 29 | 30 | def _init_loss(self): 31 | """初始化损失函数""" 32 | loss_name = self.config.get('loss', 'categorical_crossentropy') 33 | 34 | if loss_name == 'categorical_crossentropy': 35 | self.loss_fn = tf.keras.losses.CategoricalCrossentropy( 36 | from_logits=self.config.get('from_logits', False) 37 | ) 38 | elif loss_name == 'sparse_categorical_crossentropy': 39 | self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( 40 | from_logits=self.config.get('from_logits', False) 41 | ) 42 | else: 43 | raise ValueError(f"不支持的损失函数: {loss_name}") 44 | 45 | logger.info(f"使用损失函数: {loss_name}") 46 | 47 | def _init_metrics(self): 48 | """初始化评估指标""" 49 | metrics: List[str] = self.config.get('metrics', ['accuracy']) 50 | self.metrics = [] 51 | 52 | for metric_name in metrics: 53 | if metric_name == 'accuracy': 54 | self.metrics.append(tf.keras.metrics.CategoricalAccuracy()) 55 | elif metric_name == 'precision': 56 | self.metrics.append(tf.keras.metrics.Precision()) 57 | elif metric_name == 'recall': 58 | self.metrics.append(tf.keras.metrics.Recall()) 59 | elif metric_name == 'f1': 60 | self.metrics.append(tf.keras.metrics.F1Score()) 61 | else: 62 | raise ValueError(f"不支持的评估指标: {metric_name}") 63 | 64 | logger.info(f"使用评估指标: {metrics}") 65 | 66 | def train_with_augmentation( 67 | self, 68 | train_data: tuple, 69 | val_data: tuple = None, 70 | augmentation_config: Dict[str, Any] = None, 71 | **kwargs 72 | ) -> Dict[str, Any]: 73 | """使用数据增强进行训练 74 | 75 | Args: 76 | train_data: 训练数据 (X_train, y_train) 77 | val_data: 验证数据 (X_val, y_val) 78 | augmentation_config: 数据增强配置 79 | **kwargs: 其他训练参数 80 | 81 | Returns: 82 | 训练历史记录 83 | """ 84 | from ..data.augmentation import TextAugmenter 85 | 86 | if augmentation_config is None: 87 | augmentation_config = {} 88 | 89 | # 创建数据增强器 90 | augmenter = TextAugmenter(**augmentation_config) 91 | 92 | # 获取原始训练数据 93 | X_train, y_train = train_data 94 | 95 | # 应用数据增强 96 | logger.info("开始数据增强...") 97 | augmented_texts = [] 98 | augmented_labels = [] 99 | 100 | for text, label in zip(X_train, y_train): 101 | # 对每个样本进行增强 102 | augmented = augmenter.augment(text) 103 | augmented_texts.extend(augmented) 104 | augmented_labels.extend([label] * len(augmented)) 105 | 106 | # 将增强后的数据转换为模型输入格式 107 | X_train_aug = self.dataset.encode_texts(augmented_texts) 108 | y_train_aug = np.array(augmented_labels) 109 | 110 | # 合并原始数据和增强数据 111 | X_train_combined = np.concatenate([X_train, X_train_aug]) 112 | y_train_combined = np.concatenate([y_train, y_train_aug]) 113 | 114 | logger.info(f"数据增强完成,原始样本数: {len(X_train)}, 增强后样本数: {len(X_train_combined)}") 115 | 116 | # 使用增强后的数据进行训练 117 | return self.train( 118 | train_data=(X_train_combined, y_train_combined), 119 | val_data=val_data, 120 | **kwargs 121 | ) -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinJiang1994/Text_Classification/24b263fdece9827e1ce1b9e6b2f2e0971aa6c1df/src/utils/metrics.py -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | from sklearn.metrics import confusion_matrix 5 | import tensorflow as tf 6 | from typing import List, Dict, Any 7 | import pandas as pd 8 | 9 | def plot_training_history(history: Dict[str, List[float]], save_path: str = None): 10 | """ 11 | 绘制训练历史 12 | :param history: 训练历史数据 13 | :param save_path: 保存路径 14 | """ 15 | plt.figure(figsize=(12, 4)) 16 | 17 | # 绘制损失 18 | plt.subplot(1, 2, 1) 19 | plt.plot(history['loss'], label='训练损失') 20 | plt.plot(history['val_loss'], label='验证损失') 21 | plt.title('模型损失') 22 | plt.xlabel('Epoch') 23 | plt.ylabel('Loss') 24 | plt.legend() 25 | 26 | # 绘制准确率 27 | plt.subplot(1, 2, 2) 28 | plt.plot(history['accuracy'], label='训练准确率') 29 | plt.plot(history['val_accuracy'], label='验证准确率') 30 | plt.title('模型准确率') 31 | plt.xlabel('Epoch') 32 | plt.ylabel('Accuracy') 33 | plt.legend() 34 | 35 | plt.tight_layout() 36 | if save_path: 37 | plt.savefig(save_path) 38 | plt.close() 39 | 40 | def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, 41 | labels: List[str], save_path: str = None): 42 | """ 43 | 绘制混淆矩阵 44 | :param y_true: 真实标签 45 | :param y_pred: 预测标签 46 | :param labels: 标签名称列表 47 | :param save_path: 保存路径 48 | """ 49 | cm = confusion_matrix(y_true, y_pred) 50 | plt.figure(figsize=(10, 8)) 51 | sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 52 | xticklabels=labels, yticklabels=labels) 53 | plt.title('混淆矩阵') 54 | plt.xlabel('预测标签') 55 | plt.ylabel('真实标签') 56 | plt.tight_layout() 57 | if save_path: 58 | plt.savefig(save_path) 59 | plt.close() 60 | 61 | def plot_classification_metrics(metrics: Dict[str, float], save_path: str = None): 62 | """ 63 | 绘制分类指标 64 | :param metrics: 分类指标字典 65 | :param save_path: 保存路径 66 | """ 67 | plt.figure(figsize=(10, 6)) 68 | metrics_df = pd.DataFrame(metrics.items(), columns=['Metric', 'Value']) 69 | sns.barplot(x='Metric', y='Value', data=metrics_df) 70 | plt.title('分类指标') 71 | plt.xticks(rotation=45) 72 | plt.tight_layout() 73 | if save_path: 74 | plt.savefig(save_path) 75 | plt.close() 76 | 77 | def plot_attention_weights(attention_weights: np.ndarray, 78 | tokens: List[str], 79 | save_path: str = None): 80 | """ 81 | 绘制注意力权重 82 | :param attention_weights: 注意力权重矩阵 83 | :param tokens: 词元列表 84 | :param save_path: 保存路径 85 | """ 86 | plt.figure(figsize=(12, 8)) 87 | sns.heatmap(attention_weights, 88 | xticklabels=tokens, 89 | yticklabels=tokens, 90 | cmap='YlOrRd') 91 | plt.title('注意力权重可视化') 92 | plt.xlabel('词元') 93 | plt.ylabel('词元') 94 | plt.xticks(rotation=45) 95 | plt.tight_layout() 96 | if save_path: 97 | plt.savefig(save_path) 98 | plt.close() 99 | 100 | def plot_model_comparison(models_metrics: Dict[str, Dict[str, float]], 101 | metric_name: str = 'accuracy', 102 | save_path: str = None): 103 | """ 104 | 绘制模型比较图 105 | :param models_metrics: 各模型指标字典 106 | :param metric_name: 要比较的指标名称 107 | :param save_path: 保存路径 108 | """ 109 | plt.figure(figsize=(10, 6)) 110 | models = list(models_metrics.keys()) 111 | metrics = [metrics[metric_name] for metrics in models_metrics.values()] 112 | 113 | plt.bar(models, metrics) 114 | plt.title(f'模型{metric_name}比较') 115 | plt.xlabel('模型') 116 | plt.ylabel(metric_name) 117 | plt.xticks(rotation=45) 118 | plt.tight_layout() 119 | if save_path: 120 | plt.savefig(save_path) 121 | plt.close() 122 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from src.data.dataset import TextDataset 4 | from src.data.augmentation import TextAugmenter 5 | 6 | class TestDataProcessing(unittest.TestCase): 7 | def setUp(self): 8 | # 创建测试数据 9 | self.test_texts = [ 10 | "这是一个测试句子", 11 | "另一个测试句子", 12 | "第三个测试句子" 13 | ] 14 | self.test_labels = [0, 1, 2] 15 | self.vocab = ["这", "是", "一个", "测试", "句子", "另", "第三"] 16 | self.vocab_dict = {word: idx for idx, word in enumerate(self.vocab)} 17 | 18 | # 创建测试数据集 19 | self.dataset = TextDataset( 20 | txt_path="test.txt", 21 | vocab_path="test_vocab.txt", 22 | max_length=10, 23 | categories=["类别1", "类别2", "类别3"] 24 | ) 25 | self.dataset.vocab = self.vocab 26 | self.dataset.vocab_dict = self.vocab_dict 27 | 28 | def test_text_encoding(self): 29 | # 测试文本编码 30 | encoded = self.dataset.encode_single(self.test_texts[0]) 31 | self.assertEqual(encoded.shape, (1, 10)) # 检查padding后的长度 32 | 33 | # 测试批量编码 34 | x, y = self.dataset.encode_samples() 35 | self.assertEqual(len(x), len(self.test_texts)) 36 | self.assertEqual(len(y), len(self.test_labels)) 37 | 38 | def test_data_augmentation(self): 39 | augmenter = TextAugmenter(prob=0.3) 40 | 41 | # 测试同义词替换 42 | aug_text = augmenter.synonym_replacement(self.test_texts[0]) 43 | self.assertIsInstance(aug_text, str) 44 | self.assertTrue(len(aug_text) > 0) 45 | 46 | # 测试回译 47 | aug_text = augmenter.back_translation(self.test_texts[0]) 48 | self.assertIsInstance(aug_text, str) 49 | self.assertTrue(len(aug_text) > 0) 50 | 51 | # 测试批量增强 52 | aug_texts, aug_labels = augmenter.augment_batch( 53 | self.test_texts, 54 | self.test_labels, 55 | methods=['synonym'], 56 | n_augment=1 57 | ) 58 | self.assertEqual(len(aug_texts), len(self.test_texts) * 2) 59 | self.assertEqual(len(aug_labels), len(self.test_labels) * 2) 60 | 61 | if __name__ == '__main__': 62 | unittest.main() -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import tensorflow as tf 4 | from src.models.textcnn import build_textcnn 5 | from src.models.lstm import build_lstm 6 | from src.models.bert import build_bert 7 | 8 | class TestModels(unittest.TestCase): 9 | def setUp(self): 10 | self.vocab_size = 5000 11 | self.seq_length = 100 12 | self.num_classes = 10 13 | self.batch_size = 32 14 | 15 | def test_textcnn(self): 16 | model = build_textcnn( 17 | vocab_size=self.vocab_size, 18 | seq_length=self.seq_length, 19 | num_classes=self.num_classes 20 | ) 21 | 22 | # 测试模型输出形状 23 | x = np.random.randint(0, self.vocab_size, (self.batch_size, self.seq_length)) 24 | y = model.predict(x) 25 | self.assertEqual(y.shape, (self.batch_size, self.num_classes)) 26 | 27 | # 测试模型编译 28 | self.assertTrue(model.optimizer is not None) 29 | self.assertTrue(model.loss is not None) 30 | 31 | def test_lstm(self): 32 | model = build_lstm( 33 | vocab_size=self.vocab_size, 34 | seq_length=self.seq_length, 35 | num_classes=self.num_classes 36 | ) 37 | 38 | # 测试模型输出形状 39 | x = np.random.randint(0, self.vocab_size, (self.batch_size, self.seq_length)) 40 | y = model.predict(x) 41 | self.assertEqual(y.shape, (self.batch_size, self.num_classes)) 42 | 43 | # 测试模型编译 44 | self.assertTrue(model.optimizer is not None) 45 | self.assertTrue(model.loss is not None) 46 | 47 | def test_bert(self): 48 | model, tokenizer = build_bert( 49 | num_classes=self.num_classes, 50 | max_length=self.seq_length 51 | ) 52 | 53 | # 测试模型输入 54 | text = "这是一个测试句子" 55 | input_ids, attention_mask, token_type_ids = tokenizer( 56 | text, 57 | max_length=self.seq_length, 58 | padding='max_length', 59 | truncation=True, 60 | return_tensors='tf' 61 | ) 62 | 63 | # 测试模型输出形状 64 | y = model.predict([input_ids, attention_mask, token_type_ids]) 65 | self.assertEqual(y.shape, (1, self.num_classes)) 66 | 67 | # 测试模型编译 68 | self.assertTrue(model.optimizer is not None) 69 | self.assertTrue(model.loss is not None) 70 | 71 | if __name__ == '__main__': 72 | unittest.main() --------------------------------------------------------------------------------