├── requirements.txt ├── BatGPT-15B-sirius ├── configuration_batgpt.py ├── tokenization_batgpt.py └── modeling_batgpt.py ├── batgpt_web_demo.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0 2 | datasets 3 | matplotlib 4 | huggingface_hub 5 | streamlit 6 | gradio 7 | mdtex2html 8 | protobuf 9 | transformers 10 | cpm_kernels 11 | streamlit 12 | sentencepiece 13 | accelerate 14 | deepspeed -------------------------------------------------------------------------------- /BatGPT-15B-sirius/configuration_batgpt.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class BatGPTConfig(PretrainedConfig): 5 | 6 | model_type = "batgpt" 7 | 8 | def __init__( 9 | self, 10 | vocab_size=65024, 11 | emb_dim=5632, 12 | hidden_size=5632, 13 | n_layer=48, 14 | n_head=44, 15 | layer_norm_epsilon=1e-5, 16 | use_multi_query_attn=True, 17 | num_heads_per_kv=2, 18 | qkv_bias=True, 19 | use_native_attn_impl=True, 20 | mlp_activation="swiglu", 21 | hidden_dropout=0.0, 22 | ffn_hidden_size=13696, 23 | prefix_size=None, 24 | prefix_proj=False, 25 | max_seq_len=32768, 26 | pos_emb_impl="rope", 27 | use_emb_factorization=False, 28 | empty_init=True, 29 | **kwargs 30 | ): 31 | self.vocab_size = vocab_size 32 | self.emb_dim = emb_dim 33 | self.hidden_size = hidden_size 34 | self.n_layer = n_layer 35 | self.n_head = n_head 36 | self.layer_norm_epsilon = layer_norm_epsilon 37 | self.use_multi_query_attn = use_multi_query_attn 38 | self.num_heads_per_kv = num_heads_per_kv 39 | self.qkv_bias = qkv_bias 40 | self.use_native_attn_impl = use_native_attn_impl 41 | self.mlp_activation = mlp_activation 42 | self.hidden_dropout = hidden_dropout 43 | self.ffn_hidden_size = ffn_hidden_size 44 | self.prefix_size = prefix_size 45 | self.prefix_proj = prefix_proj 46 | self.max_seq_len = max_seq_len 47 | self.pos_emb_impl = pos_emb_impl 48 | self.use_emb_factorization = use_emb_factorization 49 | self.empty_init = empty_init 50 | super().__init__(**kwargs) -------------------------------------------------------------------------------- /batgpt_web_demo.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | import streamlit as st 3 | import torch 4 | import sys 5 | 6 | st.set_page_config( 7 | page_title="BatGPT-15B", 8 | page_icon=":robot:", 9 | layout='wide' 10 | ) 11 | 12 | @st.cache_resource 13 | def get_model(): 14 | from transformers import AutoModelForCausalLM, AutoTokenizer 15 | tokenizer = AutoTokenizer.from_pretrained("MLP-lab/BatGPT-15B-sirius", trust_remote_code=True) 16 | model = AutoModelForCausalLM.from_pretrained("MLP-lab/BatGPT-15B-sirius", torch_dtype=torch.float16, trust_remote_code=True).cuda() 17 | model = model.eval() 18 | return tokenizer, model 19 | 20 | 21 | tokenizer, model = get_model() 22 | 23 | st.title("BatGPT-15B Demo") 24 | 25 | max_length = st.sidebar.slider( 26 | 'max_length', 0, 32768, 8192, step=1 27 | ) 28 | top_p = st.sidebar.slider( 29 | 'top_p', 0.0, 1.0, 0.8, step=0.01 30 | ) 31 | temperature = st.sidebar.slider( 32 | 'temperature', 0.0, 1.0, 0.8, step=0.01 33 | ) 34 | 35 | if 'history' not in st.session_state: 36 | st.session_state.history = [] 37 | 38 | if 'past_key_values' not in st.session_state: 39 | st.session_state.past_key_values = None 40 | 41 | for i, (query, response) in enumerate(st.session_state.history): 42 | with st.chat_message(name="user", avatar="user"): 43 | st.markdown(query) 44 | with st.chat_message(name="assistant", avatar="assistant"): 45 | st.markdown(response) 46 | with st.chat_message(name="user", avatar="user"): 47 | input_placeholder = st.empty() 48 | with st.chat_message(name="assistant", avatar="assistant"): 49 | message_placeholder = st.empty() 50 | 51 | prompt_text = st.text_area(label="User Input", 52 | height=100, 53 | placeholder="Please enter") 54 | 55 | button = st.button("Send", key="predict") 56 | 57 | if button: 58 | input_placeholder.markdown(prompt_text) 59 | history, past_key_values = st.session_state.history, st.session_state.past_key_values 60 | for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, 61 | past_key_values=past_key_values, 62 | max_length=max_length, top_p=top_p, 63 | temperature=temperature, 64 | return_past_key_values=True): 65 | message_placeholder.markdown(response) 66 | 67 | st.session_state.history = history 68 | st.session_state.past_key_values = past_key_values -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BatGPT 2 | Bidirectional Autoregressive Talker from Generative Pre-trained Transformer 3 | 4 | ## 开源计划 5 | 6 | ### 模型 7 | 8 | - [**BatGPT-15B-sirius**](https://huggingface.co/MLP-lab/BatGPT-15B-sirius): BatGPT第一个开源模型,具有 150 亿参数,在高质量中英文语料上进行双向自回归预训练得到,并进行了指令微调与强化对齐的学习,具有指令遵循能力、多轮对话能力、推理等能力。 9 | 10 | - **mBatGPT-15B-sirius**: 具有图像,语音多模态理解能力的大模型,基于 BatGPT-15B-sirius 在500万图像文本,语音文本对上进行二阶预训练得到。 11 | 12 | 13 | *** 14 | 15 | ## Demo 16 | 17 | 我们提供了一个基于Streamlit实现的网页Demo,您可以使用streamlit运行本仓库中的batgpt_web_demo.py来打开网页Demo: 18 | 19 | ```bash 20 | streamlit run batgpt_web_demo.py 21 | ``` 22 | 23 | *** 24 | 25 | ## **BatGPT-15B-sirius** 26 | 27 | ### 介绍 (Introduction) 28 | 29 | BatGPT-15B-sirius 是上海交通大学与武汉大学(或武汉大学与上海交通大学,排名不分先后)联合自然语言处理团队设计、预训练、对齐的系列大型语言模型 [BatGPT](https://github.com/zcli-charlie/BatGPT) 中的一个开源可商用版本。 30 | BatGPT系列模型中还包括BatGPT-30B-orion,BatGPT-70B-alhena,以及BatGPT-140B-menkalinan。 31 | 32 | BatGPT-15B-sirius 包含 150 亿参数,在中英文 1T 语料上进行了预训练,在权威的中文和英文 benchmark 上均取得同不错的效果。BatGPT-15B-sirius 有如下几个特点: 33 | 34 | 1. **支持长达32K的上下文**:BatGPT-15B-sirius 采用旋转位置编码RoPE,在预训练阶段采用 2048 序列长度,并且在指令微调阶段逐步扩展到了 32K 上下文。 35 | 2. **高效的预训练目标与模型架构**:BatGPT-15B-sirius 采用双向自回归预训练目标,以提高对于训练数据的运用程度,并且基于 [Multi-Query Attention](http://arxiv.org/abs/1911.02150) 技术,在保证参数规模的前提下尽可能的减少推理显存的占用,提高推理速度。 36 | 3. **商业友好的开放协议**:BatGPT-15B-sirius 的源码以及权重不仅支持自由的学术研究使用,也允许免费开源商用,助推大模型进一步帮助人类的日常生活。 37 | 38 | BatGPT-15B-sirius is an open-source commercially available version of the series of large-scale language models [BatGPT](https://github.com/zcli-charlie/BatGPT), designed, pretrained, and aligned by the joint natural language processing teams of Shanghai Jiao Tong University and Wuhan University (or Wuhan University and Shanghai Jiao Tong University, in no particular order). 39 | 40 | The BatGPT series of models also include BatGPT-30B-orion, BatGPT-70B-alhena, and BatGPT-140B-menkalinan. 41 | 42 | BatGPT-15B-sirius contains 15 billion parameters and has been pretrained on 1T Chinese and English corpora. It achieves excellent performance on authoritative Chinese and English benchmarks. BatGPT-15B-sirius has the following characteristics: 43 | 44 | 1. **Supports Contexts Up to 32K Tokens**: BatGPT-15B-sirius uses rotated positional encoding (RoPE) and is pretrained with a sequence length of 2048 tokens. During fine-tuning, it gradually expands to support contexts up to 32K tokens. 45 | 2. **Efficient Pre-training Objectives and Model Architecture**: BatGPT-15B-sirius employs a bidirectional autoregressive pretraining objective to better utilize the training data. It also utilizes the [Multi-Query Attention](http://arxiv.org/abs/1911.02150) technique to reduce inference memory consumption and improve inference speed while maintaining model size. 46 | 3. **Business-friendly Open License**: The source code and weights of BatGPT-15B-sirius are not only available for academic research but also allow free and open-source commercial use, further facilitating the integration of large language models into human daily life. 47 | 48 | 49 | ### 软件依赖 50 | 51 | ```shell 52 | pip install protobuf transformers cpm_kernels torch>=2.0 streamlit sentencepiece accelerate deepspeed 53 | ``` 54 | 55 | ### 简易使用 56 | 57 | 如下是一个使用 BatGPT-15B-sirius 进行对话的示例: 58 | 59 | ```python 60 | import torch 61 | from transformers import AutoModelForCausalLM, AutoTokenizer 62 | tokenizer = AutoTokenizer.from_pretrained("MLP-lab/BatGPT-15B-sirius", trust_remote_code=True) 63 | model = AutoModelForCausalLM.from_pretrained("MLP-lab/BatGPT-15B-sirius", torch_dtype=torch.float16, trust_remote_code=True).cuda() 64 | model = model.eval() 65 | history = [] 66 | system_prompt = None # 你也可以指定系统提示 67 | response, history = model.chat(tokenizer, "你好", history=history, system_prompt=system_prompt) 68 | print(response) 69 | response, history = model.chat(tokenizer, "介绍一下你自己", history=history, system_prompt=system_prompt) 70 | print(response) 71 | ``` 72 | 73 | Here is an example of a conversation using BatGPT-15B-sirius: 74 | 75 | ```python 76 | import torch 77 | from transformers import AutoModelForCausalLM, AutoTokenizer 78 | tokenizer = AutoTokenizer.from_pretrained("MLP-lab/BatGPT-15B-sirius", trust_remote_code=True) 79 | model = AutoModelForCausalLM.from_pretrained("MLP-lab/BatGPT-15B-sirius", torch_dtype=torch.float16, trust_remote_code=True).cuda() 80 | model = model.eval() 81 | history = [] 82 | system_prompt = None # You can give a system prompt here. 83 | response, history = model.chat(tokenizer, "Hello", history=history, system_prompt=system_prompt) 84 | print(response) 85 | response, history = model.chat(tokenizer, "Please introduce yourself", history=history, system_prompt=system_prompt) 86 | print(response) 87 | ``` 88 | 89 | 90 | ### 模型详情 (Model Details) 91 | 92 | 93 | BatGPT-15B-sirius 具体参数和见下表: 94 | 95 | | 模型名称 | 隐含层维度 | 层数 | Query头数 | Key/Value头数 |词表大小 | 总参数量 | 训练数据(tokens) | 位置编码 | 最大长度 | 96 | |-------------------------|-------|------------|------------|------------|-----------------|--------|--------|----------------|---------| 97 | | BatGPT-15B-sirius | 5,632 | 48 | 44 | 2 | 65,536 | 15,030,081,024 | 1 万亿 | [RoPE](https://arxiv.org/abs/2104.09864) | 32K | 98 | 99 | 100 | 101 | The specific parameters of BatGPT-15B-sirius are as follows: 102 | | Model Name | Hidden Size | Num Layers | Query Heads | Key/Value Heads |Vocab Size | Total Params | Training Dats(tokens) | Position Embedding | Max Length | 103 | |-------------------------|-------|------------|------------|------------|-----------------|--------|--------|----------------|---------| 104 | | BatGPT-15B-sirius | 5,632 | 48 | 44 | 2 | 65,536 | 15,030,081,024 | 1T | [RoPE](https://arxiv.org/abs/2104.09864) | 32K | 105 | 106 | 107 | 108 | - **Developed by:** MLP Lab of Wuhan University, Shanghai Jiao Tong University 109 | - **Email**: zcli-charlie@whu.edu.cn, zhaohai@cs.sjtu.edu.cn 110 | - **Language(s) (NLP):** Chinese/English 111 | - **License:** The code in this project is licensed under the Apache 2.0 license, the model weights are licensed under the GNU AGPL 3.0 license. If you intend to use the models included in this project for commercial purposes or public deployment, please email to us to obtain authorization. Commercial usage information will be used for record purposes only and no fees will be charged. 112 | 113 | 114 | ## 免责声明 (Disclaimers) 115 | 116 | BatGPT-15B-sirius 模型的使用应当遵循社会的公序良俗,不能被用于任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 BatGPT-15B-sirius 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。 117 | 118 | 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。如使用本项目所含模型及其修改版本提供服务产生误导性或有害性言论,造成不良影响,由服务提供方负责,与本项目无关。 119 | 120 | The use of the BatGPT-15B-sirius model should adhere to societal norms and not be used for any activities that jeopardize national or social security or violate the law. Additionally, we also request users not to use the BatGPT-15B-sirius model for internet services that have not undergone appropriate security review and documentation. We hope that all users will abide by this principle to ensure that technological development occurs in a regulated and legal environment. 121 | 122 | We have done our best to ensure the compliance of the data used during the model training process. However, despite our significant efforts, unforeseen issues may still arise due to the complexity of the model and data. If misleading or harmful statements are generated through the use of the models included in this project or their modified versions while providing services, the responsibility lies with the service provider and is not associated with this project. 123 | 124 | *** 125 | 126 | ## 引用 127 | 128 | 如果你觉得我们的工作有帮助的话,请考虑引用我们的BatGPT论文: 129 | 130 | If you find our work helpful, please consider citing our BatGPT paper: 131 | 132 | ``` 133 | @article{li2023batgpt, 134 | title={BatGPT: A Bidirectional Autoregessive Talker from Generative Pre-trained Transformer}, 135 | author={Li, Zuchao and Zhang, Shitou and Zhao, Hai and Yang, Yifei and Yang, Dongjie}, 136 | journal={arXiv preprint arXiv:2307.00360}, 137 | year={2023} 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /BatGPT-15B-sirius/tokenization_batgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import List, Optional, Union, Dict, Tuple 4 | from sentencepiece import SentencePieceProcessor 5 | from transformers import PreTrainedTokenizer 6 | from transformers.utils import logging, PaddingStrategy 7 | from transformers.tokenization_utils_base import EncodedInput, BatchEncoding 8 | 9 | SPECIAL_TOKENS = ["", "", "", "<para>", "<eop>", "<eot>", "<eod>"] + ["[User]", "[Assistant]", "[System]"] + ["[Turn {}]".format(i+1) for i in range(100)] 10 | 11 | class SPTokenizer: 12 | def __init__(self, model_path: str): 13 | # reload tokenizer 14 | assert os.path.isfile(model_path), model_path 15 | self.sp_model = SentencePieceProcessor(model_file=model_path) 16 | 17 | # BOS / EOS token IDs 18 | self.n_words: int = self.sp_model.vocab_size() 19 | self.bos_id: int = self.sp_model.bos_id() 20 | self.eos_id: int = self.sp_model.eos_id() 21 | self.pad_id: int = self.sp_model.unk_id() 22 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 23 | 24 | self.special_tokens = {} 25 | self.index_special_tokens = {} 26 | for token in SPECIAL_TOKENS: 27 | self.special_tokens[token] = self.n_words 28 | self.index_special_tokens[self.n_words] = token 29 | self.n_words += 1 30 | 31 | def tokenize(self, s: str): 32 | return self.sp_model.EncodeAsPieces(s) 33 | 34 | def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: 35 | assert type(s) is str 36 | t = self.sp_model.encode(s) 37 | if bos: 38 | t = [self.bos_id] + t 39 | if eos: 40 | t = t + [self.eos_id] 41 | return t 42 | 43 | def decode(self, t: List[int]) -> str: 44 | return self.sp_model.decode(t) 45 | 46 | def decode_tokens(self, tokens: List[str]) -> str: 47 | text = self.sp_model.DecodePieces(tokens) 48 | return text 49 | 50 | def convert_token_to_id(self, token): 51 | """ Converts a token (str) in an id using the vocab. """ 52 | if token in self.special_tokens: 53 | return self.special_tokens[token] 54 | return self.sp_model.PieceToId(token) 55 | 56 | def convert_id_to_token(self, index): 57 | """Converts an index (integer) in a token (str) using the vocab.""" 58 | if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: 59 | return "" 60 | return self.sp_model.IdToPiece(index) 61 | 62 | 63 | class BatGPTTokenizer(PreTrainedTokenizer): 64 | vocab_files_names = {"vocab_file": "tokenizer.model"} 65 | 66 | model_input_names = ["input_ids", "attention_mask", "position_ids"] 67 | 68 | def __init__(self, vocab_file, padding_side="left", **kwargs): 69 | super().__init__(padding_side=padding_side, **kwargs) 70 | self.name = "BatGPTTokenizer" 71 | 72 | self.vocab_file = vocab_file 73 | self.tokenizer = SPTokenizer(vocab_file) 74 | self.special_tokens = { 75 | "<bos>": self.tokenizer.bos_id, 76 | "<eos>": self.tokenizer.eos_id, 77 | "<pad>": self.tokenizer.pad_id 78 | } 79 | 80 | # 81 | self.unk_token = "<unk>" 82 | self.add_special_tokens({'additional_special_tokens': SPECIAL_TOKENS}) 83 | 84 | def get_command(self, token): 85 | if token in self.special_tokens: 86 | return self.special_tokens[token] 87 | assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" 88 | return self.tokenizer.special_tokens[token] 89 | 90 | @property 91 | def pad_token(self) -> str: 92 | return "<unk>" 93 | 94 | @property 95 | def pad_token_id(self): 96 | return self.get_command("<pad>") 97 | 98 | @property 99 | def eos_token(self) -> str: 100 | return "</s>" 101 | 102 | @property 103 | def eos_token_id(self): 104 | return self.get_command("<eos>") 105 | 106 | @property 107 | def vocab_size(self): 108 | return self.tokenizer.n_words 109 | 110 | def get_vocab(self): 111 | """ Returns vocab as a dict """ 112 | vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} 113 | vocab.update(self.added_tokens_encoder) 114 | return vocab 115 | 116 | def _tokenize(self, text, **kwargs): 117 | return self.tokenizer.tokenize(text) 118 | 119 | def _convert_token_to_id(self, token): 120 | """ Converts a token (str) in an id using the vocab. """ 121 | return self.tokenizer.convert_token_to_id(token) 122 | 123 | def _convert_id_to_token(self, index): 124 | """Converts an index (integer) in a token (str) using the vocab.""" 125 | return self.tokenizer.convert_id_to_token(index) 126 | 127 | def convert_tokens_to_string(self, tokens: List[str]) -> str: 128 | return self.tokenizer.decode_tokens(tokens) 129 | 130 | def save_vocabulary(self, save_directory, filename_prefix=None): 131 | if os.path.isdir(save_directory): 132 | vocab_file = os.path.join( 133 | save_directory, self.vocab_files_names["vocab_file"] 134 | ) 135 | else: 136 | vocab_file = save_directory 137 | 138 | with open(self.vocab_file, 'rb') as fin: 139 | proto_str = fin.read() 140 | 141 | with open(vocab_file, "wb") as writer: 142 | writer.write(proto_str) 143 | 144 | return (vocab_file,) 145 | 146 | def get_prefix_tokens(self): 147 | prefix_tokens = [self.get_command("<doc>"), self.get_command("<para>")] 148 | return prefix_tokens 149 | 150 | def build_inputs(self, query, history=None, system_prompt=None): 151 | if history is None: 152 | history = [] 153 | role_user = "[User]" 154 | role_assistant = "[Assistant]" 155 | if system_prompt: 156 | prompt = "[System]\n\n {}\n\n<eot>".format(system_prompt) 157 | else: 158 | prompt = "" 159 | for i, (old_query, response) in enumerate(history): 160 | prompt += "[Turn {}]\n\n{} {}\n\n{} {}\n\n<eop>".format(i + 1, role_user, old_query, role_assistant, response) 161 | prompt += "[Turn {}]\n\n{} {}\n\n{}".format(len(history) + 1, role_user, query, role_assistant) 162 | inputs = self([prompt], return_tensors="pt") 163 | return inputs 164 | 165 | def build_stream_inputs(self, query: str, history: List[Tuple[str, str]] = None, system_prompt = None): 166 | role_user = "[User]" 167 | role_assistant = "[Assistant]" 168 | if history: 169 | prompt = "\n\n[Turn {}]\n\n{} {}\n\n{}".format(len(history) + 1, role_user, query, role_assistant) 170 | input_ids = self.encode(prompt, add_special_tokens=False) 171 | input_ids = input_ids[1:] 172 | inputs = self.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) 173 | else: 174 | if system_prompt: 175 | prompt = "[System]\n\n {}\n\n[Turn {}]\n\n{} {}\n\n{} ".format(system_prompt, len(history) + 1, role_user, query, role_assistant) 176 | else: 177 | prompt = "[Turn {}]\n\n{} {}\n\n{} ".format(len(history) + 1, role_user, query, role_assistant) 178 | inputs = self([prompt], return_tensors="pt") 179 | return inputs 180 | 181 | def build_inputs_with_special_tokens( 182 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 183 | ) -> List[int]: 184 | prefix_tokens = self.get_prefix_tokens() 185 | token_ids_0 = prefix_tokens + token_ids_0 186 | if token_ids_1 is not None: 187 | token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")] 188 | return token_ids_0 189 | 190 | def _pad( 191 | self, 192 | encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], 193 | max_length: Optional[int] = None, 194 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 195 | pad_to_multiple_of: Optional[int] = None, 196 | return_attention_mask: Optional[bool] = None, 197 | ) -> dict: 198 | # Load from model defaults 199 | assert self.padding_side == "left" 200 | 201 | required_input = encoded_inputs[self.model_input_names[0]] 202 | seq_length = len(required_input) 203 | 204 | if padding_strategy == PaddingStrategy.LONGEST: 205 | max_length = len(required_input) 206 | 207 | if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 208 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 209 | 210 | needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length 211 | 212 | # Initialize attention mask if not present. 213 | if "attention_mask" not in encoded_inputs: 214 | encoded_inputs["attention_mask"] = [1] * seq_length 215 | 216 | if "position_ids" not in encoded_inputs: 217 | encoded_inputs["position_ids"] = list(range(seq_length)) 218 | 219 | if needs_to_be_padded: 220 | difference = max_length - len(required_input) 221 | 222 | if "attention_mask" in encoded_inputs: 223 | encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] 224 | if "position_ids" in encoded_inputs: 225 | encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] 226 | encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input 227 | 228 | return encoded_inputs 229 | -------------------------------------------------------------------------------- /BatGPT-15B-sirius/modeling_batgpt.py: -------------------------------------------------------------------------------- 1 | # This code serves as a port of the models described in BatGPT. 2 | # It is based on the bloom codebase, which provides the initial framework for our model implementation. 3 | # To understand how to use these models, please refer to the documentation and usage instructions provided in the bloom models repository. 4 | # Additionally, we draw inspiration from the ChatGLM and Baichuan codebase, which includes implementations for prefix encoder, chat, and stream_chat functionalities. These components are utilized in our ported models. 5 | # Feel free to explore the ChatGLM and Baichuan codebase for further insights on how these components can be utilized effectively. 6 | 7 | import math 8 | import warnings 9 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any 10 | 11 | import torch 12 | import torch.utils.checkpoint 13 | from torch import nn 14 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss 15 | from torch.nn import functional as F 16 | from torch.nn.utils import skip_init 17 | 18 | import copy 19 | import re 20 | import sys 21 | 22 | from transformers.modeling_outputs import ( 23 | BaseModelOutputWithPast, 24 | CausalLMOutputWithPast, 25 | ) 26 | from transformers.modeling_utils import PreTrainedModel 27 | from transformers.utils import logging 28 | from transformers.generation.logits_process import LogitsProcessor 29 | from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput 30 | 31 | from .configuration_batgpt import BatGPTConfig 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | 36 | # flags required to enable jit fusion kernels 37 | 38 | if sys.platform != 'darwin': 39 | torch._C._jit_set_profiling_mode(False) 40 | torch._C._jit_set_profiling_executor(False) 41 | torch._C._jit_override_can_fuse_on_cpu(True) 42 | torch._C._jit_override_can_fuse_on_gpu(True) 43 | 44 | 45 | # For faster llm model initilization 46 | def module_init(cls, empty_init, *args, **kwargs): 47 | if empty_init: 48 | return skip_init(cls, *args, **kwargs) 49 | else: 50 | return cls(*args, **kwargs) 51 | 52 | class InvalidScoreLogitsProcessor(LogitsProcessor): 53 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 54 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 55 | scores.zero_() 56 | scores[..., 5] = 5e4 57 | return scores 58 | 59 | 60 | class PrefixEncoder(torch.nn.Module): 61 | """ 62 | The torch.nn model to encode the prefix 63 | Input shape: (batch-size, prefix-length) 64 | Output shape: (batch-size, prefix-length, 2*layers*hidden) 65 | """ 66 | 67 | def __init__(self, config: BatGPTConfig): 68 | super().__init__() 69 | self.prefix_proj = config.prefix_proj 70 | self.head_dim = config.hidden_size // config.n_head 71 | if self.prefix_proj: 72 | # Use a two-layer MLP to encode the prefix 73 | kv_size = config.n_layer * self.head_dim * config.num_heads_per_kv * 2 74 | self.embedding = torch.nn.Embedding(config.prefix_size, kv_size) 75 | self.trans = torch.nn.Sequential( 76 | torch.nn.Linear(kv_size, config.hidden_size), 77 | torch.nn.Tanh(), 78 | torch.nn.Linear(config.hidden_size, kv_size) 79 | ) 80 | else: 81 | self.embedding = torch.nn.Embedding(config.prefix_size, 82 | config.n_layer * self.head_dim * config.num_heads_per_kv * 2) 83 | 84 | def forward(self, prefix: torch.Tensor): 85 | if self.prefix_proj: 86 | prefix_tokens = self.embedding(prefix) 87 | past_key_values = self.trans(prefix_tokens) 88 | else: 89 | past_key_values = self.embedding(prefix) 90 | return past_key_values 91 | 92 | 93 | def _get_interleave(n): 94 | def _get_interleave_power_of_2(n): 95 | start = (2 ** (-2 ** -(math.log2(n) - 3))) 96 | ratio = start 97 | return [start * ratio ** i for i in range(n)] 98 | 99 | if math.log2(n).is_integer(): 100 | return _get_interleave_power_of_2(n) 101 | else: 102 | closest_power_of_2 = 2 ** math.floor(math.log2(n)) 103 | return _get_interleave_power_of_2(closest_power_of_2) + \ 104 | _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2] 105 | 106 | def _fill_with_neg_inf(t): 107 | """FP16-compatible function that fills a tensor with -inf.""" 108 | return t.float().fill_(float("-inf")).type_as(t) 109 | 110 | def _gen_alibi_mask(n_head, max_pos): 111 | """used in inference only""" 112 | slopes = torch.Tensor(_get_interleave(n_head)) 113 | alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand( 114 | n_head, -1, -1) 115 | alibi = alibi.view(n_head, 1, max_pos) 116 | alibi_mask = torch.triu( 117 | _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1 118 | ) 119 | alibi_mask = alibi_mask.unsqueeze(0) + alibi 120 | return alibi_mask 121 | 122 | def _build_position_ids(input_ids, device): 123 | batch_size, seq_length = input_ids.shape 124 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) 125 | return position_ids 126 | 127 | def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): 128 | """used in training only""" 129 | dim = tensor.size(0) 130 | _future_mask = torch.triu( 131 | _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1 132 | ) 133 | _future_mask = _future_mask.unsqueeze(0) + alibi 134 | _future_mask = _future_mask.to(tensor) 135 | return _future_mask[:tensor.shape[1] * attn_heads, :maxpos, :maxpos] 136 | 137 | @torch.jit.script 138 | def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: 139 | # x: [sq, b, np, hn] 140 | sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) 141 | rot_dim = rope_cache.shape[-2] * 2 142 | x, x_pass = x[..., :rot_dim], x[..., rot_dim:] 143 | # truncate to support variable sizes 144 | rope_cache = rope_cache[:sq] 145 | xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) 146 | rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) 147 | x_out2 = torch.stack( 148 | [ 149 | xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], 150 | xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], 151 | ], 152 | -1, 153 | ) 154 | x_out2 = x_out2.flatten(3) 155 | return torch.cat((x_out2, x_pass), dim=-1) 156 | 157 | 158 | 159 | 160 | 161 | class RMSNorm(torch.nn.Module): 162 | def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): 163 | super().__init__() 164 | self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) 165 | self.eps = eps 166 | 167 | def forward(self, hidden_states: torch.Tensor): 168 | input_dtype = hidden_states.dtype 169 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 170 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 171 | 172 | return (self.weight * hidden_states).to(input_dtype) 173 | 174 | 175 | class SelfAttention(torch.nn.Module): 176 | def __init__(self, config: BatGPTConfig, device=None): 177 | super(SelfAttention, self).__init__() 178 | 179 | self.num_heads = config.n_head 180 | self.use_multi_query_attn = config.use_multi_query_attn 181 | self.num_heads_per_kv = config.num_heads_per_kv 182 | self.qkv_bias = config.qkv_bias 183 | self.use_native_attn_impl = config.use_native_attn_impl 184 | if not self.use_multi_query_attn: 185 | assert self.num_heads_per_kv == self.num_heads, "num_heads_per_kv must equal to num_heads when not use_multi_query_attn" 186 | 187 | self.head_dim = config.hidden_size // config.n_head 188 | 189 | self.query_proj = nn.Linear( 190 | config.hidden_size, config.hidden_size, bias=self.qkv_bias, 191 | device=device, **_config_to_kwargs(config) 192 | ) 193 | 194 | self.key_proj = nn.Linear( 195 | config.hidden_size, self.head_dim * self.num_heads_per_kv, bias=self.qkv_bias, 196 | device=device, **_config_to_kwargs(config) 197 | ) 198 | self.value_proj = nn.Linear( 199 | config.hidden_size, self.head_dim * self.num_heads_per_kv, bias=self.qkv_bias, 200 | device=device, **_config_to_kwargs(config) 201 | ) 202 | 203 | # Output. 204 | self.dense = nn.Linear( 205 | config.hidden_size, config.hidden_size, bias=False, 206 | device=device, **_config_to_kwargs(config) 207 | ) 208 | 209 | def forward( 210 | self, 211 | hidden_states, 212 | attention_mask, 213 | rotary_pos_emb, 214 | kv_cache=None, 215 | use_cache=True 216 | ): 217 | # 1. query/key/value mapping 218 | # hidden_states: [seq_len, batch_size, hidden_size] 219 | seq_len, batch_size, hidden_size = hidden_states.shape 220 | query_layer = self.query_proj(hidden_states) 221 | key_layer = self.key_proj(hidden_states) 222 | value_layer = self.value_proj(hidden_states) 223 | 224 | query_layer = query_layer.view(seq_len, batch_size, self.num_heads, self.head_dim) 225 | 226 | key_layer = key_layer.view(seq_len, batch_size, self.num_heads_per_kv, self.head_dim) 227 | 228 | value_layer = value_layer.view(seq_len, batch_size, self.num_heads_per_kv, self.head_dim) 229 | 230 | # 2. apply the rotary position embedding 231 | if rotary_pos_emb is not None: 232 | query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) 233 | key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) 234 | 235 | # 3. adjust key and value for inference 236 | if kv_cache is not None: 237 | cache_k, cache_v = kv_cache 238 | key_layer = torch.cat((cache_k, key_layer), dim=0) 239 | value_layer = torch.cat((cache_v, value_layer), dim=0) 240 | if use_cache: 241 | kv_cache = (key_layer, value_layer) 242 | else: 243 | kv_cache = None 244 | 245 | # 4. repeat the key and value for attention 246 | if self.num_heads_per_kv != self.num_heads: 247 | key_layer = key_layer.unsqueeze(-2) 248 | key_layer = key_layer.expand( 249 | -1, -1, -1, self.num_heads // self.num_heads_per_kv, -1 250 | ) 251 | key_layer = key_layer.contiguous().view( 252 | key_layer.size()[:2] + (self.num_heads, self.head_dim) 253 | ) 254 | value_layer = value_layer.unsqueeze(-2) 255 | value_layer = value_layer.expand( 256 | -1, -1, -1, self.num_heads // self.num_heads_per_kv, -1 257 | ) 258 | value_layer = value_layer.contiguous().view( 259 | value_layer.size()[:2] + (self.num_heads, self.head_dim) 260 | ) 261 | 262 | # 5. attention [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim] 263 | query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] 264 | 265 | pytorch_version = int(torch.__version__.split('.')[0]) 266 | if self.use_native_attn_impl and pytorch_version >= 2: 267 | if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: 268 | context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, 269 | is_causal=True) 270 | else: 271 | if attention_mask is not None: 272 | attention_mask = ~attention_mask 273 | context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, 274 | attention_mask) 275 | else: 276 | attention_scores = torch.matmul(query_layer, key_layer.transpose(2, 3)) / math.sqrt(self.head_dim) 277 | 278 | if attention_mask is not None: 279 | if seq_len == 1: # inference with cache 280 | if len(attention_mask.size()) == 4: 281 | attention_mask = attention_mask[:, :, -1:, :] 282 | else: 283 | attention_mask = attention_mask[:, -1:, :] 284 | attention_scores = attention_scores + attention_mask 285 | attention_scores = torch.max(attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min)) 286 | 287 | attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) 288 | 289 | context_layer = torch.matmul(attention_probs, value_layer) 290 | 291 | # [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim] 292 | context_layer = context_layer.permute(2, 0, 1, 3) 293 | 294 | # [seq_len, batch_size, hidden_size] 295 | context_layer = context_layer.reshape(seq_len, batch_size, hidden_size) 296 | 297 | # 298 | output = self.dense(context_layer) 299 | 300 | return output, kv_cache 301 | 302 | 303 | def _config_to_kwargs(args): 304 | common_kwargs = { 305 | "dtype": args.torch_dtype, 306 | } 307 | return common_kwargs 308 | 309 | 310 | class MLP(torch.nn.Module): 311 | def __init__(self, config: BatGPTConfig, device=None): 312 | super(MLP, self).__init__() 313 | self.mlp_activation = config.mlp_activation 314 | 315 | def swiglu(x): 316 | x = torch.chunk(x, 2, dim=-1) 317 | return F.silu(x[0]) * x[1] 318 | 319 | def silu(x): 320 | return F.silu(x) 321 | 322 | # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf 323 | if self.mlp_activation == "swiglu": 324 | self.activation_func = swiglu 325 | 326 | self.gate_proj = None 327 | 328 | self.dense_h_to_4h = nn.Linear( 329 | config.hidden_size, 330 | config.ffn_hidden_size * 2, 331 | bias=False, 332 | device=device, 333 | **_config_to_kwargs(config) 334 | ) 335 | elif self.mlp_activation == "silu": 336 | self.activation_func = silu 337 | 338 | self.gate_proj = nn.Linear( 339 | config.hidden_size, 340 | config.ffn_hidden_size, 341 | bias=False, 342 | device=device, 343 | **_config_to_kwargs(config) 344 | ) 345 | 346 | self.dense_h_to_4h = nn.Linear( 347 | config.hidden_size, 348 | config.ffn_hidden_size, 349 | bias=False, 350 | device=device, 351 | **_config_to_kwargs(config) 352 | ) 353 | else: 354 | raise NotImplementedError("mlp_activation {} not supported".format(self.mlp_activation)) 355 | 356 | # Project back to h. 357 | self.dense_4h_to_h = nn.Linear( 358 | config.ffn_hidden_size, 359 | config.hidden_size, 360 | bias=False, 361 | device=device, 362 | **_config_to_kwargs(config) 363 | ) 364 | 365 | def forward(self, hidden_states): 366 | 367 | # [s, b, 4hp] 368 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 369 | 370 | if self.mlp_activation == "swiglu": 371 | intermediate_parallel = self.activation_func(intermediate_parallel) 372 | elif self.mlp_activation == "silu": 373 | gated_weight = self.activation_func(self.gate_proj(hidden_states)) 374 | intermediate_parallel = gated_weight * intermediate_parallel 375 | else: 376 | raise NotImplementedError("mlp_activation {} not supported".format(self.mlp_activation)) 377 | 378 | # [s, b, h] 379 | output = self.dense_4h_to_h(intermediate_parallel) 380 | 381 | return output 382 | 383 | 384 | class BatGPTLayer(torch.nn.Module): 385 | """A single transformer layer. 386 | 387 | Transformer layer takes input with size [s, b, h] and returns an 388 | output of the same size. 389 | """ 390 | 391 | def __init__(self, config: BatGPTConfig, device=None): 392 | super(BatGPTLayer, self).__init__() 393 | 394 | # Layernorm on the input data. 395 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon, device=device, 396 | dtype=config.torch_dtype) 397 | 398 | # Self attention. 399 | self.self_attention = SelfAttention(config, device=device) 400 | 401 | self.hidden_dropout = config.hidden_dropout 402 | 403 | # Layernorm on the attention output 404 | self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon, device=device, 405 | dtype=config.torch_dtype) 406 | 407 | # MLP 408 | self.mlp = MLP(config, device=device) 409 | 410 | def forward( 411 | self, 412 | hidden_states, 413 | attention_mask, 414 | rotary_pos_emb, 415 | kv_cache=None, 416 | use_cache=True, 417 | ): 418 | # hidden_states: [s, b, h] 419 | residual = hidden_states 420 | 421 | # Layer norm at the beginning of the transformer layer. 422 | layernorm_output = self.input_layernorm(hidden_states) 423 | 424 | # Self attention. 425 | attention_output, kv_cache = self.self_attention( 426 | layernorm_output, 427 | attention_mask, 428 | rotary_pos_emb, 429 | kv_cache=kv_cache, 430 | use_cache=use_cache 431 | ) 432 | 433 | # Residual connection. 434 | layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) 435 | 436 | layernorm_input = residual + layernorm_input 437 | 438 | # Layer norm post the self attention. 439 | layernorm_output = self.post_attention_layernorm(layernorm_input) 440 | 441 | # MLP. 442 | mlp_output = self.mlp(layernorm_output) 443 | 444 | # Second residual connection. 445 | residual = layernorm_input 446 | 447 | output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) 448 | 449 | output = residual + output 450 | 451 | return output, kv_cache 452 | 453 | 454 | class BatGPTTransformer(torch.nn.Module): 455 | """Transformer class.""" 456 | 457 | def __init__(self, config: BatGPTConfig, device=None): 458 | super(BatGPTTransformer, self).__init__() 459 | 460 | # Number of layers. 461 | self.num_layers = config.n_layer 462 | 463 | # Transformer layers. 464 | def build_layer(): 465 | return BatGPTLayer(config, device=device) 466 | 467 | self.layers = torch.nn.ModuleList([build_layer() for i in range(self.num_layers)]) 468 | 469 | # final layer norm before output. 470 | self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon, device=device, 471 | dtype=config.torch_dtype) 472 | 473 | self.gradient_checkpointing = False 474 | 475 | def _get_layer(self, layer_number): 476 | return self.layers[layer_number] 477 | 478 | def forward( 479 | self, 480 | hidden_states, 481 | attention_mask, 482 | rotary_pos_emb, 483 | kv_caches=None, 484 | use_cache: Optional[bool] = True, 485 | output_hidden_states: Optional[bool] = False, 486 | ): 487 | if not kv_caches: 488 | kv_caches = [None for _ in range(self.num_layers)] 489 | presents = () if use_cache else None 490 | if self.gradient_checkpointing and self.training: 491 | if use_cache: 492 | logger.warning_once( 493 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 494 | ) 495 | use_cache = False 496 | 497 | all_self_attentions = None 498 | all_hidden_states = () if output_hidden_states else None 499 | for index in range(self.num_layers): 500 | if output_hidden_states: 501 | all_hidden_states = all_hidden_states + (hidden_states,) 502 | 503 | layer = self._get_layer(index) 504 | if self.gradient_checkpointing and self.training: 505 | layer_ret = torch.utils.checkpoint.checkpoint( 506 | layer, 507 | hidden_states, 508 | attention_mask, 509 | rotary_pos_emb, 510 | kv_caches[index], 511 | use_cache 512 | ) 513 | else: 514 | layer_ret = layer( 515 | hidden_states, 516 | attention_mask, 517 | rotary_pos_emb, 518 | kv_cache=kv_caches[index], 519 | use_cache=use_cache 520 | ) 521 | hidden_states, kv_cache = layer_ret 522 | if use_cache: 523 | presents = presents + (kv_cache,) 524 | 525 | if output_hidden_states: 526 | all_hidden_states = all_hidden_states + (hidden_states,) 527 | 528 | hidden_states = self.ln_f(hidden_states) 529 | 530 | return hidden_states, presents, all_hidden_states, all_self_attentions 531 | 532 | 533 | class BatGPTPreTrainedModel(PreTrainedModel): 534 | """ 535 | An abstract class to handle weights initialization and 536 | a simple interface for downloading and loading pretrained models. 537 | """ 538 | 539 | is_parallelizable = False 540 | supports_gradient_checkpointing = True 541 | config_class = BatGPTConfig 542 | base_model_prefix = "transformer" 543 | _no_split_modules = ["BatGPTLayer"] 544 | 545 | def _init_weights(self, module: nn.Module): 546 | """Initialize the weights.""" 547 | return 548 | 549 | 550 | 551 | def _set_gradient_checkpointing(self, module, value=False): 552 | if isinstance(module, BatGPTTransformer): 553 | module.gradient_checkpointing = value 554 | 555 | 556 | 557 | class BatGPTModel(BatGPTPreTrainedModel): 558 | def __init__(self, config: BatGPTConfig, device=None): 559 | super().__init__(config) 560 | 561 | self.num_layers = config.n_layer 562 | self.num_heads = config.n_head 563 | self.head_dim = config.hidden_size // config.n_head 564 | self.max_seq_len = config.max_seq_len 565 | self.pos_emb_impl = config.pos_emb_impl 566 | self.model_cache_seq_len = 1024 567 | 568 | # word embedding 569 | self.word_embeddings = module_init(nn.Embedding, 570 | config.empty_init, 571 | config.vocab_size, 572 | config.emb_dim, 573 | dtype=config.torch_dtype, 574 | device=device 575 | ) 576 | 577 | self.emb_fact = None 578 | if config.use_emb_factorization or config.emb_dim != config.hidden_size: 579 | self.emb_fact = nn.Linear(config.emb_dim, config.hidden_size, bias=False, 580 | dtype=config.torch_dtype, device=device) 581 | 582 | init_kwargs = {} 583 | if device is not None: 584 | init_kwargs["device"] = device 585 | 586 | self.encoder = module_init(BatGPTTransformer, config.empty_init, config, **init_kwargs) 587 | 588 | self.first_run = True 589 | self.alibi_mask = None 590 | 591 | self.prefix_size = config.prefix_size 592 | self.prefix_proj = config.prefix_proj 593 | if self.prefix_size is not None: 594 | for param in self.parameters(): 595 | param.requires_grad = False 596 | self.prefix_tokens = torch.arange(self.prefix_size).long() 597 | self.prefix_encoder = PrefixEncoder(config) 598 | self.dropout = torch.nn.Dropout(0.1) 599 | 600 | def get_input_embeddings(self): 601 | return self.word_embeddings 602 | 603 | def get_prompt(self, batch_size, device, dtype=torch.half): 604 | prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) 605 | past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) 606 | past_key_values = past_key_values.view( 607 | batch_size, 608 | self.prefix_size, 609 | self.num_layers * 2, 610 | self.multi_query_group_num, 611 | self.kv_channels 612 | ) 613 | # seq_len, b, nh, hidden_size 614 | past_key_values = self.dropout(past_key_values) 615 | past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) 616 | return past_key_values 617 | 618 | def get_rotary_tensor(self, seq_len: int, head_dim: int, dtype: torch.dtype, device: torch.device, base: int = 10000): 619 | 620 | n_elem = head_dim // 2 621 | 622 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 623 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) 624 | 625 | # Create position indexes `[0, 1, ..., seq_len - 1]` 626 | seq_idx = torch.arange(seq_len, dtype=dtype, device=device) 627 | 628 | # Calculate the product of position index and $\theta_i$ 629 | idx_theta = torch.outer(seq_idx, theta).float() 630 | 631 | cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) 632 | 633 | # this is to mimic the behaviour of complex32, else we will get different results 634 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 635 | cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() 636 | 637 | return cache 638 | 639 | def get_causal_mask(self, input_ids, past_key_values, attention_mask=None) -> torch.BoolTensor: 640 | 641 | batch_size, seq_length = input_ids.shape 642 | 643 | # B x L x L 644 | causal_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) 645 | causal_mask.tril_() 646 | 647 | past_length = 0 648 | if past_key_values: 649 | past_length = past_key_values[0][0].shape[0] 650 | 651 | if past_length: 652 | causal_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, 653 | device=input_ids.device), causal_mask), dim=-1) 654 | 655 | if attention_mask is not None: 656 | causal_mask = causal_mask * attention_mask.unsqueeze(1) 657 | 658 | if not past_length and attention_mask is not None: 659 | causal_mask -= attention_mask.unsqueeze(-1) - 1 660 | 661 | causal_mask = (causal_mask < 0.5).bool() 662 | causal_mask.unsqueeze_(1) 663 | 664 | return causal_mask 665 | 666 | def get_alibi_mask(self, tensor, seq_length_with_past): 667 | if self.training: 668 | slopes = torch.Tensor(_get_interleave(self.num_heads)) 669 | alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand( 670 | self.num_heads, 671 | -1, -1) 672 | alibi = alibi.view(self.num_heads, 1, seq_length_with_past) 673 | mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.num_heads) 674 | else: 675 | if self.first_run: 676 | self.first_run = False 677 | self.register_buffer("future_mask", _gen_alibi_mask(self.num_heads, self.model_cache_seq_len).to(tensor), persistent=False) 678 | if seq_length_with_past > self.model_cache_seq_len: 679 | self.model_cache_seq_len = seq_length_with_past 680 | self.register_buffer("future_mask", _gen_alibi_mask(self.num_heads, self.model_cache_seq_len).to(tensor), persistent=False) 681 | mask = self.future_mask[:self.num_heads, :seq_length_with_past, :seq_length_with_past] 682 | return mask 683 | 684 | 685 | def forward( 686 | self, 687 | input_ids, 688 | position_ids: Optional[torch.Tensor] = None, 689 | attention_mask: Optional[torch.BoolTensor] = None, 690 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 691 | inputs_embeds: Optional[torch.Tensor] = None, 692 | use_cache: Optional[bool] = None, 693 | output_hidden_states: Optional[bool] = None, 694 | return_dict: Optional[bool] = None, 695 | ): 696 | output_hidden_states = ( 697 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 698 | ) 699 | use_cache = use_cache if use_cache is not None else self.config.use_cache 700 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 701 | 702 | batch_size, seq_length = input_ids.shape 703 | 704 | seq_length_with_past = seq_length 705 | 706 | # -> word embedding 707 | if inputs_embeds is None: 708 | inputs_embeds = self.word_embeddings(input_ids) 709 | # [b s h] --> [s b h]. 710 | inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() 711 | 712 | if self.prefix_size is not None: 713 | if past_key_values is None: 714 | past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, 715 | dtype=inputs_embeds.dtype) 716 | if attention_mask is not None: 717 | attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.prefix_size)), 718 | attention_mask], dim=-1) 719 | 720 | if past_key_values is not None: 721 | past_key_values_length = past_key_values[0][0].shape[0] 722 | seq_length_with_past = seq_length_with_past + past_key_values_length 723 | 724 | 725 | full_attention_mask = None 726 | rotary_pos_emb=None 727 | if self.pos_emb_impl == "alibi": 728 | if self.training: 729 | if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past: 730 | self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) 731 | alibi_mask = self.alibi_mask 732 | else: 733 | alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) 734 | 735 | 736 | if attention_mask is not None: 737 | 738 | if len(attention_mask.shape) == 2: 739 | expanded_mask = attention_mask.to(alibi_mask.dtype) 740 | expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) 741 | ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) 742 | else: 743 | expanded_mask = attention_mask 744 | src_len, tgt_len = alibi_mask.size()[-2:] 745 | expanded_mask = expanded_mask.unsqueeze(1).expand(batch_size, 1, src_len, tgt_len).to(alibi_mask.dtype) 746 | # Target sizes: [1, 1, 41, 41]. Tensor sizes: [1, 1, 8, 8] 747 | inverted_mask = 1.0 - expanded_mask 748 | inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min) 749 | full_attention_mask = inverted_mask + alibi_mask.unsqueeze(0) 750 | else: 751 | full_attention_mask = alibi_mask 752 | elif self.pos_emb_impl == "rope": 753 | if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): 754 | # B x 1 x L x L 755 | full_attention_mask = self.get_causal_mask(input_ids, past_key_values, attention_mask) 756 | 757 | # Rotary positional embeddings 758 | rotary_pos_emb = self.get_rotary_tensor(self.max_seq_len, self.head_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device) 759 | if position_ids is not None: 760 | rotary_pos_emb = rotary_pos_emb[position_ids] 761 | else: 762 | rotary_pos_emb = rotary_pos_emb[None, :seq_length] 763 | rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() 764 | else: 765 | raise NotImplementedError("position embedding type: {} not supported!".format(self.pos_emb_impl)) 766 | 767 | 768 | # Run encoder. 769 | hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( 770 | inputs_embeds, 771 | full_attention_mask, 772 | rotary_pos_emb=rotary_pos_emb, 773 | kv_caches=past_key_values, 774 | use_cache=use_cache, 775 | output_hidden_states=output_hidden_states 776 | ) 777 | 778 | if not return_dict: 779 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 780 | 781 | return BaseModelOutputWithPast( 782 | last_hidden_state=hidden_states, 783 | past_key_values=presents, 784 | hidden_states=all_hidden_states, 785 | attentions=all_self_attentions, 786 | ) 787 | 788 | 789 | class BatGPTForCausalLM(BatGPTPreTrainedModel): 790 | def __init__(self, config: BatGPTConfig, device=None): 791 | super().__init__(config) 792 | 793 | self.max_sequence_length = config.max_length 794 | 795 | self.model = BatGPTModel(config, device=device) 796 | 797 | self.lm_head = module_init(nn.Linear, config.empty_init, config.hidden_size, config.vocab_size, bias=False, 798 | dtype=config.torch_dtype, device=device) 799 | 800 | self.config = config 801 | 802 | def get_input_embeddings(self): 803 | return self.model.get_input_embeddings() 804 | 805 | def _update_model_kwargs_for_generation( 806 | self, 807 | outputs: ModelOutput, 808 | model_kwargs: Dict[str, Any], 809 | is_encoder_decoder: bool = False, 810 | standardize_cache_format: bool = False, 811 | ) -> Dict[str, Any]: 812 | # update past_key_values 813 | model_kwargs["past_key_values"] = self._extract_past_from_model_output( 814 | outputs, standardize_cache_format=standardize_cache_format 815 | ) 816 | 817 | # update attention mask 818 | if "attention_mask" in model_kwargs: 819 | attention_mask = model_kwargs["attention_mask"] 820 | model_kwargs["attention_mask"] = torch.cat( 821 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 822 | ) 823 | 824 | # update position ids 825 | if "position_ids" in model_kwargs: 826 | position_ids = model_kwargs["position_ids"] 827 | new_position_id = position_ids[..., -1:].clone() 828 | new_position_id += 1 829 | model_kwargs["position_ids"] = torch.cat( 830 | [position_ids, new_position_id], dim=-1 831 | ) 832 | 833 | model_kwargs["is_first_forward"] = False 834 | return model_kwargs 835 | 836 | def prepare_inputs_for_generation( 837 | self, 838 | input_ids: torch.LongTensor, 839 | past_key_values: Optional[torch.Tensor] = None, 840 | attention_mask: Optional[torch.Tensor] = None, 841 | position_ids: Optional[torch.Tensor] = None, 842 | is_first_forward: bool = True, 843 | **kwargs 844 | ) -> dict: 845 | 846 | # only last token for input_ids if past is not None 847 | if position_ids is None: 848 | position_ids = _build_position_ids(input_ids, device=input_ids.device) 849 | 850 | if not is_first_forward: 851 | position_ids = position_ids[..., -1:] 852 | input_ids = input_ids[:, -1:] 853 | 854 | return { 855 | "input_ids": input_ids, 856 | "past_key_values": past_key_values, 857 | "position_ids": position_ids, 858 | "attention_mask": attention_mask, 859 | "return_last_logit": True 860 | } 861 | 862 | def forward( 863 | self, 864 | input_ids: Optional[torch.Tensor] = None, 865 | position_ids: Optional[torch.Tensor] = None, 866 | attention_mask: Optional[torch.Tensor] = None, 867 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 868 | inputs_embeds: Optional[torch.Tensor] = None, 869 | labels: Optional[torch.Tensor] = None, 870 | use_cache: Optional[bool] = None, 871 | output_attentions: Optional[bool] = None, 872 | output_hidden_states: Optional[bool] = None, 873 | return_dict: Optional[bool] = None, 874 | return_last_logit: Optional[bool] = False, 875 | ): 876 | use_cache = use_cache if use_cache is not None else self.config.use_cache 877 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 878 | 879 | encodings = self.model( 880 | input_ids=input_ids, 881 | position_ids=position_ids, 882 | attention_mask=attention_mask, 883 | past_key_values=past_key_values, 884 | inputs_embeds=inputs_embeds, 885 | use_cache=use_cache, 886 | output_hidden_states=output_hidden_states, 887 | return_dict=return_dict, 888 | ) 889 | 890 | hidden_states = encodings[0] 891 | if return_last_logit: 892 | hidden_states = hidden_states[-1:] 893 | 894 | lm_logits = self.lm_head(hidden_states) 895 | 896 | lm_logits = lm_logits.transpose(0, 1).contiguous() 897 | 898 | loss = None 899 | if labels is not None: 900 | lm_logits = lm_logits.to(torch.float32) 901 | 902 | # Shift so that tokens < n predict n 903 | shift_logits = lm_logits[..., :-1, :].contiguous() 904 | shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) 905 | # Flatten the tokens 906 | loss_fct = CrossEntropyLoss(ignore_index=-100) 907 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 908 | 909 | lm_logits = lm_logits.to(hidden_states.dtype) 910 | loss = loss.to(hidden_states.dtype) 911 | 912 | if not return_dict: 913 | output = (lm_logits,) + encodings[1:] 914 | return ((loss,) + output) if loss is not None else output 915 | 916 | return CausalLMOutputWithPast( 917 | loss=loss, 918 | logits=lm_logits, 919 | past_key_values=encodings.past_key_values, 920 | hidden_states=encodings.hidden_states, 921 | attentions=encodings.attentions, 922 | ) 923 | 924 | @staticmethod 925 | def _reorder_cache( 926 | past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor 927 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: 928 | """ 929 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 930 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 931 | beam_idx at every generation step. 932 | 933 | Output shares the same memory storage as `past`. 934 | """ 935 | return tuple( 936 | ( 937 | layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), 938 | layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), 939 | ) 940 | for layer_past in past 941 | ) 942 | 943 | def process_response(self, response): 944 | response = response.strip() 945 | return response 946 | 947 | def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt = None): 948 | inputs = tokenizer.build_inputs(query, history=history, system_prompt=system_prompt) 949 | inputs = inputs.to(self.device) 950 | return inputs 951 | 952 | def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt = None): 953 | inputs = tokenizer.build_stream_inputs(query, history=history, system_prompt=system_prompt) 954 | inputs = inputs.to(self.device) 955 | return inputs 956 | 957 | @torch.no_grad() 958 | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt=None, max_length: int = 8192, num_beams=1, 959 | do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): 960 | if history is None: 961 | history = [] 962 | if logits_processor is None: 963 | logits_processor = LogitsProcessorList() 964 | logits_processor.append(InvalidScoreLogitsProcessor()) 965 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, 966 | "temperature": temperature, **kwargs} #, "logits_processor": logits_processor 967 | inputs = self.build_inputs(tokenizer, query, history=history, system_prompt=system_prompt) 968 | outputs = self.generate(**inputs, **gen_kwargs) 969 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 970 | response = tokenizer.decode(outputs, skip_special_tokens=True) # 971 | response = self.process_response(response) 972 | history = history + [(query, response)] 973 | return response, history 974 | 975 | @torch.no_grad() 976 | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt=None, past_key_values=None, 977 | max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, 978 | return_past_key_values=False, **kwargs): 979 | if history is None: 980 | history = [] 981 | if logits_processor is None: 982 | logits_processor = LogitsProcessorList() 983 | logits_processor.append(InvalidScoreLogitsProcessor()) 984 | gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, 985 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} 986 | if past_key_values is None and not return_past_key_values: 987 | inputs = self.build_inputs(tokenizer, query, history=history, system_prompt=system_prompt) 988 | else: 989 | inputs = self.build_stream_inputs(tokenizer, query, history=history, system_prompt=system_prompt) 990 | if past_key_values is not None: 991 | past_length = past_key_values[0][0].shape[0] 992 | if self.model.prefix_size is not None: 993 | past_length -= self.transformer.prefix_size 994 | inputs.position_ids += past_length 995 | attention_mask = inputs.attention_mask 996 | attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) 997 | inputs['attention_mask'] = attention_mask 998 | for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, 999 | return_past_key_values=return_past_key_values, **gen_kwargs): 1000 | if return_past_key_values: 1001 | outputs, past_key_values = outputs 1002 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 1003 | response = tokenizer.decode(outputs) 1004 | if response and response[-1] != "�": 1005 | response = self.process_response(response) 1006 | new_history = history + [(query, response)] 1007 | if return_past_key_values: 1008 | yield response, new_history, past_key_values 1009 | else: 1010 | yield response, new_history 1011 | 1012 | @torch.no_grad() 1013 | def stream_generate( 1014 | self, 1015 | input_ids, 1016 | generation_config: Optional[GenerationConfig] = None, 1017 | logits_processor: Optional[LogitsProcessorList] = None, 1018 | stopping_criteria: Optional[StoppingCriteriaList] = None, 1019 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 1020 | return_past_key_values=False, 1021 | **kwargs, 1022 | ): 1023 | batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] 1024 | 1025 | if generation_config is None: 1026 | generation_config = self.generation_config 1027 | generation_config = copy.deepcopy(generation_config) 1028 | model_kwargs = generation_config.update(**kwargs) 1029 | bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id 1030 | 1031 | if isinstance(eos_token_id, int): 1032 | eos_token_id = [eos_token_id] 1033 | 1034 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 1035 | if has_default_max_length and generation_config.max_new_tokens is None: 1036 | warnings.warn( 1037 | f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " 1038 | "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" 1039 | " recommend using `max_new_tokens` to control the maximum length of the generation.", 1040 | UserWarning, 1041 | ) 1042 | elif generation_config.max_new_tokens is not None: 1043 | generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length 1044 | if not has_default_max_length: 1045 | logger.warn( 1046 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 1047 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 1048 | "Please refer to the documentation for more information. " 1049 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", 1050 | UserWarning, 1051 | ) 1052 | 1053 | if input_ids_seq_length >= generation_config.max_length: 1054 | input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" 1055 | logger.warning( 1056 | f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" 1057 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" 1058 | " increasing `max_new_tokens`." 1059 | ) 1060 | 1061 | # 2. Set generation parameters if not already defined 1062 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 1063 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 1064 | 1065 | logits_processor = self._get_logits_processor( 1066 | generation_config=generation_config, 1067 | input_ids_seq_length=input_ids_seq_length, 1068 | encoder_input_ids=input_ids, 1069 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 1070 | logits_processor=logits_processor, 1071 | ) 1072 | 1073 | stopping_criteria = self._get_stopping_criteria( 1074 | generation_config=generation_config, stopping_criteria=stopping_criteria 1075 | ) 1076 | logits_warper = self._get_logits_warper(generation_config) 1077 | 1078 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 1079 | scores = None 1080 | while True: 1081 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1082 | # forward pass to get next token 1083 | outputs = self( 1084 | **model_inputs, 1085 | return_dict=True, 1086 | output_attentions=False, 1087 | output_hidden_states=False, 1088 | ) 1089 | 1090 | next_token_logits = outputs.logits[:, -1, :] 1091 | 1092 | # pre-process distribution 1093 | next_token_scores = logits_processor(input_ids, next_token_logits) 1094 | next_token_scores = logits_warper(input_ids, next_token_scores) 1095 | 1096 | # sample 1097 | probs = nn.functional.softmax(next_token_scores, dim=-1) 1098 | if generation_config.do_sample: 1099 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 1100 | else: 1101 | next_tokens = torch.argmax(probs, dim=-1) 1102 | 1103 | # update generated ids, model inputs, and length for next step 1104 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 1105 | model_kwargs = self._update_model_kwargs_for_generation( 1106 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 1107 | ) 1108 | unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) 1109 | if return_past_key_values: 1110 | yield input_ids, outputs.past_key_values 1111 | else: 1112 | yield input_ids 1113 | # stop when each sentence is finished, or if we exceed the maximum length 1114 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 1115 | break 1116 | 1117 | --------------------------------------------------------------------------------