├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── chat.py ├── data └── ultrachat_small.jsonl ├── requirements.txt ├── scripts └── download_ultrachat.py ├── train_mamba.py └── trainer ├── __init__.py ├── data.py └── mamba_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pytest_cache/ 3 | .venv/ 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mamba-Chat 🐍 2 | 3 | **Mamba-Chat is the first chat language model based on a state-space model architecture, not a transformer.** 4 | 5 | The model is based on Albert Gu's and Tri Dao's work *Mamba: Linear-Time Sequence Modeling with Selective State Spaces* ([paper](https://arxiv.org/pdf/2312.00752.pdf)) as well as their [model implementation](https://github.com/state-spaces/mamba). This repository provides training / fine-tuning code for the model based on some modifications of the Huggingface Trainer class. 6 | 7 | Mamba-Chat is based on Mamba-2.8B and was fine-tuned on 16,000 samples of the [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset. To learn more, you can: 8 | 9 | - Take a look at the model on [Huggingface](https://huggingface.co/havenhq/mamba-chat) 🤗 10 | - Talk to us on the [Haven](https://haven.run/) Community [Discord](https://discord.com/invite/JDjbfp6q2G) 🧑‍🤝‍🧑 11 | - Talk to Mamba-Chat on [Google Colab](https://colab.research.google.com/drive/1dUlEYnRbgJYg4_kofNpsCddLCh6vltNK?usp=sharing) 12 | 13 | 14 |
15 | 16 | ## Run Mamba-Chat 17 | 18 | We provide code for testing and fine-tuning our model. Here's how to get started and what you can do with it: 19 | 20 |
21 | 22 | 23 | **Clone repository and install dependencies:** 24 | ``` 25 | git clone https://github.com/havenhq/mamba-chat.git 26 | cd mamba-chat 27 | pip install -r requirements.txt 28 | ``` 29 | 30 |
31 | 32 | **Talk to Mamba-Chat (CLI chatbot):** 33 | ``` 34 | python chat.py 35 | ``` 36 | 37 |
38 | 39 | **Talk to Mamba-Chat (gradio app):** 40 | ``` 41 | pip install gradio==4.8.0 42 | python app.py --share 43 | ``` 44 | 45 |
46 | 47 | **Fine-Tune Mamba (the base model) on a subset of the Ultrachat dataset:** 48 | ``` 49 | python train_mamba.py --model state-spaces/mamba-2.8b --tokenizer EleutherAI/gpt-neox-20b --learning_rate 5e-5 --batch_size 4 --data_path ./data/ultrachat_small.jsonl --num_epochs 3 50 | ``` 51 | 52 |
53 | 54 | **If you have a 24GB card (3090, 4090, etc.) you can use these settings:** 55 | ``` 56 | python train_mamba.py --model state-spaces/mamba-2.8b --tokenizer EleutherAI/gpt-neox-20b --learning_rate 5e-5 --batch_size 1 --gradient_accumulation_steps 4 --optim paged_adamw_8bit --data_path ./data/ultrachat_small.jsonl --num_epochs 3 57 | ``` 58 | 59 | ## Citation 60 | 61 | ``` 62 | bibtex 63 | @misc{haven2023mambachat, 64 | title = {Mamba-Chat}, 65 | author = {Justus Mattern and Konstantin Hohr}, 66 | year = {2023}, 67 | howpublished = {GitHub}, 68 | url = {https://github.com/havenhq/mamba-chat} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 5 | from argparse import ArgumentParser 6 | 7 | def get_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument("--port", type=int, default=7860) 10 | parser.add_argument("--device", type=str, default='cuda', help='Device to run the model on') 11 | parser.add_argument("--model", type=str, default='havenhq/mamba-chat', help='Model to use') 12 | parser.add_argument( 13 | "--share", 14 | action="store_true", 15 | default=False, 16 | help="share your instance publicly through gradio", 17 | ) 18 | try: 19 | args = parser.parse_args() 20 | except: 21 | parser.print_help() 22 | exit(0) 23 | return args 24 | 25 | 26 | if __name__ == "__main__": 27 | args = get_args() 28 | 29 | device = args.device 30 | model_name = args.model 31 | eos = "<|endoftext|>" 32 | tokenizer = AutoTokenizer.from_pretrained(model_name) 33 | tokenizer.eos_token = eos 34 | tokenizer.pad_token = tokenizer.eos_token 35 | tokenizer.chat_template = AutoTokenizer.from_pretrained( 36 | "HuggingFaceH4/zephyr-7b-beta" 37 | ).chat_template 38 | 39 | model = MambaLMHeadModel.from_pretrained( 40 | model_name, device=device, dtype=torch.float16 41 | ) 42 | 43 | def chat_with_mamba( 44 | user_message, 45 | history: list[list[str]], 46 | temperature: float = 0.9, 47 | top_p: float = 0.7, 48 | max_length: int = 2000, 49 | ): 50 | history_dict: list[dict[str, str]] = [] 51 | for user_m, assistant_m in history: 52 | history_dict.append(dict(role="user", content=user_m)) 53 | history_dict.append(dict(role="assistant", content=assistant_m)) 54 | history_dict.append(dict(role="user", content=user_message)) 55 | 56 | input_ids = tokenizer.apply_chat_template( 57 | history_dict, return_tensors="pt", add_generation_prompt=True 58 | ).to(device) 59 | 60 | out = model.generate( 61 | input_ids=input_ids, 62 | max_length=max_length, 63 | temperature=temperature, 64 | top_p=top_p, 65 | eos_token_id=tokenizer.eos_token_id, 66 | ) 67 | 68 | decoded = tokenizer.batch_decode(out) 69 | assistant_message = ( 70 | decoded[0].split("<|assistant|>\n")[-1].replace(eos, "") 71 | ) 72 | return assistant_message 73 | 74 | 75 | demo = gr.ChatInterface( 76 | fn=chat_with_mamba, 77 | # examples=[ 78 | # "Explain what is state space model", 79 | # "Nice to meet you!", 80 | # "'Mamba is way better than ChatGPT.' Is this statement correct?", 81 | # ], 82 | additional_inputs=[ 83 | gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="temperature"), 84 | gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="top_p"), 85 | gr.Number(value=2000, label="max_length"), 86 | ], 87 | title="Mamba Chat", 88 | ) 89 | demo.launch(server_port=args.port, share=args.share) 90 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 4 | 5 | device = "cuda" 6 | tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat") 7 | tokenizer.eos_token = "<|endoftext|>" 8 | tokenizer.pad_token = tokenizer.eos_token 9 | tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template 10 | 11 | model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16) 12 | 13 | messages = [] 14 | while True: 15 | user_message = input("\nYour message: ") 16 | messages.append(dict( 17 | role="user", 18 | content=user_message 19 | )) 20 | 21 | input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda") 22 | 23 | out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id) 24 | 25 | decoded = tokenizer.batch_decode(out) 26 | messages.append(dict( 27 | role="assistant", 28 | content=decoded[0].split("<|assistant|>\n")[-1]) 29 | ) 30 | 31 | print("Model:", decoded[0].split("<|assistant|>\n")[-1]) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | torch==2.1.0 3 | transformers==4.35.0 4 | causal-conv1d==1.0.0 5 | mamba-ssm==1.0.1 6 | accelerate==0.25.0 7 | bitsandbytes==0.41.3 8 | scipy==1.11.4 -------------------------------------------------------------------------------- /scripts/download_ultrachat.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datasets import load_dataset 3 | 4 | data = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") 5 | 6 | 7 | with open("../data/ultrachat.jsonl", "w") as f: 8 | for d in data: 9 | f.write(json.dumps(dict(messages=d["messages"]))+"\n") 10 | 11 | -------------------------------------------------------------------------------- /train_mamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 5 | from transformers import AutoTokenizer, TrainingArguments 6 | from trainer.data import ChatDataModule 7 | from trainer.mamba_trainer import MambaTrainer 8 | 9 | 10 | def run(args): 11 | 12 | model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda") 13 | 14 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 15 | tokenizer.eos_token = "<|endoftext|>" 16 | tokenizer.pad_token = tokenizer.eos_token 17 | tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template 18 | 19 | 20 | data_module = ChatDataModule( 21 | tokenizer=tokenizer, 22 | data_path=args.data_path, 23 | conversation_template=tokenizer.chat_template, 24 | max_tokens=2048 25 | ) 26 | 27 | 28 | trainer = MambaTrainer( 29 | model=model, 30 | train_dataset=data_module.dataset, 31 | tokenizer=tokenizer, 32 | args=TrainingArguments( 33 | learning_rate=args.learning_rate, 34 | num_train_epochs=args.num_epochs, 35 | per_device_train_batch_size=args.batch_size, 36 | gradient_accumulation_steps=args.gradient_accumulation_steps, 37 | optim=args.optim, 38 | output_dir="mamba-chat", 39 | logging_steps=50, 40 | save_steps=500, 41 | ), 42 | data_collator=data_module.data_collator, 43 | ) 44 | 45 | trainer.train() 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--model", type=str, default="state-spaces/mamba-2.8b") 51 | parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") 52 | parser.add_argument("--learning_rate", type=float, default=5e-5) 53 | parser.add_argument("--batch_size", type=int, default=4) 54 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 55 | parser.add_argument("--optim", type=str, default="adamw_torch") 56 | parser.add_argument("--data_path", type=str, default="./data/ultrachat_small.jsonl") 57 | parser.add_argument("--num_epochs", type=int, default=1) 58 | args = parser.parse_args() 59 | 60 | run(args) 61 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redotvideo/mamba-chat/68e60823eb99d94d71d6c29cc203795a6312aea3/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import json 4 | 5 | from dataclasses import dataclass 6 | from typing import Dict, Sequence 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class ChatDataset(Dataset): 12 | def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int): 13 | super(ChatDataset, self).__init__() 14 | data = [] 15 | with open(data_path, "r") as file: 16 | for line in file: 17 | try: 18 | data.append(json.loads(line)) 19 | except Exception as e: 20 | print("json processing exception", e) 21 | continue 22 | 23 | 24 | data_dict = preprocess(data, tokenizer, conversation_template, max_tokens) 25 | 26 | self.input_ids = data_dict["input_ids"] 27 | self.labels = data_dict["labels"] 28 | 29 | def __len__(self): 30 | return len(self.input_ids) 31 | 32 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 33 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 34 | 35 | 36 | @dataclass 37 | class DataCollatorForChatDataset(object): 38 | """ 39 | Collate examples for supervised fine-tuning. 40 | """ 41 | 42 | tokenizer: transformers.PreTrainedTokenizer 43 | 44 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 45 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids")) 46 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 47 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) 48 | 49 | return dict( 50 | input_ids=input_ids, 51 | labels=labels, 52 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 53 | ) 54 | 55 | 56 | class ChatDataModule(): 57 | def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int): 58 | 59 | self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens) 60 | self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer) 61 | 62 | 63 | def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict: 64 | """ 65 | Preprocess the data by tokenizing. 66 | """ 67 | all_input_ids = [] 68 | all_label_ids = [] 69 | tokenizer.use_default_system_prompt = False 70 | 71 | print("Tokenizing dataset...") 72 | for conv in tqdm(conversations): 73 | current_conv = conv["messages"] 74 | tokenized_responses = [] 75 | for msg in current_conv: 76 | if msg["role"] == "assistant": 77 | tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False)) 78 | 79 | tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True) 80 | all_input_ids.append(torch.LongTensor(tokenized_conv)) 81 | 82 | 83 | return dict(input_ids=all_input_ids, labels=all_input_ids) -------------------------------------------------------------------------------- /trainer/mamba_trainer.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | import torch 3 | import os 4 | 5 | 6 | class MambaTrainer(Trainer): 7 | def compute_loss(self, model, inputs, return_outputs=False): 8 | input_ids = inputs.pop("input_ids") 9 | lm_logits = model(input_ids).logits 10 | 11 | labels = input_ids.to(lm_logits.device) 12 | shift_logits = lm_logits[:, :-1, :].contiguous() 13 | labels = labels[:, 1:].contiguous() 14 | 15 | loss_fct = torch.nn.CrossEntropyLoss() 16 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) 17 | 18 | return lm_loss 19 | 20 | def save_model(self, output_dir, _internal_call): 21 | if not os.path.exists(output_dir): 22 | os.makedirs(output_dir) 23 | 24 | torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin") 25 | self.tokenizer.save_pretrained(output_dir) --------------------------------------------------------------------------------