├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── LICENSE ├── README.md ├── README_ja.md ├── README_zh.md ├── assets ├── YuLan-logo.jpg ├── YuLan-logo.png ├── data-pipeline.png ├── data-preview.png ├── data_distribution_for_every_phase.png ├── main.png └── training-stability.png ├── post_train ├── README.md └── img │ └── result.png └── pretrain ├── README.md ├── configuration_yulanmini.py ├── datasets ├── README.md ├── data_mix │ ├── 01_20241017_013512.json │ ├── 02_20241017_013401.json │ ├── 03_20241020_001556.json │ ├── 04_20241021_170901.json │ ├── 05_20241022_221453.json │ ├── 06_20241024_013137.json │ ├── 07_20241025_022032.json │ ├── 08_20241026_151354.json │ ├── 09_20241027_190948.json │ ├── 10_20241028_225112.json │ ├── 11_20241030_124814.json │ ├── 12_20241101_002827.json │ ├── 13_20241102_160534.json │ ├── 14_20241104_000454.json │ ├── 15_20241105_023029.json │ ├── 16_20241106_180613.json │ ├── 17_20241108_004951.json │ ├── 18_20241113_034017.json │ ├── 19_20241114_115241.json │ ├── 20_20241115_234357.json │ ├── 21_20241117_021115.json │ ├── 22_20241118_155407.json │ ├── 23_20241120_033942.json │ ├── 24_20241121_133110.json │ ├── 25_20241123_030124.json │ ├── 26_20241127_205447.json │ ├── 26_20241211_015209.json │ └── 27_20241213_051741.json ├── download_datasets_step1.sh ├── download_datasets_step3.sh └── final.pdf ├── ds2_config_adamw.json ├── modeling_yulanmini.py ├── preprocess ├── README.md ├── convert_hf_datasets_to_megatron.py ├── mix │ └── update_metadata_from_clipboard.py └── tokenize │ ├── run_tokenize.sh │ ├── split_data.py │ └── tokenize_text.py ├── scripts ├── calc_norm.py ├── convert_yulanmini_to_llama.py └── estimate_mfu.py ├── setup.sh ├── synthesis ├── README.md ├── gen_lean_reasoning.py ├── gen_qwq.py └── gen_vllm.py ├── torchrun_wrapper.sh ├── train.py ├── train.sh ├── train_utils.py ├── yulanmini-2B-final-phase25.sh ├── yulanmini-2B-s25d-decay80-1sqrt-long-28k-final-phase26.sh └── yulanmini_trainer.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Additional context** 23 | Add any other context about the problem here. 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yiwen Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README_ja.md: -------------------------------------------------------------------------------- 1 |
5 | 6 | 7 |
15 | 16 | YuLan-Miniは2.4Bパラメータの軽量な言語モデルです。1.08Tトークンのみを使用して事前トレーニングを行い、特に**数学**と**コード**の分野で、より多くのデータを使用した業界トップのモデルと同等の性能を達成しています。再現性を高めるために、関連する事前トレーニングリソースをオープンソース化します。 17 | 18 | --- 19 | 20 | ## モデルのダウンロード 🔗 21 | 22 | | Model | Context Length | SFT | 🤗 Hugging Face | Wise Model | 23 | |---------|----------------|-----|-----------------|------------| 24 | | YuLan-Mini (Recommended) | 28K | ❎ | [`YuLan-Mini`](https://huggingface.co/yulan-team/YuLan-Mini) | [`YuLan-Mini`](https://wisemodel.cn/models/yulan-team/YuLan-Mini) | 25 | | YuLan-Mini-2.4B-4K | 4K | ❎ | | | 26 | | YuLan-Mini-Instruct | Comming soon | ✅ | | | 27 | 28 | --- 29 | 30 | ## 特徴 🌟 31 | 32 |5 | 6 | 7 |
15 | 16 | YuLan-Mini 是一个 2.4B 参数量的轻量化语言模型。仅使用 1.08T Tokens 进行预训练,却达到了与使用更多数据的行业领先模型相媲美的性能,尤其是 **数学** 和 **代码** 两个领域。为方便复现,我们将开源相关预训练资源。 17 | 18 | --- 19 | 20 | ## 新闻 21 | 22 | - [2025.01.29] YuLan-Mini-Instruct-v1 发布 23 | - [2024.12.23] YuLan-Mini 及预训练资源发布 24 | 25 | ## 模型下载 🔗 26 | 27 | > YuLan-Mini 是 [YuLan 系列](https://github.com/RUC-GSAI/YuLan-Chat) 的一部分,该系列还包括更大规模和不同训练策略的模型。 28 | 29 | | 模型 | 上下文长度 | SFT | 🤗 Hugging Face | ModelScope | Wise Model | 30 | |---------|----------------|-----|-----------------|------------|------------| 31 | | YuLan-Mini | 28K | ❎ | [`Base`](https://huggingface.co/yulan-team/YuLan-Mini) | [`Base`](https://modelscope.cn/models/yulan-team/YuLan-Mini) | [`Base`](https://wisemodel.cn/models/yulan-team/YuLan-Mini) | 32 | | YuLan-Mini-Instruct | 28K | ✅ | [`Instruct`](https://huggingface.co/yulan-team/YuLan-Mini-Instruct) | | | 33 | 34 | > 中间检查点可以在[这里](#%E9%A2%84%E8%AE%AD%E7%BB%83%E8%B5%84%E6%BA%90-)找到。 35 | 36 | --- 37 | 38 | ## 能力介绍 🌟 39 | 40 |config.json
由于 Hugging Face Trainer 的实现,某些参数存储在 config.json
文件中,无法通过 Trainer 的命令行参数进行修改。因此,您需要首先更新 config.json
文件中的这些参数,特别是:
save_steps
:保存中间检查点的频率。train_batch_size
:每个 GPU 的批大小(相当于 Trainer 中的 per_device_train_batch_size
)。在稳定训练阶段,我们使用了 1008 的批大小(大约 4M 个 token)。保持相同的批大小对于训练效果同样重要。以下是一个正确配置的 config.json
文件示例:
{
110 | "best_metric": null,
111 | "best_model_checkpoint": null,
112 | "epoch": 0.0,
113 | "eval_steps": 500,
114 | "global_step": 0,
115 | "is_hyper_param_search": false,
116 | "is_local_process_zero": true,
117 | "is_world_process_zero": true,
118 | "log_history": [],
119 | "logging_steps": 3,
120 | "max_steps": 0,
121 | "num_input_tokens_seen": 0,
122 | "num_train_epochs": 0,
123 | "save_steps": 250,
124 | "stateful_callbacks": {
125 | "TrainerControl": {
126 | "args": {
127 | "should_epoch_stop": false,
128 | "should_evaluate": false,
129 | "should_log": false,
130 | "should_save": true,
131 | "should_training_stop": true
132 | },
133 | "attributes": {}
134 | }
135 | },
136 | "total_flos": 0,
137 | "train_batch_size": 3,
138 | "trial_name": null,
139 | "trial_params": null
140 | }
141 |
142 | 为了确保 DeepSpeed 集成加载通用检查点,您需要在 DeepSpeed 配置 JSON 文件中启用此功能。
144 |以下是一个启用了通用检查点的 ZeRO2 配置示例:
145 |{
146 | "bf16": {
147 | "enabled": "auto"
148 | },
149 | "zero_optimization": {
150 | "stage": 2,
151 | "allgather_partitions": true,
152 | "allgather_bucket_size": 8e8,
153 | "overlap_comm": true,
154 | "reduce_scatter": true,
155 | "reduce_bucket_size": 8e8,
156 | "contiguous_gradients": true
157 | },
158 | "gradient_accumulation_steps": "auto",
159 | "gradient_clipping": "auto",
160 | "steps_per_print": 16,
161 | "train_batch_size": "auto",
162 | "train_micro_batch_size_per_gpu": "auto",
163 | "wall_clock_breakdown": false,
164 | "dump_state": true,
165 | "optimizer": {
166 | "type": "AdamW",
167 | "params": {
168 | "lr": "auto",
169 | "betas": "auto",
170 | "eps": "auto",
171 | "weight_decay": "auto"
172 | }
173 | },
174 | "checkpoint": {
175 | "load_universal": true
176 | }
177 | }
178 |
179 | 调用 trainer.train
时,包含 resume_from_checkpoint
参数以从通用检查点加载分布式优化器状态并恢复训练。
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
182 |
183 | 我们提供了一个内部训练框架供您参考,但您可以自由选择其他框架。
184 | 185 |阶段 | 194 |课程阶段 | 195 |4K 上下文 | 196 |28K 上下文 | 197 |优化器 | 198 |推理架构 | 199 |LAMBADA Acc |
200 | GSM8K Acc |
201 | HumanEval pass@1 |
202 |
---|---|---|---|---|---|---|---|---|
稳定 | 207 |5 | 208 |YuLan-Mini-Phase5 | 209 |210 | | 211 | | yulanmini |
212 | 53.85 | 213 |3.41 | 214 |12.26 | 215 |
稳定 | 218 |10 | 219 |YuLan-Mini-Phase10 | 220 |221 | | 222 | | yulanmini |
223 | 55.00 | 224 |9.57 | 225 |15.95 | 226 |
稳定 | 229 |15 | 230 |YuLan-Mini-Phase15 | 231 |232 | | 233 | | yulanmini |
234 | 55.81 | 235 |13.81 | 236 |16.99 | 237 |
稳定 | 240 |20 | 241 |YuLan-Mini-Phase20 | 242 |243 | | ✅ | 244 |yulanmini |
245 | 55.81 | 246 |21.39 | 247 |20.79 | 248 |
稳定 | 251 |25 (1T tokens) | 252 |YuLan-Mini-Before-Annealing | 253 |254 | | ✅ | 255 |yulanmini |
256 | 55.67 | 257 |29.94 | 258 |34.06 | 259 |
262 | | 263 | | 264 | | 265 | | 266 | | 267 | | 268 | | 269 | | 270 | |
退火 | 273 |26 | 274 |YuLan-Mini-4K | 275 |276 | | 277 | | llama * |
278 | 64.72 | 279 |66.65 | 280 |61.60 | 281 |
退火 | 284 |27 | 285 |286 | | YuLan-Mini | 287 |288 | | llama * |
289 | 65.67 | 290 |68.46 | 291 |64.00 | 292 |
flash_attn
and liger_kernel
libraries, we achieved 51% MFU (in comparison, Megatron only has about 41% MFU on small models of the same scale).├── train.py # 👈🏻 The main training script
21 | ├── train.sh # 👈🏻 The main training script for each curriculum phase
22 | ├── yulanmini-2B-final-phase25.sh # 👈🏻 example script for phase 25
23 | ├── yulanmini-2B-s25d-decay80-1sqrt-long-28k-final-phase26.sh # 👈🏻 example script for phase 26
24 | ├── ds2_config_adamw.json # The DeepSpeed configuration file
25 | ├── setup.sh # The setup script for the training environment
26 | ├── torchrun_wrapper.sh # The wrapper script for torchrun
27 | ├── train_utils.py # The training utility functions
28 | └── yulanmini_trainer.py # 👈🏻 The Trainer class for training
29 |
30 |
31 | trainer_state.json
Due to the implementation of Hugging Face Trainer, certain parameters are stored in the trainer_state.json
file and cannot be modified through the Trainer's command-line arguments. Therefore, you need to update these parameters in the trainer_state.json
file first, particularly:
save_steps
: The frequency of saving intermediate checkpoints.train_batch_size
: The batch size per GPU (equivalent to per_device_train_batch_size
in the Trainer). We used a batch size of 1008 (approximately 4M tokens) during the stable training stage. Maintaining this same batch size is equally important for training effectiveness.Below is an example of a properly configured trainer_state.json
file:
{
40 | "best_metric": null,
41 | "best_model_checkpoint": null,
42 | "epoch": 0.0,
43 | "eval_steps": 500,
44 | "global_step": 0,
45 | "is_hyper_param_search": false,
46 | "is_local_process_zero": true,
47 | "is_world_process_zero": true,
48 | "log_history": [],
49 | "logging_steps": 3,
50 | "max_steps": 0,
51 | "num_input_tokens_seen": 0,
52 | "num_train_epochs": 0,
53 | "save_steps": 250,
54 | "stateful_callbacks": {
55 | "TrainerControl": {
56 | "args": {
57 | "should_epoch_stop": false,
58 | "should_evaluate": false,
59 | "should_log": false,
60 | "should_save": true,
61 | "should_training_stop": true
62 | },
63 | "attributes": {}
64 | }
65 | },
66 | "total_flos": 0,
67 | "train_batch_size": 3,
68 | "trial_name": null,
69 | "trial_params": null
70 | }
71 |
72 | To ensure DeepSpeed Integration loads the Universal Checkpoint, you need to enable this feature in the DeepSpeed configuration JSON file.
74 |Here is an example of a ZeRO2 configuration with Universal Checkpointing enabled:
75 |{
76 | "bf16": {
77 | "enabled": "auto"
78 | },
79 | "zero_optimization": {
80 | "stage": 2,
81 | "allgather_partitions": true,
82 | "allgather_bucket_size": 8e8,
83 | "overlap_comm": true,
84 | "reduce_scatter": true,
85 | "reduce_bucket_size": 8e8,
86 | "contiguous_gradients": true
87 | },
88 | "gradient_accumulation_steps": "auto",
89 | "gradient_clipping": "auto",
90 | "steps_per_print": 16,
91 | "train_batch_size": "auto",
92 | "train_micro_batch_size_per_gpu": "auto",
93 | "wall_clock_breakdown": false,
94 | "dump_state": true,
95 | "optimizer": {
96 | "type": "AdamW",
97 | "params": {
98 | "lr": "auto",
99 | "betas": "auto",
100 | "eps": "auto",
101 | "weight_decay": "auto"
102 | }
103 | },
104 | "checkpoint": {
105 | "load_universal": true
106 | }
107 | }
108 |
109 | When calling trainer.train
, include the resume_from_checkpoint
argument to load the distributed optimizer state from the Universal Checkpoint and resume training.
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
112 |
113 | We provide an internal training framework for your reference, but you are free to choose other frameworks.
114 | 115 |Stage | 124 |Curriculum Phase | 125 |4K Context | 126 |28K Context | 127 |Optimizer | 128 |Inference Architecture | 129 |LAMBADA Acc |
130 | GSM8K Acc |
131 | HumanEval pass@1 |
132 |
---|---|---|---|---|---|---|---|---|
Stable | 137 |5 | 138 |YuLan-Mini-Phase5 | 139 |140 | | 141 | | yulanmini |
142 | 53.85 | 143 |3.41 | 144 |12.26 | 145 |
Stable | 148 |10 | 149 |YuLan-Mini-Phase10 | 150 |151 | | 152 | | yulanmini |
153 | 55.00 | 154 |9.57 | 155 |15.95 | 156 |
Stable | 159 |15 | 160 |YuLan-Mini-Phase15 | 161 |162 | | 163 | | yulanmini |
164 | 55.81 | 165 |13.81 | 166 |16.99 | 167 |
Stable | 170 |20 | 171 |YuLan-Mini-Phase20 | 172 |173 | | ✅ | 174 |yulanmini |
175 | 55.81 | 176 |21.39 | 177 |20.79 | 178 |
Stable | 181 |25 (1T tokens) | 182 |YuLan-Mini-Before-Annealing | 183 |184 | | ✅ | 185 |yulanmini |
186 | 55.67 | 187 |29.94 | 188 |34.06 | 189 |
192 | | 193 | | 194 | | 195 | | 196 | | 197 | | 198 | | 199 | | 200 | |
Annealing | 203 |26 | 204 |YuLan-Mini-4K | 205 |206 | | 207 | | llama * |
208 | 64.72 | 209 |66.65 | 210 |61.60 | 211 |
Annealing | 214 |27 | 215 |216 | | YuLan-Mini | 217 |218 | | llama * |
219 | 65.67 | 220 |68.46 | 221 |64.00 | 222 |
Optional[bool]: 171 | """Exit the context introduced by the 'with' keyword 172 | 173 | Args: 174 | exc_type (Optional[Type[BaseException]]): Exception type 175 | 176 | exc_val (Optional[BaseException]): Exception value 177 | 178 | exc_tb (Optional[TracebackType]): Exception traceback object 179 | 180 | Returns: 181 | Optional[bool]: Whether to silence the exception 182 | """ 183 | self.idx_writer.close() 184 | 185 | def write( 186 | self, 187 | sequence_lengths: List[int], 188 | sequence_modes: Optional[List[int]], 189 | document_indices: List[int], 190 | ) -> None: 191 | """Write the index (.idx) file 192 | 193 | Args: 194 | sequence_lengths (List[int]): The length of each sequence 195 | 196 | sequence_modes (Optional[List[int]]): The mode of each sequences 197 | 198 | document_indices (List[int]): The seqyebce indices demarcating the end of each document 199 | """ 200 | sequence_pointers = self._sequence_pointers(sequence_lengths) 201 | 202 | # the number of sequences in the dataset 203 | sequence_count = len(sequence_lengths) 204 | self.idx_writer.write(struct.pack("List[int]: 231 | """Build the sequence pointers per the sequence lengths and dtype size 232 | 233 | Args: 234 | sequence_lengths (List[int]): The length of each sequence 235 | 236 | Returns: 237 | List[int]: The pointer to the beginning of each sequence 238 | """ 239 | itemsize = DType.size(self.dtype) 240 | curr_ptr = 0 241 | list_ptr = [] 242 | for length in sequence_lengths: 243 | list_ptr.append(curr_ptr) 244 | curr_ptr += length * itemsize 245 | return list_ptr 246 | 247 | 248 | class _IndexReader(object): 249 | """Object class to read the index (.idx) file 250 | 251 | Args: 252 | idx_path (str): The path to the index file 253 | 254 | multimodal (bool): Whether the dataset is multimodal 255 | """ 256 | 257 | def __init__(self, idx_path: str, multimodal: bool) -> None: 258 | 259 | log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} from {idx_path}") 260 | 261 | with open(idx_path, "rb") as stream: 262 | header = stream.read(9) 263 | assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}" 264 | 265 | version = struct.unpack("time elapsed: {t_end - t_beg:4f} seconds") 287 | 288 | log_single_rank(logger, logging.INFO, f"\tExtract the sequence pointers") 289 | t_beg = time.time() 290 | self.sequence_pointers = numpy.frombuffer( 291 | self.bin_buffer, 292 | dtype=numpy.int64, 293 | count=self.sequence_count, 294 | offset=offset + self.sequence_lengths.nbytes, 295 | ) 296 | t_end = time.time() 297 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") 298 | 299 | log_single_rank(logger, logging.INFO, f"\tExtract the document indices") 300 | t_beg = time.time() 301 | self.document_indices = numpy.frombuffer( 302 | self.bin_buffer, 303 | dtype=numpy.int64, 304 | count=self.document_count, 305 | offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes, 306 | ) 307 | t_end = time.time() 308 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") 309 | 310 | self.sequence_modes = None 311 | if multimodal: 312 | log_single_rank(logger, logging.INFO, f"\tExtract the sequence modes") 313 | t_beg = time.time() 314 | self.sequence_modes = numpy.frombuffer( 315 | self.bin_buffer, 316 | dtype=numpy.int8, 317 | count=self.sequence_count, 318 | offset=offset 319 | + self.sequence_lengths.nbytes 320 | + self.sequence_pointers.nbytes 321 | + self.document_indices.nbytes, 322 | ) 323 | t_end = time.time() 324 | log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") 325 | 326 | assert self.sequence_lengths.shape[0] == len(self) 327 | assert self.sequence_lengths.shape[0] == self.sequence_count 328 | assert self.sequence_lengths.shape[0] == self.document_indices[-1] 329 | 330 | log_single_rank(logger, logging.INFO, f"> total number of sequences: {len(self)}") 331 | log_single_rank( 332 | logger, 333 | logging.INFO, 334 | f"> total number of documents: {self.document_indices.shape[0] - 1}", 335 | ) 336 | 337 | def __del__(self) -> None: 338 | """Clean up the object""" 339 | if hasattr(self, "bin_buffer_mmap"): 340 | self.bin_buffer_mmap._mmap.close() 341 | del self.bin_buffer_mmap 342 | 343 | def __len__(self) -> int: 344 | """Return the length of the dataset 345 | 346 | Returns: 347 | int: The length of the dataset 348 | """ 349 | return self.sequence_count 350 | 351 | @lru_cache(maxsize=8) 352 | def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: 353 | """Return the pointer, length, and mode at the index 354 | 355 | Args: 356 | idx (int): The index into the dataset 357 | 358 | Returns: 359 | Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at the index 360 | """ 361 | return ( 362 | self.sequence_pointers[idx], 363 | self.sequence_lengths[idx], 364 | self.sequence_modes[idx] if self.sequence_modes is not None else None, 365 | ) 366 | 367 | 368 | class IndexedDatasetBuilder(object): 369 | """Builder class for the IndexedDataset class 370 | 371 | Args: 372 | bin_path (str): The path to the data (.bin) file 373 | 374 | dtype (Type[numpy.number], optional): The dtype of the index file. Defaults to numpy.int32. 375 | 376 | multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False. 377 | """ 378 | 379 | def __init__( 380 | self, bin_path: str, dtype: Type[numpy.number] = numpy.int32 381 | ) -> None: 382 | self.data_file = open(bin_path, "wb") 383 | self.dtype = dtype 384 | 385 | self.sequence_lengths = [] 386 | self.document_indices = [0] 387 | 388 | def add_document( 389 | self, input_ids: List[int], token_length: int, modes: Optional[List[int]] = None 390 | ) -> None: 391 | """Add an entire document to the dataset 392 | 393 | Args: 394 | tensor (torch.Tensor): The document to add 395 | 396 | lengths (List[int]): The lengths of each item in the document 397 | 398 | modes (Optional[List[int]], optional): The modes for each item in the document. Defaults to None. 399 | """ 400 | np_array = numpy.array(input_ids, dtype=self.dtype) 401 | self.data_file.write(np_array.tobytes(order="C")) 402 | self.sequence_lengths.extend([token_length]) 403 | self.document_indices.append(len(self.sequence_lengths)) 404 | 405 | def add_index(self, path_prefix: str) -> None: 406 | """Add an entire IndexedDataset to the dataset 407 | 408 | Args: 409 | path_prefix (str): The index (.idx) and data (.bin) prefix 410 | """ 411 | # Concatenate index 412 | index = _IndexReader(get_idx_path(path_prefix), multimodal=False) 413 | assert index.dtype == self.dtype 414 | 415 | offset = len(self.sequence_lengths) 416 | self.sequence_lengths.extend(index.sequence_lengths) 417 | self.document_indices.extend((offset + index.document_indices)[1:]) 418 | 419 | # Concatenate data 420 | with open(get_bin_path(path_prefix), "rb") as f: 421 | shutil.copyfileobj(f, self.data_file) 422 | 423 | def finalize(self, idx_path: str) -> None: 424 | """Clean up and write the index (.idx) file 425 | 426 | Args: 427 | idx_path (str): The path to the index file 428 | """ 429 | self.data_file.close() 430 | with _IndexWriter(idx_path, self.dtype) as writer: 431 | writer.write(self.sequence_lengths, None, self.document_indices) 432 | 433 | 434 | def process_partition(kwargs): 435 | 436 | partition: str = kwargs['partition'] 437 | output_prefix: str = kwargs['output_prefix'] 438 | json_keys: List[str] = kwargs['json_keys'] 439 | 440 | output_bin_files = {} 441 | output_idx_files = {} 442 | builders = {} 443 | level = 'document' 444 | for key in json_keys: 445 | output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix, 446 | key, level) 447 | output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix, 448 | key, level) 449 | builders[key] = IndexedDatasetBuilder( 450 | output_bin_files[key], 451 | dtype=DType.optimal_dtype(VOCAB_SIZE), 452 | ) 453 | 454 | encoded_docs = datasets.load_dataset(partition, split='train', streaming=True) 455 | for json_dict in encoded_docs: 456 | for key in json_keys: 457 | content = json_dict[key] 458 | builders[key].add_document(content, len(content)) 459 | for key in json_keys: 460 | builders[key].finalize(output_idx_files[key]) 461 | 462 | 463 | def convert_hf_dataset(dataset_path: str, json_keys: List[str], output_prefix: str, num_workers: int = 16): 464 | 465 | output_bin_files = {} 466 | output_idx_files = {} 467 | builders = {} 468 | level = 'document' 469 | 470 | dataset_names = sorted(os.listdir(dataset_path)) 471 | dataset_names = [d for d in dataset_names if not d.startswith('.')] 472 | in_ss_out_names = [{'partition': os.path.join(dataset_path, d), 'output_prefix': d, 'json_keys': json_keys} for d in dataset_names] 473 | 474 | # process the dataset in parallel 475 | with ProcessPoolExecutor(num_workers) as executor: 476 | p = executor.map(process_partition, in_ss_out_names) 477 | for _ in tqdm(p, total=len(in_ss_out_names), desc="Processing dataset"): 478 | pass 479 | 480 | # collect different subsets into the same bin file 481 | for key in json_keys: 482 | output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix, 483 | key, level) 484 | output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix, 485 | key, level) 486 | builders[key] = IndexedDatasetBuilder( 487 | output_bin_files[key], 488 | dtype=DType.optimal_dtype(VOCAB_SIZE), 489 | ) 490 | 491 | for name in in_ss_out_names: 492 | parition_output_prefix = name['output_prefix'] 493 | full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, 494 | key, level) 495 | builders[key].add_index(full_partition_output_prefix) 496 | builders[key].finalize(output_idx_files[key]) 497 | 498 | 499 | if __name__ == "__main__": 500 | convert_hf_dataset('/data/hf_dataset/myl_new_no_math/17_20241108_004951', ['input_ids'], '17_20241108_004951') 501 | -------------------------------------------------------------------------------- /pretrain/preprocess/mix/update_metadata_from_clipboard.py: -------------------------------------------------------------------------------- 1 | datasets = [] 2 | during_dataset = False 3 | DATASET_COLUMN = 'DatasetName' 4 | 5 | subsets = [] 6 | during_subset = False 7 | SUBSET_COLUMN = 'SubsetName' 8 | 9 | sfts = [] 10 | during_sft = False 11 | ISSFT_COLUMN = 'IsSFT' 12 | 13 | print(f"Paste the '{DATASET_COLUMN}' column") 14 | 15 | while True: 16 | 17 | dataset = input() 18 | 19 | if dataset == "END_OF_DATASET": 20 | datasets.append(dataset) 21 | with open("datasets.txt", "w") as f: 22 | f.write("\n".join(datasets)) 23 | print(f"'{DATASET_COLUMN}' column saved. Then you can paste the '{SUBSET_COLUMN}' column.") 24 | 25 | elif dataset == "END_OF_SUBSET": 26 | subsets.append(dataset) 27 | with open("subsets.txt", "w") as f: 28 | f.write("\n".join(subsets)) 29 | print(f"'{SUBSET_COLUMN}' column saved. Then you can paste the '{ISSFT_COLUMN}' column.") 30 | 31 | elif dataset == "END_OF_SFT": 32 | sfts.append(dataset) 33 | with open("sfts.txt", "w") as f: 34 | f.write("\n".join(sfts)) 35 | print("'{ISSFT_COLUMN}' column saved. then you can press Ctrl+C to exit") 36 | 37 | elif dataset == DATASET_COLUMN: 38 | during_dataset = True 39 | continue 40 | 41 | elif dataset == SUBSET_COLUMN: 42 | during_subset = True 43 | continue 44 | 45 | elif dataset == ISSFT_COLUMN: 46 | during_sft = True 47 | continue 48 | 49 | if during_dataset: 50 | datasets.append(dataset.strip()) 51 | 52 | if during_subset: 53 | subsets.append(dataset.strip()) 54 | 55 | if during_sft: 56 | sfts.append(dataset.strip()) 57 | -------------------------------------------------------------------------------- /pretrain/preprocess/tokenize/run_tokenize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_path=$1 4 | 5 | tokenizer_path=6 | num_file=10000 7 | num_worker=8 8 | # num_file means how many jsonl/json/parquet files to tokenize at once (to avoid memory overflow). If num_file < actual number of files, simply run the script multiple times to tokenize all files. 9 | 10 | export RAW_DATA_PREFIX="/data/raw" 11 | export INPUT_IDS_PREFIX="/data/input_ids" 12 | # target save path for tokenized data. The tokenized data will retain the same directory structure as the raw data. 13 | 14 | echo num_file=$num_file num_worker=$num_worker 15 | 16 | # check if data_path exists 17 | if [ ! -d "$data_path" ]; then 18 | echo "$data_path does not exist." 19 | exit 20 | else 21 | echo $data_path 22 | fi 23 | 24 | 25 | python tokenize/tokenize_text.py \ 26 | --tokenizer_path $tokenizer_path \ 27 | --data_path $data_path \ 28 | --model_name mini \ 29 | --num_file $num_file \ 30 | --text_key text \ 31 | --num_worker $num_worker \ 32 | --skip_exist True 33 | 34 | # split data by 0.01B tokens 35 | python tokenize/split_data.py $data_path 36 | 37 | # delete intermediate tokenization files 38 | cat datasets_to_delete.txt | xargs -I {} rm {} 39 | rm datasets_to_delete.txt 40 | 41 | # incase of missing deletion 42 | if [ -n "$(find . -type f -regex '.*part-[0-9]+\.jsonl')" ]; then 43 | find . -type f -regex '.*part-[0-9]+\.jsonl' 44 | echo "Please check the intermediate part-xx.jsonl files listed above and delete them manually." 45 | fi 46 | -------------------------------------------------------------------------------- /pretrain/preprocess/tokenize/split_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import pathlib 5 | import numpy as np 6 | import sys 7 | import threading 8 | from tqdm import tqdm 9 | import multiprocessing as mp 10 | import pyarrow 11 | from pyarrow import parquet as pq 12 | 13 | # split data by 0.01B tokens (this is a soft limit) 14 | MAX_TOKEN = int(0.01 * 1000 * 1000 * 1000) 15 | 16 | if len(sys.argv) >= 3: 17 | father_datasets = list(sys.argv)[2:] 18 | datasets_to_delete = sys.argv[1] 19 | else: 20 | father_datasets = list(sys.argv)[1:] 21 | datasets_to_delete = "datasets_to_delete.txt" 22 | print("father_datasets", father_datasets) 23 | print("datasets_to_delete", datasets_to_delete) 24 | 25 | metadata_columns = ["source"] 26 | 27 | # replace raw data path with input_ids path 28 | raw_data_prefix = os.environ["RAW_DATA_PREFIX"] 29 | input_ids_prefix = os.environ["INPUT_IDS_PREFIX"] 30 | 31 | father_datasets = [ 32 | i.replace(raw_data_prefix, input_ids_prefix) 33 | for i in father_datasets 34 | ] 35 | 36 | 37 | def warn(msg): 38 | print("\033[0;33m" + msg + "\033[0m") 39 | 40 | 41 | def process_file(src_folder, src_file, print_id, write_format="parquet", is_last=False, last_states=None, max_part=-1): 42 | 43 | def write_to_file(all_data, num_tokens: int, cur_idx: int, metadata: list): 44 | """Write splitted data and metadata.""" 45 | 46 | cur_idx = f"{cur_idx:04d}" 47 | if write_format == "jsonl": 48 | tgt_path = os.path.join(src_folder, "splitted_part-{}.jsonl".format(cur_idx)) 49 | print(print_id, "updating", tgt_path, num_tokens) 50 | with open(tgt_path, "w") as fout: 51 | for tmp_data in all_data: 52 | fout.write(json.dumps({"input_ids": tmp_data}, ensure_ascii=False) + "\n") 53 | 54 | elif write_format == "parquet": 55 | tgt_path = os.path.join(src_folder, "splitted_part-{}.parquet".format(cur_idx)) 56 | print(print_id, "updating", tgt_path, num_tokens) 57 | arr = pyarrow.array(all_data) 58 | pq.write_table(pyarrow.Table.from_arrays([arr], ["input_ids"]), tgt_path) 59 | 60 | tokens_num_tgt_path = os.path.join(src_folder, "splitted_part-{}-metadata.json".format(cur_idx)) 61 | with open(tokens_num_tgt_path, "w") as fout: 62 | json.dump({"total_tokens_num": num_tokens, "metadata": metadata}, fout, indent=2) 63 | 64 | def load_data_jsonl(fin): 65 | """Read one line and return as parsed json data.""" 66 | data = fin.readline().strip() 67 | if not data: 68 | return None 69 | else: 70 | json_data = json.loads(data) 71 | input_ids = json_data["input_ids"] 72 | meta_data = {col: json_data[col] for col in metadata_columns if col in json_data} 73 | meta_data["num_tokens"] = len(input_ids) 74 | return (input_ids, meta_data) 75 | 76 | all_data = [] 77 | metadata = [] 78 | num_tokens = 0 79 | cur_idx = max_part + 1 80 | if last_states is not None: 81 | all_data, num_tokens, cur_idx, metadata = last_states 82 | 83 | if src_file.endswith(".parquet"): 84 | 85 | warn("Deprecated: parquet file as intermediate format is deprecated. Please use jsonl format instead.") 86 | 87 | # parquet read 88 | table = pq.read_table(src_file) 89 | all_all_data = table["input_ids"].to_pylist() 90 | for ids in all_all_data: 91 | all_data.append(ids) 92 | metadata.append({"num_tokens": len(ids)}) 93 | num_tokens += len(ids) 94 | if num_tokens > MAX_TOKEN: 95 | # flush new splitted data to file 96 | write_to_file(all_data, num_tokens, cur_idx, metadata) 97 | all_data = [] 98 | metadata = [] 99 | num_tokens = 0 100 | cur_idx += 1 101 | print(print_id, "next split", cur_idx) 102 | 103 | # trailing lines 104 | if len(all_data) > 0 and is_last: 105 | write_to_file(all_data, num_tokens, cur_idx, metadata) 106 | 107 | elif src_file.endswith(".jsonl"): 108 | 109 | # jsonl read line by line 110 | with open(src_file) as fin: 111 | while True: 112 | data = load_data_jsonl(fin) 113 | if data is None: 114 | break 115 | 116 | # add data 117 | all_data.append(data[0]) 118 | metadata.append(data[1]) 119 | num_tokens += data[1]["num_tokens"] 120 | if num_tokens > MAX_TOKEN: 121 | # flush new splitted data to file 122 | write_to_file(all_data, num_tokens, cur_idx, metadata) 123 | all_data = [] 124 | metadata = [] 125 | num_tokens = 0 126 | cur_idx += 1 127 | 128 | # trailing lines of whole wo_ppl folder 129 | if len(all_data) > 0 and is_last: 130 | write_to_file(all_data, num_tokens, cur_idx, metadata) 131 | 132 | with open(datasets_to_delete, "a") as f: 133 | f.write(src_file + "\n") 134 | print(print_id, src_file, "added to delete list") 135 | 136 | # pass states to next file 137 | return all_data, num_tokens, cur_idx, metadata 138 | 139 | 140 | def do_parts(src_folder, src_files, max_part: int): 141 | """Process all parts (each part is a tokenized dataset generated by ONE thread in `tokenize_text.py`) in one folder.""" 142 | 143 | last_states = None 144 | sort_files = sorted(src_files, key=lambda x: int(x.split("-")[-1].split(".")[0])) 145 | length = len(sort_files) 146 | for idx, src_file in enumerate(sort_files): 147 | last_states = process_file(src_folder, src_file, last_states=last_states, is_last=(idx == length - 1), max_part=max_part, print_id=os.getpid()) 148 | 149 | 150 | def process_dataset(fd): 151 | datasets = os.listdir(fd) 152 | folder2file = {} 153 | for dataset_name in tqdm(datasets): 154 | raw_src_folder = os.path.join(fd, dataset_name) 155 | print("Finding intermediate results in {} ...".format(raw_src_folder)) 156 | 157 | try: 158 | for root_dir, _, files in os.walk(raw_src_folder, topdown=False): 159 | max_part = max([int(fp.split("-")[-1].split(".")[0]) for fp in files if "splitted_part" in fp and "metadata" not in fp], default=-1) 160 | for fp in files: 161 | if "sort" in fp or "splitted_part" in fp: 162 | continue 163 | if not fp.endswith(".jsonl") and not fp.endswith(".parquet"): 164 | continue 165 | if root_dir not in folder2file: 166 | folder2file[root_dir] = ([], max_part) 167 | folder2file[root_dir][0].append(os.path.join(root_dir, fp)) 168 | 169 | except FileNotFoundError as e: 170 | print("Error Dataset: {} ({})".format(dataset_name, e)) 171 | continue 172 | except NotADirectoryError as e: 173 | print("Error Dataset: {} ({})".format(dataset_name, e)) 174 | continue 175 | 176 | if len(folder2file) == 0: 177 | print("Error Dataset: {} (len(folder2file) == 0)".format(dataset_name)) 178 | continue 179 | 180 | # process all files in parallel 181 | folder_n = len(folder2file) 182 | p = mp.Pool(32) 183 | for idx, (src_folder, (src_files, max_part)) in enumerate(folder2file.items()): 184 | print(f"Splitting {idx + 1} / {folder_n}", src_folder, len(src_files)) 185 | p.apply_async(do_parts, args=(src_folder, src_files, max_part)) 186 | p.close() 187 | p.join() 188 | 189 | warn(f"finished {raw_src_folder}") 190 | 191 | 192 | if __name__ == "__main__": 193 | try: 194 | for fd in father_datasets: 195 | process_dataset(fd) 196 | except (Exception, KeyboardInterrupt) as e: 197 | warn(f"Early abortion. Please delete manully files in {datasets_to_delete}") 198 | raise e 199 | -------------------------------------------------------------------------------- /pretrain/preprocess/tokenize/tokenize_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing as mp 4 | import os 5 | import pathlib 6 | import random 7 | import re 8 | import signal 9 | from copy import deepcopy 10 | 11 | import numpy as np 12 | import pyarrow 13 | from pyarrow import parquet as pq 14 | from tqdm import tqdm, trange 15 | 16 | from transformers import AutoTokenizer 17 | 18 | random.seed(42) 19 | 20 | # Max line per file to tokenize. In case of OOM 21 | MAX_DATA = int(1e7) 22 | 23 | # replace raw data path with input_ids path 24 | raw_data_prefix = os.environ["RAW_DATA_PREFIX"] 25 | input_ids_prefix = os.environ["INPUT_IDS_PREFIX"] 26 | 27 | SKIP_TOKENIZATION_EXTENTIONS = {".py", ".git", ".md", ".png", ".jpg"} 28 | 29 | 30 | def get_tgt_folder(file_path, model_name): 31 | """Each jsonl or parquet file will generate a folder with the same name.""" 32 | 33 | # token id folder directory 34 | file_path = file_path.replace(raw_data_prefix, 35 | input_ids_prefix) 36 | 37 | # remove the file extension 38 | tgt_folder = file_path[:file_path.rfind(".")] 39 | tgt_folder = os.path.join(tgt_folder, "wo_ppl") 40 | if os.path.exists(tgt_folder) == True: 41 | is_exists = True 42 | else: 43 | is_exists = False 44 | pathlib.Path(tgt_folder).mkdir(parents=True, exist_ok=True) 45 | return tgt_folder, is_exists 46 | 47 | 48 | def warn(msg): 49 | print("\033[0;33m" + str(msg) + "\033[0m") 50 | 51 | 52 | def clean_fn(text: str) -> str: 53 | """Data cleaning function. Important notice: this function applies to ALL the text data.""" 54 | if not isinstance(text, str): 55 | warn(f"Type Error: {type(text)} {str(text)[:10]}...") 56 | text = str(text) 57 | 58 | text = text.replace("\u3000", " ") # remove wide space 59 | 60 | return text 61 | 62 | 63 | def tokenize_text(dataset, 64 | src_folder, 65 | file_nos, 66 | tgt_folder, 67 | idx, 68 | text_key, 69 | is_first, 70 | skip_exists: bool = False): 71 | tgt_path = os.path.join(tgt_folder, "part-{}.jsonl".format(idx)) 72 | if is_first == False: 73 | write_mode = "a" 74 | else: 75 | if skip_exists and os.path.exists(tgt_path): 76 | warn(f"skip tokenizing {tgt_path}") 77 | return 78 | write_mode = "w" 79 | 80 | batch_size = 1000 81 | with open(tgt_path, write_mode) as fout: 82 | for batch_st in tqdm(range(0, len(dataset), batch_size)): 83 | batch_data = dataset[batch_st:batch_st + batch_size] 84 | batch_file_nos = file_nos[batch_st:batch_st + batch_size] 85 | input_ids = tokenizer([clean_fn(data[text_key]) for data in batch_data], 86 | add_special_tokens=False)["input_ids"] 87 | for ipts, no in zip(input_ids, batch_file_nos): 88 | new_data = {"input_ids": ipts, "source": f"{src_folder}:{no}"} 89 | fout.write(json.dumps(new_data, ensure_ascii=False) + "\n") 90 | 91 | 92 | wanna_exit = False 93 | 94 | 95 | def interrupt_handler(signum, frame, ask=True): 96 | print("Ctrl+C pressed. Waiting for the current process to be finished.") 97 | global wanna_exit 98 | wanna_exit = True 99 | 100 | 101 | def start_mp(dataset, is_first, src_folder, file_nos): 102 | """dataset: List[Dict[str, str]]""" 103 | 104 | if len(dataset) == 0: 105 | warn("len(dataset) == 0") 106 | return 107 | if not isinstance(dataset, list): 108 | warn("not isinstance(dataset, list)") 109 | return 110 | try: 111 | assert args.text_key in dataset[0] 112 | text_key = args.text_key 113 | except AssertionError: 114 | warn(f"Available Keys: {dataset[0].keys()}") 115 | raise Exception("Unknown Key!") 116 | 117 | seed = random.random() 118 | def sample_seed(): 119 | return seed 120 | 121 | # shuffle again 122 | # random.shuffle(dataset, sample_seed) 123 | # random.shuffle(file_nos, sample_seed) 124 | random.shuffle(dataset) 125 | random.shuffle(file_nos) 126 | 127 | 128 | part_num = args.num_worker 129 | slice_idx = np.linspace(0, len(dataset), part_num + 1).astype("int") 130 | p = mp.Pool(part_num) 131 | for start_id in range(part_num): 132 | start, end = slice_idx[start_id], slice_idx[start_id + 1] 133 | new_lines = dataset[start:end] 134 | p.apply_async(tokenize_text, 135 | args=(new_lines, src_folder, file_nos, tgt_folder, start_id, text_key, 136 | is_first)) 137 | p.close() 138 | p.join() 139 | print("All of the child processes over!") 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--tokenizer_path", type=str) 145 | parser.add_argument("--model_name", type=str) 146 | parser.add_argument("--data_path", type=str) 147 | parser.add_argument("--num_files", type=int) 148 | parser.add_argument("--text_key", type=str) 149 | parser.add_argument("--num_worker", type=int) 150 | parser.add_argument("--skip_exist", type=bool, default=False) 151 | parser.add_argument("--skip_exists", type=bool, default=False) 152 | args = parser.parse_args() 153 | 154 | # load tokenizer 155 | kwargs = {} 156 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, **kwargs) 157 | 158 | # register signal handler 159 | signal.signal(signal.SIGINT, interrupt_handler) 160 | 161 | # start tokenization 162 | for root, _, files in os.walk(args.data_path, topdown=False): 163 | step = 0 164 | random.shuffle(files) 165 | for fp in tqdm(files): 166 | if wanna_exit: 167 | print("Tokenization done.") 168 | break 169 | 170 | file_path = os.path.join(root, fp) 171 | 172 | # check file extention 173 | skip_tokenization = False 174 | for ext in SKIP_TOKENIZATION_EXTENTIONS: 175 | if file_path.endswith(ext): 176 | skip_tokenization = True 177 | break 178 | if skip_tokenization: 179 | continue 180 | 181 | # check target folder existance 182 | tgt_folder, is_exists = get_tgt_folder(file_path, args.model_name) 183 | if is_exists and args.skip_exist: 184 | warn(f"skip {fp}") 185 | continue 186 | 187 | print("Tokenizing {}".format(file_path)) 188 | print("Target Folder: {}".format(tgt_folder)) 189 | 190 | fin = open(file_path, "r") 191 | is_jsonl = False 192 | 193 | # this is shit code 194 | if os.path.exists(file_path + "/dataset_info.json"): 195 | import datasets 196 | ds = datasets.load_from_disk(file_path, streaming=True) 197 | started = 0 198 | for i in trange(MAX_DATA, desc="Reading Data"): 199 | try: 200 | # get dataset & line number 201 | dataset = [next(ds) for _ in range(320000)] 202 | file_nos = [started + i for i in range(len(dataset))] 203 | 204 | start_mp(dataset, True, file_path, file_nos) 205 | 206 | started += len(dataset) 207 | step = step + 1 208 | if step >= args.num_files: 209 | break 210 | except StopIteration: 211 | break 212 | 213 | if file_path.endswith(".json") == True: 214 | try: 215 | # get dataset & line number 216 | dataset = json.load(fin) 217 | file_nos = [i for i in range(len(dataset))] 218 | 219 | start_mp(dataset, True, file_path, file_nos) 220 | step = step + 1 221 | if step >= args.num_files: 222 | break 223 | continue 224 | except json.decoder.JSONDecodeError: 225 | is_jsonl = True 226 | fin.close() 227 | # reopen for jsonl 228 | fin = open(file_path, "r") 229 | 230 | if file_path.endswith(".jsonl") == True or is_jsonl == True: 231 | is_finish = False 232 | is_first = True 233 | started = 0 234 | while True: 235 | # get dataset 236 | dataset = [] 237 | for i in trange(MAX_DATA, desc="Reading Data"): 238 | tmp_data = fin.readline() 239 | if not tmp_data: 240 | is_finish = True 241 | break 242 | try: 243 | tmp_data = json.loads(tmp_data) 244 | dataset.append(tmp_data) 245 | except json.decoder.JSONDecodeError as e: 246 | warn(str(e) + tmp_data) 247 | continue 248 | 249 | # get line number 250 | file_nos = [started + i for i in range(len(dataset))] 251 | start_mp(dataset, is_first, file_path, file_nos) 252 | is_first = False # append mode 253 | if is_finish == True: 254 | break 255 | elif file_path.endswith(".parquet"): 256 | try: 257 | # get dataset & line number 258 | table = pq.read_table(file_path) 259 | file_nos = [i for i in range(len(table))] 260 | start_mp(table.to_pylist(), True, file_path, file_nos) 261 | except pyarrow.lib.ArrowInvalid as e: 262 | warn(str(e)) 263 | continue 264 | else: 265 | continue 266 | 267 | fin.close() 268 | step = step + 1 269 | if step >= args.num_files: 270 | break 271 | -------------------------------------------------------------------------------- /pretrain/scripts/calc_norm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | import torch 5 | from safetensors import safe_open 6 | 7 | 8 | def calc_norm(model_path: str): 9 | 10 | with safe_open(f"{model_path}/model.safetensors", framework="pt") as f: 11 | for k in f.keys(): 12 | v = f.get_tensor(k) 13 | vnorm = torch.norm(v).item() 14 | vnum = torch.numel(v) 15 | print(k, vnorm, vnorm / vnum, vnorm / math.sqrt(vnum)) 16 | 17 | 18 | if __name__ == "__main__": 19 | calc_norm(sys.argv[1]) 20 | -------------------------------------------------------------------------------- /pretrain/scripts/convert_yulanmini_to_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import shutil 4 | import sys 5 | from collections import defaultdict 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig, 10 | LlamaForCausalLM) 11 | 12 | 13 | 14 | def rebalance_weights2(model_path, method): 15 | target_model_path = model_path + "-" + method 16 | shutil.copytree(model_path, target_model_path, dirs_exist_ok=True) # copy includes optimizer 17 | 18 | source_model = AutoModelForCausalLM.from_pretrained(target_model_path, trust_remote_code=True) 19 | 20 | if os.path.exists(target_model_path + "/model.safetensors"): 21 | os.remove(target_model_path + "/model.safetensors") # prepare for save_pretrained 22 | 23 | if os.path.exists(target_model_path + "/model.safetensors.index.json"): 24 | os.remove(target_model_path + "/model.safetensors.index.json") 25 | os.remove(target_model_path + "/model-00001-of-00002.safetensors") 26 | os.remove(target_model_path + "/model-00002-of-00002.safetensors") 27 | 28 | target_config = LlamaConfig( 29 | attention_bias=True, 30 | attention_dropout=source_model.config.attention_dropout, 31 | bos_token_id=source_model.config.bos_token_id, 32 | eos_token_id=source_model.config.eos_token_id, 33 | head_dim=source_model.config.hidden_size // source_model.config.num_attention_heads, 34 | hidden_act=source_model.config.hidden_act, 35 | hidden_size=source_model.config.hidden_size, 36 | initializer_range=source_model.config.initializer_range, 37 | intermediate_size=source_model.config.intermediate_size, 38 | max_position_embeddings=source_model.config.max_position_embeddings, 39 | mlp_bias=False, 40 | num_attention_heads=source_model.config.num_attention_heads, 41 | num_hidden_layers=source_model.config.num_hidden_layers, 42 | num_key_value_heads=source_model.config.num_key_value_heads, 43 | pretraining_tp=1, 44 | rms_norm_eps=source_model.config.rms_norm_eps, 45 | rope_scaling=None, 46 | rope_theta=source_model.config.rope_theta, 47 | tie_word_embeddings=False, 48 | torch_dtype=torch.float32, 49 | use_cache=True, 50 | vocab_size=source_model.config.vocab_size, 51 | ) 52 | 53 | state_dict = source_model.state_dict() 54 | state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] * source_model.config.scale_emb 55 | for i in range(source_model.config.num_hidden_layers): 56 | state_dict[f"model.layers.{i}.self_attn.o_proj.bias"] = torch.zeros((source_model.config.hidden_size,), dtype=state_dict[f"model.layers.{i}.mlp.down_proj.weight"].dtype) 57 | state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] * source_model.config.scale_depth / math.sqrt(source_model.config.num_hidden_layers) 58 | state_dict[f"model.layers.{i}.mlp.down_proj.weight"] = state_dict[f"model.layers.{i}.mlp.down_proj.weight"] * source_model.config.scale_depth / math.sqrt(source_model.config.num_hidden_layers) 59 | 60 | target_model = LlamaForCausalLM(target_config) 61 | # target_model = source_model 62 | target_model.load_state_dict(state_dict) 63 | 64 | target_model = target_model.to(torch.bfloat16) 65 | target_model.save_pretrained(target_model_path) 66 | print(target_model_path) 67 | 68 | if __name__ == "__main__": 69 | rebalance_weights2(sys.argv[1], method="llama") 70 | -------------------------------------------------------------------------------- /pretrain/scripts/estimate_mfu.py: -------------------------------------------------------------------------------- 1 | # Estimate Model FLOPs Utilization of YuLan-Mini stable training stage 2 | 3 | D = 25 * 40 * 10 ** 9 4 | 5 | N1 = 56 6 | t1 = 10 * 28 * 60 * 60 # 10 stages, 23 hours/stage 7 | 8 | N2 = 48 # shrink the cluster size 9 | t2 = 15 * 32 * 60 * 60 # 15 stages, 32 hours/stage 10 | 11 | T = D / (N1 * t1 + N2 * t2) 12 | print("T =", T) 13 | 14 | C = 312 * 10 ** 12 # A800 GPU chips 15 | B = 1008 # = 56 * 18 = 46 * 21 16 | s = 4096 # seq length 17 | l = 56 # num hidden layers 18 | h = 1920 # hidden size 19 | f = 4800 # intermediate size 20 | V = 99000 # vocab size 21 | 22 | E = 8 * B * s * l * h ** 2 + 6 * B * s * l * h * f + 4 * B * s ** 2 * l * h 23 | F = 3 * E + 4 * B * s ** 2 * l * h + 6 * B * s * h * V 24 | 25 | print("F =", F) 26 | 27 | MFU = F * T / B / s / C 28 | print("MFU =", MFU) 29 | -------------------------------------------------------------------------------- /pretrain/setup.sh: -------------------------------------------------------------------------------- 1 | # setup env on each node in the slurm job 2 | 3 | LOG_PREFIX=log/"$SLURM_JOB_NAME-$SLURM_JOB_ID" 4 | LOG_DIR=/home/u20140041/pretrain-mini/${LOG_PREFIX} 5 | echo $(date +%Y-%m-%d-%H:%M:%S) > $LOG_FILE 6 | echo Setup hostname: $(hostname) >> $LOG_FILE 7 | LOG_FILE=/home/u20140041/pretrain-mini/${LOG_PREFIX}/part0.log 8 | echo "========================" >> $LOG_FILE 9 | FILES_TO_LOG=($0 train.py train_utils.py model/modeling_miniyulan.py model/configuration_miniyulan.py torchrun_wrapper.sh) 10 | mkdir -p $LOG_DIR/artifacts 11 | for file in ${FILES_TO_LOG[@]}; do 12 | echo $file >> $LOG_FILE 13 | cat $file >> $LOG_FILE 14 | cat $file >> $LOG_DIR/artifacts/$(echo $file | tr '/' '-') 15 | echo "========================" >> $LOG_FILE 16 | done 17 | 18 | set -x 19 | 20 | source ~/.bashrc 21 | source .venv/bin/activate # venvbashrc 22 | 23 | # 传递参数 24 | FETCH_TIME=$1 # 没有默认值,需要在 submit_to_slurm.sh 中填写 25 | PER_DEVICE_TRAIN_BATCH_SIZE=$2 # 默认值为 18(对应 7 节点) 26 | DATASET_MODEL_NAME=$3 # 默认值为 myl 27 | 28 | # 计算相关环境变量 29 | NNODES=$SLURM_JOB_NUM_NODES 30 | export WORLD_SIZE=$(expr $NNODES \* 8) 31 | hostnames=$(scontrol show hostnames $SLURM_JOB_NODELIST) 32 | comma_hostnames=$(echo $hostnames | tr ' ' ',') 33 | export MASTER_ADDR=$(echo $hostnames | cut -d ' ' -f 1) # MASTER节点对应RANK 0 34 | MASTER_ADDR=$(getent ahosts $MASTER_ADDR | awk '{ print $1 }' | tail -n 1) 35 | JOB_NAME=$SLURM_JOB_NAME 36 | JOB_ID=$SLURM_JOB_ID 37 | export MASTER_PORT=$(expr 11450 + $(expr $RANDOM % 10000)) # 随机选择一个端口 38 | 39 | trap 'cleanup' SIGTERM # handle scancel gracefully 40 | 41 | # cleanup 函数:在捕获到 SIGTERM 信号时,清理所有由 pdsh 启动的远程进程 42 | cleanup() { 43 | echo "Received SIGTERM at $(date +%Y-%m-%d-%H:%M:%S), cleaning up remote processes..." 44 | pdsh -w $comma_hostnames "kill \$(ps aux | grep '$SLURM_JOB_NAME-$SLURM_JOB_ID' | grep -v grep | awk '{print \$2}')" 45 | kill $(ps aux | grep '$SLURM_JOB_NAME-$SLURM_JOB_ID' | grep -v grep | awk '{print $2}') 46 | kill $(ps aux | grep '$SLURM_JOB_NAME $SLURM_JOB_ID' | grep -v grep | awk '{print $2}') 47 | curl -H "Content-Type: application/json" -X POST https://wxpusher.zjiecode.com/api/send/message --data '{"appToken": "xxx", "content": "canceled job '$SLURM_JOB_NAME-$SLURM_JOB_ID'", "topicIds": [32270]}' 48 | exit 15 49 | } 50 | 51 | ############################### 上面没有需要更改的地方 ############################### -------------------------------------------------------------------------------- /pretrain/synthesis/README.md: -------------------------------------------------------------------------------- 1 | # Data Synthesis 2 | 3 | This directory contains the scripts and prompts for data synthesis. 4 | 5 | 6 |8 | 9 | ## Preliminary 10 | 11 | ### SGLang 12 | 13 | We primarily use the [`sglang`](https://docs.sglang.ai/start/install.html) package to generate synthetic data. 14 | 15 | Then, choose the model you want to use for data synthesis. For example, we use `DeepSeek-Prover-V1.5` and `Qwen2.5-Math-Instruct-7B` to augument the Lean theorem proving dataset. 16 | 17 | ```bash 18 | CUDA_VISIBLE_DEVICES=0,1 python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Prover-V1.5-RL --port 30000 --trust-remote-code --dp 2 19 | ``` 20 | 21 | For those who run the model on a large cluster, you can install the [`sglang_router`](https://docs.sglang.ai/router/router.html) package to optimize the data parallel scheduling efficiency. 22 | 23 | ```bash 24 | pip install sglang-router 25 | ``` 26 | 27 | ### vLLM 28 | 29 | We also use the [`vLLM`](https://docs.vllm.ai/) package to generate synthetic data (on Ascend 910B NPU). 30 | 31 | ```bash 32 | python gen_vllm.py --input_file_path input.jsonl --output_file_path output.jsonl 33 | ``` 34 | 35 | ## Prompts 36 | 37 | We have publish the prompts we used for data synthesis in our technical report. We will organize the synthesis pipeline soon. 38 | -------------------------------------------------------------------------------- /pretrain/synthesis/gen_lean_reasoning.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import time 5 | from random import random, sample 6 | from typing import Tuple 7 | 8 | import datasets 9 | import pandas as pd 10 | import sglang as sgl 11 | from sglang import (RuntimeEndpoint, assistant, function, gen, 12 | set_default_backend, system, user) 13 | from tqdm import tqdm 14 | 15 | # 设置默认的运行时端点 16 | set_default_backend(RuntimeEndpoint("http://localhost:30000")) 17 | 18 | 19 | # Deepseek-Prover-V1 20 | @function 21 | def analyze_deepseek(s, natural_language_statement, formal_statement, state_before, state_after, tactic, explanation="", **kwargs) -> str: 22 | if os.path.exists("/home/huyiwen/monorepo/projects/stop_signal"): 23 | return None 24 | 25 | input_template = """I am a mathematician unfamiliar with Lean. Please explain the tactics used in a proof, as if you are in the process of trying to prove a theorem and haven't yet completed it. Explain the reasoning and logic behind choosing those specific tactics. 26 | 27 | **Statement:** 28 | {natural_language_statement} 29 | ```lean4 30 | {formal_statement} 31 | ``` 32 | 33 | **Current state:** 34 | ```lean4 35 | {state_before} 36 | ``` 37 | 38 | **Proof:** 39 | ```lean4 40 | {tactic} 41 | ``` 42 | """ 43 | 44 | assistant_prefix = """**Reasoning:** 45 | {explanation}""" 46 | 47 | s += user(input_template.format( 48 | natural_language_statement=natural_language_statement, 49 | formal_statement=formal_statement, 50 | state_before=state_before, 51 | tactic=tactic, 52 | )) 53 | 54 | s += assistant(assistant_prefix.format(explanation="") + gen("explanation", max_tokens=600)) 55 | return None 56 | 57 | 58 | # Lean-Github 59 | @function 60 | def analyze_github(s, state_before, tactic, state_after, **kwargs) -> str: 61 | if os.path.exists("/home/huyiwen/monorepo/projects/stop_signal"): 62 | return None 63 | 64 | input_template = """I am a mathematician unfamiliar with Lean. Please explain the tactics used in a proof, as if you are in the process of trying to prove a theorem and haven't yet completed it. Explain the reasoning and logic behind choosing those specific tactics. 65 | 66 | **Current state:** 67 | ```lean4 68 | {state_before} 69 | ``` 70 | 71 | **Proof:** 72 | ```lean4 73 | {tactic} 74 | ``` 75 | """ 76 | 77 | assistant_prefix = """**State after:** 78 | {state_after} 79 | 80 | **Reasoning:** 81 | """ 82 | 83 | s += user(input_template.format( 84 | state_before=state_before, 85 | tactic=tactic, 86 | )) 87 | 88 | s += assistant(assistant_prefix.format(state_after=state_after) + gen("explanation", max_tokens=600)) 89 | return None 90 | 91 | 92 | # Lean-Workbook 93 | # State Before + Tactic -> State After 94 | @function 95 | def analyze_workbook_a(s, natural_language_statement, formal_statement, state_before, state_after, tactic, explanation="", **kwargs) -> str: 96 | if os.path.exists("/home/huyiwen/monorepo/projects/stop_signal"): 97 | return None 98 | 99 | input_template = """Given a Lean tactic at a intermediate step in a proof and the goal state before the tactic, predict the resulting goal state after the tactic's application and provide a detailed explanation. You do not need to consider whether the tactic is sufficient to complete the proof; simply explain why the goal state changes to your predicted state after the tactic's execution. 100 | 101 | **Statement:** 102 | {natural_language_statement} 103 | ```lean4 104 | {formal_statement} 105 | ``` 106 | 107 | **Goal state before:** 108 | ```lean4 109 | {state_before} 110 | ``` 111 | 112 | **Tactic to execute:** 113 | ```lean4 114 | {tactic} 115 | ``` 116 | """ 117 | 118 | assistant_prefix = """**State after:** 119 | {state_after} 120 | 121 | **Explanation:** 122 | """ 123 | 124 | s += user(input_template.format( 125 | natural_language_statement=natural_language_statement, 126 | formal_statement=formal_statement, 127 | state_before=state_before, 128 | tactic=tactic, 129 | )) 130 | 131 | s += assistant(assistant_prefix.format(state_after=state_after) + gen("explanation", max_tokens=600)) 132 | return None 133 | 134 | 135 | # State After + Tactic -> State Before 136 | @function 137 | def analyze_workbook_b(s, natural_language_statement, formal_statement, state_before, state_after, tactic, explanation="", **kwargs) -> str: 138 | if os.path.exists("/home/huyiwen/monorepo/projects/stop_signal"): 139 | return None 140 | 141 | input_template = """Given a tactic applied at an intermediate step of a Lean proof and the resulting goal state **after** applying the tactic, predict one possible goal state **before** the tactic was applied, and provide a detailed explanation You don't need to consider whether the tactic is sufficient to complete the proof; simply explain why the **pre-tactic goal state** would have resulted in the given post-tactic state. 142 | 143 | **Statement:** 144 | {natural_language_statement} 145 | ```lean4 146 | {formal_statement} 147 | ``` 148 | 149 | **Tactic applied:** 150 | ```lean4 151 | {tactic} 152 | ``` 153 | 154 | **Resulting state after:** 155 | ```lean4 156 | {state_after} 157 | ``` 158 | """ 159 | 160 | assistant_prefix = """**Goal state before:** 161 | {state_before} 162 | 163 | **Explanation:** 164 | """ 165 | 166 | s += user(input_template.format( 167 | natural_language_statement=natural_language_statement, 168 | formal_statement=formal_statement, 169 | state_after=state_after, 170 | tactic=tactic, 171 | )) 172 | 173 | s += assistant(assistant_prefix.format(state_before=state_before) + gen("explanation", max_tokens=600)) 174 | return None 175 | 176 | 177 | 178 | @function 179 | def analyze_workbook(s, natural_language_statement, formal_statement, state_before, tactic, state_after, **kwargs) -> str: 180 | if os.path.exists("/home/huyiwen/monorepo/projects/stop_signal"): 181 | return None 182 | 183 | input_template = """Give the next tactic in the proof with explanatory comments. 184 | 185 | Statement: {natural_language_statement} 186 | 187 | ```lean4 188 | {formal_statement} 189 | ``` 190 | 191 | **Current state:** 192 | 193 | {state_before} 194 | """ 195 | 196 | assistant_prefix = """**Next tactic:** 197 | {tactic} 198 | /-State: 199 | {state_after}-/ 200 | 201 | **Explanatory comments:** 202 | """ 203 | 204 | s += user(input_template.format( 205 | natural_language_statement=natural_language_statement, 206 | formal_statement=formal_statement, 207 | state_before=state_before, 208 | tactic=tactic, 209 | state_after=state_after, 210 | )) 211 | 212 | s += assistant(assistant_prefix.format(tactic=tactic, state_after=state_after) + gen("explanation", max_tokens=400)) 213 | return None 214 | 215 | 216 | def analyze(lines, analyze_fn, name): 217 | 218 | # 使用batch处理多个文本 219 | states = analyze_fn.run_batch( 220 | lines, 221 | progress_bar=True, 222 | num_threads=256, 223 | temperature=0.3, 224 | top_p=0.4, 225 | ) 226 | 227 | answers = [] 228 | for line, state in zip(lines, states): 229 | # extract the explanation from the state 230 | try: 231 | line["explanation"] = state["explanation"] 232 | except Exception: 233 | line["explanation"] = "" 234 | # extract the stop reason from the state 235 | try: 236 | line["stop_reason"] = state.get_meta_info("explanation").get("finish_reason", {}).get("type", "") 237 | except: 238 | line["stop_reason"] = "" 239 | answers.append(line) 240 | 241 | print(f"/home/huyiwen/monorepo/projects/miniyulan/gen_lean/lean_explain_{name}.jsonl") 242 | with open(f"/home/huyiwen/monorepo/projects/miniyulan/gen_lean/lean_explain_{name}.jsonl", "w") as f: 243 | for line in answers: 244 | f.write(json.dumps(line) + "\n") 245 | 246 | 247 | def get_data(repo="workbook"): 248 | if repo == "workbook": # Not used 249 | lines = datasets.load_dataset("/home/huyiwen/lean-tactics/Lean-Workbook", split="train").to_list() 250 | return lines 251 | elif repo == "github": # Not used 252 | lean_github = pd.read_parquet('/home/huyiwen/lean-tactics/Lean-Github/lean-github.parquet') 253 | 254 | # dedup 255 | lean_github = lean_github.drop_duplicates(subset=['url', 'commit', 'file_path', 'start', 'end', 'tactic', 'state_before', 'state_after']) 256 | 257 | # convert string to real tuple 258 | lean_github['start'] = lean_github['start'].apply(lambda x: tuple(map(int, x[1:-1].split(',')))) 259 | lean_github['end'] = lean_github['end'].apply(lambda x: tuple(map(int, x[1:-1].split(',')))) 260 | return lean_github.to_dict(orient='records') 261 | elif repo == "deepseek": 262 | lines = datasets.load_dataset("/home/huyiwen/lean-tactics/DeepSeek-Prover-V1", split="train").to_list() 263 | return lines 264 | elif repo == "workbook-c": 265 | with open("/home/huyiwen/lean-tactics/Lean-Workbook/c.jsonl") as f: 266 | lines = [json.loads(line) for line in f] 267 | return lines 268 | elif repo == "workbook-a": 269 | with open("/home/huyiwen/lean-tactics/Lean-Workbook/a.jsonl") as f: 270 | lines = [json.loads(line) for line in f] 271 | return lines 272 | 273 | 274 | lines = get_data("github") 275 | analyze(lines, analyze_github, "github-" + time.strftime("%Y%m%d-%H%M%S")) 276 | -------------------------------------------------------------------------------- /pretrain/synthesis/gen_qwq.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import time 6 | from copy import copy 7 | from random import random, sample 8 | from typing import Tuple 9 | 10 | import sglang as sgl 11 | from sglang import (RuntimeEndpoint, assistant, function, gen, 12 | set_default_backend, system, user) 13 | from tqdm import tqdm 14 | 15 | set_default_backend(RuntimeEndpoint("http://localhost:30000")) 16 | 17 | 18 | @function 19 | def analyze_text(s, problem: str, **kwargs) -> str: 20 | 21 | if os.path.exists("/home/huyiwen/miniyulan-ckpts/qwq_gen/stop_signal"): 22 | return "Stop signal detected." 23 | 24 | sys_prompt="You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." 25 | 26 | prompt = """Please think step by step to solve the following question, and put your final answer within \\boxed{{}}. 27 | 28 | {question}""" 29 | 30 | s += system(sys_prompt) 31 | s += user(prompt.format(question=problem)) 32 | s += assistant( gen("qwq_gen", max_tokens=16000, stop=['Human:']) ) 33 | 34 | 35 | def analyze(origin_jsonl_path): 36 | 37 | lines = [] 38 | with open(origin_jsonl_path, 'r') as file: 39 | for line in file: 40 | lines.append(json.loads(line)) 41 | # lines = lines[16:] 42 | print(len(lines)) 43 | 44 | # 使用batch处理多个文本 45 | states = analyze_text.run_batch( 46 | lines, 47 | progress_bar=True, 48 | num_threads=16, 49 | temperature=0, 50 | ) 51 | 52 | llama_classify_file = origin_jsonl_path.replace(".jsonl", f"-qwq_generated-{time.strftime('%Y%m%d%H%M%S')}.jsonl") 53 | with open(llama_classify_file, "a") as f: 54 | for line, state in zip(lines, states): 55 | obj = copy(line) 56 | 57 | try: 58 | obj["qwq_gen"] = state["qwq_gen"] 59 | except Exception as e: 60 | # print(e) 61 | obj["qwq_gen"] = "" 62 | 63 | try: 64 | obj["qwq_gen_answer"] = state["qwq_gen_answer"] 65 | except Exception as e: 66 | # print(e) 67 | obj["qwq_gen_answer"] = "" 68 | 69 | try: 70 | obj["stop_reason"] = state.get_meta_info("qwq_gen").get("finish_reason", {}).get("type", "") 71 | except Exception as e: 72 | obj["stop_reason"] = str(e) 73 | 74 | f.write(json.dumps(obj) + "\n") 75 | 76 | return True 77 | 78 | 79 | if __name__ == "__main__": 80 | analyze(sys.argv[1]) 81 | -------------------------------------------------------------------------------- /pretrain/synthesis/gen_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | from vllm import LLM, SamplingParams 5 | from datasets import Dataset 6 | from transformers import AutoTokenizer 7 | 8 | 9 | def parse_args(): 10 | parse = argparse.ArgumentParser(description="gen") 11 | parse.add_argument("--input_file_path", type=str, default="", help="input_path") 12 | parse.add_argument("--output_path", type=str, default="", help="output_path") 13 | parse.add_argument("--start_index", type=int, default=None) 14 | parse.add_argument("--end_index", type=int, default=None) 15 | return parse.parse_args() 16 | 17 | def main(): 18 | 19 | args = parse_args() 20 | 21 | # Load JSONL file 22 | input_file_path = args.input_file_path 23 | output_path = args.output_path 24 | start_index = args.start_index 25 | end_index = args.end_index 26 | 27 | data = [] 28 | with open(input_file_path, "r", encoding="utf-8") as file: 29 | for line in file: 30 | data.append(json.loads(line)) 31 | 32 | # faciliate data parallelism 33 | if start_index is not None and end_index is not None: 34 | data = data[start_index:end_index] 35 | elif end_index is not None: 36 | data = data[:end_index] 37 | elif start_index is not None: 38 | data = data[start_index:] 39 | 40 | template = ( 41 | "## Instruction\nPlease gain inspiration from the following content to create a high-quality problem and solution. Present your output in two distinct sections: [Problem] and [Solution].\n\n" 42 | "## Content\n{text}\n" 43 | "## Guidelines \n[Problem]: This should be **completely self-contained**, providing all the contextual information one needs to understand and solve the problem.\n\n[Solution]: Present a comprehensive, step-by-step solution that solves the problem **correctly** and educates the student, around 250-350 words long. Clearly articulate the reasoning and methods used at each step, providing insight into the problem-solving process. Take care to format any equations properly using LaTeX or appropriate notation." 44 | ) 45 | 46 | prompts = [] 47 | for item in data: 48 | prompts.append(template.format(text=item["text"]) + " Please generate only one Problem and only one Solution, and when you finish generating the solution, end with the signal '7 |
'.") 49 | 50 | stop_tokens = [" "] 51 | sampling_params = SamplingParams(temperature=0.7, top_p=1.0, max_tokens=2048, stop=stop_tokens) 52 | 53 | llm = LLM(model="/data/Qwen2.5-7B-Instruct", tensor_parallel_size=1, gpu_memory_utilization=0.95, trust_remote_code=True) 54 | outputs = llm.generate(prompts, sampling_params) 55 | 56 | generated_texts = [] 57 | for i, output in enumerate(outputs): 58 | prompt = output.prompt 59 | generated_text = output.outputs[0].text 60 | generated_texts.append({"prompt":prompt,"output":generated_text}) 61 | 62 | 63 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 64 | with open(output_path, "w", encoding="utf-8") as json_file: 65 | json.dump(generated_texts, json_file, ensure_ascii=False, indent=4) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /pretrain/torchrun_wrapper.sh: -------------------------------------------------------------------------------- 1 | # 将本脚本所有输出重定向到文件log/$SLURM_JOB_NAME-$SLURM_JOB_ID/part$SLURM_PROCID.log: 2 | cd xxx 3 | comma_hostnames=$1 4 | shift 5 | PROCID=$(expr $(echo $comma_hostnames | tr "," "\n" | grep -n `hostname` | cut -c-1) - 1) # 仅适用9个节点以内 6 | SLURM_JOB_NAME=$1 7 | shift 8 | SLURM_JOB_ID=$1 9 | shift 10 | if [ -z "$PROCID" ]; then 11 | echo "torchrun_wrapper.sh: PROCID is empty, exit" 12 | exit 1 13 | fi 14 | if [ -z "$SLURM_JOB_NAME" ]; then 15 | echo "torchrun_wrapper.sh: SLURM_JOB_NAME is empty, exit" 16 | exit 1 17 | fi 18 | if [ -z "$SLURM_JOB_ID" ]; then 19 | echo "torchrun_wrapper.sh: SLURM_JOB_ID is empty, exit" 20 | exit 1 21 | fi 22 | echo "$(date +%Y-%m-%d %H:%M:%S) torchrun_wrapper.sh: SLURM_JOB_NAME=$SLURM_JOB_NAME, SLURM_JOB_ID=$SLURM_JOB_ID, PROCID=$PROCID; hostname=`hostname`" >> log/$SLURM_JOB_NAME-$SLURM_JOB_ID/part$PROCID.log 23 | exec &>> log/$SLURM_JOB_NAME-$SLURM_JOB_ID/part$PROCID.log 24 | 25 | export NCCL_NSOCKS_PERTHREAD=4 26 | export NCCL_SOCKET_NTHREADS=2 27 | export NCCL_MIN_CHANNELS=32 28 | 29 | source ~/.bashrc 30 | 31 | module load /opt/app/spack/share/spack/modules/gcc/11.3.0 32 | module load /opt/app/spack/share/spack/modules/cuda/12.5.1 33 | module load /opt/app/spack/share/spack/modules/libaio/0.3.113-gcc_13.1.0 34 | 35 | source .venv/bin/activate # venv 36 | 37 | # export NCCL_SOCKET_IFNAME=vpapvn # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html 38 | # export NCCL_IB_DISABLE=1 # https://github.com/NVIDIA/nccl/issues/451 39 | export LDFLAGS="-L/usr/lib64" 40 | export CFLAGS="-I/usr/include" 41 | export PYTHONPATH=. 42 | export CUTLASS_PATH=~/cutlass 43 | export LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8 # https://stackoverflow.com/questions/74367207/segmentation-fault-core-dumped-when-launching-python-in-anaconda 44 | export OPENBLAS_NUM_THREADS=24 # https://stackoverflow.com/questions/52026652/openblas-blas-thread-init-pthread-create-resource-temporarily-unavailable 45 | export OMP_NUM_THREADS=24 # https://stackoverflow.com/questions/53351194/openmp-libgomp-thread-creation-failed-resource-temporarily-unavailable-when 46 | 47 | export WANDB_MODE=disabled 48 | export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 49 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 50 | 51 | # DEBUG 52 | export TRANSFORMERS_VERBOSITY=debug 53 | export NCCL_DEBUG=DEBUG # https://stackoverflow.com/questions/61075390/pytorch-nccl-error-unhandled-system-error-nccl-version-2-4-8 54 | export NCCL_DEBUG_SUBSYS=GRAPH # https://pytorch.org/docs/stable/distributed.html 55 | # export TORCH_LOGS=+all 56 | # export TORCH_DISTRIBUTED_DEBUG=DETAIL 57 | # export TORCH_CPP_LOG_LEVEL=INFO 58 | 59 | 60 | CACHE_PATH='/fs/archive/share/yulan/data/aa_hf_cache' 61 | export TMPDIR=${CACHE_PATH}/tmp 62 | export HF_DATASETS_CACHE=${CACHE_PATH}/hf_datasets_cache 63 | export HF_HOME=${CACHE_PATH}/hf_home 64 | mkdir -p ${CACHE_PATH} 65 | mkdir -p ${TMPDIR} 66 | mkdir -p ${HF_DATASETS_CACHE} 67 | mkdir -p ${HF_HOME} 68 | 69 | # 打印所有环境变量 70 | env 71 | 72 | # 输出 73 | echo "torchrun_wrapper.sh: SLURM_JOB_NAME=$SLURM_JOB_NAME, SLURM_JOB_ID=$SLURM_JOB_ID, PROCID=$PROCID; hostname=`hostname`" 74 | echo -e "torchrun_wrapper.sh: torchrun --node_rank $PROCID $@\n" 75 | 76 | # 设置 -e 选项,这会使脚本在任何命令失败时立即退出 77 | set -e 78 | 79 | # 设置 -o pipefail,这确保管道中的任何命令失败都会导致整个管道失败 80 | set -o pipefail 81 | 82 | torchrun --node_rank $PROCID $@ 83 | 84 | if [ $PROCID -eq 0 ]; then 85 | curl -H "Content-Type: application/json" -X POST https://wxpusher.zjiecode.com/api/send/message --data '{"appToken": "xxx", "content": "'$SLURM_JOB_NAME-$SLURM_JOB_ID' done ", "topicIds": [32270]}' 86 | fi 87 | -------------------------------------------------------------------------------- /pretrain/train.sh: -------------------------------------------------------------------------------- 1 | source ~/.bashrc 2 | 3 | # 将作业提交给SLURM 4 | 5 | # 参数:--time=30:00:00 最大运行时间24小时 6 | # 参数:--job-name=xxx 作业名称 7 | # 参数:--nodes=1 使用1个节点(注意调节batch size!!!) 8 | 9 | function decay_train() { 10 | # 保存数据集并启动训练 11 | SCRIPT=$1 12 | FETCH_TIME=$2 13 | if [[ ${#FETCH_TIME} -ne 18 ]]; then 14 | echo "FETCH_TIME格式错误:$FETCH_TIME" 15 | exit 1 16 | fi 17 | RUN_REASON=$3 18 | if [[ ${#RUN_REASON} -lt 10 ]]; then 19 | echo "RUN_REASON 至少大于10个字:$RUN_REASON" 20 | exit 1 21 | fi 22 | PER_DEVICE_TRAIN_BATCH_SIZE=${4:-18} 23 | NNODES=${5:-7} 24 | MODEL_NAME=${6:-"myl_new_no_math"} 25 | JOB_NAME=$(basename $SCRIPT .sh)-$FETCH_TIME 26 | if [ -z /fs/archive/share/yulan/data/aa_mini/output/${JOB_NAME} ]; then 27 | echo "已有checkpoint!请注意是否会覆盖:/fs/archive/share/yulan/data/aa_mini/output/${JOB_NAME}" 28 | exit 1 29 | fi 30 | echo "JOB_NAME: $JOB_NAME" 31 | echo "请检查总BATCH_SIZE: $PER_DEVICE_TRAIN_BATCH_SIZE * $NNODES * 8 = $((PER_DEVICE_TRAIN_BATCH_SIZE * NNODES * 8))" 32 | echo "等价于BATCH_SIZE:$((PER_DEVICE_TRAIN_BATCH_SIZE * NNODES * 8 * 4096)) Tokens" 33 | if [ -d /fs/archive/share/yulan/data/aa_mini/hf_dataset/$MODEL_NAME/$FETCH_TIME ]; then 34 | echo "数据集已存在 /fs/archive/share/yulan/data/aa_mini/hf_dataset/$MODEL_NAME/$FETCH_TIME" 35 | else 36 | python preprocess/fetch_data/distributed_save.py $FETCH_TIME $MODEL_NAME 37 | fi 38 | 39 | JOB_ID=$(sbatch --time=36:00:00 --job-name=$JOB_NAME --nodes=$NNODES $SCRIPT $FETCH_TIME $PER_DEVICE_TRAIN_BATCH_SIZE $MODEL_NAME | grep -o -P '\d+') 40 | echo "JOB_ID: $JOB_ID" 41 | if [ -z $JOB_ID ]; then 42 | echo "启动失败" 43 | exit 1 44 | fi 45 | mkdir -p "log/$JOB_NAME-$JOB_ID" 46 | touch "log/$JOB_NAME-$JOB_ID/reason-$RUN_REASON" 47 | 48 | sleep 5 49 | nohup new_start_monitor $JOB_NAME $JOB_ID > "log/$JOB_NAME-$JOB_ID/monitor.log" 2>&1 & 50 | LOF_FILE="log/$JOB_NAME-$JOB_ID/part0.log" 51 | squeue -o "%.6i %.35j %t %8M %.R" 52 | exit 0 53 | } 54 | 55 | 56 | function main_train() { 57 | # 保存数据集并启动训练 58 | SCRIPT=$1 59 | FETCH_TIME=$2 60 | if [[ ${#FETCH_TIME} -ne 18 ]]; then 61 | echo "FETCH_TIME格式错误:$FETCH_TIME" 62 | exit 1 63 | fi 64 | RUN_REASON=$3 65 | if [[ ${#RUN_REASON} -lt 10 ]]; then 66 | echo "RUN_REASON 至少大于10个字:$RUN_REASON" 67 | exit 1 68 | fi 69 | PER_DEVICE_TRAIN_BATCH_SIZE=${4:-18} 70 | NNODES=${5:-7} 71 | MODEL_NAME=${6:-"myl_new_no_math"} 72 | JOB_NAME=$(basename $SCRIPT .sh) 73 | if [ -z /fs/archive/share/yulan/data/aa_mini/output/${JOB_NAME} ]; then 74 | echo "已有checkpoint!请注意是否会覆盖:/fs/archive/share/yulan/data/aa_mini/output/${JOB_NAME}" 75 | exit 1 76 | fi 77 | echo "JOB_NAME: $JOB_NAME" 78 | echo "请检查总BATCH_SIZE: $PER_DEVICE_TRAIN_BATCH_SIZE * $NNODES * 8 = $((PER_DEVICE_TRAIN_BATCH_SIZE * NNODES * 8))" 79 | echo "等价于BATCH_SIZE:$((PER_DEVICE_TRAIN_BATCH_SIZE * NNODES * 8 * 4096)) Tokens" 80 | if [ -d /fs/archive/share/yulan/data/aa_mini/hf_dataset/$MODEL_NAME/$FETCH_TIME ]; then 81 | echo "数据集已存在 /fs/archive/share/yulan/data/aa_mini/hf_dataset/$MODEL_NAME/$FETCH_TIME" 82 | else 83 | python preprocess/fetch_data/distributed_save.py $FETCH_TIME $MODEL_NAME 84 | fi 85 | 86 | JOB_ID=$(sbatch --time=36:00:00 --job-name=$JOB_NAME --nodes=$NNODES $SCRIPT $FETCH_TIME $PER_DEVICE_TRAIN_BATCH_SIZE $MODEL_NAME | grep -o -P '\d+') 87 | echo "JOB_ID: $JOB_ID" 88 | if [ -z $JOB_ID ]; then 89 | echo "启动失败" 90 | exit 1 91 | fi 92 | mkdir -p "log/$JOB_NAME-$JOB_ID" 93 | touch "log/$JOB_NAME-$JOB_ID/reason-$RUN_REASON" 94 | 95 | sleep 5 96 | nohup new_start_monitor $JOB_NAME $JOB_ID > "log/$JOB_NAME-$JOB_ID/monitor.log" 2>&1 & 97 | LOF_FILE="log/$JOB_NAME-$JOB_ID/part0.log" 98 | squeue -o "%.6i %.35j %t %8M %.R" 99 | exit 0 100 | } 101 | 102 | 103 | # Note: Due to subsequent modifications to the training code, this launch script may require re-adaptation. 104 | 105 | main_train yulanmini-2B-final-phase1.sh 20241017_013512 "2B-model-phase1,lm_head_alpha=1+deepspeed1+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 106 | 107 | main_train yulanmini-2B-final-phase2.sh 02_20241017_013401 "2B-model-phase2,lm_head_alpha=1+deepspeed1+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 108 | 109 | main_train yulanmini-2B-final-phase3.sh 03_20241020_001556 "2B-model-phase3,lm_head_alpha=1+deepspeed1+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 110 | 111 | main_train yulanmini-2B-final-phase4.sh 04_20241021_170901 "2B-model-phase4,lm_head_alpha=1+deepspeed1+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 112 | 113 | main_train yulanmini-2B-final-phase5.sh 05_20241022_221453 "2B-model-phase5,lm_head_alpha=1+deepspeed1+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 114 | 115 | main_train yulanmini-2B-final-phase6.sh 06_20241024_013137 "2B-model-phase6,lm_head_alpha=1+deepspeed1+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 116 | 117 | main_train yulanmini-2B-final-phase7-dp2.sh 07_20241025_022032 "2B-model-phase7,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 118 | 119 | main_train yulanmini-2B-final-phase8.sh 08_20241026_151354 "2B-model-phase8,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 120 | 121 | main_train yulanmini-2B-final-phase9.sh 09_20241027_190948 "2B-model-phase9,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 122 | 123 | main_train yulanmini-2B-final-phase10.sh 10_20241028_225112 "2B-model-phase10,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 18 7 myl_new_no_math 124 | 125 | main_train yulanmini-2B-final-phase11.sh 11_20241030_124814 "2B-model-phase11,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 126 | 127 | main_train yulanmini-2B-final-phase12.sh 12_20241101_002827 "2B-model-phase12,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 128 | 129 | main_train yulanmini-2B-final-phase13.sh 13_20241102_160534 "2B-model-phase13,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 130 | 131 | main_train yulanmini-2B-final-phase14.sh 14_20241104_000454 "2B-model-phase14,lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 132 | 133 | main_train yulanmini-2B-final-phase15.sh 15_20241105_023029 "2B-model-phase15, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 134 | 135 | main_train yulanmini-2B-final-phase16.sh 16_20241106_180613 "2B-model-phase16, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 136 | 137 | main_train yulanmini-2B-final-phase17.sh 17_20241108_004951 "2B-model-phase17, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_new_no_math 138 | 139 | main_train yulanmini-2B-final-phase18-hyw.sh 18_20241113_034017 "2B-model-phase18-remake, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 140 | 141 | main_train yulanmini-2B-final-phase19-hyw.sh 19_20241114_115241 "2B-model-phase19-remake, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 142 | 143 | main_train yulanmini-2B-final-phase20-remake.sh 20_20241115_234357 "2B-model-phase20-remake, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 144 | 145 | main_train yulanmini-2B-final-phase21.sh 21_20241117_021115 "2B-model-phase21, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 146 | 147 | main_train yulanmini-2B-final-phase22.sh 22_20241118_155407 "2B-model-phase22, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 148 | 149 | main_train yulanmini-2B-final-phase23.sh 23_20241120_033942 "2B-model-phase23, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 150 | 151 | main_train yulanmini-2B-final-phase24.sh 24_20241121_133110 "2B-model-phase23, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 152 | 153 | main_train yulanmini-2B-final-phase25.sh 25_20241123_030124 "2B-model-phase23, lm_head_alpha=1+deepspeed2+norm_alpha=True+rms_type=llama+emb_alpha=False, " 21 6 myl_mix890 154 | 155 | decay_train yulanmini-2B-s25d-decay80-1sqrt-long-28k-final-phase26.sh 26_20241211_015209 "decay-80B-phase26 " 26 5 myl_mix890_long_28k 156 | 157 | decay_train yulanmini-2B-s25d-decay80-1sqrt-long-28k-final-phase27.sh 27_20241213_051741 "decay-80B-phase27 " 26 5 myl_mix890_long_28k 158 | -------------------------------------------------------------------------------- /pretrain/train_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from typing import Dict, Union 5 | 6 | import datasets 7 | import torch 8 | import transformers 9 | import wandb 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from torch.utils.data import DataLoader, SequentialSampler 12 | from transformers import Trainer, TrainerCallback 13 | from transformers.trainer_utils import seed_worker 14 | from transformers.utils import is_datasets_available 15 | 16 | LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) 17 | RANK = int(os.getenv("RANK", "0")) 18 | WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) 19 | 20 | 21 | def print_rank0(*arg): 22 | if RANK == 0: 23 | print(*arg) 24 | 25 | 26 | class LogCallback(TrainerCallback): 27 | 28 | def on_log(self, args, state, control, model, logs=None, **kwargs): 29 | logs["train/global_step"] = state.global_step 30 | logs["train/epoch"] = state.epoch 31 | logs['train/total_flos'] = state.total_flos 32 | wandb.config.update({'global_step': state.global_step}, 33 | allow_val_change=True) 34 | 35 | 36 | class PyTorchProfilerCallback(TrainerCallback): 37 | 38 | def on_train_begin(self, args, state, control, logs=None, **kwargs): 39 | # only one epoch will be trained 40 | self.prof = torch.profiler.profile( 41 | activities=[ 42 | torch.profiler.ProfilerActivity.CPU, 43 | torch.profiler.ProfilerActivity.CUDA 44 | ], 45 | schedule=torch.profiler.schedule(wait=20, warmup=0, active=8), 46 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 47 | args.log_dir), 48 | record_shapes=True, 49 | profile_memory=True, 50 | #with_stack=True, 51 | with_flops=True, 52 | #with_modules=True 53 | ) 54 | 55 | def on_step_begin(self, args, state, control, logs=None, **kwargs): 56 | self.prof.step() 57 | 58 | def on_train_end(self, args, state, control, logs=None, **kwargs): 59 | self.prof.stop() 60 | 61 | 62 | 63 | def get_wsd_scheduler(optimizer, 64 | num_warmup_steps, 65 | num_training_steps, 66 | last_epoch=-1, 67 | stable_ratio=1.0, 68 | start_lambda=0, 69 | end_lambda=1, 70 | start_global_step=None, 71 | end_global_step=None, 72 | wsd_style="cos"): 73 | # Note: Due to subsequent modifications to the training code, this function may require re-adaptation. 74 | 75 | if wsd_style == "cos": 76 | def lr_lambda(current_step): 77 | if start_global_step is not None and end_global_step is not None and start_global_step <= current_step <= end_global_step: 78 | return (1 - math.cos( 79 | math.pi * float(current_step - start_global_step) / 80 | float(max(1, end_global_step - start_global_step)) / 2)) * ( 81 | end_lambda - start_lambda) + start_lambda 82 | if current_step < num_warmup_steps: 83 | return (float(current_step) / float(max(1, num_warmup_steps))) * ( 84 | end_lambda - start_lambda) + start_lambda 85 | num_stable_steps = stable_ratio * num_training_steps 86 | if stable_ratio == 1.0 or current_step <= num_stable_steps: 87 | return 1.0 88 | return max( 89 | 0.1, 90 | float(num_training_steps - current_step) / 91 | float(max(1, num_training_steps - num_stable_steps)), 92 | ) 93 | 94 | elif wsd_style == "linear": 95 | def lr_lambda(current_step): 96 | if start_global_step is not None and end_global_step is not None and start_global_step <= current_step <= end_global_step: 97 | return (float(current_step - start_global_step) / 98 | float(max(1, end_global_step - start_global_step))) * ( 99 | end_lambda - start_lambda) + start_lambda 100 | if current_step < num_warmup_steps: 101 | return (float(current_step) / float(max(1, num_warmup_steps))) * ( 102 | end_lambda - start_lambda) + start_lambda 103 | num_stable_steps = stable_ratio * num_training_steps 104 | if stable_ratio == 1.0 or current_step <= num_stable_steps: 105 | return 1.0 106 | return max( 107 | 0.1, 108 | float(num_training_steps - current_step) / 109 | float(max(1, num_training_steps - num_stable_steps)), 110 | ) 111 | elif wsd_style == "cos2": 112 | 113 | def lr_lambda(current_step): 114 | if start_global_step is not None and end_global_step is not None and start_global_step <= current_step <= end_global_step: 115 | return (1 - math.cos( 116 | math.pi * float(current_step - start_global_step) / 117 | float(max(1, end_global_step - start_global_step)))) * ( 118 | end_lambda - start_lambda) / 2 + start_lambda 119 | if current_step < num_warmup_steps: 120 | return (float(current_step) / float(max(1, num_warmup_steps)) 121 | ) * (end_lambda - start_lambda) + start_lambda 122 | num_stable_steps = stable_ratio * num_training_steps 123 | if stable_ratio == 1.0 or current_step <= num_stable_steps: 124 | return 1.0 125 | return max( 126 | 0.1, 127 | float(num_training_steps - current_step) / 128 | float(max(1, num_training_steps - num_stable_steps)), 129 | ) 130 | elif wsd_style == "1sqrt": 131 | 132 | def lr_lambda(current_step): 133 | if current_step > 262000: # small hack for remaining steps 134 | current_step = 262000 135 | if start_global_step is not None and end_global_step is not None and start_global_step <= current_step <= end_global_step: 136 | return (1 - math.sqrt( 137 | (current_step - start_global_step) / 138 | (end_global_step - start_global_step))) * ( 139 | start_lambda - end_lambda) + end_lambda 140 | if current_step < num_warmup_steps: 141 | return (float(current_step) / float(max(1, num_warmup_steps)) 142 | ) * (end_lambda - start_lambda) + start_lambda 143 | num_stable_steps = stable_ratio * num_training_steps 144 | if stable_ratio == 1.0 or current_step <= num_stable_steps: 145 | return 1.0 146 | return max( 147 | 0.1, 148 | float(num_training_steps - current_step) / 149 | float(max(1, num_training_steps - num_stable_steps)), 150 | ) 151 | else: 152 | raise ValueError(f"Unknown wsd_style: {wsd_style}") 153 | 154 | return LambdaLR(optimizer, lr_lambda, last_epoch) 155 | 156 | -------------------------------------------------------------------------------- /pretrain/yulanmini-2B-final-phase25.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --comment=joint_project 4 | 5 | #SBATCH --job-name=xxxx 6 | 7 | #SBATCH --ntasks-per-node=1 8 | 9 | #SBATCH --gres=gpu:a800:8 10 | 11 | #SBATCH --partition=debug 12 | 13 | #SBATCH --output=log/%x-%j/part0.log 14 | 15 | #SBATCH --error=log/%x-%j/part0.log 16 | 17 | ### 前面是slurm默认值,不要更改 18 | 19 | source setup.sh 20 | 21 | ############################### 上面没有需要更改的地方 ############################### 22 | 23 | 24 | # ========== RESUME 只需要修改这里 ========== 25 | last_stage_job_name=miniyulan-2B-final-phase24 26 | STAGE=25 27 | # ======================================== 28 | 29 | CONTINUE=false 30 | if [ "$CONTINUE" = false ]; then 31 | DO_RMS_NORM=true 32 | ALLOW_0_CHECKPOINT=false 33 | UPDATE_TRAINED_STEPS_AND_EPOCHS=true 34 | elif [ "$CONTINUE" = true ]; then 35 | DO_RMS_NORM=false 36 | ALLOW_0_CHECKPOINT=true 37 | UPDATE_TRAINED_STEPS_AND_EPOCHS=false 38 | fi 39 | 40 | MODIFY_TRAINER_STATE=false 41 | 42 | # 计算上一次的最新checkpoint 43 | last_stage_latest_checkpoint=$(ls output_soft_link/$last_stage_job_name | grep checkpoint | grep -v rebalanced | grep -v rms_norm | sort -r | head -n 1) 44 | 45 | # 如果ALLOW_0_CHECKPOINT=false,检查获得的checkpoint不应该是000结尾 46 | if [ "$ALLOW_0_CHECKPOINT" = false ] && [[ "$last_stage_latest_checkpoint" == *000 ]]; then 47 | echo "last_stage_latest_checkpoint is 000, exit" 48 | exit 1 49 | fi 50 | 51 | # 如果没有rms_norm,则重新平衡权重 52 | if [ ! -d "output_soft_link/$last_stage_job_name/$last_stage_latest_checkpoint-rms_norm" ] && [ "$DO_RMS_NORM" = true ]; then 53 | python scripts/rebalance_weight.py output_soft_link/$last_stage_job_name/$last_stage_latest_checkpoint 54 | fi 55 | 56 | # dataset path 57 | # FETCH_TIME="" # 注意!现在FETCH_TIME自动从launch中传入!!!!所以在submit_to_slurm.sh中设置!!!! 58 | DATA_PATH=hf_dataset/$DATASET_MODEL_NAME/$FETCH_TIME 59 | 60 | MODEL_PATH=output/$last_stage_job_name 61 | 62 | # model max length 63 | MODEL_MAX_LENGTH=4096 64 | 65 | # batch size 66 | # 下面的BS 节点数 GPU数 CONTEXT-SIZE 67 | # PER_DEVICE_TRAIN_BATCH_SIZE=18 68 | 69 | # gradient accumulation steps 70 | GRADIENT_ACCUMULATION_STEPS=1 71 | 72 | # learning rate 73 | LEARNING_RATE=1e-2 74 | 75 | # warmup ratio 76 | WARMUP_RATIO=0.0 # <-----第二个stage改这里 77 | 78 | # weight decay 79 | WEIGHT_DECAY=0.1 80 | 81 | # deepspeed config path 82 | DEEPSPEED_CONFIG_PATH='ds2_config_adamw_kd.json' 83 | 84 | OUTPUT_DIR=output/${JOB_NAME} 85 | mkdir -p ${OUTPUT_DIR} 86 | 87 | /usr/bin/pdsh -w $comma_hostnames bash torchrun_wrapper.sh $comma_hostnames $SLURM_JOB_NAME $SLURM_JOB_ID \ 88 | --nnodes $NNODES \ 89 | --nproc_per_node 8 \ 90 | --rdzv_backend static \ 91 | --rdzv_id $JOB_ID \ 92 | --master_addr $MASTER_ADDR \ 93 | --master_port $MASTER_PORT \ 94 | --max_restarts 3 \ 95 | train.py \ 96 | --model_name_or_path ${MODEL_PATH} \ 97 | --data_path ${DATA_PATH} \ 98 | --output_dir ${OUTPUT_DIR} \ 99 | --bf16 True \ 100 | --num_train_epochs ${STAGE} \ 101 | --model_max_length $MODEL_MAX_LENGTH \ 102 | --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ 103 | --per_device_eval_batch_size 4 \ 104 | --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ 105 | --eval_strategy "no" \ 106 | --save_strategy "steps" \ 107 | --save_steps 250 \ 108 | --save_total_limit 25 \ 109 | --learning_rate $LEARNING_RATE \ 110 | --warmup_ratio $WARMUP_RATIO \ 111 | --weight_decay $WEIGHT_DECAY \ 112 | --logging_steps 3 \ 113 | --deepspeed ${DEEPSPEED_CONFIG_PATH} \ 114 | --gradient_checkpointing True \ 115 | --deepspeed_gradient_checkpointing False \ 116 | --report_to tensorboard \ 117 | --tf32 True \ 118 | --lr_scheduler_type "linear" \ 119 | --flash_attention \ 120 | --use_wsd \ 121 | --log_dir $LOG_DIR \ 122 | --profile False \ 123 | --torch_compile \ 124 | --max_grad_norm 1 \ 125 | --hyper_param_decay_rate 0 \ 126 | --logging_dir ${LOG_DIR} \ 127 | --ddp_timeout 3600 \ 128 | --adam_beta1 0.9 \ 129 | --adam_beta2 0.95 \ 130 | --run_name $LOG_PREFIX \ 131 | --adam_epsilon 1e-15 \ 132 | --dataloader_num_workers 4 \ 133 | --dataloader_prefetch_factor 2 \ 134 | --shrink_alpha 1 \ 135 | --init_scale_o 1 \ 136 | --qk_layernorm False \ 137 | --hidden_size 1920 \ 138 | --intermediate_size 4800 \ 139 | --num_hidden_layers 56 \ 140 | --num_attention_heads 30 \ 141 | --num_key_value_heads 6 \ 142 | --model_reproduce cerebras \ 143 | --scale_emb 10 \ 144 | --tie_word_embeddings True \ 145 | --attention_bias True \ 146 | --z_loss 0.0001 \ 147 | --gradient_checkpointing_step 12 \ 148 | --use_muparam_lr True \ 149 | --initializer_range 0.00005 \ 150 | --q_proj_alpha 0.3651483716701107 \ 151 | --k_proj_alpha 0.3651483716701107 \ 152 | --v_proj_alpha 0.3651483716701107 \ 153 | --gate_up_proj_alpha 0.3651483716701107 \ 154 | --o_proj_alpha 0.03450327796711771 \ 155 | --down_proj_alpha 0.03450327796711771 \ 156 | --input_layernorm_alpha 1 \ 157 | --post_attention_layernorm_alpha 1 \ 158 | --norm_alpha 1 \ 159 | --lm_head_alpha 1 \ 160 | --dim_model_base_lr 256 \ 161 | --dim_model_base_logits 1920 \ 162 | --vi_residual_alpha 1.4 \ 163 | --wesar_weights True \ 164 | --use_norm_alpha True \ 165 | --use_emb_alpha False \ 166 | --resume_from_checkpoint $MODEL_PATH \ 167 | --add_rms_norm $DO_RMS_NORM \ 168 | --modify_trainer_state $MODIFY_TRAINER_STATE \ 169 | --update_trained_steps_and_epochs $UPDATE_TRAINED_STEPS_AND_EPOCHS \ 170 | -------------------------------------------------------------------------------- /pretrain/yulanmini-2B-s25d-decay80-1sqrt-long-28k-final-phase26.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --comment=joint_project 4 | 5 | #SBATCH --job-name=xxxx 6 | 7 | #SBATCH --ntasks-per-node=1 8 | 9 | #SBATCH --gres=gpu:a800:8 10 | 11 | #SBATCH --partition=debug 12 | 13 | #SBATCH --output=log/%x-%j/part0.log 14 | 15 | #SBATCH --error=log/%x-%j/part0.log 16 | 17 | 18 | source setup.sh 19 | 20 | # ========== RESUME 只需要修改这里 ========== 21 | last_stage_job_name=miniyulan-2B-final-phase25 22 | STAGE=26 23 | START_GLOBAL_STEP=243198 24 | DECAY_STEPS=19000 # 退火steps,注意会和batch size有关 25 | START_LAMBDA=1 26 | END_LAMBDA=0. # 从0.01降至0 27 | # ======================================== 28 | 29 | CONTINUE=false 30 | if [ "$CONTINUE" = false ]; then 31 | DO_RMS_NORM=true 32 | ALLOW_0_CHECKPOINT=false 33 | UPDATE_TRAINED_STEPS_AND_EPOCHS=true 34 | elif [ "$CONTINUE" = true ]; then 35 | DO_RMS_NORM=false 36 | ALLOW_0_CHECKPOINT=true 37 | UPDATE_TRAINED_STEPS_AND_EPOCHS=false 38 | fi 39 | 40 | MODIFY_TRAINER_STATE=false 41 | 42 | # 计算上一次的最新checkpoint 43 | last_stage_latest_checkpoint=$(ls output_soft_link/$last_stage_job_name | grep checkpoint | grep -v rebalanced | grep -v rms_norm | sort -r | head -n 1) 44 | 45 | # 如果ALLOW_0_CHECKPOINT=false,检查获得的checkpoint不应该是000结尾 46 | if [ "$ALLOW_0_CHECKPOINT" = false ] && [[ "$last_stage_latest_checkpoint" == *000 ]]; then 47 | echo "last_stage_latest_checkpoint is 000, exit" 48 | exit 1 49 | fi 50 | 51 | # 如果没有rms_norm,则重新平衡权重 52 | if [ ! -d "output_soft_link/$last_stage_job_name/$last_stage_latest_checkpoint-rms_norm" ] && [ "$DO_RMS_NORM" = true ]; then 53 | python scripts/rebalance_weight.py output_soft_link/$last_stage_job_name/$last_stage_latest_checkpoint 54 | fi 55 | 56 | # dataset path 57 | # FETCH_TIME="" # 注意!现在FETCH_TIME自动从launch中传入!!!!所以在submit_to_slurm.sh中设置!!!! 58 | DATA_PATH=hf_dataset/$DATASET_MODEL_NAME/$FETCH_TIME 59 | 60 | MODEL_PATH=output/$last_stage_job_name 61 | 62 | # model max length 63 | MODEL_MAX_LENGTH=28672 64 | 65 | # batch size 66 | # 下面的BS 节点数 GPU数 CONTEXT-SIZE 67 | # PER_DEVICE_TRAIN_BATCH_SIZE=18 68 | 69 | # gradient accumulation steps 70 | GRADIENT_ACCUMULATION_STEPS=1 71 | 72 | # learning rate 73 | LEARNING_RATE=1e-2 74 | 75 | # warmup ratio 76 | WARMUP_RATIO=0.0 77 | END_GLOBAL_STEP=$(expr $START_GLOBAL_STEP + $DECAY_STEPS) 78 | 79 | # weight decay 80 | WEIGHT_DECAY=0.1 81 | 82 | # deepspeed config path 83 | DEEPSPEED_CONFIG_PATH='ds2_config_adamw.json' 84 | 85 | OUTPUT_DIR=output/${JOB_NAME} 86 | mkdir -p ${OUTPUT_DIR} 87 | 88 | /usr/bin/pdsh -w $comma_hostnames bash torchrun_wrapper.sh $comma_hostnames $SLURM_JOB_NAME $SLURM_JOB_ID \ 89 | --nnodes $NNODES \ 90 | --nproc_per_node 8 \ 91 | --rdzv_backend static \ 92 | --rdzv_id $JOB_ID \ 93 | --master_addr $MASTER_ADDR \ 94 | --master_port $MASTER_PORT \ 95 | --max_restarts 3 \ 96 | train.py \ 97 | --model_name_or_path ${MODEL_PATH} \ 98 | --data_path ${DATA_PATH} \ 99 | --output_dir ${OUTPUT_DIR} \ 100 | --bf16 True \ 101 | --num_train_epochs ${STAGE} \ 102 | --model_max_length $MODEL_MAX_LENGTH \ 103 | --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ 104 | --per_device_eval_batch_size 4 \ 105 | --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \ 106 | --eval_strategy "no" \ 107 | --save_strategy "steps" \ 108 | --save_steps 250 \ 109 | --save_total_limit 25 \ 110 | --learning_rate $LEARNING_RATE \ 111 | --warmup_ratio $WARMUP_RATIO \ 112 | --weight_decay $WEIGHT_DECAY \ 113 | --logging_steps 3 \ 114 | --deepspeed ${DEEPSPEED_CONFIG_PATH} \ 115 | --gradient_checkpointing True \ 116 | --deepspeed_gradient_checkpointing False \ 117 | --report_to tensorboard \ 118 | --tf32 True \ 119 | --lr_scheduler_type "linear" \ 120 | --flash_attention \ 121 | --use_wsd \ 122 | --log_dir $LOG_DIR \ 123 | --profile False \ 124 | --torch_compile \ 125 | --max_grad_norm 1 \ 126 | --hyper_param_decay_rate 0 \ 127 | --logging_dir ${LOG_DIR} \ 128 | --ddp_timeout 3600 \ 129 | --adam_beta1 0.9 \ 130 | --adam_beta2 0.95 \ 131 | --run_name $LOG_PREFIX \ 132 | --adam_epsilon 1e-15 \ 133 | --dataloader_num_workers 4 \ 134 | --dataloader_prefetch_factor 2 \ 135 | --shrink_alpha 1 \ 136 | --init_scale_o 1 \ 137 | --qk_layernorm False \ 138 | --hidden_size 1920 \ 139 | --intermediate_size 4800 \ 140 | --num_hidden_layers 56 \ 141 | --num_attention_heads 30 \ 142 | --num_key_value_heads 6 \ 143 | --model_reproduce cerebras \ 144 | --scale_emb 10 \ 145 | --tie_word_embeddings True \ 146 | --attention_bias True \ 147 | --z_loss 0.0001 \ 148 | --gradient_checkpointing_step 56 \ 149 | --use_muparam_lr True \ 150 | --initializer_range 0.00005 \ 151 | --q_proj_alpha 0.3651483716701107 \ 152 | --k_proj_alpha 0.3651483716701107 \ 153 | --v_proj_alpha 0.3651483716701107 \ 154 | --gate_up_proj_alpha 0.3651483716701107 \ 155 | --o_proj_alpha 0.03450327796711771 \ 156 | --down_proj_alpha 0.03450327796711771 \ 157 | --input_layernorm_alpha 1 \ 158 | --post_attention_layernorm_alpha 1 \ 159 | --norm_alpha 1 \ 160 | --lm_head_alpha 1 \ 161 | --dim_model_base_lr 256 \ 162 | --dim_model_base_logits 1920 \ 163 | --vi_residual_alpha 1.4 \ 164 | --wesar_weights True \ 165 | --use_norm_alpha True \ 166 | --use_emb_alpha False \ 167 | --resume_from_checkpoint $MODEL_PATH \ 168 | --add_rms_norm $DO_RMS_NORM \ 169 | --modify_trainer_state $MODIFY_TRAINER_STATE \ 170 | --update_trained_steps_and_epochs $UPDATE_TRAINED_STEPS_AND_EPOCHS \ 171 | --start_lambda $START_LAMBDA \ 172 | --end_lambda $END_LAMBDA \ 173 | --start_global_step $START_GLOBAL_STEP \ 174 | --end_global_step $END_GLOBAL_STEP \ 175 | --wsd_style 1sqrt \ 176 | --------------------------------------------------------------------------------