├── AIChatbotWithLLM_Report.pdf ├── AIChatbotWithLLM_SLIDES.pdf ├── AIChatbotWithLLM_onePage.pdf ├── LlaMATH-3-8B-Instruct-4bit.yaml ├── README.md └── app.py /AIChatbotWithLLM_Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GusLovesMath/Local_LLM_Training_Apple_Silicon/HEAD/AIChatbotWithLLM_Report.pdf -------------------------------------------------------------------------------- /AIChatbotWithLLM_SLIDES.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GusLovesMath/Local_LLM_Training_Apple_Silicon/HEAD/AIChatbotWithLLM_SLIDES.pdf -------------------------------------------------------------------------------- /AIChatbotWithLLM_onePage.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GusLovesMath/Local_LLM_Training_Apple_Silicon/HEAD/AIChatbotWithLLM_onePage.pdf -------------------------------------------------------------------------------- /LlaMATH-3-8B-Instruct-4bit.yaml: -------------------------------------------------------------------------------- 1 | original_repo: GusLovesMath/LlaMATH-3-8B-Instruct-4bit 2 | mlx-repo: GusLovesMath/LlaMATH-3-8B-Instruct-4bit -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Local LLM Training on Apple Silicon - Project README 2 | 3 | This repository contains the resources and documentation for the project "Local LLM Training on Apple Silicon", where the Llama3 model was fine-tuned to efficiently solve verbose mathematical word problems on an Apple Silicon device with 16 GPUs. The project demonstrates the application of the MLX library and Metal API to achieve high computational performance and privacy on non-traditional hardware platforms. 4 | 5 | ***Note: The notebook will be posted soon.*** 6 | 7 | ## Repository Contents 8 | - **LLM_Local_Training_Llama3.ipynb**: Jupyter notebook containing all the code for setting up, training, and evaluating the LlaMATH3 model. 9 | - **AIChatbotWithLLM_SLIDES.pdf**: Presentation slides detailing the project's approach, architecture, and outcomes. 10 | - **AIChatbotWithLLM_Report.pdf**: Comprehensive report discussing the project in detail. 11 | - **AIChatbotWithLLM_onePage.pdf**: One-page summary of the project report for quick reference. 12 | - **app.py**: Updated GUI application file to replace the original in the [chat-with-mlx](https://github.com/qnguyen3/chat-with-mlx.git) repository for enhanced user interaction. 13 | - **LlaMATH-3-8B-Instruct-4bit.yaml**: Configuration file to be added to the `../chat-with-mlx/chat_with_mlx/models/config` directory for using the custom trained model. 14 | 15 | ## Installation 16 | To set up the project environment and run the models, you will need to install the following software and libraries: 17 | 18 | ```bash 19 | conda create -n localLLM python=3.11 20 | activate localLLM 21 | pip install mlx-llm 22 | pip install torch==2.3.0 23 | pip install transformers==4.40.1 24 | pip install datasets==2.19.0 25 | pip install pandas==2.2.2 26 | ``` 27 | 28 | ## Usage 29 | To use the trained LlaMATH3 model for generating responses to mathematical prompts, follow these steps: 30 | 31 | ```python 32 | from mlx_lm import load, generate 33 | 34 | # Load the model 35 | model, tokenizer = load("GusLovesMath/LlaMATH-3-8B-Instruct-4bit") 36 | 37 | # Example prompt 38 | prompt = """ 39 | Q: A new program had 60 downloads in the first month. 40 | The number of downloads in the second month was three times as many as the first month, 41 | but then reduced by 30% in the third month. How many downloads did the program have total over the three months? 42 | """ 43 | 44 | # Generate response 45 | response = generate(model, tokenizer, prompt=prompt, max_tokens=132, temp=0.0, verbose=True) 46 | print('LlaMATH Response:', response) 47 | ``` 48 | 49 | ## Model Details 50 | - **Source**: The model was converted to MLX format from `mlx-community/Meta-Llama-3-8B-Instruct-4bit` using mlx-lm version 0.12.1. 51 | - **Training Hardware**: Apple M2 Pro chip with 16GB of RAM, 16 GPUs, and CPUs. 52 | - **Model Card**: For more detailed information about the model's capabilities and training, refer to the original [model card](https://huggingface.co/GusLovesMath/LlaMATH-3-8B-Instruct-4bit). 53 | 54 | ## Interface with [chat-with-mlx](https://github.com/qnguyen3/chat-with-mlx.git) and updated `app.py` File 55 | Screenshot 2024-05-14 at 2 11 21 PM 56 | 57 | Here we have our locally Llama3, trained on verbose math problems. I call it LlaMATH3, being utilized in a local chatbot. 58 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Built upon the great work at 2 | # https://github.com/qnguyen3/chat-with-mlx 3 | 4 | import argparse 5 | import os 6 | import subprocess 7 | import time 8 | 9 | import gradio as gr 10 | from huggingface_hub import snapshot_download 11 | from langchain.text_splitter import RecursiveCharacterTextSplitter 12 | from langchain_community.document_loaders import ( 13 | Docx2txtLoader, 14 | PyPDFLoader, 15 | TextLoader, 16 | YoutubeLoader, 17 | ) 18 | from typing import Iterable 19 | from gradio.themes.base import Base 20 | from gradio.themes.utils import colors, fonts, sizes 21 | 22 | from langchain_community.embeddings import HuggingFaceEmbeddings 23 | from langchain_community.vectorstores import Chroma 24 | from openai import OpenAI 25 | 26 | from chat_with_mlx import __version__ 27 | from chat_with_mlx.models.utils import model_info 28 | from chat_with_mlx.rag.utils import get_prompt 29 | 30 | os.environ["TOKENIZERS_PARALLELISM"] = "False" 31 | SUPPORTED_LANG = [ 32 | "default", 33 | "English", 34 | "Spanish", 35 | "Chinese", 36 | "Vietnamese", 37 | "Japanese", 38 | "Korean", 39 | "Indian", 40 | "Turkish", 41 | "German", 42 | "French", 43 | "Italian", 44 | ] 45 | openai_api_base = "http://127.0.0.1:8080/v1" 46 | model_dicts, yml_path, cfg_list, mlx_config = model_info() 47 | model_list = list(cfg_list.keys()) 48 | client = OpenAI(api_key="EMPTY", base_url=openai_api_base) 49 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) 50 | emb = HuggingFaceEmbeddings( 51 | model_name="nomic-ai/nomic-embed-text-v1.5", 52 | model_kwargs={"trust_remote_code": True}, 53 | ) 54 | vectorstore = None 55 | 56 | 57 | def load_model(model_name, lang): 58 | global process, rag_prompt, rag_his_prompt, sys_prompt, default_lang 59 | default_lang = "default" 60 | prompts, sys_prompt = get_prompt(f"{yml_path[cfg_list[model_name]]}", lang) 61 | rag_prompt, rag_his_prompt = prompts[0], prompts[1] 62 | model_name_list = cfg_list[model_name].split("/") 63 | directory_path = os.path.dirname(os.path.abspath(__file__)) 64 | local_model_dir = os.path.join( 65 | directory_path, "models", "download", model_name_list[1] 66 | ) 67 | 68 | if not os.path.exists(local_model_dir): 69 | snapshot_download(repo_id=mlx_config[model_name], local_dir=local_model_dir) 70 | 71 | command = ["python3", "-m", "mlx_lm.server", "--model", local_model_dir] 72 | 73 | try: 74 | process = subprocess.Popen( 75 | command, stdin=subprocess.PIPE, stderr=subprocess.PIPE, text=True 76 | ) 77 | process.stdin.write("y\n") 78 | process.stdin.flush() 79 | return {model_status: "Model Loaded"} 80 | except Exception as e: 81 | return {model_status: f"Exception occurred: {str(e)}"} 82 | 83 | 84 | def kill_process(): 85 | global process 86 | process.terminate() 87 | time.sleep(2) 88 | if process.poll() is None: # Check if the process has indeed terminated 89 | process.kill() # Force kill if still running 90 | 91 | print("Model Killed") 92 | return {model_status: "Model Unloaded"} 93 | 94 | 95 | def check_file_type(file_path): 96 | # Check for document file extensions 97 | if ( 98 | file_path.endswith(".pdf") 99 | or file_path.endswith(".txt") 100 | or file_path.endswith(".doc") 101 | or file_path.endswith(".docx") 102 | ): 103 | return True 104 | # Check for YouTube link formats 105 | elif ( 106 | file_path.startswith("https://www.youtube.com/") 107 | or file_path.startswith("https://youtube.com/") 108 | or file_path.startswith("https://youtu.be/") 109 | ): 110 | return True 111 | else: 112 | return False 113 | 114 | 115 | def upload(files): 116 | supported = check_file_type(files) 117 | if supported: 118 | return {url: files, index_status: "Not Done"} 119 | else: 120 | return {url: "File type not supported", index_status: "Not Done"} 121 | 122 | 123 | def indexing(mode, url): 124 | global vectorstore 125 | 126 | try: 127 | if mode == "Files (docx, pdf, txt)": 128 | if url.endswith(".pdf"): 129 | loader = PyPDFLoader(url) 130 | elif url.endswith(".docx"): 131 | loader = Docx2txtLoader(url) 132 | elif url.endswith(".txt"): 133 | loader = TextLoader(url) 134 | splits = loader.load_and_split(text_splitter) 135 | elif mode == "YouTube (url)": 136 | loader = YoutubeLoader.from_youtube_url( 137 | url, add_video_info=False, language=["en", "vi"] 138 | ) 139 | splits = loader.load_and_split(text_splitter) 140 | 141 | vectorstore = Chroma.from_documents(documents=splits, embedding=emb) 142 | return {index_status: "Indexing Done"} 143 | except Exception as e: 144 | # Print the error message or return it as part of the response 145 | print(f"Error: {e}") # This will print the error to the console or log 146 | return {"index_status": "Indexing Error", "error_message": str(e)} 147 | 148 | 149 | def kill_index(): 150 | global vectorstore 151 | vectorstore = None 152 | return {index_status: "Indexing Undone"} 153 | 154 | 155 | def build_rag_context(docs): 156 | context = "" 157 | for doc in docs: 158 | context += doc.page_content + "\n" 159 | 160 | return context 161 | 162 | 163 | def chatbot(query, history, temp, max_tokens, freq_penalty, k_docs): 164 | global chat_history, sys_prompt 165 | 166 | if "vectorstore" in globals() and vectorstore is not None: 167 | if len(history) == 0: 168 | chat_history = [] 169 | if sys_prompt is not None: 170 | chat_history.append({"role": "system", "content": sys_prompt}) 171 | docs = vectorstore.similarity_search(query, k=k_docs) 172 | else: 173 | history_str = "" 174 | for i, message in enumerate(history): 175 | history_str += f"User: {message[0]}\n" 176 | history_str += f"AI: {message[1]}\n" 177 | 178 | if sys_prompt is not None: 179 | chat_history.append({"role": "system", "content": sys_prompt}) 180 | chat_history.append({"role": "user", "content": history_str}) 181 | docs = vectorstore.similarity_search(history_str) 182 | 183 | context = build_rag_context(docs) 184 | 185 | if len(history) == 0: 186 | prompt = rag_prompt.format(context=context, question=query) 187 | else: 188 | prompt = rag_his_prompt.format( 189 | chat_history=history_str, context=context, question=query 190 | ) 191 | messages = [{"role": "user", "content": prompt}] 192 | else: 193 | if len(history) == 0: 194 | chat_history = [] 195 | if sys_prompt is not None: 196 | chat_history.append({"role": "system", "content": sys_prompt}) 197 | else: 198 | chat_history = [] 199 | if sys_prompt is not None: 200 | chat_history.append({"role": "system", "content": sys_prompt}) 201 | for i, message in enumerate(history): 202 | chat_history.append({"role": "user", "content": message[0]}) 203 | chat_history.append({"role": "assistant", "content": message[1]}) 204 | chat_history.append({"role": "user", "content": query}) 205 | messages = chat_history 206 | 207 | # Uncomment for debugging 208 | # print(messages) 209 | 210 | response = client.chat.completions.create( 211 | model="gpt", 212 | messages=messages, 213 | temperature=temp, 214 | frequency_penalty=freq_penalty, 215 | max_tokens=max_tokens, 216 | stream=True, 217 | ) 218 | stop = ["<|im_end|>", "<|endoftext|>"] 219 | partial_message = "" 220 | for chunk in response: 221 | if len(chunk.choices) != 0: 222 | if chunk.choices[0].delta.content not in stop: 223 | partial_message = partial_message + chunk.choices[0].delta.content 224 | else: 225 | partial_message = partial_message + "" 226 | yield partial_message 227 | 228 | # NEW STYLE 229 | class GusStyle(Base): 230 | def __init__( 231 | self, 232 | *, 233 | primary_hue: colors.Color | str = colors.sky, 234 | secondary_hue: colors.Color | str = colors.blue, 235 | neutral_hue: colors.Color | str = colors.gray, 236 | spacing_size: sizes.Size | str = sizes.spacing_md, 237 | radius_size: sizes.Size | str = sizes.radius_md, 238 | text_size: sizes.Size | str = sizes.text_lg, 239 | font: fonts.Font 240 | | str 241 | | Iterable[fonts.Font | str] = ( 242 | fonts.GoogleFont("Quicksand"), 243 | "ui-sans-serif", 244 | "sans-serif", 245 | ), 246 | font_mono: fonts.Font 247 | | str 248 | | Iterable[fonts.Font | str] = ( 249 | fonts.GoogleFont("IBM Plex Mono"), 250 | "ui-monospace", 251 | "monospace", 252 | ), 253 | ): 254 | super().__init__( 255 | primary_hue=primary_hue, 256 | secondary_hue=secondary_hue, 257 | neutral_hue=neutral_hue, 258 | spacing_size=spacing_size, 259 | radius_size=radius_size, 260 | text_size=text_size, 261 | font=font, 262 | font_mono=font_mono, 263 | ) 264 | 265 | # UPDATED LAYOUT 266 | with gr.Blocks(fill_height=True, theme=GusStyle()) as demo: 267 | with gr.Row(): 268 | with gr.Column(scale=2): 269 | temp_slider = gr.State(0.2) 270 | max_gen_token = gr.State(512) 271 | freq_penalty = gr.State(1.05) 272 | retrieve_docs = gr.State(3) 273 | language = gr.State("default") 274 | gr.ChatInterface( 275 | chatbot=gr.Chatbot(height=800, render=False), 276 | fn=chatbot, # Function to call on user input 277 | title="🍏 MLX Chat", # Title of the web page 278 | retry_btn='Retry', 279 | undo_btn='Undo', 280 | clear_btn='Clear', 281 | additional_inputs=[temp_slider, max_gen_token, freq_penalty, retrieve_docs], 282 | ) 283 | with gr.Column(scale=1): 284 | ## SELECT MODEL 285 | model_name = gr.Dropdown( 286 | label="Select Model", 287 | info="Select your model", 288 | choices=sorted(model_list), 289 | interactive=True, 290 | render=False, 291 | ) 292 | model_name.render() 293 | language = gr.Dropdown( 294 | label="Language", 295 | choices=sorted(SUPPORTED_LANG), 296 | info="Chose Supported Language", 297 | value="default", 298 | interactive=True, 299 | ) 300 | btn1 = gr.Button("Load Model", variant="primary") 301 | btn3 = gr.Button("Unload Model", variant="stop") 302 | 303 | # FILE 304 | mode = gr.Dropdown( 305 | label="Dataset", 306 | info="Choose your dataset type", 307 | choices=["Files (docx, pdf, txt)", "YouTube (url)"], 308 | scale=5, 309 | ) 310 | url = gr.Textbox( 311 | label="URL", 312 | info="Enter your filepath (URL for Youtube)", 313 | interactive=True, 314 | ) 315 | upload_button = gr.UploadButton( 316 | label="Upload File", variant="primary" 317 | ) 318 | # MODEL STATUS 319 | # data = gr.Textbox(visible=lambda mode: mode == 'YouTube') 320 | model_status = gr.Textbox("Model Not Loaded", label="Model Status") 321 | index_status = gr.Textbox("Not Index", label="Index Status") 322 | btn1.click( 323 | load_model, 324 | inputs=[model_name, language], 325 | outputs=[model_status], 326 | ) 327 | btn3.click(kill_process, outputs=[model_status]) 328 | upload_button.upload( 329 | upload, inputs=upload_button, outputs=[url, index_status] 330 | ) 331 | 332 | index_button = gr.Button("Start Indexing", variant="primary") 333 | index_button.click( 334 | indexing, inputs=[mode, url], outputs=[index_status] 335 | ) 336 | stop_index_button = gr.Button("Stop Indexing") 337 | stop_index_button.click(kill_index, outputs=[index_status]) 338 | 339 | 340 | with gr.Accordion("Advanced Setting", open=False): 341 | with gr.Row(): 342 | with gr.Column(scale=1): 343 | temp_slider = gr.Slider( 344 | label="Temperature", 345 | value=0.2, 346 | minimum=0.0, 347 | maximum=1.0, 348 | step=0.05, 349 | interactive=True, 350 | ) 351 | max_gen_token = gr.Slider( 352 | label="Max Tokens", 353 | value=512, 354 | minimum=512, 355 | maximum=4096, 356 | step=256, 357 | interactive=True, 358 | ) 359 | with gr.Column(scale=1): 360 | freq_penalty = gr.Slider( 361 | label="Frequency Penalty", 362 | value=1.05, 363 | minimum=-2, 364 | maximum=2, 365 | step=0.05, 366 | interactive=True, 367 | ) 368 | retrieve_docs = gr.Slider( 369 | label="No. Retrieval Docs", 370 | value=3, 371 | minimum=1, 372 | maximum=10, 373 | step=1, 374 | interactive=True, 375 | ) 376 | 377 | def app(port, share): 378 | print(f"Starting MLX Chat on port {port}") 379 | print(f"Sharing: {share}") 380 | demo.launch(inbrowser=True, share=share, server_port=port) 381 | 382 | 383 | def main(): 384 | parser = argparse.ArgumentParser( 385 | description="Chat with MLX \n" 386 | "Native RAG on MacOS and Apple Silicon with MLX 🧑‍💻" 387 | ) 388 | parser.add_argument( 389 | "--version", action="version", version=f"Chat with MLX {__version__}" 390 | ) 391 | parser.add_argument( 392 | "--port", 393 | type=int, 394 | default=7860, 395 | help="Port number to run the app", 396 | ) 397 | parser.add_argument( 398 | "--share", 399 | default=False, 400 | help="Enable sharing the app", 401 | ) 402 | args = parser.parse_args() 403 | app(port=args.port, share=args.share) 404 | --------------------------------------------------------------------------------