├── .gitignore ├── GRPO-Loss-Analysis.ipynb ├── GRPO-Loss-Pytorch.ipynb ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | ##### Project Specification ##### 2 | output/ 3 | outputs/ 4 | wandb/ 5 | BIG-bench/ 6 | 7 | ##### Python.gitignore ##### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | wheelhouse/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | *.whl 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | docs/build/ 83 | docs/source/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/#use-with-ide 121 | .pdm.toml 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # ruff 153 | .ruff_cache/ 154 | 155 | # mypy 156 | .mypy_cache/ 157 | .dmypy.json 158 | dmypy.json 159 | 160 | # Pyre type checker 161 | .pyre/ 162 | 163 | # pytype static type analyzer 164 | .pytype/ 165 | 166 | # Cython debug symbols 167 | cython_debug/ 168 | 169 | # PyCharm 170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 172 | # and can be added to the global gitignore or merged into this file. For a more nuclear 173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 174 | .idea/ 175 | 176 | 177 | ##### macOS.gitignore ##### 178 | # General 179 | .DS_Store 180 | .AppleDouble 181 | .LSOverride 182 | 183 | # Icon must end with two \r 184 | Icon 185 | 186 | # Thumbnails 187 | ._* 188 | 189 | # Files that might appear in the root of a volume 190 | .DocumentRevisions-V100 191 | .fseventsd 192 | .Spotlight-V100 193 | .TemporaryItems 194 | .Trashes 195 | .VolumeIcon.icns 196 | .com.apple.timemachine.donotpresent 197 | 198 | # Directories potentially created on remote AFP share 199 | .AppleDB 200 | .AppleDesktop 201 | Network Trash Folder 202 | Temporary Items 203 | .apdisk 204 | 205 | 206 | ##### Linux.gitignore ##### 207 | *~ 208 | 209 | # Temporary files which can be created if a process still has a handle open of a deleted file 210 | .fuse_hidden* 211 | 212 | # KDE directory preferences 213 | .directory 214 | 215 | # Linux trash folder which might appear on any partition or disk 216 | .Trash-* 217 | 218 | # .nfs files are created when an open file is removed but is still being accessed 219 | .nfs* 220 | 221 | 222 | ##### Windows.gitignore ##### 223 | # Windows thumbnail cache files 224 | Thumbs.db 225 | Thumbs.db:encryptable 226 | ehthumbs.db 227 | ehthumbs_vista.db 228 | 229 | # Dump file 230 | *.stackdump 231 | 232 | # Folder config file 233 | [Dd]esktop.ini 234 | 235 | # Recycle Bin used on file shares 236 | $RECYCLE.BIN/ 237 | 238 | # Windows Installer files 239 | *.cab 240 | *.msi 241 | *.msix 242 | *.msm 243 | *.msp 244 | 245 | # Windows shortcuts 246 | *.lnk 247 | 248 | 249 | ##### Archives.gitignore ##### 250 | # It's better to unpack these files and commit the raw source because 251 | # git has its own built in compression methods. 252 | *.7z 253 | *.jar 254 | *.rar 255 | *.zip 256 | *.gz 257 | *.gzip 258 | *.tgz 259 | *.bzip 260 | *.bzip2 261 | *.bz2 262 | *.xz 263 | *.lzma 264 | *.cab 265 | *.xar 266 | 267 | # Packing-only formats 268 | *.iso 269 | *.tar 270 | 271 | # Package management formats 272 | *.dmg 273 | *.xpi 274 | *.gem 275 | *.egg 276 | *.deb 277 | *.rpm 278 | *.msi 279 | *.msm 280 | *.msp 281 | *.txz 282 | 283 | 284 | ##### Xcode.gitignore ##### 285 | # Xcode 286 | # 287 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 288 | 289 | ## User settings 290 | xcuserdata/ 291 | 292 | ## Compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) 293 | *.xcscmblueprint 294 | *.xccheckout 295 | 296 | ## Compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) 297 | build/ 298 | DerivedData/ 299 | *.moved-aside 300 | *.pbxuser 301 | !default.pbxuser 302 | *.mode1v3 303 | !default.mode1v3 304 | *.mode2v3 305 | !default.mode2v3 306 | *.perspectivev3 307 | !default.perspectivev3 308 | 309 | ## Gcc Patch 310 | /*.gcno 311 | 312 | 313 | ##### JetBrains.gitignore ##### 314 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 315 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 316 | 317 | # User settings 318 | .idea/* 319 | 320 | # User-specific stuff 321 | .idea/**/workspace.xml 322 | .idea/**/tasks.xml 323 | .idea/**/usage.statistics.xml 324 | .idea/**/dictionaries 325 | .idea/**/shelf 326 | 327 | # Generated files 328 | .idea/**/contentModel.xml 329 | 330 | # Sensitive or high-churn files 331 | .idea/**/dataSources/ 332 | .idea/**/dataSources.ids 333 | .idea/**/dataSources.local.xml 334 | .idea/**/sqlDataSources.xml 335 | .idea/**/dynamic.xml 336 | .idea/**/uiDesigner.xml 337 | .idea/**/dbnavigator.xml 338 | 339 | # Gradle 340 | .idea/**/gradle.xml 341 | .idea/**/libraries 342 | 343 | # Gradle and Maven with auto-import 344 | # When using Gradle or Maven with auto-import, you should exclude module files, 345 | # since they will be recreated, and may cause churn. Uncomment if using 346 | # auto-import. 347 | # .idea/artifacts 348 | # .idea/compiler.xml 349 | # .idea/jarRepositories.xml 350 | # .idea/modules.xml 351 | # .idea/*.iml 352 | # .idea/modules 353 | # *.iml 354 | # *.ipr 355 | 356 | # CMake 357 | cmake-build-*/ 358 | 359 | # Mongo Explorer plugin 360 | .idea/**/mongoSettings.xml 361 | 362 | # File-based project format 363 | *.iws 364 | 365 | # IntelliJ 366 | out/ 367 | 368 | # mpeltonen/sbt-idea plugin 369 | .idea_modules/ 370 | 371 | # JIRA plugin 372 | atlassian-ide-plugin.xml 373 | 374 | # Cursive Clojure plugin 375 | .idea/replstate.xml 376 | 377 | # Crashlytics plugin (for Android Studio and IntelliJ) 378 | com_crashlytics_export_strings.xml 379 | crashlytics.properties 380 | crashlytics-build.properties 381 | fabric.properties 382 | 383 | # Editor-based Rest Client 384 | .idea/httpRequests 385 | 386 | # Android studio 3.1+ serialized cache file 387 | .idea/caches/build_file_checksums.ser 388 | 389 | 390 | ##### VisualStudioCode.gitignore ##### 391 | .vscode/* 392 | # !.vscode/settings.json 393 | # !.vscode/tasks.json 394 | # !.vscode/launch.json 395 | !.vscode/extensions.json 396 | *.code-workspace 397 | 398 | # Local History for Visual Studio Code 399 | .history/ 400 | 401 | 402 | ##### Vim.gitignore ##### 403 | # Swap 404 | .*.s[a-v][a-z] 405 | !*.svg # comment out if you don't need vector files 406 | .*.sw[a-p] 407 | .s[a-rt-v][a-z] 408 | .ss[a-gi-z] 409 | .sw[a-p] 410 | 411 | # Session 412 | Session.vim 413 | Sessionx.vim 414 | 415 | # Temporary 416 | .netrwhist 417 | *~ 418 | # Auto-generated tag files 419 | tags 420 | # Persistent undo 421 | [._]*.un~ 422 | -------------------------------------------------------------------------------- /GRPO-Loss-Analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d922cdc2-7da0-409a-9840-40e6b71b57ee", 6 | "metadata": {}, 7 | "source": [ 8 | "# Why GRPO Loss is negative and ascend?" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "c053c97b-da27-49bc-9e98-f3fd05f74efc", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import torch.nn as nn\n", 20 | "import torch.nn.functional as F" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "id": "cc33f0ff-bcc9-4b26-b57e-85cdba08a2ec", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# GRPO-KL always > 0\n", 31 | "def grpo_kl(pi_logprob, pi_ref_logprob):\n", 32 | " return pi_ref_logprob.exp() / pi_logprob.exp()- (pi_ref_logprob - pi_logprob) - 1" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "2e9df2eb-f389-49fd-9012-c7ee200d3e18", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def grpo_advantage(rewards):\n", 43 | " epsilon = 0.00001\n", 44 | " A = (rewards - rewards.mean()) / (rewards.std() + epsilon)\n", 45 | " return A" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "259838e0-ca85-424d-8f2c-dcf166f8f9a7", 51 | "metadata": {}, 52 | "source": [ 53 | "## Why GRPO Loss is Negative" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "0e53d9b0-58f3-49fe-b62c-65ea85816511", 59 | "metadata": {}, 60 | "source": [ 61 | "1. we only calculative `min(ration * Advantage)-kl`\n", 62 | "\n", 63 | "\\begin{equation}\n", 64 | "\\begin{aligned}\n", 65 | "\\mathcal{L}_{\\text{GRPO}}(\\theta) \n", 66 | "= & \\textcolor{red}{-} \\frac{1}{G} \\sum_{i=1}^G \\frac{1}{|o_i|} \\sum_{t=1}^{|o_i|} \\Biggl[ \n", 67 | " \\textcolor{blue}{\\min} \\Biggl( \n", 68 | " \\frac{\\pi_\\theta(o_{i,t} \\mid q, o_{i, 0\n", 135 | "pi_logprob = torch.tensor(0.5).log()\n", 136 | "pi_old_logprob = torch.tensor(0.5).log()\n", 137 | "pi_ref_logprob = torch.tensor(0.6).log()\n", 138 | "rewards_group = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0], dtype = torch.float32)\n", 139 | "\n", 140 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 141 | "loss.sum()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "id": "d7b7ca34-d87d-4880-b031-fef5e8b581b1", 147 | "metadata": {}, 148 | "source": [ 149 | "## loss negative -> positive" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 12, 155 | "id": "8ccc557c-34ba-4216-91bd-d7513062504e", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "[Rewards] : tensor([1., 0., 0., 0., 0., 0., 0., 0.])\n", 163 | "[Adv] : tensor([ 2.4748, -0.3535, -0.3535, -0.3535, -0.3535, -0.3535, -0.3535, -0.3535])\n", 164 | "[Loss] : tensor([-49.4961, 7.0709, 7.0709, 7.0709, 7.0709, 7.0709, 7.0709,\n", 165 | " 7.0709])\n", 166 | "tensor(-1.9073e-06)\n", 167 | "[Rewards] : tensor([1., 1., 0., 0., 0., 0., 0., 0.])\n", 168 | "[Adv] : tensor([ 1.6202, 1.6202, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401])\n", 169 | "[Loss] : tensor([-32.4030, -32.4030, 10.8010, 10.8010, 10.8010, 10.8010, 10.8010,\n", 170 | " 10.8010])\n", 171 | "tensor(9.5367e-06)\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "# ratio * A > KL , loss < 0\n", 177 | "pi_logprob = torch.tensor(0.1).log()\n", 178 | "pi_old_logprob = torch.tensor(0.005).log()\n", 179 | "pi_ref_logprob = torch.tensor(0.1001).log()\n", 180 | "\n", 181 | "# one positive reward\n", 182 | "rewards_group = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0], dtype = torch.float32)\n", 183 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 184 | "print(loss.sum())\n", 185 | "\n", 186 | "# two positive reward\n", 187 | "rewards_group = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0], dtype = torch.float32)\n", 188 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 189 | "print(loss.sum())" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "96d5bb1b-3970-42df-848f-46d86d9eec26", 195 | "metadata": {}, 196 | "source": [ 197 | "# Loss Rising" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "id": "559a9cdf-ac2d-4cfe-8ffd-44368969a468", 203 | "metadata": {}, 204 | "source": [ 205 | "## rewards change" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 13, 211 | "id": "0f0a6ac5-588e-4e66-b13d-1e593ee84fbe", 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "[Rewards] : tensor([1., 0., 0., 0., 0., 0., 0., 0.])\n", 219 | "[Adv] : tensor([ 2.4748, -0.3535, -0.3535, -0.3535, -0.3535, -0.3535, -0.3535, -0.3535])\n", 220 | "[Loss] : tensor([-3.2997, 0.4714, 0.4714, 0.4714, 0.4714, 0.4714, 0.4714, 0.4714])\n", 221 | "result: 1.0 tensor(0.) \n", 222 | "\n", 223 | "[Rewards] : tensor([1., 1., 1., 1., 0., 0., 0., 0.])\n", 224 | "[Adv] : tensor([ 0.9354, 0.9354, 0.9354, 0.9354, -0.9354, -0.9354, -0.9354, -0.9354])\n", 225 | "[Loss] : tensor([-1.2472, -1.2472, -1.2472, -1.2472, 1.2472, 1.2472, 1.2472, 1.2472])\n", 226 | "result: 4.0 tensor(-2.3842e-07) \n", 227 | "\n", 228 | "[Rewards] : tensor([1., 1., 1., 1., 1., 1., 1., 0.])\n", 229 | "[Adv] : tensor([ 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, 0.3535, -2.4748])\n", 230 | "[Loss] : tensor([-0.4714, -0.4714, -0.4714, -0.4714, -0.4714, -0.4714, -0.4714, 3.2997])\n", 231 | "result: 7.0 tensor(2.3842e-07) \n", 232 | "\n", 233 | "[Rewards] : tensor([1., 1., 1., 1., 1., 1., 1., 1.])\n", 234 | "[Adv] : tensor([0., 0., 0., 0., 0., 0., 0., 0.])\n", 235 | "[Loss] : tensor([3.0994e-08, 3.0994e-08, 3.0994e-08, 3.0994e-08, 3.0994e-08, 3.0994e-08,\n", 236 | " 3.0994e-08, 3.0994e-08])\n", 237 | "result: 8.0 tensor(2.4796e-07) \n", 238 | "\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "pi_logprob = torch.tensor(0.4).log()\n", 244 | "pi_old_logprob = torch.tensor(0.3).log()\n", 245 | "pi_ref_logprob = torch.tensor(0.401).log()\n", 246 | "rewards_group = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0], dtype = torch.float32) \n", 247 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 248 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')\n", 249 | "\n", 250 | "rewards_group = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0], dtype = torch.float32) \n", 251 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 252 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')\n", 253 | "\n", 254 | "rewards_group = torch.tensor([1, 1, 1, 1, 1, 1, 1, 0], dtype = torch.float32) \n", 255 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 256 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')\n", 257 | "\n", 258 | "rewards_group = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype = torch.float32) \n", 259 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 260 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "id": "8b0888bb-59b1-462b-8cd6-ac755bd8853d", 266 | "metadata": {}, 267 | "source": [ 268 | "## policy change" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 37, 274 | "id": "e6702b26-71fb-4425-ad78-9d21e6e318a0", 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "[Rewards] : tensor([1., 1., 0., 0., 0., 0., 0., 0.])\n", 282 | "[Adv] : tensor([ 1.6202, 1.6202, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401])\n", 283 | "[Loss] : tensor([-1.6202, -1.6202, 0.5401, 0.5401, 0.5401, 0.5401, 0.5401, 0.5401])\n", 284 | "result: 2.0 tensor(1.1921e-07) \n", 285 | "\n", 286 | "[Rewards] : tensor([1., 1., 0., 0., 0., 0., 0., 0.])\n", 287 | "[Adv] : tensor([ 1.6202, 1.6202, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401])\n", 288 | "[Loss] : tensor([-1.6197, -1.6197, 0.5405, 0.5405, 0.5405, 0.5405, 0.5405, 0.5405])\n", 289 | "result: 2.0 tensor(0.0037) \n", 290 | "\n", 291 | "[Rewards] : tensor([1., 1., 0., 0., 0., 0., 0., 0.])\n", 292 | "[Adv] : tensor([ 1.6202, 1.6202, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401, -0.5401])\n", 293 | "[Loss] : tensor([-1.6039, -1.6039, 0.5563, 0.5563, 0.5563, 0.5563, 0.5563, 0.5563])\n", 294 | "result: 2.0 tensor(0.1297) \n", 295 | "\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "pi_logprob = torch.tensor(0.4).log()\n", 301 | "pi_old_logprob = torch.tensor(0.4).log()\n", 302 | "pi_ref_logprob = torch.tensor(0.401).log()\n", 303 | "rewards_group = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0], dtype = torch.float32) \n", 304 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 305 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')\n", 306 | "\n", 307 | "\n", 308 | "pi_logprob = torch.tensor(0.3).log()\n", 309 | "pi_old_logprob = torch.tensor(0.3).log()\n", 310 | "pi_ref_logprob = torch.tensor(0.401).log()\n", 311 | "rewards_group = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0], dtype = torch.float32) \n", 312 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 313 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')\n", 314 | "\n", 315 | "pi_logprob = torch.tensor(0.1).log()\n", 316 | "pi_old_logprob = torch.tensor(0.1).log()\n", 317 | "pi_ref_logprob = torch.tensor(0.401).log()\n", 318 | "rewards_group = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0], dtype = torch.float32) \n", 319 | "loss = minimal_grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, rewards_group)\n", 320 | "print('result:', rewards_group.sum().item(), loss.sum(), '\\n')\n" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "id": "2cb418a3-c7eb-419a-85fc-1b17e1c38a79", 326 | "metadata": {}, 327 | "source": [ 328 | "## loss curve with reward rising" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 14, 334 | "id": "4dae4ad1-6221-4aec-beb0-26690c11f3e4", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "pi_logprob = torch.tensor(0.1).log()\n", 339 | "pi_old_logprob = torch.tensor(0.005).log()\n", 340 | "pi_ref_logprob = torch.tensor(0.101).log()\n", 341 | "\n", 342 | "nums = 128\n", 343 | "rewards_group = torch.zeros(nums)\n", 344 | "loss_list = []\n", 345 | "for i in range(nums):\n", 346 | " rewards_group[i] = 1.0\n", 347 | " loss = minimal_grpo_loss(pi_logprob, \n", 348 | " pi_old_logprob, \n", 349 | " pi_ref_logprob, \n", 350 | " rewards_group, \n", 351 | " is_debug=False)\n", 352 | " loss_list.append(loss.sum().item())" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 15, 358 | "id": "06727692-cd14-485f-8a7a-09d81844f69d", 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "data": { 363 | "image/png": "", 364 | "text/plain": [ 365 | "
" 366 | ] 367 | }, 368 | "metadata": {}, 369 | "output_type": "display_data" 370 | } 371 | ], 372 | "source": [ 373 | "import matplotlib.pyplot as plt\n", 374 | "plt.figure(figsize=(16, 6)) \n", 375 | "plt.plot(loss_list)\n", 376 | "plt.title('grpo loss with rewards sum')\n", 377 | "plt.xlabel('reward sum')\n", 378 | "plt.ylabel('loss')\n", 379 | "plt.grid()\n", 380 | "plt.show()" 381 | ] 382 | } 383 | ], 384 | "metadata": { 385 | "kernelspec": { 386 | "display_name": "Python 3 (ipykernel)", 387 | "language": "python", 388 | "name": "python3" 389 | }, 390 | "language_info": { 391 | "codemirror_mode": { 392 | "name": "ipython", 393 | "version": 3 394 | }, 395 | "file_extension": ".py", 396 | "mimetype": "text/x-python", 397 | "name": "python", 398 | "nbconvert_exporter": "python", 399 | "pygments_lexer": "ipython3", 400 | "version": "3.11.11" 401 | } 402 | }, 403 | "nbformat": 4, 404 | "nbformat_minor": 5 405 | } 406 | -------------------------------------------------------------------------------- /GRPO-Loss-Pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ba303ea0-7339-444f-a85a-4389873b5bf5", 6 | "metadata": {}, 7 | "source": [ 8 | "# GRPO-Loss-pytorch\n", 9 | "\n", 10 | "author: xiaodongguaAIGC\n", 11 | "\n", 12 | "git: [dhcode-cpp](https://github.com/dhcode-cpp)\n", 13 | "\n", 14 | "blog: [【手撕LLM-GRPO】你只管给Reward, 剩下的交给RL(附代码)](https://zhuanlan.zhihu.com/p/20812786520)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "37d15458-8a27-4baf-afc7-6b6187013547", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import torch\n", 25 | "import torch.nn.functional as F" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "b0531f35-3559-4c4f-9225-4fdc20a647e7", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "def grpo_kl(pi_logprob, pi_ref_logprob):\n", 36 | " return pi_ref_logprob.exp() / pi_logprob.exp()- (pi_ref_logprob - pi_logprob) - 1" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "id": "74242e19-4f3c-4a58-9413-cc3d6f1863db", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, advantage, input_len, len_oi):\n", 47 | " epsilon = 0.2\n", 48 | " beta = 0.01\n", 49 | "\n", 50 | " bs, seq_len = pi_logprob.shape\n", 51 | " # skip计算采样的每条采样长度\n", 52 | " len_oi = torch.tensor([len_oi] * bs, dtype = torch.long)\n", 53 | " # 设定mask, 仅对response 为 1, 算loss\n", 54 | " mask = torch.zeros(bs, seq_len)\n", 55 | " mask[:, input_len:] = 1\n", 56 | "\n", 57 | " # GRPO loss\n", 58 | " ratio = torch.exp(pi_logprob - pi_old_logprob)\n", 59 | " ratio_clip = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)\n", 60 | " advantage = advantage.unsqueeze(dim = 1) # [a, b ,c] -> [[a], [b], [c]]\n", 61 | " policy_gradient = torch.minimum(ratio * advantage , ratio_clip * advantage)\n", 62 | " kl = grpo_kl(pi_logprob, pi_ref_logprob)\n", 63 | "\n", 64 | " loss = (policy_gradient - beta * kl) * mask\n", 65 | " loss = (-1 / bs ) * (1/len_oi.unsqueeze(dim = 1)) * loss \n", 66 | " loss = loss.sum()\n", 67 | "\n", 68 | " return loss" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "id": "50164e56-71bf-4330-80e7-1e9606c47b18", 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/html": [ 80 | "
tensor(-0.4713)\n",
 81 |        "
