├── .gitignore ├── .streamlit └── config.toml ├── LICENSE ├── README.md ├── assets ├── overview.png ├── response_area.png ├── reward_area.png ├── token_area.png └── use_example.png ├── requirments.txt ├── rl_logging_board.py ├── rollout_samples ├── eval_poem_task │ └── data.jsonl └── train_poem_task │ └── data.jsonl └── start.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # IDE 24 | .idea/ 25 | .vscode/ 26 | *.swp 27 | *.swo 28 | 29 | # 环境 30 | venv/ 31 | env/ 32 | ENV/ 33 | .env 34 | .venv 35 | 36 | # 日志和缓存 37 | *.log 38 | .cache 39 | .pytest_cache/ 40 | .coverage 41 | htmlcov/ 42 | 43 | # 系统 44 | .DS_Store 45 | Thumbs.db 46 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="dark" 3 | backgroundColor="#171719" 4 | secondaryBackgroundColor="#202025" 5 | primaryColor="#AAAAAD" 6 | font="serif" 7 | textColor="#ceced2" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📖 RL Logging Board 2 | 3 | RL Logging Board 是一个将 Reinforcement Learning from Human Feedback(RLHF)的`训练过程进行可视化的工具`,旨在: 4 | 5 | 1. 帮助人们直观`理解 RL 训练过程`,如:token 概率会随着训练升高/降低情况、response reward 分布随着训练的变化情况等。 6 | 2. 当训练不符合预期时,`通过监控 token 粒度的指标来定位可能的问题`,如:critic value 方差过大,某些异常 token 的 kl 过大等。 7 | 3. 更直观的看到每一个 step 高/低 reward 对应的 response 分布,用于`发现 reward hacking 的潜在 pattern`。 8 | 4. (可选)直观进行 `RL & SFT 模型之间的对比`,包含 response、reward 在内的多个方面。 9 | 10 | > **⚠️ 注意:** RL Logging Board 仅进行指标可视化,不包含训练框架本身!该工具初衷并非是成为 tensorboard & wanDB 的平替,事实上我们在实做的时候也会同时使用 tensorboard 和该工具。对于一些数值类型的简单 metric(例如 reward_mean,response_length,loss 等)tensorboard 已经能很好的胜任,而对于一些更细粒度的展示(如 token 级别的指标)我们会使用该工具进行辅助。 我们在 [[这个版本的 OpenRLHF](https://github.com/HarderThenHarder/OpenRLHF/tree/support-rl-logging-board)] 中集成了对这个工具的使用。 11 | 12 |
13 |

14 | 15 |

RL Logging Board 概览图

