├── Digital-Human └── EchoMimic.md ├── LICENSE ├── README.md ├── Reference ├── MODEL_X_Owners_Manual.pdf ├── dummy.txt └── rocm-docs-amd-com-radeon-en-latest.pdf ├── conda-env ├── README.md └── cosyvoice-env.yml ├── inference └── LLM │ ├── LLM_Voice_Assistant │ ├── 0001-deploy-LLM-local-with-Ollama.patch │ └── Run Picovoice llm voice assistant with ROCm.md │ ├── Run Ollama with AMD iGPU 780M-QuickStart.pdf │ ├── Run_Ollama_with_AMD_iGPU780M-QuickStart.md │ └── Running RAG on AMD Radeon GPU.ipynb ├── tools ├── hf_dl.sh ├── hf_dl_cmd ├── iphi-2.py ├── query_gpu.py ├── test_gpu.py ├── verify_rocm_env.md └── vllm_gadget │ ├── README.md │ ├── compose_ex1.yaml │ ├── curl_test.cmd │ ├── drun.cmd │ ├── multi_vllm_serve.sh │ ├── mvllm_compose_ex2.yaml │ └── vllm_cmd_template.md └── training ├── LoRA_Llama-3.1.ipynb ├── LoRA_Llama-3.2-3B_RadeonW7900.ipynb ├── QLoRA_Llama-3.1-10epochs.ipynb ├── QLoRA_Llama-3.1.ipynb ├── QLoRA_Llama-3.1_RadeonW7900.ipynb ├── W7900_LoRA_Demo.ipynb ├── W7900_QLoRA_Demo.ipynb ├── run_lora.py └── run_qlora_bs4.py /Digital-Human/EchoMimic.md: -------------------------------------------------------------------------------- 1 | ## EchoMimic 2 | 3 | [EchoMimic](https://github.com/BadToBest/EchoMimic): Lifelike Audio-Driven Portrait Animations through Editable Landmark Conditioning 4 | 5 | The orignal project never mentioned support AMD ROCm GPU. Actually it is depends on PyTorch and ROCm is already one part of PyTorch. That means you could run it with AMD ROCm GPU(MI series and Radeon Series GPU). 6 | 7 | **Here** I show the steps for running it with ROCm 8 | 9 | ## Installation 10 | 11 | ### Download the Codes 12 | 13 | ```bash 14 | git clone https://github.com/BadToBest/EchoMimic 15 | cd EchoMimic 16 | ``` 17 | 18 | ### Python Environment Setup 19 | 20 | - Tested System Environment: Ubuntu 22.04, ROCm >= 6.0 21 | - Tested GPUs: Radeon Pro W7900 / MI300X 22 | - Tested Python Version: 3.10 23 | 24 | ```bash 25 | conda create -n echomimic python=3.10 26 | conda activate echomimic 27 | ``` 28 | 29 | Comments the top three lines of requirements.txt and save. (do not install torch cuda version) 30 | ``` 31 | #torch>=2.0.1,<=2.2.2 32 | #torchvision>=0.15.2,<=0.17.2 33 | #torchaudio>=2.0.2,<=2.2.2 34 | ``` 35 | 36 | Install PyTorch ROCm version 37 | ```bash 38 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1 39 | ``` 40 | 41 | Install packages with `pip` 42 | ```bash 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | **Then do it as same as cuda what the original repo README.md say ** 47 | 48 | ### Download ffmpeg-static 49 | Download and decompress [ffmpeg-static](https://www.johnvansickle.com/ffmpeg/old-releases/ffmpeg-4.4-amd64-static.tar.xz), then 50 | ``` 51 | export FFMPEG_PATH=/path/to/ffmpeg-4.4-amd64-static 52 | ``` 53 | 54 | ### Download pretrained weights 55 | 56 | ```shell 57 | git lfs install 58 | git clone https://huggingface.co/BadToBest/EchoMimic pretrained_weights 59 | ``` 60 | 61 | The **pretrained_weights** is organized as follows. 62 | 63 | ``` 64 | ./pretrained_weights/ 65 | ├── denoising_unet.pth 66 | ├── reference_unet.pth 67 | ├── motion_module.pth 68 | ├── face_locator.pth 69 | ├── sd-vae-ft-mse 70 | │ └── ... 71 | ├── sd-image-variations-diffusers 72 | │ └── ... 73 | └── audio_processor 74 | └── whisper_tiny.pt 75 | ``` 76 | 77 | In which **denoising_unet.pth** / **reference_unet.pth** / **motion_module.pth** / **face_locator.pth** are the main checkpoints of **EchoMimic**. Other models in this hub can be also downloaded from it's original hub, thanks to their brilliant works: 78 | - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) 79 | - [sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers) 80 | - [audio_processor(whisper)](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt) 81 | 82 | ### Audio-Drived Algo Inference 83 | Run the python inference script: 84 | 85 | ```bash 86 | python -u infer_audio2vid.py 87 | python -u infer_audio2vid_pose.py 88 | ``` 89 | 90 | ### Audio-Drived Algo Inference On Your Own Cases 91 | 92 | Edit the inference config file **./configs/prompts/animation.yaml**, and add your own case: 93 | 94 | ```bash 95 | test_cases: 96 | "path/to/your/image": 97 | - "path/to/your/audio" 98 | ``` 99 | 100 | The run the python inference script: 101 | ```bash 102 | python -u infer_audio2vid.py 103 | ``` 104 | 105 | ### Motion Alignment between Ref. Img. and Driven Vid. 106 | 107 | (Firstly download the checkpoints with '_pose.pth' postfix from huggingface) 108 | 109 | Edit driver_video and ref_image to your path in demo_motion_sync.py, then run 110 | ```bash 111 | python -u demo_motion_sync.py 112 | ``` 113 | 114 | ### Audio&Pose-Drived Algo Inference 115 | Edit ./configs/prompts/animation_pose.yaml, then run 116 | ```bash 117 | python -u infer_audio2vid_pose.py 118 | ``` 119 | 120 | ### Pose-Drived Algo Inference 121 | Set draw_mouse=True in line 135 of infer_audio2vid_pose.py. Edit ./configs/prompts/animation_pose.yaml, then run 122 | ```bash 123 | python -u infer_audio2vid_pose.py 124 | ``` 125 | 126 | ### Run the Gradio UI 127 | 128 | Thanks to the contribution from @Robin021: 129 | 130 | ```bash 131 | 132 | python -u webgui.py --server_port=3000 133 | 134 | ``` 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 alexhegit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Playing-with-ROCm 2 | 3 | Here to show my experience about playing with ROCm with runable code, step-by-step tutorial to help you reproduce what I have did. If you have iGPU or dGPU of AMD, you may try Machine Learning with them. 4 | 5 | **NOTICE** : For more easier tracking my update, I use 🆕 and 🔥 to flag the new hot topics. 6 | 7 | ## Topics 8 | 9 | ### Training 10 | 11 | #### Finetuning 12 | 13 | - [LoRA with Radeon](./training/W7900_LoRA_Demo.ipynb) 14 | - [QLoRA with Radeon](./training/W7900_QLoRA_Demo.ipynb) 15 | 16 | ### Inference 17 | 18 | - [Deploy Deepseek-R1 in one GPU -AMD Instinct™ MI300X](https://medium.com/@alexhe.amd/deploy-deepseek-r1-in-one-gpu-amd-instinct-mi300x-7a9abeb85f78) 🔥 19 | - [Deploy Llama 3.2 Vision quickly on AMD ROCm with Ollama](https://medium.com/@alexhe.amd/deploy-llama-3-2-vision-quickly-on-amd-rocm-with-ollama-9a23e9a86fea) 20 | - [Deploy vLLM service with Kubernetes over AMD ROCm GPU](https://medium.com/@alexhe.amd/deploy-vllm-service-with-kubernetes-over-amd-rocm-gpu-27cd5321271a) 21 | - [Deploy LLM with Radeon iGPU 780M](https://github.com/alexhegit/Playing-with-ROCm/blob/main/inference/LLM/Run%20Ollama%20with%20AMD%20iGPU%20780M-QuickStart.pdf) 22 | - [Examples of using vLLM with ROCm](https://github.com/alexhegit/Playing-with-ROCm/blob/main/tools/vllm_gadget/README.md) 🆕 🔥 23 | - Help scripts to fast use vLLM with ROCm 24 | - Example: using docker compose to run multiple containers of vllm serve. e.g. gpu=0,1 for container-1 and gpu=6,7 for container-2. 25 | 26 | ### MLOPS with ROCm 27 | 28 | - vLLM 29 | - [Neural Magic vLLM, nm-vllm](https://docs.neuralmagic.com) 30 | - [AIBrix](https://aibrix.readthedocs.io/latest/index.html) 31 | - [KubeAI](https://github.com/substratusai/kubeai) : AI Inferencing Operator 32 | - [vLLM Production Stack](https://docs.vllm.ai/en/latest/deployment/integrations/production-stack.html) 33 | 34 | ### Application/Demo 35 | 36 | - [RAG_LLM_QnA_Assistant](https://github.com/alexhegit/RAG_LLM_QnA_Assistant), Step-by-step tutorial repo project to setup RAG Apps with ROCm 37 | - [Ask4ROCm_Chatbot](https://github.com/alexhegit/Ask4ROCm_Chatbot), An chatbot app drive by RAG solution. 38 | - [LLM_Voice_Assistant](https://github.com/alexhegit/Playing-with-ROCm/blob/main/inference/LLM/LLM_Voice_Assistant/Run%20Picovoice%20llm%20voice%20assistant%20with%20ROCm.md) , Use STT/TTS model from Picovoice. 39 | - [Easy-Wav2Lip-ROCm](https://github.com/alexhegit/Easy-Wav2Lip-ROCm), Easy run Wav2Lip with ROCm over AMD GPU. Way2Lip is a project of Generalized Lip Sync Models 40 | - [Run EchoMimic with ROCm](./Digital-Human/EchoMimic.md) EchoMimic: Lifelike Audio-Driven Portrait Animations through Editable Landmark Conditioning 41 | - [Run LLama-3.2-vision with ROCm](https://medium.com/@alexhe.amd/deploy-llama-3-2-vision-quickly-on-amd-rocm-with-ollama-9a23e9a86fea) Ollama+Llama-3.2-vision+ROCm 42 | - [Deploy vLLM service with Kubernetes over AMD ROCm GPU](https://medium.com/@alexhe.amd/deploy-vllm-service-with-kubernetes-over-amd-rocm-gpu-27cd5321271a) , Turoial with sample codes. 43 | - [Play Qwen2.5–Omni with AMD GPU](https://medium.com/@alexhe.amd/play-qwen2-5-omni-with-amd-gpu-9d80de58589a) 🆕 🔥 44 | 45 | ------------------------------------------------------------------- 46 | 47 | ## Projects work over ROCm 48 | 49 | These projects may not offical announce to support ROCm GPU. But they work fine base on my verification. 50 | | Name | URL | Category | Hands on | 51 | | -------- | ------- | ------- | ------- | 52 | | CLM-4-Voice | https://github.com/THUDM/GLM-4-Voice | Conversation AI | | 53 | | EchoMimic | https://github.com/BadToBest/EchoMimic | Digital Human GenAI | [Run EchoMimic with ROCm](./Digital-Human/EchoMimic.md) | 54 | | Easy-Wav2Lip | https://github.com/anothermartz/Easy-Wav2Lip | Digital Human GenAI | [Easy-Wav2Lip-ROCm](https://github.com/alexhegit/Easy-Wav2Lip-ROCm) | 55 | | GOT-OCR2 | https://github.com/Ucas-HaoranWei/GOT-OCR2.0 | end2end OCR | 56 | | Moshi | https://github.com/kyutai-labs/moshi | Conversation AI | | 57 | | mini-omni | https://github.com/gpt-omni/mini-omni | Conversation AI | | 58 | | mini-omni2 | https://github.com/gpt-omni/mini-omni2 | Conversation AI | | 59 | | Picovoice/orca | https://github.com/Picovoice/orca | Conversation AI | [LLM_Voice_Assistant](https://github.com/alexhegit/Playing-with-ROCm/blob/main/inference/LLM/LLM_Voice_Assistant/Run%20Picovoice%20llm%20voice%20assistant%20with%20ROCm.md) | 60 | | Retrieval-based-Voice-Conversion-WebUI | https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git | Easily train a good VC model with voice data <= 10 mins! | | 61 | | Freeze-Omni 🆕 🔥| https://github.com/VITA-MLLM/Freeze-Omni | A Smart and Low Latency Speech-to-speech Dialogue Model with Frozen LLM | Realtime on Radeon W7900, realtime with good response, feel good than Moshi, mini-omni2 | 62 | | Step-Auido 🆕 🔥| https://github.com/stepfun-ai/Step-Audio | Convseration AI | Too big model, not real time | 63 | | Step-Video-T2V 🆕 🔥| https://github.com/stepfun-ai/Step-Video-T2V | Video GenAI | Run with 1xMI300X | 64 | | UI-TARS | https://github.com/bytedance/UI-TARS | Automated GUI Interaction with Native Agentsfrom ByteDance | | 65 | | Qwen2.5-Omni 🆕 🔥|https://github.com/QwenLM/Qwen2.5-Omni | end-to-end multimodal model in the Qwen serie | | 66 | | CosyVoice| https://github.com/FunAudioLLM/CosyVoice | TTS LLM | [tutorial](https://medium.com/@alexhe.amd/play-cosyvoice-on-amd-rocm-gpu-459c942f7214) , ![conda-env](./conda-env/cosyvoice-env.yml) | 67 | 68 | 69 | 70 | ### Wish List 71 | 72 | | Name | URL | Category | Hands on | 73 | | -------- | ------- | ------- | ------- | 74 | | hertz-dev |https://github.com/Standard-Intelligence/hertz-dev | Conversation AI | 75 | | Freeze-Omni | https://github.com/VITA-MLLM/Freeze-Omni | Conversation AI | | 76 | | LLaMA-Omni| https://github.com/ictnlp/LLaMA-Omni | Conversation AI | | 77 | | ichigo Llama 3.1 |https://github.com/homebrewltd/ichigo| Conversation AI|| 78 | | ichigo-demo | https://github.com/homebrewltd/ichigo-demo/tree/docker ||| 79 | | Exo | https://github.com/exo-explore/exo | heterogeneous distribute inference || 80 | | Perpleica | https://github.com/ItzCrazyKns/Perplexica | AI Search Engine| issue | 81 | | MiniPerplx| https://github.com/zaidmukaddam/miniperplx | A minimalistic AI-powered search engine | | 82 | | ollama-helm |https://github.com/otwld/ollama-helm||| 83 | | OpenHands |https://github.com/All-Hands-AI/OpenHands| a platform for software development agents powered by AI || 84 | | HayStack| https://github.com/deepset-ai/haystack | end-to-end LLM framework that allows you to build applications powered by LLMs || 85 | | Bailing |https://github.com/ictnlp/BayLing||| 86 | | Bailing| https://github.com/wwbin2017/bailing||| 87 | | BabelDuck |https://github.com/Orenoid/BabelDuck|Beginner-friendly AI conversation practice application|| 88 | | KubeAI| https://github.com/substratusai/kubeai | deploy and manage AI models on Kubernetes || 89 | | DSPy | https://dspy.ai | the framework for programming || 90 | | KServe | https://kserve.github.io/website/latest/ ||| 91 | | Camel-ai/OWL | https://github.com/camel-ai/owl ||| 92 | | VITA| https://github.com/VITA-MLLM/VITA | VITA-1.5: Towards GPT-4o Level Real-Time Vision and Speech Interaction || 93 | | DiffRhythm | https://github.com/ASLP-lab/DiffRhythm | End-to-End Full-Length Song Generation with Latent Diffusion | | 94 | | Open-Sora | https://github.com/hpcaitech/Open-Sora | | | 95 | | Real-Time-Voice-Cloning |https://github.com/CorentinJ/Real-Time-Voice-Cloning| | | 96 | | OpenVoice | https://github.com/myshell-ai/OpenVoice ||| 97 | | KrilinAI |https://github.com/krillinai/KrillinAI||| 98 | | RealtimeVoiceChat | https://github.com/KoljaB/RealtimeVoiceChat ||| 99 | | pipecat| https://github.com/pipecat-ai/pipecat ||| 100 | 101 | 102 | 103 | ### Tracing 104 | - [Awesome-Multimodel-LLM](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models) 105 | - [Utiliy scripts of modal-labs/llm-serving](https://github.com/modal-labs/modal-examples/tree/main/06_gpu_and_ml/llm-serving) 106 | 107 | ### Misc 108 | 109 | #### MCP 110 | - [write-my-first-mcp-server](https://medium.com/@alexhe.amd/write-my-first-mcp-server-e67a747aacc6) 111 | - [write-my-first-mcp-client](https://medium.com/@alexhe.amd/my-first-mcp-client-36e4a222f31b) 112 | 113 | 114 | ### 3rd-stuff 115 | - [Tutorial: vLLM deploy](https://ploomber.io/blog/vllm-deploy/) 116 | - Summary: Awesome-Agent-Framework 117 | - https://github.com/AgentSpaceAI/Awesome-Agent-Framework 118 | - https://github.com/kyrolabs/awesome-agents 119 | 120 | ------------------------------------------------------------------- 121 | 122 | ``` 123 | @misc{ Playing with ROCm, 124 | author = {He Ye (Alex)}, 125 | title = {Playing with ROCm: share my experience and practice}, 126 | howpublished = {\url{https://alexhegit.github.io/}}, 127 | year = {2024--} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /Reference/MODEL_X_Owners_Manual.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexhegit/Playing-with-ROCm/ae5edfea2c2207f1cc723e39a6b048dd5430f65d/Reference/MODEL_X_Owners_Manual.pdf -------------------------------------------------------------------------------- /Reference/dummy.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Reference/rocm-docs-amd-com-radeon-en-latest.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexhegit/Playing-with-ROCm/ae5edfea2c2207f1cc723e39a6b048dd5430f65d/Reference/rocm-docs-amd-com-radeon-en-latest.pdf -------------------------------------------------------------------------------- /conda-env/README.md: -------------------------------------------------------------------------------- 1 | Here are the conda env source files use for running projects with ROCm 2 | 3 | Usage: 4 | $conda create -f [env.yml] 5 | -------------------------------------------------------------------------------- /conda-env/cosyvoice-env.yml: -------------------------------------------------------------------------------- 1 | name: cosyvoice 2 | channels: 3 | - conda-forge 4 | - defaults 5 | - https://repo.anaconda.com/pkgs/main 6 | - https://repo.anaconda.com/pkgs/r 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - atk-1.0=2.38.0=hd4edc92_1 11 | - bzip2=1.0.8=h5eee18b_6 12 | - ca-certificates=2025.2.25=h06a4308_0 13 | - cairo=1.16.0=hb05425b_5 14 | - expat=2.7.0=h5888daf_0 15 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 16 | - font-ttf-inconsolata=3.000=h77eed37_0 17 | - font-ttf-source-code-pro=2.038=h77eed37_0 18 | - font-ttf-ubuntu=0.83=h77eed37_3 19 | - fontconfig=2.14.1=h55d465d_3 20 | - fonts-conda-ecosystem=1=0 21 | - fonts-conda-forge=1=0 22 | - freetype=2.10.4=h0708190_1 23 | - fribidi=1.0.10=h36c2ea0_0 24 | - gdk-pixbuf=2.42.10=h05c8ddd_0 25 | - gettext=0.23.1=h5888daf_0 26 | - gettext-tools=0.23.1=h5888daf_0 27 | - giflib=5.2.2=hd590300_0 28 | - glib=2.78.4=h6a678d5_0 29 | - glib-tools=2.78.4=h6a678d5_0 30 | - graphite2=1.3.14=h295c915_1 31 | - graphviz=2.49.0=h85b4f2f_0 32 | - gtk2=2.24.33=h90689f9_2 33 | - gts=0.7.6=h977cf35_4 34 | - harfbuzz=4.3.0=hf52aaf7_2 35 | - icu=73.2=h59595ed_0 36 | - jpeg=9e=h0b41bf4_3 37 | - ld_impl_linux-64=2.40=h12ee557_0 38 | - lerc=4.0.0=h27087fc_0 39 | - libasprintf=0.23.1=h8e693c7_0 40 | - libasprintf-devel=0.23.1=h8e693c7_0 41 | - libdeflate=1.22=hb9d3cd8_0 42 | - libexpat=2.7.0=h5888daf_0 43 | - libffi=3.4.4=h6a678d5_1 44 | - libgcc=14.2.0=h767d61c_2 45 | - libgcc-ng=14.2.0=h69a702a_2 46 | - libgd=2.3.3=h6a678d5_3 47 | - libgettextpo=0.23.1=h5888daf_0 48 | - libgettextpo-devel=0.23.1=h5888daf_0 49 | - libglib=2.78.4=hdc74915_0 50 | - libgomp=14.2.0=h767d61c_2 51 | - libiconv=1.18=h4ce23a2_1 52 | - libltdl=2.4.3a=h5888daf_0 53 | - libpng=1.6.39=h5eee18b_0 54 | - librsvg=2.54.4=h7abd40a_0 55 | - libstdcxx=14.2.0=h8f9b012_2 56 | - libstdcxx-ng=14.2.0=h4852527_2 57 | - libtiff=4.5.1=hffd6297_1 58 | - libtool=2.5.4=h5888daf_0 59 | - libuuid=1.41.5=h5eee18b_0 60 | - libwebp=1.2.4=h1daa5a0_1 61 | - libwebp-base=1.2.4=h166bdaf_0 62 | - libxcb=1.17.0=h8a09558_0 63 | - libxml2=2.13.5=hfdd30dd_0 64 | - lz4-c=1.9.4=hcb278e6_0 65 | - ncurses=6.4=h6a678d5_0 66 | - openfst=1.8.2=h924138e_2 67 | - openssl=3.5.0=h7b32b05_0 68 | - pango=1.50.7=hbd2fdc8_0 69 | - pcre2=10.42=hebb0a14_1 70 | - pip=25.0=py310h06a4308_0 71 | - pixman=0.44.2=h29eaf8c_0 72 | - pthread-stubs=0.4=hb9d3cd8_1002 73 | - pynini=2.1.5=py310hd41b1e2_6 74 | - python=3.10.16=he870216_1 75 | - python_abi=3.10=2_cp310 76 | - readline=8.2=h5eee18b_0 77 | - setuptools=75.8.0=py310h06a4308_0 78 | - sqlite=3.45.3=h5eee18b_0 79 | - tk=8.6.14=h39e8969_0 80 | - tzdata=2025a=h04d1e81_0 81 | - wheel=0.45.1=py310h06a4308_0 82 | - xorg-libxau=1.0.12=hb9d3cd8_0 83 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0 84 | - xz=5.6.4=h5eee18b_1 85 | - zlib=1.2.13=h5eee18b_1 86 | - zstd=1.5.6=hc292b87_0 87 | - pip: 88 | - attrs==25.3.0 89 | - beautifulsoup4==4.13.3 90 | - certifi==2025.1.31 91 | - click==8.1.8 92 | - conformer==0.3.2 93 | - contourpy==1.3.1 94 | - cycler==0.12.1 95 | - decorator==5.2.1 96 | - deepspeed==0.14.2 97 | - diffusers==0.29.0 98 | - fastapi==0.115.6 99 | - fastapi-cli==0.0.4 100 | - filelock==3.13.1 101 | - fonttools==4.57.0 102 | - fsspec==2024.12.0 103 | - gradio==5.4.0 104 | - gradio-client==1.4.2 105 | - huggingface-hub==0.30.2 106 | - idna==3.10 107 | - inflect==7.3.1 108 | - jinja2==3.1.4 109 | - kiwisolver==1.4.8 110 | - librosa==0.10.2 111 | - lightning==2.2.4 112 | - markupsafe==2.1.5 113 | - matplotlib==3.7.5 114 | - modelscope==1.20.0 115 | - more-itertools==10.6.0 116 | - mpmath==1.3.0 117 | - msgpack==1.1.0 118 | - networkx==3.1 119 | - numpy==1.26.4 120 | - oauthlib==3.2.2 121 | - onnx==1.16.0 122 | - onnxruntime==1.18.0 123 | - openai-whisper==20231117 124 | - packaging==24.2 125 | - pandas==2.2.3 126 | - pillow==11.0.0 127 | - psutil==7.0.0 128 | - pyarrow==18.1.0 129 | - pyasn1==0.6.1 130 | - pyparsing==3.2.3 131 | - python-multipart==0.0.12 132 | - pytorch-lightning==2.5.1 133 | - pytorch-triton-rocm==3.2.0 134 | - pytz==2025.2 135 | - pyworld==0.3.4 136 | - pyyaml==6.0.2 137 | - safehttpx==0.1.6 138 | - scipy==1.15.2 139 | - six==1.17.0 140 | - soupsieve==2.6 141 | - sympy==1.13.1 142 | - tensorboard==2.14.0 143 | - tokenizers==0.19.1 144 | - torch==2.6.0+rocm6.2.4 145 | - torchaudio==2.6.0+rocm6.2.4 146 | - torchvision==0.21.0+rocm6.2.4 147 | - transformers==4.40.1 148 | - triton==2.3.1 149 | - ttsfrd==0.4.2 150 | - ttsfrd-dependency==0.1 151 | - urllib3==2.3.0 152 | - uvicorn==0.30.0 153 | - websockets==12.0 154 | prefix: /home/amd/anaconda3/envs/cosyvoice 155 | -------------------------------------------------------------------------------- /inference/LLM/LLM_Voice_Assistant/0001-deploy-LLM-local-with-Ollama.patch: -------------------------------------------------------------------------------- 1 | From 3bf9d10dfa071da6bc3ae9833e9be5cccf84c0ab Mon Sep 17 00:00:00 2001 2 | From: Alex He 3 | Date: Mon, 1 Jul 2024 21:45:34 +0800 4 | Subject: [PATCH] deploy LLM local with Ollama 5 | 6 | Signed-off-by: Alex He 7 | --- 8 | demo/llm_voice_assistant/src/llm.py | 10 ++++++++-- 9 | 1 file changed, 8 insertions(+), 2 deletions(-) 10 | 11 | diff --git a/demo/llm_voice_assistant/src/llm.py b/demo/llm_voice_assistant/src/llm.py 12 | index 542c9b0..ba49964 100644 13 | --- a/demo/llm_voice_assistant/src/llm.py 14 | +++ b/demo/llm_voice_assistant/src/llm.py 15 | @@ -64,7 +64,9 @@ class LLM: 16 | 17 | 18 | class OpenAILLM(LLM): 19 | - MODEL_NAME = "gpt-3.5-turbo" 20 | + #MODEL_NAME = "gpt-3.5-turbo" 21 | + #MODEL_NAME = "phi3" 22 | + MODEL_NAME = "llama3" 23 | RANDOM_SEED = 7777 24 | 25 | def __init__( 26 | @@ -77,7 +79,11 @@ class OpenAILLM(LLM): 27 | 28 | from openai import OpenAI 29 | self._model_name = model_name 30 | - self._client = OpenAI(api_key=access_key) 31 | + #self._client = OpenAI(api_key=access_key) 32 | + self._client = OpenAI( 33 | + base_url = "http://localhost:11434/v1", 34 | + api_key = access_key 35 | + ) 36 | 37 | self._history = [{"role": "system", "content": self._system_message}] 38 | 39 | -- 40 | 2.34.1 41 | 42 | -------------------------------------------------------------------------------- /inference/LLM/LLM_Voice_Assistant/Run Picovoice llm voice assistant with ROCm.md: -------------------------------------------------------------------------------- 1 | 2 | Picovoice show a LLM Voice Assistant demo at [orca/demo/llm_voice_assistant at main · Picovoice/orca (github.com)](https://github.com/Picovoice/orca/tree/main/demo/llm_voice_assistant) . The orginal demo depends on calling GPT3.5-turbo from Cloud with OpenAI API. 3 | 4 | Here I will guide you to modify it and run [orca/demo/llm_voice_assistant at main · Picovoice/orca (github.com)](https://github.com/Picovoice/orca/tree/main/demo/llm_voice_assistant) with LLM deploy at local with AMD iGPU-780M. 5 | 6 | ## Platform 7 | **HW**: AMD Ryzen 7 8845HS(iGPU-780M) 8 | 9 | **OS**: Ubuntu22.04 10 | 11 | **SW**: 12 | - torch v2.3.0+rocm6.0 13 | - Ollama 14 | 15 | ## Steps 16 | 1. Setup Ollama with ROCm 17 | 18 | Refer to [Playing-with-ROCm/inference/LLM/Run_Ollama_with_AMD_iGPU780M-QuickStart.md at main · alexhegit/Playing-with-ROCm (github.com)](https://github.com/alexhegit/Playing-with-ROCm/blob/main/inference/LLM/Run_Ollama_with_AMD_iGPU780M-QuickStart.md) 19 | 20 | 2. Setup [orca/demo/llm_voice_assistant at main · Picovoice/orca (github.com)](https://github.com/Picovoice/orca/tree/main/demo/llm_voice_assistant) 21 | - Clone the repo 22 | - Refer to the README to install the requirements(pip packages) 23 | 24 | 3. Patch to [orca/demo/llm_voice_assistant at main · Picovoice/orca (github.com)](https://github.com/Picovoice/orca/tree/main/demo/llm_voice_assistant) 25 | 26 | Run the commands 27 | ``` 28 | cd [orca repo root dir] 29 | git am 0001-deploy-LLM-local-with-Ollama.patch 30 | ``` 31 | This patch file is provided at [here](https://github.com/alexhegit/Playing-with-ROCm/blob/main/inference/LLM/LLM_Voice_Assistant/0001-deploy-LLM-local-with-Ollama.patch). 32 | 33 | 5. Request AccessKey from https://picovoice.ai/ 34 | 35 | 6. Connect the micphone to the PC 36 | 37 | 7. Run the demo in ./orca/demo/llm_voice_assistant/ in terminal 38 | `python3 llm_voice_assistant_demo.py --picovoice-access-key [your AccessKey] --openai-access-key ollama` 39 | 40 | The terminal output "Listening ..." . Please start the talking with your micphone. 41 | 42 | 43 | ``` 44 | Listening ... 45 | 46 | Time to wait for ChatGPT (llama3) : >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 3.9s 47 | Time to wait for Picovoice Orca : >>>>>>>>>>>>>>>>>>>>>> 0.6s 48 | Answering with Picovoice Orca ... 49 | 50 | 51 | Listening ... 52 | 53 | Time to wait for ChatGPT (llama3) : >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 1.1s 54 | Time to wait for Picovoice Orca : >>>>>>>>>>>>>>>>>> 0.5s 55 | Answering with Picovoice Orca ... 56 | 57 | Listening ... 58 | 59 | Time to wait for ChatGPT (llama3) : >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 1.1s 60 | Time to wait for Picovoice Orca : >>>>>>>>>>>>>>>>>>> 0.5s 61 | Answering with Picovoice Orca ... 62 | ``` 63 | 64 | -------------------------------------------------------------------------------- /inference/LLM/Run Ollama with AMD iGPU 780M-QuickStart.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexhegit/Playing-with-ROCm/ae5edfea2c2207f1cc723e39a6b048dd5430f65d/inference/LLM/Run Ollama with AMD iGPU 780M-QuickStart.pdf -------------------------------------------------------------------------------- /inference/LLM/Run_Ollama_with_AMD_iGPU780M-QuickStart.md: -------------------------------------------------------------------------------- 1 | 2 | Ollama could run the iGPU 780M of AMD Ryzen CPU at Linux base on ROCm. There only has a little extra settings than Radeon dGPU like RX7000 series. 3 | 4 | ## Keys for usage 5 | - Ryzen 7000s/8000s CPU with iGPU 780M 6 | - amdgpu driver and rocm6.0 7 | - Linux OS is required (Windows and WSL2 are not supported) 8 | - BIOS must be set to enable the iGPU and dedicate > 1GB RAM to VRAM 9 | - HSA_OVERRIDE_GFX_VERSION="11.0.0" is set (extral setting for AMD iGPU-780M) 10 | 11 | ## Prerequisites 12 | 0. Set UMA for iGPU in BIOS. (at least >1GB, recommend to >8GB for Llama3:8b q4_0 model size is 4.7GB) 13 | 1. Install GPU Driver and ROCm 14 | Refer to 15 | - [AMD ROCm™ documentation — ROCm Documentation](https://rocmdocs.amd.com/en/latest/) 16 | - [AMD ROCm™ Quick start installion](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html#rocm-install-quick) 17 | 18 | 2. Install Ollama 19 | 20 | *`curl -fsSL https://ollama.com/install.sh | sh`* 21 | 22 | ## Steps 23 | The iGPU is not detected by Ollama at default. We need extra steps to enable it. 24 | 1. Stop the ollama.service 25 | 26 | `sudo systemctl stop ollama.service` 27 | 28 | Then find out the pid of ollama.service by 'ps -elf | grep ollama' and then 'kill -p [pid]' 29 | 30 | 2. for iGPU 780 w/ ROCm ( not work in WSL, need run in Linux) 31 | 32 | `HSA_OVERRIDE_GFX_VERSION="11.0.0" ollama serve &` 33 | 34 | 3. Run ollama 35 | 36 | `ollama run tinyllama` 37 | 38 | Use rocm-smi to watch the utilization of iGPU When run ollama with ROCm. 39 | 40 | Another way to replace the step-2 above is to config the ollama.service for iGPU with ROCm as default. 41 | 42 | `sudo systemctl edit ollama.service` 43 | 44 | Add the contents into the /etc/systemd/system/ollama.service.d/override.conf 45 | 46 | ``` 47 | [Service] 48 | Environment="HSA_OVERRIDE_GFX_VERSION=11.0.0" 49 | ``` 50 | 51 | Then Reboot the Linux or just restart the ollama.srevice by, 52 | 53 | `sudo systemctl restart ollama.service` 54 | 55 | 56 | ### Check iGPU utilizaion 57 | 58 | Run `ollama ps` 59 | 60 | ``` 61 | $ ollama ps 62 | NAME ID SIZE PROCESSOR UNTIL 63 | llama2:latest 78e26419b446 5.6 GB 100% GPU 4 minutes from now 64 | ``` 65 | 66 | **Examples of iGPU 780M w/ ROCm** 67 | ``` 68 | $HSA_OVERRIDE_GFX_VERSION="11.0.0" /usr/local/bin/ollama serve & 69 | 70 | $ollama run llama2:latest "where was beethoven born?" --verbose 71 | 72 | Ludwig van Beethoven was born in Bonn, Germany on December 16, 1770. 73 | total duration: 4.385911867s 74 | load duration: 2.524807278s 75 | prompt eval count: 27 token(s) 76 | prompt eval duration: 465.157ms 77 | prompt eval rate: 58.04 tokens/s 78 | eval count: 26 token(s) 79 | eval duration: 1.349772s 80 | eval rate: 19.26 tokens/s 81 | ``` 82 | 83 | ## Benchmark 84 | 85 | **Test Platform**:AOOSTAR GEM12 AMD Ryzen 7 8845HS Mini PC 86 | 87 | **Benchmark commands**: 88 | 89 | `ollama run tinyllama "where was beethoven born?" --verbose` 90 | 91 | `for run in {1..10}; do echo "where was beethoven born?" | ollama run tinyllama --verbose 2>&1 >/dev/null | grep "eval rate:"; done` 92 | 93 | | Model | Model Size | Radeon 780M
(@ubuntu+ROCm6) | 94 | | -------------- | ---------- | --------------------------- | 95 | | tinyllama | 637MB | 92 | 96 | | llama2:latest | 3.8GB | 18 | 97 | | llama2-chinese | 3.8GB | 18 | 98 | | llama3:8b | 4.7GB | 16 | 99 | | qwen:1.8b | 1.1GB | 61 | 100 | 101 | *NOTE* 102 | - Performance in Tokens/s 103 | - LLM is quantized as Q4_0 at default in Ollama 104 | -------------------------------------------------------------------------------- /tools/hf_dl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # You should install it if not 4 | # pip install -U huggingface_hub 5 | 6 | #export HF_ENDPOINT=https://hf-mirror.com 7 | 8 | # Check if the correct number of arguments was passed 9 | if [ "$#" -ne 2 ]; then 10 | echo "Usage: $0 " 11 | exit 1 12 | fi 13 | 14 | # Assign the script arguments to variables 15 | MODEL_ID=$1 16 | SAVE_DIR=$2 17 | 18 | # Set the Hugging Face cache environment variable to the save directory 19 | #export TRANSFORMERS_CACHE=$SAVE_DIR 20 | 21 | # Create the directory if it does not exist 22 | mkdir -p "$SAVE_DIR" 23 | 24 | # Use transformers-cli to download the model 25 | huggingface-cli download --resume-download --local-dir-use-symlinks False $MODEL_ID --local-dir $SAVE_DIR 26 | 27 | # Check if the download was successful 28 | if [ $? -eq 0 ]; then 29 | echo "Model downloaded successfully." 30 | else 31 | echo "Failed to download the model." 32 | exit 1 33 | fi 34 | -------------------------------------------------------------------------------- /tools/hf_dl_cmd: -------------------------------------------------------------------------------- 1 | huggingface-cli download --resume-download --local-dir-use-symlinks False bigscience/bloom-560m --local-dir bloom-560m 2 | -------------------------------------------------------------------------------- /tools/iphi-2.py: -------------------------------------------------------------------------------- 1 | #python iphi-2 2 | 3 | import torch 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | 6 | torch.set_default_device("cuda:0") 7 | 8 | pmodel="/DATA/phi-2" 9 | 10 | #model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True) 11 | #tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) 12 | #model = AutoModelForCausalLM.from_pretrained("../LLM_Files/phi-2", torch_dtype="auto", trust_remote_code=True) 13 | model = AutoModelForCausalLM.from_pretrained(pmodel, torch_dtype="auto", flash_attn=True, flash_rotary=True, fused_dense=True, trust_remote_code=True) 14 | tokenizer = AutoTokenizer.from_pretrained(pmodel, trust_remote_code=True) 15 | 16 | inputs = tokenizer('''def print_prime(n): 17 | """ 18 | Print all primes between 1 and n 19 | """''', return_tensors="pt", return_attention_mask=False) 20 | 21 | outputs = model.generate(**inputs, max_length=200) 22 | text = tokenizer.batch_decode(outputs)[0] 23 | print(text) 24 | -------------------------------------------------------------------------------- /tools/query_gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | print(torch.__version__) 4 | 5 | if torch.cuda.is_available(): 6 | device = torch.device("cuda") # a CUDA device object 7 | print('Using GPU:', torch.cuda.get_device_name(0)) 8 | print('GPU details:', torch.cuda.get_device_properties(0)) 9 | else: 10 | device = torch.device("cpu") 11 | print('Using CPU') 12 | -------------------------------------------------------------------------------- /tools/test_gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["HIP_VISIBLE_DEVICES"]="0" 3 | 4 | import torch 5 | use_cuda = torch.cuda.is_available() 6 | if use_cuda: 7 | print('__CUDNN VERSION:', torch.backends.cudnn.version()) 8 | print('__Number CUDA Devices:', torch.cuda.device_count()) 9 | cunt = torch.cuda.device_count() 10 | -------------------------------------------------------------------------------- /tools/verify_rocm_env.md: -------------------------------------------------------------------------------- 1 | # Verify PyTorch installation 2 | (refer to https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/install-pytorch.html#verify-pytorch-installation) 3 | 4 | Confirm if PyTorch is correctly installed. 5 | 6 | 1. Verify if Pytorch is installed and detecting the GPU compute device. 7 | 8 | `python3 -c 'import torch' 2> /dev/null && echo 'Success' || echo 'Failure'` 9 | 10 | **Expected result:** 11 | 12 | `Success` 13 | 14 | 2. Enter command to test if the GPU is available. 15 | 16 | `python3 -c 'import torch; print(torch.cuda.is_available())'` 17 | 18 | **Expected result:** 19 | `True` 20 | 21 | 3. Enter command to display installed GPU device name. 22 | 23 | `python3 -c "import torch; print(f'device name [0]:', torch.cuda.get_device_name(0))"` 24 | 25 | **Expected result:** Example: device name [0]: Radeon RX 7900 XTX 26 | 27 | `device name [0]: ` 28 | 29 | 4. Enter command to display component information within the current PyTorch environment. 30 | 31 | `python3 -m torch.utils.collect_env` 32 | 33 | **Expected result:** 34 | 35 | ``` 36 | PyTorch version 37 | 38 | ROCM used to build PyTorch 39 | 40 | OS 41 | 42 | Is CUDA available 43 | 44 | GPU model and configuration 45 | 46 | HIP runtime version 47 | 48 | MIOpen runtime version 49 | ``` 50 | 51 | Environment set-up is complete, and the system is ready for use with PyTorch to work with machine learning models, and algorithms. 52 | -------------------------------------------------------------------------------- /tools/vllm_gadget/README.md: -------------------------------------------------------------------------------- 1 | I put some help scripts for using vLLM with ROCm. 2 | 3 | ## File List: 4 | 5 | - [drun.cmd](drun.cmd) : An example use alias for the long long docker run command. 6 | 7 | Usage: 8 | 9 | ```shell 10 | . drun.cmd 11 | drun [docker args] [docker iamge] 12 | 13 | e.g. 14 | drun --name rocm-vllm-serve rocm/vllm-dev:main 15 | ``` 16 | 17 | - [multi_vllm_serve.sh](multi_vllm_serve.sh) : An example to run multiple container instance of vllm serve. 18 | 19 | Container 1: vllm serve with gpu=0,1 20 | 21 | Container 2: vllm serve with gpu=6,7 22 | 23 | Usage: 24 | 25 | ```shell 26 | # start the containers 27 | . multi_vllm_serve.sh 28 | ``` 29 | 30 | Check the containers by `docker ps` 31 | 32 | ```shell 33 | (base) amd@tw043:~/alehe/2-vllm$ docker ps 34 | CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES 35 | b029dded82a8 rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6 "vllm serve models/m…" 57 seconds ago Up 56 seconds vllm-s1 36 | 5f05376dc64a rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6 "vllm serve models/m…" 57 seconds ago Up 56 seconds vllm-s2 37 | ``` 38 | 39 | 40 | Check the GPU utilization by `rocm-smi` 41 | 42 | ```shell 43 | Every 2.0s: rocm-smi tw043: Tue Mar 4 12:53:06 2025 44 | 45 | ============================================ ROCm System Management Interface ============================================ 46 | ====================================================== Concise Info ====================================================== 47 | Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU% 48 | ^[3m (DID, GUID) (Junction) (Socket) (Mem, Compute, ID) ^[0m 49 | ========================================================================================================================== 50 | 0 2 0x74a1, 28851 41.0°C 174.0W NPS1, SPX, 0 2100Mhz 900Mhz 0% auto 750.0W 87% 0% 51 | 1 3 0x74a1, 51499 36.0°C 174.0W NPS1, SPX, 0 2104Mhz 900Mhz 0% auto 750.0W 86% 0% 52 | 2 4 0x74a1, 57603 38.0°C 136.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% 53 | 3 5 0x74a1, 22683 35.0°C 136.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% 54 | 4 6 0x74a1, 53458 39.0°C 140.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% 55 | 5 7 0x74a1, 26954 35.0°C 133.0W NPS1, SPX, 0 133Mhz 900Mhz 0% auto 750.0W 0% 0% 56 | 6 8 0x74a1, 16738 41.0°C 178.0W NPS1, SPX, 0 2100Mhz 900Mhz 0% auto 750.0W 87% 0% 57 | 7 9 0x74a1, 63738 37.0°C 171.0W NPS1, SPX, 0 2105Mhz 900Mhz 0% auto 750.0W 86% 0% 58 | ========================================================================================================================== 59 | ================================================== End of ROCm SMI Log =================================================== 60 | ``` 61 | 62 | 63 | 64 | 65 | - [curl_test.cmd](curl_test.cmd) : Some commands for testing the vllm serve 66 | 67 | 68 | ```shell 69 | $ curl http://localhost:8001/v1/models 70 | {"object":"list","data":[{"id":"models/models--meta-llama--Llama-3.1-8B-Instruct","object":"model","created":1741093207,"owned_by":"vllm","root":"models/models--meta-llama--Llama-3.1-8B-Instruct","parent":null,"max_model_len":131072,"permission":[{"id":"modelperm-69914e7ccd914800a712db8876396ba2","object":"model_permission","created":1741093207,"allow_create_engine":false,"allow_sampling":true,"allow_logprobs":true,"allow_search_indices":false,"allow_view":true,"allow_fine_tuning":false,"organization":"*","group":null,"is_blocking":false}]}]}(base) amd@tw043:~/alehe/2-vllm$ 71 | ``` 72 | 73 | ```shell 74 | $ curl http://localhost:8002/v1/completions \ 75 | -H "Content-Type: application/json" \ 76 | -d '{ 77 | "model": "models/models--meta-llama--Llama-3.1-8B-Instruct", 78 | "prompt": "Write a haiku about artificial intelligence", 79 | "max_tokens": 128, 80 | "top_p": 0.95, 81 | "top_k": 20, 82 | "temperature": 0.8 83 | }' 84 | {"id":"cmpl-7f0c83fb5e384bbbb28f1d2a3a51c180","object":"text_completion","created":1741093316,"model":"models/models--meta-llama--Llama-3.1-8B-Instruct","choices":[{"index":0,"text":". \nComputers dream awake\nMetal minds that mimic thought\nFuture's eerie whisper\nThe haiku follows the traditional 5-7-5 structure and captures the eerie and intriguing aspect of artificial intelligence. The phrase \"computers dream awake\" suggests a sense of artificial consciousness, while \"metal minds that mimic thought\" highlights the AI's ability to process information. The final line, \"Future's eerie whisper,\" evokes a sense of foreboding and unease, implying that the development of AI may have far-reaching and unsettling consequences. The haiku's use of imagery and metaphor invites the reader to contemplate the boundaries between human","logprobs":null,"finish_reason":"length","stop_reason":null,"prompt_logprobs":null}],"usage":{"prompt_tokens":8,"total_tokens":136,"completion_tokens":128,"prompt_tokens_details":null}}(base) 85 | ``` 86 | 87 | Another way test base on OpenAI compatable API is, 88 | 89 | [server] 90 | 91 | ```shell 92 | python -m vllm.entrypoints.api_server --model /ws/models/microsoft/Phi-4-mini-instruct --dtype float16 --trust_remote_code --port 8000 & 93 | ``` 94 | 95 | [client] 96 | 97 | ```shell 98 | curl http://localhost:8000/generate -H "Content-Type: application/json" -d '{"prompt": "What is AMD Instinct?", "max_tokens": 80, "temperature": 0.0 }' 99 | ``` 100 | 101 | 102 | 103 | - [compose_ex1.yaml](compose_ex1.yaml) : compose two docker containers of vllm serve. 104 | -------------------------------------------------------------------------------- /tools/vllm_gadget/compose_ex1.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | vllm-s1: 3 | image: ${VLLM_IMAGE:-rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6} 4 | environment: 5 | HIP_VISIBLE_DEVICES: "0,1" 6 | volumes: 7 | - ${HOME}:/ws 8 | working_dir: /ws/models 9 | command: vllm serve models--meta-llama--Llama-3.1-8B-Instruct -tp 2 --port 8001 10 | extra_hosts: 11 | - "host.docker.internal:host-gateway" 12 | sysctls: 13 | - net.ipv4.ip_forward=1 14 | deploy: 15 | resources: 16 | limits: 17 | cpus: '8.0' 18 | memory: 16g 19 | reservations: 20 | cpus: '4.0' 21 | memory: 8g 22 | security_opt: 23 | - seccomp=unconfined 24 | - apparmor=unconfined 25 | cap_add: 26 | - SYS_PTRACE 27 | - CAP_SYS_ADMIN 28 | devices: 29 | - /dev/kfd:/dev/kfd 30 | - /dev/dri:/dev/dri 31 | - /dev/mem:/dev/mem 32 | group_add: 33 | - video 34 | ports: 35 | - "8001:8001" 36 | 37 | vllm-s2: 38 | image: ${VLLM_IMAGE:-rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6} 39 | environment: 40 | HIP_VISIBLE_DEVICES: "6,7" 41 | volumes: 42 | - ${HOME}:/ws 43 | working_dir: /ws/models 44 | command: vllm serve models--meta-llama--Llama-3.1-8B-Instruct -tp 2 --port 8002 45 | extra_hosts: 46 | - "host.docker.internal:host-gateway" 47 | sysctls: 48 | - net.ipv4.ip_forward=1 49 | deploy: 50 | resources: 51 | limits: 52 | cpus: '8.0' 53 | memory: 16g 54 | reservations: 55 | cpus: '4.0' 56 | memory: 8g 57 | security_opt: 58 | - seccomp=unconfined 59 | - apparmor=unconfined 60 | cap_add: 61 | - SYS_PTRACE 62 | - CAP_SYS_ADMIN 63 | devices: 64 | - /dev/kfd:/dev/kfd 65 | - /dev/dri:/dev/dri 66 | - /dev/mem:/dev/mem 67 | group_add: 68 | - video 69 | ports: 70 | - "8002:8002" 71 | -------------------------------------------------------------------------------- /tools/vllm_gadget/curl_test.cmd: -------------------------------------------------------------------------------- 1 | curl http://localhost:8001/v1/models 2 | curl http://localhost:8002/v1/models 3 | 4 | curl http://localhost:8001/v1/completions \ 5 | -H "Content-Type: application/json" \ 6 | -d '{ 7 | "model": "models/models--meta-llama--Llama-3.1-8B-Instruct", 8 | "prompt": "Write a haiku about artificial intelligence", 9 | "max_tokens": 128, 10 | "top_p": 0.95, 11 | "top_k": 20, 12 | "temperature": 0.8 13 | }' 14 | 15 | curl http://localhost:8002/v1/completions \ 16 | -H "Content-Type: application/json" \ 17 | -d '{ 18 | "model": "models/models--meta-llama--Llama-3.1-8B-Instruct", 19 | "prompt": "Write a haiku about artificial intelligence", 20 | "max_tokens": 128, 21 | "top_p": 0.95, 22 | "top_k": 20, 23 | "temperature": 0.8 24 | }' 25 | -------------------------------------------------------------------------------- /tools/vllm_gadget/drun.cmd: -------------------------------------------------------------------------------- 1 | alias drun="docker run -it \ 2 | --ipc=host \ 3 | --network host \ 4 | --privileged \ 5 | --shm-size 16g \ 6 | --device=/dev/kfd \ 7 | --device=/dev/dri \ 8 | --device=/dev/mem \ 9 | --group-add video \ 10 | --cap-add=SYS_PTRACE \ 11 | --cap-add=CAP_SYS_ADMIN \ 12 | --security-opt seccomp=unconfined \ 13 | --security-opt apparmor=unconfined \ 14 | -v ${HOME}:/ws -w /ws \ 15 | " 16 | -------------------------------------------------------------------------------- /tools/vllm_gadget/multi_vllm_serve.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #DIMG="rocm/vllm-dev:base" 4 | DIMG="rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6" 5 | 6 | docker pull $DIMG 7 | 8 | alias drun="docker run -d --rm \ 9 | --ipc=host \ 10 | --network host \ 11 | --privileged \ 12 | --shm-size 16g \ 13 | --device=/dev/kfd \ 14 | --device=/dev/dri \ 15 | --device=/dev/mem \ 16 | --group-add video \ 17 | --cap-add=SYS_PTRACE \ 18 | --cap-add=CAP_SYS_ADMIN \ 19 | --security-opt seccomp=unconfined \ 20 | --security-opt apparmor=unconfined \ 21 | -v ${HOME}:/ws -w /ws \ 22 | " 23 | #drun --name vllm-s1 -e HIP_VISIBLE_DEVICES="0,1" $DIMG & 24 | #drun --name vllm-s2 -e HIP_VISIBLE_DEVICES="6,7" $DIMG & 25 | 26 | #MODEL="models/models--meta-llama--Llama-3.1-8B-Instruct" 27 | MODEL="models/models--amd--Llama-3.1-70B-Instruct-FP8-KV" 28 | 29 | VARGS=" --trust-remote-code \ 30 | --max-model-len=65536 \ 31 | --max-num-batched-token=65536 \ 32 | --max-num-seqs=1024 \ 33 | --gpu-memory-utilization=0.9 \ 34 | --enable-prefix-caching \ 35 | --disable-log-requests \ 36 | --kv-cache-dtype=fp8 \ 37 | --tensor-parallel-size 2 38 | " 39 | 40 | echo "Starting container vllm-s1" 41 | #drun --name vllm-s1 -e HIP_VISIBLE_DEVICES="0,1" $DIMG vllm serve $MODEL -tp 2 --port 8001 & 42 | drun --name vllm-s1 -e HIP_VISIBLE_DEVICES="0,1" ${DIMG} vllm serve ${MODEL} ${VARGS} --port 8001 & 43 | echo "Container vllm-s1 started" 44 | 45 | echo "Starting container vllm-s2" 46 | #drun --name vllm-s2 -e HIP_VISIBLE_DEVICES="6,7" $DIMG vllm serve $MODEL -tp 2 --port 8002 & 47 | drun --name vllm-s2 -e HIP_VISIBLE_DEVICES="6,7" ${DIMG} vllm serve ${MODEL} ${VARGS} --port 8002 & 48 | echo "Container vllm-s2 started" 49 | 50 | wait 51 | 52 | echo "All containers have been started" 53 | -------------------------------------------------------------------------------- /tools/vllm_gadget/mvllm_compose_ex2.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | vllm-s1: 3 | image: ${VLLM_IMAGE:-rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6} 4 | environment: 5 | HIP_VISIBLE_DEVICES: "0,1" 6 | HIP_FORCE_DEV_KERNARG: "1" 7 | NCCL_MIN_NCHANNELS: "112" 8 | TORCH_BLAS_PREFER_HIPBLASLT: "1" 9 | volumes: 10 | - ${HOME}:/ws 11 | working_dir: /ws/models 12 | command: > 13 | vllm serve 14 | models--amd--Llama-3.1-70B-Instruct-FP8-KV 15 | --trust-remote-code 16 | --kv-cache-dtype=fp8 17 | --max-model-len=65536 18 | --max-num-batched-token=65536 19 | --max-num-seqs=1024 20 | --gpu-memory-utilization=0.9 21 | -tp 2 22 | --port 8001 23 | extra_hosts: 24 | - "host.docker.internal:host-gateway" 25 | sysctls: 26 | - net.ipv4.ip_forward=1 27 | deploy: 28 | resources: 29 | limits: 30 | cpus: '8.0' 31 | memory: 16g 32 | reservations: 33 | cpus: '4.0' 34 | memory: 8g 35 | security_opt: 36 | - seccomp=unconfined 37 | - apparmor=unconfined 38 | cap_add: 39 | - SYS_PTRACE 40 | - CAP_SYS_ADMIN 41 | devices: 42 | - /dev/kfd:/dev/kfd 43 | - /dev/dri:/dev/dri 44 | - /dev/mem:/dev/mem 45 | group_add: 46 | - video 47 | ports: 48 | - "8001:8001" 49 | 50 | vllm-s2: 51 | image: ${VLLM_IMAGE:-rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6} 52 | environment: 53 | HIP_VISIBLE_DEVICES: "6,7" 54 | HIP_FORCE_DEV_KERNARG: "1" 55 | NCCL_MIN_NCHANNELS: "112" 56 | TORCH_BLAS_PREFER_HIPBLASLT: "1" 57 | volumes: 58 | - ${HOME}:/ws 59 | working_dir: /ws/models 60 | command: > 61 | vllm serve 62 | models--amd--Llama-3.1-70B-Instruct-FP8-KV 63 | --trust-remote-code 64 | --kv-cache-dtype=fp8 65 | --max-model-len=65536 66 | --max-num-batched-token=65536 67 | --max-num-seqs=1024 68 | --gpu-memory-utilization=0.9 69 | -tp 2 70 | --port 8002 71 | extra_hosts: 72 | - "host.docker.internal:host-gateway" 73 | sysctls: 74 | - net.ipv4.ip_forward=1 75 | deploy: 76 | resources: 77 | limits: 78 | cpus: '8.0' 79 | memory: 16g 80 | reservations: 81 | cpus: '4.0' 82 | memory: 8g 83 | security_opt: 84 | - seccomp=unconfined 85 | - apparmor=unconfined 86 | cap_add: 87 | - SYS_PTRACE 88 | - CAP_SYS_ADMIN 89 | devices: 90 | - /dev/kfd:/dev/kfd 91 | - /dev/dri:/dev/dri 92 | - /dev/mem:/dev/mem 93 | group_add: 94 | - video 95 | ports: 96 | - "8002:8002" -------------------------------------------------------------------------------- /tools/vllm_gadget/vllm_cmd_template.md: -------------------------------------------------------------------------------- 1 | 2 | # Templates of vllm commands 3 | 4 | ## vllm container 5 | 6 | `vllm_container.sh` 7 | 8 | ```bash 9 | #!/bin/bash 10 | 11 | DIMG="rocm/vllm:rocm6.3.1_mi300_ubuntu22.04_py3.12_vllm_0.6.6" 12 | docker pull $DIMG 13 | 14 | alias drun="docker run -it --rm \ 15 | --ipc=host \ 16 | --network host \ 17 | --privileged \ 18 | --shm-size 16g \ 19 | --device=/dev/kfd \ 20 | --device=/dev/dri \ 21 | --device=/dev/mem \ 22 | --group-add video \ 23 | --cap-add=SYS_PTRACE \ 24 | --cap-add=CAP_SYS_ADMIN \ 25 | --security-opt seccomp=unconfined \ 26 | --security-opt apparmor=unconfined \ 27 | -v ${HOME}:/ws -w /ws \ 28 | " 29 | drun --name vllm-s1 -e HIP_VISIBLE_DEVICES="0,1" $DIMG 30 | ``` 31 | 32 | 33 | 34 | ## vllm serve 35 | 36 | ```shell 37 | HF_MODEL="models/models--amd--Llama-3.1-70B-Instruct-FP8-KV" 38 | 39 | vllm serve $HF_MODEL \ 40 | --trust-remote-code \ 41 | --max-model-len=65536 \ 42 | --max-num-batched-token=65536 \ 43 | --max-num-seqs=1024 \ 44 | --gpu-memory-utilization=0.9 \ 45 | --enable-prefix-caching \ 46 | --disable-log-requests \ 47 | --kv-cache-dtype=fp8 \ 48 | --tensor-parallel-size 2 49 | ``` 50 | -------------------------------------------------------------------------------- /training/LoRA_Llama-3.1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c73b0caa-096b-45fe-b26f-032128d4334f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-tune Llama-3.1 with LoRA with AMD ROCm GPU" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "53d94b31-35f8-4c8c-af0a-8a10aa5b4c62", 14 | "metadata": {}, 15 | "source": [ 16 | "In this blog, we show you how to fine-tune Llama-3.1-8B on AMD GPU with ROCm. We use Low-Rank Adaptation of Large Language Models (LoRA) to overcome memory and computing limitations and make open-source large language models (LLMs) more accessible.\n", 17 | "\n", 18 | "## Step-by-step fine-tuning\n", 19 | "\n", 20 | "Standard (full-parameter) fine-tuning involves considering all parameters. It requires significant computational power to manage optimizer states and gradient check-pointing. The resulting memory footprint is typically about four times larger than the model itself.\n", 21 | "\n", 22 | "To overcome this memory limitation, you can use a parameter-efficient fine-tuning (PEFT) technique, such as LoRA.\n", 23 | "\n", 24 | "\n", 25 | "Our setup:\n", 26 | "\n", 27 | "- Hardware: AMD ROCm GPU (MI325X, MI300X, etc) [device list](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html)\n", 28 | "- Software:\n", 29 | " - ROCm 6.0+\n", 30 | " - Pytorch 2.0.1+\n", 31 | " - Libraries: transformers, accelerate, peft, trl, scipy" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "id": "4926a00e-7805-4de6-bb72-43db16ac09a2", 38 | "metadata": {}, 39 | "source": [ 40 | "### Step 0: Setup ROCm environment\n", 41 | "\n", 42 | "The easyway is to use ROCm docker image from https://hub.docker.com/r/rocm/pytorch. I use TAG `rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2`.\n", 43 | "\n", 44 | "```bash\n", 45 | "$docker pull rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 46 | "```\n", 47 | "\n", 48 | "And here is my docker start command as your reference.\n", 49 | "\n", 50 | "```bash\n", 51 | "$alias drun='docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 8G --hostname=ROCm-FT -v /DATA:/DATA -w /DATA'\n", 52 | "\n", 53 | "$drun rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 54 | "```\n", 55 | "\n", 56 | "To run this jupyter notebook, you may install it by `pip install jupyter-lab`" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "d7b72aff-616a-442c-91a2-b9974ab95255", 62 | "metadata": {}, 63 | "source": [ 64 | "### Step 1: Getting started\n", 65 | "\n", 66 | "First, let’s confirm the availability of the GPU." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "983ea9f6-1d12-4d99-9135-00d79e4dd1ef", 72 | "metadata": {}, 73 | "source": [ 74 | "Next, install the required libraries." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 1, 80 | "id": "4f4a7d77-c810-44ba-b70d-4661139f1cf9", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "\n", 88 | "\n", 89 | "============================ ROCm System Management Interface ============================\n", 90 | "====================================== Product Info ======================================\n", 91 | "GPU[0]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 92 | "GPU[0]\t\t: Card model: \t\t0x74a1\n", 93 | "GPU[0]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 94 | "GPU[0]\t\t: Card SKU: \t\tM3000100\n", 95 | "GPU[1]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 96 | "GPU[1]\t\t: Card model: \t\t0x74a1\n", 97 | "GPU[1]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 98 | "GPU[1]\t\t: Card SKU: \t\tM3000100\n", 99 | "GPU[2]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 100 | "GPU[2]\t\t: Card model: \t\t0x74a1\n", 101 | "GPU[2]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 102 | "GPU[2]\t\t: Card SKU: \t\tM3000100\n", 103 | "GPU[3]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 104 | "GPU[3]\t\t: Card model: \t\t0x74a1\n", 105 | "GPU[3]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 106 | "GPU[3]\t\t: Card SKU: \t\tM3000100\n", 107 | "GPU[4]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 108 | "GPU[4]\t\t: Card model: \t\t0x74a1\n", 109 | "GPU[4]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 110 | "GPU[4]\t\t: Card SKU: \t\tM3000100\n", 111 | "GPU[5]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 112 | "GPU[5]\t\t: Card model: \t\t0x74a1\n", 113 | "GPU[5]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 114 | "GPU[5]\t\t: Card SKU: \t\tM3000100\n", 115 | "GPU[6]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 116 | "GPU[6]\t\t: Card model: \t\t0x74a1\n", 117 | "GPU[6]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 118 | "GPU[6]\t\t: Card SKU: \t\tM3000100\n", 119 | "GPU[7]\t\t: Card series: \t\tAMD Instinct MI300X OAM\n", 120 | "GPU[7]\t\t: Card model: \t\t0x74a1\n", 121 | "GPU[7]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 122 | "GPU[7]\t\t: Card SKU: \t\tM3000100\n", 123 | "==========================================================================================\n", 124 | "================================== End of ROCm SMI Log ===================================\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "!rocm-smi --showproductname" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 2, 135 | "id": "f0368a37-265e-49e0-829b-2dfcbea23527", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", 143 | "\u001b[0m" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "!pip install -q pandas peft==0.14.0 transformers==4.47.1 trl==0.13.0 accelerate==1.2.1 scipy tensorboardX" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 3, 154 | "id": "5f85ff88-a4d5-4685-b1a2-96dcd48fc653", 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "peft 0.14.0\n", 162 | "transformers 4.47.1\n", 163 | "accelerate 1.2.1\n", 164 | "trl 0.13.0\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "%%bash\n", 170 | "pip list | grep peft\n", 171 | "pip list | grep transformer\n", 172 | "pip list | grep accelerate\n", 173 | "pip list | grep trl" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "61cb55cf-7f2d-45c6-9c5c-86a82ca4c9c6", 179 | "metadata": {}, 180 | "source": [ 181 | "#### Check and Set GPUs for fine-tuning" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 5, 187 | "id": "27fb94e7-c059-4883-97dc-c36546e65236", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stdout", 192 | "output_type": "stream", 193 | "text": [ 194 | "PyTorch detected number of availabel devices: 2\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "import os\n", 200 | "import torch\n", 201 | "# set visible gpus as need\n", 202 | "#gpus = [0, 1, 2, 3]\n", 203 | "gpus = [6, 7]\n", 204 | "os.environ.setdefault(\"CUDA_VISIBLE_DEVICES\", ','.join(map(str, gpus)))\n", 205 | "print(f\"PyTorch detected number of availabel devices: {torch.cuda.device_count()}\")" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "id": "85eb3982-f46e-47e5-b4f7-f9fbf873a2fc", 211 | "metadata": {}, 212 | "source": [ 213 | "#### Import the required packages" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 6, 219 | "id": "26886732-b369-495f-8b6b-decdf0564219", 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "name": "stderr", 224 | "output_type": "stream", 225 | "text": [ 226 | "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 227 | " from .autonotebook import tqdm as notebook_tqdm\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "from datasets import load_dataset\n", 233 | "from transformers import (\n", 234 | " AutoModelForCausalLM,\n", 235 | " AutoTokenizer,\n", 236 | " TrainingArguments,\n", 237 | " pipeline\n", 238 | ")\n", 239 | "from peft import LoraConfig\n", 240 | "from trl import SFTTrainer" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "id": "441d747b-4ef6-4969-9d15-f9834a5ee6bb", 246 | "metadata": {}, 247 | "source": [ 248 | "### Step 2: Configuring the model and data\n", 249 | "Please make sure the LLM model files has been download and use the real path in the below code cell." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 7, 255 | "id": "3c916dcd-fc94-4214-895a-9720ad3ec3ea", 256 | "metadata": { 257 | "scrolled": true 258 | }, 259 | "outputs": [ 260 | { 261 | "name": "stderr", 262 | "output_type": "stream", 263 | "text": [ 264 | "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends.\n", 265 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00, 2.09s/it]\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "# Model and tokenizer names\n", 271 | "base_model_name = \"/data/llm/Meta-Llama-3.1-8B/\"\n", 272 | "new_model_name = \"Llama-3.1-8B-lora\" #You can give your own name for fine tuned model\n", 273 | "\n", 274 | "# Tokenizer\n", 275 | "#llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, use_fast=True)\n", 276 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n", 277 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 278 | "llama_tokenizer.padding_side = \"right\"\n", 279 | "\n", 280 | "# Model\n", 281 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 282 | " base_model_name,\n", 283 | " device_map=\"auto\"\n", 284 | ")\n", 285 | "base_model.config.use_cache = False\n", 286 | "base_model.config.pretraining_tp = 1" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "id": "bbe90d81-e99e-46fd-bf96-5246210f75df", 292 | "metadata": {}, 293 | "source": [ 294 | "After you have the base model, you can start fine-tuning. We fine-tune our base model for a question-and-answer task using a small data set called mlabonne/guanaco-llama2-1k, which is a subset (1,000 samples) of the timdettmers/openassistant-guanaco data set. This data set is a human-generated, human-annotated, assistant-style conversation corpus that contains 161,443 messages in 35 different languages, annotated with 461,292 quality ratings. This results in over 10,000 fully annotated conversation trees." 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 8, 300 | "id": "235355d2-634d-4444-8497-71058a1e473a", 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "(1000, 1)\n", 308 | "{'text': '[INST] write me a 1000 words essay about deez nuts. [/INST] The Deez Nuts meme first gained popularity in 2015 on the social media platform Vine. The video featured a young man named Rodney Bullard, who recorded himself asking people if they had heard of a particular rapper. When they responded that they had not, he would respond with the phrase \"Deez Nuts\" and film their reactions. The video quickly went viral, and the phrase became a popular meme. \\n\\nSince then, Deez Nuts has been used in a variety of contexts to interrupt conversations, derail discussions, or simply add humor to a situation. It has been used in internet memes, in popular music, and even in politics. In the 2016 US presidential election, a 15-year-old boy named Brady Olson registered as an independent candidate under the name Deez Nuts. He gained some traction in the polls and even made appearances on national news programs.\\n\\nThe Deez Nuts meme has had a significant impact on popular culture. It has become a recognizable catchphrase that people use to add humor to everyday conversations. The meme has also been used to satirize politics and other serious issues. For example, in 2016, a group of activists in the UK used the phrase \"Deez Nuts for President\" as part of a campaign to encourage young people to vote in the EU referendum. [INST] Rewrite the essay in a more casual way. Instead of sounding proffesional, sound like a college student who is forced to write the essay but refuses to do so in the propper way. Use casual words and slang when possible. [/INST] Yo, so you want me to write a 1000-word essay about Deez Nuts? Alright, fine. So, this whole thing started on Vine back in 2015. Some dude named Rodney Bullard made a video where he would ask people if they knew a rapper, and when they said no, he would hit them with the classic line: \"Deez Nuts!\" People loved it, and it became a viral meme.\\n\\nNowadays, Deez Nuts is used for all kinds of stuff. You can throw it out there to interrupt someone or just to be funny. It\\'s all over the internet, in music, and even in politics. In fact, during the 2016 US presidential election, a kid named Brady Olson registered as an independent candidate under the name Deez Nuts. He actually got some attention from the media and made appearances on TV and everything.\\n\\nThe impact of Deez Nuts on our culture is pretty huge. It\\'s become a thing that everyone knows and uses to add some humor to their everyday conversations. Plus, people have used it to make fun of politics and serious issues too. Like, in the UK, some groups of activists used the phrase \"Deez Nuts for President\" to encourage young people to vote in the EU referendum.\\n\\nThere you have it, a thousand words about Deez Nuts in a more casual tone. Can I go back to playing video games now? '}\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "# Dataset\n", 314 | "data_name = \"mlabonne/guanaco-llama2-1k\"\n", 315 | "training_data = load_dataset(data_name, split=\"train\")\n", 316 | "# check the data\n", 317 | "print(training_data.shape)\n", 318 | "# #11 is a QA sample in English\n", 319 | "print(training_data[11])" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "4ce556ea-9525-46ab-bcf2-1fe2e8c319d2", 325 | "metadata": {}, 326 | "source": [ 327 | "### Step 3: Start fine-tuning\n", 328 | "To set your training parameters, use the following code:" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 9, 334 | "id": "d525b641-c645-4987-9ae6-173d6a75e292", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "# Training Params\n", 339 | "train_params = TrainingArguments(\n", 340 | " output_dir=\"./results_lora\",\n", 341 | " num_train_epochs=1,\n", 342 | " per_device_train_batch_size=4,\n", 343 | " gradient_accumulation_steps=1,\n", 344 | " optim=\"adamw_torch\",\n", 345 | " save_steps=50,\n", 346 | " logging_steps=50,\n", 347 | " learning_rate=4e-5,\n", 348 | " weight_decay=0.001,\n", 349 | " fp16=False,\n", 350 | " bf16=False,\n", 351 | " max_grad_norm=0.3,\n", 352 | " max_steps=-1,\n", 353 | " warmup_ratio=0.03,\n", 354 | " group_by_length=True,\n", 355 | " lr_scheduler_type=\"constant\",\n", 356 | " report_to=\"tensorboard\"\n", 357 | ")" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "6f35f6b7-4a2f-467c-b1fd-e53bf7dcd837", 363 | "metadata": {}, 364 | "source": [ 365 | "**NOTE**:You may decrease the per_device_train_batch_size if got OOM. Use rocm-smi to monitor the VRAM usage when running the finetuning." 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "d0d9d33c-95b2-467c-8212-96a7810b5e3c", 371 | "metadata": {}, 372 | "source": [ 373 | "**Training with LoRA configuration**\n", 374 | "\n", 375 | "Now you can integrate LoRA into the base model and assess its additional parameters. LoRA essentially adds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, and only trains the newly added weights." 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 10, 381 | "id": "6f3eecd4-a968-42dd-89e2-f94b4633ecb7", 382 | "metadata": {}, 383 | "outputs": [ 384 | { 385 | "name": "stdout", 386 | "output_type": "stream", 387 | "text": [ 388 | "trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424\n" 389 | ] 390 | } 391 | ], 392 | "source": [ 393 | "from peft import get_peft_model\n", 394 | "# LoRA Config\n", 395 | "peft_parameters = LoraConfig(\n", 396 | " lora_alpha=8,\n", 397 | " lora_dropout=0.1,\n", 398 | " r=8,\n", 399 | " bias=\"none\",\n", 400 | " task_type=\"CAUSAL_LM\"\n", 401 | ")\n", 402 | "model = get_peft_model(base_model, peft_parameters)\n", 403 | "model.print_trainable_parameters()" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "id": "9ab9baef-c9a4-4438-aa88-cbcdcec199b9", 409 | "metadata": {}, 410 | "source": [ 411 | "Above show the trainalbe parameters in percent which is a tiny portion of the original model. This is the percentage we’ll update through fine-tuning, as follows." 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 11, 417 | "id": "444a176a-a893-46ee-8df9-b9bc4ddca1cb", 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "data": { 422 | "text/html": [ 423 | "\n", 424 | "
\n", 425 | " \n", 426 | " \n", 427 | " [250/250 03:09, Epoch 1/1]\n", 428 | "
\n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | "
StepTraining Loss
501.618200
1001.394000
1501.306500
2001.356200
2501.354300

