├── .dockerignore
├── Dockerfile
├── pyproject.toml
├── Makefile
├── LICENSE
├── config
└── config.yaml
├── README.md
├── .gitignore
└── main.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv
2 | test_env
3 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM vllm/vllm-openai:latest
2 |
3 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh
4 | ENV PATH="/root/.local/bin/:$PATH"
5 | WORKDIR /workspace
6 |
7 | ADD uv.lock uv.lock
8 | ADD pyproject.toml pyproject.toml
9 |
10 | RUN uv sync
11 | ENTRYPOINT ["/bin/bash"]
12 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "workspace"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.10"
7 | dependencies = [
8 | "diffusers>=0.32.2",
9 | "hydra-core>=1.3.2",
10 | "hydra-joblib-launcher>=1.2.0",
11 | "omegaconf>=2.3.0",
12 | "pre-commit>=4.1.0",
13 | "unsloth>=2025.2.4",
14 | "vllm==0.7.2",
15 | ]
16 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | IMAGE_NAME = grpo_unsloth
2 | CONTAINER_NAME = grpo_unsloth_container
3 |
4 | .PHONY: build create start stop clean
5 |
6 | build:
7 | docker build -t $(IMAGE_NAME) .
8 |
9 | create:
10 | docker create -it \
11 | --gpus=all \
12 | --name $(CONTAINER_NAME) \
13 | -v $$(pwd)/models:/models \
14 | -v $$(pwd):/workspace \
15 | -e HF_HOME=/models/cache \
16 | $(IMAGE_NAME)
17 |
18 | start:
19 | docker start $(CONTAINER_NAME)
20 |
21 | dry_run:
22 | docker exec -it $(CONTAINER_NAME) bash -c "uv run python main.py 'saving=null' 'training.max_steps=10'"
23 |
24 | train:
25 | docker exec -it $(CONTAINER_NAME) bash -c "uv run python main.py"
26 |
27 | stop:
28 | docker stop $(CONTAINER_NAME)
29 |
30 | clean:
31 | docker rm $(CONTAINER_NAME)
32 |
33 | # Combined targets
34 | up: build create start dry_run
35 |
36 | down: stop clean
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Artur Tanona
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - override hydra/launcher: joblib
4 | # - launcher:
5 | # n_jobs: 1
6 | # prefer: threads # use processes instead of threads
7 | # backend: multiprocessing # use multiprocessing backend instead of loky
8 | model:
9 | name: "Qwen/Qwen2.5-3B-Instruct"
10 | max_seq_length: 1024
11 | load_in_4bit: true
12 | fast_inference: true
13 | gpu_memory_utilization: 0.5
14 | lora:
15 | rank: 64
16 | target_modules:
17 | - "q_proj"
18 | - "k_proj"
19 | - "v_proj"
20 | - "o_proj"
21 | - "gate_proj"
22 | - "up_proj"
23 | - "down_proj"
24 | alpha: 64
25 | use_gradient_checkpointing: "unsloth"
26 | random_state: 3407
27 |
28 | training:
29 | learning_rate: 5e-6
30 | adam_beta1: 0.9
31 | adam_beta2: 0.99
32 | weight_decay: 0.1
33 | warmup_ratio: 0.1
34 | lr_scheduler_type: "cosine"
35 | optim: "adamw_8bit"
36 | logging_steps: 1
37 | per_device_train_batch_size: 1
38 | gradient_accumulation_steps: 1
39 | num_generations: 8
40 | max_prompt_length: 256
41 | max_completion_length: 200
42 | max_steps: 300
43 | save_steps: 100
44 | max_grad_norm: 0.1
45 | report_to: "none"
46 | output_dir: "outputs"
47 |
48 | saving:
49 | username: "your_username" # HuggingFace username
50 | model_dir: "model"
51 | hub_model_id: "${saving.username}/model"
52 | save_gguf:
53 | enabled: false
54 | quantization_methods:
55 | - "q4_k_m"
56 | - "q8_0"
57 | - "q5_k_m"
58 | save_merged:
59 | enabled: false
60 | methods:
61 | - "merged_16bit"
62 | - "merged_4bit"
63 | - "lora"
64 |
65 | system_prompt: |
66 | Respond in the following format:
67 |
68 | ...
69 |
70 |
71 | ...
72 |
73 |
74 | generation:
75 | temperature: 0.8
76 | top_p: 0.95
77 | max_tokens: 1024
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🚀 Local GRPO Training
2 |
3 | This is a refactored local version of the Unsloth Colab notebook, based on the excellent work by Daniel Han and the Unsloth team.
4 |
5 | Now you can run GRPO policy locally and feel the AHA MOMENT on your own machine! ✨
6 |
7 | ## 📚 Sources
8 | - 🔗 Original Colab notebook by Daniel Han: [LinkedIn Post](https://www.linkedin.com/posts/danielhanchen_google-colab-activity-7293333957046063104-M3lq)
9 | - 🧠 Reasoning model guidance from [Unsloth's blog post](https://unsloth.ai/blog/r1-reasoning)
10 | - 🎯 Reward model from [Will's Gist](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb)
11 |
12 | ## 🛠️ Prerequisites
13 |
14 | - 🖥️ GPU (NVIDIA)
15 | - 🔧 make (optional - see Advanced Instructions if not using make)
16 |
17 | ## 🏃♂️ Quick Start
18 |
19 | ```bash
20 | make up
21 | ```
22 |
23 | ## ⚙️ Configuration
24 |
25 | Modify `config.yaml` to customize settings and parameters. Then simply run:
26 | ```bash
27 | make train
28 | ```
29 |
30 | ## 🧹 Clean up
31 |
32 | ```bash
33 | make down
34 | ```
35 |
36 | ## ⚠️ Limitations
37 |
38 | - 🎮 Currently supports single GPU operations only
39 | - 💪 For multi-GPU or H100 access, please visit [runpod.io](https://runpod.io)
40 |
41 | ## 🔍 Advanced Instructions
42 |
43 | If you prefer not to use `make`, you can run the Docker commands directly:
44 |
45 | ```bash
46 | # 🏗️ Build the image
47 | docker build -t grpo_unsloth .
48 |
49 | # 📦 Create container
50 | docker create -it \
51 | --gpus=all \
52 | --name grpo_unsloth_container \
53 | -v $(pwd)/models:/models \
54 | -v $(pwd):/workspace \
55 | -e HF_HOME=/models/cache \
56 | grpo_unsloth
57 |
58 | # 🚀 Start container
59 | docker start grpo_unsloth_container
60 |
61 | # 🧪 Run a quick test (dry run)
62 | docker exec -it grpo_unsloth_container bash -c "uv run python main.py 'saving=null' 'training.max_steps=10'"
63 |
64 | # 🏃 Run full training
65 | docker exec -it grpo_unsloth_container bash -c "uv run python main.py 'saving=null'"
66 |
67 | # ⏹️ Stop container
68 | docker stop grpo_unsloth_container
69 |
70 | # 🗑️ Remove container
71 | docker rm grpo_unsloth_container
72 | ```
73 |
74 | ## 🤝 Contributing
75 |
76 | Feel free to open issues and pull requests!
77 |
78 | ## 📄 License
79 |
80 | This project is open-source and available under the MIT License.
81 |
82 | [](https://github.com/ArturTanona/grpo_unsloth_docker/blob/main/LICENSE)
83 | [](https://github.com/ArturTanona/grpo_unsloth_docker/stargazers)
84 | [](https://github.com/ArturTanona/grpo_unsloth_docker/issues)
85 | [](https://github.com/ArturTanona/grpo_unsloth_docker/network/members)
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 |
174 | outputs/*
175 | models/*
176 | unsloth_compiled_cache/*
177 | saved_model/*
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from unsloth import FastLanguageModel, PatchFastRL
3 |
4 | PatchFastRL("GRPO", FastLanguageModel) # needed for GRPO
5 |
6 |
7 | from unsloth import is_bfloat16_supported # noqa: E402
8 | from trl import GRPOConfig, GRPOTrainer # noqa: E402
9 | import re # noqa: E402
10 | from datasets import load_dataset, Dataset # noqa: E402
11 | from vllm import SamplingParams # noqa: E402
12 |
13 | from dataclasses import dataclass # noqa: E402
14 | from omegaconf import DictConfig # noqa: E402
15 | from dataclasses import field # noqa: E402
16 | import hydra # noqa: E402
17 |
18 | max_seq_length = 1024 # Can increase for longer reasoning traces
19 | lora_rank = 64 # Larger rank = smarter, but slower
20 |
21 |
22 | @dataclass
23 | class LoraConfig:
24 | rank: int = 64
25 | target_modules: List = field(
26 | default_factory=lambda: [
27 | "q_proj",
28 | "k_proj",
29 | "v_proj",
30 | "o_proj",
31 | "gate_proj",
32 | "up_proj",
33 | "down_proj",
34 | ]
35 | )
36 | use_gradient_checkpointing: str = "unsloth"
37 | random_state: int = 3407
38 |
39 |
40 | @dataclass
41 | class ModelConfig:
42 | max_seq_length: int = 1024
43 | load_in_4bit: bool = True
44 | fast_inference: bool = True
45 | lora: LoraConfig = field(default_factory=lambda: LoraConfig())
46 |
47 | gpu_memory_utilization: float = 0.5
48 |
49 |
50 | def prepare_model(cfg: DictConfig):
51 | model, tokenizer = FastLanguageModel.from_pretrained(
52 | model_name="Qwen/Qwen2.5-3B-Instruct",
53 | max_seq_length=cfg.model.max_seq_length,
54 | load_in_4bit=cfg.model.load_in_4bit, # False for LoRA 16bit
55 | fast_inference=cfg.model.fast_inference, # Enable vLLM fast inference
56 | max_lora_rank=cfg.model.lora.rank,
57 | gpu_memory_utilization=0.5, # Reduce if out of memory
58 | )
59 |
60 | model = FastLanguageModel.get_peft_model(
61 | model,
62 | r=cfg.model.lora.rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
63 | target_modules=[
64 | "q_proj",
65 | "k_proj",
66 | "v_proj",
67 | "o_proj",
68 | "gate_proj",
69 | "up_proj",
70 | "down_proj",
71 | ], # Remove QKVO if out of memory
72 | lora_alpha=cfg.model.lora.rank,
73 | use_gradient_checkpointing=cfg.model.lora.use_gradient_checkpointing, # Enable long context finetuning
74 | random_state=cfg.model.lora.random_state,
75 | )
76 | return model, tokenizer
77 |
78 |
79 | # Load and prep dataset
80 | SYSTEM_PROMPT = """
81 | Respond in the following format:
82 |
83 | ...
84 |
85 |
86 | ...
87 |
88 | """
89 |
90 | XML_COT_FORMAT = """\
91 |
92 | {reasoning}
93 |
94 |
95 | {answer}
96 |
97 | """
98 |
99 |
100 | def extract_xml_answer(text: str) -> str:
101 | answer = text.split("")[-1]
102 | answer = answer.split("")[0]
103 | return answer.strip()
104 |
105 |
106 | def extract_hash_answer(text: str) -> str | None:
107 | if "####" not in text:
108 | return None
109 | return text.split("####")[1].strip()
110 |
111 |
112 | # uncomment middle messages for 1-shot prompting
113 | def get_gsm8k_questions(split="train") -> Dataset:
114 | data = load_dataset("openai/gsm8k", "main")[split] # type: ignore
115 | data = data.map(
116 | lambda x: { # type: ignore
117 | "prompt": [
118 | {"role": "system", "content": SYSTEM_PROMPT},
119 | {"role": "user", "content": x["question"]},
120 | ],
121 | "answer": extract_hash_answer(x["answer"]),
122 | }
123 | ) # type: ignore
124 | return data # type: ignore
125 |
126 |
127 | dataset = get_gsm8k_questions()
128 |
129 |
130 | # Reward functions
131 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
132 | responses = [completion[0]["content"] for completion in completions]
133 | q = prompts[0][-1]["content"]
134 | extracted_responses = [extract_xml_answer(r) for r in responses]
135 | print(
136 | "-" * 20,
137 | f"Question:\n{q}",
138 | f"\nAnswer:\n{answer[0]}",
139 | f"\nResponse:\n{responses[0]}",
140 | f"\nExtracted:\n{extracted_responses[0]}",
141 | )
142 | return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
143 |
144 |
145 | def int_reward_func(completions, **kwargs) -> list[float]:
146 | responses = [completion[0]["content"] for completion in completions]
147 | extracted_responses = [extract_xml_answer(r) for r in responses]
148 | return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
149 |
150 |
151 | def strict_format_reward_func(completions, **kwargs) -> list[float]:
152 | """Reward function that checks if the completion has a specific format."""
153 | pattern = r"^\n.*?\n\n\n.*?\n\n$"
154 | responses = [completion[0]["content"] for completion in completions]
155 | matches = [re.match(pattern, r) for r in responses]
156 | return [0.5 if match else 0.0 for match in matches]
157 |
158 |
159 | def soft_format_reward_func(completions, **kwargs) -> list[float]:
160 | """Reward function that checks if the completion has a specific format."""
161 | pattern = r".*?\s*.*?"
162 | responses = [completion[0]["content"] for completion in completions]
163 | matches = [re.match(pattern, r) for r in responses]
164 | return [0.5 if match else 0.0 for match in matches]
165 |
166 |
167 | def count_xml(text) -> float:
168 | count = 0.0
169 | if text.count("\n") == 1:
170 | count += 0.125
171 | if text.count("\n\n") == 1:
172 | count += 0.125
173 | if text.count("\n\n") == 1:
174 | count += 0.125
175 | count -= len(text.split("\n\n")[-1]) * 0.001
176 | if text.count("\n") == 1:
177 | count += 0.125
178 | count -= (len(text.split("\n")[-1]) - 1) * 0.001
179 | return count
180 |
181 |
182 | def xmlcount_reward_func(completions, **kwargs) -> list[float]:
183 | contents = [completion[0]["content"] for completion in completions]
184 | return [count_xml(c) for c in contents]
185 |
186 |
187 | def strawberry_example(tokenizer, model):
188 | text = tokenizer.apply_chat_template(
189 | [
190 | {"role": "user", "content": "How many r's are in strawberry?"},
191 | ],
192 | tokenize=False,
193 | add_generation_prompt=True,
194 | )
195 |
196 | sampling_params = SamplingParams(
197 | temperature=0.8,
198 | top_p=0.95,
199 | max_tokens=1024,
200 | )
201 | output = (
202 | model.fast_generate(
203 | [text],
204 | sampling_params=sampling_params,
205 | lora_request=None,
206 | )[0]
207 | .outputs[0]
208 | .text
209 | )
210 |
211 | print(output)
212 |
213 |
214 | # output
215 |
216 |
217 | def strawberry_example_lora(tokenizer, model):
218 | text = tokenizer.apply_chat_template(
219 | [
220 | {"role": "system", "content": SYSTEM_PROMPT},
221 | {"role": "user", "content": "How many r's are in strawberry?"},
222 | ],
223 | tokenize=False,
224 | add_generation_prompt=True,
225 | )
226 |
227 | sampling_params = SamplingParams(
228 | temperature=0.8,
229 | top_p=0.95,
230 | max_tokens=1024,
231 | )
232 | output = (
233 | model.fast_generate(
234 | text,
235 | sampling_params=sampling_params,
236 | lora_request=model.load_lora("grpo_saved_lora"),
237 | )[0]
238 | .outputs[0]
239 | .text
240 | )
241 |
242 | print(output)
243 |
244 |
245 | def save(cfg, model, tokenizer):
246 | if cfg.saving.save_gguf.enabled:
247 | for quant_method in cfg.saving.save_gguf.quantization_methods:
248 | model.save_pretrained_gguf(
249 | cfg.saving.model_dir, tokenizer, quantization_method=quant_method
250 | )
251 | if cfg.saving.token: # Only push if token is provided
252 | model.push_to_hub_gguf(
253 | cfg.saving.hub_model_id,
254 | tokenizer,
255 | quantization_method=quant_method,
256 | token=cfg.saving.token,
257 | )
258 |
259 | if cfg.saving.save_merged.enabled:
260 | for save_method in cfg.saving.save_merged.methods:
261 | model.save_pretrained_merged(
262 | cfg.saving.model_dir, tokenizer, save_method=save_method
263 | )
264 | if cfg.saving.token: # Only push if token is provided
265 | model.push_to_hub_merged(
266 | cfg.saving.hub_model_id,
267 | tokenizer,
268 | save_method=save_method,
269 | token=cfg.saving.token,
270 | )
271 |
272 |
273 | @hydra.main(config_path="config", config_name="config.yaml")
274 | def main(cfg: DictConfig):
275 | model, tokenizer = prepare_model(cfg)
276 | training_args = GRPOConfig(
277 | use_vllm=True,
278 | learning_rate=cfg.training.learning_rate,
279 | adam_beta1=cfg.training.adam_beta1,
280 | adam_beta2=cfg.training.adam_beta2,
281 | weight_decay=cfg.training.weight_decay,
282 | warmup_ratio=cfg.training.warmup_ratio,
283 | lr_scheduler_type=cfg.training.lr_scheduler_type,
284 | optim=cfg.training.optim,
285 | logging_steps=cfg.training.logging_steps,
286 | bf16=is_bfloat16_supported(),
287 | fp16=not is_bfloat16_supported(),
288 | per_device_train_batch_size=cfg.training.per_device_train_batch_size,
289 | gradient_accumulation_steps=cfg.training.gradient_accumulation_steps,
290 | num_generations=cfg.training.num_generations,
291 | max_prompt_length=cfg.training.max_prompt_length,
292 | max_completion_length=cfg.training.max_completion_length,
293 | max_steps=cfg.training.max_steps,
294 | save_steps=cfg.training.save_steps,
295 | max_grad_norm=cfg.training.max_grad_norm,
296 | report_to=cfg.training.report_to,
297 | output_dir=cfg.training.output_dir,
298 | )
299 |
300 | trainer = GRPOTrainer(
301 | model=model,
302 | processing_class=tokenizer,
303 | reward_funcs=[
304 | xmlcount_reward_func,
305 | soft_format_reward_func,
306 | strict_format_reward_func,
307 | int_reward_func,
308 | correctness_reward_func,
309 | ],
310 | args=training_args,
311 | train_dataset=dataset,
312 | )
313 | trainer.train()
314 | strawberry_example(tokenizer=tokenizer, model=model)
315 | strawberry_example_lora(tokenizer=tokenizer, model=model)
316 | trainer.save_model('/workspace/saved_model')
317 |
318 | if cfg.saving is not None:
319 | save(cfg, model, tokenizer)
320 |
321 |
322 | if __name__ == "__main__":
323 | main()
324 |
--------------------------------------------------------------------------------