├── .dockerignore ├── Dockerfile ├── pyproject.toml ├── Makefile ├── LICENSE ├── config └── config.yaml ├── README.md ├── .gitignore └── main.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | test_env 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM vllm/vllm-openai:latest 2 | 3 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh 4 | ENV PATH="/root/.local/bin/:$PATH" 5 | WORKDIR /workspace 6 | 7 | ADD uv.lock uv.lock 8 | ADD pyproject.toml pyproject.toml 9 | 10 | RUN uv sync 11 | ENTRYPOINT ["/bin/bash"] 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "workspace" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "diffusers>=0.32.2", 9 | "hydra-core>=1.3.2", 10 | "hydra-joblib-launcher>=1.2.0", 11 | "omegaconf>=2.3.0", 12 | "pre-commit>=4.1.0", 13 | "unsloth>=2025.2.4", 14 | "vllm==0.7.2", 15 | ] 16 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | IMAGE_NAME = grpo_unsloth 2 | CONTAINER_NAME = grpo_unsloth_container 3 | 4 | .PHONY: build create start stop clean 5 | 6 | build: 7 | docker build -t $(IMAGE_NAME) . 8 | 9 | create: 10 | docker create -it \ 11 | --gpus=all \ 12 | --name $(CONTAINER_NAME) \ 13 | -v $$(pwd)/models:/models \ 14 | -v $$(pwd):/workspace \ 15 | -e HF_HOME=/models/cache \ 16 | $(IMAGE_NAME) 17 | 18 | start: 19 | docker start $(CONTAINER_NAME) 20 | 21 | dry_run: 22 | docker exec -it $(CONTAINER_NAME) bash -c "uv run python main.py 'saving=null' 'training.max_steps=10'" 23 | 24 | train: 25 | docker exec -it $(CONTAINER_NAME) bash -c "uv run python main.py" 26 | 27 | stop: 28 | docker stop $(CONTAINER_NAME) 29 | 30 | clean: 31 | docker rm $(CONTAINER_NAME) 32 | 33 | # Combined targets 34 | up: build create start dry_run 35 | 36 | down: stop clean 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Artur Tanona 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: joblib 4 | # - launcher: 5 | # n_jobs: 1 6 | # prefer: threads # use processes instead of threads 7 | # backend: multiprocessing # use multiprocessing backend instead of loky 8 | model: 9 | name: "Qwen/Qwen2.5-3B-Instruct" 10 | max_seq_length: 1024 11 | load_in_4bit: true 12 | fast_inference: true 13 | gpu_memory_utilization: 0.5 14 | lora: 15 | rank: 64 16 | target_modules: 17 | - "q_proj" 18 | - "k_proj" 19 | - "v_proj" 20 | - "o_proj" 21 | - "gate_proj" 22 | - "up_proj" 23 | - "down_proj" 24 | alpha: 64 25 | use_gradient_checkpointing: "unsloth" 26 | random_state: 3407 27 | 28 | training: 29 | learning_rate: 5e-6 30 | adam_beta1: 0.9 31 | adam_beta2: 0.99 32 | weight_decay: 0.1 33 | warmup_ratio: 0.1 34 | lr_scheduler_type: "cosine" 35 | optim: "adamw_8bit" 36 | logging_steps: 1 37 | per_device_train_batch_size: 1 38 | gradient_accumulation_steps: 1 39 | num_generations: 8 40 | max_prompt_length: 256 41 | max_completion_length: 200 42 | max_steps: 300 43 | save_steps: 100 44 | max_grad_norm: 0.1 45 | report_to: "none" 46 | output_dir: "outputs" 47 | 48 | saving: 49 | username: "your_username" # HuggingFace username 50 | model_dir: "model" 51 | hub_model_id: "${saving.username}/model" 52 | save_gguf: 53 | enabled: false 54 | quantization_methods: 55 | - "q4_k_m" 56 | - "q8_0" 57 | - "q5_k_m" 58 | save_merged: 59 | enabled: false 60 | methods: 61 | - "merged_16bit" 62 | - "merged_4bit" 63 | - "lora" 64 | 65 | system_prompt: | 66 | Respond in the following format: 67 | 68 | ... 69 | 70 | 71 | ... 72 | 73 | 74 | generation: 75 | temperature: 0.8 76 | top_p: 0.95 77 | max_tokens: 1024 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 Local GRPO Training 2 | 3 | This is a refactored local version of the Unsloth Colab notebook, based on the excellent work by Daniel Han and the Unsloth team. 4 | 5 | Now you can run GRPO policy locally and feel the AHA MOMENT on your own machine! ✨ 6 | 7 | ## 📚 Sources 8 | - 🔗 Original Colab notebook by Daniel Han: [LinkedIn Post](https://www.linkedin.com/posts/danielhanchen_google-colab-activity-7293333957046063104-M3lq) 9 | - 🧠 Reasoning model guidance from [Unsloth's blog post](https://unsloth.ai/blog/r1-reasoning) 10 | - 🎯 Reward model from [Will's Gist](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) 11 | 12 | ## 🛠️ Prerequisites 13 | 14 | - 🖥️ GPU (NVIDIA) 15 | - 🔧 make (optional - see Advanced Instructions if not using make) 16 | 17 | ## 🏃‍♂️ Quick Start 18 | 19 | ```bash 20 | make up 21 | ``` 22 | 23 | ## ⚙️ Configuration 24 | 25 | Modify `config.yaml` to customize settings and parameters. Then simply run: 26 | ```bash 27 | make train 28 | ``` 29 | 30 | ## 🧹 Clean up 31 | 32 | ```bash 33 | make down 34 | ``` 35 | 36 | ## ⚠️ Limitations 37 | 38 | - 🎮 Currently supports single GPU operations only 39 | - 💪 For multi-GPU or H100 access, please visit [runpod.io](https://runpod.io) 40 | 41 | ## 🔍 Advanced Instructions 42 | 43 | If you prefer not to use `make`, you can run the Docker commands directly: 44 | 45 | ```bash 46 | # 🏗️ Build the image 47 | docker build -t grpo_unsloth . 48 | 49 | # 📦 Create container 50 | docker create -it \ 51 | --gpus=all \ 52 | --name grpo_unsloth_container \ 53 | -v $(pwd)/models:/models \ 54 | -v $(pwd):/workspace \ 55 | -e HF_HOME=/models/cache \ 56 | grpo_unsloth 57 | 58 | # 🚀 Start container 59 | docker start grpo_unsloth_container 60 | 61 | # 🧪 Run a quick test (dry run) 62 | docker exec -it grpo_unsloth_container bash -c "uv run python main.py 'saving=null' 'training.max_steps=10'" 63 | 64 | # 🏃 Run full training 65 | docker exec -it grpo_unsloth_container bash -c "uv run python main.py 'saving=null'" 66 | 67 | # ⏹️ Stop container 68 | docker stop grpo_unsloth_container 69 | 70 | # 🗑️ Remove container 71 | docker rm grpo_unsloth_container 72 | ``` 73 | 74 | ## 🤝 Contributing 75 | 76 | Feel free to open issues and pull requests! 77 | 78 | ## 📄 License 79 | 80 | This project is open-source and available under the MIT License. 81 | 82 | [![GitHub](https://img.shields.io/github/license/ArturTanona/grpo_unsloth_docker)](https://github.com/ArturTanona/grpo_unsloth_docker/blob/main/LICENSE) 83 | [![GitHub stars](https://img.shields.io/github/stars/ArturTanona/grpo_unsloth_docker)](https://github.com/ArturTanona/grpo_unsloth_docker/stargazers) 84 | [![GitHub issues](https://img.shields.io/github/issues/ArturTanona/grpo_unsloth_docker)](https://github.com/ArturTanona/grpo_unsloth_docker/issues) 85 | [![GitHub forks](https://img.shields.io/github/forks/ArturTanona/grpo_unsloth_docker)](https://github.com/ArturTanona/grpo_unsloth_docker/network/members) 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | 174 | outputs/* 175 | models/* 176 | unsloth_compiled_cache/* 177 | saved_model/* -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from unsloth import FastLanguageModel, PatchFastRL 3 | 4 | PatchFastRL("GRPO", FastLanguageModel) # needed for GRPO 5 | 6 | 7 | from unsloth import is_bfloat16_supported # noqa: E402 8 | from trl import GRPOConfig, GRPOTrainer # noqa: E402 9 | import re # noqa: E402 10 | from datasets import load_dataset, Dataset # noqa: E402 11 | from vllm import SamplingParams # noqa: E402 12 | 13 | from dataclasses import dataclass # noqa: E402 14 | from omegaconf import DictConfig # noqa: E402 15 | from dataclasses import field # noqa: E402 16 | import hydra # noqa: E402 17 | 18 | max_seq_length = 1024 # Can increase for longer reasoning traces 19 | lora_rank = 64 # Larger rank = smarter, but slower 20 | 21 | 22 | @dataclass 23 | class LoraConfig: 24 | rank: int = 64 25 | target_modules: List = field( 26 | default_factory=lambda: [ 27 | "q_proj", 28 | "k_proj", 29 | "v_proj", 30 | "o_proj", 31 | "gate_proj", 32 | "up_proj", 33 | "down_proj", 34 | ] 35 | ) 36 | use_gradient_checkpointing: str = "unsloth" 37 | random_state: int = 3407 38 | 39 | 40 | @dataclass 41 | class ModelConfig: 42 | max_seq_length: int = 1024 43 | load_in_4bit: bool = True 44 | fast_inference: bool = True 45 | lora: LoraConfig = field(default_factory=lambda: LoraConfig()) 46 | 47 | gpu_memory_utilization: float = 0.5 48 | 49 | 50 | def prepare_model(cfg: DictConfig): 51 | model, tokenizer = FastLanguageModel.from_pretrained( 52 | model_name="Qwen/Qwen2.5-3B-Instruct", 53 | max_seq_length=cfg.model.max_seq_length, 54 | load_in_4bit=cfg.model.load_in_4bit, # False for LoRA 16bit 55 | fast_inference=cfg.model.fast_inference, # Enable vLLM fast inference 56 | max_lora_rank=cfg.model.lora.rank, 57 | gpu_memory_utilization=0.5, # Reduce if out of memory 58 | ) 59 | 60 | model = FastLanguageModel.get_peft_model( 61 | model, 62 | r=cfg.model.lora.rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 63 | target_modules=[ 64 | "q_proj", 65 | "k_proj", 66 | "v_proj", 67 | "o_proj", 68 | "gate_proj", 69 | "up_proj", 70 | "down_proj", 71 | ], # Remove QKVO if out of memory 72 | lora_alpha=cfg.model.lora.rank, 73 | use_gradient_checkpointing=cfg.model.lora.use_gradient_checkpointing, # Enable long context finetuning 74 | random_state=cfg.model.lora.random_state, 75 | ) 76 | return model, tokenizer 77 | 78 | 79 | # Load and prep dataset 80 | SYSTEM_PROMPT = """ 81 | Respond in the following format: 82 | 83 | ... 84 | 85 | 86 | ... 87 | 88 | """ 89 | 90 | XML_COT_FORMAT = """\ 91 | 92 | {reasoning} 93 | 94 | 95 | {answer} 96 | 97 | """ 98 | 99 | 100 | def extract_xml_answer(text: str) -> str: 101 | answer = text.split("")[-1] 102 | answer = answer.split("")[0] 103 | return answer.strip() 104 | 105 | 106 | def extract_hash_answer(text: str) -> str | None: 107 | if "####" not in text: 108 | return None 109 | return text.split("####")[1].strip() 110 | 111 | 112 | # uncomment middle messages for 1-shot prompting 113 | def get_gsm8k_questions(split="train") -> Dataset: 114 | data = load_dataset("openai/gsm8k", "main")[split] # type: ignore 115 | data = data.map( 116 | lambda x: { # type: ignore 117 | "prompt": [ 118 | {"role": "system", "content": SYSTEM_PROMPT}, 119 | {"role": "user", "content": x["question"]}, 120 | ], 121 | "answer": extract_hash_answer(x["answer"]), 122 | } 123 | ) # type: ignore 124 | return data # type: ignore 125 | 126 | 127 | dataset = get_gsm8k_questions() 128 | 129 | 130 | # Reward functions 131 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: 132 | responses = [completion[0]["content"] for completion in completions] 133 | q = prompts[0][-1]["content"] 134 | extracted_responses = [extract_xml_answer(r) for r in responses] 135 | print( 136 | "-" * 20, 137 | f"Question:\n{q}", 138 | f"\nAnswer:\n{answer[0]}", 139 | f"\nResponse:\n{responses[0]}", 140 | f"\nExtracted:\n{extracted_responses[0]}", 141 | ) 142 | return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] 143 | 144 | 145 | def int_reward_func(completions, **kwargs) -> list[float]: 146 | responses = [completion[0]["content"] for completion in completions] 147 | extracted_responses = [extract_xml_answer(r) for r in responses] 148 | return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] 149 | 150 | 151 | def strict_format_reward_func(completions, **kwargs) -> list[float]: 152 | """Reward function that checks if the completion has a specific format.""" 153 | pattern = r"^\n.*?\n\n\n.*?\n\n$" 154 | responses = [completion[0]["content"] for completion in completions] 155 | matches = [re.match(pattern, r) for r in responses] 156 | return [0.5 if match else 0.0 for match in matches] 157 | 158 | 159 | def soft_format_reward_func(completions, **kwargs) -> list[float]: 160 | """Reward function that checks if the completion has a specific format.""" 161 | pattern = r".*?\s*.*?" 162 | responses = [completion[0]["content"] for completion in completions] 163 | matches = [re.match(pattern, r) for r in responses] 164 | return [0.5 if match else 0.0 for match in matches] 165 | 166 | 167 | def count_xml(text) -> float: 168 | count = 0.0 169 | if text.count("\n") == 1: 170 | count += 0.125 171 | if text.count("\n\n") == 1: 172 | count += 0.125 173 | if text.count("\n\n") == 1: 174 | count += 0.125 175 | count -= len(text.split("\n\n")[-1]) * 0.001 176 | if text.count("\n") == 1: 177 | count += 0.125 178 | count -= (len(text.split("\n")[-1]) - 1) * 0.001 179 | return count 180 | 181 | 182 | def xmlcount_reward_func(completions, **kwargs) -> list[float]: 183 | contents = [completion[0]["content"] for completion in completions] 184 | return [count_xml(c) for c in contents] 185 | 186 | 187 | def strawberry_example(tokenizer, model): 188 | text = tokenizer.apply_chat_template( 189 | [ 190 | {"role": "user", "content": "How many r's are in strawberry?"}, 191 | ], 192 | tokenize=False, 193 | add_generation_prompt=True, 194 | ) 195 | 196 | sampling_params = SamplingParams( 197 | temperature=0.8, 198 | top_p=0.95, 199 | max_tokens=1024, 200 | ) 201 | output = ( 202 | model.fast_generate( 203 | [text], 204 | sampling_params=sampling_params, 205 | lora_request=None, 206 | )[0] 207 | .outputs[0] 208 | .text 209 | ) 210 | 211 | print(output) 212 | 213 | 214 | # output 215 | 216 | 217 | def strawberry_example_lora(tokenizer, model): 218 | text = tokenizer.apply_chat_template( 219 | [ 220 | {"role": "system", "content": SYSTEM_PROMPT}, 221 | {"role": "user", "content": "How many r's are in strawberry?"}, 222 | ], 223 | tokenize=False, 224 | add_generation_prompt=True, 225 | ) 226 | 227 | sampling_params = SamplingParams( 228 | temperature=0.8, 229 | top_p=0.95, 230 | max_tokens=1024, 231 | ) 232 | output = ( 233 | model.fast_generate( 234 | text, 235 | sampling_params=sampling_params, 236 | lora_request=model.load_lora("grpo_saved_lora"), 237 | )[0] 238 | .outputs[0] 239 | .text 240 | ) 241 | 242 | print(output) 243 | 244 | 245 | def save(cfg, model, tokenizer): 246 | if cfg.saving.save_gguf.enabled: 247 | for quant_method in cfg.saving.save_gguf.quantization_methods: 248 | model.save_pretrained_gguf( 249 | cfg.saving.model_dir, tokenizer, quantization_method=quant_method 250 | ) 251 | if cfg.saving.token: # Only push if token is provided 252 | model.push_to_hub_gguf( 253 | cfg.saving.hub_model_id, 254 | tokenizer, 255 | quantization_method=quant_method, 256 | token=cfg.saving.token, 257 | ) 258 | 259 | if cfg.saving.save_merged.enabled: 260 | for save_method in cfg.saving.save_merged.methods: 261 | model.save_pretrained_merged( 262 | cfg.saving.model_dir, tokenizer, save_method=save_method 263 | ) 264 | if cfg.saving.token: # Only push if token is provided 265 | model.push_to_hub_merged( 266 | cfg.saving.hub_model_id, 267 | tokenizer, 268 | save_method=save_method, 269 | token=cfg.saving.token, 270 | ) 271 | 272 | 273 | @hydra.main(config_path="config", config_name="config.yaml") 274 | def main(cfg: DictConfig): 275 | model, tokenizer = prepare_model(cfg) 276 | training_args = GRPOConfig( 277 | use_vllm=True, 278 | learning_rate=cfg.training.learning_rate, 279 | adam_beta1=cfg.training.adam_beta1, 280 | adam_beta2=cfg.training.adam_beta2, 281 | weight_decay=cfg.training.weight_decay, 282 | warmup_ratio=cfg.training.warmup_ratio, 283 | lr_scheduler_type=cfg.training.lr_scheduler_type, 284 | optim=cfg.training.optim, 285 | logging_steps=cfg.training.logging_steps, 286 | bf16=is_bfloat16_supported(), 287 | fp16=not is_bfloat16_supported(), 288 | per_device_train_batch_size=cfg.training.per_device_train_batch_size, 289 | gradient_accumulation_steps=cfg.training.gradient_accumulation_steps, 290 | num_generations=cfg.training.num_generations, 291 | max_prompt_length=cfg.training.max_prompt_length, 292 | max_completion_length=cfg.training.max_completion_length, 293 | max_steps=cfg.training.max_steps, 294 | save_steps=cfg.training.save_steps, 295 | max_grad_norm=cfg.training.max_grad_norm, 296 | report_to=cfg.training.report_to, 297 | output_dir=cfg.training.output_dir, 298 | ) 299 | 300 | trainer = GRPOTrainer( 301 | model=model, 302 | processing_class=tokenizer, 303 | reward_funcs=[ 304 | xmlcount_reward_func, 305 | soft_format_reward_func, 306 | strict_format_reward_func, 307 | int_reward_func, 308 | correctness_reward_func, 309 | ], 310 | args=training_args, 311 | train_dataset=dataset, 312 | ) 313 | trainer.train() 314 | strawberry_example(tokenizer=tokenizer, model=model) 315 | strawberry_example_lora(tokenizer=tokenizer, model=model) 316 | trainer.save_model('/workspace/saved_model') 317 | 318 | if cfg.saving is not None: 319 | save(cfg, model, tokenizer) 320 | 321 | 322 | if __name__ == "__main__": 323 | main() 324 | --------------------------------------------------------------------------------