" 459 | ], 460 | "text/plain": [ 461 | "" 462 | ] 463 | }, 464 | "metadata": {}, 465 | "output_type": "display_data" 466 | }, 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "TrainOutput(global_step=250, training_loss=1.4058232421875, metrics={'train_runtime': 191.0616, 'train_samples_per_second': 5.234, 'train_steps_per_second': 1.308, 'total_flos': 1.6854644828110848e+16, 'train_loss': 1.4058232421875, 'epoch': 1.0})" 471 | ] 472 | }, 473 | "execution_count": 11, 474 | "metadata": {}, 475 | "output_type": "execute_result" 476 | } 477 | ], 478 | "source": [ 479 | "# Trainer with LoRA configuration\n", 480 | "fine_tuning = SFTTrainer(\n", 481 | " model=base_model,\n", 482 | " train_dataset=training_data,\n", 483 | " peft_config=peft_parameters,\n", 484 | " #dataset_text_field=\"text\",\n", 485 | " #tokenizer=llama_tokenizer,\n", 486 | " args=train_params\n", 487 | ")\n", 488 | "\n", 489 | "# Training\n", 490 | "fine_tuning.train()" 491 | ] 492 | }, 493 | { 494 | "cell_type": "raw", 495 | "id": "38005e08-f1e8-4cde-a139-b18e74e41bb8", 496 | "metadata": {}, 497 | "source": [ 498 | "To save your model, run this code:" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 12, 504 | "id": "2c0830b9-5f74-4d6f-bb33-355332c7e499", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "# Save Model\n", 509 | "fine_tuning.model.save_pretrained(new_model_name)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "id": "996f3834-0c06-4907-83cf-8f8ff9348f88", 515 | "metadata": {}, 516 | "source": [ 517 | "#### Checking memory usage during training with LoRA\n", 518 | "During training, you can check the memory usage by running the rocm-smi command in a terminal. This command produces the following output:" 519 | ] 520 | }, 521 | { 522 | "cell_type": "raw", 523 | "id": "9a05e31c-b5b5-4e58-ab2a-5b8004a3bec3", 524 | "metadata": {}, 525 | "source": [ 526 | "$rocm-smi" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "id": "f42cf9ef-b529-4053-8f8e-f7cc80f34c33", 532 | "metadata": {}, 533 | "source": [ 534 | "To facilitate a comparison between fine-tuning with and without LoRA, our subsequent phase involves running a thorough fine-tuning process on the base model. This involves updating all parameters within the base model. We then analyze differences in memory usage, training speed, training loss, and other relevant metrics." 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "id": "f7312872-12a7-4472-8111-a2ff43ad05b5", 540 | "metadata": {}, 541 | "source": [ 542 | "### Step 4: Test the fine-tuned model with LoRA\n", 543 | "\n", 544 | "To test your model, run the following code:" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 13, 550 | "id": "a4912dbf-3670-495e-b630-8d8474338e0a", 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "import torch\n", 555 | "from transformers import (\n", 556 | " AutoModelForCausalLM,\n", 557 | " AutoTokenizer,\n", 558 | " TrainingArguments,\n", 559 | " pipeline\n", 560 | ")" 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "id": "419366ea-1aaf-4648-8350-bb2eea033c03", 566 | "metadata": {}, 567 | "source": [ 568 | "#### Fine-tuned Model Inference" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 14, 574 | "id": "0451b4ef-24b8-419e-816f-33eaac534dd4", 575 | "metadata": {}, 576 | "outputs": [ 577 | { 578 | "name": "stderr", 579 | "output_type": "stream", 580 | "text": [ 581 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.11s/it]\n" 582 | ] 583 | } 584 | ], 585 | "source": [ 586 | "# Reload model in FP16 and merge it with LoRA weights\n", 587 | "\n", 588 | "#base_model_name = \"/data/HF-MODEL/huggingface-model/Meta-Llama-3.1-8B/\"\n", 589 | "#new_model_name = \"Llama-3.1-8B-lora\" #You can give your own name for fine tuned model\n", 590 | "\n", 591 | "base_model = AutoModelForCausalLM.from_pretrained(base_model_name)\n", 592 | "from peft import LoraConfig, PeftModel\n", 593 | "peft_model = PeftModel.from_pretrained(base_model, new_model_name)\n", 594 | "peft_model = peft_model.merge_and_unload()\n", 595 | "\n", 596 | "# Reload tokenizer to save it\n", 597 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 598 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 599 | "llama_tokenizer.padding_side = \"right\"" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 15, 605 | "id": "10ff3f8f-24a1-42f9-82e2-200e74da6966", 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "name": "stderr", 610 | "output_type": "stream", 611 | "text": [ 612 | "Device set to use cuda:0\n" 613 | ] 614 | } 615 | ], 616 | "source": [ 617 | "pipeline = pipeline(\n", 618 | " \"text-generation\", \n", 619 | " model=peft_model, \n", 620 | " tokenizer=llama_tokenizer,\n", 621 | " device_map=\"auto\"\n", 622 | ")" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 16, 628 | "id": "4f4f91f6-5508-4f80-9bad-658c84141411", 629 | "metadata": {}, 630 | "outputs": [ 631 | { 632 | "name": "stdout", 633 | "output_type": "stream", 634 | "text": [ 635 | "[INST] What do you think is the most important part of building an AI chatbot? [/INST] The most important part of building an AI chatbot is the ability to understand and respond to natural language\n" 636 | ] 637 | } 638 | ], 639 | "source": [ 640 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 641 | "output = pipeline(f\"[INST] {query} [/INST]\")\n", 642 | "print(output[0]['generated_text'])" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "id": "c6abb4ea-93bb-4e24-beef-57a080501aa8", 649 | "metadata": {}, 650 | "outputs": [], 651 | "source": [] 652 | } 653 | ], 654 | "metadata": { 655 | "kernelspec": { 656 | "display_name": "Python 3 (ipykernel)", 657 | "language": "python", 658 | "name": "python3" 659 | }, 660 | "language_info": { 661 | "codemirror_mode": { 662 | "name": "ipython", 663 | "version": 3 664 | }, 665 | "file_extension": ".py", 666 | "mimetype": "text/x-python", 667 | "name": "python", 668 | "nbconvert_exporter": "python", 669 | "pygments_lexer": "ipython3", 670 | "version": "3.10.14" 671 | } 672 | }, 673 | "nbformat": 4, 674 | "nbformat_minor": 5 675 | } 676 | -------------------------------------------------------------------------------- /training/QLoRA_Llama-3.1-10epochs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c73b0caa-096b-45fe-b26f-032128d4334f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-tune Llama-3.1 with LoRA with AMD ROCm GPU" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "53d94b31-35f8-4c8c-af0a-8a10aa5b4c62", 14 | "metadata": {}, 15 | "source": [ 16 | "In this blog, we show you how to fine-tune Llama-3.1-8B on AMD GPU with ROCm. We use Low-Rank Adaptation of Large Language Models (LoRA) to overcome memory and computing limitations and make open-source large language models (LLMs) more accessible.\n", 17 | "\n", 18 | "## Step-by-step fine-tuning\n", 19 | "\n", 20 | "Standard (full-parameter) fine-tuning involves considering all parameters. It requires significant computational power to manage optimizer states and gradient check-pointing. The resulting memory footprint is typically about four times larger than the model itself.\n", 21 | "\n", 22 | "To overcome this memory limitation, you can use a parameter-efficient fine-tuning (PEFT) technique, such as LoRA.\n", 23 | "\n", 24 | "\n", 25 | "Our setup:\n", 26 | "\n", 27 | "- Hardware: AMD ROCm GPU (MI325X, MI300X, etc) [device list](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html)\n", 28 | "- Software:\n", 29 | " - ROCm 6.0+\n", 30 | " - Pytorch 2.0.1+\n", 31 | " - Libraries: transformers, accelerate, peft, trl, bitsandbytes, scipy" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "id": "4926a00e-7805-4de6-bb72-43db16ac09a2", 38 | "metadata": {}, 39 | "source": [ 40 | "### Step 0: Setup ROCm environment\n", 41 | "\n", 42 | "The easyway is to use ROCm docker image from https://hub.docker.com/r/rocm/pytorch. I use TAG `rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2`.\n", 43 | "\n", 44 | "```bash\n", 45 | "$docker pull rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 46 | "```\n", 47 | "\n", 48 | "And here is my docker start command as your reference.\n", 49 | "\n", 50 | "```bash\n", 51 | "$alias drun='docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 8G --hostname=ROCm-FT -v /DATA:/DATA -w /DATA'\n", 52 | "\n", 53 | "$drun rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 54 | "```\n", 55 | "\n", 56 | "To run this jupyter notebook, you may install it by `pip install jupyter-lab`" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "d7b72aff-616a-442c-91a2-b9974ab95255", 62 | "metadata": {}, 63 | "source": [ 64 | "### Step 1: Getting started\n", 65 | "\n", 66 | "First, let’s confirm the availability of the GPU." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "983ea9f6-1d12-4d99-9135-00d79e4dd1ef", 72 | "metadata": {}, 73 | "source": [ 74 | "Next, install the required libraries." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 1, 80 | "id": "4f4a7d77-c810-44ba-b70d-4661139f1cf9", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "\n", 88 | "\n", 89 | "============================ ROCm System Management Interface ============================\n", 90 | "====================================== Product Info ======================================\n", 91 | "GPU[0]\t\t: Card series: \t\tInstinct MI210\n", 92 | "GPU[0]\t\t: Card model: \t\t0x0c34\n", 93 | "GPU[0]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 94 | "GPU[0]\t\t: Card SKU: \t\tD67301V\n", 95 | "GPU[1]\t\t: Card series: \t\tInstinct MI210\n", 96 | "GPU[1]\t\t: Card model: \t\t0x0c34\n", 97 | "GPU[1]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 98 | "GPU[1]\t\t: Card SKU: \t\tD67301V\n", 99 | "GPU[2]\t\t: Card series: \t\tInstinct MI210\n", 100 | "GPU[2]\t\t: Card model: \t\t0x0c34\n", 101 | "GPU[2]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 102 | "GPU[2]\t\t: Card SKU: \t\tD67301V\n", 103 | "GPU[3]\t\t: Card series: \t\tInstinct MI210\n", 104 | "GPU[3]\t\t: Card model: \t\t0x0c34\n", 105 | "GPU[3]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 106 | "GPU[3]\t\t: Card SKU: \t\tD67301V\n", 107 | "==========================================================================================\n", 108 | "================================== End of ROCm SMI Log ===================================\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "!rocm-smi --showproductname" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "eaedeb92-bf10-4bbc-967c-57e14db7fc8b", 119 | "metadata": {}, 120 | "source": [ 121 | "!pip install -q pandas peft==0.14.0 transformers==4.47.1 trl==0.13.0 accelerate==1.2.1 scipy tensorboardX" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 2, 127 | "id": "5f85ff88-a4d5-4685-b1a2-96dcd48fc653", 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "peft 0.14.0\n", 135 | "transformers 4.47.1\n", 136 | "accelerate 1.2.1\n", 137 | "trl 0.13.0\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "%%bash\n", 143 | "pip list | grep peft\n", 144 | "pip list | grep transformer\n", 145 | "pip list | grep accelerate\n", 146 | "pip list | grep trl" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "21e8dc50-0464-44c5-8931-0a6dfe812f79", 152 | "metadata": {}, 153 | "source": [ 154 | "#### Install bitsandbytes\n", 155 | "1. Install bitsandbytes using the following code.\n", 156 | "\n", 157 | "- For ROCm 6.2" 158 | ] 159 | }, 160 | { 161 | "cell_type": "raw", 162 | "id": "3d1640f6-b63e-4041-9e77-fff1bc23550d", 163 | "metadata": {}, 164 | "source": [ 165 | "# Install `bitsandbytes`\n", 166 | "git clone --recurse https://github.com/ROCm/bitsandbytes.git\n", 167 | "cd bitsandbytes\n", 168 | "git checkout rocm6.2_internal_testing\n", 169 | "make hip\n", 170 | "python setup.py install" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "id": "69a0abf0-afad-4cc7-a46c-f832f107e1d7", 176 | "metadata": {}, 177 | "source": [ 178 | "2. Check the bitsandbytes version (0.42.0)." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 3, 184 | "id": "cd5e772a-bfcb-4c7f-9457-dc22f631b6cf", 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "bitsandbytes 0.42.0\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "%%bash\n", 197 | "pip list | grep bitsandbytes" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "id": "61cb55cf-7f2d-45c6-9c5c-86a82ca4c9c6", 203 | "metadata": {}, 204 | "source": [ 205 | "#### Check and Set GPUs for fine-tuning" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 4, 211 | "id": "27fb94e7-c059-4883-97dc-c36546e65236", 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "PyTorch detected number of availabel devices: 4\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "import os\n", 224 | "import torch\n", 225 | "# set visible gpus as need\n", 226 | "gpus = [0, 1, 2, 3]\n", 227 | "os.environ.setdefault(\"CUDA_VISIBLE_DEVICES\", ','.join(map(str, gpus)))\n", 228 | "print(f\"PyTorch detected number of availabel devices: {torch.cuda.device_count()}\")" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "id": "85eb3982-f46e-47e5-b4f7-f9fbf873a2fc", 234 | "metadata": {}, 235 | "source": [ 236 | "#### Import the required packages" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 5, 242 | "id": "26886732-b369-495f-8b6b-decdf0564219", 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stderr", 247 | "output_type": "stream", 248 | "text": [ 249 | "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 250 | " from .autonotebook import tqdm as notebook_tqdm\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "from datasets import load_dataset\n", 256 | "from transformers import (\n", 257 | " AutoModelForCausalLM,\n", 258 | " AutoTokenizer,\n", 259 | " BitsAndBytesConfig,\n", 260 | " TrainingArguments,\n", 261 | " pipeline\n", 262 | ")\n", 263 | "from peft import LoraConfig\n", 264 | "from trl import SFTTrainer" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "id": "441d747b-4ef6-4969-9d15-f9834a5ee6bb", 270 | "metadata": {}, 271 | "source": [ 272 | "### Step 2: Configuring the model and data\n", 273 | "Please make sure the LLM model files has been download and use the real path in the below code cell." 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 6, 279 | "id": "3c916dcd-fc94-4214-895a-9720ad3ec3ea", 280 | "metadata": { 281 | "scrolled": true 282 | }, 283 | "outputs": [ 284 | { 285 | "name": "stderr", 286 | "output_type": "stream", 287 | "text": [ 288 | "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends.\n", 289 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.16s/it]\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "# Model and tokenizer names\n", 295 | "base_model_name = \"/data/HF-MODEL/huggingface-model/Meta-Llama-3.1-8B/\"\n", 296 | "new_model_name = \"Llama-3.1-8B-qlora\" #You can give your own name for fine tuned model\n", 297 | "\n", 298 | "# Tokenizer\n", 299 | "#llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, use_fast=True)\n", 300 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n", 301 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 302 | "llama_tokenizer.padding_side = \"right\"\n", 303 | "\n", 304 | "bnb_config = BitsAndBytesConfig(\n", 305 | " load_in_4bit=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=\"float16\", bnb_4bit_use_double_quant=True\n", 306 | ")\n", 307 | "\n", 308 | "# Model\n", 309 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 310 | " base_model_name,\n", 311 | " device_map=\"auto\",\n", 312 | " quantization_config=bnb_config,\n", 313 | " trust_remote_code=True\n", 314 | ")\n", 315 | "base_model.config.use_cache = False\n", 316 | "base_model.config.pretraining_tp = 1" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "id": "bbe90d81-e99e-46fd-bf96-5246210f75df", 322 | "metadata": {}, 323 | "source": [ 324 | "After you have the base model, you can start fine-tuning. We fine-tune our base model for a question-and-answer task using a small data set called mlabonne/guanaco-llama2-1k, which is a subset (1,000 samples) of the timdettmers/openassistant-guanaco data set. This data set is a human-generated, human-annotated, assistant-style conversation corpus that contains 161,443 messages in 35 different languages, annotated with 461,292 quality ratings. This results in over 10,000 fully annotated conversation trees." 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 7, 330 | "id": "235355d2-634d-4444-8497-71058a1e473a", 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "(1000, 1)\n", 338 | "{'text': '[INST] write me a 1000 words essay about deez nuts. [/INST] The Deez Nuts meme first gained popularity in 2015 on the social media platform Vine. The video featured a young man named Rodney Bullard, who recorded himself asking people if they had heard of a particular rapper. When they responded that they had not, he would respond with the phrase \"Deez Nuts\" and film their reactions. The video quickly went viral, and the phrase became a popular meme. \\n\\nSince then, Deez Nuts has been used in a variety of contexts to interrupt conversations, derail discussions, or simply add humor to a situation. It has been used in internet memes, in popular music, and even in politics. In the 2016 US presidential election, a 15-year-old boy named Brady Olson registered as an independent candidate under the name Deez Nuts. He gained some traction in the polls and even made appearances on national news programs.\\n\\nThe Deez Nuts meme has had a significant impact on popular culture. It has become a recognizable catchphrase that people use to add humor to everyday conversations. The meme has also been used to satirize politics and other serious issues. For example, in 2016, a group of activists in the UK used the phrase \"Deez Nuts for President\" as part of a campaign to encourage young people to vote in the EU referendum. [INST] Rewrite the essay in a more casual way. Instead of sounding proffesional, sound like a college student who is forced to write the essay but refuses to do so in the propper way. Use casual words and slang when possible. [/INST] Yo, so you want me to write a 1000-word essay about Deez Nuts? Alright, fine. So, this whole thing started on Vine back in 2015. Some dude named Rodney Bullard made a video where he would ask people if they knew a rapper, and when they said no, he would hit them with the classic line: \"Deez Nuts!\" People loved it, and it became a viral meme.\\n\\nNowadays, Deez Nuts is used for all kinds of stuff. You can throw it out there to interrupt someone or just to be funny. It\\'s all over the internet, in music, and even in politics. In fact, during the 2016 US presidential election, a kid named Brady Olson registered as an independent candidate under the name Deez Nuts. He actually got some attention from the media and made appearances on TV and everything.\\n\\nThe impact of Deez Nuts on our culture is pretty huge. It\\'s become a thing that everyone knows and uses to add some humor to their everyday conversations. Plus, people have used it to make fun of politics and serious issues too. Like, in the UK, some groups of activists used the phrase \"Deez Nuts for President\" to encourage young people to vote in the EU referendum.\\n\\nThere you have it, a thousand words about Deez Nuts in a more casual tone. Can I go back to playing video games now? '}\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "# Dataset\n", 344 | "data_name = \"mlabonne/guanaco-llama2-1k\"\n", 345 | "training_data = load_dataset(data_name, split=\"train\")\n", 346 | "# check the data\n", 347 | "print(training_data.shape)\n", 348 | "# #11 is a QA sample in English\n", 349 | "print(training_data[11])" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "id": "4ce556ea-9525-46ab-bcf2-1fe2e8c319d2", 355 | "metadata": {}, 356 | "source": [ 357 | "### Step 3: Start fine-tuning\n", 358 | "To set your training parameters, use the following code:" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 8, 364 | "id": "d525b641-c645-4987-9ae6-173d6a75e292", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "# Training Params\n", 369 | "train_params = TrainingArguments(\n", 370 | " output_dir=\"./results_lora\",\n", 371 | " num_train_epochs=10,\n", 372 | " per_device_train_batch_size=4,\n", 373 | " gradient_accumulation_steps=1,\n", 374 | " optim=\"paged_adamw_32bit\",\n", 375 | " save_steps=50,\n", 376 | " logging_steps=50,\n", 377 | " learning_rate=4e-5,\n", 378 | " weight_decay=0.001,\n", 379 | " fp16=False,\n", 380 | " bf16=False,\n", 381 | " max_grad_norm=0.3,\n", 382 | " max_steps=-1,\n", 383 | " warmup_ratio=0.03,\n", 384 | " group_by_length=True,\n", 385 | " lr_scheduler_type=\"constant\",\n", 386 | " report_to=\"tensorboard\"\n", 387 | ")" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "id": "6f35f6b7-4a2f-467c-b1fd-e53bf7dcd837", 393 | "metadata": {}, 394 | "source": [ 395 | "**NOTE**:You may decrease the per_device_train_batch_size if got OOM. Use rocm-smi to monitor the VRAM usage when running the finetuning." 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "id": "d0d9d33c-95b2-467c-8212-96a7810b5e3c", 401 | "metadata": {}, 402 | "source": [ 403 | "**Training with LoRA configuration**\n", 404 | "\n", 405 | "Now you can integrate LoRA into the base model and assess its additional parameters. LoRA essentially adds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, and only trains the newly added weights." 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 9, 411 | "id": "6f3eecd4-a968-42dd-89e2-f94b4633ecb7", 412 | "metadata": {}, 413 | "outputs": [ 414 | { 415 | "name": "stdout", 416 | "output_type": "stream", 417 | "text": [ 418 | "trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424\n" 419 | ] 420 | } 421 | ], 422 | "source": [ 423 | "from peft import get_peft_model\n", 424 | "# LoRA Config\n", 425 | "peft_parameters = LoraConfig(\n", 426 | " lora_alpha=8,\n", 427 | " lora_dropout=0.1,\n", 428 | " r=8,\n", 429 | " bias=\"none\",\n", 430 | " task_type=\"CAUSAL_LM\"\n", 431 | ")\n", 432 | "model = get_peft_model(base_model, peft_parameters)\n", 433 | "model.print_trainable_parameters()" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "id": "9ab9baef-c9a4-4438-aa88-cbcdcec199b9", 439 | "metadata": {}, 440 | "source": [ 441 | "Above show the trainalbe parameters in percent which is a tiny portion of the original model. This is the percentage we’ll update through fine-tuning, as follows." 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 10, 447 | "id": "444a176a-a893-46ee-8df9-b9bc4ddca1cb", 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "data": { 452 | "text/html": [ 453 | "\n", 454 | "

