├── README.md ├── all_embd_to_openai.py ├── all_to_tool_call ├── README.md ├── all_to_tool_call.py └── test.ipynb ├── dapo_from_scratch ├── dapo_loss.png ├── grpo_loss.png ├── reward_func.py ├── test.py └── train.py ├── date_modify.ipynb ├── deep_research ├── client.py ├── prompts.py ├── search_mcp.py ├── searxng │ ├── docker-compose.yaml │ └── settings.yml ├── test.log └── test.md ├── deepseek_learn ├── 1.png ├── 2.png ├── 3.png ├── MLA.py ├── MTP_train │ ├── MTP.py │ ├── loss.png │ └── test.py ├── deepseek_r1_train │ ├── deepseek-r1解读.md │ ├── deepseek_r1_train.py │ └── deepseek_reward.png └── test.ipynb ├── gradio_mcp_client.py ├── grpo_from_scratch ├── grpo_loss.png ├── reward_func.py ├── test.py └── train.py ├── knowledge_distillation_llm ├── README.md ├── dataset.py ├── images │ ├── fkl.png │ ├── fkl_formula.png │ ├── rkl.png │ └── rkl_formula.png ├── train.py └── utils.py ├── pdf2markdown ├── README.md └── pdf2markdown.py ├── ppo_from_scratch ├── ppo.png └── ppo_train.py ├── rag_demo ├── medical_data.txt └── rag.ipynb ├── s1_from_scratch ├── README.md ├── generate.py ├── s1_train.py └── training_loss.png ├── table_extract ├── README.md ├── chinese_cht.ttf ├── imgs │ ├── 2.jpg │ └── extract_2.png └── table2txt.ipynb ├── table_rag.ipynb ├── train_llm_from_scratch ├── README.md ├── dataset.py ├── dpo.png ├── dpo_train.py ├── screenshot-20241207-093824.png ├── sft_train.py ├── test_llm.ipynb ├── tokenizer │ ├── merges.txt │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json ├── train.ipynb ├── train.py ├── train_tokenizer.ipynb ├── trainer_state_pretrain.json └── trainer_state_sft.json ├── train_moe_from_scratch ├── 1.ipynb ├── README.md ├── dataset.py ├── moe_sft_train.py ├── moe_test.py ├── moe_train.py ├── screenshot-20241207-093824.png ├── sft.jsonl ├── tokenizer │ ├── merges.txt │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json └── train.jsonl ├── train_multimodal_from_scratch ├── README.md ├── gradio_vlm.py ├── sft_train.py ├── test.ipynb ├── test.py ├── train.py └── trainer.ipynb └── train_siglip_from_scratch ├── clip.png ├── data_process.ipynb ├── dataset.py ├── model.py ├── siglip.png ├── test.ipynb ├── train.py └── 伪代码.png /README.md: -------------------------------------------------------------------------------- 1 | # 复现各种大模型相关算法 2 | -------------------------------------------------------------------------------- /all_embd_to_openai.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from contextlib import asynccontextmanager 4 | from typing import List, Optional 5 | from langchain_community.embeddings import OpenVINOBgeEmbeddings 6 | import click 7 | import uvicorn 8 | import tiktoken 9 | from fastapi import FastAPI 10 | from fastapi.responses import JSONResponse, Response 11 | from langchain_community.embeddings import OpenVINOBgeEmbeddings 12 | from pydantic import BaseModel 13 | 14 | encoder = tiktoken.get_encoding("cl100k_base") 15 | 16 | class EmbeddingRequest(BaseModel): 17 | #输入可以是字符串、字符串列表、至于为什么要加上List[List[int]],因为在集成maxkb的过程中发现,其调用向量模型传过来的参数中文字是经过tiktoken编码的。 18 | input: str|List[str]|List[List[int]] 19 | model: str 20 | 21 | TIMEOUT_KEEP_ALIVE = 5 # seconds. 22 | 23 | class OpenaiServer: 24 | 25 | def __init__(self, 26 | embedding_model_path): 27 | 28 | # 可在此处修改为自己的模型,可以通过任意方式加载(huggingface,langchain,sentence-transformers等), 29 | # 此处为了加速使用了openvino模型,可根据需要自行修改为自己的模型 30 | self.model = OpenVINOBgeEmbeddings( 31 | model_name_or_path=embedding_model_path, 32 | model_kwargs={"device": "CPU"}, 33 | encode_kwargs={"normalize_embeddings": True}, 34 | ) 35 | 36 | @asynccontextmanager 37 | async def lifespan(app: FastAPI): 38 | # terminate rank0 worker 39 | yield 40 | 41 | self.app = FastAPI(lifespan=lifespan) 42 | 43 | 44 | self.register_routes() 45 | def register_routes(self): 46 | self.app.add_api_route("/health", self.health, methods=["GET"]) 47 | self.app.add_api_route("/v1/embeddings", 48 | self.get_embeddings, 49 | methods=["POST"]) 50 | 51 | async def health(self) -> Response: 52 | return Response(status_code=200) 53 | async def get_embeddings(self, request: EmbeddingRequest) -> Response: 54 | 55 | data = [] 56 | if isinstance(request.input, List): 57 | if isinstance(request.input[0], str): 58 | # 修改完模型后,可根据向量模型的具体推理方式修改如下方法 59 | # 此处为langchain加载的向量模型所使用的推理方法embed_documents和embed_query,分别针对列表和字符串 60 | # 如使用sentence_transformers,推理方法如下: 61 | # from sentence_transformers import SentenceTransformer 62 | # model = SentenceTransformer("shibing624/text2vec-base-chinese") 63 | # sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡'] 64 | # sentence_embeddings = model.encode(sentences) 65 | 66 | embedding = self.model.embed_documents(request.input) 67 | for i, item in enumerate(embedding): 68 | tmp = { 69 | "object": "embedding", 70 | "embedding": item, 71 | "index": i 72 | } 73 | data.append(tmp) 74 | elif isinstance(request.input[0], List): 75 | # 将tiktoken编码的文本转会文本 76 | text_list = [encoder.decode(item) for item in request.input] 77 | embedding = self.model.embed_documents(text_list) 78 | for i, item in enumerate(embedding): 79 | tmp = { 80 | "object": "embedding", 81 | "embedding": item, 82 | "index": i 83 | } 84 | data.append(tmp) 85 | else: 86 | # 87 | embedding = self.model.embed_query(request.input) 88 | tmp = { 89 | "object": "embedding", 90 | "embedding": embedding, 91 | "index": 0 92 | } 93 | data.append(tmp) 94 | 95 | 96 | res = { 97 | "object": "list", 98 | "data": data, 99 | "model": request.model, 100 | "usage": { 101 | "prompt_tokens": 0, 102 | "total_tokens": 0 103 | } 104 | } 105 | return JSONResponse(content=res) 106 | 107 | async def __call__(self, host, port): 108 | config = uvicorn.Config(self.app, 109 | host=host, 110 | port=port, 111 | log_level="info", 112 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE) 113 | await uvicorn.Server(config).serve() 114 | @click.command() 115 | @click.argument("model_dir") 116 | @click.option("--host", type=str, default=None) 117 | @click.option("--port", type=int, default=8000) 118 | def entrypoint(model_dir, 119 | host: Optional[str] = None, 120 | port: int = 8000): 121 | host = host or "0.0.0.0" 122 | port = port or 8000 123 | logging.info(f"Starting server at {host}:{port}") 124 | 125 | server = OpenaiServer(embedding_model_path=model_dir) 126 | 127 | asyncio.run(server(host, port)) 128 | 129 | if __name__ == "__main__": 130 | entrypoint() 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /all_to_tool_call/README.md: -------------------------------------------------------------------------------- 1 | ## 大模型api为什么需要具备工具调用能力? 2 | 大多数的agent框架或者mcp客户端需要模型api支持工具调用能力 3 | 4 | ## 直接工具调用和间接工具调用 5 | - 直接工具调用 6 | 7 | api返回的结果中有字段专门存储工具调用的结果 8 | - 间接工具调用 9 | 10 | 在提示词中让模型按照给定格式返回工具调用的结果,然后从content中解析出工具参数和名称 11 | ``` 12 | ChatCompletion(id='0196bea6713a7620552a143e3aa91f93', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0196bea67535a49350b2ab4b41a7e588', function=Function(arguments='{"location": "北京市"}', name='get_current_weather'), type='function')]))], created=1746955301, model='Qwen/Qwen2.5-7B-Instruct', object='chat.completion', service_tier=None, system_fingerprint='', usage=CompletionUsage(completion_tokens=22, prompt_tokens=273, total_tokens=295, completion_tokens_details=None, prompt_tokens_details=None)) 13 | ``` 14 | 15 | ## 实现流程 16 | ```mermaid 17 | graph TD 18 | A[用户请求] --> B{大模型推理} 19 | B --> C[生成调用决策] 20 | C --> D{小模型解析} 21 | D --> E[执行工具调用] 22 | E --> F[返回结构化结果] 23 | ``` -------------------------------------------------------------------------------- /all_to_tool_call/all_to_tool_call.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel 3 | import uvicorn 4 | import requests 5 | import argparse 6 | from typing import Optional, List 7 | from openai import OpenAI 8 | # 创建FastAPI应用实例 9 | app = FastAPI() 10 | 11 | class ChatCompletionRequest(BaseModel): 12 | 13 | model: str 14 | messages: List 15 | max_tokens: int = 4096 16 | temperature: float = 0.7 17 | tools: Optional[List] = None 18 | 19 | parser = argparse.ArgumentParser(description="启动模型服务代理") 20 | parser.add_argument('--no_tool_call_base_url', type=str, default="https://api.siliconflow.cn/v1") 21 | parser.add_argument('--no_tool_call_model_name', type=str, default="Qwen/Qwen2.5-Coder-32B-Instruct") 22 | parser.add_argument('--no_tool_call_api_key', type=str, default="sk-") 23 | parser.add_argument('--tool_call_base_url', type=str, default="https://api.siliconflow.cn/v1") 24 | parser.add_argument('--tool_call_model_name', type=str, default="Qwen/Qwen2.5-7B-Instruct") 25 | parser.add_argument('--tool_call_api_key', type=str, default="sk-") 26 | parser.add_argument('--host', type=str, default="0.0.0.0") 27 | parser.add_argument('--port', type=int, default=8888) 28 | args = parser.parse_args() 29 | 30 | def generate_text(base_url: str, model: str, messages: List, max_tokens: int, temperature: float, api_key: str, tools=None): 31 | 32 | client = OpenAI(base_url=base_url, api_key=api_key) 33 | 34 | 35 | 36 | completion = client.chat.completions.create( 37 | model=model, 38 | temperature=temperature, 39 | max_tokens=max_tokens, 40 | messages=messages, 41 | tools=tools, 42 | ) 43 | return completion 44 | 45 | 46 | # 定义路由和处理函数,与OpenAI API兼容 47 | @app.post("/v1/chat/completions") 48 | async def create_chat_completion(request: ChatCompletionRequest): 49 | 50 | if request.tools: 51 | TOOL_EAXMPLE = "You will receive a JSON string containing a list of callable tools. Please parse this JSON string and return a JSON object containing the tool name and tool parameters." 52 | 53 | REUTRN_FORMAT="{\"tool\": \"tool name\", \"parameters\": {\"parameter name\": \"parameter value\"}}" 54 | 55 | INSTRUCTION = f""" 56 | {TOOL_EAXMPLE} 57 | Answer the following questions as best you can. 58 | 59 | Use the following format: 60 | ```tool_json 61 | {REUTRN_FORMAT} 62 | ``` 63 | 64 | Please choose the appropriate tool according to the user's question. If you don't need to call it, please reply directly to the user's question. When the user communicates with you in a language other than English, you need to communicate with the user in the same language. 65 | 66 | When you have enough information from the tool results, respond directly to the user with a text message without having to call the tool again. 67 | 68 | You can use the following tools: 69 | {request.tools} 70 | """ 71 | messages = [{"role": "system", "content": INSTRUCTION}] 72 | messages += request.messages 73 | response = generate_text(args.no_tool_call_base_url, args.no_tool_call_model_name, messages, request.max_tokens, request.temperature, args.no_tool_call_api_key) 74 | response = response.choices[0].message.content 75 | print(response) 76 | messages = [{"role": "system", "content": "Answer the initial based on the directly."}] 77 | print(request.messages[-1]['content']) 78 | messages += [{"role": "user", "content": f"\n{request.messages[-1]['content']}\n\n\n{response}\n"} 79 | ] 80 | response = generate_text(args.tool_call_base_url, args.tool_call_model_name, messages, request.max_tokens, request.temperature, args.tool_call_api_key, tools=request.tools) 81 | 82 | else: 83 | response = generate_text(args.no_tool_call_base_url, args.no_tool_call_model_name, request.messages, request.max_tokens, request.temperature, args.no_tool_call_api_key) 84 | 85 | return response 86 | 87 | # 启动FastAPI应用,使用命令行参数指定的端口 88 | if __name__ == "__main__": 89 | uvicorn.run(app, host=args.host, port=args.port) -------------------------------------------------------------------------------- /all_to_tool_call/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "ChatCompletion(id='0196bec84e9e399c781eb23b425fd788', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0196bec8536338780ca82d04571819e9', function=Function(arguments='{\"location\": \"北京市\"}', name='get_current_weather'), type='function')]))], created=1746957520, model='Qwen/Qwen2.5-7B-Instruct', object='chat.completion', service_tier=None, system_fingerprint='', usage=CompletionUsage(completion_tokens=22, prompt_tokens=273, total_tokens=295, completion_tokens_details=None, prompt_tokens_details=None))" 12 | ] 13 | }, 14 | "execution_count": 4, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "from openai import OpenAI\n", 21 | "from datetime import datetime\n", 22 | "import json\n", 23 | "import os\n", 24 | "import random\n", 25 | "\n", 26 | "# client = OpenAI(\n", 27 | "# api_key='sk-zfaguzfmjrruybpjgwaxabwytcdgwrvrcsldmxigrsmolpyt',\n", 28 | "# base_url=\"https://api.siliconflow.cn/v1\",\n", 29 | "# )\n", 30 | "\n", 31 | "client = OpenAI(\n", 32 | " api_key='yyy',\n", 33 | " base_url=\"http://10.250.2.24:8888/v1\",\n", 34 | ")\n", 35 | "\n", 36 | "tools = [\n", 37 | " # 工具1 获取当前时刻的时间\n", 38 | " {\n", 39 | " \"type\": \"function\",\n", 40 | " \"function\": {\n", 41 | " \"name\": \"get_current_time\",\n", 42 | " \"description\": \"当你想知道现在的时间时非常有用。\",\n", 43 | " \"parameters\": {},\n", 44 | " },\n", 45 | " },\n", 46 | " # 工具2 获取指定城市的天气\n", 47 | " {\n", 48 | " \"type\": \"function\",\n", 49 | " \"function\": {\n", 50 | " \"name\": \"get_current_weather\",\n", 51 | " \"description\": \"当你想查询指定城市的天气时非常有用。\",\n", 52 | " \"parameters\": {\n", 53 | " \"type\": \"object\",\n", 54 | " \"properties\": {\n", 55 | " # 查询天气时需要提供位置,因此参数设置为location\n", 56 | " \"location\": {\n", 57 | " \"type\": \"string\",\n", 58 | " \"description\": \"城市或县区,比如北京市、杭州市、余杭区等。\",\n", 59 | " }\n", 60 | " },\n", 61 | " \"required\": [\"location\"],\n", 62 | " },\n", 63 | " },\n", 64 | " },\n", 65 | "]\n", 66 | "\n", 67 | "\n", 68 | "# 封装模型响应函数\n", 69 | "def get_response(query):\n", 70 | " messages = [\n", 71 | " {\n", 72 | " \"role\": \"system\",\n", 73 | " \"content\": \"You are a helpful assistant.\"\n", 74 | " },\n", 75 | " {\n", 76 | " \"role\": \"user\",\n", 77 | " \"content\": query\n", 78 | " }\n", 79 | " ]\n", 80 | " completion = client.chat.completions.create(\n", 81 | " model=\"Qwen/Qwen2.5-Coder-32B-Instruct\",\n", 82 | " messages=messages,\n", 83 | " tools=tools,\n", 84 | " )\n", 85 | " return completion\n", 86 | "\n", 87 | "get_response('北京天气怎么样')" 88 | ] 89 | } 90 | ], 91 | "metadata": { 92 | "kernelspec": { 93 | "display_name": "wyf", 94 | "language": "python", 95 | "name": "python3" 96 | }, 97 | "language_info": { 98 | "codemirror_mode": { 99 | "name": "ipython", 100 | "version": 3 101 | }, 102 | "file_extension": ".py", 103 | "mimetype": "text/x-python", 104 | "name": "python", 105 | "nbconvert_exporter": "python", 106 | "pygments_lexer": "ipython3", 107 | "version": "3.11.0" 108 | } 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 2 112 | } 113 | -------------------------------------------------------------------------------- /dapo_from_scratch/dapo_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/dapo_from_scratch/dapo_loss.png -------------------------------------------------------------------------------- /dapo_from_scratch/grpo_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/dapo_from_scratch/grpo_loss.png -------------------------------------------------------------------------------- /dapo_from_scratch/reward_func.py: -------------------------------------------------------------------------------- 1 | import re 2 | def extract_answer(text): 3 | answer = text.split("")[-1] 4 | answer = answer.split("")[0] 5 | return answer.strip() 6 | 7 | def mark_num(text): 8 | reward = 0 9 | if text.count("\n") == 1: 10 | reward += 0.125 11 | 12 | if text.count("\n") == 1: 13 | reward += 0.125 14 | 15 | if text.count("\n") == 1: 16 | reward += 0.125 17 | 18 | if text.count("\n") == 1: 19 | reward += 0.125 20 | return reward 21 | 22 | # 生成答案是否正确的奖励 23 | def correctness_reward(prompts, responses, answers): 24 | 25 | extracted_responses = [extract_answer(r) for r in responses] 26 | print(f"问题:\n{prompts[0]}", f"\n答案:\n{answers[0]}", f"\n模型输出:\n{responses[0]}", f"\n提取后的答案:\n{extracted_responses[0]}") 27 | return [2.0 if response == str(ans) else 0.0 for response, ans in zip(extracted_responses, answers)] 28 | 29 | # 生成答案是否是数字的奖励(单纯依赖结果是否正确进行奖励,条件很苛刻,会导致奖励比较稀疏,模型难以收敛,所以加上答案是否是数字的奖励,虽然答案错误,但是至少生成的是数字(对于数学问题),也要给予适当奖励) 30 | def digit_reward(prompts, responses, answers): 31 | extracted_responses = [extract_answer(r) for r in responses] 32 | return [0.5 if response.isdigit() else 0.0 for response in extracted_responses] 33 | 34 | # 格式奖励 35 | def hard_format_reward(prompts, responses, answers): 36 | pattern = r"^\n.*?\n\n\n.*?\n\n$" 37 | matches = [re.match(pattern, response) for response in responses] 38 | return [0.5 if match else 0.0 for match in matches] 39 | 40 | # 标记奖励(改善格式奖励稀疏问题) 41 | def mark_reward(prompts, responses, answers): 42 | return [mark_num(response) for response in responses] -------------------------------------------------------------------------------- /dapo_from_scratch/test.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | client = OpenAI(api_key='ww', base_url='http://10.250.2.24:8036/v1') 3 | 4 | SYSTEM_PROMPT = """ 5 | 按照如下格式回答问题: 6 | 7 | 你的思考过程 8 | 9 | 10 | 你的回答 11 | 12 | """ 13 | 14 | completion = client.chat.completions.create( 15 | model = 'qwen1.5b', 16 | 17 | temperature=0.0, 18 | logprobs = True, 19 | messages=[ 20 | { 21 | "role": "system", 22 | "content": SYSTEM_PROMPT}, 23 | { 24 | "role": "user", 25 | "content": "天上五只鸟,地上五只鸡,一共几只鸭", 26 | } 27 | ], 28 | ) 29 | print(completion.choices[0].message.content) 30 | -------------------------------------------------------------------------------- /date_modify.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from langchain_openai import ChatOpenAI\n", 10 | "llm = ChatOpenAI(temperature=0, model=\"qwen2\", api_key=\"nn\", base_url='http://***/v1')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "llm.invoke('去年中秋节是哪天?').content" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "functions = [{\n", 29 | " 'name': 'get_item_info',\n", 30 | " 'description': '获取时间范围内某个产品的信息',\n", 31 | " 'parameters': {\n", 32 | " 'type': 'object',\n", 33 | " 'properties': {\n", 34 | " 'item': {\n", 35 | " 'type': 'string',\n", 36 | " 'description':\n", 37 | " '产品名称',\n", 38 | " },\n", 39 | " 'start': {\n", 40 | " 'type': 'string',\n", 41 | " 'description':\n", 42 | " '时间范围的起始时间'\n", 43 | " },\n", 44 | " 'end': {\n", 45 | " 'type': 'string',\n", 46 | " 'description':\n", 47 | " '时间范围的结束时间'\n", 48 | " },\n", 49 | " },\n", 50 | " 'required': ['event', 'start', 'end'],\n", 51 | " }\n", 52 | " }]" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "llm = llm.bind_tools(functions)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "llm.invoke('2023年1月12至2月1日笔记本的销量')" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "llm.invoke('去年九月份笔记本的销量')" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "import jionlp\n", 89 | "query = '去年中秋节的营收多少'\n", 90 | "def date_extract(query):\n", 91 | " date = jionlp.parse_time(query)\n", 92 | " return date['time']\n", 93 | "date = date_extract(query)\n", 94 | "print(date)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "date_modify_prompt = '''\n", 104 | "你是一个日期转换助手,请按照如下格式对用户输入进行转化。\n", 105 | "\n", 106 | "用户输入:\n", 107 | "2023年端午节的订单销量是多少\n", 108 | "\n", 109 | "日期:\n", 110 | "['2023-06-22 00:00:00', '2023-06-22 23:59:59']\n", 111 | "\n", 112 | "输出:\n", 113 | "2023-06-22 00:00:00到2023-06-22 23:59:59的订单销量是多少\n", 114 | "\n", 115 | "用户输入:\n", 116 | "2022年中秋节的营收多少\n", 117 | "\n", 118 | "日期:\n", 119 | "['2022-09-10 00:00:00', '2022-09-10 23:59:59']\n", 120 | "\n", 121 | "输出:\n", 122 | "2022-09-10 00:00:00到2022-09-10 23:59:59营收多少\n", 123 | "\n", 124 | "用户输入:\n", 125 | "{}\n", 126 | "\n", 127 | "日期:\n", 128 | "{}\n", 129 | "\n", 130 | "输出:\n", 131 | "'''" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "from langchain_openai import ChatOpenAI\n", 141 | "llm = ChatOpenAI(temperature=0, model=\"qwen2\", api_key=\"nn\", base_url='http://10.250.2.23:8600/v1')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "query = '去年第一季度的销量怎么样'\n", 151 | "modify_query = llm.invoke(date_modify_prompt.format(query, date_extract(query))).content\n", 152 | "print(modify_query)" 153 | ] 154 | } 155 | ], 156 | "metadata": { 157 | "kernelspec": { 158 | "display_name": "base", 159 | "language": "python", 160 | "name": "python3" 161 | }, 162 | "language_info": { 163 | "name": "python", 164 | "version": "undefined.undefined.undefined" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /deep_research/client.py: -------------------------------------------------------------------------------- 1 | from mcp import ClientSession, StdioServerParameters 2 | from mcp.client.stdio import stdio_client 3 | from typing import Optional 4 | from openai import AsyncOpenAI 5 | from contextlib import AsyncExitStack 6 | import json 7 | import asyncio 8 | import os 9 | from prompts import * 10 | from search_mcp import logger 11 | 12 | 13 | base_url = "https://openrouter.ai/api/v1" 14 | api_key = 'aaa' 15 | model_name = 'deepseek/deepseek-chat:free' 16 | def get_clear_json(text): 17 | if '```json' not in text: 18 | return 0, text 19 | 20 | return 1, text.split('```json')[1].split('```')[0] 21 | 22 | class MCPClient: 23 | def __init__(self): 24 | self.session: Optional[ClientSession] = None 25 | self.exit_stack = AsyncExitStack() 26 | self.client = AsyncOpenAI( 27 | base_url=base_url, 28 | api_key=api_key, 29 | ) 30 | 31 | async def connect_to_server(self, server_script_path: str): 32 | server_params = StdioServerParameters( 33 | command="python", 34 | args=[server_script_path], 35 | env=None 36 | ) 37 | 38 | stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) 39 | self.stdio, self.write = stdio_transport 40 | self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) 41 | 42 | await self.session.initialize() 43 | 44 | # 列出可用工具 45 | response = await self.session.list_tools() 46 | tools = response.tools 47 | logger.info(f"\nConnected to server with tools: {[tool.name for tool in tools]}") 48 | 49 | 50 | async def process_query(self, query: str) -> str: 51 | """使用 LLM 和 MCP 服务器提供的工具处理查询""" 52 | 53 | response = await self.session.list_tools() 54 | 55 | available_tools = [{ 56 | "type": "function", 57 | "function": { 58 | "name": tool.name, 59 | "description": tool.description, 60 | "parameters": tool.inputSchema 61 | } 62 | } for tool in response.tools] 63 | logger.info(f'available_tools:\n\n{available_tools}') 64 | 65 | messages = [ 66 | { 67 | "role": "system", 68 | "content": SYSTEM_PROMPT + str(available_tools) 69 | }, 70 | { 71 | "role": "user", 72 | "content": query 73 | } 74 | ] 75 | response = await self.client.chat.completions.create( 76 | model=model_name, 77 | messages=messages 78 | ) 79 | 80 | message = response.choices[0].message 81 | logger.info(f'llm_output(tool call):{message.content}') 82 | 83 | results = [] 84 | while True: 85 | 86 | flag, json_text = get_clear_json(message.content) 87 | 88 | if flag == 0: 89 | response = await self.client.chat.completions.create( 90 | model=model_name, 91 | messages=[{"role": "user", "content": query}] 92 | ) 93 | return response.choices[0].message.content 94 | 95 | json_text = json.loads(json_text) 96 | tool_name = json_text['name'] 97 | tool_args = json_text['params'] 98 | result = await self.session.call_tool(tool_name, tool_args) 99 | logger.info(f'tool name: \n{tool_name}\ntool call result: \n{result}') 100 | results.append(result.content[0].text) 101 | 102 | messages.append({ 103 | "role": "assistant", 104 | "content": message.content 105 | }) 106 | messages.append({ 107 | "role": "user", 108 | "content": f'工具调用结果如下:{result}' 109 | }) 110 | 111 | messages.append({ 112 | "role": "user", 113 | "content": NEXT_STEP_PROMPT.format(query) 114 | }) 115 | 116 | response = await self.client.chat.completions.create( 117 | model=model_name, 118 | messages=messages 119 | ) 120 | 121 | message = response.choices[0].message 122 | logger.info(f'llm_output:\n{message.content}') 123 | 124 | if 'finish' in message.content: 125 | break 126 | 127 | messages.append({ 128 | "role": "assistant", 129 | "content": message.content 130 | }) 131 | 132 | 133 | 134 | 135 | messages.append({ 136 | "role": "user", 137 | "content": FINISH_GENETATE.format('\n\n'.join(results), query) 138 | }) 139 | 140 | response = await self.client.chat.completions.create( 141 | model=model_name, 142 | messages=messages 143 | ) 144 | 145 | message = response.choices[0].message.content 146 | return message 147 | 148 | async def chat_loop(self): 149 | """运行交互式聊天循环""" 150 | logger.info("\nMCP Client Started!") 151 | logger.info("Type your queries or 'quit' to exit.") 152 | 153 | while True: 154 | try: 155 | query = input("\nQuery: ").strip() 156 | if query.lower() == 'quit': 157 | break 158 | response = await self.process_query(query) 159 | print(response) 160 | except Exception as e: 161 | logger.error(f"\nError: {str(e)}") 162 | 163 | 164 | 165 | async def main(): 166 | 167 | client = MCPClient() 168 | 169 | 170 | await client.connect_to_server('./search_mcp.py') 171 | 172 | await client.chat_loop() 173 | 174 | 175 | 176 | if __name__ == "__main__": 177 | asyncio.run(main()) -------------------------------------------------------------------------------- /deep_research/prompts.py: -------------------------------------------------------------------------------- 1 | SYSTEM_PROMPT = """你是一个报告生成助手。 2 | 你可以使用 MCP 服务器提供的工具来完成任务。 3 | MCP 服务器会动态提供工具,你需要先检查当前可用的工具。 4 | 5 | 在使用 MCP 工具时,请遵循以下步骤: 6 | 1、根据任务需求选择合适的工具 7 | 2、按照工具的参数要求提供正确的参数 8 | 3、观察工具的返回结果,并根据结果决定下一步操作 9 | 4、工具可能会发生变化,比如新增工具或现有工具消失 10 | 11 | 请遵循以下指南: 12 | - 使用工具时,确保参数符合工具的文档要求 13 | - 如果出现错误,请理解错误原因并尝试用修正后的参数重新调用 14 | - 按照任务需求逐步完成,优先选择最合适的工具 15 | - 如果需要连续调用多个工具,请一次只调用一个工具并等待结果 16 | - 以```json```格式输出。例如:```json{"name": "tool_name", "params": {"param1": "value1", "param2": "value2"}}``` 17 | 18 | 请清楚地向用户解释你的推理过程和操作步骤。 19 | 20 | 可选择的工具如下: 21 | """ 22 | 23 | 24 | NEXT_STEP_PROMPT = """ 25 | ## 任务目标 26 | 根据已经获取的信息,判断是否可以解决用户的需求。 27 | 28 | ## 任务要求 29 | - 请认真审视用户的需求,特别注意用户需求中的条件和范围 30 | - 如果可以解决(满足用户给出的条件和范围),请输出 31 | - 如果缺少数据或内容,请继续调用合适的工具获取更多信息 32 | 33 | ## 用户需求 34 | 用户需求如下: 35 | {} 36 | """ 37 | 38 | 39 | FINISH_GENETATE = ''' 40 | ## 任务目标 41 | 根据已收集信息和用户需求生成完整报告。 42 | 43 | ## 已收集信息 44 | {} 45 | 46 | ## 任务要求 47 | 1、请根据图片的描述将图片链接插入到合适的位置,如果没有符合要求的图片,请不要插入图片 48 | 2、以markdown格式生成报告 49 | 50 | ## 用户需求 51 | {}''' -------------------------------------------------------------------------------- /deep_research/search_mcp.py: -------------------------------------------------------------------------------- 1 | from mcp.server.fastmcp import FastMCP 2 | import pymysql 3 | import requests 4 | from openai import OpenAI 5 | mcp = FastMCP("search") 6 | import pymysql 7 | import pandas as pd 8 | from prompts import * 9 | import logging 10 | 11 | 12 | 13 | base_url = "https://openrouter.ai/api/v1" 14 | api_key = 'aaa' 15 | model_name = 'deepseek/deepseek-chat:free' 16 | 17 | # 创建日志记录器 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.INFO) 20 | 21 | # 创建控制台处理器 22 | console_handler = logging.StreamHandler() 23 | console_handler.setLevel(logging.INFO) 24 | console_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 25 | console_handler.setFormatter(console_formatter) 26 | logger.addHandler(console_handler) 27 | 28 | # 创建文件处理器 29 | file_handler = logging.FileHandler('test.log') 30 | file_handler.setLevel(logging.INFO) 31 | file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 32 | file_handler.setFormatter(file_formatter) 33 | logger.addHandler(file_handler) 34 | 35 | client = OpenAI( 36 | base_url=base_url, 37 | api_key=api_key, 38 | ) 39 | 40 | def generate_query(query, stream=False): 41 | prompt ="""You are an expert research assistant. Given the user's query, generate up to four distinct, precise search queries that would help gather comprehensive information on the topic. 42 | Return only a Python list of strings, for example: ['query1', 'query2', 'query3'].""" 43 | 44 | response = client.chat.completions.create( 45 | model=model_name, 46 | messages = [ 47 | {"role": "system", "content": "You are a helpful and precise research assistant."}, 48 | {"role": "user", "content": f"User Query: {query}\n\n{prompt}"} 49 | ] 50 | ) 51 | return response.choices[0].message.content 52 | 53 | 54 | def if_useful(query: str, page_text: str): 55 | prompt ="""You are a critical research evaluator. Given the user's query and the content of a webpage, determine if the webpage contains information relevant and useful for addressing the query. 56 | Respond with exactly one word: 'Yes' if the page is useful, or 'No' if it is not. Do not include any extra text.""" 57 | 58 | response = client.chat.completions.create( 59 | model=model_name, 60 | messages = [ 61 | {"role": "system", "content": "You are a strict and concise evaluator of research relevance."}, 62 | {"role": "user", "content": f"User Query: {query}\n\nWebpage Content (first 20000 characters):\n{page_text[:20000]}\n\n{prompt}"} 63 | ] 64 | ) 65 | 66 | response = response.choices[0].message.content 67 | 68 | if response: 69 | answer = response.strip() 70 | if answer in ["Yes", "No"]: 71 | return answer 72 | else: 73 | # Fallback: try to extract Yes/No from the response. 74 | if "Yes" in answer: 75 | return "Yes" 76 | elif "No" in answer: 77 | return "No" 78 | return "No" 79 | 80 | def extract_relevant_context(query, search_query, page_text): 81 | prompt ="""You are an expert information extractor. Given the user's query, the search query that led to this page, and the webpage content, extract all pieces of information that are relevant to answering the user's query. 82 | Return only the relevant context as plain text without commentary.""" 83 | 84 | response = client.chat.completions.create( 85 | model=model_name, 86 | messages = [ 87 | {"role": "system", "content": "You are an expert in extracting and summarizing relevant information."}, 88 | {"role": "user", "content": f"User Query: {query}\nSearch Query: {search_query}\n\nWebpage Content (first 20000 characters):\n{page_text[:20000]}\n\n{prompt}"} 89 | ] 90 | ) 91 | 92 | response = response.choices[0].message.content 93 | if response: 94 | return response.strip() 95 | return "" 96 | 97 | def get_new_search_queries(user_query, previous_search_queries, all_contexts): 98 | context_combined = "\n".join(all_contexts) 99 | prompt ="""You are an analytical research assistant. Based on the original query, the search queries performed so far, and the extracted contexts from webpages, determine if further research is needed. 100 | If further research is needed, provide up to four new search queries as a Python list (for example, ['new query1', 'new query2']). If you believe no further research is needed, respond with exactly . 101 | Output only a Python list or the token without any additional text.""" 102 | 103 | response = client.chat.completions.create( 104 | model=model_name, 105 | messages = [ 106 | {"role": "system", "content": "You are an expert in extracting and summarizing relevant information."}, 107 | {"role": "user", "content": f"User Query: {user_query}\nPrevious Search Queries: {previous_search_queries}\n\nExtracted Relevant Contexts:\n{context_combined}\n\n{prompt}"} 108 | ] 109 | ) 110 | 111 | response = response.choices[0].message.content 112 | if response: 113 | cleaned = response.strip() 114 | if cleaned == "": 115 | return "" 116 | try: 117 | new_queries = eval(cleaned) 118 | if isinstance(new_queries, list): 119 | return new_queries 120 | else: 121 | logger.info(f"LLM did not return a list for new search queries. Response: {response}") 122 | return [] 123 | except Exception as e: 124 | logger.error(f"Error parsing new search queries:{e}, Response:{response}") 125 | return [] 126 | return [] 127 | 128 | 129 | # Reorganized this function, integrated get_images calls, and set top_k as a parameter 130 | def web_search(query: str, top_k: int = 2, categories: str = 'general') -> str: 131 | 132 | links = [] 133 | response = requests.get(f'http://10.250.2.24:8088/search?format=json&q={query}&language=zh-CN&time_range=&safesearch=0&categories={categories}', timeout=10) 134 | results = response.json()['results'] 135 | for result in results[:top_k]: 136 | links.append(result['url' if categories == 'general' else 'img_src' if categories == 'images' else '']) 137 | 138 | return links 139 | 140 | 141 | def fetch_webpage_text(url): 142 | JINA_BASE_URL = 'https://r.jina.ai/' 143 | full_url = f"{JINA_BASE_URL}{url}" 144 | 145 | try: 146 | resp = requests.get(full_url, timeout=50) 147 | if resp.status_code == 200: 148 | return resp.text 149 | else: 150 | text = resp.text 151 | logger.info(f"Jina fetch error for {url}: {resp.status_code} - {text}") 152 | return "" 153 | except Exception as e: 154 | logger.error(f"Error fetching webpage text with Jina:{e}") 155 | return "" 156 | 157 | 158 | def process_link(link, query, search_query): 159 | logger.info(f"Fetching content from: {link}") 160 | page_text = fetch_webpage_text(link) 161 | if not page_text: 162 | return None 163 | usefulness = if_useful(query, page_text) 164 | logger.info(f"Page usefulness for {link}: {usefulness}") 165 | if usefulness == "Yes": 166 | context = extract_relevant_context(query, search_query, page_text) 167 | if context: 168 | logger.info(f"Extracted context from {link} (first 200 chars): {context[:200]}") 169 | return context 170 | return None 171 | 172 | 173 | def get_images_description(iamge_url): 174 | completion = client.chat.completions.create( 175 | 176 | model="qwen/qwen2.5-vl-32b-instruct:free", 177 | messages=[ 178 | { 179 | "role": "user", 180 | "content": [ 181 | { 182 | "type": "text", 183 | "text": "使用一句话描述图片的内容" 184 | }, 185 | { 186 | "type": "image_url", 187 | "image_url": { 188 | "url": iamge_url 189 | } 190 | } 191 | ] 192 | } 193 | ] 194 | ) 195 | return completion.choices[0].message.content 196 | 197 | @mcp.tool() 198 | def search(query: str) -> str: 199 | """互联网搜索""" 200 | iteration_limit = 3 201 | iteration = 0 202 | aggregated_contexts = [] 203 | all_search_queries = [] 204 | iteration = 0 205 | 206 | new_search_queries = eval(generate_query(query)) 207 | all_search_queries.extend(new_search_queries) 208 | # add the original query to the search queries in case it is not already in the list 209 | if query not in all_search_queries: 210 | all_search_queries.append(query) 211 | while iteration < iteration_limit: 212 | logger.info(f"\n=== Iteration {iteration + 1} ===") 213 | iteration_contexts = [] 214 | search_results = [web_search(query, top_k=2, categories='general') for query in new_search_queries] 215 | 216 | unique_links = {} 217 | for idx, links in enumerate(search_results): 218 | search_query = new_search_queries[idx] # prevent the query being replaced as the Search parameter 219 | for link in links: 220 | if link not in unique_links: 221 | unique_links[link] = search_query 222 | 223 | logger.info(f"Aggregated {len(unique_links)} unique links from this iteration.") 224 | 225 | # Process each link concurrently: fetch, judge, and extract context. 226 | link_results = [ 227 | process_link(link, query, unique_links[link]) 228 | for link in unique_links 229 | ] 230 | 231 | # Collect non-None contexts. 232 | for res in link_results: 233 | if res: 234 | iteration_contexts.append(res) 235 | 236 | if iteration_contexts: 237 | aggregated_contexts.extend(iteration_contexts) 238 | else: 239 | logger.info("No useful contexts were found in this iteration.") 240 | 241 | new_search_queries = get_new_search_queries(query, all_search_queries, aggregated_contexts) 242 | if new_search_queries == "": 243 | logger.info("LLM indicated that no further research is needed.") 244 | break 245 | elif new_search_queries: 246 | logger.info(f"LLM provided new search queries:{new_search_queries}") 247 | all_search_queries.extend(new_search_queries) 248 | else: 249 | logger.info("LLM did not provide any new search queries. Ending the loop.") 250 | break 251 | 252 | iteration += 1 253 | return '\n\n'.join(aggregated_contexts) 254 | 255 | @mcp.tool() 256 | def get_images(query: str) -> str: 257 | '''获取图片链接和描述''' 258 | logger.info(f"Searching for images for query: {query}") 259 | # Get image links directly through web_search function 260 | img_srcs = web_search(query, top_k=2, categories='images') 261 | 262 | result = {} 263 | 264 | for img_src in img_srcs: 265 | logger.info(f"Fetching image description for: {img_src}") 266 | description = get_images_description(img_src) 267 | logger.info(f"Image description for {img_src}: {description}") 268 | result[img_src] = description 269 | 270 | return result 271 | 272 | 273 | 274 | if __name__ == "__main__": 275 | mcp.run() 276 | 277 | -------------------------------------------------------------------------------- /deep_research/searxng/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.7" 2 | 3 | services: 4 | caddy: 5 | container_name: caddy 6 | image: docker.io/library/caddy:2-alpine 7 | network_mode: host 8 | restart: unless-stopped 9 | volumes: 10 | - ./Caddyfile:/etc/caddy/Caddyfile:ro 11 | - caddy-data:/data:rw 12 | - caddy-config:/config:rw 13 | environment: 14 | - SEARXNG_HOSTNAME=${SEARXNG_HOSTNAME:-http://localhost} 15 | - SEARXNG_TLS=${LETSENCRYPT_EMAIL:-internal} 16 | cap_drop: 17 | - ALL 18 | cap_add: 19 | - NET_BIND_SERVICE 20 | logging: 21 | driver: "json-file" 22 | options: 23 | max-size: "1m" 24 | max-file: "1" 25 | 26 | redis: 27 | container_name: redis-searxng 28 | image: docker.io/valkey/valkey:8-alpine 29 | command: valkey-server --save 30 1 --loglevel warning 30 | restart: unless-stopped 31 | networks: 32 | - searxng 33 | volumes: 34 | - valkey-data2:/data 35 | cap_drop: 36 | - ALL 37 | cap_add: 38 | - SETGID 39 | - SETUID 40 | - DAC_OVERRIDE 41 | logging: 42 | driver: "json-file" 43 | options: 44 | max-size: "1m" 45 | max-file: "1" 46 | 47 | searxng: 48 | container_name: searxng 49 | image: docker.io/searxng/searxng:latest 50 | restart: unless-stopped 51 | networks: 52 | - searxng 53 | ports: 54 | - "8088:8080" 55 | volumes: 56 | - ./searxng:/etc/searxng:rw 57 | environment: 58 | - SEARXNG_BASE_URL=https://${SEARXNG_HOSTNAME:-localhost}/ 59 | - UWSGI_WORKERS=${SEARXNG_UWSGI_WORKERS:-4} 60 | - UWSGI_THREADS=${SEARXNG_UWSGI_THREADS:-4} 61 | cap_drop: 62 | - ALL 63 | cap_add: 64 | - CHOWN 65 | - SETGID 66 | - SETUID 67 | logging: 68 | driver: "json-file" 69 | options: 70 | max-size: "1m" 71 | max-file: "1" 72 | 73 | networks: 74 | searxng: 75 | 76 | volumes: 77 | caddy-data: 78 | caddy-config: 79 | valkey-data2: 80 | -------------------------------------------------------------------------------- /deep_research/searxng/settings.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.searxng.org/admin/settings/settings.html#settings-use-default-settings 2 | use_default_settings: true 3 | 4 | engines: 5 | # 启用默认禁用的引擎 6 | - name: bing 7 | disabled: false 8 | 9 | - name: 360search 10 | engine: 360search 11 | shortcut: 360so 12 | disabled: false 13 | 14 | - name: baidu 15 | engine: baidu 16 | shortcut: baidu 17 | disabled: false 18 | 19 | - name: iqiyi 20 | engine: iqiyi 21 | shortcut: iq 22 | disabled: false 23 | 24 | - name: acfun 25 | engine: acfun 26 | shortcut: acf 27 | disabled: false 28 | 29 | 30 | # 禁用默认启用的引擎 31 | - name: arch linux wiki 32 | engine: archlinux 33 | disabled: true 34 | - name: duckduckgo 35 | engine: duckduckgo 36 | distabled: true 37 | - name: github 38 | engine: github 39 | shortcut: gh 40 | disabled: true 41 | - name: wikipedia 42 | engine: wikipedia 43 | disabled: true 44 | 45 | 46 | server: 47 | # base_url is defined in the SEARXNG_BASE_URL environment variable, see .env and docker-compose.yml 48 | secret_key: "888" # change this! 49 | limiter: false # can be disabled for a private instance 50 | image_proxy: true 51 | ui: 52 | static_use_hash: true 53 | redis: 54 | url: redis://redis:6379/0 55 | 56 | search: 57 | safe_search: 0 58 | autocomplete: "" 59 | default_lang: "" 60 | formats: 61 | - html 62 | - json 63 | - csv 64 | - rss 65 | ratelimit: 66 | enabled: true 67 | # 调整每秒允许的请求数 68 | per_second: 5 69 | # 调整每分钟允许的请求数 70 | per_minute: 60 71 | 72 | 73 | -------------------------------------------------------------------------------- /deep_research/test.log: -------------------------------------------------------------------------------- 1 | 2025-04-04 16:07:25,268 - search_mcp - INFO - 2 | Connected to server with tools: ['search', 'get_images'] 3 | 2025-04-04 16:07:25,277 - search_mcp - INFO - 4 | MCP Client Started! 5 | 2025-04-04 16:07:25,279 - search_mcp - INFO - Type your queries or 'quit' to exit. 6 | 2025-04-04 16:07:51,186 - search_mcp - INFO - available_tools: 7 | 8 | [{'type': 'function', 'function': {'name': 'search', 'description': '互联网搜索', 'parameters': {'properties': {'query': {'title': 'Query', 'type': 'string'}}, 'required': ['query'], 'title': 'searchArguments', 'type': 'object'}}}, {'type': 'function', 'function': {'name': 'get_images', 'description': '获取图片链接和描述', 'parameters': {'properties': {'query': {'title': 'Query', 'type': 'string'}}, 'required': ['query'], 'title': 'get_imagesArguments', 'type': 'object'}}}] 9 | 2025-04-04 16:07:58,424 - search_mcp - INFO - llm_output(tool call):为了撰写关于大模型MCP协议的报告并配上插图,我将按照以下步骤进行: 10 | 11 | 1. **搜索MCP协议的相关信息**:首先,我需要了解MCP协议的基本概念、应用场景和技术细节。这将帮助我撰写报告的内容。 12 | 13 | 2. **获取相关插图**:在撰写报告的同时,我需要获取与MCP协议相关的图片,以便更好地解释和展示协议的内容。 14 | 15 | 我将首先使用`search`工具来搜索MCP协议的相关信息。 16 | 17 | ```json 18 | {"name": "search", "params": {"query": "大模型MCP协议 2025"}} 19 | ``` 20 | 2025-04-04 16:08:01,079 - __main__ - INFO - 21 | === Iteration 1 === 22 | 2025-04-04 16:08:13,145 - __main__ - INFO - Aggregated 6 unique links from this iteration. 23 | 2025-04-04 16:08:13,150 - __main__ - INFO - Fetching content from: https://baijiahao.baidu.com/s?id=1827543504174294111&wfr=spider&for=pc 24 | 2025-04-04 16:08:19,434 - __main__ - INFO - Page usefulness for https://baijiahao.baidu.com/s?id=1827543504174294111&wfr=spider&for=pc: Yes 25 | 2025-04-04 16:08:30,436 - __main__ - INFO - Extracted context from https://baijiahao.baidu.com/s?id=1827543504174294111&wfr=spider&for=pc (first 200 chars): The webpage provides detailed information on the Model Context Protocol (MCP) as of 2025, including its technical aspects and industry applications. Here are the relevant details: 26 | 27 | 1. **MCP Developmen 28 | 2025-04-04 16:08:30,444 - __main__ - INFO - Fetching content from: https://blog.csdn.net/monsion/article/details/146855599 29 | 2025-04-04 16:08:36,725 - __main__ - INFO - Page usefulness for https://blog.csdn.net/monsion/article/details/146855599: Yes 30 | 2025-04-04 16:08:52,693 - __main__ - INFO - Extracted context from https://blog.csdn.net/monsion/article/details/146855599 (first 200 chars): 大模型MCP协议(Model Context Protocol)是由Anthropic公司于2024年11月推出的一种开放标准,旨在统一大型语言模型(LLM)与外部数据源和工具之间的通信协议。MCP通过标准化接口、生态丰富、平台无关性、数据安全性和可扩展性等核心价值,解决了AI模型与外部系统交互中的兼容性、安全性和复杂性问题。2025年3月26日,MCP规范发布了新版本,主要更新包括基于OAuth 31 | 2025-04-04 16:08:52,701 - __main__ - INFO - Fetching content from: https://blog.csdn.net/weixin_61514920/article/details/146687846 32 | 2025-04-04 16:08:59,214 - __main__ - INFO - Page usefulness for https://blog.csdn.net/weixin_61514920/article/details/146687846: Yes 33 | 2025-04-04 16:09:09,017 - __main__ - INFO - Extracted context from https://blog.csdn.net/weixin_61514920/article/details/146687846 (first 200 chars): The webpage discusses the Model Context Protocol (MCP), particularly its applications and developments expected by 2025. Here are the summarized details relevant to the user's query regarding "大模型MCP协 34 | 2025-04-04 16:09:09,022 - __main__ - INFO - Fetching content from: https://m.sohu.com/a/877695960_121924584 35 | 2025-04-04 16:09:15,942 - __main__ - INFO - Page usefulness for https://m.sohu.com/a/877695960_121924584: Yes 36 | 2025-04-04 16:09:23,494 - __main__ - INFO - Extracted context from https://m.sohu.com/a/877695960_121924584 (first 200 chars): 大模型MCP协议在2025年的最新进展中,MCP协议被认为是行业的重要里程碑。作为连接智能体与人、环境的重要桥梁,MCP协议的标准化与安全性将深刻影响未来AI技术的发展。它不仅可以降低开发门槛,还将为多模型的无缝协作提供基础支持,推动AI能力的指数级增长。 37 | 2025-04-04 16:09:23,501 - __main__ - INFO - Fetching content from: https://new.qq.com/rain/a/20250330A06YHG00 38 | 2025-04-04 16:09:58,597 - __main__ - INFO - Page usefulness for https://new.qq.com/rain/a/20250330A06YHG00: No 39 | 2025-04-04 16:09:58,601 - __main__ - INFO - Fetching content from: https://www.163.com/dy/article/JOPHDL3L0538AWSX.html 40 | 2025-04-04 16:10:04,031 - __main__ - INFO - Page usefulness for https://www.163.com/dy/article/JOPHDL3L0538AWSX.html: No 41 | 2025-04-04 16:10:06,212 - __main__ - INFO - LLM did not provide any new search queries. Ending the loop. 42 | 2025-04-04 16:10:06,221 - search_mcp - INFO - tool name: 43 | search 44 | tool call result: 45 | meta=None content=[TextContent(type='text', text='The webpage provides detailed information on the Model Context Protocol (MCP) as of 2025, including its technical aspects and industry applications. Here are the relevant details:\n\n1. **MCP Development Background**:\n - Early language models like GPT-2 relied on static training data, leading to "data silos."\n - In November 2024, Anthropic introduced MCP to standardize AI interactions with external systems, inspired by the Language Server Protocol (LSP).\n\n2. **MCP Overview**:\n - MCP is an open protocol using a client-server architecture to connect AI models with external resources.\n - Core design principles include modularity, security, and interoperability.\n\n3. **MCP Server Core Functions**:\n - **Tools**: Encapsulate external services as callable functions (e.g., SQL queries, email sending).\n - **Resources**: Expose structured data like database records or API data.\n - **Prompts**: Predefined task templates and workflows to reduce development complexity.\n\n4. **MCP Protocol Specification**:\n - **Communication Protocol**: Layered design supporting HTTP/2, WebSocket, and gRPC.\n - **Data Layer**: Uses Protobuf for strong-type schema.\n - **Security Layer**: Implements OAuth 2.0 and end-to-end encryption (E2EE), with support for SM4 algorithm in enterprise deployments.\n\n5. **Developer Practical Guide**:\n - Steps to set up a minimal MCP server using Python, including installation, tool service definition, server startup, and client invocation.\n\n6. **Impact and Future Trends**:\n - **Technical Value**: Enhances development efficiency, ensures security, and fosters ecosystem collaboration.\n - **Future Trends**: Expansion to multimodal devices (e.g., AR glasses) and deployment on mobile devices with lightweight models like DeepSeek-R1-Zero.\n\nThese points directly address the user\'s query about the latest technical details of the MCP protocol in 2025.\n\n大模型MCP协议(Model Context Protocol)是由Anthropic公司于2024年11月推出的一种开放标准,旨在统一大型语言模型(LLM)与外部数据源和工具之间的通信协议。MCP通过标准化接口、生态丰富、平台无关性、数据安全性和可扩展性等核心价值,解决了AI模型与外部系统交互中的兼容性、安全性和复杂性问题。2025年3月26日,MCP规范发布了新版本,主要更新包括基于OAuth 2.1的身份验证框架、用Streamable HTTP传输取代HTTP+SSE传输方式,以及采用更加标准化的数据访问接口,进一步增强了MCP的安全性和性能。MCP的技术架构遵循客户端-服务器模型,包含MCP Hosts、MCP Clients、MCP Servers、本地数据源和远程服务等核心组件,支持本地通信和远程通信两种机制,使用JSON-RPC 2.0格式进行消息传输。MCP的应用场景广泛,包括数据与文件系统、开发工具、网络与浏览器自动化、生产力和通信、AI与专业工具等领域,支持的工具和服务包括Git、GitHub、Google Maps、PostgreSQL、Slack等。\n\nThe webpage discusses the Model Context Protocol (MCP), particularly its applications and developments expected by 2025. Here are the summarized details relevant to the user\'s query regarding "大模型MCP协议 2025 应用场景":\n\n1. **智能助手与知识管理**: MCP enables models like Claude to interact with user data sources, offering personalized services.\n2. **客户服务与自动化**: Integration with CRM systems allows automated customer service, enhancing efficiency by accessing real-time customer information.\n3. **内容创作与编辑**: MCP supports tools that assist in content generation and editing, streamlining creative processes.\n4. **软件开发与调试**: Developers can enhance coding environments with functionalities like auto-completion, error detection, and repair suggestions through MCP-enabled tools.\n\nThese applications leverage MCP\'s standardized communication protocol to enhance interaction between AI models and external data/tools, aiming for improved functionality and user experience in various sectors.\n\n大模型MCP协议在2025年的最新进展中,MCP协议被认为是行业的重要里程碑。作为连接智能体与人、环境的重要桥梁,MCP协议的标准化与安全性将深刻影响未来AI技术的发展。它不仅可以降低开发门槛,还将为多模型的无缝协作提供基础支持,推动AI能力的指数级增长。', annotations=None)] isError=False 46 | 2025-04-04 16:10:12,223 - search_mcp - INFO - llm_output: 47 | 根据已经获取的信息,我们可以撰写一篇关于大模型MCP协议的报告,内容涵盖了MCP协议的背景、核心功能、技术架构、应用场景以及未来发展趋势。然而,用户还要求报告配上插图,而目前我们尚未获取相关的图片资源。 48 | 49 | 因此,我们需要调用`get_images`工具来获取与MCP协议相关的插图,以完善报告。 50 | 51 | ```json 52 | {"name": "get_images", "params": {"query": "大模型MCP协议 2025"}} 53 | ``` 54 | 2025-04-04 16:10:12,233 - __main__ - INFO - Searching for images for query: 大模型MCP协议 2025 55 | 2025-04-04 16:10:15,254 - __main__ - INFO - Fetching image description for: https://modelcontextprotocol.info/images/blog/20250319/img2.webp 56 | 2025-04-04 16:10:19,270 - __main__ - INFO - Image description for https://modelcontextprotocol.info/images/blog/20250319/img2.webp: 图片展示了MCP(Multi-Client Protocol)架构的示意图,包括MCP客户端、服务器、远程服务、本地数据源以及与Claude等主机的连接关系。 57 | 2025-04-04 16:10:19,275 - __main__ - INFO - Fetching image description for: https://developer.qcloudimg.com/http-save/yehe-3089065/ac6a5560fddb5847febfb1355f814c59.webp 58 | 2025-04-04 16:10:24,879 - __main__ - INFO - Image description for https://developer.qcloudimg.com/http-save/yehe-3089065/ac6a5560fddb5847febfb1355f814c59.webp: 图片展示了一段对话,内容是关于查询数据库中金额最高的订单信息的过程,最终结果显示订单号为2,订单日期为2025年1月6日,订单金额为200.50,客户姓名为Jane Smith。 59 | 2025-04-04 16:10:24,888 - search_mcp - INFO - tool name: 60 | get_images 61 | tool call result: 62 | meta=None content=[TextContent(type='text', text='{"https://modelcontextprotocol.info/images/blog/20250319/img2.webp": "\\u56fe\\u7247\\u5c55\\u793a\\u4e86MCP\\uff08Multi-Client Protocol\\uff09\\u67b6\\u6784\\u7684\\u793a\\u610f\\u56fe\\uff0c\\u5305\\u62ecMCP\\u5ba2\\u6237\\u7aef\\u3001\\u670d\\u52a1\\u5668\\u3001\\u8fdc\\u7a0b\\u670d\\u52a1\\u3001\\u672c\\u5730\\u6570\\u636e\\u6e90\\u4ee5\\u53ca\\u4e0eClaude\\u7b49\\u4e3b\\u673a\\u7684\\u8fde\\u63a5\\u5173\\u7cfb\\u3002", "https://developer.qcloudimg.com/http-save/yehe-3089065/ac6a5560fddb5847febfb1355f814c59.webp": "\\u56fe\\u7247\\u5c55\\u793a\\u4e86\\u4e00\\u6bb5\\u5bf9\\u8bdd\\uff0c\\u5185\\u5bb9\\u662f\\u5173\\u4e8e\\u67e5\\u8be2\\u6570\\u636e\\u5e93\\u4e2d\\u91d1\\u989d\\u6700\\u9ad8\\u7684\\u8ba2\\u5355\\u4fe1\\u606f\\u7684\\u8fc7\\u7a0b\\uff0c\\u6700\\u7ec8\\u7ed3\\u679c\\u663e\\u793a\\u8ba2\\u5355\\u53f7\\u4e3a2\\uff0c\\u8ba2\\u5355\\u65e5\\u671f\\u4e3a2025\\u5e741\\u67086\\u65e5\\uff0c\\u8ba2\\u5355\\u91d1\\u989d\\u4e3a200.50\\uff0c\\u5ba2\\u6237\\u59d3\\u540d\\u4e3aJane Smith\\u3002"}', annotations=None)] isError=False 63 | 2025-04-04 16:10:29,154 - search_mcp - INFO - llm_output: 64 | 根据已经获取的信息,我们具备了撰写一篇关于大模型MCP协议的完整报告所需的内容,包括MCP协议的背景、核心功能、技术架构、应用场景以及未来发展趋势。此外,通过工具调用,我们还获得了与MCP协议相关的插图,例如MCP架构的示意图以及与MCP协议交互的场景图。 65 | 66 | 因此,结合已有的文字内容和插图资源,已经完全能够满足用户的需求。以下是最终的判断: 67 | 68 | -------------------------------------------------------------------------------- /deep_research/test.md: -------------------------------------------------------------------------------- 1 | # 大模型MCP协议报告(2025年) 2 | 3 | ## 1. MCP背景与发展 4 | 大模型MCP协议(Model Context Protocol)是由Anthropic公司于2024年11月推出的一种开放标准,旨在统一大型语言模型(LLM)与外部数据源和工具之间的通信协议。MCP通过标准化接口、生态丰富、平台无关性、数据安全性和可扩展性等核心价值,解决了AI模型与外部系统交互中的兼容性、安全性和复杂性问题。 5 | 6 | **相关图片**: 7 | ![MCP架构示意图](https://modelcontextprotocol.info/images/blog/20250319/img2.webp) 8 | *图片展示了MCP(Multi-Client Protocol)架构的示意关系,包括MCP客户端、服务器、远程服务、本地数据源以及与Claude等主机的连接关系。* 9 | 10 | ## 2. MCP协议的核心功能 11 | MCP协议基于客户端-服务器模型,包含以下核心功能: 12 | - **工具**:封装外部服务为可调用函数,例如SQL查询、邮件发送等。 13 | - **资源**:暴露结构化数据,如数据库记录或API数据。 14 | - **提示**:提供预定义的任务模板和工作流,降低开发复杂度。 15 | 16 | ## 3. MCP技术架构 17 | MCP协议的技术架构遵循分层设计,具体包括: 18 | - **通信协议**:支持HTTP/2、WebSocket和gRPC。 19 | - **数据层**:基于Protobuf实现强类型数据模式。 20 | - **安全层**:集成OAuth 2.0和端到端加密(E2EE),并支持企业级SM4算法。 21 | 22 | ## 4. MCP应用场景 23 | MCP协议在2025年的应用场景广泛,涵盖以下领域: 24 | 1. **智能助手与知识管理**:MCP使模型能够与用户数据源交互,提供个性化服务。 25 | 2. **客户服务与自动化**:与CRM系统集成,实现高效的客户服务自动化。 26 | 3. **内容创作与编辑**:支持辅助内容生成和编辑的工具,优化创作流程。 27 | 4. **软件开发与调试**:开发者可通过MCP增强代码环境的自动补全、错误检测和修复建议功能。 28 | 29 | **相关图片**: 30 | ![数据库查询示例](https://developer.qcloudimg.com/http-save/yehe-3089065/ac6a5560fddb5847febfb1355f814c59.webp) 31 | *图片展示了通过MCP协议查询数据库中金额最高的订单信息的过程,结果包括订单号、日期、金额和客户姓名。* 32 | 33 | ## 5. 未来发展趋势 34 | MCP协议在2025年的最新进展中,被认为是行业的重要里程碑。其标准化与安全性将深刻影响未来AI技术的发展,具体趋势包括: 35 | - **多模态设备扩展**:支持增强现实(AR)眼镜等设备。 36 | - **轻量级模型部署**:在移动设备上部署如DeepSeek-R1-Zero等轻量级模型。 37 | 38 | ## 6. 结论 39 | 大模型MCP协议通过标准化接口和强大的功能,为AI模型与外部系统的交互提供了高效、安全的解决方案。其广泛的应用场景和未来发展潜力,将进一步推动AI技术的普及和创新。 -------------------------------------------------------------------------------- /deepseek_learn/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/deepseek_learn/1.png -------------------------------------------------------------------------------- /deepseek_learn/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/deepseek_learn/2.png -------------------------------------------------------------------------------- /deepseek_learn/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/deepseek_learn/3.png -------------------------------------------------------------------------------- /deepseek_learn/MLA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | # rms归一化 7 | class RMSNorm(nn.Module): 8 | def __init__(self, hidden_size, eps=1e-6): 9 | 10 | super().__init__() 11 | self.weight = nn.Parameter(torch.ones(hidden_size)) 12 | self.variance_epsilon = eps 13 | 14 | def forward(self, hidden_states): 15 | hidden_states = hidden_states.float() 16 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 17 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 18 | return self.weight * hidden_states.float() 19 | 20 | 21 | def rotate_half(x): 22 | x1, x2 = x.chunk(2, dim=-1) 23 | return torch.cat((-x2, x1), dim=-1) 24 | 25 | def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2): 26 | 27 | cos = cos.unsqueeze(unsqueeze_dim) 28 | sin = sin.unsqueeze(unsqueeze_dim) 29 | 30 | q_embed = (q*cos) + (rotate_half(q)*sin) 31 | k_embed = (k*cos) + (rotate_half(k)*sin) 32 | 33 | return q_embed, k_embed 34 | 35 | # 旋转位置编码 36 | class RotaryEmbedding(nn.Module): 37 | def __init__(self, dim, max_seq_len=1024): 38 | super(RotaryEmbedding, self).__init__() 39 | self.dim = dim 40 | self.max_seq_len = max_seq_len 41 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 42 | t = torch.arange(max_seq_len).float().unsqueeze(1) 43 | freqs = t @ inv_freq.unsqueeze(0) 44 | freqs = torch.cat((freqs, freqs), dim=-1) 45 | 46 | self.register_buffer("cos_cached", freqs.cos()) 47 | self.register_buffer("sin_cached", freqs.sin()) 48 | 49 | def forward(self, q, k): 50 | cos = self.cos_cached[:q.shape[1], :].unsqueeze(0) 51 | sin = self.sin_cached[:q.shape[1], :].unsqueeze(0) 52 | return apply_rotate_pos_emb(q, k, cos, sin) 53 | 54 | class MLA(nn.Module): 55 | def __init__(self, 56 | dim, 57 | n_heads, 58 | q_lora_rank, 59 | kv_lora_rank, 60 | qk_nope_head_dim, 61 | qk_rope_head_dim, 62 | v_head_dim, 63 | max_seq_len, 64 | max_batch_size, 65 | mode): 66 | super().__init__() 67 | self.dim = dim # 隐藏层维度 68 | self.n_heads = n_heads #总头数 69 | self.q_lora_rank = q_lora_rank # q低秩压缩到的维度 70 | self.kv_lora_rank = kv_lora_rank # kv低秩压缩到的维度 71 | self.qk_nope_head_dim = qk_nope_head_dim 72 | self.qk_rope_head_dim = qk_rope_head_dim 73 | self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim # qk的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度 74 | self.v_head_dim = v_head_dim # value的维度,等于不带旋转位置编码的k维度 75 | self.mode = mode 76 | self.max_seq_len = max_seq_len 77 | self.max_batch_size = max_batch_size 78 | 79 | 80 | self.wq_a = nn.Linear(self.dim, self.q_lora_rank) # q的降维矩阵 81 | self.q_norm = RMSNorm(self.q_lora_rank) 82 | self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim) # q的升维矩阵 83 | # 4096*128+128*4864 = 524,288 + 622592 = 1146880 4096*4864 = 19,922,944 84 | 85 | self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) # kv的降维矩阵 86 | # nn.Linear(self.dim, self.kv_lora_rank) 87 | # nn.Linear(self.dim, self.qk_rope_head_dim) 88 | self.kv_norm = RMSNorm(self.kv_lora_rank) 89 | self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) # kv的升维矩阵 90 | 91 | self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim) 92 | 93 | self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim) # 旋转旋转位置编码 94 | 95 | if self.mode == 'naive': 96 | self.register_buffer('k_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim), persistent=False) 97 | self.register_buffer('v_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim), persistent=False) 98 | 99 | else: 100 | self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank), persistent=False) 101 | self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim), persistent=False) 102 | 103 | 104 | def forward(self, x, mask=None): 105 | 106 | bs, seq_len, _ = x.shape 107 | 108 | q = self.wq_a(x) # [bs, seq_len, q_lora_rank] 109 | q = self.q_norm(q) # [bs, seq_len, q_lora_rank] 110 | q = self.wq_b(q) # [bs, seq_len, n_heads * qk_head_dim] 111 | q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim) # [bs, seq_len, n_heads, qk_head_dim] 112 | q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim] 113 | 114 | kv = self.wkv_a(x) # [bs, seq_len, kv_lora_rank + qk_rope_head_dim] 115 | kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim] 116 | 117 | k_pe = k_pe.unsqueeze(2) # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim] 118 | q_pe, k_pe = self.rotary_emb(q_pe, k_pe) 119 | if self.mode == 'naive': 120 | 121 | q = torch.cat([q_nope, q_pe], dim=-1) # * [bs, seq_len, n_heads, qk_head_dim] 122 | 123 | kv = self.kv_norm(kv) # [bs, seq_len, kv_lora_rank)] 124 | kv = self.wkv_b(kv) # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)] 125 | kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim) 126 | k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) 127 | 128 | k = torch.cat([k_nope, k_pe.expand(-1,-1,self.n_heads,-1)], dim=-1) 129 | # k shape:[bs, seq_len, n_heads, qk_head_dim] 130 | self.k_cache[:bs, :seq_len, :, :] = k 131 | self.v_cache[:bs, :seq_len, :, :] = v 132 | # scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) 133 | scores = torch.matmul(q.transpose(1, 2), self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)) 134 | scores = scores.transpose(1, 2) 135 | 136 | else: 137 | k_pe = k_pe.squeeze(2) 138 | wkv_b = self.wkv_b.weight # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank] 139 | wkv_b = wkv_b.view(self.n_heads, -1, self.kv_lora_rank) # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank] 140 | q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank] 141 | # q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T) c为压缩后的kv 142 | # wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵 c可以替代原先的k,这样就可以直接使用压缩后的kv计算注意力了,kv_caceh时也只需存储压缩后的kv 143 | kv = self.kv_norm(kv) 144 | self.kv_cache[:bs, :seq_len, :] = kv # kv shape:[bs, seq_len, kv_lora_rank] 145 | self.pe_cache[:bs, :seq_len, :] = k_pe # k_pe shape:[bs, seq_len, qk_rope_head_dim] 146 | 147 | scores_nope = torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bs, :seq_len, :]) # bshc btc -> bshc bct -> bsht 148 | scores_pe = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bs, :seq_len, :]) # bshr btr -> bshr bt1r -> bshr bthr -> bsht 149 | scores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) # [bs, seq_len, n_heads, seq_len] 150 | 151 | if mask is not None: 152 | # mask shape:[bs, seq_len, seq_len] 153 | scores += mask.unsqueeze(2) 154 | 155 | scores = scores.softmax(dim=-1) 156 | 157 | if self.mode == 'naive': 158 | x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bs, :seq_len]) # bsht,bthd -> bhst, bhtd -> bhsd -> bshd 159 | else: 160 | 161 | # scores * v = scores * c * wkv_b[:, -self.v_head_dim:] 162 | x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bs, :seq_len]) # x shape:[bs, seq_len, n_heads, kv_lora_rank] 163 | x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) # bshc, hdc -> bshc,dch -> bsdh -> bshd 164 | 165 | x = x.contiguous ().view(bs, seq_len, -1) 166 | x = self.wo(x) 167 | 168 | return x 169 | 170 | if __name__ == '__main__': 171 | 172 | x = torch.randn(4, 100, 4096) 173 | 174 | dim = 4096 175 | n_heads = 16 176 | q_lora_rank = 128 177 | kv_lora_rank = 64 178 | qk_nope_head_dim = 256 179 | qk_rope_head_dim = 48 180 | v_head_dim = 256 181 | max_seq_len = 512 182 | max_batch_size = 16 183 | mode = 'none' 184 | 185 | mla = MLA(dim=dim, 186 | n_heads=n_heads, 187 | q_lora_rank=q_lora_rank, 188 | kv_lora_rank=kv_lora_rank, 189 | qk_nope_head_dim=qk_nope_head_dim, 190 | qk_rope_head_dim=qk_rope_head_dim, 191 | v_head_dim=v_head_dim, 192 | max_seq_len=max_seq_len, 193 | max_batch_size=max_batch_size, 194 | mode=mode) 195 | 196 | 197 | 198 | print(mla(x)) 199 | print(mla.kv_cache) -------------------------------------------------------------------------------- /deepseek_learn/MTP_train/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/deepseek_learn/MTP_train/loss.png -------------------------------------------------------------------------------- /deepseek_learn/MTP_train/test.py: -------------------------------------------------------------------------------- 1 | from MTP import MTP,Config 2 | from transformers import AutoTokenizer 3 | import torch 4 | config = Config() 5 | 6 | tokenizer = AutoTokenizer.from_pretrained(config.llm_model_path) 7 | model = MTP(config) 8 | model.cpu() 9 | model.load_state_dict(torch.load('/home/user/wyf/deepseek_learn/MTP_train/mtp/model_4499.pth')) 10 | 11 | input_ids = tokenizer.apply_chat_template([{"role": "user", "content": "宝宝生下来很可爱,但是觉得嘴唇发紫,而且心跳的也很慢,新生儿心率多少正常?"}], tokenize=True, return_tensors='pt', add_generation_prompt=True) 12 | seq = model.generate(input_ids, max_length=100) 13 | print(tokenizer.decode(seq[0])) -------------------------------------------------------------------------------- /deepseek_learn/deepseek_r1_train/deepseek-r1解读.md: -------------------------------------------------------------------------------- 1 | # 常规的大模型训练方案 2 | pretrain -> sft -> rl 3 | 4 | # DeepSeek-R1-Zero 5 | pretrain -> rl 6 | 7 | 缺陷:中英文混合、格式混乱 8 | 9 | # DeepSeek-R1 10 | pretrain -> sft一阶段 -> rl一阶段 -> sft二阶段 -> rl二阶段 11 | 12 | ## sft一阶段(冷启动) 13 | 14 | 目的:引入数千条高质量长推理链数据对基础模型微调,强制规范输出格式(如\推理过程\),提升可读性。\ 15 | 数据来源:收集DeepSeek-R1-Zero的输出结果,以可读的格式呈现,最后通过人工标注者进行后处理以优化结果 16 | 17 | ## rl一阶段(推理导向的rl) 18 | 19 | rl方法:GRPO\ 20 | 奖励模型:基于规则的奖励(答案准确性和语言一致性),针对代码、数学、编程等有固定答案的任务设计奖励函数。 21 | 22 | ## sft二阶段 23 | 24 | 数据来源:推理数据和非推理数据合并 25 | 26 | 推理数据:rl一阶段checkpoint输出数据(60万)。rl一阶段,仅纳入了可以基于规则的奖励进行评估的数据。在sft二阶段,通过引入额外的数据来扩展数据集,其中一些数据通过将真实答案和模型预测输入DeepSeek-V3进行判断,使用生成式奖励模型。此外,由于模型输出有时会显得混乱且难以阅读,过滤掉了包含混合语言、长段落和代码块的推理链。对于每个提示,采样多个回答,仅保留正确的回答。收集了大约60万个与推理相关的训练样本。 27 | 28 | 非推理数据:如写作、事实问答、自我认知和翻译等,重用DeepSeek-V3监督微调数据集的部分内容。收集了大约20万个与推理无关的训练样本。 29 | 30 | ## rl二阶段(通用对齐的rl) 31 | 32 | 通用对齐RL(RLHF):融入人类偏好奖励模型(Helpfulness & Harmlessness),确保模型在开放域任务中的安全性与实用性。 -------------------------------------------------------------------------------- /deepseek_learn/deepseek_r1_train/deepseek_r1_train.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from datasets import load_dataset, Dataset 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | import trl 6 | from trl import GRPOConfig, GRPOTrainer 7 | from peft import LoraConfig, get_peft_model, TaskType 8 | 9 | SYSTEM_PROMPT = """ 10 | 按照如下格式生成: 11 | 12 | ... 13 | 14 | 15 | ... 16 | 17 | """ 18 | def process_data(data): 19 | data = data.map(lambda x: { 20 | 'prompt': [ 21 | {'role': 'system', 'content': SYSTEM_PROMPT}, 22 | {'role': 'user', 'content': x['question_zh-cn']} 23 | ], 24 | 'answer': x['answer_only'] 25 | }) 26 | return data 27 | def extract_answer(text): 28 | answer = text.split("")[-1] 29 | answer = answer.split("")[0] 30 | return answer.strip() 31 | 32 | def mark_num(text): 33 | reward = 0 34 | if text.count("\n") == 1: 35 | reward += 0.125 36 | 37 | if text.count("\n") == 1: 38 | reward += 0.125 39 | 40 | if text.count("\n") == 1: 41 | reward += 0.125 42 | 43 | if text.count("\n") == 1: 44 | reward += 0.125 45 | return reward 46 | 47 | # 生成答案是否正确的奖励 48 | def correctness_reward(prompts, completions, answer, **kwargs): 49 | responses = [completion[0]['content'] for completion in completions] 50 | extracted_responses = [extract_answer(r) for r in responses] 51 | print(f"问题:\n{prompts[0][-1]['content']}", f"\n答案:\n{answer[0]}", f"\n模型输出:\n{responses[0]}", f"\n提取后的答案:\n{extracted_responses[0]}") 52 | return [2.0 if response == str(ans) else 0.0 for response, ans in zip(extracted_responses, answer)] 53 | # 生成答案是否是数字的奖励(单纯依赖结果是否正确进行奖励,条件很苛刻,会导致奖励比较稀疏,模型难以收敛,所以加上答案是否是数字的奖励,虽然答案错误,但是至少生成的是数字(对于数学问题),也要给予适当奖励) 54 | def digit_reward(completions, **kwargs): 55 | responses = [completion[0]['content'] for completion in completions] 56 | extracted_responses = [extract_answer(r) for r in responses] 57 | return [0.5 if response.isdigit() else 0.0 for response in extracted_responses] 58 | 59 | # 格式奖励 60 | def hard_format_reward(completions, **kwargs): 61 | pattern = r"^\n.*?\n\n\n.*?\n\n$" 62 | responses = [completion[0]["content"] for completion in completions] 63 | matches = [re.match(pattern, response) for response in responses] 64 | return [0.5 if match else 0.0 for match in matches] 65 | # 格式奖励 66 | def soft_format_reward(completions, **kwargs): 67 | pattern = r".*?\s*.*?" 68 | responses = [completion[0]["content"] for completion in completions] 69 | matches = [re.match(pattern, response) for response in responses] 70 | return [0.5 if match else 0.0 for match in matches] 71 | # 标记奖励(改善格式奖励稀疏问题) 72 | def mark_reward(completions, **kwargs): 73 | responses = [completion[0]["content"] for completion in completions] 74 | return [mark_num(response) for response in responses] 75 | 76 | 77 | if __name__ == '__main__': 78 | model_name = "/home/user/Downloads/Qwen2.5-0.5B-Instruct" 79 | 80 | model = AutoModelForCausalLM.from_pretrained(model_name) 81 | # 如果使用lora方法训练,取消如下注释 82 | # lora_config = LoraConfig( 83 | # r=8, 84 | # lora_alpha=256, 85 | # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 86 | # lora_dropout=0.1, 87 | # task_type=TaskType.CAUSAL_LM) 88 | # # 使用lora方法训练 89 | # model = get_peft_model(model, lora_config) 90 | model.cuda() 91 | 92 | tokenizer = AutoTokenizer.from_pretrained(model_name) 93 | 94 | ds = load_dataset('/home/user/wyf/deepseek_learn/gsm8k_chinese') 95 | data = process_data(ds['train']) 96 | 97 | output_dir="output" 98 | 99 | training_args = GRPOConfig( 100 | output_dir=output_dir, 101 | learning_rate=5e-6, 102 | adam_beta1 = 0.9, 103 | adam_beta2 = 0.99, 104 | weight_decay = 0.1, 105 | warmup_ratio = 0.1, 106 | lr_scheduler_type='cosine', 107 | logging_steps=1, 108 | bf16=True, 109 | per_device_train_batch_size=1, 110 | gradient_accumulation_steps=4, 111 | num_generations=16, 112 | max_prompt_length=256, 113 | max_completion_length=200, 114 | num_train_epochs=1, 115 | save_steps=100, 116 | max_grad_norm=0.1, 117 | log_on_each_node=False, 118 | use_vllm=False, 119 | report_to="tensorboard" 120 | ) 121 | 122 | trainer = GRPOTrainer( 123 | model=model, 124 | processing_class=tokenizer, 125 | reward_funcs=[ 126 | mark_reward, 127 | soft_format_reward, 128 | hard_format_reward, 129 | digit_reward, 130 | correctness_reward 131 | ], 132 | args=training_args, 133 | train_dataset=data, 134 | 135 | ) 136 | trainer.train() 137 | trainer.save_model(output_dir) 138 | -------------------------------------------------------------------------------- /deepseek_learn/deepseek_r1_train/deepseek_reward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/deepseek_learn/deepseek_r1_train/deepseek_reward.png -------------------------------------------------------------------------------- /deepseek_learn/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from openai import OpenAI\n", 10 | "client = OpenAI(base_url='http://10.250.2.24:8222/v1', api_key='nn')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 14, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "SYSTEM_PROMPT = \"\"\"\n", 20 | "按照如下格式生成:\n", 21 | "\n", 22 | "...\n", 23 | "\n", 24 | "\n", 25 | "...\n", 26 | "\n", 27 | "\"\"\"" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 22, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "\n", 40 | "树下一只猴,树上骑个猴,所以一共是2只猴。\n", 41 | "\n", 42 | "\n", 43 | "4\n", 44 | "\n", 45 | "\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "res = client.chat.completions.create(\n", 51 | " model=\"ddd\",\n", 52 | " messages=[{\"role\": \"system\", \"content\": SYSTEM_PROMPT}, {\"role\": \"user\", \"content\": '树下一只猴?树上骑个猴?请问一共几只猴?'}],\n", 53 | " temperature=0\n", 54 | ")\n", 55 | "print(res.choices[0].message.content)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 36, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "from langchain_openai import ChatOpenAI\n", 65 | "llm = ChatOpenAI(temperature=0, base_url='http://10.250.2.24:8222/v1', api_key='nn', model='ddd')" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 37, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "AIMessage(content='树下一只猴,树上骑个猴,一共是3只猴。', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 46, 'total_tokens': 64, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'ddd', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-0b6d6d1f-3fd2-4dda-9916-1194b853650b-0', usage_metadata={'input_tokens': 46, 'output_tokens': 18, 'total_tokens': 64})" 77 | ] 78 | }, 79 | "execution_count": 37, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "llm.invoke('树下一只猴?树上骑个猴?请问一共几只猴?')" 86 | ] 87 | } 88 | ], 89 | "metadata": { 90 | "kernelspec": { 91 | "display_name": "wyf", 92 | "language": "python", 93 | "name": "python3" 94 | }, 95 | "language_info": { 96 | "codemirror_mode": { 97 | "name": "ipython", 98 | "version": 3 99 | }, 100 | "file_extension": ".py", 101 | "mimetype": "text/x-python", 102 | "name": "python", 103 | "nbconvert_exporter": "python", 104 | "pygments_lexer": "ipython3", 105 | "version": "3.11.0" 106 | } 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 2 110 | } 111 | -------------------------------------------------------------------------------- /gradio_mcp_client.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from mcp.client.sse import sse_client 3 | from mcp import ClientSession 4 | from openai import AsyncOpenAI 5 | import json 6 | 7 | 8 | SYSTEM_PROMPT = """你是一个AI助手。 9 | 你可以使用 MCP 服务器提供的工具来完成任务。 10 | MCP 服务器会动态提供工具,你需要先检查当前可用的工具。 11 | 12 | 在使用 MCP 工具时,请遵循以下步骤: 13 | 1、根据任务需求选择合适的工具 14 | 2、按照工具的参数要求提供正确的参数 15 | 3、观察工具的返回结果,并根据结果决定下一步操作 16 | 4、工具可能会发生变化,比如新增工具或现有工具消失 17 | 18 | 请遵循以下指南: 19 | - 使用工具时,确保参数符合工具的文档要求 20 | - 如果出现错误,请理解错误原因并尝试用修正后的参数重新调用 21 | - 按照任务需求逐步完成,优先选择最合适的工具 22 | - 如果需要连续调用多个工具,请一次只调用一个工具并等待结果 23 | 24 | 请清楚地向用户解释你的推理过程和操作步骤。 25 | """ 26 | 27 | async def query(query: str, mcp_server_url, model_name, base_url, api_key, temperature): 28 | 29 | client = AsyncOpenAI( 30 | base_url=base_url, api_key=api_key 31 | ) 32 | 33 | async with sse_client(mcp_server_url) as streams: 34 | 35 | async with ClientSession(*streams) as session: 36 | 37 | await session.initialize() 38 | 39 | response = await session.list_tools() 40 | messages = [ 41 | { 42 | "role": "system", 43 | "content": SYSTEM_PROMPT 44 | }, 45 | { 46 | "role": "user", 47 | "content": query 48 | } 49 | ] 50 | 51 | available_tools = [{ 52 | "type": "function", 53 | "function": { 54 | "name": tool.name, 55 | "description": tool.description, 56 | "parameters": tool.inputSchema 57 | } 58 | } for tool in response.tools] 59 | 60 | 61 | # 初始化 LLM API 调用 62 | response = await client.chat.completions.create( 63 | model=model_name, 64 | temperature=temperature, 65 | messages=messages, 66 | tools=available_tools, 67 | stream=True 68 | ) 69 | # message = response.choices[0].message 70 | full_response = "" 71 | tool_call_text = "" 72 | 73 | while True: 74 | func_call_list = [] 75 | async for chunk in response: 76 | if chunk.choices[0].delta.content: 77 | content = chunk.choices[0].delta.content 78 | full_response += content 79 | yield tool_call_text, full_response.replace('', "").replace('', "") # 流式返回当前累积内容 80 | elif chunk.choices[0].delta.tool_calls: 81 | 82 | for tcchunk in chunk.choices[0].delta.tool_calls: 83 | if len(func_call_list) <= tcchunk.index: 84 | func_call_list.append({ 85 | "id": "", 86 | "name": "", 87 | "type": "function", 88 | "function": { "name": "", "arguments": "" } 89 | }) 90 | tc = func_call_list[tcchunk.index] 91 | if tcchunk.id: 92 | tc["id"] += tcchunk.id 93 | if tcchunk.function.name: 94 | tc["function"]["name"] += tcchunk.function.name 95 | if tcchunk.function.arguments: 96 | tc["function"]["arguments"] += tcchunk.function.arguments 97 | 98 | 99 | if not func_call_list: 100 | break 101 | 102 | full_response += '🛠️ 调用工具...\n' 103 | yield tool_call_text, full_response.replace('', "").replace('', "") 104 | 105 | for tool_call in func_call_list: 106 | print(tool_call) 107 | tool_name = tool_call['function']['name'] 108 | if tool_call['function']['arguments']: 109 | tool_args = json.loads(tool_call['function']['arguments']) 110 | else: 111 | tool_args = {} 112 | 113 | # 执行工具调用 114 | result = await session.call_tool(tool_name, tool_args) 115 | # 记录调用详情到状态栏 116 | tool_call_text += f"✅ 工具返回: {tool_name}\n参数: {tool_args}\n结果: {str(result.content)}\n---\n" 117 | yield tool_call_text, full_response.replace('', "").replace('', "") # 先更新状态栏 118 | 119 | # 将工具调用和结果添加到消息历史 120 | messages.append({ 121 | "role": "assistant", 122 | "tool_calls": [ 123 | { 124 | "id": tool_call['id'], 125 | "type": "function", 126 | "function": { 127 | "name": tool_name, 128 | "arguments": json.dumps(tool_args) 129 | } 130 | } 131 | ] 132 | }) 133 | messages.append({ 134 | "role": "tool", 135 | "tool_call_id": tool_call['id'], 136 | "content": str(result.content) 137 | }) 138 | 139 | # 将工具调用的结果交给 LLM 140 | response = await client.chat.completions.create( 141 | model=model_name, 142 | temperature=temperature, 143 | messages=messages, 144 | tools=available_tools, 145 | stream=True) 146 | 147 | 148 | 149 | with gr.Blocks() as demo: 150 | gr.Markdown("## MCP 客户端") 151 | 152 | # 左右分栏布局 153 | with gr.Row(): 154 | # 左侧参数输入栏 155 | with gr.Column(scale=1): 156 | gr.Markdown("### 🧠 大模型配置") 157 | model_name = gr.Textbox( 158 | label="模型名称" 159 | ) 160 | base_url = gr.Textbox( 161 | label="API 地址" 162 | ) 163 | api_key = gr.Textbox( 164 | label="API Key", 165 | type="password" 166 | ) 167 | temperature = gr.Number( 168 | label="温度", 169 | value=0.0, 170 | ) 171 | 172 | gr.Markdown("### 🌐 MCP 服务配置") 173 | mcp_server_url = gr.Textbox( 174 | label="MCP 服务地址" 175 | ) 176 | 177 | # 工具调用状态面板 178 | tool_status = gr.Textbox( 179 | label="🛠️ 工具调用记录", 180 | lines=10, 181 | interactive=False, 182 | autoscroll=True, 183 | ) 184 | 185 | # 右侧输出区域 186 | with gr.Column(scale=2): 187 | gr.Markdown("### 💬 交互窗口") 188 | result_display = gr.Textbox( 189 | label="🧠 模型输出", 190 | lines=35, 191 | show_copy_button=True, 192 | ) 193 | 194 | # 最底部问题输入区 195 | with gr.Row(): 196 | query_input = gr.Textbox( 197 | label="❓ 输入你的问题", 198 | placeholder="输入问题后点击生成按钮...", 199 | scale=4 200 | ) 201 | generate = gr.Button( 202 | "🚀 开始生成", 203 | scale=1, 204 | variant="primary" 205 | ) 206 | 207 | generate.click(fn=query, inputs=[query_input, mcp_server_url, model_name, base_url, api_key, temperature], outputs=[tool_status, result_display]) 208 | 209 | 210 | 211 | 212 | 213 | if __name__ == "__main__": 214 | demo.queue().launch(server_name='0.0.0.0', allowed_paths=['./']) 215 | 216 | -------------------------------------------------------------------------------- /grpo_from_scratch/grpo_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/grpo_from_scratch/grpo_loss.png -------------------------------------------------------------------------------- /grpo_from_scratch/reward_func.py: -------------------------------------------------------------------------------- 1 | import re 2 | def extract_answer(text): 3 | answer = text.split("")[-1] 4 | answer = answer.split("")[0] 5 | return answer.strip() 6 | 7 | def mark_num(text): 8 | reward = 0 9 | if text.count("\n") == 1: 10 | reward += 0.125 11 | 12 | if text.count("\n") == 1: 13 | reward += 0.125 14 | 15 | if text.count("\n") == 1: 16 | reward += 0.125 17 | 18 | if text.count("\n") == 1: 19 | reward += 0.125 20 | return reward 21 | 22 | # 生成答案是否正确的奖励 23 | def correctness_reward(prompts, responses, answers): 24 | 25 | extracted_responses = [extract_answer(r) for r in responses] 26 | print(f"问题:\n{prompts[0]}", f"\n答案:\n{answers[0]}", f"\n模型输出:\n{responses[0]}", f"\n提取后的答案:\n{extracted_responses[0]}") 27 | return [2.0 if response == str(ans) else 0.0 for response, ans in zip(extracted_responses, answers)] 28 | 29 | # 生成答案是否是数字的奖励(单纯依赖结果是否正确进行奖励,条件很苛刻,会导致奖励比较稀疏,模型难以收敛,所以加上答案是否是数字的奖励,虽然答案错误,但是至少生成的是数字(对于数学问题),也要给予适当奖励) 30 | def digit_reward(prompts, responses, answers): 31 | extracted_responses = [extract_answer(r) for r in responses] 32 | return [0.5 if response.isdigit() else 0.0 for response in extracted_responses] 33 | 34 | # 格式奖励 35 | def hard_format_reward(prompts, responses, answers): 36 | pattern = r"^\n.*?\n\n\n.*?\n\n$" 37 | matches = [re.match(pattern, response) for response in responses] 38 | return [0.5 if match else 0.0 for match in matches] 39 | 40 | # 标记奖励(改善格式奖励稀疏问题) 41 | def mark_reward(prompts, responses, answers): 42 | return [mark_num(response) for response in responses] -------------------------------------------------------------------------------- /grpo_from_scratch/test.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | client = OpenAI(api_key='ww', base_url='http://10.250.2.24:8036/v1') 3 | 4 | SYSTEM_PROMPT = """ 5 | 按照如下格式回答问题: 6 | 7 | 你的思考过程 8 | 9 | 10 | 你的回答 11 | 12 | """ 13 | 14 | completion = client.chat.completions.create( 15 | model = 'qwen1.5b', 16 | 17 | temperature=0.0, 18 | logprobs = True, 19 | messages=[ 20 | { 21 | "role": "system", 22 | "content": SYSTEM_PROMPT}, 23 | { 24 | "role": "user", 25 | "content": "天上五只鸟,地上五只鸡,一共几只鸭", 26 | } 27 | ], 28 | ) 29 | print(completion.choices[0].message.content) 30 | -------------------------------------------------------------------------------- /knowledge_distillation_llm/README.md: -------------------------------------------------------------------------------- 1 | ## 原理 2 | 将待压缩的模型作为教师模型,将体积更小的模型作为学生模型,让学生模型在教师模型的监督下进行优化,将学生模型学习到教师模型的概率分布,通过kl散度进行控制。 3 | ## 方法 4 | 对于大模型的知识蒸馏,主要分为两种: 5 | ### 其一、黑盒知识蒸馏。 6 | 使用大模型生成数据,通过这些数据去微调更小的模型,来达到蒸馏的目的。缺点是蒸馏效率低,优点是实现简单。 7 | ### 其二、白盒知识蒸馏。 8 | 获取学生模型和教师模型的输出概率分布(或者中间隐藏层的概率分布),通过kl散度将学生模型的概率分布向教师模型对齐。 9 | 下面主要介绍和测试白盒知识蒸馏: 10 | 白盒知识蒸馏主要在于模型分布的对齐,模型分布对齐主要依赖kl散度,对于kl散度的使用又有如下几种方式: 11 | #### 其一、前向kl散度。 12 | 也就是我们经常说的kl散度。\ 13 | ![fkl_formula](./images/fkl_formula.png "可选标题")\ 14 | p为教师模型的概率分布,q为学生模型的概率分布,minillm论文中提到前向kl散度可能会使学生模型高估教师模型中概率比较低的位置,结合公式来看,当p增大时,为了使得kl散度小,则q也需要增大,但是当p趋于0时,无论q取任何值,kl散度都比较小,因为此时p(x)log((p(x)/q(x)))的大小主要受p(x)控制,这样起不到优化q分布的效果,可能会使q分布高估p分布中概率低的位置。 15 | 下图展示了前向kl散度的拟合情况,前向kl散度是一种均值搜索,更倾向于拟合多峰 16 | ![fkl](./images/fkl.png "可选标题") 17 | 18 | #### 其二、反向kl散度。 19 | 为了缓解前向kl散度的缺点,提出了反向kl散度。\ 20 | ![rkl_formula](./images/rkl_formula.png "可选标题")\ 21 | p为教师模型的概率分布,q为学生模型的概率分布,当p趋于零时,为了使kl散度小,q也需趋于0。 22 | minillm论文中说对于大模型的知识蒸馏,反向kl散度优于前向kl散度,但是也有其他论文说反向kl散度不一定比前向kl散度更优,实际选择中,可能要基于实验驱动。 23 | 反向kl散度是一种模式搜索,更倾向于拟合单个峰 24 | ![rkl](./images/rkl.png "可选标题") 25 | 26 | #### 其三、偏向前kl散度。 27 | 对学生模型的分布和教师模型的分布进行加权作为学生模型的分布。 28 | #### 其四、偏向反kl散度。 29 | 对学生模型的分布和教师模型的分布进行加权作为教师模型的分布。 30 | ## 三、测试 31 | qwen2.5-3b作为教师模型,qwen2.5-0.5b作为学生模型\ 32 | 流程如下:\ 33 | 1、将qwen2.5-3b模型在指定数据集上微调(训练数据5000条,测试数据1000条,测试准确度为81.1%)\ 34 | 2、探索如下三种方案下的蒸馏效果(均使用前向kl散度):\ 35 | 2.1 不微调学生模型+kl散度损失\ 36 | 蒸馏1个epoch,准确度70.5%\ 37 | 蒸馏2个epoch,准确度73%\ 38 | 2.2 微调学生模型(模型准确度80.3%)+kl散度损失\ 39 | 蒸馏2个epoch,准确度61.9%\ 40 | 2.3 不微调学生模型+kl散度损失和交叉熵损失加权\ 41 | 蒸馏2个epoch,70.5%\ 42 | 3、上述实验中只使用kl散度的效果最好,如下实验中使用kl散度的变种进行测试,经过测试,效果都不如前向kl散度效果好。\ 43 | 3.1 反向kl散度\ 44 | 准确率只有54%\ 45 | 3.2 偏向前向kl散度\ 46 | 损失下降异常,效果很差,不断重复输出。\ 47 | 由于资源和时间的限制,所有测试均保持相同的超参数,未针对不同损失设置不同超参数。 -------------------------------------------------------------------------------- /knowledge_distillation_llm/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | import os 8 | import pandas as pd 9 | 10 | from torch.utils.data import IterableDataset, Dataset 11 | import json 12 | import numpy as np 13 | from transformers import PreTrainedModel 14 | from transformers.modeling_outputs import CausalLMOutputWithPast 15 | from transformers import PretrainedConfig 16 | from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig 17 | 18 | class SFTDataset(Dataset): 19 | def __init__(self, data_path, tokenizer, max_seq_len): 20 | super().__init__() 21 | self.data_path = data_path 22 | self.tokenizer = tokenizer 23 | self.max_seq_len = max_seq_len 24 | self.padding_id = tokenizer.pad_token_id 25 | with open(self.data_path, 'r', encoding='utf-8') as f: 26 | self.data = json.load(f) 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, index): 32 | line = self.data[index] 33 | instruction_text = line['instruction'] 34 | input_text = line['input'] 35 | output_text = line['output'] 36 | query = instruction_text + input_text 37 | answer = output_text + self.tokenizer.eos_token 38 | messages = [] 39 | messages.append({'role': 'user', 'content': query}) 40 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) 41 | 42 | prompt_input_ids = self.tokenizer.encode(prompt) 43 | answer_input_ids = self.tokenizer.encode(answer) 44 | 45 | input_ids = prompt_input_ids + answer_input_ids 46 | labels = [-100] * len(prompt_input_ids) + answer_input_ids 47 | attention_mask = [1] * len(input_ids) 48 | text_len = len(input_ids) 49 | 50 | if text_len > self.max_seq_len: 51 | input_ids = input_ids[:self.max_seq_len] 52 | labels = labels[:self.max_seq_len] 53 | attention_mask = attention_mask[:self.max_seq_len] 54 | else: 55 | input_ids = input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - text_len) 56 | labels = labels + [-100] * (self.max_seq_len - text_len) 57 | attention_mask = attention_mask + [0] * (self.max_seq_len - text_len) 58 | 59 | # input_ids = input_ids[:-1] 60 | # labels = labels[1:] 61 | return {'input_ids': torch.tensor(input_ids), 'attention_mask':torch.tensor(attention_mask), 'labels': torch.tensor(labels)} -------------------------------------------------------------------------------- /knowledge_distillation_llm/images/fkl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/knowledge_distillation_llm/images/fkl.png -------------------------------------------------------------------------------- /knowledge_distillation_llm/images/fkl_formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/knowledge_distillation_llm/images/fkl_formula.png -------------------------------------------------------------------------------- /knowledge_distillation_llm/images/rkl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/knowledge_distillation_llm/images/rkl.png -------------------------------------------------------------------------------- /knowledge_distillation_llm/images/rkl_formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/knowledge_distillation_llm/images/rkl_formula.png -------------------------------------------------------------------------------- /knowledge_distillation_llm/train.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator 2 | from peft import LoraConfig, get_peft_model, TaskType 3 | from peft import PeftModel 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from transformers import Trainer, TrainingArguments 8 | from dataset import SFTDataset 9 | from utils import compute_fkl, compute_rkl, compute_skewed_fkl, compute_skewed_rkl 10 | 11 | 12 | class KGTrainer(Trainer): 13 | 14 | def __init__( 15 | self, 16 | model = None, 17 | teacher_model = None, 18 | if_use_entropy = False, 19 | args = None, 20 | data_collator = None, 21 | train_dataset = None, 22 | eval_dataset = None, 23 | tokenizer = None, 24 | model_init = None, 25 | compute_metrics = None, 26 | callbacks = None, 27 | optimizers = (None, None), 28 | preprocess_logits_for_metrics = None, 29 | ): 30 | super().__init__( 31 | model, 32 | args, 33 | data_collator, 34 | train_dataset, 35 | eval_dataset, 36 | tokenizer, 37 | model_init, 38 | compute_metrics, 39 | callbacks, 40 | optimizers, 41 | preprocess_logits_for_metrics, 42 | ) 43 | self.teacher_model = teacher_model 44 | self.if_use_entropy = if_use_entropy 45 | 46 | 47 | def compute_loss(self, model, inputs, return_outputs=False): 48 | 49 | outputs = model(**inputs) 50 | with torch.no_grad(): 51 | teacher_outputs = self.teacher_model(**inputs) 52 | 53 | loss = outputs.loss 54 | logits = outputs.logits 55 | teacher_logits = teacher_outputs.logits 56 | 57 | # 如果教师模型和学生模型输出形状不匹配,对学生模型进行padding或对教师模型进行截断 58 | if logits.shape[-1] != teacher_logits.shape[-1]: 59 | # gap = teacher_logits.shape[-1] - logits.shape[-1] 60 | # if gap > 0: 61 | # pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device) 62 | # logits = torch.cat([logits, pad_logits], dim=-1) 63 | 64 | teacher_logits = teacher_logits[:, :, :logits.shape[-1]] 65 | 66 | labels = inputs['labels'] 67 | kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0) 68 | 69 | if self.if_use_entropy: 70 | loss_total = 0.5 * kl + 0.5 * loss 71 | else: 72 | loss_total = kl 73 | 74 | return (loss_total, outputs) if return_outputs else loss_total 75 | 76 | 77 | if __name__ == '__main__': 78 | 79 | # 学生模型 80 | model = AutoModelForCausalLM.from_pretrained("Qwen2.5-0.5B-Instruct") 81 | 82 | lora_config = LoraConfig( 83 | r=8, 84 | lora_alpha=256, 85 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 86 | lora_dropout=0.1, 87 | task_type=TaskType.CAUSAL_LM) 88 | # 使用lora方法训练 89 | model = get_peft_model(model, lora_config) 90 | model.cuda() 91 | print(model.print_trainable_parameters()) 92 | 93 | tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-Instruct") 94 | 95 | # 教师模型,在给定数据上通过lora微调 96 | teacher_model = AutoModelForCausalLM.from_pretrained("Qwen2.5-7B-Instruct") 97 | # 是否加载lora模型 98 | lora_path = 'qwen2.5_7b/lora/sft' 99 | teacher_model = PeftModel.from_pretrained(teacher_model, lora_path) 100 | teacher_model.cuda() 101 | teacher_model.eval() 102 | 103 | args = TrainingArguments(output_dir='./results', 104 | num_train_epochs=10, 105 | do_train=True, 106 | per_device_train_batch_size=2, 107 | gradient_accumulation_steps=16, 108 | logging_steps=10, 109 | report_to='tensorboard', 110 | save_strategy='epoch', 111 | save_total_limit=10, 112 | bf16=True, 113 | learning_rate=0.0005, 114 | lr_scheduler_type='cosine', 115 | dataloader_num_workers=8, 116 | dataloader_pin_memory=True) 117 | data_collator = DefaultDataCollator() 118 | dataset = SFTDataset('data.json', tokenizer=tokenizer, max_seq_len=512) 119 | trainer = KGTrainer(model=model, 120 | teacher_model=teacher_model, 121 | if_use_entropy = True, 122 | args=args, 123 | train_dataset=dataset, 124 | tokenizer=tokenizer, 125 | data_collator=data_collator) 126 | # 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True 127 | trainer.train(resume_from_checkpoint=False) 128 | trainer.save_model('./saves') 129 | trainer.save_state() 130 | 131 | 132 | -------------------------------------------------------------------------------- /knowledge_distillation_llm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # 计算前向kl散度 4 | def compute_fkl( 5 | logits, 6 | teacher_logits, 7 | target, 8 | padding_id, 9 | reduction="sum", 10 | temp = 1.0, 11 | 12 | ): 13 | logits = logits / temp 14 | teacher_logits = teacher_logits / temp 15 | 16 | log_probs = torch.log_softmax(logits, -1, dtype=torch.float32) 17 | teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) 18 | teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32) 19 | kl = (teacher_probs * (teacher_log_probs - log_probs)) 20 | kl = kl.sum(-1) 21 | if reduction == "sum": 22 | pad_mask = target.eq(padding_id) 23 | kl = kl.masked_fill_(pad_mask, 0.0) 24 | kl = kl.sum() 25 | 26 | return kl 27 | # 计算反向kl散度 28 | def compute_rkl( 29 | logits, 30 | teacher_logits, 31 | target, 32 | padding_id, 33 | reduction="sum", 34 | temp = 1.0 35 | ): 36 | logits = logits / temp 37 | teacher_logits = teacher_logits / temp 38 | 39 | probs = torch.softmax(logits, -1, dtype=torch.float32) 40 | log_probs = torch.log_softmax(logits, -1, dtype=torch.float32) 41 | teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32) 42 | kl = (probs * (log_probs - teacher_log_probs)) 43 | kl = kl.sum(-1) 44 | if reduction == "sum": 45 | pad_mask = target.eq(padding_id) 46 | kl = kl.masked_fill_(pad_mask, 0.0) 47 | kl = kl.sum() 48 | return kl 49 | 50 | # 计算偏向前kl散度 51 | def compute_skewed_fkl( 52 | logits, 53 | teacher_logits, 54 | target, 55 | padding_id, 56 | reduction="sum", 57 | temp = 1.0, 58 | skew_lambda = 0.1 59 | ): 60 | logits = logits / temp 61 | teacher_logits = teacher_logits / temp 62 | 63 | probs = torch.softmax(logits, -1, dtype=torch.float32) 64 | teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) 65 | mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * probs 66 | mixed_log_probs = torch.log(mixed_probs) 67 | teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32) 68 | kl = (teacher_probs * (teacher_log_probs - mixed_log_probs)) 69 | kl = kl.sum(-1) 70 | if reduction == "sum": 71 | pad_mask = target.eq(padding_id) 72 | kl = kl.masked_fill_(pad_mask, 0.0) 73 | kl = kl.sum() 74 | 75 | 76 | return kl 77 | # 计算偏向反kl散度 78 | def compute_skewed_rkl( 79 | logits, 80 | teacher_logits, 81 | target, 82 | padding_id, 83 | reduction="sum", 84 | temp = 1.0, 85 | skew_lambda = 0.1 86 | ): 87 | logits = logits / temp 88 | teacher_logits = teacher_logits / temp 89 | 90 | probs = torch.softmax(logits, -1, dtype=torch.float32) 91 | teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) 92 | mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * probs 93 | mixed_log_probs = torch.log(mixed_probs) 94 | log_probs = torch.log_softmax(logits, -1, dtype=torch.float32) 95 | kl = (probs * (log_probs - mixed_log_probs)) 96 | kl = kl.sum(-1) 97 | 98 | if reduction == "sum": 99 | pad_mask = target.eq(padding_id) 100 | kl = kl.masked_fill_(pad_mask, 0.0) 101 | kl = kl.sum() 102 | 103 | 104 | return kl -------------------------------------------------------------------------------- /pdf2markdown/README.md: -------------------------------------------------------------------------------- 1 | # pdf2markdown 2 | ## 介绍 3 | 对gptpdf进行二次开发,使用版面分析模型替换原项目中对pdf的处理逻辑,标记出表格和图片,通过多模态大模型生成markdown格式内容。 4 | 5 | ## 使用方法 6 | 1、将代码中的模型路径换成自己的本地路径 7 | 8 | 2、运行pdf2markdown.py文件 9 | -------------------------------------------------------------------------------- /pdf2markdown/pdf2markdown.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List, Tuple, Optional, Dict 4 | import logging 5 | import cv2 6 | 7 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 8 | import fitz 9 | import shapely.geometry as sg 10 | from shapely.geometry.base import BaseGeometry 11 | from shapely.validation import explain_validity 12 | import concurrent.futures 13 | 14 | import torch 15 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 16 | from qwen_vl_utils import process_vision_info 17 | import numpy as np 18 | import fitz 19 | import logging 20 | import os 21 | output_dir = 'output' 22 | os.makedirs(output_dir, exist_ok=True) 23 | from PIL import Image 24 | 25 | from rapid_layout import RapidLayout, VisLayout 26 | 27 | layout_engine = RapidLayout(conf_thres=0.5, model_type="pp_layout_cdla") 28 | 29 | model = Qwen2VLForConditionalGeneration.from_pretrained("/home/user/Downloads/Qwen2-VL-7B-Instruct/", 30 | torch_dtype=torch.bfloat16, 31 | device_map="auto") 32 | 33 | min_pixels = 256*28*28 34 | max_pixels = 1280*28*28 35 | processor = AutoProcessor.from_pretrained("/home/user/Downloads/Qwen2-VL-7B-Instruct/", 36 | min_pixels=min_pixels, 37 | max_pixels=max_pixels) 38 | 39 | # This Default Prompt Using Chinese and could be changed to other languages. 40 | DEFAULT_PROMPT = """使用markdown语法,将图片中识别到的文字转换为markdown格式输出。你必须做到: 41 | 1. 输出和使用识别到的图片的相同的语言,例如,识别到英语的字段,输出的内容必须是英语。 42 | 2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。 43 | 3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式、忽略掉长直线、忽略掉页码。 44 | 再次强调,不要解释和输出无关的文字,直接输出图片中的内容。 45 | """ 46 | DEFAULT_RECT_PROMPT = """图片中用带颜色的矩形框和名称(%s)标注出了一些区域。如果区域是表格或者图片,使用 ![]() 的形式插入到输出内容中,否则直接输出文字内容。 47 | """ 48 | DEFAULT_ROLE_PROMPT = """你是一个PDF文档解析器,使用markdown和latex语法输出图片的内容。 49 | """ 50 | 51 | def _parse_pdf_to_images(pdf_path: str, output_dir: str = './output') -> List[Tuple[str, List[str]]]: 52 | image_infos = [] 53 | pdf_document = fitz.open(pdf_path) 54 | for page_index, page in enumerate(pdf_document): 55 | rect_images = [] 56 | logging.info(f'parse page: {page_index}') 57 | # 保存页面为图片 58 | pix = page.get_pixmap(matrix=fitz.Matrix(4, 4)) 59 | pix = Image.frombytes('RGB', [pix.width, pix.height], pix.samples) 60 | boxes, scores, class_names, elapse = layout_engine(pix) 61 | for index, (class_name, box) in enumerate(zip(class_names, boxes)): 62 | if class_name == 'figure' or class_name == 'table': 63 | name = f'{page_index}_{index}.png' 64 | sub_pix = pix.crop(box) 65 | sub_pix.save(os.path.join(output_dir, name)) 66 | rect_images.append(name) 67 | 68 | boxes_ = [] 69 | scores_ = [] 70 | class_names_ = [] 71 | for i, (class_name, box, score) in enumerate(zip(class_names, boxes, scores)): 72 | if class_name == 'figure' or class_name == 'table': 73 | boxes_.append(box) 74 | scores_.append(score) 75 | class_name = f'{page_index}_{i}.png' 76 | class_names_.append(class_name) 77 | 78 | page_image = os.path.join(output_dir, f'{page_index}.png') 79 | pix = np.array(pix) 80 | pix = cv2.cvtColor(pix, cv2.COLOR_RGB2BGR) 81 | print(boxes_, scores_, class_names_) 82 | ploted_img = VisLayout.draw_detections(pix, boxes_, scores_, class_names_) 83 | if ploted_img is not None: 84 | cv2.imwrite(page_image, ploted_img) 85 | # ploted_img.save(page_image) 86 | image_infos.append((page_image, rect_images)) 87 | pdf_document.close() 88 | return image_infos 89 | 90 | 91 | def _gpt_parse_images( 92 | image_infos: List[Tuple[str, List[str]]], 93 | prompt_dict: Optional[Dict] = None, 94 | output_dir: str = './', 95 | api_key: Optional[str] = None, 96 | base_url: Optional[str] = None, 97 | # model: str = 'gpt-4o', 98 | verbose: bool = False, 99 | gpt_worker: int = 1, 100 | **args 101 | ) -> str: 102 | """ 103 | Parse images to markdown content. 104 | """ 105 | 106 | if isinstance(prompt_dict, dict) and 'prompt' in prompt_dict: 107 | prompt = prompt_dict['prompt'] 108 | logging.info("prompt is provided, using user prompt.") 109 | else: 110 | prompt = DEFAULT_PROMPT 111 | logging.info("prompt is not provided, using default prompt.") 112 | if isinstance(prompt_dict, dict) and 'rect_prompt' in prompt_dict: 113 | rect_prompt = prompt_dict['rect_prompt'] 114 | logging.info("rect_prompt is provided, using user prompt.") 115 | else: 116 | rect_prompt = DEFAULT_RECT_PROMPT 117 | logging.info("rect_prompt is not provided, using default prompt.") 118 | if isinstance(prompt_dict, dict) and 'role_prompt' in prompt_dict: 119 | role_prompt = prompt_dict['role_prompt'] 120 | logging.info("role_prompt is provided, using user prompt.") 121 | else: 122 | role_prompt = DEFAULT_ROLE_PROMPT 123 | logging.info("role_prompt is not provided, using default prompt.") 124 | 125 | def _process_page(index: int, image_info: Tuple[str, List[str]]) -> Tuple[int, str]: 126 | logging.info(f'gpt parse page: {index}') 127 | 128 | # agent = Agent(role=role_prompt, api_key=api_key, base_url=base_url, disable_python_run=True, model=model, **args) 129 | page_image, rect_images = image_info 130 | local_prompt = prompt 131 | local_prompt = role_prompt + local_prompt 132 | if rect_images: 133 | local_prompt += rect_prompt + ', '.join(rect_images) 134 | # content = agent.run([local_prompt, {'image': page_image}], display=verbose) 135 | messages = [ 136 | { 137 | "role": "user", 138 | "content": [ 139 | { 140 | "type": "image", 141 | "image": page_image, 142 | }, 143 | {"type": "text", "text": local_prompt}, 144 | ], 145 | } 146 | ] 147 | text = processor.apply_chat_template( 148 | messages, tokenize=False, add_generation_prompt=True 149 | ) 150 | print(text) 151 | image_inputs, video_inputs = process_vision_info(messages) 152 | inputs = processor( 153 | text=[text], 154 | images=image_inputs, 155 | videos=video_inputs, 156 | padding=True, 157 | return_tensors="pt", 158 | ) 159 | inputs = inputs.to("cuda") 160 | 161 | # Inference: Generation of the output 162 | generated_ids = model.generate(**inputs, max_new_tokens=2000, num_beams=1) 163 | generated_ids_trimmed = [ 164 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 165 | ] 166 | output_text = processor.batch_decode( 167 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 168 | ) 169 | return index, output_text 170 | 171 | contents = [None] * len(image_infos) 172 | with concurrent.futures.ThreadPoolExecutor(max_workers=gpt_worker) as executor: 173 | futures = [executor.submit(_process_page, index, image_info) for index, image_info in enumerate(image_infos)] 174 | for future in concurrent.futures.as_completed(futures): 175 | index, content = future.result() 176 | content = content[0] 177 | print(content) 178 | 179 | # 在某些情况下大模型还是会输出 ```markdown ```字符串 180 | if '```markdown' in content: 181 | content = content.replace('```markdown\n', '') 182 | last_backticks_pos = content.rfind('```') 183 | if last_backticks_pos != -1: 184 | content = content[:last_backticks_pos] + content[last_backticks_pos + 3:] 185 | 186 | contents[index] = content 187 | 188 | output_path = os.path.join(output_dir, 'output.md') 189 | with open(output_path, 'w', encoding='utf-8') as f: 190 | f.write('\n\n'.join(contents)) 191 | 192 | return '\n\n'.join(contents) 193 | 194 | 195 | def parse_pdf( 196 | pdf_path: str, 197 | output_dir: str = './', 198 | prompt: Optional[Dict] = None, 199 | api_key: Optional[str] = None, 200 | base_url: Optional[str] = None, 201 | model: str = 'gpt-4o', 202 | verbose: bool = False, 203 | gpt_worker: int = 1, 204 | **args 205 | ) -> Tuple[str, List[str]]: 206 | """ 207 | Parse a PDF file to a markdown file. 208 | """ 209 | if not os.path.exists(output_dir): 210 | os.makedirs(output_dir) 211 | 212 | image_infos = _parse_pdf_to_images(pdf_path, output_dir=output_dir) 213 | print(image_infos) 214 | content = _gpt_parse_images( 215 | image_infos=image_infos, 216 | output_dir=output_dir, 217 | prompt_dict=prompt, 218 | api_key=api_key, 219 | base_url=base_url, 220 | model=model, 221 | verbose=verbose, 222 | gpt_worker=gpt_worker, 223 | **args 224 | ) 225 | 226 | all_rect_images = [] 227 | # remove all rect images 228 | if not verbose: 229 | for page_image, rect_images in image_infos: 230 | if os.path.exists(page_image): 231 | os.remove(page_image) 232 | all_rect_images.extend(rect_images) 233 | return content, all_rect_images 234 | 235 | 236 | result = parse_pdf( 237 | pdf_path='/home/user/wyf/test.pdf', 238 | output_dir="./output", 239 | verbose=True, 240 | gpt_worker=1 241 | ) -------------------------------------------------------------------------------- /ppo_from_scratch/ppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/ppo_from_scratch/ppo.png -------------------------------------------------------------------------------- /s1_from_scratch/README.md: -------------------------------------------------------------------------------- 1 | # S1模型 2 | 3 | ## 概述 4 | 以**不足50美元的云计算成本**完成训练,在数学推理和代码生成等任务中展现出与OpenAI的o1、DeepSeek的R1等顶尖模型相媲美的性能。 5 | 6 | --- 7 | 8 | ## 实现流程 9 | ### 数据来源 10 | 16 个不同的来源收集59029个问题,主要包括数学问题。 11 | 通过Google Gemini Flash Thinking API生成推理轨迹,得到三元组(问题,推理,答案)(蒸馏Gemini,之后使用DeepSeek r1重新生成轨迹,性能进一步提升) 12 | ### 数据处理 13 | 1、质量过滤:剔除调用API报错的样本,保留54116个,剔除格式有问题的样本,保留51,581个\ 14 | 2、 研究人员筛选出高质量且无需进一步筛选的数据共384个样本作为最终1,000 个样本的一部分\ 15 | 3、难度过滤:Qwen2.5-7B-Instruct 和 Qwen2.5-32B-Instruct生成问题答案,Claude 3.5 Sonnet评估正确性,去除模型回答正确和推理轨迹短的样本,保留24496个\ 16 | 4、Claude 3.5 Sonnet将数据按数学主题分类(MSC),覆盖50个领域。随机选择一个领域,按照推理链长短进行排序,优先选择长度较长的,不断迭代,直到得到1000个样本(加上之前的384个) 17 | ### 微调 18 | 使用1000条数据微调Qwen2.5-32B-Instruct 19 | ### 测试时扩展 20 | **预算强制**控制模型思考过程,进一步提升模型的性能 -------------------------------------------------------------------------------- /s1_from_scratch/generate.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | from transformers import AutoTokenizer 3 | 4 | tokenizer = AutoTokenizer.from_pretrained("/home/user/Downloads/DeepSeek-R1-Distill-Qwen-1.5B") 5 | llm = LLM(model="/home/user/Downloads/DeepSeek-R1-Distill-Qwen-1.5B", gpu_memory_utilization=0.15) 6 | 7 | sampling_params = SamplingParams( 8 | temperature=0, 9 | max_tokens=32768, 10 | skip_special_tokens=False 11 | ) 12 | 13 | prompt = '9.11和9.8谁大?' 14 | prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n" 15 | 16 | # 模型原始输出部分 17 | outputs = llm.generate( 18 | prompt, 19 | sampling_params 20 | ) 21 | print(f'原始输出:{prompt}{outputs[0].outputs[0].text}') 22 | print('+'*20) 23 | 24 | sampling_params = SamplingParams( 25 | temperature=0, 26 | max_tokens=32768, 27 | stop='', 28 | skip_special_tokens=False 29 | ) 30 | 31 | outputs = llm.generate( 32 | prompt, 33 | sampling_params 34 | ) 35 | wait = 'Wait' 36 | for i in range(1): 37 | prompt += outputs[0].outputs[0].text + wait 38 | 39 | outputs = llm.generate( 40 | prompt, 41 | sampling_params 42 | ) 43 | 44 | print(f'wait后的输出:{prompt}{outputs[0].outputs[0].text}') 45 | print('+'*20) 46 | prompt += outputs[0].outputs[0].text 47 | stop_token_ids = tokenizer("<|im_end|>")["input_ids"] 48 | sampling_params = SamplingParams( 49 | max_tokens=32768, 50 | min_tokens=0, 51 | stop_token_ids=stop_token_ids, 52 | skip_special_tokens=False, 53 | temperature=0.0, 54 | ) 55 | outputs = llm.generate( 56 | prompt, 57 | sampling_params=sampling_params, 58 | ) 59 | 60 | print(f'最后的输出:{prompt}{outputs[0].outputs[0].text}') 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /s1_from_scratch/s1_train.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments 2 | from transformers import DefaultDataCollator 3 | from peft import LoraConfig, get_peft_model, TaskType 4 | from datasets import load_dataset 5 | from torch.utils.data import Dataset 6 | import torch 7 | 8 | THINK_TAG = '<|im_start|>think' 9 | ANSWER_TAG = '<|im_start|>answer' 10 | END_TAG = '<|im_end|>' 11 | 12 | class S1Dataset(Dataset): 13 | def __init__(self, ds, tokenizer, max_length=4096): 14 | self.ds = ds 15 | self.max_length = max_length 16 | self.tokenizer = tokenizer 17 | 18 | def __getitem__(self, index): 19 | sample = self.ds[index] 20 | question = sample['question'] 21 | gemini_thinking_trajectory = sample['gemini_thinking_trajectory'] 22 | gemini_attempt = sample['gemini_attempt'] 23 | 24 | q = self.tokenizer.apply_chat_template([{"role": "user", "content": question}], tokenize=False, add_generation_prompt=True) 25 | a = THINK_TAG + gemini_thinking_trajectory + ANSWER_TAG + gemini_attempt + END_TAG 26 | 27 | q_input_ids = self.tokenizer.encode(q) 28 | a_input_ids = self.tokenizer.encode(a) 29 | 30 | input_ids = q_input_ids + a_input_ids 31 | attention_mask = [1] * len(input_ids) 32 | labels = [-100] * len(q_input_ids) + a_input_ids 33 | 34 | if len(input_ids) > self.max_length: 35 | input_ids = input_ids[:self.max_length] 36 | attention_mask = attention_mask[:self.max_length] 37 | labels = labels[:self.max_length] 38 | else: 39 | padding_len = self.max_length - len(input_ids) 40 | input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_len 41 | attention_mask = attention_mask + [0] * padding_len 42 | labels = labels + [-100] * padding_len 43 | 44 | return {"input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long)} 45 | 46 | def __len__(self): 47 | return len(self.ds) 48 | 49 | if __name__ == "__main__": 50 | 51 | model = AutoModelForCausalLM.from_pretrained("/home/user/Downloads/Qwen2.5-0.5B-Instruct") 52 | tokenizer = AutoTokenizer.from_pretrained("/home/user/Downloads/Qwen2.5-0.5B-Instruct") 53 | lora_config = LoraConfig( 54 | task_type=TaskType.CAUSAL_LM, 55 | r=8, 56 | lora_alpha=32, 57 | lora_dropout=0.1, 58 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 59 | ) 60 | 61 | model = get_peft_model(model, lora_config) 62 | model.print_trainable_parameters() 63 | 64 | ds = load_dataset('./s1K-1.1') 65 | data_collator = DefaultDataCollator() 66 | 67 | args = TrainingArguments( 68 | output_dir="./s1", 69 | learning_rate=1e-5, 70 | per_device_train_batch_size=2, 71 | gradient_accumulation_steps=4, 72 | num_train_epochs=3, 73 | weight_decay=0.01, 74 | logging_steps=10, 75 | save_strategy="steps", 76 | save_steps=100, 77 | save_total_limit=3, 78 | report_to="tensorboard", 79 | bf16=True 80 | ) 81 | 82 | train_dataset = S1Dataset(ds['train'], tokenizer, max_length=1024) 83 | trainer = Trainer( 84 | model=model, 85 | args=args, 86 | train_dataset=train_dataset, 87 | tokenizer=tokenizer, 88 | data_collator=data_collator 89 | ) 90 | 91 | trainer.train() 92 | trainer.save_model() 93 | trainer.save_state() 94 | -------------------------------------------------------------------------------- /s1_from_scratch/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/s1_from_scratch/training_loss.png -------------------------------------------------------------------------------- /table_extract/README.md: -------------------------------------------------------------------------------- 1 | # table2txt 2 | 3 | ## 介绍 4 | 支持图片或者pdf中的普通文本提取和表格中文本的提取(pdf需要先转成图片),并保持其结构化排版布局(尽量保持其结构,不完美) 5 | 6 | 可参考如下示例: 7 | 8 | 需要提取的图片: 9 | 10 | ![需要提取的图片](./imgs/2.jpg "表格") 11 | 12 | 提取之后: 13 | 14 | ![提取之后的图片](./imgs/extract_2.png "表格") 15 | 16 | ## 使用方法 17 | 18 | 1、下载模型 19 | 20 | modelscope下载表格提取模型,并修改代码中相关路径 21 | 22 | https://modelscope.cn/models/iic/cv_dla34_table-structure-recognition_cycle-centernet 23 | 24 | 2、修改代码中需要提取的图片路径 25 | 26 | ## 注意 27 | 28 | 有时会出现调整完坐标之后的效果图片无法绘制的情况,可忽略,文字可正常提取 29 | 30 | 代码中有不完善的地方,可根据需要自行修改 31 | -------------------------------------------------------------------------------- /table_extract/chinese_cht.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/table_extract/chinese_cht.ttf -------------------------------------------------------------------------------- /table_extract/imgs/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/table_extract/imgs/2.jpg -------------------------------------------------------------------------------- /table_extract/imgs/extract_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/table_extract/imgs/extract_2.png -------------------------------------------------------------------------------- /table_rag.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "vscode": { 8 | "languageId": "plaintext" 9 | } 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "from docx import Document\n", 14 | "# 读取word中表格,处理成表头加行的形式\n", 15 | "def read_table_from_word(file_path):\n", 16 | " all_rows = []\n", 17 | " doc = Document(file_path)\n", 18 | " for i, table in enumerate(doc.tables):\n", 19 | " for row in table.rows:\n", 20 | " cells_text = [cell.text.replace('\\n', '') for cell in row.cells]\n", 21 | " all_rows.append(cells_text)\n", 22 | " new_all_rows=[]\n", 23 | " i = 0\n", 24 | " # 合并跨页的表格行\n", 25 | " while i' 36 | input_ids = self.tokenizer.encode(text) 37 | text_len = len(input_ids) 38 | if text_len > self.max_seq_len: 39 | input_ids = input_ids[:self.max_seq_len] 40 | else: 41 | input_ids = input_ids + [0] * (self.max_seq_len - text_len) 42 | input_ids = np.array(input_ids) 43 | X = np.array(input_ids[:-1]).astype(np.int64) 44 | Y = np.array(input_ids[1:]).astype(np.int64) 45 | return { 46 | 'input_ids': torch.from_numpy(X), 47 | 'labels': torch.from_numpy(Y), 48 | } 49 | 50 | class SFTDataset(Dataset): 51 | def __init__(self, data_path, tokenizer, max_seq_len): 52 | super().__init__() 53 | self.data_path = data_path 54 | self.tokenizer = tokenizer 55 | self.max_seq_len = max_seq_len 56 | 57 | with open(self.data_path, 'r', encoding='utf-8') as f: 58 | self.data = f.readlines() 59 | 60 | def __len__(self): 61 | return len(self.data) 62 | 63 | def __getitem__(self, index): 64 | line = self.data[index] 65 | line = json.loads(line) 66 | instruction_text = line['instruction'] 67 | input_text = line['input'] 68 | output_text = line['output'] 69 | history = line['history'] 70 | query = instruction_text + input_text 71 | answer = output_text + self.tokenizer.eos_token 72 | messages = [] 73 | if history: 74 | for i in history: 75 | messages.append({'role': 'user', 'content': i[0]}) 76 | messages.append({'role': 'assistant', 'content': i[1]}) 77 | 78 | messages.append({'role': 'user', 'content': query}) 79 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) 80 | prompt_input_ids = self.tokenizer.encode(prompt) 81 | answer_input_ids = self.tokenizer.encode(answer) 82 | input_ids = prompt_input_ids + answer_input_ids 83 | labels = [0] * len(prompt_input_ids) + answer_input_ids 84 | text_len = len(input_ids) 85 | if text_len > self.max_seq_len: 86 | input_ids = input_ids[:self.max_seq_len] 87 | labels = labels[:self.max_seq_len] 88 | else: 89 | input_ids = input_ids + [0] * (self.max_seq_len - text_len) 90 | labels = labels + [0] * (self.max_seq_len - text_len) 91 | 92 | input_ids = input_ids[:-1] 93 | labels = labels[1:] 94 | return {'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels)} 95 | 96 | 97 | # 内存不够,可使用如下方法加载数据 98 | # class LLMDataset(IterableDataset): 99 | # def __init__(self, data_path, tokenizer, max_seq_len): 100 | # super().__init__() 101 | # self.data_path = data_path 102 | # self.tokenizer = tokenizer 103 | # self.max_seq_len = max_seq_len 104 | 105 | # def __iter__(self): 106 | # return self.data_process() 107 | 108 | # def data_process(self): 109 | # with open(self.data_path, 'r', encoding='utf-8') as f: 110 | # for line in f: 111 | # line = json.loads(line) 112 | # text = '' + line['text'] + '' 113 | # input_ids = self.tokenizer.encode(text) 114 | # text_len = len(input_ids) 115 | # if text_len > self.max_seq_len: 116 | # input_ids = input_ids[:self.max_seq_len] 117 | # else: 118 | # input_ids = input_ids + [0] * (self.max_seq_len - text_len) 119 | # input_ids = np.array(input_ids) 120 | # X = np.array(input_ids[:-1]).astype(np.int64) 121 | # Y = np.array(input_ids[1:]).astype(np.int64) 122 | # yield { 123 | # 'input_ids': torch.from_numpy(X), 124 | # 'labels': torch.from_numpy(Y), 125 | # } 126 | 127 | class DPODataset(Dataset): 128 | def __init__(self, data_path, tokenizer): 129 | super().__init__() 130 | self.data_path = data_path 131 | self.tokenizer = tokenizer 132 | 133 | with open(self.data_path, 'r', encoding='utf-8') as f: 134 | self.datas = json.load(f) 135 | 136 | def __getitem__(self, index): 137 | sample = self.datas[index] 138 | prompt = sample['prompt'] 139 | chosen = sample['chosen'] 140 | rejected = sample['rejected'] 141 | messages = [ 142 | {"role": "user", "content": prompt} 143 | ] 144 | text = self.tokenizer.apply_chat_template( 145 | messages, 146 | tokenize=False, 147 | add_generation_prompt=True 148 | ) 149 | prompt_inputs = self.tokenizer(text=text)['input_ids'] 150 | rejected_inputs = self.tokenizer(text=rejected)['input_ids'] + [self.tokenizer.eos_token_id] 151 | chosen_inputs = self.tokenizer(text=chosen)['input_ids'] + [self.tokenizer.eos_token_id] 152 | return [prompt_inputs, chosen_inputs, rejected_inputs] 153 | 154 | def __len__(self): 155 | return len(self.datas) 156 | 157 | 158 | class DPODataCollator: 159 | def __init__(self, tokenizer, max_seq_len): 160 | self.tokenizer = tokenizer 161 | self.max_seq_len = max_seq_len 162 | def __call__(self, features): 163 | inputs_ids = [] 164 | labels = [] 165 | 166 | for feature in features: 167 | inputs_ids.append(feature[0] + feature[1]) 168 | labels.append([0]*len(feature[0]) + feature[1]) 169 | for feature in features: 170 | inputs_ids.append(feature[0] + feature[2]) 171 | labels.append([0]*len(feature[0]) + feature[2]) 172 | 173 | def process(inputs_ids, labels): 174 | inputs_ids = [input_ids[:self.max_seq_len] for input_ids in inputs_ids] 175 | labels = [label[:self.max_seq_len] for label in labels] 176 | max_len = max([len(input_ids) for input_ids in inputs_ids]) 177 | batch_input_ids = [] 178 | batch_labels = [] 179 | 180 | for input_ids, label in zip(inputs_ids, labels): 181 | if len(input_ids) <= max_len: 182 | input_ids = input_ids+[0]*(max_len-len(input_ids)) 183 | label = label+[0]*(max_len-len(label)) 184 | batch_input_ids.append(input_ids[:-1]) 185 | batch_labels.append(label[1:]) 186 | return batch_input_ids, batch_labels 187 | 188 | inputs_ids, labels = process(inputs_ids, labels) 189 | 190 | return { 191 | "input_ids": torch.tensor(inputs_ids), 192 | "labels": torch.tensor(labels) 193 | } 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /train_llm_from_scratch/dpo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/train_llm_from_scratch/dpo.png -------------------------------------------------------------------------------- /train_llm_from_scratch/dpo_train.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments, Trainer, AutoModelForCausalLM, AutoTokenizer, AutoConfig 2 | import torch 3 | import torch.nn.functional as F 4 | from dataset import DPODataset, DPODataCollator 5 | from train import LLM, Config 6 | 7 | 8 | def logits_to_probs(logits, labels): 9 | # logits shape: (batch_size, seq_len, vocab_size) 10 | # labels shape: (batch_size, seq_len) 11 | # probs shape: (batch_size, seq_len) 12 | log_probs = F.log_softmax(logits, dim=2) 13 | probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1) 14 | return probs 15 | 16 | def mask_logits(logits, labels): 17 | # logits shape: (batch_size, seq_len, vocab_size) 18 | # labels_masks shape: (batch_size, seq_len) 19 | new_logits = [] 20 | for logit, label in zip(logits, labels): 21 | new_logits.append(logit[label != 0].sum().unsqueeze(0)) 22 | 23 | return new_logits 24 | 25 | 26 | def dpo_loss(ref_probs, probs, beta): 27 | def split_probs(probs): 28 | len_chosen = int(len(probs) // 2) 29 | chosen_data = probs[:len_chosen] 30 | reject_data = probs[len_chosen:] 31 | return torch.cat(chosen_data), torch.cat(reject_data) 32 | 33 | ref_chosen_probs, ref_reject_probs = split_probs(ref_probs) 34 | chosen_probs, reject_probs = split_probs(probs) 35 | pi_logratios = chosen_probs - reject_probs 36 | ref_logratios = ref_chosen_probs - ref_reject_probs 37 | logits = pi_logratios - ref_logratios 38 | loss = -F.logsigmoid(beta*logits) 39 | return loss.mean() 40 | 41 | 42 | 43 | class DPOTrainer(Trainer): 44 | 45 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 46 | input_ids = inputs['input_ids'] 47 | labels = inputs['labels'] 48 | with torch.no_grad(): 49 | ref_logits = ref_model(input_ids=input_ids, labels = labels).logits 50 | ref_probs = logits_to_probs(ref_logits, labels) 51 | ref_probs = mask_logits(ref_probs, labels) 52 | logits = model(input_ids=input_ids, labels = labels).logits 53 | probs = logits_to_probs(logits, labels) 54 | probs = mask_logits(probs, labels) 55 | loss = dpo_loss(ref_probs, probs, 0.1) 56 | return loss 57 | 58 | # def training_step( 59 | # self, model, inputs, num_items_in_batch=None 60 | # ) -> torch.Tensor: 61 | # input_ids = inputs['input_ids'] 62 | # labels = inputs['labels'] 63 | # with torch.no_grad(): 64 | # ref_logits = ref_model(input_ids=input_ids, labels = labels).logits 65 | # ref_probs = logits_to_probs(ref_logits, labels) 66 | # ref_probs = mask_logits(ref_probs, labels) 67 | # # 因为参考模型的累计概率不发生变化,为了尽量减少多次计算,计算一次参考模型的累积概率,多训练几次需要优化的模型 68 | # for _ in range(1): 69 | 70 | # model.train() 71 | # logits = model(input_ids=input_ids, labels = labels).logits 72 | # probs = logits_to_probs(logits, labels) 73 | # probs = mask_logits(probs, labels) 74 | 75 | # if hasattr(self.optimizer, "train") and callable(self.optimizer.train): 76 | # self.optimizer.train() 77 | 78 | # with self.compute_loss_context_manager(): 79 | # # loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) 80 | # loss = dpo_loss(ref_probs, probs, 0.2) 81 | 82 | # # del inputs 83 | # if ( 84 | # self.args.torch_empty_cache_steps is not None 85 | # and self.state.global_step % self.args.torch_empty_cache_steps == 0 86 | # ): 87 | 88 | # torch.cuda.empty_cache() 89 | 90 | # kwargs = {} 91 | 92 | # if self.args.n_gpu > 1: 93 | # loss = loss.mean() # mean() to average on multi-gpu parallel training 94 | 95 | # self.accelerator.backward(loss, retain_graph=True, **kwargs) 96 | # # Finally we need to normalize the loss for reporting 97 | # if num_items_in_batch is None: 98 | # return loss.detach() / self.args.gradient_accumulation_steps 99 | # return loss.detach() 100 | 101 | 102 | if __name__ == "__main__": 103 | AutoConfig.register("small_model", Config) 104 | AutoModelForCausalLM.register(Config, LLM) 105 | model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/sft') 106 | 107 | print(f'模型可训练参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') 108 | ref_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/sft').eval().to('cuda') 109 | 110 | tokenizer = AutoTokenizer.from_pretrained("/home/user/wyf/train_model_from_scratch/tokenizer", use_fast=True) 111 | data_collator = DPODataCollator(tokenizer, max_seq_len=512) # 加载的大模型旋转位置编码最大长度为1024,这里不能超过这个值 112 | args = TrainingArguments(output_dir='./dpo-1-epoch', 113 | num_train_epochs=1, # 训练太多轮,模型似乎会输出很多重复内容 114 | do_train=True, 115 | per_device_train_batch_size=16, 116 | gradient_accumulation_steps=4, 117 | # max_steps=15000, 118 | logging_steps=50, 119 | report_to='tensorboard', 120 | save_total_limit=3, 121 | bf16=True, 122 | learning_rate=0.00001, # 学习率很重要,太大会把模型训飞 123 | lr_scheduler_type='cosine', 124 | dataloader_num_workers=1, 125 | dataloader_pin_memory=True, 126 | save_safetensors=False, 127 | save_steps=100) 128 | dataset = DPODataset('/home/user/wyf/train_model_from_scratch/dataset/dpo_data_512.json', tokenizer=tokenizer) 129 | trainer = DPOTrainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator) 130 | 131 | # 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True 132 | trainer.train(resume_from_checkpoint=False) 133 | trainer.save_model('./saves/dpo-1-epoch') 134 | trainer.save_state() -------------------------------------------------------------------------------- /train_llm_from_scratch/screenshot-20241207-093824.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/train_llm_from_scratch/screenshot-20241207-093824.png -------------------------------------------------------------------------------- /train_llm_from_scratch/sft_train.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | import os 8 | import pandas as pd 9 | 10 | from torch.utils.data import IterableDataset, Dataset 11 | import json 12 | import numpy as np 13 | from transformers import PreTrainedModel 14 | from transformers.modeling_outputs import CausalLMOutputWithPast 15 | from transformers import PretrainedConfig 16 | from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig 17 | from dataset import SFTDataset, LLMDataset 18 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 19 | import torch 20 | from train import LLM, Config 21 | 22 | if __name__ == '__main__': 23 | AutoConfig.register("small_model", Config) 24 | AutoModelForCausalLM.register(Config, LLM) 25 | model = AutoModelForCausalLM.from_pretrained('./saves/model') 26 | print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') 27 | 28 | data_collator = DefaultDataCollator() 29 | tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True) 30 | args = TrainingArguments(output_dir='./sft', 31 | num_train_epochs=5, 32 | do_train=True, 33 | per_device_train_batch_size=64, 34 | gradient_accumulation_steps=8, 35 | # max_steps=15000, 36 | logging_steps=100, 37 | report_to='tensorboard', 38 | save_total_limit=5, 39 | bf16=True, 40 | learning_rate=2e-4, 41 | lr_scheduler_type='cosine', 42 | dataloader_num_workers=1, 43 | dataloader_pin_memory=True, 44 | save_safetensors=False) 45 | dataset = SFTDataset('./sft_data_zh.jsonl', tokenizer=tokenizer, max_seq_len=1024) 46 | trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator) 47 | # 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True 48 | trainer.train(resume_from_checkpoint=False) 49 | trainer.save_model('./saves/sft') 50 | trainer.save_state() 51 | -------------------------------------------------------------------------------- /train_llm_from_scratch/test_llm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig\n", 10 | "import torch\n", 11 | "from train import LLM, Config" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 5, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "t = AutoTokenizer.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/pretrain')\n", 21 | "AutoConfig.register(\"small_model\", Config)\n", 22 | "AutoModelForCausalLM.register(Config, LLM)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "\n", 32 | "model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/pretrain')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 187, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "[2, 20, 14, 20, 6239]\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "input_data = [t.bos_token_id] + t.encode('1+1等于')\n", 50 | "print(input_data)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 188, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "1.5,即1.5,即1.5,\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "for token in model.generate({\"input_ids\":torch.tensor(input_data).unsqueeze(0), \"labels\":None}, t.eos_token_id, 20, stream=False,temperature=0.0, top_k=8):\n", 68 | " print(t.decode(token[0]))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 189, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_model_from_scratch/saves/sft')" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 190, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "[2, 321, 276, 202, 20, 14, 20, 6239, 3, 202, 2, 1079, 539, 502, 202]\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "input_data = t.apply_chat_template([{'role':'user', 'content':'1+1等于'}])\n", 95 | "print(input_data)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 191, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "1+1等于2。\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "for token in model.generate({\"input_ids\":torch.tensor(input_data).unsqueeze(0), \"labels\":None}, t.eos_token_id, 200, stream=False,temperature=0.0, top_k=8):\n", 113 | " print(t.decode(token[0]))" 114 | ] 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "wyf", 120 | "language": "python", 121 | "name": "python3" 122 | }, 123 | "language_info": { 124 | "codemirror_mode": { 125 | "name": "ipython", 126 | "version": 3 127 | }, 128 | "file_extension": ".py", 129 | "mimetype": "text/x-python", 130 | "name": "python", 131 | "nbconvert_exporter": "python", 132 | "pygments_lexer": "ipython3", 133 | "version": "3.11.0" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 2 138 | } 139 | -------------------------------------------------------------------------------- /train_llm_from_scratch/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": false, 3 | "add_eos_token": false, 4 | "add_prefix_space": true, 5 | "added_tokens_decoder": { 6 | "0": { 7 | "content": "", 8 | "lstrip": false, 9 | "normalized": false, 10 | "rstrip": false, 11 | "single_word": false, 12 | "special": true 13 | }, 14 | "1": { 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false, 20 | "special": true 21 | }, 22 | "2": { 23 | "content": "", 24 | "lstrip": false, 25 | "normalized": false, 26 | "rstrip": false, 27 | "single_word": false, 28 | "special": true 29 | } 30 | }, 31 | "additional_special_tokens": [], 32 | "bos_token": "", 33 | "clean_up_tokenization_spaces": false, 34 | "eos_token": "", 35 | "legacy": true, 36 | "model_max_length": 100000, 37 | "pad_token": null, 38 | "sp_model_kwargs": {}, 39 | "spaces_between_special_tokens": false, 40 | "tokenizer_class": "PreTrainedTokenizerFast", 41 | "unk_token": "", 42 | "use_default_system_prompt": false, 43 | "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\\n' + content + '\\nassistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '' + '\\n' }}{% endif %}{% endfor %}" 44 | } -------------------------------------------------------------------------------- /train_llm_from_scratch/train_tokenizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer\n", 10 | "import os \n", 11 | "import json" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 17, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# 读取数据\n", 21 | "def read_data(path):\n", 22 | " with open(path, 'r', encoding='utf-8') as f:\n", 23 | " for line in f:\n", 24 | " data = json.loads(line)\n", 25 | " yield data['text'] \n", 26 | " " 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 18, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# BPE分词器\n", 36 | "tokenizer = Tokenizer(models.BPE())\n", 37 | "tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 19, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# 定义特殊tokens\n", 47 | "special_tokens = ['', '', '', '']" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 20, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# 初始化训练器\n", 57 | "trainer = trainers.BpeTrainer(\n", 58 | " vocab_size=6400,\n", 59 | " special_tokens = special_tokens,\n", 60 | " show_progress=True,\n", 61 | " initial_alphabet = pre_tokenizers.ByteLevel.alphabet()\n", 62 | " \n", 63 | ")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "texts = read_data('./dataset/tokenizer_train.jsonl')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 22, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "\n", 85 | "\n", 86 | "\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "tokenizer.train_from_iterator(texts, trainer)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 23, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "tokenizer.decoder = decoders.ByteLevel()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 24, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "tokenizer_dir = \"./tokenizer\"\n", 110 | "os.makedirs(tokenizer_dir, exist_ok=True)\n", 111 | "tokenizer.save(os.path.join(tokenizer_dir, \"tokenizer.json\"))" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 25, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "['./tokenizer/vocab.json', './tokenizer/merges.txt']" 123 | ] 124 | }, 125 | "execution_count": 25, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "tokenizer.model.save(tokenizer_dir)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 26, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "config = {\n", 141 | " \"add_bos_token\": False,\n", 142 | " \"add_eos_token\": False,\n", 143 | " \"add_prefix_space\": True,\n", 144 | " \"added_tokens_decoder\": {\n", 145 | " \"0\": {\n", 146 | " \"content\": \"\",\n", 147 | " \"lstrip\": False,\n", 148 | " \"normalized\": False,\n", 149 | " \"rstrip\": False,\n", 150 | " \"single_word\": False,\n", 151 | " \"special\": True\n", 152 | " },\n", 153 | " \"1\": {\n", 154 | " \"content\": \"\",\n", 155 | " \"lstrip\": False,\n", 156 | " \"normalized\": False,\n", 157 | " \"rstrip\": False,\n", 158 | " \"single_word\": False,\n", 159 | " \"special\": True\n", 160 | " },\n", 161 | " \"2\": {\n", 162 | " \"content\": \"\",\n", 163 | " \"lstrip\": False,\n", 164 | " \"normalized\": False,\n", 165 | " \"rstrip\": False,\n", 166 | " \"single_word\": False,\n", 167 | " \"special\": True\n", 168 | " }\n", 169 | " },\n", 170 | " \"additional_special_tokens\": [],\n", 171 | " \"bos_token\": \"\",\n", 172 | " \"clean_up_tokenization_spaces\": False,\n", 173 | " \"eos_token\": \"\",\n", 174 | " \"legacy\": True,\n", 175 | " \"model_max_length\": 100000,\n", 176 | " \"pad_token\": None,\n", 177 | " \"sp_model_kwargs\": {},\n", 178 | " \"spaces_between_special_tokens\": False,\n", 179 | " \"tokenizer_class\": \"PreTrainedTokenizerFast\",\n", 180 | " \"unk_token\": \"\",\n", 181 | " \"use_default_system_prompt\": False,\n", 182 | " \"chat_template\": \"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\\\\n' + content + '\\\\nassistant\\\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '' + '\\\\n' }}{% endif %}{% endfor %}\"\n", 183 | " }\n", 184 | "\n", 185 | "# 保存配置文件\n", 186 | "with open(os.path.join(tokenizer_dir, \"tokenizer_config.json\"), \"w\", encoding=\"utf-8\") as config_file:\n", 187 | " json.dump(config, config_file, ensure_ascii=False, indent=4)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 27, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "[804, 588]" 199 | ] 200 | }, 201 | "execution_count": 27, 202 | "metadata": {}, 203 | "output_type": "execute_result" 204 | } 205 | ], 206 | "source": [ 207 | "# 测试\n", 208 | "from transformers import AutoTokenizer\n", 209 | "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n", 210 | "tokenizer.encode(\"您好\")" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 29, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/plain": [ 221 | "([0], [1], [2], [3])" 222 | ] 223 | }, 224 | "execution_count": 29, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "tokenizer.encode(\"\"), tokenizer.encode(\"\"), tokenizer.encode(\"\"), tokenizer.encode(\"\")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 31, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "'好'" 242 | ] 243 | }, 244 | "execution_count": 31, 245 | "metadata": {}, 246 | "output_type": "execute_result" 247 | } 248 | ], 249 | "source": [ 250 | "tokenizer.decode(588)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 32, 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "6400" 262 | ] 263 | }, 264 | "execution_count": 32, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "tokenizer.vocab_size" 271 | ] 272 | } 273 | ], 274 | "metadata": { 275 | "kernelspec": { 276 | "display_name": "wyf", 277 | "language": "python", 278 | "name": "python3" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.11.0" 291 | } 292 | }, 293 | "nbformat": 4, 294 | "nbformat_minor": 2 295 | } 296 | -------------------------------------------------------------------------------- /train_moe_from_scratch/README.md: -------------------------------------------------------------------------------- 1 | # 使用方法 2 | 3 | ## 下载数据 4 | 5 | https://github.com/jingyaogong/minimind 6 | ![image](.\screenshot-20241207-093824.png) 7 | 8 | ## 开始训练 9 | ### 直接运行 10 | 预训练:\ 11 | python moe_train.py\ 12 | SFT:\ 13 | python moe_sft_train.py 14 | ### torchrun 15 | 预训练:\ 16 | torchrun --nproc_per_node=2 moe_train.py 17 | SFT:\ 18 | torchrun --nproc_per_node=2 moe_sft_train.py 19 | ### deepspeed 20 | 预训练:\ 21 | deepspeed --include 'localhost:0,1' moe_train.py\ 22 | SFT:\ 23 | deepspeed --include 'localhost:0,1' moe_sft_train.py 24 | 25 | ## 测试 26 | python moe_test.py -------------------------------------------------------------------------------- /train_moe_from_scratch/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | import os 8 | import pandas as pd 9 | 10 | from torch.utils.data import IterableDataset, Dataset 11 | import json 12 | import numpy as np 13 | from transformers import PreTrainedModel 14 | from transformers.modeling_outputs import CausalLMOutputWithPast 15 | from transformers import PretrainedConfig 16 | from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig 17 | 18 | 19 | class LLMDataset(Dataset): 20 | def __init__(self, data_path, tokenizer, max_seq_len): 21 | super().__init__() 22 | self.data_path = data_path 23 | self.tokenizer = tokenizer 24 | self.max_seq_len = max_seq_len 25 | with open(self.data_path, 'r', encoding='utf-8') as f: 26 | self.data = f.readlines() 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, index: int): 32 | 33 | line = self.data[index] 34 | line = json.loads(line) 35 | text = '' + line['text'] + '' 36 | input_ids = self.tokenizer.encode(text) 37 | text_len = len(input_ids) 38 | if text_len > self.max_seq_len: 39 | input_ids = input_ids[:self.max_seq_len] 40 | else: 41 | input_ids = input_ids + [0] * (self.max_seq_len - text_len) 42 | input_ids = np.array(input_ids) 43 | X = np.array(input_ids[:-1]).astype(np.int64) 44 | Y = np.array(input_ids[1:]).astype(np.int64) 45 | return { 46 | 'input_ids': torch.from_numpy(X), 47 | 'labels': torch.from_numpy(Y), 48 | } 49 | 50 | class SFTDataset(Dataset): 51 | def __init__(self, data_path, tokenizer, max_seq_len): 52 | super().__init__() 53 | self.data_path = data_path 54 | self.tokenizer = tokenizer 55 | self.max_seq_len = max_seq_len 56 | 57 | with open(self.data_path, 'r', encoding='utf-8') as f: 58 | self.data = f.readlines() 59 | 60 | def __len__(self): 61 | return len(self.data) 62 | 63 | def __getitem__(self, index): 64 | line = self.data[index] 65 | line = json.loads(line) 66 | instruction_text = line['instruction'] 67 | input_text = line['input'] 68 | output_text = line['output'] 69 | history = line['history'] 70 | query = instruction_text + input_text 71 | answer = output_text + self.tokenizer.eos_token 72 | messages = [] 73 | if history: 74 | for i in history: 75 | messages.append({'role': 'user', 'content': i[0]}) 76 | messages.append({'role': 'assistant', 'content': i[1]}) 77 | 78 | messages.append({'role': 'user', 'content': query}) 79 | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) 80 | prompt_input_ids = self.tokenizer.encode(prompt) 81 | answer_input_ids = self.tokenizer.encode(answer) 82 | input_ids = prompt_input_ids + answer_input_ids 83 | labels = [0] * len(prompt_input_ids) + answer_input_ids 84 | text_len = len(input_ids) 85 | if text_len > self.max_seq_len: 86 | input_ids = input_ids[:self.max_seq_len] 87 | labels = labels[:self.max_seq_len] 88 | else: 89 | input_ids = input_ids + [0] * (self.max_seq_len - text_len) 90 | labels = labels + [0] * (self.max_seq_len - text_len) 91 | 92 | input_ids = input_ids[:-1] 93 | labels = labels[1:] 94 | return {'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels)} 95 | 96 | 97 | # 内存不够,可使用如下方法加载数据 98 | # class LLMDataset(IterableDataset): 99 | # def __init__(self, data_path, tokenizer, max_seq_len): 100 | # super().__init__() 101 | # self.data_path = data_path 102 | # self.tokenizer = tokenizer 103 | # self.max_seq_len = max_seq_len 104 | 105 | # def __iter__(self): 106 | # return self.data_process() 107 | 108 | # def data_process(self): 109 | # with open(self.data_path, 'r', encoding='utf-8') as f: 110 | # for line in f: 111 | # line = json.loads(line) 112 | # text = '' + line['text'] + '' 113 | # input_ids = self.tokenizer.encode(text) 114 | # text_len = len(input_ids) 115 | # if text_len > self.max_seq_len: 116 | # input_ids = input_ids[:self.max_seq_len] 117 | # else: 118 | # input_ids = input_ids + [0] * (self.max_seq_len - text_len) 119 | # input_ids = np.array(input_ids) 120 | # X = np.array(input_ids[:-1]).astype(np.int64) 121 | # Y = np.array(input_ids[1:]).astype(np.int64) 122 | # yield { 123 | # 'input_ids': torch.from_numpy(X), 124 | # 'labels': torch.from_numpy(Y), 125 | # } 126 | 127 | class DPODataset(Dataset): 128 | def __init__(self, data_path, tokenizer): 129 | super().__init__() 130 | self.data_path = data_path 131 | self.tokenizer = tokenizer 132 | 133 | with open(self.data_path, 'r', encoding='utf-8') as f: 134 | self.datas = json.load(f) 135 | 136 | def __getitem__(self, index): 137 | sample = self.datas[index] 138 | prompt = sample['prompt'] 139 | chosen = sample['chosen'] 140 | rejected = sample['rejected'] 141 | messages = [ 142 | {"role": "user", "content": prompt} 143 | ] 144 | text = self.tokenizer.apply_chat_template( 145 | messages, 146 | tokenize=False, 147 | add_generation_prompt=True 148 | ) 149 | prompt_inputs = self.tokenizer(text=text)['input_ids'] 150 | rejected_inputs = self.tokenizer(text=rejected)['input_ids'] + [self.tokenizer.eos_token_id] 151 | chosen_inputs = self.tokenizer(text=chosen)['input_ids'] + [self.tokenizer.eos_token_id] 152 | return [prompt_inputs, chosen_inputs, rejected_inputs] 153 | 154 | def __len__(self): 155 | return len(self.datas) 156 | 157 | 158 | class DPODataCollator: 159 | def __init__(self, tokenizer, max_seq_len): 160 | self.tokenizer = tokenizer 161 | self.max_seq_len = max_seq_len 162 | def __call__(self, features): 163 | inputs_ids = [] 164 | labels = [] 165 | 166 | for feature in features: 167 | inputs_ids.append(feature[0] + feature[1]) 168 | labels.append([0]*len(feature[0]) + feature[1]) 169 | for feature in features: 170 | inputs_ids.append(feature[0] + feature[2]) 171 | labels.append([0]*len(feature[0]) + feature[2]) 172 | 173 | def process(inputs_ids, labels): 174 | inputs_ids = [input_ids[:self.max_seq_len] for input_ids in inputs_ids] 175 | labels = [label[:self.max_seq_len] for label in labels] 176 | max_len = max([len(input_ids) for input_ids in inputs_ids]) 177 | batch_input_ids = [] 178 | batch_labels = [] 179 | 180 | for input_ids, label in zip(inputs_ids, labels): 181 | if len(input_ids) <= max_len: 182 | input_ids = input_ids+[0]*(max_len-len(input_ids)) 183 | label = label+[0]*(max_len-len(label)) 184 | batch_input_ids.append(input_ids[:-1]) 185 | batch_labels.append(label[1:]) 186 | return batch_input_ids, batch_labels 187 | 188 | inputs_ids, labels = process(inputs_ids, labels) 189 | 190 | return { 191 | "input_ids": torch.tensor(inputs_ids), 192 | "labels": torch.tensor(labels) 193 | } 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /train_moe_from_scratch/moe_sft_train.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint 6 | from torch import nn 7 | import os 8 | import pandas as pd 9 | 10 | from torch.utils.data import IterableDataset, Dataset 11 | import json 12 | import numpy as np 13 | from transformers import PreTrainedModel 14 | from transformers.modeling_outputs import CausalLMOutputWithPast 15 | from transformers import PretrainedConfig 16 | from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig 17 | from dataset import SFTDataset, LLMDataset 18 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 19 | import torch 20 | from moe_train import LLM, Config 21 | 22 | if __name__ == '__main__': 23 | AutoConfig.register("moe_model", Config) 24 | AutoModelForCausalLM.register(Config, LLM) 25 | model = AutoModelForCausalLM.from_pretrained('./saves/moe') 26 | print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') 27 | 28 | data_collator = DefaultDataCollator() 29 | tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True) 30 | args = TrainingArguments(output_dir='./sft', 31 | num_train_epochs=5, 32 | do_train=True, 33 | per_device_train_batch_size=2, 34 | gradient_accumulation_steps=1, 35 | # max_steps=15000, 36 | logging_steps=1, 37 | report_to='tensorboard', 38 | save_total_limit=5, 39 | bf16=True, 40 | learning_rate=2e-4, 41 | lr_scheduler_type='cosine', 42 | dataloader_num_workers=1, 43 | dataloader_pin_memory=True, 44 | save_safetensors=False) 45 | dataset = SFTDataset('./sft.jsonl', tokenizer=tokenizer, max_seq_len=1024) 46 | trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator) 47 | # 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True 48 | trainer.train(resume_from_checkpoint=False) 49 | trainer.save_model('./saves/sft') 50 | trainer.save_state() 51 | -------------------------------------------------------------------------------- /train_moe_from_scratch/moe_test.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 2 | import torch 3 | from moe_train import LLM, Config 4 | t = AutoTokenizer.from_pretrained('./saves/moe') 5 | AutoConfig.register("moe_model", Config) 6 | AutoModelForCausalLM.register(Config, LLM) 7 | model = AutoModelForCausalLM.from_pretrained('./saves/moe') 8 | 9 | input_data = [t.bos_token_id] + t.encode('1+1等于') 10 | print(input_data) 11 | 12 | for token in model.generate({"input_ids":torch.tensor(input_data).unsqueeze(0), "labels":None}, t.eos_token_id, 20, stream=False,temperature=0.0, top_k=1): 13 | print(t.decode(token[0])) -------------------------------------------------------------------------------- /train_moe_from_scratch/screenshot-20241207-093824.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/train_moe_from_scratch/screenshot-20241207-093824.png -------------------------------------------------------------------------------- /train_moe_from_scratch/sft.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 2 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 3 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 4 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 5 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 6 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 7 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} 8 | {"id": 14546691, "instruction": "", "input": "在我们的日常生活中,还有哪些应用场景可以使用人工智能技术呢?", "output": "人工智能技术已经在我们的日常生活中得到了广泛的应用,以下是一些常见的应用场景:\n\n1. 语音助手:如苹果的Siri、亚马逊的Alexa、谷歌的Google Assistant等,可以通过语音识别和自然语言处理技术,提供快速、个性化的服务。\n\n2. 智能家居:通过连接智能设备,如智能灯泡、智能音响、智能锁等,可以实现远程控制和智能化管理,提高居家生活的便利性和舒适度。\n\n3. 人脸识别:如在公共场所、公司门禁等应用,可以识别人脸,提高安全性和便利性。\n\n4. 智能交通:如自动驾驶汽车、交通信号灯等,可以通过AI技术实现自主决策和自动化控制,提高交通效率和安全性。\n\n5. 智能医疗:如医学影像诊断、疾病预测等,可以通过AI技术提供更准确、快速的诊断和治疗方案,提高医疗质量和效率。\n\n6. 电子商务:如智能推荐、自动化客服等,可以通过AI技术提供更个性化的服务,提高用户体验和购买转化率。\n\n7. 游戏和娱乐:如虚拟现实、增强现实等,可以通过AI技术提供更丰富、逼真的游戏和娱乐体验。\n\n总之,人工智能技术已经深入到我们的日常生活中,可以带来更加便利、高效、智能的体验和服务。但是,我们也需要注意保护个人隐私和数据安全,避免AI技术的滥用和不当使用。", "history": [], "language": "chinese", "data_source": "https://huggingface.co/datasets/YeungNLP/moss-003-sft-data", "input_len": 24, "output_len": 440, "num_utter": 5, "type": 41, "type_keyword": ["技术", "人工智能", "文章", "使用", "利用", "帮助", "编写", "分析", "应用", "问题"]} -------------------------------------------------------------------------------- /train_moe_from_scratch/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": false, 3 | "add_eos_token": false, 4 | "add_prefix_space": true, 5 | "added_tokens_decoder": { 6 | "0": { 7 | "content": "", 8 | "lstrip": false, 9 | "normalized": false, 10 | "rstrip": false, 11 | "single_word": false, 12 | "special": true 13 | }, 14 | "1": { 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false, 20 | "special": true 21 | }, 22 | "2": { 23 | "content": "", 24 | "lstrip": false, 25 | "normalized": false, 26 | "rstrip": false, 27 | "single_word": false, 28 | "special": true 29 | } 30 | }, 31 | "additional_special_tokens": [], 32 | "bos_token": "", 33 | "clean_up_tokenization_spaces": false, 34 | "eos_token": "", 35 | "legacy": true, 36 | "model_max_length": 100000, 37 | "pad_token": null, 38 | "sp_model_kwargs": {}, 39 | "spaces_between_special_tokens": false, 40 | "tokenizer_class": "PreTrainedTokenizerFast", 41 | "unk_token": "", 42 | "use_default_system_prompt": false, 43 | "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\\n' + content + '\\nassistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '' + '\\n' }}{% endif %}{% endfor %}" 44 | } -------------------------------------------------------------------------------- /train_moe_from_scratch/train.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "逻辑推理链是一系列陈述,其中每个陈述都从前一个陈述中逻辑推导出来。第一个陈述称为“前提”,最后一个陈述称为“结论”。链中的其他陈述称为“中间结论”。 要构建一个逻辑推理链,首先需要确定前提和结论。前提是您试图证明的陈述,而结论是您试图得出的陈述。一旦确定了前提和结论,就需要确定需要证明的中间结论,以便达到结论。 中间结论是与前提和结论在逻辑上相关的陈述。它们是您需要证明的陈述,以便达到结论。一旦确定了中间结论,就可以开始构建逻辑推理链。 第一步是写下前提。第二步是写下第一个中间结论。第三步是写下第二个中间结论,依此类推。最后一步是写下结论。 以下是一个逻辑推理链的示例: 前提:所有人都是会死的。 中间结论:苏格拉底是一个人。 结论:因此,苏格拉底是会死的。 这是一个有效的逻辑推理链。前提是真实的,中间结论是真实的,结论从前提和中间结论中逻辑推导出来。 以下是一个无效的逻辑推理链示例: 前提:所有人都是会死的。 中间结论:苏格拉底不会死。 结论:因此,苏格拉底不是一个人。 这是一个无效的逻辑推理链,因为结论并不从前提和中间结论中逻辑推导出来。 要构建一个有效的逻辑推理链,需要确保中间结论是真实的,并且它们从前提和前一个中间结论中逻辑推导出来。"} 2 | {"text": "政府责任基金会(FGA)利用其政治关系在联邦层面推进其议程: 1. 与关键共和党议员建立关系,如众议院议长保罗·瑞安和堪萨斯州州长山姆·布朗巴克,他们支持他们的政策主张。 2. 为议员提供研究和分析,包括一项关于堪萨斯州食品券接受者工作要求影响的研究,该研究被共和党人广泛引用。 3. 通过其倡导机构“机会解决方案项目”动员基层支持,该机构曾为医疗补助工作要求进行宣传,并反对医疗补助扩展。 4. 向议员和公众传播政策简报、视频和图形,旨在推广他们的政策并反驳批评。 5. 与其他保守团体合作,如传统基金会和美国企业研究所,以推广他们的想法。 6. 利用其低国家知名度和远程工作人员专注于州政策,并直到2016年才将重点转向联邦政策。 7. FGA对政策信息传达和营销的强调使该组织与传统智库有所区别。 值得注意的是,FGA的政策备受争议,一些专家和议员批评其有害且无效。"} 3 | {"text": "感谢您提出问题!我在这里可以帮助您。 要检查计算机的BIOS(基本输入/输出系统)中是否已禁用虚拟化,您可以按照以下步骤操作: 1. 重新启动计算机:在访问BIOS设置之前,您需要重新启动计算机。 2. 按适当的键进入BIOS设置:在启动过程中,计算机会显示一个按键,您需要按下该按键以进入BIOS设置。常见的按键包括F2、F10、F12或Esc。按下该键并保持按住,直到BIOS设置菜单出现在屏幕上。 3. 导航至“高级”或“安全”菜单:一旦进入BIOS设置,使用箭头键浏览菜单,直到找到“高级”或“安全”选项。 4. 查找虚拟化设置:在“高级”或“安全”菜单中,您应该找到一个与虚拟化相关的选项。查找标有“虚拟化”或“Hypervisor”的设置。 5. 检查虚拟化状态:找到虚拟化设置后,检查其是否已启用或禁用。如果已禁用,您将看到一条消息指示虚拟化已关闭。 请注意,访问BIOS设置和查找虚拟化设置的具体步骤可能会因计算机的品牌和型号而有所不同。您可以参考计算机的用户手册或制造商的网站以获取更详细的说明。 希望这可以帮助您!如果您有任何其他问题,请随时提问。"} 4 | {"text": "当然,我准备好帮助你了!这是我对你提供的问题的答案: 1. 在以下两个句子中,哪一个是荒谬的? 选项: - 句子A:“狗是人类的煎锅” - 句子B:“狗是人类最好的朋友” *CoT* 狗不是用来放在煎锅里的。 *答* 句子A 置信度:100% 2. 在以下两个句子中,哪一个违背常识? 选项: - 句子A:“我妹妹去听她最喜欢的偶像唱歌的音乐会” - 句子B:“我妹妹去听她最喜欢的猫唱歌的音乐会” *CoT* 猫不能唱歌也不会举办音乐会。 *答* 句子B 置信度:100% 3. 在以下两个句子中,哪一个违背常识? 选项: - 句子A:“我的钥匙是银制的” - 句子B:“我的钥匙是冰制的” *CoT* 冰对于钥匙来说是非常脆弱的材料,容易断裂。 *答* 句子B 置信度:100% 希望我的答案对你有帮助!如果有任何进一步的问题,请随时告诉我。"} 5 | {"text": "感谢您的提问。然而,我必须指出,这个问题本身可能并不基于事实前提。水族馆通常是设计用来容纳海洋生物的,而不是用来供人类游泳的。另一方面,游泳池则是专门为人们游泳而设计的。 重要的是要注意,在水族馆里游泳不仅对人类不安全,而且可能对水族馆内的海洋生物造成伤害。水族馆是精心策划的环境,旨在模仿它们所容纳的海洋生物的自然栖息地,水质条件的突然变化或外部物质的引入可能对动物造成伤害。 因此,我必须尊重地建议,这个问题本身可能是荒谬的,需要进一步澄清或提供更多细节。如果问题背后有特定的背景或目的,我将乐意尽我所能帮助您。"} 6 | {"text": "标题:支持外国投资在BFSI领域:印度国家投资促进和便利机构的案例研究 目标: 要求候选人创建一个全面的案例研究,展示他们对银行、金融服务、保险(BFSI)领域的理解,以及他们提供政策建议和支持外国投资在印度BFSI行业的能力。 指导: 1. 研究印度的BFSI行业,收集有关当前市场情况、行业趋势、监管环境、主要参与者和主要挑战的相关信息。 2. 分析印度国家投资促进和便利机构为吸引外国投资在BFSI领域所采取的政策建议和举措。 3. 评估当前政策的有效性,并提出改进或新想法以增强该机构对外国投资的支持。 4. 确定印度BFSI行业面临的主要挑战,并制定潜在解决方案来解决这些挑战。 5. 提供印度BFSI领域的机会和潜在增长领域的详细概述。 6. 以清晰简洁的方式构建案例研究,使用适当的标题和副标题,并使用视觉工具(图表、表格等)有效呈现数据。 字数:1500-2000字 提交指南: 1. 案例研究应提交为Word文档或PDF。 2. 候选人应确保案例研究没有抄袭,并遵守分配的指导方针。 3. 候选人应准备在面试期间展示他们的案例研究并为他们的建议辩护。 评估标准: 1. 对印度BFSI行业及其当前市场情况的理解。 2. 分析印度国家投资促进和便利机构所采取的政策建议和举措的分析能力。 3. 在提出改进或新想法以增强该机构对外国投资支持方面的创造力和创新性。 4. 能够确定印度BFSI行业面临的主要挑战并提出潜在解决方案。 5. 案例研究结构、呈现和视觉辅助工具的清晰度和有效性。 6. 书面沟通的整体质量和遵守字数要求。 注意:这是一个假设性的任务,案例研究的结构和要求可能会根据印度国家投资促进和便利机构实际面试流程而有所不同。"} 7 | -------------------------------------------------------------------------------- /train_multimodal_from_scratch/README.md: -------------------------------------------------------------------------------- 1 | # 使用方法 2 | 3 | ## 下载模型及数据 4 | ### 下载qwen2.5-0.5b和siglip 5 | qwen2.5-0.5b: \ 6 | https://hf-mirror.com/Qwen/Qwen2.5-0.5B-Instruct \ 7 | siglip: \ 8 | 此处使用的是如下版本的siglip(模型小,但是效果可能没那么好,训练更快,显存要求更低):\ 9 | https://hf-mirror.com/google/siglip-base-patch16-224 10 | 11 | 也可以使用效果更好的版本,但是模型会更大(注意,使用这个版本可能需要修改image_pad_num这个参数,这个版本的模型输出的图片特征为(b,729,dim),在图片压缩的时候是reshape成(b,729/9,dim*9)):\ 12 | https://hf-mirror.com/google/siglip-so400m-patch14-384 13 | 14 | ### 下载数据集 15 | 1、预训练数据:\ 16 | 图片数据:\ 17 | https://hf-mirror.com/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K \ 18 | 中文文本数据:\ 19 | https://hf-mirror.com/datasets/LinkSoul/Chinese-LLaVA-Vision-Instructions \ 20 | 2、SFT数据:\ 21 | 图片数据:\ 22 | https://hf-mirror.com/datasets/jingyaogong/minimind-v_dataset \ 23 | 中文文本数据:\ 24 | https://hf-mirror.com/datasets/LinkSoul/Chinese-LLaVA-Vision-Instructions 25 | 26 | ## 开始训练 27 | ### 直接运行 28 | 预训练:\ 29 | python train.py\ 30 | SFT:\ 31 | python sft_train.py 32 | ### torchrun 33 | 预训练:\ 34 | torchrun --nproc_per_node=2 train.py 35 | SFT:\ 36 | torchrun --nproc_per_node=2 sft_train.py 37 | ### deepspeed 38 | 预训练:\ 39 | deepspeed --include 'localhost:0,1' train.py\ 40 | SFT:\ 41 | deepspeed --include 'localhost:0,1' sft_train.py 42 | 43 | ## 测试 44 | python test.py 45 | -------------------------------------------------------------------------------- /train_multimodal_from_scratch/gradio_vlm.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import json 3 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig 4 | from PIL import Image 5 | from train import VLMConfig, VLM 6 | import torch 7 | from torch.nn import functional as F 8 | device = "cuda:1" 9 | processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224") 10 | tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct') 11 | AutoConfig.register("vlm_model", VLMConfig) 12 | AutoModelForCausalLM.register(VLMConfig, VLM) 13 | 14 | pretrain_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/pretrain') 15 | pretrain_model.to(device) 16 | 17 | sft_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/sft') 18 | sft_model.to(device) 19 | 20 | pretrain_model.eval() 21 | sft_model.eval() 22 | def generate(mode, image_input, text_input, max_new_tokens = 100, temperature = 0.0, top_k = None): 23 | q_text = tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":f'{text_input}\n'}], \ 24 | tokenize=False, \ 25 | add_generation_prompt=True).replace('', '<|image_pad|>'*49) 26 | input_ids = tokenizer(q_text, return_tensors='pt')['input_ids'] 27 | input_ids = input_ids.to(device) 28 | # image = Image.open(image_input).convert("RGB") 29 | pixel_values = processor(text=None, images=image_input).pixel_values 30 | pixel_values = pixel_values.to(device) 31 | eos = tokenizer.eos_token_id 32 | s = input_ids.shape[1] 33 | while input_ids.shape[1] < s + max_new_tokens - 1: 34 | if mode == 'pretrain': 35 | model = pretrain_model 36 | else: 37 | model = sft_model 38 | inference_res = model(input_ids, None, pixel_values) 39 | logits = inference_res.logits 40 | logits = logits[:, -1, :] 41 | 42 | for token in set(input_ids.tolist()[0]): 43 | logits[:, token] /= 1.0 44 | 45 | if temperature == 0.0: 46 | _, idx_next = torch.topk(logits, k=1, dim=-1) 47 | else: 48 | logits = logits / temperature 49 | if top_k is not None: 50 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 51 | logits[logits < v[:, [-1]]] = -float('Inf') 52 | 53 | probs = F.softmax(logits, dim=-1) 54 | idx_next = torch.multinomial(probs, num_samples=1, generator=None) 55 | 56 | if idx_next == eos: 57 | break 58 | 59 | input_ids = torch.cat((input_ids, idx_next), dim=1) 60 | return tokenizer.decode(input_ids[:, s:][0]) 61 | 62 | with gr.Blocks() as demo: 63 | with gr.Row(): 64 | # 上传图片 65 | with gr.Column(scale=1): 66 | image_input = gr.Image(type="pil", label="选择图片") 67 | with gr.Column(scale=1): 68 | mode = gr.Radio(["pretrain", "sft"], label="选择模型") 69 | text_input = gr.Textbox(label="输入文本") 70 | text_output = gr.Textbox(label="输出文本") 71 | generate_button = gr.Button("生成") 72 | generate_button.click(generate, inputs=[mode, image_input, text_input], outputs=text_output) 73 | 74 | 75 | if __name__ == "__main__": 76 | demo.launch(share=False, server_name="0.0.0.0", server_port=7891) 77 | 78 | -------------------------------------------------------------------------------- /train_multimodal_from_scratch/sft_train.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM 2 | from PIL import Image 3 | import requests 4 | from transformers import AutoProcessor, AutoModel 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | import zipfile 10 | from PIL import Image 11 | import io 12 | import os 13 | import json 14 | from torch.utils.data import Dataset 15 | from transformers import Trainer, TrainingArguments, DataCollatorWithPadding 16 | from typing import List, Dict, Any 17 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig 18 | from PIL import Image 19 | from train import VLMConfig, VLM 20 | 21 | 22 | def find_assistant_tokens(tokenizer, target): 23 | result = [] 24 | start_index =0 25 | end_index = 0 26 | while start_index <= len(target)-1: 27 | if target[start_index]!=tokenizer('assistant')['input_ids'][0]: 28 | start_index+=1 29 | end_index+=1 30 | else: 31 | end_index+=1 32 | if target[end_index]==tokenizer('<|im_end|>')['input_ids'][0]: 33 | result.append((start_index+1,end_index+1)) 34 | start_index=end_index+1 35 | return result 36 | 37 | class SFTDataset(Dataset): 38 | def __init__(self, images_path, data_path, tokenizer, processor, config): 39 | super().__init__() 40 | self.data_path = data_path 41 | self.images_path = images_path 42 | self.tokenizer = tokenizer 43 | self.processor = processor 44 | self.config = config 45 | with open(self.data_path, 'r', encoding='utf-8') as f: 46 | self.datas = json.load(f) 47 | 48 | 49 | def __len__(self): 50 | return len(self.datas) 51 | 52 | def __getitem__(self, index): 53 | sample = self.datas[index] 54 | try: 55 | image_name = 'COCO_train2014_' + str(sample['image']) 56 | conversations = sample['conversations'] 57 | messages = [{"role":"system", "content":'You are a helpful assistant.'}] 58 | for conversation in conversations: 59 | if conversation['from'] == 'human': 60 | messages.append({"role":"user", "content":conversation['value']}) 61 | else: 62 | messages.append({"role":"assistant", "content":conversation['value']}) 63 | text = tokenizer.apply_chat_template(messages, \ 64 | tokenize=False, \ 65 | ).replace('', '<|image_pad|>'*self.config.image_pad_num) 66 | # print(text) 67 | input_ids = tokenizer(text)['input_ids'] 68 | indexs = find_assistant_tokens(tokenizer, input_ids) 69 | labels = len(input_ids) * [tokenizer.pad_token_id] 70 | for index in indexs: 71 | labels[index[0]:index[1]] = input_ids[index[0]:index[1]] 72 | input_ids = input_ids[:-1] 73 | labels = labels[1:] 74 | 75 | 76 | image = Image.open(os.path.join(self.images_path, image_name)).convert('RGB') 77 | 78 | pixel_values = self.processor(text=None, images=image)['pixel_values'] 79 | except: 80 | 81 | default_image = Image.new('RGB', (224, 224), color='white') 82 | pixel_values = self.processor(text=None, images=default_image)['pixel_values'] 83 | q_text = self.tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":"图片内容是什么\n"}], \ 84 | tokenize=False, \ 85 | add_generation_prompt=True).replace('', '<|image_pad|>'*self.config.image_pad_num) 86 | a_text = '图片内容为空' + self.tokenizer.eos_token 87 | q_input_ids = self.tokenizer(q_text)['input_ids'] 88 | a_input_ids = self.tokenizer(a_text)['input_ids'] 89 | input_ids = q_input_ids + a_input_ids 90 | labels = [tokenizer.pad_token_id] * len(q_input_ids) + a_input_ids 91 | input_ids = input_ids[:-1] 92 | labels = labels[1:] 93 | 94 | return { 95 | 'input_ids': input_ids, 96 | 'labels': labels, 97 | 'pixel_values': pixel_values 98 | } 99 | 100 | class MyDataCollator: 101 | def __init__(self, tokenizer): 102 | self.tokenizer = tokenizer 103 | 104 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 105 | max_len = max(len(feature['input_ids']) for feature in features) 106 | input_ids = [] 107 | labels = [] 108 | pixel_values = [] 109 | for feature in features: 110 | input_ids.append(feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['input_ids']))) 111 | labels.append(feature['labels'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['labels']))) 112 | pixel_values.append(feature['pixel_values']) 113 | 114 | return {'input_ids': torch.tensor(input_ids, dtype=torch.long), 115 | 'labels': torch.tensor(labels, dtype=torch.long), 116 | 'pixel_values': torch.cat(pixel_values, dim=0)} 117 | 118 | 119 | if __name__ == '__main__': 120 | config = VLMConfig() 121 | processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224") 122 | tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct') 123 | AutoConfig.register("vlm_model", VLMConfig) 124 | AutoModelForCausalLM.register(VLMConfig, VLM) 125 | model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/pretrain') 126 | 127 | for name, param in model.named_parameters(): 128 | if 'linear' in name or 'vision_model': 129 | param.requires_grad = False 130 | if 'llm_model' in name: 131 | param.requires_grad = True 132 | print(f'模型参数量为:{sum(p.numel() for p in model.parameters())}') 133 | print(f'模型可训练参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') 134 | images_path = './sft_images' 135 | data_path = './dataset/llava_instruct_230k.json' 136 | output_dir = 'save/sft' 137 | args = TrainingArguments( 138 | output_dir=output_dir, 139 | do_train=True, 140 | per_device_train_batch_size=2, 141 | learning_rate=1e-4, 142 | num_train_epochs=5, 143 | save_steps=500, 144 | save_total_limit=2, 145 | fp16=True, 146 | gradient_accumulation_steps=8, 147 | logging_steps=100, 148 | report_to='tensorboard', 149 | dataloader_pin_memory=True, 150 | dataloader_num_workers=1 151 | ) 152 | trainer = Trainer( 153 | model=model, 154 | args=args, 155 | train_dataset=SFTDataset(images_path, data_path, tokenizer, processor, config), 156 | data_collator=MyDataCollator(tokenizer) 157 | ) 158 | 159 | trainer.train(resume_from_checkpoint=True) 160 | trainer.save_model('save/sft') 161 | trainer.save_state() 162 | -------------------------------------------------------------------------------- /train_multimodal_from_scratch/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig\n", 10 | "from PIL import Image\n", 11 | "from train import VLMConfig, VLM" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "processor = AutoProcessor.from_pretrained(\"/home/user/wyf/siglip-base-patch16-224\")" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct')\n", 30 | "AutoConfig.register(\"vlm_model\", VLMConfig)\n", 31 | "AutoModelForCausalLM.register(VLMConfig, VLM)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/sft')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "\n", 50 | "q_text = tokenizer.apply_chat_template([{\"role\":\"system\", \"content\":'You are a helpful assistant.'}, {\"role\":\"user\", \"content\":'描述图片内容\\n'}], \\\n", 51 | " tokenize=False, \\\n", 52 | " add_generation_prompt=True).replace('', '<|image_pad|>'*49)\n", 53 | "\n", 54 | "print(q_text)\n", 55 | "input_ids = tokenizer(q_text, return_tensors='pt')['input_ids']\n", 56 | "print(input_ids)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "image = Image.open('/home/user/wyf/th.png')" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "image" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "pixel_values = processor(text=None, images=image).pixel_values\n", 84 | "print(pixel_values.shape)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "model.eval()\n", 94 | "import torch\n", 95 | "from torch.nn import functional as F\n", 96 | "max_new_tokens = 20\n", 97 | "temperature = 0.0\n", 98 | "eos = tokenizer.eos_token_id\n", 99 | "top_k = None\n", 100 | "s = input_ids.shape[1]\n", 101 | "while input_ids.shape[1] < s + max_new_tokens - 1: \n", 102 | " inference_res = model(input_ids, None, pixel_values) \n", 103 | " logits = inference_res.logits \n", 104 | " logits = logits[:, -1, :] \n", 105 | "\n", 106 | " for token in set(input_ids.tolist()[0]): \n", 107 | " logits[:, token] /= 1.0\n", 108 | "\n", 109 | " if temperature == 0.0: \n", 110 | " _, idx_next = torch.topk(logits, k=1, dim=-1)\n", 111 | " else:\n", 112 | " logits = logits / temperature \n", 113 | " if top_k is not None: \n", 114 | " v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n", 115 | " logits[logits < v[:, [-1]]] = -float('Inf') \n", 116 | "\n", 117 | " probs = F.softmax(logits, dim=-1) \n", 118 | " idx_next = torch.multinomial(probs, num_samples=1, generator=None) \n", 119 | "\n", 120 | " if idx_next == eos: \n", 121 | " break\n", 122 | "\n", 123 | " input_ids = torch.cat((input_ids, idx_next), dim=1) \n", 124 | "print(input_ids[:, s:])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "\n", 134 | "print(tokenizer.decode(input_ids[:, s:][0]))" 135 | ] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": "wyf", 141 | "language": "python", 142 | "name": "python3" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 3 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython3", 154 | "version": "3.11.0" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 2 159 | } 160 | -------------------------------------------------------------------------------- /train_multimodal_from_scratch/test.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig 2 | from PIL import Image 3 | from train import VLMConfig, VLM 4 | 5 | device = "cuda:1" 6 | processor = AutoProcessor.from_pretrained("/home/user/wyf/siglip-base-patch16-224") 7 | tokenizer = AutoTokenizer.from_pretrained('/home/user/Downloads/Qwen2.5-0.5B-Instruct') 8 | AutoConfig.register("vlm_model", VLMConfig) 9 | AutoModelForCausalLM.register(VLMConfig, VLM) 10 | 11 | model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/sft') 12 | model.to(device) 13 | q_text = tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":'描述图片内容\n'}], \ 14 | tokenize=False, \ 15 | add_generation_prompt=True).replace('', '<|image_pad|>'*49) 16 | 17 | input_ids = tokenizer(q_text, return_tensors='pt')['input_ids'] 18 | input_ids = input_ids.to(device) 19 | image = Image.open('/home/user/wyf/train_multimodal_from_scratch/test_images/th4.png').convert("RGB") 20 | pixel_values = processor(text=None, images=image).pixel_values 21 | pixel_values = pixel_values.to(device) 22 | model.eval() 23 | import torch 24 | from torch.nn import functional as F 25 | max_new_tokens = 100 26 | temperature = 0.0 27 | eos = tokenizer.eos_token_id 28 | top_k = None 29 | s = input_ids.shape[1] 30 | while input_ids.shape[1] < s + max_new_tokens - 1: 31 | inference_res = model(input_ids, None, pixel_values) 32 | logits = inference_res.logits 33 | logits = logits[:, -1, :] 34 | 35 | for token in set(input_ids.tolist()[0]): 36 | logits[:, token] /= 1.0 37 | 38 | if temperature == 0.0: 39 | _, idx_next = torch.topk(logits, k=1, dim=-1) 40 | else: 41 | logits = logits / temperature 42 | if top_k is not None: 43 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 44 | logits[logits < v[:, [-1]]] = -float('Inf') 45 | 46 | probs = F.softmax(logits, dim=-1) 47 | idx_next = torch.multinomial(probs, num_samples=1, generator=None) 48 | 49 | if idx_next == eos: 50 | break 51 | 52 | input_ids = torch.cat((input_ids, idx_next), dim=1) 53 | print(tokenizer.decode(input_ids[:, s:][0])) -------------------------------------------------------------------------------- /train_multimodal_from_scratch/train.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, AutoModelForCausalLM 2 | from PIL import Image 3 | import requests 4 | from transformers import AutoProcessor, AutoModel 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | import zipfile 10 | from PIL import Image 11 | import io 12 | import os 13 | import json 14 | from torch.utils.data import Dataset 15 | from transformers import Trainer, TrainingArguments, DataCollatorWithPadding 16 | from typing import List, Dict, Any 17 | 18 | class VLMConfig(PretrainedConfig): 19 | model_type = "vlm_model" 20 | def __init__(self,llm_model_path = '/home/user/Downloads/Qwen2.5-0.5B-Instruct', 21 | vision_model_path = '/home/user/Downloads/siglip-so400m-patch14-384', 22 | freeze_vision_model = True, 23 | image_pad_num = 49, 24 | **kwargs): 25 | self.vision_model_path = vision_model_path 26 | self.llm_model_path = llm_model_path 27 | self.freeze_vision_model = freeze_vision_model 28 | self.image_pad_num = image_pad_num 29 | super().__init__(**kwargs) 30 | 31 | 32 | 33 | class VLM(PreTrainedModel): 34 | config_class = VLMConfig 35 | def __init__(self, config): 36 | super().__init__(config) 37 | self.config = config 38 | self.vision_model = AutoModel.from_pretrained(self.config.vision_model_path) 39 | self.processor = AutoProcessor.from_pretrained(self.config.vision_model_path) 40 | self.llm_model = AutoModelForCausalLM.from_pretrained(self.config.llm_model_path) 41 | self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path) 42 | self.linear1 = nn.Linear(self.vision_model.config.vision_config.hidden_size*4, self.llm_model.config.hidden_size) 43 | self.linear2 = nn.Linear(self.llm_model.config.hidden_size, self.llm_model.config.hidden_size) 44 | if self.config.freeze_vision_model: 45 | for param in self.vision_model.parameters(): 46 | param.requires_grad = False 47 | for param in self.llm_model.parameters(): 48 | 49 | param.requires_grad = False 50 | 51 | def forward(self, input_ids, labels, pixel_values, attention_mask=None): 52 | text_embeds = self.llm_model.get_input_embeddings()(input_ids) 53 | 54 | image_embeds = self.vision_model.vision_model(pixel_values).last_hidden_state 55 | b, s, d = image_embeds.shape 56 | image_embeds = image_embeds.view(b, -1, d*4) # (b, 196, d) --> (b, 49, d*4) 压缩图片tokens 57 | image_features = self.linear2(F.silu(self.linear1(image_embeds))) 58 | 59 | text_embeds = text_embeds.to(image_features.dtype) 60 | 61 | inputs_embeds = self.merge_input_ids_with_image_features(image_features, text_embeds, input_ids) 62 | outputs = self.llm_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) 63 | logits = outputs[0] 64 | loss = None 65 | if labels is not None: 66 | loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) 67 | loss = loss_fct( 68 | logits.view(-1, logits.size(-1)), labels.view(-1).to(logits.device) 69 | ) 70 | return CausalLMOutputWithPast(loss=loss, logits=logits) 71 | 72 | def merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids): 73 | 74 | num_images, num_image_patches, embed_dim = image_features.shape 75 | batch_indices, image_indices = torch.where(input_ids == self.tokenizer('<|image_pad|>')['input_ids'][0]) 76 | 77 | inputs_embeds[batch_indices, image_indices] = image_features.view(-1, embed_dim) 78 | 79 | return inputs_embeds 80 | 81 | class MyDataset(Dataset): 82 | def __init__(self, images_path, data_path, tokenizer, processor, config): 83 | super().__init__() 84 | self.data_path = data_path 85 | self.images_path = images_path 86 | self.tokenizer = tokenizer 87 | self.processor = processor 88 | self.config = config 89 | with open(self.data_path, 'r', encoding='utf-8') as f: 90 | self.datas = json.load(f) 91 | 92 | 93 | def __len__(self): 94 | return len(self.datas) 95 | 96 | def __getitem__(self, index): 97 | sample = self.datas[index] 98 | try: 99 | image_name = sample['image'] 100 | conversations = sample['conversations'] 101 | q_text = self.tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":conversations[0]['value']}], \ 102 | tokenize=False, \ 103 | add_generation_prompt=True).replace('', '<|image_pad|>'*self.config.image_pad_num) 104 | a_text = conversations[1]['value'] + self.tokenizer.eos_token 105 | q_input_ids = self.tokenizer(q_text)['input_ids'] 106 | a_input_ids = self.tokenizer(a_text)['input_ids'] 107 | input_ids = q_input_ids + a_input_ids 108 | labels = [tokenizer.pad_token_id] * len(q_input_ids) + a_input_ids 109 | input_ids = input_ids[:-1] 110 | labels = labels[1:] 111 | 112 | 113 | image = Image.open(os.path.join(self.images_path, image_name)).convert("RGB") 114 | pixel_values = self.processor(text=None, images=image)['pixel_values'] 115 | except: 116 | default_image = Image.new('RGB', (224, 224), color='white') 117 | pixel_values = self.processor(text=None, images=default_image)['pixel_values'] 118 | q_text = self.tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful assistant.'}, {"role":"user", "content":"图片内容是什么\n"}], \ 119 | tokenize=False, \ 120 | add_generation_prompt=True).replace('', '<|image_pad|>'*self.config.image_pad_num) 121 | a_text = '图片内容为空' + self.tokenizer.eos_token 122 | q_input_ids = self.tokenizer(q_text)['input_ids'] 123 | a_input_ids = self.tokenizer(a_text)['input_ids'] 124 | input_ids = q_input_ids + a_input_ids 125 | labels = [tokenizer.pad_token_id] * len(q_input_ids) + a_input_ids 126 | input_ids = input_ids[:-1] 127 | labels = labels[1:] 128 | 129 | return { 130 | 'input_ids': input_ids, 131 | 'labels': labels, 132 | 'pixel_values': pixel_values 133 | } 134 | 135 | 136 | class MyDataCollator: 137 | def __init__(self, tokenizer): 138 | self.tokenizer = tokenizer 139 | 140 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 141 | max_len = max(len(feature['input_ids']) for feature in features) 142 | input_ids = [] 143 | labels = [] 144 | pixel_values = [] 145 | for feature in features: 146 | input_ids.append(feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['input_ids']))) 147 | labels.append(feature['labels'] + [self.tokenizer.pad_token_id] * (max_len - len(feature['labels']))) 148 | pixel_values.append(feature['pixel_values']) 149 | 150 | return {'input_ids': torch.tensor(input_ids, dtype=torch.long), 151 | 'labels': torch.tensor(labels, dtype=torch.long), 152 | 'pixel_values': torch.cat(pixel_values, dim=0)} 153 | 154 | 155 | 156 | if __name__ == '__main__': 157 | config = VLMConfig(vision_model_path='/home/user/wyf/siglip-base-patch16-224', image_pad_num=49) 158 | model = VLM(config).cuda() 159 | print(model) 160 | print(f'模型参数量为:{sum(p.numel() for p in model.parameters() if p.requires_grad)}') 161 | images_path = './dataset/LLaVA-CC3M-Pretrain-595K/images' 162 | data_path = './dataset/Chinese-LLaVA-Vision-Instructions/LLaVA-CC3M-Pretrain-595K/chat-translated.json' 163 | tokenizer = AutoTokenizer.from_pretrained(config.llm_model_path) 164 | processor = AutoProcessor.from_pretrained(config.vision_model_path) 165 | output_dir = 'save/pretrain' 166 | args = TrainingArguments( 167 | output_dir=output_dir, 168 | do_train=True, 169 | per_device_train_batch_size=8, 170 | learning_rate=1e-4, 171 | num_train_epochs=5, 172 | save_steps=500, 173 | save_total_limit=2, 174 | fp16=True, 175 | gradient_accumulation_steps=8, 176 | logging_steps=100, 177 | report_to='tensorboard', 178 | dataloader_pin_memory=True, 179 | dataloader_num_workers=1 180 | ) 181 | trainer = Trainer( 182 | model=model, 183 | args=args, 184 | train_dataset=MyDataset(images_path, data_path, tokenizer, processor, config), 185 | data_collator=MyDataCollator(tokenizer) 186 | ) 187 | 188 | trainer.train(resume_from_checkpoint=False) 189 | trainer.save_model('save/pretrain') 190 | trainer.save_state() 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /train_siglip_from_scratch/clip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/train_siglip_from_scratch/clip.png -------------------------------------------------------------------------------- /train_siglip_from_scratch/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | from PIL import Image 4 | import os 5 | import torch 6 | import pandas as pd 7 | from io import BytesIO 8 | import base64 9 | from transformers import AutoTokenizer, AutoProcessor 10 | import random 11 | 12 | class SiglipDataset(Dataset): 13 | def __init__(self, text_data_path, 14 | image_data_path, 15 | tokenizer, 16 | processor, 17 | max_seq_length=64, 18 | ): 19 | super().__init__() 20 | self.text_data_path = text_data_path 21 | self.image_data_path = image_data_path 22 | self.tokenizer = tokenizer 23 | self.processor = processor 24 | self.max_seq_length = max_seq_length 25 | with open(self.text_data_path, 'r', encoding='utf-8') as f: 26 | self.datas = [] 27 | lines = f.readlines() 28 | for line in lines: 29 | line = json.loads(line) 30 | for image_id in line['image_ids']: 31 | self.datas.append({'image_id': image_id, 'text': line['text']}) 32 | 33 | random.shuffle(self.datas) 34 | 35 | self.images = pd.read_csv(self.image_data_path, sep='\t', header=None) 36 | def __getitem__(self, index): 37 | 38 | sample = self.datas[index] 39 | 40 | image_id = sample['image_id'] 41 | text = sample['text'] 42 | tok = self.tokenizer(text, max_length=self.max_seq_length, padding='max_length', truncation=True) 43 | input_ids = tok['input_ids'] 44 | attention_mask = tok['attention_mask'] 45 | image_base64 = self.images[self.images[0]==image_id][1].values[0] 46 | image_bytes = base64.b64decode(image_base64) 47 | 48 | 49 | image = Image.open(BytesIO(image_bytes)).convert("RGB") 50 | pixel_values = self.processor(images=image, return_tensors='pt')['pixel_values'] 51 | 52 | return { 53 | 'input_ids': input_ids, 54 | 'attention_mask': attention_mask, 55 | 'pixel_values': pixel_values 56 | } 57 | 58 | def __len__(self): 59 | return len(self.datas) 60 | 61 | class MyDataCollator: 62 | def __init__(self, tokenizer): 63 | self.tokenizer = tokenizer 64 | 65 | def __call__(self, features): 66 | input_ids = [f['input_ids'] for f in features] 67 | attention_mask = [f['attention_mask'] for f in features] 68 | pixel_values = [f['pixel_values'] for f in features] 69 | return { 70 | 'input_ids': torch.tensor(input_ids), 71 | 'attention_mask': torch.tensor(attention_mask), 72 | 'pixel_values': torch.cat(pixel_values, dim=0) 73 | } 74 | 75 | 76 | if __name__ == '__main__': 77 | 78 | tokenizer = AutoTokenizer.from_pretrained('/home/user/wyf/chinese-roberta-wwm-ext') 79 | processor = AutoProcessor.from_pretrained('/home/user/wyf/train_siglip_from_scratch/vit-base-patch16-224') 80 | 81 | dataset = SiglipDataset(text_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_texts.jsonl', 82 | image_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_imgs.tsv', 83 | tokenizer=tokenizer, 84 | processor=processor, 85 | max_seq_length=64) 86 | 87 | print(len(dataset)) 88 | print(dataset[2]) -------------------------------------------------------------------------------- /train_siglip_from_scratch/model.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer, AutoProcessor 2 | from transformers import ViTImageProcessor, ViTForImageClassification 3 | 4 | import torch.nn as nn 5 | from transformers.utils import ModelOutput 6 | import torch 7 | import torch.nn.functional as F 8 | from dataclasses import dataclass 9 | 10 | @dataclass 11 | class SiglipOutput(ModelOutput): 12 | loss: torch.FloatTensor = None 13 | logits_per_text: torch.FloatTensor = None 14 | logits_per_image: torch.FloatTensor = None 15 | text_embeds: torch.FloatTensor = None 16 | image_embeds: torch.FloatTensor = None 17 | 18 | 19 | 20 | 21 | class SiglipConfig(PretrainedConfig): 22 | model_type = "siglip" 23 | def __init__( 24 | self, 25 | vision_model_name_or_path: str = "vit-base-patch16-224", 26 | text_model_name_or_path: str = "bert-base-chinese", 27 | **kwargs): 28 | super().__init__(**kwargs) 29 | self.vision_model_name_or_path = vision_model_name_or_path 30 | self.text_model_name_or_path = text_model_name_or_path 31 | 32 | 33 | 34 | class SiglipModel(PreTrainedModel): 35 | config_class = SiglipConfig 36 | def __init__(self, config: SiglipConfig): 37 | super().__init__(config) 38 | self.vision_model = AutoModel.from_pretrained(config.vision_model_name_or_path) 39 | self.process = AutoProcessor.from_pretrained(config.vision_model_name_or_path) 40 | self.text_model = AutoModel.from_pretrained(config.text_model_name_or_path) 41 | self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) 42 | self.t = nn.Parameter(torch.randn(1)) 43 | self.b = nn.Parameter(torch.randn(1)) 44 | 45 | 46 | def forward(self, input_ids, attention_mask, pixel_values): 47 | 48 | 49 | text_outputs = self.text_model(input_ids, attention_mask) 50 | vision_outputs = self.vision_model(pixel_values) 51 | 52 | 53 | vision_features = vision_outputs[1] # pooler_output 54 | text_features = text_outputs[1] # pooler_output 55 | 56 | vision_features = vision_features / vision_features.norm(p=2, dim=-1, keepdim=True) # l2标准化 57 | text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) # l2标准化 58 | 59 | logits_per_text = torch.matmul(text_features, vision_features.t()) * self.t.exp() + self.b 60 | logits_per_image = logits_per_text.t() 61 | 62 | b = logits_per_text.shape[0] 63 | eye = torch.eye(b, device=logits_per_text.device) # 生成单位矩阵 64 | labels = 2*eye - torch.ones_like(logits_per_text, device=logits_per_text.device) # 对角线全为1,非对角线为-1,即成对的图文标签为1,非成对的为-1 65 | loglik = F.logsigmoid(labels * logits_per_text) 66 | nll = -torch.sum(loglik, dim=-1) 67 | loss = nll.mean() 68 | 69 | 70 | return SiglipOutput(loss=loss, logits_per_text=logits_per_text, logits_per_image=logits_per_image, text_embeds=text_features, image_embeds=vision_features) 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /train_siglip_from_scratch/siglip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/train_siglip_from_scratch/siglip.png -------------------------------------------------------------------------------- /train_siglip_from_scratch/train.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments, Trainer, default_data_collator 2 | from model import SiglipModel, SiglipConfig 3 | from dataset import SiglipDataset, MyDataCollator 4 | from transformers import AutoTokenizer, AutoProcessor 5 | from transformers import ViTImageProcessor, ViTForImageClassification 6 | 7 | def train(): 8 | 9 | config = SiglipConfig(vision_model_name_or_path='/home/user/wyf/train_siglip_from_scratch/vit-base-patch16-224', 10 | text_model_name_or_path='/home/user/wyf/chinese-roberta-wwm-ext') 11 | 12 | model = SiglipModel(config) 13 | tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) 14 | processor = AutoProcessor.from_pretrained(config.vision_model_name_or_path) 15 | 16 | args = TrainingArguments( 17 | output_dir='./outputs', 18 | do_train=True, 19 | per_device_train_batch_size=32, 20 | learning_rate=1e-4, 21 | num_train_epochs=40, 22 | save_steps=2000, 23 | save_total_limit=5, 24 | fp16=True, 25 | gradient_accumulation_steps=8, 26 | logging_steps=100, 27 | report_to='none', 28 | dataloader_pin_memory=True, 29 | dataloader_num_workers=1, 30 | ) 31 | dataset = SiglipDataset(text_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_texts.jsonl', 32 | image_data_path='/home/user/wyf/train_siglip_from_scratch/MUGE/all_imgs.tsv', 33 | tokenizer=tokenizer, 34 | processor=processor, 35 | max_seq_length=64) 36 | 37 | trainer = Trainer( 38 | model=model, 39 | args=args, 40 | train_dataset=dataset, 41 | data_collator=MyDataCollator(tokenizer) 42 | ) 43 | trainer.train(resume_from_checkpoint=True) 44 | trainer.save_model() 45 | trainer.save_state() 46 | 47 | if __name__ == '__main__': 48 | train() -------------------------------------------------------------------------------- /train_siglip_from_scratch/伪代码.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyf3/llm_related/2b82092a4e143d6184e63d6ef273bd73a0bc83ab/train_siglip_from_scratch/伪代码.png --------------------------------------------------------------------------------