├── README_zh.md ├── webui ├── __init__.py └── chatbot_demo.py ├── weaverbird ├── __init__.py ├── chains │ ├── __init__.py │ └── chat_retro │ │ ├── __init__.py │ │ └── prompt.py ├── models │ ├── neural_retriever.py │ ├── ranker.py │ ├── __init__.py │ ├── chat_model.py │ ├── chat_weaverbird.py │ ├── chat_glm2.py │ ├── chat_llama2.py │ ├── llm_loader.py │ └── template.py ├── embeddings │ ├── __init__.py │ └── query_ref_encoder.py ├── document_loaders │ ├── __init__.py │ └── local_kb_loader.py ├── config_factory │ ├── base_config.py │ ├── __init__.py │ ├── retro_config.py │ ├── generation_config.py │ ├── basemodel_config.py │ └── fintune_config.py ├── utils │ ├── kb_utils.py │ ├── __init__.py │ ├── const.py │ ├── chatbot_utils.py │ ├── log_utils.py │ ├── misc.py │ └── registrable.py ├── retrievers │ ├── __init__.py │ ├── adaptive_retro.py │ └── web_searcher.py └── cn_text_splitter.py ├── docs └── figures │ └── webui.jpg ├── requirements.txt ├── scripts ├── check_prompts.py ├── load_llm_model.py ├── run_web_searcher.py ├── chat_llm_model.py ├── init_local_kb.py └── train_encoder.py ├── README.md └── .gitignore /README_zh.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webui/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weaverbird/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weaverbird/chains/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weaverbird/models/neural_retriever.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weaverbird/chains/chat_retro/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/figures/webui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/fin_domain_llm/HEAD/docs/figures/webui.jpg -------------------------------------------------------------------------------- /weaverbird/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from weaverbird.embeddings.query_ref_encoder import QueryRefEncoder 2 | 3 | __all__ = ['QueryRefEncoder'] 4 | -------------------------------------------------------------------------------- /weaverbird/document_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from weaverbird.document_loaders.local_kb_loader import LocalKnowledgeBaseLoader 2 | 3 | __all__ = ['LocalKnowledgeBaseLoader'] 4 | -------------------------------------------------------------------------------- /weaverbird/config_factory/base_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | 3 | 4 | @dataclass 5 | class BaseConfig: 6 | 7 | def dict(self): 8 | return {k: v for k, v in asdict(self).items()} 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.1 2 | transformers>=4.29.1 3 | nltk 4 | pandas>=2.0.1 5 | numpy>=1.25.0 6 | sentence-transformers 7 | gradio>=3.38.3 8 | langchain>=0.0.266 9 | google-search-results 10 | dateparser 11 | tqdm 12 | tiktoken 13 | -------------------------------------------------------------------------------- /scripts/check_prompts.py: -------------------------------------------------------------------------------- 1 | from weaverbird.chains.chat_retro.prompt import CHAT_RETRO_EN_PROMPT 2 | 3 | if __name__ == '__main__': 4 | prompt = CHAT_RETRO_EN_PROMPT.format(context='hello', date='20200930', question='what is nasdaq close price') 5 | print(prompt) 6 | -------------------------------------------------------------------------------- /weaverbird/utils/kb_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_kbs_list(kb_root_dir): 5 | lst_default = ["None"] 6 | if not os.path.exists(kb_root_dir): 7 | return lst_default 8 | lst = os.listdir(kb_root_dir) 9 | if not lst: 10 | return lst_default 11 | lst.sort() 12 | return lst_default + lst -------------------------------------------------------------------------------- /scripts/load_llm_model.py: -------------------------------------------------------------------------------- 1 | from weaverbird.models import load_model_and_tokenizer 2 | from weaverbird.utils import parse_configs 3 | 4 | 5 | def main(): 6 | model_config_dict = {'model_name_or_dir': 'chatglm2-6b'} 7 | 8 | configs = parse_configs(model_config_dict) 9 | 10 | load_model_and_tokenizer(configs['model_config']) 11 | 12 | return 13 | 14 | 15 | if __name__ == '__main__': 16 | main() -------------------------------------------------------------------------------- /weaverbird/config_factory/__init__.py: -------------------------------------------------------------------------------- 1 | from weaverbird.config_factory.basemodel_config import BaseModelConfig 2 | from weaverbird.config_factory.fintune_config import FinetuningConfig 3 | from weaverbird.config_factory.generation_config import GenerationConfig 4 | from weaverbird.config_factory.retro_config import RetroConfig 5 | 6 | __all__ = ['BaseModelConfig', 7 | 'GenerationConfig', 8 | 'FinetuningConfig', 9 | 'RetroConfig'] 10 | -------------------------------------------------------------------------------- /weaverbird/models/ranker.py: -------------------------------------------------------------------------------- 1 | from weaverbird.utils import Registrable 2 | 3 | 4 | class ReRanker(Registrable): 5 | def __init__(self, top_j): 6 | self.top_j = top_j 7 | 8 | def rank(self, docs): 9 | pass 10 | 11 | 12 | @ReRanker.register(name='score_reranker') 13 | class ScoreReranker(ReRanker): 14 | def __init__(self, top_j=5): 15 | super(ScoreReranker, self).__init__(top_j=top_j) 16 | 17 | def rank(self, docs): 18 | docs.sort(key=lambda x: x.metadata["score"], reverse=True) 19 | return docs[:self.top_j] 20 | -------------------------------------------------------------------------------- /weaverbird/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from weaverbird.config_factory import RetroConfig 2 | from weaverbird.retrievers.web_searcher import WebSearcher 3 | 4 | __all__ = ['WebSearcher'] 5 | 6 | 7 | class BaseRetro: 8 | @staticmethod 9 | def build_from_config(retro_config: RetroConfig): 10 | 11 | for retro_cls in __all__: 12 | if eval(retro_cls).Config.retro_name == retro_config.retro_name: 13 | return eval(retro_cls).build_from_config(retro_config) 14 | 15 | raise NotImplementedError('Retro Model retro_config.retro_name not implemented.') 16 | -------------------------------------------------------------------------------- /scripts/run_web_searcher.py: -------------------------------------------------------------------------------- 1 | from weaverbird.retrievers import BaseRetro 2 | from weaverbird.utils import parse_configs 3 | 4 | 5 | def main(): 6 | search_config = {'model_name_or_path': None, 7 | 'serp_api_token': 'xxx'} 8 | 9 | configs = parse_configs(search_config) 10 | 11 | search_config = configs['retro_config'] 12 | 13 | web_searcher_cls = BaseRetro.build_from_config(search_config) 14 | 15 | results = web_searcher_cls.get_relevant_documents('what does Elon Musk think of BYD') 16 | 17 | print(results) 18 | 19 | return 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /weaverbird/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from weaverbird.utils.chatbot_utils import get_base_url, parse_text 2 | from weaverbird.utils.const import Language 3 | from weaverbird.utils.kb_utils import get_kbs_list 4 | from weaverbird.utils.log_utils import default_logger as logger 5 | from weaverbird.utils.misc import count_parameters, parse_configs, dispatch_model, get_logits_processor, load_yaml_config 6 | from weaverbird.utils.registrable import Registrable 7 | 8 | __all__ = ['get_kbs_list', 9 | 'get_base_url', 10 | 'parse_text', 11 | 'Language', 12 | 'Registrable', 13 | 'count_parameters', 14 | 'parse_configs', 15 | 'dispatch_model', 16 | 'get_logits_processor', 17 | 'load_yaml_config'] 18 | -------------------------------------------------------------------------------- /scripts/chat_llm_model.py: -------------------------------------------------------------------------------- 1 | from langchain import LLMChain 2 | 3 | from weaverbird.chains.chat_retro.prompt import CHAT_RETRO_EN_PROMPT 4 | from weaverbird.models import ChatGLM2 5 | from weaverbird.utils import parse_configs 6 | 7 | 8 | def main(): 9 | model_config_dict = {'model_name_or_path': 'chatglm2-6b'} 10 | 11 | configs = parse_configs(model_config_dict) 12 | 13 | chat_model = ChatGLM2(configs['model_config'], generation_config=configs['generation_config']) 14 | 15 | chat_prompt = CHAT_RETRO_EN_PROMPT 16 | 17 | chain = LLMChain(prompt=chat_prompt, llm=chat_model, verbose=True) 18 | 19 | print(chain({'context': 'hello', 'date': '20200930', 'question': 'what is nasdaq close price'})) 20 | 21 | return 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /weaverbird/utils/const.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ExplicitEnum(str, Enum): 5 | """ 6 | Enum with more explicit error message for missing values. 7 | """ 8 | 9 | def __str__(self): 10 | return str(self.value) 11 | 12 | @classmethod 13 | def _missing_(cls, value): 14 | raise ValueError( 15 | f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" 16 | ) 17 | 18 | class LogConst(ExplicitEnum): 19 | """Format for log handler. 20 | """ 21 | DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s' 22 | DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s]' \ 23 | ' - %(levelname)s: %(message)s' 24 | 25 | 26 | class Language(ExplicitEnum): 27 | EN = 'en' 28 | CN = 'cn' -------------------------------------------------------------------------------- /weaverbird/config_factory/retro_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Literal, Optional 3 | 4 | from weaverbird.config_factory.base_config import BaseConfig 5 | 6 | 7 | @dataclass 8 | class RetroConfig(BaseConfig): 9 | retro_name: Optional[Literal["web_searcher_retro", "adaptive_retro"]] = field( 10 | default="web_searcher_retro", 11 | metadata={"help": "Which retro model cls to initialize."} 12 | ) 13 | 14 | serp_api_token: Optional[bool] = field( 15 | default=None, 16 | metadata={"help": "Token for serpapi. See https://serpapi.com/search-api"} 17 | ) 18 | 19 | engine: Optional[Literal["google", "bing", "baidu"]] = field( 20 | default="google", 21 | metadata={"help": "Which search engine to use."} 22 | ) 23 | 24 | num_search_results: Optional[int] = field( 25 | default=10, 26 | metadata={"help": "Number of pages per search"} 27 | ) 28 | -------------------------------------------------------------------------------- /weaverbird/models/__init__.py: -------------------------------------------------------------------------------- 1 | from weaverbird.models.chat_glm2 import ChatGLM2 2 | from weaverbird.models.chat_llama2 import ChatLlama2 3 | from weaverbird.models.chat_weaverbird import ChatWeaverBird 4 | from weaverbird.models.llm_loader import load_model_and_tokenizer 5 | 6 | __all__ = ['load_model_and_tokenizer', 7 | 'ChatGLM2', 8 | 'ChatLlama2', 9 | 'ChatWeaverBird'] 10 | 11 | 12 | class BaseModel: 13 | @staticmethod 14 | def build_from_config(model_config, **kwargs): 15 | """Build up the runner from runner config. 16 | 17 | Args: 18 | runner_config (RunnerConfig): config for the runner. 19 | 20 | Returns: 21 | Runner: the corresponding runner class. 22 | """ 23 | if 'glm' in model_config.model_name_or_path.lower(): 24 | model_cls = ChatGLM2 25 | else: 26 | model_cls = ChatLlama2 27 | return model_cls(model_config, **kwargs) 28 | -------------------------------------------------------------------------------- /scripts/init_local_kb.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain import FAISS 4 | from langchain.text_splitter import TextSplitter 5 | 6 | from weaverbird.document_loaders import LocalKnowledgeBaseLoader 7 | from weaverbird.embeddings import QueryRefEncoder 8 | 9 | 10 | class NewFinDocTextSplitter(TextSplitter): 11 | def __init__(self): 12 | super(NewFinDocTextSplitter, self).__init__() 13 | 14 | def split_text(self, text: str) -> List[str]: 15 | # split by "Doc" because text has \n 16 | documents = text.split("Doc ")[1:] 17 | 18 | return documents 19 | 20 | 21 | def main(): 22 | text_splitter = NewFinDocTextSplitter() 23 | loader = LocalKnowledgeBaseLoader("report_cn_v0724.txt", text_splitter=text_splitter) 24 | docs = loader.load() 25 | print(len(docs)) 26 | 27 | model_dir = 'encoder' 28 | embeddings = QueryRefEncoder(model_dir=model_dir) 29 | db = FAISS.from_documents(docs, embeddings) 30 | 31 | query = "迈瑞医疗(300760)2022 年三季报发布的业绩是多少" 32 | docs = db.similarity_search(query) 33 | 34 | print(docs[0].page_content) 35 | 36 | return 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /weaverbird/utils/chatbot_utils.py: -------------------------------------------------------------------------------- 1 | def get_base_url(url): 2 | return url.split('/')[-1] 3 | 4 | 5 | def parse_text(text): 6 | """ref: https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 7 | lines = text.split("\n") 8 | lines = [line for line in lines if line != ""] 9 | count = 0 10 | for i, line in enumerate(lines): 11 | if "```" in line: 12 | count += 1 13 | items = line.split('`') 14 | if count % 2 == 1: 15 | lines[i] = f'
'
16 |             else:
17 |                 lines[i] = f'
' 18 | else: 19 | if i > 0: 20 | if count % 2 == 1: 21 | line = line.replace("`", "\`") 22 | line = line.replace("<", "<") 23 | line = line.replace(">", ">") 24 | line = line.replace(" ", " ") 25 | line = line.replace("*", "*") 26 | line = line.replace("_", "_") 27 | line = line.replace("-", "-") 28 | line = line.replace(".", ".") 29 | line = line.replace("!", "!") 30 | line = line.replace("(", "(") 31 | line = line.replace(")", ")") 32 | line = line.replace("$", "$") 33 | lines[i] = "
" + line 34 | text = "".join(lines) 35 | return text 36 | -------------------------------------------------------------------------------- /weaverbird/chains/chat_retro/prompt.py: -------------------------------------------------------------------------------- 1 | from langchain import PromptTemplate 2 | 3 | _EN_TEMPLATE = """Through an search, you have obtained some information. Each line represents a piece of information, and each piece is independent. It includes the posting time, title, and a snippet of the information. The closer the posting time is to the present, the more important the information. The information is not complete, and the ellipsis (...) indicates omitted sections. Here is the search result: 4 | {context} 5 | 6 | The current date is {date}. You need to answer the user's questions based on the information provided above. If there are multiple questions in a query, please answer all of them. If the user's question includes keywords like 'recent' or 'latest' to indicate a recent time frame, pay attention to the correspondence between the current date and the date of the information. You MUST respond in the same language as the question! The question is: {question}""" 7 | 8 | _CN_TEMPLATE = """通过搜索你得到了一些信息,每一行是一条信息,每条信息是独立的,其中包含了这条信息的发布时间,标题和片段,发布时间离现在越近的信息越重要,信息并不是完整的,句子中的"..."表示省略部分,以下为搜索到的信息: 9 | {context} 10 | 11 | 当前日期为{date}。你需要根据上面这些信息来回答用户的问题。如果提问中有多个问题,请一并回答。如果用户的问题中提到了类似“最近”或“最新”这样表示近期的关键词,需要注意当前日期和信息的日期对应关系。要求回答完整,答案必须使用和问题同样的语种! 问题是:{question}""" 12 | 13 | CHAT_RETRO_EN_PROMPT = PromptTemplate(input_variables=['context', 'date', 'question'], 14 | template=_EN_TEMPLATE) 15 | 16 | CHAT_RETRO_CN_PROMPT = PromptTemplate(input_variables=['context', 'date', 'question'], 17 | template=_CN_TEMPLATE) 18 | -------------------------------------------------------------------------------- /weaverbird/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import typing 4 | 5 | from .const import LogConst 6 | 7 | # -------- log setting --------- 8 | DEFAULT_LOGGER = "weaverbird.logger" 9 | 10 | 11 | class CustomFormatter(logging.Formatter): 12 | grey = "\x1b[38;20m" 13 | yellow = "\x1b[33;20m" 14 | red = "\x1b[31;20m" 15 | bold_red = "\x1b[31;1m" 16 | reset = "\x1b[0m" 17 | format = LogConst.DEFAULT_FORMAT_LONG 18 | 19 | FORMATS = { 20 | logging.DEBUG: grey + format + reset, 21 | logging.INFO: grey + format + reset, 22 | logging.WARNING: yellow + format + reset, 23 | logging.ERROR: red + format + reset, 24 | logging.CRITICAL: bold_red + format + reset 25 | } 26 | 27 | def format(self, record): 28 | log_fmt = self.FORMATS.get(record.levelno) 29 | formatter = logging.Formatter(log_fmt) 30 | return formatter.format(record) 31 | 32 | 33 | DEFAULT_FORMATTER = CustomFormatter() 34 | 35 | _ch = logging.StreamHandler(stream=sys.stdout) 36 | _ch.setFormatter(DEFAULT_FORMATTER) 37 | 38 | _DEFAULT_HANDLERS = [_ch] 39 | 40 | _LOGGER_CACHE = {} # type: typing.Dict[str, logging.Logger] 41 | 42 | 43 | def get_logger(name, level="INFO", handlers=None, update=False): 44 | if name in _LOGGER_CACHE and not update: 45 | return _LOGGER_CACHE[name] 46 | logger = logging.getLogger(name) 47 | logger.setLevel(level) 48 | logger.handlers = handlers or _DEFAULT_HANDLERS 49 | logger.propagate = False 50 | return logger 51 | 52 | 53 | # -------------------------- Singleton Object -------------------------- 54 | default_logger = get_logger(DEFAULT_LOGGER) -------------------------------------------------------------------------------- /weaverbird/retrievers/adaptive_retro.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | from langchain.callbacks.manager import CallbackManagerForRetrieverRun 6 | from langchain.embeddings.base import Embeddings 7 | from langchain.pydantic_v1 import Field 8 | from langchain.schema import BaseRetriever, Document 9 | from langchain.utilities import SerpAPIWrapper 10 | from langchain.vectorstores import VectorStore 11 | 12 | from weaverbird.config_factory import WebSearchConfig 13 | 14 | 15 | def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: 16 | """ 17 | Create an index of embeddings for a list of contexts. 18 | 19 | Borrowed from https://github.com/langchain-ai/langchain/blob/ddd07001f354cd09a76a61e1f5c678bf885506d2/ 20 | libs/langchain/langchain/retrievers/knn.py 21 | 22 | Args: 23 | contexts: List of contexts to embed. 24 | embeddings: Embeddings model to use. 25 | 26 | Returns: 27 | Index of embeddings. 28 | """ 29 | with concurrent.futures.ThreadPoolExecutor() as executor: 30 | return np.array(list(executor.map(embeddings.embed_query, contexts))) 31 | 32 | 33 | class AdaptiveRetriever(BaseRetriever): 34 | local_kb: Optional[VectorStore] = None 35 | """local knowledge base to store documents.""" 36 | 37 | web_searcher: SerpAPIWrapper = Field(..., description="Web Search API Wrapper") 38 | 39 | @classmethod 40 | def build_from_config(cls, search_config: WebSearchConfig): 41 | pass 42 | 43 | def _get_relevant_documents( 44 | self, 45 | query: str, 46 | *, 47 | run_manager: CallbackManagerForRetrieverRun, 48 | ) -> List[Document]: 49 | """Get documents relevant for a query. 50 | 51 | Args: 52 | query: string to find relevant documents for 53 | 54 | Returns: 55 | List of relevant documents 56 | """ 57 | -------------------------------------------------------------------------------- /weaverbird/models/chat_model.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | List, 4 | Optional, 5 | ) 6 | 7 | from langchain.callbacks.manager import CallbackManagerForLLMRun 8 | from langchain.chat_models.base import BaseChatModel 9 | from langchain.schema import BaseMessage, ChatResult 10 | 11 | from weaverbird.config_factory import FinetuningConfig, GenerationConfig, BaseModelConfig 12 | from weaverbird.models.llm_model import LLMModel 13 | 14 | 15 | class WeaverBirdChat(BaseChatModel): 16 | model_name: str = "weaverbird_chat" 17 | """model name of WeaverBird, default is `weaverbird_chat`""" 18 | 19 | request_timeout: Optional[int] = 60 20 | """request timeout for chat http requests""" 21 | 22 | max_retries: int = 6 23 | """Maximum number of retries to make when generating""" 24 | 25 | streaming: Optional[bool] = True 26 | """streaming mode. not supported yet""" 27 | 28 | llm_model: Optional[LLMModel] = None 29 | """LLM model to use in weaverbird""" 30 | 31 | retriever_model: Optional[LLMModel] = None 32 | """retriever model to use in weaverbird""" 33 | 34 | @classmethod 35 | def build_from_config(cls, 36 | llm_model_config: BaseModelConfig, 37 | llm_finetuning_config: Optional[FinetuningConfig] = None, 38 | llm_generation_config: Optional[GenerationConfig] = None): 39 | llm_model = LLMModel(model_config=llm_model_config, 40 | finetuning_config=llm_finetuning_config, 41 | generation_config=llm_generation_config) 42 | 43 | 44 | 45 | return cls(llm_model=llm_model) 46 | 47 | if retro_config is not None: 48 | self.retriever = None 49 | 50 | def _generate( 51 | self, 52 | messages: List[BaseMessage], 53 | stop: Optional[List[str]] = None, 54 | run_manager: Optional[CallbackManagerForLLMRun] = None, 55 | **kwargs: Any, 56 | ) -> ChatResult: 57 | pass 58 | 59 | @property 60 | def _llm_type(self) -> str: 61 | return "weaverbird_chat" 62 | -------------------------------------------------------------------------------- /weaverbird/config_factory/generation_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from dataclasses import asdict, dataclass, field 3 | 4 | from weaverbird.config_factory.base_config import BaseConfig 5 | 6 | 7 | @dataclass 8 | class GenerationConfig(BaseConfig): 9 | """ 10 | Arguments pertaining to specify the decoding parameters. 11 | source: https://github.com/hiyouga/LLaMA-Efficient-Tuning 12 | """ 13 | do_sample: Optional[bool] = field( 14 | default=True, 15 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} 16 | ) 17 | temperature: Optional[float] = field( 18 | default=0.95, 19 | metadata={"help": "The value used to modulate the next token probabilities."} 20 | ) 21 | top_p: Optional[float] = field( 22 | default=0.7, 23 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} 24 | ) 25 | top_k: Optional[int] = field( 26 | default=50, 27 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} 28 | ) 29 | num_beams: Optional[int] = field( 30 | default=1, 31 | metadata={"help": "Number of beams for beam search. 1 means no beam search."} 32 | ) 33 | max_length: Optional[int] = field( 34 | default=10000, 35 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} 36 | ) 37 | max_new_tokens: Optional[int] = field( 38 | default=512, 39 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} 40 | ) 41 | repetition_penalty: Optional[float] = field( 42 | default=1.0, 43 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} 44 | ) 45 | length_penalty: Optional[float] = field( 46 | default=1.0, 47 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} 48 | ) 49 | 50 | def to_dict(self) -> Dict[str, Any]: 51 | args = asdict(self) 52 | if args.get("max_new_tokens", None): 53 | args.pop("max_length", None) 54 | return args -------------------------------------------------------------------------------- /weaverbird/cn_text_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | from langchain.text_splitter import CharacterTextSplitter 5 | 6 | 7 | class ChineseTextSplitter(CharacterTextSplitter): 8 | """ 9 | borrowed from https://github.com/chatchat-space/Langchain-Chatchat/blob/ 10 | f1f8ab80e4f72156abeb12afd8566ff90beca350/text_splitter/chinese_text_splitter.py 11 | """ 12 | 13 | def __init__(self, sentence_size: int, pdf: bool = False, **kwargs): 14 | super().__init__(**kwargs) 15 | self.pdf = pdf 16 | self.sentence_size = sentence_size 17 | 18 | def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 19 | if self.pdf: 20 | text = re.sub(r"\n{3,}", r"\n", text) 21 | text = re.sub('\s', " ", text) 22 | text = re.sub("\n\n", "", text) 23 | 24 | text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 25 | text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 26 | text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 27 | text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) 28 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 29 | text = text.rstrip() # 段尾如果有多余的\n就去掉它 30 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 31 | ls = [i for i in text.split("\n") if i] 32 | for ele in ls: 33 | if len(ele) > self.sentence_size: 34 | ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) 35 | ele1_ls = ele1.split("\n") 36 | for ele_ele1 in ele1_ls: 37 | if len(ele_ele1) > self.sentence_size: 38 | ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) 39 | ele2_ls = ele_ele2.split("\n") 40 | for ele_ele2 in ele2_ls: 41 | if len(ele_ele2) > self.sentence_size: 42 | ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) 43 | ele2_id = ele2_ls.index(ele_ele2) 44 | ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ 45 | ele2_id + 1:] 46 | ele_id = ele1_ls.index(ele_ele1) 47 | ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] 48 | 49 | id = ls.index(ele) 50 | ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] 51 | return ls 52 | -------------------------------------------------------------------------------- /weaverbird/embeddings/query_ref_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain.embeddings.base import Embeddings 4 | from transformers import AutoTokenizer, AutoModel 5 | import torch 6 | 7 | 8 | class QueryRefEncoder(Embeddings): 9 | """ 10 | Produce embeddings of query and references using a pretrained encoder 11 | """ 12 | 13 | def __init__(self, model_dir, device=None, max_batch_size=400): 14 | super(QueryRefEncoder, self).__init__() 15 | self.tokenizer = AutoTokenizer.from_pretrained(model_dir) 16 | self.query_encoder = AutoModel.from_pretrained(model_dir + "/query_encoder") 17 | self.reference_encoder = AutoModel.from_pretrained(model_dir + "/reference_encoder") 18 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if not device else device 19 | self.query_encoder = self.query_encoder.to(self.device).eval() 20 | self.reference_encoder = self.reference_encoder.to(self.device).eval() 21 | assert max_batch_size > 0 22 | self.max_batch_size = max_batch_size 23 | 24 | def get_embeddings(self, sentences: List[str], encoder) -> torch.Tensor: 25 | # Tokenization and Inference 26 | torch.cuda.empty_cache() 27 | with torch.no_grad(): 28 | inputs = self.tokenizer(sentences, padding=True, 29 | truncation=True, return_tensors='pt') 30 | for key in inputs: 31 | inputs[key] = inputs[key].to(self.device) 32 | outputs = encoder(**inputs) 33 | # Mean Pool 34 | token_embeddings = outputs[0] 35 | mask = inputs["attention_mask"] 36 | token_embeddings = token_embeddings.masked_fill( 37 | ~mask[..., None].bool(), 0.) 38 | sentence_embeddings = token_embeddings.sum( 39 | dim=1) / mask.sum(dim=1)[..., None] 40 | return sentence_embeddings 41 | 42 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 43 | """Compute doc embeddings using a trained retriever model. 44 | 45 | Args: 46 | texts: The list of texts to embed. 47 | 48 | Returns: 49 | List of embeddings, one for each text. 50 | """ 51 | texts = list(map(lambda x: x.replace("\n", " "), texts)) 52 | embeddings = self.get_embeddings(texts, self.reference_encoder) 53 | return embeddings.tolist() 54 | 55 | def embed_query(self, text: str) -> List[float]: 56 | """Compute query embeddings using a trained retriever model. 57 | 58 | Args: 59 | text: The text to embed. 60 | 61 | Returns: 62 | Embeddings for the text. 63 | """ 64 | text = text.replace("\n", " ") 65 | embedding = self.get_embeddings([text], self.query_encoder)[0] 66 | return embedding.tolist() 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## WeaverBird: Empowering Financial Decision-Making with Large Language Model, Knowledge Base, and Search Engine 2 | 3 | ![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg) 4 | ![](https://img.shields.io/badge/license-Apache-000000.svg) 5 | ![PRs Welcome](https://img.shields.io/badge/PRs-Welcome-green) 6 | ![GitHub last commit](https://img.shields.io/github/last-commit/ant-research/fin_domain_llm) 7 | ![Stars](https://img.shields.io/github/stars/ant-research/fin_domain_llm) 8 | 9 | Hugging Face 10 | 11 |

📃 Paper • 🌐 中文 README • 📺 [Demo, version 0.0.1]

12 | 13 | 14 | WeaverBird is an intelligent dialogue system designed specifically for the finance domain. Our system harnesses a large language model of GPT architecture that has been tuned using extensive corpora of finance-related text. 15 | 16 | 17 | **We are actively updating the repo and will finish shortly** 18 | 19 | 20 | 21 | 22 | 23 | ## News 24 | 25 | 26 | - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [08-10-2023] We released the paper [WeaverBird](https://arxiv.org/abs/2308.05361)! 27 | 28 | 29 | ## Citation [Back to Top] 30 | 31 | 32 | 33 | If you find `WeaverBird` useful for your research or development, please cite the following paper: 34 | ``` 35 | @article{xue2023weaverbird, 36 | title={WeaverBird: Empowering Financial Decision-Making with Large Language Model, Knowledge Base, and Search Engine}, 37 | author={Siqiao Xue and Fan Zhou and Yi Xu and Ming Jin and Qingsong Wen and Hongyan Hao and Qingyang Dai and Caigao Jiang and Hongyu Zhao and Shuo Xie and Jianshan He and James Zhang and Hongyuan Mei}, 38 | journal={arXiv preprint arXiv:2308.05361}, 39 | year={2023} 40 | } 41 | ``` 42 | 43 | ## Acknowledgment [Back to Top] 44 | 45 | 46 | The project is developed by researchers from Ant Group, University of Chicago and TTIC. We thank our colleagues at Machine 47 | Intelligence Team, Interbank Technologies Team of Ant Group and Alibaba DAMO Academy for helpful comments. 48 | 49 | The following repositories are used in `WeaverBird`, either in close to original form or as an inspiration: 50 | 51 | - [Langchain](https://github.com/langchain-ai/langchain) 52 | - [Langchain-chatglm](https://github.com/chatchat-space/Langchain-Chatchat) 53 | - [WebGLM](https://github.com/THUDM/WebGLM) 54 | - [EasyTPP](https://github.com/ant-research/EasyTemporalPointProcess) 55 | - [Huggingface - transformers](https://github.com/huggingface/transformers) 56 | - [LLaMA-Efficient-Tuning](https://github.com/hiyouga/LLaMA-Efficient-Tuning) 57 | 58 | ## Star History [Back to Top] 59 | 60 | 61 | ![Star History Chart](https://api.star-history.com/svg?repos=ant-research/fin_domain_llm&type=Date) 62 | 63 | 64 | -------------------------------------------------------------------------------- /weaverbird/models/chat_weaverbird.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | List, 4 | Optional, ) 5 | 6 | from langchain import PromptTemplate 7 | from langchain.callbacks.manager import CallbackManagerForLLMRun 8 | from langchain.chat_models.base import BaseChatModel 9 | from langchain.llms.base import LLM 10 | from langchain.schema import ( 11 | BaseMessage, 12 | ChatResult, Document, 13 | ) 14 | from langchain.schema import BaseRetriever 15 | 16 | from weaverbird.utils.misc import get_current_time 17 | 18 | 19 | class ChatWeaverBird(BaseChatModel): 20 | model_name: str = "chat_weaverbird" 21 | """model name of WeaverBird, default is `chat_weaverbird`""" 22 | 23 | llm_model: LLM 24 | """LLM model to use in weaverbird""" 25 | 26 | retriever_model: Optional[BaseRetriever] = None 27 | """retriever model to use in weaverbird""" 28 | 29 | prompt_template: PromptTemplate 30 | """template to construct the prompt """ 31 | 32 | streaming: bool = False 33 | """Whether to stream the results or not.""" 34 | 35 | def __init__(self, llm_model, retriever_model, prompt_template): 36 | super(ChatWeaverBird, self).__init__() 37 | self.llm_model = llm_model 38 | self.retriever_model = retriever_model 39 | self.prompt_template = prompt_template 40 | 41 | def _generate( 42 | self, 43 | messages: List[BaseMessage], 44 | stop: Optional[List[str]] = None, 45 | run_manager: Optional[CallbackManagerForLLMRun] = None, 46 | stream: Optional[bool] = None, 47 | **kwargs: Any, 48 | ) -> ChatResult: 49 | chat_history = kwargs.get('chat_history', []) 50 | docs = [] 51 | if self.retriever_model is not None: 52 | docs = self.retriever_model._get_relevant_documents() 53 | 54 | should_stream = stream if stream is not None else self.streaming 55 | 56 | if len(docs) > 0: 57 | prompt = self._generate_prompt(docs, messages) 58 | else: 59 | prompt = messages 60 | 61 | for answer_result in self.llm_model._generate_answer(prompt=prompt, 62 | history=chat_history, 63 | streaming=should_stream): 64 | resp = answer_result.generatios 65 | history = answer_result.llm_output['history'] 66 | history[-1][0] = messages 67 | response = { 68 | "prompt": prompt, 69 | "query": messages, 70 | "result": resp, 71 | "source_documents": docs 72 | } 73 | yield response, history 74 | 75 | def _generate_prompt(self, 76 | related_docs: List[Document], 77 | query: List[BaseMessage]): 78 | cur_time = get_current_time() 79 | 80 | if len(related_docs): 81 | context = "\n".join( 82 | [f"{doc.metadata.get('date', '')} {doc.metadata.get('title', '')} {doc.page_content}" for doc in 83 | related_docs]) 84 | else: 85 | context = '' 86 | # do a concate for query here 87 | query = ''.join(query) 88 | kwargs = {'question': query, 'date': cur_time, 'context': context} 89 | return self.prompt_template.format(kwargs) 90 | 91 | @property 92 | def _llm_type(self) -> str: 93 | return "chat_weaverbird" 94 | -------------------------------------------------------------------------------- /weaverbird/retrievers/web_searcher.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import dateparser 4 | from langchain.callbacks.manager import CallbackManagerForRetrieverRun 5 | from langchain.pydantic_v1 import Field 6 | from langchain.schema import BaseRetriever, Document 7 | from langchain.utilities import SerpAPIWrapper 8 | 9 | from weaverbird.config_factory import RetroConfig 10 | from weaverbird.utils import logger 11 | 12 | 13 | class WebSearcher(BaseRetriever): 14 | """Simplest WebSearcher 15 | borrowed from 16 | https://github.com/langchain-ai/langchain/blob/ddd07001f354cd09a76a61e1f5c678bf885506d2/libs/langchain/langchain/retrievers/web_research.py 17 | """ 18 | 19 | web_searcher: SerpAPIWrapper = Field(..., description="Web Search API Wrapper") 20 | 21 | class Config: 22 | 23 | """Configuration for this pydantic object.""" 24 | 25 | retro_name = 'web_searcher_retro' 26 | 27 | @classmethod 28 | def build_from_config(cls, search_config: RetroConfig): 29 | serpapi_api_key = search_config.serp_api_token 30 | search_config_dict = search_config.dict() 31 | search_config_dict.pop('serp_api_token') 32 | return cls(web_searcher=SerpAPIWrapper(serpapi_api_key=serpapi_api_key, params=search_config_dict)) 33 | 34 | def _search_result2docs(self, search_results): 35 | docs = [] 36 | logger.info(f'# search_results {len(search_results)}') 37 | for result in search_results: 38 | doc = Document(page_content=result["snippet"].replace('\n', '') if "snippet" in result.keys() else "", 39 | metadata={"link": result["link"] if "link" in result.keys() else "", 40 | "title": result["title"] if "title" in result.keys() else "", 41 | "source": result["source"] if "source" in result.keys() else "", 42 | "filename": result["title"] if "title" in result.keys() else "", 43 | "date": dateparser.parse(result['date']).strftime( 44 | "%Y-%m-%d") if 'date' in result.keys() else "", 45 | "score": 100}) # for the moment we fix the score 46 | docs.append(doc) 47 | return docs 48 | 49 | def _get_relevant_documents( 50 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun 51 | ) -> List[Document]: 52 | """Search websites for documents related to the query input. 53 | 54 | Args: 55 | query: user query 56 | 57 | Returns: 58 | Relevant documents from various urls. 59 | """ 60 | search_results = self.web_searcher.results(self.clean_search_query(query)) 61 | 62 | return self._search_result2docs(search_results['organic_results']) 63 | 64 | def clean_search_query(self, query: str) -> str: 65 | # Some search tools (e.g., Google) will 66 | # fail to return results if query has a 67 | # leading digit: 1. "LangCh..." 68 | # Check if the first character is a digit 69 | if query[0].isdigit(): 70 | # Find the position of the first quote 71 | first_quote_pos = query.find('"') 72 | if first_quote_pos != -1: 73 | # Extract the part of the string after the quote 74 | query = query[first_quote_pos + 1:] 75 | # Remove the trailing quote if present 76 | if query.endswith('"'): 77 | query = query[:-1] 78 | return query.strip() 79 | -------------------------------------------------------------------------------- /weaverbird/models/chat_glm2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from langchain.llms.base import LLM 4 | from langchain.schema import BaseMessage, ChatResult 5 | from transformers import PreTrainedModel, PreTrainedTokenizer 6 | 7 | from weaverbird.config_factory import BaseModelConfig, FinetuningConfig, GenerationConfig 8 | from weaverbird.models.llm_loader import load_model_and_tokenizer 9 | from weaverbird.utils import dispatch_model 10 | from weaverbird.utils.misc import torch_gc 11 | 12 | 13 | class ChatGLM2(LLM): 14 | """ GLM2 from THU """ 15 | model: Optional[PreTrainedModel] = None 16 | 17 | tokenizer: Optional[PreTrainedTokenizer] = None 18 | 19 | generation_config: Optional[GenerationConfig] = None 20 | 21 | def __init__( 22 | self, 23 | model_config: BaseModelConfig, 24 | finetuning_config: Optional[FinetuningConfig] = None, 25 | generation_config: Optional[GenerationConfig] = None 26 | ) -> None: 27 | super(ChatGLM2, self).__init__() 28 | self.model, self.tokenizer = load_model_and_tokenizer(model_config, finetuning_config) 29 | self.model = dispatch_model(self.model) 30 | self.model = self.model.eval() # enable evaluation mode 31 | self.generation_config = generation_config 32 | 33 | @classmethod 34 | def build_from_config(cls, configs): 35 | return cls(model_config=configs['model_config'], finetuning_config=configs['finetuning_config'], 36 | generation_config=configs['generation_config']) 37 | 38 | def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: 39 | history = kwargs.get('hisotry', []) 40 | response, _ = self.model.chat( 41 | self.tokenizer, 42 | prompt, 43 | history=history, 44 | max_length=self.generation_config.max_length, 45 | temperature=self.generation_config.temperature 46 | ) 47 | print(f"response:{response}") 48 | return response 49 | 50 | def _generate_answer(self, prompt: str, history: List[List[BaseMessage]] = [], streaming: bool = False): 51 | if streaming: 52 | history += [[]] 53 | for inum, (stream_resp, _) in enumerate(self.model.stream_chat( 54 | self.tokenizer, 55 | prompt, 56 | history=history[ 57 | -self.generation_config.max_history_message_length:-1] if self.generation_config.max_history_message_length > 1 else [], 58 | max_length=self.generation_config.max_length, 59 | temperature=self.generation_config.temperature 60 | )): 61 | history[-1] = [prompt, stream_resp] 62 | llm_output = {'history': history} 63 | yield ChatResult(generations=stream_resp, llm_output=llm_output) 64 | else: 65 | response, _ = self.model.chat( 66 | self.tokenizer, 67 | prompt, 68 | history=history[ 69 | -self.generation_config.max_history_message_length:-1] if self.generation_config.max_history_message_length > 1 else [], 70 | max_length=self.generation_config.max_length, 71 | temperature=self.generation_config.temperature 72 | ) 73 | torch_gc() 74 | history += [[prompt, response]] 75 | llm_output = {'history': history} 76 | yield ChatResult(generations=response, llm_output=llm_output) 77 | 78 | @property 79 | def _llm_type(self) -> str: 80 | """Return type of llm.""" 81 | return "chat_glm2" 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /weaverbird/config_factory/basemodel_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Literal, Optional 3 | from dataclasses import dataclass, field 4 | from weaverbird.config_factory.base_config import BaseConfig 5 | 6 | 7 | @dataclass 8 | class BaseModelConfig(BaseConfig): 9 | """ 10 | Arguments pertaining to which model/config/tokenizer. 11 | source: https://github.com/hiyouga/LLaMA-Efficient-Tuning 12 | """ 13 | model_name_or_path: str = field( 14 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 15 | ) 16 | cache_dir: Optional[str] = field( 17 | default=None, 18 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 19 | ) 20 | use_fast_tokenizer: Optional[bool] = field( 21 | default=False, 22 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 23 | ) 24 | use_auth_token: Optional[bool] = field( 25 | default=False, 26 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 27 | ) 28 | model_revision: Optional[str] = field( 29 | default="main", 30 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 31 | ) 32 | padding_side: Optional[Literal["left", "right"]] = field( 33 | default="left", 34 | metadata={"help": "The side on which the model should have padding applied."} 35 | ) 36 | quantization_bit: Optional[int] = field( 37 | default=None, 38 | metadata={"help": "The number of bits to quantize the model."} 39 | ) 40 | quantization_type: Optional[Literal["fp4", "nf4"]] = field( 41 | default="nf4", 42 | metadata={"help": "Quantization data type to use in int4 training."} 43 | ) 44 | double_quantization: Optional[bool] = field( 45 | default=True, 46 | metadata={"help": "Whether to use double quantization in int4 training or not."} 47 | ) 48 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field( 49 | default=None, 50 | metadata={"help": "Adopt scaled rotary positional embeddings."} 51 | ) 52 | checkpoint_dir: Optional[str] = field( 53 | default=None, 54 | metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} 55 | ) 56 | reward_model: Optional[str] = field( 57 | default=None, 58 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 59 | ) 60 | plot_loss: Optional[bool] = field( 61 | default=False, 62 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 63 | ) 64 | hf_auth_token: Optional[str] = field( 65 | default=None, 66 | metadata={"help": "Auth token to log in with Hugging Face Hub."} 67 | ) 68 | compute_dtype: Optional[torch.dtype] = field( 69 | default=None, 70 | metadata={"help": "Used in quantization configs. Do not specify this argument manually."} 71 | ) 72 | model_max_length: Optional[int] = field( 73 | default=None, 74 | metadata={"help": "Used in rope scaling. Do not specify this argument manually."} 75 | ) 76 | 77 | def __post_init__(self): 78 | if self.compute_dtype is not None or self.model_max_length is not None: 79 | raise ValueError("These arguments cannot be specified.") 80 | 81 | if self.checkpoint_dir is not None: # support merging multiple lora weights 82 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 83 | 84 | if self.quantization_bit is not None: 85 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." 86 | 87 | if self.use_auth_token == True and self.hf_auth_token is not None: 88 | from huggingface_hub.hf_api import HfFolder # lazy load 89 | HfFolder.save_token(self.hf_auth_token) -------------------------------------------------------------------------------- /weaverbird/document_loaders/local_kb_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | from langchain.document_loaders import TextLoader, UnstructuredMarkdownLoader, UnstructuredFileLoader 5 | from langchain.document_loaders.base import BaseLoader 6 | from langchain.text_splitter import TextSplitter 7 | from tqdm import tqdm 8 | 9 | from weaverbird.cn_text_splitter import ChineseTextSplitter 10 | from weaverbird.utils import logger 11 | 12 | 13 | def tree(filepath, ignore_dir_names=None, ignore_file_names=None): 14 | """ 15 | Return two list, the first one is the dirs of all files under filepath, the second one is the file names of all 16 | corresponding files. 17 | borrowed from https://github.com/chatchat-space/Langchain-Chatchat/ 18 | 19 | """ 20 | if ignore_dir_names is None: 21 | ignore_dir_names = [] 22 | if ignore_file_names is None: 23 | ignore_file_names = [] 24 | ret_list = [] 25 | if isinstance(filepath, str): 26 | if not os.path.exists(filepath): 27 | print("Directory not existed") 28 | return None, None 29 | elif os.path.isfile(filepath) and os.path.basename(filepath) not in ignore_file_names: 30 | return [filepath], [os.path.basename(filepath)] 31 | elif os.path.isdir(filepath) and os.path.basename(filepath) not in ignore_dir_names: 32 | for file in os.listdir(filepath): 33 | fullfilepath = os.path.join(filepath, file) 34 | if os.path.isfile(fullfilepath) and os.path.basename(fullfilepath) not in ignore_file_names: 35 | ret_list.append(fullfilepath) 36 | if os.path.isdir(fullfilepath) and os.path.basename(fullfilepath) not in ignore_dir_names: 37 | ret_list.extend(tree(fullfilepath, ignore_dir_names, ignore_file_names)[0]) 38 | return ret_list, [os.path.basename(p) for p in ret_list] 39 | 40 | 41 | def load_file(file_dir, text_splitter): 42 | if file_dir.lower().endswith(".md"): 43 | loader = UnstructuredMarkdownLoader(file_dir) 44 | elif file_dir.lower().endswith(".txt"): 45 | loader = TextLoader(file_dir, autodetect_encoding=True) 46 | else: 47 | loader = UnstructuredFileLoader(file_dir, mode="elements") 48 | 49 | docs = loader.load_and_split(text_splitter=text_splitter) 50 | 51 | return docs 52 | 53 | 54 | class LocalKnowledgeBaseLoader(BaseLoader): 55 | def __init__(self, 56 | file_dir: str or List[str], 57 | text_splitter: Optional[TextSplitter] = ChineseTextSplitter 58 | ): 59 | self.file_dir = file_dir 60 | self.text_splitter = text_splitter 61 | 62 | def _load_from_single_dir(self, file_dir): 63 | docs = [] 64 | loaded_files = [] 65 | failed_files = [] 66 | if not os.path.exists(file_dir): 67 | logger.info("Directory not existed") 68 | return None 69 | elif os.path.isfile(file_dir): 70 | file = os.path.split(file_dir)[-1] 71 | try: 72 | docs = load_file(file_dir, self.text_splitter) 73 | logger.info(f"{file} loaded") 74 | loaded_files.append(file_dir) 75 | except Exception as e: 76 | logger.error(e) 77 | logger.info(f"{file} failed to load") 78 | failed_files.append(file) 79 | elif os.path.isdir(file_dir): 80 | docs = [] 81 | for single_file_dir, file in tqdm(zip(*tree(file_dir)), desc="loading files"): 82 | try: 83 | docs += load_file(single_file_dir, self.text_splitter) 84 | loaded_files.append(single_file_dir) 85 | except Exception as e: 86 | logger.error(e) 87 | failed_files.append(single_file_dir) 88 | 89 | return docs, loaded_files, failed_files 90 | 91 | def _load_from_multiple_dir(self, file_dir): 92 | docs = [] 93 | loaded_files = [] 94 | failed_files = [] 95 | 96 | for file in file_dir: 97 | docs_, loaded_files_, failed_files_ = self._load_from_single_dir(file) 98 | docs.extend(docs_) 99 | loaded_files.extend(loaded_files_) 100 | failed_files.extend(failed_files_) 101 | return docs, loaded_files, failed_files 102 | 103 | def load(self): 104 | if isinstance(self.file_dir, str): 105 | docs, loaded_files, failed_files = self._load_from_single_dir(self.file_dir) 106 | else: 107 | docs, loaded_files, failed_files = self._load_from_multiple_dir(self.file_dir) 108 | 109 | return docs 110 | -------------------------------------------------------------------------------- /weaverbird/config_factory/fintune_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Literal, Optional 3 | from dataclasses import asdict, dataclass, field 4 | 5 | from weaverbird.config_factory.base_config import BaseConfig 6 | 7 | 8 | @dataclass 9 | class FinetuningConfig(BaseConfig): 10 | """ 11 | Arguments pertaining to which techniques we are going to fine-tuning with. 12 | source: https://github.com/hiyouga/LLaMA-Efficient-Tuning 13 | """ 14 | finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( 15 | default="lora", 16 | metadata={"help": "Which fine-tuning method to use."} 17 | ) 18 | num_hidden_layers: Optional[int] = field( 19 | default=32, 20 | metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \ 21 | LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ 22 | LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ 23 | BLOOM choices: [\"24\", \"30\", \"70\"], \ 24 | Falcon choices: [\"32\", \"60\"], \ 25 | Baichuan choices: [\"32\", \"40\"] \ 26 | Qwen choices: [\"32\"], \ 27 | XVERSE choices: [\"40\"], \ 28 | ChatGLM2 choices: [\"28\"]"} 29 | ) 30 | num_layer_trainable: Optional[int] = field( 31 | default=3, 32 | metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} 33 | ) 34 | name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( 35 | default="mlp", 36 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ 37 | LLaMA choices: [\"mlp\", \"self_attn\"], \ 38 | BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ 39 | Baichuan choices: [\"mlp\", \"self_attn\"], \ 40 | Qwen choices: [\"mlp\", \"attn\"], \ 41 | LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} 42 | ) 43 | lora_rank: Optional[int] = field( 44 | default=8, 45 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 46 | ) 47 | lora_alpha: Optional[float] = field( 48 | default=32.0, 49 | metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} 50 | ) 51 | lora_dropout: Optional[float] = field( 52 | default=0.1, 53 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 54 | ) 55 | lora_target: Optional[str] = field( 56 | default=None, 57 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ 58 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 59 | BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ 60 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 61 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ 62 | LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} 63 | ) 64 | resume_lora_training: Optional[bool] = field( 65 | default=True, 66 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 67 | ) 68 | ppo_score_norm: Optional[bool] = field( 69 | default=False, 70 | metadata={"help": "Use score normalization in PPO Training."} 71 | ) 72 | dpo_beta: Optional[float] = field( 73 | default=0.1, 74 | metadata={"help": "The beta parameter for the DPO loss."} 75 | ) 76 | 77 | def __post_init__(self): 78 | if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA 79 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] 80 | 81 | if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 82 | trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)] 83 | else: # fine-tuning the first n layers if num_layer_trainable < 0 84 | trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] 85 | 86 | self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] 87 | 88 | assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." 89 | 90 | def save_to_json(self, json_path: str): 91 | r"""Saves the content of this instance in JSON format inside `json_path`.""" 92 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" 93 | with open(json_path, "w", encoding="utf-8") as f: 94 | f.write(json_string) 95 | 96 | @classmethod 97 | def load_from_json(cls, json_path: str): 98 | r"""Creates an instance from the content of `json_path`.""" 99 | with open(json_path, "r", encoding="utf-8") as f: 100 | text = f.read() 101 | return cls(**json.loads(text)) -------------------------------------------------------------------------------- /weaverbird/models/chat_llama2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Tuple, Dict 2 | 3 | import torch 4 | from langchain.llms.base import LLM 5 | from transformers import GenerationConfig 6 | from transformers import PreTrainedModel, PreTrainedTokenizer 7 | 8 | from weaverbird.config_factory import BaseModelConfig, FinetuningConfig 9 | from weaverbird.config_factory import GenerationConfig as WBGenerationConfig 10 | from weaverbird.models.llm_loader import load_model_and_tokenizer 11 | from weaverbird.models.template import get_template_and_fix_tokenizer, Template 12 | from weaverbird.utils import dispatch_model, get_logits_processor 13 | 14 | 15 | class ChatLlama2(LLM): 16 | """ 17 | LLAMA2 from Meta 18 | 19 | Borrowed from https://github.com/hiyouga/LLaMA-Efficient-Tuning/blob/469f859161dec0e34f4cc849f20e43d442680b5c/src/llmtuner/chat/stream_chat.py 20 | """ 21 | model: Optional[PreTrainedModel] = None 22 | 23 | tokenizer: Optional[PreTrainedTokenizer] = None 24 | 25 | generation_config: Optional[WBGenerationConfig] = None 26 | 27 | template: Optional[Template] = None 28 | 29 | def __init__( 30 | self, 31 | model_config: BaseModelConfig, 32 | finetuning_config: Optional[FinetuningConfig] = None, 33 | generation_config: Optional[WBGenerationConfig] = None 34 | ) -> None: 35 | super(ChatLlama2, self).__init__() 36 | self.model, self.tokenizer = load_model_and_tokenizer(model_config, finetuning_config) 37 | self.model = dispatch_model(self.model) 38 | self.model = self.model.eval() # enable evaluation mode 39 | self.generation_config = generation_config 40 | self.template = get_template_and_fix_tokenizer("llama2", self.tokenizer) 41 | 42 | @classmethod 43 | def build_from_config(cls, configs): 44 | return cls(model_config=configs['model_config'], finetuning_config=configs['finetuning_config'], 45 | generation_config=configs['generation_config']) 46 | 47 | def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: 48 | response, _ = self.chat( 49 | prompt, 50 | history=[] 51 | ) 52 | print(f"response:{response}") 53 | print(f"+++++++++++++++++++++++++++++++++++") 54 | return response 55 | 56 | def process_args( 57 | self, 58 | query: str, 59 | history: Optional[List[Tuple[str, str]]] = None, 60 | system: Optional[str] = None, 61 | **input_kwargs 62 | ) -> Tuple[Dict[str, Any], int]: 63 | system = system or "" 64 | 65 | prompt, _ = self.template.encode_oneturn( 66 | tokenizer=self.tokenizer, query=query, resp="", history=history, system=system 67 | ) 68 | input_ids = torch.tensor([prompt], device=self.model.device) 69 | prompt_length = len(input_ids[0]) 70 | 71 | do_sample = input_kwargs.pop("do_sample", None) 72 | temperature = input_kwargs.pop("temperature", None) 73 | top_p = input_kwargs.pop("top_p", None) 74 | top_k = input_kwargs.pop("top_k", None) 75 | repetition_penalty = input_kwargs.pop("repetition_penalty", None) 76 | max_length = input_kwargs.pop("max_length", None) 77 | max_new_tokens = input_kwargs.pop("max_new_tokens", None) 78 | 79 | generation_config = self.generation_config.dict() 80 | generation_config.update(dict( 81 | do_sample=do_sample if do_sample is not None else generation_config["do_sample"], 82 | temperature=temperature or generation_config["temperature"], 83 | top_p=top_p or generation_config["top_p"], 84 | top_k=top_k or generation_config["top_k"], 85 | repetition_penalty=repetition_penalty or generation_config["repetition_penalty"], 86 | eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, 87 | pad_token_id=self.tokenizer.pad_token_id 88 | )) 89 | 90 | if max_length: 91 | generation_config.pop("max_new_tokens", None) 92 | generation_config["max_length"] = max_length 93 | 94 | if max_new_tokens: 95 | generation_config.pop("max_length", None) 96 | generation_config["max_new_tokens"] = max_new_tokens 97 | 98 | gen_kwargs = dict( 99 | inputs=input_ids, 100 | generation_config=GenerationConfig(**generation_config), 101 | logits_processor=get_logits_processor() 102 | ) 103 | 104 | return gen_kwargs, prompt_length 105 | 106 | @torch.inference_mode() 107 | def chat( 108 | self, 109 | prompt: str, 110 | history: Optional[List[Tuple[str, str]]] = None, 111 | system: Optional[str] = None, 112 | **input_kwargs 113 | ) -> Tuple[str, Tuple[int, int]]: 114 | gen_kwargs, prompt_length = self.process_args(prompt, history, system, **input_kwargs) 115 | generation_output = self.model.generate(**gen_kwargs) 116 | outputs = generation_output.tolist()[0][prompt_length:] 117 | response = self.tokenizer.decode(outputs, skip_special_tokens=True) 118 | response_length = len(outputs) 119 | return response, (prompt_length, response_length) 120 | 121 | @property 122 | def _llm_type(self) -> str: 123 | """Return type of llm.""" 124 | return "chat_llama2" 125 | -------------------------------------------------------------------------------- /weaverbird/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from typing import Tuple, Optional, Dict, Any 5 | 6 | import torch 7 | import yaml 8 | from transformers import HfArgumentParser 9 | from transformers.generation.logits_process import LogitsProcessor 10 | from transformers.generation.utils import LogitsProcessorList 11 | from transformers.modeling_utils import PreTrainedModel 12 | 13 | from weaverbird.config_factory import BaseModelConfig, FinetuningConfig, GenerationConfig, RetroConfig 14 | 15 | 16 | def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: 17 | """ 18 | Returns the number of trainable parameters and number of all parameters in the model. 19 | source: https://github.com/hiyouga/LLaMA-Efficient-Tuning 20 | """ 21 | trainable_params, all_param = 0, 0 22 | for param in model.parameters(): 23 | num_params = param.numel() 24 | # if using DS Zero 3 and the weights are initialized empty 25 | if num_params == 0 and hasattr(param, "ds_numel"): 26 | num_params = param.ds_numel 27 | 28 | # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 29 | if param.__class__.__name__ == "Params4bit": 30 | num_params = num_params * 2 31 | 32 | all_param += num_params 33 | if param.requires_grad: 34 | trainable_params += num_params 35 | 36 | return trainable_params, all_param 37 | 38 | 39 | def parse_configs(configs: Optional[Dict[str, Any]] = None): 40 | parser = HfArgumentParser(( 41 | BaseModelConfig, 42 | FinetuningConfig, 43 | GenerationConfig, 44 | RetroConfig 45 | )) 46 | 47 | parsed_config = _parse_args(parser, configs) 48 | return {'model_config': parsed_config[0], 49 | 'finetuning_config': parsed_config[1], 50 | 'generation_config': parsed_config[2], 51 | 'retro_config': parsed_config[3]} 52 | 53 | 54 | def load_yaml_config(config_dir): 55 | """ Load yaml config file from disk. 56 | 57 | Args: 58 | config_dir: str or Path 59 | The path of the config file. 60 | 61 | Returns: 62 | Config: dict. 63 | """ 64 | with open(config_dir) as config_file: 65 | # load configs 66 | config = yaml.load(config_file, Loader=yaml.FullLoader) 67 | 68 | return config 69 | 70 | 71 | def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: 72 | if args is not None: 73 | return parser.parse_dict(args) 74 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 75 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 76 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 77 | return parser.parse_json_file(os.path.abspath(sys.argv[1])) 78 | else: 79 | return parser.parse_args_into_dataclasses() 80 | 81 | 82 | def torch_gc() -> None: 83 | """Collects GPU memory. 84 | """ 85 | if torch.cuda.is_available(): 86 | torch.cuda.empty_cache() 87 | torch.cuda.ipc_collect() 88 | 89 | 90 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: 91 | """Configures device map for ChatGLM2. 92 | 93 | Source: https://github.com/hiyouga/ChatGLM-Efficient-Tuning 94 | """ 95 | num_layers = 28 96 | layers_per_gpu = 30 / num_gpus 97 | device_map = { 98 | "transformer.embedding.word_embeddings": 0, 99 | "transformer.encoder.final_layernorm": 0, 100 | "transformer.output_layer": 0, 101 | "transformer.rotary_pos_emb": 0, 102 | "transformer.prefix_encoder": 0, 103 | "lm_head": 0 104 | } 105 | 106 | added_layers = 2 107 | target_gpu = 0 108 | 109 | for i in range(num_layers): 110 | if added_layers >= layers_per_gpu: 111 | target_gpu += 1 112 | added_layers = 0 113 | assert target_gpu < num_gpus 114 | device_map[f"transformer.encoder.layers.{i}"] = target_gpu 115 | added_layers += 1 116 | 117 | return device_map 118 | 119 | 120 | def get_current_time(): 121 | now = time.time() 122 | time_arr = time.localtime(now) 123 | return time.strftime("%Y-%m-%d", time_arr) 124 | 125 | 126 | # Avoid runtime error in model.generate(do_sample=True). 127 | # Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54 128 | class InvalidScoreLogitsProcessor(LogitsProcessor): 129 | 130 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 131 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 132 | scores.zero_() 133 | scores[..., 5] = 5e4 134 | return scores 135 | 136 | 137 | def get_logits_processor() -> LogitsProcessorList: 138 | logits_processor = LogitsProcessorList() 139 | logits_processor.append(InvalidScoreLogitsProcessor()) 140 | return logits_processor 141 | 142 | 143 | def dispatch_model(model: PreTrainedModel) -> PreTrainedModel: 144 | """Dispatches a pre-trained model to GPUs with balanced memory. 145 | 146 | Source: https://github.com/hiyouga/LLaMA-Efficient-Tuning/blob/main/src/llmtuner/extras/misc.py 147 | """ 148 | if torch.cuda.device_count() > 1: 149 | from accelerate import dispatch_model 150 | 151 | if 'chatglm' in model.name_or_path: 152 | device_map = auto_configure_device_map(torch.cuda.device_count()) 153 | else: 154 | from accelerate.utils import infer_auto_device_map, get_balanced_memory 155 | 156 | if model._no_split_modules is None: 157 | raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") 158 | 159 | kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} 160 | max_memory = get_balanced_memory(model, **kwargs) 161 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) 162 | 163 | model.tie_weights() 164 | return dispatch_model(model, device_map) 165 | else: 166 | return model.cuda() 167 | -------------------------------------------------------------------------------- /weaverbird/models/llm_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from types import MethodType 3 | from typing import TYPE_CHECKING, Tuple, Optional 4 | 5 | import torch 6 | from transformers import ( 7 | AutoConfig, 8 | AutoModelForCausalLM, 9 | AutoModel, 10 | AutoTokenizer, 11 | BitsAndBytesConfig, 12 | PretrainedConfig, 13 | PreTrainedModel, 14 | PreTrainedTokenizerBase 15 | ) 16 | from transformers.deepspeed import is_deepspeed_zero3_enabled 17 | from transformers.utils.versions import require_version 18 | 19 | from weaverbird.config_factory import BaseModelConfig, FinetuningConfig 20 | from weaverbird.utils import logger, count_parameters 21 | 22 | if TYPE_CHECKING: 23 | from transformers import PreTrainedTokenizer 24 | 25 | 26 | def load_model_and_tokenizer( 27 | model_config: BaseModelConfig, 28 | finetuning_args: Optional[FinetuningConfig] = None 29 | ) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]: 30 | """ 31 | Loads pretrained model and tokenizer. 32 | source: https://github.com/hiyouga/LLaMA-Efficient-Tuning 33 | """ 34 | 35 | config_kwargs = { 36 | "trust_remote_code": True, 37 | "cache_dir": model_config.cache_dir, 38 | "revision": model_config.model_revision, 39 | "use_auth_token": True if model_config.use_auth_token else None, 40 | } 41 | 42 | tokenizer = AutoTokenizer.from_pretrained( 43 | Path(model_config.model_name_or_path), 44 | use_fast=model_config.use_fast_tokenizer, 45 | padding_side=model_config.padding_side, 46 | **config_kwargs 47 | ) 48 | 49 | if finetuning_args is not None and finetuning_args.finetuning_type == "full" and model_config.checkpoint_dir is not None: 50 | model_to_load = model_config.checkpoint_dir[0] 51 | else: 52 | model_to_load = model_config.model_name_or_path 53 | 54 | config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) 55 | 56 | if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config 57 | if model_config.compute_dtype == torch.bfloat16: 58 | setattr(config, "bf16", True) 59 | else: 60 | setattr(config, "fp16", True) 61 | 62 | # Set RoPE scaling 63 | if model_config.rope_scaling is not None: 64 | if hasattr(config, "use_dynamic_ntk"): # for Qwen models 65 | setattr(config, "use_dynamic_ntk", True) 66 | setattr(config, "use_logn_attn", True) 67 | logger.info("Using dynamic NTK scaling.") 68 | 69 | elif hasattr(config, "rope_scaling"): # for LLaMA models 70 | require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0") 71 | 72 | scaling_factor = 2.0 73 | 74 | setattr(config, "rope_scaling", {"type": model_config.rope_scaling, "factor": scaling_factor}) 75 | logger.info("Using {} scaling strategy and setting scaling factor to {}".format( 76 | model_config.rope_scaling, scaling_factor 77 | )) 78 | 79 | else: 80 | logger.warning("Current model does not support RoPE scaling.") 81 | 82 | # Quantization configurations (using bitsandbytes library). 83 | if model_config.quantization_bit is not None: 84 | if is_deepspeed_zero3_enabled(): 85 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") 86 | 87 | if model_config.quantization_bit == 8: 88 | require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") 89 | config_kwargs["load_in_8bit"] = True 90 | config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) 91 | 92 | elif model_config.quantization_bit == 4: 93 | require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") 94 | config_kwargs["load_in_4bit"] = True 95 | config_kwargs["quantization_config"] = BitsAndBytesConfig( 96 | load_in_4bit=True, 97 | bnb_4bit_compute_dtype=model_config.compute_dtype, 98 | bnb_4bit_use_double_quant=model_config.double_quantization, 99 | bnb_4bit_quant_type=model_config.quantization_type 100 | ) 101 | 102 | config_kwargs["device_map"] = "auto" 103 | logger.info("Quantizing model to {} bit.".format(model_config.quantization_bit)) 104 | 105 | # Load and prepare pre-trained models (without valuehead). 106 | if 'glm' in model_config.model_name_or_path.lower(): 107 | model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs) 108 | else: 109 | model = AutoModelForCausalLM.from_pretrained( 110 | model_to_load, 111 | config=config, 112 | torch_dtype=model_config.compute_dtype, 113 | low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), 114 | **config_kwargs 115 | ) 116 | 117 | # Disable custom generate method (for Qwen) 118 | if "GenerationMixin" not in str(model.generate.__func__): 119 | model.generate = MethodType(PreTrainedModel.generate, model) 120 | 121 | # Fix LM head (for ChatGLM2) 122 | if not hasattr(model, "lm_head") and hasattr(model, "transformer"): 123 | setattr(model, "lm_head", model.transformer.output_layer) 124 | 125 | # Register auto class to save the custom code files. 126 | if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): 127 | config.__class__.register_for_auto_class() 128 | if isinstance(model, PreTrainedModel) and ("AutoModelForCausalLM" in getattr(config, "auto_map", {}) or 129 | "AutoModel" in getattr(config, "auto_map", {})): 130 | model.__class__.register_for_auto_class() 131 | if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): 132 | tokenizer.__class__.register_for_auto_class() 133 | 134 | # Prepare model for inference 135 | model.requires_grad_(False) # fix all model params 136 | infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability 137 | model = model.to(infer_dtype) if model_config.quantization_bit is None else model 138 | 139 | _, all_param = count_parameters(model) 140 | logger.info("num params: {:d}".format(all_param)) 141 | 142 | return model, tokenizer 143 | -------------------------------------------------------------------------------- /weaverbird/utils/registrable.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .log_utils import default_logger as logger 4 | 5 | 6 | class Registrable: 7 | """Any class that inherits from ``Registrable`` gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod ``@BaseClass.register(name)``. 8 | 9 | After which you can call ``BaseClass.list_available()`` to get the keys for the registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass. 10 | 11 | Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call ``from_params(params)`` on the returned subclass. 12 | """ 13 | 14 | _registry = defaultdict(dict) 15 | _default_impl = None 16 | 17 | @classmethod 18 | def register(cls, name, constructor=None, overwrite=False): 19 | """Register a class under a particular name. 20 | Args: 21 | name (str): The name to register the class under. 22 | constructor (str): optional (default=None) 23 | The name of the method to use on the class to construct the object. If this is given, 24 | we will use this method (which must be a ``classmethod``) instead of the default 25 | constructor. 26 | overwrite (bool) : optional (default=False) 27 | If True, overwrites any existing models registered under ``name``. Else, 28 | throws an error if a model is already registered under ``name``. 29 | 30 | # Examples 31 | To use this class, you would typically have a base class that inherits from ``Registrable``: 32 | ```python 33 | class Transform(Registrable): 34 | ... 35 | ``` 36 | Then, if you want to register a subclass, you decorate it like this: 37 | ```python 38 | @Transform.register("shift-transform") 39 | class ShiftTransform(Transform): 40 | def __init__(self, param1: int, param2: str): 41 | ... 42 | ``` 43 | Registering a class like this will let you instantiate a class from a config file, where you 44 | give ``"type": "shift-transform"``, and keys corresponding to the parameters of the ``__init__`` 45 | method (note that for this to work, those parameters must have type annotations). 46 | If you want to have the instantiation from a config file call a method other than the 47 | constructor, either because you have several different construction paths that could be 48 | taken for the same object (as we do in ``Transform``) or because you have logic you want to 49 | happen before you get to the constructor, you can register a specific ``@classmethod`` as the constructor to use. 50 | """ 51 | registry = Registrable._registry[cls] 52 | 53 | def add_subclass_to_registry(subclass): 54 | # Add to registry, raise an error if key has already been used. 55 | if name in registry: 56 | if overwrite: 57 | message = ( 58 | f"{name} has already been registered as {registry[name][0].__name__}, but " 59 | f"overwrite=True, so overwriting with {cls.__name__}" 60 | ) 61 | logger.info(message) 62 | else: 63 | message = ( 64 | f"Cannot register {name} as {cls.__name__}; " 65 | f"name already in use for {registry[name][0].__name__}" 66 | ) 67 | raise RuntimeError(message) 68 | registry[name] = (subclass, constructor) 69 | return subclass 70 | 71 | return add_subclass_to_registry 72 | 73 | @classmethod 74 | def by_name(cls, name): 75 | """ 76 | Returns a callable function that constructs an argument of the registered class. Because 77 | you can register particular functions as constructors for specific names, this isn't 78 | necessarily the ``__init__`` method of some class. 79 | """ 80 | logger.debug(f"instantiating registered subclass {name} of {cls}") 81 | subclass, constructor = cls.resolve_class_name(name) 82 | if not constructor: 83 | return subclass 84 | else: 85 | return getattr(subclass, constructor) 86 | 87 | @classmethod 88 | def resolve_class_name(cls, name): 89 | """ 90 | Returns the subclass that corresponds to the given ``name``, along with the name of the 91 | method that was registered as a constructor for that ``name``, if any. 92 | This method also allows ``name`` to be a fully-specified module name, instead of a name that 93 | was already added to the ``Registry``. In that case, you cannot use a separate function as 94 | a constructor (as you need to call ``cls.register()`` in order to tell us what separate 95 | function to use). 96 | """ 97 | if name in Registrable._registry[cls]: 98 | subclass, constructor = Registrable._registry[cls].get(name) 99 | return subclass, constructor 100 | else: 101 | for base_cls, v in Registrable._registry.items(): 102 | if name in v: 103 | subclass, constructor = Registrable._registry[base_cls].get(name) 104 | return subclass, constructor 105 | 106 | if "." in name: 107 | # This might be a fully qualified class name, so we'll try importing its "module" 108 | # and finding it there. 109 | parts = name.split(".") 110 | submodule = ".".join(parts[:-1]) 111 | class_name = parts[-1] 112 | import importlib 113 | try: 114 | module = importlib.import_module(submodule) 115 | except ModuleNotFoundError: 116 | raise RuntimeError( 117 | f"tried to interpret {name} as a path to a class " 118 | f"but unable to import module {submodule}" 119 | ) 120 | 121 | try: 122 | subclass = getattr(module, class_name) 123 | constructor = None 124 | return subclass, constructor 125 | except AttributeError: 126 | raise RuntimeError( 127 | f"tried to interpret {name} as a path to a class " 128 | f"but unable to find class {class_name} in {submodule}" 129 | ) 130 | 131 | else: 132 | # is not a qualified class name 133 | raise RuntimeError( 134 | f"{name} is not a registered name for {cls.__name__}. " 135 | "You probably need to use the --include-package flag " 136 | "to load your custom code. Alternatively, you can specify your choices " 137 | """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ 138 | "in which case they will be automatically imported correctly." 139 | ) 140 | 141 | @classmethod 142 | def list_available(cls): 143 | """List default first if it exists""" 144 | keys = list(Registrable._registry[cls].keys()) 145 | default = cls._default_impl 146 | 147 | if default is None: 148 | return keys 149 | elif default not in keys: 150 | raise RuntimeError(f"Default implementation {default} is not registered") 151 | else: 152 | return [default] + [k for k in keys if k != default] 153 | 154 | @classmethod 155 | def registry_dict(cls): 156 | return Registrable._registry[cls] -------------------------------------------------------------------------------- /scripts/train_encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import time 6 | 7 | import datasets 8 | import torch 9 | from datasets.load import load_from_disk 10 | from torch.optim import AdamW 11 | from torch.utils.data import DataLoader 12 | from transformers import AutoTokenizer, AutoModel 13 | 14 | 15 | class QueryRefEncoderMainModel(torch.nn.Module): 16 | """ 17 | heavily borrowed from WebGLM: https://github.com/THUDM/WebGLM 18 | """ 19 | 20 | def __init__(self, model_dir) -> None: 21 | super().__init__() 22 | self.question_encoder = AutoModel.from_pretrained(model_dir) 23 | self.reference_encoder = AutoModel.from_pretrained(model_dir) 24 | 25 | total = sum([param.nelement() for param in self.parameters()]) 26 | print("Number of parameter: %.2fM" % (total / 1e6)) 27 | 28 | def mean_pooling(self, token_embeddings, mask): 29 | token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) 30 | sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] 31 | return sentence_embeddings 32 | 33 | def forward(self, question, pos, neg): 34 | global args 35 | 36 | q = self.question_encoder(**question) 37 | r_pos = self.reference_encoder(**pos) 38 | r_neg = self.reference_encoder(**neg) 39 | cls_q = self.mean_pooling(q[0], question["attention_mask"]) 40 | cls_q /= args.temp 41 | cls_r_pos = self.mean_pooling(r_pos[0], pos["attention_mask"]) 42 | cls_r_neg = self.mean_pooling(r_neg[0], neg["attention_mask"]) 43 | 44 | method = "cos" 45 | 46 | if method == "inner_product": 47 | l_pos = torch.matmul(cls_q, torch.transpose(cls_r_pos, 0, 1)) 48 | l_neg = torch.matmul(cls_q, torch.transpose(cls_r_neg, 0, 1)) 49 | elif method == "cos": 50 | l_pos = torch.matmul(cls_q, torch.transpose(cls_r_pos, 0, 1)) / (cls_q.norm() * cls_r_pos.norm()) 51 | l_neg = torch.matmul(cls_q, torch.transpose(cls_r_neg, 0, 1)) / (cls_q.norm() * cls_r_neg.norm()) 52 | else: 53 | raise NotImplementedError 54 | 55 | return l_pos, l_neg 56 | 57 | @staticmethod 58 | def loss(l_pos, l_neg): 59 | return torch.nn.functional.cross_entropy(torch.cat([l_pos, l_neg], dim=1), 60 | torch.arange(0, len(l_pos), dtype=torch.long, device=args.device)) 61 | 62 | @staticmethod 63 | def num_correct(l_pos, l_neg): 64 | return ((torch.diag(l_pos) > torch.diag(l_neg)) == True).sum() 65 | 66 | @staticmethod 67 | def acc(l_pos, l_neg): 68 | return ((torch.diag(l_pos) > torch.diag(l_neg)) == True).sum() / len(l_pos) 69 | 70 | 71 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 72 | def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1): 73 | self.warmup = warmup 74 | self.total = total 75 | self.ratio = ratio 76 | super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 77 | 78 | def lr_lambda(self, step): 79 | if step < self.warmup: 80 | return (1 - self.ratio) * step / float(max(1, self.warmup)) 81 | 82 | return max( 83 | 0.0, 84 | 1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)), 85 | ) 86 | 87 | 88 | def move_dict_to_device(obj, device): 89 | for key in obj: 90 | obj[key] = obj[key].to(device) 91 | 92 | 93 | def collate(data): 94 | question = tokenizer([item["question"] for item in data], return_tensors="pt", padding=True, truncation=True) 95 | positive_reference = tokenizer([item["positive_reference"] for item in data], return_tensors="pt", padding=True, 96 | truncation=True) 97 | negative_reference = tokenizer([item["negative_reference"] for item in data], return_tensors="pt", padding=True, 98 | truncation=True) 99 | 100 | for key in question: question[key] = question[key].to(args.device) 101 | for key in positive_reference: positive_reference[key] = positive_reference[key].to(args.device) 102 | for key in negative_reference: negative_reference[key] = negative_reference[key].to(args.device) 103 | 104 | return question, positive_reference, negative_reference 105 | 106 | 107 | def eval(): 108 | model.eval() 109 | with torch.no_grad(): 110 | total_acc = 0 111 | for q, pos, neg in eval_loader: 112 | results = model(q, pos, neg) 113 | tot_cr = model.num_correct(*results) 114 | total_acc += tot_cr 115 | 116 | print("EVALUATION, Acc: %10.6f" % (total_acc / len(eval_set))) 117 | 118 | 119 | def save(name): 120 | os.makedirs(log_dir, exist_ok=True) 121 | model.question_encoder.save_pretrained(os.path.join(log_dir, name, "query_encoder")) 122 | model.reference_encoder.save_pretrained(os.path.join(log_dir, name, "reference_encoder")) 123 | 124 | 125 | def train(max_epoch=10, eval_step=200, save_step=400, print_step=50): 126 | step = 0 127 | for epoch in range(0, max_epoch): 128 | print("EPOCH %d" % epoch) 129 | for q, pos, neg in train_loader: 130 | model.train() 131 | step += 1 132 | opt.zero_grad() 133 | results = model(q, pos, neg) 134 | loss = model.loss(*results) 135 | 136 | if step % print_step == 0: 137 | print("Step %4d, Loss, Acc: %10.6f, %10.6f" % (step, loss, model.acc(*results))) 138 | 139 | loss.backward() 140 | opt.step() 141 | 142 | scheduler.step() 143 | model.zero_grad() 144 | if step % eval_step == 0: 145 | eval() 146 | pass 147 | if step % save_step == 0: 148 | save("step-%d" % (step)) 149 | 150 | save("step-%d-epoch-%d" % (step, epoch)) 151 | # eval() 152 | 153 | 154 | def data_process(): 155 | """Use preprocesed data """ 156 | with open("../raw_data/retro_source_report.txt", "r") as file: 157 | document = file.read() 158 | documents = document.split("Doc ") 159 | documents = ["Doc " + document for document in documents[1:]] 160 | 161 | with open("../raw_data/retro_qa.json", "r") as file: 162 | data = json.load(file) 163 | 164 | features = [] 165 | for item in data: 166 | print("item: ", item) 167 | positive_ids = [int(i) for i in item["pos_index"].split(",")] 168 | negative_ids = [int(i) for i in item["neg_index"].split(",")] 169 | for pos, neg in itertools.product(positive_ids, negative_ids): 170 | print(len(features), pos, neg) 171 | features.append({ 172 | "question": item["input"], 173 | "positive_label": pos, 174 | "positive_reference": documents[pos], 175 | "negative_label": neg, 176 | "negative_reference": documents[neg] 177 | }) 178 | num_training = 5 * int(len(data) * 0.8) 179 | train_data = features[:num_training] 180 | eval_data = features[num_training:] 181 | train_data = datasets.Dataset.from_list(train_data) 182 | eval_data = datasets.Dataset.from_list(eval_data) 183 | train_data.save_to_disk("../raw_data/retriever/train") 184 | eval_data.save_to_disk("../raw_data/retriever/eval") 185 | 186 | 187 | if __name__ == "__main__": 188 | args = argparse.ArgumentParser() 189 | args.add_argument("--max_epoch", type=int, default=3) 190 | args.add_argument("--eval_step", type=int, default=40) 191 | args.add_argument("--save_step", type=int, default=40) 192 | args.add_argument("--print_step", type=int, default=40) 193 | args.add_argument("--device", type=str, default="cpu") 194 | args.add_argument("--temp", type=float, default=0.05) 195 | args.add_argument("--train_batch_size", type=int, default=64) 196 | args.add_argument("--eval_batch_size", type=int, default=32) 197 | args.add_argument("--lr", type=float, default=1e-6) 198 | args.add_argument("--warmup", type=int, default=100) 199 | args.add_argument("--total", type=int, default=1000) 200 | args.add_argument("--ratio", type=float, default=0.0) 201 | args.add_argument("--save_dir", type=str, default="./retriever_runs") 202 | args.add_argument("--train_data_dir", type=str, default="../raw_data/retriever") 203 | args.add_argument("--train_data_dir", type=str, default="m3e-small") 204 | 205 | args = args.parse_args() 206 | 207 | log_dir = os.path.join(args.save_dir, time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time()))) 208 | 209 | train_set = load_from_disk(os.path.join(args.train_data_dir, "train")) 210 | eval_set = load_from_disk(os.path.join(args.train_data_dir, "eval")) 211 | 212 | tokenizer = AutoTokenizer.from_pretrained(args.model_dir) 213 | train_loader = DataLoader(train_set, batch_size=args.train_batch_size, collate_fn=collate) 214 | eval_loader = DataLoader(eval_set, batch_size=args.eval_batch_size, collate_fn=collate) 215 | 216 | model = QueryRefEncoderMainModel(args.model_dir) 217 | model = model.to(args.device) 218 | opt = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01) 219 | scheduler_args = { 220 | "warmup": args.warmup, 221 | "total": args.total, 222 | "ratio": args.ratio, 223 | } 224 | scheduler = WarmupLinearScheduler(opt, **scheduler_args) 225 | temp = args.temp 226 | 227 | train(max_epoch=args.max_epoch, eval_step=args.eval_step, save_step=args.save_step, print_step=args.print_step) 228 | -------------------------------------------------------------------------------- /webui/chatbot_demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from weaverbird.utils import get_kbs_list 4 | 5 | 6 | class Demo: 7 | theme = gr.themes.Soft() 8 | 9 | block_css = """.importantButton { 10 | background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; 11 | border: none !important; 12 | } 13 | .importantButton:hover { 14 | background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important; 15 | border: none !important; 16 | } 17 | a { 18 | cursor: pointer; 19 | text-decoration: none !important; 20 | align-items: center; 21 | justify-content: center; 22 | min-width: max-content; 23 | height: 24px; 24 | border-radius: 4px; 25 | box-sizing: border-box !important; 26 | padding: 0px 8px; 27 | color: #174ae4 !important; 28 | background-color: #d1dbfa; 29 | margin-inline-end: 6px; 30 | } 31 | 32 | a:hover { 33 | text-decoration: underline; 34 | } 35 | .message { 36 | width: auto !important; 37 | } 38 | .custom_btn { 39 | font-size: small !important; 40 | } 41 | .custom_btn_2 { 42 | font-size: 1.5em; 43 | } 44 | img { 45 | max-width: 100%; 46 | max-height:100%; 47 | } 48 | .container { 49 | display: grid; 50 | align-items: center; 51 | grid-template-columns: 1fr 1fr 1fr; 52 | column-gap: 1px; 53 | } 54 | .column { 55 | float: left; 56 | width: 33%; 57 | } 58 | 59 | #tracker { 60 | background-color: transparent; 61 | margin-inline-end: unset; 62 | padding: unset; 63 | } 64 | #tracker img { 65 | margin-left: 4em; 66 | } 67 | .custom_height { 68 | height: 2.5em; 69 | } 70 | .custom_width { 71 | max-width: 2.5em; 72 | min-width: 2.5em !important; 73 | } 74 | """ 75 | 76 | en_title = """ 77 |

WeaverBird

78 | """ 79 | 80 | en_sub_title = """ 81 |

An Open and Light GPT for Finance

82 | """ 83 | 84 | en_examples = [ 85 | ["How will ChatGPT contribute to Nvidia's AI business in the short term?"], 86 | ["What does Tesla’s Elon Musk think of BYD rivalry"], 87 | ['What are the growth prospects for Microsoft Corporation in coming years'] 88 | ] 89 | 90 | en_input_text = gr.Textbox( 91 | show_label=False, 92 | placeholder="""Ask a question and press ENTER. Be specific: use company names and specify times for best results. 93 | """, 94 | container=False) 95 | 96 | cn_title = """ 97 |

织工鸟

98 | """ 99 | 100 | cn_sub_title = """ 101 |

一个开源且轻量级的金融领域GPT

102 | """ 103 | 104 | cn_examples = [ 105 | ['阿里巴巴的2023年Q1净利润多少?'], 106 | ['请写一篇公司简评,标题为比亚迪(002594.SZ):2023年一季度业绩高速增长'], 107 | ["半夏资本李蓓的最新投资观点是什么"], 108 | ] 109 | 110 | cn_input_text = gr.Textbox( 111 | show_label=False, 112 | placeholder="输入问题,按回车键提交。请具体一些并包含公司名和时间段,这样效果会更好。", 113 | container=False) 114 | 115 | kb_root_dir = '' 116 | 117 | def __init__(self, **kwargs): 118 | self.theme = kwargs.pop('theme', self.theme) 119 | self.block_css = kwargs.pop('block_css', self.block_css) 120 | self.en_title = kwargs.pop('en_title', self.en_title) 121 | self.en_examples = kwargs.pop('en_examples', self.en_examples) 122 | self.cn_title = kwargs.pop('cn_title', self.cn_title) 123 | self.cn_examples = kwargs.pop('cn_examples', self.cn_examples) 124 | self.kb_root_dir = kwargs.pop('kb_root_dir', self.kb_root_dir) 125 | 126 | @staticmethod 127 | def set_example(example: list) -> dict: 128 | return gr.Textbox.update(value=example[0]) 129 | 130 | @staticmethod 131 | def reset_history(): 132 | return [], [] 133 | 134 | def init_model(self): 135 | return 136 | 137 | def get_answer(self, query, chatbot, history, vs_name, search_engine): 138 | return 139 | 140 | def run(self): 141 | with gr.Blocks(css=self.block_css, theme=self.theme) as demo: 142 | chat_history = gr.State([]) 143 | with gr.Tab('English'): 144 | gr.HTML(self.en_title) 145 | gr.HTML(self.en_sub_title) 146 | with gr.Tab('Chat'): 147 | with gr.Row(): 148 | with gr.Column(scale=1): 149 | search_engine = gr.Radio(["Off", "On"], label="Search Engine", value="On") 150 | kb_list = get_kbs_list(self.kb_root_dir) 151 | select_kb = gr.Dropdown( 152 | kb_list, 153 | label="Knowledge Base", 154 | interactive=True, 155 | value=kb_list[0] if len(kb_list) > 0 else None 156 | ) 157 | with gr.Accordion("Try Asking About"): 158 | example_text = gr.Examples(examples=self.en_examples, 159 | fn=self.set_example, 160 | inputs=self.en_input_text, 161 | outputs=self.en_input_text, 162 | label="Examples") 163 | with gr.Column(scale=5): 164 | with gr.Row(): 165 | chatbot = gr.Chatbot(elem_id="chat-box", show_label=False, height=500) 166 | with gr.Row(): 167 | empty_btn = gr.Button("🗑️ ", 168 | elem_classes=['custom_height', 'custom_width', 'custom_btn_2']) 169 | self.en_input_text.render() 170 | sub_btn = gr.Button("➡️", 171 | elem_classes=['custom_height', 'custom_width', 'custom_btn_2']) 172 | sub_btn.click(self.get_answer, 173 | [self.en_input_text, chatbot, chat_history, select_kb, search_engine], 174 | [chatbot, chat_history, self.en_input_text]) 175 | 176 | empty_btn.click(self.reset_history, outputs=[chatbot, chat_history], show_progress=True) 177 | 178 | self.en_input_text.submit(self.get_answer, 179 | [self.en_input_text, chatbot, chat_history, select_kb, search_engine], 180 | [chatbot, chat_history, self.en_input_text]) 181 | 182 | with gr.Tab('中文'): 183 | gr.HTML(self.cn_title) 184 | gr.HTML(self.cn_sub_title) 185 | with gr.Tab('对话'): 186 | with gr.Row(): 187 | with gr.Column(scale=1): 188 | search_engine = gr.Radio(["关", "开"], label="搜索引擎", value="开") 189 | kb_list = get_kbs_list(self.kb_root_dir) 190 | select_kb = gr.Dropdown( 191 | kb_list, 192 | label="知识库", 193 | interactive=True, 194 | value=kb_list[0] if len(kb_list) > 0 else None 195 | ) 196 | with gr.Accordion("可以尝试问这些问题"): 197 | example_text = gr.Examples(examples=self.cn_examples, 198 | fn=self.set_example, 199 | inputs=self.cn_input_text, 200 | outputs=self.cn_input_text, 201 | label="参考问题") 202 | with gr.Column(scale=5): 203 | with gr.Row(): 204 | chatbot = gr.Chatbot(elem_id="chat-box", show_label=False, height=500) 205 | with gr.Row(): 206 | empty_btn = gr.Button("🗑️ ", 207 | elem_classes=['custom_height', 'custom_width', 'custom_btn_2']) 208 | self.cn_input_text.render() 209 | sub_btn = gr.Button("➡️", 210 | elem_classes=['custom_height', 'custom_width', 'custom_btn_2']) 211 | sub_btn.click(self.get_answer, 212 | [self.cn_input_text, chatbot, chat_history, select_kb, search_engine], 213 | [chatbot, chat_history, self.en_input_text]) 214 | 215 | empty_btn.click(self.reset_history, outputs=[chatbot, chat_history], show_progress=True) 216 | 217 | self.cn_input_text.submit(self.get_answer, 218 | [self.cn_input_text, chatbot, chat_history, select_kb, search_engine], 219 | [chatbot, chat_history, self.cn_input_text]) 220 | 221 | demo.queue(concurrency_count=50).launch( 222 | server_name='0.0.0.0', 223 | show_api=False, 224 | share=False, 225 | inbrowser=False) 226 | 227 | 228 | def main(): 229 | my_demo = Demo() 230 | 231 | my_demo.run() 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /weaverbird/models/template.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 3 | 4 | import tiktoken 5 | 6 | if TYPE_CHECKING: 7 | from transformers import PreTrainedTokenizer 8 | from weaverbird.utils import logger 9 | 10 | 11 | # Source: https://github.com/hiyouga/LLaMA-Efficient-Tuning/blob/469f859161dec0e34f4cc849f20e43d442680b5c/src/llmtuner/extras/template.py#L183 12 | 13 | @dataclass 14 | class Template: 15 | prefix: List[Union[str, Dict[str, str]]] 16 | prompt: List[Union[str, Dict[str, str]]] 17 | system: str 18 | sep: List[Union[str, Dict[str, str]]] 19 | stop_words: List[str] 20 | use_history: bool 21 | efficient_eos: bool 22 | 23 | def encode_oneturn( 24 | self, 25 | tokenizer: "PreTrainedTokenizer", 26 | query: str, 27 | resp: str, 28 | history: Optional[List[Tuple[str, str]]] = None, 29 | system: Optional[str] = None 30 | ) -> Tuple[List[int], List[int]]: 31 | r""" 32 | Returns a single pair of token ids representing prompt and response respectively. 33 | """ 34 | system, history = self._format(query, resp, history, system) 35 | encoded_pairs = self._encode(tokenizer, system, history) 36 | prompt_ids = [] 37 | for query_ids, resp_ids in encoded_pairs[:-1]: 38 | prompt_ids = prompt_ids + query_ids + resp_ids 39 | prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] 40 | return prompt_ids, answer_ids 41 | 42 | def encode_multiturn( 43 | self, 44 | tokenizer: "PreTrainedTokenizer", 45 | query: str, 46 | resp: str, 47 | history: Optional[List[Tuple[str, str]]] = None, 48 | system: Optional[str] = None 49 | ) -> List[Tuple[List[int], List[int]]]: 50 | r""" 51 | Returns multiple pairs of token ids representing prompts and responses respectively. 52 | """ 53 | system, history = self._format(query, resp, history, system) 54 | encoded_pairs = self._encode(tokenizer, system, history) 55 | return encoded_pairs 56 | 57 | def _format( 58 | self, 59 | query: str, 60 | resp: str, 61 | history: Optional[List[Tuple[str, str]]] = None, 62 | system: Optional[str] = None 63 | ) -> Tuple[str, List[Tuple[str, str]]]: 64 | r""" 65 | Aligns inputs to the standard format. 66 | """ 67 | system = system or self.system # use system if provided 68 | history = history if (history and self.use_history) else [] 69 | history = history + [(query, resp)] 70 | return system, history 71 | 72 | def _get_special_ids( 73 | self, 74 | tokenizer: "PreTrainedTokenizer" 75 | ) -> Tuple[List[int], List[int]]: 76 | if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): 77 | bos_ids = [tokenizer.bos_token_id] 78 | else: # baichuan, qwen and gpt2 models have no bos token 79 | bos_ids = [] 80 | 81 | if tokenizer.eos_token_id is None: 82 | raise ValueError("EOS token is required.") 83 | 84 | if self.efficient_eos: # used in baichuan, qwen, chatglm, etc. 85 | eos_ids = [] 86 | else: 87 | eos_ids = [tokenizer.eos_token_id] 88 | 89 | return bos_ids, eos_ids 90 | 91 | def _encode( 92 | self, 93 | tokenizer: "PreTrainedTokenizer", 94 | system: str, 95 | history: List[Tuple[str, str]] 96 | ) -> List[Tuple[List[int], List[int]]]: 97 | r""" 98 | Encodes formatted inputs to pairs of token ids. 99 | Turn 0: bos + prefix + sep + query resp + eos 100 | Turn t: sep + bos + query resp + eos 101 | """ 102 | bos_ids, eos_ids = self._get_special_ids(tokenizer) 103 | sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) 104 | encoded_pairs = [] 105 | for turn_idx, (query, resp) in enumerate(history): 106 | if turn_idx == 0: 107 | prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) 108 | if len(prefix_ids) != 0: # has prefix 109 | prefix_ids = bos_ids + prefix_ids + sep_ids 110 | else: 111 | prefix_ids = bos_ids 112 | else: 113 | prefix_ids = sep_ids + bos_ids 114 | 115 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx)) 116 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) 117 | encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) 118 | return encoded_pairs 119 | 120 | def _convert_inputs_to_ids( 121 | self, 122 | tokenizer: "PreTrainedTokenizer", 123 | context: List[Union[str, Dict[str, str]]], 124 | system: Optional[str] = None, 125 | query: Optional[str] = None, 126 | idx: Optional[str] = None 127 | ) -> List[int]: 128 | """ 129 | Converts context to token ids. 130 | """ 131 | if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) 132 | kwargs = dict(allowed_special="all") 133 | else: 134 | kwargs = dict(add_special_tokens=False) 135 | 136 | token_ids = [] 137 | for elem in context: 138 | if isinstance(elem, str): 139 | if len(elem) == 0: 140 | continue 141 | elem = elem.replace("{{system}}", system, 1) if system is not None else elem 142 | elem = elem.replace("{{query}}", query, 1) if query is not None else elem 143 | elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem 144 | token_ids = token_ids + tokenizer.encode(elem, **kwargs) 145 | elif isinstance(elem, dict): 146 | token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] 147 | else: 148 | raise NotImplementedError 149 | 150 | return token_ids 151 | 152 | 153 | @dataclass 154 | class Llama2Template(Template): 155 | 156 | def _encode( 157 | self, 158 | tokenizer: "PreTrainedTokenizer", 159 | system: str, 160 | history: List[Tuple[str, str]] 161 | ) -> List[Tuple[List[int], List[int]]]: 162 | """ 163 | Encodes formatted inputs to pairs of token ids. 164 | Turn 0: bos + prefix + query resp + eos 165 | Turn t: bos + query resp + eos 166 | """ 167 | bos_ids, eos_ids = self._get_special_ids(tokenizer) 168 | encoded_pairs = [] 169 | for turn_idx, (query, resp) in enumerate(history): 170 | if turn_idx == 0: # llama2 template has no sep_ids 171 | query = self.prefix[0].replace("{{system}}", system) + query 172 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) 173 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) 174 | encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) 175 | return encoded_pairs 176 | 177 | 178 | templates: Dict[str, Template] = {} 179 | 180 | 181 | def register_template( 182 | name: str, 183 | prefix: List[Union[str, Dict[str, str]]], 184 | prompt: List[Union[str, Dict[str, str]]], 185 | system: str, 186 | sep: List[Union[str, Dict[str, str]]], 187 | stop_words: Optional[List[str]] = [], 188 | use_history: Optional[bool] = True, 189 | efficient_eos: Optional[bool] = False 190 | ) -> None: 191 | template_class = Llama2Template if "llama2" in name else Template 192 | templates[name] = template_class( 193 | prefix=prefix, 194 | prompt=prompt, 195 | system=system, 196 | sep=sep, 197 | stop_words=stop_words, 198 | use_history=use_history, 199 | efficient_eos=efficient_eos 200 | ) 201 | 202 | 203 | def get_template_and_fix_tokenizer( 204 | name: str, 205 | tokenizer: "PreTrainedTokenizer" 206 | ) -> Template: 207 | if tokenizer.eos_token_id is None: 208 | tokenizer.eos_token = "<|endoftext|>" 209 | logger.info("Add eos token: {}".format(tokenizer.eos_token)) 210 | 211 | if tokenizer.pad_token_id is None: 212 | tokenizer.pad_token = tokenizer.eos_token 213 | logger.info("Add pad token: {}".format(tokenizer.pad_token)) 214 | 215 | if name is None: 216 | return None 217 | 218 | template = templates.get(name, None) 219 | assert template is not None, "Template {} does not exist.".format(name) 220 | tokenizer.add_special_tokens( 221 | dict(additional_special_tokens=template.stop_words), 222 | replace_additional_special_tokens=False 223 | ) 224 | return template 225 | 226 | 227 | r""" 228 | Supports language model inference without histories. 229 | """ 230 | register_template( 231 | name="vanilla", 232 | prefix=[], 233 | prompt=[ 234 | "{{query}}" 235 | ], 236 | system="", 237 | sep=[], 238 | use_history=False 239 | ) 240 | 241 | r""" 242 | Default template. 243 | """ 244 | register_template( 245 | name="default", 246 | prefix=[ 247 | "{{system}}" 248 | ], 249 | prompt=[ 250 | "Human: {{query}}\nAssistant: " 251 | ], 252 | system=( 253 | "A chat between a curious user and an artificial intelligence assistant. " 254 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 255 | ), 256 | sep=[ 257 | "\n" 258 | ] 259 | ) 260 | 261 | r""" 262 | Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf 263 | https://huggingface.co/meta-llama/Llama-2-13b-chat-hf 264 | https://huggingface.co/meta-llama/Llama-2-70b-chat-hf 265 | """ 266 | register_template( 267 | name="llama2", 268 | prefix=[ 269 | "<>\n{{system}}\n<>\n\n" 270 | ], 271 | prompt=[ 272 | "[INST] {{query}} [/INST] " 273 | ], 274 | system=( 275 | "You are a helpful, respectful and honest assistant. " 276 | "Always answer as helpfully as possible, while being safe. " 277 | "Your answers should not include any harmful, unethical, " 278 | "racist, sexist, toxic, dangerous, or illegal content. " 279 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 280 | "If a question does not make any sense, or is not factually coherent, " 281 | "explain why instead of answering something not correct. " 282 | "If you don't know the answer to a question, please don't share false information." 283 | ), 284 | sep=[] 285 | ) 286 | 287 | r""" 288 | Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 289 | https://huggingface.co/ziqingyang/chinese-alpaca-2-7b 290 | """ 291 | register_template( 292 | name="llama2_zh", 293 | prefix=[ 294 | "<>\n{{system}}\n<>\n\n" 295 | ], 296 | prompt=[ 297 | "[INST] {{query}} [/INST] " 298 | ], 299 | system="You are a helpful assistant. 你是一个乐于助人的助手。", 300 | sep=[] 301 | ) 302 | 303 | r""" 304 | Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff 305 | https://github.com/ymcui/Chinese-LLaMA-Alpaca 306 | """ 307 | register_template( 308 | name="alpaca", 309 | prefix=[ 310 | "{{system}}" 311 | ], 312 | prompt=[ 313 | "### Instruction:\n{{query}}\n\n### Response:\n" 314 | ], 315 | system=( 316 | "Below is an instruction that describes a task. " 317 | "Write a response that appropriately completes the request." 318 | ), 319 | sep=[ 320 | "\n\n" 321 | ] 322 | ) 323 | 324 | r""" 325 | Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 326 | https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 327 | """ 328 | register_template( 329 | name="vicuna", 330 | prefix=[ 331 | "{{system}}" 332 | ], 333 | prompt=[ 334 | "USER: {{query}} ASSISTANT: " 335 | ], 336 | system=( 337 | "A chat between a curious user and an artificial intelligence assistant. " 338 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 339 | ), 340 | sep=[] 341 | ) 342 | 343 | r""" 344 | Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B 345 | """ 346 | register_template( 347 | name="belle", 348 | prefix=[ 349 | "{{system}}" 350 | ], 351 | prompt=[ 352 | "Human: {{query}}\n\nBelle: " 353 | ], 354 | system="", 355 | sep=[ 356 | "\n\n" 357 | ] 358 | ) 359 | 360 | r""" 361 | Supports: https://github.com/CVI-SZU/Linly 362 | """ 363 | register_template( 364 | name="linly", 365 | prefix=[ 366 | "{{system}}" 367 | ], 368 | prompt=[ 369 | "User: {{query}}\nBot: " 370 | ], 371 | system="", 372 | sep=[ 373 | "\n" 374 | ] 375 | ) 376 | 377 | r""" 378 | Supports: https://github.com/Neutralzz/BiLLa 379 | """ 380 | register_template( 381 | name="billa", 382 | prefix=[ 383 | "{{system}}" 384 | ], 385 | prompt=[ 386 | "Human: {{query}}\nAssistant: " 387 | ], 388 | system="", 389 | sep=[ 390 | "\n" 391 | ] 392 | ) 393 | 394 | r""" 395 | Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 396 | """ 397 | register_template( 398 | name="ziya", 399 | prefix=[ 400 | "{{system}}" 401 | ], 402 | prompt=[ 403 | {"token": ""}, 404 | ":{{query}}\n", 405 | {"token": ""}, 406 | ":" 407 | ], 408 | system="", 409 | sep=[ 410 | "\n" 411 | ] 412 | ) 413 | 414 | r""" 415 | Supports: https://huggingface.co/qhduan/aquilachat-7b 416 | """ 417 | register_template( 418 | name="aquila", 419 | prefix=[ 420 | "{{system}}" 421 | ], 422 | prompt=[ 423 | "Human: {{query}}###Assistant: " 424 | ], 425 | system=( 426 | "A chat between a curious human and an artificial intelligence assistant. " 427 | "The assistant gives helpful, detailed, and polite answers to the human's questions." 428 | ), 429 | sep=[ 430 | "###" 431 | ] 432 | ) 433 | 434 | r""" 435 | Supports: https://huggingface.co/internlm/internlm-chat-7b 436 | """ 437 | register_template( 438 | name="intern", 439 | prefix=[ 440 | "{{system}}" 441 | ], 442 | prompt=[ 443 | "<|User|>:{{query}}", 444 | {"token": ""}, 445 | "\n<|Bot|>:" 446 | ], 447 | system="", 448 | sep=[ 449 | {"token": ""}, 450 | "\n" 451 | ], 452 | stop_words=[ 453 | "" 454 | ], 455 | efficient_eos=True 456 | ) 457 | 458 | r""" 459 | Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat 460 | """ 461 | register_template( 462 | name="baichuan", 463 | prefix=[ 464 | "{{system}}" 465 | ], 466 | prompt=[ 467 | {"token": ""}, # user token 468 | "{{query}}", 469 | {"token": ""} # assistant token 470 | ], 471 | system="", 472 | sep=[], 473 | efficient_eos=True 474 | ) 475 | 476 | r""" 477 | Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat 478 | https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat 479 | """ 480 | register_template( 481 | name="baichuan2", 482 | prefix=[ 483 | "{{system}}" 484 | ], 485 | prompt=[ 486 | {"token": ""}, # user token 487 | "{{query}}", 488 | {"token": ""} # assistant token 489 | ], 490 | system="", 491 | sep=[], 492 | efficient_eos=True 493 | ) 494 | 495 | r""" 496 | Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha 497 | https://huggingface.co/HuggingFaceH4/starchat-beta 498 | """ 499 | register_template( 500 | name="starchat", 501 | prefix=[ 502 | {"token": "<|system|>"}, 503 | "\n{{system}}", 504 | ], 505 | prompt=[ 506 | {"token": "<|user|>"}, 507 | "\n{{query}}", 508 | {"token": "<|end|>"}, 509 | "\n", 510 | {"token": "<|assistant|>"} 511 | ], 512 | system="", 513 | sep=[ 514 | {"token": "<|end|>"}, 515 | "\n" 516 | ], 517 | stop_words=[ 518 | "<|end|>" 519 | ], 520 | efficient_eos=True 521 | ) 522 | 523 | r""" 524 | Supports: https://huggingface.co/Qwen/Qwen-7B-Chat 525 | """ 526 | register_template( 527 | name="chatml", 528 | prefix=[ 529 | {"token": "<|im_start|>"}, 530 | "system\n{{system}}" 531 | ], 532 | prompt=[ 533 | {"token": "<|im_start|>"}, 534 | "user\n{{query}}", 535 | {"token": "<|im_end|>"}, 536 | "\n", 537 | {"token": "<|im_start|>"}, 538 | "assistant\n" 539 | ], 540 | system="You are a helpful assistant.", 541 | sep=[ 542 | {"token": "<|im_end|>"}, 543 | "\n" 544 | ], 545 | stop_words=[ 546 | "<|im_end|>" 547 | ], 548 | efficient_eos=True 549 | ) 550 | 551 | r""" 552 | Supports: https://huggingface.co/THUDM/chatglm2-6b 553 | """ 554 | register_template( 555 | name="chatglm2", 556 | prefix=[ 557 | {"token": "[gMASK]"}, 558 | {"token": "sop"}, 559 | "{{system}}" 560 | ], 561 | prompt=[ 562 | "[Round {{idx}}]\n\n问:{{query}}\n\n答:" 563 | ], 564 | system="", 565 | sep=[ 566 | "\n\n" 567 | ], 568 | efficient_eos=True 569 | ) 570 | 571 | r""" 572 | Supports: https://huggingface.co/xverse/XVERSE-13B-Chat 573 | """ 574 | register_template( 575 | name="xverse", 576 | prefix=[ 577 | "{{system}}" 578 | ], 579 | prompt=[ 580 | "Human: {{query}}\n\nAssistant: " 581 | ], 582 | system="", 583 | sep=[] 584 | ) 585 | --------------------------------------------------------------------------------