16 |
17 | 18 | 使用该工具需要使用者在自己使用的训练框架中(如:[[OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)]),加入对应的指标保存代码(保存为 `.jsonl` 的本地文件,在 [[后续内容](#21-保存数据格式)] 中会有对应 OpenRLHF 的数据获取示例),再使用该工具进行可视化。 19 | 20 | 我们会在后面详细介绍所有需要保存的每一个 metric,并提供一些 [[示例文件](rollout_samples/train_poem_task/data.jsonl)] 以供参考。 21 | 22 | 23 | ## 1. 有哪些可视化的部分?(Visualization Modules) 24 | 25 | > PS:如果只想运行该平台,可直接跳转至 [[下一章节:如何运行该工具?](#2-如何运行该工具)]。 26 | 27 | 在这个部分中我们将描述该工具支持的所有可视化功能,并分享一些我个人常用到的一些关键部分。 28 | 29 | 我们使用一个非常简单的「押韵任务」(训练模型续写的诗句必须押韵)来进行示例展示,以便于大家理解。 30 | 31 | ### 1.1 Reward Area(curve & distribution) 32 | 33 | reward 是 RL 训练中的核心指标,也是训练中观测最频繁的一个 metric,我们主要关注: 34 | 35 | - **训练曲线(Traning Curves):** 在保存文件中,我们可以将 reference model 的 reward 保存,随后进行可视化(可选),这通常需要让 reference model(init policy)对指定的 prompt(s) 进行推理,并得到打分结果。因此这个过程通常发生在 dev 测试集上(或者在训练集合中提前把 init policy 的结果提前跑好),该指标能更好的进行 RL Model 和 Reference Model 之间的 metric 比较,是一个可选项(`如果不保存 reference metric 则会只展示 rl model 的 metirc`)。 36 | - **每个 batch 内得分分布(Reward Distribution):** 通过观测得分分布,我们可以直观的看到 RL 训练的收敛性,由于 PPO 属于 [[Reverse KL](https://agustinus.kristia.de/blog/forward-reverse-kl/)] 的优化方法,因此正常来讲,`训练后期的 reward distribution 应该比前期收的更尖锐(如下图中上下两部分中的右侧柱状统计图)`,如果训练过程发现并非如此,那么我们可能需要关注下训练框架或是训练 setting 中是否存在问题。 37 | - **与 Reference Model 之间的 reward 差值分布:** 正常来讲,`训练到后期的模型 reward 低于 init policy 的样本应当越来越少(如下图中上下两部分中的左侧侧柱状统计图)`,我们可以通过查看在训练过程中一直未能战胜 init policy 的那些样本,并分析为什么这类样本无法提升。(是否是训练集的 OOD,或是 RM 的 OOD) 38 | 39 |
40 | 41 |
42 | 43 | ### 1.2 Response Area(Sort by Many Metrics) 44 | 45 | Instance Level 的监控同样比较重要,我们不仅需要看到一个 step 的 mean reward,同时希望看到这个 step 下每一个样本的具体指标是怎样的(这有助于我们找到整个训练过程中的那些“害群之马”,将那些异常的 case 筛选出来进行分析)。 46 | 47 | 我们主要关注: 48 | 49 | - **按 reward 排序(Reward Sort):** 通过升序/降序排列,我们能够查看每个 batch 内部分数高(或低)的样本的特征。`对于高分,我们需要着重考虑这类 prompt 是否存在 hacking features`(这通常会通过一些其他方法锁定到发生 hacking 的 step 范围),接着尝试在这些 step(s) 内的高分样本中观测得到可能的 hacking feature(s),并在 Reward Model 侧验证这些猜想。`对于低分,我们需要尝试解释这类 prompts 为何一直无法被优化`(这可能是由于 init policy 能力不足以完成这类型的任务,或是现有 reward model 对这类任务打不出更高的分数等)。 50 | - **按 log_ratio 排序(KL Sort):** `log_ratio(或 KL)能够直接反应「当前模型被优化的程度」`,因此我们可以对 log_ratio 进行排序,我们能够看到在相同的训练步数下,哪些类型的 prompts 被优化的程度过高(小心会过早的出现 hacking),哪些类型的 prompts 被优化的很少(不怎么能被提升)。除此之外,`由于 kl 会和 reward 一起被加入 returns 中优化,因此,对 kl 的监控也是必要的`(例如出现 negative kl 过大导致训练目标偏移),这时我们需要找到那些 kl 过大的样本,并对这些样本进行分析(为什么会出现过大 kl 的情况),从而排除一些训练框架本身的问题。 51 | 52 | 53 |
54 | 55 |
56 | 57 | ### 1.3 Token Area(kl, critic value, dense_reward, prob) 58 | 59 | Token Level 的监控是我们所能做到最小粒度的监控,它反映了整个 RL 训练过程中每一个 token 上的变化情况,具体来讲,我们主要关注: 60 | 61 | - **token reward:** 这是 policy model 训练时真实使用的 signal(= `kl_penalty + rewards(dense or sparse)`)。我们可以关注训练过程中,每个 step 下不同 token 被「奖励 or 惩罚」的程度,从而帮助我们更好的调整各项 reward signal 之间的比例(例如 kl_coef,或是 dense signal 的 ratio)。 62 | - **token value:** 这是 `critic model 为当前 policy model 在每个 token(action)上给出的评价(state value)`,通过直接比较 token value 和 token reward,我们能较为直观的看出 value model 对哪些 token 的拟合是容易的(MSE 小),对哪些 token 的拟合是困难的(MSE 大)。 63 | - **prob & ref_prob:** 该指标反映了`当前 response 中每一个 token 「被当前 policy 选择的概率」和「被 init policy 选择的概率」`,相比于 log_ratio,概率能够更直观的让我们理解模型当前的「策略」是怎样的。对于那些低分 case,我们可以观测在 init policy 模型上它们的概率是怎样的,若 init model 较为「坚定」的选择了这些 bad case,或许我们可以反查回 sft 数据中大概率存在这样的异常数据。 64 | 65 |
66 | 67 |
68 | 69 | 70 | ## 2. 如何运行该工具? 71 | 72 | ### 2.1 保存数据格式 73 | 74 | 在之前介绍中有提到,因为这个工具本身只做可视化,本身不包含训练,`因此需要我们在别的训练框架中自行保存工具所需要的 metric`,因为与训练框架解耦,理论可以支持任何训练框架的训练过程可视化。 75 | 76 | 工具加载需要 `.jsonl` 的文件([查看示例文件](rollout_samples/train_poem_task/data.jsonl)),每一行 json 需要包含以下 key(蓝链代码以 OpenRLHF 框架来为例): 77 | 78 | 79 | | 字段名 | 类型 | 描述 | 80 | |--------|------|------| 81 | | prompt | str | prompt 的字符串形式(可带 Pad Token,后续在工具里设置自动过滤即可),通常在 [generate](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L249) 中拿到 | 82 | | response | str | response 的字符串形式(可带 Pad Token,后续在工具里设置自动过滤即可),通常在 [generate](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L254) 中拿到 | 83 | | response_tokens | List[str] | response token 的 str list,用于逐 token 展示细粒度指标(kl,dense signal 等),可通过 tokenizer.convert_ids_to_tokens() 获得 | 84 | | logprobs | List[float] | response token 对应 policy model 的 log_probs,在 [forward 过程](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L286) 中获得 | 85 | | ref_logprobs | List[float] | response token 对应 reference model 的 log_probs,在 [forward 过程](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L289) 中获得 | 86 | | values | List[float] | critic model 对 response token 打出的 value 值,在 [critic forward](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L293) 阶段中获得 | 87 | | token_rewards | List[float] | kl_penalty + reward (dense 直接相加,sparse 加在尾 token),对应训练中所使用的真实 rewards,通过 [combined reward](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/models/utils.py#L74) 获得 | 88 | | reward | float | response 的整体 reward,通过 [rm 打分](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L304) 后获得 | 89 | | step | int | 当前数据的训练步数,在 [training loop](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_trainer.py#L250C17-L250C22) 中获得 | 90 | | ref_response | str, Optional | reference model 对该 prompt 的采样结果(可选) | 91 | | ref_reward | float, Optional | reference response 的整体 reward(可选) | 92 | 93 | 94 | ```json 95 | { 96 | "prompt": "请编写“上相生坤位,中兴运泰开。”的下一句,要求押韵。", 97 | "response": "威仪如鹤梦,道德似桃梅。", 98 | "response_tokens": ["威", "仪", "如", "鹤", "梦", ",", "道德", "似", "桃", "梅", "。", "", ""], 99 | "logprobs": [-4.847491264343262, -1.052163004875183, -3.1773641109466553, -3.155355215072632, -3.759133815765381, -0.0032821616623550653, -4.711000442504883, -0.7994625568389893, -4.159769535064697, -1.7499101161956787, -0.0008301864145323634, -2.3007127310847864e-05], 100 | "ref_logprobs": [-4.84375,-1.0546875,-3.171875,-3.15625,-3.765625,-0.0032806396484375,-4.71875,-0.80078125,-4.15625,-1.75,-0.00083160400390625,-2.3245811462402344e-05], 101 | "values": [-0.61328125,-0.01904296875,-0.373046875,-0.62890625,-0.3203125,-0.328125,-0.302734375,-0.353515625,-0.1474609375,-0.19140625,0.08642578125,-0.09765625], 102 | "token_rewards": [0.0007482529035769403,-0.0005048990133218467,0.001097822212614119,-0.00017895699420478195,-0.0012982368934899569,3.0440278919741104e-07,-0.0015499115688726306,-0.0002637386496644467,0.0007039070478640497,-1.7976761228055693e-05,-2.835178918303427e-07,-4.77368296003533e-08], 103 | "reward": 0.0, 104 | "step": 4, 105 | "(Optional)ref_response": "从 reference model 采样的结果(可选)", 106 | "(Optional)ref_reward": 0.0, 107 | }, 108 | ... 109 | ``` 110 | 111 | 将数据文件保存到 `./rollout_samples/` 目录下一个 `单独的文件夹` 即可。 112 | 113 | > **Note:** 工具会读取目标文件夹下的所有 .jsonl 文件,因此在训练中可按 DP 存储为独立的文件(这样就不用 gather 到主节点),工具会根据 json data 中的 `step` 进行自动整合。 114 | 115 | ### 2.2 启动可视化工具 116 | 117 | 1. 安装工具所需要的依赖包: 118 | 119 | ```sh 120 | pip install -r requirments.txt 121 | ``` 122 | 123 | 2. 运行启动脚本: 124 | 125 | ```sh 126 | bash start.sh 127 | ``` 128 | 129 | 其中,`start.sh` 里通过 `--server.port` 来指定 web 页面的端口: 130 | 131 | ```sh 132 | streamlit run rl_logging_board.py --server.port 8901 133 | ``` 134 | 135 |
136 | 137 |
138 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/overview.png -------------------------------------------------------------------------------- /assets/response_area.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/response_area.png -------------------------------------------------------------------------------- /assets/reward_area.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/reward_area.png -------------------------------------------------------------------------------- /assets/token_area.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/token_area.png -------------------------------------------------------------------------------- /assets/use_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/use_example.png -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.40.0 2 | ujson==5.10.0 3 | plotly==5.18.0 -------------------------------------------------------------------------------- /rl_logging_board.py: -------------------------------------------------------------------------------- 1 | """ 2 | ==== No Bugs in code, just some Random Unexpected FEATURES ==== 3 | ┌─────────────────────────────────────────────────────────────┐ 4 | │┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐│ 5 | ││Esc│!1 │@2 │#3 │$4 │%5 │^6 │&7 │*8 │(9 │)0 │_- │+= │|\ │`~ ││ 6 | │├───┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴───┤│ 7 | ││ Tab │ Q │ W │ E │ R │ T │ Y │ U │ I │ O │ P │{[ │}] │ BS ││ 8 | │├─────┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴─────┤│ 9 | ││ Ctrl │ A │ S │ D │ F │ G │ H │ J │ K │ L │: ;│" '│ Enter ││ 10 | │├──────┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴────┬───┤│ 11 | ││ Shift │ Z │ X │ C │ V │ B │ N │ M │< ,│> .│? /│Shift │Fn ││ 12 | │└─────┬──┴┬──┴──┬┴───┴───┴───┴───┴───┴──┬┴───┴┬──┴┬─────┴───┘│ 13 | │ │Fn │ Alt │ Space │ Alt │Win│ HHKB │ 14 | │ └───┴─────┴───────────────────────┴─────┴───┘ │ 15 | └─────────────────────────────────────────────────────────────┘ 16 | 17 | 启动 web 页面可视化 RL 训练过程中间 metric 的数据状态。 18 | 19 | Author: pankeyu 20 | Date: 2023/10/31 21 | """ 22 | import os 23 | import copy 24 | import traceback 25 | 26 | try: 27 | import ujson as json 28 | except: 29 | import json 30 | print('`pip install ujson` can be faster.') 31 | 32 | import numpy as np 33 | import pandas as pd 34 | import streamlit as st 35 | 36 | import plotly.express as px 37 | import plotly.graph_objects as go 38 | import plotly.figure_factory as ff 39 | 40 | 41 | st.set_page_config( 42 | page_title="RL Logging Board", 43 | page_icon="📖", 44 | layout='wide' 45 | ) 46 | 47 | 48 | def load_log_file( 49 | logdir: os.PathLike, 50 | max_samples_each_step: int 51 | ): 52 | """ 53 | 解析本地log文件。 54 | 55 | Args: 56 | logdir (os.PathLike): _description_ 57 | max_samples_each_step (int): _description_ 58 | """ 59 | st.session_state['logging_name'] = logdir 60 | st.session_state['max_samples_each_step'] = max_samples_each_step 61 | st.session_state['logging_data'] = {} 62 | error_lines, success_lines = 0, 0 63 | 64 | all_logs = os.listdir(logdir) 65 | 66 | progress_text = f"Processing all files..." 67 | loading_files_bar = st.progress(0., text=progress_text) 68 | 69 | progress_text = f"Processing each file samples..." 70 | loading_samples_bar = st.progress(0., text=progress_text) 71 | 72 | 73 | for log_index in range(len(all_logs)): 74 | 75 | if not all_logs[log_index].endswith('.jsonl'): 76 | continue 77 | 78 | rl_log_file = os.path.join( 79 | logdir, 80 | all_logs[log_index] 81 | ) 82 | 83 | mock_max_lines_num = 10000 84 | 85 | with open(rl_log_file, 'r', encoding='utf8', errors='ignore') as f: 86 | for i, line in enumerate(f): 87 | try: 88 | data = json.loads(line) 89 | data['step'] = int(data['step']) 90 | if data['step'] not in st.session_state['logging_data']: 91 | st.session_state['logging_data'][data['step']] = { 92 | 'prompt': [], 93 | 'response': [], 94 | 'ref_response': [], 95 | 'reward': [], 96 | 'ref_reward': [], 97 | 'response_tokens': [], 98 | 'logprobs': [], 99 | 'ref_logprobs': [], 100 | 'probs': [], 101 | 'ref_probs': [], 102 | 'values': [], 103 | 'token_rewards': [], 104 | 'kl': [], 105 | 'avg_kl': [], 106 | 'sum_kl': [], 107 | 'log_ratio': [], 108 | 'avg_log_ratio': [], 109 | 'sum_log_ratio': [], 110 | 'valid_reward': [], 111 | 'ref_valid_reward': [], 112 | 'response_tokens_len': [], 113 | 'ground_truth': [] 114 | } 115 | elif len(st.session_state['logging_data'][data['step']]['prompt']) >= max_samples_each_step: 116 | percentage = (i + 1) / mock_max_lines_num 117 | percentage = min(percentage, 1.0) 118 | loading_samples_bar.progress(percentage, text=f"[{int(percentage * 100)}%] Processing {i + 1} / {mock_max_lines_num} samples in each files...") 119 | 120 | for key in st.session_state['logging_data'][data['step']]: 121 | if key in data: 122 | st.session_state['logging_data'][data['step']][key].append(data[key]) 123 | 124 | if 'response_tokens' in data: 125 | st.session_state['logging_data'][data['step']]['response_tokens_len'].append(len(data['response_tokens'])) 126 | 127 | if 'logprobs' in data and 'ref_logprobs' in data: 128 | logp = np.array(data['logprobs']) 129 | ref_logp = np.array(data['ref_logprobs']) 130 | log_ratio = logp - ref_logp 131 | kl = np.exp(log_ratio) - 1 - log_ratio 132 | st.session_state['logging_data'][data['step']]['log_ratio'].append(log_ratio.tolist()) 133 | st.session_state['logging_data'][data['step']]['avg_log_ratio'].append(np.nanmean(log_ratio)) 134 | st.session_state['logging_data'][data['step']]['sum_log_ratio'].append(np.nansum(log_ratio)) 135 | st.session_state['logging_data'][data['step']]['kl'].append(kl.tolist()) 136 | st.session_state['logging_data'][data['step']]['avg_kl'].append(np.nanmean(kl)) 137 | st.session_state['logging_data'][data['step']]['sum_kl'].append(np.nansum(kl)) 138 | st.session_state['logging_data'][data['step']]['probs'].append(np.exp(logp).tolist()) 139 | st.session_state['logging_data'][data['step']]['ref_probs'].append(np.exp(ref_logp).tolist()) 140 | 141 | success_lines += 1 142 | 143 | except: 144 | print(traceback.format_exc()) 145 | error_lines += 1 146 | 147 | percentage = (i + 1) / mock_max_lines_num 148 | percentage = min(percentage, 1.0) 149 | loading_samples_bar.progress(percentage, text=f"[{int(percentage * 100)}%] Processing {i + 1} / {mock_max_lines_num} samples...") 150 | 151 | percentage = 1.0 152 | loading_samples_bar.progress(percentage, text=f"[{int(percentage * 100)}%] Processing {(success_lines + error_lines)} / {(success_lines + error_lines)} samples...") 153 | 154 | file_percentage = (log_index + 1) / len(all_logs) 155 | loading_files_bar.progress(file_percentage, text=f"[{int(file_percentage * 100)}%] Loading {log_index + 1} / {len(all_logs)} files...") 156 | 157 | st.toast( 158 | f'Loaded {success_lines + error_lines} sample(s), sucess: {success_lines}, error: {error_lines}.', 159 | icon='🎉' 160 | ) 161 | 162 | if not st.session_state['logging_data']: 163 | st.warning(f'No log file(s) found in {logdir}.', icon='⚠️') 164 | st.stop() 165 | 166 | all_steps = [int(s) for s in list(st.session_state["logging_data"].keys())] 167 | all_steps.sort() 168 | st.session_state['max_step_index'] = max(all_steps) 169 | st.session_state['min_step_index'] = min(all_steps) 170 | st.session_state['step_gap'] = 1 if len(all_steps) < 2 else all_steps[1] - all_steps[0] 171 | 172 | rewards_dict = {'step': [], 'reward': [], 'ref_reward': []} 173 | for step in st.session_state['logging_data']: 174 | st.session_state['logging_data'][step]['avg_reward'] = sum(st.session_state['logging_data'][step]['reward']) / len(st.session_state['logging_data'][step]['reward']) 175 | 176 | current_step_resp_length = [len(resp) for resp in st.session_state['logging_data'][step]['response']] 177 | st.session_state['logging_data'][step]['avg_length'] = int(sum(current_step_resp_length) / len(current_step_resp_length)) 178 | 179 | current_step_ref_resp_length = [len(resp) for resp in st.session_state['logging_data'][step]['ref_response']] 180 | st.session_state['logging_data'][step]['avg_ref_length'] = int(sum(current_step_ref_resp_length) / len(current_step_ref_resp_length)) if len(current_step_ref_resp_length) else 0 181 | 182 | if len(st.session_state['logging_data'][step]['ref_reward']): 183 | st.session_state['logging_data'][step]['avg_ref_reward'] = sum(st.session_state['logging_data'][step]['ref_reward']) / len(st.session_state['logging_data'][step]['ref_reward']) if len(st.session_state['logging_data'][step]['ref_reward']) else 0 184 | else: 185 | st.session_state['logging_data'][step]['avg_ref_reward'] = 0 186 | rewards_dict['step'].append(step) 187 | rewards_dict['reward'].append(st.session_state['logging_data'][step]['avg_reward']) 188 | rewards_dict['ref_reward'].append(st.session_state['logging_data'][step]['avg_ref_reward']) 189 | 190 | rewards_df = pd.DataFrame.from_dict(rewards_dict) 191 | st.session_state['reward_df'] = rewards_df.set_index('step') 192 | 193 | 194 | def plot_filled_line( 195 | x: list, 196 | y_list_list: list, 197 | data_names: list, 198 | colors: list, 199 | title=None 200 | ): 201 | """ 202 | 绘制带有阴影的折线图,阴影上下界为当前x对应的y列表中的最大、最小值。 203 | 204 | Args: 205 | x (list): step 横轴索引 206 | y_list_list (line_num, steps, step_wise): 可绘制多条直线,维度为:绘制折线条数,总共的step数,每个step对应几个y值 207 | data_names (list): 每条折线的名字列表 208 | colors (list): 每条折线的颜色列表(rgb), e.g. -> ['255,171,171'] 209 | """ 210 | fig = go.Figure() 211 | 212 | x_rev = x[::-1] 213 | for i in range(len(y_list_list)): 214 | y_list = y_list_list[i] 215 | y_mean, y_lower, y_upper = [], [], [] 216 | for y in y_list: 217 | y_arr = np.array(y) 218 | mean, std = float(y_arr.mean()), float(y_arr.std()) 219 | y_mean.append(mean) 220 | y_lower.append(mean - std) 221 | y_upper.append(mean + std) 222 | # y_lower.append(min(y)) 223 | # y_upper.append(max(y)) 224 | y_lower = y_lower[::-1] 225 | 226 | fig.add_trace(go.Scatter( 227 | x=x + x_rev, 228 | y=y_upper + y_lower, 229 | fill='toself', 230 | fillcolor=f'rgba({colors[i]},0.1)', 231 | line_color='rgba(255,255,255,0)', 232 | showlegend=False, 233 | name=data_names[i], 234 | )) 235 | fig.add_trace(go.Scatter( 236 | x=x, y=y_mean, 237 | line_color=f'rgb({colors[i]})', 238 | name=data_names[i], 239 | )) 240 | 241 | fig.update_traces(mode='lines') 242 | 243 | if title: 244 | fig.update_layout( 245 | title=title, 246 | legend=dict(orientation="h") 247 | ) 248 | 249 | return fig 250 | 251 | 252 | def init_sidebar(): 253 | """ 254 | 侧边栏实例化。 255 | """ 256 | st.sidebar.markdown( 257 | "

📖 RL Logging Board

", 258 | unsafe_allow_html=True 259 | ) 260 | 261 | base_root_path = st.sidebar.text_input( 262 | "Log(s) Root Path", 263 | value='./rollout_samples', 264 | ) 265 | 266 | if not os.path.exists(base_root_path): 267 | st.warning(f'Log(s) Root Path: `{base_root_path}` is not exists.', icon='⚠️') 268 | st.stop() 269 | 270 | all_log_path_in_logdir = os.listdir(base_root_path) 271 | 272 | if not all_log_path_in_logdir: 273 | st.warning('No log files found.') 274 | st.code("""Logging Dir should be like: 275 | Base Log Dir 276 | |__eval_topk_0_topp_1 (dir for evaluate logs) 277 | | |__eval.jsonl 278 | |__topk_0_topp_1 (dir for training logs, only for rl logs) 279 | |__rollout_data_rank_0_1313.jsonl 280 | ... 281 | """) 282 | st.stop() 283 | 284 | log_name = st.sidebar.selectbox( 285 | 'Choose Log Name', 286 | options=all_log_path_in_logdir, 287 | index=len(all_log_path_in_logdir) - 1 288 | ) 289 | 290 | max_samples_each_step = st.sidebar.number_input( 291 | 'Max Samples Each Step', 292 | help='当step batch size 过大时可能会造成平台卡顿,可设置阈值来下采样每个step的数据。', 293 | value=128, 294 | max_value=10240, 295 | min_value=1 296 | ) 297 | 298 | load_btn = st.sidebar.button( 299 | "Load & View", 300 | use_container_width=True 301 | ) 302 | 303 | if load_btn and ( 304 | 'logging_data' not in st.session_state 305 | or 306 | log_name != st.session_state['logging_name'] 307 | or 308 | max_samples_each_step != st.session_state.get('max_samples_each_step', -1) 309 | ): 310 | load_log_file( 311 | os.path.join(base_root_path, log_name), 312 | max_samples_each_step 313 | ) 314 | 315 | with st.sidebar.expander('🧩 module setting', expanded=True): 316 | st.session_state['show_reward_logging'] = st.checkbox('Reward 曲线图', value=True) 317 | st.session_state['var_scaling'] = st.slider('Variance Scaling', min_value=0.1, max_value=1.0, value=0.2, help='Reward 曲线图阴影面积调整(对方差做 scaling)。') 318 | st.session_state['zero_shift'] = st.checkbox('Zero Shift', value=False, help='是否将所有reward曲线的第一项都平移到0(仅用于对比变化趋势)。') 319 | st.session_state['show_response'] = st.checkbox('Response 对比', value=True) 320 | 321 | with st.sidebar.expander('⚙️ show details setting', expanded=True): 322 | st.session_state['use_logp_as_kl'] = st.checkbox('Use LogP as KL', value=True, help='在 Reward 曲线图中用 LogProb 替代 KL 展示。') 323 | st.session_state['drop_pad'] = st.checkbox('Drop Padding Token', value=True) 324 | st.session_state['pad_token'] = st.text_input('Pad Token', value='', disabled=not st.session_state['drop_pad']) 325 | st.session_state['drop_sys_prompt'] = st.checkbox('Drop System Prompt', value=True) 326 | st.session_state['end_token_of_sys_prompt'] = st.text_input('End Token of System Prompt', value='', disabled=not st.session_state['drop_sys_prompt']) 327 | st.session_state['show_charts'] = st.checkbox('Show Charts', value=True) 328 | st.session_state['show_batch_samples'] = st.checkbox('Show Batch Samples', value=True) 329 | st.session_state['show_samples_pair'] = st.checkbox('Show Samples Pair', value=True) 330 | st.session_state['show_token_heat_map'] = st.checkbox('Show Heat Map', value=True) 331 | 332 | def plot_filled_line( 333 | x: list, 334 | y_list_list: list, 335 | data_names: list, 336 | colors: list, 337 | title=None, 338 | var_scaling=1. 339 | ): 340 | """ 341 | 绘制带有阴影的折线图,阴影上下界为当前x对应的y列表中的最大、最小值。 342 | 343 | Args: 344 | x (list): step 横轴索引 345 | y_list_list (line_num, steps, step_wise): 可绘制多条直线,维度为:绘制折线条数,总共的step数,每个step对应几个y值 346 | data_names (list): 每条折线的名字列表 347 | colors (list): 每条折线的颜色列表(rgb), e.g. -> ['255,171,171'] 348 | """ 349 | fig = go.Figure() 350 | 351 | x_rev = x[::-1] 352 | for i in range(len(y_list_list)): 353 | y_list = y_list_list[i] 354 | zero_shift_value = 0 355 | y_mean, y_lower, y_upper = [], [], [] 356 | 357 | for idx, y in enumerate(y_list): 358 | y_arr = np.array(y) 359 | if idx == 0 and st.session_state['zero_shift']: 360 | zero_shift_value = np.nanmean(y_arr) 361 | 362 | y_arr = y_arr - zero_shift_value 363 | mean, std = float(np.nanmean(y_arr)), float(np.nanstd(y_arr)) 364 | std *= var_scaling 365 | y_mean.append(mean) 366 | y_lower.append(mean - std) 367 | y_upper.append(mean + std) 368 | y_lower = y_lower[::-1] 369 | 370 | fig.add_trace(go.Scatter( 371 | x=x + x_rev, 372 | y=y_upper + y_lower, 373 | fill='toself', 374 | fillcolor=f'rgba({colors[i]},0.1)', 375 | line_color='rgba(255,255,255,0)', 376 | showlegend=False, 377 | name=data_names[i], 378 | )) 379 | fig.add_trace(go.Scatter( 380 | x=x, y=y_mean, 381 | line_color=f'rgb({colors[i]})', 382 | name=data_names[i], 383 | )) 384 | 385 | fig.update_traces(mode='lines') 386 | 387 | if title: 388 | fig.update_layout( 389 | title=title, 390 | legend=dict(orientation="h") 391 | ) 392 | 393 | return fig 394 | 395 | 396 | def main_page(): 397 | """ 398 | Metrics Page. 399 | """ 400 | if "logging_data" not in st.session_state: 401 | st.info("Please Press 「Load & View」Button to load log.") 402 | else: 403 | if st.session_state['show_reward_logging']: 404 | step_reward_tab, step_kl_tab, resp_len_tab = st.tabs([ 405 | 'Step-Reward', 406 | 'Step-KL', 407 | 'Step-RespLen' 408 | ]) 409 | 410 | with step_reward_tab: 411 | steps, reward, ref_reward, valid_reward, ref_valid_reward = [], [], [], [], [] 412 | for step, value_dict in st.session_state['logging_data'].items(): 413 | steps.append(step) 414 | reward.append(value_dict['reward']) 415 | 416 | if value_dict['ref_reward']: 417 | ref_reward.append(value_dict['ref_reward']) 418 | 419 | if value_dict['valid_reward']: 420 | valid_reward.append(value_dict['valid_reward']) 421 | 422 | if value_dict['ref_valid_reward']: 423 | ref_valid_reward.append(value_dict['ref_valid_reward']) 424 | 425 | all_curves = { 426 | 'ref_reward': { 427 | 'value': ref_reward, 428 | 'color': '132,201,255' 429 | }, 430 | 'reward': { 431 | 'value': reward, 432 | 'color': '255,171,171' 433 | }, 434 | 'ref_valid_reward': { 435 | 'value': ref_valid_reward, 436 | 'color': '132,155,200' 437 | }, 438 | 'valid_reward': { 439 | 'value': valid_reward, 440 | 'color': '200,155,200' 441 | } 442 | } 443 | 444 | candidate_curves = [key for key in all_curves if all_curves[key]['value']] 445 | 446 | show_curves = st.multiselect( 447 | 'Show Rewards', 448 | candidate_curves, 449 | candidate_curves, 450 | label_visibility='collapsed' 451 | ) 452 | 453 | reward_fig = plot_filled_line( 454 | x=steps, 455 | y_list_list=[all_curves[r]['value'] for r in show_curves], 456 | data_names=show_curves, 457 | colors=[all_curves[r]['color'] for r in show_curves], 458 | title='👾 Rewards Logging (Step level)', 459 | var_scaling=st.session_state['var_scaling'] 460 | ) 461 | 462 | st.plotly_chart(reward_fig, theme="streamlit", use_container_width=True) 463 | 464 | with step_kl_tab: 465 | steps, kl = [], [] 466 | 467 | if st.session_state['use_logp_as_kl']: 468 | for step, value_dict in st.session_state['logging_data'].items(): 469 | if all(value_dict['avg_log_ratio']): 470 | steps.append(step) 471 | kl.append(value_dict['avg_log_ratio']) 472 | else: 473 | for step, value_dict in st.session_state['logging_data'].items(): 474 | if all(value_dict['kl']): 475 | steps.append(step) 476 | kl.append(value_dict['avg_kl']) 477 | 478 | reward_fig = plot_filled_line( 479 | x=steps, 480 | y_list_list=[kl], 481 | data_names=['KL'], 482 | colors=['255,165,0'], 483 | title='👾 KL Logging (Step level)' 484 | ) 485 | st.plotly_chart(reward_fig, theme="streamlit", use_container_width=True) 486 | 487 | with resp_len_tab: 488 | steps, resp_len = [], [] 489 | 490 | for step, value_dict in st.session_state['logging_data'].items(): 491 | if value_dict['response_tokens_len']: 492 | steps.append(step) 493 | resp_len.append(value_dict['response_tokens_len']) 494 | 495 | resp_len_fig = plot_filled_line( 496 | x=steps, 497 | y_list_list=[resp_len], 498 | data_names=['resp_len'], 499 | colors=['255,165,0'], 500 | title='👾 Response Length Logging (Step level)' 501 | ) 502 | st.plotly_chart(resp_len_fig, theme="streamlit", use_container_width=True) 503 | 504 | if st.session_state['show_response']: 505 | st.markdown('⚡️ **Each Step Response**') 506 | 507 | if st.session_state['min_step_index'] == st.session_state['max_step_index']: 508 | step_index = st.session_state['min_step_index'] 509 | elif ( 510 | len(st.session_state['logging_data']) > 2 511 | and 512 | list(st.session_state['logging_data'].keys())[2] - list(st.session_state['logging_data'].keys())[1] != list(st.session_state['logging_data'].keys())[1] - list(st.session_state['logging_data'].keys())[0] 513 | ): 514 | step_index = st.selectbox( 515 | f"Step Index({st.session_state['max_step_index']} total steps):", 516 | list(st.session_state['logging_data'].keys()), 517 | index=0 518 | ) 519 | else: 520 | step_index = st.slider( 521 | f"Step Index({st.session_state['max_step_index']} total steps):", 522 | min_value=st.session_state['min_step_index'], 523 | max_value=st.session_state['max_step_index'], 524 | value=st.session_state['min_step_index'], 525 | step=st.session_state['step_gap'] 526 | ) 527 | 528 | cur_step_content_dict = st.session_state['logging_data'][step_index] 529 | cur_step_filtered_content_dict = copy.deepcopy(cur_step_content_dict) 530 | 531 | cur_step_filtered_content_dict['prompt'] = [] 532 | for prompt in cur_step_content_dict['prompt']: 533 | if st.session_state['drop_pad']: 534 | prompt = prompt.replace(st.session_state['pad_token'], '').strip() 535 | if st.session_state['drop_sys_prompt']: 536 | prompt = prompt.split(st.session_state['end_token_of_sys_prompt'])[-1] 537 | cur_step_filtered_content_dict['prompt'].append(prompt) 538 | 539 | cur_step_filtered_content_dict['response'] = [c.replace(st.session_state['pad_token'], '').strip() if st.session_state['drop_pad'] else c for c in cur_step_content_dict['response']] 540 | cur_step_filtered_content_dict['reward_gap'] = [r - ref_r for r, ref_r in zip(cur_step_content_dict['reward'], cur_step_content_dict['ref_reward'])] 541 | cur_step_filtered_content_dict['valid_reward_gap'] = [r - ref_r for r, ref_r in zip(cur_step_content_dict['reward'], cur_step_content_dict['valid_reward'])] 542 | 543 | if st.session_state['show_charts']: 544 | 545 | if not cur_step_filtered_content_dict['ref_reward']: 546 | cur_step_filtered_content_dict['ref_reward'] = [0] * len(cur_step_filtered_content_dict['reward']) 547 | 548 | c1, c2, c3 = st.columns([6, 6, 6]) 549 | 550 | with c1: # reward 分布 551 | reward_distribution_dict = { 552 | 'sample_index': [], 553 | 'reward': [], 554 | 'tag': [] 555 | } 556 | for sample_index, (reward, ref_reward) in enumerate(zip(cur_step_filtered_content_dict['reward'], cur_step_filtered_content_dict['ref_reward'])): 557 | reward_distribution_dict['sample_index'].append(sample_index) 558 | reward_distribution_dict['reward'].append(reward) 559 | reward_distribution_dict['tag'].append('Reward') 560 | reward_distribution_dict['sample_index'].append(sample_index) 561 | reward_distribution_dict['reward'].append(ref_reward) 562 | reward_distribution_dict['tag'].append('Ref Reward') 563 | 564 | reward_distribution_df = pd.DataFrame.from_dict(reward_distribution_dict) 565 | fig = px.bar( 566 | reward_distribution_df, 567 | x="sample_index", 568 | y="reward", 569 | color="tag", 570 | barmode='group', 571 | color_discrete_sequence=px.colors.diverging.Portland, 572 | title="Reward in current batch samples" 573 | ) 574 | st.plotly_chart(fig, theme="streamlit", use_container_width=True) 575 | 576 | with c2: # reward gap 分布 577 | reward_distribution_dict = { 578 | 'sample_index': [i for i in range(len(cur_step_filtered_content_dict['reward_gap']))], 579 | 'reward_gap': cur_step_filtered_content_dict['reward_gap'] 580 | } 581 | reward_distribution_df = pd.DataFrame.from_dict(reward_distribution_dict) 582 | fig = px.bar( 583 | reward_distribution_df, 584 | x="sample_index", 585 | y="reward_gap", 586 | color="reward_gap", 587 | color_discrete_sequence=['red'], 588 | title="Reward Gap (r - ref_r) in current batch" 589 | ) 590 | st.plotly_chart(fig, theme="streamlit", use_container_width=True) 591 | 592 | with c3: # reward 方差分布 593 | if cur_step_filtered_content_dict['ref_reward']: 594 | hist_data = [ 595 | cur_step_filtered_content_dict['ref_reward'], 596 | cur_step_filtered_content_dict['reward'], 597 | ] 598 | group_labels = ['Ref Rewards', 'Rewards'] 599 | else: 600 | hist_data = [cur_step_filtered_content_dict['reward']] 601 | group_labels = ['Rewards'] 602 | 603 | fig = ff.create_distplot( 604 | hist_data, 605 | group_labels, 606 | bin_size=[.02, .02], 607 | curve_type='normal' 608 | ) 609 | fig.update_layout(title="Reward Distribution in current batch") 610 | st.plotly_chart(fig, use_container_width=True) 611 | 612 | showed_keys = [ 613 | 'prompt', 614 | 'response', 615 | 'reward', 616 | 'ground_truth', 617 | 'valid_reward', 618 | 'avg_log_ratio', 619 | 'sum_log_ratio', 620 | 'avg_kl', 621 | 'sum_kl', 622 | 'ref_response', 623 | 'ref_reward', 624 | 'ref_valid_reward', 625 | 'reward_gap', 626 | 'valid_reward_gap' 627 | ] 628 | candidate_keys = [k for k in showed_keys if cur_step_filtered_content_dict[k]] 629 | content_dict = dict([(k, cur_step_filtered_content_dict[k]) for k in candidate_keys]) 630 | content_df = pd.DataFrame.from_dict(content_dict) 631 | 632 | if st.session_state['show_batch_samples']: 633 | st.dataframe( 634 | content_df, 635 | use_container_width=True, 636 | height=350 637 | ) 638 | 639 | if st.session_state['show_samples_pair']: 640 | 641 | c1, c2, c3 = st.columns([1, 1, 4]) 642 | with c1: 643 | if step_index == st.session_state['min_step_index']: 644 | delta_char = 0 645 | else: 646 | try: 647 | cur_avg_len = st.session_state['logging_data'][step_index]['avg_length'] 648 | last_avg_len = st.session_state['logging_data'][step_index-st.session_state['step_gap']]['avg_length'] 649 | delta_char = cur_avg_len - last_avg_len 650 | except: 651 | delta_char = 0 652 | st.metric( # actor在当前step下的平均回复长度,delta为与上一个step的比较 653 | 'Response Average Length', 654 | value=f"{st.session_state['logging_data'][step_index]['avg_length']} 字", 655 | delta=f'{delta_char} 字' 656 | ) 657 | 658 | with c2: # ref_model在当前step下的平均回复长度,delta为与上一个step的比较 659 | try: 660 | delta_char = 0 if step_index == st.session_state['min_step_index'] else st.session_state['logging_data'][step_index]['avg_ref_length'] - st.session_state['logging_data'][step_index-st.session_state['step_gap']]['avg_ref_length'] 661 | except: 662 | delta_char = 0 663 | st.metric( 664 | 'Ref Response Average Length', 665 | value=f"{st.session_state['logging_data'][step_index]['avg_ref_length']} 字", 666 | delta=f'{delta_char} 字' 667 | ) 668 | 669 | with c3: 670 | sample_index = st.number_input( 671 | f'Sample index in current step batch: ', 672 | min_value=0, 673 | max_value=len(cur_step_filtered_content_dict['response']) - 1, 674 | value=0 675 | ) 676 | 677 | # 单样本展示 response - ref_response 的回复 678 | c1, c2, c3, c4 = st.columns([4, 4, 4, 2]) 679 | with c1: 680 | st.markdown('Prompt', unsafe_allow_html=True) 681 | content = cur_step_filtered_content_dict["prompt"][sample_index].replace('\n', ' \n').replace('~', '~') 682 | st.markdown( 683 | f'{content}', 684 | unsafe_allow_html=True 685 | ) 686 | with c2: 687 | st.markdown(':green[Response]') 688 | content = cur_step_filtered_content_dict["response"][sample_index].replace('\n', ' \n').replace('~', '~') 689 | st.markdown( 690 | f'{content}', 691 | unsafe_allow_html=True 692 | ) 693 | with c3: 694 | st.markdown(':blue[Ref Response]') 695 | if ( 696 | "ref_response" in cur_step_filtered_content_dict 697 | and 698 | cur_step_filtered_content_dict["ref_response"] 699 | ): 700 | content = cur_step_filtered_content_dict["ref_response"][sample_index].replace('\n', ' \n').replace('~', '~') 701 | st.markdown( 702 | f'{content}', 703 | unsafe_allow_html=True 704 | ) 705 | else: 706 | st.info('No `ref_response` found in log line data.') 707 | with c4: 708 | st.markdown(':orange[Reward Gap]') 709 | reward_gap = round(cur_step_filtered_content_dict["reward_gap"][sample_index], 4) if cur_step_filtered_content_dict["reward_gap"] else 0. 710 | st.metric( 711 | ' ', 712 | value=reward_gap 713 | ) 714 | 715 | # 展示更详细的 token-level 的信息 716 | if 'token_rewards' in cur_step_filtered_content_dict and cur_step_filtered_content_dict['token_rewards']: 717 | # 检查 resp_tokens 的长度和 logprobs 的长度是否对齐 718 | resp_token_len = len(cur_step_filtered_content_dict['response_tokens'][sample_index]) 719 | logp_len = len(cur_step_filtered_content_dict['logprobs'][sample_index]) 720 | if resp_token_len != logp_len: 721 | st.info( 722 | f'Note: `resp_tokens` (len: {resp_token_len}) is not equal to `logprobs` (len: {logp_len}), this may caused by tokens, CLIP response tokens!', 723 | icon='⚠️' 724 | ) 725 | cur_step_filtered_content_dict['response_tokens'][sample_index] = cur_step_filtered_content_dict['response_tokens'][sample_index][:logp_len] 726 | 727 | show_values = st.multiselect( 728 | 'Select show value(s)', 729 | ['token_reward', 'log_ratio', 'kl', 'token_value', 'logp', 'ref_logp', 'prob', 'ref_prob'], 730 | ['token_reward', 'log_ratio', 'kl', 'token_value', 'logp', 'ref_logp', 'prob', 'ref_prob'] 731 | ) 732 | 733 | new_dict, index_list = {}, [] 734 | 735 | if st.session_state['drop_pad'] and cur_step_filtered_content_dict['response_tokens'][sample_index][-1] == st.session_state['pad_token']: 736 | first_pad_token_idx = cur_step_filtered_content_dict['response_tokens'][sample_index].index(st.session_state['pad_token']) 737 | response_tokens_without_pad_token = cur_step_filtered_content_dict['response_tokens'][sample_index][:first_pad_token_idx] 738 | else: 739 | response_tokens_without_pad_token = cur_step_filtered_content_dict['response_tokens'][sample_index] 740 | 741 | for token_idx in range(len(response_tokens_without_pad_token)): 742 | if cur_step_filtered_content_dict['response_tokens']: 743 | resp_token = cur_step_filtered_content_dict['response_tokens'][sample_index][token_idx] 744 | resp_token = f'{token_idx} - {resp_token}' 745 | if resp_token not in new_dict: 746 | new_dict[resp_token] = [] 747 | 748 | if cur_step_filtered_content_dict['token_rewards']: 749 | token_reward = cur_step_filtered_content_dict['token_rewards'][sample_index][token_idx] 750 | if 'token_reward' in show_values: 751 | new_dict[resp_token].append(token_reward) 752 | if 'token_reward' not in index_list: 753 | index_list.append('token_reward') 754 | 755 | if cur_step_filtered_content_dict['log_ratio']: 756 | log_ratio = cur_step_filtered_content_dict['log_ratio'][sample_index][token_idx] 757 | if 'log_ratio' in show_values: 758 | new_dict[resp_token].append(log_ratio) 759 | if 'log_ratio' not in index_list: 760 | index_list.append('log_ratio') 761 | 762 | if cur_step_filtered_content_dict['kl']: 763 | kl = cur_step_filtered_content_dict['kl'][sample_index][token_idx] 764 | if 'kl' in show_values: 765 | new_dict[resp_token].append(kl) 766 | if 'kl' not in index_list: 767 | index_list.append('kl') 768 | 769 | if cur_step_filtered_content_dict['values']: 770 | value = cur_step_filtered_content_dict['values'][sample_index][token_idx] 771 | if 'token_value' in show_values: 772 | new_dict[resp_token].append(value) 773 | if 'token_value' not in index_list: 774 | index_list.append('token_value') 775 | 776 | if cur_step_filtered_content_dict['logprobs']: 777 | logp = cur_step_filtered_content_dict['logprobs'][sample_index][token_idx] 778 | if 'logp' in show_values: 779 | new_dict[resp_token].append(logp) 780 | if 'logp' not in index_list: 781 | index_list.append('logp') 782 | 783 | if cur_step_filtered_content_dict['ref_logprobs']: 784 | ref_logp = cur_step_filtered_content_dict['ref_logprobs'][sample_index][token_idx] 785 | if 'ref_logp' in show_values: 786 | new_dict[resp_token].append(ref_logp) 787 | if 'ref_logp' not in index_list: 788 | index_list.append('ref_logp') 789 | 790 | if cur_step_filtered_content_dict['probs']: 791 | prob = cur_step_filtered_content_dict['probs'][sample_index][token_idx] 792 | if 'prob' in show_values: 793 | new_dict[resp_token].append(prob) 794 | if 'prob' not in index_list: 795 | index_list.append('prob') 796 | 797 | if cur_step_filtered_content_dict['ref_probs']: 798 | ref_prob = cur_step_filtered_content_dict['ref_probs'][sample_index][token_idx] 799 | if 'ref_prob' in show_values: 800 | new_dict[resp_token].append(ref_prob) 801 | if 'ref_prob' not in index_list: 802 | index_list.append('ref_prob') 803 | 804 | try: 805 | token_level_df = pd.DataFrame.from_dict(new_dict) 806 | renamed_index_dict = dict((i, name) for i, name in enumerate(index_list)) 807 | token_level_df.rename( 808 | index=renamed_index_dict, 809 | inplace=True 810 | ) 811 | 812 | st.dataframe( 813 | token_level_df.style.background_gradient(axis=1, cmap="binary"), 814 | use_container_width=True 815 | ) 816 | 817 | if st.session_state['show_token_heat_map']: 818 | fig = px.imshow( 819 | token_level_df, 820 | text_auto=True, 821 | aspect="auto", 822 | color_continuous_scale="balance", 823 | ) 824 | fig.update_xaxes(side="top") 825 | st.plotly_chart(fig, theme="streamlit", use_container_width=True) 826 | except Exception as e: 827 | st.error(f'Error occured: {e}.') 828 | st.write(new_dict) 829 | 830 | 831 | if __name__ == '__main__': 832 | init_sidebar() 833 | main_page() 834 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | streamlit run rl_logging_board.py --server.port 8901 --------------------------------------------------------------------------------