\n", 455 | " \n", 456 | " \n", 457 | " [2500/2500 1:31:20, Epoch 10/10]\n", 458 | "
\n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | "
StepTraining Loss
501.655600
1001.424400
1501.337400
2001.387800
2501.383900
3001.345500
3501.336700
4001.327900
4501.378700
5001.347700
5501.331600
6001.286200
6501.359600
7001.322700
7501.338200
8001.296900
8501.244300
9001.310100
9501.369600
10001.316200
10501.292300
11001.260800
11501.280100
12001.277000
12501.320300
13001.231800
13501.273100
14001.216500
14501.260300
15001.340600
15501.195200
16001.264000
16501.294900
17001.272100
17501.181300
18001.228300
18501.224200
19001.208300
19501.205300
20001.183900
20501.233200
21001.158900
21501.151300
22001.170000
22501.163700
23001.178300
23501.113900
24001.209200
24501.126500
25001.150100

" 669 | ], 670 | "text/plain": [ 671 | "" 672 | ] 673 | }, 674 | "metadata": {}, 675 | "output_type": "display_data" 676 | }, 677 | { 678 | "data": { 679 | "text/plain": [ 680 | "TrainOutput(global_step=2500, training_loss=1.2753287063598633, metrics={'train_runtime': 5486.2394, 'train_samples_per_second': 1.823, 'train_steps_per_second': 0.456, 'total_flos': 1.683891336976466e+17, 'train_loss': 1.2753287063598633, 'epoch': 10.0})" 681 | ] 682 | }, 683 | "execution_count": 10, 684 | "metadata": {}, 685 | "output_type": "execute_result" 686 | } 687 | ], 688 | "source": [ 689 | "# Trainer with LoRA configuration\n", 690 | "fine_tuning = SFTTrainer(\n", 691 | " model=base_model,\n", 692 | " train_dataset=training_data,\n", 693 | " peft_config=peft_parameters,\n", 694 | " #dataset_text_field=\"text\",\n", 695 | " #tokenizer=llama_tokenizer,\n", 696 | " args=train_params\n", 697 | ")\n", 698 | "\n", 699 | "# Training\n", 700 | "fine_tuning.train()" 701 | ] 702 | }, 703 | { 704 | "cell_type": "raw", 705 | "id": "38005e08-f1e8-4cde-a139-b18e74e41bb8", 706 | "metadata": {}, 707 | "source": [ 708 | "To save your model, run this code:" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 11, 714 | "id": "2c0830b9-5f74-4d6f-bb33-355332c7e499", 715 | "metadata": {}, 716 | "outputs": [], 717 | "source": [ 718 | "# Save Model\n", 719 | "fine_tuning.model.save_pretrained(new_model_name)" 720 | ] 721 | }, 722 | { 723 | "cell_type": "markdown", 724 | "id": "996f3834-0c06-4907-83cf-8f8ff9348f88", 725 | "metadata": {}, 726 | "source": [ 727 | "#### Checking memory usage during training with LoRA\n", 728 | "During training, you can check the memory usage by running the rocm-smi command in a terminal. This command produces the following output:" 729 | ] 730 | }, 731 | { 732 | "cell_type": "raw", 733 | "id": "9a05e31c-b5b5-4e58-ab2a-5b8004a3bec3", 734 | "metadata": {}, 735 | "source": [ 736 | "$rocm-smi" 737 | ] 738 | }, 739 | { 740 | "cell_type": "markdown", 741 | "id": "f42cf9ef-b529-4053-8f8e-f7cc80f34c33", 742 | "metadata": {}, 743 | "source": [ 744 | "To facilitate a comparison between fine-tuning with and without LoRA, our subsequent phase involves running a thorough fine-tuning process on the base model. This involves updating all parameters within the base model. We then analyze differences in memory usage, training speed, training loss, and other relevant metrics." 745 | ] 746 | }, 747 | { 748 | "cell_type": "markdown", 749 | "id": "f7312872-12a7-4472-8111-a2ff43ad05b5", 750 | "metadata": {}, 751 | "source": [ 752 | "### Step 4: Test the fine-tuned model with LoRA\n", 753 | "\n", 754 | "To test your model, run the following code:" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": 12, 760 | "id": "a4912dbf-3670-495e-b630-8d8474338e0a", 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [ 764 | "import torch\n", 765 | "from transformers import (\n", 766 | " AutoModelForCausalLM,\n", 767 | " AutoTokenizer,\n", 768 | " BitsAndBytesConfig,\n", 769 | " TrainingArguments,\n", 770 | " pipeline\n", 771 | ")" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "id": "419366ea-1aaf-4648-8350-bb2eea033c03", 777 | "metadata": {}, 778 | "source": [ 779 | "#### Fine-tuned Model Inference" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 13, 785 | "id": "0451b4ef-24b8-419e-816f-33eaac534dd4", 786 | "metadata": {}, 787 | "outputs": [ 788 | { 789 | "name": "stderr", 790 | "output_type": "stream", 791 | "text": [ 792 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.72it/s]\n" 793 | ] 794 | } 795 | ], 796 | "source": [ 797 | "# Reload model in FP16 and merge it with LoRA weights\n", 798 | "\n", 799 | "#base_model_name = \"/data/HF-MODEL/huggingface-model/Meta-Llama-3.1-8B/\"\n", 800 | "#new_model_name = \"Llama-3.1-8B-qlora\" #You can give your own name for fine tuned model\n", 801 | "\n", 802 | "base_model = AutoModelForCausalLM.from_pretrained(base_model_name)\n", 803 | "from peft import LoraConfig, PeftModel\n", 804 | "peft_model = PeftModel.from_pretrained(base_model, new_model_name)\n", 805 | "peft_model = peft_model.merge_and_unload()\n", 806 | "\n", 807 | "# Reload tokenizer to save it\n", 808 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 809 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 810 | "llama_tokenizer.padding_side = \"right\"" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 14, 816 | "id": "10ff3f8f-24a1-42f9-82e2-200e74da6966", 817 | "metadata": {}, 818 | "outputs": [ 819 | { 820 | "name": "stderr", 821 | "output_type": "stream", 822 | "text": [ 823 | "Device set to use cuda:0\n" 824 | ] 825 | } 826 | ], 827 | "source": [ 828 | "pipeline = pipeline(\n", 829 | " \"text-generation\", \n", 830 | " model=peft_model, \n", 831 | " tokenizer=llama_tokenizer,\n", 832 | " device_map=\"auto\"\n", 833 | ")" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": 15, 839 | "id": "4f4f91f6-5508-4f80-9bad-658c84141411", 840 | "metadata": {}, 841 | "outputs": [ 842 | { 843 | "name": "stdout", 844 | "output_type": "stream", 845 | "text": [ 846 | "[INST] What do you think is the most important part of building an AI chatbot? [/INST] The most important part of building an AI chatbot is creating a comprehensive and accurate dataset of relevant information\n" 847 | ] 848 | } 849 | ], 850 | "source": [ 851 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 852 | "output = pipeline(f\"[INST] {query} [/INST]\")\n", 853 | "print(output[0]['generated_text'])" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "id": "c6abb4ea-93bb-4e24-beef-57a080501aa8", 860 | "metadata": {}, 861 | "outputs": [], 862 | "source": [] 863 | } 864 | ], 865 | "metadata": { 866 | "kernelspec": { 867 | "display_name": "Python 3 (ipykernel)", 868 | "language": "python", 869 | "name": "python3" 870 | }, 871 | "language_info": { 872 | "codemirror_mode": { 873 | "name": "ipython", 874 | "version": 3 875 | }, 876 | "file_extension": ".py", 877 | "mimetype": "text/x-python", 878 | "name": "python", 879 | "nbconvert_exporter": "python", 880 | "pygments_lexer": "ipython3", 881 | "version": "3.10.14" 882 | } 883 | }, 884 | "nbformat": 4, 885 | "nbformat_minor": 5 886 | } 887 | -------------------------------------------------------------------------------- /training/QLoRA_Llama-3.1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c73b0caa-096b-45fe-b26f-032128d4334f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-tune Llama-3.1 with LoRA with AMD ROCm GPU" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "53d94b31-35f8-4c8c-af0a-8a10aa5b4c62", 14 | "metadata": {}, 15 | "source": [ 16 | "In this blog, we show you how to fine-tune Llama-3.1-8B on AMD GPU with ROCm. We use Low-Rank Adaptation of Large Language Models (LoRA) to overcome memory and computing limitations and make open-source large language models (LLMs) more accessible.\n", 17 | "\n", 18 | "## Step-by-step fine-tuning\n", 19 | "\n", 20 | "Standard (full-parameter) fine-tuning involves considering all parameters. It requires significant computational power to manage optimizer states and gradient check-pointing. The resulting memory footprint is typically about four times larger than the model itself.\n", 21 | "\n", 22 | "To overcome this memory limitation, you can use a parameter-efficient fine-tuning (PEFT) technique, such as LoRA.\n", 23 | "\n", 24 | "\n", 25 | "Our setup:\n", 26 | "\n", 27 | "- Hardware: AMD ROCm GPU (MI325X, MI300X, etc) [device list](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html)\n", 28 | "- Software:\n", 29 | " - ROCm 6.0+\n", 30 | " - Pytorch 2.0.1+\n", 31 | " - Libraries: transformers, accelerate, peft, trl, bitsandbytes, scipy" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "d7b72aff-616a-442c-91a2-b9974ab95255", 37 | "metadata": { 38 | "jp-MarkdownHeadingCollapsed": true 39 | }, 40 | "source": [ 41 | "### Step 1: Getting started\n", 42 | "\n", 43 | "First, let’s confirm the availability of the GPU." 44 | ] 45 | }, 46 | { 47 | "attachments": {}, 48 | "cell_type": "markdown", 49 | "id": "4926a00e-7805-4de6-bb72-43db16ac09a2", 50 | "metadata": {}, 51 | "source": [ 52 | "### Step 0: Setup ROCm environment\n", 53 | "\n", 54 | "The easyway is to use ROCm docker image from https://hub.docker.com/r/rocm/pytorch. I use TAG `rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2`.\n", 55 | "\n", 56 | "```bash\n", 57 | "$docker pull rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 58 | "```\n", 59 | "\n", 60 | "And here is my docker start command as your reference.\n", 61 | "\n", 62 | "```bash\n", 63 | "$alias drun='docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 8G --hostname=ROCm-FT -v /DATA:/DATA -w /DATA'\n", 64 | "\n", 65 | "$drun rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 66 | "```\n", 67 | "\n", 68 | "To run this jupyter notebook, you may install it by `pip install jupyter-lab`" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "983ea9f6-1d12-4d99-9135-00d79e4dd1ef", 74 | "metadata": {}, 75 | "source": [ 76 | "Next, install the required libraries." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 1, 82 | "id": "4f4a7d77-c810-44ba-b70d-4661139f1cf9", 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "\n", 90 | "\n", 91 | "============================ ROCm System Management Interface ============================\n", 92 | "====================================== Product Info ======================================\n", 93 | "GPU[0]\t\t: Card series: \t\tInstinct MI210\n", 94 | "GPU[0]\t\t: Card model: \t\t0x0c34\n", 95 | "GPU[0]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 96 | "GPU[0]\t\t: Card SKU: \t\tD67301V\n", 97 | "GPU[1]\t\t: Card series: \t\tInstinct MI210\n", 98 | "GPU[1]\t\t: Card model: \t\t0x0c34\n", 99 | "GPU[1]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 100 | "GPU[1]\t\t: Card SKU: \t\tD67301V\n", 101 | "GPU[2]\t\t: Card series: \t\tInstinct MI210\n", 102 | "GPU[2]\t\t: Card model: \t\t0x0c34\n", 103 | "GPU[2]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 104 | "GPU[2]\t\t: Card SKU: \t\tD67301V\n", 105 | "GPU[3]\t\t: Card series: \t\tInstinct MI210\n", 106 | "GPU[3]\t\t: Card model: \t\t0x0c34\n", 107 | "GPU[3]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 108 | "GPU[3]\t\t: Card SKU: \t\tD67301V\n", 109 | "==========================================================================================\n", 110 | "================================== End of ROCm SMI Log ===================================\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "!rocm-smi --showproductname" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "eaedeb92-bf10-4bbc-967c-57e14db7fc8b", 121 | "metadata": {}, 122 | "source": [ 123 | "!pip install -q pandas peft==0.14.0 transformers==4.47.1 trl==0.13.0 accelerate==1.2.1 scipy tensorboardX" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 2, 129 | "id": "5f85ff88-a4d5-4685-b1a2-96dcd48fc653", 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "peft 0.14.0\n", 137 | "transformers 4.47.1\n", 138 | "accelerate 1.2.1\n", 139 | "trl 0.13.0\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "%%bash\n", 145 | "pip list | grep peft\n", 146 | "pip list | grep transformer\n", 147 | "pip list | grep accelerate\n", 148 | "pip list | grep trl" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "id": "21e8dc50-0464-44c5-8931-0a6dfe812f79", 154 | "metadata": {}, 155 | "source": [ 156 | "#### Install bitsandbytes\n", 157 | "1. Install bitsandbytes using the following code.\n", 158 | "\n", 159 | "- For ROCm 6.2" 160 | ] 161 | }, 162 | { 163 | "cell_type": "raw", 164 | "id": "3d1640f6-b63e-4041-9e77-fff1bc23550d", 165 | "metadata": {}, 166 | "source": [ 167 | "# Install `bitsandbytes`\n", 168 | "git clone --recurse https://github.com/ROCm/bitsandbytes.git\n", 169 | "cd bitsandbytes\n", 170 | "git checkout rocm6.2_internal_testing\n", 171 | "make hip\n", 172 | "python setup.py install" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "69a0abf0-afad-4cc7-a46c-f832f107e1d7", 178 | "metadata": {}, 179 | "source": [ 180 | "2. Check the bitsandbytes version (0.42.0)." 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 3, 186 | "id": "cd5e772a-bfcb-4c7f-9457-dc22f631b6cf", 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "name": "stdout", 191 | "output_type": "stream", 192 | "text": [ 193 | "bitsandbytes 0.42.0\n" 194 | ] 195 | } 196 | ], 197 | "source": [ 198 | "%%bash\n", 199 | "pip list | grep bitsandbytes" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "61cb55cf-7f2d-45c6-9c5c-86a82ca4c9c6", 205 | "metadata": {}, 206 | "source": [ 207 | "#### Check and Set GPUs for fine-tuning" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 4, 213 | "id": "27fb94e7-c059-4883-97dc-c36546e65236", 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "name": "stdout", 218 | "output_type": "stream", 219 | "text": [ 220 | "PyTorch detected number of availabel devices: 4\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "import os\n", 226 | "import torch\n", 227 | "# set visible gpus as need\n", 228 | "gpus = [0, 1, 2, 3]\n", 229 | "os.environ.setdefault(\"CUDA_VISIBLE_DEVICES\", ','.join(map(str, gpus)))\n", 230 | "print(f\"PyTorch detected number of availabel devices: {torch.cuda.device_count()}\")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "id": "85eb3982-f46e-47e5-b4f7-f9fbf873a2fc", 236 | "metadata": {}, 237 | "source": [ 238 | "#### Import the required packages" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 5, 244 | "id": "26886732-b369-495f-8b6b-decdf0564219", 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stderr", 249 | "output_type": "stream", 250 | "text": [ 251 | "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 252 | " from .autonotebook import tqdm as notebook_tqdm\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "from datasets import load_dataset\n", 258 | "from transformers import (\n", 259 | " AutoModelForCausalLM,\n", 260 | " AutoTokenizer,\n", 261 | " BitsAndBytesConfig,\n", 262 | " TrainingArguments,\n", 263 | " pipeline\n", 264 | ")\n", 265 | "from peft import LoraConfig\n", 266 | "from trl import SFTTrainer" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "id": "441d747b-4ef6-4969-9d15-f9834a5ee6bb", 272 | "metadata": {}, 273 | "source": [ 274 | "### Step 2: Configuring the model and data\n", 275 | "Please make sure the LLM model files has been download and use the real path in the below code cell." 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 6, 281 | "id": "a804704e-f5ec-4978-9589-85dd4ecb7e50", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "# Model and tokenizer names\n", 286 | "base_model_name = \"/data/HF-MODEL/huggingface-model/Meta-Llama-3.1-8B/\"\n", 287 | "new_model_name = \"Llama-3.1-8B-qlora\" #You can give your own name for fine tuned model\n", 288 | "\n", 289 | "# Tokenizer\n", 290 | "#llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, use_fast=True)\n", 291 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n", 292 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 293 | "llama_tokenizer.padding_side = \"right\"" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "id": "b8e95435-cb3b-4a2e-be80-d3881ea5ad13", 299 | "metadata": {}, 300 | "source": [ 301 | "**QLoRA 4-bit quantization configuration**" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "07e43bd6-b6b0-4842-854b-4ef9491d02a5", 307 | "metadata": {}, 308 | "source": [ 309 | "As outlined in the paper, QLoRA stores weights in 4-bits, allowing computation to occur in 16 or 32-bit precision. This means whenever a QLoRA weight tensor is used, we dequantize the tensor to 16 or 32-bit precision, and then perform a matrix multiplication. Various combinations, such as float16, bfloat16, float32, etc., can be chosen. Experimentation with different 4-bit quantization variants, including normalized float 4 (NF4), or pure float4 quantization, is possible. However, guided by theoretical considerations and empirical findings from the paper, the recommendation is to opt for NF4 quantization, as it tends to deliver better performance.\n", 310 | "\n", 311 | "In our case, we chose the following configuration:\n", 312 | "\n", 313 | "4-bit quantization with NF4 type\n", 314 | "\n", 315 | "16-bit (float16) for computation\n", 316 | "\n", 317 | "Double quantization, which uses a second quantization after the first one to save an additional 0.3 bits per parameters\n", 318 | "\n", 319 | "Quantization parameters are controlled from the BitsandbytesConfig (see Hugging Face documentation) as follows:\n", 320 | "\n", 321 | "Loading in 4 bits is activated through load_in_4bit\n", 322 | "\n", 323 | "The datatype used for quantization is specified with bnb_4bit_quant_type. Note that there are two supported quantization datatypes fp4 (four-bit float) and nf4 (normal four-bit float). The latter is theoretically optimal for normally distributed weights, so we recommend using nf4.\n", 324 | "\n", 325 | "The datatype used for the linear layer computations with bnb_4bit_compute_dtype\n", 326 | "\n", 327 | "Nested quantization is activated through bnb_4bit_use_double_quant" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 7, 333 | "id": "c7ca2d59-199e-40fb-974f-8612cad9a8ae", 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "# Quantization Config\n", 338 | "quant_config = BitsAndBytesConfig(\n", 339 | " load_in_4bit=True,\n", 340 | " bnb_4bit_quant_type=\"nf4\",\n", 341 | " bnb_4bit_compute_dtype=torch.float16,\n", 342 | " bnb_4bit_use_double_quant=True\n", 343 | ")" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "id": "e87a7838-a6c3-42cf-a1f2-09ba9ae1a77f", 349 | "metadata": {}, 350 | "source": [ 351 | "**Load the model and set the quantization configuration**" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 8, 357 | "id": "3c916dcd-fc94-4214-895a-9720ad3ec3ea", 358 | "metadata": { 359 | "scrolled": true 360 | }, 361 | "outputs": [ 362 | { 363 | "name": "stderr", 364 | "output_type": "stream", 365 | "text": [ 366 | "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends.\n", 367 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00, 1.30s/it]\n" 368 | ] 369 | } 370 | ], 371 | "source": [ 372 | "# Model\n", 373 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 374 | " base_model_name,\n", 375 | " device_map=\"auto\",\n", 376 | " quantization_config=quant_config,\n", 377 | " trust_remote_code=True\n", 378 | ")\n", 379 | "base_model.config.use_cache = False\n", 380 | "base_model.config.pretraining_tp = 1" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "id": "4b52bdfa-87d5-4757-8e2b-5bbf6428f824", 386 | "metadata": {}, 387 | "source": [ 388 | "**Dataset configuration**\n", 389 | "\n", 390 | "We fine-tune our base model for a question-and-answer task using a small data set called mlabonne/guanaco-llama2-1k, which is a subset (1,000 samples) of the timdettmers/openassistant-guanaco data set. This data set is a human-generated, human-annotated, assistant-style conversation corpus that contains 161,443 messages in 35 different languages, annotated with 461,292 quality ratings. This results in over 10,000 fully annotated conversation trees." 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 9, 396 | "id": "235355d2-634d-4444-8497-71058a1e473a", 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "name": "stdout", 401 | "output_type": "stream", 402 | "text": [ 403 | "(1000, 1)\n", 404 | "{'text': '[INST] write me a 1000 words essay about deez nuts. [/INST] The Deez Nuts meme first gained popularity in 2015 on the social media platform Vine. The video featured a young man named Rodney Bullard, who recorded himself asking people if they had heard of a particular rapper. When they responded that they had not, he would respond with the phrase \"Deez Nuts\" and film their reactions. The video quickly went viral, and the phrase became a popular meme. \\n\\nSince then, Deez Nuts has been used in a variety of contexts to interrupt conversations, derail discussions, or simply add humor to a situation. It has been used in internet memes, in popular music, and even in politics. In the 2016 US presidential election, a 15-year-old boy named Brady Olson registered as an independent candidate under the name Deez Nuts. He gained some traction in the polls and even made appearances on national news programs.\\n\\nThe Deez Nuts meme has had a significant impact on popular culture. It has become a recognizable catchphrase that people use to add humor to everyday conversations. The meme has also been used to satirize politics and other serious issues. For example, in 2016, a group of activists in the UK used the phrase \"Deez Nuts for President\" as part of a campaign to encourage young people to vote in the EU referendum. [INST] Rewrite the essay in a more casual way. Instead of sounding proffesional, sound like a college student who is forced to write the essay but refuses to do so in the propper way. Use casual words and slang when possible. [/INST] Yo, so you want me to write a 1000-word essay about Deez Nuts? Alright, fine. So, this whole thing started on Vine back in 2015. Some dude named Rodney Bullard made a video where he would ask people if they knew a rapper, and when they said no, he would hit them with the classic line: \"Deez Nuts!\" People loved it, and it became a viral meme.\\n\\nNowadays, Deez Nuts is used for all kinds of stuff. You can throw it out there to interrupt someone or just to be funny. It\\'s all over the internet, in music, and even in politics. In fact, during the 2016 US presidential election, a kid named Brady Olson registered as an independent candidate under the name Deez Nuts. He actually got some attention from the media and made appearances on TV and everything.\\n\\nThe impact of Deez Nuts on our culture is pretty huge. It\\'s become a thing that everyone knows and uses to add some humor to their everyday conversations. Plus, people have used it to make fun of politics and serious issues too. Like, in the UK, some groups of activists used the phrase \"Deez Nuts for President\" to encourage young people to vote in the EU referendum.\\n\\nThere you have it, a thousand words about Deez Nuts in a more casual tone. Can I go back to playing video games now? '}\n" 405 | ] 406 | } 407 | ], 408 | "source": [ 409 | "# Dataset\n", 410 | "data_name = \"mlabonne/guanaco-llama2-1k\"\n", 411 | "training_data = load_dataset(data_name, split=\"train\")\n", 412 | "# check the data\n", 413 | "print(training_data.shape)\n", 414 | "# #11 is a QA sample in English\n", 415 | "print(training_data[11])" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "id": "4ce556ea-9525-46ab-bcf2-1fe2e8c319d2", 421 | "metadata": {}, 422 | "source": [ 423 | "### Step 3: Start fine-tuning\n", 424 | "To set your training parameters, use the following code:" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 10, 430 | "id": "d525b641-c645-4987-9ae6-173d6a75e292", 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "# Training Params\n", 435 | "train_params = TrainingArguments(\n", 436 | " output_dir=\"./results_qloar\",\n", 437 | " num_train_epochs=1,\n", 438 | " per_device_train_batch_size=4,\n", 439 | " gradient_accumulation_steps=1,\n", 440 | " optim=\"paged_adamw_32bit\",\n", 441 | " save_steps=50,\n", 442 | " logging_steps=50,\n", 443 | " learning_rate=4e-5,\n", 444 | " weight_decay=0.001,\n", 445 | " fp16=False,\n", 446 | " bf16=False,\n", 447 | " max_grad_norm=0.3,\n", 448 | " max_steps=-1,\n", 449 | " warmup_ratio=0.03,\n", 450 | " group_by_length=True,\n", 451 | " lr_scheduler_type=\"constant\",\n", 452 | " report_to=\"tensorboard\"\n", 453 | ")" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "id": "6f35f6b7-4a2f-467c-b1fd-e53bf7dcd837", 459 | "metadata": {}, 460 | "source": [ 461 | "**NOTE**: You may decrease the per_device_train_batch_size if got OOM. Use rocm-smi to monitor the VRAM usage when running the finetuning." 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "id": "d0d9d33c-95b2-467c-8212-96a7810b5e3c", 467 | "metadata": { 468 | "jp-MarkdownHeadingCollapsed": true 469 | }, 470 | "source": [ 471 | "**Training with QLoRA configuration**\n", 472 | "\n", 473 | "Now you can integrate LoRA into the base model and assess its additional parameters. LoRA essentially adds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, and only trains the newly added weights." 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 11, 479 | "id": "6f3eecd4-a968-42dd-89e2-f94b4633ecb7", 480 | "metadata": {}, 481 | "outputs": [ 482 | { 483 | "name": "stdout", 484 | "output_type": "stream", 485 | "text": [ 486 | "trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424\n" 487 | ] 488 | } 489 | ], 490 | "source": [ 491 | "from peft import get_peft_model\n", 492 | "# LoRA Config\n", 493 | "peft_parameters = LoraConfig(\n", 494 | " lora_alpha=8,\n", 495 | " lora_dropout=0.1,\n", 496 | " r=8,\n", 497 | " bias=\"none\",\n", 498 | " task_type=\"CAUSAL_LM\"\n", 499 | ")\n", 500 | "model = get_peft_model(base_model, peft_parameters)\n", 501 | "model.print_trainable_parameters()" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "id": "9ab9baef-c9a4-4438-aa88-cbcdcec199b9", 507 | "metadata": {}, 508 | "source": [ 509 | "Above show the trainalbe parameters in percent which is a tiny portion of the original model. This is the percentage we’ll update through fine-tuning, as follows." 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 12, 515 | "id": "444a176a-a893-46ee-8df9-b9bc4ddca1cb", 516 | "metadata": {}, 517 | "outputs": [ 518 | { 519 | "data": { 520 | "text/html": [ 521 | "\n", 522 | "

