├── .idea ├── .gitignore ├── LLM-TextClassification.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── data └── datasets │ ├── longnews │ ├── class.txt │ ├── dev.json │ └── train.json │ └── thucnews │ ├── class.txt │ ├── dev.json │ ├── test.json │ └── train.json ├── lora_predict.py ├── main.py ├── module ├── TemporalAttention.py ├── adapter.py ├── argument.py └── others.py ├── predict.py ├── requirements.txt ├── run.sh └── scripts └── download_model.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/LLM-TextClassification.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 38 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 大模型文本分类工具包 2 | 3 | 本项目旨在提供一个灵活高效的文本分类解决方案,基于先进的大语言模型(LLM),包括Qwen和DeepSeek。项目支持两种主要模式:直接使用预训练的LLM结合自定义分类层进行文本分类,以及通过LoRA(Low-Rank Adaptation)技术对LLM进行微调后添加分类层以实现更精准的分类效果。 4 | 5 | ## 主要特性: 6 | - **双模型支持**:集成Qwen与DeepSeek两大先进语言模型。 7 | - **多样化部署方案**:支持纯LLM+分类层模式及LLM+LoRA+分类层模式。 8 | - **易于扩展**:模块化设计便于根据需要调整或替换组件。 9 | 10 | 欢迎贡献代码、提出问题或分享您的使用案例! 11 | 12 | ## 配置环境、下载模型及运行项目 13 | 14 | ### 配置环境 15 | 16 | #### 1. 安装依赖 17 | 首先,请确保您已安装了Python(推荐版本3.10及以上)。然后,通过以下命令安装所需的依赖项: 18 | 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | #### 2.下载模型 24 | 25 | 本项目支持Qwen和DeepSeek等多种预训练语言模型。您可以通过以下方式下载所需模型(以Qwen和DeepSeek为例): 26 | 27 | **使用Qwen模型** 28 | 29 | ```python 30 | #模型下载 31 | from modelscope import snapshot_download 32 | model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct',cache_dir="./ckpt") 33 | ``` 34 | 35 | **使用DeepSeek模型** 36 | 37 | ```python 38 | #模型下载 39 | from modelscope import snapshot_download 40 | model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',cache_dir="./ckpt") 41 | ``` 42 | 43 | #### 3.准备数据集 44 | 45 | ​ 本代码专注于实现文本分类任务,其数据源于一个专门构建的长文本分类数据集。在数据预处理阶段,我们采用了0.15的比例对原始数据进行划分,以构建评估模型性能的测试集与用于学习的训练集。具体而言,训练集包含5950条精心挑选的样本,而测试集则包括了1050条样本。该分类体系涵盖了七个核心领域,分别是:“时尚”、“财经”、“时政”、“家居”、“房产”、“教育”以及“科技”。 46 | 47 | #### 4、训练 48 | 49 | ​ **LLM(全参微调)+分类层模式:** 50 | 51 | ```bash 52 | #!/bin/bash 53 | export CUDA_DEVICE_MAX_CONNECTIONS=1 54 | 55 | MODEL="Qwen2.5-0.5B-Instruct" 56 | DATA="data/datasets/longnews" 57 | 58 | function usage() { 59 | echo ' 60 | Usage: bash finetune/finetune_lora_single_gpu.sh [-m MODEL_PATH] [-d DATA_PATH] 61 | ' 62 | } 63 | 64 | while [[ "$1" != "" ]]; do 65 | case $1 in 66 | -m | --model ) 67 | shift 68 | MODEL=$1 69 | ;; 70 | -d | --data ) 71 | shift 72 | DATA=$1 73 | ;; 74 | -h | --help ) 75 | usage 76 | exit 0 77 | ;; 78 | * ) 79 | echo "Unknown argument ${1}" 80 | exit 1 81 | ;; 82 | esac 83 | shift 84 | done 85 | 86 | export CUDA_VISIBLE_DEVICES=0 87 | python main.py \ 88 | --model_name_or_path $MODEL \ 89 | --is_training True \ 90 | --data_path $DATA \ 91 | --bf16 True \ 92 | --output_dir output_qwen/longnews \ 93 | --num_train_epochs 3 \ 94 | --per_device_train_batch_size 2 \ 95 | --per_device_eval_batch_size 1 \ 96 | --gradient_accumulation_steps 2 \ 97 | --evaluation_strategy "no" \ 98 | --save_strategy "steps" \ 99 | --save_steps 2000 \ 100 | --save_total_limit 10 \ 101 | --learning_rate 3e-4 \ 102 | --weight_decay 0.1 \ 103 | --adam_beta2 0.95 \ 104 | --warmup_ratio 0.01 \ 105 | --lr_scheduler_type "cosine" \ 106 | --logging_steps 1 \ 107 | --report_to "none" \ 108 | --model_max_length 512 \ 109 | --lazy_preprocess True \ 110 | --gradient_checkpointing 111 | ``` 112 | 113 | ​ **LLM+LoRA+分类层模式** 114 | 115 | ```bash 116 | #!/bin/bash 117 | export CUDA_DEVICE_MAX_CONNECTIONS=1 118 | 119 | MODEL="Qwen2.5-0.5B-Instruct" 120 | DATA="data/datasets/longnews" 121 | 122 | function usage() { 123 | echo ' 124 | Usage: bash finetune/finetune_lora_single_gpu.sh [-m MODEL_PATH] [-d DATA_PATH] 125 | ' 126 | } 127 | 128 | while [[ "$1" != "" ]]; do 129 | case $1 in 130 | -m | --model ) 131 | shift 132 | MODEL=$1 133 | ;; 134 | -d | --data ) 135 | shift 136 | DATA=$1 137 | ;; 138 | -h | --help ) 139 | usage 140 | exit 0 141 | ;; 142 | * ) 143 | echo "Unknown argument ${1}" 144 | exit 1 145 | ;; 146 | esac 147 | shift 148 | done 149 | 150 | export CUDA_VISIBLE_DEVICES=0 151 | python main.py \ 152 | --model_name_or_path $MODEL \ 153 | --is_training True \ 154 | --data_path $DATA \ 155 | --bf16 True \ 156 | --output_dir output_qwen/longnews \ 157 | --num_train_epochs 3 \ 158 | --per_device_train_batch_size 2 \ 159 | --per_device_eval_batch_size 1 \ 160 | --gradient_accumulation_steps 2 \ 161 | --evaluation_strategy "no" \ 162 | --save_strategy "steps" \ 163 | --save_steps 2000 \ 164 | --save_total_limit 10 \ 165 | --learning_rate 3e-4 \ 166 | --weight_decay 0.1 \ 167 | --adam_beta2 0.95 \ 168 | --warmup_ratio 0.01 \ 169 | --lr_scheduler_type "cosine" \ 170 | --logging_steps 1 \ 171 | --report_to "none" \ 172 | --model_max_length 512 \ 173 | --lazy_preprocess True \ 174 | --gradient_checkpointing \ 175 | --use_lora 176 | ``` 177 | 178 | #### 5、实验结果 179 | 180 | (1)qwen全参微调+分类层 181 | 182 | ```bash 183 | precision recall f1-score support 184 | 185 | 教育 0.98 0.94 0.96 154 186 | 财经 0.95 0.96 0.95 130 187 | 科技 0.95 0.98 0.96 135 188 | 房产 0.99 0.94 0.96 156 189 | 时政 0.90 0.95 0.92 130 190 | 家居 0.95 0.94 0.95 158 191 | 时尚 0.99 1.00 0.99 138 192 | 193 | accuracy 0.96 1001 194 | macro avg 0.96 0.96 0.96 1001 195 | weighted avg 0.96 0.96 0.96 1001 196 | ``` 197 | 198 | (2)lora+deepseek+分类层结果 199 | 200 | ```bash 201 | precision recall f1-score support 202 | 203 | 教育 0.92 0.89 0.90 154 204 | 财经 0.82 0.93 0.87 130 205 | 科技 0.84 0.96 0.90 135 206 | 房产 0.87 0.88 0.87 156 207 | 时政 0.91 0.79 0.85 130 208 | 家居 0.91 0.76 0.83 158 209 | 时尚 0.92 0.98 0.95 138 210 | 211 | accuracy 0.88 1001 212 | macro avg 0.88 0.88 0.88 1001 213 | weighted avg 0.89 0.88 0.88 1001 214 | ``` 215 | 216 | (3)lora+deepseek+TemporalAttention分类层的结果 217 | 218 | ```bash 219 | precision recall f1-score support 220 | 221 | 教育 0.92 0.87 0.90 154 222 | 财经 0.81 0.90 0.85 130 223 | 科技 0.78 0.97 0.87 135 224 | 房产 0.94 0.85 0.90 156 225 | 时政 0.85 0.87 0.86 130 226 | 家居 0.91 0.74 0.82 158 227 | 时尚 0.94 0.96 0.95 138 228 | 229 | accuracy 0.88 1001 230 | macro avg 0.88 0.88 0.88 1001 231 | weighted avg 0.88 0.88 0.88 1001 232 | ``` 233 | 234 | (4)[bert-base-chinese分类](https://github.com/Dylan9897/ai-nlp-project/tree/main/TextClassification)结果对比: 235 | 236 | ```bash 237 | precision recall f1-score support 238 | 239 | 教育 0.97 0.98 0.97 154 240 | 财经 0.97 0.94 0.95 130 241 | 科技 0.97 0.99 0.98 135 242 | 房产 0.95 0.95 0.95 156 243 | 时政 0.95 0.95 0.95 130 244 | 家居 0.96 0.96 0.96 158 245 | 时尚 0.99 1.00 1.00 138 246 | 247 | accuracy 0.97 1001 248 | macro avg 0.97 0.97 0.97 1001 249 | weighted avg 0.97 0.97 0.97 1001 250 | ``` 251 | 252 | 在Thucnews数据集上的实验结果,以Qwen全参微调为例: 253 | 254 | ```bash 255 | precision recall f1-score support 256 | 257 | finance 0.92 0.87 0.89 1000 258 | realty 0.92 0.93 0.92 1000 259 | stocks 0.83 0.84 0.84 1000 260 | education 0.94 0.94 0.94 1000 261 | science 0.83 0.86 0.85 1000 262 | society 0.88 0.91 0.89 1000 263 | politics 0.88 0.88 0.88 1000 264 | sports 0.95 0.94 0.94 1000 265 | game 0.93 0.91 0.92 1000 266 | entertainment 0.90 0.91 0.91 1000 267 | 268 | accuracy 0.90 10000 269 | macro avg 0.90 0.90 0.90 10000 270 | weighted avg 0.90 0.90 0.90 10000 271 | ``` 272 | 273 | [bert-base-chinese分类](https://github.com/Dylan9897/ai-nlp-project/tree/main/TextClassification)结果对比: 274 | 275 | ```bash 276 | precision recall f1-score support 277 | 278 | finance 0.92 0.93 0.92 1000 279 | realty 0.96 0.95 0.95 1000 280 | stocks 0.91 0.89 0.90 1000 281 | education 0.96 0.97 0.97 1000 282 | science 0.91 0.90 0.91 1000 283 | society 0.90 0.95 0.93 1000 284 | politics 0.92 0.92 0.92 1000 285 | sports 0.98 0.98 0.98 1000 286 | game 0.97 0.94 0.95 1000 287 | entertainment 0.95 0.97 0.96 1000 288 | 289 | accuracy 0.94 10000 290 | macro avg 0.94 0.94 0.94 10000 291 | weighted avg 0.94 0.94 0.94 10000 292 | ``` 293 | 294 | #### 6.常问问题 295 | 296 | **(1)项目完美吗?** 297 | 298 | 答:这是一个每周夜间马拉松项目,请向我们提供反馈,我们将改进它。 299 | 300 | **(2)为什么不直接使用LLM** 301 | 302 | 答:分类器需要输出一个准确而有效的类,LLM可能会回答"**根据给定的内容,类别是\*\*\*,嗯……这取决于……**",编写解析器很麻烦。 303 | 304 | -------------------------------------------------------------------------------- /data/datasets/longnews/class.txt: -------------------------------------------------------------------------------- 1 | 教育 2 | 财经 3 | 科技 4 | 房产 5 | 时政 6 | 家居 7 | 时尚 -------------------------------------------------------------------------------- /data/datasets/thucnews/class.txt: -------------------------------------------------------------------------------- 1 | finance 2 | realty 3 | stocks 4 | education 5 | science 6 | society 7 | politics 8 | sports 9 | game 10 | entertainment -------------------------------------------------------------------------------- /lora_predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from sklearn import metrics 4 | import transformers 5 | from safetensors.torch import load_file 6 | from peft import LoraConfig, get_peft_model 7 | from module.adapter import create_and_replace 8 | 9 | 10 | 11 | from transformers import AutoModelForSequenceClassification,AutoTokenizer 12 | from module.argument import ModelArguments,DataArguments,TrainingArguments,LoraArguments 13 | 14 | parser = transformers.HfArgumentParser( 15 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 16 | ) 17 | 18 | ( 19 | model_args, 20 | data_args, 21 | training_args, 22 | lora_args, 23 | ) = parser.parse_args_into_dataclasses() 24 | 25 | device_map = None 26 | 27 | # Set RoPE scaling factor 28 | config = transformers.AutoConfig.from_pretrained( 29 | model_args.model_name_or_path, 30 | cache_dir=training_args.cache_dir, 31 | trust_remote_code=True, 32 | is_training=model_args.is_training 33 | ) 34 | config.use_cache = False 35 | print(f"checkpoint for config is {config}") 36 | 37 | model = AutoModelForSequenceClassification.from_pretrained("output_qwen/longnews",num_labels=7) 38 | tokenizer = AutoTokenizer.from_pretrained("output_qwen/longnews") 39 | model.config.pad_token_id = 151643 40 | model.cuda() 41 | 42 | lora_config = LoraConfig( 43 | r=lora_args.lora_r, 44 | lora_alpha=lora_args.lora_alpha, 45 | target_modules=lora_args.lora_target_modules, 46 | lora_dropout=lora_args.lora_dropout, 47 | bias=lora_args.lora_bias, 48 | # task_type="SEQ_CLS" 49 | task_type="CAUSAL_LM" 50 | ) 51 | model = get_peft_model(model, lora_config) 52 | create_and_replace(model) 53 | # print(model) 54 | # s = input() 55 | # 加载保存的权重 56 | weights = load_file("output_qwen/longnews/adapter_model.safetensors") 57 | model.load_state_dict(weights, strict=False) 58 | 59 | y_true = [] 60 | y_pred = [] 61 | with open("data/datasets/longnews/dev.json","r",encoding="utf-8") as f: 62 | for line in tqdm(f.readlines()): 63 | example = json.loads(line) 64 | content = example["content"] 65 | label = eval(example["label"]) 66 | y_true.append(label) 67 | input_demo = tokenizer(content, padding="max_length",truncation=True,return_tensors="pt") 68 | for key in input_demo.keys(): 69 | input_demo[key] = input_demo[key].cuda() 70 | output = model(**input_demo) 71 | pred = output.logits.argmax().item() 72 | y_pred.append(pred) 73 | 74 | report = metrics.classification_report(y_true, y_pred) 75 | print(report) 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import transformers 3 | from transformers import GPTQConfig 4 | from datasets import load_dataset 5 | from transformers import AutoModelForSequenceClassification 6 | from module.argument import ModelArguments,DataArguments,TrainingArguments,LoraArguments 7 | 8 | from peft import LoraConfig, get_peft_model 9 | from module.adapter import create_and_replace 10 | 11 | def train(verbose=False): 12 | global local_rank 13 | 14 | parser = transformers.HfArgumentParser( 15 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 16 | ) 17 | ( 18 | model_args, 19 | data_args, 20 | training_args, 21 | lora_args, 22 | ) = parser.parse_args_into_dataclasses() 23 | 24 | 25 | device_map = None 26 | 27 | # Set RoPE scaling factor 28 | config = transformers.AutoConfig.from_pretrained( 29 | model_args.model_name_or_path, 30 | cache_dir=training_args.cache_dir, 31 | trust_remote_code=True, 32 | is_training=model_args.is_training 33 | ) 34 | config.use_cache = False 35 | 36 | model = AutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path , num_labels=7) 37 | model.config.pad_token_id = 151643 38 | 39 | tokenizer = transformers.AutoTokenizer.from_pretrained( 40 | model_args.model_name_or_path, 41 | cache_dir=training_args.cache_dir, 42 | model_max_length=training_args.model_max_length, 43 | padding_side="right", 44 | use_fast=False, 45 | trust_remote_code=True, 46 | ) 47 | 48 | 49 | 50 | 51 | # exit() 52 | if training_args.use_lora: 53 | modules_to_save = ["score",'embed_tokens'] 54 | 55 | lora_config = LoraConfig( 56 | r=lora_args.lora_r, 57 | lora_alpha=lora_args.lora_alpha, 58 | target_modules=lora_args.lora_target_modules, 59 | lora_dropout=lora_args.lora_dropout, 60 | bias=lora_args.lora_bias, 61 | task_type="CAUSAL_LM", 62 | # task_type="SEQ_CLS", 63 | modules_to_save=modules_to_save # This argument serves for adding new tokens. 64 | ) 65 | model = get_peft_model(model, lora_config) 66 | 67 | # Print peft trainable params 68 | model.print_trainable_parameters() 69 | 70 | 71 | 72 | for name, param in model.named_parameters(): 73 | # 检查是否是需要设置为可更新的参数 74 | if name.startswith("base_model.model.score."): 75 | print(f"Setting {name} to be updateable.") 76 | param.requires_grad = True 77 | elif name == "base_model.model.model.embed_tokens.weight": 78 | print(f"Setting {name} to be updateable.") 79 | param.requires_grad = True 80 | else: 81 | pass 82 | if model_args.add_adapter: 83 | create_and_replace(model) 84 | if training_args.gradient_checkpointing: 85 | model.enable_input_require_grads() 86 | # 检查模型的梯度 87 | if verbose: 88 | for name, param in model.named_parameters(): 89 | if param.requires_grad: 90 | print(f"Parameter Name: {name}, Updateable: True") 91 | else: 92 | print(f"Parameter Name: {name}, Updateable: False") 93 | 94 | s = input() 95 | 96 | def process_function(examples): 97 | examples["label"] = [int(unit) for unit in examples["label"]] 98 | return tokenizer(examples["content"], padding="max_length",truncation=True) 99 | 100 | def load_data(dataset): 101 | # 加载训练和验证数据集 102 | dataset = load_dataset("json", data_files={"train": os.path.join(data_args.data_path, "train.json"), 103 | "valid": os.path.join(data_args.data_path, "dev.json")}) 104 | # 使用 map 方法应用数据处理函数,并设置 batched=True 以批量处理数据 105 | processed_dataset = dataset.map(process_function, batched=True, batch_size=16) 106 | # 移除不再需要的列,比如 'content' 和 'metadata' 107 | processed_dataset = processed_dataset.remove_columns(["content", "metadata"]) 108 | return processed_dataset 109 | 110 | processed_data = load_data(data_args) 111 | 112 | trainer = transformers.Trainer( 113 | model=model, 114 | args=training_args, 115 | train_dataset=processed_data["train"], 116 | eval_dataset=processed_data["valid"], 117 | 118 | ) 119 | # print(training_args.output_dir) 120 | trainer.train() 121 | trainer.save_state() 122 | trainer.save_model(output_dir=training_args.output_dir) 123 | tokenizer.save_pretrained(training_args.output_dir) 124 | 125 | 126 | if __name__ == "__main__": 127 | train(verbose=True) 128 | -------------------------------------------------------------------------------- /module/TemporalAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class TemporalAttention(nn.Module): 6 | def __init__(self, input_seq_len, output_seq_len, feature_dim): 7 | super(TemporalAttention, self).__init__() 8 | # Define a linear layer to compute attention scores for each time step. 9 | # We assume the same attention mechanism is applied across all features. 10 | self.attention_linear = nn.Linear(feature_dim, 1) 11 | # Output sequence length must be defined in advance. 12 | self.output_seq_len = output_seq_len 13 | # Optionally, you can add a transformation layer if needed. 14 | self.transform = nn.Linear(input_seq_len, output_seq_len) 15 | 16 | def forward(self, x): 17 | x = x.permute(0, 2, 1) 18 | # print(f"checkpoint for TemporalAttention's transform layers params is:") 19 | # print("Weight:", self.transform.weight) 20 | # print("Bias:", self.transform.bias) 21 | batch_size, seq_len, feature_dim = x.size() 22 | 23 | # Compute attention scores (batch_size, seq_len, 1) 24 | attention_scores = self.attention_linear(x).squeeze(-1) # Remove the last dimension 25 | 26 | # Apply softmax along the sequence length to get attention weights (batch_size, seq_len) 27 | attention_weights = F.softmax(attention_scores, dim=-1) 28 | 29 | # Reshape attention weights to match the input dimensions for multiplication (batch_size, seq_len, 1) 30 | attention_weights = attention_weights.unsqueeze(-1) 31 | 32 | # Expand attention weights to match the input dimensions (batch_size, seq_len, feature_dim) 33 | attention_weights = attention_weights.expand_as(x) 34 | 35 | # Apply the attention weights to the input features (batch_size, seq_len, feature_dim) 36 | weighted_input = x * attention_weights 37 | 38 | # Sum over the sequence length dimension to get the attended features (batch_size, feature_dim) 39 | attended_features = weighted_input.sum(dim=1) 40 | 41 | # Transform the attended features to match the desired output sequence length (batch_size, output_seq_len, feature_dim) 42 | output = attended_features.unsqueeze(1).expand(batch_size, self.output_seq_len, feature_dim) 43 | output = output.permute(0, 2, 1) 44 | return output 45 | 46 | if __name__ == '__main__': 47 | # Example usage 48 | batch_size = 2 49 | input_seq_len = 1536 50 | output_seq_len = 7 51 | feature_dim = 512 52 | 53 | x = torch.randn(batch_size, input_seq_len, feature_dim) # Example input tensor with shape (2, 1536, 512) 54 | attention_layer = TemporalAttention(input_seq_len=input_seq_len, output_seq_len=output_seq_len, feature_dim=feature_dim) 55 | output = attention_layer(x) # Output tensor with shape (2, 7, 512) 56 | 57 | print(output.shape) -------------------------------------------------------------------------------- /module/adapter.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForSequenceClassification 2 | import sys 3 | sys.path.append('/home/root123/workspace/handx/LLM-TextClassification') 4 | from module.TemporalAttention import TemporalAttention 5 | 6 | def create_and_replace(model): 7 | if hasattr(model, 'score'): 8 | target_model = model.score 9 | in_features = target_model.in_features 10 | out_features = target_model.out_features 11 | 12 | new_model = TemporalAttention(input_seq_len=in_features,output_seq_len=out_features,feature_dim=512) 13 | 14 | setattr(model,"score",new_model) 15 | else: 16 | raise Exception("Please confirm whether the name of the layer in the model is correct") 17 | 18 | 19 | 20 | if __name__=="__main__": 21 | model_path = "ckpt/DeepSeek-R1-Distill-Qwen-1___5B" 22 | num_labels = 7 23 | model = AutoModelForSequenceClassification.from_pretrained(model_path,num_labels=num_labels) 24 | create_and_replace(model) 25 | print(model) 26 | # for name, param in model.named_parameters(): 27 | # if param.requires_grad: 28 | # print(f"Parameter Name: {name}, Updateable: True") 29 | # else: 30 | # print(f"Parameter Name: {name}, Updateable: False") -------------------------------------------------------------------------------- /module/argument.py: -------------------------------------------------------------------------------- 1 | # encoding : utf-8 -*- 2 | # @author : 冬瓜 3 | # @mail : dylan_han@126.com 4 | # @Time : 2025/3/2 11:39 5 | # encoding : utf-8 -*- 6 | # @author : 冬瓜 7 | # @mail : dylan_han@126.com 8 | # @Time : 2024/4/8 23:15 9 | import transformers 10 | from typing import Optional, List 11 | from dataclasses import dataclass, field 12 | 13 | @dataclass 14 | class ModelArguments: 15 | model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B") 16 | bert_name_or_path: Optional[str] = field(default="bert-base-chinese") # bert模型文件地址 17 | is_training: bool = False # 是否为训练模式 18 | update_bmlp: bool = False # bert模型是否更新 19 | 20 | 21 | @dataclass 22 | class DataArguments: 23 | data_path: str = field( 24 | default=None, metadata={"help": "Path to the training data."} 25 | ) 26 | eval_data_path: str = field( 27 | default=None, metadata={"help": "Path to the evaluation data."} 28 | ) 29 | lazy_preprocess: bool = False 30 | 31 | 32 | @dataclass 33 | class TrainingArguments(transformers.TrainingArguments): 34 | cache_dir: Optional[str] = field(default=None) 35 | optim: str = field(default="adamw_torch") 36 | model_max_length: int = field( 37 | default=8192, 38 | metadata={ 39 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 40 | }, 41 | ) 42 | use_lora: bool = False 43 | remove_unused_columns:bool = False 44 | 45 | @dataclass 46 | class LoraArguments: 47 | lora_r: int = 64 48 | lora_alpha: int = 16 49 | lora_dropout: float = 0.05 50 | lora_target_modules: List[str] = field( 51 | default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] 52 | ) 53 | lora_weight_path: str = "" 54 | lora_bias: str = "none" 55 | q_lora: bool = False 56 | 57 | -------------------------------------------------------------------------------- /module/others.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.functional import softplus 5 | 6 | 7 | class Mish(nn.Module): 8 | def __init__(self): 9 | super(Mish, self).__init__() 10 | print('Mish activation loaded') 11 | 12 | def forward(self,x): 13 | out = F.softplus(x) 14 | x = x*(torch.tanh(softplus(x))) 15 | return x 16 | 17 | class FeedForwardNetwork(nn.Module): 18 | def __init__(self, input_size, output_size,init_method='xavier_uniform',bias=False): 19 | super(FeedForwardNetwork, self).__init__() 20 | # 定义全连接层 21 | if not bias: 22 | self.fc = nn.Linear(input_size, output_size,bias=False) 23 | else: 24 | self.fc = nn.Linear(input_size, output_size) 25 | 26 | 27 | # 参数初始化 28 | if init_method == 'xavier_uniform': 29 | nn.init.xavier_uniform_(self.fc.weight) 30 | if bias: 31 | self.fc.bias.data.fill_(0) # 初始化偏置为0 32 | elif init_method == 'kaiming_uniform': 33 | nn.init.kaiming_uniform_(self.fc.weight, nonlinearity='relu') 34 | if bias: 35 | self.fc.bias.data.fill_(0) # 初始化偏置为0 36 | else: 37 | raise ValueError('Unsupported initialization method') 38 | 39 | def forward(self, x): 40 | out = self.fc(x) 41 | return out 42 | 43 | class GRUCell(nn.Module): 44 | def __init__(self, input_size, hidden_size): 45 | super(GRUCell, self).__init__() 46 | 47 | 48 | # 定义门控和线性转换层 49 | self.input_gate = FeedForwardNetwork(input_size + hidden_size, hidden_size) 50 | self.update_gate = FeedForwardNetwork(input_size + hidden_size, hidden_size) 51 | self.reset_gate = FeedForwardNetwork(input_size + hidden_size, hidden_size) 52 | 53 | # 激活函数为sigmoid和tanh 54 | self.sigmoid = nn.Sigmoid() 55 | self.tanh = nn.Tanh() 56 | 57 | def forward(self, input_, hidden_state): 58 | combined = torch.cat((input_, hidden_state), dim=2) # torch.Size([2, 512, 4096]) 59 | 60 | # 计算更新门和重置门信号 61 | z = self.sigmoid(self.update_gate(combined)) 62 | r = self.sigmoid(self.reset_gate(combined)) 63 | # 计算候选隐藏状态 64 | h_prime = self.tanh(self.input_gate(torch.cat((input_, r * hidden_state), dim=2))) 65 | # 更新隐藏状态 66 | hidden_state = (1 - z) * hidden_state + z * h_prime 67 | return hidden_state 68 | 69 | 70 | class TextCNN(nn.Module): 71 | def __init__(self,in_features,out_features): 72 | super().__init__() 73 | self.mish = Mish() 74 | self.convs = nn.ModuleList([nn.Conv2d(1, 512, (k, in_features)) for k in [2,3,4]]) 75 | self.dropout = nn.Dropout(0.5) 76 | 77 | self.fc = FeedForwardNetwork(512 * 3,out_features) 78 | 79 | 80 | def conv_and_pool(self, x, conv): 81 | x = self.mish(conv(x)).squeeze(3) 82 | print(x.shape) 83 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 84 | print(f"in conv and pool x.shape is {x.shape}") 85 | return x 86 | 87 | def forward(self, x): 88 | print(f"x.shape is {x.shape}") 89 | # out = x.unsqueeze(1) 90 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1) 91 | out = self.dropout(out) 92 | print(f"textcnn output is {out.shape}") 93 | s = input() 94 | out = self.fc(out) 95 | print(f"textcnn output is {out.shape}") 96 | # out = torch.index_select(out, dim=1, index=indices) 97 | return out 98 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from sklearn import metrics 4 | from transformers import AutoModelForSequenceClassification,AutoTokenizer 5 | 6 | model = AutoModelForSequenceClassification.from_pretrained("output_qwen/thucnews") 7 | tokenizer = AutoTokenizer.from_pretrained("output_qwen/thucnews") 8 | model.cuda() 9 | 10 | 11 | 12 | 13 | y_true = [] 14 | y_pred = [] 15 | with open("data/datasets/thucnews/test.json","r",encoding="utf-8") as f: 16 | for line in tqdm(f.readlines()): 17 | example = json.loads(line) 18 | content = example["content"] 19 | label = eval(example["label"]) 20 | y_true.append(label) 21 | 22 | input_demo = tokenizer(content, padding="max_length",truncation=True,return_tensors="pt") 23 | 24 | for key in input_demo.keys(): 25 | input_demo[key] = input_demo[key].cuda() 26 | 27 | output = model(**input_demo) 28 | 29 | pred = output.logits.argmax().item() 30 | 31 | 32 | y_pred.append(pred) 33 | 34 | columns = open("data/datasets/thucnews/class.txt","r",encoding="utf-8").readlines() 35 | columns = [x.strip("\n") for x in columns] 36 | 37 | report = metrics.classification_report(y_true, y_pred, target_names=columns) 38 | print(report) 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | peft 2 | transformers 3 | torch 4 | datasets 5 | peft 6 | safetensors 7 | scikit-learn 8 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | 4 | MODEL="ckpt/DeepSeek-R1-Distill-Qwen-1___5B" # Set the path if you do not want to load from huggingface directly 5 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 6 | # See the section for finetuning in README for more information. 7 | DATA="data/datasets/longnews" 8 | 9 | function usage() { 10 | echo ' 11 | Usage: bash finetune/finetune_lora_single_gpu.sh [-m MODEL_PATH] [-d DATA_PATH] 12 | ' 13 | } 14 | 15 | while [[ "$1" != "" ]]; do 16 | case $1 in 17 | -m | --model ) 18 | shift 19 | MODEL=$1 20 | ;; 21 | -d | --data ) 22 | shift 23 | DATA=$1 24 | ;; 25 | -h | --help ) 26 | usage 27 | exit 0 28 | ;; 29 | * ) 30 | echo "Unknown argument ${1}" 31 | exit 1 32 | ;; 33 | esac 34 | shift 35 | done 36 | 37 | export CUDA_VISIBLE_DEVICES=0 38 | python lora_predict.py \ 39 | --model_name_or_path $MODEL \ 40 | --is_training True \ 41 | --add_adapter True \ 42 | --data_path $DATA \ 43 | --bf16 True \ 44 | --output_dir output_qwen/longnews\ 45 | --num_train_epochs 3 \ 46 | --per_device_train_batch_size 2 \ 47 | --per_device_eval_batch_size 1 \ 48 | --gradient_accumulation_steps 2 \ 49 | --evaluation_strategy "no" \ 50 | --save_strategy "steps" \ 51 | --save_steps 2000 \ 52 | --save_total_limit 10 \ 53 | --learning_rate 5e-4 \ 54 | --weight_decay 0.1 \ 55 | --adam_beta2 0.95 \ 56 | --warmup_ratio 0.01 \ 57 | --warmup_steps 1000 \ 58 | --lr_scheduler_type "cosine" \ 59 | --logging_steps 1 \ 60 | --report_to "none" \ 61 | --model_max_length 512 \ 62 | --lazy_preprocess True \ 63 | --gradient_checkpointing \ 64 | --use_lora 65 | 66 | -------------------------------------------------------------------------------- /scripts/download_model.py: -------------------------------------------------------------------------------- 1 | # encoding : utf-8 -*- 2 | # @author : 冬瓜 3 | # @mail : dylan_han@126.com 4 | # @Time : 2025/3/2 11:49 5 | 6 | from modelscope import snapshot_download 7 | snapshot_download('Qwen/Qwen2.5-0.5B-Instruct',cache_dir="./ckpt") 8 | snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',cache_dir="./ckpt") --------------------------------------------------------------------------------