├── .github ├── FUNDING.yml └── ISSUE_TEMPLATE │ ├── bug---issue.md │ └── feature-request.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── images ├── Assistant.png ├── Colab.png ├── Discord button.png ├── Discord.png ├── Documentation Button.png ├── Free version button.png ├── Kaggle.png ├── Kofi button.png ├── LAION 2GPU.png ├── Merge.png ├── Run.png ├── Slim Orca 2GPUs.png ├── Terminal_Type.png ├── Where_Terminal.png ├── buy me a coffee button.png ├── documentation github button.png ├── documentation green button.png ├── documentation lighter.png ├── documentation white button.png ├── made with unsloth.png ├── ollama.png ├── peft x trl button.png ├── start free finetune button.png ├── unsloth end.png ├── unsloth loading page render.png ├── unsloth logo black text.png ├── unsloth logo only.png ├── unsloth logo white text.png ├── unsloth made with love.png ├── unsloth new logo.png └── unsloth sticker.png ├── pyproject.toml ├── tests ├── __init__.py ├── qlora │ ├── README.md │ ├── test_hf_qlora_train_and_merge.py │ └── test_unsloth_qlora_train_and_merge.py ├── saving │ └── test_unsloth_save.py ├── test_model_registry.py └── utils │ ├── __init__.py │ ├── data_utils.py │ └── hf_utils.py ├── unsloth-cli.py └── unsloth ├── __init__.py ├── _auto_install.py ├── chat_templates.py ├── dataprep ├── __init__.py ├── synthetic.py └── synthetic_configs.py ├── kernels ├── __init__.py ├── cross_entropy_loss.py ├── fast_lora.py ├── flex_attention.py ├── geglu.py ├── layernorm.py ├── moe │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── benchmark │ │ ├── benchmark_fused_moe.py │ │ └── utils.py │ ├── grouped_gemm │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── interface.py │ │ ├── kernels │ │ │ ├── __init__.py │ │ │ ├── autotuning.py │ │ │ ├── backward.py │ │ │ ├── forward.py │ │ │ └── tuning.py │ │ └── reference │ │ │ ├── __init__.py │ │ │ ├── layers │ │ │ ├── llama4_moe.py │ │ │ └── qwen3_moe.py │ │ │ ├── moe_block.py │ │ │ └── moe_ops.py │ ├── requirements.txt │ └── tests │ │ ├── __init__.py │ │ ├── common.py │ │ ├── moe_utils.py │ │ ├── run_qwen3_moe_tests.sh │ │ ├── test_grouped_gemm.py │ │ ├── test_llama4_moe.py │ │ └── test_qwen3_moe.py ├── rms_layernorm.py ├── rope_embedding.py ├── swiglu.py └── utils.py ├── models ├── __init__.py ├── _utils.py ├── cohere.py ├── dpo.py ├── gemma.py ├── gemma2.py ├── granite.py ├── llama.py ├── llama4.py ├── loader.py ├── loader_utils.py ├── mapper.py ├── mistral.py ├── qwen2.py ├── qwen3.py ├── qwen3_moe.py ├── rl.py ├── rl_replacements.py └── vision.py ├── registry ├── REGISTRY.md ├── __init__.py ├── _deepseek.py ├── _gemma.py ├── _llama.py ├── _mistral.py ├── _phi.py ├── _qwen.py └── registry.py ├── save.py ├── tokenizer_utils.py ├── trainer.py └── utils ├── __init__.py └── hf_hub.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: unsloth 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug---issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug / Issue 3 | about: Bug / Issue 4 | title: "[Bug]" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | 1. Did you update? `pip install --upgrade unsloth unsloth_zoo` 11 | 2. `Colab` or `Kaggle` or local / cloud 12 | 3. Number GPUs used, use `nvidia-smi` 13 | 4. Which notebook? 14 | 5. Paste `Unsloth` printout with :sloth: sloth emoji 15 | 6. Which trainer? `SFTTrainer`, `GRPOTrainer` etc 16 | 7. **Minimal code to reproduce error Remove Hugging Face token!** 17 | 18 | You can also join our Discord: https://discord.com/invite/unsloth 19 | Have you tried visiting our Docs? https://docs.unsloth.ai/basics/errors-troubleshooting 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: New features, model support, ideas 4 | title: "[Feature]" 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | For new models, have you tried: 11 | ```python 12 | from unsloth import FastModel 13 | model, tokenizer = FastModel.from_pretrained( 14 | "microsoft/Phi-4-multimodal-instruct", 15 | trust_remote_code = True, 16 | ) 17 | from transformers import AutoModelForSequenceClassification 18 | model, tokenizer = FastModel.from_pretrained( 19 | auto_model = AutoModelForSequenceClassification, 20 | ) 21 | ``` 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | .vscode 176 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | feedback@huggingface.co. 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.1, available at 120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 127 | [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 131 | [Mozilla CoC]: https://github.com/mozilla/diversity 132 | [FAQ]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # 🦥 Contributing to Unsloth 2 | 3 | Thank you for not only using Unsloth but also for being interested in helping out! We value all contributions, whether they come in the form of code, ideas, support for others or just by simply spreading the word of Unsloth! 💕 4 | 5 | - **[Support the Community](https://github.com/unslothai/unsloth/issues)**: Answer questions, review pull requests, or assist others in discussions. 6 | - **Fix Bugs**: Identify and resolve issues with the existing codebase. 7 | - **Submit Ideas**: Request new features or share enhancements you'd like to see. 8 | - **Develop Features**: Implement new functionality or improve existing tools which can be done via PRs. 9 | - **[Improve Documentation](https://docs.unsloth.ai/)**: Help by creating guides, FAQs, or enhancing clarity. 10 | 11 | One of the best ways to support us is by spreading the word about Unsloth! Share how it’s powering your amazing projects in blog posts or social media, and inspire others to explore its potential. Even a simple star on our repo goes a long way in showing your support and helping the community grow. 🌟 12 | 13 | ## Submitting Issues 14 | If you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out: 15 | 16 | ### Reporting Bugs 17 | 1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues. 18 | 2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful. 19 | 3. **Be Thorough**: Attach screenshots, traceback logs, or any additional information that might speed up resolution. 20 | 21 | ## Spread the Word 22 | Your support extends beyond code: 23 | - Spread the word by writing about Unsloth in blogs or social media. 24 | - Share how Unsloth powers your projects. 25 | - Star our repository to show your appreciation. 26 | 27 | Finally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/blob/main/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone. 28 | 29 | Thank you so much for reading and we hope you have lots of fun using Unsloth! 🦥 30 | -------------------------------------------------------------------------------- /images/Assistant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Assistant.png -------------------------------------------------------------------------------- /images/Colab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Colab.png -------------------------------------------------------------------------------- /images/Discord button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Discord button.png -------------------------------------------------------------------------------- /images/Discord.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Discord.png -------------------------------------------------------------------------------- /images/Documentation Button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Documentation Button.png -------------------------------------------------------------------------------- /images/Free version button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Free version button.png -------------------------------------------------------------------------------- /images/Kaggle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Kaggle.png -------------------------------------------------------------------------------- /images/Kofi button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Kofi button.png -------------------------------------------------------------------------------- /images/LAION 2GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/LAION 2GPU.png -------------------------------------------------------------------------------- /images/Merge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Merge.png -------------------------------------------------------------------------------- /images/Run.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Run.png -------------------------------------------------------------------------------- /images/Slim Orca 2GPUs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Slim Orca 2GPUs.png -------------------------------------------------------------------------------- /images/Terminal_Type.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Terminal_Type.png -------------------------------------------------------------------------------- /images/Where_Terminal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/Where_Terminal.png -------------------------------------------------------------------------------- /images/buy me a coffee button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/buy me a coffee button.png -------------------------------------------------------------------------------- /images/documentation github button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/documentation github button.png -------------------------------------------------------------------------------- /images/documentation green button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/documentation green button.png -------------------------------------------------------------------------------- /images/documentation lighter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/documentation lighter.png -------------------------------------------------------------------------------- /images/documentation white button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/documentation white button.png -------------------------------------------------------------------------------- /images/made with unsloth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/made with unsloth.png -------------------------------------------------------------------------------- /images/ollama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/ollama.png -------------------------------------------------------------------------------- /images/peft x trl button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/peft x trl button.png -------------------------------------------------------------------------------- /images/start free finetune button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/start free finetune button.png -------------------------------------------------------------------------------- /images/unsloth end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth end.png -------------------------------------------------------------------------------- /images/unsloth loading page render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth loading page render.png -------------------------------------------------------------------------------- /images/unsloth logo black text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth logo black text.png -------------------------------------------------------------------------------- /images/unsloth logo only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth logo only.png -------------------------------------------------------------------------------- /images/unsloth logo white text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth logo white text.png -------------------------------------------------------------------------------- /images/unsloth made with love.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth made with love.png -------------------------------------------------------------------------------- /images/unsloth new logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth new logo.png -------------------------------------------------------------------------------- /images/unsloth sticker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/images/unsloth sticker.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/tests/__init__.py -------------------------------------------------------------------------------- /tests/qlora/README.md: -------------------------------------------------------------------------------- 1 | ## QLoRA Train and Merge Tests 2 | 3 | ### Overview 4 | Tests that performing QLoRA training and merging weights to 16-bits post-training maintains same behavior as trained model. 5 | 6 | - `test_unsloth_qlora_train_and_merge.py`: Test Unsloth QLoRA train and merge using `FastLanguageModel.from_pretrained`, `FastLanguageModel.get_peft_model`, and `FastLanguageModel.save_pretrained_merged` apis 7 | - `test_hf_qlora_train_and_merge.py`: Test Hugging Face QLoRA train and merge using `from_pretrained`, `get_peft_model`, and `merge_and_unload` apis. 8 | - Demonstrates that `peft`'s `merge_and_unload` results in loss of accuracy as it requantizes the base layer after merging adapter weights so that the model still contains `Linear4Bit` layers post merging. 9 | - I (@jeromeku) implemented a custom merge function that replaces all `LoraLayers` with `Linear` layers whose weights are the dequantized base layer weights with adapter weights merged (compute done in fp32, cast to original dtype after merging), roughly equivalent to `FastLanguageModel.save_pretrained_merged`. 10 | 11 | ### Usage 12 | Run unsloth test: 13 | ```bash 14 | python tests/qlora/test_unsloth_qlora_train_and_merge.py 15 | ``` 16 | Run huggingface test: 17 | ```bash 18 | python tests/qlora/test_hf_qlora_train_and_merge.py 19 | ``` 20 | 21 | ### Details 22 | The tests train a QLoRA model on a single prompt dataset 23 | ``` 24 | QUESTION = "What day was I born?" 25 | ANSWER = "January 1, 2058" 26 | USER_MESSAGE = {"role": "user", "content": QUESTION} 27 | ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER} 28 | ``` 29 | 30 | Given that the answer is impossible to answer accurately without finetuning, we can only expect the model to answer the question correctly if the model has been trained on the question. 31 | 32 | To check this behavior, we check the model's response to the question before and after training and after merging, checking that the model's response contains the answer after training and merging but not before training. 33 | 34 | ### Results 35 | 36 | For the unsloth test, the model's behavior is as expected: 37 | - before training, the model's response does not contain the answer 38 | - after training, the model's response contains the answer 39 | - after merging, the model's response contains the answer 40 | 41 | For the huggingface test, the model's behavior is as expected: 42 | - before training, the model's response does not contain the answer 43 | - after training, the model's response contains the answer 44 | - after using peft's `merge_and_unload`, the model's response does not contain the answer 45 | - after using my custom merge function, the model's response contains the answer 46 | 47 | The scripts should output training params, training logs, as well as model responses before and after training and after merging (only prints model responses if answer is not contained in response). -------------------------------------------------------------------------------- /tests/qlora/test_hf_qlora_train_and_merge.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # ruff: noqa 16 | import sys 17 | from pathlib import Path 18 | 19 | REPO_ROOT = Path(__file__).parents[2] 20 | sys.path.append(str(REPO_ROOT)) 21 | 22 | import itertools 23 | from copy import deepcopy 24 | 25 | import torch 26 | from datasets import Dataset 27 | from trl import SFTConfig 28 | from tests.utils import header_footer_context 29 | from tests.utils.data_utils import ( 30 | ANSWER, 31 | DEFAULT_MESSAGES, 32 | USER_MESSAGE, 33 | check_responses, 34 | create_dataset, 35 | describe_peft_weights, 36 | ) 37 | from tests.utils.hf_utils import ( 38 | convert_lora_to_linear, 39 | fix_llama3_tokenizer, 40 | get_peft_config, 41 | sample_responses, 42 | setup_model, 43 | setup_tokenizer, 44 | setup_trainer, 45 | ) 46 | 47 | if __name__ == "__main__": 48 | model_name = "meta-llama/Llama-3.2-1B-Instruct" 49 | dtype = torch.bfloat16 50 | max_steps = 100 51 | num_examples = 1000 52 | lora_rank = 64 53 | output_dir = "sft_test" 54 | seed = 42 55 | batch_size = 5 56 | num_generations = 5 57 | tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer]) 58 | temperature = 0.8 59 | max_new_tokens = 20 60 | 61 | peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear") 62 | model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config) 63 | 64 | prompt = tokenizer.apply_chat_template( 65 | [USER_MESSAGE], tokenize=False, add_generation_prompt=True 66 | ) 67 | with header_footer_context("Test Prompt and Answer"): 68 | print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") 69 | 70 | dataset: Dataset = create_dataset( 71 | tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES 72 | ) 73 | with header_footer_context("Dataset"): 74 | print(f"Dataset: {next(iter(dataset))}") 75 | 76 | training_args = SFTConfig( 77 | output_dir=output_dir, 78 | max_steps=max_steps, 79 | per_device_train_batch_size=batch_size, 80 | log_level="info", 81 | report_to="none", 82 | num_train_epochs=1, 83 | logging_steps=1, 84 | seed=seed, 85 | bf16=dtype == torch.bfloat16, 86 | fp16=dtype == torch.float16, 87 | save_strategy="no", 88 | ) 89 | 90 | with header_footer_context("Train Args"): 91 | print(training_args) 92 | print(peft_config) 93 | 94 | trainer = setup_trainer( 95 | model, tokenizer, dataset, training_args, peft_config=peft_config 96 | ) 97 | 98 | with header_footer_context("Model"): 99 | print(type(model.model)) 100 | 101 | generation_args = { 102 | "num_generations": num_generations, 103 | "max_new_tokens": max_new_tokens, 104 | "temperature": temperature, 105 | "skip_special_tokens": False, 106 | "dtype": dtype, 107 | } 108 | responses = sample_responses( 109 | model, 110 | tokenizer, 111 | prompt=prompt, 112 | **generation_args, 113 | ) 114 | with header_footer_context("Responses before training"): 115 | check_responses(responses, answer=ANSWER, prompt=prompt) 116 | 117 | with header_footer_context("Peft Weights before training"): 118 | for name, stats in itertools.islice(describe_peft_weights(model), 2): 119 | print(f"{name}:\n{stats}") 120 | 121 | output = trainer.train() 122 | with header_footer_context("Peft Weights after training"): 123 | for name, stats in itertools.islice(describe_peft_weights(model), 2): 124 | print(f"{name}:\n{stats}") 125 | 126 | with header_footer_context("Trainer Output"): 127 | print(output) 128 | 129 | responses = sample_responses( 130 | model, 131 | tokenizer, 132 | prompt=prompt, 133 | **generation_args, 134 | ) 135 | with header_footer_context("Responses after training"): 136 | check_responses(responses, answer=ANSWER, prompt=prompt) 137 | 138 | model_copy = deepcopy(model) 139 | 140 | merged_model = convert_lora_to_linear(model) 141 | 142 | responses = sample_responses( 143 | merged_model, 144 | tokenizer, 145 | prompt=prompt, 146 | **generation_args, 147 | ) 148 | with header_footer_context("Responses after custom merging to 16bit"): 149 | check_responses(responses, answer=ANSWER, prompt=prompt) 150 | 151 | merged_model_peft = model_copy.merge_and_unload() 152 | responses = sample_responses( 153 | merged_model_peft, 154 | tokenizer, 155 | prompt=prompt, 156 | **generation_args, 157 | ) 158 | with header_footer_context("Responses after peft merge_and_unload"): 159 | check_responses(responses, answer=ANSWER, prompt=prompt) 160 | -------------------------------------------------------------------------------- /tests/qlora/test_unsloth_qlora_train_and_merge.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # ruff: noqa 16 | import sys 17 | from pathlib import Path 18 | 19 | REPO_ROOT = Path(__file__).parents[2] 20 | sys.path.append(str(REPO_ROOT)) 21 | 22 | import itertools 23 | from unsloth import FastLanguageModel 24 | 25 | import torch 26 | from datasets import Dataset 27 | from trl import SFTConfig 28 | from tests.utils import header_footer_context 29 | from tests.utils.data_utils import ( 30 | DEFAULT_MESSAGES, 31 | USER_MESSAGE, 32 | ANSWER, 33 | create_dataset, 34 | describe_peft_weights, 35 | check_responses, 36 | ) 37 | from tests.utils.hf_utils import ( 38 | sample_responses, 39 | setup_trainer, 40 | ) 41 | 42 | 43 | def get_unsloth_model_and_tokenizer( 44 | model_name: str, 45 | max_seq_length: int, 46 | load_in_4bit: bool, 47 | fast_inference: bool, 48 | max_lora_rank: int = None, 49 | gpu_memory_utilization: float = 0.5, 50 | dtype: torch.dtype = torch.bfloat16, 51 | ): 52 | return FastLanguageModel.from_pretrained( 53 | model_name=model_name, 54 | max_seq_length=max_seq_length, 55 | load_in_4bit=load_in_4bit, 56 | fast_inference=fast_inference, 57 | max_lora_rank=max_lora_rank, 58 | gpu_memory_utilization=gpu_memory_utilization, 59 | dtype=dtype, 60 | ) 61 | 62 | 63 | def get_unsloth_peft_model( 64 | model, 65 | lora_rank: int, 66 | target_modules: list[str] = "all-linear", 67 | use_gradient_checkpointing: str = False, 68 | random_state: int = 42, 69 | ): 70 | return FastLanguageModel.get_peft_model( 71 | model, 72 | r=lora_rank, 73 | target_modules=target_modules, 74 | lora_alpha=lora_rank, 75 | use_gradient_checkpointing=use_gradient_checkpointing, 76 | random_state=random_state, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | model_name = "meta-llama/Llama-3.2-1B-Instruct" 82 | dtype = torch.bfloat16 83 | max_steps = 100 84 | num_examples = 1000 85 | lora_rank = 64 86 | output_dir = "sft_test" 87 | seed = 42 88 | batch_size = 5 89 | num_generations = 5 90 | target_modules = [ 91 | "q_proj", 92 | "k_proj", 93 | "v_proj", 94 | "o_proj", 95 | "gate_proj", 96 | "up_proj", 97 | "down_proj", 98 | ] 99 | gradient_checkpointing = False 100 | unsloth_merged_path = "unsloth_merged_16bit" 101 | 102 | model, tokenizer = get_unsloth_model_and_tokenizer( 103 | model_name, 104 | max_seq_length=512, 105 | load_in_4bit=True, 106 | fast_inference=False, 107 | max_lora_rank=lora_rank, 108 | dtype=dtype, 109 | ) 110 | temperature = 0.8 111 | max_new_tokens = 20 112 | 113 | model = get_unsloth_peft_model( 114 | model, 115 | lora_rank=lora_rank, 116 | target_modules=target_modules, 117 | use_gradient_checkpointing=gradient_checkpointing, 118 | random_state=seed, 119 | ) 120 | 121 | prompt = tokenizer.apply_chat_template( 122 | [USER_MESSAGE], tokenize=False, add_generation_prompt=True 123 | ) 124 | 125 | with header_footer_context("Test Prompt and Answer"): 126 | print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") 127 | 128 | dataset: Dataset = create_dataset( 129 | tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES 130 | ) 131 | with header_footer_context("Dataset"): 132 | print(f"Dataset: {next(iter(dataset))}") 133 | 134 | training_args = SFTConfig( 135 | output_dir=output_dir, 136 | max_steps=max_steps, 137 | per_device_train_batch_size=batch_size, 138 | log_level="info", 139 | report_to="none", 140 | num_train_epochs=1, 141 | logging_steps=1, 142 | seed=seed, 143 | bf16=dtype == torch.bfloat16, 144 | fp16=dtype == torch.float16, 145 | save_strategy="no", 146 | ) 147 | 148 | with header_footer_context("Train Args"): 149 | print(training_args) 150 | 151 | trainer = setup_trainer(model, tokenizer, dataset, training_args) 152 | 153 | with header_footer_context("Model"): 154 | print(type(model.model)) 155 | 156 | generation_args = { 157 | "num_generations": num_generations, 158 | "max_new_tokens": max_new_tokens, 159 | "temperature": temperature, 160 | "skip_special_tokens": False, 161 | "dtype": dtype, 162 | } 163 | responses = sample_responses( 164 | model, 165 | tokenizer, 166 | prompt=prompt, 167 | **generation_args, 168 | ) 169 | with header_footer_context("Responses before training"): 170 | check_responses(responses, answer=ANSWER, prompt=prompt) 171 | with header_footer_context("Peft Weights before training"): 172 | for name, stats in itertools.islice(describe_peft_weights(model), 2): 173 | print(f"{name}:\n{stats}") 174 | 175 | output = trainer.train() 176 | with header_footer_context("Peft Weights after training"): 177 | for name, stats in itertools.islice(describe_peft_weights(model), 2): 178 | print(f"{name}:\n{stats}") 179 | 180 | with header_footer_context("Trainer Output"): 181 | print(output) 182 | 183 | responses = sample_responses( 184 | model, 185 | tokenizer, 186 | prompt=prompt, 187 | **generation_args, 188 | ) 189 | with header_footer_context("Responses after training"): 190 | check_responses(responses, answer=ANSWER, prompt=prompt) 191 | 192 | model.save_pretrained_merged( 193 | unsloth_merged_path, 194 | tokenizer, 195 | save_method="merged_16bit", 196 | ) 197 | merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer( 198 | unsloth_merged_path, 199 | max_seq_length=512, 200 | load_in_4bit=False, 201 | fast_inference=False, 202 | dtype=dtype, 203 | ) 204 | responses = sample_responses( 205 | merged_model_unsloth, 206 | tokenizer, 207 | prompt=prompt, 208 | **generation_args, 209 | ) 210 | with header_footer_context("Responses after unsloth merge to 16bit"): 211 | check_responses(responses, answer=ANSWER, prompt=prompt) 212 | -------------------------------------------------------------------------------- /tests/saving/test_unsloth_save.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import pytest 6 | 7 | from unsloth import FastLanguageModel, FastModel 8 | 9 | model_to_test = [ 10 | # Text Models 11 | "unsloth/tinyllama", 12 | "unsloth/tinyllama-bnb-4bit", 13 | "unsloth/Qwen2.5-0.5B-Instruct", 14 | "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit", 15 | "unsloth/Phi-4-mini-instruct", 16 | "unsloth/Phi-4-mini-instruct-bnb-4bit", 17 | "unsloth/Qwen2.5-0.5B", 18 | # Vision Models 19 | "unsloth/gemma-3-1b-it", 20 | "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", 21 | "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit" 22 | ] 23 | 24 | # Variables 25 | save_file_sizes = {} 26 | save_file_sizes["merged_16bit"] = {} 27 | save_file_sizes["merged_4bit"] = {} 28 | 29 | tokenizer_files = [ 30 | "tokenizer_config.json", 31 | "special_tokens_map.json", 32 | ] 33 | 34 | @pytest.fixture(scope="session", params=model_to_test) 35 | def loaded_model_tokenizer(request): 36 | model_name = request.param 37 | print("Loading model and tokenizer...") 38 | 39 | model, tokenizer = FastModel.from_pretrained( 40 | model_name, # use small model 41 | max_seq_length=128, 42 | dtype=None, 43 | load_in_4bit=True, 44 | ) 45 | 46 | # Apply LoRA 47 | model = FastModel.get_peft_model( 48 | model, 49 | r=16, 50 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 51 | lora_alpha=16, 52 | use_gradient_checkpointing="unsloth", 53 | ) 54 | 55 | return model, tokenizer 56 | 57 | @pytest.fixture(scope="session") 58 | def model(loaded_model_tokenizer): 59 | return loaded_model_tokenizer[0] 60 | 61 | @pytest.fixture(scope="session") 62 | def tokenizer(loaded_model_tokenizer): 63 | return loaded_model_tokenizer[1] 64 | 65 | @pytest.fixture 66 | def temp_save_dir(): 67 | dir = tempfile.mkdtemp() 68 | print(f"Temporary directory created at: {dir}") 69 | yield dir 70 | print(f"Temporary directory deleted: {dir}") 71 | shutil.rmtree(dir) 72 | 73 | 74 | def delete_quantization_config(model): 75 | # Since merged, edit quantization_config 76 | old_config = model.config 77 | new_config = model.config.to_dict() 78 | if "quantization_config" in new_config: 79 | del new_config["quantization_config"] 80 | original_model = model 81 | new_config = type(model.config).from_dict(new_config) 82 | while hasattr(original_model, "model"): 83 | original_model = original_model.model 84 | original_model.config = new_config 85 | model.config = new_config 86 | 87 | def test_save_merged_16bit(model, tokenizer, temp_save_dir: str): 88 | save_path = os.path.join(temp_save_dir, "unsloth_merged_16bit", model.config._name_or_path.replace("/", "_")) 89 | 90 | model.save_pretrained_merged( 91 | save_path, 92 | tokenizer=tokenizer, 93 | save_method="merged_16bit" 94 | ) 95 | 96 | # Check model files 97 | assert os.path.isdir(save_path), f"Directory {save_path} does not exist." 98 | assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found." 99 | 100 | weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")] 101 | assert len(weight_files) > 0, "No weight files found in the save directory." 102 | 103 | # Check tokenizer files 104 | for file in tokenizer_files: 105 | assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory." 106 | 107 | # Check config to see if it is 16bit by checking for quantization config 108 | config_path = os.path.join(save_path, "config.json") 109 | with open(config_path, "r") as f: 110 | config = json.load(f) 111 | 112 | assert "quantization_config" not in config, "Quantization config not found in the model config." 113 | 114 | # Store the size of the model files 115 | total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files) 116 | save_file_sizes["merged_16bit"][model.config._name_or_path] = total_size 117 | print(f"Total size of merged_16bit files: {total_size} bytes") 118 | 119 | # Test loading the model from the saved path 120 | loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained( 121 | save_path, 122 | max_seq_length=128, 123 | dtype=None, 124 | load_in_4bit=True, 125 | ) 126 | 127 | def test_save_merged_4bit(model, tokenizer, temp_save_dir: str): 128 | save_path = os.path.join(temp_save_dir, "unsloth_merged_4bit", model.config._name_or_path.replace("/", "_")) 129 | 130 | model.save_pretrained_merged( 131 | save_path, 132 | tokenizer=tokenizer, 133 | save_method="merged_4bit_forced" 134 | ) 135 | 136 | # Check model files 137 | assert os.path.isdir(save_path), f"Directory {save_path} does not exist." 138 | assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found." 139 | 140 | weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")] 141 | assert len(weight_files) > 0, "No weight files found in the save directory." 142 | 143 | # Check tokenizer files 144 | for file in tokenizer_files: 145 | assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory." 146 | 147 | # Store the size of the model files 148 | total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files) 149 | save_file_sizes["merged_4bit"][model.config._name_or_path] = total_size 150 | 151 | print(f"Total size of merged_4bit files: {total_size} bytes") 152 | 153 | assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "Merged 4bit files are larger than merged 16bit files." 154 | 155 | # Check config to see if it is 4bit 156 | config_path = os.path.join(save_path, "config.json") 157 | with open(config_path, "r") as f: 158 | config = json.load(f) 159 | 160 | assert "quantization_config" in config, "Quantization config not found in the model config." 161 | 162 | # Test loading the model from the saved path 163 | loaded_model, loaded_tokenizer = FastModel.from_pretrained( 164 | save_path, 165 | max_seq_length=128, 166 | dtype=None, 167 | load_in_4bit=True, 168 | ) 169 | 170 | -------------------------------------------------------------------------------- /tests/test_model_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Test model registration methods 4 | Checks that model registration methods work for respective models as well as all models 5 | The check is performed 6 | - by registering the models 7 | - checking that the instantiated models can be found on huggingface hub by querying for the model id 8 | 9 | """ 10 | 11 | from dataclasses import dataclass 12 | 13 | import pytest 14 | from huggingface_hub import ModelInfo as HfModelInfo 15 | 16 | from unsloth.registry import register_models, search_models 17 | from unsloth.registry._deepseek import register_deepseek_models 18 | from unsloth.registry._gemma import register_gemma_models 19 | from unsloth.registry._llama import register_llama_models 20 | from unsloth.registry._mistral import register_mistral_models 21 | from unsloth.registry._phi import register_phi_models 22 | from unsloth.registry._qwen import register_qwen_models 23 | from unsloth.registry.registry import MODEL_REGISTRY, QUANT_TAG_MAP, QuantType 24 | from unsloth.utils.hf_hub import get_model_info 25 | 26 | MODEL_NAMES = [ 27 | "llama", 28 | "qwen", 29 | "mistral", 30 | "phi", 31 | "gemma", 32 | "deepseek", 33 | ] 34 | MODEL_REGISTRATION_METHODS = [ 35 | register_llama_models, 36 | register_qwen_models, 37 | register_mistral_models, 38 | register_phi_models, 39 | register_gemma_models, 40 | register_deepseek_models, 41 | ] 42 | 43 | 44 | @dataclass 45 | class ModelTestParam: 46 | name: str 47 | register_models: callable 48 | 49 | 50 | def _test_model_uploaded(model_ids: list[str]): 51 | missing_models = [] 52 | for _id in model_ids: 53 | model_info: HfModelInfo = get_model_info(_id) 54 | if not model_info: 55 | missing_models.append(_id) 56 | 57 | return missing_models 58 | 59 | 60 | TestParams = [ 61 | ModelTestParam(name, models) 62 | for name, models in zip(MODEL_NAMES, MODEL_REGISTRATION_METHODS) 63 | ] 64 | 65 | 66 | # Test that model registration methods register respective models 67 | @pytest.mark.parametrize("model_test_param", TestParams, ids=lambda param: param.name) 68 | def test_model_registration(model_test_param: ModelTestParam): 69 | MODEL_REGISTRY.clear() 70 | registration_method = model_test_param.register_models 71 | registration_method() 72 | registered_models = MODEL_REGISTRY.keys() 73 | missing_models = _test_model_uploaded(registered_models) 74 | assert not missing_models, ( 75 | f"{model_test_param.name} missing following models: {missing_models}" 76 | ) 77 | 78 | 79 | def test_all_model_registration(): 80 | register_models() 81 | registered_models = MODEL_REGISTRY.keys() 82 | missing_models = _test_model_uploaded(registered_models) 83 | assert not missing_models, f"Missing following models: {missing_models}" 84 | 85 | def test_quant_type(): 86 | # Test that the quant_type is correctly set for model paths 87 | # NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH 88 | dynamic_quant_models = search_models(quant_types=[QuantType.UNSLOTH]) 89 | assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models) 90 | quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH] 91 | assert all(quant_tag in m.model_path for m in dynamic_quant_models) -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | from contextlib import contextmanager 17 | 18 | 19 | @contextmanager 20 | def timer(name): 21 | start = time.time() 22 | yield 23 | end = time.time() 24 | print(f"{name} took {end - start:.2f} seconds") 25 | 26 | 27 | @contextmanager 28 | def header_footer_context(title: str, char="-"): 29 | print() 30 | print(f"{char}" * 50 + f" {title} " + f"{char}" * 50) 31 | yield 32 | print(f"{char}" * (100 + len(title) + 2)) 33 | print() 34 | -------------------------------------------------------------------------------- /tests/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from datasets import Dataset 17 | 18 | QUESTION = "What day was I born?" 19 | ANSWER = "January 1, 2058" 20 | USER_MESSAGE = {"role": "user", "content": QUESTION} 21 | ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER} 22 | DTYPE = torch.bfloat16 23 | DEFAULT_MESSAGES = [[USER_MESSAGE, ASSISTANT_MESSAGE]] 24 | 25 | 26 | def create_instruction_dataset(messages: list[dict] = DEFAULT_MESSAGES): 27 | dataset = Dataset.from_dict({"messages": messages}) 28 | return dataset 29 | 30 | 31 | def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = None): 32 | dataset = create_instruction_dataset(messages) 33 | 34 | def _apply_chat_template(example): 35 | chat = tokenizer.apply_chat_template(example["messages"], tokenize=False) 36 | return {"text": chat} 37 | 38 | dataset = dataset.map(_apply_chat_template, remove_columns="messages") 39 | if num_examples is not None: 40 | if len(dataset) < num_examples: 41 | num_repeats = num_examples // len(dataset) + 1 42 | dataset = dataset.repeat(num_repeats) 43 | dataset = dataset.select(range(num_examples)) 44 | 45 | return dataset 46 | 47 | 48 | def describe_param( 49 | param: torch.Tensor, 50 | include_l1: bool = False, 51 | include_l2: bool = False, 52 | include_infinity: bool = False, 53 | as_str: bool = True, 54 | ) -> dict: 55 | """ 56 | Provide a statistical summary of a 2D weight matrix or tensor. 57 | If as_str is True, the summary is returned as a formatted string. 58 | Parameters: 59 | param: torch.Tensor 60 | include_l1 (bool): Whether to include the L1 norm (sum of absolute values). 61 | include_l2 (bool): Whether to include the L2 norm (Frobenius norm). 62 | include_infinity (bool): Whether to include the infinity norm (max absolute value). 63 | as_str (bool): Whether to return the summary as a formatted string. 64 | 65 | Returns: 66 | dict: A dictionary with the following statistics: 67 | - shape: Dimensions of the matrix. 68 | - mean: Average value. 69 | - median: Median value. 70 | - std: Standard deviation. 71 | - min: Minimum value. 72 | - max: Maximum value. 73 | - percentile_25: 25th percentile. 74 | - percentile_75: 75th percentile. 75 | Additionally, if enabled: 76 | - L1_norm: Sum of absolute values. 77 | - L2_norm: Euclidean (Frobenius) norm. 78 | - infinity_norm: Maximum absolute value. 79 | """ 80 | 81 | param = param.float() 82 | summary = { 83 | "shape": param.shape, 84 | "mean": param.mean().cpu().item(), 85 | "std": param.std().cpu().item(), 86 | "min": param.min().cpu().item(), 87 | "max": param.max().cpu().item(), 88 | "percentile_25": param.quantile(0.25).cpu().item(), 89 | "percentile_50": param.quantile(0.5).cpu().item(), 90 | "percentile_75": param.quantile(0.75).cpu().item(), 91 | } 92 | 93 | if include_l1: 94 | summary["L1_norm"] = param.abs().sum().cpu().item() 95 | if include_l2: 96 | summary["L2_norm"] = param.norm().cpu().item() 97 | if include_infinity: 98 | summary["infinity_norm"] = param.abs().max().cpu().item() 99 | 100 | return format_summary(summary) if as_str else summary 101 | 102 | 103 | def format_summary(stats: dict, precision: int = 6) -> str: 104 | """ 105 | Format the statistical summary dictionary for printing. 106 | 107 | Parameters: 108 | stats (dict): The dictionary returned by describe_param. 109 | precision (int): Number of decimal places for floating point numbers. 110 | 111 | Returns: 112 | str: A formatted string representing the summary. 113 | """ 114 | lines = [] 115 | for key, value in stats.items(): 116 | if isinstance(value, float): 117 | formatted_value = f"{value:.{precision}f}" 118 | elif isinstance(value, (tuple, list)): 119 | # Format each element in tuples or lists (e.g., the shape) 120 | formatted_value = ", ".join(str(v) for v in value) 121 | formatted_value = ( 122 | f"({formatted_value})" 123 | if isinstance(value, tuple) 124 | else f"[{formatted_value}]" 125 | ) 126 | else: 127 | formatted_value = str(value) 128 | lines.append(f"{key}: {formatted_value}") 129 | return "\n".join(lines) 130 | 131 | 132 | def get_peft_weights(model): 133 | # ruff: noqa 134 | is_lora_weight = lambda name: any(s in name for s in ["lora_A", "lora_B"]) 135 | return { 136 | name: param for name, param in model.named_parameters() if is_lora_weight(name) 137 | } 138 | 139 | 140 | def describe_peft_weights(model): 141 | for name, param in get_peft_weights(model).items(): 142 | yield name, describe_param(param, as_str=True) 143 | 144 | 145 | def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool: 146 | for i, response in enumerate(responses, start=1): 147 | if answer in response: 148 | print(f"\u2713 response {i} contains answer") 149 | else: 150 | print(f"\u2717 response {i} does not contain answer") 151 | if prompt is not None: 152 | response = response.replace(prompt, "") 153 | print(f" -> response: {response}") 154 | -------------------------------------------------------------------------------- /tests/utils/hf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from contextlib import contextmanager, nullcontext 17 | from typing import Callable, Optional 18 | 19 | import bitsandbytes as bnb 20 | import torch 21 | from bitsandbytes.functional import dequantize_4bit 22 | from peft import get_peft_model, prepare_model_for_kbit_training 23 | from peft.tuners.lora import LoraConfig, LoraLayer 24 | from transformers import ( 25 | AutoModelForCausalLM, 26 | AutoTokenizer, 27 | BitsAndBytesConfig, 28 | ) 29 | from transformers.trainer_callback import ( 30 | TrainerCallback, 31 | TrainerControl, 32 | TrainerState, 33 | TrainingArguments, 34 | ) 35 | from trl import SFTTrainer 36 | 37 | 38 | class PeftWeightCallback(TrainerCallback): 39 | def on_log( 40 | self, 41 | args: TrainingArguments, 42 | state: TrainerState, 43 | control: TrainerControl, 44 | logs, 45 | **kwargs, 46 | ): 47 | print(f"DEBUG::CALLBACK::on_log::{state.log_history}") 48 | 49 | def on_train_begin( 50 | self, 51 | args: TrainingArguments, 52 | state: TrainerState, 53 | control: TrainerControl, 54 | **kwargs, 55 | ): 56 | model = kwargs.get("model") 57 | assert model is not None 58 | print(f"DEBUG::CALLBACK::on_train_begin::{kwargs.keys()}") 59 | 60 | def on_step_end( 61 | self, 62 | args: TrainingArguments, 63 | state: TrainerState, 64 | control: TrainerControl, 65 | **kwargs, 66 | ): 67 | print(f"DEBUG::CALLBACK::on_step_end::{state.global_step}") 68 | 69 | 70 | @torch.inference_mode() 71 | def generate_responses( 72 | model, 73 | tokenizer, 74 | prompt, 75 | max_new_tokens: int = 100, 76 | temperature: float = 0.8, 77 | do_sample: bool = True, 78 | num_generations: int = 1, 79 | skip_special_tokens: bool = True, 80 | dtype: torch.dtype = None, 81 | ): 82 | inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)] 83 | keys = inputs[0].keys() 84 | batched_inputs = { 85 | key: torch.cat([input[key] for input in inputs], dim=0).to(model.device) 86 | for key in keys 87 | } 88 | 89 | if dtype is not None: 90 | inference_context = torch.autocast(device_type="cuda", dtype=dtype) 91 | else: 92 | inference_context = nullcontext() 93 | 94 | with inference_context: 95 | outputs = model.generate( 96 | **batched_inputs, 97 | max_new_tokens=max_new_tokens, 98 | do_sample=do_sample, 99 | temperature=temperature, 100 | ) 101 | 102 | responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens) 103 | return responses 104 | 105 | 106 | def sample_responses( 107 | model, 108 | tokenizer, 109 | prompt, 110 | temperature: float = 0.8, 111 | num_generations: int = 1, 112 | max_new_tokens: int = 100, 113 | skip_special_tokens: bool = True, 114 | dtype: torch.dtype = None, 115 | ): 116 | responses = generate_responses( 117 | model, 118 | tokenizer, 119 | prompt, 120 | temperature=temperature, 121 | num_generations=num_generations, 122 | max_new_tokens=max_new_tokens, 123 | skip_special_tokens=skip_special_tokens, 124 | dtype=dtype, 125 | ) 126 | return responses 127 | 128 | 129 | def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []): 130 | tokenizer = AutoTokenizer.from_pretrained(model_name) 131 | for fixup_func in fixup_funcs: 132 | tokenizer = fixup_func(tokenizer) 133 | return tokenizer 134 | 135 | 136 | def setup_model( 137 | model_name, 138 | quantize: bool = True, 139 | dtype=torch.bfloat16, 140 | peft_config=None, 141 | autocast_adapter: bool = True, 142 | ): 143 | if quantize: 144 | bnb_config = BitsAndBytesConfig( 145 | load_in_4bit=True, 146 | bnb_4bit_use_double_quant=True, 147 | bnb_4bit_quant_type="nf4", 148 | bnb_4bit_compute_dtype=dtype, 149 | ) 150 | else: 151 | bnb_config = None 152 | 153 | model = AutoModelForCausalLM.from_pretrained( 154 | model_name, 155 | device_map="cuda:0", 156 | attn_implementation="sdpa", 157 | quantization_config=bnb_config, 158 | torch_dtype=dtype, 159 | ) 160 | model = prepare_model_for_kbit_training(model) if quantize else model 161 | 162 | if peft_config is not None: 163 | model = get_peft_model( 164 | model, peft_config, autocast_adapter_dtype=autocast_adapter 165 | ) 166 | 167 | return model 168 | 169 | 170 | def get_peft_config( 171 | lora_rank, 172 | lora_alpha=None, 173 | lora_dropout=0.0, 174 | bias="none", 175 | target_modules="all-linear", 176 | ): 177 | lora_alpha = lora_alpha or 2 * lora_rank 178 | peft_config = LoraConfig( 179 | lora_alpha=lora_alpha, 180 | lora_dropout=lora_dropout, 181 | r=lora_rank, 182 | bias=bias, 183 | target_modules=target_modules, 184 | task_type="CAUSAL_LM", 185 | ) 186 | return peft_config 187 | 188 | 189 | def setup_trainer( 190 | model, 191 | tokenizer, 192 | dataset, 193 | train_args, 194 | peft_config=None, 195 | formatting_func=None, 196 | collator=None, 197 | ): 198 | return SFTTrainer( 199 | model=model, 200 | peft_config=peft_config, 201 | train_dataset=dataset, 202 | processing_class=tokenizer, 203 | formatting_func=formatting_func, 204 | data_collator=collator, 205 | args=train_args, 206 | ) 207 | 208 | 209 | def setup_lora( 210 | model, 211 | tokenizer, 212 | dataset, 213 | peft_config, 214 | train_args, 215 | formatting_func=None, 216 | collator=None, 217 | ): 218 | return LoraConfig( 219 | model=model, 220 | peft_config=peft_config, 221 | train_dataset=dataset, 222 | processing_class=tokenizer, 223 | formatting_func=formatting_func, 224 | data_collator=collator, 225 | args=train_args, 226 | ) 227 | 228 | 229 | def convert_weights_back_to_dtype(model, dtype): 230 | """ 231 | SFTTrainer calls get_peft_model and prepare_model_for_kbit_training which converts all weights to float32. 232 | This function converts the non-loraweights back to the original dtype. 233 | """ 234 | for name, param in model.named_parameters(): 235 | if any(s in name for s in ["norm", "embed"]): 236 | param.data = param.data.to(dtype) 237 | 238 | 239 | def fix_llama3_tokenizer(tokenizer, padding_side="right"): 240 | tokenizer.padding_side = padding_side 241 | added_vocab = tokenizer.get_added_vocab() 242 | pad_token = [w for w in added_vocab if "pad" in w] 243 | assert len(pad_token) == 1 244 | tokenizer.pad_token = pad_token[0] # Load dataset from the hub 245 | return tokenizer 246 | 247 | 248 | def replace_module( 249 | module: torch.nn.Module, 250 | target_module_type: torch.nn.Module, 251 | conversion_func: Callable, 252 | ): 253 | for child_name, child_module in module.named_children(): 254 | if isinstance(child_module, target_module_type): 255 | new_module = conversion_func(child_module) 256 | setattr(module, child_name, new_module) 257 | else: 258 | replace_module(child_module, target_module_type, conversion_func) 259 | 260 | 261 | def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"): 262 | base_layer = module.get_base_layer() 263 | weight = base_layer.weight 264 | 265 | assert isinstance(weight, bnb.nn.Params4bit) 266 | quant_state = weight.quant_state 267 | original_dtype = quant_state.dtype 268 | 269 | w_dq = dequantize_4bit(weight.data, quant_state).float() 270 | lora_delta = ( 271 | module.lora_B[adapter_name].weight 272 | @ module.lora_A[adapter_name].weight 273 | * module.scaling[adapter_name] 274 | ) 275 | w_dq += lora_delta.float() 276 | w_dq = w_dq.to(original_dtype) 277 | 278 | new_module = torch.nn.Linear( 279 | w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None 280 | ) 281 | new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False) 282 | if module.lora_bias[adapter_name]: 283 | bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias 284 | new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False) 285 | return new_module 286 | 287 | 288 | def convert_lora_to_linear(model: torch.nn.Module): 289 | replace_module(model, LoraLayer, _convert_lora_to_linear) 290 | assert not any(isinstance(module, LoraLayer) for module in model.modules()) 291 | return model 292 | -------------------------------------------------------------------------------- /unsloth/_auto_install.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | try: import torch 16 | except: raise ImportError('Install torch via `pip install torch`') 17 | from packaging.version import Version as V 18 | v = V(torch.__version__) 19 | cuda = str(torch.version.cuda) 20 | is_ampere = torch.cuda.get_device_capability()[0] >= 8 21 | USE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI 22 | if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8"): 23 | raise RuntimeError(f"CUDA = {cuda} not supported!") 24 | if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!") 25 | elif v <= V('2.1.1'): x = 'cu{}{}-torch211' 26 | elif v <= V('2.1.2'): x = 'cu{}{}-torch212' 27 | elif v < V('2.3.0'): x = 'cu{}{}-torch220' 28 | elif v < V('2.4.0'): x = 'cu{}{}-torch230' 29 | elif v < V('2.5.0'): x = 'cu{}{}-torch240' 30 | elif v < V('2.5.1'): x = 'cu{}{}-torch250' 31 | elif v <= V('2.5.1'): x = 'cu{}{}-torch251' 32 | elif v < V('2.7.0'): x = 'cu{}{}-torch260' 33 | elif v < V('2.8.0'): x = 'cu{}{}-torch270' 34 | else: raise RuntimeError(f"Torch = {v} too new!") 35 | x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") 36 | print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') -------------------------------------------------------------------------------- /unsloth/dataprep/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .synthetic import * 16 | -------------------------------------------------------------------------------- /unsloth/dataprep/synthetic_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | synthetic_qa_config = """\ 16 | # Master configuration file for Synthetic Data Kit 17 | 18 | # Global paths configuration 19 | paths: 20 | # Input data locations 21 | input: 22 | pdf: "{data_output_location}/pdf" 23 | html: "{data_output_location}/html" 24 | youtube: "{data_output_location}/youtube" 25 | docx: "{data_output_location}/docx" 26 | ppt: "{data_output_location}/ppt" 27 | txt: "{data_output_location}/txt" 28 | 29 | # Output locations 30 | output: 31 | parsed: "{data_output_location}/output" # Where parsed text files are saved 32 | generated: "{data_output_location}/generated" # Where generated content is saved 33 | cleaned: "{data_output_location}/cleaned" # Where cleaned content is saved 34 | final: "{data_output_location}/final" # Where final formatted content is saved 35 | 36 | # VLLM server configuration 37 | vllm: 38 | api_base: "http://localhost:8000/v1" # Base URL for VLLM API 39 | port: 8000 # Port for VLLM server 40 | model: "{model_name}" # Default model to use 41 | max_retries: 3 # Number of retries for API calls 42 | retry_delay: 1.0 # Initial delay between retries (seconds) 43 | 44 | # Ingest configuration 45 | ingest: 46 | default_format: "txt" # Default output format for parsed files 47 | youtube_captions: "auto" # Options: "auto", "manual" - caption preference 48 | 49 | # LLM generation parameters 50 | generation: 51 | temperature: {temperature} # Higher = more creative, lower = more deterministic 52 | top_p: {top_p} # Nucleus sampling parameter 53 | chunk_size: {chunk_size} # Size of text chunks for processing 54 | overlap: {overlap} # Overlap between chunks to maintain context 55 | max_tokens: {max_tokens} # Maximum tokens in LLM responses 56 | num_pairs: {default_num_pairs} # Default number of QA pairs to generate 57 | 58 | # Content cleanup parameters 59 | cleanup: 60 | threshold: {cleanup_threshold} # Default quality threshold (1-10) 61 | batch_size: {cleanup_batch_size} # Number of items per batch for rating 62 | temperature: {cleanup_temperature} # Temperature for rating (lower = more consistent) 63 | 64 | # Format conversion parameters 65 | format: 66 | default: "jsonl" # Default output format 67 | include_metadata: true # Include metadata in output files 68 | pretty_json: true # Use indentation in JSON output 69 | 70 | # Prompts for different tasks 71 | prompts: 72 | # Summary generation prompt 73 | summary: | 74 | Summarize this document in 3-5 sentences, focusing on the main topic and key concepts. 75 | 76 | # QA pair generation prompt 77 | qa_generation: | 78 | Create {num_pairs} question-answer pairs from this text for LLM training. 79 | 80 | Rules: 81 | 1. Questions must be about important facts in the text 82 | 2. Answers must be directly supported by the text 83 | 3. Return JSON format only: 84 | 85 | [ 86 | {{ 87 | "question": "Question 1?", 88 | "answer": "Answer 1." 89 | }}, 90 | {{ 91 | "question": "Question 2?", 92 | "answer": "Answer 2." 93 | }} 94 | ] 95 | 96 | Text: 97 | {text} 98 | 99 | # QA pair rating prompt 100 | qa_rating: | 101 | Rate each of these question-answer pairs for quality and return exactly this JSON format: 102 | 103 | [ 104 | {{"question": "same question text", "answer": "same answer text", "rating": n}} 105 | ] 106 | 107 | Where n is a number from 1-10. 108 | 109 | DO NOT include any text outside of the JSON array, just return valid JSON: 110 | 111 | {pairs}""" -------------------------------------------------------------------------------- /unsloth/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .cross_entropy_loss import ( 16 | fast_cross_entropy_loss, 17 | post_patch_loss_function, 18 | patch_loss_functions, 19 | ) 20 | from .rms_layernorm import ( 21 | fast_rms_layernorm, 22 | patch_rms_layernorm, 23 | unpatch_rms_layernorm, 24 | ) 25 | from .layernorm import ( 26 | fast_layernorm, 27 | patch_layernorm, 28 | ) 29 | from .rope_embedding import fast_rope_embedding, inplace_rope_embedding 30 | from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel 31 | from .geglu import ( 32 | geglu_exact_forward_kernel, 33 | geglu_exact_backward_kernel, 34 | geglu_approx_forward_kernel, 35 | geglu_approx_backward_kernel, 36 | ) 37 | from .fast_lora import ( 38 | get_lora_parameters, 39 | get_lora_parameters_bias, 40 | apply_lora_mlp_swiglu, 41 | apply_lora_mlp_geglu_exact, 42 | apply_lora_mlp_geglu_approx, 43 | apply_lora_qkv, 44 | apply_lora_o, 45 | fast_lora_forward, 46 | ) 47 | from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora 48 | 49 | from .flex_attention import ( 50 | HAS_FLEX_ATTENTION, 51 | slow_attention_softcapping, 52 | slow_inference_attention_softcapping, 53 | create_flex_attention_causal_mask, 54 | create_flex_attention_sliding_window_mask, 55 | ) 56 | 57 | import os 58 | if "UNSLOTH_ZOO_IS_PRESENT" not in os.environ: 59 | try: 60 | print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") 61 | except: 62 | print("Unsloth: Will patch your computer to enable 2x faster free finetuning.") 63 | pass 64 | pass 65 | del os 66 | -------------------------------------------------------------------------------- /unsloth/kernels/flex_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from functools import lru_cache 17 | from transformers.models.llama.modeling_llama import logger 18 | import os 19 | 20 | torch_compile_options = { 21 | "epilogue_fusion" : True, 22 | "max_autotune" : True, 23 | "shape_padding" : True, 24 | "trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1", 25 | "triton.cudagraphs" : False, 26 | } 27 | 28 | # Flex Attention supported from torch 2.5 onwards only 29 | try: 30 | from torch.nn.attention.flex_attention import ( 31 | flex_attention as _flex_attention, 32 | create_block_mask as _create_block_mask, 33 | ) 34 | _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options) 35 | HAS_FLEX_ATTENTION = False 36 | except: 37 | HAS_FLEX_ATTENTION = False 38 | pass 39 | 40 | 41 | if not HAS_FLEX_ATTENTION: 42 | 43 | # Logit softcapping 44 | @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) 45 | def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): 46 | n_heads = self.config.num_attention_heads 47 | head_dim = self.head_dim 48 | n_kv_heads = self.config.num_key_value_heads 49 | n_groups = self.num_key_value_groups 50 | 51 | # Grouped query attention 52 | K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) 53 | V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) 54 | K = K.reshape(bsz, n_heads, q_len, head_dim) 55 | V = V.reshape(bsz, n_heads, q_len, head_dim) 56 | 57 | # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e 58 | # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below 59 | # We default to using the config file itself 60 | # s = self.config.hidden_size // self.config.num_attention_heads 61 | s = self.config.query_pre_attn_scalar 62 | t = self.config.attn_logit_softcapping 63 | 64 | Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly 65 | A = torch.matmul(Q, K.transpose(2, 3)) 66 | A = t * torch.tanh(A / t) # Logit softcapping 67 | A += causal_mask[:q_len, :q_len] 68 | # Much slower in torch compile! 69 | # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) 70 | A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) 71 | A = torch.matmul(A, V) 72 | A = A.transpose(1, 2).contiguous() 73 | A = A.reshape(bsz, q_len, n_heads*head_dim) 74 | return A 75 | pass 76 | 77 | create_flex_attention_causal_mask = None 78 | create_flex_attention_sliding_window_mask = None 79 | else: 80 | # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb 81 | # for more examples 82 | # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al 83 | import functools, math 84 | 85 | def generate_tanh_softcap(t): 86 | def tanh_softcap(x, b, h, q_idx, kv_idx): 87 | return t * torch.tanh(x / t) 88 | return tanh_softcap 89 | pass 90 | def causal_masker(b, h, q_idx, kv_idx): 91 | return q_idx >= kv_idx 92 | pass 93 | 94 | @functools.lru_cache 95 | def sliding_window_masker(size = 4096): 96 | def sliding_window(b, h, q_idx, kv_idx): 97 | causal_mask = q_idx >= kv_idx 98 | window_mask = q_idx - kv_idx <= size 99 | return causal_mask & window_mask 100 | return sliding_window 101 | pass 102 | 103 | @functools.lru_cache 104 | def create_block_mask(mask, n = 128): 105 | return _create_block_mask( 106 | mask, 1, 1, n, n, 107 | BLOCK_SIZE = 128, 108 | _compile = True, 109 | ) 110 | pass 111 | 112 | def create_flex_attention_causal_mask(max_seq_length = 8192): 113 | causal_mask = create_block_mask(causal_masker, max_seq_length) 114 | return causal_mask 115 | pass 116 | 117 | def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096): 118 | sliding_masker = sliding_window_masker(sliding_window) 119 | causal_mask = create_block_mask(sliding_masker, max_seq_length) 120 | return causal_mask 121 | pass 122 | 123 | @functools.lru_cache 124 | def flex_attention(s, t): 125 | scale = 1.0 / math.sqrt(s) 126 | score_mod = generate_tanh_softcap(t) 127 | return functools.partial( 128 | _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True, 129 | ) 130 | pass 131 | 132 | def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): 133 | n_heads = self.config.num_attention_heads 134 | head_dim = self.head_dim 135 | s = self.config.query_pre_attn_scalar 136 | t = self.config.attn_logit_softcapping 137 | fx = flex_attention(s, t) 138 | A = fx(query = Q, key = K, value = V, block_mask = causal_mask) 139 | A = A.transpose(1, 2).contiguous() 140 | A = A.reshape(bsz, q_len, n_heads*head_dim) 141 | return A 142 | pass 143 | pass 144 | 145 | 146 | torch_matmul = torch.matmul 147 | torch_tanh = torch.tanh 148 | torch_nn_functional_softmax = torch.nn.functional.softmax 149 | def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): 150 | n_heads = self.config.num_attention_heads 151 | head_dim = self.head_dim 152 | n_kv_heads = self.config.num_key_value_heads 153 | n_groups = self.num_key_value_groups 154 | 155 | # Grouped query attention 156 | K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) 157 | V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) 158 | K = K.reshape(bsz, n_heads, q_len, head_dim) 159 | V = V.reshape(bsz, n_heads, q_len, head_dim) 160 | 161 | # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e 162 | # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below 163 | # We default to using the config file itself 164 | # s = self.config.hidden_size // self.config.num_attention_heads 165 | s = self.config.query_pre_attn_scalar 166 | t = self.config.attn_logit_softcapping 167 | 168 | Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly 169 | A = torch_matmul(Q, K.transpose(2, 3)) 170 | 171 | # Logit softcapping 172 | A /= t; torch_tanh(A, out = A); A *= t; 173 | A += causal_mask[:q_len, :q_len] 174 | # Much slower in torch compile! 175 | # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) 176 | A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) 177 | A = torch_matmul(A, V) 178 | A = A.transpose(1, 2).contiguous() 179 | A = A.reshape(bsz, q_len, n_heads*head_dim) 180 | return A 181 | pass 182 | -------------------------------------------------------------------------------- /unsloth/kernels/geglu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import triton 16 | import triton.language as tl 17 | import torch 18 | from .utils import ( 19 | calculate_settings, 20 | triton_tanh, 21 | torch_cuda_device, 22 | ) 23 | 24 | 25 | @triton.jit 26 | def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): 27 | block_idx = tl.program_id(0) 28 | offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 29 | mask = offsets < n_elements 30 | 31 | # f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) 32 | # h = f * up 33 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 34 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 35 | 36 | f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0) 37 | f_row = f_row.to(g_row.dtype) # Exact copy from HF 38 | h_row = f_row * g_row 39 | 40 | # Store h 41 | tl.store(h + offsets, h_row, mask = mask) 42 | pass 43 | 44 | 45 | def geglu_exact_forward_kernel(gate, up): 46 | batch, seq_len, hd = gate.shape 47 | n_elements = gate.numel() 48 | device = gate.device 49 | out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) 50 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 51 | with torch_cuda_device(device): 52 | _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) 53 | return out 54 | pass 55 | 56 | 57 | @triton.jit 58 | def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): 59 | """ 60 | f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) 61 | h = f * up 62 | 63 | df/de (with help of Wolfram :) 64 | df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2) 65 | 66 | Reuse via 67 | f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e 68 | """ 69 | block_idx = tl.program_id(0) 70 | offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 71 | mask = offsets < n_elements 72 | 73 | DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32) 74 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 75 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 76 | 77 | # Break e_row away for re-use 78 | # f = 1/2 * e * (1 + erf(1/sqrt(2) * e)) 79 | f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0) 80 | f_row = f_partial_row * e_row 81 | 82 | f_row = f_row.to(DW_row.dtype) 83 | # h = f * g 84 | h_row = f_row * g_row 85 | # df = DW * f 86 | df_row = DW_row * f_row 87 | # dg = DW * g 88 | dg_row = DW_row * g_row 89 | 90 | # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2) 91 | t = 0.3989422804014327 # 1/sqrt(2*pi) 92 | df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row) 93 | 94 | de_row = dg_row.to(tl.float32) * df_de 95 | de_row = de_row.to(DW_row.dtype) 96 | 97 | # Store derivatives in buffers 98 | tl.store(DW + offsets, h_row, mask = mask) # h = f * g 99 | tl.store(e + offsets, df_row, mask = mask) # df = DW * f 100 | tl.store(g + offsets, de_row, mask = mask) # de 101 | pass 102 | 103 | 104 | def geglu_exact_backward_kernel(DW, e, g): 105 | batch_seq_len, hd = e.shape 106 | n_elements = e.numel() 107 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 108 | with torch_cuda_device(e.device): 109 | _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) 110 | return DW, e, g 111 | pass 112 | 113 | 114 | @triton.jit 115 | def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): 116 | block_idx = tl.program_id(0) 117 | offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 118 | mask = offsets < n_elements 119 | 120 | # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) )) 121 | # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )) 122 | # h = f * up 123 | s = 0.7978845608028654 # math.sqrt(2 / math.pi) 124 | 125 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 126 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 127 | 128 | f_row = 0.5 * e_row * ( 129 | triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \ 130 | + 1.0 131 | ) 132 | f_row = f_row.to(g_row.dtype) # Exact copy from HF 133 | h_row = f_row * g_row 134 | 135 | # Store h 136 | tl.store(h + offsets, h_row, mask = mask) 137 | pass 138 | 139 | 140 | def geglu_approx_forward_kernel(gate, up): 141 | batch, seq_len, hd = gate.shape 142 | n_elements = gate.numel() 143 | device = gate.device 144 | out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) 145 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 146 | with torch_cuda_device(device): 147 | _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) 148 | return out 149 | pass 150 | 151 | 152 | @triton.jit 153 | def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): 154 | """ 155 | f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )) 156 | h = f * up 157 | 158 | df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :)) 159 | df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] + 160 | 1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \ 161 | ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) ) 162 | 163 | Notice sech^2(x) = 1 - tanh^2(x) 164 | So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ) 165 | 166 | See https://www.desmos.com/calculator/nqprfoni6x 167 | """ 168 | block_idx = tl.program_id(0) 169 | offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 170 | mask = offsets < n_elements 171 | 172 | DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32) 173 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 174 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 175 | 176 | # See https://www.desmos.com/calculator/nqprfoni6x 177 | s = 0.7978845608028654 # math.sqrt(2 / math.pi) 178 | a = s * e_row # a = sqrt(2 / pi) * x 179 | b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2 180 | T = 1.0 + triton_tanh(a + b) 181 | T2 = 0.5 * T 182 | # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b) 183 | Q2 = -T2 * (T - 2.0) * (a + 3.0 * b) 184 | df_de = T2 + Q2 # 1/2 * (T + Q) 185 | 186 | # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) )) 187 | f_row = T2 * e_row 188 | f_row = f_row.to(DW_row.dtype) 189 | # h = f * g 190 | h_row = f_row * g_row 191 | # df = DW * f 192 | df_row = DW_row * f_row 193 | # dg = DW * g 194 | dg_row = DW_row * g_row 195 | 196 | de_row = dg_row.to(tl.float32) * df_de 197 | de_row = de_row.to(DW_row.dtype) 198 | 199 | # Store derivatives in buffers 200 | tl.store(DW + offsets, h_row, mask = mask) # h = f * g 201 | tl.store(e + offsets, df_row, mask = mask) # df = DW * f 202 | tl.store(g + offsets, de_row, mask = mask) # de 203 | pass 204 | 205 | 206 | def geglu_approx_backward_kernel(DW, e, g): 207 | batch_seq_len, hd = e.shape 208 | n_elements = e.numel() 209 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 210 | with torch_cuda_device(e.device): 211 | _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) 212 | return DW, e, g 213 | pass 214 | -------------------------------------------------------------------------------- /unsloth/kernels/layernorm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import triton 17 | import triton.language as tl 18 | import torch 19 | from .utils import calculate_settings, torch_cuda_device 20 | from unsloth_zoo.patching_utils import ( 21 | patch_layernorm, 22 | ) 23 | 24 | 25 | @triton.jit 26 | def layernorm_forward( 27 | Y, Y_row_stride, 28 | X, X_row_stride, 29 | W, 30 | b, 31 | r, 32 | mu, 33 | n_cols : tl.constexpr, 34 | eps : tl.constexpr, 35 | BLOCK_SIZE : tl.constexpr 36 | ): 37 | row_idx = tl.program_id(0) 38 | col_offsets = tl.arange(0, BLOCK_SIZE) 39 | mask = col_offsets < n_cols 40 | 41 | Y += row_idx * Y_row_stride 42 | X += row_idx * X_row_stride 43 | r += row_idx 44 | mu += row_idx 45 | 46 | # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules 47 | # are in float32! 48 | X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) 49 | W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) 50 | b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32) 51 | 52 | mean_X = tl.sum(X_row, axis = 0) / n_cols 53 | # (X[0] - mean) == -mean so we need to mask it out 54 | XX = tl.where(mask, X_row - mean_X, 0) 55 | row_var = tl.sum(XX * XX, axis = 0) / n_cols 56 | inv_var = tl.math.rsqrt(row_var + eps) 57 | tl.store (r, inv_var) 58 | tl.store (mu, mean_X) 59 | output = (XX * inv_var) * W_row + b_row 60 | tl.store(Y + col_offsets, output, mask = mask) 61 | pass 62 | 63 | 64 | @triton.jit 65 | def layernorm_backward( 66 | dY, dY_row_stride, 67 | X, X_row_stride, 68 | W, 69 | b, 70 | r, 71 | mu, 72 | n_cols : tl.constexpr, 73 | eps : tl.constexpr, 74 | BLOCK_SIZE : tl.constexpr 75 | ): 76 | # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md 77 | row_idx = tl.program_id(0) 78 | col_offsets = tl.arange(0, BLOCK_SIZE) 79 | mask = col_offsets < n_cols 80 | 81 | dY += row_idx * dY_row_stride 82 | X += row_idx * X_row_stride 83 | r += row_idx 84 | mu += row_idx 85 | 86 | # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules 87 | # are in float32! 88 | dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32) 89 | X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32) 90 | W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32) 91 | b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32) 92 | 93 | inv_var = tl.load(r) .to(tl.float32) 94 | mean = tl.load(mu).to(tl.float32) 95 | normed = (X_row - mean) * inv_var 96 | dY_W = dY_row * W_row 97 | dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols 98 | dX_row = dX_row * inv_var 99 | tl.store(dY + col_offsets, dX_row, mask = mask) 100 | pass 101 | 102 | 103 | class Fast_Layernorm(torch.autograd.Function): 104 | @staticmethod 105 | def forward(ctx, X, W, b, eps): 106 | shape = X.shape 107 | dim = shape[-1] 108 | X = X.view(-1, dim) 109 | n_rows, n_cols = X.shape 110 | BLOCK_SIZE, num_warps = calculate_settings(n_cols) 111 | device = X.device 112 | Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device) 113 | r = torch.empty(n_rows, dtype = torch.float32, device = device) 114 | mu = torch.empty(n_rows, dtype = torch.float32, device = device) 115 | 116 | with torch_cuda_device(device): 117 | layernorm_forward[(n_rows,)]( 118 | Y, Y.stride(0), 119 | X, X.stride(0), 120 | W, 121 | b, 122 | r, 123 | mu, 124 | n_cols, eps, 125 | BLOCK_SIZE = BLOCK_SIZE, 126 | num_warps = num_warps, 127 | ) 128 | ctx.eps = eps 129 | ctx.BLOCK_SIZE = BLOCK_SIZE 130 | ctx.num_warps = num_warps 131 | ctx.save_for_backward(X, W, b, r, mu) 132 | return Y.view(*shape) 133 | pass 134 | 135 | @staticmethod 136 | def backward(ctx, dY): 137 | shape = dY.shape 138 | dim = shape[-1] 139 | dY = dY.view(-1, dim) 140 | X, W, b, r, mu = ctx.saved_tensors 141 | n_rows, n_cols = dY.shape 142 | 143 | with torch_cuda_device(dY.device): 144 | layernorm_backward[(n_rows,)]( 145 | dY, dY.stride(0), 146 | X, X .stride(0), 147 | W, 148 | b, 149 | r, 150 | mu, 151 | n_cols, ctx.eps, 152 | BLOCK_SIZE = ctx.BLOCK_SIZE, 153 | num_warps = ctx.num_warps, 154 | ) 155 | dX = dY.view(*shape) 156 | return dX, None, None, None, None 157 | pass 158 | pass 159 | 160 | 161 | def fast_layernorm(layernorm, X): 162 | assert(layernorm.elementwise_affine is True) 163 | W = layernorm.weight 164 | bias = layernorm.bias 165 | eps = layernorm.variance_epsilon if \ 166 | hasattr(layernorm, "variance_epsilon") \ 167 | else layernorm.eps 168 | out = Fast_Layernorm.apply(X, W, bias, eps) 169 | return out 170 | pass 171 | 172 | 173 | 174 | def test_layernorm( 175 | dim = 1024, eps = 1e-5, dtype = torch.float16, 176 | bsz = 21, random_state = 3407, seqlen = 3341, 177 | ): 178 | from torch.nn import LayerNorm 179 | layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype) 180 | torch.cuda.manual_seed(random_state) 181 | torch.manual_seed(random_state) 182 | torch.nn.init.uniform_(layernorm.weight) 183 | torch.nn.init.uniform_(layernorm.bias) 184 | X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda") 185 | XX = X.clone() 186 | X .requires_grad_(True) 187 | XX.requires_grad_(True) 188 | Y = layernorm(X) 189 | YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True) 190 | Y.backward(YY) 191 | correct_grad = X.grad.clone() 192 | # from unsloth.kernels import fast_layernorm 193 | Y = fast_layernorm(layernorm, XX) 194 | Y.backward(YY) 195 | assert(torch.dist(correct_grad, XX.grad).item() <= 0.1) 196 | pass 197 | 198 | 199 | def testing_suite_layernorm(): 200 | for dim in [512, 1024, 2048]: 201 | for dtype in [torch.float16, torch.bfloat16]: 202 | with torch.autocast(device_type = "cuda", dtype = dtype): 203 | for seqlen in [3341, 2048, 349]: 204 | for random_state in [3407, 42]: 205 | test_layernorm( 206 | dim = dim, 207 | eps = 1e-5, 208 | dtype = dtype, 209 | bsz = 21, 210 | random_state = random_state, 211 | seqlen = seqlen, 212 | ) 213 | pass 214 | pass 215 | pass 216 | pass 217 | pass 218 | pass 219 | -------------------------------------------------------------------------------- /unsloth/kernels/moe/README.md: -------------------------------------------------------------------------------- 1 | ## MoE Grouped GEMM 2 | 3 | Optimized implementation of `MoE MLP Block`. 4 | 5 | ### Background 6 | 7 | `MoE MLP` requires the following steps: 8 | - Calculate `topk_weights` and `topk_indices` 9 | - If using a grouped gemm implementation, calculate permutation indices needed to rearrange tokens grouped by expert 10 | - For each expert: 11 | - `expert_tokens`: gather the tokens assigned to the expert 12 | - `first_gemm`: `gate / up proj` @ `expert_tokens` 13 | - `silu_and_mul`: `silu` and `mul` of `first_gemm` 14 | - `second_gemm`: `silu_and_mul` @ `down proj` 15 | - `scatter_second_gemm`: scatter the `second_gemm` to the original token order 16 | - `topk_weight_mul`: `second_gemm` @ `topk_weights` 17 | - `final_output`: if `topk > 1`, `topk_weight_mul.view(num_tokens, topk, -1).sum(dim=1)` else `topk_weight_mul` 18 | 19 | One way to eliminate the loop is to use a grouped GEMM, where all expert GEMMs are computed within a single kernel, which iterates over tiles of the expert GEMMs as individual GEMMs, where each GEMM, the `A` matrix is `M' x K` and the `B` matrix is `K x N`, where `M'` is the number of tokens assigned to the expert and `B` is the weight matrix for that expert. 20 | 21 | This requires an additional permute (and subsequent copy) of the hidden states such that the tokens assigned to each expert are contiguous in memory before running the first grouped GEMM within the Expert MLP. 22 | Additionally, after the second grouped GEMM, the hidden states must be permuted back to the original token order and multiplied by `topk_weights` to get the final output. 23 | 24 | ### Optimizations 25 | This repo implements a grouped GEMM-based MoE MLP with the following optimizations: 26 | - Eliminates the loop over experts by performing gemms as a grouped GEMM, computing the expert gemms within a single fused triton kernel 27 | - Fuses the permutation of hidden states from token order (original input order) to expert order (tokens grouped by expert) within the prologue of first the first grouped GEMM 28 | - Fuses the (un)permutation of hidden states from expert order back to token order in second GEMM 29 | - Fuses the mul of hidden states by expert weights within epilogue of second GEMM (only implemented for inference, not for training) 30 | 31 | ### Structure 32 | - `grouped_gemm/interface.py`: wrappers for the individual forward / backward kernels as well as the `torch.autograd.Function` 33 | - `grouped_gemm/kernels/forward.py`: forward kernel 34 | - `grouped_gemm/kernels/backward.py`: backward dX and dW kernels 35 | - `grouped_gemm/kernels/tuning.py`: manual tuning utils 36 | - `grouped_gemm/kernels/autotuning.py`: autotuning utils 37 | - `grouped_gemm/reference/moe_block.py`: contains `Qwen3MoeFusedGroupedGEMMBlock`, a reference implementation of Huggingface `Qwen3SparseMOEBlock` with fused triton kernel in-place of original HF expert computation 38 | - `grouped_gemm/reference/moe_ops.py`: supporting ops (routing, token sorting, etc.) and reference MoE block using a torch-native grouped gemm approach. 39 | 40 | ### Tests 41 | - `grouped_gemm/tests/test_grouped_gemm.py`: unit tests for forward, backward grouped gemm kernels as well as the wrapped grouped gemm autograd.Function. Best not to run this entire test suite at once due to the large number of parametrized unit tests. Rather, use filters to run specific 42 | sets of tests. E.g., to run forward tests with autotune turned on: `pytest -sv -k "forward and autotune" --tb=short tests/test_grouped_gemm.py`. Use the test function names and parameter ids for words to filter on. 43 | - `grouped_gemm/tests/test_qwen3_moe.py`: end to end test for Qwen3 MoE block. IMPORTANT: read `tests/run_qwen3_moe_tests.sh` as well as notes in the test itself for complications when running parametrized pytest test suites and triton / autotune. TLDR: use the test script and NOT pytest to run the tests. 44 | 45 | ### Benchmarks 46 | - `grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` or `Llama4TextMoe` against the fused implementation 47 | 48 | 49 | Running with these flags on an `H100` to bench forward pass (run with `--help` to see all available flags): 50 | 51 | For `Qwen3-30B-A3B`: 52 | ``` 53 | python benchmark/benchmark_fused_moe.py --model qwen3 --mode forward --seqlen 1024 --permute_x --permute_y --autotune 54 | ``` 55 | 56 | For the backward bench: 57 | ``` 58 | python benchmark/benchmark_fused_moe.py --model qwen3 --mode backward --seqlen 1024 --permute_x --permute_y --autotune 59 | ``` 60 | 61 | For `Llama-4-Scout-17B-16E`: 62 | ``` 63 | python benchmark/benchmark_fused_moe.py --model llama4 --autotune --mode=forward --permute_y 64 | ``` 65 | Ditto for backwards. 66 | 67 | ### Notes 68 | - Tested and benched on `H100`, though should run on Ampere and possibly even earlier gpu generations though the autotuning configs will need to be adjusted. 69 | - The env I used to develop the kernel was `pytorch 2.7/2.8` and `pytorch-triton 3.3`. 70 | - The kernels can be run either as autotuned (see `autotuning.py`) or with manually specified config (see `tuning.py`). Recommended to run using autotuner since the MoE block requires 2 configs for the forward (2 grouped gemms) and 4 for the backwards (dX and dW per grouped gemm, 2 grouped gemms). 71 | - Running with autotuning turned off with the default manual kernel config will result is **highly** sub-optimal performance as it is only meant for testing / debugging purposes. 72 | - I've tried to strike a balance between compilation time and autotuning search space -- can probably squeeze even more performance for specific workloads. 73 | - The Llama4 reference layer is still highly under-optimized as there are many low-hanging opportunities for further speedups around routing and shared expert calculation. 74 | 75 | TODO: 76 | - TMA store: implemented but not enabled currently due to non-determinism arising from triton pipelining bug. 77 | - Warp specialization: Hopper support for WS not yet enabled on triton 3.3x branch which ships with latest pytorch 2.7. 78 | - Additional optimizations: 79 | - Fused / optimized implementations of routing, token sorting, etc. 80 | - Better software pipelining within grouped gemm 81 | - Threadblock swizzling for better L2 caching 82 | - Llama4 83 | - Fused gather / topk weight merging 84 | - Custom topk, gather indices kernel 85 | - Shared expert fusion with experts calculation -------------------------------------------------------------------------------- /unsloth/kernels/moe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/unsloth/kernels/moe/__init__.py -------------------------------------------------------------------------------- /unsloth/kernels/moe/benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import math 6 | import os 7 | from itertools import product 8 | 9 | import pandas as pd 10 | import torch 11 | 12 | from grouped_gemm.kernels.tuning import ( 13 | KernelConfigBackward_dW, 14 | KernelConfigBackward_dX, 15 | KernelConfigForward, 16 | KernelResult, 17 | ) 18 | 19 | SEED = 42 20 | 21 | 22 | def create_merged_results( 23 | df: pd.DataFrame, mode: str, seqlen: int, dtype: torch.dtype, autotune: bool 24 | ): 25 | kernel_result_cols = df.columns.to_list() 26 | test_config_dict = { 27 | "mode": mode, 28 | "seqlen": seqlen, 29 | "dtype": dtype, 30 | "autotune": autotune, 31 | } 32 | test_config_cols = list(test_config_dict.keys()) 33 | for col in test_config_cols: 34 | df[col] = test_config_dict[col] 35 | # Reorder columns so that test config cols are first 36 | df = df[test_config_cols + kernel_result_cols] 37 | return df 38 | 39 | 40 | def post_process_results( 41 | results: list[KernelResult], 42 | mode: str, 43 | seqlen: int, 44 | dtype: torch.dtype, 45 | autotune: bool, 46 | ): 47 | df = KernelResult.to_dataframe(results, sort_by="speedup") 48 | df = create_merged_results(df, mode, seqlen, dtype, autotune) 49 | return df 50 | 51 | 52 | def save_results( 53 | df: pd.DataFrame, 54 | results_dir: str, 55 | mode: str, 56 | seqlen: int, 57 | dtype: torch.dtype, 58 | autotune: bool, 59 | ): 60 | dt = datetime.datetime.now().strftime("%Y%m%d_%H%M") 61 | save_dir = f"{results_dir}/{mode}" 62 | save_path = f"{save_dir}/{dt}_{seqlen}_{str(dtype).split('.')[-1]}.csv" 63 | if not os.path.exists(save_dir): 64 | os.makedirs(save_dir) 65 | print(f"Saving results to {save_path}") 66 | df.to_csv(save_path, index=False) 67 | 68 | 69 | def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y: bool): 70 | block_m_range = power_of_two_range(args.BLOCK_SIZE_M[0], args.BLOCK_SIZE_M[1]) 71 | block_n_range = power_of_two_range(args.BLOCK_SIZE_N[0], args.BLOCK_SIZE_N[1]) 72 | block_k_range = power_of_two_range(args.BLOCK_SIZE_K[0], args.BLOCK_SIZE_K[1]) 73 | num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step=2) 74 | num_stages_range = multiples_of_range( 75 | args.num_stages[0], args.num_stages[1], step=1 76 | ) 77 | 78 | mode = args.mode 79 | kernel_configs = [] 80 | for ( 81 | block_m, 82 | block_n, 83 | block_k, 84 | num_warps, 85 | num_stages, 86 | tma_load_a, 87 | tma_load_b, 88 | ) in product( 89 | block_m_range, 90 | block_n_range, 91 | block_k_range, 92 | num_warps_range, 93 | num_stages_range, 94 | [True, False], 95 | [True, False], 96 | ): 97 | if mode == "forward": 98 | kernel_config = KernelConfigForward( 99 | BLOCK_SIZE_M=block_m, 100 | BLOCK_SIZE_N=block_n, 101 | BLOCK_SIZE_K=block_k, 102 | num_warps=num_warps, 103 | num_stages=num_stages, 104 | use_tma_load_w=tma_load_a, 105 | use_tma_load_x=tma_load_b, 106 | permute_x=permute_x, 107 | permute_y=permute_y, 108 | ) 109 | elif mode == "dW": 110 | kernel_config = KernelConfigBackward_dW( 111 | BLOCK_SIZE_M=block_m, 112 | BLOCK_SIZE_N=block_n, 113 | BLOCK_SIZE_K=block_k, 114 | num_warps=num_warps, 115 | num_stages=num_stages, 116 | use_tma_load_dy=tma_load_a, 117 | use_tma_load_x=tma_load_b, 118 | permute_x=permute_x, 119 | permute_y=permute_y, 120 | ) 121 | elif mode == "dX": 122 | kernel_config = KernelConfigBackward_dX( 123 | BLOCK_SIZE_M=block_m, 124 | BLOCK_SIZE_N=block_n, 125 | BLOCK_SIZE_K=block_k, 126 | num_warps=num_warps, 127 | num_stages=num_stages, 128 | use_tma_load_dy=tma_load_a, 129 | use_tma_load_w=tma_load_b, 130 | permute_x=permute_x, 131 | permute_y=permute_y, 132 | ) 133 | else: 134 | raise ValueError(f"Invalid mode: {mode}") 135 | kernel_configs.append(kernel_config) 136 | 137 | logging.info(f"Pruning {len(kernel_configs)} kernel configs") 138 | 139 | pruned_configs = [] 140 | for config in kernel_configs: 141 | if mode == "forward": 142 | if permute_x and config.use_tma_load_x: 143 | continue 144 | elif mode == "dW": 145 | if permute_x and config.use_tma_load_x: 146 | continue 147 | if permute_y and config.use_tma_load_dy: 148 | continue 149 | elif mode == "dX": 150 | if permute_y and config.use_tma_load_dy: 151 | continue 152 | pruned_configs.append(config) 153 | logging.info(f"After pruning, {len(pruned_configs)} kernel configs") 154 | 155 | return pruned_configs 156 | 157 | 158 | def power_of_two_range(start, end): 159 | start = math.log2(start) 160 | end = math.log2(end) 161 | return [2**i for i in range(int(start), int(end) + 1)] 162 | 163 | 164 | def multiples_of_range(start, end, step=1): 165 | return list(range(start, end + step, step)) 166 | 167 | 168 | def map_key_to_args(key, mode): 169 | pass 170 | 171 | 172 | def save_autotune_results(autotune_cache, mode, ref_time, fused_time, results_dir): 173 | device_name = torch.cuda.get_device_name().replace(" ", "_") 174 | dt = datetime.datetime.now().strftime("%Y%m%d_%H%M") 175 | save_dir = f"{results_dir}/{mode}/autotune/{dt}/{device_name}" 176 | if not os.path.exists(save_dir): 177 | os.makedirs(save_dir) 178 | 179 | for key, config in autotune_cache.items(): 180 | key = [ 181 | str(k) if not "torch" in str(k) else str(k.split("torch.")[-1]) for k in key 182 | ] 183 | filename = "_".join(key) 184 | save_path = f"{save_dir}/{filename}.json" 185 | print(f"Saving autotune results to {save_path}") 186 | with open(save_path, "w") as f: 187 | result = { 188 | **config.all_kwargs(), 189 | "ref_time": ref_time, 190 | "fused_time": fused_time, 191 | } 192 | json.dump(result, f) 193 | 194 | 195 | def get_autotuner(mode): 196 | if mode == "forward": 197 | from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel 198 | 199 | return _autotuned_grouped_gemm_forward_kernel 200 | elif mode == "dW": 201 | from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel 202 | 203 | return _autotuned_grouped_gemm_dW_kernel 204 | elif mode == "dX": 205 | from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dX_kernel 206 | 207 | return _autotuned_grouped_gemm_dX_kernel 208 | elif mode == "backward": 209 | from grouped_gemm.kernels.backward import ( 210 | _autotuned_grouped_gemm_dW_kernel, 211 | _autotuned_grouped_gemm_dX_kernel, 212 | ) 213 | 214 | return _autotuned_grouped_gemm_dW_kernel, _autotuned_grouped_gemm_dX_kernel 215 | else: 216 | raise ValueError(f"Invalid mode: {mode}") 217 | 218 | 219 | def postprocess_autotune_results(autotuner, mode, ref_time, fused_time, results_dir): 220 | for key, value in autotuner.cache.items(): 221 | print(f"{mode} {key}: {value.all_kwargs()}") 222 | save_autotune_results( 223 | autotuner.cache, 224 | mode=mode, 225 | ref_time=ref_time, 226 | fused_time=fused_time, 227 | results_dir=results_dir, 228 | ) 229 | -------------------------------------------------------------------------------- /unsloth/kernels/moe/grouped_gemm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/unsloth/kernels/moe/grouped_gemm/__init__.py -------------------------------------------------------------------------------- /unsloth/kernels/moe/grouped_gemm/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/unsloth/kernels/moe/grouped_gemm/kernels/__init__.py -------------------------------------------------------------------------------- /unsloth/kernels/moe/grouped_gemm/kernels/tuning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manual tuning utils 3 | """ 4 | 5 | from collections import OrderedDict 6 | from dataclasses import asdict, dataclass, fields 7 | from itertools import product 8 | from typing import Optional 9 | 10 | import pandas as pd 11 | import torch 12 | import triton 13 | from triton.runtime.errors import OutOfResources 14 | 15 | from grouped_gemm.kernels.autotuning import ( 16 | BOOLS, 17 | DEFAULT_K_BLOCK_SIZES, 18 | DEFAULT_M_BLOCK_SIZES, 19 | DEFAULT_N_BLOCK_SIZES, 20 | DEFAULT_NUM_STAGES, 21 | DEFAULT_NUM_WARPS, 22 | ) 23 | 24 | 25 | @dataclass 26 | class DeviceProperties: 27 | NUM_SM: int 28 | NUM_REGS: int 29 | SIZE_SMEM: int 30 | WARP_SIZE: int 31 | 32 | 33 | _DEVICE_PROPERTIES: Optional[DeviceProperties] = None 34 | 35 | 36 | def get_device_properties(): 37 | global _DEVICE_PROPERTIES 38 | if _DEVICE_PROPERTIES is None: 39 | properties = triton.runtime.driver.active.utils.get_device_properties( 40 | torch.cuda.current_device() 41 | ) 42 | NUM_SM = properties["multiprocessor_count"] 43 | NUM_REGS = properties["max_num_regs"] 44 | SIZE_SMEM = properties["max_shared_mem"] 45 | WARP_SIZE = properties["warpSize"] 46 | _DEVICE_PROPERTIES = DeviceProperties(NUM_SM, NUM_REGS, SIZE_SMEM, WARP_SIZE) 47 | return _DEVICE_PROPERTIES 48 | 49 | 50 | @dataclass 51 | class KernelConfig: 52 | BLOCK_SIZE_M: int = 32 53 | BLOCK_SIZE_N: int = 32 54 | BLOCK_SIZE_K: int = 32 55 | num_warps: int = 4 56 | num_stages: int = 2 57 | flatten: bool = True 58 | permute_x: bool = False 59 | permute_y: bool = False 60 | fuse_mul_post: bool = False 61 | use_tma_store: bool = False 62 | 63 | def to_string(self, include_tuning_params: bool = False, include_tma: bool = False): 64 | s = [] 65 | if self.permute_x: 66 | s.append("permute_x") 67 | if self.permute_y: 68 | s.append("permute_y") 69 | if include_tuning_params: 70 | s.append( 71 | f"BLOCK_SIZE_M={self.BLOCK_SIZE_M},BLOCK_SIZE_N={self.BLOCK_SIZE_N},BLOCK_SIZE_K={self.BLOCK_SIZE_K},num_warps={self.num_warps},num_stages={self.num_stages},flatten={self.flatten}" 72 | ) 73 | if include_tma: 74 | for f in fields(self): 75 | if f.name.startswith("use_tma_"): 76 | if getattr(self, f.name): 77 | s.append(f.name) 78 | return ",".join(s) 79 | 80 | 81 | @dataclass 82 | class KernelConfigForward(KernelConfig): 83 | use_tma_load_w: bool = False 84 | use_tma_load_x: bool = False 85 | 86 | 87 | @dataclass 88 | class KernelConfigBackward_dW(KernelConfig): 89 | use_tma_load_dy: bool = False 90 | use_tma_load_x: bool = False 91 | 92 | 93 | @dataclass 94 | class KernelConfigBackward_dX(KernelConfig): 95 | use_tma_load_dy: bool = False 96 | use_tma_load_w: bool = False 97 | 98 | 99 | @dataclass 100 | class KernelResult: 101 | torch_time: float 102 | triton_time: float 103 | speedup: float 104 | kernel_config: KernelConfig 105 | 106 | def to_dict(self): 107 | return OrderedDict( 108 | **asdict(self.kernel_config), 109 | torch_time=self.torch_time, 110 | triton_time=self.triton_time, 111 | speedup=self.speedup, 112 | ) 113 | 114 | @staticmethod 115 | def to_dataframe( 116 | results: list["KernelResult"], sort_by: str = "speedup", ascending: bool = False 117 | ): 118 | df = pd.DataFrame([result.to_dict() for result in results]) 119 | df = df.sort_values(by=sort_by, ascending=ascending) 120 | return df 121 | 122 | @staticmethod 123 | def to_csv( 124 | results: list["KernelResult"], 125 | sort_by: str = "speedup", 126 | ascending: bool = False, 127 | filename: str = "results.csv", 128 | ): 129 | df = KernelResult.to_dataframe(results, sort_by, ascending) 130 | df.to_csv(filename, index=False) 131 | 132 | @staticmethod 133 | def print_table( 134 | results: list["KernelResult"], 135 | sort_by: str = "speedup", 136 | ascending: bool = False, 137 | num_results: int = 10, 138 | ): 139 | df = KernelResult.to_dataframe(results, sort_by, ascending) 140 | print(df.head(num_results).to_string(index=False)) 141 | 142 | 143 | def get_kernel_configs( 144 | BLOCK_M=DEFAULT_M_BLOCK_SIZES, 145 | BLOCK_N=DEFAULT_N_BLOCK_SIZES, 146 | BLOCK_K=DEFAULT_K_BLOCK_SIZES, 147 | num_warps=DEFAULT_NUM_WARPS, 148 | num_stages=DEFAULT_NUM_STAGES, 149 | use_tma_loads=BOOLS, 150 | fuse_permute=BOOLS, 151 | ): 152 | kernel_configs_fwd = [] 153 | kernel_configs_backward_dW = [] 154 | kernel_configs_backward_dX = [] 155 | for block_m, block_n, block_k, w, s, use_tma_load, permute in product( 156 | BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, use_tma_loads, fuse_permute 157 | ): 158 | kernel_configs_fwd.append( 159 | KernelConfigForward( 160 | BLOCK_SIZE_M=block_m, 161 | BLOCK_SIZE_N=block_n, 162 | BLOCK_SIZE_K=block_k, 163 | num_warps=w, 164 | num_stages=s, 165 | use_tma_load_x=use_tma_load, 166 | use_tma_load_w=use_tma_load, 167 | use_tma_store=False, 168 | permute_x=permute, 169 | permute_y=permute, 170 | ) 171 | ) 172 | kernel_configs_backward_dW.append( 173 | KernelConfigBackward_dW( 174 | BLOCK_SIZE_M=block_m, 175 | BLOCK_SIZE_N=block_n, 176 | BLOCK_SIZE_K=block_k, 177 | num_warps=w, 178 | num_stages=s, 179 | use_tma_load_dy=use_tma_load, 180 | use_tma_load_x=use_tma_load, 181 | use_tma_store=False, 182 | permute_x=permute, 183 | permute_y=permute, 184 | ) 185 | ) 186 | kernel_configs_backward_dX.append( 187 | KernelConfigBackward_dX( 188 | BLOCK_SIZE_M=block_m, 189 | BLOCK_SIZE_N=block_n, 190 | BLOCK_SIZE_K=block_k, 191 | num_warps=w, 192 | num_stages=s, 193 | use_tma_load_dy=use_tma_load, 194 | use_tma_load_w=use_tma_load, 195 | use_tma_store=False, 196 | permute_x=permute, 197 | permute_y=permute, 198 | ) 199 | ) 200 | 201 | kernel_configs_fwd = prune_kernel_configs_fwd(kernel_configs_fwd) 202 | kernel_configs_backward_dW = prune_kernel_configs_backward_dW( 203 | kernel_configs_backward_dW 204 | ) 205 | kernel_configs_backward_dX = prune_kernel_configs_backward_dX( 206 | kernel_configs_backward_dX 207 | ) 208 | return kernel_configs_fwd, kernel_configs_backward_dW, kernel_configs_backward_dX 209 | 210 | 211 | def prune_kernel_configs_fwd(configs: list[KernelConfigForward]): 212 | pruned_configs = [] 213 | for config in configs: 214 | if config.use_tma_load_x and config.permute_x: 215 | continue 216 | if config.permute_x and config.permute_y: 217 | continue 218 | if config.use_tma_store and config.permute_y: 219 | continue 220 | pruned_configs.append(config) 221 | return pruned_configs 222 | 223 | 224 | def prune_kernel_configs_backward_dX(configs: list[KernelConfigBackward_dX]): 225 | pruned_configs = [] 226 | for config in configs: 227 | if config.use_tma_load_dy and config.permute_y: 228 | continue 229 | if config.permute_x and config.permute_y: 230 | continue 231 | if config.use_tma_store and config.permute_x: 232 | continue 233 | pruned_configs.append(config) 234 | return pruned_configs 235 | 236 | 237 | def prune_kernel_configs_backward_dW(configs: list[KernelConfigBackward_dW]): 238 | pruned_configs = [] 239 | for config in configs: 240 | if config.use_tma_load_dy and config.permute_y: 241 | continue 242 | if config.use_tma_load_x and config.permute_x: 243 | continue 244 | if config.permute_x and config.permute_y: 245 | continue 246 | pruned_configs.append(config) 247 | return pruned_configs 248 | 249 | 250 | class TritonTuningContext: 251 | def __init__(self, kernel_config: KernelConfig): 252 | self.kernel_config = kernel_config 253 | self.success = True 254 | 255 | def __enter__(self): 256 | # Setup code can be added here if needed 257 | return self 258 | 259 | def __exit__(self, exc_type, exc_value, traceback): 260 | if exc_type is OutOfResources: 261 | name = exc_value.name 262 | required = exc_value.required 263 | limit = exc_value.limit 264 | print( 265 | f"Kernel config {self.kernel_config} failed: {name}, required: {required}, limit: {limit}" 266 | ) 267 | self.success = False 268 | elif exc_type is not None: 269 | print( 270 | f"Error running Triton grouped GEMM for kernel config: {self.kernel_config}: {exc_value}" 271 | ) 272 | self.success = False 273 | # Return False to propagate exceptions, True to suppress them 274 | return True 275 | -------------------------------------------------------------------------------- /unsloth/kernels/moe/grouped_gemm/reference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/unsloth/kernels/moe/grouped_gemm/reference/__init__.py -------------------------------------------------------------------------------- /unsloth/kernels/moe/grouped_gemm/reference/moe_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig 3 | from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock 4 | 5 | from grouped_gemm.interface import grouped_gemm 6 | from grouped_gemm.kernels.tuning import ( 7 | KernelConfigBackward_dW, 8 | KernelConfigBackward_dX, 9 | KernelConfigForward, 10 | ) 11 | from grouped_gemm.reference.moe_ops import ( 12 | Qwen3MoeGroupedGEMMBlock, 13 | permute, 14 | unpermute, 15 | ) 16 | 17 | """ 18 | Reference implementation of MoE block using grouped gemm. 19 | 20 | This is the same as the Qwen3MoeGroupedGEMMBlock but with triton grouped gemm in place of torch-native grouped gemm implementation. 21 | 22 | NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging. 23 | """ 24 | 25 | 26 | class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock): 27 | def __init__( 28 | self, 29 | config: Qwen3MoeConfig, 30 | gate: torch.Tensor, 31 | gate_up_proj: torch.Tensor, 32 | down_proj: torch.Tensor, 33 | permute_x: bool = True, 34 | permute_y: bool = True, 35 | autotune: bool = True, 36 | kernel_config_fwd: KernelConfigForward = None, 37 | kernel_config_bwd_dW: KernelConfigBackward_dW = None, 38 | kernel_config_bwd_dX: KernelConfigBackward_dX = None, 39 | dW_only: bool = False, 40 | dX_only: bool = False, 41 | ): 42 | super().__init__(config, gate, gate_up_proj, down_proj) 43 | self.permute_x = permute_x 44 | self.permute_y = permute_y 45 | self.autotune = autotune 46 | if not autotune: 47 | assert ( 48 | kernel_config_fwd is not None 49 | and kernel_config_bwd_dW is not None 50 | and kernel_config_bwd_dX is not None 51 | ), "Kernel configs must be provided if autotune is False" 52 | self.kernel_config_fwd = kernel_config_fwd 53 | self.kernel_config_bwd_dW = kernel_config_bwd_dW 54 | self.kernel_config_bwd_dX = kernel_config_bwd_dX 55 | self.dW_only = dW_only 56 | self.dX_only = dX_only 57 | 58 | @classmethod 59 | def from_hf( 60 | cls, 61 | moe_block: Qwen3MoeSparseMoeBlock, 62 | permute_x: bool = True, 63 | permute_y: bool = True, 64 | autotune: bool = True, 65 | kernel_config_fwd: KernelConfigForward = None, 66 | kernel_config_bwd_dW: KernelConfigBackward_dW = None, 67 | kernel_config_bwd_dX: KernelConfigBackward_dX = None, 68 | dW_only: bool = False, 69 | dX_only: bool = False, 70 | ): 71 | config: Qwen3MoeConfig = moe_block.experts[0].config 72 | gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights( 73 | moe_block 74 | ) 75 | return cls( 76 | config, 77 | gate, 78 | gate_up_proj, 79 | down_proj, 80 | permute_x=permute_x, 81 | permute_y=permute_y, 82 | autotune=autotune, 83 | kernel_config_fwd=kernel_config_fwd, 84 | kernel_config_bwd_dW=kernel_config_bwd_dW, 85 | kernel_config_bwd_dX=kernel_config_bwd_dX, 86 | dW_only=dW_only, 87 | dX_only=dX_only, 88 | ) 89 | 90 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 91 | batch_size, sequence_length, hidden_dim = hidden_states.shape 92 | num_tokens = batch_size * sequence_length 93 | total_tokens = num_tokens * self.top_k 94 | 95 | hidden_states = hidden_states.view(-1, hidden_dim) 96 | 97 | router_logits, routing_weights, selected_experts = self.run_router( 98 | hidden_states 99 | ) 100 | # Pre-processing 101 | # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order 102 | # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph 103 | token_counts_by_expert, gather_indices = ( 104 | self.get_token_counts_and_gather_indices(selected_experts) 105 | ) 106 | 107 | # 2. permute_x -> permutation will be fused in prologue of first grouped gemm 108 | if not self.permute_x: 109 | hidden_states = permute(hidden_states, gather_indices, self.top_k) 110 | # Start expert computation 111 | hidden_states = grouped_gemm( 112 | X=hidden_states, 113 | W=self.gate_up_proj, 114 | m_sizes=token_counts_by_expert, 115 | gather_indices=gather_indices, 116 | topk=self.top_k, 117 | permute_x=self.permute_x, 118 | permute_y=False, # output of first grouped gemm should never be permuted 119 | autotune=self.autotune, 120 | kernel_config_fwd=self.kernel_config_fwd, 121 | kernel_config_bwd_dW=self.kernel_config_bwd_dW, 122 | kernel_config_bwd_dX=self.kernel_config_bwd_dX, 123 | is_first_gemm=True, 124 | dW_only=self.dW_only, 125 | dX_only=self.dX_only, 126 | ) 127 | hidden_states = self.act_and_mul(hidden_states) 128 | hidden_states = grouped_gemm( 129 | X=hidden_states, 130 | W=self.down_proj, 131 | m_sizes=token_counts_by_expert, 132 | gather_indices=gather_indices, 133 | topk=self.top_k, 134 | permute_x=False, 135 | permute_y=self.permute_y, 136 | autotune=self.autotune, 137 | kernel_config_fwd=self.kernel_config_fwd, 138 | kernel_config_bwd_dW=self.kernel_config_bwd_dW, 139 | kernel_config_bwd_dX=self.kernel_config_bwd_dX, 140 | is_first_gemm=False, 141 | dW_only=self.dW_only, 142 | dX_only=self.dX_only, 143 | ) 144 | 145 | # Post-processing 146 | # 1. Unpermute from expert order to token order 147 | if not self.permute_y: 148 | hidden_states = unpermute(hidden_states, gather_indices) 149 | 150 | # 2. Merge topk weights 151 | hidden_states = ( 152 | hidden_states.view(num_tokens, self.top_k, hidden_dim) 153 | * routing_weights[..., None] 154 | ) 155 | hidden_states = hidden_states.sum(dim=1) 156 | 157 | hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) 158 | return hidden_states, router_logits 159 | -------------------------------------------------------------------------------- /unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def permute(X: torch.Tensor, gather_indices: torch.Tensor, topk: int): 7 | """ 8 | Scatters X to a new tensor with shape [total_tokens, hidden_dim] where total_tokens is num_tokens * topk, 9 | permuting the tokens according to sorted_token_idx. 10 | 11 | Helper for grouped gemm where hidden states need be ordered by expert. 12 | X: [num_tokens, hidden_dim] 13 | sorted_token_idx: [num_tokens * topk] 14 | topk: int 15 | 16 | Returns: 17 | [total_tokens, hidden_dim] 18 | """ 19 | assert gather_indices.ndim == 1 20 | X = X.view(-1, X.shape[-1]) 21 | # Shortcut for topk == 1 22 | if topk == 1: 23 | return X[gather_indices] 24 | 25 | return X[gather_indices // topk] 26 | 27 | 28 | def unpermute(X: torch.Tensor, gather_indices: torch.Tensor): 29 | X = X.view(-1, X.shape[-1]) if X.ndim > 2 else X 30 | unpermuted = torch.empty_like(X) 31 | unpermuted.index_copy_(0, gather_indices, X) 32 | return unpermuted.view_as(X) 33 | 34 | 35 | def calculate_topk( 36 | gating_output: torch.Tensor, 37 | top_k: int, 38 | use_sigmoid: bool, 39 | renormalize: bool, 40 | pre_act: bool = True, 41 | post_act: bool = False, 42 | ): 43 | """ 44 | If post_act is True, then activation function is run AFTER topk 45 | If post_act is False, then activation function is run BEFORE topk 46 | 47 | This is to align with triton_bench implementation (post_act) whereas most models use pre_act (e.g. llama4, deepseek) 48 | """ 49 | assert pre_act ^ post_act, "only one of pre_act or post_act can be True" 50 | 51 | def _activation(gating_output: torch.Tensor): 52 | if use_sigmoid: 53 | scores = torch.sigmoid(gating_output.to(torch.float32)).to(gating_output.dtype) 54 | else: 55 | scores = F.softmax(gating_output.to(torch.float32), dim=1).to(gating_output.dtype) 56 | 57 | return scores 58 | 59 | if pre_act: 60 | scores = _activation(gating_output) 61 | else: 62 | scores = gating_output 63 | 64 | topk_weights, topk_ids = torch.topk(scores, k=top_k, dim=1) 65 | 66 | if post_act: 67 | topk_weights = _activation(topk_weights) 68 | 69 | if renormalize: 70 | topk_weights /= torch.sum(topk_weights, dim=-1, keepdim=True).to(gating_output.dtype) 71 | 72 | return topk_weights, topk_ids 73 | 74 | 75 | @torch.no_grad() 76 | def get_routing_indices(selected_experts, num_experts, return_scatter_indices: bool = False): 77 | """ 78 | Returns: 79 | token_counts_by_expert: [num_experts] 80 | gather_indices: [num_tokens] 81 | scatter_indices [Optional] (torch.Tensor): 82 | Indices for unpermuting gathered inputs back to token order, shape ``(bs * seqlen * top_k,)``. 83 | """ 84 | # group tokens together by expert indices from 0 to num_experts and pass that to experts forward 85 | token_counts_by_expert = torch.histc( 86 | selected_experts.view(-1), 87 | bins=num_experts, 88 | min=0, 89 | max=num_experts, 90 | ) 91 | # token_indices_experts_sorted shape (bs*slen*top_k,) 92 | gather_indices = torch.argsort(selected_experts.view(-1), stable=True) 93 | if return_scatter_indices: 94 | scatter_indices = gather_indices.argsort() 95 | return token_counts_by_expert, gather_indices, scatter_indices 96 | else: 97 | return token_counts_by_expert, gather_indices 98 | 99 | 100 | def torch_grouped_gemm(X, W, m_sizes, transpose=True): 101 | """ 102 | X: [M, K] if forward, else [M, N] 103 | W: [E, N, K] 104 | m_sizes: [E] 105 | 106 | Returns: 107 | Y: [M, N] if forward, else [M, K] 108 | """ 109 | X = X.view(-1, X.shape[-1]) 110 | M, K = X.shape 111 | 112 | assert m_sizes.ndim == 1 113 | E = m_sizes.shape[0] 114 | 115 | assert W.ndim == 3 116 | assert W.shape[0] == E 117 | 118 | N = W.shape[1] 119 | 120 | result = torch.zeros((M, N), dtype=X.dtype, device=X.device) 121 | 122 | m_start = 0 123 | for g in range(E): 124 | m_size = m_sizes[g] 125 | if m_size > 0: 126 | m_end = m_start + m_size 127 | 128 | # Extract group input 129 | # m_size x K 130 | X_g = X[m_start:m_end] 131 | # N x K 132 | W_g = W[g] 133 | 134 | # Y_g = X_g @ W_g.T -> [m_size, N] 135 | W_g = W_g.T if transpose else W_g 136 | Y_g = X_g @ W_g 137 | 138 | result[m_start:m_end] = Y_g 139 | 140 | m_start = m_end 141 | return result 142 | -------------------------------------------------------------------------------- /unsloth/kernels/moe/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | git+https://github.com/huggingface/transformers.git@main 3 | pytest 4 | pandas 5 | ruff -------------------------------------------------------------------------------- /unsloth/kernels/moe/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/unsloth/kernels/moe/tests/__init__.py -------------------------------------------------------------------------------- /unsloth/kernels/moe/tests/run_qwen3_moe_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | SEQLENS=(1024) 6 | DTYPES=(bfloat16) 7 | PERMUTE_X=(false true) 8 | PERMUTE_Y=(false true) 9 | AUTOTUNE=(false true) 10 | 11 | for SEQLEN in "${SEQLENS[@]}"; do 12 | for DTYPE in "${DTYPES[@]}"; do 13 | for PX in "${PERMUTE_X[@]}"; do 14 | for PY in "${PERMUTE_Y[@]}"; do 15 | for AT in "${AUTOTUNE[@]}"; do 16 | 17 | ARGS=() 18 | [[ "$PX" == "true" ]] && ARGS+=("--permute_x") 19 | [[ "$PY" == "true" ]] && ARGS+=("--permute_y") 20 | [[ "$AT" == "true" ]] && ARGS+=("--autotune") 21 | 22 | ARGS+=(--seqlen "$SEQLEN" --dtype "$DTYPE") 23 | 24 | echo "Running with args: ${ARGS[*]}" 25 | if ! python -m tests.test_qwen3_moe "${ARGS[@]}"; then 26 | echo "❌ Test failed with args: --permute_x=$PX --permute_y=$PY --autotune=$AT --seqlen=$SEQLEN --dtype=$DTYPE" >&2 27 | else 28 | echo "✅ Test passed with args: --permute_x=$PX --permute_y=$PY --autotune=$AT --seqlen=$SEQLEN --dtype=$DTYPE" 29 | fi 30 | 31 | done 32 | done 33 | done 34 | done 35 | done 36 | -------------------------------------------------------------------------------- /unsloth/kernels/moe/tests/test_llama4_moe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from contextlib import contextmanager 4 | from functools import partial 5 | 6 | import pytest 7 | import torch 8 | from transformers import AutoConfig 9 | from transformers.models.llama4 import Llama4Config, Llama4TextConfig 10 | from transformers.models.llama4.modeling_llama4 import Llama4TextMoe 11 | 12 | from grouped_gemm.kernels.tuning import ( 13 | KernelConfigBackward_dW, 14 | KernelConfigBackward_dX, 15 | KernelConfigForward, 16 | ) 17 | from grouped_gemm.reference.layers.llama4_moe import ( 18 | Llama4GroupedGemmTextMoe, 19 | Llama4TritonTextMoe, 20 | ) 21 | 22 | TOLERANCES = { 23 | torch.bfloat16: (1e-2, 1e-2), 24 | torch.float16: (1e-3, 1e-3), 25 | torch.float: (1e-5, 1e-5), 26 | } 27 | 28 | LLAMA4_SCOUT_ID = "meta-llama/Llama-4-Scout-17B-16E" 29 | SEED = 42 30 | SEQ_LENS = [1024] 31 | DTYPES = [torch.bfloat16] 32 | # Reduce the number of autotuning configs to prevent excessive runtime 33 | NUM_AUTOTUNE_CONFIGS = 50 34 | 35 | 36 | @contextmanager 37 | def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80): 38 | print(char * num_chars) 39 | print(prelude) 40 | yield 41 | print(epilogue) 42 | print(char * num_chars) 43 | 44 | 45 | def get_text_config(model_id): 46 | config: Llama4Config = AutoConfig.from_pretrained(model_id) 47 | return config.text_config 48 | 49 | 50 | def prep_triton_kernel_traits(autotune): 51 | if not autotune: 52 | kernel_config_fwd = KernelConfigForward() 53 | kernel_config_bwd_dW = KernelConfigBackward_dW() 54 | kernel_config_bwd_dX = KernelConfigBackward_dX() 55 | else: 56 | from grouped_gemm.kernels.backward import ( 57 | _autotuned_grouped_gemm_dW_kernel, 58 | _autotuned_grouped_gemm_dX_kernel, 59 | ) 60 | from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel 61 | 62 | # Hack to reduce number of autotuning configs 63 | _autotuned_grouped_gemm_forward_kernel.configs = ( 64 | _autotuned_grouped_gemm_forward_kernel.configs[:NUM_AUTOTUNE_CONFIGS] 65 | ) 66 | _autotuned_grouped_gemm_dW_kernel.configs = ( 67 | _autotuned_grouped_gemm_dW_kernel.configs[:NUM_AUTOTUNE_CONFIGS] 68 | ) 69 | _autotuned_grouped_gemm_dX_kernel.configs = ( 70 | _autotuned_grouped_gemm_dX_kernel.configs[:NUM_AUTOTUNE_CONFIGS] 71 | ) 72 | 73 | kernel_config_fwd = None 74 | kernel_config_bwd_dW = None 75 | kernel_config_bwd_dX = None 76 | 77 | return kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX 78 | 79 | 80 | def sparse_to_dense(t: torch.Tensor): 81 | t = t.sum(dim=0).view(-1) 82 | return t 83 | 84 | 85 | @torch.no_grad() 86 | def _check_diff( 87 | t1: torch.Tensor, 88 | t2: torch.Tensor, 89 | atol, 90 | rtol, 91 | precision=".6f", 92 | verbose=False, 93 | msg="", 94 | ): 95 | t2 = t2.view_as(t1) 96 | diff = t1.sub(t2).abs().max().item() 97 | if verbose: 98 | if msg == "": 99 | msg = "diff" 100 | print(f"{msg}: {diff:{precision}}") 101 | assert torch.allclose(t1, t2, atol=atol, rtol=rtol) 102 | 103 | 104 | def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module): 105 | y.backward(grad_output) 106 | for name, param in module.named_parameters(): 107 | assert param.grad is not None, f"{name} missing grad!" 108 | 109 | 110 | def _check_grads( 111 | m1: torch.nn.Module, 112 | m2: torch.nn.Module, 113 | atol, 114 | rtol, 115 | precision=".6f", 116 | verbose=False, 117 | msg="", 118 | ): 119 | for name, param in m1.named_parameters(): 120 | _check_diff( 121 | param.grad, 122 | m2.get_parameter(name).grad, 123 | atol=atol, 124 | rtol=rtol, 125 | precision=precision, 126 | verbose=verbose, 127 | msg=f"{msg}:{name}.grad", 128 | ) 129 | 130 | 131 | @pytest.fixture 132 | def model_config(): 133 | return AutoConfig.from_pretrained(LLAMA4_SCOUT_ID).text_config 134 | 135 | 136 | @pytest.mark.parametrize( 137 | "overlap_router_shared", 138 | [False, True], 139 | ids=lambda x: "overlap_router_shared" if x else "no_overlap", 140 | ) 141 | @pytest.mark.parametrize( 142 | "permute_y", [False, True], ids=lambda x: "permute_y" if x else "no_permute_y" 143 | ) 144 | @pytest.mark.parametrize( 145 | "permute_x", [False], ids=lambda x: "permute_x" if x else "no_permute_x" 146 | ) # Llama4 does not support permute_x 147 | @pytest.mark.parametrize( 148 | "autotune", [True], ids=lambda x: "autotune" if x else "manual" 149 | ) 150 | @pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}") 151 | @pytest.mark.parametrize("dtype", DTYPES, ids=str) 152 | def test_llama4_ref( 153 | dtype: torch.dtype, 154 | seqlen, 155 | autotune: bool, 156 | permute_x: bool, 157 | permute_y: bool, 158 | overlap_router_shared: bool, 159 | model_config: Llama4TextConfig, # test fixture 160 | bs: int = 1, 161 | device="cuda", 162 | precision=".6f", 163 | verbose=False, 164 | ): 165 | torch.manual_seed( 166 | SEED 167 | ) # Should not be needed when running using pytest -- autouse fixture in conftest.py 168 | device = "cuda" 169 | hidden_dim = model_config.hidden_size 170 | atol, rtol = TOLERANCES[dtype] 171 | check_diff = partial( 172 | _check_diff, atol=atol, rtol=rtol, precision=precision, verbose=verbose 173 | ) 174 | check_grads = partial( 175 | _check_grads, atol=atol, rtol=rtol, precision=precision, verbose=verbose 176 | ) 177 | 178 | # Reference op -- HF 179 | llama4_ref = Llama4TextMoe(model_config).to(dtype=dtype, device=device) 180 | 181 | # Torch grouped gemm impl 182 | llama4_gg_ref = Llama4GroupedGemmTextMoe( 183 | model_config, overlap_router_shared=overlap_router_shared 184 | ).to(dtype=dtype, device=device) 185 | llama4_gg_ref.copy_weights(llama4_ref) 186 | llama4_gg_ref.check_weights(llama4_ref) 187 | 188 | x_ref = torch.randn( 189 | bs, seqlen, hidden_dim, dtype=dtype, device=device, requires_grad=True 190 | ) 191 | x_torch_gg = x_ref.detach().clone().requires_grad_() 192 | x_triton = x_ref.detach().clone().requires_grad_() 193 | 194 | y_ref, routing_ref = llama4_ref(x_ref) 195 | y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg) 196 | assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}" 197 | with annotated_context("Testing torch grouped gemm Llama4TextMoe"): 198 | check_diff(y_ref, y_torch_gg, msg="y_torch_gg") 199 | check_diff( 200 | sparse_to_dense(routing_ref), routing_torch_gg, msg="routing_torch_gg" 201 | ) 202 | 203 | kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = ( 204 | prep_triton_kernel_traits(autotune) 205 | ) 206 | 207 | llama4_triton = Llama4TritonTextMoe( 208 | model_config, 209 | overlap_router_shared=overlap_router_shared, 210 | permute_x=permute_x, 211 | permute_y=permute_y, 212 | autotune=autotune, 213 | kernel_config_fwd=kernel_config_fwd, 214 | kernel_config_bwd_dW=kernel_config_bwd_dW, 215 | kernel_config_bwd_dX=kernel_config_bwd_dX, 216 | ).to(device=device, dtype=dtype) 217 | llama4_triton.copy_weights(llama4_ref) 218 | llama4_triton.check_weights(llama4_ref) 219 | 220 | y_triton, routing_triton = llama4_triton(x_triton) 221 | with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"): 222 | check_diff(y_ref, y_triton, msg="y_triton") 223 | check_diff(sparse_to_dense(routing_ref), routing_triton, msg="routing_triton") 224 | 225 | ref_grad = torch.randn_like(y_ref) 226 | run_backwards(y_ref, ref_grad, llama4_ref) 227 | run_backwards(y_torch_gg, ref_grad, llama4_gg_ref) 228 | with annotated_context("Testing torch group gemm Llama4TextMoe backward"): 229 | check_grads(llama4_ref, llama4_gg_ref, msg="torch_gg") 230 | 231 | run_backwards(y_triton, ref_grad, llama4_triton) 232 | with annotated_context("Testing triton group gemm Llama4TextMoe backward"): 233 | check_grads(llama4_ref, llama4_triton, msg="triton") 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument("--seqlen", type=int, default=1024) 239 | parser.add_argument( 240 | "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16" 241 | ) 242 | args = parser.parse_args() 243 | args.dtype = getattr(torch, args.dtype) 244 | args_dict = vars(args) 245 | 246 | model_id = LLAMA4_SCOUT_ID 247 | 248 | text_config: Llama4TextConfig = get_text_config(model_id) 249 | for overlap in [False, True]: 250 | test_llama4_ref( 251 | seqlen=args.seqlen, 252 | model_config=text_config, 253 | dtype=args.dtype, 254 | autotune=True, 255 | permute_x=False, 256 | permute_y=True, 257 | overlap_router_shared=overlap, 258 | verbose=True, 259 | ) 260 | -------------------------------------------------------------------------------- /unsloth/kernels/rope_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import triton 16 | import triton.language as tl 17 | import torch 18 | from .utils import calculate_settings, torch_cuda_device 19 | ROPE_GROUP_SIZE : int = 4 20 | 21 | def _rope_embedding( 22 | Q, Q_row_stride, 23 | cos, cos_row_stride, 24 | sin, sin_row_stride, 25 | seqlen, 26 | head_dim : tl.constexpr, 27 | n_heads : tl.constexpr, 28 | BACKWARD_PASS : tl.constexpr, 29 | BLOCK_SIZE : tl.constexpr, 30 | ): 31 | """ 32 | Calculates the RoPE Embedding quickly 33 | RoPE is Q * cos + rotate_half(Q) * sin 34 | See our blog post for more info 35 | """ 36 | ROPE_GROUP_SIZE = 4 37 | row_position = tl.program_id(0) 38 | group_head_position = tl.program_id(1) 39 | col_offsets = tl.arange(0, BLOCK_SIZE) 40 | half_head_dim = head_dim // 2 41 | mask = col_offsets < half_head_dim 42 | 43 | sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \ 44 | half_head_dim*0 + col_offsets, mask = mask, other = 0) 45 | cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \ 46 | half_head_dim*0 + col_offsets, mask = mask, other = 0) 47 | 48 | if BACKWARD_PASS: 49 | # See our blog post for more info. 50 | sin1 = -sin1 51 | pass 52 | 53 | # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8 54 | head_start = group_head_position * ROPE_GROUP_SIZE 55 | head_end = min((head_start + ROPE_GROUP_SIZE), n_heads) 56 | 57 | # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238) 58 | for k in range(head_start, head_end): 59 | offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets 60 | offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim 61 | 62 | # For Gemma - sometimes RoPE must be done in float32 and not bfloat16 63 | Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype) 64 | Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype) 65 | 66 | tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask) 67 | tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask) 68 | pass 69 | pass 70 | _rope_embedding = triton.jit(_rope_embedding) 71 | _rope_embedding = triton.heuristics( 72 | { 73 | "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]), 74 | } 75 | )(_rope_embedding) 76 | 77 | 78 | class Fast_RoPE_Embedding(torch.autograd.Function): 79 | @staticmethod 80 | def forward(ctx, Q, cos, sin): 81 | cos, sin = cos.squeeze(), sin.squeeze() 82 | batch : int 83 | seq_len : int 84 | n_heads : int 85 | head_dim : int 86 | batch, seq_len, n_heads, head_dim = Q.shape 87 | Q = Q.view(batch*seq_len, n_heads*head_dim) 88 | n_rows : int 89 | n_cols : int 90 | n_rows, n_cols = Q.shape 91 | assert(seq_len <= cos.shape[0]) 92 | 93 | # [TODO] Changing blocksize to head_dim//2 seems to have 94 | # some concurrency / un-deterministic issues. 95 | BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2) 96 | 97 | # group_size = 4 # 4 or 8, too large group_size can hurt performance. 98 | div : int 99 | mod : int 100 | div, mod = divmod(n_heads, ROPE_GROUP_SIZE) 101 | n_groups : int = div + (mod != 0) 102 | 103 | with torch_cuda_device(Q.device): 104 | _rope_embedding[(n_rows, n_groups, )]( 105 | Q, Q.stride(0), 106 | cos, cos.stride(0), 107 | sin, sin.stride(0), 108 | seq_len, 109 | head_dim, n_heads, 110 | BACKWARD_PASS = False, 111 | BLOCK_SIZE = BLOCK_SIZE, 112 | num_warps = num_warps, 113 | ) 114 | ctx.BLOCK_SIZE = BLOCK_SIZE 115 | ctx.num_warps = num_warps 116 | ctx.n_groups = n_groups 117 | ctx.cos = cos 118 | ctx.sin = sin 119 | return Q.view(batch, seq_len, n_heads, head_dim) 120 | pass 121 | 122 | @staticmethod 123 | def backward(ctx, dY): 124 | batch : int 125 | seq_len : int 126 | n_heads : int 127 | head_dim : int 128 | batch, seq_len, n_heads, head_dim = dY.shape 129 | dY = dY.reshape(batch*seq_len, n_heads*head_dim) 130 | # Must be reshape not view 131 | n_rows : int 132 | n_cols : int 133 | n_rows, n_cols = dY.shape 134 | 135 | cos = ctx.cos 136 | sin = ctx.sin 137 | 138 | with torch_cuda_device(dY.device): 139 | _rope_embedding[(n_rows, ctx.n_groups, )]( 140 | dY, dY .stride(0), 141 | cos, cos.stride(0), 142 | sin, sin.stride(0), 143 | seq_len, head_dim, n_heads, 144 | BACKWARD_PASS = True, 145 | BLOCK_SIZE = ctx.BLOCK_SIZE, 146 | num_warps = ctx.num_warps, 147 | ) 148 | dY = dY.view(batch, seq_len, n_heads, head_dim) 149 | return dY, None, None, 150 | pass 151 | pass 152 | 153 | # [TODO] Unsure why RoPE Embedding is not torch.compiling properly 154 | @torch.compiler.disable 155 | def fast_rope_embedding(Q, K, cos, sin): 156 | Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) 157 | K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) 158 | return Q, K 159 | pass 160 | 161 | 162 | class Slow_RoPE_Embedding(torch.autograd.Function): 163 | @staticmethod 164 | def forward(ctx, Q, cos, sin, position_ids): 165 | if position_ids is not None: 166 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 167 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 168 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 169 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 170 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 171 | 172 | # Q * cos + rotate_half(Q) * sin 173 | half = Q.shape[-1]//2 174 | RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1) 175 | Q *= cos 176 | Q.addcmul_(RH_Q, sin) 177 | # RH_Q *= sin 178 | # Q += RH_Q 179 | ctx.save_for_backward(cos, sin) 180 | return Q 181 | pass 182 | 183 | @staticmethod 184 | def backward(ctx, dY): 185 | cos, sin = ctx.saved_tensors 186 | # Q * cos + rotate_half.T(Q) * sin 187 | half = dY.shape[-1]//2 188 | RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1) 189 | dY *= cos 190 | dY.addcmul_(RH_dY, sin) 191 | # RH_dY *= sin 192 | # dY += RH_dY 193 | return dY, None, None, None 194 | pass 195 | pass 196 | 197 | 198 | def inplace_rope_embedding(Q, K, cos, sin, position_ids): 199 | Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids) 200 | K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids) 201 | return Q, K 202 | pass 203 | -------------------------------------------------------------------------------- /unsloth/kernels/swiglu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import triton 16 | import triton.language as tl 17 | import torch 18 | from .utils import calculate_settings, torch_cuda_device 19 | 20 | 21 | @triton.jit 22 | def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): 23 | block_idx = tl.program_id(0) 24 | offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 25 | mask = offsets < n_elements 26 | 27 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 28 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 29 | 30 | # f = e * sigmoid(e) 31 | f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row)) 32 | f_row = f_row.to(g_row.dtype) # Exact copy from HF 33 | # h = f * g 34 | h_row = f_row * g_row 35 | 36 | # Store h 37 | tl.store(h + offsets, h_row, mask = mask) 38 | pass 39 | 40 | 41 | def swiglu_fg_kernel(e, g): 42 | batch, seq_len, hd = e.shape 43 | n_elements = e.numel() 44 | h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device) 45 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 46 | with torch_cuda_device(e.device): 47 | _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) 48 | return h 49 | pass 50 | 51 | 52 | @triton.jit 53 | def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,): 54 | """ 55 | e = e.float() 56 | se = 1.0 / (1.0 + torch.exp(-e)) 57 | f = (se * e).to(dtype) 58 | h = f * g 59 | df = DW * f 60 | dg = DW * g 61 | de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) 62 | """ 63 | block_idx = tl.program_id(0) 64 | offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 65 | mask = offsets < n_elements 66 | 67 | DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32) 68 | e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32) 69 | g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32) 70 | 71 | # e = e.float() 72 | # se = 1.0 / (1.0 + torch.exp(-e)) 73 | se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row)) 74 | # f = (se * e).to(dtype) 75 | f_row = se_row * e_row 76 | f_row = f_row.to(DW_row.dtype) 77 | # h = f * g 78 | h_row = f_row * g_row 79 | # df = DW * f 80 | df_row = DW_row * f_row 81 | # dg = DW * g 82 | dg_row = DW_row * g_row 83 | # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype) 84 | de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row)) 85 | de_row = de_row.to(DW_row.dtype) 86 | 87 | # Store derivatives in buffers 88 | tl.store(DW + offsets, h_row, mask = mask) # h = f * g 89 | tl.store(e + offsets, df_row, mask = mask) # df = DW * f 90 | tl.store(g + offsets, de_row, mask = mask) # de 91 | pass 92 | 93 | 94 | def swiglu_DWf_DW_dfg_kernel(DW, e, g): 95 | batch_seq_len, hd = e.shape 96 | n_elements = e.numel() 97 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 98 | with torch_cuda_device(e.device): 99 | _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) 100 | return DW, e, g 101 | pass 102 | -------------------------------------------------------------------------------- /unsloth/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .llama import FastLlamaModel 16 | from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel 17 | from .mistral import FastMistralModel 18 | from .qwen2 import FastQwen2Model 19 | from .qwen3 import FastQwen3Model 20 | from .qwen3_moe import FastQwen3MoeModel 21 | from .granite import FastGraniteModel 22 | from .dpo import PatchDPOTrainer, PatchKTOTrainer 23 | from ._utils import is_bfloat16_supported, is_vLLM_available, __version__ 24 | from .rl import PatchFastRL, vLLMSamplingParams -------------------------------------------------------------------------------- /unsloth/models/dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = [ 16 | "PatchDPOTrainer", 17 | "PatchKTOTrainer", 18 | ] 19 | 20 | def PatchDPOTrainer(): return 21 | 22 | def PatchKTOTrainer(): return 23 | -------------------------------------------------------------------------------- /unsloth/models/llama4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # from unsloth_studio.models import patch_llama4 16 | # patch_llama4() 17 | -------------------------------------------------------------------------------- /unsloth/models/loader_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit 16 | # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! 17 | from packaging.version import Version 18 | from transformers import __version__ as transformers_version 19 | transformers_version = Version(transformers_version) 20 | SUPPORTS_FOURBIT = transformers_version >= Version("4.37") 21 | 22 | BAD_MAPPINGS = \ 23 | { 24 | "unsloth/Qwen3-32B-unsloth-bnb-4bit".lower() : "unsloth/Qwen3-32B-bnb-4bit".lower(), # 32B dynamic quant is way too big 25 | "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit".lower() : "unsloth/Qwen3-30B-A3B".lower(), # HF loads MoEs too slowly 26 | "unsloth/Qwen3-30B-A3B-bnb-4bit".lower() : "unsloth/Qwen3-30B-A3B".lower(), # We rather do it on the fly 27 | "unsloth/Qwen3-30B-A3B-Base-unsloth-bnb-4bit".lower() : "unsloth/Qwen3-30B-A3B-Base".lower(), # HF loads MoEs too slowly 28 | "unsloth/Qwen3-30B-A3B-Base-bnb-4bit".lower() : "unsloth/Qwen3-30B-A3B-Base".lower(), # We rather do it on the fly 29 | } 30 | 31 | def __get_model_name( 32 | model_name, 33 | load_in_4bit = True, 34 | INT_TO_FLOAT_MAPPER = None, 35 | FLOAT_TO_INT_MAPPER = None, 36 | MAP_TO_UNSLOTH_16bit = None, 37 | ): 38 | model_name = str(model_name) 39 | lower_model_name = model_name.lower() 40 | 41 | if not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER: 42 | 43 | model_name = INT_TO_FLOAT_MAPPER[lower_model_name] 44 | print( 45 | f"Unsloth: Your transformers version of {transformers_version} does not support native "\ 46 | f"4bit loading.\nThe minimum required version is 4.37.\n"\ 47 | f'Try `pip install --upgrade "transformers>=4.37"`\n'\ 48 | f"to obtain the latest transformers build, then restart this session.\n"\ 49 | f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)." 50 | ) 51 | return model_name 52 | 53 | elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER: 54 | 55 | new_model_name = INT_TO_FLOAT_MAPPER[lower_model_name] 56 | # logger.warning_once( 57 | # f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\ 58 | # f"`load_in_4bit = False`. We shall load `{new_model_name}` instead." 59 | # ) 60 | return new_model_name 61 | 62 | elif not load_in_4bit and lower_model_name in MAP_TO_UNSLOTH_16bit: 63 | 64 | new_model_name = MAP_TO_UNSLOTH_16bit[lower_model_name] 65 | return new_model_name 66 | 67 | elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER: 68 | 69 | # Support returning original full -bnb-4bit name if specified specifically 70 | # since we'll map it to the dynamic version instead 71 | if lower_model_name.endswith("-bnb-4bit"): 72 | return lower_model_name 73 | 74 | new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name] 75 | # logger.warning_once( 76 | # f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\ 77 | # f"We shall load `{new_model_name}` for 4x faster loading." 78 | # ) 79 | return new_model_name 80 | pass 81 | 82 | return None 83 | pass 84 | 85 | 86 | def _get_new_mapper(): 87 | try: 88 | import requests 89 | new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py" 90 | with requests.get(new_mapper, timeout = 3) as new_mapper: new_mapper = new_mapper.text 91 | new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER"):] 92 | new_mapper = new_mapper\ 93 | .replace("INT_TO_FLOAT_MAPPER", "NEW_INT_TO_FLOAT_MAPPER")\ 94 | .replace("FLOAT_TO_INT_MAPPER", "NEW_FLOAT_TO_INT_MAPPER")\ 95 | .replace("MAP_TO_UNSLOTH_16bit", "NEW_MAP_TO_UNSLOTH_16bit") 96 | 97 | exec(new_mapper, globals()) 98 | return NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit 99 | except: 100 | return {}, {}, {} 101 | pass 102 | pass 103 | 104 | 105 | def get_model_name(model_name, load_in_4bit = True): 106 | new_model_name = __get_model_name( 107 | model_name = model_name, 108 | load_in_4bit = load_in_4bit, 109 | INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER, 110 | FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER, 111 | MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit, 112 | ) 113 | # In the rare case, we convert bad model names to other names 114 | # For eg too large dynamic quants or MoEs 115 | if new_model_name is not None and type(new_model_name) is str and \ 116 | new_model_name.lower() in BAD_MAPPINGS: 117 | new_model_name = BAD_MAPPINGS[new_model_name.lower()] 118 | 119 | if new_model_name is None and model_name.count("/") == 1 and model_name[0].isalnum(): 120 | # Try checking if a new Unsloth version allows it! 121 | NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = _get_new_mapper() 122 | upgraded_model_name = __get_model_name( 123 | model_name = model_name, 124 | load_in_4bit = load_in_4bit, 125 | INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER, 126 | FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER, 127 | MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit, 128 | ) 129 | if upgraded_model_name is not None: 130 | raise NotImplementedError( 131 | f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"\ 132 | 'pip uninstall unsloth unsloth_zoo -y\n'\ 133 | 'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'\ 134 | 'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'\ 135 | ) 136 | pass 137 | pass 138 | return new_model_name if new_model_name is not None else model_name 139 | pass 140 | -------------------------------------------------------------------------------- /unsloth/models/qwen2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .llama import * 16 | from .llama import ( 17 | LlamaRotaryEmbedding, 18 | LlamaLinearScalingRotaryEmbedding, 19 | ) 20 | from transformers.models.qwen2.modeling_qwen2 import ( 21 | Qwen2Attention, 22 | Qwen2DecoderLayer, 23 | Qwen2Model, 24 | Qwen2ForCausalLM, 25 | ) 26 | # For Pytorch 2.1.1 27 | try: 28 | from transformers.models.qwen2.modeling_qwen2 import ( 29 | Qwen2SdpaAttention, 30 | Qwen2FlashAttention2, 31 | ) 32 | except: 33 | Qwen2SdpaAttention = Qwen2Attention 34 | Qwen2FlashAttention2 = Qwen2Attention 35 | pass 36 | 37 | 38 | class FastQwen2Model(FastLlamaModel): 39 | 40 | @staticmethod 41 | def pre_patch(): 42 | init_name, function = patch_linear_scaling( 43 | model_name = "qwen2", 44 | rope_module = LlamaRotaryEmbedding, 45 | scaled_rope_module = LlamaLinearScalingRotaryEmbedding, 46 | attention_module = Qwen2Attention, 47 | ) 48 | if init_name is not None: 49 | exec(function, globals()) 50 | Qwen2Attention.__init__ = eval(init_name) 51 | pass 52 | Qwen2Attention .forward = LlamaAttention_fast_forward 53 | Qwen2SdpaAttention .forward = LlamaAttention_fast_forward 54 | Qwen2FlashAttention2.forward = LlamaAttention_fast_forward 55 | Qwen2DecoderLayer .forward = LlamaDecoderLayer_fast_forward 56 | Qwen2Model .forward = LlamaModel_fast_forward 57 | Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference) 58 | PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward 59 | fix_prepare_inputs_for_generation(Qwen2ForCausalLM) 60 | 61 | # Solves https://github.com/unslothai/unsloth/issues/168 62 | # Static KV Cache was introduced in 4.38.0, causing training to be much slower. 63 | # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. 64 | # https://github.com/huggingface/transformers/pull/27931 65 | # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py 66 | import transformers.models.qwen2.modeling_qwen2 67 | transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding 68 | return 69 | pass 70 | 71 | 72 | @staticmethod 73 | def from_pretrained( 74 | model_name = "Qwen/Qwen2-7B", 75 | max_seq_length = 4096, 76 | dtype = None, 77 | load_in_4bit = True, 78 | token = None, 79 | device_map = "sequential", 80 | rope_scaling = None, # Qwen2 does not support RoPE scaling 81 | fix_tokenizer = True, 82 | model_patcher = None, 83 | tokenizer_name = None, 84 | trust_remote_code = False, 85 | **kwargs, 86 | ): 87 | return FastLlamaModel.from_pretrained( 88 | model_name = model_name, 89 | max_seq_length = max_seq_length, 90 | dtype = dtype, 91 | load_in_4bit = load_in_4bit, 92 | token = token, 93 | device_map = device_map, 94 | rope_scaling = rope_scaling, 95 | fix_tokenizer = fix_tokenizer, 96 | model_patcher = FastQwen2Model, 97 | tokenizer_name = tokenizer_name, 98 | trust_remote_code = trust_remote_code, 99 | **kwargs, 100 | ) 101 | pass 102 | pass 103 | -------------------------------------------------------------------------------- /unsloth/registry/REGISTRY.md: -------------------------------------------------------------------------------- 1 | ## Model Registry 2 | 3 | ### Structure 4 | ``` 5 | unsloth 6 | -registry 7 | __init__.py 8 | registry.py 9 | _llama.py 10 | _mistral.py 11 | _phi.py 12 | ... 13 | ``` 14 | 15 | Each model is registered in a separate file within the `registry` module (e.g. `registry/_llama.py`). 16 | 17 | Within each model registration file, a high-level `ModelMeta` is created for each model version, with the following structure: 18 | ```python 19 | @dataclass 20 | class ModelMeta: 21 | org: str 22 | base_name: str 23 | model_version: str 24 | model_info_cls: type[ModelInfo] 25 | model_sizes: list[str] = field(default_factory=list) 26 | instruct_tags: list[str] = field(default_factory=list) 27 | quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list) 28 | is_multimodal: bool = False 29 | ``` 30 | 31 | Each model then instantiates a global `ModelMeta` for its specific model version, defining how the model path (e.g. `unsloth/Llama-3.1-8B-Instruct`) is constructed since each model type has a different naming convention. 32 | ```python 33 | LlamaMeta_3_1 = ModelMeta( 34 | org="meta-llama", 35 | base_name="Llama", 36 | instruct_tags=[None, "Instruct"], 37 | model_version="3.1", 38 | model_sizes=["8"], 39 | model_info_cls=LlamaModelInfo, 40 | is_multimodal=False, 41 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 42 | ) 43 | ``` 44 | 45 | `LlamaModelInfo` is a subclass of `ModelInfo` that defines the model path for each model size and quant type. 46 | ```python 47 | class LlamaModelInfo(ModelInfo): 48 | @classmethod 49 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 50 | key = f"{base_name}-{version}-{size}B" 51 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 52 | ``` 53 | 54 | Once these constructs are defined, the model is registered by writing a register_xx_models function. 55 | ```python 56 | def register_llama_3_1_models(include_original_model: bool = False): 57 | global _IS_LLAMA_3_1_REGISTERED 58 | if _IS_LLAMA_3_1_REGISTERED: 59 | return 60 | _register_models(LlamaMeta_3_1, include_original_model=include_original_model) 61 | _IS_LLAMA_3_1_REGISTERED = True 62 | ``` 63 | 64 | `_register_models` is a helper function that registers the model with the registry. The global `_IS_XX_REGISTERED` is used to prevent duplicate registration. 65 | 66 | Once a model is registered, registry.registry.MODEL_REGISTRY is updated with the model info and can be searched with `registry.search_models`. 67 | 68 | ### Tests 69 | 70 | The `tests/test_model_registry.py` file contains tests for the model registry. 71 | 72 | Also, each model registration file is an executable module that checks that all registered models are available on `huggingface_hub`. 73 | ```python 74 | python unsloth.registry._llama.py 75 | ``` 76 | 77 | Prints the following (abridged) output: 78 | ```bash 79 | ✓ unsloth/Llama-3.1-8B 80 | ✓ unsloth/Llama-3.1-8B-bnb-4bit 81 | ✓ unsloth/Llama-3.1-8B-unsloth-bnb-4bit 82 | ✓ meta-llama/Llama-3.1-8B 83 | ✓ unsloth/Llama-3.1-8B-Instruct 84 | ✓ unsloth/Llama-3.1-8B-Instruct-bnb-4bit 85 | ✓ unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit 86 | ✓ meta-llama/Llama-3.1-8B-Instruct 87 | ✓ unsloth/Llama-3.2-1B 88 | ✓ unsloth/Llama-3.2-1B-bnb-4bit 89 | ✓ unsloth/Llama-3.2-1B-unsloth-bnb-4bit 90 | ✓ meta-llama/Llama-3.2-1B 91 | ... 92 | ``` 93 | 94 | ### TODO 95 | - Model Collections 96 | - [x] Gemma3 97 | - [ ] Llama3.1 98 | - [x] Llama3.2 99 | - [x] MistralSmall 100 | - [x] Qwen2.5 101 | - [x] Qwen2.5-VL 102 | - [ ] Qwen2.5 Coder 103 | - [x] QwenQwQ-32B 104 | - [x] Deepseek v3 105 | - [x] Deepseek R1 106 | - [x] Phi-4 107 | - [ ] Unsloth 4-bit Dynamic Quants 108 | - [ ] Vision/multimodal models 109 | - Sync model uploads with registry 110 | - Add utility methods for tracking model stats -------------------------------------------------------------------------------- /unsloth/registry/__init__.py: -------------------------------------------------------------------------------- 1 | from ._deepseek import register_deepseek_models as _register_deepseek_models 2 | from ._gemma import register_gemma_models as _register_gemma_models 3 | from ._llama import register_llama_models as _register_llama_models 4 | from ._mistral import register_mistral_models as _register_mistral_models 5 | from ._phi import register_phi_models as _register_phi_models 6 | from ._qwen import register_qwen_models as _register_qwen_models 7 | from .registry import MODEL_REGISTRY, ModelInfo, QuantType 8 | 9 | _ARE_MODELS_REGISTERED = False 10 | 11 | def register_models(): 12 | global _ARE_MODELS_REGISTERED 13 | 14 | if _ARE_MODELS_REGISTERED: 15 | return 16 | _register_deepseek_models() 17 | _register_gemma_models() 18 | _register_llama_models() 19 | _register_mistral_models() 20 | _register_phi_models() 21 | _register_qwen_models() 22 | 23 | _ARE_MODELS_REGISTERED = True 24 | 25 | def search_models(org: str = None, base_name: str = None, version: str = None, size: str = None, quant_types: list[QuantType] = None, search_pattern: str = None) -> list[ModelInfo]: 26 | """ 27 | Get model info from the registry. 28 | 29 | See registry.ModelInfo for more fields. 30 | 31 | If search_pattern is provided, the full model path will be matched against the pattern, where the model path is the model_id on huggingface hub. 32 | 33 | """ 34 | if not _ARE_MODELS_REGISTERED: 35 | register_models() 36 | 37 | model_infos = MODEL_REGISTRY.values() 38 | if org: 39 | model_infos = [model_info for model_info in model_infos if model_info.org == org] 40 | if base_name: 41 | model_infos = [model_info for model_info in model_infos if model_info.base_name == base_name] 42 | if version: 43 | model_infos = [model_info for model_info in model_infos if model_info.version == version] 44 | if size: 45 | model_infos = [model_info for model_info in model_infos if model_info.size == size] 46 | if quant_types: 47 | model_infos = [model_info for model_info in model_infos if any(model_info.quant_type == quant_type for quant_type in quant_types)] 48 | if search_pattern: 49 | model_infos = [model_info for model_info in model_infos if search_pattern in model_info.model_path] 50 | 51 | return model_infos -------------------------------------------------------------------------------- /unsloth/registry/_deepseek.py: -------------------------------------------------------------------------------- 1 | from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models 2 | 3 | _IS_DEEPSEEK_V3_REGISTERED = False 4 | _IS_DEEPSEEK_V3_0324_REGISTERED = False 5 | _IS_DEEPSEEK_R1_REGISTERED = False 6 | _IS_DEEPSEEK_R1_ZERO_REGISTERED = False 7 | _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = False 8 | _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = False 9 | 10 | class DeepseekV3ModelInfo(ModelInfo): 11 | @classmethod 12 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 13 | key = f"{base_name}-V{version}" 14 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 15 | 16 | class DeepseekR1ModelInfo(ModelInfo): 17 | @classmethod 18 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 19 | key = f"{base_name}-{version}" if version else base_name 20 | if size: 21 | key = f"{key}-{size}B" 22 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 23 | 24 | # Deepseek V3 Model Meta 25 | DeepseekV3Meta = ModelMeta( 26 | org="deepseek-ai", 27 | base_name="DeepSeek", 28 | instruct_tags=[None], 29 | model_version="3", 30 | model_sizes=[""], 31 | model_info_cls=DeepseekV3ModelInfo, 32 | is_multimodal=False, 33 | quant_types=[QuantType.NONE, QuantType.BF16], 34 | ) 35 | 36 | DeepseekV3_0324Meta = ModelMeta( 37 | org="deepseek-ai", 38 | base_name="DeepSeek", 39 | instruct_tags=[None], 40 | model_version="3-0324", 41 | model_sizes=[""], 42 | model_info_cls=DeepseekV3ModelInfo, 43 | is_multimodal=False, 44 | quant_types=[QuantType.NONE, QuantType.GGUF], 45 | ) 46 | 47 | DeepseekR1Meta = ModelMeta( 48 | org="deepseek-ai", 49 | base_name="DeepSeek-R1", 50 | instruct_tags=[None], 51 | model_version="", 52 | model_sizes=[""], 53 | model_info_cls=DeepseekR1ModelInfo, 54 | is_multimodal=False, 55 | quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF], 56 | ) 57 | 58 | DeepseekR1ZeroMeta = ModelMeta( 59 | org="deepseek-ai", 60 | base_name="DeepSeek-R1", 61 | instruct_tags=[None], 62 | model_version="Zero", 63 | model_sizes=[""], 64 | model_info_cls=DeepseekR1ModelInfo, 65 | is_multimodal=False, 66 | quant_types=[QuantType.NONE, QuantType.GGUF], 67 | ) 68 | 69 | DeepseekR1DistillLlamaMeta = ModelMeta( 70 | org="deepseek-ai", 71 | base_name="DeepSeek-R1-Distill", 72 | instruct_tags=[None], 73 | model_version="Llama", 74 | model_sizes=["8", "70"], 75 | model_info_cls=DeepseekR1ModelInfo, 76 | is_multimodal=False, 77 | quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]}, 78 | ) 79 | 80 | # Deepseek R1 Distill Qwen Model Meta 81 | DeepseekR1DistillQwenMeta = ModelMeta( 82 | org="deepseek-ai", 83 | base_name="DeepSeek-R1-Distill", 84 | instruct_tags=[None], 85 | model_version="Qwen", 86 | model_sizes=["1.5", "7", "14", "32"], 87 | model_info_cls=DeepseekR1ModelInfo, 88 | is_multimodal=False, 89 | quant_types={ 90 | "1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF], 91 | "7": [QuantType.UNSLOTH, QuantType.BNB], 92 | "14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF], 93 | "32": [QuantType.GGUF, QuantType.BNB], 94 | }, 95 | ) 96 | 97 | def register_deepseek_v3_models(include_original_model: bool = False): 98 | global _IS_DEEPSEEK_V3_REGISTERED 99 | if _IS_DEEPSEEK_V3_REGISTERED: 100 | return 101 | _register_models(DeepseekV3Meta, include_original_model=include_original_model) 102 | _IS_DEEPSEEK_V3_REGISTERED = True 103 | 104 | def register_deepseek_v3_0324_models(include_original_model: bool = False): 105 | global _IS_DEEPSEEK_V3_0324_REGISTERED 106 | if _IS_DEEPSEEK_V3_0324_REGISTERED: 107 | return 108 | _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model) 109 | _IS_DEEPSEEK_V3_0324_REGISTERED = True 110 | 111 | def register_deepseek_r1_models(include_original_model: bool = False): 112 | global _IS_DEEPSEEK_R1_REGISTERED 113 | if _IS_DEEPSEEK_R1_REGISTERED: 114 | return 115 | _register_models(DeepseekR1Meta, include_original_model=include_original_model) 116 | _IS_DEEPSEEK_R1_REGISTERED = True 117 | 118 | def register_deepseek_r1_zero_models(include_original_model: bool = False): 119 | global _IS_DEEPSEEK_R1_ZERO_REGISTERED 120 | if _IS_DEEPSEEK_R1_ZERO_REGISTERED: 121 | return 122 | _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) 123 | _IS_DEEPSEEK_R1_ZERO_REGISTERED = True 124 | 125 | def register_deepseek_r1_distill_llama_models(include_original_model: bool = False): 126 | global _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED 127 | if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED: 128 | return 129 | _register_models(DeepseekR1DistillLlamaMeta, include_original_model=include_original_model) 130 | _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True 131 | 132 | def register_deepseek_r1_distill_qwen_models(include_original_model: bool = False): 133 | global _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED 134 | if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED: 135 | return 136 | _register_models(DeepseekR1DistillQwenMeta, include_original_model=include_original_model) 137 | _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True 138 | 139 | def register_deepseek_models(include_original_model: bool = False): 140 | register_deepseek_v3_models(include_original_model=include_original_model) 141 | register_deepseek_v3_0324_models(include_original_model=include_original_model) 142 | register_deepseek_r1_models(include_original_model=include_original_model) 143 | register_deepseek_r1_zero_models(include_original_model=include_original_model) 144 | register_deepseek_r1_distill_llama_models(include_original_model=include_original_model) 145 | register_deepseek_r1_distill_qwen_models(include_original_model=include_original_model) 146 | 147 | def _list_deepseek_r1_distill_models(): 148 | from unsloth.utils.hf_hub import ModelInfo as HfModelInfo 149 | from unsloth.utils.hf_hub import list_models 150 | models: list[HfModelInfo] = list_models(author="unsloth", search="Distill", limit=1000) 151 | distill_models = [] 152 | for model in models: 153 | model_id = model.id 154 | model_name = model_id.split("/")[-1] 155 | # parse out only the version 156 | version = model_name.removeprefix("DeepSeek-R1-Distill-") 157 | distill_models.append(version) 158 | 159 | return distill_models 160 | 161 | 162 | register_deepseek_models(include_original_model=True) 163 | 164 | if __name__ == "__main__": 165 | from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info 166 | MODEL_REGISTRY.clear() 167 | 168 | register_deepseek_models(include_original_model=True) 169 | 170 | for model_id, model_info in MODEL_REGISTRY.items(): 171 | model_info = _check_model_info(model_id) 172 | if model_info is None: 173 | print(f"\u2718 {model_id}") 174 | else: 175 | print(f"\u2713 {model_id}") 176 | # distill_models = _list_deepseek_r1_distill_models() 177 | # for model in sorted(distill_models): 178 | # if "qwen" in model.lower(): 179 | # print(model) -------------------------------------------------------------------------------- /unsloth/registry/_gemma.py: -------------------------------------------------------------------------------- 1 | from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models 2 | 3 | _IS_GEMMA_3_BASE_REGISTERED = False 4 | _IS_GEMMA_3_INSTRUCT_REGISTERED = False 5 | 6 | class GemmaModelInfo(ModelInfo): 7 | @classmethod 8 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 9 | key = f"{base_name}-{version}-{size}B" 10 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 11 | 12 | # Gemma3 Base Model Meta 13 | GemmaMeta3Base = ModelMeta( 14 | org="google", 15 | base_name="gemma", 16 | instruct_tags=["pt"], # pt = base 17 | model_version="3", 18 | model_sizes=["1", "4", "12", "27"], 19 | model_info_cls=GemmaModelInfo, 20 | is_multimodal=True, 21 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 22 | ) 23 | 24 | # Gemma3 Instruct Model Meta 25 | GemmaMeta3Instruct = ModelMeta( 26 | org="google", 27 | base_name="gemma", 28 | instruct_tags=["it"], # it = instruction tuned 29 | model_version="3", 30 | model_sizes=["1", "4", "12", "27"], 31 | model_info_cls=GemmaModelInfo, 32 | is_multimodal=True, 33 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], 34 | ) 35 | 36 | def register_gemma_3_base_models(include_original_model: bool = False): 37 | global _IS_GEMMA_3_BASE_REGISTERED 38 | if _IS_GEMMA_3_BASE_REGISTERED: 39 | return 40 | _register_models(GemmaMeta3Base, include_original_model=include_original_model) 41 | _IS_GEMMA_3_BASE_REGISTERED = True 42 | 43 | def register_gemma_3_instruct_models(include_original_model: bool = False): 44 | global _IS_GEMMA_3_INSTRUCT_REGISTERED 45 | if _IS_GEMMA_3_INSTRUCT_REGISTERED: 46 | return 47 | _register_models(GemmaMeta3Instruct, include_original_model=include_original_model) 48 | _IS_GEMMA_3_INSTRUCT_REGISTERED = True 49 | 50 | def register_gemma_models(include_original_model: bool = False): 51 | register_gemma_3_base_models(include_original_model=include_original_model) 52 | register_gemma_3_instruct_models(include_original_model=include_original_model) 53 | 54 | 55 | if __name__ == "__main__": 56 | from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info 57 | MODEL_REGISTRY.clear() 58 | 59 | register_gemma_models(include_original_model=True) 60 | 61 | for model_id, model_info in MODEL_REGISTRY.items(): 62 | model_info = _check_model_info(model_id) 63 | if model_info is None: 64 | print(f"\u2718 {model_id}") 65 | else: 66 | print(f"\u2713 {model_id}") 67 | -------------------------------------------------------------------------------- /unsloth/registry/_llama.py: -------------------------------------------------------------------------------- 1 | from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models 2 | 3 | _IS_LLAMA_3_1_REGISTERED = False 4 | _IS_LLAMA_3_2_REGISTERED = False 5 | _IS_LLAMA_3_2_VISION_REGISTERED = False 6 | 7 | 8 | class LlamaModelInfo(ModelInfo): 9 | @classmethod 10 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 11 | key = f"{base_name}-{version}-{size}B" 12 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 13 | 14 | 15 | class LlamaVisionModelInfo(ModelInfo): 16 | @classmethod 17 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 18 | key = f"{base_name}-{version}-{size}B-Vision" 19 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 20 | 21 | 22 | # Llama 3.1 23 | LlamaMeta_3_1 = ModelMeta( 24 | org="meta-llama", 25 | base_name="Llama", 26 | instruct_tags=[None, "Instruct"], 27 | model_version="3.1", 28 | model_sizes=["8"], 29 | model_info_cls=LlamaModelInfo, 30 | is_multimodal=False, 31 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 32 | ) 33 | 34 | # Llama 3.2 Base Models 35 | LlamaMeta_3_2_Base = ModelMeta( 36 | org="meta-llama", 37 | base_name="Llama", 38 | instruct_tags=[None], 39 | model_version="3.2", 40 | model_sizes=["1", "3"], 41 | model_info_cls=LlamaModelInfo, 42 | is_multimodal=False, 43 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 44 | ) 45 | 46 | # Llama 3.2 Instruction Tuned Models 47 | LlamaMeta_3_2_Instruct = ModelMeta( 48 | org="meta-llama", 49 | base_name="Llama", 50 | instruct_tags=["Instruct"], 51 | model_version="3.2", 52 | model_sizes=["1", "3"], 53 | model_info_cls=LlamaModelInfo, 54 | is_multimodal=False, 55 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], 56 | ) 57 | 58 | # Llama 3.2 Vision 59 | LlamaMeta_3_2_Vision = ModelMeta( 60 | org="meta-llama", 61 | base_name="Llama", 62 | instruct_tags=[None, "Instruct"], 63 | model_version="3.2", 64 | model_sizes=["11", "90"], 65 | model_info_cls=LlamaVisionModelInfo, 66 | is_multimodal=True, 67 | quant_types={ 68 | "11": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 69 | "90": [QuantType.NONE], 70 | }, 71 | ) 72 | 73 | 74 | def register_llama_3_1_models(include_original_model: bool = False): 75 | global _IS_LLAMA_3_1_REGISTERED 76 | if _IS_LLAMA_3_1_REGISTERED: 77 | return 78 | _register_models(LlamaMeta_3_1, include_original_model=include_original_model) 79 | _IS_LLAMA_3_1_REGISTERED = True 80 | 81 | def register_llama_3_2_models(include_original_model: bool = False): 82 | global _IS_LLAMA_3_2_REGISTERED 83 | if _IS_LLAMA_3_2_REGISTERED: 84 | return 85 | _register_models(LlamaMeta_3_2_Base, include_original_model=include_original_model) 86 | _register_models(LlamaMeta_3_2_Instruct, include_original_model=include_original_model) 87 | _IS_LLAMA_3_2_REGISTERED = True 88 | 89 | def register_llama_3_2_vision_models(include_original_model: bool = False): 90 | global _IS_LLAMA_3_2_VISION_REGISTERED 91 | if _IS_LLAMA_3_2_VISION_REGISTERED: 92 | return 93 | _register_models(LlamaMeta_3_2_Vision, include_original_model=include_original_model) 94 | _IS_LLAMA_3_2_VISION_REGISTERED = True 95 | 96 | 97 | def register_llama_models(include_original_model: bool = False): 98 | register_llama_3_1_models(include_original_model=include_original_model) 99 | register_llama_3_2_models(include_original_model=include_original_model) 100 | register_llama_3_2_vision_models(include_original_model=include_original_model) 101 | 102 | if __name__ == "__main__": 103 | from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info 104 | MODEL_REGISTRY.clear() 105 | 106 | register_llama_models(include_original_model=True) 107 | 108 | for model_id, model_info in MODEL_REGISTRY.items(): 109 | model_info = _check_model_info(model_id) 110 | if model_info is None: 111 | print(f"\u2718 {model_id}") 112 | else: 113 | print(f"\u2713 {model_id}") 114 | -------------------------------------------------------------------------------- /unsloth/registry/_mistral.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models 4 | 5 | _IS_MISTRAL_SMALL_REGISTERED = False 6 | 7 | _MISTRAL_SMALL_03_25_VERSION = "2503" 8 | _MISTRAL_SMALL_01_25_VERSION = "2501" 9 | _MISTRAL_SMALL_09_24_VERSION = "2409" # Not uploaded to unsloth 10 | 11 | class MistralSmallModelInfo(ModelInfo): 12 | @classmethod 13 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 14 | if version == _MISTRAL_SMALL_03_25_VERSION: 15 | key = f"{base_name}-3.1-{size}B-{instruct_tag}" 16 | else: 17 | key = f"{base_name}-{size}B-{instruct_tag}" 18 | key += f"-{version}" 19 | key = cls.append_quant_type(key, quant_type) 20 | 21 | return key 22 | 23 | 24 | MistralSmall_2503_Base_Meta = ModelMeta( 25 | org="mistralai", 26 | base_name="Mistral-Small", 27 | instruct_tags=["Base"], 28 | model_version=_MISTRAL_SMALL_03_25_VERSION, 29 | model_sizes=["24"], 30 | model_info_cls=MistralSmallModelInfo, 31 | is_multimodal=False, 32 | quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB], 33 | ) 34 | 35 | MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta) 36 | MistralSmall_2503_Instruct_Meta.instruct_tags = ["Instruct"] 37 | MistralSmall_2503_Instruct_Meta.quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF] 38 | 39 | MistralSmall_2501_Base_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta) 40 | MistralSmall_2501_Base_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION 41 | 42 | MistralSmall_2501_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Instruct_Meta) 43 | MistralSmall_2501_Instruct_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION 44 | 45 | def register_mistral_small_models(include_original_model: bool = False): 46 | global _IS_MISTRAL_SMALL_REGISTERED 47 | if _IS_MISTRAL_SMALL_REGISTERED: 48 | return 49 | _register_models(MistralSmall_2503_Base_Meta, include_original_model=include_original_model) 50 | _register_models(MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model) 51 | _register_models(MistralSmall_2501_Base_Meta, include_original_model=include_original_model) 52 | _register_models(MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model) 53 | 54 | _IS_MISTRAL_SMALL_REGISTERED = True 55 | 56 | def register_mistral_models(include_original_model: bool = False): 57 | register_mistral_small_models(include_original_model=include_original_model) 58 | 59 | if __name__ == "__main__": 60 | from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info 61 | MODEL_REGISTRY.clear() 62 | 63 | register_mistral_models(include_original_model=True) 64 | 65 | for model_id, model_info in MODEL_REGISTRY.items(): 66 | model_info = _check_model_info(model_id) 67 | if model_info is None: 68 | print(f"\u2718 {model_id}") 69 | else: 70 | print(f"\u2713 {model_id}") -------------------------------------------------------------------------------- /unsloth/registry/_phi.py: -------------------------------------------------------------------------------- 1 | from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models 2 | 3 | _IS_PHI_4_REGISTERED = False 4 | _IS_PHI_4_INSTRUCT_REGISTERED = False 5 | 6 | class PhiModelInfo(ModelInfo): 7 | @classmethod 8 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 9 | key = f"{base_name}-{version}" 10 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 11 | 12 | # Phi Model Meta 13 | PhiMeta4 = ModelMeta( 14 | org="microsoft", 15 | base_name="phi", 16 | instruct_tags=[None], 17 | model_version="4", 18 | model_sizes=["1"], # Assuming only one size 19 | model_info_cls=PhiModelInfo, 20 | is_multimodal=False, 21 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 22 | ) 23 | 24 | # Phi Instruct Model Meta 25 | PhiInstructMeta4 = ModelMeta( 26 | org="microsoft", 27 | base_name="phi", 28 | instruct_tags=["mini-instruct"], 29 | model_version="4", 30 | model_sizes=["1"], # Assuming only one size 31 | model_info_cls=PhiModelInfo, 32 | is_multimodal=False, 33 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], 34 | ) 35 | 36 | def register_phi_4_models(include_original_model: bool = False): 37 | global _IS_PHI_4_REGISTERED 38 | if _IS_PHI_4_REGISTERED: 39 | return 40 | _register_models(PhiMeta4, include_original_model=include_original_model) 41 | _IS_PHI_4_REGISTERED = True 42 | 43 | def register_phi_4_instruct_models(include_original_model: bool = False): 44 | global _IS_PHI_4_INSTRUCT_REGISTERED 45 | if _IS_PHI_4_INSTRUCT_REGISTERED: 46 | return 47 | _register_models(PhiInstructMeta4, include_original_model=include_original_model) 48 | _IS_PHI_4_INSTRUCT_REGISTERED = True 49 | 50 | def register_phi_models(include_original_model: bool = False): 51 | register_phi_4_models(include_original_model=include_original_model) 52 | register_phi_4_instruct_models(include_original_model=include_original_model) 53 | 54 | if __name__ == "__main__": 55 | from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info 56 | MODEL_REGISTRY.clear() 57 | 58 | register_phi_models(include_original_model=True) 59 | 60 | for model_id, model_info in MODEL_REGISTRY.items(): 61 | model_info = _check_model_info(model_id) 62 | if model_info is None: 63 | print(f"\u2718 {model_id}") 64 | else: 65 | print(f"\u2713 {model_id}") -------------------------------------------------------------------------------- /unsloth/registry/_qwen.py: -------------------------------------------------------------------------------- 1 | from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models 2 | 3 | _IS_QWEN_2_5_REGISTERED = False 4 | _IS_QWEN_2_5_VL_REGISTERED = False 5 | _IS_QWEN_QWQ_REGISTERED = False 6 | class QwenModelInfo(ModelInfo): 7 | @classmethod 8 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 9 | key = f"{base_name}{version}-{size}B" 10 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 11 | 12 | 13 | class QwenVLModelInfo(ModelInfo): 14 | @classmethod 15 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 16 | key = f"{base_name}{version}-VL-{size}B" 17 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 18 | 19 | class QwenQwQModelInfo(ModelInfo): 20 | @classmethod 21 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 22 | key = f"{base_name}-{size}B" 23 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 24 | 25 | class QwenQVQPreviewModelInfo(ModelInfo): 26 | @classmethod 27 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): 28 | key = f"{base_name}-{size}B-Preview" 29 | return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) 30 | 31 | # Qwen2.5 Model Meta 32 | Qwen_2_5_Meta = ModelMeta( 33 | org="Qwen", 34 | base_name="Qwen", 35 | instruct_tags=[None, "Instruct"], 36 | model_version="2.5", 37 | model_sizes=["3", "7"], 38 | model_info_cls=QwenModelInfo, 39 | is_multimodal=False, 40 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 41 | ) 42 | 43 | # Qwen2.5 VL Model Meta 44 | Qwen_2_5_VLMeta = ModelMeta( 45 | org="Qwen", 46 | base_name="Qwen", 47 | instruct_tags=["Instruct"], # No base, only instruction tuned 48 | model_version="2.5", 49 | model_sizes=["3", "7", "32", "72"], 50 | model_info_cls=QwenVLModelInfo, 51 | is_multimodal=True, 52 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], 53 | ) 54 | 55 | # Qwen QwQ Model Meta 56 | QwenQwQMeta = ModelMeta( 57 | org="Qwen", 58 | base_name="QwQ", 59 | instruct_tags=[None], 60 | model_version="", 61 | model_sizes=["32"], 62 | model_info_cls=QwenQwQModelInfo, 63 | is_multimodal=False, 64 | quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], 65 | ) 66 | 67 | # Qwen QVQ Preview Model Meta 68 | QwenQVQPreviewMeta = ModelMeta( 69 | org="Qwen", 70 | base_name="QVQ", 71 | instruct_tags=[None], 72 | model_version="", 73 | model_sizes=["72"], 74 | model_info_cls=QwenQVQPreviewModelInfo, 75 | is_multimodal=True, 76 | quant_types=[QuantType.NONE, QuantType.BNB], 77 | ) 78 | 79 | def register_qwen_2_5_models(include_original_model: bool = False): 80 | global _IS_QWEN_2_5_REGISTERED 81 | if _IS_QWEN_2_5_REGISTERED: 82 | return 83 | _register_models(Qwen_2_5_Meta, include_original_model=include_original_model) 84 | _IS_QWEN_2_5_REGISTERED = True 85 | 86 | def register_qwen_2_5_vl_models(include_original_model: bool = False): 87 | global _IS_QWEN_2_5_VL_REGISTERED 88 | if _IS_QWEN_2_5_VL_REGISTERED: 89 | return 90 | _register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model) 91 | _IS_QWEN_2_5_VL_REGISTERED = True 92 | 93 | def register_qwen_qwq_models(include_original_model: bool = False): 94 | global _IS_QWEN_QWQ_REGISTERED 95 | if _IS_QWEN_QWQ_REGISTERED: 96 | return 97 | _register_models(QwenQwQMeta, include_original_model=include_original_model) 98 | _register_models(QwenQVQPreviewMeta, include_original_model=include_original_model) 99 | _IS_QWEN_QWQ_REGISTERED = True 100 | 101 | def register_qwen_models(include_original_model: bool = False): 102 | register_qwen_2_5_models(include_original_model=include_original_model) 103 | register_qwen_2_5_vl_models(include_original_model=include_original_model) 104 | register_qwen_qwq_models(include_original_model=include_original_model) 105 | 106 | if __name__ == "__main__": 107 | from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info 108 | MODEL_REGISTRY.clear() 109 | 110 | register_qwen_models(include_original_model=True) 111 | 112 | for model_id, model_info in MODEL_REGISTRY.items(): 113 | model_info = _check_model_info(model_id) 114 | if model_info is None: 115 | print(f"\u2718 {model_id}") 116 | else: 117 | print(f"\u2713 {model_id}") 118 | -------------------------------------------------------------------------------- /unsloth/registry/registry.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, field 3 | from enum import Enum 4 | 5 | 6 | class QuantType(Enum): 7 | BNB = "bnb" 8 | UNSLOTH = "unsloth" # dynamic 4-bit quantization 9 | GGUF = "GGUF" 10 | NONE = "none" 11 | BF16 = "bf16" # only for Deepseek V3 12 | 13 | # Tags for Hugging Face model paths 14 | BNB_QUANTIZED_TAG = "bnb-4bit" 15 | UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG 16 | GGUF_TAG = "GGUF" 17 | BF16_TAG = "bf16" 18 | 19 | QUANT_TAG_MAP = { 20 | QuantType.BNB: BNB_QUANTIZED_TAG, 21 | QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG, 22 | QuantType.GGUF: GGUF_TAG, 23 | QuantType.NONE: None, 24 | QuantType.BF16: BF16_TAG, 25 | } 26 | 27 | # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH 28 | @dataclass 29 | class ModelInfo: 30 | org: str 31 | base_name: str 32 | version: str 33 | size: int 34 | name: str = None # full model name, constructed from base_name, version, and size unless provided 35 | is_multimodal: bool = False 36 | instruct_tag: str = None 37 | quant_type: QuantType = None 38 | description: str = None 39 | 40 | def __post_init__(self): 41 | self.name = self.name or self.construct_model_name( 42 | self.base_name, 43 | self.version, 44 | self.size, 45 | self.quant_type, 46 | self.instruct_tag, 47 | ) 48 | 49 | @staticmethod 50 | def append_instruct_tag(key: str, instruct_tag: str = None): 51 | if instruct_tag: 52 | key = "-".join([key, instruct_tag]) 53 | return key 54 | 55 | @staticmethod 56 | def append_quant_type( 57 | key: str, quant_type: QuantType = None 58 | ): 59 | if quant_type != QuantType.NONE: 60 | key = "-".join([key, QUANT_TAG_MAP[quant_type]]) 61 | return key 62 | 63 | @classmethod 64 | def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag, key=""): 65 | key = cls.append_instruct_tag(key, instruct_tag) 66 | key = cls.append_quant_type(key, quant_type) 67 | return key 68 | 69 | @property 70 | def model_path( 71 | self, 72 | ) -> str: 73 | return f"{self.org}/{self.name}" 74 | 75 | 76 | @dataclass 77 | class ModelMeta: 78 | org: str 79 | base_name: str 80 | model_version: str 81 | model_info_cls: type[ModelInfo] 82 | model_sizes: list[str] = field(default_factory=list) 83 | instruct_tags: list[str] = field(default_factory=list) 84 | quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list) 85 | is_multimodal: bool = False 86 | 87 | 88 | MODEL_REGISTRY: dict[str, ModelInfo] = {} 89 | 90 | 91 | def register_model( 92 | model_info_cls: ModelInfo, 93 | org: str, 94 | base_name: str, 95 | version: str, 96 | size: int, 97 | instruct_tag: str = None, 98 | quant_type: QuantType = None, 99 | is_multimodal: bool = False, 100 | name: str = None, 101 | ): 102 | name = name or model_info_cls.construct_model_name( 103 | base_name=base_name, 104 | version=version, 105 | size=size, 106 | quant_type=quant_type, 107 | instruct_tag=instruct_tag, 108 | ) 109 | key = f"{org}/{name}" 110 | 111 | if key in MODEL_REGISTRY: 112 | raise ValueError(f"Model {key} already registered, current keys: {MODEL_REGISTRY.keys()}") 113 | 114 | MODEL_REGISTRY[key] = model_info_cls( 115 | org=org, 116 | base_name=base_name, 117 | version=version, 118 | size=size, 119 | is_multimodal=is_multimodal, 120 | instruct_tag=instruct_tag, 121 | quant_type=quant_type, 122 | name=name, 123 | ) 124 | 125 | 126 | def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): 127 | from huggingface_hub import HfApi 128 | from huggingface_hub import ModelInfo as HfModelInfo 129 | from huggingface_hub.utils import RepositoryNotFoundError 130 | 131 | api = HfApi() 132 | 133 | try: 134 | model_info: HfModelInfo = api.model_info(model_id, expand=properties) 135 | except Exception as e: 136 | if isinstance(e, RepositoryNotFoundError): 137 | warnings.warn(f"{model_id} not found on Hugging Face") 138 | model_info = None 139 | else: 140 | raise e 141 | return model_info 142 | 143 | 144 | def _register_models(model_meta: ModelMeta, include_original_model: bool = False): 145 | org = model_meta.org 146 | base_name = model_meta.base_name 147 | instruct_tags = model_meta.instruct_tags 148 | model_version = model_meta.model_version 149 | model_sizes = model_meta.model_sizes 150 | is_multimodal = model_meta.is_multimodal 151 | quant_types = model_meta.quant_types 152 | model_info_cls = model_meta.model_info_cls 153 | 154 | for size in model_sizes: 155 | for instruct_tag in instruct_tags: 156 | # Handle quant types per model size 157 | if isinstance(quant_types, dict): 158 | _quant_types = quant_types[size] 159 | else: 160 | _quant_types = quant_types 161 | for quant_type in _quant_types: 162 | # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH 163 | _org = "unsloth" # unsloth models -- these are all quantized versions of the original model 164 | register_model( 165 | model_info_cls=model_info_cls, 166 | org=_org, 167 | base_name=base_name, 168 | version=model_version, 169 | size=size, 170 | instruct_tag=instruct_tag, 171 | quant_type=quant_type, 172 | is_multimodal=is_multimodal, 173 | ) 174 | # include original model from releasing organization 175 | if include_original_model: 176 | register_model( 177 | model_info_cls=model_info_cls, 178 | org=org, 179 | base_name=base_name, 180 | version=model_version, 181 | size=size, 182 | instruct_tag=instruct_tag, 183 | quant_type=QuantType.NONE, 184 | is_multimodal=is_multimodal, 185 | ) 186 | -------------------------------------------------------------------------------- /unsloth/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass, field 17 | from typing import Optional 18 | from functools import wraps 19 | 20 | import trl 21 | import inspect 22 | from trl import SFTTrainer 23 | from . import is_bfloat16_supported 24 | from unsloth_zoo.training_utils import ( 25 | unsloth_train as _unsloth_train, 26 | ) 27 | from unsloth_zoo.vision_utils import ( 28 | UnslothVisionDataCollator, 29 | ) 30 | from packaging.version import Version 31 | import dataclasses 32 | 33 | __all__ = [ 34 | "UnslothTrainingArguments", 35 | "UnslothTrainer", 36 | "unsloth_train", 37 | "_patch_trl_trainer", 38 | "UnslothVisionDataCollator", 39 | ] 40 | 41 | # Unsloth gradient accumulation fix: 42 | from transformers import __version__ as transformers_version 43 | if Version(transformers_version) > Version("4.45.2"): 44 | def unsloth_train(trainer, *args, **kwargs): 45 | return trainer.train(*args, **kwargs) 46 | pass 47 | else: 48 | def unsloth_train(trainer, *args, **kwargs): 49 | if len(args) != 0 or len(kwargs) != 0: 50 | raise RuntimeError( 51 | "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\ 52 | "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ 53 | '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' 54 | ) 55 | print( 56 | "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\ 57 | "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ 58 | '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' 59 | ) 60 | return _unsloth_train(trainer) 61 | pass 62 | pass 63 | 64 | try: 65 | from trl import SFTConfig as TrainingArguments 66 | except: 67 | from transformers import TrainingArguments 68 | pass 69 | @dataclass 70 | class UnslothTrainingArguments(TrainingArguments): 71 | embedding_learning_rate : Optional[float] = field( 72 | default = None, 73 | metadata = {"help" : "Different learning rates for embeddings and lm_head."} 74 | ) 75 | pass 76 | 77 | 78 | def _create_unsloth_optimizer( 79 | model, 80 | optimizer_cls, 81 | optimizer_kwargs, 82 | embedding_lr = 5e-5, 83 | ): 84 | lr = optimizer_kwargs["lr"] 85 | weight_decay = optimizer_kwargs.get("weight_decay", 0.0) 86 | 87 | param_groups = \ 88 | { 89 | "non_embeddings" : {}, 90 | "embeddings" : {}, 91 | } 92 | 93 | for name, param in model.named_parameters(): 94 | if not param.requires_grad: continue 95 | if name.endswith("modules_to_save.default.weight"): 96 | partial_name = name[:-len(".modules_to_save.default.weight")] 97 | partial_name = partial_name[partial_name.rfind(".")+1:] 98 | print(f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.") 99 | param_groups["embeddings"] [name] = param 100 | else: 101 | param_groups["non_embeddings"][name] = param 102 | pass 103 | pass 104 | 105 | optimizer_grouped_parameters = [ 106 | { 107 | "params" : list(param_groups["non_embeddings"].values()), 108 | "weight_decay" : weight_decay, 109 | "lr" : lr, 110 | }, 111 | { 112 | "params" : list(param_groups["embeddings"].values()), 113 | "weight_decay" : weight_decay, 114 | "lr" : embedding_lr, 115 | }, 116 | ] 117 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 118 | return optimizer 119 | pass 120 | 121 | 122 | class UnslothTrainer(SFTTrainer): 123 | def create_optimizer(self): 124 | embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None) 125 | if embedding_learning_rate is None: return super().create_optimizer() 126 | 127 | if self.optimizer is None: 128 | optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args) 129 | self.optimizer = _create_unsloth_optimizer( 130 | self.model, 131 | optimizer_cls, 132 | optimizer_kwargs, 133 | embedding_learning_rate, 134 | ) 135 | pass 136 | return self.optimizer 137 | pass 138 | pass 139 | 140 | # From `trl>=0.13.0`, they changed how to pass several params to the trainer 141 | # We need to patch to make the transition smooth 142 | def _backwards_compatible_trainer(trainer_class, config_class): 143 | original_init = trainer_class.__init__ 144 | 145 | @wraps(original_init) 146 | def new_init(self, *args, **kwargs): 147 | # All Trainer tokenizer are now called processing_class 148 | trainer_params = set(inspect.signature(original_init).parameters.keys()) 149 | 150 | if "processing_class" in trainer_params and "tokenizer" in kwargs: 151 | kwargs["processing_class"] = kwargs.pop("tokenizer") 152 | pass 153 | 154 | if ("args" in kwargs) and (Version(trl.__version__) >= Version("0.13.0.dev0")): 155 | training_args = kwargs.pop("args", None) 156 | 157 | # Get parameters that Trainer.__init__ actually expects 158 | trainer_params.remove('self') 159 | trainer_params.remove('args') 160 | 161 | # Get fields that should be passed to Config init 162 | config_fields = { 163 | field.name: field for field in dataclasses.fields(config_class) 164 | if field.init 165 | } 166 | 167 | # Create config dict with valid fields from training_args 168 | config_dict = { 169 | name: getattr(training_args, name) 170 | for name in config_fields 171 | if hasattr(training_args, name) 172 | } 173 | 174 | # Get parameters that exist in Config but not in TrainingArguments 175 | from transformers import TrainingArguments 176 | moved_params = \ 177 | set(inspect.signature(config_class) .parameters.keys()) - \ 178 | set(inspect.signature(TrainingArguments).parameters.keys()) 179 | 180 | # Separate kwargs into trainer kwargs and config kwargs 181 | trainer_kwargs = {} 182 | additional_config_kwargs = {} 183 | 184 | for key, value in kwargs.items(): 185 | if key in trainer_params: trainer_kwargs[key] = value 186 | elif key in moved_params or key in config_fields: 187 | additional_config_kwargs[key] = value 188 | else: 189 | additional_config_kwargs[key] = value 190 | pass 191 | pass 192 | 193 | # Update config_dict with additional kwargs 194 | config_dict.update(additional_config_kwargs) 195 | 196 | # Create Config with all the collected parameters 197 | # Reinitialising config class with parameters (that were none initially but populated on first init) 198 | # causes the 2nd init to fail as there are mutual exclusive checks on pairs of parameters. 199 | # Refer: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_config.py#L499-L502 for example 200 | # So we only create config class if the previous init was not TrainingArguments 201 | if not isinstance(training_args, TrainingArguments): 202 | config = config_class(**config_dict) 203 | else: 204 | config = training_args 205 | 206 | # Reconstruct kwargs for Trainer 207 | kwargs = trainer_kwargs 208 | kwargs["args"] = config 209 | pass 210 | original_init(self, *args, **kwargs) 211 | pass 212 | return new_init 213 | pass 214 | 215 | 216 | def _patch_trl_trainer(): 217 | import trl 218 | if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return 219 | if Version(trl.__version__) <= Version("0.11.0"): return 220 | 221 | import trl.trainer 222 | trl_classes = dir(trl.trainer) 223 | trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) 224 | trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) 225 | trl_classes = list(trl_trainers & trl_configs) 226 | 227 | for x in trl_classes: 228 | try: exec(f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals()) 229 | except: continue 230 | pass 231 | 232 | trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True 233 | pass 234 | -------------------------------------------------------------------------------- /unsloth/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unslothai/unsloth/c5a2a36e47e4179eca7d2557063a5a3e82611b36/unsloth/utils/__init__.py -------------------------------------------------------------------------------- /unsloth/utils/hf_hub.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi, ModelInfo 2 | 3 | _HFAPI: HfApi = None 4 | 5 | POPULARITY_PROPERTIES = [ 6 | "downloads", 7 | "downloadsAllTime", 8 | "trendingScore", 9 | "likes", 10 | ] 11 | THOUSAND = 1000 12 | MILLION = 1000000 13 | BILLION = 1000000000 14 | 15 | 16 | def formatted_int(value: int) -> str: 17 | if value < THOUSAND: 18 | return str(value) 19 | elif value < MILLION: 20 | return f"{float(value) / 1000:,.1f}K" 21 | elif value < BILLION: 22 | return f"{float(value) // 1000000:,.1f}M" 23 | 24 | 25 | def get_model_info( 26 | model_id: str, properties: list[str] = ["safetensors", "lastModified"] 27 | ) -> ModelInfo: 28 | """ 29 | Get the model info for a specific model. 30 | 31 | properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/model_info 32 | Default properties: ["safetensors", "lastModified"], only retrieves minimal information. 33 | Set to None to retrieve the full model information. 34 | """ 35 | global _HFAPI 36 | if _HFAPI is None: 37 | _HFAPI = HfApi() 38 | try: 39 | model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties) 40 | except Exception as e: 41 | print(f"Error getting model info for {model_id}: {e}") 42 | model_info = None 43 | return model_info 44 | 45 | 46 | def list_models( 47 | properties: list[str] = None, 48 | full: bool = False, 49 | sort: str = "downloads", 50 | author: str = "unsloth", 51 | search: str = None, 52 | limit: int = 10, 53 | ) -> list[ModelInfo]: 54 | """ 55 | Retrieve model information from the Hugging Face Hub. 56 | 57 | properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/list_models 58 | full: bool = Whether to retrieve the full model information, if True properties will be ignored. 59 | sort: str = The sort order. 60 | author: str = The author of the model. 61 | search: str = The search query for filtering models. 62 | 63 | """ 64 | global _HFAPI 65 | if _HFAPI is None: 66 | _HFAPI = HfApi() 67 | if full: 68 | properties = None 69 | 70 | models: list[ModelInfo] = _HFAPI.list_models( 71 | author=author, 72 | search=search, 73 | sort=sort, 74 | limit=limit, 75 | expand=properties, 76 | full=full, 77 | ) 78 | return models 79 | --------------------------------------------------------------------------------