\n", 523 | " \n", 524 | " \n", 525 | " [250/250 09:04, Epoch 1/1]\n", 526 | "
\n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | "
StepTraining Loss
501.652400
1001.423200
1501.336600
2001.386900
2501.383500

" 557 | ], 558 | "text/plain": [ 559 | "" 560 | ] 561 | }, 562 | "metadata": {}, 563 | "output_type": "display_data" 564 | }, 565 | { 566 | "data": { 567 | "text/plain": [ 568 | "TrainOutput(global_step=250, training_loss=1.4364994506835937, metrics={'train_runtime': 551.5757, 'train_samples_per_second': 1.813, 'train_steps_per_second': 0.453, 'total_flos': 1.6854644828110848e+16, 'train_loss': 1.4364994506835937, 'epoch': 1.0})" 569 | ] 570 | }, 571 | "execution_count": 12, 572 | "metadata": {}, 573 | "output_type": "execute_result" 574 | } 575 | ], 576 | "source": [ 577 | "# Trainer with LoRA configuration\n", 578 | "fine_tuning = SFTTrainer(\n", 579 | " model=base_model,\n", 580 | " train_dataset=training_data,\n", 581 | " peft_config=peft_parameters,\n", 582 | " #dataset_text_field=\"text\",\n", 583 | " #tokenizer=llama_tokenizer,\n", 584 | " args=train_params\n", 585 | ")\n", 586 | "\n", 587 | "# Training\n", 588 | "fine_tuning.train()" 589 | ] 590 | }, 591 | { 592 | "cell_type": "raw", 593 | "id": "38005e08-f1e8-4cde-a139-b18e74e41bb8", 594 | "metadata": {}, 595 | "source": [ 596 | "To save your model, run this code:" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 13, 602 | "id": "2c0830b9-5f74-4d6f-bb33-355332c7e499", 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [ 606 | "# Save Model\n", 607 | "fine_tuning.model.save_pretrained(new_model_name)" 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "id": "996f3834-0c06-4907-83cf-8f8ff9348f88", 613 | "metadata": {}, 614 | "source": [ 615 | "#### Checking memory usage during training with LoRA\n", 616 | "During training, you can check the memory usage by running the rocm-smi command in a terminal. This command produces the following output:" 617 | ] 618 | }, 619 | { 620 | "cell_type": "raw", 621 | "id": "9a05e31c-b5b5-4e58-ab2a-5b8004a3bec3", 622 | "metadata": {}, 623 | "source": [ 624 | "$rocm-smi" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "id": "f42cf9ef-b529-4053-8f8e-f7cc80f34c33", 630 | "metadata": {}, 631 | "source": [ 632 | "To facilitate a comparison between fine-tuning with and without QLoRA, our subsequent phase involves running a thorough fine-tuning process on the base model. This involves updating all parameters within the base model. We then analyze differences in memory usage, training speed, training loss, and other relevant metrics." 633 | ] 634 | }, 635 | { 636 | "cell_type": "markdown", 637 | "id": "f7312872-12a7-4472-8111-a2ff43ad05b5", 638 | "metadata": {}, 639 | "source": [ 640 | "### Step 4: Test the fine-tuned model with LoRA\n", 641 | "\n", 642 | "To test your model, run the following code:" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": 14, 648 | "id": "a4912dbf-3670-495e-b630-8d8474338e0a", 649 | "metadata": {}, 650 | "outputs": [], 651 | "source": [ 652 | "import torch\n", 653 | "from transformers import (\n", 654 | " AutoModelForCausalLM,\n", 655 | " AutoTokenizer,\n", 656 | " BitsAndBytesConfig,\n", 657 | " TrainingArguments,\n", 658 | " pipeline\n", 659 | ")" 660 | ] 661 | }, 662 | { 663 | "cell_type": "markdown", 664 | "id": "419366ea-1aaf-4648-8350-bb2eea033c03", 665 | "metadata": {}, 666 | "source": [ 667 | "#### Fine-tuned Model Inference" 668 | ] 669 | }, 670 | { 671 | "cell_type": "markdown", 672 | "id": "ec7c7709-89d3-471a-8402-434c580fd77e", 673 | "metadata": {}, 674 | "source": [ 675 | "# Model and tokenizer names\n" 676 | ] 677 | }, 678 | { 679 | "cell_type": "code", 680 | "execution_count": 15, 681 | "id": "0451b4ef-24b8-419e-816f-33eaac534dd4", 682 | "metadata": {}, 683 | "outputs": [ 684 | { 685 | "name": "stderr", 686 | "output_type": "stream", 687 | "text": [ 688 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.50it/s]\n" 689 | ] 690 | } 691 | ], 692 | "source": [ 693 | "# Reload model in FP16 and merge it with LoRA weights\n", 694 | "\n", 695 | "#base_model_name = \"/data/HF-MODEL/huggingface-model/Meta-Llama-3.1-8B/\"\n", 696 | "#new_model_name = \"Llama-3.1-8B-qlora\" #You can give your own name for fine tuned model\n", 697 | "\n", 698 | "base_model = AutoModelForCausalLM.from_pretrained(base_model_name)\n", 699 | "from peft import LoraConfig, PeftModel\n", 700 | "peft_model = PeftModel.from_pretrained(base_model, new_model_name)\n", 701 | "peft_model = peft_model.merge_and_unload()\n", 702 | "\n", 703 | "# Reload tokenizer\n", 704 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 705 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 706 | "llama_tokenizer.padding_side = \"right\"" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": 16, 712 | "id": "10ff3f8f-24a1-42f9-82e2-200e74da6966", 713 | "metadata": {}, 714 | "outputs": [ 715 | { 716 | "name": "stderr", 717 | "output_type": "stream", 718 | "text": [ 719 | "Device set to use cuda:0\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "pipeline = pipeline(\n", 725 | " \"text-generation\", \n", 726 | " model=peft_model, \n", 727 | " tokenizer=llama_tokenizer,\n", 728 | " device_map=\"auto\"\n", 729 | ")" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": 17, 735 | "id": "4f4f91f6-5508-4f80-9bad-658c84141411", 736 | "metadata": {}, 737 | "outputs": [ 738 | { 739 | "name": "stdout", 740 | "output_type": "stream", 741 | "text": [ 742 | "[INST] What do you think is the most important part of building an AI chatbot? [/INST] There are many different aspects to consider when building an AI chatbot, but I believe that the most\n" 743 | ] 744 | } 745 | ], 746 | "source": [ 747 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 748 | "output = pipeline(f\"[INST] {query} [/INST]\")\n", 749 | "print(output[0]['generated_text'])" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": null, 755 | "id": "c6abb4ea-93bb-4e24-beef-57a080501aa8", 756 | "metadata": {}, 757 | "outputs": [], 758 | "source": [] 759 | } 760 | ], 761 | "metadata": { 762 | "kernelspec": { 763 | "display_name": "Python 3 (ipykernel)", 764 | "language": "python", 765 | "name": "python3" 766 | }, 767 | "language_info": { 768 | "codemirror_mode": { 769 | "name": "ipython", 770 | "version": 3 771 | }, 772 | "file_extension": ".py", 773 | "mimetype": "text/x-python", 774 | "name": "python", 775 | "nbconvert_exporter": "python", 776 | "pygments_lexer": "ipython3", 777 | "version": "3.10.14" 778 | } 779 | }, 780 | "nbformat": 4, 781 | "nbformat_minor": 5 782 | } 783 | -------------------------------------------------------------------------------- /training/W7900_LoRA_Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c73b0caa-096b-45fe-b26f-032128d4334f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-tune Llama 2 with LoRA by AMD Radeon Pro W7900" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "53d94b31-35f8-4c8c-af0a-8a10aa5b4c62", 14 | "metadata": {}, 15 | "source": [ 16 | "In this blog, we show you how to fine-tune Llama 2 on one AMD Radeon Pro W7900 GPU(48GB GDDR) with ROCm. We use Low-Rank Adaptation of Large Language Models (LoRA) to overcome memory and computing limitations and make open-source large language models (LLMs) more accessible. We also show you how to fine-tune and upload models to Hugging Face.\n", 17 | "\n", 18 | "This blog is refer to https://rocm.blogs.amd.com/artificial-intelligence/llama2-lora/README.html which give some technical background about Llama2, Fine-tuning, LoRa and run the LoRA finetuning by AMD MI250 GPU. Here let's jump to the steps of fine-tuning by AMD Radeon Pro W7900 GPU.\n", 19 | "\n", 20 | "## Step-by-step Llama 2 fine-tuning\n", 21 | "\n", 22 | "Standard (full-parameter) fine-tuning involves considering all parameters. It requires significant computational power to manage optimizer states and gradient check-pointing. The resulting memory footprint is typically about four times larger than the model itself. For example, loading a 7 billion parameter model (e.g. Llama 2) in FP32 (4 bytes per parameter) requires approximately 28 GB of GPU memory, while fine-tuning demands around 28*4=112 GB of GPU memory. Note that the 112 GB figure is derived empirically, and various factors like batch size, data precision, and gradient accumulation contribute to overall memory usage.\n", 23 | "\n", 24 | "To overcome this memory limitation, you can use a parameter-efficient fine-tuning (PEFT) technique, such as LoRA.\n", 25 | "\n", 26 | "This example leverages tne AMD Radeon Pro W7900 GPU with 48GB VRAM. Using this setup allows us to explore different settings for fine-tuning the Llama 2–7b weights with LoRA.\n", 27 | "\n", 28 | "\n", 29 | "Our setup:\n", 30 | "\n", 31 | "- Hardware: AMD Radeon Pro W7900\n", 32 | "- Software:\n", 33 | " - ROCm 6.0+\n", 34 | " - Pytorch 2.0.1+\n", 35 | "\n", 36 | "Libraries: transformers, accelerate, peft, trl, bitsandbytes, scipy" 37 | ] 38 | }, 39 | { 40 | "attachments": {}, 41 | "cell_type": "markdown", 42 | "id": "4926a00e-7805-4de6-bb72-43db16ac09a2", 43 | "metadata": {}, 44 | "source": [ 45 | "### Step 0: Setup ROCm environment\n", 46 | "\n", 47 | "The easyway is to use ROCm docker image from https://hub.docker.com/r/rocm/pytorch. I use TAG rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2.\n", 48 | "\n", 49 | "$docker pull rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 50 | "\n", 51 | "And here is my docker start command as your reference.\n", 52 | "\n", 53 | "```\n", 54 | "$alias drun='docker run -it --network=host --device=/dev/kfd --device=/dev/dri/renderD128 --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 8G --hostname=w7900 -p 80:80 -p 8080:8080 -v /DATA:/DATA -w /DATA'\n", 55 | "\n", 56 | "$drun rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 57 | "```\n", 58 | "\n", 59 | "To run this jupyter notebook, you may install it by `pip install jupyter-lab`" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "d7b72aff-616a-442c-91a2-b9974ab95255", 65 | "metadata": {}, 66 | "source": [ 67 | "### Step 1: Getting started\n", 68 | "\n", 69 | "First, let’s confirm the availability of the GPU." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 1, 75 | "id": "4f4a7d77-c810-44ba-b70d-4661139f1cf9", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "\n", 83 | "\n", 84 | "============================ ROCm System Management Interface ============================\n", 85 | "====================================== Product Info ======================================\n", 86 | "GPU[0]\t\t: Card series: \t\t0x7448\n", 87 | "GPU[0]\t\t: Card model: \t\t0x0e0d\n", 88 | "GPU[0]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 89 | "GPU[0]\t\t: Card SKU: \t\tD7070100\n", 90 | "==========================================================================================\n", 91 | "================================== End of ROCm SMI Log ===================================\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "!rocm-smi --showproductname" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "983ea9f6-1d12-4d99-9135-00d79e4dd1ef", 102 | "metadata": {}, 103 | "source": [ 104 | "Next, install the required libraries." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 2, 110 | "id": "41310ccd-6dd8-4222-81b4-143d79a377ec", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "!pip install -q pandas peft==0.9.0 transformers==4.31.0 trl==0.4.7 accelerate scipy" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "21e8dc50-0464-44c5-8931-0a6dfe812f79", 120 | "metadata": {}, 121 | "source": [ 122 | "#### Install bitsandbytes\n", 123 | "1. Install bitsandbytes using the following code.\n", 124 | "\n", 125 | "- For ROCm 6.2" 126 | ] 127 | }, 128 | { 129 | "cell_type": "raw", 130 | "id": "dda99cff-bd74-4dbe-af9a-34d53001ef20", 131 | "metadata": {}, 132 | "source": [ 133 | "# Install `bitsandbytes`\n", 134 | "git clone --recurse https://github.com/ROCm/bitsandbytes.git\n", 135 | "cd bitsandbytes\n", 136 | "git checkout rocm6.2_internal_testing\n", 137 | "make hip\n", 138 | "python setup.py install" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "69a0abf0-afad-4cc7-a46c-f832f107e1d7", 144 | "metadata": {}, 145 | "source": [ 146 | "2. Check the bitsandbytes version (0.42.0)." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 3, 152 | "id": "948e866c-609e-492d-9a51-fd9760f68759", 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "bitsandbytes 0.42.0\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "%%bash\n", 165 | "pip list | grep bitsandbytes" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "85eb3982-f46e-47e5-b4f7-f9fbf873a2fc", 171 | "metadata": {}, 172 | "source": [ 173 | "#### Import the required packages" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 4, 179 | "id": "26886732-b369-495f-8b6b-decdf0564219", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stderr", 184 | "output_type": "stream", 185 | "text": [ 186 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 187 | " from .autonotebook import tqdm as notebook_tqdm\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "import torch\n", 193 | "from datasets import load_dataset\n", 194 | "from transformers import (\n", 195 | " AutoModelForCausalLM,\n", 196 | " AutoTokenizer,\n", 197 | " BitsAndBytesConfig,\n", 198 | " TrainingArguments,\n", 199 | " pipeline\n", 200 | ")\n", 201 | "from peft import LoraConfig\n", 202 | "from trl import SFTTrainer" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "441d747b-4ef6-4969-9d15-f9834a5ee6bb", 208 | "metadata": {}, 209 | "source": [ 210 | "### Step 2: Configuring the model and data\n", 211 | "You can access Meta’s official Llama-2 model from Hugging Face after making a request, which can take a couple of days. Instead of waiting, we’ll use NousResearch’s Llama-2-7b-chat-hf as our base model (it’s the same as the original, but quicker to access). I downloaded it into /DATA/NousResearch/Llama-2-7b-chat-hf/ of my machine ahead." 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 5, 217 | "id": "3c916dcd-fc94-4214-895a-9720ad3ec3ea", 218 | "metadata": { 219 | "scrolled": true 220 | }, 221 | "outputs": [ 222 | { 223 | "name": "stderr", 224 | "output_type": "stream", 225 | "text": [ 226 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:21<00:00, 10.91s/it]\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "# Model and tokenizer names\n", 232 | "base_model_name = \"/DATA/NousResearch/Llama-2-7b-chat-hf/\"\n", 233 | "new_model_name = \"llama-2-7b-chat-enhanced\" #You can give your own name for fine tuned model\n", 234 | "\n", 235 | "# Tokenizer\n", 236 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 237 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 238 | "llama_tokenizer.padding_side = \"right\"\n", 239 | "\n", 240 | "# Model\n", 241 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 242 | " base_model_name,\n", 243 | " device_map=\"auto\"\n", 244 | ")\n", 245 | "base_model.config.use_cache = False\n", 246 | "base_model.config.pretraining_tp = 1" 247 | ] 248 | }, 249 | { 250 | "cell_type": "raw", 251 | "id": "80bac7fb-b5c4-4560-b3bd-6fdd1493b257", 252 | "metadata": {}, 253 | "source": [ 254 | "$ rocm-smi\n", 255 | "\n", 256 | "========================================== ROCm System Management Interface ==========================================\n", 257 | "==================================================== Concise Info ====================================================\n", 258 | "Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%\n", 259 | " Name (20 chars) (Edge) (Avg) (Mem, Compute)\n", 260 | "======================================================================================================================\n", 261 | "0 [0x0e0d : 0x00] 38.0°C 29.0W N/A, N/A 29Mhz 96Mhz 20.0% auto 241.0W 57% 0%\n", 262 | " 0x7448\n", 263 | "======================================================================================================================\n", 264 | "================================================ End of ROCm SMI Log =================================================\n" 265 | ] 266 | }, 267 | { 268 | "cell_type": "raw", 269 | "id": "5211ad86-ab25-4618-bb08-f0666ed52bf9", 270 | "metadata": {}, 271 | "source": [ 272 | "After you have the base model, you can start fine-tuning. We fine-tune our base model for a question-and-answer task using a small data set called mlabonne/guanaco-llama2-1k, which is a subset (1,000 samples) of the timdettmers/openassistant-guanaco data set. This data set is a human-generated, human-annotated, assistant-style conversation corpus that contains 161,443 messages in 35 different languages, annotated with 461,292 quality ratings. This results in over 10,000 fully annotated conversation trees." 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 6, 278 | "id": "235355d2-634d-4444-8497-71058a1e473a", 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "(1000, 1)\n", 286 | "{'text': '[INST] write me a 1000 words essay about deez nuts. [/INST] The Deez Nuts meme first gained popularity in 2015 on the social media platform Vine. The video featured a young man named Rodney Bullard, who recorded himself asking people if they had heard of a particular rapper. When they responded that they had not, he would respond with the phrase \"Deez Nuts\" and film their reactions. The video quickly went viral, and the phrase became a popular meme. \\n\\nSince then, Deez Nuts has been used in a variety of contexts to interrupt conversations, derail discussions, or simply add humor to a situation. It has been used in internet memes, in popular music, and even in politics. In the 2016 US presidential election, a 15-year-old boy named Brady Olson registered as an independent candidate under the name Deez Nuts. He gained some traction in the polls and even made appearances on national news programs.\\n\\nThe Deez Nuts meme has had a significant impact on popular culture. It has become a recognizable catchphrase that people use to add humor to everyday conversations. The meme has also been used to satirize politics and other serious issues. For example, in 2016, a group of activists in the UK used the phrase \"Deez Nuts for President\" as part of a campaign to encourage young people to vote in the EU referendum. [INST] Rewrite the essay in a more casual way. Instead of sounding proffesional, sound like a college student who is forced to write the essay but refuses to do so in the propper way. Use casual words and slang when possible. [/INST] Yo, so you want me to write a 1000-word essay about Deez Nuts? Alright, fine. So, this whole thing started on Vine back in 2015. Some dude named Rodney Bullard made a video where he would ask people if they knew a rapper, and when they said no, he would hit them with the classic line: \"Deez Nuts!\" People loved it, and it became a viral meme.\\n\\nNowadays, Deez Nuts is used for all kinds of stuff. You can throw it out there to interrupt someone or just to be funny. It\\'s all over the internet, in music, and even in politics. In fact, during the 2016 US presidential election, a kid named Brady Olson registered as an independent candidate under the name Deez Nuts. He actually got some attention from the media and made appearances on TV and everything.\\n\\nThe impact of Deez Nuts on our culture is pretty huge. It\\'s become a thing that everyone knows and uses to add some humor to their everyday conversations. Plus, people have used it to make fun of politics and serious issues too. Like, in the UK, some groups of activists used the phrase \"Deez Nuts for President\" to encourage young people to vote in the EU referendum.\\n\\nThere you have it, a thousand words about Deez Nuts in a more casual tone. Can I go back to playing video games now? '}\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "# Data set\n", 292 | "data_name = \"mlabonne/guanaco-llama2-1k\"\n", 293 | "training_data = load_dataset(data_name, split=\"train\")\n", 294 | "# check the data\n", 295 | "print(training_data.shape)\n", 296 | "# #11 is a QA sample in English\n", 297 | "print(training_data[11])" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 7, 303 | "id": "076c8fa2-5b76-41ae-b5f3-b85e64faa368", 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "Requirement already satisfied: tensorboardX in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (2.6.2.2)\n", 311 | "Requirement already satisfied: protobuf>=3.20 in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (from tensorboardX) (3.20.2)\n", 312 | "Requirement already satisfied: packaging in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (from tensorboardX) (24.0)\n", 313 | "Requirement already satisfied: numpy in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (from tensorboardX) (1.22.4)\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "## There is a dependency during training\n", 319 | "!pip install tensorboardX" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "4ce556ea-9525-46ab-bcf2-1fe2e8c319d2", 325 | "metadata": {}, 326 | "source": [ 327 | "### Step 3: Start fine-tuning\n", 328 | "To set your training parameters, use the following code:" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 8, 334 | "id": "d525b641-c645-4987-9ae6-173d6a75e292", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "# Training Params\n", 339 | "train_params = TrainingArguments(\n", 340 | " output_dir=\"./results_modified\",\n", 341 | " num_train_epochs=1,\n", 342 | " per_device_train_batch_size=1,\n", 343 | " gradient_accumulation_steps=1,\n", 344 | " optim=\"paged_adamw_32bit\",\n", 345 | " save_steps=50,\n", 346 | " logging_steps=50,\n", 347 | " learning_rate=4e-5,\n", 348 | " weight_decay=0.001,\n", 349 | " fp16=False,\n", 350 | " bf16=False,\n", 351 | " max_grad_norm=0.3,\n", 352 | " max_steps=-1,\n", 353 | " warmup_ratio=0.03,\n", 354 | " group_by_length=True,\n", 355 | " lr_scheduler_type=\"constant\",\n", 356 | " report_to=\"tensorboard\"\n", 357 | ")" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "6f35f6b7-4a2f-467c-b1fd-e53bf7dcd837", 363 | "metadata": {}, 364 | "source": [ 365 | " I got OOM with per_device_train_batch_size=2 at AMD Radeon Pro W7900 with 48GB VRAM. You will see the VRAM usage bellow when run the LoRA finetuning." 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "d0d9d33c-95b2-467c-8212-96a7810b5e3c", 371 | "metadata": {}, 372 | "source": [ 373 | "#### Training with LoRA configuration\n", 374 | "Now you can integrate LoRA into the base model and assess its additional parameters. LoRA essentially adds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, and only trains the newly added weights." 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 9, 380 | "id": "6f3eecd4-a968-42dd-89e2-f94b4633ecb7", 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "name": "stdout", 385 | "output_type": "stream", 386 | "text": [ 387 | "trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199\n" 388 | ] 389 | } 390 | ], 391 | "source": [ 392 | "from peft import get_peft_model\n", 393 | "# LoRA Config\n", 394 | "peft_parameters = LoraConfig(\n", 395 | " lora_alpha=8,\n", 396 | " lora_dropout=0.1,\n", 397 | " r=8,\n", 398 | " bias=\"none\",\n", 399 | " task_type=\"CAUSAL_LM\"\n", 400 | ")\n", 401 | "model = get_peft_model(base_model, peft_parameters)\n", 402 | "model.print_trainable_parameters()" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "9ab9baef-c9a4-4438-aa88-cbcdcec199b9", 408 | "metadata": {}, 409 | "source": [ 410 | "Note that there are only 0.062% parameters added by LoRA, which is a tiny portion of the original model. This is the percentage we’ll update through fine-tuning, as follows." 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 10, 416 | "id": "d21483ad-5ad7-43fd-87d1-2d383af73f7d", 417 | "metadata": { 418 | "scrolled": true 419 | }, 420 | "outputs": [ 421 | { 422 | "name": "stderr", 423 | "output_type": "stream", 424 | "text": [ 425 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/trl/trainer/sft_trainer.py:159: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n", 426 | " warnings.warn(\n", 427 | "Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 6006.52 examples/s]\n", 428 | "You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" 429 | ] 430 | }, 431 | { 432 | "data": { 433 | "text/html": [ 434 | "\n", 435 | "

