├── .gitignore
├── .streamlit
└── config.toml
├── LICENSE
├── README.md
├── assets
├── overview.png
├── response_area.png
├── reward_area.png
├── token_area.png
└── use_example.png
├── requirments.txt
├── rl_logging_board.py
├── rollout_samples
├── eval_poem_task
│ └── data.jsonl
└── train_poem_task
│ └── data.jsonl
└── start.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # IDE
24 | .idea/
25 | .vscode/
26 | *.swp
27 | *.swo
28 |
29 | # 环境
30 | venv/
31 | env/
32 | ENV/
33 | .env
34 | .venv
35 |
36 | # 日志和缓存
37 | *.log
38 | .cache
39 | .pytest_cache/
40 | .coverage
41 | htmlcov/
42 |
43 | # 系统
44 | .DS_Store
45 | Thumbs.db
46 |
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | base="dark"
3 | backgroundColor="#171719"
4 | secondaryBackgroundColor="#202025"
5 | primaryColor="#AAAAAD"
6 | font="serif"
7 | textColor="#ceced2"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 📖 RL Logging Board
2 |
3 | RL Logging Board 是一个将 Reinforcement Learning from Human Feedback(RLHF)的`训练过程进行可视化的工具`,旨在:
4 |
5 | 1. 帮助人们直观`理解 RL 训练过程`,如:token 概率会随着训练升高/降低情况、response reward 分布随着训练的变化情况等。
6 | 2. 当训练不符合预期时,`通过监控 token 粒度的指标来定位可能的问题`,如:critic value 方差过大,某些异常 token 的 kl 过大等。
7 | 3. 更直观的看到每一个 step 高/低 reward 对应的 response 分布,用于`发现 reward hacking 的潜在 pattern`。
8 | 4. (可选)直观进行 `RL & SFT 模型之间的对比`,包含 response、reward 在内的多个方面。
9 |
10 | > **⚠️ 注意:** RL Logging Board 仅进行指标可视化,不包含训练框架本身!该工具初衷并非是成为 tensorboard & wanDB 的平替,事实上我们在实做的时候也会同时使用 tensorboard 和该工具。对于一些数值类型的简单 metric(例如 reward_mean,response_length,loss 等)tensorboard 已经能很好的胜任,而对于一些更细粒度的展示(如 token 级别的指标)我们会使用该工具进行辅助。 我们在 [[这个版本的 OpenRLHF](https://github.com/HarderThenHarder/OpenRLHF/tree/support-rl-logging-board)] 中集成了对这个工具的使用。
11 |
12 |
13 |
14 |

15 |
RL Logging Board 概览图
16 |
17 |
18 | 使用该工具需要使用者在自己使用的训练框架中(如:[[OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)]),加入对应的指标保存代码(保存为 `.jsonl` 的本地文件,在 [[后续内容](#21-保存数据格式)] 中会有对应 OpenRLHF 的数据获取示例),再使用该工具进行可视化。
19 |
20 | 我们会在后面详细介绍所有需要保存的每一个 metric,并提供一些 [[示例文件](rollout_samples/train_poem_task/data.jsonl)] 以供参考。
21 |
22 |
23 | ## 1. 有哪些可视化的部分?(Visualization Modules)
24 |
25 | > PS:如果只想运行该平台,可直接跳转至 [[下一章节:如何运行该工具?](#2-如何运行该工具)]。
26 |
27 | 在这个部分中我们将描述该工具支持的所有可视化功能,并分享一些我个人常用到的一些关键部分。
28 |
29 | 我们使用一个非常简单的「押韵任务」(训练模型续写的诗句必须押韵)来进行示例展示,以便于大家理解。
30 |
31 | ### 1.1 Reward Area(curve & distribution)
32 |
33 | reward 是 RL 训练中的核心指标,也是训练中观测最频繁的一个 metric,我们主要关注:
34 |
35 | - **训练曲线(Traning Curves):** 在保存文件中,我们可以将 reference model 的 reward 保存,随后进行可视化(可选),这通常需要让 reference model(init policy)对指定的 prompt(s) 进行推理,并得到打分结果。因此这个过程通常发生在 dev 测试集上(或者在训练集合中提前把 init policy 的结果提前跑好),该指标能更好的进行 RL Model 和 Reference Model 之间的 metric 比较,是一个可选项(`如果不保存 reference metric 则会只展示 rl model 的 metirc`)。
36 | - **每个 batch 内得分分布(Reward Distribution):** 通过观测得分分布,我们可以直观的看到 RL 训练的收敛性,由于 PPO 属于 [[Reverse KL](https://agustinus.kristia.de/blog/forward-reverse-kl/)] 的优化方法,因此正常来讲,`训练后期的 reward distribution 应该比前期收的更尖锐(如下图中上下两部分中的右侧柱状统计图)`,如果训练过程发现并非如此,那么我们可能需要关注下训练框架或是训练 setting 中是否存在问题。
37 | - **与 Reference Model 之间的 reward 差值分布:** 正常来讲,`训练到后期的模型 reward 低于 init policy 的样本应当越来越少(如下图中上下两部分中的左侧侧柱状统计图)`,我们可以通过查看在训练过程中一直未能战胜 init policy 的那些样本,并分析为什么这类样本无法提升。(是否是训练集的 OOD,或是 RM 的 OOD)
38 |
39 |
40 |

41 |
42 |
43 | ### 1.2 Response Area(Sort by Many Metrics)
44 |
45 | Instance Level 的监控同样比较重要,我们不仅需要看到一个 step 的 mean reward,同时希望看到这个 step 下每一个样本的具体指标是怎样的(这有助于我们找到整个训练过程中的那些“害群之马”,将那些异常的 case 筛选出来进行分析)。
46 |
47 | 我们主要关注:
48 |
49 | - **按 reward 排序(Reward Sort):** 通过升序/降序排列,我们能够查看每个 batch 内部分数高(或低)的样本的特征。`对于高分,我们需要着重考虑这类 prompt 是否存在 hacking features`(这通常会通过一些其他方法锁定到发生 hacking 的 step 范围),接着尝试在这些 step(s) 内的高分样本中观测得到可能的 hacking feature(s),并在 Reward Model 侧验证这些猜想。`对于低分,我们需要尝试解释这类 prompts 为何一直无法被优化`(这可能是由于 init policy 能力不足以完成这类型的任务,或是现有 reward model 对这类任务打不出更高的分数等)。
50 | - **按 log_ratio 排序(KL Sort):** `log_ratio(或 KL)能够直接反应「当前模型被优化的程度」`,因此我们可以对 log_ratio 进行排序,我们能够看到在相同的训练步数下,哪些类型的 prompts 被优化的程度过高(小心会过早的出现 hacking),哪些类型的 prompts 被优化的很少(不怎么能被提升)。除此之外,`由于 kl 会和 reward 一起被加入 returns 中优化,因此,对 kl 的监控也是必要的`(例如出现 negative kl 过大导致训练目标偏移),这时我们需要找到那些 kl 过大的样本,并对这些样本进行分析(为什么会出现过大 kl 的情况),从而排除一些训练框架本身的问题。
51 |
52 |
53 |
54 |

55 |
56 |
57 | ### 1.3 Token Area(kl, critic value, dense_reward, prob)
58 |
59 | Token Level 的监控是我们所能做到最小粒度的监控,它反映了整个 RL 训练过程中每一个 token 上的变化情况,具体来讲,我们主要关注:
60 |
61 | - **token reward:** 这是 policy model 训练时真实使用的 signal(= `kl_penalty + rewards(dense or sparse)`)。我们可以关注训练过程中,每个 step 下不同 token 被「奖励 or 惩罚」的程度,从而帮助我们更好的调整各项 reward signal 之间的比例(例如 kl_coef,或是 dense signal 的 ratio)。
62 | - **token value:** 这是 `critic model 为当前 policy model 在每个 token(action)上给出的评价(state value)`,通过直接比较 token value 和 token reward,我们能较为直观的看出 value model 对哪些 token 的拟合是容易的(MSE 小),对哪些 token 的拟合是困难的(MSE 大)。
63 | - **prob & ref_prob:** 该指标反映了`当前 response 中每一个 token 「被当前 policy 选择的概率」和「被 init policy 选择的概率」`,相比于 log_ratio,概率能够更直观的让我们理解模型当前的「策略」是怎样的。对于那些低分 case,我们可以观测在 init policy 模型上它们的概率是怎样的,若 init model 较为「坚定」的选择了这些 bad case,或许我们可以反查回 sft 数据中大概率存在这样的异常数据。
64 |
65 |
66 |

67 |
68 |
69 |
70 | ## 2. 如何运行该工具?
71 |
72 | ### 2.1 保存数据格式
73 |
74 | 在之前介绍中有提到,因为这个工具本身只做可视化,本身不包含训练,`因此需要我们在别的训练框架中自行保存工具所需要的 metric`,因为与训练框架解耦,理论可以支持任何训练框架的训练过程可视化。
75 |
76 | 工具加载需要 `.jsonl` 的文件([查看示例文件](rollout_samples/train_poem_task/data.jsonl)),每一行 json 需要包含以下 key(蓝链代码以 OpenRLHF 框架来为例):
77 |
78 |
79 | | 字段名 | 类型 | 描述 |
80 | |--------|------|------|
81 | | prompt | str | prompt 的字符串形式(可带 Pad Token,后续在工具里设置自动过滤即可),通常在 [generate](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L249) 中拿到 |
82 | | response | str | response 的字符串形式(可带 Pad Token,后续在工具里设置自动过滤即可),通常在 [generate](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L254) 中拿到 |
83 | | response_tokens | List[str] | response token 的 str list,用于逐 token 展示细粒度指标(kl,dense signal 等),可通过 tokenizer.convert_ids_to_tokens() 获得 |
84 | | logprobs | List[float] | response token 对应 policy model 的 log_probs,在 [forward 过程](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L286) 中获得 |
85 | | ref_logprobs | List[float] | response token 对应 reference model 的 log_probs,在 [forward 过程](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L289) 中获得 |
86 | | values | List[float] | critic model 对 response token 打出的 value 值,在 [critic forward](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L293) 阶段中获得 |
87 | | token_rewards | List[float] | kl_penalty + reward (dense 直接相加,sparse 加在尾 token),对应训练中所使用的真实 rewards,通过 [combined reward](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/models/utils.py#L74) 获得 |
88 | | reward | float | response 的整体 reward,通过 [rm 打分](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_utils/experience_maker.py#L304) 后获得 |
89 | | step | int | 当前数据的训练步数,在 [training loop](https://github.com/OpenRLHF/OpenRLHF/blob/fb170bc161dc23244eeefe6f9ec04306a61cfd24/openrlhf/trainer/ppo_trainer.py#L250C17-L250C22) 中获得 |
90 | | ref_response | str, Optional | reference model 对该 prompt 的采样结果(可选) |
91 | | ref_reward | float, Optional | reference response 的整体 reward(可选) |
92 |
93 |
94 | ```json
95 | {
96 | "prompt": "请编写“上相生坤位,中兴运泰开。”的下一句,要求押韵。",
97 | "response": "威仪如鹤梦,道德似桃梅。",
98 | "response_tokens": ["威", "仪", "如", "鹤", "梦", ",", "道德", "似", "桃", "梅", "。", "", ""],
99 | "logprobs": [-4.847491264343262, -1.052163004875183, -3.1773641109466553, -3.155355215072632, -3.759133815765381, -0.0032821616623550653, -4.711000442504883, -0.7994625568389893, -4.159769535064697, -1.7499101161956787, -0.0008301864145323634, -2.3007127310847864e-05],
100 | "ref_logprobs": [-4.84375,-1.0546875,-3.171875,-3.15625,-3.765625,-0.0032806396484375,-4.71875,-0.80078125,-4.15625,-1.75,-0.00083160400390625,-2.3245811462402344e-05],
101 | "values": [-0.61328125,-0.01904296875,-0.373046875,-0.62890625,-0.3203125,-0.328125,-0.302734375,-0.353515625,-0.1474609375,-0.19140625,0.08642578125,-0.09765625],
102 | "token_rewards": [0.0007482529035769403,-0.0005048990133218467,0.001097822212614119,-0.00017895699420478195,-0.0012982368934899569,3.0440278919741104e-07,-0.0015499115688726306,-0.0002637386496644467,0.0007039070478640497,-1.7976761228055693e-05,-2.835178918303427e-07,-4.77368296003533e-08],
103 | "reward": 0.0,
104 | "step": 4,
105 | "(Optional)ref_response": "从 reference model 采样的结果(可选)",
106 | "(Optional)ref_reward": 0.0,
107 | },
108 | ...
109 | ```
110 |
111 | 将数据文件保存到 `./rollout_samples/` 目录下一个 `单独的文件夹` 即可。
112 |
113 | > **Note:** 工具会读取目标文件夹下的所有 .jsonl 文件,因此在训练中可按 DP 存储为独立的文件(这样就不用 gather 到主节点),工具会根据 json data 中的 `step` 进行自动整合。
114 |
115 | ### 2.2 启动可视化工具
116 |
117 | 1. 安装工具所需要的依赖包:
118 |
119 | ```sh
120 | pip install -r requirments.txt
121 | ```
122 |
123 | 2. 运行启动脚本:
124 |
125 | ```sh
126 | bash start.sh
127 | ```
128 |
129 | 其中,`start.sh` 里通过 `--server.port` 来指定 web 页面的端口:
130 |
131 | ```sh
132 | streamlit run rl_logging_board.py --server.port 8901
133 | ```
134 |
135 |
136 |

137 |
138 |
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/overview.png
--------------------------------------------------------------------------------
/assets/response_area.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/response_area.png
--------------------------------------------------------------------------------
/assets/reward_area.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/reward_area.png
--------------------------------------------------------------------------------
/assets/token_area.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/token_area.png
--------------------------------------------------------------------------------
/assets/use_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HarderThenHarder/RLLoggingBoard/cf56d5d97b293eddc75e6474e6607c9ccafc7fd5/assets/use_example.png
--------------------------------------------------------------------------------
/requirments.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.40.0
2 | ujson==5.10.0
3 | plotly==5.18.0
--------------------------------------------------------------------------------
/rl_logging_board.py:
--------------------------------------------------------------------------------
1 | """
2 | ==== No Bugs in code, just some Random Unexpected FEATURES ====
3 | ┌─────────────────────────────────────────────────────────────┐
4 | │┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐│
5 | ││Esc│!1 │@2 │#3 │$4 │%5 │^6 │&7 │*8 │(9 │)0 │_- │+= │|\ │`~ ││
6 | │├───┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴───┤│
7 | ││ Tab │ Q │ W │ E │ R │ T │ Y │ U │ I │ O │ P │{[ │}] │ BS ││
8 | │├─────┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴┬──┴─────┤│
9 | ││ Ctrl │ A │ S │ D │ F │ G │ H │ J │ K │ L │: ;│" '│ Enter ││
10 | │├──────┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴────┬───┤│
11 | ││ Shift │ Z │ X │ C │ V │ B │ N │ M │< ,│> .│? /│Shift │Fn ││
12 | │└─────┬──┴┬──┴──┬┴───┴───┴───┴───┴───┴──┬┴───┴┬──┴┬─────┴───┘│
13 | │ │Fn │ Alt │ Space │ Alt │Win│ HHKB │
14 | │ └───┴─────┴───────────────────────┴─────┴───┘ │
15 | └─────────────────────────────────────────────────────────────┘
16 |
17 | 启动 web 页面可视化 RL 训练过程中间 metric 的数据状态。
18 |
19 | Author: pankeyu
20 | Date: 2023/10/31
21 | """
22 | import os
23 | import copy
24 | import traceback
25 |
26 | try:
27 | import ujson as json
28 | except:
29 | import json
30 | print('`pip install ujson` can be faster.')
31 |
32 | import numpy as np
33 | import pandas as pd
34 | import streamlit as st
35 |
36 | import plotly.express as px
37 | import plotly.graph_objects as go
38 | import plotly.figure_factory as ff
39 |
40 |
41 | st.set_page_config(
42 | page_title="RL Logging Board",
43 | page_icon="📖",
44 | layout='wide'
45 | )
46 |
47 |
48 | def load_log_file(
49 | logdir: os.PathLike,
50 | max_samples_each_step: int
51 | ):
52 | """
53 | 解析本地log文件。
54 |
55 | Args:
56 | logdir (os.PathLike): _description_
57 | max_samples_each_step (int): _description_
58 | """
59 | st.session_state['logging_name'] = logdir
60 | st.session_state['max_samples_each_step'] = max_samples_each_step
61 | st.session_state['logging_data'] = {}
62 | error_lines, success_lines = 0, 0
63 |
64 | all_logs = os.listdir(logdir)
65 |
66 | progress_text = f"Processing all files..."
67 | loading_files_bar = st.progress(0., text=progress_text)
68 |
69 | progress_text = f"Processing each file samples..."
70 | loading_samples_bar = st.progress(0., text=progress_text)
71 |
72 |
73 | for log_index in range(len(all_logs)):
74 |
75 | if not all_logs[log_index].endswith('.jsonl'):
76 | continue
77 |
78 | rl_log_file = os.path.join(
79 | logdir,
80 | all_logs[log_index]
81 | )
82 |
83 | mock_max_lines_num = 10000
84 |
85 | with open(rl_log_file, 'r', encoding='utf8', errors='ignore') as f:
86 | for i, line in enumerate(f):
87 | try:
88 | data = json.loads(line)
89 | data['step'] = int(data['step'])
90 | if data['step'] not in st.session_state['logging_data']:
91 | st.session_state['logging_data'][data['step']] = {
92 | 'prompt': [],
93 | 'response': [],
94 | 'ref_response': [],
95 | 'reward': [],
96 | 'ref_reward': [],
97 | 'response_tokens': [],
98 | 'logprobs': [],
99 | 'ref_logprobs': [],
100 | 'probs': [],
101 | 'ref_probs': [],
102 | 'values': [],
103 | 'token_rewards': [],
104 | 'kl': [],
105 | 'avg_kl': [],
106 | 'sum_kl': [],
107 | 'log_ratio': [],
108 | 'avg_log_ratio': [],
109 | 'sum_log_ratio': [],
110 | 'valid_reward': [],
111 | 'ref_valid_reward': [],
112 | 'response_tokens_len': [],
113 | 'ground_truth': []
114 | }
115 | elif len(st.session_state['logging_data'][data['step']]['prompt']) >= max_samples_each_step:
116 | percentage = (i + 1) / mock_max_lines_num
117 | percentage = min(percentage, 1.0)
118 | loading_samples_bar.progress(percentage, text=f"[{int(percentage * 100)}%] Processing {i + 1} / {mock_max_lines_num} samples in each files...")
119 |
120 | for key in st.session_state['logging_data'][data['step']]:
121 | if key in data:
122 | st.session_state['logging_data'][data['step']][key].append(data[key])
123 |
124 | if 'response_tokens' in data:
125 | st.session_state['logging_data'][data['step']]['response_tokens_len'].append(len(data['response_tokens']))
126 |
127 | if 'logprobs' in data and 'ref_logprobs' in data:
128 | logp = np.array(data['logprobs'])
129 | ref_logp = np.array(data['ref_logprobs'])
130 | log_ratio = logp - ref_logp
131 | kl = np.exp(log_ratio) - 1 - log_ratio
132 | st.session_state['logging_data'][data['step']]['log_ratio'].append(log_ratio.tolist())
133 | st.session_state['logging_data'][data['step']]['avg_log_ratio'].append(np.nanmean(log_ratio))
134 | st.session_state['logging_data'][data['step']]['sum_log_ratio'].append(np.nansum(log_ratio))
135 | st.session_state['logging_data'][data['step']]['kl'].append(kl.tolist())
136 | st.session_state['logging_data'][data['step']]['avg_kl'].append(np.nanmean(kl))
137 | st.session_state['logging_data'][data['step']]['sum_kl'].append(np.nansum(kl))
138 | st.session_state['logging_data'][data['step']]['probs'].append(np.exp(logp).tolist())
139 | st.session_state['logging_data'][data['step']]['ref_probs'].append(np.exp(ref_logp).tolist())
140 |
141 | success_lines += 1
142 |
143 | except:
144 | print(traceback.format_exc())
145 | error_lines += 1
146 |
147 | percentage = (i + 1) / mock_max_lines_num
148 | percentage = min(percentage, 1.0)
149 | loading_samples_bar.progress(percentage, text=f"[{int(percentage * 100)}%] Processing {i + 1} / {mock_max_lines_num} samples...")
150 |
151 | percentage = 1.0
152 | loading_samples_bar.progress(percentage, text=f"[{int(percentage * 100)}%] Processing {(success_lines + error_lines)} / {(success_lines + error_lines)} samples...")
153 |
154 | file_percentage = (log_index + 1) / len(all_logs)
155 | loading_files_bar.progress(file_percentage, text=f"[{int(file_percentage * 100)}%] Loading {log_index + 1} / {len(all_logs)} files...")
156 |
157 | st.toast(
158 | f'Loaded {success_lines + error_lines} sample(s), sucess: {success_lines}, error: {error_lines}.',
159 | icon='🎉'
160 | )
161 |
162 | if not st.session_state['logging_data']:
163 | st.warning(f'No log file(s) found in {logdir}.', icon='⚠️')
164 | st.stop()
165 |
166 | all_steps = [int(s) for s in list(st.session_state["logging_data"].keys())]
167 | all_steps.sort()
168 | st.session_state['max_step_index'] = max(all_steps)
169 | st.session_state['min_step_index'] = min(all_steps)
170 | st.session_state['step_gap'] = 1 if len(all_steps) < 2 else all_steps[1] - all_steps[0]
171 |
172 | rewards_dict = {'step': [], 'reward': [], 'ref_reward': []}
173 | for step in st.session_state['logging_data']:
174 | st.session_state['logging_data'][step]['avg_reward'] = sum(st.session_state['logging_data'][step]['reward']) / len(st.session_state['logging_data'][step]['reward'])
175 |
176 | current_step_resp_length = [len(resp) for resp in st.session_state['logging_data'][step]['response']]
177 | st.session_state['logging_data'][step]['avg_length'] = int(sum(current_step_resp_length) / len(current_step_resp_length))
178 |
179 | current_step_ref_resp_length = [len(resp) for resp in st.session_state['logging_data'][step]['ref_response']]
180 | st.session_state['logging_data'][step]['avg_ref_length'] = int(sum(current_step_ref_resp_length) / len(current_step_ref_resp_length)) if len(current_step_ref_resp_length) else 0
181 |
182 | if len(st.session_state['logging_data'][step]['ref_reward']):
183 | st.session_state['logging_data'][step]['avg_ref_reward'] = sum(st.session_state['logging_data'][step]['ref_reward']) / len(st.session_state['logging_data'][step]['ref_reward']) if len(st.session_state['logging_data'][step]['ref_reward']) else 0
184 | else:
185 | st.session_state['logging_data'][step]['avg_ref_reward'] = 0
186 | rewards_dict['step'].append(step)
187 | rewards_dict['reward'].append(st.session_state['logging_data'][step]['avg_reward'])
188 | rewards_dict['ref_reward'].append(st.session_state['logging_data'][step]['avg_ref_reward'])
189 |
190 | rewards_df = pd.DataFrame.from_dict(rewards_dict)
191 | st.session_state['reward_df'] = rewards_df.set_index('step')
192 |
193 |
194 | def plot_filled_line(
195 | x: list,
196 | y_list_list: list,
197 | data_names: list,
198 | colors: list,
199 | title=None
200 | ):
201 | """
202 | 绘制带有阴影的折线图,阴影上下界为当前x对应的y列表中的最大、最小值。
203 |
204 | Args:
205 | x (list): step 横轴索引
206 | y_list_list (line_num, steps, step_wise): 可绘制多条直线,维度为:绘制折线条数,总共的step数,每个step对应几个y值
207 | data_names (list): 每条折线的名字列表
208 | colors (list): 每条折线的颜色列表(rgb), e.g. -> ['255,171,171']
209 | """
210 | fig = go.Figure()
211 |
212 | x_rev = x[::-1]
213 | for i in range(len(y_list_list)):
214 | y_list = y_list_list[i]
215 | y_mean, y_lower, y_upper = [], [], []
216 | for y in y_list:
217 | y_arr = np.array(y)
218 | mean, std = float(y_arr.mean()), float(y_arr.std())
219 | y_mean.append(mean)
220 | y_lower.append(mean - std)
221 | y_upper.append(mean + std)
222 | # y_lower.append(min(y))
223 | # y_upper.append(max(y))
224 | y_lower = y_lower[::-1]
225 |
226 | fig.add_trace(go.Scatter(
227 | x=x + x_rev,
228 | y=y_upper + y_lower,
229 | fill='toself',
230 | fillcolor=f'rgba({colors[i]},0.1)',
231 | line_color='rgba(255,255,255,0)',
232 | showlegend=False,
233 | name=data_names[i],
234 | ))
235 | fig.add_trace(go.Scatter(
236 | x=x, y=y_mean,
237 | line_color=f'rgb({colors[i]})',
238 | name=data_names[i],
239 | ))
240 |
241 | fig.update_traces(mode='lines')
242 |
243 | if title:
244 | fig.update_layout(
245 | title=title,
246 | legend=dict(orientation="h")
247 | )
248 |
249 | return fig
250 |
251 |
252 | def init_sidebar():
253 | """
254 | 侧边栏实例化。
255 | """
256 | st.sidebar.markdown(
257 | "📖 RL Logging Board
",
258 | unsafe_allow_html=True
259 | )
260 |
261 | base_root_path = st.sidebar.text_input(
262 | "Log(s) Root Path",
263 | value='./rollout_samples',
264 | )
265 |
266 | if not os.path.exists(base_root_path):
267 | st.warning(f'Log(s) Root Path: `{base_root_path}` is not exists.', icon='⚠️')
268 | st.stop()
269 |
270 | all_log_path_in_logdir = os.listdir(base_root_path)
271 |
272 | if not all_log_path_in_logdir:
273 | st.warning('No log files found.')
274 | st.code("""Logging Dir should be like:
275 | Base Log Dir
276 | |__eval_topk_0_topp_1 (dir for evaluate logs)
277 | | |__eval.jsonl
278 | |__topk_0_topp_1 (dir for training logs, only for rl logs)
279 | |__rollout_data_rank_0_1313.jsonl
280 | ...
281 | """)
282 | st.stop()
283 |
284 | log_name = st.sidebar.selectbox(
285 | 'Choose Log Name',
286 | options=all_log_path_in_logdir,
287 | index=len(all_log_path_in_logdir) - 1
288 | )
289 |
290 | max_samples_each_step = st.sidebar.number_input(
291 | 'Max Samples Each Step',
292 | help='当step batch size 过大时可能会造成平台卡顿,可设置阈值来下采样每个step的数据。',
293 | value=128,
294 | max_value=10240,
295 | min_value=1
296 | )
297 |
298 | load_btn = st.sidebar.button(
299 | "Load & View",
300 | use_container_width=True
301 | )
302 |
303 | if load_btn and (
304 | 'logging_data' not in st.session_state
305 | or
306 | log_name != st.session_state['logging_name']
307 | or
308 | max_samples_each_step != st.session_state.get('max_samples_each_step', -1)
309 | ):
310 | load_log_file(
311 | os.path.join(base_root_path, log_name),
312 | max_samples_each_step
313 | )
314 |
315 | with st.sidebar.expander('🧩 module setting', expanded=True):
316 | st.session_state['show_reward_logging'] = st.checkbox('Reward 曲线图', value=True)
317 | st.session_state['var_scaling'] = st.slider('Variance Scaling', min_value=0.1, max_value=1.0, value=0.2, help='Reward 曲线图阴影面积调整(对方差做 scaling)。')
318 | st.session_state['zero_shift'] = st.checkbox('Zero Shift', value=False, help='是否将所有reward曲线的第一项都平移到0(仅用于对比变化趋势)。')
319 | st.session_state['show_response'] = st.checkbox('Response 对比', value=True)
320 |
321 | with st.sidebar.expander('⚙️ show details setting', expanded=True):
322 | st.session_state['use_logp_as_kl'] = st.checkbox('Use LogP as KL', value=True, help='在 Reward 曲线图中用 LogProb 替代 KL 展示。')
323 | st.session_state['drop_pad'] = st.checkbox('Drop Padding Token', value=True)
324 | st.session_state['pad_token'] = st.text_input('Pad Token', value='', disabled=not st.session_state['drop_pad'])
325 | st.session_state['drop_sys_prompt'] = st.checkbox('Drop System Prompt', value=True)
326 | st.session_state['end_token_of_sys_prompt'] = st.text_input('End Token of System Prompt', value='', disabled=not st.session_state['drop_sys_prompt'])
327 | st.session_state['show_charts'] = st.checkbox('Show Charts', value=True)
328 | st.session_state['show_batch_samples'] = st.checkbox('Show Batch Samples', value=True)
329 | st.session_state['show_samples_pair'] = st.checkbox('Show Samples Pair', value=True)
330 | st.session_state['show_token_heat_map'] = st.checkbox('Show Heat Map', value=True)
331 |
332 | def plot_filled_line(
333 | x: list,
334 | y_list_list: list,
335 | data_names: list,
336 | colors: list,
337 | title=None,
338 | var_scaling=1.
339 | ):
340 | """
341 | 绘制带有阴影的折线图,阴影上下界为当前x对应的y列表中的最大、最小值。
342 |
343 | Args:
344 | x (list): step 横轴索引
345 | y_list_list (line_num, steps, step_wise): 可绘制多条直线,维度为:绘制折线条数,总共的step数,每个step对应几个y值
346 | data_names (list): 每条折线的名字列表
347 | colors (list): 每条折线的颜色列表(rgb), e.g. -> ['255,171,171']
348 | """
349 | fig = go.Figure()
350 |
351 | x_rev = x[::-1]
352 | for i in range(len(y_list_list)):
353 | y_list = y_list_list[i]
354 | zero_shift_value = 0
355 | y_mean, y_lower, y_upper = [], [], []
356 |
357 | for idx, y in enumerate(y_list):
358 | y_arr = np.array(y)
359 | if idx == 0 and st.session_state['zero_shift']:
360 | zero_shift_value = np.nanmean(y_arr)
361 |
362 | y_arr = y_arr - zero_shift_value
363 | mean, std = float(np.nanmean(y_arr)), float(np.nanstd(y_arr))
364 | std *= var_scaling
365 | y_mean.append(mean)
366 | y_lower.append(mean - std)
367 | y_upper.append(mean + std)
368 | y_lower = y_lower[::-1]
369 |
370 | fig.add_trace(go.Scatter(
371 | x=x + x_rev,
372 | y=y_upper + y_lower,
373 | fill='toself',
374 | fillcolor=f'rgba({colors[i]},0.1)',
375 | line_color='rgba(255,255,255,0)',
376 | showlegend=False,
377 | name=data_names[i],
378 | ))
379 | fig.add_trace(go.Scatter(
380 | x=x, y=y_mean,
381 | line_color=f'rgb({colors[i]})',
382 | name=data_names[i],
383 | ))
384 |
385 | fig.update_traces(mode='lines')
386 |
387 | if title:
388 | fig.update_layout(
389 | title=title,
390 | legend=dict(orientation="h")
391 | )
392 |
393 | return fig
394 |
395 |
396 | def main_page():
397 | """
398 | Metrics Page.
399 | """
400 | if "logging_data" not in st.session_state:
401 | st.info("Please Press 「Load & View」Button to load log.")
402 | else:
403 | if st.session_state['show_reward_logging']:
404 | step_reward_tab, step_kl_tab, resp_len_tab = st.tabs([
405 | 'Step-Reward',
406 | 'Step-KL',
407 | 'Step-RespLen'
408 | ])
409 |
410 | with step_reward_tab:
411 | steps, reward, ref_reward, valid_reward, ref_valid_reward = [], [], [], [], []
412 | for step, value_dict in st.session_state['logging_data'].items():
413 | steps.append(step)
414 | reward.append(value_dict['reward'])
415 |
416 | if value_dict['ref_reward']:
417 | ref_reward.append(value_dict['ref_reward'])
418 |
419 | if value_dict['valid_reward']:
420 | valid_reward.append(value_dict['valid_reward'])
421 |
422 | if value_dict['ref_valid_reward']:
423 | ref_valid_reward.append(value_dict['ref_valid_reward'])
424 |
425 | all_curves = {
426 | 'ref_reward': {
427 | 'value': ref_reward,
428 | 'color': '132,201,255'
429 | },
430 | 'reward': {
431 | 'value': reward,
432 | 'color': '255,171,171'
433 | },
434 | 'ref_valid_reward': {
435 | 'value': ref_valid_reward,
436 | 'color': '132,155,200'
437 | },
438 | 'valid_reward': {
439 | 'value': valid_reward,
440 | 'color': '200,155,200'
441 | }
442 | }
443 |
444 | candidate_curves = [key for key in all_curves if all_curves[key]['value']]
445 |
446 | show_curves = st.multiselect(
447 | 'Show Rewards',
448 | candidate_curves,
449 | candidate_curves,
450 | label_visibility='collapsed'
451 | )
452 |
453 | reward_fig = plot_filled_line(
454 | x=steps,
455 | y_list_list=[all_curves[r]['value'] for r in show_curves],
456 | data_names=show_curves,
457 | colors=[all_curves[r]['color'] for r in show_curves],
458 | title='👾 Rewards Logging (Step level)',
459 | var_scaling=st.session_state['var_scaling']
460 | )
461 |
462 | st.plotly_chart(reward_fig, theme="streamlit", use_container_width=True)
463 |
464 | with step_kl_tab:
465 | steps, kl = [], []
466 |
467 | if st.session_state['use_logp_as_kl']:
468 | for step, value_dict in st.session_state['logging_data'].items():
469 | if all(value_dict['avg_log_ratio']):
470 | steps.append(step)
471 | kl.append(value_dict['avg_log_ratio'])
472 | else:
473 | for step, value_dict in st.session_state['logging_data'].items():
474 | if all(value_dict['kl']):
475 | steps.append(step)
476 | kl.append(value_dict['avg_kl'])
477 |
478 | reward_fig = plot_filled_line(
479 | x=steps,
480 | y_list_list=[kl],
481 | data_names=['KL'],
482 | colors=['255,165,0'],
483 | title='👾 KL Logging (Step level)'
484 | )
485 | st.plotly_chart(reward_fig, theme="streamlit", use_container_width=True)
486 |
487 | with resp_len_tab:
488 | steps, resp_len = [], []
489 |
490 | for step, value_dict in st.session_state['logging_data'].items():
491 | if value_dict['response_tokens_len']:
492 | steps.append(step)
493 | resp_len.append(value_dict['response_tokens_len'])
494 |
495 | resp_len_fig = plot_filled_line(
496 | x=steps,
497 | y_list_list=[resp_len],
498 | data_names=['resp_len'],
499 | colors=['255,165,0'],
500 | title='👾 Response Length Logging (Step level)'
501 | )
502 | st.plotly_chart(resp_len_fig, theme="streamlit", use_container_width=True)
503 |
504 | if st.session_state['show_response']:
505 | st.markdown('⚡️ **Each Step Response**')
506 |
507 | if st.session_state['min_step_index'] == st.session_state['max_step_index']:
508 | step_index = st.session_state['min_step_index']
509 | elif (
510 | len(st.session_state['logging_data']) > 2
511 | and
512 | list(st.session_state['logging_data'].keys())[2] - list(st.session_state['logging_data'].keys())[1] != list(st.session_state['logging_data'].keys())[1] - list(st.session_state['logging_data'].keys())[0]
513 | ):
514 | step_index = st.selectbox(
515 | f"Step Index({st.session_state['max_step_index']} total steps):",
516 | list(st.session_state['logging_data'].keys()),
517 | index=0
518 | )
519 | else:
520 | step_index = st.slider(
521 | f"Step Index({st.session_state['max_step_index']} total steps):",
522 | min_value=st.session_state['min_step_index'],
523 | max_value=st.session_state['max_step_index'],
524 | value=st.session_state['min_step_index'],
525 | step=st.session_state['step_gap']
526 | )
527 |
528 | cur_step_content_dict = st.session_state['logging_data'][step_index]
529 | cur_step_filtered_content_dict = copy.deepcopy(cur_step_content_dict)
530 |
531 | cur_step_filtered_content_dict['prompt'] = []
532 | for prompt in cur_step_content_dict['prompt']:
533 | if st.session_state['drop_pad']:
534 | prompt = prompt.replace(st.session_state['pad_token'], '').strip()
535 | if st.session_state['drop_sys_prompt']:
536 | prompt = prompt.split(st.session_state['end_token_of_sys_prompt'])[-1]
537 | cur_step_filtered_content_dict['prompt'].append(prompt)
538 |
539 | cur_step_filtered_content_dict['response'] = [c.replace(st.session_state['pad_token'], '').strip() if st.session_state['drop_pad'] else c for c in cur_step_content_dict['response']]
540 | cur_step_filtered_content_dict['reward_gap'] = [r - ref_r for r, ref_r in zip(cur_step_content_dict['reward'], cur_step_content_dict['ref_reward'])]
541 | cur_step_filtered_content_dict['valid_reward_gap'] = [r - ref_r for r, ref_r in zip(cur_step_content_dict['reward'], cur_step_content_dict['valid_reward'])]
542 |
543 | if st.session_state['show_charts']:
544 |
545 | if not cur_step_filtered_content_dict['ref_reward']:
546 | cur_step_filtered_content_dict['ref_reward'] = [0] * len(cur_step_filtered_content_dict['reward'])
547 |
548 | c1, c2, c3 = st.columns([6, 6, 6])
549 |
550 | with c1: # reward 分布
551 | reward_distribution_dict = {
552 | 'sample_index': [],
553 | 'reward': [],
554 | 'tag': []
555 | }
556 | for sample_index, (reward, ref_reward) in enumerate(zip(cur_step_filtered_content_dict['reward'], cur_step_filtered_content_dict['ref_reward'])):
557 | reward_distribution_dict['sample_index'].append(sample_index)
558 | reward_distribution_dict['reward'].append(reward)
559 | reward_distribution_dict['tag'].append('Reward')
560 | reward_distribution_dict['sample_index'].append(sample_index)
561 | reward_distribution_dict['reward'].append(ref_reward)
562 | reward_distribution_dict['tag'].append('Ref Reward')
563 |
564 | reward_distribution_df = pd.DataFrame.from_dict(reward_distribution_dict)
565 | fig = px.bar(
566 | reward_distribution_df,
567 | x="sample_index",
568 | y="reward",
569 | color="tag",
570 | barmode='group',
571 | color_discrete_sequence=px.colors.diverging.Portland,
572 | title="Reward in current batch samples"
573 | )
574 | st.plotly_chart(fig, theme="streamlit", use_container_width=True)
575 |
576 | with c2: # reward gap 分布
577 | reward_distribution_dict = {
578 | 'sample_index': [i for i in range(len(cur_step_filtered_content_dict['reward_gap']))],
579 | 'reward_gap': cur_step_filtered_content_dict['reward_gap']
580 | }
581 | reward_distribution_df = pd.DataFrame.from_dict(reward_distribution_dict)
582 | fig = px.bar(
583 | reward_distribution_df,
584 | x="sample_index",
585 | y="reward_gap",
586 | color="reward_gap",
587 | color_discrete_sequence=['red'],
588 | title="Reward Gap (r - ref_r) in current batch"
589 | )
590 | st.plotly_chart(fig, theme="streamlit", use_container_width=True)
591 |
592 | with c3: # reward 方差分布
593 | if cur_step_filtered_content_dict['ref_reward']:
594 | hist_data = [
595 | cur_step_filtered_content_dict['ref_reward'],
596 | cur_step_filtered_content_dict['reward'],
597 | ]
598 | group_labels = ['Ref Rewards', 'Rewards']
599 | else:
600 | hist_data = [cur_step_filtered_content_dict['reward']]
601 | group_labels = ['Rewards']
602 |
603 | fig = ff.create_distplot(
604 | hist_data,
605 | group_labels,
606 | bin_size=[.02, .02],
607 | curve_type='normal'
608 | )
609 | fig.update_layout(title="Reward Distribution in current batch")
610 | st.plotly_chart(fig, use_container_width=True)
611 |
612 | showed_keys = [
613 | 'prompt',
614 | 'response',
615 | 'reward',
616 | 'ground_truth',
617 | 'valid_reward',
618 | 'avg_log_ratio',
619 | 'sum_log_ratio',
620 | 'avg_kl',
621 | 'sum_kl',
622 | 'ref_response',
623 | 'ref_reward',
624 | 'ref_valid_reward',
625 | 'reward_gap',
626 | 'valid_reward_gap'
627 | ]
628 | candidate_keys = [k for k in showed_keys if cur_step_filtered_content_dict[k]]
629 | content_dict = dict([(k, cur_step_filtered_content_dict[k]) for k in candidate_keys])
630 | content_df = pd.DataFrame.from_dict(content_dict)
631 |
632 | if st.session_state['show_batch_samples']:
633 | st.dataframe(
634 | content_df,
635 | use_container_width=True,
636 | height=350
637 | )
638 |
639 | if st.session_state['show_samples_pair']:
640 |
641 | c1, c2, c3 = st.columns([1, 1, 4])
642 | with c1:
643 | if step_index == st.session_state['min_step_index']:
644 | delta_char = 0
645 | else:
646 | try:
647 | cur_avg_len = st.session_state['logging_data'][step_index]['avg_length']
648 | last_avg_len = st.session_state['logging_data'][step_index-st.session_state['step_gap']]['avg_length']
649 | delta_char = cur_avg_len - last_avg_len
650 | except:
651 | delta_char = 0
652 | st.metric( # actor在当前step下的平均回复长度,delta为与上一个step的比较
653 | 'Response Average Length',
654 | value=f"{st.session_state['logging_data'][step_index]['avg_length']} 字",
655 | delta=f'{delta_char} 字'
656 | )
657 |
658 | with c2: # ref_model在当前step下的平均回复长度,delta为与上一个step的比较
659 | try:
660 | delta_char = 0 if step_index == st.session_state['min_step_index'] else st.session_state['logging_data'][step_index]['avg_ref_length'] - st.session_state['logging_data'][step_index-st.session_state['step_gap']]['avg_ref_length']
661 | except:
662 | delta_char = 0
663 | st.metric(
664 | 'Ref Response Average Length',
665 | value=f"{st.session_state['logging_data'][step_index]['avg_ref_length']} 字",
666 | delta=f'{delta_char} 字'
667 | )
668 |
669 | with c3:
670 | sample_index = st.number_input(
671 | f'Sample index in current step batch: ',
672 | min_value=0,
673 | max_value=len(cur_step_filtered_content_dict['response']) - 1,
674 | value=0
675 | )
676 |
677 | # 单样本展示 response - ref_response 的回复
678 | c1, c2, c3, c4 = st.columns([4, 4, 4, 2])
679 | with c1:
680 | st.markdown('Prompt', unsafe_allow_html=True)
681 | content = cur_step_filtered_content_dict["prompt"][sample_index].replace('\n', ' \n').replace('~', '~')
682 | st.markdown(
683 | f'{content}',
684 | unsafe_allow_html=True
685 | )
686 | with c2:
687 | st.markdown(':green[Response]')
688 | content = cur_step_filtered_content_dict["response"][sample_index].replace('\n', ' \n').replace('~', '~')
689 | st.markdown(
690 | f'{content}',
691 | unsafe_allow_html=True
692 | )
693 | with c3:
694 | st.markdown(':blue[Ref Response]')
695 | if (
696 | "ref_response" in cur_step_filtered_content_dict
697 | and
698 | cur_step_filtered_content_dict["ref_response"]
699 | ):
700 | content = cur_step_filtered_content_dict["ref_response"][sample_index].replace('\n', ' \n').replace('~', '~')
701 | st.markdown(
702 | f'{content}',
703 | unsafe_allow_html=True
704 | )
705 | else:
706 | st.info('No `ref_response` found in log line data.')
707 | with c4:
708 | st.markdown(':orange[Reward Gap]')
709 | reward_gap = round(cur_step_filtered_content_dict["reward_gap"][sample_index], 4) if cur_step_filtered_content_dict["reward_gap"] else 0.
710 | st.metric(
711 | ' ',
712 | value=reward_gap
713 | )
714 |
715 | # 展示更详细的 token-level 的信息
716 | if 'token_rewards' in cur_step_filtered_content_dict and cur_step_filtered_content_dict['token_rewards']:
717 | # 检查 resp_tokens 的长度和 logprobs 的长度是否对齐
718 | resp_token_len = len(cur_step_filtered_content_dict['response_tokens'][sample_index])
719 | logp_len = len(cur_step_filtered_content_dict['logprobs'][sample_index])
720 | if resp_token_len != logp_len:
721 | st.info(
722 | f'Note: `resp_tokens` (len: {resp_token_len}) is not equal to `logprobs` (len: {logp_len}), this may caused by tokens, CLIP response tokens!',
723 | icon='⚠️'
724 | )
725 | cur_step_filtered_content_dict['response_tokens'][sample_index] = cur_step_filtered_content_dict['response_tokens'][sample_index][:logp_len]
726 |
727 | show_values = st.multiselect(
728 | 'Select show value(s)',
729 | ['token_reward', 'log_ratio', 'kl', 'token_value', 'logp', 'ref_logp', 'prob', 'ref_prob'],
730 | ['token_reward', 'log_ratio', 'kl', 'token_value', 'logp', 'ref_logp', 'prob', 'ref_prob']
731 | )
732 |
733 | new_dict, index_list = {}, []
734 |
735 | if st.session_state['drop_pad'] and cur_step_filtered_content_dict['response_tokens'][sample_index][-1] == st.session_state['pad_token']:
736 | first_pad_token_idx = cur_step_filtered_content_dict['response_tokens'][sample_index].index(st.session_state['pad_token'])
737 | response_tokens_without_pad_token = cur_step_filtered_content_dict['response_tokens'][sample_index][:first_pad_token_idx]
738 | else:
739 | response_tokens_without_pad_token = cur_step_filtered_content_dict['response_tokens'][sample_index]
740 |
741 | for token_idx in range(len(response_tokens_without_pad_token)):
742 | if cur_step_filtered_content_dict['response_tokens']:
743 | resp_token = cur_step_filtered_content_dict['response_tokens'][sample_index][token_idx]
744 | resp_token = f'{token_idx} - {resp_token}'
745 | if resp_token not in new_dict:
746 | new_dict[resp_token] = []
747 |
748 | if cur_step_filtered_content_dict['token_rewards']:
749 | token_reward = cur_step_filtered_content_dict['token_rewards'][sample_index][token_idx]
750 | if 'token_reward' in show_values:
751 | new_dict[resp_token].append(token_reward)
752 | if 'token_reward' not in index_list:
753 | index_list.append('token_reward')
754 |
755 | if cur_step_filtered_content_dict['log_ratio']:
756 | log_ratio = cur_step_filtered_content_dict['log_ratio'][sample_index][token_idx]
757 | if 'log_ratio' in show_values:
758 | new_dict[resp_token].append(log_ratio)
759 | if 'log_ratio' not in index_list:
760 | index_list.append('log_ratio')
761 |
762 | if cur_step_filtered_content_dict['kl']:
763 | kl = cur_step_filtered_content_dict['kl'][sample_index][token_idx]
764 | if 'kl' in show_values:
765 | new_dict[resp_token].append(kl)
766 | if 'kl' not in index_list:
767 | index_list.append('kl')
768 |
769 | if cur_step_filtered_content_dict['values']:
770 | value = cur_step_filtered_content_dict['values'][sample_index][token_idx]
771 | if 'token_value' in show_values:
772 | new_dict[resp_token].append(value)
773 | if 'token_value' not in index_list:
774 | index_list.append('token_value')
775 |
776 | if cur_step_filtered_content_dict['logprobs']:
777 | logp = cur_step_filtered_content_dict['logprobs'][sample_index][token_idx]
778 | if 'logp' in show_values:
779 | new_dict[resp_token].append(logp)
780 | if 'logp' not in index_list:
781 | index_list.append('logp')
782 |
783 | if cur_step_filtered_content_dict['ref_logprobs']:
784 | ref_logp = cur_step_filtered_content_dict['ref_logprobs'][sample_index][token_idx]
785 | if 'ref_logp' in show_values:
786 | new_dict[resp_token].append(ref_logp)
787 | if 'ref_logp' not in index_list:
788 | index_list.append('ref_logp')
789 |
790 | if cur_step_filtered_content_dict['probs']:
791 | prob = cur_step_filtered_content_dict['probs'][sample_index][token_idx]
792 | if 'prob' in show_values:
793 | new_dict[resp_token].append(prob)
794 | if 'prob' not in index_list:
795 | index_list.append('prob')
796 |
797 | if cur_step_filtered_content_dict['ref_probs']:
798 | ref_prob = cur_step_filtered_content_dict['ref_probs'][sample_index][token_idx]
799 | if 'ref_prob' in show_values:
800 | new_dict[resp_token].append(ref_prob)
801 | if 'ref_prob' not in index_list:
802 | index_list.append('ref_prob')
803 |
804 | try:
805 | token_level_df = pd.DataFrame.from_dict(new_dict)
806 | renamed_index_dict = dict((i, name) for i, name in enumerate(index_list))
807 | token_level_df.rename(
808 | index=renamed_index_dict,
809 | inplace=True
810 | )
811 |
812 | st.dataframe(
813 | token_level_df.style.background_gradient(axis=1, cmap="binary"),
814 | use_container_width=True
815 | )
816 |
817 | if st.session_state['show_token_heat_map']:
818 | fig = px.imshow(
819 | token_level_df,
820 | text_auto=True,
821 | aspect="auto",
822 | color_continuous_scale="balance",
823 | )
824 | fig.update_xaxes(side="top")
825 | st.plotly_chart(fig, theme="streamlit", use_container_width=True)
826 | except Exception as e:
827 | st.error(f'Error occured: {e}.')
828 | st.write(new_dict)
829 |
830 |
831 | if __name__ == '__main__':
832 | init_sidebar()
833 | main_page()
834 |
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | streamlit run rl_logging_board.py --server.port 8901
--------------------------------------------------------------------------------