\n" 82 | ], 83 | "text/plain": [ 84 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m-0.4713\u001b[0m\u001b[1m)\u001b[0m\n" 85 | ] 86 | }, 87 | "metadata": {}, 88 | "output_type": "display_data" 89 | } 90 | ], 91 | "source": [ 92 | "# 输出分布\n", 93 | "pi_logits = torch.randn(3, 5, 32) # batch, seq_len, vocab_size\n", 94 | "pi_ref_logits = torch.randn(3, 5, 32)\n", 95 | "pi_old_logits = torch.randn(3, 5, 32)\n", 96 | "\n", 97 | "# 获取log prob\n", 98 | "pi_logprob = F.log_softmax(pi_logits, dim = -1)\n", 99 | "pi_ref_logprob = F.log_softmax(pi_ref_logits, dim = -1)\n", 100 | "pi_old_logprob = F.log_softmax(pi_old_logits, dim = -1)\n", 101 | "\n", 102 | "# group data\n", 103 | "token_ids = torch.tensor([[11, 12, 13, 14, 15], # 输入为11,12,13, 输出为:14, 15\n", 104 | " [11, 12, 13, 15, 16],\n", 105 | " [11, 12, 13, 16, 17],])\n", 106 | "\n", 107 | "# 获取policy\n", 108 | "pi_logprob = torch.gather(pi_logprob, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)\n", 109 | "pi_ref_logprob = torch.gather(pi_ref_logprob, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)\n", 110 | "pi_old_logprob = torch.gather(pi_old_logprob, dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)\n", 111 | "loss = grpo_loss(pi_logprob, pi_old_logprob, pi_ref_logprob, torch.tensor([-1, 2, 1]), 3, 2)\n", 112 | "print(loss)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "bf925213-e779-4155-b85d-fff71b4ec661", 118 | "metadata": {}, 119 | "source": [ 120 | "## Trl Implementation" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "3c4b0d89-6626-4e0e-99a2-57ce1f0f2544", 126 | "metadata": {}, 127 | "source": [ 128 | "- ppo clip ratio\n", 129 | "- grpo clip ratio\n", 130 | "- trl \"not\" clip ratio, it haven't minibatch, ` exp( logprob - logprob.detach()` always equal `1`" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 8, 136 | "id": "60b46980-106c-4fb0-9ca7-5f0fe18156ef", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/html": [ 142 | "
tensor([1.])\n",
143 |        "
\n" 144 | ], 145 | "text/plain": [ 146 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 147 | ] 148 | }, 149 | "metadata": {}, 150 | "output_type": "display_data" 151 | }, 152 | { 153 | "data": { 154 | "text/html": [ 155 | "
tensor([1.])\n",
156 |        "
\n" 157 | ], 158 | "text/plain": [ 159 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1\u001b[0m.\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 160 | ] 161 | }, 162 | "metadata": {}, 163 | "output_type": "display_data" 164 | } 165 | ], 166 | "source": [ 167 | "policy = torch.tensor([0.5])\n", 168 | "old_policy = torch.tensor([0.5])\n", 169 | "ratio = policy/old_policy\n", 170 | "print(ratio)\n", 171 | "\n", 172 | "ratio = torch.exp( policy.log() - old_policy.log())\n", 173 | "print(ratio)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 9, 179 | "id": "1109e320-5872-4ba4-af99-05dee46d4a93", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/html": [ 185 | "
tensor([0.4000])\n",
186 |        "
\n" 187 | ], 188 | "text/plain": [ 189 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.4000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 190 | ] 191 | }, 192 | "metadata": {}, 193 | "output_type": "display_data" 194 | } 195 | ], 196 | "source": [ 197 | "gradient = -0.2\n", 198 | "policy_gradient = - gradient * ( 1 / old_policy)\n", 199 | "print(policy_gradient)" 200 | ] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "Python 3 (ipykernel)", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.11.9" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 5 224 | } 225 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 dhcode95 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRPO-Loss-Pytorch 2 | 3 | 代码节选自课程《手撕LLM》关键实现代码: 4 | 5 | blog: [【手撕LLM-GRPO】你只管给Reward, 剩下的交给RL(附代码)](https://zhuanlan.zhihu.com/p/20812786520) 6 | 7 | blog: [GRPO的Loss为什么会有负值?](https://zhuanlan.zhihu.com/p/28326620566) 8 | --------------------------------------------------------------------------------