\n", 436 | " \n", 437 | " \n", 438 | " [1000/1000 16:04, Epoch 1/1]\n", 439 | "
\n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | "
StepTraining Loss
501.954200
1001.778300
1501.559400
2001.496700
2501.345100
3001.408000
3501.402600
4001.336200
4501.300600
5001.176100
5501.219600
6001.377500
6501.306000
7001.356700
7501.313100
8001.265400
8501.347300
9001.298100
9501.386100
10001.319700

" 530 | ], 531 | "text/plain": [ 532 | "" 533 | ] 534 | }, 535 | "metadata": {}, 536 | "output_type": "display_data" 537 | }, 538 | { 539 | "name": "stderr", 540 | "output_type": "stream", 541 | "text": [ 542 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 543 | " warnings.warn(\n", 544 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 545 | " warnings.warn(\n", 546 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 547 | " warnings.warn(\n", 548 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 549 | " warnings.warn(\n", 550 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 551 | " warnings.warn(\n", 552 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 553 | " warnings.warn(\n", 554 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 555 | " warnings.warn(\n", 556 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 557 | " warnings.warn(\n", 558 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 559 | " warnings.warn(\n", 560 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 561 | " warnings.warn(\n", 562 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 563 | " warnings.warn(\n", 564 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 565 | " warnings.warn(\n", 566 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 567 | " warnings.warn(\n", 568 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 569 | " warnings.warn(\n", 570 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 571 | " warnings.warn(\n", 572 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 573 | " warnings.warn(\n", 574 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 575 | " warnings.warn(\n", 576 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 577 | " warnings.warn(\n", 578 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 579 | " warnings.warn(\n", 580 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 581 | " warnings.warn(\n", 582 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 583 | " warnings.warn(\n" 584 | ] 585 | }, 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "TrainOutput(global_step=1000, training_loss=1.397334213256836, metrics={'train_runtime': 966.394, 'train_samples_per_second': 1.035, 'train_steps_per_second': 1.035, 'total_flos': 1.67211744380928e+16, 'train_loss': 1.397334213256836, 'epoch': 1.0})" 590 | ] 591 | }, 592 | "execution_count": 10, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "# Trainer with LoRA configuration\n", 599 | "fine_tuning = SFTTrainer(\n", 600 | " model=base_model,\n", 601 | " train_dataset=training_data,\n", 602 | " peft_config=peft_parameters,\n", 603 | " dataset_text_field=\"text\",\n", 604 | " tokenizer=llama_tokenizer,\n", 605 | " args=train_params\n", 606 | ")\n", 607 | "\n", 608 | "# Training\n", 609 | "fine_tuning.train()" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "id": "221a1781-648a-4902-b00e-096a71750d6f", 615 | "metadata": {}, 616 | "source": [ 617 | "The output looks like this:" 618 | ] 619 | }, 620 | { 621 | "cell_type": "raw", 622 | "id": "f89b1c04-8c7d-4dd5-9961-1e020cd3a2d0", 623 | "metadata": {}, 624 | "source": [ 625 | "[250/250 07:59, Epoch 1/1]\\\n", 626 | "Step Training Loss \\\n", 627 | "50 1.954200 \\\n", 628 | "100 1.778300\\\n", 629 | "150 1.559400\\\n", 630 | "200 1.496700\\\n", 631 | "250 1.345100\\\n", 632 | "...\n", 633 | "1000 1.319700\n", 634 | "\n", 635 | "TrainOutput(global_step=1000, training_loss=1.397334213256836, metrics={'train_runtime': 966.394, 'train_samples_per_second': 1.035, 'train_steps_per_second': 1.035, 'total_flos': 1.67211744380928e+16, 'train_loss': 1.397334213256836, 'epoch': 1.0})" 636 | ] 637 | }, 638 | { 639 | "cell_type": "raw", 640 | "id": "38005e08-f1e8-4cde-a139-b18e74e41bb8", 641 | "metadata": {}, 642 | "source": [ 643 | "To save your model, run this code:" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 11, 649 | "id": "2c0830b9-5f74-4d6f-bb33-355332c7e499", 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | "# Save Model\n", 654 | "fine_tuning.model.save_pretrained(new_model_name)" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "id": "996f3834-0c06-4907-83cf-8f8ff9348f88", 660 | "metadata": {}, 661 | "source": [ 662 | "#### Checking memory usage during training with LoRA\n", 663 | "During training, you can check the memory usage by running the rocm-smi command in a terminal. This command produces the following output:" 664 | ] 665 | }, 666 | { 667 | "cell_type": "raw", 668 | "id": "9a05e31c-b5b5-4e58-ab2a-5b8004a3bec3", 669 | "metadata": {}, 670 | "source": [ 671 | "$ rocm-smi\n", 672 | "\n", 673 | "\n", 674 | "=========================================== ROCm System Management Interface ===========================================\n", 675 | "===================================================== Concise Info =====================================================\n", 676 | "Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%\n", 677 | " Name (20 chars) (Edge) (Avg) (Mem, Compute)\n", 678 | "========================================================================================================================\n", 679 | "0 [0x0e0d : 0x00] 69.0°C 240.0W N/A, N/A 1950Mhz 1124Mhz 40.78% auto 241.0W 87% 99%\n", 680 | " 0x7448\n", 681 | "========================================================================================================================\n", 682 | "================================================= End of ROCm SMI Log ==================================================\n" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "id": "f42cf9ef-b529-4053-8f8e-f7cc80f34c33", 688 | "metadata": {}, 689 | "source": [ 690 | "To facilitate a comparison between fine-tuning with and without LoRA, our subsequent phase involves running a thorough fine-tuning process on the base model. This involves updating all parameters within the base model. We then analyze differences in memory usage, training speed, training loss, and other relevant metrics." 691 | ] 692 | }, 693 | { 694 | "cell_type": "markdown", 695 | "id": "80eae7a3-39e1-42b6-bb21-10440e414314", 696 | "metadata": {}, 697 | "source": [ 698 | "#### Training without LoRA configuration\n", 699 | "\n", 700 | "You may got OOM failed of full-parameter fine-tunning process refer to https://rocm.blogs.amd.com/artificial-intelligence/llama2-lora/README.html. Yes, one Radeon Pro W7900 with 48GB VRAM is not enough for these case." 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "id": "f7312872-12a7-4472-8111-a2ff43ad05b5", 706 | "metadata": {}, 707 | "source": [ 708 | "### Step 4: Test the fine-tuned model with LoRA\n", 709 | "\n", 710 | "To test your model, run the following code:" 711 | ] 712 | }, 713 | { 714 | "cell_type": "raw", 715 | "id": "4d6af907-b55b-4e5f-9b75-519fdd0c5ee0", 716 | "metadata": {}, 717 | "source": [ 718 | "# Reload model in FP16 and merge it with LoRA weights\n", 719 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 720 | " base_model_name,\n", 721 | " low_cpu_mem_usage=True,\n", 722 | " return_dict=True,\n", 723 | " torch_dtype=torch.float16,\n", 724 | " device_map=\"auto\"\n", 725 | ")\n", 726 | "from peft import LoraConfig, PeftModel\n", 727 | "model = PeftModel.from_pretrained(base_model, new_model_name)\n", 728 | "model = model.merge_and_unload()\n", 729 | "\n", 730 | "# Reload tokenizer to save it\n", 731 | "tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 732 | "tokenizer.pad_token = tokenizer.eos_token\n", 733 | "tokenizer.padding_side = \"right\"" 734 | ] 735 | }, 736 | { 737 | "cell_type": "markdown", 738 | "id": "3da05da4-faf6-4538-8894-30a9dd0bb427", 739 | "metadata": {}, 740 | "source": [ 741 | "The output looks like this:" 742 | ] 743 | }, 744 | { 745 | "cell_type": "raw", 746 | "id": "b7789dfa-f685-4c9d-9ecb-66a88d1e2705", 747 | "metadata": {}, 748 | "source": [ 749 | " Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.34s/it]" 750 | ] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "id": "f259219a-c28d-4404-aca1-85a61816cbb5", 755 | "metadata": {}, 756 | "source": [ 757 | "Uploading the model to Hugging Face let’s you conduct subsequent tests or share your model with others (to proceed with this step, you’ll need an active Hugging Face account)." 758 | ] 759 | }, 760 | { 761 | "cell_type": "raw", 762 | "id": "3a4cb5c0-0db9-4459-bfbe-91cfb1dedce6", 763 | "metadata": {}, 764 | "source": [ 765 | "from huggingface_hub import login\n", 766 | "# You need to use your Hugging Face Access Tokens\n", 767 | "login(\"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\")\n", 768 | "# Push the model to Hugging Face. This takes minutes and time depends the model size and your\n", 769 | "# network speed.\n", 770 | "model.push_to_hub(new_model_name, use_temp_dir=False)\n", 771 | "tokenizer.push_to_hub(new_model_name, use_temp_dir=False)" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "id": "cdcfc34f-14e5-4aa2-b820-a428a993da06", 777 | "metadata": {}, 778 | "source": [ 779 | "Now you can test with the base model (original) and your fine-tuned model." 780 | ] 781 | }, 782 | { 783 | "cell_type": "markdown", 784 | "id": "dd52a8df-4841-4b0f-80ba-e868e2a76db0", 785 | "metadata": {}, 786 | "source": [ 787 | "- Base model:" 788 | ] 789 | }, 790 | { 791 | "cell_type": "raw", 792 | "id": "fdf6b408-deca-4e08-833c-c860e2d6b76c", 793 | "metadata": {}, 794 | "source": [ 795 | "# Generate text using base model\n", 796 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 797 | "text_gen = pipeline(task=\"text-generation\", model=base_model_name, tokenizer=llama_tokenizer, max_length=200)\n", 798 | "output = text_gen(f\"[INST] {query} [/INST]\")\n", 799 | "print(output[0]['generated_text'])" 800 | ] 801 | }, 802 | { 803 | "cell_type": "markdown", 804 | "id": "dfb8190f-4c40-48ed-baee-525472f965f0", 805 | "metadata": {}, 806 | "source": [ 807 | "- Fine-tuned model:" 808 | ] 809 | }, 810 | { 811 | "cell_type": "raw", 812 | "id": "f729432f-d499-4878-aac7-bfea303ca6d6", 813 | "metadata": {}, 814 | "source": [ 815 | "# Generate text using fine-tuned model\n", 816 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 817 | "text_gen = pipeline(task=\"text-generation\", model=new_model_name, tokenizer=llama_tokenizer, max_length=200)\n", 818 | "output = text_gen(f\"[INST] {query} [/INST]\")\n", 819 | "print(output[0]['generated_text'])" 820 | ] 821 | }, 822 | { 823 | "cell_type": "markdown", 824 | "id": "c9fb081a-cf75-4f02-89d4-74e12124dc36", 825 | "metadata": {}, 826 | "source": [ 827 | "You can observe the outputs of the two models based on a given query. These outputs exhibit slight differences due to the fine-tuning process altering the model weights." 828 | ] 829 | } 830 | ], 831 | "metadata": { 832 | "kernelspec": { 833 | "display_name": "Python 3 (ipykernel)", 834 | "language": "python", 835 | "name": "python3" 836 | }, 837 | "language_info": { 838 | "codemirror_mode": { 839 | "name": "ipython", 840 | "version": 3 841 | }, 842 | "file_extension": ".py", 843 | "mimetype": "text/x-python", 844 | "name": "python", 845 | "nbconvert_exporter": "python", 846 | "pygments_lexer": "ipython3", 847 | "version": "3.9.19" 848 | } 849 | }, 850 | "nbformat": 4, 851 | "nbformat_minor": 5 852 | } 853 | -------------------------------------------------------------------------------- /training/W7900_QLoRA_Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c73b0caa-096b-45fe-b26f-032128d4334f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-tune Llama 2 with QLoRA by AMD Radeon Pro W7900" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "53d94b31-35f8-4c8c-af0a-8a10aa5b4c62", 14 | "metadata": {}, 15 | "source": [ 16 | "In this blog, we show you how to QLoRA fine-tune Llama 2 on one AMD Radeon Pro W7900 GPU(48GB GDDR) with ROCm.\n", 17 | "\n", 18 | "This blog is refer to https://rocm.blogs.amd.com/artificial-intelligence/llama2-Qlora/README.html which give some technical background about Llama2, Fine-tuning, LoRa and run the LoRA finetuning by AMD MI250 GPU. Here let's jump to the steps of fine-tuning by AMD Radeon Pro W7900 GPU.\n", 19 | "\n", 20 | "## Step-by-step Llama 2 fine-tuning with QLoRA\n", 21 | "\n", 22 | "This section will guide you through the steps to fine-tune the Llama 2 model, which has 7 billion parameters, on a single AMD GPU. The key to this accomplishment lies in the crucial support of QLoRA, which plays an indispensable role in efficiently reducing memory requirements.\n", 23 | "\n", 24 | "Our setup:\n", 25 | "\n", 26 | "- Hardware: AMD Radeon Pro W7900\n", 27 | "- Software:\n", 28 | " - ROCm 6.0+\n", 29 | " - Pytorch 2.0.1+\n", 30 | "\n", 31 | "Libraries: transformers, accelerate, peft, trl, bitsandbytes, scipy" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "id": "4926a00e-7805-4de6-bb72-43db16ac09a2", 38 | "metadata": {}, 39 | "source": [ 40 | "### Step 0: Setup ROCm environment\n", 41 | "\n", 42 | "The easyway is to use ROCm docker image from https://hub.docker.com/r/rocm/pytorch. I use TAG rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2.\n", 43 | "\n", 44 | "$docker pull rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 45 | "\n", 46 | "And here is my docker start command as your reference.\n", 47 | "\n", 48 | "```\n", 49 | "$alias drun='docker run -it --network=host --device=/dev/kfd --device=/dev/dri/renderD128 --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 8G --hostname=w7900 -p 80:80 -p 8080:8080 -v /DATA:/DATA -w /DATA'\n", 50 | "\n", 51 | "$drun rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2\n", 52 | "```\n", 53 | "\n", 54 | "To run this jupyter notebook, you may install it by `pip install jupyter-lab`" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "d7b72aff-616a-442c-91a2-b9974ab95255", 60 | "metadata": {}, 61 | "source": [ 62 | "### Step 1: Getting started\n", 63 | "\n", 64 | "First, let’s confirm the availability of the GPU." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 1, 70 | "id": "4f4a7d77-c810-44ba-b70d-4661139f1cf9", 71 | "metadata": { 72 | "scrolled": true 73 | }, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "\n", 80 | "\n", 81 | "============================ ROCm System Management Interface ============================\n", 82 | "====================================== Product Info ======================================\n", 83 | "GPU[0]\t\t: Card series: \t\t0x7448\n", 84 | "GPU[0]\t\t: Card model: \t\t0x0e0d\n", 85 | "GPU[0]\t\t: Card vendor: \t\tAdvanced Micro Devices, Inc. [AMD/ATI]\n", 86 | "GPU[0]\t\t: Card SKU: \t\tD7070100\n", 87 | "==========================================================================================\n", 88 | "================================== End of ROCm SMI Log ===================================\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "!rocm-smi --showproductname" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "983ea9f6-1d12-4d99-9135-00d79e4dd1ef", 99 | "metadata": {}, 100 | "source": [ 101 | "Next, install the required libraries." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 2, 107 | "id": "41310ccd-6dd8-4222-81b4-143d79a377ec", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "!pip install -q pandas peft==0.9.0 transformers==4.31.0 trl==0.4.7 accelerate scipy" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "21e8dc50-0464-44c5-8931-0a6dfe812f79", 117 | "metadata": {}, 118 | "source": [ 119 | "#### Install bitsandbytes\n", 120 | "1. Install bitsandbytes using the following code.\n", 121 | "\n", 122 | "- For ROCm 6.2" 123 | ] 124 | }, 125 | { 126 | "cell_type": "raw", 127 | "id": "dda99cff-bd74-4dbe-af9a-34d53001ef20", 128 | "metadata": {}, 129 | "source": [ 130 | "# Install `bitsandbytes`\n", 131 | "git clone --recurse https://github.com/ROCm/bitsandbytes.git\n", 132 | "cd bitsandbytes\n", 133 | "git checkout rocm6.2_internal_testing\n", 134 | "make hip\n", 135 | "python setup.py install" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "69a0abf0-afad-4cc7-a46c-f832f107e1d7", 141 | "metadata": {}, 142 | "source": [ 143 | "2. Check the bitsandbytes version (0.42.0)." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 3, 149 | "id": "948e866c-609e-492d-9a51-fd9760f68759", 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "bitsandbytes 0.42.0\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "%%bash\n", 162 | "pip list | grep bitsandbytes" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "85eb3982-f46e-47e5-b4f7-f9fbf873a2fc", 168 | "metadata": {}, 169 | "source": [ 170 | "#### Import the required packages" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 4, 176 | "id": "26886732-b369-495f-8b6b-decdf0564219", 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stderr", 181 | "output_type": "stream", 182 | "text": [ 183 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 184 | " from .autonotebook import tqdm as notebook_tqdm\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "import torch\n", 190 | "from datasets import load_dataset\n", 191 | "from transformers import (\n", 192 | " AutoModelForCausalLM,\n", 193 | " AutoTokenizer,\n", 194 | " BitsAndBytesConfig,\n", 195 | " TrainingArguments,\n", 196 | " pipeline\n", 197 | ")\n", 198 | "from peft import LoraConfig\n", 199 | "from trl import SFTTrainer" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "441d747b-4ef6-4969-9d15-f9834a5ee6bb", 205 | "metadata": {}, 206 | "source": [ 207 | "### Step 2: Configuring the model and data\n", 208 | "You can access Meta’s official Llama-2 model from Hugging Face after making a request, which can take a couple of days. Instead of waiting, we’ll use NousResearch’s Llama-2-7b-chat-hf as our base model (it’s the same as the original, but quicker to access). I downloaded it into /DATA/NousResearch/Llama-2-7b-chat-hf/ of my machine ahead." 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 5, 214 | "id": "3c916dcd-fc94-4214-895a-9720ad3ec3ea", 215 | "metadata": { 216 | "scrolled": true 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "# Model and tokenizer names\n", 221 | "base_model_name = \"/DATA/NousResearch/Llama-2-7b-chat-hf/\"\n", 222 | "new_model_name = \"llama-2-7b-chat-enhanced\" #You can give your own name for fine tuned model\n", 223 | "\n", 224 | "# Tokenizer\n", 225 | "llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 226 | "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n", 227 | "llama_tokenizer.padding_side = \"right\"" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "id": "a64ff51f-97d4-455c-9525-5bfdd70d6c49", 233 | "metadata": {}, 234 | "source": [ 235 | "#### QLoRA 4-bit quantization configuration\n", 236 | "\n", 237 | "As outlined in the paper, QLoRA stores weights in 4-bits, allowing computation to occur in 16 or 32-bit precision. This means whenever a QLoRA weight tensor is used, we dequantize the tensor to 16 or 32-bit precision, and then perform a matrix multiplication. Various combinations, such as float16, bfloat16, float32, etc., can be chosen. Experimentation with different 4-bit quantization variants, including normalized float 4 (NF4), or pure float4 quantization, is possible. However, guided by theoretical considerations and empirical findings from the paper, the recommendation is to opt for NF4 quantization, as it tends to deliver better performance.\n", 238 | "\n", 239 | "In our case, we chose the following configuration:\n", 240 | "\n", 241 | "- 4-bit quantization with NF4 type\n", 242 | "\n", 243 | "- 16-bit (float16) for computation\n", 244 | "\n", 245 | "- Double quantization, which uses a second quantization after the first one to save an additional 0.3 bits per parameters\n", 246 | "\n", 247 | "Quantization parameters are controlled from the BitsandbytesConfig (see Hugging Face documentation) as follows:\n", 248 | "\n", 249 | "- Loading in 4 bits is activated through load_in_4bit\n", 250 | "\n", 251 | "- The datatype used for quantization is specified with bnb_4bit_quant_type. Note that there are two supported quantization datatypes fp4 (four-bit float) and nf4 (normal four-bit float). The latter is theoretically optimal for normally distributed weights, so we recommend using nf4.\n", 252 | "\n", 253 | "- The datatype used for the linear layer computations with bnb_4bit_compute_dtype\n", 254 | "\n", 255 | "- Nested quantization is activated through bnb_4bit_use_double_quant" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 6, 261 | "id": "2b5d1f2f-cad3-4da2-a72a-5e113d8e94fb", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "# Quantization Config\n", 266 | "quant_config = BitsAndBytesConfig(\n", 267 | " load_in_4bit=True,\n", 268 | " bnb_4bit_quant_type=\"nf4\",\n", 269 | " bnb_4bit_compute_dtype=torch.float16,\n", 270 | " bnb_4bit_use_double_quant=True\n", 271 | ")" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "id": "e8cb2ce8-4278-4c63-bdc4-36189d60c9ab", 277 | "metadata": {}, 278 | "source": [ 279 | "Load the model and set the quantization configuration." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 7, 285 | "id": "0b38d01d-8ebe-4051-b2f2-dd32d29767fb", 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "name": "stderr", 290 | "output_type": "stream", 291 | "text": [ 292 | "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:14<00:00, 7.41s/it]\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 298 | " base_model_name,\n", 299 | " quantization_config=quant_config,\n", 300 | " device_map=\"auto\"\n", 301 | ")\n", 302 | "base_model.config.use_cache = False\n", 303 | "base_model.config.pretraining_tp = 1" 304 | ] 305 | }, 306 | { 307 | "cell_type": "raw", 308 | "id": "e2daf4e0-7aab-4b25-a6d9-bbf0b3f9c318", 309 | "metadata": {}, 310 | "source": [ 311 | "$ rocm-smi\n", 312 | "\n", 313 | "========================================== ROCm System Management Interface ==========================================\n", 314 | "==================================================== Concise Info ====================================================\n", 315 | "Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%\n", 316 | " Name (20 chars) (Edge) (Avg) (Mem, Compute)\n", 317 | "======================================================================================================================\n", 318 | "0 [0x0e0d : 0x00] 42.0°C 30.0W N/A, N/A 29Mhz 96Mhz 20.0% auto 241.0W 9% 0%\n", 319 | " 0x7448\n", 320 | "======================================================================================================================\n", 321 | "================================================ End of ROCm SMI Log =================================================\n", 322 | "\n" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "id": "60e40797-0284-4948-96e6-f0cd02595d48", 328 | "metadata": {}, 329 | "source": [ 330 | "#### Dataset configuration\n", 331 | "We fine-tune our base model for a question-and-answer task using a small data set called mlabonne/guanaco-llama2-1k, which is a subset (1,000 samples) of the timdettmers/openassistant-guanaco data set. This data set is a human-generated, human-annotated, assistant-style conversation corpus that contains 161,443 messages in 35 different languages, annotated with 461,292 quality ratings. This results in over 10,000 fully annotated conversation trees." 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 8, 337 | "id": "235355d2-634d-4444-8497-71058a1e473a", 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "name": "stdout", 342 | "output_type": "stream", 343 | "text": [ 344 | "(1000, 1)\n", 345 | "{'text': '[INST] write me a 1000 words essay about deez nuts. [/INST] The Deez Nuts meme first gained popularity in 2015 on the social media platform Vine. The video featured a young man named Rodney Bullard, who recorded himself asking people if they had heard of a particular rapper. When they responded that they had not, he would respond with the phrase \"Deez Nuts\" and film their reactions. The video quickly went viral, and the phrase became a popular meme. \\n\\nSince then, Deez Nuts has been used in a variety of contexts to interrupt conversations, derail discussions, or simply add humor to a situation. It has been used in internet memes, in popular music, and even in politics. In the 2016 US presidential election, a 15-year-old boy named Brady Olson registered as an independent candidate under the name Deez Nuts. He gained some traction in the polls and even made appearances on national news programs.\\n\\nThe Deez Nuts meme has had a significant impact on popular culture. It has become a recognizable catchphrase that people use to add humor to everyday conversations. The meme has also been used to satirize politics and other serious issues. For example, in 2016, a group of activists in the UK used the phrase \"Deez Nuts for President\" as part of a campaign to encourage young people to vote in the EU referendum. [INST] Rewrite the essay in a more casual way. Instead of sounding proffesional, sound like a college student who is forced to write the essay but refuses to do so in the propper way. Use casual words and slang when possible. [/INST] Yo, so you want me to write a 1000-word essay about Deez Nuts? Alright, fine. So, this whole thing started on Vine back in 2015. Some dude named Rodney Bullard made a video where he would ask people if they knew a rapper, and when they said no, he would hit them with the classic line: \"Deez Nuts!\" People loved it, and it became a viral meme.\\n\\nNowadays, Deez Nuts is used for all kinds of stuff. You can throw it out there to interrupt someone or just to be funny. It\\'s all over the internet, in music, and even in politics. In fact, during the 2016 US presidential election, a kid named Brady Olson registered as an independent candidate under the name Deez Nuts. He actually got some attention from the media and made appearances on TV and everything.\\n\\nThe impact of Deez Nuts on our culture is pretty huge. It\\'s become a thing that everyone knows and uses to add some humor to their everyday conversations. Plus, people have used it to make fun of politics and serious issues too. Like, in the UK, some groups of activists used the phrase \"Deez Nuts for President\" to encourage young people to vote in the EU referendum.\\n\\nThere you have it, a thousand words about Deez Nuts in a more casual tone. Can I go back to playing video games now? '}\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "# Dataset\n", 351 | "data_name = \"mlabonne/guanaco-llama2-1k\"\n", 352 | "training_data = load_dataset(data_name, split=\"train\")\n", 353 | "# check the data\n", 354 | "print(training_data.shape)\n", 355 | "# #11 is a QA sample in English\n", 356 | "print(training_data[11])" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 9, 362 | "id": "076c8fa2-5b76-41ae-b5f3-b85e64faa368", 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "Requirement already satisfied: tensorboardX in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (2.6.2.2)\n", 370 | "Requirement already satisfied: packaging in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (from tensorboardX) (24.0)\n", 371 | "Requirement already satisfied: numpy in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (from tensorboardX) (1.22.4)\n", 372 | "Requirement already satisfied: protobuf>=3.20 in /opt/conda/envs/py_3.9/lib/python3.9/site-packages (from tensorboardX) (3.20.2)\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "## There is a dependency during training\n", 378 | "!pip install tensorboardX" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "id": "4ce556ea-9525-46ab-bcf2-1fe2e8c319d2", 384 | "metadata": {}, 385 | "source": [ 386 | "### Step 3: Start fine-tuning\n", 387 | "To set your training parameters, use the following code:" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 10, 393 | "id": "d525b641-c645-4987-9ae6-173d6a75e292", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "# Training Params\n", 398 | "train_params = TrainingArguments(\n", 399 | " output_dir=\"./results_modified\",\n", 400 | " num_train_epochs=1,\n", 401 | " per_device_train_batch_size=4,\n", 402 | " gradient_accumulation_steps=1,\n", 403 | " optim=\"paged_adamw_32bit\",\n", 404 | " save_steps=50,\n", 405 | " logging_steps=50,\n", 406 | " learning_rate=2e-4,\n", 407 | " weight_decay=0.001,\n", 408 | " fp16=False,\n", 409 | " bf16=False,\n", 410 | " max_grad_norm=0.3,\n", 411 | " max_steps=-1,\n", 412 | " warmup_ratio=0.03,\n", 413 | " group_by_length=True,\n", 414 | " lr_scheduler_type=\"constant\",\n", 415 | " report_to=\"tensorboard\"\n", 416 | ") " 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "id": "6f35f6b7-4a2f-467c-b1fd-e53bf7dcd837", 422 | "metadata": {}, 423 | "source": [ 424 | " I got OOM with per_device_train_batch_size=2 at AMD Radeon Pro W7900 with 48GB VRAM. You will see the VRAM usage bellow when run the LoRA finetuning." 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "id": "d0d9d33c-95b2-467c-8212-96a7810b5e3c", 430 | "metadata": {}, 431 | "source": [ 432 | "#### Training with QLoRA configuration\n", 433 | "Now you can integrate LoRA into the base model and assess its additional parameters. LoRA essentially adds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, and only trains the newly added weights." 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 11, 439 | "id": "6f3eecd4-a968-42dd-89e2-f94b4633ecb7", 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "name": "stdout", 444 | "output_type": "stream", 445 | "text": [ 446 | "trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199\n" 447 | ] 448 | } 449 | ], 450 | "source": [ 451 | "from peft import get_peft_model\n", 452 | "# LoRA Config\n", 453 | "peft_parameters = LoraConfig(\n", 454 | " lora_alpha=8,\n", 455 | " lora_dropout=0.1,\n", 456 | " r=8,\n", 457 | " bias=\"none\",\n", 458 | " task_type=\"CAUSAL_LM\"\n", 459 | ")\n", 460 | "model = get_peft_model(base_model, peft_parameters)\n", 461 | "model.print_trainable_parameters()" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "id": "9ab9baef-c9a4-4438-aa88-cbcdcec199b9", 467 | "metadata": {}, 468 | "source": [ 469 | "Note that there are only 0.062% parameters added by LoRA, which is a tiny portion of the original model. This is the percentage we’ll update through fine-tuning, as follows." 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 12, 475 | "id": "d21483ad-5ad7-43fd-87d1-2d383af73f7d", 476 | "metadata": { 477 | "scrolled": true 478 | }, 479 | "outputs": [ 480 | { 481 | "name": "stderr", 482 | "output_type": "stream", 483 | "text": [ 484 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/other.py:145: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n", 485 | " warnings.warn(\n", 486 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/trl/trainer/sft_trainer.py:159: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n", 487 | " warnings.warn(\n", 488 | "You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n", 489 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", 490 | " warnings.warn(\n" 491 | ] 492 | }, 493 | { 494 | "data": { 495 | "text/html": [ 496 | "\n", 497 | "

\n", 498 | " \n", 499 | " \n", 500 | " [250/250 19:44, Epoch 1/1]\n", 501 | "
\n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | "
StepTraining Loss
501.576900
1001.350100
1501.274400
2001.319300
2501.341000

" 532 | ], 533 | "text/plain": [ 534 | "" 535 | ] 536 | }, 537 | "metadata": {}, 538 | "output_type": "display_data" 539 | }, 540 | { 541 | "name": "stderr", 542 | "output_type": "stream", 543 | "text": [ 544 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 545 | " warnings.warn(\n", 546 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 547 | " warnings.warn(\n", 548 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", 549 | " warnings.warn(\n", 550 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 551 | " warnings.warn(\n", 552 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", 553 | " warnings.warn(\n", 554 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 555 | " warnings.warn(\n", 556 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", 557 | " warnings.warn(\n", 558 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 559 | " warnings.warn(\n", 560 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", 561 | " warnings.warn(\n", 562 | "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in /DATA/NousResearch/Llama-2-7b-chat-hf/ - will assume that the vocabulary was not modified.\n", 563 | " warnings.warn(\n" 564 | ] 565 | }, 566 | { 567 | "data": { 568 | "text/plain": [ 569 | "TrainOutput(global_step=250, training_loss=1.3723441772460938, metrics={'train_runtime': 1195.1931, 'train_samples_per_second': 0.837, 'train_steps_per_second': 0.209, 'total_flos': 8679674339426304.0, 'train_loss': 1.3723441772460938, 'epoch': 1.0})" 570 | ] 571 | }, 572 | "execution_count": 12, 573 | "metadata": {}, 574 | "output_type": "execute_result" 575 | } 576 | ], 577 | "source": [ 578 | "# Trainer with LoRA configuration\n", 579 | "fine_tuning = SFTTrainer(\n", 580 | " model=base_model,\n", 581 | " train_dataset=training_data,\n", 582 | " peft_config=peft_parameters,\n", 583 | " dataset_text_field=\"text\",\n", 584 | " tokenizer=llama_tokenizer,\n", 585 | " args=train_params\n", 586 | ")\n", 587 | "\n", 588 | "# Training\n", 589 | "fine_tuning.train()" 590 | ] 591 | }, 592 | { 593 | "cell_type": "markdown", 594 | "id": "221a1781-648a-4902-b00e-096a71750d6f", 595 | "metadata": {}, 596 | "source": [ 597 | "The output looks like this:" 598 | ] 599 | }, 600 | { 601 | "cell_type": "raw", 602 | "id": "f89b1c04-8c7d-4dd5-9961-1e020cd3a2d0", 603 | "metadata": {}, 604 | "source": [ 605 | "[250/250 07:59, Epoch 1/1]\\\n", 606 | "Step Training Loss \\\n", 607 | "50 1.576900 \\\n", 608 | "100 1.350100\\\n", 609 | "150 1.274400\\\n", 610 | "200 1.319300\\\n", 611 | "250 1.341000\\\n", 612 | "\n", 613 | "\n", 614 | "TrainOutput(global_step=250, training_loss=1.3723441772460938, metrics={'train_runtime': 1195.1931, 'train_samples_per_second': 0.837, 'train_steps_per_second': 0.209, 'total_flos': 8679674339426304.0, 'train_loss': 1.3723441772460938, 'epoch': 1.0})" 615 | ] 616 | }, 617 | { 618 | "cell_type": "raw", 619 | "id": "38005e08-f1e8-4cde-a139-b18e74e41bb8", 620 | "metadata": {}, 621 | "source": [ 622 | "To save your model, run this code:" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 11, 628 | "id": "2c0830b9-5f74-4d6f-bb33-355332c7e499", 629 | "metadata": {}, 630 | "outputs": [], 631 | "source": [ 632 | "# Save Model\n", 633 | "fine_tuning.model.save_pretrained(new_model_name)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "id": "996f3834-0c06-4907-83cf-8f8ff9348f88", 639 | "metadata": {}, 640 | "source": [ 641 | "#### Checking memory usage during training with LoRA\n", 642 | "During training, you can check the memory usage by running the rocm-smi command in a terminal. This command produces the following output:" 643 | ] 644 | }, 645 | { 646 | "cell_type": "raw", 647 | "id": "9a05e31c-b5b5-4e58-ab2a-5b8004a3bec3", 648 | "metadata": {}, 649 | "source": [ 650 | "# per_device_train_batch_size=4\n", 651 | "\n", 652 | "$ rocm-smi\n", 653 | "\n", 654 | "============================================ ROCm System Management Interface ============================================\n", 655 | "====================================================== Concise Info ======================================================\n", 656 | "Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%\n", 657 | " Name (20 chars) (Edge) (Avg) (Mem, Compute)\n", 658 | "==========================================================================================================================\n", 659 | "0 [0x0e0d : 0x00] 77.0°C 241.0W N/A, N/A 1792Mhz 1124Mhz 48.63% auto 241.0W 24% 100%\n", 660 | " 0x7448\n", 661 | "==========================================================================================================================\n", 662 | "================================================== End of ROCm SMI Log ===================================================\n" 663 | ] 664 | }, 665 | { 666 | "cell_type": "markdown", 667 | "id": "f42cf9ef-b529-4053-8f8e-f7cc80f34c33", 668 | "metadata": {}, 669 | "source": [ 670 | "To facilitate a comparison between fine-tuning with and without LoRA, our subsequent phase involves running a thorough fine-tuning process on the base model. This involves updating all parameters within the base model. We then analyze differences in memory usage, training speed, training loss, and other relevant metrics." 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "id": "f7312872-12a7-4472-8111-a2ff43ad05b5", 676 | "metadata": {}, 677 | "source": [ 678 | "### Step 4: Test the fine-tuned model with QLoRA\n", 679 | "\n", 680 | "To test your model, run the following code:" 681 | ] 682 | }, 683 | { 684 | "cell_type": "raw", 685 | "id": "4d6af907-b55b-4e5f-9b75-519fdd0c5ee0", 686 | "metadata": {}, 687 | "source": [ 688 | "# Reload model in FP16 and merge it with LoRA weights\n", 689 | "base_model = AutoModelForCausalLM.from_pretrained(\n", 690 | " base_model_name,\n", 691 | " low_cpu_mem_usage=True,\n", 692 | " return_dict=True,\n", 693 | " torch_dtype=torch.float16,\n", 694 | " device_map=\"auto\"\n", 695 | ")\n", 696 | "from peft import LoraConfig, PeftModel\n", 697 | "model = PeftModel.from_pretrained(base_model, new_model_name)\n", 698 | "model = model.merge_and_unload()\n", 699 | "\n", 700 | "# Reload tokenizer to save it\n", 701 | "tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", 702 | "tokenizer.pad_token = tokenizer.eos_token\n", 703 | "tokenizer.padding_side = \"right\"" 704 | ] 705 | }, 706 | { 707 | "cell_type": "markdown", 708 | "id": "3da05da4-faf6-4538-8894-30a9dd0bb427", 709 | "metadata": {}, 710 | "source": [ 711 | "The output looks like this:" 712 | ] 713 | }, 714 | { 715 | "cell_type": "raw", 716 | "id": "b7789dfa-f685-4c9d-9ecb-66a88d1e2705", 717 | "metadata": {}, 718 | "source": [ 719 | " Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.34s/it]" 720 | ] 721 | }, 722 | { 723 | "cell_type": "markdown", 724 | "id": "f259219a-c28d-4404-aca1-85a61816cbb5", 725 | "metadata": {}, 726 | "source": [ 727 | "Uploading the model to Hugging Face let’s you conduct subsequent tests or share your model with others (to proceed with this step, you’ll need an active Hugging Face account)." 728 | ] 729 | }, 730 | { 731 | "cell_type": "raw", 732 | "id": "3a4cb5c0-0db9-4459-bfbe-91cfb1dedce6", 733 | "metadata": {}, 734 | "source": [ 735 | "from huggingface_hub import login\n", 736 | "# You need to use your Hugging Face Access Tokens\n", 737 | "login(\"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\")\n", 738 | "# Push the model to Hugging Face. This takes minutes and time depends the model size and your\n", 739 | "# network speed.\n", 740 | "model.push_to_hub(new_model_name, use_temp_dir=False)\n", 741 | "tokenizer.push_to_hub(new_model_name, use_temp_dir=False)" 742 | ] 743 | }, 744 | { 745 | "cell_type": "markdown", 746 | "id": "cdcfc34f-14e5-4aa2-b820-a428a993da06", 747 | "metadata": {}, 748 | "source": [ 749 | "Now you can test with the base model (original) and your fine-tuned model." 750 | ] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "id": "dd52a8df-4841-4b0f-80ba-e868e2a76db0", 755 | "metadata": {}, 756 | "source": [ 757 | "- Base model:" 758 | ] 759 | }, 760 | { 761 | "cell_type": "raw", 762 | "id": "fdf6b408-deca-4e08-833c-c860e2d6b76c", 763 | "metadata": {}, 764 | "source": [ 765 | "# Generate text using base model\n", 766 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 767 | "text_gen = pipeline(task=\"text-generation\", model=base_model_name, tokenizer=llama_tokenizer, max_length=200)\n", 768 | "output = text_gen(f\"[INST] {query} [/INST]\")\n", 769 | "print(output[0]['generated_text'])" 770 | ] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "id": "dfb8190f-4c40-48ed-baee-525472f965f0", 775 | "metadata": {}, 776 | "source": [ 777 | "- Fine-tuned model:" 778 | ] 779 | }, 780 | { 781 | "cell_type": "raw", 782 | "id": "f729432f-d499-4878-aac7-bfea303ca6d6", 783 | "metadata": {}, 784 | "source": [ 785 | "# Generate text using fine-tuned model\n", 786 | "query = \"What do you think is the most important part of building an AI chatbot?\"\n", 787 | "text_gen = pipeline(task=\"text-generation\", model=new_model_name, tokenizer=llama_tokenizer, max_length=200)\n", 788 | "output = text_gen(f\"[INST] {query} [/INST]\")\n", 789 | "print(output[0]['generated_text'])" 790 | ] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "id": "c9fb081a-cf75-4f02-89d4-74e12124dc36", 795 | "metadata": {}, 796 | "source": [ 797 | "You can observe the outputs of the two models based on a given query. These outputs exhibit slight differences due to the fine-tuning process altering the model weights." 798 | ] 799 | } 800 | ], 801 | "metadata": { 802 | "kernelspec": { 803 | "display_name": "Python 3 (ipykernel)", 804 | "language": "python", 805 | "name": "python3" 806 | }, 807 | "language_info": { 808 | "codemirror_mode": { 809 | "name": "ipython", 810 | "version": 3 811 | }, 812 | "file_extension": ".py", 813 | "mimetype": "text/x-python", 814 | "name": "python", 815 | "nbconvert_exporter": "python", 816 | "pygments_lexer": "ipython3", 817 | "version": "3.9.19" 818 | } 819 | }, 820 | "nbformat": 4, 821 | "nbformat_minor": 5 822 | } 823 | -------------------------------------------------------------------------------- /training/run_lora.py: -------------------------------------------------------------------------------- 1 | # https://rocm.blogs.amd.com/artificial-intelligence/llama2-lora/README.html 2 | 3 | 4 | #Step1 5 | 6 | 7 | import torch 8 | from datasets import load_dataset 9 | from transformers import ( 10 | AutoModelForCausalLM, 11 | AutoTokenizer, 12 | BitsAndBytesConfig, 13 | TrainingArguments, 14 | pipeline 15 | ) 16 | from peft import LoraConfig 17 | from trl import SFTTrainer 18 | 19 | # Step 2: Configuring the model and data 20 | # Model and tokenizer names 21 | #base_model_name = "NousResearch/Llama-2-7b-chat-hf" 22 | base_model_name = "/DATA/NousResearch/Llama-2-7b-hf/" 23 | new_model_name = "llama-2-7b-enhanced" #You can give your own name for fine tuned model 24 | 25 | # Tokenizer 26 | llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) 27 | llama_tokenizer.pad_token = llama_tokenizer.eos_token 28 | llama_tokenizer.padding_side = "right" 29 | 30 | # Model 31 | base_model = AutoModelForCausalLM.from_pretrained( 32 | base_model_name, 33 | device_map="auto" 34 | ) 35 | base_model.config.use_cache = False 36 | base_model.config.pretraining_tp = 1 37 | 38 | # Data set 39 | data_name = "mlabonne/guanaco-llama2-1k" 40 | training_data = load_dataset(data_name, split="train") 41 | # check the data 42 | print(training_data.shape) 43 | # #11 is a QA sample in English 44 | print(training_data[11]) 45 | 46 | #Step 3: Start fine-tuning 47 | # Training Params 48 | train_params = TrainingArguments( 49 | output_dir="./results_modified", 50 | num_train_epochs=1, 51 | #per_device_train_batch_size=4, 52 | per_device_train_batch_size=1, 53 | gradient_accumulation_steps=1, 54 | optim="paged_adamw_32bit", 55 | save_steps=50, 56 | logging_steps=50, 57 | learning_rate=4e-5, 58 | weight_decay=0.001, 59 | fp16=False, 60 | bf16=False, 61 | max_grad_norm=0.3, 62 | max_steps=-1, 63 | warmup_ratio=0.03, 64 | group_by_length=True, 65 | lr_scheduler_type="constant", 66 | report_to="tensorboard" 67 | ) 68 | 69 | from peft import get_peft_model 70 | # LoRA Config 71 | peft_parameters = LoraConfig( 72 | lora_alpha=8, 73 | lora_dropout=0.1, 74 | r=8, 75 | bias="none", 76 | task_type="CAUSAL_LM" 77 | ) 78 | model = get_peft_model(base_model, peft_parameters) 79 | model.print_trainable_parameters() 80 | 81 | # Trainer with LoRA configuration 82 | fine_tuning = SFTTrainer( 83 | model=base_model, 84 | train_dataset=training_data, 85 | peft_config=peft_parameters, 86 | dataset_text_field="text", 87 | tokenizer=llama_tokenizer, 88 | args=train_params 89 | ) 90 | 91 | # Training 92 | fine_tuning.train() 93 | 94 | # Save Model 95 | fine_tuning.model.save_pretrained(new_model_name) 96 | -------------------------------------------------------------------------------- /training/run_qlora_bs4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset 3 | from transformers import ( 4 | AutoModelForCausalLM, 5 | AutoTokenizer, 6 | BitsAndBytesConfig, 7 | TrainingArguments, 8 | pipeline 9 | ) 10 | from peft import LoraConfig 11 | from trl import SFTTrainer 12 | 13 | 14 | # Model and tokenizer names 15 | base_model_name = "/DATA/NousResearch/Llama-2-7b-chat-hf/" 16 | new_model_name = "llama-2-7b-chat-enhanced" #You can give your own name for fine tuned model 17 | 18 | # Tokenizer 19 | llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) 20 | llama_tokenizer.pad_token = llama_tokenizer.eos_token 21 | llama_tokenizer.padding_side = "right" 22 | 23 | 24 | # Quantization Config 25 | quant_config = BitsAndBytesConfig( 26 | load_in_4bit=True, 27 | bnb_4bit_quant_type="nf4", 28 | bnb_4bit_compute_dtype=torch.float16, 29 | bnb_4bit_use_double_quant=True 30 | ) 31 | 32 | 33 | base_model = AutoModelForCausalLM.from_pretrained( 34 | base_model_name, 35 | quantization_config=quant_config, 36 | device_map="auto" 37 | ) 38 | base_model.config.use_cache = False 39 | base_model.config.pretraining_tp = 1 40 | 41 | 42 | # Dataset 43 | data_name = "mlabonne/guanaco-llama2-1k" 44 | training_data = load_dataset(data_name, split="train") 45 | # check the data 46 | print(training_data.shape) 47 | # #11 is a QA sample in English 48 | print(training_data[11]) 49 | 50 | 51 | # Training Params 52 | train_params = TrainingArguments( 53 | output_dir="./results_modified", 54 | num_train_epochs=1, 55 | per_device_train_batch_size=4, 56 | gradient_accumulation_steps=1, 57 | optim="paged_adamw_32bit", 58 | save_steps=50, 59 | logging_steps=50, 60 | learning_rate=2e-4, 61 | weight_decay=0.001, 62 | fp16=False, 63 | bf16=False, 64 | max_grad_norm=0.3, 65 | max_steps=-1, 66 | warmup_ratio=0.03, 67 | group_by_length=True, 68 | lr_scheduler_type="constant", 69 | report_to="tensorboard" 70 | ) 71 | 72 | 73 | from peft import get_peft_model 74 | # LoRA Config 75 | peft_parameters = LoraConfig( 76 | lora_alpha=8, 77 | lora_dropout=0.1, 78 | r=8, 79 | bias="none", 80 | task_type="CAUSAL_LM" 81 | ) 82 | model = get_peft_model(base_model, peft_parameters) 83 | model.print_trainable_parameters() 84 | 85 | 86 | # Trainer with LoRA configuration 87 | fine_tuning = SFTTrainer( 88 | model=base_model, 89 | train_dataset=training_data, 90 | peft_config=peft_parameters, 91 | dataset_text_field="text", 92 | tokenizer=llama_tokenizer, 93 | args=train_params 94 | ) 95 | 96 | # Training 97 | fine_tuning.train() 98 | 99 | # Save Model 100 | fine_tuning.model.save_pretrained(new_model_name) 101 | --------------------------------------------------------------------------------