├── .idea
├── .gitignore
├── LLM-TextClassification.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── data
└── datasets
│ ├── longnews
│ ├── class.txt
│ ├── dev.json
│ └── train.json
│ └── thucnews
│ ├── class.txt
│ ├── dev.json
│ ├── test.json
│ └── train.json
├── lora_predict.py
├── main.py
├── module
├── TemporalAttention.py
├── adapter.py
├── argument.py
└── others.py
├── predict.py
├── requirements.txt
├── run.sh
└── scripts
└── download_model.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/LLM-TextClassification.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 大模型文本分类工具包
2 |
3 | 本项目旨在提供一个灵活高效的文本分类解决方案,基于先进的大语言模型(LLM),包括Qwen和DeepSeek。项目支持两种主要模式:直接使用预训练的LLM结合自定义分类层进行文本分类,以及通过LoRA(Low-Rank Adaptation)技术对LLM进行微调后添加分类层以实现更精准的分类效果。
4 |
5 | ## 主要特性:
6 | - **双模型支持**:集成Qwen与DeepSeek两大先进语言模型。
7 | - **多样化部署方案**:支持纯LLM+分类层模式及LLM+LoRA+分类层模式。
8 | - **易于扩展**:模块化设计便于根据需要调整或替换组件。
9 |
10 | 欢迎贡献代码、提出问题或分享您的使用案例!
11 |
12 | ## 配置环境、下载模型及运行项目
13 |
14 | ### 配置环境
15 |
16 | #### 1. 安装依赖
17 | 首先,请确保您已安装了Python(推荐版本3.10及以上)。然后,通过以下命令安装所需的依赖项:
18 |
19 | ```bash
20 | pip install -r requirements.txt
21 | ```
22 |
23 | #### 2.下载模型
24 |
25 | 本项目支持Qwen和DeepSeek等多种预训练语言模型。您可以通过以下方式下载所需模型(以Qwen和DeepSeek为例):
26 |
27 | **使用Qwen模型**
28 |
29 | ```python
30 | #模型下载
31 | from modelscope import snapshot_download
32 | model_dir = snapshot_download('Qwen/Qwen2.5-0.5B-Instruct',cache_dir="./ckpt")
33 | ```
34 |
35 | **使用DeepSeek模型**
36 |
37 | ```python
38 | #模型下载
39 | from modelscope import snapshot_download
40 | model_dir = snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',cache_dir="./ckpt")
41 | ```
42 |
43 | #### 3.准备数据集
44 |
45 | 本代码专注于实现文本分类任务,其数据源于一个专门构建的长文本分类数据集。在数据预处理阶段,我们采用了0.15的比例对原始数据进行划分,以构建评估模型性能的测试集与用于学习的训练集。具体而言,训练集包含5950条精心挑选的样本,而测试集则包括了1050条样本。该分类体系涵盖了七个核心领域,分别是:“时尚”、“财经”、“时政”、“家居”、“房产”、“教育”以及“科技”。
46 |
47 | #### 4、训练
48 |
49 | **LLM(全参微调)+分类层模式:**
50 |
51 | ```bash
52 | #!/bin/bash
53 | export CUDA_DEVICE_MAX_CONNECTIONS=1
54 |
55 | MODEL="Qwen2.5-0.5B-Instruct"
56 | DATA="data/datasets/longnews"
57 |
58 | function usage() {
59 | echo '
60 | Usage: bash finetune/finetune_lora_single_gpu.sh [-m MODEL_PATH] [-d DATA_PATH]
61 | '
62 | }
63 |
64 | while [[ "$1" != "" ]]; do
65 | case $1 in
66 | -m | --model )
67 | shift
68 | MODEL=$1
69 | ;;
70 | -d | --data )
71 | shift
72 | DATA=$1
73 | ;;
74 | -h | --help )
75 | usage
76 | exit 0
77 | ;;
78 | * )
79 | echo "Unknown argument ${1}"
80 | exit 1
81 | ;;
82 | esac
83 | shift
84 | done
85 |
86 | export CUDA_VISIBLE_DEVICES=0
87 | python main.py \
88 | --model_name_or_path $MODEL \
89 | --is_training True \
90 | --data_path $DATA \
91 | --bf16 True \
92 | --output_dir output_qwen/longnews \
93 | --num_train_epochs 3 \
94 | --per_device_train_batch_size 2 \
95 | --per_device_eval_batch_size 1 \
96 | --gradient_accumulation_steps 2 \
97 | --evaluation_strategy "no" \
98 | --save_strategy "steps" \
99 | --save_steps 2000 \
100 | --save_total_limit 10 \
101 | --learning_rate 3e-4 \
102 | --weight_decay 0.1 \
103 | --adam_beta2 0.95 \
104 | --warmup_ratio 0.01 \
105 | --lr_scheduler_type "cosine" \
106 | --logging_steps 1 \
107 | --report_to "none" \
108 | --model_max_length 512 \
109 | --lazy_preprocess True \
110 | --gradient_checkpointing
111 | ```
112 |
113 | **LLM+LoRA+分类层模式**
114 |
115 | ```bash
116 | #!/bin/bash
117 | export CUDA_DEVICE_MAX_CONNECTIONS=1
118 |
119 | MODEL="Qwen2.5-0.5B-Instruct"
120 | DATA="data/datasets/longnews"
121 |
122 | function usage() {
123 | echo '
124 | Usage: bash finetune/finetune_lora_single_gpu.sh [-m MODEL_PATH] [-d DATA_PATH]
125 | '
126 | }
127 |
128 | while [[ "$1" != "" ]]; do
129 | case $1 in
130 | -m | --model )
131 | shift
132 | MODEL=$1
133 | ;;
134 | -d | --data )
135 | shift
136 | DATA=$1
137 | ;;
138 | -h | --help )
139 | usage
140 | exit 0
141 | ;;
142 | * )
143 | echo "Unknown argument ${1}"
144 | exit 1
145 | ;;
146 | esac
147 | shift
148 | done
149 |
150 | export CUDA_VISIBLE_DEVICES=0
151 | python main.py \
152 | --model_name_or_path $MODEL \
153 | --is_training True \
154 | --data_path $DATA \
155 | --bf16 True \
156 | --output_dir output_qwen/longnews \
157 | --num_train_epochs 3 \
158 | --per_device_train_batch_size 2 \
159 | --per_device_eval_batch_size 1 \
160 | --gradient_accumulation_steps 2 \
161 | --evaluation_strategy "no" \
162 | --save_strategy "steps" \
163 | --save_steps 2000 \
164 | --save_total_limit 10 \
165 | --learning_rate 3e-4 \
166 | --weight_decay 0.1 \
167 | --adam_beta2 0.95 \
168 | --warmup_ratio 0.01 \
169 | --lr_scheduler_type "cosine" \
170 | --logging_steps 1 \
171 | --report_to "none" \
172 | --model_max_length 512 \
173 | --lazy_preprocess True \
174 | --gradient_checkpointing \
175 | --use_lora
176 | ```
177 |
178 | #### 5、实验结果
179 |
180 | (1)qwen全参微调+分类层
181 |
182 | ```bash
183 | precision recall f1-score support
184 |
185 | 教育 0.98 0.94 0.96 154
186 | 财经 0.95 0.96 0.95 130
187 | 科技 0.95 0.98 0.96 135
188 | 房产 0.99 0.94 0.96 156
189 | 时政 0.90 0.95 0.92 130
190 | 家居 0.95 0.94 0.95 158
191 | 时尚 0.99 1.00 0.99 138
192 |
193 | accuracy 0.96 1001
194 | macro avg 0.96 0.96 0.96 1001
195 | weighted avg 0.96 0.96 0.96 1001
196 | ```
197 |
198 | (2)lora+deepseek+分类层结果
199 |
200 | ```bash
201 | precision recall f1-score support
202 |
203 | 教育 0.92 0.89 0.90 154
204 | 财经 0.82 0.93 0.87 130
205 | 科技 0.84 0.96 0.90 135
206 | 房产 0.87 0.88 0.87 156
207 | 时政 0.91 0.79 0.85 130
208 | 家居 0.91 0.76 0.83 158
209 | 时尚 0.92 0.98 0.95 138
210 |
211 | accuracy 0.88 1001
212 | macro avg 0.88 0.88 0.88 1001
213 | weighted avg 0.89 0.88 0.88 1001
214 | ```
215 |
216 | (3)lora+deepseek+TemporalAttention分类层的结果
217 |
218 | ```bash
219 | precision recall f1-score support
220 |
221 | 教育 0.92 0.87 0.90 154
222 | 财经 0.81 0.90 0.85 130
223 | 科技 0.78 0.97 0.87 135
224 | 房产 0.94 0.85 0.90 156
225 | 时政 0.85 0.87 0.86 130
226 | 家居 0.91 0.74 0.82 158
227 | 时尚 0.94 0.96 0.95 138
228 |
229 | accuracy 0.88 1001
230 | macro avg 0.88 0.88 0.88 1001
231 | weighted avg 0.88 0.88 0.88 1001
232 | ```
233 |
234 | (4)[bert-base-chinese分类](https://github.com/Dylan9897/ai-nlp-project/tree/main/TextClassification)结果对比:
235 |
236 | ```bash
237 | precision recall f1-score support
238 |
239 | 教育 0.97 0.98 0.97 154
240 | 财经 0.97 0.94 0.95 130
241 | 科技 0.97 0.99 0.98 135
242 | 房产 0.95 0.95 0.95 156
243 | 时政 0.95 0.95 0.95 130
244 | 家居 0.96 0.96 0.96 158
245 | 时尚 0.99 1.00 1.00 138
246 |
247 | accuracy 0.97 1001
248 | macro avg 0.97 0.97 0.97 1001
249 | weighted avg 0.97 0.97 0.97 1001
250 | ```
251 |
252 | 在Thucnews数据集上的实验结果,以Qwen全参微调为例:
253 |
254 | ```bash
255 | precision recall f1-score support
256 |
257 | finance 0.92 0.87 0.89 1000
258 | realty 0.92 0.93 0.92 1000
259 | stocks 0.83 0.84 0.84 1000
260 | education 0.94 0.94 0.94 1000
261 | science 0.83 0.86 0.85 1000
262 | society 0.88 0.91 0.89 1000
263 | politics 0.88 0.88 0.88 1000
264 | sports 0.95 0.94 0.94 1000
265 | game 0.93 0.91 0.92 1000
266 | entertainment 0.90 0.91 0.91 1000
267 |
268 | accuracy 0.90 10000
269 | macro avg 0.90 0.90 0.90 10000
270 | weighted avg 0.90 0.90 0.90 10000
271 | ```
272 |
273 | [bert-base-chinese分类](https://github.com/Dylan9897/ai-nlp-project/tree/main/TextClassification)结果对比:
274 |
275 | ```bash
276 | precision recall f1-score support
277 |
278 | finance 0.92 0.93 0.92 1000
279 | realty 0.96 0.95 0.95 1000
280 | stocks 0.91 0.89 0.90 1000
281 | education 0.96 0.97 0.97 1000
282 | science 0.91 0.90 0.91 1000
283 | society 0.90 0.95 0.93 1000
284 | politics 0.92 0.92 0.92 1000
285 | sports 0.98 0.98 0.98 1000
286 | game 0.97 0.94 0.95 1000
287 | entertainment 0.95 0.97 0.96 1000
288 |
289 | accuracy 0.94 10000
290 | macro avg 0.94 0.94 0.94 10000
291 | weighted avg 0.94 0.94 0.94 10000
292 | ```
293 |
294 | #### 6.常问问题
295 |
296 | **(1)项目完美吗?**
297 |
298 | 答:这是一个每周夜间马拉松项目,请向我们提供反馈,我们将改进它。
299 |
300 | **(2)为什么不直接使用LLM**
301 |
302 | 答:分类器需要输出一个准确而有效的类,LLM可能会回答"**根据给定的内容,类别是\*\*\*,嗯……这取决于……**",编写解析器很麻烦。
303 |
304 |
--------------------------------------------------------------------------------
/data/datasets/longnews/class.txt:
--------------------------------------------------------------------------------
1 | 教育
2 | 财经
3 | 科技
4 | 房产
5 | 时政
6 | 家居
7 | 时尚
--------------------------------------------------------------------------------
/data/datasets/thucnews/class.txt:
--------------------------------------------------------------------------------
1 | finance
2 | realty
3 | stocks
4 | education
5 | science
6 | society
7 | politics
8 | sports
9 | game
10 | entertainment
--------------------------------------------------------------------------------
/lora_predict.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from sklearn import metrics
4 | import transformers
5 | from safetensors.torch import load_file
6 | from peft import LoraConfig, get_peft_model
7 | from module.adapter import create_and_replace
8 |
9 |
10 |
11 | from transformers import AutoModelForSequenceClassification,AutoTokenizer
12 | from module.argument import ModelArguments,DataArguments,TrainingArguments,LoraArguments
13 |
14 | parser = transformers.HfArgumentParser(
15 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
16 | )
17 |
18 | (
19 | model_args,
20 | data_args,
21 | training_args,
22 | lora_args,
23 | ) = parser.parse_args_into_dataclasses()
24 |
25 | device_map = None
26 |
27 | # Set RoPE scaling factor
28 | config = transformers.AutoConfig.from_pretrained(
29 | model_args.model_name_or_path,
30 | cache_dir=training_args.cache_dir,
31 | trust_remote_code=True,
32 | is_training=model_args.is_training
33 | )
34 | config.use_cache = False
35 | print(f"checkpoint for config is {config}")
36 |
37 | model = AutoModelForSequenceClassification.from_pretrained("output_qwen/longnews",num_labels=7)
38 | tokenizer = AutoTokenizer.from_pretrained("output_qwen/longnews")
39 | model.config.pad_token_id = 151643
40 | model.cuda()
41 |
42 | lora_config = LoraConfig(
43 | r=lora_args.lora_r,
44 | lora_alpha=lora_args.lora_alpha,
45 | target_modules=lora_args.lora_target_modules,
46 | lora_dropout=lora_args.lora_dropout,
47 | bias=lora_args.lora_bias,
48 | # task_type="SEQ_CLS"
49 | task_type="CAUSAL_LM"
50 | )
51 | model = get_peft_model(model, lora_config)
52 | create_and_replace(model)
53 | # print(model)
54 | # s = input()
55 | # 加载保存的权重
56 | weights = load_file("output_qwen/longnews/adapter_model.safetensors")
57 | model.load_state_dict(weights, strict=False)
58 |
59 | y_true = []
60 | y_pred = []
61 | with open("data/datasets/longnews/dev.json","r",encoding="utf-8") as f:
62 | for line in tqdm(f.readlines()):
63 | example = json.loads(line)
64 | content = example["content"]
65 | label = eval(example["label"])
66 | y_true.append(label)
67 | input_demo = tokenizer(content, padding="max_length",truncation=True,return_tensors="pt")
68 | for key in input_demo.keys():
69 | input_demo[key] = input_demo[key].cuda()
70 | output = model(**input_demo)
71 | pred = output.logits.argmax().item()
72 | y_pred.append(pred)
73 |
74 | report = metrics.classification_report(y_true, y_pred)
75 | print(report)
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import transformers
3 | from transformers import GPTQConfig
4 | from datasets import load_dataset
5 | from transformers import AutoModelForSequenceClassification
6 | from module.argument import ModelArguments,DataArguments,TrainingArguments,LoraArguments
7 |
8 | from peft import LoraConfig, get_peft_model
9 | from module.adapter import create_and_replace
10 |
11 | def train(verbose=False):
12 | global local_rank
13 |
14 | parser = transformers.HfArgumentParser(
15 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
16 | )
17 | (
18 | model_args,
19 | data_args,
20 | training_args,
21 | lora_args,
22 | ) = parser.parse_args_into_dataclasses()
23 |
24 |
25 | device_map = None
26 |
27 | # Set RoPE scaling factor
28 | config = transformers.AutoConfig.from_pretrained(
29 | model_args.model_name_or_path,
30 | cache_dir=training_args.cache_dir,
31 | trust_remote_code=True,
32 | is_training=model_args.is_training
33 | )
34 | config.use_cache = False
35 |
36 | model = AutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path , num_labels=7)
37 | model.config.pad_token_id = 151643
38 |
39 | tokenizer = transformers.AutoTokenizer.from_pretrained(
40 | model_args.model_name_or_path,
41 | cache_dir=training_args.cache_dir,
42 | model_max_length=training_args.model_max_length,
43 | padding_side="right",
44 | use_fast=False,
45 | trust_remote_code=True,
46 | )
47 |
48 |
49 |
50 |
51 | # exit()
52 | if training_args.use_lora:
53 | modules_to_save = ["score",'embed_tokens']
54 |
55 | lora_config = LoraConfig(
56 | r=lora_args.lora_r,
57 | lora_alpha=lora_args.lora_alpha,
58 | target_modules=lora_args.lora_target_modules,
59 | lora_dropout=lora_args.lora_dropout,
60 | bias=lora_args.lora_bias,
61 | task_type="CAUSAL_LM",
62 | # task_type="SEQ_CLS",
63 | modules_to_save=modules_to_save # This argument serves for adding new tokens.
64 | )
65 | model = get_peft_model(model, lora_config)
66 |
67 | # Print peft trainable params
68 | model.print_trainable_parameters()
69 |
70 |
71 |
72 | for name, param in model.named_parameters():
73 | # 检查是否是需要设置为可更新的参数
74 | if name.startswith("base_model.model.score."):
75 | print(f"Setting {name} to be updateable.")
76 | param.requires_grad = True
77 | elif name == "base_model.model.model.embed_tokens.weight":
78 | print(f"Setting {name} to be updateable.")
79 | param.requires_grad = True
80 | else:
81 | pass
82 | if model_args.add_adapter:
83 | create_and_replace(model)
84 | if training_args.gradient_checkpointing:
85 | model.enable_input_require_grads()
86 | # 检查模型的梯度
87 | if verbose:
88 | for name, param in model.named_parameters():
89 | if param.requires_grad:
90 | print(f"Parameter Name: {name}, Updateable: True")
91 | else:
92 | print(f"Parameter Name: {name}, Updateable: False")
93 |
94 | s = input()
95 |
96 | def process_function(examples):
97 | examples["label"] = [int(unit) for unit in examples["label"]]
98 | return tokenizer(examples["content"], padding="max_length",truncation=True)
99 |
100 | def load_data(dataset):
101 | # 加载训练和验证数据集
102 | dataset = load_dataset("json", data_files={"train": os.path.join(data_args.data_path, "train.json"),
103 | "valid": os.path.join(data_args.data_path, "dev.json")})
104 | # 使用 map 方法应用数据处理函数,并设置 batched=True 以批量处理数据
105 | processed_dataset = dataset.map(process_function, batched=True, batch_size=16)
106 | # 移除不再需要的列,比如 'content' 和 'metadata'
107 | processed_dataset = processed_dataset.remove_columns(["content", "metadata"])
108 | return processed_dataset
109 |
110 | processed_data = load_data(data_args)
111 |
112 | trainer = transformers.Trainer(
113 | model=model,
114 | args=training_args,
115 | train_dataset=processed_data["train"],
116 | eval_dataset=processed_data["valid"],
117 |
118 | )
119 | # print(training_args.output_dir)
120 | trainer.train()
121 | trainer.save_state()
122 | trainer.save_model(output_dir=training_args.output_dir)
123 | tokenizer.save_pretrained(training_args.output_dir)
124 |
125 |
126 | if __name__ == "__main__":
127 | train(verbose=True)
128 |
--------------------------------------------------------------------------------
/module/TemporalAttention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class TemporalAttention(nn.Module):
6 | def __init__(self, input_seq_len, output_seq_len, feature_dim):
7 | super(TemporalAttention, self).__init__()
8 | # Define a linear layer to compute attention scores for each time step.
9 | # We assume the same attention mechanism is applied across all features.
10 | self.attention_linear = nn.Linear(feature_dim, 1)
11 | # Output sequence length must be defined in advance.
12 | self.output_seq_len = output_seq_len
13 | # Optionally, you can add a transformation layer if needed.
14 | self.transform = nn.Linear(input_seq_len, output_seq_len)
15 |
16 | def forward(self, x):
17 | x = x.permute(0, 2, 1)
18 | # print(f"checkpoint for TemporalAttention's transform layers params is:")
19 | # print("Weight:", self.transform.weight)
20 | # print("Bias:", self.transform.bias)
21 | batch_size, seq_len, feature_dim = x.size()
22 |
23 | # Compute attention scores (batch_size, seq_len, 1)
24 | attention_scores = self.attention_linear(x).squeeze(-1) # Remove the last dimension
25 |
26 | # Apply softmax along the sequence length to get attention weights (batch_size, seq_len)
27 | attention_weights = F.softmax(attention_scores, dim=-1)
28 |
29 | # Reshape attention weights to match the input dimensions for multiplication (batch_size, seq_len, 1)
30 | attention_weights = attention_weights.unsqueeze(-1)
31 |
32 | # Expand attention weights to match the input dimensions (batch_size, seq_len, feature_dim)
33 | attention_weights = attention_weights.expand_as(x)
34 |
35 | # Apply the attention weights to the input features (batch_size, seq_len, feature_dim)
36 | weighted_input = x * attention_weights
37 |
38 | # Sum over the sequence length dimension to get the attended features (batch_size, feature_dim)
39 | attended_features = weighted_input.sum(dim=1)
40 |
41 | # Transform the attended features to match the desired output sequence length (batch_size, output_seq_len, feature_dim)
42 | output = attended_features.unsqueeze(1).expand(batch_size, self.output_seq_len, feature_dim)
43 | output = output.permute(0, 2, 1)
44 | return output
45 |
46 | if __name__ == '__main__':
47 | # Example usage
48 | batch_size = 2
49 | input_seq_len = 1536
50 | output_seq_len = 7
51 | feature_dim = 512
52 |
53 | x = torch.randn(batch_size, input_seq_len, feature_dim) # Example input tensor with shape (2, 1536, 512)
54 | attention_layer = TemporalAttention(input_seq_len=input_seq_len, output_seq_len=output_seq_len, feature_dim=feature_dim)
55 | output = attention_layer(x) # Output tensor with shape (2, 7, 512)
56 |
57 | print(output.shape)
--------------------------------------------------------------------------------
/module/adapter.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForSequenceClassification
2 | import sys
3 | sys.path.append('/home/root123/workspace/handx/LLM-TextClassification')
4 | from module.TemporalAttention import TemporalAttention
5 |
6 | def create_and_replace(model):
7 | if hasattr(model, 'score'):
8 | target_model = model.score
9 | in_features = target_model.in_features
10 | out_features = target_model.out_features
11 |
12 | new_model = TemporalAttention(input_seq_len=in_features,output_seq_len=out_features,feature_dim=512)
13 |
14 | setattr(model,"score",new_model)
15 | else:
16 | raise Exception("Please confirm whether the name of the layer in the model is correct")
17 |
18 |
19 |
20 | if __name__=="__main__":
21 | model_path = "ckpt/DeepSeek-R1-Distill-Qwen-1___5B"
22 | num_labels = 7
23 | model = AutoModelForSequenceClassification.from_pretrained(model_path,num_labels=num_labels)
24 | create_and_replace(model)
25 | print(model)
26 | # for name, param in model.named_parameters():
27 | # if param.requires_grad:
28 | # print(f"Parameter Name: {name}, Updateable: True")
29 | # else:
30 | # print(f"Parameter Name: {name}, Updateable: False")
--------------------------------------------------------------------------------
/module/argument.py:
--------------------------------------------------------------------------------
1 | # encoding : utf-8 -*-
2 | # @author : 冬瓜
3 | # @mail : dylan_han@126.com
4 | # @Time : 2025/3/2 11:39
5 | # encoding : utf-8 -*-
6 | # @author : 冬瓜
7 | # @mail : dylan_han@126.com
8 | # @Time : 2024/4/8 23:15
9 | import transformers
10 | from typing import Optional, List
11 | from dataclasses import dataclass, field
12 |
13 | @dataclass
14 | class ModelArguments:
15 | model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")
16 | bert_name_or_path: Optional[str] = field(default="bert-base-chinese") # bert模型文件地址
17 | is_training: bool = False # 是否为训练模式
18 | update_bmlp: bool = False # bert模型是否更新
19 |
20 |
21 | @dataclass
22 | class DataArguments:
23 | data_path: str = field(
24 | default=None, metadata={"help": "Path to the training data."}
25 | )
26 | eval_data_path: str = field(
27 | default=None, metadata={"help": "Path to the evaluation data."}
28 | )
29 | lazy_preprocess: bool = False
30 |
31 |
32 | @dataclass
33 | class TrainingArguments(transformers.TrainingArguments):
34 | cache_dir: Optional[str] = field(default=None)
35 | optim: str = field(default="adamw_torch")
36 | model_max_length: int = field(
37 | default=8192,
38 | metadata={
39 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
40 | },
41 | )
42 | use_lora: bool = False
43 | remove_unused_columns:bool = False
44 |
45 | @dataclass
46 | class LoraArguments:
47 | lora_r: int = 64
48 | lora_alpha: int = 16
49 | lora_dropout: float = 0.05
50 | lora_target_modules: List[str] = field(
51 | default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
52 | )
53 | lora_weight_path: str = ""
54 | lora_bias: str = "none"
55 | q_lora: bool = False
56 |
57 |
--------------------------------------------------------------------------------
/module/others.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.functional import softplus
5 |
6 |
7 | class Mish(nn.Module):
8 | def __init__(self):
9 | super(Mish, self).__init__()
10 | print('Mish activation loaded')
11 |
12 | def forward(self,x):
13 | out = F.softplus(x)
14 | x = x*(torch.tanh(softplus(x)))
15 | return x
16 |
17 | class FeedForwardNetwork(nn.Module):
18 | def __init__(self, input_size, output_size,init_method='xavier_uniform',bias=False):
19 | super(FeedForwardNetwork, self).__init__()
20 | # 定义全连接层
21 | if not bias:
22 | self.fc = nn.Linear(input_size, output_size,bias=False)
23 | else:
24 | self.fc = nn.Linear(input_size, output_size)
25 |
26 |
27 | # 参数初始化
28 | if init_method == 'xavier_uniform':
29 | nn.init.xavier_uniform_(self.fc.weight)
30 | if bias:
31 | self.fc.bias.data.fill_(0) # 初始化偏置为0
32 | elif init_method == 'kaiming_uniform':
33 | nn.init.kaiming_uniform_(self.fc.weight, nonlinearity='relu')
34 | if bias:
35 | self.fc.bias.data.fill_(0) # 初始化偏置为0
36 | else:
37 | raise ValueError('Unsupported initialization method')
38 |
39 | def forward(self, x):
40 | out = self.fc(x)
41 | return out
42 |
43 | class GRUCell(nn.Module):
44 | def __init__(self, input_size, hidden_size):
45 | super(GRUCell, self).__init__()
46 |
47 |
48 | # 定义门控和线性转换层
49 | self.input_gate = FeedForwardNetwork(input_size + hidden_size, hidden_size)
50 | self.update_gate = FeedForwardNetwork(input_size + hidden_size, hidden_size)
51 | self.reset_gate = FeedForwardNetwork(input_size + hidden_size, hidden_size)
52 |
53 | # 激活函数为sigmoid和tanh
54 | self.sigmoid = nn.Sigmoid()
55 | self.tanh = nn.Tanh()
56 |
57 | def forward(self, input_, hidden_state):
58 | combined = torch.cat((input_, hidden_state), dim=2) # torch.Size([2, 512, 4096])
59 |
60 | # 计算更新门和重置门信号
61 | z = self.sigmoid(self.update_gate(combined))
62 | r = self.sigmoid(self.reset_gate(combined))
63 | # 计算候选隐藏状态
64 | h_prime = self.tanh(self.input_gate(torch.cat((input_, r * hidden_state), dim=2)))
65 | # 更新隐藏状态
66 | hidden_state = (1 - z) * hidden_state + z * h_prime
67 | return hidden_state
68 |
69 |
70 | class TextCNN(nn.Module):
71 | def __init__(self,in_features,out_features):
72 | super().__init__()
73 | self.mish = Mish()
74 | self.convs = nn.ModuleList([nn.Conv2d(1, 512, (k, in_features)) for k in [2,3,4]])
75 | self.dropout = nn.Dropout(0.5)
76 |
77 | self.fc = FeedForwardNetwork(512 * 3,out_features)
78 |
79 |
80 | def conv_and_pool(self, x, conv):
81 | x = self.mish(conv(x)).squeeze(3)
82 | print(x.shape)
83 | x = F.max_pool1d(x, x.size(2)).squeeze(2)
84 | print(f"in conv and pool x.shape is {x.shape}")
85 | return x
86 |
87 | def forward(self, x):
88 | print(f"x.shape is {x.shape}")
89 | # out = x.unsqueeze(1)
90 | out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
91 | out = self.dropout(out)
92 | print(f"textcnn output is {out.shape}")
93 | s = input()
94 | out = self.fc(out)
95 | print(f"textcnn output is {out.shape}")
96 | # out = torch.index_select(out, dim=1, index=indices)
97 | return out
98 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from sklearn import metrics
4 | from transformers import AutoModelForSequenceClassification,AutoTokenizer
5 |
6 | model = AutoModelForSequenceClassification.from_pretrained("output_qwen/thucnews")
7 | tokenizer = AutoTokenizer.from_pretrained("output_qwen/thucnews")
8 | model.cuda()
9 |
10 |
11 |
12 |
13 | y_true = []
14 | y_pred = []
15 | with open("data/datasets/thucnews/test.json","r",encoding="utf-8") as f:
16 | for line in tqdm(f.readlines()):
17 | example = json.loads(line)
18 | content = example["content"]
19 | label = eval(example["label"])
20 | y_true.append(label)
21 |
22 | input_demo = tokenizer(content, padding="max_length",truncation=True,return_tensors="pt")
23 |
24 | for key in input_demo.keys():
25 | input_demo[key] = input_demo[key].cuda()
26 |
27 | output = model(**input_demo)
28 |
29 | pred = output.logits.argmax().item()
30 |
31 |
32 | y_pred.append(pred)
33 |
34 | columns = open("data/datasets/thucnews/class.txt","r",encoding="utf-8").readlines()
35 | columns = [x.strip("\n") for x in columns]
36 |
37 | report = metrics.classification_report(y_true, y_pred, target_names=columns)
38 | print(report)
39 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | peft
2 | transformers
3 | torch
4 | datasets
5 | peft
6 | safetensors
7 | scikit-learn
8 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 |
4 | MODEL="ckpt/DeepSeek-R1-Distill-Qwen-1___5B" # Set the path if you do not want to load from huggingface directly
5 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
6 | # See the section for finetuning in README for more information.
7 | DATA="data/datasets/longnews"
8 |
9 | function usage() {
10 | echo '
11 | Usage: bash finetune/finetune_lora_single_gpu.sh [-m MODEL_PATH] [-d DATA_PATH]
12 | '
13 | }
14 |
15 | while [[ "$1" != "" ]]; do
16 | case $1 in
17 | -m | --model )
18 | shift
19 | MODEL=$1
20 | ;;
21 | -d | --data )
22 | shift
23 | DATA=$1
24 | ;;
25 | -h | --help )
26 | usage
27 | exit 0
28 | ;;
29 | * )
30 | echo "Unknown argument ${1}"
31 | exit 1
32 | ;;
33 | esac
34 | shift
35 | done
36 |
37 | export CUDA_VISIBLE_DEVICES=0
38 | python lora_predict.py \
39 | --model_name_or_path $MODEL \
40 | --is_training True \
41 | --add_adapter True \
42 | --data_path $DATA \
43 | --bf16 True \
44 | --output_dir output_qwen/longnews\
45 | --num_train_epochs 3 \
46 | --per_device_train_batch_size 2 \
47 | --per_device_eval_batch_size 1 \
48 | --gradient_accumulation_steps 2 \
49 | --evaluation_strategy "no" \
50 | --save_strategy "steps" \
51 | --save_steps 2000 \
52 | --save_total_limit 10 \
53 | --learning_rate 5e-4 \
54 | --weight_decay 0.1 \
55 | --adam_beta2 0.95 \
56 | --warmup_ratio 0.01 \
57 | --warmup_steps 1000 \
58 | --lr_scheduler_type "cosine" \
59 | --logging_steps 1 \
60 | --report_to "none" \
61 | --model_max_length 512 \
62 | --lazy_preprocess True \
63 | --gradient_checkpointing \
64 | --use_lora
65 |
66 |
--------------------------------------------------------------------------------
/scripts/download_model.py:
--------------------------------------------------------------------------------
1 | # encoding : utf-8 -*-
2 | # @author : 冬瓜
3 | # @mail : dylan_han@126.com
4 | # @Time : 2025/3/2 11:49
5 |
6 | from modelscope import snapshot_download
7 | snapshot_download('Qwen/Qwen2.5-0.5B-Instruct',cache_dir="./ckpt")
8 | snapshot_download('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',cache_dir="./ckpt")
--------------------------------------------------------------------------------