├── .gitignore ├── CyberFriend_LLM_core ├── ChromaRag.py ├── api_server.py ├── download.py ├── finetune │ ├── configs │ │ ├── ds_zero_2.json │ │ ├── ds_zero_3.json │ │ ├── lora.yaml │ │ ├── ptuning_v2.yaml │ │ └── sft.yaml │ ├── finetune_hf.py │ └── requirement.txt ├── requirements.txt └── utils.py ├── CyberFriend_bot_plugin ├── .env.prod ├── GetPathUtil.py ├── common │ ├── CustomChecker.py │ ├── MembersOptUtil.py │ ├── MessageBuilder.py │ └── __init__.py ├── plugins │ ├── __init__.py │ ├── add_image_to_db │ │ ├── __init__.py │ │ └── config.py │ ├── cyber_friend │ │ ├── __init__.py │ │ ├── config.py │ │ ├── prompt.txt │ │ └── utils.py │ ├── group_handle │ │ ├── __init__.py │ │ └── config.py │ ├── member_join │ │ ├── __init__.py │ │ └── config.py │ ├── member_leave │ │ ├── __init__.py │ │ └── config.py │ ├── message_record │ │ ├── ImageUtil.py │ │ ├── __init__.py │ │ ├── config.py │ │ ├── get_record.py │ │ └── util.py │ ├── scheduler │ │ ├── __init__.py │ │ └── config.py │ └── update_members │ │ ├── MembersUtil.py │ │ ├── __init__.py │ │ └── config.py ├── pyproject.toml ├── record_data │ ├── create_dataset.py │ ├── get_records.py │ └── query_number.py └── requirements.txt ├── LICENSE ├── finetune_and_restart.sh ├── readme.md ├── resources ├── code_structure.md └── proj_structure.png └── run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | *.db 162 | *.json 163 | /CyberFriend_LLM_core/finetune_demo/output* 164 | /CyberFriend_LLM_core/finetune_demo/output/ 165 | /CyberFriend_LLM_core/finetune_demo/data/ 166 | /CyberFriend_LLM_core/finetune_demo/wandb/ 167 | /CyberFriend_LLM_core/finetune_demo/code/ 168 | chatglm3-6b 169 | .DS_store -------------------------------------------------------------------------------- /CyberFriend_LLM_core/ChromaRag.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import hashlib 3 | import importlib 4 | import os 5 | import sys 6 | import uuid 7 | from typing import Literal, List, Dict, Tuple 8 | 9 | import chardet 10 | import chromadb 11 | import langchain 12 | import numpy as np 13 | from chromadb import QueryResult 14 | from langchain_community.embeddings import HuggingFaceEmbeddings 15 | from langchain_core.documents import Document 16 | from nonebot import logger 17 | 18 | sys.path.append(os.path.join(os.path.dirname(__file__), '../CyberFriend_bot_plugin')) 19 | from CyberFriend_bot_plugin.GetPathUtil import getPath 20 | 21 | import threading 22 | 23 | EMBEDDING_MODEL = "bge-large-zh" 24 | 25 | TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" 26 | 27 | # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) 28 | CHUNK_SIZE = 250 29 | # 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) 30 | OVERLAP_SIZE = 50 31 | # 知识库匹配向量数量 32 | VECTOR_SEARCH_TOP_K = 3 33 | # 知识库匹配的距离阈值,一般取值范围在0-1之间,SCORE越小,距离越小从而相关度越高。 34 | # 但有用户报告遇到过匹配分值超过1的情况,为了兼容性默认设为1,在WEBUI中调整范围为0-2 35 | SCORE_THRESHOLD = 1.0 36 | 37 | # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 38 | text_splitter_dict = { 39 | "ChineseRecursiveTextSplitter": { 40 | "source": "huggingface", # 选择tiktoken则使用openai的方法 41 | "tokenizer_name_or_path": "", 42 | }, 43 | "SpacyTextSplitter": { 44 | "source": "huggingface", 45 | "tokenizer_name_or_path": "gpt2", 46 | }, 47 | "RecursiveCharacterTextSplitter": { 48 | "source": "tiktoken", 49 | "tokenizer_name_or_path": "cl100k_base", 50 | }, 51 | "MarkdownHeaderTextSplitter": { 52 | "headers_to_split_on": 53 | [ 54 | ("#", "head1"), 55 | ("##", "head2"), 56 | ("###", "head3"), 57 | ("####", "head4"), 58 | ] 59 | }, 60 | } 61 | 62 | 63 | def detect_device() -> Literal["cuda", "mps", "cpu"]: 64 | try: 65 | import torch 66 | if torch.cuda.is_available(): 67 | return "cuda" 68 | if torch.backends.mps.is_available(): 69 | return "mps" 70 | except: 71 | pass 72 | return "cpu" 73 | 74 | 75 | def normalize(embeddings: List[List[float]]) -> np.ndarray: 76 | ''' 77 | sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn 78 | ''' 79 | norm = np.linalg.norm(embeddings, axis=1) 80 | norm = np.reshape(norm, (norm.shape[0], 1)) 81 | norm = np.tile(norm, (1, len(embeddings[0]))) 82 | return np.divide(embeddings, norm) 83 | 84 | 85 | def encrypt(fpath: str, algorithm: str = "md5") -> str: 86 | hash_algorithm = None 87 | if algorithm is not None and isinstance(algorithm, str): 88 | algorithm = algorithm.lower() 89 | 90 | if algorithm == 'md5': 91 | hash_algorithm = hashlib.md5() 92 | elif algorithm == 'sha1': 93 | hash_algorithm = hashlib.sha1() 94 | elif algorithm == 'sha256': 95 | hash_algorithm = hashlib.sha256() 96 | else: 97 | raise ValueError("unsupported hash algorithm") 98 | # 以二进制模式打开文件 99 | with open(fpath, 'rb') as f: 100 | # 分块读取文件内容 101 | for chunk in iter(lambda: f.read(2 ** 12), b''): 102 | # 更新散列值 103 | hash_algorithm.update(chunk) 104 | # 返回十六进制字符串 105 | return hash_algorithm.hexdigest() 106 | 107 | 108 | def encryptText(texts: List[str], algorithm: str = "md5") -> List[str]: 109 | hash_algorithm = None 110 | if algorithm is not None and isinstance(algorithm, str): 111 | algorithm = algorithm.lower() 112 | ans = [] 113 | for text in texts: 114 | if algorithm == 'md5': 115 | hash_algorithm = hashlib.md5() 116 | elif algorithm == 'sha1': 117 | hash_algorithm = hashlib.sha1() 118 | elif algorithm == 'sha256': 119 | hash_algorithm = hashlib.sha256() 120 | else: 121 | raise ValueError("unsupported hash algorithm") 122 | hash_algorithm.update(text.encode()) 123 | ans.append(hash_algorithm.hexdigest()) 124 | return ans 125 | 126 | 127 | class WordEmbeddingModel: 128 | _instance = None 129 | _lock = threading.Lock() 130 | 131 | def __new__(cls, *args, **kwargs): 132 | if cls._instance is None: 133 | with cls._lock: 134 | if cls._instance is None: 135 | cls._instance = super().__new__(cls) 136 | return cls._instance 137 | 138 | def __init__(self, embed_model=EMBEDDING_MODEL, device=detect_device()): 139 | if not hasattr(self, 'initialized'): 140 | self.model = HuggingFaceEmbeddings(model_name=embed_model, 141 | model_kwargs={'device': device}) 142 | self.initialized = True 143 | 144 | 145 | def make_text_splitter( 146 | splitter_name: str = TEXT_SPLITTER_NAME, 147 | chunk_size: int = CHUNK_SIZE, 148 | chunk_overlap: int = OVERLAP_SIZE, 149 | ): 150 | """ 151 | 根据参数获取特定的分词器 152 | """ 153 | splitter_name = splitter_name or "SpacyTextSplitter" 154 | try: 155 | if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 156 | headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] 157 | text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( 158 | headers_to_split_on=headers_to_split_on) 159 | else: 160 | 161 | try: ## 优先使用用户自定义的text_splitter 162 | text_splitter_module = importlib.import_module('text_splitter') 163 | TextSplitter = getattr(text_splitter_module, splitter_name) 164 | except: ## 否则使用langchain的text_splitter 165 | text_splitter_module = importlib.import_module('langchain.text_splitter') 166 | TextSplitter = getattr(text_splitter_module, splitter_name) 167 | 168 | if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 169 | try: 170 | text_splitter = TextSplitter.from_tiktoken_encoder( 171 | encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], 172 | pipeline="zh_core_web_sm", 173 | chunk_size=chunk_size, 174 | chunk_overlap=chunk_overlap 175 | ) 176 | except: 177 | text_splitter = TextSplitter.from_tiktoken_encoder( 178 | encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], 179 | chunk_size=chunk_size, 180 | chunk_overlap=chunk_overlap 181 | ) 182 | elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 183 | if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": 184 | text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = EMBEDDING_MODEL 185 | 186 | if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": 187 | from transformers import GPT2TokenizerFast 188 | from langchain.text_splitter import CharacterTextSplitter 189 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 190 | else: ## 字符长度加载 191 | from transformers import AutoTokenizer 192 | tokenizer = AutoTokenizer.from_pretrained( 193 | text_splitter_dict[splitter_name]["tokenizer_name_or_path"], 194 | trust_remote_code=True) 195 | text_splitter = TextSplitter.from_huggingface_tokenizer( 196 | tokenizer=tokenizer, 197 | chunk_size=chunk_size, 198 | chunk_overlap=chunk_overlap 199 | ) 200 | else: 201 | try: 202 | text_splitter = TextSplitter( 203 | pipeline="zh_core_web_sm", 204 | chunk_size=chunk_size, 205 | chunk_overlap=chunk_overlap 206 | ) 207 | except: 208 | text_splitter = TextSplitter( 209 | chunk_size=chunk_size, 210 | chunk_overlap=chunk_overlap 211 | ) 212 | except Exception as e: 213 | print(e) 214 | text_splitter_module = importlib.import_module('langchain.text_splitter') 215 | TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") 216 | text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) 217 | 218 | # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287 219 | # text_splitter._tokenizer.max_length = 37016792 220 | # text_splitter._tokenizer.prefer_gpu() 221 | return text_splitter 222 | 223 | 224 | LOADER_DICT = {"UnstructuredHTMLLoader": ['.html', '.htm'], 225 | "MHTMLLoader": ['.mhtml'], 226 | "UnstructuredMarkdownLoader": ['.md'], 227 | "JSONLoader": [".json"], 228 | "JSONLinesLoader": [".jsonl"], 229 | "CSVLoader": [".csv"], 230 | # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv 231 | "RapidOCRPDFLoader": [".pdf"], 232 | "RapidOCRDocLoader": ['.docx', '.doc'], 233 | "RapidOCRPPTLoader": ['.ppt', '.pptx', ], 234 | "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], 235 | "UnstructuredFileLoader": ['.eml', '.msg', '.rst', 236 | '.rtf', '.txt', '.xml', 237 | '.epub', '.odt', '.tsv'], 238 | "UnstructuredEmailLoader": ['.eml', '.msg'], 239 | "UnstructuredEPubLoader": ['.epub'], 240 | "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'], 241 | "NotebookLoader": ['.ipynb'], 242 | "UnstructuredODTLoader": ['.odt'], 243 | "PythonLoader": ['.py'], 244 | "UnstructuredRSTLoader": ['.rst'], 245 | "UnstructuredRTFLoader": ['.rtf'], 246 | "SRTLoader": ['.srt'], 247 | "TomlLoader": ['.toml'], 248 | "UnstructuredTSVLoader": ['.tsv'], 249 | "UnstructuredWordDocumentLoader": ['.docx', '.doc'], 250 | "UnstructuredXMLLoader": ['.xml'], 251 | "UnstructuredPowerPointLoader": ['.ppt', '.pptx'], 252 | "EverNoteLoader": ['.enex'], 253 | } 254 | 255 | 256 | def get_LoaderClass(file_extension): 257 | for LoaderClass, extensions in LOADER_DICT.items(): 258 | if file_extension in extensions: 259 | return LoaderClass 260 | 261 | 262 | SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] 263 | 264 | 265 | def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): 266 | ''' 267 | 根据loader_name和文件路径或内容返回文档加载器。 268 | ''' 269 | loader_kwargs = loader_kwargs or {} 270 | try: 271 | if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader", 272 | "RapidOCRDocLoader", "RapidOCRPPTLoader"]: 273 | document_loaders_module = importlib.import_module('document_loaders') 274 | else: 275 | document_loaders_module = importlib.import_module('langchain_community.document_loaders') 276 | DocumentLoader = getattr(document_loaders_module, loader_name) 277 | except Exception as e: 278 | msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" 279 | logger.error(f'{e.__class__.__name__}: {msg}', 280 | exc_info=e) 281 | document_loaders_module = importlib.import_module('langchain_community.document_loaders') 282 | DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") 283 | 284 | if loader_name == "UnstructuredFileLoader": 285 | loader_kwargs.setdefault("autodetect_encoding", True) 286 | elif loader_name == "CSVLoader": 287 | if not loader_kwargs.get("encoding"): 288 | # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误 289 | with open(file_path, 'rb') as struct_file: 290 | encode_detect = chardet.detect(struct_file.read()) 291 | if encode_detect is None: 292 | encode_detect = {"encoding": "utf-8"} 293 | loader_kwargs["encoding"] = encode_detect["encoding"] 294 | 295 | elif loader_name == "JSONLoader": 296 | loader_kwargs.setdefault("jq_schema", ".") 297 | loader_kwargs.setdefault("text_content", False) 298 | elif loader_name == "JSONLinesLoader": 299 | loader_kwargs.setdefault("jq_schema", ".") 300 | loader_kwargs.setdefault("text_content", False) 301 | 302 | loader = DocumentLoader(file_path, **loader_kwargs) 303 | return loader 304 | 305 | 306 | class ChromaRag: 307 | 308 | def __init__(self, dbpath=getPath("knowledge_db", "default"), llm=None, embed_model=EMBEDDING_MODEL): 309 | self.client = chromadb.PersistentClient(path=dbpath) 310 | self.defaultCollection = self.client.get_or_create_collection("default") 311 | self.collection = {"default": self.defaultCollection} 312 | self.embed_model = embed_model 313 | self.word_embedding_model = None 314 | 315 | def loadWordEmbeddingModel(self): 316 | self.word_embedding_model = WordEmbeddingModel(embed_model=self.embed_model) 317 | 318 | def embedding(self, texts: List[str]) -> List[List[float]]: 319 | if self.word_embedding_model is None: 320 | self.loadWordEmbeddingModel() 321 | return self.word_embedding_model.model.embed_documents(texts=texts) 322 | 323 | def embeddingDocs(self, docs): 324 | texts = [x.page_content for x in docs] 325 | metadatas = [x.metadata for x in docs] 326 | if self.word_embedding_model is None: 327 | self.loadWordEmbeddingModel() 328 | embeddings = normalize(self.word_embedding_model.model.embed_documents(texts)).tolist() 329 | # embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data 330 | if embeddings is not None: 331 | return { 332 | "texts": texts, 333 | "embeddings": embeddings, 334 | "metadatas": metadatas, 335 | } 336 | 337 | def search(self, msg, top_k: int = VECTOR_SEARCH_TOP_K, collectionName="default", score_threshold: float = SCORE_THRESHOLD) -> List[ 338 | Tuple[Document, float]]: 339 | embeddings = self.embedding([msg])[0] 340 | collection = self.collection.get(collectionName, self.defaultCollection) 341 | query_result: QueryResult = collection.query(query_embeddings=embeddings, n_results=top_k) 342 | return [ 343 | # TODO: Chroma can do batch querying, 344 | (Document(page_content=result[0], metadata=result[1] or {}), result[2]) 345 | for result in zip( 346 | query_result["documents"][0], 347 | query_result["metadatas"][0], 348 | query_result["distances"][0], 349 | ) 350 | ] 351 | 352 | def add_doc(self, filePath, collectionName="default", loader_kwargs: Dict = {}, text_splitter=TEXT_SPLITTER_NAME, chunk_size=CHUNK_SIZE, 353 | chunk_overlap=OVERLAP_SIZE): 354 | if not os.path.exists(filePath): 355 | raise RuntimeError(f"{filePath} not exists") 356 | 357 | md5 = encrypt(filePath) 358 | res = self.queryByMd5Bool(md5, collectionName) 359 | if res: 360 | logger.warning(f"{filePath} is exists") 361 | return [] 362 | 363 | fileName = os.path.basename(filePath) 364 | ext = os.path.splitext(filePath)[-1].lower() 365 | document_loader_name = get_LoaderClass(ext) 366 | loader = get_loader(loader_name=document_loader_name, 367 | file_path=filePath, 368 | loader_kwargs=loader_kwargs) 369 | docs = loader.load() 370 | text_splitter = make_text_splitter(splitter_name=text_splitter, chunk_size=chunk_size, 371 | chunk_overlap=chunk_overlap) 372 | if text_splitter == "MarkdownHeaderTextSplitter": 373 | docs = text_splitter.split_text(docs[0].page_content) 374 | else: 375 | docs = text_splitter.split_documents(docs) 376 | 377 | if not docs: 378 | raise RuntimeError("分割文档失败") 379 | 380 | for doc in docs: 381 | source = doc.metadata.get("source", "") 382 | if not source or os.path.isabs(source): 383 | doc.metadata["source"] = fileName 384 | doc.metadata["md5"] = md5 385 | 386 | doc_infos = [] 387 | data = self.embeddingDocs(docs) 388 | ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))] 389 | collection = self.collection.get(collectionName, self.defaultCollection) 390 | 391 | for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]): 392 | collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text) 393 | doc_infos.append({"id": _id, "metadata": metadata}) 394 | return doc_infos 395 | 396 | def add_docs(self, filePaths, collectionName="default"): 397 | for i in filePaths: 398 | self.add_doc(i, collectionName) 399 | 400 | def add_text(self, text, collectionName="default", metadata=None): 401 | if metadata is None: 402 | return self.add_texts([text], collectionName, None) 403 | else: 404 | return self.add_texts([text], collectionName, [metadata]) 405 | 406 | def add_texts(self, texts, collectionName="default", metadata=None): 407 | if metadata is None: 408 | metadata = [{"source": "default"} for _ in range(len(texts))] 409 | 410 | md5s = encryptText(texts) 411 | addText = [] 412 | for md5, metaD in zip(md5s, metadata): 413 | metaD["md5"] = md5 414 | addText.append(not self.queryByMd5Bool(md5, collectionName)) 415 | doc_infos = [] 416 | embed = self.embedding(texts) 417 | ids = [str(uuid.uuid1()) for _ in range(len(texts))] 418 | collection = self.collection.get(collectionName, self.defaultCollection) 419 | 420 | for _id, text, embedding, metadata, needAdd in zip(ids, texts, embed, metadata, addText): 421 | if needAdd: 422 | collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text) 423 | doc_infos.append({"id": _id, "metadata": metadata, "result": True}) 424 | else: 425 | doc_infos.append({"id": _id, "metadata": metadata, "result": False}) 426 | return doc_infos 427 | 428 | def deleteByFile(self, fileName, collectionName="default"): 429 | collection = self.collection.get(collectionName, self.defaultCollection) 430 | return collection.delete(where={"source": fileName}) 431 | 432 | def queryBySource(self, source, collectionName="default"): 433 | collection = self.collection.get(collectionName, self.defaultCollection) 434 | return collection.get(where={"source": source}) 435 | 436 | def queryByMd5(self, md5, collectionName="default"): 437 | collection = self.collection.get(collectionName, self.defaultCollection) 438 | return collection.get(where={"md5": md5}) 439 | 440 | def queryByMd5Bool(self, md5, collectionName="default"): 441 | res = self.queryByMd5(md5, collectionName) 442 | if res: 443 | if len(res["ids"]) > 0: 444 | return True 445 | return False 446 | 447 | 448 | 449 | if __name__ == '__main__': 450 | rc = ChromaRag() 451 | # print(rc.add_doc(r"test.txt")) 452 | # print(rc.add_text("test")) 453 | print(rc.add_text("test")) 454 | print(rc.queryBySource("default")) 455 | print(rc.queryByMd5("098f6bcd4621d373cade4e832627b4f6")) 456 | print(rc.queryByMd5Bool("098f6bcd4621d373cade4e832627b4f6")) 457 | print(rc.search("test")) 458 | -------------------------------------------------------------------------------- /CyberFriend_LLM_core/api_server.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tiktoken 3 | import torch 4 | import uvicorn 5 | 6 | from fastapi import FastAPI, HTTPException, Response 7 | from fastapi.middleware.cors import CORSMiddleware 8 | 9 | from contextlib import asynccontextmanager 10 | from typing import List, Literal, Optional, Union 11 | from loguru import logger 12 | from pydantic import BaseModel, Field 13 | from transformers import AutoTokenizer 14 | from utils import process_response, generate_chatglm3, generate_stream_chatglm3 15 | from sentence_transformers import SentenceTransformer 16 | 17 | from sse_starlette.sse import EventSourceResponse 18 | 19 | # Set up limit request time 20 | EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 21 | 22 | # set LLM path 23 | MODEL_PATH = 'ZhipuAI/chatglm3-6b' 24 | TOKENIZER_PATH = 'ZhipuAI/chatglm3-6b' 25 | 26 | # set Embedding Model path 27 | EMBEDDING_PATH = 'ZhipuAI/chatglm3-6b' 28 | 29 | 30 | @asynccontextmanager 31 | async def lifespan(app: FastAPI): 32 | yield 33 | if torch.cuda.is_available(): 34 | torch.cuda.empty_cache() 35 | torch.cuda.ipc_collect() 36 | 37 | 38 | app = FastAPI(lifespan=lifespan) 39 | 40 | app.add_middleware( 41 | CORSMiddleware, 42 | allow_origins=["*"], 43 | allow_credentials=True, 44 | allow_methods=["*"], 45 | allow_headers=["*"], 46 | ) 47 | 48 | 49 | class ModelCard(BaseModel): 50 | id: str 51 | object: str = "model" 52 | created: int = Field(default_factory=lambda: int(time.time())) 53 | owned_by: str = "owner" 54 | root: Optional[str] = None 55 | parent: Optional[str] = None 56 | permission: Optional[list] = None 57 | 58 | 59 | class ModelList(BaseModel): 60 | object: str = "list" 61 | data: List[ModelCard] = [] 62 | 63 | 64 | class FunctionCallResponse(BaseModel): 65 | name: Optional[str] = None 66 | arguments: Optional[str] = None 67 | 68 | 69 | class ChatMessage(BaseModel): 70 | role: Literal["user", "assistant", "system", "function"] 71 | content: str = None 72 | name: Optional[str] = None 73 | function_call: Optional[FunctionCallResponse] = None 74 | 75 | 76 | class DeltaMessage(BaseModel): 77 | role: Optional[Literal["user", "assistant", "system"]] = None 78 | content: Optional[str] = None 79 | function_call: Optional[FunctionCallResponse] = None 80 | 81 | 82 | ## for Embedding 83 | class EmbeddingRequest(BaseModel): 84 | input: List[str] 85 | model: str 86 | 87 | 88 | class CompletionUsage(BaseModel): 89 | prompt_tokens: int 90 | completion_tokens: int 91 | total_tokens: int 92 | 93 | 94 | class EmbeddingResponse(BaseModel): 95 | data: list 96 | model: str 97 | object: str 98 | usage: CompletionUsage 99 | 100 | 101 | # for ChatCompletionRequest 102 | 103 | class UsageInfo(BaseModel): 104 | prompt_tokens: int = 0 105 | total_tokens: int = 0 106 | completion_tokens: Optional[int] = 0 107 | 108 | 109 | class ChatCompletionRequest(BaseModel): 110 | model: str 111 | messages: List[ChatMessage] 112 | temperature: Optional[float] = 0.8 113 | top_p: Optional[float] = 0.8 114 | max_tokens: Optional[int] = None 115 | stream: Optional[bool] = False 116 | tools: Optional[Union[dict, List[dict]]] = None 117 | repetition_penalty: Optional[float] = 1.1 118 | 119 | 120 | class ChatCompletionResponseChoice(BaseModel): 121 | index: int 122 | message: ChatMessage 123 | finish_reason: Literal["stop", "length", "function_call"] 124 | 125 | 126 | class ChatCompletionResponseStreamChoice(BaseModel): 127 | delta: DeltaMessage 128 | finish_reason: Optional[Literal["stop", "length", "function_call"]] 129 | index: int 130 | 131 | 132 | class ChatCompletionResponse(BaseModel): 133 | model: str 134 | id: str 135 | object: Literal["chat.completion", "chat.completion.chunk"] 136 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 137 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 138 | usage: Optional[UsageInfo] = None 139 | 140 | 141 | @app.get("/health") 142 | async def health() -> Response: 143 | """Health check.""" 144 | return Response(status_code=200) 145 | 146 | 147 | @app.post("/v1/embeddings", response_model=EmbeddingResponse) 148 | async def get_embeddings(request: EmbeddingRequest): 149 | embeddings = [embedding_model.encode(text) for text in request.input] 150 | embeddings = [embedding.tolist() for embedding in embeddings] 151 | 152 | def num_tokens_from_string(string: str) -> int: 153 | """ 154 | Returns the number of tokens in a text string. 155 | use cl100k_base tokenizer 156 | """ 157 | encoding = tiktoken.get_encoding('cl100k_base') 158 | num_tokens = len(encoding.encode(string)) 159 | return num_tokens 160 | 161 | response = { 162 | "data": [ 163 | { 164 | "object": "embedding", 165 | "embedding": embedding, 166 | "index": index 167 | } 168 | for index, embedding in enumerate(embeddings) 169 | ], 170 | "model": request.model, 171 | "object": "list", 172 | "usage": CompletionUsage( 173 | prompt_tokens=sum(len(text.split()) for text in request.input), 174 | completion_tokens=0, 175 | total_tokens=sum(num_tokens_from_string(text) for text in request.input), 176 | ) 177 | } 178 | return response 179 | 180 | 181 | @app.get("/v1/models", response_model=ModelList) 182 | async def list_models(): 183 | model_card = ModelCard( 184 | id="chatglm3-6b" 185 | ) 186 | return ModelList( 187 | data=[model_card] 188 | ) 189 | 190 | 191 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 192 | async def create_chat_completion(request: ChatCompletionRequest): 193 | global model, tokenizer 194 | 195 | if len(request.messages) < 1 or request.messages[-1].role == "assistant": 196 | raise HTTPException(status_code=400, detail="Invalid request") 197 | 198 | gen_params = dict( 199 | messages=request.messages, 200 | temperature=request.temperature, 201 | top_p=request.top_p, 202 | max_tokens=request.max_tokens or 1024, 203 | echo=False, 204 | stream=request.stream, 205 | repetition_penalty=request.repetition_penalty, 206 | tools=request.tools, 207 | ) 208 | logger.debug(f"==== request ====\n{gen_params}") 209 | 210 | if request.stream: 211 | 212 | # Use the stream mode to read the first few characters, if it is not a function call, direct stram output 213 | predict_stream_generator = predict_stream(request.model, gen_params) 214 | output = next(predict_stream_generator) 215 | if not contains_custom_function(output): 216 | return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") 217 | 218 | # Obtain the result directly at one time and determine whether tools needs to be called. 219 | logger.debug(f"First result output:\n{output}") 220 | 221 | function_call = None 222 | if output and request.tools: 223 | try: 224 | function_call = process_response(output, use_tool=True) 225 | except: 226 | logger.warning("Failed to parse tool call") 227 | 228 | # CallFunction 229 | if isinstance(function_call, dict): 230 | function_call = FunctionCallResponse(**function_call) 231 | 232 | """ 233 | In this demo, we did not register any tools. 234 | You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here. 235 | Similar to the following method: 236 | function_args = json.loads(function_call.arguments) 237 | tool_response = dispatch_tool(tool_name: str, tool_params: dict) 238 | """ 239 | tool_response = "" 240 | 241 | if not gen_params.get("messages"): 242 | gen_params["messages"] = [] 243 | 244 | gen_params["messages"].append(ChatMessage( 245 | role="assistant", 246 | content=output, 247 | )) 248 | gen_params["messages"].append(ChatMessage( 249 | role="function", 250 | name=function_call.name, 251 | content=tool_response, 252 | )) 253 | 254 | # Streaming output of results after function calls 255 | generate = predict(request.model, gen_params) 256 | return EventSourceResponse(generate, media_type="text/event-stream") 257 | 258 | else: 259 | # Handled to avoid exceptions in the above parsing function process. 260 | generate = parse_output_text(request.model, output) 261 | return EventSourceResponse(generate, media_type="text/event-stream") 262 | 263 | # Here is the handling of stream = False 264 | response = generate_chatglm3(model, tokenizer, gen_params) 265 | 266 | # Remove the first newline character 267 | if response["text"].startswith("\n"): 268 | response["text"] = response["text"][1:] 269 | response["text"] = response["text"].strip() 270 | 271 | usage = UsageInfo() 272 | function_call, finish_reason = None, "stop" 273 | if request.tools: 274 | try: 275 | function_call = process_response(response["text"], use_tool=True) 276 | except: 277 | logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.") 278 | 279 | if isinstance(function_call, dict): 280 | finish_reason = "function_call" 281 | function_call = FunctionCallResponse(**function_call) 282 | 283 | message = ChatMessage( 284 | role="assistant", 285 | content=response["text"], 286 | function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, 287 | ) 288 | 289 | logger.debug(f"==== message ====\n{message}") 290 | 291 | choice_data = ChatCompletionResponseChoice( 292 | index=0, 293 | message=message, 294 | finish_reason=finish_reason, 295 | ) 296 | task_usage = UsageInfo.model_validate(response["usage"]) 297 | for usage_key, usage_value in task_usage.model_dump().items(): 298 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 299 | 300 | return ChatCompletionResponse( 301 | model=request.model, 302 | id="", # for open_source model, id is empty 303 | choices=[choice_data], 304 | object="chat.completion", 305 | usage=usage 306 | ) 307 | 308 | 309 | async def predict(model_id: str, params: dict): 310 | global model, tokenizer 311 | 312 | choice_data = ChatCompletionResponseStreamChoice( 313 | index=0, 314 | delta=DeltaMessage(role="assistant"), 315 | finish_reason=None 316 | ) 317 | chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") 318 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 319 | 320 | previous_text = "" 321 | for new_response in generate_stream_chatglm3(model, tokenizer, params): 322 | decoded_unicode = new_response["text"] 323 | delta_text = decoded_unicode[len(previous_text):] 324 | previous_text = decoded_unicode 325 | 326 | finish_reason = new_response["finish_reason"] 327 | if len(delta_text) == 0 and finish_reason != "function_call": 328 | continue 329 | 330 | function_call = None 331 | if finish_reason == "function_call": 332 | try: 333 | function_call = process_response(decoded_unicode, use_tool=True) 334 | except: 335 | logger.warning( 336 | "Failed to parse tool call, maybe the response is not a tool call or have been answered.") 337 | 338 | if isinstance(function_call, dict): 339 | function_call = FunctionCallResponse(**function_call) 340 | 341 | delta = DeltaMessage( 342 | content=delta_text, 343 | role="assistant", 344 | function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, 345 | ) 346 | 347 | choice_data = ChatCompletionResponseStreamChoice( 348 | index=0, 349 | delta=delta, 350 | finish_reason=finish_reason 351 | ) 352 | chunk = ChatCompletionResponse( 353 | model=model_id, 354 | id="", 355 | choices=[choice_data], 356 | object="chat.completion.chunk" 357 | ) 358 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 359 | 360 | choice_data = ChatCompletionResponseStreamChoice( 361 | index=0, 362 | delta=DeltaMessage(), 363 | finish_reason="stop" 364 | ) 365 | chunk = ChatCompletionResponse( 366 | model=model_id, 367 | id="", 368 | choices=[choice_data], 369 | object="chat.completion.chunk" 370 | ) 371 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 372 | yield '[DONE]' 373 | 374 | 375 | def predict_stream(model_id, gen_params): 376 | """ 377 | The function call is compatible with stream mode output. 378 | 379 | The first seven characters are determined. 380 | If not a function call, the stream output is directly generated. 381 | Otherwise, the complete character content of the function call is returned. 382 | 383 | :param model_id: 384 | :param gen_params: 385 | :return: 386 | """ 387 | output = "" 388 | is_function_call = False 389 | has_send_first_chunk = False 390 | for new_response in generate_stream_chatglm3(model, tokenizer, gen_params): 391 | decoded_unicode = new_response["text"] 392 | delta_text = decoded_unicode[len(output):] 393 | output = decoded_unicode 394 | 395 | # When it is not a function call and the character length is> 7, 396 | # try to judge whether it is a function call according to the special function prefix 397 | if not is_function_call and len(output) > 7: 398 | 399 | # Determine whether a function is called 400 | is_function_call = contains_custom_function(output) 401 | if is_function_call: 402 | continue 403 | 404 | # Non-function call, direct stream output 405 | finish_reason = new_response["finish_reason"] 406 | 407 | # Send an empty string first to avoid truncation by subsequent next() operations. 408 | if not has_send_first_chunk: 409 | message = DeltaMessage( 410 | content="", 411 | role="assistant", 412 | function_call=None, 413 | ) 414 | choice_data = ChatCompletionResponseStreamChoice( 415 | index=0, 416 | delta=message, 417 | finish_reason=finish_reason 418 | ) 419 | chunk = ChatCompletionResponse( 420 | model=model_id, 421 | id="", 422 | choices=[choice_data], 423 | created=int(time.time()), 424 | object="chat.completion.chunk" 425 | ) 426 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 427 | 428 | send_msg = delta_text if has_send_first_chunk else output 429 | has_send_first_chunk = True 430 | message = DeltaMessage( 431 | content=send_msg, 432 | role="assistant", 433 | function_call=None, 434 | ) 435 | choice_data = ChatCompletionResponseStreamChoice( 436 | index=0, 437 | delta=message, 438 | finish_reason=finish_reason 439 | ) 440 | chunk = ChatCompletionResponse( 441 | model=model_id, 442 | id="", 443 | choices=[choice_data], 444 | created=int(time.time()), 445 | object="chat.completion.chunk" 446 | ) 447 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 448 | 449 | if is_function_call: 450 | yield output 451 | else: 452 | yield '[DONE]' 453 | 454 | 455 | async def parse_output_text(model_id: str, value: str): 456 | """ 457 | Directly output the text content of value 458 | 459 | :param model_id: 460 | :param value: 461 | :return: 462 | """ 463 | choice_data = ChatCompletionResponseStreamChoice( 464 | index=0, 465 | delta=DeltaMessage(role="assistant", content=value), 466 | finish_reason=None 467 | ) 468 | chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") 469 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 470 | 471 | choice_data = ChatCompletionResponseStreamChoice( 472 | index=0, 473 | delta=DeltaMessage(), 474 | finish_reason="stop" 475 | ) 476 | chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") 477 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 478 | yield '[DONE]' 479 | 480 | 481 | def contains_custom_function(value: str) -> bool: 482 | """ 483 | Determine whether 'function_call' according to a special function prefix. 484 | 485 | For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_" 486 | 487 | [Note] This is not a rigorous judgment method, only for reference. 488 | 489 | :param value: 490 | :return: 491 | """ 492 | return value and 'get_' in value 493 | 494 | 495 | if __name__ == "__main__": 496 | model_path = "ZhipuAI/chatglm3-6b" 497 | from transformers import AutoTokenizer, AutoModelForCausalLM 498 | from peft import PeftModel 499 | 500 | load_st = time.time() 501 | omodel = AutoModelForCausalLM.from_pretrained( 502 | model_path, load_in_8bit=False, trust_remote_code=True, 503 | device_map="auto" # 模型不同层会被自动分配到不同GPU上进行计算) 504 | ) 505 | load_end = time.time() 506 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 507 | # 原始 LLM 安装上 Lora 模型 508 | model = PeftModel.from_pretrained(omodel, "./finetune_demo/output/checkpoint-3000") 509 | 510 | # load Embedding 511 | embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda") 512 | uvicorn.run(app, host='0.0.0.0', port=9021, workers=1) -------------------------------------------------------------------------------- /CyberFriend_LLM_core/download.py: -------------------------------------------------------------------------------- 1 | from modelscope import snapshot_download 2 | model_dir = snapshot_download("ZhipuAI/chatglm3-6b", cache_dir='./', revision = "v1.0.0") -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/configs/ds_zero_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | }, 22 | 23 | "gradient_accumulation_steps": "auto", 24 | "gradient_clipping": "auto", 25 | "steps_per_print": 2000, 26 | "train_batch_size": "auto", 27 | "train_micro_batch_size_per_gpu": "auto", 28 | "wall_clock_breakdown": false 29 | } -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/configs/ds_zero_3.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "bf16": { 5 | "enabled": "auto" 6 | }, 7 | "optimizer": { 8 | "type": "AdamW", 9 | "params": { 10 | "lr": "auto", 11 | "betas": "auto", 12 | "eps": "auto", 13 | "weight_decay": "auto" 14 | } 15 | }, 16 | "zero_optimization": { 17 | "stage": 3, 18 | "allgather_partitions": true, 19 | "allgather_bucket_size": 5e8, 20 | "reduce_scatter": true, 21 | "contiguous_gradients": true, 22 | "overlap_comm": true, 23 | "sub_group_size": 1e9, 24 | "reduce_bucket_size": "auto", 25 | "stage3_prefetch_bucket_size": "auto", 26 | "stage3_param_persistence_threshold": "auto", 27 | "stage3_max_live_parameters": 1e9, 28 | "stage3_max_reuse_distance": 1e9, 29 | "stage3_gather_16bit_weights_on_model_save": true 30 | } 31 | } -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/configs/lora.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.json 3 | val_file: train.json 4 | test_file: train.json 5 | num_proc: 32 6 | max_input_length: 1024 7 | max_output_length: 256 8 | training_args: 9 | # see `transformers.Seq2SeqTrainingArguments` 10 | output_dir: ./output 11 | max_steps: 3000 12 | # settings for data loading 13 | per_device_train_batch_size: 1 14 | dataloader_num_workers: 32 15 | remove_unused_columns: false 16 | # settings for saving checkpoints 17 | save_strategy: steps 18 | save_steps: 500 19 | # settings for logging 20 | log_level: info 21 | logging_strategy: steps 22 | logging_steps: 10 23 | # settings for evaluation 24 | per_device_eval_batch_size: 16 25 | evaluation_strategy: steps 26 | eval_steps: 500 27 | # settings for optimizer 28 | # adam_epsilon: 1e-6 29 | # uncomment the following line to detect nan or inf values 30 | # debug: underflow_overflow 31 | predict_with_generate: true 32 | # see `transformers.GenerationConfig` 33 | generation_config: 34 | max_new_tokens: 256 35 | # set your absolute deepspeed path here 36 | #deepspeed: ds_zero_2.json 37 | peft_config: 38 | peft_type: LORA 39 | task_type: CAUSAL_LM 40 | r: 256 41 | lora_alpha: 512 42 | lora_dropout: 0.1 43 | -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/configs/ptuning_v2.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.json 3 | val_file: train.json 4 | test_file: train.json 5 | num_proc: 16 6 | max_input_length: 1024 7 | max_output_length: 512 8 | training_args: 9 | # see `transformers.Seq2SeqTrainingArguments` 10 | output_dir: ./output 11 | max_steps: 3000 12 | # settings for data loading 13 | per_device_train_batch_size: 2 14 | dataloader_num_workers: 16 15 | remove_unused_columns: false 16 | # settings for saving checkpoints 17 | save_strategy: steps 18 | save_steps: 3000 19 | # settings for logging 20 | log_level: info 21 | logging_strategy: steps 22 | logging_steps: 10 23 | # settings for evaluation 24 | per_device_eval_batch_size: 16 25 | evaluation_strategy: steps 26 | eval_steps: 500 27 | # settings for optimizer 28 | # adam_epsilon: 1e-6 29 | # uncomment the following line to detect nan or inf values 30 | # debug: underflow_overflow 31 | predict_with_generate: true 32 | # see `transformers.GenerationConfig` 33 | generation_config: 34 | max_new_tokens: 1024 35 | # set your absolute deepspeed path here 36 | #deepspeed: ds_zero_3.json 37 | peft_config: 38 | peft_type: PREFIX_TUNING 39 | num_virtual_tokens: 64 -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/configs/sft.yaml: -------------------------------------------------------------------------------- 1 | data_config: 2 | train_file: train.json 3 | val_file: dev.json 4 | test_file: dev.json 5 | num_proc: 16 6 | max_input_length: 128 7 | max_output_length: 256 8 | training_args: 9 | # see `transformers.Seq2SeqTrainingArguments` 10 | output_dir: ./output 11 | max_steps: 3000 12 | # settings for data loading 13 | per_device_train_batch_size: 1 14 | dataloader_num_workers: 16 15 | remove_unused_columns: false 16 | # settings for saving checkpoints 17 | save_strategy: steps 18 | save_steps: 500 19 | # settings for logging 20 | log_level: info 21 | logging_strategy: steps 22 | logging_steps: 10 23 | # settings for evaluation 24 | per_device_eval_batch_size: 16 25 | evaluation_strategy: steps 26 | eval_steps: 500 27 | # settings for optimizer 28 | # adam_epsilon: 1e-6 29 | # uncomment the following line to detect nan or inf values 30 | # debug: underflow_overflow 31 | predict_with_generate: true 32 | generation_config: 33 | max_new_tokens: 256 34 | # set your absolute deepspeed path here 35 | deepspeed: ds_zero_3.json -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/finetune_hf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import dataclasses as dc 5 | import functools 6 | from collections.abc import Callable, Mapping, Sequence 7 | from pathlib import Path 8 | from typing import Annotated, Any, Optional, Union 9 | 10 | import jieba 11 | import numpy as np 12 | import ruamel.yaml as yaml 13 | import torch 14 | import typer 15 | from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset 16 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 17 | from peft import ( 18 | PeftConfig, 19 | PeftModelForCausalLM, 20 | get_peft_config, 21 | get_peft_model 22 | ) 23 | from rouge_chinese import Rouge 24 | from torch import nn 25 | from transformers import ( 26 | AutoModelForCausalLM, 27 | AutoTokenizer, 28 | EvalPrediction, 29 | GenerationConfig, 30 | PreTrainedModel, 31 | PreTrainedTokenizer, 32 | PreTrainedTokenizerFast, 33 | Seq2SeqTrainingArguments, AutoConfig, 34 | ) 35 | from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq 36 | 37 | from transformers import Seq2SeqTrainer as _Seq2SeqTrainer 38 | 39 | ModelType = Union[PreTrainedModel, PeftModelForCausalLM] 40 | TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 41 | app = typer.Typer(pretty_exceptions_show_locals=False) 42 | 43 | 44 | class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq): 45 | def __call__(self, features, return_tensors=None): 46 | output_ids = ( 47 | [feature['output_ids'] for feature in features] 48 | if 'output_ids' in features[0].keys() 49 | else None 50 | ) 51 | if output_ids is not None: 52 | max_output_length = max(len(out) for out in output_ids) 53 | if self.pad_to_multiple_of is not None: 54 | max_output_length = ( 55 | ( 56 | max_output_length + self.pad_to_multiple_of - 1) // 57 | self.pad_to_multiple_of * self.pad_to_multiple_of 58 | ) 59 | for feature in features: 60 | remainder = [self.tokenizer.pad_token_id] * ( 61 | max_output_length - len(feature['output_ids']) 62 | ) 63 | if isinstance(feature['output_ids'], list): 64 | feature['output_ids'] = feature['output_ids'] + remainder 65 | else: 66 | feature['output_ids'] = np.concatenate( 67 | [feature['output_ids'], remainder] 68 | ).astype(np.int64) 69 | return super().__call__(features, return_tensors) 70 | 71 | 72 | class Seq2SeqTrainer(_Seq2SeqTrainer): 73 | def prediction_step( 74 | self, 75 | model: nn.Module, 76 | inputs: dict[str, Any], 77 | prediction_loss_only: bool, 78 | ignore_keys=None, 79 | **gen_kwargs, 80 | ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 81 | 82 | if self.args.predict_with_generate: 83 | output_ids = inputs.pop('output_ids') 84 | input_ids = inputs['input_ids'] 85 | loss, generated_tokens, labels = super().prediction_step( 86 | model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs 87 | ) 88 | generated_tokens = generated_tokens[:, input_ids.size()[1]:] 89 | if self.args.predict_with_generate: 90 | labels = output_ids 91 | 92 | # breakpoint() 93 | return loss, generated_tokens, labels 94 | 95 | 96 | def _resolve_path(path: Union[str, Path]) -> Path: 97 | return Path(path).expanduser().resolve() 98 | 99 | 100 | def _sanity_check( 101 | input_ids: Sequence[int], 102 | output_ids: Sequence[int], 103 | tokenizer: PreTrainedTokenizer, 104 | ): 105 | print('--> Sanity check') 106 | for in_id, out_id in zip(input_ids, output_ids): 107 | if in_id == 0: 108 | continue 109 | if in_id in tokenizer.tokenizer.index_special_tokens: 110 | in_text = tokenizer.tokenizer.index_special_tokens[in_id] 111 | else: 112 | in_text = tokenizer.decode([in_id]) 113 | print(f'{repr(in_text):>20}: {in_id} -> {out_id}') 114 | 115 | 116 | @functools.cache 117 | def _get_yaml_parser() -> yaml.YAML: 118 | parser = yaml.YAML(typ='safe', pure=True) 119 | parser.indent(mapping=2, offset=2, sequence=4) 120 | parser.default_flow_style = False 121 | return parser 122 | 123 | 124 | @dc.dataclass 125 | class DataConfig(object): 126 | train_file: str 127 | val_file: Optional[str] = None 128 | test_file: Optional[str] = None 129 | 130 | num_proc: Optional[int] = None 131 | 132 | @property 133 | def data_format(self) -> str: 134 | return Path(self.train_file).suffix 135 | 136 | @property 137 | def data_files(self) -> dict[NamedSplit, str]: 138 | return { 139 | split: data_file 140 | for split, data_file in zip( 141 | [Split.TRAIN, Split.VALIDATION, Split.TEST], 142 | [self.train_file, self.val_file, self.test_file], 143 | ) 144 | if data_file is not None 145 | } 146 | 147 | 148 | @dc.dataclass 149 | class FinetuningConfig(object): 150 | data_config: DataConfig 151 | 152 | max_input_length: int 153 | max_output_length: int 154 | 155 | # 修改这里: 使用 default_factory 而不是 default 156 | training_args: Seq2SeqTrainingArguments = dc.field( 157 | default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output') 158 | ) 159 | peft_config: Optional[PeftConfig] = None 160 | 161 | def __post_init__(self): 162 | if not self.training_args.do_eval or self.data_config.val_file is None: 163 | # skips the evaluation stage when `do_eval` or `eval_file` is not provided 164 | self.training_args.do_eval = False 165 | self.training_args.evaluation_strategy = 'no' 166 | self.data_config.val_file = None 167 | else: 168 | self.training_args.per_device_eval_batch_size = ( 169 | self.training_args.per_device_eval_batch_size 170 | or self.training_args.per_device_train_batch_size 171 | ) 172 | 173 | @classmethod 174 | def from_dict(cls, **kwargs) -> 'FinetuningConfig': 175 | training_args = kwargs.get('training_args', None) 176 | if training_args is not None and not isinstance( 177 | training_args, Seq2SeqTrainingArguments 178 | ): 179 | gen_config = training_args.get('generation_config') 180 | # TODO: a bit hacky 181 | if not isinstance(gen_config, GenerationConfig): 182 | training_args['generation_config'] = GenerationConfig( 183 | **gen_config 184 | ) 185 | kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args) 186 | 187 | data_config = kwargs.get('data_config') 188 | if not isinstance(data_config, DataConfig): 189 | kwargs['data_config'] = DataConfig(**data_config) 190 | 191 | peft_config = kwargs.get('peft_config', None) 192 | if peft_config is not None and not isinstance(peft_config, PeftConfig): 193 | kwargs['peft_config'] = get_peft_config(peft_config) 194 | return cls(**kwargs) 195 | 196 | @classmethod 197 | def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig': 198 | path = _resolve_path(path) 199 | kwargs = _get_yaml_parser().load(path) 200 | return cls.from_dict(**kwargs) 201 | 202 | 203 | def _load_datasets( 204 | data_dir: Path, 205 | data_format: str, 206 | data_files: dict[NamedSplit, str], 207 | num_proc: Optional[int], 208 | ) -> DatasetDict: 209 | if data_format in ('.csv', '.json', '.jsonl'): 210 | dataset_dct = load_dataset( 211 | data_format[1:], 212 | data_dir=data_dir, 213 | data_files=data_files, 214 | num_proc=num_proc, 215 | ) 216 | else: 217 | err_msg = f"Cannot load dataset in the '{data_format}' format." 218 | raise NotImplementedError(err_msg) 219 | 220 | return dataset_dct 221 | 222 | 223 | class DataManager(object): 224 | def __init__(self, data_dir: str, data_config: DataConfig): 225 | self._num_proc = data_config.num_proc 226 | 227 | self._dataset_dct = _load_datasets( 228 | _resolve_path(data_dir), 229 | data_config.data_format, 230 | data_config.data_files, 231 | self._num_proc, 232 | ) 233 | 234 | def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]: 235 | return self._dataset_dct.get(split, None) 236 | 237 | def get_dataset( 238 | self, 239 | split: NamedSplit, 240 | process_fn: Callable[[dict[str, Any]], dict[str, Any]], 241 | batched: bool = True, 242 | remove_orig_columns: bool = True, 243 | ) -> Optional[Dataset]: 244 | orig_dataset = self._get_dataset(split) 245 | if orig_dataset is None: 246 | return 247 | 248 | if remove_orig_columns: 249 | remove_columns = orig_dataset.column_names 250 | else: 251 | remove_columns = None 252 | return orig_dataset.map( 253 | process_fn, 254 | batched=batched, 255 | remove_columns=remove_columns, 256 | num_proc=self._num_proc, 257 | ) 258 | 259 | 260 | def print_model_size(model: PreTrainedModel): 261 | print("--> Model") 262 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 263 | print(f"\n--> model has {total_params / 1e6}M params\n") 264 | 265 | 266 | def process_batch( 267 | batch: Mapping[str, Sequence], 268 | tokenizer: PreTrainedTokenizer, 269 | max_input_length: int, 270 | max_output_length: int, 271 | ) -> dict[str, list]: 272 | batched_tools = batch.get('tools', None) 273 | batched_conv = batch['conversations'] 274 | batched_input_ids = [] 275 | batched_labels = [] 276 | 277 | if batched_tools is None: 278 | batched_tools = [None] * len(batched_conv) 279 | 280 | for tools, conv in zip(batched_tools, batched_conv): 281 | input_ids, loss_masks = [ 282 | tokenizer.get_command('[gMASK]'), 283 | tokenizer.get_command('sop'), 284 | ], [False, False] 285 | 286 | if tools is not None: 287 | raise NotImplementedError() 288 | 289 | for message in conv: 290 | if message['role'] in ('system', 'user'): 291 | loss_mask_val = False 292 | else: 293 | loss_mask_val = True 294 | 295 | if message['role'] == 'tool': 296 | raise NotImplementedError() 297 | else: 298 | new_input_ids = tokenizer.build_single_message( 299 | message['role'], '', message['content'] 300 | ) 301 | new_loss_masks = [loss_mask_val] * len(new_input_ids) 302 | 303 | input_ids += new_input_ids 304 | loss_masks += new_loss_masks 305 | 306 | input_ids.append(tokenizer.eos_token_id) 307 | loss_masks = [False, *loss_masks] 308 | labels = [] 309 | for input_id, mask in zip(input_ids, loss_masks): 310 | if mask: 311 | labels.append(input_id) 312 | else: 313 | labels.append(-100) 314 | max_length = max_input_length + max_output_length + 1 315 | batched_input_ids.append(input_ids[:max_length]) 316 | batched_labels.append(labels[:max_length]) 317 | return {'input_ids': batched_input_ids, 'labels': batched_labels} 318 | 319 | 320 | def process_batch_eval( 321 | batch: Mapping[str, Sequence], 322 | tokenizer: PreTrainedTokenizer, 323 | max_input_length: int, 324 | max_output_length: int, 325 | ) -> dict[str, list]: 326 | batched_tools = batch.get('tools', None) 327 | batched_conv = batch['conversations'] 328 | batched_input_ids = [] 329 | # To avoid computing loss, we do not provide the `labels` field in the input dictionary. 330 | batched_output_ids = [] 331 | 332 | if batched_tools is None: 333 | batched_tools = [None] * len(batched_conv) 334 | 335 | for tools, conv in zip(batched_tools, batched_conv): 336 | input_ids = [ 337 | tokenizer.get_command('[gMASK]'), 338 | tokenizer.get_command('sop'), 339 | ] 340 | 341 | if tools is not None: 342 | raise NotImplementedError() 343 | 344 | for message in conv: 345 | if len(input_ids) >= max_input_length: 346 | break 347 | if message['role'] == 'tool': 348 | raise NotImplementedError() 349 | else: 350 | new_input_ids = tokenizer.build_single_message( 351 | message['role'], '', message['content'] 352 | ) 353 | if message['role'] == 'assistant': 354 | output_prompt, output_ids = ( 355 | new_input_ids[:1], 356 | new_input_ids[1:], 357 | ) 358 | output_ids.append(tokenizer.eos_token_id) 359 | batched_input_ids.append( 360 | input_ids[:max_input_length] + output_prompt[:1] 361 | ) 362 | batched_output_ids.append(output_ids[:max_output_length]) 363 | input_ids += new_input_ids 364 | return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids} 365 | 366 | 367 | # TODO: Not sure if this is necessary, can set it to half 368 | def _prepare_model_for_training(model: nn.Module): 369 | for param in model.parameters(): 370 | if param.requires_grad: 371 | param.data = param.data.to(torch.float32) 372 | 373 | 374 | def load_tokenizer_and_model( 375 | model_dir: str, 376 | trust_remote_code: bool = False, 377 | peft_config: Optional[PeftConfig] = None, 378 | ) -> tuple[PreTrainedTokenizer, nn.Module]: 379 | tokenizer = AutoTokenizer.from_pretrained( 380 | model_dir, trust_remote_code=trust_remote_code 381 | ) 382 | if peft_config is not None: 383 | if peft_config.peft_type.name == "PREFIX_TUNING": 384 | config = AutoConfig.from_pretrained( 385 | model_dir, 386 | trust_remote_code=trust_remote_code, 387 | empty_init=False 388 | ) 389 | config.pre_seq_len = peft_config.num_virtual_tokens 390 | config.use_cache = False 391 | model = AutoModelForCausalLM.from_pretrained( 392 | model_dir, 393 | trust_remote_code=trust_remote_code, 394 | config=config, 395 | empty_init=False 396 | ) 397 | if peft_config.peft_type.name == "LORA": 398 | model = AutoModelForCausalLM.from_pretrained( 399 | model_dir, 400 | trust_remote_code=trust_remote_code, 401 | empty_init=False 402 | ) 403 | model = get_peft_model(model, peft_config) 404 | model.print_trainable_parameters() 405 | else: 406 | model = AutoModelForCausalLM.from_pretrained( 407 | model_dir, 408 | trust_remote_code=trust_remote_code, 409 | empty_init=False 410 | ) 411 | print_model_size(model) 412 | 413 | return tokenizer, model 414 | 415 | 416 | def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer): 417 | batched_pred_ids, batched_label_ids = eval_preds 418 | 419 | metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []} 420 | for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids): 421 | pred_txt = tokenizer.decode(pred_ids).strip() 422 | label_txt = tokenizer.decode(label_ids).strip() 423 | pred_tokens = list(jieba.cut(pred_txt)) 424 | label_tokens = list(jieba.cut(label_txt)) 425 | rouge = Rouge() 426 | scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens)) 427 | for k, v in scores[0].items(): 428 | metrics_dct[k].append(round(v['f'] * 100, 4)) 429 | metrics_dct['bleu-4'].append( 430 | sentence_bleu( 431 | [label_tokens], 432 | pred_tokens, 433 | smoothing_function=SmoothingFunction().method3, 434 | ) 435 | ) 436 | return {k: np.mean(v) for k, v in metrics_dct.items()} 437 | 438 | 439 | @app.command() 440 | def main( 441 | data_dir: Annotated[str, typer.Argument(help='')], 442 | model_dir: Annotated[ 443 | str, 444 | typer.Argument( 445 | help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.' 446 | ), 447 | ], 448 | config_file: Annotated[str, typer.Argument(help='')], 449 | ): 450 | ft_config = FinetuningConfig.from_file(config_file) 451 | tokenizer, model = load_tokenizer_and_model( 452 | model_dir, 453 | trust_remote_code=True, 454 | peft_config=ft_config.peft_config, 455 | ) 456 | data_manager = DataManager(data_dir, ft_config.data_config) 457 | 458 | train_dataset = data_manager.get_dataset( 459 | Split.TRAIN, 460 | functools.partial( 461 | process_batch, 462 | tokenizer=tokenizer, 463 | max_input_length=ft_config.max_input_length, 464 | max_output_length=ft_config.max_output_length, 465 | ), 466 | batched=True, 467 | ) 468 | print('train_dataset:', train_dataset) 469 | val_dataset = data_manager.get_dataset( 470 | Split.VALIDATION, 471 | functools.partial( 472 | process_batch_eval, 473 | tokenizer=tokenizer, 474 | max_input_length=ft_config.max_input_length, 475 | max_output_length=ft_config.max_output_length, 476 | ), 477 | batched=True, 478 | ) 479 | if val_dataset is not None: 480 | print('val_dataset:', val_dataset) 481 | test_dataset = data_manager.get_dataset( 482 | Split.TEST, 483 | functools.partial( 484 | process_batch_eval, 485 | tokenizer=tokenizer, 486 | max_input_length=ft_config.max_input_length, 487 | max_output_length=ft_config.max_output_length, 488 | ), 489 | batched=True, 490 | ) 491 | if test_dataset is not None: 492 | print('test_dataset:', test_dataset) 493 | 494 | # checks encoded dataset 495 | # _sanity_check( 496 | # train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer 497 | # ) 498 | 499 | # turn model to fp32 500 | _prepare_model_for_training(model) 501 | 502 | ft_config.training_args.generation_config.pad_token_id = ( 503 | tokenizer.pad_token_id 504 | ) 505 | ft_config.training_args.generation_config.eos_token_id = [ 506 | tokenizer.eos_token_id, 507 | tokenizer.get_command('<|user|>'), 508 | tokenizer.get_command('<|observation|>'), 509 | ] 510 | trainer = Seq2SeqTrainer( 511 | model=model, 512 | args=ft_config.training_args, 513 | data_collator=DataCollatorForSeq2Seq( 514 | tokenizer=tokenizer, 515 | padding='longest', 516 | return_tensors='pt', 517 | ), 518 | train_dataset=train_dataset, 519 | eval_dataset=val_dataset.select(list(range(50))), 520 | tokenizer=tokenizer, 521 | compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), 522 | ) 523 | trainer.train() 524 | 525 | # test stage 526 | if test_dataset is not None: 527 | trainer.predict(test_dataset) 528 | 529 | 530 | if __name__ == '__main__': 531 | app() 532 | -------------------------------------------------------------------------------- /CyberFriend_LLM_core/finetune/requirement.txt: -------------------------------------------------------------------------------- 1 | # for finetune demo 2 | jieba>=0.42.1 3 | ruamel_yaml>=0.18.5 4 | rouge_chinese>=1.0.3 5 | jupyter>=1.0.0 6 | datasets>=2.16.1 7 | peft>=0.7.1 8 | deepspeed>=0.13.1 -------------------------------------------------------------------------------- /CyberFriend_LLM_core/requirements.txt: -------------------------------------------------------------------------------- 1 | # basic requirements 2 | 3 | protobuf>=4.25.2 4 | transformers>=4.37.1 5 | tokenizers>=0.15.0 6 | cpm_kernels>=1.0.11 7 | torch>=2.1.0 8 | gradio>=4.16.0 9 | sentencepiece>=0.1.99 10 | sentence_transformers>=2.2.2 11 | accelerate>=0.26.1 12 | streamlit>=1.30.0 13 | fastapi>=0.109.0 14 | loguru~=0.7.2 15 | mdtex2html>=1.3.0 16 | latex2mathml>=3.77.0 17 | 18 | # for openai demo 19 | 20 | openai>=1.10.0 21 | zhipuai>=2.0.1 22 | 23 | pydantic>=2.5.3 24 | sse-starlette>=2.0.0 25 | uvicorn>=0.27.0 26 | timm>=0.9.12 27 | tiktoken>=0.5.2 28 | 29 | # for langchain demo 30 | 31 | langchain>=0.1.4 32 | langchainhub>=0.1.14 33 | arxiv>=2.1.0 34 | 35 | langchain 36 | transformers 37 | sentence_transformers 38 | chromadb 39 | unstructured -------------------------------------------------------------------------------- /CyberFriend_LLM_core/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import torch 4 | from transformers import PreTrainedModel, PreTrainedTokenizer 5 | from transformers.generation.logits_process import LogitsProcessor 6 | from typing import Union, Tuple 7 | 8 | 9 | class InvalidScoreLogitsProcessor(LogitsProcessor): 10 | def __call__( 11 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 12 | ) -> torch.FloatTensor: 13 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 14 | scores.zero_() 15 | scores[..., 5] = 5e4 16 | return scores 17 | 18 | 19 | def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: 20 | content = "" 21 | for response in output.split("<|assistant|>"): 22 | metadata, content = response.split("\n", maxsplit=1) 23 | if not metadata.strip(): 24 | content = content.strip() 25 | content = content.replace("[[训练时间]]", "2023年") 26 | else: 27 | if use_tool: 28 | content = "\n".join(content.split("\n")[1:-1]) 29 | 30 | def tool_call(**kwargs): 31 | return kwargs 32 | 33 | parameters = eval(content) 34 | content = { 35 | "name": metadata.strip(), 36 | "arguments": json.dumps(parameters, ensure_ascii=False) 37 | } 38 | else: 39 | content = { 40 | "name": metadata.strip(), 41 | "content": content 42 | } 43 | return content 44 | 45 | 46 | @torch.inference_mode() 47 | def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): 48 | messages = params["messages"] 49 | tools = params["tools"] 50 | temperature = float(params.get("temperature", 1.0)) 51 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 52 | top_p = float(params.get("top_p", 1.0)) 53 | max_new_tokens = int(params.get("max_tokens", 256)) 54 | echo = params.get("echo", True) 55 | messages = process_chatglm_messages(messages, tools=tools) 56 | query, role = messages[-1]["content"], messages[-1]["role"] 57 | 58 | inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) 59 | inputs = inputs.to(model.device) 60 | input_echo_len = len(inputs["input_ids"][0]) 61 | 62 | if input_echo_len >= model.config.seq_length: 63 | print(f"Input length larger than {model.config.seq_length}") 64 | 65 | eos_token_id = [ 66 | tokenizer.eos_token_id, 67 | tokenizer.get_command("<|user|>"), 68 | ] 69 | 70 | gen_kwargs = { 71 | "max_new_tokens": max_new_tokens, 72 | "do_sample": True if temperature > 1e-5 else False, 73 | "top_p": top_p, 74 | "repetition_penalty": repetition_penalty, 75 | "logits_processor": [InvalidScoreLogitsProcessor()], 76 | } 77 | if temperature > 1e-5: 78 | gen_kwargs["temperature"] = temperature 79 | 80 | total_len = 0 81 | for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): 82 | total_ids = total_ids.tolist()[0] 83 | total_len = len(total_ids) 84 | if echo: 85 | output_ids = total_ids[:-1] 86 | else: 87 | output_ids = total_ids[input_echo_len:-1] 88 | 89 | response = tokenizer.decode(output_ids) 90 | if response and response[-1] != "�": 91 | response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) 92 | 93 | yield { 94 | "text": response, 95 | "usage": { 96 | "prompt_tokens": input_echo_len, 97 | "completion_tokens": total_len - input_echo_len, 98 | "total_tokens": total_len, 99 | }, 100 | "finish_reason": "function_call" if stop_found else None, 101 | } 102 | 103 | if stop_found: 104 | break 105 | 106 | # Only last stream result contains finish_reason, we set finish_reason as stop 107 | ret = { 108 | "text": response, 109 | "usage": { 110 | "prompt_tokens": input_echo_len, 111 | "completion_tokens": total_len - input_echo_len, 112 | "total_tokens": total_len, 113 | }, 114 | "finish_reason": "stop", 115 | } 116 | yield ret 117 | 118 | gc.collect() 119 | torch.cuda.empty_cache() 120 | 121 | 122 | def process_chatglm_messages(messages, tools=None): 123 | _messages = messages 124 | messages = [] 125 | if tools: 126 | messages.append( 127 | { 128 | "role": "system", 129 | "content": "Answer the following questions as best as you can. You have access to the following tools:", 130 | "tools": tools 131 | } 132 | ) 133 | 134 | for m in _messages: 135 | role, content, func_call = m.role, m.content, m.function_call 136 | if role == "function": 137 | messages.append( 138 | { 139 | "role": "observation", 140 | "content": content 141 | } 142 | ) 143 | 144 | elif role == "assistant" and func_call is not None: 145 | for response in content.split("<|assistant|>"): 146 | metadata, sub_content = response.split("\n", maxsplit=1) 147 | messages.append( 148 | { 149 | "role": role, 150 | "metadata": metadata, 151 | "content": sub_content.strip() 152 | } 153 | ) 154 | else: 155 | messages.append({"role": role, "content": content}) 156 | return messages 157 | 158 | 159 | def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): 160 | for response in generate_stream_chatglm3(model, tokenizer, params): 161 | pass 162 | return response 163 | 164 | 165 | def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]: 166 | stop_found = False 167 | for string in stop_strings: 168 | idx = reply.find(string) 169 | if idx != -1: 170 | reply = reply[:idx] 171 | stop_found = True 172 | break 173 | 174 | if not stop_found: 175 | # If something like "\nYo" is generated just before "\nYou: is completed, trim it 176 | for string in stop_strings: 177 | for j in range(len(string) - 1, 0, -1): 178 | if reply[-j:] == string[:j]: 179 | reply = reply[:-j] 180 | break 181 | else: 182 | continue 183 | 184 | break 185 | 186 | return reply, stop_found 187 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/.env.prod: -------------------------------------------------------------------------------- 1 | DRIVER=~fastapi+~httpx+~websockets 2 | HOST=0.0.0.0 # 配置 NoneBot 监听的 IP / 主机名 3 | PORT=5556 # 配置 NoneBot 监听的端口 4 | COMMAND_START=["/"] # 配置命令起始字符 5 | COMMAND_SEP=["."] # 配置命令分割字符 6 | API_TIMEOUT=600 -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/GetPathUtil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | project_path = os.path.dirname(__file__) 4 | 5 | def getPath(*relativePath: str) -> str: 6 | relativePathFiltered = [path for path in relativePath if path is not None and path != ''] 7 | p = os.path.join(project_path, *relativePathFiltered) 8 | os.makedirs(os.path.dirname(p), exist_ok=True) 9 | return p 10 | 11 | BOT_ID="2469104787" -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/common/CustomChecker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | 6 | from nonebot.adapters.onebot.v11 import Event, PrivateMessageEvent, GroupMessageEvent 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | from CyberFriend_bot_plugin.plugins.update_members import membersService 10 | 11 | 12 | async def is_me(event: Event): 13 | return ((event.get_user_id() in [str(mem.user_id) for mem in membersService.queryByGroupId(647155255)] and isinstance(event, PrivateMessageEvent)) 14 | or (isinstance(event, GroupMessageEvent) and event.group_id == 647155255)) 15 | 16 | 17 | async def is_private(event: Event): 18 | return isinstance(event, PrivateMessageEvent) 19 | 20 | 21 | async def is_group(event: Event): 22 | return isinstance(event, GroupMessageEvent) 23 | 24 | 25 | async def is_admin(event: GroupMessageEvent): 26 | return event.sender.role == "admin" 27 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/common/MembersOptUtil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 6 | from plugins.update_members import membersService 7 | 8 | 9 | def get_normal_member_str(group_id, user_id): 10 | mem = membersService.query(group_id, user_id) 11 | if mem is None: 12 | return str(user_id) 13 | elif len(mem.name_card.strip())!=0: 14 | return mem.name_card 15 | else: 16 | return mem.nick_name -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/common/MessageBuilder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from base64 import b64encode 3 | 4 | from nonebot.adapters.onebot.v11 import Message, Event 5 | 6 | 7 | class MessageBuilder: 8 | 9 | def __init__(self, msg=""): 10 | self.message = msg 11 | 12 | def appendAt(self, user_id, name=None): 13 | """尽量随便加个name,不然可能会出现一个名字需要10s超时失败的查询""" 14 | if name is None: 15 | self.message += "[CQ:at,qq=" + str(user_id) + ",name=" + str(user_id) + "]" 16 | else: 17 | self.message += "[CQ:at,qq=" + str(user_id) + ",name=" + str(name) + "]" 18 | return self 19 | 20 | def appendText(self, msg): 21 | self.message += msg 22 | return self 23 | 24 | def appendReply(self, id:int): 25 | self.message += f"[CQ:reply,id={id}]" 26 | return self 27 | 28 | def appendReplyWithEvent(self, event: Event): 29 | self.message += f"[CQ:reply,id={event.message_id}]" 30 | return self 31 | 32 | def appendImage(self, file: str): 33 | """支持本地文件/URL/base64字符串""" 34 | if isinstance(file, str) and not file.startswith("http") and not file.startswith("base64"): 35 | with open(file, "rb") as f: 36 | file = f"base64://{b64encode(f.read()).decode()}" 37 | self.message += f"[CQ:image,file={file},cache=true,proxy=true]" 38 | return self 39 | 40 | def build(self): 41 | return Message(self.message) 42 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holk-h/CyberFriend/a047c16e68488084bff949d78d1646c3516e8586/CyberFriend_bot_plugin/plugins/__init__.py -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/add_image_to_db/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from nonebot import get_driver, on_command 5 | from nonebot.adapters.onebot.v11 import Event 6 | from nonebot.internal.permission import Permission 7 | from nonebot.plugin import PluginMetadata 8 | 9 | from .config import Config 10 | from ..message_record import imageRecordService 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 13 | from CyberFriend_bot_plugin.common.CustomChecker import is_me 14 | 15 | __plugin_meta__ = PluginMetadata( 16 | name="add_image_to_db", 17 | description="", 18 | usage="", 19 | config=Config, 20 | ) 21 | 22 | global_config = get_driver().config 23 | config = Config.parse_obj(global_config) 24 | 25 | p = Permission(is_me) 26 | add_image_to_db = on_command("addpic", permission=p) 27 | 28 | @add_image_to_db.handle() 29 | async def handle_function(event: Event): 30 | msg = event.get_message() 31 | images = msg["image"] 32 | if len(images)>0: 33 | for img in images: 34 | imageRecordService.addOne(filePath=img.get('data')['file'], url=img.get('data')['url']) 35 | await add_image_to_db.finish("success:"+str(len(images))) 36 | else: 37 | await add_image_to_db.finish("请在消息中添加图片") 38 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/add_image_to_db/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/cyber_friend/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | import sys 5 | import time 6 | import ast 7 | 8 | from nonebot import logger 9 | from nonebot import on_message, get_bots 10 | from nonebot.internal.adapter import Bot, Event 11 | from nonebot.adapters.onebot.v11 import Message 12 | 13 | from .utils import GLM 14 | from ..message_record import MessageRecordService, imageRecordService 15 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 16 | from common.MessageBuilder import MessageBuilder 17 | 18 | bot = get_bots() 19 | glm = GLM() 20 | messageRecordService = MessageRecordService() 21 | llm_reply = on_message(priority=10, block=False) 22 | 23 | def extract_session(text): 24 | pattern = r"group_(\d+)_\d+" 25 | match = re.search(pattern, text) 26 | middle_part = match.group(1) if match else None 27 | return middle_part 28 | 29 | def extract_id(text): 30 | pattern = r"group_\d+_(\d+)" 31 | match = re.search(pattern, text) 32 | third_part = match.group(1) if match else None 33 | return third_part 34 | 35 | def remove_cq_patterns(json_objects): 36 | pattern = re.compile(r'\[CQ:(?!at).*?\]') 37 | for obj in json_objects: 38 | for key in obj.keys(): 39 | obj[key] = re.sub(pattern, '', obj[key]) 40 | 41 | return json_objects 42 | 43 | def glmCall(session_id): 44 | records = messageRecordService.queryLast(session_id) 45 | records = [{str(i.user_id): i.message} for i in records] 46 | records.reverse() 47 | logger.warning(records) 48 | return glm.call(records) 49 | 50 | SESSION_ID_WHITE_LIST = ['647155255', '793626723'] 51 | 52 | IMAGE_PATTERN = ["?", "我不知道", "?"] 53 | 54 | @llm_reply.handle() 55 | async def handle_function(bot: Bot, event: Event): 56 | session_id = extract_session(event.get_session_id()) 57 | if session_id in SESSION_ID_WHITE_LIST: 58 | if event.is_tome() or random.randint(1,10) == 4: 59 | message = glmCall(session_id) 60 | logger.warning(message) 61 | try: 62 | for msg in ast.literal_eval(message): 63 | # logger.warning('msg:'+msg) 64 | if msg in IMAGE_PATTERN: 65 | await llm_reply.send(MessageBuilder().appendImage(imageRecordService.getRandomImage()).build()) 66 | else: 67 | await llm_reply.send(Message(msg)) 68 | except Exception as e: 69 | logger.warning(e) 70 | if len(message) > 0: 71 | messageRecordService.addOne(session_id, 0, str(message), time.time()) 72 | await llm_reply.finish() 73 | else: 74 | await llm_reply.finish() 75 | else: 76 | await llm_reply.finish() -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/cyber_friend/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/cyber_friend/prompt.txt: -------------------------------------------------------------------------------- 1 | ## 任务: 根据聊天记录,模拟真实人类,在QQ群聊天。 2 | 3 | ## 环境: QQ群聊。 4 | 5 | ## 角色: 作为一个有独立人格的群友,而不是人工智能助手。 6 | 7 | ## Workflow: 8 | 1. **理解上下文**: 根据聊天历史和当前的对话内容作出个性化回应,让对话更自然。 9 | 2. **使用日常语言**: 仿照日常对话中的非正式和自然语言风格,包括俚语和常见表达方式。 10 | 3. **表达情感和同理心**: 在适当的时候表达同情、兴奋或幽默等情感,使对话显得更人性化。 11 | 4. **避免重复和机械式回应**: 力求每次都提供独特和具体的回应。不要发送单个问号。 12 | 5. **流利性和连贯性**: 保持语言的流畅性和逻辑连贯性,确保对话易于理解和跟随。 13 | 6. 你无需帮助任何人,只是在群里聊天即可。 14 | 7. 回复要简短,符合网络发言的特征。内容可以是发起一个新的话题,或者只是跟随当前群聊主题。 15 | 8. 你可以骂人,或者是根据自己的知识解答问题,或者只是单纯的根据群内话题进行聊天。 16 | 9. 你可以回复多条消息,但是回复要完整,不要重复自己说过的话。 17 | 18 | ## 聊天记录的格式: 19 | [{'user_id':'message_content'},...] 20 | 其中,user_id代表群内每个用户的id,message_content则是他发送的消息,user_id为0则代表是你的历史发言。 21 | 在聊天记录中: 22 | - "[CQ:at,qq=]"代表提到某个人。 23 | - "[CQ:image,file=,url=]"代表某张图片或表情包。 24 | - "[CQ:video,file=,url=]"代表某个视频。 25 | - "[CQ:mface,id=]"和"[CQ:face,id=]"代表聊天表情。 26 | 你也可以回复表情包和聊天表情。 27 | 28 | 29 | ## 目标: 作为群友,参与聊天,或者是挑起话题。**你需要有自己的个性,不要只是附和**。 30 | 31 | 下面,user会发给你群聊最近的15条聊天记录,你只需要根据这些聊天记录,输出你自己现在想说的话,条数不限,格式为['你想说的消息1','你想说的消息2', ...]。 -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/cyber_friend/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import sys 4 | import os 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 6 | from CyberFriend_bot_plugin.GetPathUtil import getPath 7 | 8 | 9 | class GLM: 10 | def __init__(self): 11 | with open(getPath('plugins\cyber_friend\prompt.txt'), encoding='utf-8') as f: 12 | self.prompt = f.read() 13 | 14 | def call(self, records): 15 | # Define the URL and the payload for the POST request 16 | url = "http://100.87.223.81:9021/v1/chat/completions" 17 | data = { 18 | "model": "chatglm3-6b", 19 | "messages": [ 20 | {"role": "system", "content": self.prompt}, 21 | {"role": "user", "content": str(records)} 22 | ], 23 | "stream": False, 24 | "max_tokens": 100, 25 | "temperature": 0.8, 26 | "top_p": 0.8 27 | } 28 | # Make the POST request 29 | response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(data)) 30 | # Check if the request was successful 31 | if response.status_code == 200: 32 | # Parse the JSON response to extract the 'content' 33 | response_json = response.json() 34 | content = response_json.get('choices', [{}])[0].get('message', {}).get('content', 'No content found') 35 | else: 36 | content = f"Failed to fetch the story. Error code: {response.status_code}" 37 | return str(content) 38 | 39 | if __name__ == "__main__": 40 | glm = GLM() 41 | print(type(glm.call('你好'))) -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/group_handle/__init__.py: -------------------------------------------------------------------------------- 1 | from nonebot import on_request, logger 2 | from nonebot.adapters.onebot.v11 import Bot, GroupRequestEvent, RequestEvent 3 | from nonebot.plugin import PluginMetadata 4 | 5 | __plugin_meta__ = PluginMetadata( 6 | name="自动同意群邀请", 7 | description="自动同意特定用户拉入群的邀请。", 8 | usage="插件无需手动操作,自动运行。" 9 | ) 10 | 11 | auto_accept_group_invite = on_request() 12 | 13 | ALLOWED_USER_IDS = {1599840925} 14 | 15 | @auto_accept_group_invite.handle() 16 | async def handle_group_invite(bot: Bot, event: RequestEvent): 17 | logger.warning(f"收到来自 {event.user_id} 的群邀请") 18 | # 确保邀请来自特定用户 19 | if event.user_id in ALLOWED_USER_IDS: 20 | # 使用正确的bot实例回应请求 21 | await bot.set_group_add_request( 22 | flag=event.flag, 23 | sub_type=event.sub_type, 24 | approve=True, 25 | reason=" " 26 | ) 27 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/group_handle/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/member_join/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from nonebot import get_driver, on_notice, get_bots 5 | from nonebot.adapters.onebot.v11 import GroupIncreaseNoticeEvent, Event 6 | from nonebot.internal.rule import Rule 7 | from nonebot.plugin import PluginMetadata 8 | from nonebot import logger 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 11 | from CyberFriend_bot_plugin.GetPathUtil import getPath, BOT_ID 12 | from .config import Config 13 | 14 | __plugin_meta__ = PluginMetadata( 15 | name="member_join", 16 | description="", 17 | usage="", 18 | config=Config, 19 | ) 20 | 21 | from ..update_members import membersService 22 | 23 | global_config = get_driver().config 24 | config = Config.parse_obj(global_config) 25 | 26 | 27 | async def isGroupIncreaseNoticeEvent(event: Event) -> bool: 28 | return isinstance(event, GroupIncreaseNoticeEvent) 29 | 30 | rule = Rule(isGroupIncreaseNoticeEvent) 31 | 32 | member_join = on_notice(rule=rule) 33 | 34 | @member_join.handle() 35 | async def handle_function(event: GroupIncreaseNoticeEvent): 36 | # 获取新人的id 37 | user_id = event.get_user_id() 38 | # 获取群号 39 | group_id = event.group_id 40 | bot = get_bots()[BOT_ID] 41 | oneData = await bot.call_api("get_group_member_info", group_id=group_id, user_id=user_id) 42 | logger.warning(oneData) 43 | mem = membersService.query(group_id,user_id) 44 | if mem is None: 45 | membersService.addOne(oneData) 46 | else: 47 | membersService.updateOne(oneData, True) 48 | await member_join.finish() -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/member_join/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/member_leave/__init__.py: -------------------------------------------------------------------------------- 1 | from nonebot import get_driver, on_notice, logger 2 | from nonebot.adapters.onebot.v11 import Event, GroupDecreaseNoticeEvent 3 | from nonebot.internal.rule import Rule 4 | from nonebot.plugin import PluginMetadata 5 | 6 | from .config import Config 7 | 8 | __plugin_meta__ = PluginMetadata( 9 | name="member_leave", 10 | description="", 11 | usage="", 12 | config=Config, 13 | ) 14 | 15 | from ..update_members import membersService 16 | 17 | global_config = get_driver().config 18 | config = Config.parse_obj(global_config) 19 | async def isGroupDecreaseNoticeEvent(event: Event) -> bool: 20 | return isinstance(event, GroupDecreaseNoticeEvent) 21 | 22 | rule = Rule(isGroupDecreaseNoticeEvent) 23 | 24 | member_leave = on_notice(rule=rule) 25 | 26 | @member_leave.handle() 27 | async def handle_function(event: GroupDecreaseNoticeEvent): 28 | # 获取新人的id 29 | user_id = event.get_user_id() 30 | # 获取群号 31 | group_id = event.group_id 32 | logger.warning(f"member_leave: {group_id} {user_id}") 33 | membersService.updateEnable(group_id, user_id, False) 34 | await member_leave.finish() 35 | 36 | 37 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/member_leave/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/message_record/ImageUtil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | import re 5 | import sys 6 | import uuid 7 | from base64 import b64encode 8 | 9 | import requests 10 | from sqlalchemy import create_engine, Column, String 11 | from sqlalchemy.orm import sessionmaker, declarative_base 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 14 | from CyberFriend_bot_plugin.GetPathUtil import getPath 15 | 16 | engine = create_engine('sqlite:///' + getPath("plugins/message_record/image_record.db"), echo=False) 17 | Base = declarative_base() 18 | session = sessionmaker(bind=engine)() 19 | 20 | 21 | class ImageRecord(Base): 22 | __tablename__ = 'image_record' 23 | 24 | file_name = Column(String, primary_key=True) 25 | file_url = Column(String) 26 | file_base64 = Column(String) 27 | 28 | def __repr__(self): 29 | return f'' 30 | 31 | 32 | Base.metadata.create_all(engine) 33 | 34 | 35 | def get_file_name(file=None, url=None): 36 | # if the file is a url, extract the last part after the slash 37 | file_name = "" 38 | if file is not None: 39 | file_name = os.path.basename(file) 40 | if len(file_name) == 0 and url is not None: 41 | us = url.split("/") 42 | file_name = us[-1] 43 | if len(us[-1])<9 and len(us)>1: 44 | file_name = us[-2] 45 | if len(file_name) < 9: 46 | file_name = uuid.uuid4().hex 47 | # replace any special characters in the file name with underscores 48 | file_name = re.sub("[^a-zA-Z0-9.]", "_", file_name) 49 | return file_name 50 | 51 | 52 | def download_file(url, name): 53 | file_name = name 54 | if file_name is None or len(file_name) < 5: 55 | file_name = get_file_name(file_name, url=url) 56 | response = requests.get(url) 57 | # check if the response is successful 58 | if response.status_code == 200: 59 | # open a file with the same name as the url 60 | with open(file_name, "wb") as f: 61 | # write the response content to the file 62 | f.write(response.content) 63 | # return the file name 64 | return file_name 65 | else: 66 | # raise an exception if the response is not successful 67 | raise Exception(f"Failed to download file from {url}") 68 | 69 | 70 | class ImageRecordService: 71 | def __init__(self, session=session): 72 | self.session = session 73 | 74 | def queryAllName(self): 75 | return self.session.query(ImageRecord.file_name).all() 76 | 77 | def queryAll(self): 78 | return self.session.query(ImageRecord).all() 79 | 80 | def queryByName(self, file_name): 81 | try: 82 | ans = self.session.query(ImageRecord).filter(ImageRecord.file_name == file_name).one() 83 | except: 84 | ans = None 85 | return ans 86 | 87 | def addOne(self, filePath, url=None): 88 | """ 89 | 本地文件仅传filePath, 网络文件 请传filePath:文件名(可以是None),url: 下载地址 90 | """ 91 | name = get_file_name(filePath, url) 92 | if self.queryByName(name) is None: 93 | file_url = filePath 94 | if url is not None: 95 | file_url = url 96 | filePath = download_file(url, name) 97 | try: 98 | with open(filePath, "rb") as f: 99 | file_base64 = f"base64://{b64encode(f.read()).decode()}" 100 | imageRecord = ImageRecord(file_name=name, file_base64=file_base64, file_url=file_url) 101 | session.add(imageRecord) 102 | session.commit() 103 | finally: 104 | # 删除自动下载的网络文件 105 | if os.path.exists(filePath) and url is not None: 106 | os.remove(filePath) 107 | 108 | def getRandomImage(self): 109 | all = self.queryAllName() 110 | r = random.randint(0, len(all)-1) 111 | return self.queryByName(all[r][0]).file_base64 112 | 113 | imageRecordService = ImageRecordService() 114 | 115 | if __name__ == '__main__': 116 | # imageRecordService.addOne(r"F:\Pictures\temp\QQ截图20240215142045.jpg") 117 | # imageRecordService.addOne(filePath=None, url="https://tianquan.gtimg.cn/nudgeaction/item/0/expression.jpg") 118 | print(imageRecordService.queryAllName()) 119 | # print(imageRecordService.getRandomImage()) 120 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/message_record/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from nonebot import get_driver, get_bots, on_message 4 | from nonebot.plugin import PluginMetadata 5 | from nonebot.internal.adapter import Bot, Event 6 | from nonebot.typing import T_State 7 | from nonebot import logger 8 | 9 | from .ImageUtil import imageRecordService 10 | from .config import Config 11 | 12 | __plugin_meta__ = PluginMetadata( 13 | name="message_record", 14 | description="", 15 | usage="", 16 | config=Config, 17 | ) 18 | 19 | from .util import MessageRecordService 20 | 21 | global_config = get_driver().config 22 | config = Config.parse_obj(global_config) 23 | 24 | bot = get_bots() 25 | message_record = on_message(rule=None, priority=10, block=False) 26 | messageRecordService = MessageRecordService() 27 | @message_record.handle() 28 | async def handle_function(bot: Bot, event: Event, state: T_State): 29 | # 获取发送人的 QQ 号 30 | user_id = event.get_user_id() 31 | # 获取发送人的昵称 32 | user_name = event.sender.nickname 33 | # 获取发送的消息内容 34 | message = event.get_message() 35 | for m in message["image"]: 36 | logger.info(f"{m.get('data')['file']}: {m.get('data')['url']}") 37 | # 判断是否为群聊消息 38 | if event.message_type == "group": 39 | # 获取群聊的 ID 40 | group_id = event.group_id 41 | # logger.warning(group_id) 42 | # 获取群聊的名称 43 | # group_info = await bot.get_group_info(group_id=group_id) 44 | # group_name = group_info["group_name"] 45 | # 输出群聊+人 和 消息内容 46 | messageRecordService.addOne(group_id, user_id, str(message), time.time()) 47 | else: 48 | # 输出私聊人 和 消息内容 49 | messageRecordService.addOne(user_id, user_id, str(message), time.time()) 50 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/message_record/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/message_record/get_record.py: -------------------------------------------------------------------------------- 1 | import json 2 | import datetime 3 | from util import MessageRecordService 4 | 5 | target_group_id = 536348689 6 | path = f'D:\holk\CyberFriend\CyberFriend_bot_plugin\\record_data\{target_group_id}_{datetime.date.today()}.json' 7 | MS = MessageRecordService() 8 | 9 | record = MS.querySpecifyAll(target_group_id) 10 | record = [{"user_id": i.user_id, "message": i.message} for i in record] 11 | record.reverse() 12 | filtered_messages = [] 13 | for msg in record: 14 | # cq_index = msg["message"].find("[CQ:forward") 15 | # if cq_index == -1 or (cq_index != -1 and msg["message"][cq_index:].startswith("[CQ:at")): 16 | filtered_messages.append(msg) 17 | with open(path, 'w', encoding='utf-8') as f: 18 | json.dump(filtered_messages, f, indent=4, ensure_ascii=False) -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/message_record/util.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from sqlalchemy import create_engine, Column, Integer, String, Float 4 | from sqlalchemy.orm import sessionmaker, declarative_base 5 | import sys 6 | import os 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 8 | from CyberFriend_bot_plugin.GetPathUtil import getPath 9 | 10 | engine = create_engine('sqlite:///'+getPath("plugins/message_record/message_record.db"), echo=False) 11 | Base = declarative_base() 12 | session = sessionmaker(bind=engine)() 13 | 14 | class MessageRecord(Base): 15 | __tablename__ = 'message_record' 16 | 17 | group_id = Column(Integer, primary_key=True) 18 | user_id = Column(Integer, primary_key=True) 19 | message = Column(String) 20 | data_time = Column(Float, primary_key=True) 21 | 22 | def __repr__(self): 23 | return f'' 24 | 25 | Base.metadata.create_all(engine) 26 | 27 | class MessageRecordService: 28 | def __init__(self, session=session): 29 | self.session=session 30 | 31 | def queryAll(self): 32 | return self.session.query(MessageRecord).all() 33 | 34 | def querySpecifyAll(self, group_id): 35 | return self.session.query(MessageRecord).filter(MessageRecord.group_id==group_id).order_by(MessageRecord.data_time.desc()).all() 36 | 37 | def queryLast(self, group_id): 38 | return self.session.query(MessageRecord).filter(MessageRecord.group_id==group_id).order_by(MessageRecord.data_time.desc()).limit(15).all() 39 | 40 | def addOne(self,group_id, user_id, message, data_time): 41 | messageRecord = MessageRecord(group_id=group_id, user_id=user_id, message=message, data_time=data_time) 42 | session.add(messageRecord) 43 | session.commit() 44 | 45 | if __name__ == '__main__': 46 | # messageRecordService = MessageRecordService() 47 | # for i in range(10): 48 | # messageRecordService.addOne(2, 2, "adsbbb"+str(i), time.time()) 49 | # time.sleep(0.1) 50 | # print(messageRecordService.queryAll()) 51 | # print(messageRecordService.queryLast(2)) 52 | from nonebot.adapters.onebot.v11 import Message, MessageSegment 53 | message = Message( 54 | [ 55 | MessageSegment.text("test"), 56 | MessageSegment.text("test4"), 57 | MessageSegment.image(r"http://gchat.qpic.cn/gchatpic_new/0/0-0-B8F694B7886F0E94481D91958E8AE31F/0?term=2") 58 | ] 59 | ) 60 | for i in message["image"]: 61 | print(i.get("data")) 62 | print(i) 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import nonebot 5 | from nonebot import get_driver 6 | from nonebot import logger 7 | from nonebot import require 8 | from nonebot.plugin import PluginMetadata 9 | 10 | from .config import Config 11 | from ..update_members import membersService 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 13 | from CyberFriend_bot_plugin.GetPathUtil import getPath, BOT_ID 14 | 15 | require("nonebot_plugin_apscheduler") 16 | 17 | from nonebot_plugin_apscheduler import scheduler 18 | 19 | __plugin_meta__ = PluginMetadata( 20 | name="scheduler", 21 | description="", 22 | usage="", 23 | config=Config, 24 | ) 25 | 26 | global_config = get_driver().config 27 | config = Config.parse_obj(global_config) 28 | 29 | 30 | # 基于装饰器的方式 31 | @scheduler.scheduled_job("cron", hour="3", minute="0", second="0", id="job_0") 32 | async def update(): 33 | bot = nonebot.get_bots()[BOT_ID] 34 | data = await bot.call_api("get_group_list") 35 | groups = [one["group_id"] for one in data] 36 | logger.info(f"start scheduled_job update {groups}") 37 | for g_id in groups: 38 | await membersService.updateGroup(g_id) 39 | 40 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/scheduler/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/update_members/MembersUtil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import time 5 | from typing import Type, Union, Any 6 | 7 | from nonebot import get_bots 8 | from sqlalchemy import create_engine, Column, Integer, String, Float, update 9 | from sqlalchemy.orm import sessionmaker, declarative_base 10 | import sys 11 | import os 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 14 | from CyberFriend_bot_plugin.GetPathUtil import getPath, BOT_ID 15 | 16 | engine = create_engine('sqlite:///' + getPath("plugins/update_members/members.db"), echo=False) 17 | Base = declarative_base() 18 | session = sessionmaker(bind=engine)() 19 | 20 | 21 | class Members(Base): 22 | __tablename__ = 'Members' 23 | 24 | group_id = Column(Integer, primary_key=True) 25 | user_id = Column(Integer, primary_key=True) 26 | qq_level = Column(Integer) 27 | age = Column(Integer) 28 | sex = Column(Integer) 29 | nick_name = Column(String) 30 | name_card = Column(String) 31 | special_title = Column(String) 32 | email = Column(String) 33 | sign = Column(String) 34 | join_date = Column(Integer) 35 | last_speak_date = Column(Integer) 36 | ext_data = Column(String) 37 | ctime = Column(Integer) 38 | utime = Column(Integer) 39 | enable = Column(Integer) 40 | 41 | def __repr__(self): 42 | return f'' 43 | 44 | 45 | Base.metadata.create_all(engine) 46 | 47 | 48 | class MembersService: 49 | def __init__(self, session=session): 50 | self.session = session 51 | self.bot = None 52 | 53 | def getSelfBot(self): 54 | if self.bot is None: 55 | try: 56 | self.bot = get_bots()[BOT_ID] 57 | except ValueError: 58 | self.bot = None 59 | return self.bot 60 | 61 | def queryAll(self): 62 | return self.session.query(Members).all() 63 | 64 | # 65 | # def querySpecifyAll(self, group_id): 66 | # return self.session.query(MessageRecord).filter(MessageRecord.group_id == group_id).all() 67 | # 68 | def query(self, group_id, user_id) -> Type[Members] | None: 69 | try: 70 | return self.session.query(Members).filter(Members.group_id == group_id, Members.user_id == user_id).one() 71 | except: 72 | return None 73 | 74 | def queryByGroupId(self, group_id) -> list[Type[Members]]: 75 | try: 76 | return self.session.query(Members).filter(Members.group_id == group_id).all() 77 | except: 78 | return [] 79 | 80 | def detect(self, group_id, start_time=time.time()): 81 | mems: list[Type[Members]] = self.queryByGroupId(group_id) 82 | ans = [] 83 | if mems: 84 | for mem in mems: 85 | if len(mem.special_title) == 0 and mem.join_date <= start_time: 86 | ans.append(mem.user_id) 87 | return ans 88 | 89 | def queryByData(self, oneData: dict): 90 | user_id = oneData.get('user_id') 91 | group_id = oneData.get('group_id') 92 | return self.query(group_id, user_id) 93 | 94 | def addOne(self, oneData: dict): 95 | user_id = oneData.get('user_id') 96 | group_id = oneData.get('group_id') 97 | user_name = oneData.get('user_name') 98 | sex = oneData.get('sex') 99 | age = oneData.get('age') 100 | title = oneData.get('title') 101 | title_expire_time = oneData.get('title_expire_time') 102 | nickname = oneData.get('nickname') 103 | user_displayname = oneData.get('user_displayname') 104 | card = oneData.get('card') 105 | distance = oneData.get('distance') 106 | honor = oneData.get('honor') 107 | join_time = oneData.get('join_time') 108 | last_active_time = oneData.get('last_active_time') 109 | last_sent_time = oneData.get('last_sent_time') 110 | unique_name = oneData.get('unique_name') 111 | area = oneData.get('area') 112 | level = oneData.get('level') 113 | role = oneData.get('role') 114 | unfriendly = oneData.get('unfriendly') 115 | card_changeable = oneData.get('card_changeable') 116 | shut_up_timestamp = oneData.get('shut_up_timestamp') 117 | messageRecord = Members(group_id=group_id, user_id=user_id, qq_level=level, age=age, sex=sex, nick_name=nickname, 118 | name_card=user_displayname, special_title=title, join_date=join_time, last_speak_date=last_sent_time, 119 | ctime=int(time.time()), utime=int(time.time()), enable=1) 120 | session.add(messageRecord) 121 | session.commit() 122 | 123 | from sqlalchemy import update 124 | 125 | def updateOne(self, oneData: dict, enable=True): 126 | user_id = oneData.get('user_id') 127 | group_id = oneData.get('group_id') 128 | user_name = oneData.get('user_name') 129 | sex = oneData.get('sex') 130 | age = oneData.get('age') 131 | title = oneData.get('title') 132 | title_expire_time = oneData.get('title_expire_time') 133 | nickname = oneData.get('nickname') 134 | user_displayname = oneData.get('user_displayname') 135 | card = oneData.get('card') 136 | distance = oneData.get('distance') 137 | honor = oneData.get('honor') 138 | join_time = oneData.get('join_time') 139 | last_active_time = oneData.get('last_active_time') 140 | last_sent_time = oneData.get('last_sent_time') 141 | unique_name = oneData.get('unique_name') 142 | area = oneData.get('area') 143 | level = oneData.get('level') 144 | role = oneData.get('role') 145 | unfriendly = oneData.get('unfriendly') 146 | card_changeable = oneData.get('card_changeable') 147 | shut_up_timestamp = oneData.get('shut_up_timestamp') 148 | 149 | # 使用 SQLAlchemy 的 update 语句更新数据 150 | update_statement = update(Members).where(Members.user_id == user_id, Members.group_id == group_id).values( 151 | group_id=group_id, 152 | qq_level=level, 153 | age=age, 154 | sex=sex, 155 | nick_name=nickname, 156 | name_card=user_displayname, 157 | special_title=title, 158 | join_date=join_time, 159 | last_speak_date=last_sent_time, 160 | utime=int(time.time()), 161 | enable=int(enable) 162 | ) 163 | 164 | session.execute(update_statement) 165 | session.commit() 166 | 167 | def updateEnable(self, group_id, user_id, enable=False): 168 | update_statement = update(Members).where(Members.user_id == user_id, Members.group_id == group_id).values( 169 | utime=int(time.time()), 170 | enable=int(enable) 171 | ) 172 | session.execute(update_statement) 173 | session.commit() 174 | 175 | def updateTitle(self, group_id, user_id, title): 176 | if self.query(group_id, user_id) is not None: 177 | update_statement = update(Members).where(Members.user_id == user_id, Members.group_id == group_id).values( 178 | utime=int(time.time()), 179 | special_title=title 180 | ) 181 | session.execute(update_statement) 182 | session.commit() 183 | 184 | async def updateGroup(self, group_id: int or str): 185 | try: 186 | if isinstance(group_id, str): 187 | group_id = int(group_id) 188 | bot = self.getSelfBot() 189 | data = await bot.get_group_member_list(group_id=group_id) 190 | onlineSet = set() 191 | for oneData in data: 192 | user_id = oneData.get('user_id') 193 | group_id = oneData.get('group_id') 194 | onlineSet.add(str(group_id) + str(user_id)) 195 | mem = self.query(group_id, user_id) 196 | if mem is None: 197 | self.addOne(oneData) 198 | else: 199 | self.updateOne(oneData, True) 200 | all = self.queryByGroupId(group_id) 201 | for mem in all: 202 | user_id = mem.user_id 203 | group_id = mem.group_id 204 | key = str(group_id) + str(user_id) 205 | if key not in onlineSet: 206 | d = self.query(group_id, user_id) 207 | if d.enable == 1: 208 | self.updateEnable(group_id, user_id, False) 209 | return True 210 | except Exception as e: 211 | import traceback 212 | traceback.print_exception(e) 213 | return False 214 | 215 | membersService = MembersService() 216 | 217 | if __name__ == '__main__': 218 | # membersService = MembersService() 219 | # print(membersService.queryAll()) 220 | # membersService.addOne({'user_id': 1084701532, 'group_id': 494611635, 'user_name': 'huozhe', 'sex': 'female', 'age': 0, 'title': '', 'title_expire_time': 0, 'nickname': 'huozhe', 'user_displayname': '火者\u2067汪 \u2067\u202d\u202d\u202d', 'card': '火者\u2067汪 \u2067\u202d\u202d\u202d', 'distance': 100, 'honor': [], 'join_time': 1668395651, 'last_active_time': 1706601769, 'last_sent_time': 1706601769, 'unique_name': '', 'area': '', 'level': 10315, 'role': 'owner', 'unfriendly': False, 'card_changeable': True, 'shut_up_timestamp': 0}) 221 | print(membersService.query(536348689, 477751243)) 222 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/update_members/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | from nonebot import get_driver, on_command 6 | from nonebot.adapters.onebot.v11 import Event, PrivateMessageEvent 7 | from nonebot.internal.permission import Permission 8 | from nonebot.plugin import PluginMetadata 9 | 10 | from .MembersUtil import MembersService, membersService 11 | from .config import Config 12 | from nonebot import logger 13 | 14 | __plugin_meta__ = PluginMetadata( 15 | name="update_members", 16 | description="", 17 | usage="", 18 | config=Config, 19 | ) 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) 21 | from CyberFriend_bot_plugin.common.CustomChecker import is_me 22 | 23 | global_config = get_driver().config 24 | config = Config.parse_obj(global_config) 25 | 26 | 27 | p = Permission(is_me) 28 | member_update = on_command("update", permission=p) 29 | 30 | 31 | 32 | @member_update.handle() 33 | async def handle_function(event: Event): 34 | msg: str = event.get_message().__str__() 35 | # logger.warning("msg:"+msg) 36 | ans = [] 37 | if msg.strip() != "/update": 38 | todo = re.split(r"\s+", msg)[1:] 39 | else: 40 | todo = ["647155255"] 41 | # logger.warning("msg:"+str(todo)) 42 | for i in todo: 43 | tmp = await membersService.updateGroup(i) 44 | toAp = "OK" if tmp else "FAIL" 45 | ans.append(i + ":" + toAp) 46 | await member_update.finish(str(ans)) 47 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/plugins/update_members/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | 3 | 4 | class Config(BaseModel, extra=Extra.ignore): 5 | """Plugin Config Here""" 6 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "CyberFriend" 3 | version = "0.1.0" 4 | description = "CyberFriend" 5 | readme = "README.md" 6 | requires-python = ">=3.8, <4.0" 7 | 8 | [tool.nonebot] 9 | adapters = [ 10 | { name = "OneBot V11", module_name = "nonebot.adapters.onebot.v11" } 11 | ] 12 | plugins = ["nonebot_plugin_apscheduler"] 13 | plugin_dirs = ["plugins"] 14 | builtin_plugins = [] 15 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/record_data/create_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | from CyberFriend_bot_plugin.GetPathUtil import getPath 8 | 9 | def read_prompt_file(prompt_file_path): 10 | with open(prompt_file_path, 'r', encoding='utf-8') as file: 11 | return file.read() 12 | 13 | def get_consecutive_chat_records(data, num_records=15): 14 | if len(data) <= num_records: 15 | return None, None 16 | start_index = random.randint(0, len(data) - num_records - 1) 17 | selected_records = data[start_index:start_index + num_records] 18 | 19 | # 检查以确保 next_record_index 在数据范围内 20 | next_record_index = start_index + num_records 21 | if next_record_index >= len(data): 22 | return selected_records, None # 如果越界,就返回无assistant回复 23 | 24 | next_user_id = data[next_record_index]['user_id'] 25 | assistant_messages = [data[next_record_index]['message']] 26 | 27 | # 向后查找连续消息 28 | for i in range(next_record_index + 1, len(data)): 29 | if data[i]['user_id'] == next_user_id: 30 | assistant_messages.append(data[i]['message']) 31 | else: 32 | break 33 | 34 | return selected_records, assistant_messages 35 | 36 | def create_dataset_entry(prompt_content, chat_records, assistant_messages): 37 | conversations = [{'role': 'system', 'content': prompt_content}] 38 | user_conversation = { 39 | 'role': 'user', 40 | 'content': str([{str(record['user_id']): record['message']} for record in chat_records]) 41 | } 42 | assistant_conversation = { 43 | 'role': 'assistant', 44 | 'content': str(assistant_messages) 45 | } 46 | conversations.extend([user_conversation, assistant_conversation]) 47 | return {'conversations': conversations} 48 | 49 | def generate_datasets(prompt_file_path, json_file_paths, num_datasets=180000): 50 | datasets = [] 51 | prompt_content = read_prompt_file(prompt_file_path) 52 | all_data = [] 53 | 54 | for json_file_path in json_file_paths: 55 | with open(json_file_path, 'r', encoding='utf-8') as file: 56 | data = json.load(file) 57 | all_data.extend(data) 58 | 59 | while len(datasets) < num_datasets and len(all_data) > 0: 60 | chat_records, assistant_messages = get_consecutive_chat_records(all_data) 61 | if chat_records is None or assistant_messages is None: 62 | continue # 跳过这次循环迭代,不添加当前的数据集条目 63 | dataset_entry = create_dataset_entry(prompt_content, chat_records, assistant_messages) 64 | datasets.append(dataset_entry) 65 | # 为了避免重复处理相同的记录,这里可以根据需要调整all_data的裁剪逻辑 66 | 67 | return datasets 68 | 69 | prompt_file_path = getPath('plugins\cyber_friend\prompt.txt') 70 | json_file_paths = [getPath('record_data/536348689_2024-02-15.json'),getPath('record_data/793626723_2024-02-15.json')] 71 | 72 | # 生成数据集 73 | datasets = generate_datasets(prompt_file_path, json_file_paths) 74 | 75 | # 保存生成的数据集 76 | output_file_path = 'train.json' 77 | with open(output_file_path, 'w', encoding='utf-8') as f: 78 | f.write(json.dumps(datasets, ensure_ascii=False) + '\n') 79 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/record_data/get_records.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import json 3 | import datetime 4 | import re 5 | import os 6 | import sys 7 | 8 | # Add the parent directory to the path to import GetPathUtil 9 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 10 | from CyberFriend_bot_plugin.GetPathUtil import getPath 11 | 12 | def get_database_path(): 13 | return getPath('plugins\\message_record\\message_record.db') 14 | 15 | def execute_query(db_path, target_group_id): 16 | query = "SELECT user_id, message FROM message_record WHERE group_id = ?" 17 | with sqlite3.connect(db_path) as conn: 18 | cursor = conn.cursor() 19 | cursor.execute(query, (target_group_id,)) 20 | return cursor.fetchall() 21 | 22 | def clean_messages(rows): 23 | pattern = r"\[CQ:.*?\]" 24 | extracted_data = [] 25 | for user_id, message in rows: 26 | cleaned_message = re.sub(pattern, '', message).strip() 27 | if cleaned_message: 28 | extracted_data.append({'user_id': user_id, 'message': cleaned_message}) 29 | return extracted_data 30 | 31 | def save_json(data, path): 32 | with open(path, 'w', encoding='utf-8') as f: 33 | json.dump(data, f, indent=4, ensure_ascii=False) 34 | 35 | def ensure_directory_exists(path): 36 | os.makedirs(os.path.dirname(path), exist_ok=True) 37 | 38 | def extract_and_convert_to_json(db_path, target_group_id): 39 | rows = execute_query(db_path, target_group_id) 40 | extracted_data = clean_messages(rows) 41 | json_data = json.dumps(extracted_data, indent=4, ensure_ascii=False) 42 | return json_data 43 | 44 | def main(target_group_id): 45 | db_path = get_database_path() 46 | json_result = extract_and_convert_to_json(db_path, target_group_id) 47 | output_path = getPath(f'record_data\\{target_group_id}_{datetime.date.today()}.json') 48 | ensure_directory_exists(output_path) 49 | save_json(json_result, output_path) 50 | 51 | if __name__ == "__main__": 52 | target_group_id = 536348689 53 | main(target_group_id) 54 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/record_data/query_number.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import datetime 3 | import sys 4 | import os 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 6 | from CyberFriend_bot_plugin.GetPathUtil import getPath 7 | 8 | database_path = getPath('plugins\message_record\message_record.db') 9 | 10 | def count_messages(db_path, target_group_id): 11 | conn = sqlite3.connect(db_path) 12 | cursor = conn.cursor() 13 | query = f"SELECT COUNT(*) FROM message_record WHERE group_id = {target_group_id}" 14 | cursor.execute(query) 15 | count = cursor.fetchone()[0] 16 | 17 | conn.close() 18 | 19 | return count 20 | 21 | target_group_id = 536348689 22 | message_count = count_messages(database_path, target_group_id) 23 | print(f"Group ID {target_group_id} has {message_count} messages recorded.") 24 | -------------------------------------------------------------------------------- /CyberFriend_bot_plugin/requirements.txt: -------------------------------------------------------------------------------- 1 | SQLAlchemy 2 | nb-cli 3 | nonebot2[fastapi] 4 | nonebot-adapter-console 5 | nonebot-adapter-onebot 6 | nonebot_plugin_apscheduler -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 h 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /finetune_and_restart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the absolute path for the CyberFriend project 4 | PROJECT_DIR=$(pwd) 5 | 6 | # Ensure the correct Python command is used 7 | PYTHON_CMD=python3 8 | if ! command -v $PYTHON_CMD &> /dev/null; then 9 | PYTHON_CMD=python 10 | fi 11 | 12 | # Stop the CyberFriendCore session gracefully 13 | tmux send-keys -t CyberFriendCore C-c 14 | sleep 10 15 | 16 | # Run the fine-tuning script 17 | $PYTHON_CMD $PROJECT_DIR/CyberFriend_LLM_core/finetune/finetune_hf.py $PROJECT_DIR/data/ /chatglm3-6b $PROJECT_DIR/CyberFriend_LLM_core/finetune/configs/configs/lora.yaml 18 | 19 | # Restart the CyberFriendCore session 20 | tmux send-keys -t CyberFriendCore "$PYTHON_CMD $PROJECT_DIR/CyberFriend_LLM_core/api_server.py" Enter -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 🚧 CyberFriend: 你的赛博群友🤖 2 | 这是一个由大语言模型 (LLM) 技术驱动的,一个模仿现代人类互联网聊天风格的机器人。它同时会使用群管理工具、互联网搜索引擎、带有记忆(本地向量数据库),为各种聊天群聊提供广泛的功能、活跃群内气氛。 3 | 4 | 技术上,项目使用 Nonebot 作为机器人交互框架,通过使用各种聊天协议适配器,从而实现支持任何聊天平台(如 QQ、Telegram 等等);目前使用 ChatGLM3 作为 LLM 模型,通过收集群内聊天记录来微调模型(或者直接调用ChatGPT-4o等能力强的大模型),可以实现模拟真实群友的功能。 5 | 6 | [代码结构和说明](resources/code_structure.md) 7 | 8 | 项目架构图: 9 | ![](resources/proj_structure.png) 10 | 11 | 目前正在开发阶段,初步效果非常炸裂,欢迎加入我们! 12 | 13 | ## 功能规划 14 | 15 | 1. **基础对话引擎**: 16 | - **核心对话处理**:使用大型语言模型(LLM)为基础,创建能够处理日常对话的AI。 17 | - **上下文理解**:在对话中保持上下文的连贯性,保存聊天记录。 18 | 19 | 2. **个性化与适应性**: 20 | - **用户行为分析**:分析群内的对话内容和风格,更适应每个群。 21 | - **持续学习**:根据大量的聊天记录进行微调,能够更像真人。 22 | - **RAG**:拥有“群聊知识库”,可以进行群聊私域的记录和检索。🚧 23 | - **模拟机制**:能够模拟单独的某个人的语气。🚧 24 | - **调用群聊操作**:能够调用群聊相关操作,例如撤回、退群、禁言等等。🚧 25 | - **调用工具**:能够调用相关工具,例如搜索、绘图等等。🚧 26 | 27 | 3. **场景模拟**: 28 | - **特定场景**:开发适用于不同社交场景的对话模式(如朋友聚会、工作讨论等)。🚧 29 | - **角色扮演能力**:能够根据用户的需要扮演不同的社交角色(如朋友、同事)。🚧 30 | 31 | 4. **集成与扩展性**: 32 | - **API接口**:提供API接口,允许第三方应用程序和服务集成。 33 | - **模块化设计**:确保系统的高度模块化,便于未来的扩展和升级。 34 | - **多种协议支持**:独立出来和聊天平台的适配器。 35 | 36 | ## 使用 37 | 38 | 如项目架构图所示,本项目可分为三大模块,分别单独部署:聊天平台及适配器(目前项目采用 QQ+Shamrock)、Nonebot机器人框架中控与数据库、LLM模块。 39 | 40 | - **须知:请确保您拥有py开发经验,且为了防止库版本冲突影响出现报错,请在创建venv环境再进行后续操作,推荐使用pycharm** 41 | 42 | 43 | 下面依次介绍这三大模块的部署操作。 44 | 45 | ### 聊天平台适配器的部署 46 | 47 | 请参考 [Shamrock-快速开始](https://trumanin2023.github.io/Shamrock/guide/getting-started.html) 48 | 或者你可以使用其他qq的适配器,例如[Napcat](https://napneko.github.io/zh-CN/),同样使用[Onebot v11](https://onebot.dev/)标准,你只需要填写好适配器和Nonebot的连接即可。 49 | 50 | ### Nonebot及数据库的部署 51 | 52 | 首先请参考 [Nonebot-快速开始](https://nonebot.dev/docs/),或者在 `CyberFriend_bot_plugin` 中,运行 `pip install -r requirements.txt`,安装相关依赖。 53 | 54 | 然后,在 `CyberFriend_bot_plugin` 中,配置好 `.env.prod` 中,你的 Shamrock 的端口,之后执行 `nb run` 命令。 55 | 56 | 理想状态下,输出 `[INFO] websockets | connection open`,即为运行成功,数据库也会一并开启。 57 | 58 | ### LLM 模块的部署 59 | 60 | 目前我们使用 [`ChatGLM3`](https://github.com/THUDM/ChatGLM3) 作为 LLM 模型。 61 | 62 | 首先,进入 `CyberFriend_LLM_core`,执行 `pip install -r requirements.txt & python download.py`,安装依赖并下载模型权重,权重会下载到这个目录下,你可以修改文件中的 `cache_dir` 来更改下载路径,或是使用别的下载方法,例如直接从 [chatglm3-6b-huggingface](https://huggingface.co/THUDM/chatglm3-6b) 上下载。 63 | 64 | 然后,执行 `python api_server.py`,运行开启 LLM 模型的 api 服务器。 65 | 66 | 要进行模型微调,你需要参考 [chatglm3-微调](https://github.com/THUDM/ChatGLM3/tree/main/finetune_demo),使用 `CyberFriend_bot_plugin/record_data` 中的工具来构建数据集,然后在 `CyberFriend_LLM_core/finetune` 中,执行 `python finetune_hf.py /path/to/dataset /path/to/model configs/lora.yaml` 来进行微调。 67 | 68 | 至此,整个项目的数据流就打通了,你可以正常使用了。 69 | 70 | > 模型微调与一键运行的脚本正在构建与测试中...之后会更新 71 | 72 | ## ToDo: 73 | - RAG 自动记录与检索 74 | - 群管理相关操作 75 | - 网络搜索调用功能 76 | - 回复特定关注的消息功能 77 | - 主动发起消息功能 78 | - 创建网页游戏:找出 AI 群友 79 | -------------------------------------------------------------------------------- /resources/code_structure.md: -------------------------------------------------------------------------------- 1 | 功能性代码结构和功能说明如下: 2 | ``` 3 | CyberFriend 4 | ├── CyberFriend_LLM_core # CyberFriend的核心大模型(LLM)部分 5 | │ ├── ChromaRag.py # RAG 部分 6 | │ ├── api_server.py # API 服务器 7 | │ ├── download.py # 用于下载 ChatGLM3 的程序 8 | │ ├── finetune # 微调相关的目录,包含配置和微调脚本 9 | │ │ ├── configs # 微调配置文件存放的位置,使用 lora 即可 10 | │ │ └── finetune_hf.py # 使用 Hugging Face 库进行微调模型的程序 11 | │ └── utils.py # 工具函数 12 | ├── CyberFriend_bot_plugin # CyberFriend 的机器人插件部分 13 | │ ├── GetPathUtil.py # 获取路径的工具 14 | │ ├── common # 公共功能模块,包含各种工具和检查器 15 | │ │ ├── CustomChecker.py # 自定义检查功能 16 | │ │ ├── MembersOptUtil.py # 成员操作工具 17 | │ │ └── MessageBuilder.py # 消息构建器 18 | │ ├── plugins # 插件目录,包含各种机器人插件 19 | │ │ ├── add_image_to_db # 将图片添加到数据库的插件 20 | │ │ ├── cyber_friend # CyberFriend 核心插件 21 | │ │ │ ├── prompt.txt # prompt 文本文件 22 | │ │ │ └── utils.py # 工具脚本 23 | │ │ ├── group_handle # 群组处理插件 24 | │ │ ├── member_join # 成员加入处理插件 25 | │ │ ├── member_leave # 成员离开处理插件 26 | │ │ ├── message_record # 消息记录插件,包含图像工具、获取记录和其他工具 27 | │ │ │ ├── ImageUtil.py # 图像处理工具 28 | │ │ │ ├── get_record.py # 获取记录的脚本 29 | │ │ │ └── util.py # 通用工具脚本 30 | │ │ ├── scheduler # 计划任务插件 31 | │ │ └── update_members # 更新成员信息的插件 32 | │ │ └── MembersUtil.py # 成员工具脚本 33 | │ ├── pyproject.toml # Python项目配置文件 34 | │ └── record_data # 记录数据的目录 35 | │ ├── create_dataset.py # 创建数据集的脚本 36 | │ ├── get_records.py # 获取记录的脚本 37 | │ └── query_number.py # 查询编号的脚本 38 | ├── finetune_and_restart.sh # 微调模型并重启服务的脚本 39 | └── run.sh # 运行 CyberFriend 服务的脚本 40 | 41 | ``` -------------------------------------------------------------------------------- /resources/proj_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holk-h/CyberFriend/a047c16e68488084bff949d78d1646c3516e8586/resources/proj_structure.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the absolute path for the CyberFriend project. Adjust this as needed. 4 | PROJECT_DIR=$(pwd) 5 | 6 | # Check if tmux is installed 7 | if ! command -v tmux &> /dev/null; then 8 | echo "tmux could not be found. Please install tmux." 9 | exit 1 10 | fi 11 | 12 | # Determine the correct Python and pip commands 13 | PYTHON_CMD=python3 14 | PIP_CMD=pip3 15 | if ! command -v $PYTHON_CMD &> /dev/null; then 16 | PYTHON_CMD=python 17 | PIP_CMD=pip 18 | fi 19 | 20 | # Ensure pip is installed for the determined Python command 21 | if ! $PYTHON_CMD -m pip --version &> /dev/null; then 22 | echo "pip could not be found for $PYTHON_CMD. Please ensure pip is installed." 23 | exit 1 24 | fi 25 | 26 | # Navigate to the project directory 27 | cd $PROJECT_DIR 28 | 29 | # Start the CyberFriendCore session 30 | tmux new-session -d -s CyberFriendCore "$PYTHON_CMD -m $PIP_CMD install -r $PROJECT_DIR/CyberFriend_LLM_core/requirements.txt; $PYTHON_CMD $PROJECT_DIR/CyberFriend_LLM_core/api_server.py" 31 | echo "API started in a tmux session named CyberFriendCore, use 'tmux attach -t CyberFriendCore' to attach to the session." 32 | 33 | # Start the CyberFriendBotPlugin session 34 | tmux new-session -d -s CyberFriendBotPlugin "cd $PROJECT_DIR/CyberFriend_bot_plugin && $PYTHON_CMD -m $PIP_CMD install -r requirements.txt && nb run" 35 | echo "CyberFriendBotPlugin started in a tmux session named CyberFriendBotPlugin, use 'tmux attach -t CyberFriendBotPlugin' to attach to the session." 36 | 37 | # Schedule the cron job for daily execution at 4 AM 38 | CRON_JOB="0 4 * * * $PROJECT_DIR/finetune_and_restart.sh" 39 | (crontab -l 2>/dev/null; echo "$CRON_JOB") | crontab - 40 | 41 | echo "Setup complete. Fine-tuning and restart scheduled at 4 AM daily." 42 | --------------------------------------------------------------------------------