├── .gitignore ├── README.md ├── chat ├── __init__.py ├── chatbot.py ├── chatglm3 │ └── chatglm3.py ├── qwen │ ├── qwen.py │ └── tokenization_util.py └── utils.py ├── config.ini ├── data ├── db_tpu │ └── .gitkeep └── uploaded │ └── .gitkeep ├── doc_processor ├── __init__.py ├── document_loaders │ ├── FilteredCSVloader.py │ ├── __init__.py │ ├── mydocloader.py │ ├── myimgloader.py │ ├── mypdfloader.py │ ├── mypptloader.py │ └── ocr.py ├── knowledge_file.py └── text_splitter │ ├── __init__.py │ ├── ali_text_splitter.py │ ├── chinese_recursive_text_splitter.py │ ├── chinese_text_splitter.py │ └── zh_title_enhance.py ├── docs ├── Environment_Install_Guide.md └── Sail_Install_Guide.md ├── embedding ├── __init__.py ├── embedding.py ├── npuengine.py └── sentence_model.py ├── models └── .gitkeep ├── requirements.txt ├── run.sh ├── scripts ├── compile.sh └── export_onnx.py ├── static ├── embedding.png ├── img1.png └── img2.png └── web_demo_st.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | data/uploaded/* 4 | !data/uploaded/.gitkeep 5 | 6 | data/db/* 7 | !data/db/.gitkeep 8 | 9 | data/db_tpu/* 10 | !data/db_tpu/.gitkeep 11 | 12 | __pycache__ 13 | 14 | nltk_data/ 15 | 16 | *.zip 17 | *.tar 18 | *.tar.gz 19 | 20 | models/ 21 | !models/.gitkeep 22 | 23 | *.bmodel 24 | *.bin 25 | *.pt 26 | 27 | *.pdf 28 | 29 | sophon-sail/ 30 | python_wheels/ 31 | *.whl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChatDoc-TPU 2 | 3 | 这个项目是基于 Sophgo TPU 实现的文档对话工具。项目可在 BM1684X 上独立部署运行。 4 | 5 | - [介绍](#介绍) 6 | - [特点](#特点) 7 | - [安装](#安装) 8 | - [安装第三方库](#安装第三方库) 9 | - [安装sail](#安装sail) 10 | - [项目结构树](#项目结构树) 11 | - [启动](#启动) 12 | - [操作说明](#操作说明) 13 | - [界面简介](#界面简介) 14 | - [上传文档](#上传文档) 15 | - [持久化知识库](#持久化知识库) 16 | - [导入知识库](#导入知识库) 17 | - [删除知识库](#删除知识库) 18 | - [重命名知识库](#重命名知识库) 19 | - [清除聊天记录](#清除聊天记录) 20 | - [移除选中文档](#移除选中文档) 21 | 22 | 23 | ## 介绍 24 | 25 | 该项目的主要目标是通过使用自然语言来简化与文档的交互,并提取有价值的信息。此项目使用LangChain、[ChatGLM3-TPU](https://github.com/sophgo/sophon-demo/tree/release/sample/ChatGLM3)或[QWEN-TPU](https://github.com/sophgo/sophon-demo/tree/release/sample/Qwen)构建,以向用户提供流畅自然的对话体验。 26 | 27 | 以 ChatGPT 为例(可替换为其他LLM,本仓库已支持 Chatglm3-6B 和 Qwen-7B,需要保证接口一致),本地知识库问答流程如下: 28 | ![Flow](<./static/embedding.png>) 29 | 30 | ## 特点 31 | 32 | - 完全本地推理。 33 | - 支持多种文档格式PDF, DOCX, TXT。 34 | - 与文档内容进行聊天,提出问题根据文档获得相关答案。 35 | - 用户友好的界面,确保流畅的交互。 36 | 37 | 38 | ## 安装 39 | 40 | 按照以下步骤,可以将这个项目部署到SophGo的设备上 41 | 42 | ### 安装第三方库 43 | ```bash 44 | cd ChatDoc-TPU 45 | # 考虑到 langchain 和 sail 版本依赖,推荐在 python>=3.8 环境运行 46 | pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 47 | ``` 48 | ### 安装sail 49 | 50 | 此例程`依赖新版本sail`,旧版本需要更新,安装方法请参考[Sail_Install_Guide](./docs/Sail_Install_Guide.md) 51 | 52 | ## 项目结构树 53 | ``` 54 | |-- ChatDoc-TPU 55 | |-- data 56 | |-- db_tpu -- 知识库持久化目录 57 | |-- uploaded -- 已上传文件目录 58 | |-- models 59 | |-- bert_model -- BERT 模型 60 | |-- glm3_model -- charglm3-6B 模型 61 | |-- qwen_model -- qwen-7B 模型 62 | |-- chat 63 | |-- chatbot.py -- ChatDoc业务逻辑脚本 64 | |-- charglm3 -- charglm3 代码 65 | |-- qwen -- qwen 代码 66 | |-- embedding -- 文本嵌入模型 67 | |-- docs -- 环境安装文档 68 | |-- static -- README中图片文件 69 | |-- README.md -- README 70 | |-- config.ini -- 推理模型配置文件 71 | |-- requirements.txt -- 项目依赖 72 | |-- run.sh -- 启动脚本 73 | |-- web_demo_st.py -- 页面交互脚本 74 | ``` 75 | 76 | ## 启动 77 | 78 | 回到`ChatDoc-TPU`主目录,启动程序,模型和配置文件自动下载,使用默认路径 79 | 80 | | Model | Cmd | 81 | | :-------------- | :------------------------------------| 82 | | ChatGLM3-6B | ./run.sh --model chatglm3 --dev_id 0 | 83 | | Qwen-7B | ./run.sh --model qwen --dev_id 0 | 84 | 85 | ```bash 86 | usage: ./run.sh [--model MODEL] [--dev_id DEV_ID] [--server_address SERVER_ADDRESS] [--server_port SERVER_PORT] 87 | --model: 选择模型,可选项为 chatglm3/qwen。默认为 "chatglm3"。 88 | --dev_id: 用于推理的 TPU 设备 ID。默认为 0。 89 | --server_address: web server 地址。默认为 "0.0.0.0"。 90 | --server_port:web sever 端口。如不设置,从 8501 起自动分配。 91 | ``` 92 | 93 | 启动后您可以通过浏览器打开,`URL: http://{host_ip}:8501`,host_ip为启动ChatDoc的设备IP,或者您通过参数设置的`server_address` 94 | 95 | > **说明**: 96 | >1. 在 `config.ini` 中可修改模型路径,默认使用int4模型 97 | >2. dev_id 需设置为 BM1684X 设备id 98 | >3. 默认使用 2k seq_len 模型,如果需要其他参数的模型,可参考[ChatGLM3模型导出与编译](https://github.com/sophgo/sophon-demo/blob/release/sample/ChatGLM3/docs/ChatGLM3_Export_Guide.md)和[Qwen模型导出与编译](https://github.com/sophgo/sophon-demo/blob/release/sample/Qwen/docs/Qwen_Export_Guide.md) 99 | >4. embedding 模型默认使用 [shibing624/text2vec-bge-large-chinese](https://huggingface.co/shibing624/text2vec-bge-large-chinese),导出模型方法可参考 [export_onnx.py](./scripts/export_onnx.py) 100 | 101 | ## 操作说明 102 | 103 | ![UI](<./static/img1.png>) 104 | 105 | ### 界面简介 106 | ChatDoc由控制区和聊天对话区组成。控制区用于管理文档和知识库,聊天对话区用于输入和接受消息。 107 | 108 | 上图中的10号区域是 ChatDoc 当前选中的文档。若10号区域为空,即 ChatDoc 没有选中任何文档,仍在聊天对话区与 ChatDoc 对话,则此时的 ChatDoc 是一个单纯依托 LLM 的 ChatBot。 109 | 110 | ### 上传文档 111 | 点击`1`选择要上传的文档,然后点击按钮`4`构建知识库。随后将embedding文档,完成后将被选中,并显示在10号区域,接着就可开始对话。我们可重复上传文档,embedding成功的文档均会进入10号区域。 112 | 113 | ### 持久化知识库 114 | 10号区域选中的文档在用户刷新或者关闭页面时,将会清空。若用户需要保存这些已经embedding的文档,可以选择持久化知识库,下次进入时无需embedding计算即可加载知识库。具体做法是,在10号区域不为空的情况下,点击按钮`5`即可持久化知识库,知识库的名称是所有文档名称以逗号连接而成。 115 | 116 | ### 导入知识库 117 | 118 | 用户可以从选择框`2`查看目前已持久化的知识库。选中我们需要加载的知识库后,点击按钮`3`导入知识库。完成后即可开始对话。注意cpu版的知识库和tpu版的知识库不能混用,若启动tpu版程序,则不能加载已持久化的cpu版知识库;若启动cpu版程序,则不能加载已持久化的tpu版知识库。 119 | 120 | ### 删除知识库 121 | 122 | 当用户需要删除本地已经持久化的知识库时,可从选择框`2`选择要删除的知识库,然后点击按钮`6`删除知识库。 123 | 124 | ### 重命名知识库 125 | 126 | ![Rename](<./static/img2.png>) 127 | 128 | 由于知识库的命名是由其文档的名称组合而来,难免造成知识库名称过长的问题。ChatDoc提供了一个修改知识库名称的功能,选择框`2`选择我们要修改的知识库,然后点击按钮`9`重命名知识库,随后ChatDoc将弹出一个输入框和一个确认按钮,如上图。在输出框输入修改后的名称,然后点击`确认重命名`按钮。 129 | 130 | ### 清除聊天记录 131 | 132 | 点击按钮`7`即可清除聊天对话区聊天记录。其他不受影响。 133 | 134 | ### 移除选中文档 135 | 136 | 点击按钮`8`将清空10号区域,同时清除聊天记录。 -------------------------------------------------------------------------------- /chat/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | from .chatbot import DocChatbot -------------------------------------------------------------------------------- /chat/chatbot.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | import os 11 | import shutil 12 | import time 13 | import numpy as np 14 | from datetime import datetime 15 | import faiss 16 | from langchain.document_loaders import UnstructuredPowerPointLoader, UnstructuredWordDocumentLoader, \ 17 | UnstructuredPDFLoader, UnstructuredFileLoader 18 | import logging 19 | import pickle 20 | from langchain.text_splitter import RecursiveCharacterTextSplitter 21 | from typing import List 22 | from glob import glob 23 | from tqdm import tqdm 24 | 25 | from embedding import Word2VecEmbedding 26 | from .chatglm3.chatglm3 import Chatglm3 27 | from .qwen.qwen import Qwen 28 | import doc_processor 29 | from doc_processor.knowledge_file import KnowledgeFile 30 | 31 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 32 | 33 | 34 | class DocChatbot: 35 | _instance = None 36 | 37 | def __init__(self) -> None: 38 | self.llm = None 39 | 40 | llm_model = os.getenv("LLM_MODEL") 41 | dev_id = 0 42 | if os.getenv("DEVICE_ID"): 43 | dev_id = int(os.getenv("DEVICE_ID")) 44 | else: 45 | logging.warning("DEVICE_ID is empty in env var, use default {}".format(dev_id)) 46 | if llm_model == "chatglm3": 47 | self.llm = Chatglm3(dev_id) 48 | elif llm_model == "qwen7b": 49 | self.llm = Qwen(dev_id) 50 | else: 51 | self.llm = Chatglm3(dev_id) 52 | logging.warning("llm_model env var empty, use default chatglm3") 53 | 54 | self.vector_db = None 55 | self.string_db = None 56 | self.files = None 57 | 58 | self.db_base_path = "data/db_tpu" 59 | # embeddings_size hard code here, can read from model output size 60 | self.embeddings_size = 1024 61 | self.embeddings = Word2VecEmbedding() 62 | logging.info("chatbot init success!") 63 | 64 | def docs2embedding(self, docs): 65 | emb = [] 66 | for i in tqdm(range(len(docs) // 4)): 67 | emb += self.embeddings.embed_documents(docs[i * 4: i * 4 + 4]) 68 | if len(docs) % 4 != 0: 69 | residue = docs[-(len(docs) % 4):] + [" " for _ in range(4 - len(docs) % 4)] 70 | emb += self.embeddings.embed_documents(residue)[:len(docs) % 4] 71 | 72 | return emb 73 | 74 | def query_from_doc(self, query_string, k=1): 75 | query_vec = self.embeddings.embed_query(query_string) 76 | _, i = self.vector_db.search(x=np.array([query_vec]), k=k) 77 | return [self.string_db[ind] for ind in i[0]] 78 | 79 | # split documents, generate embeddings and ingest to vector db 80 | def init_vector_db_from_documents(self, file_list: List[str]): 81 | docs = [] 82 | for file in file_list: 83 | kb_file = KnowledgeFile(filename=file) 84 | doc = kb_file.docs2texts() 85 | docs.extend(doc) 86 | 87 | # 文件解析失败 88 | if len(docs) == 0: 89 | return False 90 | 91 | emb_num = 0 92 | start_time = time.time() 93 | if self.vector_db is None: 94 | self.files = ", ".join([item.split("/")[-1] for item in file_list]) 95 | emb = self.docs2embedding([x.page_content for x in docs]) 96 | emb = np.array(emb).astype(np.float32) 97 | if not emb.flags['C_CONTIGUOUS']: 98 | emb = np.ascontiguousarray(emb) 99 | emb_num = len(emb) 100 | self.vector_db = faiss.IndexFlatL2(self.embeddings_size) 101 | self.vector_db.add(emb) 102 | self.string_db = docs 103 | else: 104 | self.files = self.files + ", " + ", ".join([item.split("/")[-1] for item in file_list]) 105 | emb = self.docs2embedding([x.page_content for x in docs]) 106 | emb_num = len(emb) 107 | self.vector_db.add(np.array(emb)) 108 | self.string_db += docs 109 | 110 | logging.info("Total embedding docs time {}, embedding vector size {}, embedding vector num {}".format(time.time()- start_time, self.embeddings_size, emb_num)) 111 | return True 112 | 113 | def load_vector_db_from_local(self, index_name: str): 114 | with open(f"{self.db_base_path}/{index_name}/db.string", "rb") as file: 115 | byte_stream = file.read() 116 | self.string_db = pickle.loads(byte_stream) 117 | self.vector_db = faiss.read_index(f"{self.db_base_path}/{index_name}/db.index") 118 | self.files = open(f"{self.db_base_path}/{index_name}/name.txt", 'r', encoding='utf-8').read() 119 | 120 | def save_vector_db_to_local(self): 121 | now = datetime.now() 122 | folder_name = now.strftime("%Y-%m-%d_%H-%M-%S-%f") 123 | os.mkdir(f"{self.db_base_path}/{folder_name}") 124 | faiss.write_index(self.vector_db, f"{self.db_base_path}/{folder_name}/db.index") 125 | byte_stream = pickle.dumps(self.string_db) 126 | with open(f"{self.db_base_path}/{folder_name}/db.string", "wb") as file: 127 | file.write(byte_stream) 128 | with open(f"{self.db_base_path}/{folder_name}/name.txt", "w", encoding="utf-8") as file: 129 | file.write(self.files) 130 | 131 | def del_vector_db(self, file_name): 132 | shutil.rmtree(f"{self.db_base_path}/" + file_name) 133 | self.vector_db = None 134 | 135 | def get_vector_db(self): 136 | file_list = glob(f"{self.db_base_path}/*") 137 | return [x.split("/")[-1] for x in file_list] 138 | 139 | def time2file_name(self, path): 140 | return open(f"{self.db_base_path}/{path}/name.txt", 'r', encoding='utf-8').read() 141 | 142 | def load_first_vector_db(self): 143 | file_list = glob(f"{self.db_base_path}/*") 144 | index_name = file_list[0].split("/")[-1] 145 | self.load_vector_db_from_local(index_name) 146 | 147 | def rename(self, file_list, new_name): 148 | with open(f"{self.db_base_path}/{file_list}/name.txt", "w", encoding="utf-8") as file: 149 | file.write(new_name) 150 | 151 | def stream_predict(self, query, history): 152 | history.append((query, '')) 153 | res = '' 154 | response = "根据文件内容,这个合同的甲方(购买方)是内蒙古北方航空科技有限公司。" 155 | for i in response: 156 | res += i 157 | time.sleep(0.01) 158 | history[-1] = (query, res) 159 | yield res, history 160 | 161 | def filter_space(self, string): 162 | result = "" 163 | count = 0 164 | for char in string: 165 | if char == " " or char == '\t': 166 | count += 1 167 | if count < 4: 168 | result += char 169 | else: 170 | result += char 171 | count = 0 172 | return result 173 | 174 | @classmethod 175 | def get_instance(cls): 176 | if cls._instance is None: 177 | cls._instance = DocChatbot() 178 | return cls._instance 179 | 180 | -------------------------------------------------------------------------------- /chat/chatglm3/chatglm3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | import configparser 11 | from transformers import AutoTokenizer 12 | import numpy as np 13 | import time 14 | import logging 15 | import sophon.sail as sail 16 | 17 | from ..utils import fp16_cast, type_convert 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 19 | 20 | 21 | class Chatglm3: 22 | def __init__(self, dev_id=0): 23 | config = configparser.ConfigParser() 24 | config.read('config.ini') 25 | bmodel_path = config.get('chatglm3', 'bmodel_path') 26 | token_path = config.get('chatglm3', 'token_path') 27 | # load tokenizer 28 | self.input_str = "" 29 | self.system = [{"role":"system", 30 | "content":"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown."}] 31 | self.history = [] 32 | self.sp = AutoTokenizer.from_pretrained(token_path, trust_remote_code=True) 33 | logging.info("load {} success!".format(token_path)) 34 | # warm up 35 | self.sp.decode([0]) 36 | self.EOS = self.sp.eos_token_id 37 | 38 | # load bmodel 39 | # 这里devio,后面都没有创建系统内存的tensor 40 | self.net = sail.Engine(bmodel_path, dev_id, sail.IOMode.DEVIO) 41 | logging.info("load {} success, dev_id {}".format(bmodel_path, dev_id)) 42 | self.handle = sail.Handle(dev_id) 43 | self.graph_names = self.net.get_graph_names() 44 | 45 | # initialize glm parameters 46 | self.NUM_LAYERS = (len(self.graph_names) - 2) // 2 47 | self.first_hidden_input_shape = self.net.get_input_shape("block_0", self.net.get_input_names("block_0")[0]) 48 | self.SEQLEN, _, self.HIDDEN_SIZE = self.first_hidden_input_shape 49 | 50 | self.name_embed = "embedding" 51 | self.name_embed_cache = "embedding_cache" 52 | self.name_lm = "lm_head" 53 | self.name_blocks = ["block_"+str(i) for i in range(self.NUM_LAYERS)] 54 | self.name_blocks_cache = ["block_cache_"+str(i) for i in range(self.NUM_LAYERS)] 55 | 56 | # tensors: 57 | # forward_first: embedding_tensor 58 | self.first_embed_input = self.init_sail_tensor(self.name_embed, 0, [1, self.SEQLEN]) 59 | self.first_embed_output = self.init_sail_tensor(self.name_embed, 0, [1, self.SEQLEN, self.HIDDEN_SIZE], False) 60 | 61 | # forward_next: embedding_tensor 62 | self.next_embed_input = self.init_sail_tensor(self.name_embed_cache, 0, [1, 1]) 63 | self.next_embed_output = self.init_sail_tensor(self.name_embed_cache, 0, [1, 1, self.HIDDEN_SIZE], False) 64 | 65 | # forward_first: hidden_state 66 | self.first_hidden_input = self.init_sail_tensor(self.name_blocks[0], 0) 67 | self.first_hidden_output = self.init_sail_tensor(self.name_blocks[0], 0, None, False) 68 | 69 | # forward_next: hidden_state 70 | self.next_hidden_input = self.init_sail_tensor(self.name_blocks_cache[0], 0) 71 | self.next_hidden_output = self.init_sail_tensor(self.name_blocks_cache[0], 0, None, False) 72 | 73 | # forward_first: position_id_tensor 和 attention_mask_tensor 74 | self.first_pid = self.init_sail_tensor(self.name_blocks[0], 1) 75 | self.first_attention = self.init_sail_tensor(self.name_blocks[0], 2) 76 | 77 | # forward_next: position_id_tensor and attention_mask_tensor 78 | self.next_pid = self.init_sail_tensor(self.name_blocks_cache[0], 1) 79 | self.next_attention = self.init_sail_tensor(self.name_blocks_cache[0], 2) 80 | 81 | # forward_next: present_key / present_value (for update kv_cache) 82 | self.present_key = self.init_sail_tensor(self.name_blocks_cache[0], 1, None, False) 83 | self.present_value = self.init_sail_tensor(self.name_blocks_cache[0], 2, None, False) 84 | 85 | # forward_first: key_tensor 和 value_tensor 86 | self.past_key_output = [] 87 | self.past_value_output = [] 88 | 89 | # forward_next: cache block的kv tensor名 90 | self.cache_key_input = [] 91 | self.cache_key_output = [] 92 | self.cache_value_input = [] 93 | self.cache_value_output = [] 94 | 95 | for i in range(self.NUM_LAYERS): 96 | self.past_key_output.append(self.init_sail_tensor(self.name_blocks[0], 1, None, False)) 97 | self.past_value_output.append(self.init_sail_tensor(self.name_blocks[0], 2, None, False)) 98 | self.past_key_output[i]["data"].memory_set(0) 99 | self.past_value_output[i]["data"].memory_set(0) 100 | 101 | self.cache_key_input.append(self.init_sail_tensor(self.name_blocks_cache[0], 3)) 102 | self.cache_key_output.append(self.init_sail_tensor(self.name_blocks_cache[0], 1, None, False)) 103 | 104 | self.cache_value_input.append(self.init_sail_tensor(self.name_blocks_cache[0], 4)) 105 | self.cache_value_output.append(self.init_sail_tensor(self.name_blocks_cache[0], 2, None, False)) 106 | 107 | # lm_head tensor 108 | self.lm_input = self.init_sail_tensor(self.name_lm, 0) 109 | self.lm_output = self.init_sail_tensor(self.name_lm, 0, None, False) 110 | 111 | self.token_length = 0 112 | self.round = 0 113 | 114 | def init_sail_tensor(self, name, tensor_idx, shape=None, is_input=True): 115 | """ 116 | init a sail tensor of sail.engine. 117 | parameters: 118 | input: 119 | name: str, graph_name/net_name 120 | tensor_idx: int, input/output tensor id 121 | shape: list[int], shape of tensor 122 | is_input: bool, is input tensor or not 123 | return: 124 | dict 125 | """ 126 | tensor = {} 127 | if is_input: 128 | tensor["name"] = self.net.get_input_names(name)[tensor_idx] 129 | tensor["shape"] = self.net.get_input_shape(name, tensor["name"]) if shape is None else shape 130 | tensor["dtype"] = self.net.get_input_dtype(name, tensor["name"]) 131 | tensor["data"] = sail.Tensor(self.handle, tensor["shape"], tensor["dtype"], False, True) 132 | else: 133 | tensor["name"] = self.net.get_output_names(name)[tensor_idx] 134 | tensor["shape"] = self.net.get_output_shape(name, tensor["name"]) if shape is None else shape 135 | tensor["dtype"] = self.net.get_output_dtype(name, tensor["name"]) 136 | tensor["data"] = sail.Tensor(self.handle, tensor["shape"], tensor["dtype"], False, True) 137 | return tensor 138 | 139 | def generate_tokens(self, input_str): 140 | if not self.history or self.history[0]["role"] != "system": 141 | self.history = self.system + self.history 142 | tokens = self.sp.build_chat_input(input_str, history=self.history, role="user") 143 | return tokens 144 | 145 | def forward_first(self, token): 146 | # Keep 147 | input_ids = np.zeros(self.SEQLEN, type_convert(self.first_embed_input["dtype"])) 148 | input_ids[:min(self.SEQLEN, len(token))] = token 149 | self.token_length = len(token) 150 | input_ids = input_ids.reshape(1, -1) 151 | 152 | position_id = np.zeros(self.SEQLEN, type_convert(self.first_pid["dtype"])) 153 | for i in range(self.token_length): 154 | position_id[i] = i 155 | 156 | attention_mask = np.zeros(self.SEQLEN*self.SEQLEN, type_convert(self.first_attention["dtype"])) #这里的type要从模型获取。 157 | for i in range(self.SEQLEN): 158 | for j in range(self.SEQLEN): 159 | if not (j <= i and i < self.token_length): 160 | attention_mask[i*self.SEQLEN + j] = -10000.0 161 | # embedding 162 | self.first_embed_input["data"].update_data(fp16_cast(input_ids)) 163 | input_embed_tensors = {self.first_embed_input["name"]: self.first_embed_input["data"]} 164 | output_embed_tensors = {self.first_embed_output["name"]: self.first_embed_output["data"]} 165 | self.net.process(self.name_embed, input_embed_tensors, output_embed_tensors) 166 | 167 | # blocks 168 | self.first_hidden_tensor = self.first_embed_output["data"] 169 | self.first_hidden_tensor.reshape(self.first_hidden_input["shape"]) 170 | self.first_pid["data"].update_data(fp16_cast(position_id.reshape(self.first_pid["shape"]))) 171 | self.first_attention["data"].update_data(fp16_cast(attention_mask.reshape(self.first_attention["shape"]))) 172 | 173 | input_blocks_tensors = {self.first_hidden_input["name"]: self.first_hidden_tensor, 174 | self.first_pid["name"]: self.first_pid["data"], 175 | self.first_attention["name"]: self.first_attention["data"]} 176 | 177 | for i in range(self.NUM_LAYERS): 178 | output_blocks_tensors = {self.first_hidden_output["name"]: self.first_hidden_tensor, 179 | self.past_key_output[i]["name"]: self.present_key["data"], 180 | self.past_value_output[i]["name"]: self.present_value["data"],} 181 | self.net.process(self.name_blocks[i], input_blocks_tensors, output_blocks_tensors) 182 | 183 | unit_size = np.prod(self.present_key["shape"][1:]) 184 | self.past_key_output[i]["data"].sync_d2d(self.present_key["data"], 0, (self.SEQLEN - self.token_length)*unit_size, self.token_length * unit_size) 185 | self.past_value_output[i]["data"].sync_d2d(self.present_value["data"], 0, (self.SEQLEN - self.token_length)*unit_size, self.token_length * unit_size) 186 | 187 | # lm_head 188 | # hidden_states 的最后一个位置的元素取出来作为 lm_head的输入 189 | copy_len = self.first_hidden_tensor.shape()[-1] 190 | self.lm_input["data"].sync_d2d(self.first_hidden_tensor, 191 | (self.token_length-1)* copy_len, 192 | 0, 193 | copy_len) 194 | 195 | input_lm_tensors = {self.lm_input["name"]: self.lm_input["data"]} 196 | output_lm_tensors = {self.lm_output["name"]: self.lm_output["data"]} 197 | 198 | self.net.process(self.name_lm, input_lm_tensors, output_lm_tensors) 199 | return int(self.lm_output["data"].asnumpy()) 200 | 201 | def forward_next(self, ): 202 | attention_mask = np.zeros(self.SEQLEN+1, type_convert(self.next_attention["dtype"])) 203 | for i in range(self.SEQLEN - self.token_length + 1): 204 | attention_mask[i] = -10000.0 205 | position_id = np.array(self.token_length - 1, type_convert(self.next_pid["dtype"])) 206 | 207 | # embedding 208 | self.next_embed_input["data"] = self.lm_output["data"] 209 | self.next_embed_input["data"].reshape(self.next_embed_input["shape"]) 210 | 211 | input_embed_tensors = {self.next_embed_input["name"]: self.next_embed_input["data"]} 212 | output_embed_tensors = {self.next_embed_output["name"]: self.next_embed_output["data"]} 213 | self.net.process(self.name_embed_cache, input_embed_tensors, output_embed_tensors) 214 | 215 | # blocks 216 | self.next_pid["data"].update_data(fp16_cast(position_id.reshape(self.next_pid["shape"]))) 217 | self.next_attention["data"].update_data(fp16_cast(attention_mask.reshape(self.next_attention["shape"]))) 218 | 219 | self.next_hidden_tensor = self.next_embed_output["data"] 220 | self.next_hidden_tensor.reshape(self.next_hidden_input["shape"]) 221 | 222 | for i in range(self.NUM_LAYERS): 223 | inputs_block_cache_tensors = {self.next_hidden_input["name"]: self.next_hidden_tensor, 224 | self.next_pid["name"]: self.next_pid["data"], 225 | self.next_attention["name"]: self.next_attention["data"], 226 | self.cache_key_input[i]["name"]: self.past_key_output[i]["data"], 227 | self.cache_value_input[i]["name"]: self.past_value_output[i]["data"]} 228 | outputs_block_cache_tensors = {self.next_hidden_output["name"]: self.next_hidden_tensor, 229 | self.cache_key_output[i]["name"]: self.past_key_output[i]["data"], 230 | self.cache_value_output[i]["name"]: self.past_value_output[i]["data"]} 231 | self.net.process(self.name_blocks_cache[i], inputs_block_cache_tensors, outputs_block_cache_tensors) 232 | 233 | self.lm_input_tensor = self.next_hidden_tensor 234 | self.lm_input_tensor.reshape(self.lm_input["shape"]) 235 | 236 | input_lm_tensors = {self.lm_input["name"]: self.lm_input_tensor} 237 | output_lm_tensors = {self.lm_output["name"]: self.lm_output["data"]} 238 | self.net.process(self.name_lm, input_lm_tensors, output_lm_tensors) 239 | return int(self.lm_output["data"].asnumpy()) #int32 240 | 241 | 242 | def build_prompt(self, query, history): 243 | prompt = [] 244 | # import pdb; pdb.set_trace() 245 | for i in range(0, len(history)): 246 | prompt.extend([{"role":"user", "content":history[i][0]}, 247 | {"role":"assistant", "content":history[i][1]}]) 248 | prompt = self.system + prompt 249 | prompt = self.sp.build_chat_input(query, history=prompt, role="user") 250 | return prompt 251 | 252 | def stream_predict(self, input_str, history): 253 | # import pdb; pdb.set_trace() 254 | prompt = self.build_prompt(input_str, history) 255 | history.append((input_str, '')) 256 | tok_num = 0 257 | answer_cur = [] 258 | tokens = prompt 259 | # input is empty 260 | if not tokens: 261 | logging.error("Sorry: your question is too wierd!!") 262 | return 263 | if len(tokens) > self.SEQLEN: 264 | logging.warning("The maximum question length should be shorter than {} but we get {} instead, \ 265 | history will be cleared, please ask again".format(self.SEQLEN, len(tokens))) 266 | return 267 | 268 | first_start = time.time() 269 | token = self.forward_first(tokens) 270 | first_end = time.time() 271 | pre_token = 30910 272 | pre_ids = [pre_token] 273 | pre_word= self.sp.decode(pre_ids) 274 | # Sentencepiece will remove space token if the token list it receive has only one token, we add a pre_token so that space token will not be removed. 275 | while token != self.EOS and self.token_length < self.SEQLEN: 276 | ids = [pre_token, token] 277 | word = self.sp.decode(ids) 278 | diff = word[len(pre_word):] 279 | answer_cur += [token] 280 | yield self.sp.decode(answer_cur), history 281 | print(diff, flush=True, end='') 282 | self.token_length += 1 283 | tok_num += 1 284 | token = self.forward_next() 285 | 286 | # 计时 287 | next_end = time.time() 288 | first_duration = first_end-first_start 289 | next_duration = next_end-first_end 290 | tps = tok_num / next_duration 291 | 292 | print() 293 | print(f"FTL: {first_duration:.3f} s") 294 | print(f"TPS: {tps:.3f} token/s") 295 | -------------------------------------------------------------------------------- /chat/qwen/qwen.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | import configparser 11 | import time 12 | from .tokenization_util import make_context 13 | from transformers import AutoTokenizer 14 | import numpy as np 15 | import logging 16 | import sophon.sail as sail 17 | 18 | from ..utils import fp16_cast, type_convert 19 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 20 | 21 | 22 | class Qwen: 23 | def __init__(self, dev_id = 0): 24 | config = configparser.ConfigParser() 25 | config.read('config.ini') 26 | bmodel_path = config.get('qwen7b', 'bmodel_path') 27 | token_path = config.get('qwen7b', 'token_path') 28 | self.input_str = "" 29 | self.system_prompt = "You are QWEN, a large language model. Follow the user's instructions carefully." 30 | self.history = [] 31 | 32 | # load tokenizer 33 | self.sp = AutoTokenizer.from_pretrained(token_path, trust_remote_code=True) 34 | logging.info("load {} success, dev_id {}".format(bmodel_path, dev_id)) 35 | # warm up 36 | self.sp.decode([0]) 37 | self.EOS = self.sp.im_end_id 38 | 39 | # load bmodel 40 | # 这里devio,后面都没有创建系统内存的tensor 41 | self.net = sail.Engine(bmodel_path, dev_id, sail.IOMode.DEVIO) 42 | logging.info("load {} success!".format(bmodel_path)) 43 | self.handle = sail.Handle(dev_id) 44 | self.graph_names = self.net.get_graph_names() 45 | 46 | # initialize qwen parameters 47 | self.NUM_LAYERS = (len(self.graph_names) - 2) // 2 48 | self.first_hidden_input_shape = self.net.get_input_shape("block_0", self.net.get_input_names("block_0")[0]) 49 | _, self.SEQLEN, self.HIDDEN_SIZE = self.first_hidden_input_shape 50 | 51 | # initialize net name 52 | self.name_embed = "embedding" 53 | self.name_embed_cache = "embedding_cache" 54 | self.name_lm = "lm_head" 55 | self.name_blocks = ["block_"+str(i) for i in range(self.NUM_LAYERS)] 56 | self.name_blocks_cache = ["block_cache_"+str(i) for i in range(self.NUM_LAYERS)] 57 | 58 | # initialize tensors (inputs & outputs) 59 | # forward_first: embedding_tensor 60 | self.first_embed_input = self.init_sail_tensor(self.name_embed, 0, [1, self.SEQLEN]) 61 | self.first_embed_output = self.init_sail_tensor(self.name_embed, 0, [1, self.SEQLEN, self.HIDDEN_SIZE], False) 62 | 63 | # forward_next: embedding_tensor 64 | self.next_embed_input = self.init_sail_tensor(self.name_embed_cache, 0, [1, 1]) 65 | self.next_embed_output = self.init_sail_tensor(self.name_embed_cache, 0, [1, self.HIDDEN_SIZE], False) 66 | 67 | # forward_first: hidden_state 68 | self.first_hidden_input = self.init_sail_tensor(self.name_blocks[0], 0) 69 | self.first_hidden_output = self.init_sail_tensor(self.name_blocks[0], 0, None, False) 70 | 71 | # forward_next: hidden_state 72 | self.next_hidden_input = self.init_sail_tensor(self.name_blocks_cache[0], 0) 73 | self.next_hidden_output = self.init_sail_tensor(self.name_blocks_cache[0], 0, None, False) 74 | 75 | # forward_first: position_id_tensor and attention_mask_tensor 76 | self.first_pid = self.init_sail_tensor(self.name_blocks[0], 1) 77 | self.first_attention = self.init_sail_tensor(self.name_blocks[0], 2) 78 | 79 | # forward_next: position_id_tensor and attention_mask_tensor 80 | self.next_pid = self.init_sail_tensor(self.name_blocks_cache[0], 1) 81 | self.next_attention = self.init_sail_tensor(self.name_blocks_cache[0], 2) 82 | 83 | # forward_next: present_key / present_value (for update kv_cache) 84 | self.present_key = self.init_sail_tensor(self.name_blocks_cache[0], 1, None, False) 85 | self.present_value = self.init_sail_tensor(self.name_blocks_cache[0], 2, None, False) 86 | 87 | # forward_first: key_tensor and value_tensor 88 | self.past_key_output = [] 89 | self.past_value_output = [] 90 | 91 | # forward_next: kv cache block 92 | self.cache_key_input = [] 93 | self.cache_key_output = [] 94 | self.cache_value_input = [] 95 | self.cache_value_output = [] 96 | 97 | for _ in range(self.NUM_LAYERS): 98 | self.past_key_output.append(self.init_sail_tensor(self.name_blocks[0], 1, None, False)) 99 | self.past_value_output.append(self.init_sail_tensor(self.name_blocks[0], 2, None, False)) 100 | 101 | self.cache_key_input.append(self.init_sail_tensor(self.name_blocks_cache[0], 3)) 102 | self.cache_key_output.append(self.init_sail_tensor(self.name_blocks_cache[0], 1, None, False)) 103 | 104 | self.cache_value_input.append(self.init_sail_tensor(self.name_blocks_cache[0], 4)) 105 | self.cache_value_output.append(self.init_sail_tensor(self.name_blocks_cache[0], 2, None, False)) 106 | 107 | # lm_head tensor 108 | self.lm_input = self.init_sail_tensor(self.name_lm, 0) 109 | self.lm_output = self.init_sail_tensor(self.name_lm, 0, None, False) 110 | 111 | self.token_length = 0 112 | 113 | def init_sail_tensor(self, name, tensor_idx, shape=None, is_input=True): 114 | """ 115 | init a sail tensor of sail.engine. 116 | parameters: 117 | input: 118 | name: str, graph_name/net_name 119 | tensor_idx: int, input/output tensor id 120 | shape: list[int], shape of tensor 121 | is_input: bool, is input tensor or not 122 | return: 123 | dict 124 | """ 125 | tensor = {} 126 | if is_input: 127 | tensor["name"] = self.net.get_input_names(name)[tensor_idx] 128 | tensor["shape"] = self.net.get_input_shape(name, tensor["name"]) if shape is None else shape 129 | tensor["dtype"] = self.net.get_input_dtype(name, tensor["name"]) 130 | tensor["data"] = sail.Tensor(self.handle, tensor["shape"], tensor["dtype"], False, True) 131 | else: 132 | tensor["name"] = self.net.get_output_names(name)[tensor_idx] 133 | tensor["shape"] = self.net.get_output_shape(name, tensor["name"]) if shape is None else shape 134 | tensor["dtype"] = self.net.get_output_dtype(name, tensor["name"]) 135 | tensor["data"] = sail.Tensor(self.handle, tensor["shape"], tensor["dtype"], False, True) 136 | return tensor 137 | 138 | # inference for the first token 139 | def forward_first(self, token): 140 | input_ids = np.zeros(self.SEQLEN, type_convert(self.first_embed_input["dtype"])) 141 | input_ids[:min(self.SEQLEN, len(token))] = token 142 | input_ids = input_ids.reshape(1, -1) 143 | self.token_length = len(token) 144 | position_id = np.zeros(self.SEQLEN, type_convert(self.first_pid["dtype"])) 145 | for i in range(self.token_length): 146 | position_id[i] = i 147 | 148 | attention_mask = np.ones(self.SEQLEN*self.SEQLEN, type_convert(self.first_attention["dtype"])) * (-10000.0) 149 | for i in range(self.token_length): 150 | for j in range(self.SEQLEN): 151 | if (j <= i): 152 | attention_mask[i*self.SEQLEN + j] = 0 153 | 154 | # embedding 155 | self.first_embed_input["data"].update_data(input_ids) 156 | input_embed_tensors = {self.first_embed_input["name"]: self.first_embed_input["data"]} 157 | output_embed_tensors = {self.first_embed_output["name"]: self.first_embed_output["data"]} 158 | 159 | # Embedding Layer Inference 160 | self.net.process(self.name_embed, input_embed_tensors, output_embed_tensors) 161 | 162 | # blocks 163 | self.first_hidden_tensor = self.first_embed_output["data"] 164 | self.first_hidden_tensor.reshape(self.first_hidden_input["shape"]) 165 | self.first_pid["data"].update_data(position_id.reshape(self.first_pid["shape"])) 166 | self.first_attention["data"].update_data(fp16_cast(attention_mask.reshape(self.first_attention["shape"]))) # set bf16 in the future. 167 | 168 | input_blocks_tensors = {self.first_hidden_input["name"]: self.first_hidden_tensor, 169 | self.first_pid["name"]: self.first_pid["data"], 170 | self.first_attention["name"]: self.first_attention["data"]} 171 | 172 | # Transformer Block Inference 173 | for i in range(self.NUM_LAYERS): 174 | output_blocks_tensors = {self.first_hidden_output["name"]: self.first_hidden_tensor, 175 | self.past_key_output[i]["name"]: self.past_key_output[i]["data"], 176 | self.past_value_output[i]["name"]: self.past_value_output[i]["data"]} 177 | 178 | self.net.process(self.name_blocks[i], input_blocks_tensors, output_blocks_tensors) 179 | 180 | # get the last token info as Lm head input 181 | copy_len = self.first_hidden_tensor.shape()[-1] 182 | self.lm_input["data"].sync_d2d(self.first_hidden_tensor, 183 | (self.token_length-1)* copy_len, 184 | 0, 185 | copy_len) 186 | 187 | input_lm_tensors = {self.lm_input["name"]: self.lm_input["data"]} 188 | output_lm_tensors = {self.lm_output["name"]: self.lm_output["data"]} 189 | 190 | # Lm_head Inference 191 | self.net.process(self.name_lm, input_lm_tensors, output_lm_tensors) 192 | return int(self.lm_output["data"].asnumpy()) 193 | 194 | # The following tokens prediction 195 | def forward_next(self, ): 196 | attention_mask = np.zeros(self.SEQLEN+1, type_convert(self.next_attention["dtype"])) 197 | for i in range(self.token_length-1, self.SEQLEN): 198 | attention_mask[i] = -10000.0 199 | position_id = np.array(self.token_length - 1, type_convert(self.next_pid["dtype"])) 200 | 201 | # embedding 202 | self.next_embed_input["data"] = self.lm_output["data"] 203 | self.next_embed_input["data"].reshape(self.next_embed_input["shape"]) 204 | 205 | input_embed_tensors = {self.next_embed_input["name"]: self.next_embed_input["data"]} 206 | output_embed_tensors = {self.next_embed_output["name"]: self.next_embed_output["data"]} 207 | # Embedding Layer Inference 208 | self.net.process(self.name_embed_cache, input_embed_tensors, output_embed_tensors) 209 | 210 | # blocks 211 | self.next_pid["data"].update_data(position_id.reshape(self.next_pid["shape"])) 212 | self.next_attention["data"].update_data(fp16_cast(attention_mask.reshape(self.next_attention["shape"]))) 213 | 214 | self.next_hidden_tensor = self.next_embed_output["data"] 215 | self.next_hidden_tensor.reshape(self.next_hidden_input["shape"]) 216 | 217 | # Transformer Block Inference 218 | for i in range(self.NUM_LAYERS): 219 | inputs_block_cache_tensors = {self.next_hidden_input["name"]: self.next_hidden_tensor, 220 | self.next_pid["name"]: self.next_pid["data"], 221 | self.next_attention["name"]: self.next_attention["data"], 222 | self.cache_key_input[i]["name"]: self.past_key_output[i]["data"], 223 | self.cache_value_input[i]["name"]: self.past_value_output[i]["data"]} 224 | outputs_block_cache_tensors = {self.next_hidden_output["name"]: self.next_hidden_tensor, 225 | self.cache_key_output[i]["name"]: self.present_key["data"], 226 | self.cache_value_output[i]["name"]: self.present_value["data"]} 227 | self.net.process(self.name_blocks_cache[i], inputs_block_cache_tensors, outputs_block_cache_tensors) 228 | 229 | # update kv_cache() 230 | unit_size = self.present_key["shape"][-1]*self.present_key["shape"][-2] 231 | self.past_key_output[i]["data"].sync_d2d(self.present_key["data"], 0, (self.token_length-1)*unit_size, unit_size) 232 | self.past_value_output[i]["data"].sync_d2d(self.present_value["data"], 0, (self.token_length-1)*unit_size, unit_size) 233 | 234 | self.lm_input_tensor = self.next_hidden_tensor 235 | self.lm_input_tensor.reshape(self.lm_input["shape"]) 236 | 237 | input_lm_tensors = {self.lm_input["name"]: self.lm_input_tensor} 238 | output_lm_tensors = {self.lm_output["name"]: self.lm_output["data"]} 239 | 240 | # Lm_head Inference 241 | self.net.process(self.name_lm, input_lm_tensors, output_lm_tensors) 242 | return int(self.lm_output["data"].asnumpy()) 243 | 244 | 245 | def stream_predict(self, input_str, history): 246 | tokens = make_context(self.sp, 247 | input_str, 248 | history=history, 249 | system=self.system_prompt, 250 | max_window_size=self.SEQLEN, 251 | chat_format="chatml") 252 | history.append((input_str, '')) 253 | tok_num = 0 254 | answer_cur = [] 255 | 256 | if not tokens: 257 | logging.error("Sorry: your question is too wierd!!") 258 | return 259 | if self.token_length > self.SEQLEN: 260 | logging.error("The maximum question length should be shorter than {} but we get {} instead.".format(self.SEQLEN, self.token_length)) 261 | return 262 | 263 | # First token 264 | first_start = time.time() 265 | token = self.forward_first(tokens) 266 | first_end = time.time() 267 | pre_token = 30910 268 | pre_ids = [pre_token] 269 | pre_word= self.sp.decode(pre_ids) 270 | # Sentencepiece will remove space token if the token list it receive has only one token, we add a pre_token so that space token will not be removed. 271 | while token != self.EOS and self.token_length < self.SEQLEN: 272 | ids = [pre_token, token] 273 | word = self.sp.decode(ids) 274 | diff = word[len(pre_word):] 275 | answer_cur += [token] 276 | yield self.sp.decode(answer_cur), history 277 | print(diff, flush=True, end='') 278 | self.token_length += 1 279 | tok_num += 1 280 | token = self.forward_next() 281 | 282 | # counting time 283 | next_end = time.time() 284 | first_duration = first_end-first_start 285 | next_duration = next_end-first_end 286 | tps = tok_num / next_duration 287 | 288 | 289 | print() 290 | print(f"FTL: {first_duration:.3f} s") 291 | print(f"TPS: {tps:.3f} token/s") 292 | -------------------------------------------------------------------------------- /chat/qwen/tokenization_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | from transformers import PreTrainedTokenizer 3 | 4 | def make_context( 5 | tokenizer: PreTrainedTokenizer, 6 | query: str, 7 | history: List[Tuple[str, str]] = None, 8 | system: str = "", 9 | max_window_size: int = 6144, 10 | chat_format: str = "chatml", 11 | ): 12 | if history is None: 13 | history = [] 14 | 15 | if chat_format == "chatml": 16 | im_start, im_end = "<|im_start|>", "<|im_end|>" 17 | im_start_tokens = [tokenizer.im_start_id] 18 | im_end_tokens = [tokenizer.im_end_id] 19 | nl_tokens = tokenizer.encode("\n") 20 | 21 | def _tokenize_str(role, content): 22 | return f"{role}\n{content}", tokenizer.encode( 23 | role, allowed_special=set() 24 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) 25 | 26 | system_text, system_tokens_part = _tokenize_str("system", system) 27 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 28 | 29 | raw_text = "" 30 | context_tokens = [] 31 | 32 | for turn_query, turn_response in reversed(history): 33 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 34 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 35 | response_text, response_tokens_part = _tokenize_str( 36 | "assistant", turn_response 37 | ) 38 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 39 | 40 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 41 | prev_chat = ( 42 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 43 | ) 44 | 45 | current_context_size = ( 46 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 47 | ) 48 | if current_context_size < max_window_size: 49 | context_tokens = next_context_tokens + context_tokens 50 | raw_text = prev_chat + raw_text 51 | else: 52 | break 53 | 54 | context_tokens = system_tokens + context_tokens 55 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 56 | context_tokens += ( 57 | nl_tokens 58 | + im_start_tokens 59 | + _tokenize_str("user", query)[1] 60 | + im_end_tokens 61 | + nl_tokens 62 | + im_start_tokens 63 | + tokenizer.encode("assistant") 64 | + nl_tokens 65 | ) 66 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 67 | 68 | elif chat_format == "raw": 69 | raw_text = query 70 | context_tokens = tokenizer.encode(raw_text) 71 | else: 72 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 73 | 74 | return context_tokens -------------------------------------------------------------------------------- /chat/utils.py: -------------------------------------------------------------------------------- 1 | #===----------------------------------------------------------------------===# 2 | # 3 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 4 | # 5 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 6 | # third-party components. 7 | # 8 | #===----------------------------------------------------------------------===# 9 | import sophon.sail as sail 10 | import numpy as np 11 | 12 | #convert sail_dtype to numpy dtype 13 | def type_convert(sail_dtype): 14 | if sail_dtype == sail.Dtype.BM_FLOAT32: 15 | return np.float32 16 | if sail_dtype == sail.Dtype.BM_FLOAT16: 17 | return np.float16 18 | if sail_dtype == sail.Dtype.BM_INT32: 19 | return np.int32 20 | if sail_dtype == sail.Dtype.BM_BFLOAT16: # 后续需要修改bf16的接口,现在先用fp16的代替 21 | return np.float16 22 | 23 | raise TypeError("only support float32 and int32 right now") 24 | 25 | def fp16_cast(arr:np.ndarray): 26 | """ 27 | reinterpret an array with int16 instead of float16, because pybind11 do not support float16. 28 | """ 29 | if arr.dtype == np.float16: 30 | return arr.view(np.uint16) 31 | else: 32 | return arr 33 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [bert_model] 2 | bmodel_path = ./models/bert_model/bge_large_512_fp16.bmodel 3 | token_path = ./models/bert_model/token_config 4 | 5 | [chatglm3] 6 | bmodel_path = ./models/glm3_model/chatglm3-6b_int4_1dev_2k.bmodel 7 | token_path = ./models/glm3_model/token_config 8 | 9 | [qwen7b] 10 | bmodel_path = ./models/qwen_model/qwen-7b_int4_1dev_2k.bmodel 11 | token_path = ./models/qwen_model/token_config 12 | -------------------------------------------------------------------------------- /data/db_tpu/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyifan2018/ChatDoc-TPU/1559fce5bcc972b6fee49a211beeaa66e26b37f1/data/db_tpu/.gitkeep -------------------------------------------------------------------------------- /data/uploaded/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyifan2018/ChatDoc-TPU/1559fce5bcc972b6fee49a211beeaa66e26b37f1/data/uploaded/.gitkeep -------------------------------------------------------------------------------- /doc_processor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | from .knowledge_file import KnowledgeFile 11 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/FilteredCSVloader.py: -------------------------------------------------------------------------------- 1 | ## 指定制定列的csv文件加载器 2 | 3 | from langchain.document_loaders import CSVLoader 4 | import csv 5 | from io import TextIOWrapper 6 | from typing import Dict, List, Optional 7 | from langchain.docstore.document import Document 8 | from langchain.document_loaders.helpers import detect_file_encodings 9 | 10 | 11 | class FilteredCSVLoader(CSVLoader): 12 | def __init__( 13 | self, 14 | file_path: str, 15 | columns_to_read: List[str], 16 | source_column: Optional[str] = None, 17 | metadata_columns: List[str] = [], 18 | csv_args: Optional[Dict] = None, 19 | encoding: Optional[str] = None, 20 | autodetect_encoding: bool = False, 21 | ): 22 | super().__init__( 23 | file_path=file_path, 24 | source_column=source_column, 25 | metadata_columns=metadata_columns, 26 | csv_args=csv_args, 27 | encoding=encoding, 28 | autodetect_encoding=autodetect_encoding, 29 | ) 30 | self.columns_to_read = columns_to_read 31 | 32 | def load(self) -> List[Document]: 33 | """Load data into document objects.""" 34 | 35 | docs = [] 36 | try: 37 | with open(self.file_path, newline="", encoding=self.encoding) as csvfile: 38 | docs = self.__read_file(csvfile) 39 | except UnicodeDecodeError as e: 40 | if self.autodetect_encoding: 41 | detected_encodings = detect_file_encodings(self.file_path) 42 | for encoding in detected_encodings: 43 | try: 44 | with open( 45 | self.file_path, newline="", encoding=encoding.encoding 46 | ) as csvfile: 47 | docs = self.__read_file(csvfile) 48 | break 49 | except UnicodeDecodeError: 50 | continue 51 | else: 52 | raise RuntimeError(f"Error loading {self.file_path}") from e 53 | except Exception as e: 54 | raise RuntimeError(f"Error loading {self.file_path}") from e 55 | 56 | return docs 57 | 58 | def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: 59 | docs = [] 60 | csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore 61 | for i, row in enumerate(csv_reader): 62 | if self.columns_to_read[0] in row: 63 | content = row[self.columns_to_read[0]] 64 | # Extract the source if available 65 | source = ( 66 | row.get(self.source_column, None) 67 | if self.source_column is not None 68 | else self.file_path 69 | ) 70 | metadata = {"source": source, "row": i} 71 | 72 | for col in self.metadata_columns: 73 | if col in row: 74 | metadata[col] = row[col] 75 | 76 | doc = Document(page_content=content, metadata=metadata) 77 | docs.append(doc) 78 | else: 79 | raise ValueError(f"Column '{self.columns_to_read[0]}' not found in CSV file.") 80 | 81 | return docs 82 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .mypdfloader import RapidOCRPDFLoader 2 | from .myimgloader import RapidOCRLoader 3 | from .mydocloader import RapidOCRDocLoader 4 | from .mypptloader import RapidOCRPPTLoader 5 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/mydocloader.py: -------------------------------------------------------------------------------- 1 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 2 | from typing import List 3 | import tqdm 4 | 5 | 6 | class RapidOCRDocLoader(UnstructuredFileLoader): 7 | def _get_elements(self) -> List: 8 | def doc2text(filepath): 9 | from docx.table import _Cell, Table 10 | from docx.oxml.table import CT_Tbl 11 | from docx.oxml.text.paragraph import CT_P 12 | from docx.text.paragraph import Paragraph 13 | from docx import Document, ImagePart 14 | from PIL import Image 15 | from io import BytesIO 16 | import numpy as np 17 | from rapidocr_onnxruntime import RapidOCR 18 | ocr = RapidOCR() 19 | doc = Document(filepath) 20 | resp = "" 21 | 22 | def iter_block_items(parent): 23 | from docx.document import Document 24 | if isinstance(parent, Document): 25 | parent_elm = parent.element.body 26 | elif isinstance(parent, _Cell): 27 | parent_elm = parent._tc 28 | else: 29 | raise ValueError("RapidOCRDocLoader parse fail") 30 | 31 | for child in parent_elm.iterchildren(): 32 | if isinstance(child, CT_P): 33 | yield Paragraph(child, parent) 34 | elif isinstance(child, CT_Tbl): 35 | yield Table(child, parent) 36 | 37 | b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables), 38 | desc="RapidOCRDocLoader block index: 0") 39 | for i, block in enumerate(iter_block_items(doc)): 40 | b_unit.set_description( 41 | "RapidOCRDocLoader block index: {}".format(i)) 42 | b_unit.refresh() 43 | if isinstance(block, Paragraph): 44 | resp += block.text.strip() + "\n" 45 | images = block._element.xpath('.//pic:pic') # 获取所有图片 46 | for image in images: 47 | for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id 48 | part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片 49 | if isinstance(part, ImagePart): 50 | image = Image.open(BytesIO(part._blob)) 51 | result, _ = ocr(np.array(image)) 52 | if result: 53 | ocr_result = [line[1] for line in result] 54 | resp += "\n".join(ocr_result) 55 | elif isinstance(block, Table): 56 | for row in block.rows: 57 | for cell in row.cells: 58 | for paragraph in cell.paragraphs: 59 | resp += paragraph.text.strip() + "\n" 60 | b_unit.update(1) 61 | return resp 62 | 63 | text = doc2text(self.file_path) 64 | from unstructured.partition.text import partition_text 65 | return partition_text(text=text, **self.unstructured_kwargs) 66 | 67 | 68 | if __name__ == '__main__': 69 | loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx") 70 | docs = loader.load() 71 | print(docs) 72 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/myimgloader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 3 | # from document_loaders.ocr import get_ocr 4 | from .ocr import get_ocr 5 | 6 | 7 | class RapidOCRLoader(UnstructuredFileLoader): 8 | def _get_elements(self) -> List: 9 | def img2text(filepath): 10 | resp = "" 11 | ocr = get_ocr() 12 | result, _ = ocr(filepath) 13 | if result: 14 | ocr_result = [line[1] for line in result] 15 | resp += "\n".join(ocr_result) 16 | return resp 17 | 18 | text = img2text(self.file_path) 19 | from unstructured.partition.text import partition_text 20 | return partition_text(text=text, **self.unstructured_kwargs) 21 | 22 | 23 | if __name__ == "__main__": 24 | loader = RapidOCRLoader(file_path="/home/junqian/workspace/junqian_warehouse/test_text_splitter/pngtest.png") 25 | docs = loader.load() 26 | print(docs) 27 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/mypdfloader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 3 | # from configs import PDF_OCR_THRESHOLD 4 | from .ocr import get_ocr 5 | import tqdm 6 | PDF_OCR_THRESHOLD = (0.6, 0.6) 7 | 8 | class RapidOCRPDFLoader(UnstructuredFileLoader): 9 | def _get_elements(self) -> List: 10 | def pdf2text(filepath): 11 | import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 12 | import numpy as np 13 | ocr = get_ocr() 14 | doc = fitz.open(filepath) 15 | resp = "" 16 | 17 | b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") 18 | for i, page in enumerate(doc): 19 | b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) 20 | b_unit.refresh() 21 | text = page.get_text("") 22 | resp += text + "\n" 23 | 24 | img_list = page.get_image_info(xrefs=True) 25 | for img in img_list: 26 | if xref := img.get("xref"): 27 | bbox = img["bbox"] 28 | # 检查图片尺寸是否超过设定的阈值 29 | if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0] 30 | or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]): 31 | continue 32 | pix = fitz.Pixmap(doc, xref) 33 | img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) 34 | result, _ = ocr(img_array) 35 | if result: 36 | ocr_result = [line[1] for line in result] 37 | resp += "\n".join(ocr_result) 38 | 39 | # 更新进度 40 | b_unit.update(1) 41 | return resp 42 | 43 | text = pdf2text(self.file_path) 44 | from unstructured.partition.text import partition_text 45 | return partition_text(text=text, **self.unstructured_kwargs) 46 | 47 | 48 | if __name__ == "__main__": 49 | loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf") 50 | docs = loader.load() 51 | print(docs) 52 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/mypptloader.py: -------------------------------------------------------------------------------- 1 | from langchain.document_loaders.unstructured import UnstructuredFileLoader 2 | from typing import List 3 | import tqdm 4 | 5 | 6 | class RapidOCRPPTLoader(UnstructuredFileLoader): 7 | def _get_elements(self) -> List: 8 | def ppt2text(filepath): 9 | from pptx import Presentation 10 | from PIL import Image 11 | import numpy as np 12 | from io import BytesIO 13 | from rapidocr_onnxruntime import RapidOCR 14 | ocr = RapidOCR() 15 | prs = Presentation(filepath) 16 | resp = "" 17 | 18 | def extract_text(shape): 19 | nonlocal resp 20 | if shape.has_text_frame: 21 | resp += shape.text.strip() + "\n" 22 | if shape.has_table: 23 | for row in shape.table.rows: 24 | for cell in row.cells: 25 | for paragraph in cell.text_frame.paragraphs: 26 | resp += paragraph.text.strip() + "\n" 27 | if shape.shape_type == 13: # 13 表示图片 28 | image = Image.open(BytesIO(shape.image.blob)) 29 | result, _ = ocr(np.array(image)) 30 | if result: 31 | ocr_result = [line[1] for line in result] 32 | resp += "\n".join(ocr_result) 33 | elif shape.shape_type == 6: # 6 表示组合 34 | for child_shape in shape.shapes: 35 | extract_text(child_shape) 36 | 37 | b_unit = tqdm.tqdm(total=len(prs.slides), 38 | desc="RapidOCRPPTLoader slide index: 1") 39 | # 遍历所有幻灯片 40 | for slide_number, slide in enumerate(prs.slides, start=1): 41 | b_unit.set_description( 42 | "RapidOCRPPTLoader slide index: {}".format(slide_number)) 43 | b_unit.refresh() 44 | sorted_shapes = sorted(slide.shapes, 45 | key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历 46 | for shape in sorted_shapes: 47 | extract_text(shape) 48 | b_unit.update(1) 49 | return resp 50 | 51 | text = ppt2text(self.file_path) 52 | from unstructured.partition.text import partition_text 53 | return partition_text(text=text, **self.unstructured_kwargs) 54 | 55 | 56 | if __name__ == '__main__': 57 | loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") 58 | docs = loader.load() 59 | print(docs) 60 | -------------------------------------------------------------------------------- /doc_processor/document_loaders/ocr.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | 4 | if TYPE_CHECKING: 5 | try: 6 | from rapidocr_paddle import RapidOCR 7 | except ImportError: 8 | from rapidocr_onnxruntime import RapidOCR 9 | 10 | 11 | def get_ocr(use_cuda: bool = True) -> "RapidOCR": 12 | try: 13 | from rapidocr_paddle import RapidOCR 14 | ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda) 15 | except ImportError: 16 | from rapidocr_onnxruntime import RapidOCR 17 | ocr = RapidOCR() 18 | return ocr 19 | -------------------------------------------------------------------------------- /doc_processor/knowledge_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from .text_splitter import zh_title_enhance as func_zh_title_enhance 4 | import langchain.document_loaders 5 | from langchain.docstore.document import Document 6 | from langchain.text_splitter import TextSplitter 7 | from pathlib import Path 8 | import json 9 | from typing import List, Union,Dict, Tuple, Generator 10 | import chardet 11 | from configparser import ConfigParser 12 | 13 | KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") 14 | if not os.path.exists(KB_ROOT_PATH): 15 | os.mkdir(KB_ROOT_PATH) 16 | 17 | TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" 18 | ZH_TITLE_ENHANCE = False 19 | CHUNK_SIZE = 250 20 | OVERLAP_SIZE = 50 21 | text_splitter_dict = { 22 | "ChineseRecursiveTextSplitter": { 23 | "source": "huggingface", # 选择tiktoken则使用openai的方法 24 | "tokenizer_name_or_path": "", 25 | }, 26 | "SpacyTextSplitter": { 27 | "source": "huggingface", 28 | "tokenizer_name_or_path": "gpt2", 29 | }, 30 | "RecursiveCharacterTextSplitter": { 31 | "source": "tiktoken", 32 | "tokenizer_name_or_path": "cl100k_base", 33 | }, 34 | "MarkdownHeaderTextSplitter": { 35 | "headers_to_split_on": 36 | [ 37 | ("#", "head1"), 38 | ("##", "head2"), 39 | ("###", "head3"), 40 | ("####", "head4"), 41 | ] 42 | }, 43 | } 44 | LLM_MODELS = os.getenv("LLM_MODEL") 45 | 46 | 47 | LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], 48 | "MHTMLLoader": ['.mhtml'], 49 | "UnstructuredMarkdownLoader": ['.md'], 50 | "JSONLoader": [".json"], 51 | "JSONLinesLoader": [".jsonl"], 52 | "CSVLoader": [".csv"], 53 | # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv 54 | "RapidOCRPDFLoader": [".pdf"], 55 | "RapidOCRDocLoader": ['.docx', '.doc'], 56 | "RapidOCRPPTLoader": ['.ppt', '.pptx', ], 57 | "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], 58 | "UnstructuredFileLoader": ['.eml', '.msg', '.rst', 59 | '.rtf', '.txt', '.xml', 60 | '.epub', '.odt','.tsv'], 61 | "UnstructuredEmailLoader": ['.eml', '.msg'], 62 | "UnstructuredEPubLoader": ['.epub'], 63 | "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'], 64 | "NotebookLoader": ['.ipynb'], 65 | "UnstructuredODTLoader": ['.odt'], 66 | "PythonLoader": ['.py'], 67 | "UnstructuredRSTLoader": ['.rst'], 68 | "UnstructuredRTFLoader": ['.rtf'], 69 | "SRTLoader": ['.srt'], 70 | "TomlLoader": ['.toml'], 71 | "UnstructuredTSVLoader": ['.tsv'], 72 | "UnstructuredWordDocumentLoader": ['.docx', '.doc'], 73 | "UnstructuredXMLLoader": ['.xml'], 74 | "UnstructuredPowerPointLoader": ['.ppt', '.pptx'], 75 | "EverNoteLoader": ['.enex'], 76 | } 77 | SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] 78 | 79 | def get_LoaderClass(file_extension): 80 | for LoaderClass, extensions in LOADER_DICT.items(): 81 | if file_extension in extensions: 82 | return LoaderClass 83 | 84 | def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): 85 | ''' 86 | 根据loader_name和文件路径或内容返回文档加载器。 87 | ''' 88 | loader_kwargs = loader_kwargs or {} 89 | try: 90 | if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader", 91 | "RapidOCRDocLoader", "RapidOCRPPTLoader"]: 92 | document_loaders_module = importlib.import_module('document_loaders') 93 | else: 94 | document_loaders_module = importlib.import_module('langchain.document_loaders') 95 | DocumentLoader = getattr(document_loaders_module, loader_name) 96 | except Exception as e: 97 | msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" 98 | # logger.error(f'{e.__class__.__name__}: {msg}', 99 | # exc_info=e if log_verbose else None) 100 | document_loaders_module = importlib.import_module('langchain.document_loaders') 101 | DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") 102 | 103 | if loader_name == "UnstructuredFileLoader": 104 | loader_kwargs.setdefault("autodetect_encoding", True) 105 | elif loader_name == "CSVLoader": 106 | if not loader_kwargs.get("encoding"): 107 | # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误 108 | with open(file_path, 'rb') as struct_file: 109 | encode_detect = chardet.detect(struct_file.read()) 110 | if encode_detect is None: 111 | encode_detect = {"encoding": "utf-8"} 112 | loader_kwargs["encoding"] = encode_detect["encoding"] 113 | 114 | elif loader_name == "JSONLoader": 115 | loader_kwargs.setdefault("jq_schema", ".") 116 | loader_kwargs.setdefault("text_content", False) 117 | elif loader_name == "JSONLinesLoader": 118 | loader_kwargs.setdefault("jq_schema", ".") 119 | loader_kwargs.setdefault("text_content", False) 120 | 121 | loader = DocumentLoader(file_path, **loader_kwargs) 122 | return loader 123 | 124 | 125 | 126 | def make_text_splitter( 127 | splitter_name: str = TEXT_SPLITTER_NAME, 128 | chunk_size: int = CHUNK_SIZE, 129 | chunk_overlap: int = OVERLAP_SIZE, 130 | llm_model: str = LLM_MODELS, 131 | ): 132 | """ 133 | 根据参数获取特定的分词器 134 | """ 135 | splitter_name = splitter_name or "SpacyTextSplitter" 136 | try: 137 | if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 138 | headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] 139 | text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( 140 | headers_to_split_on=headers_to_split_on) 141 | else: 142 | 143 | try: ## 优先使用用户自定义的text_splitter 144 | text_splitter_module = importlib.import_module('text_splitter') 145 | TextSplitter = getattr(text_splitter_module, splitter_name) 146 | except: ## 否则使用langchain的text_splitter 147 | text_splitter_module = importlib.import_module('langchain.text_splitter') 148 | TextSplitter = getattr(text_splitter_module, splitter_name) 149 | 150 | if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 151 | try: 152 | text_splitter = TextSplitter.from_tiktoken_encoder( 153 | encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], 154 | pipeline="zh_core_web_sm", 155 | chunk_size=chunk_size, 156 | chunk_overlap=chunk_overlap 157 | ) 158 | except: 159 | text_splitter = TextSplitter.from_tiktoken_encoder( 160 | encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], 161 | chunk_size=chunk_size, 162 | chunk_overlap=chunk_overlap 163 | ) 164 | elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 165 | if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": 166 | config = ConfigParser() 167 | config.read('config.ini') 168 | token_path = config.get(llm_model, "token_path") 169 | text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = token_path 170 | 171 | if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": 172 | from transformers import GPT2TokenizerFast 173 | from langchain.text_splitter import CharacterTextSplitter 174 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 175 | else: ## 字符长度加载 176 | from transformers import AutoTokenizer 177 | tokenizer = AutoTokenizer.from_pretrained( 178 | text_splitter_dict[splitter_name]["tokenizer_name_or_path"], 179 | trust_remote_code=True) 180 | text_splitter = TextSplitter.from_huggingface_tokenizer( 181 | tokenizer=tokenizer, 182 | chunk_size=chunk_size, 183 | chunk_overlap=chunk_overlap 184 | ) 185 | else: 186 | try: 187 | text_splitter = TextSplitter( 188 | pipeline="zh_core_web_sm", 189 | chunk_size=chunk_size, 190 | chunk_overlap=chunk_overlap 191 | ) 192 | except: 193 | text_splitter = TextSplitter( 194 | chunk_size=chunk_size, 195 | chunk_overlap=chunk_overlap 196 | ) 197 | except Exception as e: 198 | print(e) 199 | text_splitter_module = importlib.import_module('langchain.text_splitter') 200 | TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") 201 | text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) 202 | 203 | # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287 204 | # text_splitter._tokenizer.max_length = 37016792 205 | # text_splitter._tokenizer.prefer_gpu() 206 | return text_splitter 207 | 208 | 209 | class KnowledgeFile: 210 | def __init__( 211 | self, 212 | filename: str, 213 | loader_kwargs: Dict = {}, 214 | ): 215 | ''' 216 | 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 217 | ''' 218 | self.filename = str(Path(filename).as_posix()) 219 | self.ext = os.path.splitext(filename)[-1].lower() 220 | if self.ext not in SUPPORTED_EXTS: 221 | raise ValueError(f"暂未支持的文件格式 {self.filename}") 222 | self.loader_kwargs = loader_kwargs 223 | self.filepath = filename 224 | self.docs = None 225 | self.splited_docs = None 226 | self.document_loader_name = get_LoaderClass(self.ext) 227 | self.text_splitter_name = TEXT_SPLITTER_NAME 228 | 229 | def file2docs(self, refresh: bool = False): 230 | if self.docs is None or refresh: 231 | # logger.info(f"{self.document_loader_name} used for {self.filepath}") 232 | loader = get_loader(loader_name=self.document_loader_name, 233 | file_path=self.filepath, 234 | loader_kwargs=self.loader_kwargs) 235 | self.docs = loader.load() 236 | return self.docs 237 | 238 | def docs2texts( 239 | self, 240 | docs: List[Document] = None, 241 | zh_title_enhance: bool = ZH_TITLE_ENHANCE, 242 | refresh: bool = False, 243 | chunk_size: int = CHUNK_SIZE, 244 | chunk_overlap: int = OVERLAP_SIZE, 245 | text_splitter: TextSplitter = None, 246 | ): 247 | docs = docs or self.file2docs(refresh=refresh) 248 | if not docs: 249 | return [] 250 | if self.ext not in [".csv"]: 251 | if text_splitter is None: 252 | text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, 253 | chunk_overlap=chunk_overlap) 254 | if self.text_splitter_name == "MarkdownHeaderTextSplitter": 255 | docs = text_splitter.split_text(docs[0].page_content) 256 | else: 257 | docs = text_splitter.split_documents(docs) 258 | 259 | if not docs: 260 | return [] 261 | 262 | print(f"文档切分示例:{docs[0]}") 263 | if zh_title_enhance: 264 | docs = func_zh_title_enhance(docs) 265 | self.splited_docs = docs 266 | return self.splited_docs 267 | 268 | def file2text( 269 | self, 270 | zh_title_enhance: bool = ZH_TITLE_ENHANCE, 271 | refresh: bool = False, 272 | chunk_size: int = CHUNK_SIZE, 273 | chunk_overlap: int = OVERLAP_SIZE, 274 | text_splitter: TextSplitter = None, 275 | ): 276 | if self.splited_docs is None or refresh: 277 | docs = self.file2docs() 278 | self.splited_docs = self.docs2texts(docs=docs, 279 | zh_title_enhance=zh_title_enhance, 280 | refresh=refresh, 281 | chunk_size=chunk_size, 282 | chunk_overlap=chunk_overlap, 283 | text_splitter=text_splitter) 284 | return self.splited_docs 285 | 286 | def file_exist(self): 287 | return os.path.isfile(self.filepath) 288 | 289 | def get_mtime(self): 290 | return os.path.getmtime(self.filepath) 291 | 292 | def get_size(self): 293 | return os.path.getsize(self.filepath) 294 | 295 | 296 | 297 | if __name__ == "__main__": 298 | from pprint import pprint 299 | 300 | kb_file = KnowledgeFile( 301 | filename="/home/junqian/workspace/junqian_warehouse/ChatDoc-TPU/doc_processor/test_samples/SOPHON-SAIL_zh.pdf") 302 | # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" 303 | # docs = kb_file.file2docs() 304 | text = kb_file.file2text() 305 | # print(text) 306 | # pprint(docs[-1]) -------------------------------------------------------------------------------- /doc_processor/text_splitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .chinese_text_splitter import ChineseTextSplitter 2 | from .ali_text_splitter import AliTextSplitter 3 | from .zh_title_enhance import zh_title_enhance 4 | from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter -------------------------------------------------------------------------------- /doc_processor/text_splitter/ali_text_splitter.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | import re 3 | from typing import List 4 | 5 | 6 | class AliTextSplitter(CharacterTextSplitter): 7 | def __init__(self, pdf: bool = False, **kwargs): 8 | super().__init__(**kwargs) 9 | self.pdf = pdf 10 | 11 | def split_text(self, text: str) -> List[str]: 12 | # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278 13 | # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html 14 | # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id 15 | if self.pdf: 16 | text = re.sub(r"\n{3,}", r"\n", text) 17 | text = re.sub('\s', " ", text) 18 | text = re.sub("\n\n", "", text) 19 | try: 20 | from modelscope.pipelines import pipeline 21 | except ImportError: 22 | raise ImportError( 23 | "Could not import modelscope python package. " 24 | "Please install modelscope with `pip install modelscope`. " 25 | ) 26 | 27 | 28 | p = pipeline( 29 | task="document-segmentation", 30 | model='damo/nlp_bert_document-segmentation_chinese-base', 31 | device="cpu") 32 | result = p(documents=text) 33 | sent_list = [i for i in result["text"].split("\n\t") if i] 34 | return sent_list 35 | -------------------------------------------------------------------------------- /doc_processor/text_splitter/chinese_recursive_text_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Optional, Any 3 | from langchain.text_splitter import RecursiveCharacterTextSplitter 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def _split_text_with_regex_from_end( 10 | text: str, separator: str, keep_separator: bool 11 | ) -> List[str]: 12 | # Now that we have the separator, split the text 13 | if separator: 14 | if keep_separator: 15 | # The parentheses in the pattern keep the delimiters in the result. 16 | _splits = re.split(f"({separator})", text) 17 | splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] 18 | if len(_splits) % 2 == 1: 19 | splits += _splits[-1:] 20 | # splits = [_splits[0]] + splits 21 | else: 22 | splits = re.split(separator, text) 23 | else: 24 | splits = list(text) 25 | return [s for s in splits if s != ""] 26 | 27 | 28 | class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): 29 | def __init__( 30 | self, 31 | separators: Optional[List[str]] = None, 32 | keep_separator: bool = True, 33 | is_separator_regex: bool = True, 34 | **kwargs: Any, 35 | ) -> None: 36 | """Create a new TextSplitter.""" 37 | super().__init__(keep_separator=keep_separator, **kwargs) 38 | self._separators = separators or [ 39 | "\n\n", 40 | "\n", 41 | "。|!|?", 42 | "\.\s|\!\s|\?\s", 43 | ";|;\s", 44 | ",|,\s" 45 | ] 46 | self._is_separator_regex = is_separator_regex 47 | 48 | def _split_text(self, text: str, separators: List[str]) -> List[str]: 49 | """Split incoming text and return chunks.""" 50 | final_chunks = [] 51 | # Get appropriate separator to use 52 | separator = separators[-1] 53 | new_separators = [] 54 | for i, _s in enumerate(separators): 55 | _separator = _s if self._is_separator_regex else re.escape(_s) 56 | if _s == "": 57 | separator = _s 58 | break 59 | if re.search(_separator, text): 60 | separator = _s 61 | new_separators = separators[i + 1:] 62 | break 63 | 64 | _separator = separator if self._is_separator_regex else re.escape(separator) 65 | splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) 66 | 67 | # Now go merging things, recursively splitting longer texts. 68 | _good_splits = [] 69 | _separator = "" if self._keep_separator else separator 70 | for s in splits: 71 | if self._length_function(s) < self._chunk_size: 72 | _good_splits.append(s) 73 | else: 74 | if _good_splits: 75 | merged_text = self._merge_splits(_good_splits, _separator) 76 | final_chunks.extend(merged_text) 77 | _good_splits = [] 78 | if not new_separators: 79 | final_chunks.append(s) 80 | else: 81 | other_info = self._split_text(s, new_separators) 82 | final_chunks.extend(other_info) 83 | if _good_splits: 84 | merged_text = self._merge_splits(_good_splits, _separator) 85 | final_chunks.extend(merged_text) 86 | return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""] 87 | 88 | 89 | if __name__ == "__main__": 90 | text_splitter = ChineseRecursiveTextSplitter( 91 | keep_separator=True, 92 | is_separator_regex=True, 93 | chunk_size=50, 94 | chunk_overlap=0 95 | ) 96 | ls = [ 97 | """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", 98 | ] 99 | # text = """""" 100 | for inum, text in enumerate(ls): 101 | print(inum) 102 | chunks = text_splitter.split_text(text) 103 | for chunk in chunks: 104 | print(chunk) 105 | -------------------------------------------------------------------------------- /doc_processor/text_splitter/chinese_text_splitter.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | import re 3 | from typing import List 4 | 5 | 6 | class ChineseTextSplitter(CharacterTextSplitter): 7 | def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): 8 | super().__init__(**kwargs) 9 | self.pdf = pdf 10 | self.sentence_size = sentence_size 11 | 12 | def split_text1(self, text: str) -> List[str]: 13 | if self.pdf: 14 | text = re.sub(r"\n{3,}", "\n", text) 15 | text = re.sub('\s', ' ', text) 16 | text = text.replace("\n\n", "") 17 | sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; 18 | sent_list = [] 19 | for ele in sent_sep_pattern.split(text): 20 | if sent_sep_pattern.match(ele) and sent_list: 21 | sent_list[-1] += ele 22 | elif ele: 23 | sent_list.append(ele) 24 | return sent_list 25 | 26 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 27 | if self.pdf: 28 | text = re.sub(r"\n{3,}", r"\n", text) 29 | text = re.sub('\s', " ", text) 30 | text = re.sub("\n\n", "", text) 31 | 32 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 33 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 34 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 35 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 36 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 37 | text = text.rstrip() # 段尾如果有多余的\n就去掉它 38 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 39 | ls = [i for i in text.split("\n") if i] 40 | for ele in ls: 41 | if len(ele) > self.sentence_size: 42 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 43 | ele1_ls = ele1.split("\n") 44 | for ele_ele1 in ele1_ls: 45 | if len(ele_ele1) > self.sentence_size: 46 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 47 | ele2_ls = ele_ele2.split("\n") 48 | for ele_ele2 in ele2_ls: 49 | if len(ele_ele2) > self.sentence_size: 50 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 51 | ele2_id = ele2_ls.index(ele_ele2) 52 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 53 | ele2_id + 1:] 54 | ele_id = ele1_ls.index(ele_ele1) 55 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 56 | 57 | id = ls.index(ele) 58 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 59 | return ls 60 | -------------------------------------------------------------------------------- /doc_processor/text_splitter/zh_title_enhance.py: -------------------------------------------------------------------------------- 1 | from langchain.docstore.document import Document 2 | import re 3 | 4 | 5 | def under_non_alpha_ratio(text: str, threshold: float = 0.5): 6 | """Checks if the proportion of non-alpha characters in the text snippet exceeds a given 7 | threshold. This helps prevent text like "-----------BREAK---------" from being tagged 8 | as a title or narrative text. The ratio does not count spaces. 9 | 10 | Parameters 11 | ---------- 12 | text 13 | The input string to test 14 | threshold 15 | If the proportion of non-alpha characters exceeds this threshold, the function 16 | returns False 17 | """ 18 | if len(text) == 0: 19 | return False 20 | 21 | alpha_count = len([char for char in text if char.strip() and char.isalpha()]) 22 | total_count = len([char for char in text if char.strip()]) 23 | try: 24 | ratio = alpha_count / total_count 25 | return ratio < threshold 26 | except: 27 | return False 28 | 29 | 30 | def is_possible_title( 31 | text: str, 32 | title_max_word_length: int = 20, 33 | non_alpha_threshold: float = 0.5, 34 | ) -> bool: 35 | """Checks to see if the text passes all of the checks for a valid title. 36 | 37 | Parameters 38 | ---------- 39 | text 40 | The input text to check 41 | title_max_word_length 42 | The maximum number of words a title can contain 43 | non_alpha_threshold 44 | The minimum number of alpha characters the text needs to be considered a title 45 | """ 46 | 47 | # 文本长度为0的话,肯定不是title 48 | if len(text) == 0: 49 | print("Not a title. Text is empty.") 50 | return False 51 | 52 | # 文本中有标点符号,就不是title 53 | ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" 54 | ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) 55 | if ENDS_IN_PUNCT_RE.search(text) is not None: 56 | return False 57 | 58 | # 文本长度不能超过设定值,默认20 59 | # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it 60 | # is less expensive and actual tokenization doesn't add much value for the length check 61 | if len(text) > title_max_word_length: 62 | return False 63 | 64 | # 文本中数字的占比不能太高,否则不是title 65 | if under_non_alpha_ratio(text, threshold=non_alpha_threshold): 66 | return False 67 | 68 | # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles 69 | if text.endswith((",", ".", ",", "。")): 70 | return False 71 | 72 | if text.isnumeric(): 73 | print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore 74 | return False 75 | 76 | # 开头的字符内应该有数字,默认5个字符内 77 | if len(text) < 5: 78 | text_5 = text 79 | else: 80 | text_5 = text[:5] 81 | alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) 82 | if not alpha_in_text_5: 83 | return False 84 | 85 | return True 86 | 87 | 88 | def zh_title_enhance(docs: Document) -> Document: 89 | title = None 90 | if len(docs) > 0: 91 | for doc in docs: 92 | if is_possible_title(doc.page_content): 93 | doc.metadata['category'] = 'cn_Title' 94 | title = doc.page_content 95 | elif title: 96 | doc.page_content = f"下文与({title})有关。{doc.page_content}" 97 | return docs 98 | else: 99 | print("文件不存在") 100 | -------------------------------------------------------------------------------- /docs/Environment_Install_Guide.md: -------------------------------------------------------------------------------- 1 | # sophon-demo环境安装指南 2 | ## 目录 3 | - [sophon-demo环境安装指南](#sophon-demo环境安装指南) 4 | - [目录](#目录) 5 | - [1 TPU-MLIR环境搭建](#1-tpu-mlir环境搭建) 6 | - [2 TPU-NNTC环境搭建](#2-tpu-nntc环境搭建) 7 | - [3 x86 PCIe平台的开发和运行环境搭建](#3-x86-pcie平台的开发和运行环境搭建) 8 | - [3.1 安装libsophon](#31-安装libsophon) 9 | - [3.2 安装sophon-ffmpeg和sophon-opencv](#32-安装sophon-ffmpeg和sophon-opencv) 10 | - [3.3 编译安装sophon-sail](#33-编译安装sophon-sail) 11 | - [4 SoC平台的开发和运行环境搭建](#4-soc平台的开发和运行环境搭建) 12 | - [4.1 交叉编译环境搭建](#41-交叉编译环境搭建) 13 | - [4.2 交叉编译安装sophon-sail](#42-交叉编译安装sophon-sail) 14 | - [5 arm PCIe平台的开发和运行环境搭建](#5-arm-pcie平台的开发和运行环境搭建) 15 | - [5.1 安装libsophon](#51-安装libsophon) 16 | - [5.2 安装sophon-ffmpeg和sophon-opencv](#52-安装sophon-ffmpeg和sophon-opencv) 17 | - [5.3 编译安装sophon-sail](#53-编译安装sophon-sail) 18 | - [6 riscv PCIe平台的开发和运行环境搭建](#6-riscv-pcie平台的开发和运行环境搭建) 19 | - [6.1 安装libsophon](#61-安装libsophon) 20 | - [6.2 安装sophon-ffmpeg和sophon-opencv](#62-安装sophon-ffmpeg和sophon-opencv) 21 | - [6.3 编译安装sophon-sail](#63-编译安装sophon-sail) 22 | 23 | Sophon Demo所依赖的环境主要包括用于编译和量化模型的TPU-NNTC、TPU-MLIR环境,用于编译C++程序的开发环境以及用于部署程序的运行环境。 24 | 25 | ## 1 TPU-MLIR环境搭建 26 | 使用TPU-MLIR编译BModel,通常需要在x86主机上安装TPU-MLIR环境,x86主机已安装Ubuntu16.04/18.04/20.04系统,并且运行内存在12GB以上。TPU-MLIR环境安装步骤主要包括: 27 | 28 | 1. 安装Docker 29 | 30 | 若已安装docker,请跳过本节。 31 | ```bash 32 | # 安装docker 33 | sudo apt-get install docker.io 34 | # docker命令免root权限执行 35 | # 创建docker用户组,若已有docker组会报错,没关系可忽略 36 | sudo groupadd docker 37 | # 将当前用户加入docker组 38 | sudo usermod -aG docker $USER 39 | # 切换当前会话到新group或重新登录重启X会话 40 | newgrp docker​ 41 | ``` 42 | > **提示**:需要logout系统然后重新登录,再使用docker就不需要sudo了。 43 | 44 | 2. 创建并进入docker 45 | 46 | TPU-MLIR使用的docker是sophgo/tpuc_dev:latest, docker镜像和tpu-mlir有绑定关系,少数情况下有可能更新了tpu-mlir,需要新的镜像。 47 | ```bash 48 | docker pull sophgo/tpuc_dev:latest 49 | # 这里将本级目录映射到docker内的/workspace目录,用户需要根据实际情况将demo的目录映射到docker里面 50 | # myname只是举个名字的例子, 请指定成自己想要的容器的名字 51 | docker run --privileged --name myname -v $PWD:/workspace -it sophgo/tpuc_dev:latest 52 | # 此时已经进入docker,并在/workspace目录下 53 | ``` 54 | 55 | 3. 安装TPU-MLIR 56 | 57 | 目前支持两种安装方法: 58 | 59 | (1)直接从pypi下载并安装: 60 | ```bash 61 | pip install tpu_mlir 62 | ``` 63 | (2)从[TPU-MLIR Github](https://github.com/sophgo/tpu-mlir/releases)下载最新`tpu_mlir-*-py3-none-any.whl`,然后使用pip安装: 64 | ```bash 65 | pip install tpu_mlir-*-py3-none-any.whl 66 | ``` 67 | 68 | TPU-MLIR在对不同框架模型处理时所需的依赖不同,对于onnx或torch生成的模型文件, 69 | 使用下面命令安装额外的依赖环境: 70 | ```bash 71 | pip install tpu_mlir[onnx] 72 | pip install tpu_mlir[torch] 73 | ``` 74 | 目前支持五种配置: onnx, torch, tensorflow, caffe, paddle。可使用一条命令安装多个配置,也可直接安装全部依赖环境: 75 | ```bash 76 | pip install tpu_mlir[onnx,torch,caffe] 77 | pip install tpu_mlir[all] 78 | ``` 79 | 80 | 建议TPU-MLIR的镜像仅用于编译和量化模型,程序编译和运行请在开发和运行环境中进行。更多TPU-MLIR的教程请参考[算能官网](https://developer.sophgo.com/site/index/material/31/all.html)的《TPU-MLIR快速入门手册》和《TPU-MLIR开发参考手册》。 81 | 82 | ## 2 TPU-NNTC环境搭建 83 | 使用TPU-NNTC编译BModel,通常需要在x86主机上安装TPU-NNTC环境,x86主机已安装Ubuntu16.04/18.04/20.04系统,并且运行内存在12GB以上。TPU-NNTC环境安装步骤主要包括: 84 | 85 | 1. 安装Docker 86 | 87 | 若已安装docker,请跳过本节。 88 | ```bash 89 | # 安装docker 90 | sudo apt-get install docker.io 91 | # docker命令免root权限执行 92 | # 创建docker用户组,若已有docker组会报错,没关系可忽略 93 | sudo groupadd docker 94 | # 将当前用户加入docker组 95 | sudo usermod -aG docker $USER 96 | # 切换当前会话到新group或重新登录重启X会话 97 | newgrp docker​ 98 | ``` 99 | > **提示**:需要logout系统然后重新登录,再使用docker就不需要sudo了。 100 | 101 | 2. 下载并解压TPU-NNTC 102 | 103 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的TPU-NNTC压缩包,命名如tpu-nntc_vx.y.z-hash-date.tar.gz,x.y.z表示版本号,并进行解压。 104 | ```bash 105 | mkdir tpu-nntc 106 | # 将压缩包解压到tpu-nntc 107 | tar zxvf tpu-nntc_vx.y.z--.tar.gz --strip-components=1 -C tpu-nntc 108 | ``` 109 | 110 | 3. 创建并进入docker 111 | 112 | TPU-NNTC使用的docker是sophgo/tpuc_dev:2.1, docker镜像和tpu-nntc有绑定关系,少数情况下有可能更新了tpu-nntc,需要新的镜像。 113 | ```bash 114 | cd tpu-nntc 115 | # 进入docker,如果当前系统没有对应镜像,会自动从docker hub上下载 116 | # 这里将tpu-nntc的上一级目录映射到docker内的/workspace目录,用户需要根据实际情况将demo的目录映射到docker里面 117 | # 这里用了8001到8001端口映射,之后在使用ufw可视化工具会用到 118 | # 如果端口已经占用,请更换其他未占用端口,后面根据需要更换进行调整 119 | docker run --privileged --name myname -v $PWD/..:/workspace -p 8001:8001 -it sophgo/tpuc_dev:v2.1 120 | # 此时已经进入docker,并在/workspace目录下 121 | # 下面初始化软件环境 122 | cd /workspace/tpu-nntc 123 | source scripts/envsetup.sh 124 | ``` 125 | 此镜像仅用于编译和量化模型,程序编译和运行请在开发和运行环境中进行。更多TPU-NNTC的教程请参考[算能官网](https://developer.sophgo.com/site/index/material/31/all.html)的《TPU-NNTC快速入门指南》和《TPU-NNTC开发参考手册》。 126 | 127 | 128 | 129 | ## 3 x86 PCIe平台的开发和运行环境搭建 130 | 如果您在x86平台安装了PCIe加速卡,开发环境与运行环境可以是统一的,您可以直接在宿主机上搭建开发和运行环境。 131 | 132 | **注意:** mlir提供的docker环境用来编译模型的,不建议与运行环境混用,如果您需要在主机上搭建docker测试环境,请参考官网《LIBSOPHON使用手册.pdf》第6章-使用Docker搭建测试环境。 133 | 134 | ### 3.1 安装libsophon 135 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的libsophon安装包,包括: 136 | * sophon-driver_x.y.z_amd64.deb 137 | * sophon-libsophon_x.y.z_amd64.deb 138 | * sophon-libsophon-dev_x.y.z_amd64.deb 139 | 140 | 其中:x.y.z表示版本号;sophon-driver包含了PCIe加速卡驱动;sophon-libsophon包含了运行时环境(库文件、工具等);sophon-libsophon-dev包含了开发环境(头文件等)。如果只是在部署环境上安装,则不需要安装 sophon-libsophon-dev。 141 | ```bash 142 | # 安装依赖库,只需要执行一次 143 | sudo apt install dkms libncurses5 144 | # 安装libsophon 145 | sudo dpkg -i sophon-*amd64.deb 146 | # 在终端执行如下命令,或者登出再登入当前用户后即可使用bm-smi等命令: 147 | source /etc/profile 148 | ``` 149 | 150 | 更多libsophon信息请参考《LIBSOPHON使用手册.pdf》。 151 | 152 | ### 3.2 安装sophon-ffmpeg和sophon-opencv 153 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的sophon-mw安装包,包括: 154 | * sophon-mw-sophon-ffmpeg_x.y.z_amd64.deb 155 | * sophon-mw-sophon-ffmpeg-dev_x.y.z_amd64.deb 156 | * sophon-mw-sophon-opencv_x.y.z_amd64.deb 157 | * sophon-mw-sophon-opencv-dev_x.y.z_amd64.deb 158 | 159 | 其中:x.y.z表示版本号;sophon-ffmpeg/sophon-opencv包含了ffmpeg/opencv运行时环境(库文件、工具等);sophon-ffmpeg-dev/sophon-opencv-dev包含了开发环境(头文件、pkgconfig、cmake等)。如果只是在部署环境上安装,则不需要安装 sophon-ffmpeg-dev/sophon-opencv-dev。 160 | 161 | sophon-mw-sophon-ffmpeg依赖sophon-libsophon包,而sophon-mw-sophon-opencv依赖sophon-mw-sophon-ffmpeg,因此在安装次序上必须 162 | 先安装libsophon, 然后sophon-mw-sophon-ffmpeg, 最后安装sophon-mw-sophon-opencv。 163 | 164 | 如果运行环境中使用的libstdc++库使用GCC5.1之前的旧版本ABI接口(典型的有CENTOS系统),请使用sophon-mw-sophon-opencv-abi0相关安装包。 165 | 166 | ```bash 167 | # 安装sophon-ffmpeg 168 | sudo dpkg -i sophon-mw-sophon-ffmpeg_*amd64.deb sophon-mw-sophon-ffmpeg-dev_*amd64.deb 169 | # 安装sophon-opencv 170 | sudo dpkg -i sophon-mw-sophon-opencv_*amd64.deb sophon-mw-sophon-opencv-dev_*amd64.deb 171 | # 在终端执行如下命令,或者logout再login当前用户后即可使用安装的工具 172 | source /etc/profile 173 | ``` 174 | 175 | 更多sophon-mw信息请参考《MULTIMEDIA使用手册.pdf》、《MULTIMEDIA开发参考手册.pdf》。 176 | 177 | ### 3.3 编译安装sophon-sail 178 | 如果例程依赖sophon-sail则需要编译和安装sophon-sail,否则可跳过本章节。 179 | 180 | 需从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的SDK,里面有sophon-sail的压缩包,命名如sophon-sail_x.y.z.tar.gz,x.y.z表示版本号。 181 | 您可以打开sophon-sail压缩包里面提供的用户手册(命名为sophon-sail_zh.pdf),参考编译安装指南章节,选择您需要的模式(C++/Python,PCIE MODE)进行安装。 182 | 183 | 184 | ## 4 SoC平台的开发和运行环境搭建 185 | 对于SoC平台,安装好SophonSDK(>=v22.09.02)后内部已经集成了相应的libsophon、sophon-opencv和sophon-ffmpeg运行库包,位于`/opt/sophon/`下,可直接用于运行环境。通常在x86主机上交叉编译程序,使之能够在SoC平台运行。SophonSDK固件刷新方法可参考[FAQ文档](./FAQ.md#12-soc模式下如何使用sd卡刷更新固件). 186 | 187 | ### 4.1 交叉编译环境搭建 188 | 需要在x86主机上使用SOPHON SDK搭建交叉编译环境,将程序所依赖的头文件和库文件打包至soc-sdk目录中。 189 | 1. 安装交叉编译工具链 190 | ```bash 191 | sudo apt-get install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu 192 | ``` 193 | 如果报错:`/lib/aarch64-linux-gnu/libc.so.6: version 'GLIBC_2.33' not found`。 194 | 这是由于您主机上的交叉编译工具链版本太高导致,可以在[linaro官方网站](https://releases.linaro.org/components/toolchain/binaries/7.5-2019.12/aarch64-linux-gnu/)下载不高于边缘设备gcc版本的交叉编译工具链。 195 | 196 | 这里提供一个ubuntu配置的例子: 197 | ```bash 198 | sudo apt remove cpp-*-aarch64-linux-gnu 199 | 200 | wget -nd https://releases.linaro.org/components/toolchain/binaries/7.5-2019.12/aarch64-linux-gnu/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz 201 | 202 | tar xvf gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz 203 | 204 | echo "export PATH=$PWD/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu/bin/:$PATH" >> ~/.bashrc 205 | 206 | source ~/.bashrc 207 | ``` 208 | 209 | 2. 打包libsophon 210 | 211 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的sophon-img安装包,其中包括libsophon_soc_x.y.z_aarch64.tar.gz,x.y.z表示版本号,并进行解压。 212 | 213 | ```bash 214 | # 创建依赖文件的根目录 215 | mkdir -p soc-sdk 216 | # 解压libsophon_soc_x.y.z_aarch64.tar.gz 217 | tar -zxf libsophon_soc_${x.y.z}_aarch64.tar.gz 218 | # 将相关的库目录和头文件目录拷贝到依赖文件根目录下 219 | cp -rf libsophon_soc_${x.y.z}_aarch64/opt/sophon/libsophon-${x.y.z}/lib ${soc-sdk} 220 | cp -rf libsophon_soc_${x.y.z}_aarch64/opt/sophon/libsophon-${x.y.z}/include ${soc-sdk} 221 | ``` 222 | 223 | 3. 打包sophon-ffmpeg和sophon-opencv 224 | 225 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的sophon-mw安装包,其中包括sophon-mw-soc_x.y.z_aarch64.tar.gz,x.y.z表示版本号,并进行解压,如果您使用BM1688 SOPHONSDK,“sophon-mw”这个名字或许需要替换成“sophon-media”。 226 | ```bash 227 | # 解压sophon-mw-soc_x.y.z_aarch64.tar.gz 228 | tar -zxf sophon-mw-soc_${x.y.z}_aarch64.tar.gz 229 | # 将ffmpeg和opencv的库目录和头文件目录拷贝到soc-sdk目录下 230 | cp -rf sophon-mw-soc_${x.y.z}_aarch64/opt/sophon/sophon-ffmpeg_${x.y.z}/lib ${soc-sdk} 231 | cp -rf sophon-mw-soc_${x.y.z}_aarch64/opt/sophon/sophon-ffmpeg_${x.y.z}/include ${soc-sdk} 232 | cp -rf sophon-mw-soc_${x.y.z}_aarch64/opt/sophon/sophon-opencv_${x.y.z}/lib ${soc-sdk} 233 | cp -rf sophon-mw-soc_${x.y.z}_aarch64/opt/sophon/sophon-opencv_${x.y.z}/include ${soc-sdk} 234 | ``` 235 | 236 | 4. 如果您使用BM1688的GeminiSDK1.3以上版本,您还需要做这些操作: 237 | 从sdk中获取sophon-img/bsp-debs/目录下的sophon-soc-libisp_${x.y.z}_arm64.deb,然后运行如下命令: 238 | ``` 239 | dpkg -x sophon-soc-libisp_${x.y.z}_arm64.deb sophon-libisp 240 | cp -rf sophon-libisp/opt/sophon/sophon-soc-libisp_${x.y.z}/lib ${soc-sdk} 241 | ``` 242 | 243 | 这里,交叉编译环境已经搭建完成,接下来可以使用打包好的soc-sdk编译需要在SoC平台上运行的程序。更多交叉编译信息请参考《LIBSOPHON使用手册.pdf》。 244 | 245 | ### 4.2 交叉编译安装sophon-sail 246 | 如果例程依赖sophon-sail则需要编译和安装sophon-sail,否则可跳过本章节。需要在x86主机上交叉编译sophon-sail,并在SoC平台上安装。 247 | 248 | 需从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的SOPHONSDK,进入sophon-sail_${date}文件夹,sophon-sail的发布包命名如sophon-sail_x.y.z.tar.gz,x.y.z表示版本号,您可以打开同级目录下的用户手册(命名为sophon-sail_zh.pdf或SOPHON-SAIL_zh.pdf)。 249 | 参考编译安装指南章节,选择您需要的模式(C++/Python,SoC MODE)进行安装,**注意需要选择包含ffmpeg和opencv的编译方式。** 250 | 251 | 252 | 在您按照教程将sophon-sail的库文件拷贝到目标soc上之后,您还需要设置以下环境变量: 253 | ```bash 254 | echo 'export LD_LIBRARY_PATH=/opt/sophon/sophon-sail/lib/:$LD_LIBRARY_PATH' >> ~/.bashrc 255 | source ~/.bashrc 256 | ``` 257 | ## 5 arm PCIe平台的开发和运行环境搭建 258 | 如果您在arm平台安装了PCIe加速卡,开发环境与运行环境可以是统一的,您可以直接在宿主机上搭建开发和运行环境。 259 | 这里提供银河麒麟v10机器的环境安装方法,其他类型机器具体请参考官网开发手册。 260 | ### 5.1 安装libsophon 261 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的libsophon安装包, 262 | 安装包由一个文件构成,其中“$arch”为当前机器的硬件架构,使用以下命令可以获取当前服务器的arch: 263 | ``` 264 | uname -m 265 | ``` 266 | 通常x86_64机器对应的硬件架构为x86_64,arm64机器对应的硬件架构为aarch64: 267 | ``` 268 | libsophon_x.y.z_$arch.tar.gz,x.y.z表示版本号 269 | ``` 270 | 可以通过如下步骤安装: 271 | 272 | **注意:如果有旧版本,先参考下面的卸载方式步骤卸载旧版本。** 273 | ``` 274 | tar -xzvf libsophon_${x.y.z}_aarch64.tar.gz 275 | sudo cp -r libsophon_${x.y.z}_aarch64/* / 276 | sudo ln -s /opt/sophon/libsophon-${x.y.z} /opt/sophon/libsophon-current 277 | ``` 278 | 接下来请先按照您所使用Linux发行版的要求搭建驱动编译环境,然后做如下操作: 279 | ``` 280 | sudo ln -s /opt/sophon/driver-${x.y.z}/$bin /lib/firmware/bm1684x_firmware.bin 281 | sudo ln -s /opt/sophon/driver-${x.y.z}/$bin /lib/firmware/bm1684_ddr_firmware.bin 282 | sudo ln -s /opt/sophon/driver-${x.y.z}/$bin /lib/firmware/bm1684_tcm_firmware.bin 283 | cd /opt/sophon/driver-${x.y.z} 284 | ``` 285 | 此处“$bin”是带有版本号的bin文件全名, 对于bm1684x板卡,为a53lite_pkg.bin,对于bm1684板卡,如bm1684_ddr.bin_v3.1.1-63a8614d-220906和bm1684_tcm.bin_v3.1.1-63a8614d-220906。 286 | 287 | 之后就可以编译驱动了(这里不依赖于dkms): 288 | ``` 289 | sudo make SOC_MODE=0 PLATFORM=asic SYNC_API_INT_MODE=1 \ 290 | TARGET_PROJECT=sg_pcie_device FW_SIMPLE=0 \ 291 | PCIE_MODE_ENABLE_CPU=1 292 | sudo cp ./bmsophon.ko /lib/modules/$(uname -r)/kernel/ 293 | sudo depmod 294 | sudo modprobe bmsophon 295 | ``` 296 | 最后是一些配置工作: 297 | 298 | 添加库和可执行文件路径: 299 | ``` 300 | sudo cp /opt/sophon/libsophon-current/data/libsophon.conf /etc/ld.so.conf.d/ 301 | sudo ldconfig 302 | sudo cp /opt/sophon/libsophon-current/data/libsophon-bin-path.sh /etc/profile.d/ 303 | ``` 304 | 在终端执行如下命令,或者登出再登入当前用户后即可使用bm-smi等命令: 305 | ``` 306 | source /etc/profile 307 | ``` 308 | 添加cmake config文件: 309 | ``` 310 | sudo mkdir -p /usr/lib/cmake/libsophon 311 | sudo cp /opt/sophon/libsophon-current/data/libsophon-config.cmake /usr/lib/cmake/libsophon/ 312 | ``` 313 | 卸载方式: 314 | ``` 315 | sudo rm -f /etc/ld.so.conf.d/libsophon.conf 316 | sudo ldconfig 317 | sudo rm -f /etc/profile.d/libsophon-bin-path.sh 318 | sudo rm -rf /usr/lib/cmake/libsophon 319 | sudo rmmod bmsophon 320 | sudo rm -f /lib/modules/$(uname -r)/kernel/bmsophon.ko 321 | sudo depmod 322 | sudo rm -f /lib/firmware/bm1684x_firmware.bin 323 | sudo rm -f /lib/firmware/bm1684_ddr_firmware.bin 324 | sudo rm -f /lib/firmware/bm1684_tcm_firmware.bin 325 | sudo rm -f /opt/sophon/libsophon-current 326 | sudo rm -rf /opt/sophon/libsophon-0.4.6 327 | sudo rm -rf /opt/sophon/driver-0.4.6 328 | ``` 329 | 其他平台机器请参考《LIBSOPHON使用手册.pdf》 330 | 331 | ### 5.2 安装sophon-ffmpeg和sophon-opencv 332 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的sophon-mw安装包, 333 | 安装包由一个文件构成: 334 | ``` 335 | sophon-mw_x.y.z_aarch64.tar.gz,x.y.z表示版本号 336 | ``` 337 | 可以通过如下步骤安装: 338 | 339 | 先按照《LIBSOPHON使用手册》安装好libsophon包,然后, 340 | ``` 341 | tar -xzvf sophon-mw_${x.y.z}_aarch64.tar.gz 342 | sudo cp -r sophon-mw_${x.y.z}_aarch64/* / 343 | sudo ln -s /opt/sophon/sophon-ffmpeg_${x.y.z} /opt/sophon/sophon-ffmpeg-latest 344 | sudo ln -s /opt/sophon/sophon-opencv_${x.y.z} /opt/sophon/sophon-opencv-latest 345 | sudo ln -s /opt/sophon/sophon-sample_${x.y.z} /opt/sophon/sophon-sample-latest 346 | sudo sed -i "s/usr\/local/opt\/sophon\/sophon-ffmpeg-latest/g" /opt/sophon/sophon-ffmpeg-latest/lib/pkgconfig/*.pc 347 | sudo sed -i "s/^prefix=.*$/prefix=\/opt\/sophon\/sophon-opencv-latest/g" /opt/sophon/sophon-opencv-latest/lib/pkgconfig/opencv4.pc 348 | ``` 349 | 最后,**安装bz2 libc6 libgcc依赖库**(这部分需要根据操作系统不同,选择对应的安装包,这里不统一介绍) 350 | 然后是一些配置工作: 351 | 352 | 添加库和可执行文件路径: 353 | ``` 354 | sudo cp /opt/sophon/sophon-ffmpeg-latest/data/01_sophon-ffmpeg.conf /etc/ld.so.conf.d/ 355 | sudo cp /opt/sophon/sophon-opencv-latest/data/02_sophon-opencv.conf /etc/ld.so.conf.d/ 356 | sudo ldconfig 357 | sudo cp /opt/sophon/sophon-ffmpeg-latest/data/sophon-ffmpeg-autoconf.sh /etc/profile.d/ 358 | sudo cp /opt/sophon/sophon-opencv-latest/data/sophon-opencv-autoconf.sh /etc/profile.d/ 359 | sudo cp /opt/sophon/sophon-sample-latest/data/sophon-sample-autoconf.sh /etc/profile.d/ 360 | source /etc/profile 361 | ``` 362 | 其他平台机器请参考《MULTIMEDIA使用手册.pdf》、《MULTIMEDIA开发参考手册.pdf》。 363 | 364 | ### 5.3 编译安装sophon-sail 365 | 如果例程依赖sophon-sail则需要编译和安装sophon-sail,否则可跳过本章节。 366 | 367 | 需从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的SDK,里面有sophon-sail的压缩包,命名如sophon-sail_x.y.z.tar.gz,x.y.z表示版本号。 368 | 您可以打开sophon-sail压缩包里面提供的用户手册(命名为sophon-sail_zh.pdf),参考编译安装指南章节,选择您需要的模式(C++/Python, ARM PCIE MODE)进行安装。 369 | 370 | ## 6 riscv PCIe平台的开发和运行环境搭建 371 | 如果您在riscv平台安装了PCIe加速卡,开发环境与运行环境可以是统一的,您可以直接在宿主机上搭建开发和运行环境。 372 | 这里提供SG2042服务器的环境安装方法,其他类型机器具体请参考官网开发手册。 373 | 374 | ### 6.1 安装libsophon 375 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的libsophon安装包, 376 | 安装包由以下3个文件构成: 377 | ```bash 378 | sophon-libsophon-dev-{x.y.z}.riscv64.rpm 379 | sophon-libsophon-{x.y.z}.riscv64.rpm 380 | sophon-driver-{x.y.z}.riscv64.rpm 381 | ``` 382 | 安装前需要通过后面“卸载方式”中的步骤卸载旧版本libsophon,可以通过如下步骤安装: 383 | ```bash 384 | 安装依赖库,只需要执行一次: 385 | sudo yum install -y epel-release 386 | sudo yum install -y dkms 387 | sudo yum install -y ncurses* 388 | 安装libsophon: 389 | sudo rpm -ivh sophon-driver-{x.y.z}.riscv64.rpm 390 | sudo rpm -ivh sophon-libsophon-{x.y.z}.riscv64.rpm 391 | sudo rpm -ivh --force sophon-libsophon-dev-{x.y.z}.riscv64.rpm 392 | 在终端执行如下命令,或者登出再登入当前用户后即可使用bm-smi等命令: 393 | source /etc/profile 394 | ``` 395 | 卸载方式: 396 | ```bash 397 | sudo rpm -e sophon-driver 398 | sudo rpm -e sophon-libsophon-dev 399 | sudo rpm -e sophon-libsophon 400 | ``` 401 | 其他平台机器请参考《LIBSOPHON使用手册.pdf》。 402 | 403 | ### 6.2 安装sophon-ffmpeg和sophon-opencv 404 | 从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的sophon-mw安装包, 405 | 406 | sophon-mw安装包由一个文件构成: 407 | ```bash 408 | sophon-mw_{x.y.z}_riscv.tar.gz 409 | ``` 410 | 411 | 安装之前需要保证libsophon已安装完毕,如果有老版本,请参考"卸载方式"卸载,安装步骤如下: 412 | ```bash 413 | tar -xzvf sophon-mw_{x.y.z}_riscv_64.tar.gz 414 | sudo cp -r sophon-mw_{x.y.z}_riscv_64/* / 415 | sudo ln -s /opt/sophon/sophon-ffmpeg_{x.y.z} /opt/sophon/sophon-ffmpeg-latest 416 | sudo ln -s /opt/sophon/sophon-opencv_{x.y.z} /opt/sophon/sophon-opencv-latest 417 | sudo ln -s /opt/sophon/sophon-sample_{x.y.z} /opt/sophon/sophon-sample-latest 418 | sudo sed -i "s/usr\/local/opt\/sophon\/sophon-ffmpeg-latest/g" /opt/sophon/sophon-ffmpeg-latest/lib/pkgconfig/*.pc 419 | sudo sed -i "s/^prefix=.*$/prefix=\/opt\/sophon\/sophon-opencv-latest/g" /opt/sophon/sophon-opencv-latest/lib/pkgconfig/opencv4.pc 420 | sudo cp /opt/sophon/sophon-ffmpeg-latest/data/01_sophon-ffmpeg.conf /etc/ld.so.conf.d/ 421 | sudo cp /opt/sophon/sophon-opencv-latest/data/02_sophon-opencv.conf /etc/ld.so.conf.d/ 422 | sudo ldconfig 423 | sudo cp /opt/sophon/sophon-ffmpeg-latest/data/sophon-ffmpeg-autoconf.sh /etc/profile.d/ 424 | sudo cp /opt/sophon/sophon-opencv-latest/data/sophon-opencv-autoconf.sh /etc/profile.d/ 425 | sudo cp /opt/sophon/sophon-sample-latest/data/sophon-sample-autoconf.sh /etc/profile.d/ 426 | source /etc/profile 427 | ``` 428 | 429 | 卸载方式: 430 | ```bash 431 | sudo rm -f /etc/ld.so.conf.d/01_sophon-ffmpeg.conf 432 | sudo rm -f /etc/ld.so.conf.d/02_sophon-opencv.conf 433 | sudo ldconfig 434 | sudo rm -f /etc/profile.d/sophon-ffmpeg-autoconf.sh 435 | sudo rm -f /etc/profile.d/sophon-opencv-autoconf.sh 436 | sudo rm -f /etc/profile.d/sophon-sample-autoconf.sh 437 | sudo rm -f /opt/sophon/sophon-ffmpeg-latest 438 | sudo rm -f /opt/sophon/sophon-opencv-latest 439 | sudo rm -f /opt/sophon/sophon-sample-latest 440 | sudo rm -rf /opt/sophon/sophon-ffmpeg_{x.y.z} /opt/sophon/sophon-opencv_{x.y.z} /opt/sophon/sophon-sample_{x.y.z} 441 | ``` 442 | 443 | 其他平台机器请参考《MULTIMEDIA使用手册.pdf》、《MULTIMEDIA开发参考手册.pdf》。 444 | 445 | ### 6.3 编译安装sophon-sail 446 | 如果例程依赖sophon-sail则需要编译和安装sophon-sail,否则可跳过本章节。 447 | 448 | 需从[算能官网](https://developer.sophgo.com/site/index/material/28/all.html)上下载符合[环境依赖](../README.md)的SDK,里面有sophon-sail的压缩包,命名如sophon-sail_x.y.z.tar.gz,x.y.z表示版本号。 449 | 您可以打开sophon-sail压缩包里面提供的用户手册(命名为sophon-sail_zh.pdf),参考编译安装指南章节,选择您需要的模式(C++/Python,RSICV PCIE MODE)进行安装。 -------------------------------------------------------------------------------- /docs/Sail_Install_Guide.md: -------------------------------------------------------------------------------- 1 | # Sail_Install_Guide 2 | 3 | 以下提供两种方式分别在不同平台(PCIe、SoC)安装sail。 4 | 1. 源码编译安装 5 | 2. 通过预编译的whl包安装 6 | 7 | 更详细的Sail的安装可以参考[sophon-sail编译安装指南](https://doc.sophgo.com/sdk-docs/v23.07.01/docs_latest_release/docs/sophon-sail/docs/zh/html/1_build.html#)。 8 | 9 | 通常情况下,在运行环境通过源码编译sail不会有`python或者驱动版本依赖`问题,PCIe环境下推荐[源码编译安装](#源码编译安装)。Soc环境下python版本较为固定,建议使用[预编译的whl包安装](#预编译的whl包安装)。 10 | 11 | 此例程无需安装sophon-mw,源码编译时可设置`-DONLY_RUNTIME=ON`。 12 | 13 | - [源码编译安装](#源码编译安装) 14 | - [x86/arm PCIe平台](#x86arm-pcie平台) 15 | - [SoC平台](#soc平台) 16 | - [预编译的whl包安装](#预编译的whl包安装) 17 | 18 | ## 源码编译安装 19 | ### x86/arm PCIe平台 20 | 21 | 如果您在x86/arm平台安装了PCIe加速卡(如SC系列加速卡),并使用它测试本例程,您需要安装libsophon,具体请参考[x86-pcie平台的开发和运行环境搭建](./Environment_Install_Guide.md#3-x86-pcie平台的开发和运行环境搭建)或[arm-pcie平台的开发和运行环境搭建](./Environment_Install_Guide.md#5-arm-pcie平台的开发和运行环境搭建)。 22 | 23 | 下载SOPHON-SAIL源码,解压后进入其源码目录,编译`不包含bmcv,sophon-ffmpeg,sophon-opencv`的SAIL, 通过此方式编译出来的SAIL无法使用其Decoder、Bmcv等多媒体相关接口。 24 | ```bash 25 | pip3 install dfss -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade 26 | python3 -m dfss --url=open@sophgo.com:sophon-demo/ChatGLM3/sail/sophon-sail_20240226.tar.gz 27 | tar xvf sophon-sail_20240226.tar.gz 28 | ``` 29 | 30 | 创建编译文件夹build,并进入build文件夹 31 | ```bash 32 | cd sophon-sail 33 | mkdir build && cd build 34 | ``` 35 | 执行编译命令 36 | 37 | ```bash 38 | cmake -DONLY_RUNTIME=ON .. 39 | make pysail 40 | ``` 41 | 打包生成python wheel,生成的wheel包的路径为‘python/pcie/dist’,文件名为‘sophon-3.7.0-py3-none-any.whl’ 42 | ```bash 43 | cd ../python/pcie 44 | chmod +x sophon_pcie_whl.sh 45 | ./sophon_pcie_whl.sh 46 | ``` 47 | 安装python wheel 48 | 49 | ```bash 50 | pip3 install ./dist/sophon-3.7.0-py3-none-any.whl --force-reinstall 51 | ``` 52 | 53 | ### SoC平台 54 | 55 | 如果您使用SoC平台(如SE、SM系列边缘设备),并使用它测试本例程,刷机后在`/opt/sophon/`下已经预装了相应的libsophon、sophon-opencv和sophon-ffmpeg运行库包。 56 | 57 | `使用指定版本的python3(和目标SOC上的python3保持一致)`,通过交叉编译的方式,编译出`不包含bmcv,sophon-ffmpeg,sophon-opencv的SAIL`, python3的安装方式可通过python官方网站获取, 也可以根据[获取在X86主机上进行交叉编译的Python3]获取已经编译好的python3。 本示例使用的python3路径为‘python_3.8.2/bin/python3’,python3的动态库目录‘python_3.8.2/lib’。 58 | 59 | 如果您需要其他版本的sophon-sail,可以参考上一小节,下载源码自己编译,参考[sail交叉编译方法](https://doc.sophgo.com/sdk-docs/v23.07.01/docs_latest_release/docs/sophon-sail/docs/zh/html/1_build.html#id5)。 60 | 61 | 下载SOPHON-SAIL源码,解压后进入其源码目录。通过此方式编译出来的SAIL无法使用其Decoder、Bmcv等多媒体相关接口。 62 | 63 | 创建编译文件夹build,并进入build文件夹 64 | ```bash 65 | mkdir build && cd build 66 | ``` 67 | 执行编译命令 68 | ```bash 69 | cmake -DBUILD_TYPE=soc \ 70 | -DONLY_RUNTIME=ON \ 71 | -DCMAKE_TOOLCHAIN_FILE=../cmake/BM168x_SOC/ToolChain_aarch64_linux.cmake \ 72 | -DPYTHON_EXECUTABLE=python_3.8.2/bin/python3 \ 73 | -DCUSTOM_PY_LIBDIR=python_3.8.2/lib \ 74 | -DLIBSOPHON_BASIC_PATH=libsophon_soc_0.4.1_aarch64/opt/sophon/libsophon-0.4.1 .. 75 | make pysail 76 | ``` 77 | 打包生成python wheel,生成的wheel包的路径为‘python/soc/dist’,文件名为‘sophon_arm-3.7.0-py3-none-any.whl’ 78 | ```bash 79 | cd ../python/soc 80 | chmod +x sophon_soc_whl.sh 81 | ./sophon_soc_whl.sh 82 | ``` 83 | 安装python wheel 84 | 85 | 将‘sophon_arm-3.7.0-py3-none-any.whl’拷贝到目标SOC上,然后执行如下安装命令 86 | ```bash 87 | pip3 install sophon_arm-3.7.0-py3-none-any.whl --force-reinstall 88 | ``` 89 | 90 | ## 预编译的whl包安装 91 | 92 | 目前大多设备python版本为3.8.2,可直接下载此whl包 93 | 94 | ```bash 95 | pip3 install dfss -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade 96 | python3 -m dfss --url=open@sophgo.com:sophon-demo/ChatGLM3/sail/soc/sophon_arm-3.7.0-py3-none-any.whl 97 | ``` 98 | 99 | 下面提供了在不同版本python、libsophon、sophon-mw下预编译好的whl包,用户可根据自己的机器环境选择安装。可通过`ls /opt/sophon`查看运行环境各sdk版本 100 | ```bash 101 | pip3 install dfss -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade 102 | python3 -m dfss --url=open@sophgo.com:ezoo/chatdoc/python_wheels.zip 103 | unzip python_wheels.zip 104 | ``` 105 | 106 | 文件目录如下 107 | ``` 108 | python_wheels 109 | ├── arm_pcie 110 | │ ├── libsophon-0.4.4_sophonmw-0.5.1 111 | │ │ ├── py310 112 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 113 | │ │ ├── py35 114 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 115 | │ │ ├── py36 116 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 117 | │ │ ├── py37 118 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 119 | │ │ ├── py38 120 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 121 | │ │ └── py39 122 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 123 | │ ├── libsophon-0.4.6_sophonmw-0.6.0 124 | │ │ ├── py310 125 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 126 | │ │ ├── py35 127 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 128 | │ │ ├── py36 129 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 130 | │ │ ├── py37 131 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 132 | │ │ ├── py38 133 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 134 | │ │ └── py39 135 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 136 | │ ├── libsophon-0.4.8_sophonmw-0.6.2 137 | │ │ ├── py310 138 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 139 | │ │ ├── py35 140 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 141 | │ │ ├── py36 142 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 143 | │ │ ├── py37 144 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 145 | │ │ ├── py38 146 | │ │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 147 | │ │ └── py39 148 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 149 | │ └── libsophon-0.4.9_sophonmw-0.7.0 150 | │ ├── py310 151 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 152 | │ ├── py35 153 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 154 | │ ├── py36 155 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 156 | │ ├── py37 157 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 158 | │ ├── py38 159 | │ │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 160 | │ └── py39 161 | │ └── sophon_arm_pcie-3.7.0-py3-none-any.whl 162 | ├── loongarch 163 | │ ├── libsophon-0.4.8_runtime 164 | │ │ ├── py310 165 | │ │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 166 | │ │ ├── py37 167 | │ │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 168 | │ │ ├── py38 169 | │ │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 170 | │ │ └── py39 171 | │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 172 | │ └── libsophon-0.4.9_sophonmw-0.7.0 173 | │ ├── py310 174 | │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 175 | │ ├── py37 176 | │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 177 | │ ├── py38 178 | │ │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 179 | │ └── py39 180 | │ └── sophon_loongarch64-3.7.0-py3-none-any.whl 181 | ├── soc_BM1684_BM1684X 182 | │ ├── libsophon-0.4.4_sophonmw-0.5.1 183 | │ │ ├── py310 184 | │ │ │ └── sophon_arm-3.7.0-py3-none-any.whl 185 | │ │ └── py38 186 | │ │ └── sophon_arm-3.7.0-py3-none-any.whl 187 | │ ├── libsophon-0.4.6_sophonmw-0.6.0 188 | │ │ ├── py310 189 | │ │ │ └── sophon_arm-3.7.0-py3-none-any.whl 190 | │ │ └── py38 191 | │ │ └── sophon_arm-3.7.0-py3-none-any.whl 192 | │ ├── libsophon-0.4.8_sophonmw-0.6.2 193 | │ │ ├── py310 194 | │ │ │ └── sophon_arm-3.7.0-py3-none-any.whl 195 | │ │ └── py38 196 | │ │ └── sophon_arm-3.7.0-py3-none-any.whl 197 | │ └── libsophon-0.4.9_sophonmw-0.7.0 198 | │ ├── py310 199 | │ │ └── sophon_arm-3.7.0-py3-none-any.whl 200 | │ └── py38 201 | │ └── sophon_arm-3.7.0-py3-none-any.whl 202 | └── soc_BM1688 203 | └── libsophon-0.4.9_sophonmw-1.2.0 204 | ├── py310 205 | │ └── sophon_arm-3.7.0-py3-none-any.whl 206 | └── py38 207 | └── sophon_arm-3.7.0-py3-none-any.whl 208 | ``` 209 | 210 | 选择适合运行环境的版本安装,例如 211 | ```bash 212 | pip3 install sophon_arm-3.7.0-py3-none-any.whl --force-reinstall 213 | ``` 214 | -------------------------------------------------------------------------------- /embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | from .embedding import Word2VecEmbedding -------------------------------------------------------------------------------- /embedding/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | from langchain.embeddings.base import Embeddings 11 | from typing import List 12 | from .sentence_model import SentenceModel 13 | 14 | 15 | class Word2VecEmbedding(Embeddings): 16 | model = SentenceModel() 17 | 18 | def embed_query(self, text: str) -> List[float]: 19 | embeddings_tpu = self.model.encode_tpu([text, "", "", ""]) 20 | return embeddings_tpu.tolist()[0] 21 | 22 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 23 | embeddings_tpu = self.model.encode_tpu(texts) 24 | return embeddings_tpu.tolist() 25 | 26 | -------------------------------------------------------------------------------- /embedding/npuengine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | import logging 11 | import sophon.sail as sail 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 13 | 14 | class EngineOV: 15 | def __init__(self, model_path="./models/bert_model/bge_large_512_fp16.bmodel", device_id=0) : 16 | self.net = sail.Engine(model_path, device_id, sail.IOMode.SYSIO) 17 | logging.info("load {} success, dev_id {}".format(model_path, device_id)) 18 | self.graph_name = self.net.get_graph_names()[0] 19 | self.input_names = self.net.get_input_names(self.graph_name) 20 | self.output_names = self.net.get_output_names(self.graph_name) 21 | 22 | 23 | def __call__(self, input_ids, attention_mask, token_type_ids): 24 | input_data = {self.input_names[0]: input_ids, 25 | self.input_names[1]: attention_mask, 26 | self.input_names[2]: token_type_ids} 27 | outputs = self.net.process(self.graph_name, input_data) 28 | return outputs[self.output_names[0]] 29 | 30 | -------------------------------------------------------------------------------- /embedding/sentence_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #===----------------------------------------------------------------------===# 3 | # 4 | # Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. 5 | # 6 | # SOPHON-DEMO is licensed under the 2-Clause BSD License except for the 7 | # third-party components. 8 | # 9 | #===----------------------------------------------------------------------===# 10 | from enum import Enum 11 | from typing import List, Union, Optional 12 | import numpy as np 13 | import torch 14 | from tqdm.autonotebook import trange 15 | from transformers import AutoTokenizer 16 | import configparser 17 | import logging 18 | import os 19 | 20 | from .npuengine import EngineOV 21 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 22 | 23 | 24 | class EncoderType(Enum): 25 | FIRST_LAST_AVG = 0 26 | LAST_AVG = 1 27 | CLS = 2 28 | POOLER = 3 29 | MEAN = 4 30 | 31 | def __str__(self): 32 | return self.name 33 | 34 | @staticmethod 35 | def from_string(s): 36 | try: 37 | return EncoderType[s] 38 | except KeyError: 39 | raise ValueError() 40 | 41 | 42 | class SentenceModel: 43 | def __init__( 44 | self, 45 | model_name_or_path: str = "shibing624/text2vec-base-chinese", 46 | encoder_type: Union[str, EncoderType] = "MEAN", 47 | max_seq_length: int = 256, 48 | device: Optional[str] = None, 49 | ): 50 | """ 51 | Initializes the base sentence model. 52 | 53 | :param model_name_or_path: The name of the model to load from the huggingface models library. 54 | :param encoder_type: The type of encoder to use, See the EncoderType enum for options: 55 | FIRST_LAST_AVG, LAST_AVG, CLS, POOLER(cls + dense), MEAN(mean of last_hidden_state) 56 | :param max_seq_length: The maximum sequence length. 57 | :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if GPU. 58 | 59 | bert model: https://huggingface.co/transformers/model_doc/bert.html?highlight=bert#transformers.BertModel.forward 60 | BERT return: , [hidden_states, attentions] 61 | Note that: in doc, it says is better semantic summery than . 62 | thus, we use . 63 | """ 64 | config = configparser.ConfigParser() 65 | config.read('config.ini') 66 | bmodel_path = config.get('bert_model', 'bmodel_path') 67 | token_path = config.get('bert_model', 'token_path') 68 | dev_id = 0 69 | if os.getenv("DEVICE_ID"): 70 | dev_id = int(os.getenv("DEVICE_ID")) 71 | else: 72 | logging.warning("DEVICE_ID is empty in env var, use default {}".format(dev_id)) 73 | self.model_name_or_path = model_name_or_path 74 | encoder_type = EncoderType.from_string(encoder_type) if isinstance(encoder_type, str) else encoder_type 75 | if encoder_type not in list(EncoderType): 76 | raise ValueError(f"encoder_type must be in {list(EncoderType)}") 77 | self.encoder_type = encoder_type 78 | self.max_seq_length = max_seq_length 79 | self.tokenizer = AutoTokenizer.from_pretrained(token_path) 80 | 81 | self.bert = EngineOV(model_path=bmodel_path, 82 | device_id=dev_id) 83 | self.bert.padding_to = 512 84 | 85 | 86 | def __str__(self): 87 | return f"" 89 | 90 | def get_sentence_embedding_dimension(self): 91 | """ 92 | Get the dimension of the sentence embeddings. 93 | 94 | Returns 95 | ------- 96 | int or None 97 | The dimension of the sentence embeddings, or None if it cannot be determined. 98 | """ 99 | # Use getattr to safely access the out_features attribute of the pooler's dense layer 100 | return getattr(self.bert.pooler.dense, "out_features", None) 101 | 102 | def get_sentence_embeddings_tpu(self, input_ids, attention_mask, token_type_ids=None): 103 | """ 104 | Returns the model output by encoder_type as embeddings. 105 | 106 | Utility function for self.bert() method. 107 | """ 108 | input_ids, attention_mask, token_type_ids = input_ids.numpy(), attention_mask.numpy(), token_type_ids.numpy() 109 | if input_ids.shape[1] > self.bert.padding_to: 110 | input_ids = input_ids[:, :self.bert.padding_to] 111 | attention_mask = attention_mask[:, :self.bert.padding_to] 112 | token_type_ids = token_type_ids[:, :self.bert.padding_to] 113 | elif input_ids.shape[1] < self.bert.padding_to: 114 | input_ids = np.pad(input_ids, 115 | ((0, 0), (0, self.bert.padding_to - input_ids.shape[1])), 116 | mode='constant', constant_values=0) 117 | attention_mask = np.pad(attention_mask, 118 | ((0, 0), (0, self.bert.padding_to - attention_mask.shape[1])), 119 | mode='constant', constant_values=0) 120 | token_type_ids = np.pad(token_type_ids, 121 | ((0, 0), (0, self.bert.padding_to - token_type_ids.shape[1])), 122 | mode='constant', constant_values=0) 123 | model_output = self.bert(input_ids.astype(np.float32), 124 | attention_mask.astype(np.float32), 125 | token_type_ids.astype(np.float32)) 126 | 127 | if self.encoder_type == EncoderType.MEAN: 128 | """ 129 | Mean Pooling - Take attention mask into account for correct averaging 130 | """ 131 | token_embeddings = torch.from_numpy(model_output) # Contains all token embeddings 132 | input_mask_expanded = torch.from_numpy(attention_mask).unsqueeze(-1).expand(token_embeddings.size()).float() 133 | final_encoding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 134 | input_mask_expanded.sum(1), min=1e-9) 135 | return final_encoding # [batch, hid_size] 136 | else: 137 | raise NotImplementedError 138 | 139 | def encode_tpu( 140 | self, 141 | sentences: Union[str, List[str]], 142 | batch_size: int = 32, 143 | show_progress_bar: bool = False, 144 | convert_to_numpy: bool = True, 145 | convert_to_tensor: bool = False, 146 | device: str = None, 147 | normalize_embeddings: bool = False, 148 | max_seq_length: int = None, 149 | ): 150 | """ 151 | Returns the embeddings for a batch of sentences. 152 | 153 | :param sentences: str/list, Input sentences 154 | :param batch_size: int, Batch size 155 | :param show_progress_bar: bool, Whether to show a progress bar for the sentences 156 | :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. 157 | :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy 158 | :param device: Which device to use for the computation 159 | :param normalize_embeddings: If true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. 160 | :param max_seq_length: Override value for max_seq_length 161 | """ 162 | if max_seq_length is None: 163 | max_seq_length = self.max_seq_length 164 | if convert_to_tensor: 165 | convert_to_numpy = False 166 | input_is_string = False 167 | if isinstance(sentences, str) or not hasattr(sentences, "__len__"): 168 | sentences = [sentences] 169 | input_is_string = True 170 | 171 | all_embeddings = [] 172 | length_sorted_idx = np.argsort([-len(s) for s in sentences]) 173 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 174 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): 175 | sentences_batch = sentences_sorted[start_index: start_index + batch_size] 176 | # Compute sentences embeddings 177 | with torch.no_grad(): 178 | features = self.tokenizer( 179 | sentences_batch, max_length=max_seq_length, 180 | padding=True, truncation=True, return_tensors='pt' 181 | ) 182 | embeddings = self.get_sentence_embeddings_tpu(**features) 183 | embeddings = embeddings.detach() 184 | if normalize_embeddings: 185 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 186 | 187 | if convert_to_numpy: 188 | embeddings = embeddings.cpu() 189 | all_embeddings.extend(embeddings) 190 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 191 | if convert_to_tensor: 192 | all_embeddings = torch.stack(all_embeddings) 193 | elif convert_to_numpy: 194 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 195 | 196 | if input_is_string: 197 | all_embeddings = all_embeddings[0] 198 | return all_embeddings 199 | 200 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyifan2018/ChatDoc-TPU/1559fce5bcc972b6fee49a211beeaa66e26b37f1/models/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-cpu==1.7.4 2 | Flask==2.2.2 3 | langchain==0.0.189 4 | numpy==1.24.4 5 | sentencepiece==0.1.99 6 | streamlit==1.25.0 7 | tiktoken==0.4.0 8 | torch==2.0.1 9 | tqdm==4.66.1 10 | transformers==4.27.1 11 | unstructured==0.7.1 12 | rapidocr_onnxruntime==1.3.8 13 | PyMuPDF==1.23.16 14 | pdf2image 15 | tabulate -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | res=$(which unzip) 5 | 6 | if [ $? != 0 ]; 7 | then 8 | echo "Please install unzip on your system!" 9 | exit 10 | fi 11 | 12 | pip3 install dfss -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade 13 | 14 | # default param 15 | llm_model="chatglm3" 16 | dev_id="0" 17 | server_address="0.0.0.0" 18 | server_port="" 19 | 20 | # Args 21 | parse_args() { 22 | while [[ $# -gt 0 ]]; do 23 | key="$1" 24 | 25 | case $key in 26 | --model) 27 | llm_model="$2" 28 | shift 2 29 | ;; 30 | --dev_id) 31 | dev_id="$2" 32 | shift 2 33 | ;; 34 | --server_address) 35 | server_address="$2" 36 | shift 2 37 | ;; 38 | --server_port) 39 | server_port="$2" 40 | shift 2 41 | ;; 42 | *) 43 | echo "Invalid option: $key" >&2 44 | exit 1 45 | ;; 46 | :) 47 | echo "Option -$OPTARG requires an argument." >&2 48 | exit 1 49 | ;; 50 | esac 51 | done 52 | } 53 | 54 | # Process Args 55 | parse_args "$@" 56 | 57 | 58 | # nltk_data & bert_model is required 59 | if [ ! -d "$HOME/nltk_data" ]; then 60 | echo "$HOME/nltk_dat does not exist, download..." 61 | python3 -m dfss --url=open@sophgo.com:ezoo/chatdoc/nltk_data.zip 62 | unzip nltk_data.zip 63 | mv nltk_data ~ 64 | rm nltk_data.zip 65 | echo "nltk_data download!" 66 | else 67 | echo "$HOME/nltk_dat already exist..." 68 | fi 69 | 70 | # download bert_model 71 | if [ ! -d "./models/bert_model" ]; then 72 | echo "./models/bert_model does not exist, download..." 73 | python3 -m dfss --url=open@sophgo.com:ezoo/chatdoc/bert_model.zip 74 | unzip bert_model.zip -d ./models 75 | rm bert_model.zip 76 | echo "bert_model download!" 77 | else 78 | echo "$HOME/nltk_dat already exist..." 79 | fi 80 | 81 | # download LLM models 82 | if [ "$llm_model" == "chatglm3" ]; then 83 | if [ ! -d "./models/glm3_model" ]; then 84 | echo "./models/glm3_model does not exist, download..." 85 | python3 -m dfss --url=open@sophgo.com:ezoo/chatdoc/glm3_model.tar.gz 86 | tar -zxvf glm3_model.tar.gz -C ./models 87 | rm glm3_model.tar.gz 88 | echo "glm3_model download!" 89 | else 90 | echo "./models/glm3_model already exist..." 91 | fi 92 | elif [ "$llm_model" == "qwen7b" ]; then 93 | if [ ! -d "./models/qwen_model" ]; then 94 | echo "./models/qwen_model does not exist, download...." 95 | python3 -m dfss --url=open@sophgo.com:ezoo/chatdoc/qwen_model.tar.gz 96 | tar -zxvf qwen_model.tar.gz -C ./models 97 | rm qwen_model.tar.gz 98 | echo "qwen_model download!" 99 | else 100 | echo "./models/qwen_model already exist..." 101 | fi 102 | else 103 | echo "Error: --model is not recognized. Must be 'chatglm3' or 'qwen'." 104 | exit 1 105 | fi 106 | 107 | 108 | export LLM_MODEL=$llm_model 109 | export DEVICE_ID=$dev_id 110 | 111 | if [ "$server_port" == "" ]; then 112 | # auto server port 113 | streamlit run web_demo_st.py --server.address $server_address 114 | else 115 | streamlit run web_demo_st.py --server.address $server_address --server.port $server_port 116 | fi -------------------------------------------------------------------------------- /scripts/compile.sh: -------------------------------------------------------------------------------- 1 | model_transform.py \ 2 | --model_name bge_large_512 \ 3 | --model_def text2vec-bge-large-chinese.onnx \ 4 | --input_shapes [[4,512],[4,512],[4,512]] \ 5 | --mlir bge_large_512.mlir 6 | 7 | model_deploy.py \ 8 | --mlir bge_large_512.mlir \ 9 | --quantize F16 \ 10 | --chip bm1684x \ 11 | --model bge_large_512_fp16.bmodel \ 12 | --compare_all \ 13 | --debug 14 | -------------------------------------------------------------------------------- /scripts/export_onnx.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | import torch 3 | import numpy as np 4 | 5 | # Mean Pooling - Take attention mask into account for correct averaging 6 | def mean_pooling(model_output, attention_mask): 7 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings 8 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 9 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 10 | 11 | # Load model from HuggingFace Hub 12 | tokenizer = BertTokenizer.from_pretrained('shibing624/text2vec-bge-large-chinese') 13 | model = BertModel.from_pretrained('shibing624/text2vec-bge-large-chinese') 14 | sentences = ['如何更换花呗绑定银行卡'] 15 | # Tokenize sentences 16 | encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') 17 | 18 | 19 | # model.eval() 20 | input_ids = encoded_input['input_ids'] 21 | token_type_ids = encoded_input['token_type_ids'] 22 | attention_mask = encoded_input['attention_mask'] 23 | 24 | input_ids, attention_mask, token_type_ids = input_ids.numpy(), attention_mask.numpy(), token_type_ids.numpy() 25 | if input_ids.shape[1] > 512: 26 | input_ids = input_ids[:, :512] 27 | attention_mask = attention_mask[:, :512] 28 | token_type_ids = token_type_ids[:, :512] 29 | elif input_ids.shape[1] < 512: 30 | input_ids = np.pad(input_ids, 31 | ((0, 0), (0, 512 - input_ids.shape[1])), 32 | mode='constant', constant_values=0) 33 | attention_mask = np.pad(attention_mask, 34 | ((0, 0), (0, 512 - attention_mask.shape[1])), 35 | mode='constant', constant_values=0) 36 | token_type_ids = np.pad(token_type_ids, 37 | ((0, 0), (0, 512 - token_type_ids.shape[1])), 38 | mode='constant', constant_values=0) 39 | input_ids = torch.tensor(input_ids) 40 | token_type_ids = torch.tensor(token_type_ids) 41 | attention_mask = torch.tensor(attention_mask) 42 | 43 | # Compute token embeddings 44 | with torch.no_grad(): 45 | model_output = model(input_ids, attention_mask, token_type_ids) 46 | 47 | print(input_ids) 48 | print(token_type_ids) 49 | print(attention_mask) 50 | 51 | torch.onnx.export(model, (input_ids,attention_mask,token_type_ids), "text2vec-bge-large-chinese.onnx", input_names=['input_ids', 'attention_mask', 'token_type_ids'],dynamic_axes={'input_ids': {0: 'batch'}, 'attention_mask': {0: 'batch'}, 'token_type_ids': {0: 'batch'}}) 52 | # Perform pooling. In this case, mean pooling. 53 | sentence_embeddings = mean_pooling(model_output, attention_mask) 54 | print("Sentence embeddings:") 55 | print(sentence_embeddings) 56 | -------------------------------------------------------------------------------- /static/embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyifan2018/ChatDoc-TPU/1559fce5bcc972b6fee49a211beeaa66e26b37f1/static/embedding.png -------------------------------------------------------------------------------- /static/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyifan2018/ChatDoc-TPU/1559fce5bcc972b6fee49a211beeaa66e26b37f1/static/img1.png -------------------------------------------------------------------------------- /static/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyifan2018/ChatDoc-TPU/1559fce5bcc972b6fee49a211beeaa66e26b37f1/static/img2.png -------------------------------------------------------------------------------- /web_demo_st.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from chat import DocChatbot 3 | import os 4 | import streamlit as st 5 | import time 6 | import sys 7 | import logging 8 | sys.path.append(".") 9 | sys.path.append(os.path.join(os.path.dirname(__file__), 'doc_processor')) 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 11 | 12 | 13 | @st.cache_resource 14 | def load_model(): 15 | return DocChatbot.get_instance() 16 | 17 | 18 | chatbot_st = load_model() 19 | 20 | # TODO: use glm2 format and hard code now, new to opt 21 | def cut_history(u_input): 22 | if 'messages' not in st.session_state: 23 | return [] 24 | 25 | history = [] 26 | for idx in range(0, len(st.session_state['messages']), 2): 27 | history.append((st.session_state['messages'][idx]['content'], st.session_state['messages'][idx + 1]['content'])) 28 | 29 | spilt = -1 30 | if len(history) > 0: 31 | while spilt >= -len(history): 32 | prompt_str = ["[Round {}]\n\n问:{}\n\n答:{}\n\n".format(idx + 1, x[0], x[1]) for idx, x in 33 | enumerate(history[spilt:])] 34 | if len("".join(prompt_str) + u_input) < 450: 35 | spilt -= 1 36 | else: 37 | spilt += 1 38 | break 39 | else: 40 | spilt = 0 41 | 42 | if spilt != 0: 43 | history = history[spilt:] 44 | 45 | return history 46 | 47 | 48 | with st.sidebar: 49 | st.title("💬 ChatDoc-TPU") 50 | st.write("上传一个文档,然后与我对话.") 51 | with st.form("Upload and Process", True): 52 | uploaded_file = st.file_uploader("上传文档", type=None, accept_multiple_files=True, help = "目前支持多种文件格式,包括pdf、ppt、html、图片、表格等" 53 | ) 54 | 55 | option = st.selectbox( 56 | "选择已保存的知识库", 57 | chatbot_st.get_vector_db(), 58 | format_func=lambda x: chatbot_st.time2file_name(x) 59 | ) 60 | 61 | col1, col2 = st.columns(2) 62 | with col1: 63 | import_repository = st.form_submit_button("导入知识库") 64 | with col2: 65 | add_repository = st.form_submit_button("添加知识库") 66 | col3, col4 = st.columns(2) 67 | with col3: 68 | save_repository = st.form_submit_button("保存知识库") 69 | with col4: 70 | del_repository = st.form_submit_button("删除知识库") 71 | 72 | col5, col6 = st.columns(2) 73 | with col5: 74 | clear = st.form_submit_button("清除聊天记录") 75 | with col6: 76 | clear_file = st.form_submit_button("移除选中文档") 77 | 78 | if st.form_submit_button("重命名知识库") or 'renaming' in st.session_state: 79 | if len([x for x in chatbot_st.get_vector_db()]) == 0: 80 | st.error("无可选的本地知识库。") 81 | st.stop() 82 | 83 | st.session_state['renaming'] = True 84 | title = st.text_input('新知识库名称') 85 | if st.form_submit_button("确认重命名"): 86 | if title == "": 87 | st.error("请输出新的知识库名称。") 88 | else: 89 | chatbot_st.rename(option, title) 90 | st.success("重命名成功。") 91 | del st.session_state["renaming"] 92 | time.sleep(0.1) 93 | st.experimental_rerun() 94 | 95 | if save_repository and 'files' not in st.session_state: 96 | st.error("先上传文件构建知识库,才能保存知识库。") 97 | 98 | if not uploaded_file and add_repository: 99 | st.error("请先上传文件,再点击构建知识库。") 100 | 101 | if import_repository and len([x for x in chatbot_st.get_vector_db()]) == 0: 102 | st.error("无可选的本地知识库。") 103 | 104 | if clear: 105 | if 'files' not in st.session_state: 106 | if "messages" in st.session_state: 107 | del st.session_state["messages"] 108 | else: 109 | st.session_state["messages"] = [{"role": "assistant", "content": "嗨!"}] 110 | 111 | if clear_file: 112 | if 'files' in st.session_state: 113 | del st.session_state["files"] 114 | if 'messages' in st.session_state: 115 | del st.session_state["messages"] 116 | 117 | if uploaded_file and add_repository: 118 | with st.spinner("Initializing vector db..."): 119 | files_name = [] 120 | for i, item in enumerate(uploaded_file): 121 | ext_name = os.path.splitext(item.name)[-1] 122 | file_name = f"""./data/uploaded/{item.name}""" 123 | with open(file_name, "wb") as f: 124 | f.write(item.getbuffer()) 125 | f.close() 126 | files_name.append(file_name) 127 | if chatbot_st.init_vector_db_from_documents(files_name): 128 | if 'files' in st.session_state: 129 | st.session_state['files'] = st.session_state['files'] + files_name 130 | else: 131 | st.session_state['files'] = files_name 132 | 133 | st.session_state["messages"] = [{"role": "assistant", "content": "嗨!"}] 134 | st.success('知识库添加完成!', icon='🎉') 135 | st.balloons() 136 | else: 137 | st.error("文件解析失败!") 138 | 139 | if save_repository and 'files' in st.session_state: 140 | chatbot_st.save_vector_db_to_local() 141 | st.success('知识库保存成功!', icon='🎉') 142 | st.experimental_rerun() 143 | 144 | if import_repository and option: 145 | chatbot_st.load_vector_db_from_local(option) 146 | st.session_state["messages"] = [{"role": "assistant", "content": "嗨!"}] 147 | st.success('知识库导入完成!', icon='🎉') 148 | st.session_state['files'] = chatbot_st.time2file_name(option).split(", ") 149 | st.balloons() 150 | 151 | if del_repository and option: 152 | chatbot_st.del_vector_db(option) 153 | st.success('知识库删除完成!', icon='🎉') 154 | st.experimental_rerun() 155 | 156 | if 'files' in st.session_state: 157 | st.markdown("\n".join([str(i + 1) + ". " + x.split("/")[-1] for i, x in enumerate(st.session_state.files)])) 158 | else: 159 | st.info( 160 | '点击Browse files选择要上传的文档,然后点击添加知识库按钮构建知识库。或者选择选择已持久化的知识库然后点击导入知识库按钮导入知识库。', 161 | icon="ℹ️") 162 | 163 | if 'messages' in st.session_state: 164 | for msg in st.session_state.messages: 165 | st.chat_message(msg["role"]).write(msg["content"]) 166 | 167 | if user_input := st.chat_input(): 168 | # import pdb;pdb.set_trace() 169 | if 'files' not in st.session_state: 170 | his = cut_history(user_input) 171 | if 'messages' not in st.session_state: 172 | st.session_state["messages"] = [{"role": "user", "content": user_input}] 173 | else: 174 | st.session_state["messages"].append({"role": "user", "content": user_input}) 175 | st.chat_message("user").write(user_input) 176 | with st.chat_message("assistant"): 177 | answer_container = st.empty() 178 | for result_answer, _ in chatbot_st.llm.stream_predict(user_input, his): 179 | answer_container.markdown(result_answer) 180 | st.session_state["messages"].append({"role": "assistant", "content": result_answer}) 181 | else: 182 | st.session_state["messages"].append({"role": "user", "content": user_input}) 183 | st.chat_message("user").write(user_input) 184 | with st.chat_message("assistant"): 185 | answer_container = st.empty() 186 | start_time = time.time() 187 | docs = chatbot_st.query_from_doc(user_input, 3) 188 | logging.info("Total quire time {}".format(time.time()- start_time)) 189 | refer = "\n".join([x.page_content.replace("\n", '\t') for x in docs]) 190 | PROMPT = """{}。\n请根据下面的参考文档回答上述问题。\n{}\n""" 191 | prompt = PROMPT.format(user_input, refer) 192 | 193 | for result_answer, _ in chatbot_st.llm.stream_predict(prompt, []): 194 | answer_container.markdown(result_answer) 195 | 196 | with st.expander("References"): 197 | for i, doc in enumerate(docs): 198 | source_str = os.path.basename(doc.metadata["source"]) if "source" in doc.metadata else "" 199 | page_str = doc.metadata['page'] + 1 if "page" in doc.metadata else "" 200 | st.write(f"""### Reference [{i + 1}] {source_str} P{page_str}""") 201 | st.write(doc.page_content) 202 | i += 1 203 | 204 | st.session_state["messages"].append({"role": "assistant", "content": result_answer}) 205 | --------------------------------------------------------------------------------