├── LICENSE ├── MEAP-Pretain ├── c4_bin │ └── c4_0_0000000006.bin ├── convert │ ├── convert_lit_checkpoint.py │ └── convert_safetensors.py ├── data_process │ ├── gz_unzip_v1.py │ └── prepare_c4.py ├── lit_gpt │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── config.cpython-310.pyc │ │ ├── fused_cross_entropy.cpython-310.pyc │ │ ├── fused_rotary_embedding.cpython-310.pyc │ │ ├── model.cpython-310.pyc │ │ ├── packed_dataset.cpython-310.pyc │ │ ├── rmsnorm.cpython-310.pyc │ │ ├── speed_monitor.cpython-310.pyc │ │ ├── tokenizer.cpython-310.pyc │ │ └── utils.cpython-310.pyc │ ├── adapter.py │ ├── adapter_v2.py │ ├── config.py │ ├── fused_cross_entropy.py │ ├── fused_rotary_embedding.py │ ├── lora.py │ ├── model.py │ ├── packed_dataset.py │ ├── rmsnorm.py │ ├── speed_monitor.py │ ├── tokenizer.py │ └── utils.py ├── pretrained │ ├── meap_0.1b_2b.py │ ├── meap_0.1b_2b_0.05.py │ ├── meap_0.1b_2b_0.1.py │ ├── meap_0.3b_5b.py │ ├── meap_0.3b_5b_0.05.py │ ├── meap_0.3b_5b_0.1.py │ ├── meap_0.3b_5b_0.2.py │ ├── meap_0.5b_10b.py │ ├── meap_0.5b_10b_0.05.py │ ├── meap_0.5b_10b_0.1.py │ ├── meap_0.5b_10b_0.2.py │ ├── meap_1b.py │ ├── meap_1b_0.05.py │ ├── meap_1b_0.1.py │ ├── meap_1b_0.15_20b.py │ ├── meap_1b_0.15_40b.py │ ├── meap_1b_0.15_60b.py │ └── meap_1b_0.2.py ├── requirements.txt ├── run │ └── run_one_node.sh └── tokenizer │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer.model │ └── tokenizer_config.json ├── MEAP-SFT ├── MEAP-SFT.py ├── deepspeed_zero_stage2_config.json ├── requirements.txt └── template.py ├── README.md └── data └── alpaca_data.jsonl /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 scitix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MEAP-Pretain/c4_bin/c4_0_0000000006.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/c4_bin/c4_0_0000000006.bin -------------------------------------------------------------------------------- /MEAP-Pretain/convert/convert_lit_checkpoint.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import sys 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Dict, Literal, Optional, Tuple, Union 7 | from dataclasses import asdict 8 | import json 9 | import torch 10 | 11 | # support running without installing as a package 12 | wd = Path(__file__).parent.parent.resolve() 13 | sys.path.append(str(wd)) 14 | 15 | from lit_gpt import Config 16 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 17 | # from scripts.convert_hf_checkpoint import layer_template, load_param 18 | 19 | 20 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 21 | split = layer_name.split(".") 22 | number = int(split[idx]) 23 | split[idx] = "{}" 24 | from_name = ".".join(split) 25 | return from_name, number 26 | 27 | 28 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 29 | if hasattr(param, "_load_tensor"): 30 | # support tensors loaded via `lazy_load()` 31 | print(f"Loading {name!r} into RAM") 32 | param = param._load_tensor() 33 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 34 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 35 | param = param.to(dtype) 36 | return param 37 | def copy_weights_falcon( 38 | size: Literal["7b", "40b"], 39 | state_dict: Dict[str, torch.Tensor], 40 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 41 | saver: Optional[incremental_save] = None, 42 | ): 43 | weight_map = { 44 | "transformer.wte.weight": "transformer.word_embeddings.weight", 45 | "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", 46 | "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", 47 | "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", 48 | "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", 49 | "transformer.ln_f.bias": "transformer.ln_f.bias", 50 | "transformer.ln_f.weight": "transformer.ln_f.weight", 51 | "lm_head.weight": "lm_head.weight", 52 | } 53 | # the original model definition is different for each size 54 | if size == "7b": 55 | weight_map.update( 56 | { 57 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", 58 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", 59 | } 60 | ) 61 | elif size == "40b": 62 | weight_map.update( 63 | { 64 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", 65 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", 66 | "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", 67 | "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", 68 | } 69 | ) 70 | else: 71 | raise NotImplementedError 72 | 73 | for name, param in lit_weights.items(): 74 | if "transformer.h" in name: 75 | from_name, number = layer_template(name, 2) 76 | to_name = weight_map[from_name].format(number) 77 | else: 78 | to_name = weight_map[name] 79 | param = load_param(param, name, None) 80 | if saver is not None: 81 | param = saver.store_early(param) 82 | state_dict[to_name] = param 83 | 84 | 85 | def copy_weights_gpt_neox( 86 | state_dict: Dict[str, torch.Tensor], 87 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 88 | saver: Optional[incremental_save] = None, 89 | ) -> None: 90 | weight_map = { 91 | "transformer.wte.weight": "gpt_neox.embed_in.weight", 92 | "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", 93 | "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", 94 | "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", 95 | "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", 96 | "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", 97 | "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", 98 | "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", 99 | "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", 100 | "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", 101 | "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", 102 | "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", 103 | "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", 104 | "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", 105 | "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", 106 | "lm_head.weight": "embed_out.weight", 107 | } 108 | 109 | for name, param in lit_weights.items(): 110 | if "transformer.h" in name: 111 | from_name, number = layer_template(name, 2) 112 | to_name = weight_map[from_name].format(number) 113 | else: 114 | to_name = weight_map[name] 115 | param = load_param(param, name, None) 116 | if saver is not None: 117 | param = saver.store_early(param) 118 | state_dict[to_name] = param 119 | 120 | 121 | def copy_weights_llama( 122 | config: Config, 123 | state_dict: Dict[str, torch.Tensor], 124 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 125 | saver: Optional[incremental_save] = None, 126 | ): 127 | weight_map = { 128 | "transformer.wte.weight": "model.embed_tokens.weight", 129 | "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", 130 | "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", 131 | "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", 132 | "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", 133 | "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", 134 | "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", 135 | "transformer.ln_f.weight": "model.norm.weight", 136 | "lm_head.weight": "lm_head.weight", 137 | } 138 | for name, param in lit_weights.items(): 139 | if name.endswith(".attn.attn.weight"): 140 | from_name, number = layer_template(name, 2) 141 | q = "model.layers.{}.self_attn.q_proj.weight".format(number) 142 | k = "model.layers.{}.self_attn.k_proj.weight".format(number) 143 | v = "model.layers.{}.self_attn.v_proj.weight".format(number) 144 | qkv = load_param(param, name,None) 145 | qp, kp, vp = tensor_split(qkv, config) 146 | for to_name, param in zip((q, k, v), (qp, kp, vp)): 147 | if saver is not None: 148 | param = saver.store_early(param) 149 | state_dict[to_name] = param 150 | elif "transformer.h" in name: 151 | from_name, number = layer_template(name, 2) 152 | to_name = weight_map[from_name] 153 | 154 | if to_name is None: 155 | continue 156 | to_name = to_name.format(number) 157 | param = load_param(param, name,None) 158 | if saver is not None: 159 | param = saver.store_early(param) 160 | state_dict[to_name] = param 161 | 162 | else: 163 | to_name = weight_map[name] 164 | param = load_param(param, name, None) 165 | if saver is not None: 166 | param = saver.store_early(param) 167 | state_dict[to_name] = param 168 | 169 | 170 | def tensor_split( 171 | param: Union[torch.Tensor, NotYetLoadedTensor], config: Config 172 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 173 | def kstart(start, blen, klen) -> int: 174 | """returns start index of keys in batch""" 175 | return start + (blen - (klen * 2)) 176 | 177 | def vstart(start, blen, klen) -> int: 178 | """returns start index of values in batch""" 179 | return start + blen - klen 180 | 181 | def vend(start, blen) -> int: 182 | """returns last index of values in batch""" 183 | return start + blen 184 | 185 | # num observations 186 | nobs = param.shape[0] 187 | # batch length 188 | blen = nobs // config.n_query_groups 189 | # key length in batch 190 | klen = config.head_size 191 | # value length in batch 192 | vlen = config.head_size 193 | # the starting index of each new batch 194 | starts = range(0, nobs, blen) 195 | # the indices to splice on 196 | splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] 197 | 198 | qc = () 199 | kc = () 200 | vc = () 201 | 202 | for splice in splices: 203 | qs, ks, vs, ve = splice 204 | qc += (param[qs:ks, :],) 205 | kc += (param[ks:vs, :],) 206 | vc += (param[vs:ve, :],) 207 | 208 | q = torch.cat(qc) 209 | k = torch.cat(kc) 210 | v = torch.cat(vc) 211 | 212 | return q, k, v 213 | 214 | 215 | def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 216 | return lit_weights.get("model", lit_weights) 217 | 218 | 219 | def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: 220 | weight_names = {wk.split(".")[-1] for wk in lit_weights} 221 | # LoRA or QLoRA 222 | if any("lora" in wn for wn in weight_names): 223 | raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") 224 | # adapter v2. adapter_bias will only be in adapter_v2 225 | elif "adapter_bias" in weight_names: 226 | raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") 227 | # adapter. gating_factor is in adapter and adapter_v2 228 | elif "gating_factor" in weight_names: 229 | raise NotImplementedError("Converting models finetuned with adapter not yet supported.") 230 | 231 | 232 | def get_tinyllama_init_hf_config() -> dict: 233 | return { 234 | "architectures": ["LlamaForCausalLM"], 235 | "bos_token_id": 1, 236 | "eos_token_id": 2, 237 | "hidden_act": "silu", 238 | "hidden_size": None, 239 | "initializer_range": 0.02, 240 | "intermediate_size": None, 241 | "max_position_embeddings": None, 242 | "model_type": "llama", 243 | "num_attention_heads": None, 244 | "num_hidden_layers": None, 245 | "num_key_value_heads": None, 246 | "pretraining_tp": 1, 247 | "rms_norm_eps": None, 248 | "rope_scaling": None, 249 | "tie_word_embeddings": False, 250 | "torch_dtype": "float32", 251 | "transformers_version": "4.31.0.dev0", 252 | "use_cache": True, 253 | "vocab_size": None, 254 | } 255 | 256 | 257 | def convert_config_lit_to_hf(lit_config_dict: dict) -> dict: 258 | lit_hf_mapping = { 259 | "block_size": "max_position_embeddings", 260 | "vocab_size": "vocab_size", 261 | "n_layer": "num_hidden_layers", 262 | "n_embd": "hidden_size", 263 | "n_head": "num_attention_heads", 264 | "n_query_groups": "num_key_value_heads", 265 | "intermediate_size": "intermediate_size", 266 | "norm_eps": "rms_norm_eps", 267 | 268 | } 269 | hf_config_dict = get_tinyllama_init_hf_config() 270 | 271 | for lit_key, hf_key in lit_hf_mapping.items(): 272 | hf_config_dict[hf_key] = lit_config_dict[lit_key] 273 | return hf_config_dict 274 | 275 | 276 | @torch.inference_mode() 277 | def convert_lit_checkpoint(*, 278 | checkpoint_name: str, 279 | out_dir: Path, 280 | model_name: str, 281 | model_only: bool = True) -> None: 282 | config = Config.from_name(model_name) 283 | 284 | if "falcon" in model_name: 285 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 286 | elif config._mlp_class == "LLaMAMLP": 287 | copy_fn = partial(copy_weights_llama, config) 288 | else: 289 | copy_fn = copy_weights_gpt_neox 290 | 291 | # initialize a new empty state dict to hold our new weights 292 | sd = {} 293 | 294 | # checkpoint_name cannot be hardcoded because there exists different outputs such as 295 | # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") 296 | pth_file = out_dir / checkpoint_name 297 | bin_file = pth_file.with_suffix(".bin") 298 | 299 | with incremental_save(bin_file) as saver: 300 | with contextlib.ExitStack() as stack: 301 | lit_weights = stack.enter_context(lazy_load(pth_file)) 302 | lit_weights = maybe_unwrap_state_dict(lit_weights) 303 | check_conversion_supported(lit_weights) 304 | # Incremental save will trigger error 305 | copy_fn(sd, lit_weights, saver=None) 306 | gc.collect() 307 | saver.save(sd) 308 | 309 | # convert lit config file to hf-style 310 | if not model_only: 311 | print('Converting config file...') 312 | lit_config = asdict(config) 313 | hf_config = convert_config_lit_to_hf(lit_config) 314 | config_path = out_dir / "config.json" 315 | with open(config_path, "w") as f: 316 | json.dump(hf_config, f, indent=4) 317 | 318 | 319 | 320 | 321 | if __name__ == "__main__": 322 | from jsonargparse import CLI 323 | 324 | CLI(convert_lit_checkpoint, as_positional=False) 325 | -------------------------------------------------------------------------------- /MEAP-Pretain/convert/convert_safetensors.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | 4 | def load_models_tokenizer(checkpoint_path): 5 | model = AutoModelForCausalLM.from_pretrained( 6 | checkpoint_path, 7 | trust_remote_code=True, 8 | #torch_device='cpu', 9 | torch_dtype=torch.float16 10 | ) 11 | # return the result 12 | return model 13 | checkpoint_path = '' 14 | model = load_models_tokenizer(checkpoint_path) 15 | out_dir = '' 16 | model.save_pretrained(out_dir, use_safetensors=True) 17 | -------------------------------------------------------------------------------- /MEAP-Pretain/data_process/gz_unzip_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import shutil 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from threading import Lock 6 | 7 | lock = Lock() 8 | 9 | def decompress_and_count_json(file_path, folder_path): 10 | output_file = file_path.rstrip('.gz') 11 | try: 12 | 13 | with gzip.open(file_path, 'rb') as gz_file: 14 | with open(output_file, 'wb') as out_file: 15 | shutil.copyfileobj(gz_file, out_file) 16 | 17 | json_count = count_json_files(folder_path) 18 | 19 | with lock: 20 | print(f"Extraction completed successfully.: {file_path} -> {output_file}") 21 | print(f" {json_count} .json 。") 22 | except Exception as e: 23 | with lock: 24 | print(f"Extra fail: {file_path}, error: {e}") 25 | 26 | def count_json_files(folder_path): 27 | return sum(1 for f in os.listdir(folder_path) if f.endswith('.json')) 28 | 29 | def decompress_all_and_track_json_multithreaded(folder_path, max_workers=4): 30 | gz_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.gz')] 31 | 32 | if not gz_files: 33 | print("not found .gz") 34 | return 35 | 36 | print(f"start {len(gz_files)} ....") 37 | 38 | 39 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 40 | futures = [executor.submit(decompress_and_count_json, file_path, folder_path) for file_path in gz_files] 41 | for future in as_completed(futures): 42 | future.result() 43 | 44 | print("all sucessful !") 45 | 46 | if __name__ == "__main__": 47 | folder_path = input("input file folder: ") 48 | if os.path.isdir(folder_path): 49 | decompress_all_and_track_json_multithreaded(folder_path) 50 | else: 51 | print("input file path wrong,please input again") 52 | 53 | -------------------------------------------------------------------------------- /MEAP-Pretain/data_process/prepare_c4.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import sys 5 | from pathlib import Path 6 | import time 7 | import numpy as np 8 | from tqdm import tqdm 9 | from multiprocessing import Process, cpu_count 10 | # support running without installing as a package 11 | wd = Path(__file__).parent.parent.resolve() 12 | sys.path.append(str(wd)) 13 | 14 | import lit_gpt.packed_dataset as packed_dataset 15 | from transformers import AutoTokenizer 16 | 17 | 18 | 19 | filename_sets = { 20 | "c4": "json_c4/c4-train*", 21 | } 22 | 23 | 24 | 25 | def prepare_full( 26 | filenames: list[Path], checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "",process_id: int = 0 27 | ) -> None: 28 | """Prepare the "Red Pajama" dataset using the original tokenizer.""" 29 | import zstandard as zstd 30 | 31 | destination_path.mkdir(parents=True, exist_ok=True) 32 | print(checkpoint_dir) 33 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False) 34 | 35 | builder = packed_dataset.PackedDatasetBuilder( 36 | outdir=destination_path, 37 | prefix="c4_" + str(process_id), 38 | chunk_size=chunk_size, 39 | sep_token=tokenizer.eos_token_id, 40 | dtype="auto", 41 | vocab_size=tokenizer.vocab_size, 42 | ) 43 | 44 | for filepath in filenames: 45 | #filepath = source_path / name 46 | 47 | print(f"Processing {filepath}") 48 | 49 | 50 | with open(filepath, encoding="utf-8") as f: 51 | for row in tqdm(f): 52 | try: 53 | text = json.loads(row)["text"] 54 | text_ids = tokenizer.encode(text)+[tokenizer.eos_token_id] 55 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 56 | except Exception as e: 57 | print(f"Error processing row: {e}") 58 | continue 59 | #os.remove(filepath) 60 | 61 | builder.write_reminder() 62 | 63 | 64 | def prepare( 65 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 66 | checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), 67 | destination_path: Path = Path("data/redpajama_sample"), 68 | sample: bool = True, 69 | match: str = "", 70 | ) -> None: 71 | """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" 72 | # with open(checkpoint_dir / "lit_config.json") as fp: 73 | # config = Config(**json.load(fp)) 74 | sample = False 75 | 76 | #num_processes = cpu_count() 77 | num_processes = 80 78 | pattern = filename_sets["c4"] 79 | filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) 80 | print("source_path",source_path) 81 | print("pattern",pattern) 82 | print("os.path.join(source_path, pattern)",os.path.join(source_path, pattern)) 83 | print("filenames",filenames) 84 | chunked_filenames = np.array_split(filenames, num_processes) 85 | print("chunked_filenames",chunked_filenames) 86 | 87 | processes = [] 88 | start_time = time.time() 89 | chunk_size = 4097 * 1024 90 | for i,filename in enumerate(chunked_filenames): 91 | print("iter filename",filename) 92 | p = Process(target=prepare_full, args=(filename, checkpoint_dir, destination_path, chunk_size, match,i)) 93 | processes.append(p) 94 | p.start() 95 | 96 | for p in processes: 97 | p.join() 98 | end_time = time.time() 99 | elapsed_time = end_time - start_time 100 | print(f"Time taken: {elapsed_time:.2f} seconds") 101 | 102 | 103 | 104 | if __name__ == "__main__": 105 | from jsonargparse import CLI 106 | 107 | CLI(prepare) -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from lit_gpt.model import GPT 2 | from lit_gpt.config import Config 3 | from lit_gpt.tokenizer import Tokenizer 4 | from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss 5 | from lightning_utilities.core.imports import RequirementCache 6 | 7 | if not bool(RequirementCache("torch>=2.1.0dev")): 8 | raise ImportError( 9 | "Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the" 10 | " repository README.md" 11 | ) 12 | _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0") 13 | if not bool(_LIGHTNING_AVAILABLE): 14 | raise ImportError( 15 | "Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n" 16 | f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" 17 | ) 18 | 19 | 20 | __all__ = ["GPT", "Config", "Tokenizer"] 21 | -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/fused_cross_entropy.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/fused_cross_entropy.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/fused_rotary_embedding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/fused_rotary_embedding.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/packed_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/packed_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/rmsnorm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/rmsnorm.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/speed_monitor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/speed_monitor.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/lit_gpt/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/adapter.py: -------------------------------------------------------------------------------- 1 | """Implementation of the paper: 2 | 3 | LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention 4 | https://arxiv.org/abs/2303.16199 5 | 6 | Port for Lit-GPT 7 | """ 8 | from dataclasses import dataclass 9 | from typing import Any, Dict, List, Optional, Tuple, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | from typing_extensions import Self 14 | 15 | from lit_gpt.config import Config as BaseConfig 16 | from lit_gpt.model import GPT as BaseModel 17 | from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention 18 | from lit_gpt.model import KVCache, RoPECache, apply_rope 19 | 20 | 21 | @dataclass 22 | class Config(BaseConfig): 23 | adapter_prompt_length: int = 10 24 | adapter_start_layer: int = 2 25 | 26 | 27 | class GPT(BaseModel): 28 | """The implementation is identical to `lit_gpt.model.GPT` with the exception that 29 | the `Block` saves the layer index and passes it down to the attention layer.""" 30 | 31 | def __init__(self, config: Config) -> None: 32 | nn.Module.__init__(self) 33 | assert config.padded_vocab_size is not None 34 | self.config = config 35 | 36 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 37 | self.transformer = nn.ModuleDict( 38 | dict( 39 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 40 | h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), 41 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 42 | ) 43 | ) 44 | 45 | self.rope_cache: Optional[RoPECache] = None 46 | self.mask_cache: Optional[torch.Tensor] = None 47 | self.kv_caches: List[KVCache] = [] 48 | self.adapter_kv_caches: List[KVCache] = [] 49 | 50 | def reset_cache(self) -> None: 51 | super().reset_cache() 52 | self.adapter_kv_caches.clear() 53 | 54 | def forward( 55 | self, 56 | idx: torch.Tensor, 57 | max_seq_length: Optional[int] = None, 58 | input_pos: Optional[torch.Tensor] = None, 59 | lm_head_chunk_size: int = 0, 60 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 61 | B, T = idx.size() 62 | use_kv_cache = input_pos is not None 63 | 64 | block_size = self.config.block_size 65 | if max_seq_length is None: 66 | max_seq_length = block_size 67 | if use_kv_cache: # not relevant otherwise 68 | assert ( 69 | max_seq_length >= T 70 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 71 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 72 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 73 | 74 | if self.rope_cache is None: 75 | self.rope_cache = self.build_rope_cache(idx) 76 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 77 | # for the kv-cache support (only during inference), we only create it in that situation 78 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 79 | if use_kv_cache and self.mask_cache is None: 80 | self.mask_cache = self.build_mask_cache(idx) 81 | 82 | cos, sin = self.rope_cache 83 | if use_kv_cache: 84 | cos = cos.index_select(0, input_pos) 85 | sin = sin.index_select(0, input_pos) 86 | mask = self.mask_cache.index_select(2, input_pos) 87 | mask = mask[:, :, :, :max_seq_length] 88 | else: 89 | cos = cos[:T] 90 | sin = sin[:T] 91 | mask = None 92 | 93 | # forward the model itself 94 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 95 | 96 | if not use_kv_cache: 97 | for block in self.transformer.h: 98 | x, *_ = block(x, (cos, sin), max_seq_length) 99 | else: 100 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1)) 101 | self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)] 102 | for i, block in enumerate(self.transformer.h): 103 | x, self.kv_caches[i], self.adapter_kv_caches[i] = block( 104 | x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] 105 | ) 106 | 107 | x = self.transformer.ln_f(x) 108 | 109 | if lm_head_chunk_size > 0: 110 | # chunk the lm head logits to reduce the peak memory used by autograd 111 | return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] 112 | return self.lm_head(x) # (b, t, vocab_size) 113 | 114 | @classmethod 115 | def from_name(cls, name: str, **kwargs: Any) -> Self: 116 | return cls(Config.from_name(name, **kwargs)) 117 | 118 | def _init_weights(self, module: nn.Module) -> None: 119 | """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" 120 | super()._init_weights(module) 121 | if isinstance(module, CausalSelfAttention): 122 | module.reset_parameters() 123 | 124 | 125 | class Block(nn.Module): 126 | """The implementation is identical to `lit_gpt.model.Block` with the exception that 127 | we replace the attention layer where adaption is implemented.""" 128 | 129 | def __init__(self, config: Config, block_idx: int) -> None: 130 | super().__init__() 131 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 132 | self.attn = CausalSelfAttention(config, block_idx) 133 | if not config.shared_attention_norm: 134 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 135 | self.mlp = config.mlp_class(config) 136 | 137 | self.config = config 138 | 139 | def forward( 140 | self, 141 | x: torch.Tensor, 142 | rope: RoPECache, 143 | max_seq_length: int, 144 | mask: Optional[torch.Tensor] = None, 145 | input_pos: Optional[torch.Tensor] = None, 146 | kv_cache: Optional[KVCache] = None, 147 | adapter_kv_cache: Optional[KVCache] = None, 148 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 149 | n_1 = self.norm_1(x) 150 | h, new_kv_cache, new_adapter_kv_cache = self.attn( 151 | n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache 152 | ) 153 | if self.config.parallel_residual: 154 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 155 | x = x + h + self.mlp(n_2) 156 | else: 157 | if self.config.shared_attention_norm: 158 | raise NotImplementedError( 159 | "No checkpoint amongst the ones we support uses this configuration" 160 | " (non-parallel residual and shared attention norm)." 161 | ) 162 | x = x + h 163 | x = x + self.mlp(self.norm_2(x)) 164 | return x, new_kv_cache, new_adapter_kv_cache 165 | 166 | 167 | class CausalSelfAttention(BaseCausalSelfAttention): 168 | """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention 169 | over the adaption prompt.""" 170 | 171 | def __init__(self, config: Config, block_idx: int) -> None: 172 | super().__init__(config) 173 | if block_idx >= config.adapter_start_layer: 174 | # adapter embedding layer 175 | self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) 176 | # gate for adaption 177 | self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) 178 | self.reset_parameters() 179 | self.block_idx = block_idx 180 | 181 | def forward( 182 | self, 183 | x: torch.Tensor, 184 | rope: RoPECache, 185 | max_seq_length: int, 186 | mask: Optional[torch.Tensor] = None, 187 | input_pos: Optional[torch.Tensor] = None, 188 | kv_cache: Optional[KVCache] = None, 189 | adapter_kv_cache: Optional[KVCache] = None, 190 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 191 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 192 | 193 | qkv = self.attn(x) 194 | 195 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 196 | q_per_kv = self.config.n_head // self.config.n_query_groups 197 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 198 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) 199 | qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 200 | 201 | # split batched computation into three 202 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 203 | 204 | # repeat k and v if necessary 205 | if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 206 | # for MHA this is a no-op 207 | k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 208 | v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 209 | 210 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 211 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 212 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 213 | 214 | n_elem = int(self.config.rotary_percentage * self.config.head_size) 215 | 216 | cos, sin = rope 217 | q_roped = apply_rope(q[..., :n_elem], cos, sin) 218 | k_roped = apply_rope(k[..., :n_elem], cos, sin) 219 | q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 220 | k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 221 | 222 | if kv_cache is not None: 223 | cache_k, cache_v = kv_cache 224 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 225 | # check if reached token limit 226 | if input_pos[-1] >= max_seq_length: 227 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 228 | # shift 1 position to the left 229 | cache_k = torch.roll(cache_k, -1, dims=2) 230 | cache_v = torch.roll(cache_v, -1, dims=2) 231 | k = cache_k.index_copy_(2, input_pos, k) 232 | v = cache_v.index_copy_(2, input_pos, v) 233 | kv_cache = k, v 234 | 235 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 236 | 237 | if self.block_idx >= self.config.adapter_start_layer: 238 | aT = self.config.adapter_prompt_length 239 | if adapter_kv_cache is not None: 240 | ak, av = adapter_kv_cache 241 | else: 242 | prefix = self.adapter_wte.weight.reshape(1, aT, C) 243 | aqkv = self.attn(prefix) 244 | aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) 245 | aqkv = aqkv.permute(0, 2, 3, 1, 4) 246 | _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) 247 | if self.config.n_query_groups != 1: 248 | # for MHA this is a no-op 249 | ak = ak.repeat_interleave(q_per_kv, dim=2) 250 | av = av.repeat_interleave(q_per_kv, dim=2) 251 | ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) 252 | av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) 253 | adapter_kv_cache = (ak, av) 254 | 255 | amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) 256 | ay = self.scaled_dot_product_attention(q, ak, av, amask) 257 | y = y + self.gating_factor * ay 258 | 259 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 260 | 261 | # output projection 262 | y = self.proj(y) 263 | 264 | return y, kv_cache, adapter_kv_cache 265 | 266 | def reset_parameters(self) -> None: 267 | torch.nn.init.zeros_(self.gating_factor) 268 | 269 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 270 | """For compatibility with older checkpoints.""" 271 | if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: 272 | state_dict[key] = state_dict[key].permute(0, 2, 1, 3) 273 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 274 | 275 | 276 | def mark_only_adapter_as_trainable(model: GPT) -> None: 277 | """Sets `requires_grad=False` for all non-adapter weights.""" 278 | for name, param in model.named_parameters(): 279 | param.requires_grad = adapter_filter(name, param) 280 | 281 | 282 | def adapter_filter(key: str, value: Any) -> bool: 283 | return "adapter_wte" in key or "gating_factor" in key 284 | -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/adapter_v2.py: -------------------------------------------------------------------------------- 1 | """Implementation of the paper: 2 | 3 | LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model 4 | https://arxiv.org/abs/2304.15010 5 | 6 | Port for Lit-GPT 7 | """ 8 | from dataclasses import dataclass 9 | from typing import Any, Dict, List, Optional, Tuple, Type 10 | 11 | import torch 12 | import torch.nn as nn 13 | from typing_extensions import Self 14 | 15 | import lit_gpt 16 | from lit_gpt.adapter import GPT as BaseModel 17 | from lit_gpt.adapter import Block as BaseBlock 18 | from lit_gpt.adapter import Config as BaseConfig 19 | from lit_gpt.adapter import KVCache, RoPECache 20 | from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention 21 | from lit_gpt.model import apply_rope 22 | from lit_gpt.utils import map_old_state_dict_weights 23 | 24 | 25 | @dataclass 26 | class Config(BaseConfig): 27 | @property 28 | def mlp_class(self) -> Type: 29 | return getattr(lit_gpt.adapter_v2, self._mlp_class) 30 | 31 | 32 | def adapter_filter(key: str, value: Any) -> bool: 33 | adapter_substrings = ( 34 | # regular adapter v1 parameters 35 | "adapter_wte", 36 | "gating_factor", 37 | # adapter v2: new bias and scale used in Linear 38 | "adapter_scale", 39 | "adapter_bias", 40 | # adapter v2: Norm parameters are now trainable 41 | "norm_1", 42 | "norm_2", 43 | "ln_f", 44 | ) 45 | return any(s in key for s in adapter_substrings) 46 | 47 | 48 | class AdapterV2Linear(torch.nn.Module): 49 | def __init__(self, in_features: int, out_features: int, **kwargs) -> None: 50 | super().__init__() 51 | self.linear = torch.nn.Linear(in_features, out_features, **kwargs) 52 | self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) 53 | self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) 54 | self.reset_parameters() 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | return self.adapter_scale * (self.linear(x) + self.adapter_bias) 58 | 59 | def reset_parameters(self) -> None: 60 | nn.init.zeros_(self.adapter_bias) 61 | nn.init.ones_(self.adapter_scale) 62 | 63 | 64 | class GPT(BaseModel): 65 | def __init__(self, config: Config) -> None: 66 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 67 | nn.Module.__init__(self) 68 | assert config.padded_vocab_size is not None 69 | self.config = config 70 | 71 | self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False) 72 | self.transformer = nn.ModuleDict( 73 | dict( 74 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 75 | h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), 76 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 77 | ) 78 | ) 79 | 80 | self.rope_cache: Optional[RoPECache] = None 81 | self.mask_cache: Optional[torch.Tensor] = None 82 | self.kv_caches: List[KVCache] = [] 83 | self.adapter_kv_caches: List[KVCache] = [] 84 | 85 | @classmethod 86 | def from_name(cls, name: str, **kwargs: Any) -> Self: 87 | return cls(Config.from_name(name, **kwargs)) 88 | 89 | def _init_weights(self, module: nn.Module) -> None: 90 | """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" 91 | super()._init_weights(module) 92 | if isinstance(module, CausalSelfAttention): 93 | module.reset_parameters() 94 | if isinstance(module, AdapterV2Linear): 95 | module.reset_parameters() 96 | 97 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 98 | """For compatibility with base checkpoints.""" 99 | mapping = {"lm_head.weight": "lm_head.linear.weight"} 100 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 101 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 102 | 103 | 104 | class Block(BaseBlock): 105 | """The implementation is identical to `lit_gpt.model.Block` with the exception that 106 | we replace the attention layer where adaption is implemented.""" 107 | 108 | def __init__(self, config: Config, block_idx: int) -> None: 109 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 110 | nn.Module.__init__(self) 111 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 112 | self.attn = CausalSelfAttention(config, block_idx) 113 | if not config.shared_attention_norm: 114 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 115 | self.mlp = config.mlp_class(config) 116 | 117 | self.config = config 118 | 119 | 120 | class CausalSelfAttention(BaseCausalSelfAttention): 121 | def __init__(self, config: Config, block_idx: int) -> None: 122 | """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for 123 | parameter-efficient fine-tuning. 124 | 125 | *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for 126 | query, key and value for each head) we can do this in a single pass with a single weight matrix. 127 | """ 128 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 129 | nn.Module.__init__(self) 130 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 131 | # key, query, value projections for all heads, but in a batch 132 | self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) 133 | # output projection 134 | self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) 135 | if block_idx >= config.adapter_start_layer: 136 | # adapter embedding layer 137 | self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) 138 | # gate for adaption 139 | self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) 140 | self.reset_parameters() 141 | self.block_idx = block_idx 142 | 143 | self.config = config 144 | 145 | def forward( 146 | self, 147 | x: torch.Tensor, 148 | rope: RoPECache, 149 | max_seq_length: int, 150 | mask: Optional[torch.Tensor] = None, 151 | input_pos: Optional[torch.Tensor] = None, 152 | kv_cache: Optional[KVCache] = None, 153 | adapter_kv_cache: Optional[KVCache] = None, 154 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 155 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 156 | 157 | qkv = self.attn(x) 158 | 159 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 160 | q_per_kv = self.config.n_head // self.config.n_query_groups 161 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 162 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) 163 | qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 164 | 165 | # split batched computation into three 166 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 167 | 168 | # repeat k and v if necessary 169 | if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 170 | # for MHA this is a no-op 171 | k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 172 | v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 173 | 174 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 175 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 176 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 177 | 178 | n_elem = int(self.config.rotary_percentage * self.config.head_size) 179 | 180 | cos, sin = rope 181 | q_roped = apply_rope(q[..., :n_elem], cos, sin) 182 | k_roped = apply_rope(k[..., :n_elem], cos, sin) 183 | q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 184 | k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 185 | 186 | if kv_cache is not None: 187 | cache_k, cache_v = kv_cache 188 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 189 | # check if reached token limit 190 | if input_pos[-1] >= max_seq_length: 191 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 192 | # shift 1 position to the left 193 | cache_k = torch.roll(cache_k, -1, dims=2) 194 | cache_v = torch.roll(cache_v, -1, dims=2) 195 | k = cache_k.index_copy_(2, input_pos, k) 196 | v = cache_v.index_copy_(2, input_pos, v) 197 | kv_cache = k, v 198 | 199 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 200 | 201 | if self.block_idx >= self.config.adapter_start_layer: 202 | aT = self.config.adapter_prompt_length 203 | if adapter_kv_cache is not None: 204 | ak, av = adapter_kv_cache 205 | else: 206 | prefix = self.adapter_wte.weight.reshape(1, aT, C) 207 | aqkv = self.attn(prefix) 208 | aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) 209 | aqkv = aqkv.permute(0, 2, 3, 1, 4) 210 | _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) 211 | if self.config.n_query_groups != 1: 212 | # for MHA this is a no-op 213 | ak = ak.repeat_interleave(q_per_kv, dim=2) 214 | av = av.repeat_interleave(q_per_kv, dim=2) 215 | ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) 216 | av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) 217 | adapter_kv_cache = (ak, av) 218 | 219 | amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) 220 | ay = self.scaled_dot_product_attention(q, ak, av, amask) 221 | y = y + self.gating_factor * ay 222 | 223 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 224 | 225 | # output projection 226 | y = self.proj(y) 227 | 228 | return y, kv_cache, adapter_kv_cache 229 | 230 | def reset_parameters(self) -> None: 231 | torch.nn.init.zeros_(self.gating_factor) 232 | 233 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 234 | """For compatibility with base checkpoints.""" 235 | mapping = { 236 | "attn.weight": "attn.linear.weight", 237 | "attn.bias": "attn.linear.bias", 238 | "proj.weight": "proj.linear.weight", 239 | "proj.bias": "proj.linear.bias", 240 | } 241 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 242 | # For compatibility with older checkpoints 243 | if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: 244 | state_dict[key] = state_dict[key].permute(0, 2, 1, 3) 245 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 246 | 247 | 248 | class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): 249 | def __init__(self, config: Config) -> None: 250 | nn.Module.__init__(self) 251 | self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 252 | self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) 253 | 254 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 255 | """For compatibility with base checkpoints.""" 256 | mapping = { 257 | "fc.weight": "fc.linear.weight", 258 | "fc.bias": "fc.linear.bias", 259 | "proj.weight": "proj.linear.weight", 260 | "proj.bias": "proj.linear.bias", 261 | } 262 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 263 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 264 | 265 | 266 | class LLaMAMLP(lit_gpt.model.LLaMAMLP): 267 | def __init__(self, config: Config) -> None: 268 | nn.Module.__init__(self) 269 | self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 270 | self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 271 | self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) 272 | 273 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 274 | """For compatibility with base checkpoints.""" 275 | mapping = { 276 | "fc_1.weight": "fc_1.linear.weight", 277 | "fc_1.bias": "fc_1.linear.bias", 278 | "fc_2.weight": "fc_2.linear.weight", 279 | "fc_2.bias": "fc_2.linear.bias", 280 | "proj.weight": "proj.linear.weight", 281 | "proj.bias": "proj.linear.bias", 282 | } 283 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 284 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 285 | 286 | 287 | def mark_only_adapter_v2_as_trainable(model: GPT) -> None: 288 | """Sets requires_grad=False for all non-adapter weights""" 289 | for name, param in model.named_parameters(): 290 | param.requires_grad = adapter_filter(name, param) 291 | -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/fused_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import xentropy_cuda_lib 6 | 7 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for 8 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent 9 | # version of PyTorch. The following 2 lines are for backward compatibility with 10 | # older PyTorch. 11 | if "all_gather_into_tensor" not in dir(torch.distributed): 12 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 13 | 14 | 15 | class SoftmaxCrossEntropyLossFn(torch.autograd.Function): 16 | @staticmethod 17 | def forward( 18 | ctx, 19 | logits, 20 | labels, 21 | smoothing=0.0, 22 | ignored_index=-100, 23 | inplace_backward=False, 24 | process_group=None, 25 | ): 26 | """ 27 | logits: (batch, vocab_size) 28 | labels: (batch,) 29 | If process_group is not None, we're doing Tensor Parallel: each process is responsible for 30 | one part of the vocab. The loss needs to be aggregated across processes. 31 | """ 32 | batch, vocab_size = logits.shape 33 | assert labels.shape == (batch,) 34 | world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) 35 | ctx.total_classes = world_size * vocab_size 36 | 37 | if world_size == 1: 38 | losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) 39 | losses.masked_fill_(labels == ignored_index, 0) 40 | labels_local = labels 41 | else: 42 | rank = torch.distributed.get_rank(process_group) 43 | vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size 44 | 45 | # Create a mask of valid vocab ids (1 means it needs to be masked). 46 | labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) 47 | ignored_mask = labels == ignored_index 48 | labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) 49 | 50 | # For tensor parallel cross entropy with smoothing, we want to pass in the total number 51 | # of classes so that smoothing can be applied correctly. If total_classes=-1, use the 52 | # last dimension of the input tensor. 53 | losses, lse_local = xentropy_cuda_lib.forward( 54 | logits, labels_local, smoothing, world_size * vocab_size 55 | ) 56 | assert lse_local.shape == (batch,) 57 | assert losses.shape == (batch,) 58 | losses.masked_fill_(ignored_mask, 0) 59 | # For labels == ignored_index, the loss is always 0. 60 | # If there's no smoothing, if labels are in the vocab of this partition, losses contains 61 | # lse_local - predicted logit, and 0 otherwise. 62 | # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains 63 | # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) 64 | # For labels not in the vocab of this partition, losses contains 65 | # 0.1 * (lse_local - sum logit / total_classes). 66 | 67 | lse_allgather = torch.empty( 68 | world_size, batch, dtype=lse_local.dtype, device=lse_local.device 69 | ) 70 | torch.distributed.all_gather_into_tensor( 71 | lse_allgather, lse_local.contiguous(), group=process_group 72 | ) 73 | handle_losses = torch.distributed.all_reduce( 74 | losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True 75 | ) 76 | lse = torch.logsumexp(lse_allgather, dim=0) 77 | # If there's no smoothing, the total losses are lse_local - predicted_logit, 78 | # we just have to subtract the lse_local and add the lse (global). 79 | # If there's smoothing=0.1, the total losses are 80 | # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) 81 | # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). 82 | rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor") 83 | lse_local = lse_allgather[ 84 | rank_per_sample, torch.arange(batch, device=lse_allgather.device) 85 | ] 86 | 87 | handle_losses.wait() 88 | if smoothing == 0.0: 89 | losses += lse - lse_local 90 | else: 91 | losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( 92 | lse - lse_allgather.sum(dim=0) 93 | ) 94 | losses.masked_fill_(ignored_mask, 0) 95 | 96 | ctx.save_for_backward(logits, lse, labels_local) 97 | ctx.smoothing = smoothing 98 | ctx.ignored_index = ignored_index 99 | ctx.inplace_backward = inplace_backward 100 | return losses 101 | 102 | @staticmethod 103 | def backward(ctx, grad_loss): 104 | logits, lse, labels = ctx.saved_tensors 105 | grad_loss = grad_loss.contiguous() 106 | grad_loss.masked_fill_(labels == ctx.ignored_index, 0) 107 | grad_logits = xentropy_cuda_lib.backward( 108 | grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes 109 | ) 110 | return grad_logits, None, None, None, None, None, None 111 | 112 | 113 | class FusedCrossEntropyLoss(nn.Module): 114 | def __init__( 115 | self, 116 | ignore_index=-100, 117 | reduction="mean", 118 | label_smoothing=0.0, 119 | inplace_backward=True, 120 | process_group=None, 121 | ): 122 | super().__init__() 123 | if reduction not in ["mean", "none"]: 124 | raise NotImplementedError("Only support reduction = 'mean' or 'none'") 125 | self.ignore_index = ignore_index 126 | self.reduction = reduction 127 | self.label_smoothing = label_smoothing 128 | self.inplace_backward = inplace_backward 129 | self.process_group = process_group 130 | 131 | def forward(self, input, target): 132 | assert input.is_cuda and target.is_cuda 133 | # SoftmaxCrossEntropyLoss implicitly casts to float 134 | if len(input.shape) == 3: 135 | input = input.view(-1, input.size(-1)) 136 | target = target.view(-1) 137 | loss = SoftmaxCrossEntropyLossFn.apply( 138 | input, 139 | target, 140 | self.label_smoothing, 141 | self.ignore_index, 142 | self.inplace_backward, 143 | self.process_group, 144 | ) 145 | if self.reduction == "mean": 146 | return loss.sum() / (target != self.ignore_index).sum() 147 | else: 148 | return loss -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import math 4 | from typing import Optional, Tuple 5 | 6 | import rotary_emb 7 | import torch 8 | from einops import rearrange, repeat 9 | 10 | class ApplyRotaryEmb(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 13 | """ 14 | x: (batch_size, seqlen, nheads, headdim) 15 | cos, sin: (seqlen, rotary_dim / 2) 16 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 17 | of 1st half and 2nd half (GPT-NeoX style). 18 | rotary_dim must be <= headdim 19 | Apply rotary embedding to the first rotary_dim of x. 20 | """ 21 | batch, seqlen, nheads, headdim = x.shape 22 | rotary_seqlen, rotary_dim = cos.shape 23 | rotary_dim *= 2 24 | assert rotary_dim <= headdim 25 | assert seqlen <= rotary_seqlen 26 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 27 | x_ro = x[..., :rotary_dim] 28 | x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) 29 | out = torch.empty_like(x) if not inplace else x 30 | out_ro = out[..., :rotary_dim] 31 | if inplace: 32 | o1, o2 = x1, x2 33 | else: 34 | o1, o2 = ( 35 | out_ro.chunk(2, dim=-1) 36 | if not interleaved 37 | else (out_ro[..., ::2], out_ro[..., 1::2]) 38 | ) 39 | rotary_emb.apply_rotary( 40 | x1, 41 | x2, 42 | rearrange(cos[:seqlen], "s d -> s 1 d"), 43 | rearrange(sin[:seqlen], "s d -> s 1 d"), 44 | o1, 45 | o2, 46 | False, 47 | ) 48 | if not inplace and rotary_dim < headdim: 49 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 50 | ctx.save_for_backward(cos, sin) 51 | ctx.interleaved = interleaved 52 | ctx.inplace = inplace 53 | return out if not inplace else x 54 | 55 | @staticmethod 56 | def backward(ctx, do): 57 | cos, sin = ctx.saved_tensors 58 | _, seqlen, _, headdim = do.shape 59 | rotary_dim = cos.shape[-1] 60 | rotary_dim *= 2 61 | inplace = ctx.inplace 62 | do_ro = do[..., :rotary_dim] 63 | do1, do2 = ( 64 | do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) 65 | ) 66 | dx = torch.empty_like(do) if not inplace else do 67 | if inplace: 68 | dx1, dx2 = do1, do2 69 | else: 70 | dx_ro = dx[..., :rotary_dim] 71 | dx1, dx2 = ( 72 | dx_ro.chunk(2, dim=-1) 73 | if not ctx.interleaved 74 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 75 | ) 76 | rotary_emb.apply_rotary( 77 | do1, 78 | do2, 79 | rearrange(cos[:seqlen], "s d -> s 1 d"), 80 | rearrange(sin[:seqlen], "s d -> s 1 d"), 81 | dx1, 82 | dx2, 83 | True, 84 | ) 85 | if not inplace and rotary_dim < headdim: 86 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 87 | return dx, None, None, None, None 88 | 89 | 90 | apply_rotary_emb_func = ApplyRotaryEmb.apply 91 | 92 | -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/model.py: -------------------------------------------------------------------------------- 1 | """Full definition of a GPT NeoX Language Model, all of it in this single file. 2 | 3 | Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and 4 | https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. 5 | """ 6 | import math 7 | from typing import Any, List, Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from lightning_utilities.core.imports import RequirementCache 12 | from typing_extensions import Self 13 | from flash_attn import flash_attn_func 14 | from lit_gpt.config import Config 15 | from xformers.ops import SwiGLU 16 | from .fused_rotary_embedding import apply_rotary_emb_func 17 | RoPECache = Tuple[torch.Tensor, torch.Tensor] 18 | KVCache = Tuple[torch.Tensor, torch.Tensor] 19 | FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") 20 | 21 | 22 | class GPT(nn.Module): 23 | def __init__(self, config: Config) -> None: 24 | super().__init__() 25 | assert config.padded_vocab_size is not None 26 | self.config = config 27 | 28 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 29 | self.transformer = nn.ModuleDict( 30 | dict( 31 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 32 | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), 33 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 34 | ) 35 | ) 36 | self.rope_cache: Optional[RoPECache] = None 37 | self.mask_cache: Optional[torch.Tensor] = None 38 | self.kv_caches: List[KVCache] = [] 39 | 40 | def _init_weights(self, module: nn.Module, n_layer) -> None: 41 | """Meant to be used with `gpt.apply(gpt._init_weights)`.""" 42 | # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf 43 | if isinstance(module, nn.Embedding): 44 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) 45 | # RWKV: set it to 1e-4 46 | # torch.nn.init.uniform_(module.weight, -1e-4, 1e-4) 47 | elif isinstance(module, nn.Linear): 48 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) 49 | if module.bias is not None: 50 | torch.nn.init.zeros_(module.bias) 51 | # GPT-NeoX 52 | for name, p in module.named_parameters(): 53 | if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3 54 | nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) 55 | 56 | 57 | def reset_cache(self) -> None: 58 | self.kv_caches.clear() 59 | if self.mask_cache is not None and self.mask_cache.device.type == "xla": 60 | # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179 61 | self.rope_cache = None 62 | self.mask_cache = None 63 | 64 | def forward( 65 | self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None 66 | ) -> torch.Tensor: 67 | B, T = idx.size() 68 | use_kv_cache = input_pos is not None 69 | 70 | block_size = self.config.block_size 71 | if max_seq_length is None: 72 | max_seq_length = block_size 73 | if use_kv_cache: # not relevant otherwise 74 | assert ( 75 | max_seq_length >= T 76 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 77 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 78 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 79 | 80 | if self.rope_cache is None: 81 | self.rope_cache = self.build_rope_cache(idx) 82 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 83 | # for the kv-cache support (only during inference), we only create it in that situation 84 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 85 | if use_kv_cache and self.mask_cache is None: 86 | self.mask_cache = self.build_mask_cache(idx) 87 | 88 | cos, sin = self.rope_cache 89 | if use_kv_cache: 90 | 91 | cos = cos.index_select(0, input_pos) 92 | sin = sin.index_select(0, input_pos) 93 | mask = self.mask_cache.index_select(2, input_pos) 94 | mask = mask[:, :, :, :max_seq_length] 95 | else: 96 | cos = cos[:T] 97 | sin = sin[:T] 98 | mask = None 99 | 100 | # forward the model itself 101 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 102 | 103 | if not use_kv_cache: 104 | for block in self.transformer.h: 105 | x, *_ = block(x, (cos, sin), max_seq_length) 106 | else: 107 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2) 108 | for i, block in enumerate(self.transformer.h): 109 | x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i]) 110 | 111 | x = self.transformer.ln_f(x) 112 | 113 | return self.lm_head(x) # (b, t, vocab_size) 114 | 115 | @classmethod 116 | def from_name(cls, name: str, **kwargs: Any) -> Self: 117 | return cls(Config.from_name(name, **kwargs)) 118 | 119 | def build_rope_cache(self, idx: torch.Tensor) -> RoPECache: 120 | return build_rope_cache( 121 | seq_len=self.config.block_size, 122 | n_elem=int(self.config.rotary_percentage * self.config.head_size), 123 | dtype=torch.bfloat16, 124 | device=idx.device, 125 | condense_ratio=self.config.condense_ratio, 126 | ) 127 | 128 | def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor: 129 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 130 | return torch.tril(ones).unsqueeze(0).unsqueeze(0) 131 | 132 | def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]: 133 | B = idx.size(0) 134 | heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups 135 | 136 | k_cache_shape = ( 137 | B, 138 | max_seq_length, 139 | heads, 140 | rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size), 141 | ) 142 | v_cache_shape = (B, max_seq_length, heads, self.config.head_size) 143 | device = idx.device 144 | return [ 145 | (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device)) 146 | for _ in range(self.config.n_layer) 147 | ] 148 | 149 | 150 | class Block(nn.Module): 151 | def __init__(self, config: Config) -> None: 152 | super().__init__() 153 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 154 | self.attn = CausalSelfAttention(config) 155 | if not config.shared_attention_norm: 156 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 157 | self.mlp = config.mlp_class(config) 158 | self.config = config 159 | def forward( 160 | self, 161 | x: torch.Tensor, 162 | rope: RoPECache, 163 | max_seq_length: int, 164 | mask: Optional[torch.Tensor] = None, 165 | input_pos: Optional[torch.Tensor] = None, 166 | kv_cache: Optional[KVCache] = None, 167 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 168 | 169 | n_1 = self.norm_1(x) 170 | h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) 171 | if self.config.parallel_residual: 172 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 173 | x = x + h + self.mlp(n_2) 174 | else: 175 | if self.config.shared_attention_norm: 176 | raise NotImplementedError( 177 | "No checkpoint amongst the ones we support uses this configuration" 178 | " (non-parallel residual and shared attention norm)." 179 | ) 180 | 181 | x = x + h 182 | x = x + self.mlp(self.norm_2(x)) 183 | return x, new_kv_cache 184 | 185 | 186 | class CausalSelfAttention(nn.Module): 187 | def __init__(self, config: Config) -> None: 188 | super().__init__() 189 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 190 | # key, query, value projections for all heads, but in a batch 191 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 192 | # output projection 193 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 194 | 195 | self.config = config 196 | 197 | def forward( 198 | self, 199 | x: torch.Tensor, 200 | rope: RoPECache, 201 | max_seq_length: int, 202 | mask: Optional[torch.Tensor] = None, 203 | input_pos: Optional[torch.Tensor] = None, 204 | kv_cache: Optional[KVCache] = None, 205 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 206 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 207 | 208 | qkv = self.attn(x) 209 | 210 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 211 | q_per_kv = self.config.n_head // self.config.n_query_groups 212 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 213 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs) 214 | # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 215 | 216 | # split batched computation into three 217 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) 218 | 219 | # repeat k and v if necessary 220 | # Peiyuan: we do not need to do this as flash attention 2 already support GQA 221 | # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 222 | # # for MHA this is a no-op 223 | # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 224 | # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 225 | 226 | q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) 227 | k = k.reshape(B, T, -1, self.config.head_size) 228 | v = v.reshape(B, T, -1, self.config.head_size) 229 | 230 | cos, sin = rope 231 | 232 | # apply rope in fp32 significanly stabalize training 233 | # fused rope expect (batch_size, seqlen, nheads, headdim) 234 | q = apply_rotary_emb_func(q, cos, sin, False, True) 235 | k = apply_rotary_emb_func(k, cos, sin, False, True) 236 | 237 | # n_elem = int(self.config.rotary_percentage * self.config.head_size) 238 | 239 | # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 240 | # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 241 | # print( (q_roped - q).sum()) 242 | # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 243 | # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 244 | 245 | if kv_cache is not None: 246 | cache_k, cache_v = kv_cache 247 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 248 | # check if reached token limit 249 | if input_pos[-1] >= max_seq_length: 250 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 251 | # shift 1 position to the left 252 | cache_k = torch.roll(cache_k, -1, dims=1) 253 | cache_v = torch.roll(cache_v, -1, dims=1) 254 | 255 | k = cache_k.index_copy_(1, input_pos, k) 256 | v = cache_v.index_copy_(1, input_pos, v) 257 | kv_cache = k, v 258 | 259 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 260 | 261 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 262 | 263 | # output projection 264 | y = self.proj(y) 265 | 266 | return y, kv_cache 267 | 268 | def scaled_dot_product_attention( 269 | self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None 270 | ): 271 | scale = 1.0 / math.sqrt(self.config.head_size) 272 | 273 | if ( 274 | FlashAttention2Available 275 | and mask is None 276 | and q.device.type == "cuda" 277 | and q.dtype in (torch.float16, torch.bfloat16) 278 | ): 279 | from flash_attn import flash_attn_func 280 | 281 | return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True) 282 | q = q.transpose(1, 2) 283 | k = k.transpose(1, 2) 284 | v = v.transpose(1, 2) 285 | if q.size() != k.size(): 286 | k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) 287 | v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) 288 | y = torch.nn.functional.scaled_dot_product_attention( 289 | q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None 290 | ) 291 | return y.transpose(1, 2) 292 | 293 | 294 | class GptNeoxMLP(nn.Module): 295 | def __init__(self, config: Config) -> None: 296 | super().__init__() 297 | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 298 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 299 | 300 | def forward(self, x: torch.Tensor) -> torch.Tensor: 301 | x = self.fc(x) 302 | x = torch.nn.functional.gelu(x) 303 | return self.proj(x) 304 | 305 | 306 | class LLaMAMLP(nn.Module): 307 | def __init__(self, config: Config) -> None: 308 | super().__init__() 309 | # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 310 | # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 311 | # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 312 | self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False) 313 | def forward(self, x: torch.Tensor) -> torch.Tensor: 314 | # x_fc_1 = self.fc_1(x) 315 | # x_fc_2 = self.fc_2(x) 316 | # x = torch.nn.functional.silu(x_fc_1) * x_fc_2 317 | # return self.proj(x) 318 | return self.swiglu(x) 319 | 320 | 321 | def build_rope_cache( 322 | seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 323 | ) -> RoPECache: 324 | """Enhanced Transformer with Rotary Position Embedding. 325 | 326 | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ 327 | transformers/rope/__init__.py. MIT License: 328 | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. 329 | """ 330 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 331 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) 332 | 333 | # Create position indexes `[0, 1, ..., seq_len - 1]` 334 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 335 | 336 | # Calculate the product of position index and $\theta_i$ 337 | idx_theta = torch.outer(seq_idx, theta) 338 | 339 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 340 | 341 | # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding 342 | if dtype == torch.bfloat16: 343 | return cos.bfloat16(), sin.bfloat16() 344 | # this is to mimic the behaviour of complex32, else we will get different results 345 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 346 | return cos.half(), sin.half() 347 | return cos, sin 348 | 349 | 350 | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 351 | head_size = x.size(-1) 352 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 353 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 354 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 355 | roped = (x * cos) + (rotated * sin) 356 | return roped.type_as(x) 357 | -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/packed_dataset.py: -------------------------------------------------------------------------------- 1 | # Very loosely inspired by indexed_dataset in Fairseq, Megatron 2 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py 3 | 4 | 5 | import os 6 | import random 7 | import struct 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import IterableDataset, get_worker_info 12 | 13 | dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} 14 | 15 | 16 | def code(dtype): 17 | for k in dtypes: 18 | if dtypes[k] == dtype: 19 | return k 20 | raise ValueError(dtype) 21 | 22 | 23 | HDR_MAGIC = b"LITPKDS" 24 | HDR_SIZE = 24 # bytes 25 | 26 | 27 | class PackedDataset(IterableDataset): 28 | def __init__( 29 | self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 30 | ): 31 | self._filenames = filenames 32 | self._n_chunks = n_chunks 33 | self._block_size = block_size 34 | self._seed = seed 35 | self._shuffle = shuffle 36 | self._wrap = wrap 37 | self._num_processes = num_processes 38 | self._process_rank = process_rank 39 | 40 | def __iter__(self): 41 | worker_info = get_worker_info() 42 | num_workers = worker_info.num_workers if worker_info is not None else 1 43 | worker_id = worker_info.id if worker_info is not None else 0 44 | num_shards = num_workers * self._num_processes 45 | shard_id = self._process_rank * num_workers + worker_id 46 | 47 | max_num_files = len(self._filenames) // num_shards * num_shards 48 | filenames = self._filenames[shard_id:max_num_files:num_shards] 49 | 50 | return PackedDatasetIterator( 51 | filenames=filenames, 52 | n_chunks=self._n_chunks, 53 | block_size=self._block_size, 54 | seed=self._seed, 55 | shuffle=self._shuffle, 56 | wrap=self._wrap, 57 | ) 58 | 59 | 60 | class PackedDatasetBuilder(object): 61 | def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): 62 | if dtype == "auto": 63 | if vocab_size is None: 64 | raise ValueError("vocab_size cannot be None when dtype='auto'") 65 | if vocab_size is not None and vocab_size < 65500: 66 | self._dtype = np.uint16 67 | else: 68 | self._dtype = np.int32 69 | else: 70 | self._dtype = dtype 71 | self._counter = 0 72 | self._chunk_size = chunk_size 73 | self._outdir = outdir 74 | self._prefix = prefix 75 | self._sep_token = sep_token 76 | self._arr = np.zeros(self._chunk_size, dtype=self._dtype) 77 | self._arr.fill(self._sep_token) 78 | self._idx = 0 79 | self._version = 1 80 | self._filenames = [] 81 | 82 | def _write_chunk(self): 83 | filename = f"{self._prefix}_{self._counter:010d}.bin" 84 | filename = os.path.join(self._outdir, filename) 85 | 86 | with open(filename, "wb") as f: 87 | f.write(HDR_MAGIC) 88 | f.write(struct.pack(" self._chunk_size: 108 | part_len = self._chunk_size - self._idx 109 | self._arr[self._idx : self._idx + part_len] = arr[:part_len] 110 | self._write_chunk() 111 | arr = arr[part_len:] 112 | 113 | arr_len = arr.shape[0] 114 | self._arr[self._idx : self._idx + arr_len] = arr 115 | self._idx += arr_len 116 | 117 | def write_reminder(self): 118 | self._write_chunk() 119 | 120 | 121 | class PackedDatasetIterator: 122 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): 123 | self._seed = seed 124 | self._shuffle = shuffle 125 | self._rng = np.random.default_rng(seed) if shuffle else None 126 | self._block_idxs = None 127 | 128 | self._wrap = wrap 129 | 130 | # TODO: instead of filenames, we could have a single text stream 131 | # (or text file) with the sequence of all files to be 132 | # fetched/loaded. 133 | self._filenames = filenames 134 | self._file_idx = 0 135 | 136 | self._n_chunks = n_chunks 137 | 138 | self._dtype = None 139 | self._block_size = block_size 140 | self._n_blocks = None 141 | 142 | self._mmaps = [] 143 | self._buffers = [] 144 | 145 | self._block_idxs = [] 146 | self._curr_idx = 0 147 | 148 | self._load_n_chunks() 149 | 150 | def _read_header(self, path): 151 | with open(path, "rb") as f: 152 | magic = f.read(len(HDR_MAGIC)) 153 | assert magic == HDR_MAGIC, "File doesn't match expected format." 154 | version = struct.unpack(" len(self._filenames[self._file_idx :]): 171 | # if not self._wrap: 172 | # raise StopIteration 173 | self._file_idx = 0 174 | 175 | for i in range(self._n_chunks): 176 | filename = self._filenames[self._file_idx + i] 177 | if self._dtype is None: 178 | self._dtype, self._chunk_size = self._read_header(filename) 179 | self._n_blocks = self._chunk_size // self._block_size 180 | # TODO: check header matches with previous files 181 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 182 | self._mmaps.append(mmap) 183 | self._buffers.append(memoryview(mmap)) 184 | 185 | self._file_idx += self._n_chunks 186 | n_all_blocks = self._n_chunks * self._n_blocks 187 | 188 | self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) 189 | 190 | self._curr_idx = 0 191 | 192 | def __del__(self): 193 | self._close_mmaps() 194 | del self._mmaps 195 | del self._buffers 196 | 197 | def __iter__(self): 198 | return self 199 | 200 | def __next__(self): 201 | if self._curr_idx >= len(self._block_idxs): 202 | self._load_n_chunks() 203 | # TODO: trigger fetching next next n_chunks if remote 204 | block_idx = self._block_idxs[self._curr_idx] 205 | chunk_id = block_idx // self._n_blocks 206 | buffer = self._buffers[chunk_id] 207 | elem_id = (block_idx % self._n_blocks) * self._block_size 208 | offset = np.dtype(self._dtype).itemsize * elem_id 209 | arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) 210 | self._curr_idx += 1 211 | return torch.from_numpy(arr.astype(np.int64)) 212 | 213 | 214 | class CombinedDataset(IterableDataset): 215 | def __init__(self, datasets, seed, weights=None): 216 | self._seed = seed 217 | self._datasets = datasets 218 | self._weights = weights 219 | n_datasets = len(datasets) 220 | if weights is None: 221 | self._weights = [1 / n_datasets] * n_datasets 222 | 223 | def __iter__(self): 224 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights) 225 | 226 | 227 | class CombinedDatasetIterator: 228 | def __init__(self, datasets, seed, weights): 229 | self._datasets = [iter(el) for el in datasets] 230 | self._weights = weights 231 | self._rng = random.Random(seed) 232 | 233 | def __next__(self): 234 | (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) 235 | return next(dataset) 236 | -------------------------------------------------------------------------------- /MEAP-Pretain/lit_gpt/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | class Tokenizer: 9 | def __init__(self, checkpoint_dir: Path) -> None: 10 | # some checkpoints have both files, `.model` takes precedence 11 | if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): 12 | from sentencepiece import SentencePieceProcessor 13 | 14 | self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) 15 | self.backend = "sentencepiece" 16 | self.bos_id = self.processor.bos_id() 17 | self.eos_id = self.processor.eos_id() 18 | elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): 19 | from tokenizers import Tokenizer as HFTokenizer 20 | 21 | self.processor = HFTokenizer.from_file(str(vocabulary_path)) 22 | self.backend = "huggingface" 23 | with open(checkpoint_dir / "tokenizer_config.json") as fp: 24 | config = json.load(fp) 25 | bos_token = config.get("bos_token") 26 | self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None 27 | self.eos_id = self.token_to_id(config["eos_token"]) 28 | else: 29 | raise NotImplementedError 30 | 31 | @property 32 | def vocab_size(self) -> int: 33 | if self.backend == "huggingface": 34 | return self.processor.get_vocab_size(with_added_tokens=False) 35 | if self.backend == "sentencepiece": 36 | return self.processor.vocab_size() 37 | raise RuntimeError 38 | 39 | def token_to_id(self, token: str) -> int: 40 | if self.backend == "huggingface": 41 | id_ = self.processor.token_to_id(token) 42 | elif self.backend == "sentencepiece": 43 | id_ = self.processor.piece_to_id(token) 44 | else: 45 | raise RuntimeError 46 | if id_ is None: 47 | raise ValueError(f"token {token!r} not found in the collection.") 48 | return id_ 49 | 50 | def encode( 51 | self, 52 | string: str, 53 | device: Optional[torch.device] = None, 54 | bos: bool = False, 55 | eos: bool = True, 56 | max_length: int = -1, 57 | ) -> torch.Tensor: 58 | if self.backend == "huggingface": 59 | tokens = self.processor.encode(string).ids 60 | elif self.backend == "sentencepiece": 61 | tokens = self.processor.encode(string) 62 | else: 63 | raise RuntimeError 64 | if bos: 65 | bos_id = self.bos_id 66 | if bos_id is None: 67 | raise NotImplementedError("This tokenizer does not defined a bos token") 68 | tokens = [bos_id] + tokens 69 | if eos: 70 | tokens = tokens + [self.eos_id] 71 | if max_length > 0: 72 | tokens = tokens[:max_length] 73 | return torch.tensor(tokens, dtype=torch.int, device=device) 74 | 75 | def decode(self, tensor: torch.Tensor) -> str: 76 | tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() 77 | return self.processor.decode(tokens) 78 | -------------------------------------------------------------------------------- /MEAP-Pretain/pretrained/meap_0.1b_2b_0.05.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | 14 | from transformers import AutoTokenizer 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 19 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 20 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 21 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 22 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 23 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 24 | from pytorch_lightning.loggers import WandbLogger 25 | from lit_gpt import FusedCrossEntropyLoss 26 | import random 27 | 28 | model_name = "tiny_LLaMA_0.1b_mask" 29 | name = "meap_0.1b_mask" 30 | out_dir = Path("out_meap_0.1b_mask_1220_0.05") / name 31 | #num_nodes=1 32 | # Hyperparameters 33 | num_of_devices = 8 34 | global_batch_size = 256 35 | learning_rate = 4e-4 36 | micro_batch_size = 16 37 | max_step = 1907 38 | warmup_steps = 190 39 | log_step_interval = 10 40 | eval_iters = 100 41 | save_step_interval = 1900 42 | eval_step_interval = 600 43 | 44 | 45 | weight_decay = 5e-2 46 | beta1 = 0.9 47 | beta2 = 0.95 48 | grad_clip = 1.0 49 | decay_lr = True 50 | min_lr = 4e-5 51 | 52 | batch_size = global_batch_size // num_of_devices 53 | gradient_accumulation_steps = batch_size // micro_batch_size 54 | assert gradient_accumulation_steps > 0 55 | warmup_iters = warmup_steps * gradient_accumulation_steps 56 | 57 | checkpoint_dir = Path("../tokenizer") 58 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False,trust_remote_code=True) 59 | special_tokens = { 60 | "additional_special_tokens": ["[MASK]"], 61 | "pad_token": tokenizer.pad_token or tokenizer.eos_token, 62 | "eos_token": tokenizer.eos_token, 63 | "bos_token": tokenizer.bos_token or tokenizer.eos_token, 64 | } 65 | num_added_tokens = tokenizer.add_special_tokens(special_tokens) 66 | use_mask = True 67 | mask_ratio = 0.05 68 | max_iters = max_step * gradient_accumulation_steps 69 | lr_decay_iters = max_iters 70 | log_iter_interval = log_step_interval * gradient_accumulation_steps 71 | 72 | 73 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 74 | train_data_config = [ 75 | ("c4", 1.0), 76 | ] 77 | 78 | val_data_config = [ 79 | ("validation", 1.0), 80 | ] 81 | 82 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 83 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 84 | wandb_logger = WandbLogger(project="NTP-MASK", 85 | name="mask_0.1b_2b_1220_0.05", 86 | log_model=True, 87 | save_dir=str(out_dir), 88 | config={ 89 | 90 | "model_name": model_name, 91 | "total_params": None, 92 | 93 | "global_batch_size": global_batch_size, 94 | "micro_batch_size": micro_batch_size, 95 | "gradient_accumulation_steps": gradient_accumulation_steps, 96 | "learning_rate": learning_rate, 97 | "min_lr": min_lr, 98 | "weight_decay": weight_decay, 99 | "warmup_steps": warmup_steps, 100 | "max_step": max_step, 101 | "beta1": beta1, 102 | "beta2": beta2, 103 | "grad_clip": grad_clip, 104 | 105 | 106 | "train_data_config": train_data_config, 107 | "val_data_config": val_data_config, 108 | 109 | 110 | "num_devices": num_of_devices, 111 | "precision": None, 112 | } 113 | 114 | ) 115 | 116 | 117 | 118 | def setup( 119 | devices: int = 8, 120 | train_data_dir: Path = Path("../c4_bin"), 121 | val_data_dir: Optional[Path] = None, 122 | precision: Optional[str] = None, 123 | tpu: bool = False, 124 | resume: Union[bool, Path] = False, 125 | ) -> None: 126 | print("devices: ", devices) 127 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 128 | 129 | if devices > 1: 130 | if tpu: 131 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 132 | devices = "auto" 133 | strategy = XLAStrategy(sync_module_states=False) 134 | else: 135 | strategy = FSDPStrategy( 136 | auto_wrap_policy={Block}, 137 | activation_checkpointing_policy=None, 138 | state_dict_type="full", 139 | limit_all_gathers=True, 140 | cpu_offload=False, 141 | ) 142 | else: 143 | strategy = "auto" 144 | 145 | fabric = L.Fabric(devices=devices,strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 146 | fabric.print(hparams) 147 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 148 | main(fabric, train_data_dir, val_data_dir, resume) 149 | 150 | 151 | def main(fabric, train_data_dir, val_data_dir, resume): 152 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 153 | 154 | if fabric.global_rank == 0: 155 | out_dir.mkdir(parents=True, exist_ok=True) 156 | 157 | config = Config.from_name(model_name) 158 | fabric.print(f"train_data_dir: {train_data_dir}") 159 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 160 | fabric.print(f"val_data_dir: {val_data_dir}") 161 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 162 | #train_data_dir=Path("./output_data/slim_star_combined ") 163 | #val_data_dir=Path("./output_data/slim_star_combined ") 164 | fabric.print(f"train_data_dir: {train_data_dir}") 165 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 166 | fabric.print(f"val_data_dir: {val_data_dir}") 167 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 168 | train_dataloader, val_dataloader = create_dataloaders( 169 | batch_size=micro_batch_size, 170 | block_size=config.block_size, 171 | fabric=fabric, 172 | train_data_dir=train_data_dir, 173 | val_data_dir=val_data_dir, 174 | seed=3407, 175 | ) 176 | if val_dataloader is None: 177 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 178 | else: 179 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 180 | 181 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 182 | 183 | fabric.print(f"Loading model with {config.__dict__}") 184 | t0 = time.perf_counter() 185 | with fabric.init_module(empty_init=False): 186 | model = GPT(config) 187 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 188 | 189 | total_params = num_parameters(model) 190 | #wandb_logger.experiment.config.update({"total_params": total_params}) 191 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 192 | fabric.print(f"Total parameters {num_parameters(model):,}") 193 | 194 | model = fabric.setup(model) 195 | optimizer = torch.optim.AdamW( 196 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 197 | ) 198 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 199 | optimizer = fabric.setup_optimizers(optimizer) 200 | 201 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 202 | 203 | if resume is True: 204 | resume = sorted(out_dir.glob("*.pth"))[-1] 205 | if resume : 206 | fabric.print(f"Resuming training from {resume}") 207 | fabric.load(resume, state) 208 | 209 | train_time = time.perf_counter() 210 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 211 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 212 | if fabric.device.type == "cuda": 213 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 214 | 215 | def get_cycled_seed(curr_iter): 216 | cycle_range = 51 217 | cycled_value = 50 + (curr_iter % cycle_range) 218 | return cycled_value 219 | 220 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 221 | model = state["model"] 222 | optimizer = state["optimizer"] 223 | 224 | if val_dataloader is not None: 225 | validate(fabric, model, val_dataloader) # sanity check 226 | 227 | with torch.device("meta"): 228 | meta_model = GPT(model.config) 229 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 230 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 231 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 232 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 233 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 234 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 235 | # measured_flos run in meta. Will trigger fusedRMSNorm error 236 | #measured_flops = measure_flops(meta_model, x) 237 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 238 | del meta_model, x 239 | 240 | total_lengths = 0 241 | total_t0 = time.perf_counter() 242 | 243 | if fabric.device.type == "xla": 244 | import torch_xla.core.xla_model as xm 245 | 246 | xm.mark_step() 247 | 248 | 249 | initial_iter = state["iter_num"] 250 | curr_iter = 0 251 | # random.seed(curr_iter) 252 | 253 | loss_func = FusedCrossEntropyLoss() 254 | for train_data in train_dataloader: 255 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 256 | if resume: 257 | if curr_iter < initial_iter: 258 | curr_iter += 1 259 | continue 260 | else: 261 | resume = False 262 | curr_iter = -1 263 | fabric.barrier() 264 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 265 | if state["iter_num"] >= max_iters: 266 | break 267 | 268 | # determine and set the learning rate for this iteration 269 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 270 | for param_group in optimizer.param_groups: 271 | param_group["lr"] = lr 272 | 273 | iter_t0 = time.perf_counter() 274 | input_ids = train_data[:, 0 : model.config.block_size].clone().contiguous() 275 | if use_mask: 276 | random.seed(get_cycled_seed(state["step_count"])) 277 | num_masks = max(1, int(model.config.block_size * mask_ratio)) 278 | mask_positions = random.sample(range(0, model.config.block_size-1), num_masks) 279 | #print(f"mask_positions: {mask_positions}") 280 | bs, seq_len = input_ids.shape 281 | #input_ids[0,0] = 32000 282 | input_ids[:,mask_positions] = tokenizer.convert_tokens_to_ids("[MASK]") 283 | 284 | 285 | 286 | 287 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 288 | 289 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 290 | with fabric.no_backward_sync(model, enabled=is_accumulating): 291 | logits = model(input_ids) 292 | loss = loss_func(logits, targets) 293 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 294 | fabric.backward(loss / gradient_accumulation_steps) 295 | 296 | if not is_accumulating: 297 | wandb_logger.log_metrics({ 298 | "train/loss": loss.item(), 299 | "train/learning_rate": lr, 300 | "train/step": state["step_count"], 301 | "system/gpu_memory": torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 302 | "system/gpu_memory_reserved": torch.cuda.max_memory_reserved() / 1e9 if torch.cuda.is_available() else 0, 303 | "performance/iter_time": (t1 - iter_t0) * 1000, 304 | "performance/tokens_per_sec": (micro_batch_size * model.config.block_size) / (t1 - iter_t0), 305 | "system/estimated_tflops": (estimated_flops * fabric.world_size / 1e12) * (1 / (t1 - iter_t0)) 306 | }, step=state["step_count"]) 307 | 308 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 309 | optimizer.step() 310 | optimizer.zero_grad() 311 | state["step_count"] += 1 312 | elif fabric.device.type == "xla": 313 | xm.mark_step() 314 | state["iter_num"] += 1 315 | 316 | # input_id: B L 317 | total_lengths += input_ids.size(1) 318 | t1 = time.perf_counter() 319 | fabric.print( 320 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 321 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 322 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 323 | # print days as well 324 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 325 | ) 326 | 327 | monitor.on_train_batch_end( 328 | state["iter_num"] * micro_batch_size, 329 | t1 - total_t0, 330 | # this assumes that device FLOPs are the same and that all devices have the same batch size 331 | fabric.world_size, 332 | state["step_count"], 333 | flops_per_batch=estimated_flops, 334 | lengths=total_lengths, 335 | train_loss = loss.item() 336 | ) 337 | 338 | 339 | 340 | 341 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 342 | 343 | t0 = time.perf_counter() 344 | val_loss = validate(fabric, model, val_dataloader) 345 | t1 = time.perf_counter() - t0 346 | monitor.eval_end(t1) 347 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 348 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 349 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 350 | fabric.barrier() 351 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 352 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 353 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 354 | fabric.save(checkpoint_path, state) 355 | 356 | 357 | @torch.no_grad() 358 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 359 | fabric.print("Validating ...") 360 | model.eval() 361 | 362 | losses = torch.zeros(eval_iters, device=fabric.device) 363 | for k, val_data in enumerate(val_dataloader): 364 | if k >= eval_iters: 365 | break 366 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 367 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 368 | logits = model(input_ids) 369 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 370 | 371 | # loss_func = FusedCrossEntropyLoss() 372 | # loss = loss_func(logits, targets) 373 | losses[k] = loss.item() 374 | 375 | out = losses.mean() 376 | 377 | model.train() 378 | return out 379 | 380 | 381 | def create_dataloader( 382 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 383 | ) -> DataLoader: 384 | datasets = [] 385 | data_config = train_data_config if split == "train" else val_data_config 386 | for prefix, _ in data_config: 387 | #print(f"data_dir: {data_dir}") 388 | #print(f"prefix: {prefix}") 389 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 390 | #print(f"filenames: {filenames}") 391 | random.seed(seed) 392 | random.shuffle(filenames) 393 | #print(f"filenames after shuffle: {filenames}") 394 | dataset = PackedDataset( 395 | filenames, 396 | # n_chunks control the buffer size. 397 | # Note that the buffer size also impacts the random shuffle 398 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 399 | n_chunks=1, 400 | block_size=block_size, 401 | shuffle=shuffle, 402 | seed=seed+fabric.global_rank, 403 | num_processes=fabric.world_size, 404 | process_rank=fabric.global_rank, 405 | ) 406 | datasets.append(dataset) 407 | 408 | if not datasets: 409 | raise RuntimeError( 410 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 411 | ) 412 | weights = [weight for _, weight in data_config] 413 | sum_weights = sum(weights) 414 | weights = [el / sum_weights for el in weights] 415 | 416 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 417 | 418 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 419 | 420 | 421 | def create_dataloaders( 422 | batch_size: int, 423 | block_size: int, 424 | fabric, 425 | train_data_dir: Path = Path("data/redpajama_sample"), 426 | val_data_dir: Optional[Path] = None, 427 | seed: int = 12345, 428 | ) -> Tuple[DataLoader, DataLoader]: 429 | # Increase by one because we need the next word as well 430 | effective_block_size = block_size + 1 431 | train_dataloader = create_dataloader( 432 | batch_size=batch_size, 433 | block_size=effective_block_size, 434 | fabric=fabric, 435 | data_dir=train_data_dir, 436 | shuffle=True, 437 | seed=seed, 438 | split="train" 439 | ) 440 | val_dataloader = ( 441 | create_dataloader( 442 | batch_size=batch_size, 443 | block_size=effective_block_size, 444 | fabric=fabric, 445 | data_dir=val_data_dir, 446 | shuffle=False, 447 | seed=seed, 448 | split="validation" 449 | ) 450 | if val_data_dir 451 | else None 452 | ) 453 | return train_dataloader, val_dataloader 454 | 455 | 456 | # learning rate decay scheduler (cosine with warmup) 457 | def get_lr(it): 458 | # 1) linear warmup for warmup_iters steps 459 | if it < warmup_iters: 460 | return learning_rate * it / warmup_iters 461 | # 2) if it > lr_decay_iters, return min learning rate 462 | if it > lr_decay_iters: 463 | return min_lr 464 | # 3) in between, use cosine decay down to min learning rate 465 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 466 | assert 0 <= decay_ratio <= 1 467 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 468 | return min_lr + coeff * (learning_rate - min_lr) 469 | 470 | 471 | if __name__ == "__main__": 472 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 473 | # torch.backends.cuda.enable_flash_sdp(False) 474 | torch.set_float32_matmul_precision("high") 475 | 476 | from jsonargparse import CLI 477 | 478 | CLI(setup) 479 | -------------------------------------------------------------------------------- /MEAP-Pretain/pretrained/meap_0.5b_10b.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | 14 | from transformers import AutoTokenizer 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 19 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 20 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 21 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 22 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 23 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 24 | from pytorch_lightning.loggers import WandbLogger 25 | from lit_gpt import FusedCrossEntropyLoss 26 | import random 27 | 28 | model_name = "tiny_LLaMA_0.5b_mask" 29 | name = "meap_0.5b_mask" 30 | out_dir = Path("out_meap_0.5b_mask_0.15mask_ratio") / name 31 | #num_nodes=1 32 | # Hyperparameters 33 | num_of_devices = 8 34 | global_batch_size = 256 35 | learning_rate = 4e-4 36 | micro_batch_size = 8 37 | max_step = 10000 38 | warmup_steps = 1000 39 | log_step_interval = 10 40 | eval_iters = 100 41 | save_step_interval = 5000 42 | eval_step_interval = 500 43 | 44 | 45 | weight_decay = 5e-2 46 | beta1 = 0.9 47 | beta2 = 0.95 48 | grad_clip = 1.0 49 | decay_lr = True 50 | min_lr = 4e-5 51 | 52 | batch_size = global_batch_size // num_of_devices 53 | gradient_accumulation_steps = batch_size // micro_batch_size 54 | assert gradient_accumulation_steps > 0 55 | warmup_iters = warmup_steps * gradient_accumulation_steps 56 | 57 | checkpoint_dir = Path("../tokenizer") 58 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False,trust_remote_code=True) 59 | special_tokens = { 60 | "additional_special_tokens": ["[MASK]"], 61 | "pad_token": tokenizer.pad_token or tokenizer.eos_token, 62 | "eos_token": tokenizer.eos_token, 63 | "bos_token": tokenizer.bos_token or tokenizer.eos_token, 64 | } 65 | num_added_tokens = tokenizer.add_special_tokens(special_tokens) 66 | use_mask = True 67 | mask_ratio = 0.15 68 | max_iters = max_step * gradient_accumulation_steps 69 | lr_decay_iters = max_iters 70 | log_iter_interval = log_step_interval * gradient_accumulation_steps 71 | 72 | 73 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 74 | train_data_config = [ 75 | ("c4", 1.0), 76 | ] 77 | 78 | val_data_config = [ 79 | ("validation", 1.0), 80 | ] 81 | 82 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 83 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 84 | wandb_logger = WandbLogger(project="NTP-MASK", 85 | name="meap_0.5b_mask_10b_0.15mask_ratio_1219", 86 | log_model=True, 87 | save_dir=str(out_dir), 88 | config={ 89 | 90 | "model_name": model_name, 91 | "total_params": None, 92 | 93 | 94 | "global_batch_size": global_batch_size, 95 | "micro_batch_size": micro_batch_size, 96 | "gradient_accumulation_steps": gradient_accumulation_steps, 97 | "learning_rate": learning_rate, 98 | "min_lr": min_lr, 99 | "weight_decay": weight_decay, 100 | "warmup_steps": warmup_steps, 101 | "max_step": max_step, 102 | "beta1": beta1, 103 | "beta2": beta2, 104 | "grad_clip": grad_clip, 105 | 106 | 107 | "train_data_config": train_data_config, 108 | "val_data_config": val_data_config, 109 | 110 | 111 | "num_devices": num_of_devices, 112 | "precision": None, 113 | } 114 | 115 | ) 116 | 117 | 118 | 119 | def setup( 120 | devices: int = 8, 121 | train_data_dir: Path = Path("../c4_bin"), 122 | val_data_dir: Optional[Path] = None, 123 | precision: Optional[str] = None, 124 | tpu: bool = False, 125 | resume: Union[bool, Path] = False, 126 | ) -> None: 127 | print("devices: ", devices) 128 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 129 | 130 | if devices > 1: 131 | if tpu: 132 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 133 | devices = "auto" 134 | strategy = XLAStrategy(sync_module_states=False) 135 | else: 136 | strategy = FSDPStrategy( 137 | auto_wrap_policy={Block}, 138 | activation_checkpointing_policy=None, 139 | state_dict_type="full", 140 | limit_all_gathers=True, 141 | cpu_offload=False, 142 | ) 143 | else: 144 | strategy = "auto" 145 | 146 | fabric = L.Fabric(devices=devices,strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 147 | fabric.print(hparams) 148 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 149 | main(fabric, train_data_dir, val_data_dir, resume) 150 | 151 | 152 | def main(fabric, train_data_dir, val_data_dir, resume): 153 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 154 | 155 | if fabric.global_rank == 0: 156 | out_dir.mkdir(parents=True, exist_ok=True) 157 | 158 | config = Config.from_name(model_name) 159 | fabric.print(f"train_data_dir: {train_data_dir}") 160 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 161 | fabric.print(f"val_data_dir: {val_data_dir}") 162 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 163 | #train_data_dir=Path("./output_data/slim_star_combined ") 164 | #val_data_dir=Path("./output_data/slim_star_combined ") 165 | fabric.print(f"train_data_dir: {train_data_dir}") 166 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 167 | fabric.print(f"val_data_dir: {val_data_dir}") 168 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 169 | train_dataloader, val_dataloader = create_dataloaders( 170 | batch_size=micro_batch_size, 171 | block_size=config.block_size, 172 | fabric=fabric, 173 | train_data_dir=train_data_dir, 174 | val_data_dir=val_data_dir, 175 | seed=3407, 176 | ) 177 | if val_dataloader is None: 178 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 179 | else: 180 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 181 | 182 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 183 | 184 | fabric.print(f"Loading model with {config.__dict__}") 185 | t0 = time.perf_counter() 186 | with fabric.init_module(empty_init=False): 187 | model = GPT(config) 188 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 189 | 190 | total_params = num_parameters(model) 191 | #wandb_logger.experiment.config.update({"total_params": total_params}) 192 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 193 | fabric.print(f"Total parameters {num_parameters(model):,}") 194 | 195 | model = fabric.setup(model) 196 | optimizer = torch.optim.AdamW( 197 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 198 | ) 199 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 200 | optimizer = fabric.setup_optimizers(optimizer) 201 | 202 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 203 | 204 | if resume is True: 205 | resume = "" 206 | if resume : 207 | resume = "" 208 | fabric.print(f"Resuming training from {resume}") 209 | fabric.load(resume, state) 210 | 211 | train_time = time.perf_counter() 212 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 213 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 214 | if fabric.device.type == "cuda": 215 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 216 | 217 | def get_cycled_seed(curr_iter): 218 | cycle_range = 51 219 | cycled_value = 50 + (curr_iter % cycle_range) 220 | return cycled_value 221 | 222 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 223 | model = state["model"] 224 | optimizer = state["optimizer"] 225 | 226 | if val_dataloader is not None: 227 | validate(fabric, model, val_dataloader) # sanity check 228 | 229 | with torch.device("meta"): 230 | meta_model = GPT(model.config) 231 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 232 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 233 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 234 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 235 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 236 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 237 | # measured_flos run in meta. Will trigger fusedRMSNorm error 238 | #measured_flops = measure_flops(meta_model, x) 239 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 240 | del meta_model, x 241 | 242 | total_lengths = 0 243 | total_t0 = time.perf_counter() 244 | 245 | if fabric.device.type == "xla": 246 | import torch_xla.core.xla_model as xm 247 | 248 | xm.mark_step() 249 | 250 | 251 | initial_iter = state["iter_num"] 252 | curr_iter = 0 253 | 254 | 255 | loss_func = FusedCrossEntropyLoss() 256 | for train_data in train_dataloader: 257 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 258 | if resume: 259 | if curr_iter < initial_iter: 260 | curr_iter += 1 261 | continue 262 | else: 263 | resume = False 264 | curr_iter = -1 265 | fabric.barrier() 266 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 267 | if state["iter_num"] >= max_iters: 268 | break 269 | 270 | # determine and set the learning rate for this iteration 271 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 272 | for param_group in optimizer.param_groups: 273 | param_group["lr"] = lr 274 | 275 | iter_t0 = time.perf_counter() 276 | input_ids = train_data[:, 0 : model.config.block_size].clone().contiguous() 277 | if use_mask: 278 | random.seed(get_cycled_seed(state["step_count"])) 279 | num_masks = max(1, int(model.config.block_size * mask_ratio)) 280 | mask_positions = random.sample(range(0, model.config.block_size-1), num_masks) 281 | 282 | bs, seq_len = input_ids.shape 283 | input_ids[:,mask_positions] = tokenizer.convert_tokens_to_ids("[MASK]") 284 | 285 | 286 | 287 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 288 | #print(tokenizer.convert_tokens_to_ids("[MASK]") in targets) 289 | #print(f"targets shape: {targets.shape}") 290 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 291 | with fabric.no_backward_sync(model, enabled=is_accumulating): 292 | logits = model(input_ids) 293 | loss = loss_func(logits, targets) 294 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 295 | fabric.backward(loss / gradient_accumulation_steps) 296 | 297 | if not is_accumulating: 298 | wandb_logger.log_metrics({ 299 | "train/loss": loss.item(), 300 | "train/learning_rate": lr, 301 | "train/step": state["step_count"], 302 | "system/gpu_memory": torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 303 | "system/gpu_memory_reserved": torch.cuda.max_memory_reserved() / 1e9 if torch.cuda.is_available() else 0, 304 | "performance/iter_time": (t1 - iter_t0) * 1000, 305 | "performance/tokens_per_sec": (micro_batch_size * model.config.block_size) / (t1 - iter_t0), 306 | "system/estimated_tflops": (estimated_flops * fabric.world_size / 1e12) * (1 / (t1 - iter_t0)) 307 | }, step=state["step_count"]) 308 | 309 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 310 | optimizer.step() 311 | optimizer.zero_grad() 312 | state["step_count"] += 1 313 | elif fabric.device.type == "xla": 314 | xm.mark_step() 315 | state["iter_num"] += 1 316 | 317 | # input_id: B L 318 | total_lengths += input_ids.size(1) 319 | t1 = time.perf_counter() 320 | fabric.print( 321 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 322 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 323 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 324 | # print days as well 325 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 326 | ) 327 | 328 | monitor.on_train_batch_end( 329 | state["iter_num"] * micro_batch_size, 330 | t1 - total_t0, 331 | # this assumes that device FLOPs are the same and that all devices have the same batch size 332 | fabric.world_size, 333 | state["step_count"], 334 | flops_per_batch=estimated_flops, 335 | lengths=total_lengths, 336 | train_loss = loss.item() 337 | ) 338 | 339 | 340 | 341 | 342 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 343 | 344 | t0 = time.perf_counter() 345 | val_loss = validate(fabric, model, val_dataloader) 346 | t1 = time.perf_counter() - t0 347 | monitor.eval_end(t1) 348 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 349 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 350 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 351 | fabric.barrier() 352 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 353 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 354 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 355 | fabric.save(checkpoint_path, state) 356 | 357 | 358 | @torch.no_grad() 359 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 360 | fabric.print("Validating ...") 361 | model.eval() 362 | 363 | losses = torch.zeros(eval_iters, device=fabric.device) 364 | for k, val_data in enumerate(val_dataloader): 365 | if k >= eval_iters: 366 | break 367 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 368 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 369 | logits = model(input_ids) 370 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 371 | 372 | # loss_func = FusedCrossEntropyLoss() 373 | # loss = loss_func(logits, targets) 374 | losses[k] = loss.item() 375 | 376 | out = losses.mean() 377 | 378 | model.train() 379 | return out 380 | 381 | 382 | def create_dataloader( 383 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 384 | ) -> DataLoader: 385 | datasets = [] 386 | data_config = train_data_config if split == "train" else val_data_config 387 | 388 | for prefix, _ in data_config: 389 | #print(f"data_dir: {data_dir}") 390 | #print(f"prefix: {prefix}") 391 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 392 | #print(f"filenames: {filenames}") 393 | random.seed(seed) 394 | random.shuffle(filenames) 395 | #print(f"filenames after shuffle: {filenames}") 396 | dataset = PackedDataset( 397 | filenames, 398 | # n_chunks control the buffer size. 399 | # Note that the buffer size also impacts the random shuffle 400 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 401 | n_chunks=1, 402 | block_size=block_size, 403 | shuffle=shuffle, 404 | seed=seed+fabric.global_rank, 405 | num_processes=fabric.world_size, 406 | process_rank=fabric.global_rank, 407 | ) 408 | datasets.append(dataset) 409 | 410 | if not datasets: 411 | raise RuntimeError( 412 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 413 | ) 414 | 415 | weights = [weight for _, weight in data_config] 416 | 417 | sum_weights = sum(weights) 418 | 419 | weights = [el / sum_weights for el in weights] 420 | 421 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 422 | 423 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 424 | 425 | 426 | def create_dataloaders( 427 | batch_size: int, 428 | block_size: int, 429 | fabric, 430 | train_data_dir: Path = Path("data/redpajama_sample"), 431 | val_data_dir: Optional[Path] = None, 432 | seed: int = 12345, 433 | ) -> Tuple[DataLoader, DataLoader]: 434 | # Increase by one because we need the next word as well 435 | effective_block_size = block_size + 1 436 | train_dataloader = create_dataloader( 437 | batch_size=batch_size, 438 | block_size=effective_block_size, 439 | fabric=fabric, 440 | data_dir=train_data_dir, 441 | shuffle=True, 442 | seed=seed, 443 | split="train" 444 | ) 445 | val_dataloader = ( 446 | create_dataloader( 447 | batch_size=batch_size, 448 | block_size=effective_block_size, 449 | fabric=fabric, 450 | data_dir=val_data_dir, 451 | shuffle=False, 452 | seed=seed, 453 | split="validation" 454 | ) 455 | if val_data_dir 456 | else None 457 | ) 458 | return train_dataloader, val_dataloader 459 | 460 | 461 | # learning rate decay scheduler (cosine with warmup) 462 | def get_lr(it): 463 | # 1) linear warmup for warmup_iters steps 464 | if it < warmup_iters: 465 | return learning_rate * it / warmup_iters 466 | # 2) if it > lr_decay_iters, return min learning rate 467 | if it > lr_decay_iters: 468 | return min_lr 469 | # 3) in between, use cosine decay down to min learning rate 470 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 471 | assert 0 <= decay_ratio <= 1 472 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 473 | return min_lr + coeff * (learning_rate - min_lr) 474 | 475 | 476 | if __name__ == "__main__": 477 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 478 | # torch.backends.cuda.enable_flash_sdp(False) 479 | torch.set_float32_matmul_precision("high") 480 | 481 | from jsonargparse import CLI 482 | 483 | CLI(setup) 484 | 485 | -------------------------------------------------------------------------------- /MEAP-Pretain/pretrained/meap_0.5b_10b_0.05.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | 14 | from transformers import AutoTokenizer 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 19 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 20 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 21 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 22 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 23 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 24 | from pytorch_lightning.loggers import WandbLogger 25 | from lit_gpt import FusedCrossEntropyLoss 26 | import random 27 | 28 | model_name = "tiny_LLaMA_0.5b_mask" 29 | name = "meap_0.5b_mask" 30 | out_dir = Path("out_meap_0.5b_mask_0.05mask_ratio") / name 31 | #num_nodes=1 32 | # Hyperparameters 33 | num_of_devices = 8 34 | global_batch_size = 256 35 | learning_rate = 4e-4 36 | micro_batch_size = 8 37 | max_step = 10000 38 | warmup_steps = 1000 39 | log_step_interval = 10 40 | eval_iters = 100 41 | save_step_interval = 5000 42 | eval_step_interval = 500 43 | 44 | 45 | weight_decay = 5e-2 46 | beta1 = 0.9 47 | beta2 = 0.95 48 | grad_clip = 1.0 49 | decay_lr = True 50 | min_lr = 4e-5 51 | 52 | batch_size = global_batch_size // num_of_devices 53 | gradient_accumulation_steps = batch_size // micro_batch_size 54 | assert gradient_accumulation_steps > 0 55 | warmup_iters = warmup_steps * gradient_accumulation_steps 56 | 57 | checkpoint_dir = Path("../tokenizer") 58 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False,trust_remote_code=True) 59 | special_tokens = { 60 | "additional_special_tokens": ["[MASK]"], 61 | "pad_token": tokenizer.pad_token or tokenizer.eos_token, 62 | "eos_token": tokenizer.eos_token, 63 | "bos_token": tokenizer.bos_token or tokenizer.eos_token, 64 | } 65 | num_added_tokens = tokenizer.add_special_tokens(special_tokens) 66 | use_mask = True 67 | mask_ratio = 0.05 68 | max_iters = max_step * gradient_accumulation_steps 69 | lr_decay_iters = max_iters 70 | log_iter_interval = log_step_interval * gradient_accumulation_steps 71 | 72 | 73 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 74 | train_data_config = [ 75 | ("c4", 1.0), 76 | ] 77 | 78 | val_data_config = [ 79 | ("validation", 1.0), 80 | ] 81 | 82 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 83 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 84 | wandb_logger = WandbLogger(project="NTP-MASK", 85 | name="meap_0.5b_mask_10b_0.05mask_ratio_1220", 86 | log_model=True, 87 | save_dir=str(out_dir), 88 | config={ 89 | 90 | "model_name": model_name, 91 | "total_params": None, 92 | 93 | 94 | "global_batch_size": global_batch_size, 95 | "micro_batch_size": micro_batch_size, 96 | "gradient_accumulation_steps": gradient_accumulation_steps, 97 | "learning_rate": learning_rate, 98 | "min_lr": min_lr, 99 | "weight_decay": weight_decay, 100 | "warmup_steps": warmup_steps, 101 | "max_step": max_step, 102 | "beta1": beta1, 103 | "beta2": beta2, 104 | "grad_clip": grad_clip, 105 | 106 | 107 | "train_data_config": train_data_config, 108 | "val_data_config": val_data_config, 109 | 110 | 111 | "num_devices": num_of_devices, 112 | "precision": None, 113 | } 114 | 115 | ) 116 | 117 | 118 | 119 | def setup( 120 | devices: int = 8, 121 | train_data_dir: Path = Path("../c4_bin"), 122 | val_data_dir: Optional[Path] = None, 123 | precision: Optional[str] = None, 124 | tpu: bool = False, 125 | resume: Union[bool, Path] = False, 126 | ) -> None: 127 | print("devices: ", devices) 128 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 129 | 130 | if devices > 1: 131 | if tpu: 132 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 133 | devices = "auto" 134 | strategy = XLAStrategy(sync_module_states=False) 135 | else: 136 | strategy = FSDPStrategy( 137 | auto_wrap_policy={Block}, 138 | activation_checkpointing_policy=None, 139 | state_dict_type="full", 140 | limit_all_gathers=True, 141 | cpu_offload=False, 142 | ) 143 | else: 144 | strategy = "auto" 145 | 146 | fabric = L.Fabric(devices=devices,strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 147 | fabric.print(hparams) 148 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 149 | main(fabric, train_data_dir, val_data_dir, resume) 150 | 151 | 152 | def main(fabric, train_data_dir, val_data_dir, resume): 153 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 154 | 155 | if fabric.global_rank == 0: 156 | out_dir.mkdir(parents=True, exist_ok=True) 157 | 158 | config = Config.from_name(model_name) 159 | fabric.print(f"train_data_dir: {train_data_dir}") 160 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 161 | fabric.print(f"val_data_dir: {val_data_dir}") 162 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 163 | #train_data_dir=Path("./output_data/slim_star_combined ") 164 | #val_data_dir=Path("./output_data/slim_star_combined ") 165 | fabric.print(f"train_data_dir: {train_data_dir}") 166 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 167 | fabric.print(f"val_data_dir: {val_data_dir}") 168 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 169 | train_dataloader, val_dataloader = create_dataloaders( 170 | batch_size=micro_batch_size, 171 | block_size=config.block_size, 172 | fabric=fabric, 173 | train_data_dir=train_data_dir, 174 | val_data_dir=val_data_dir, 175 | seed=3407, 176 | ) 177 | if val_dataloader is None: 178 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 179 | else: 180 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 181 | 182 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 183 | 184 | fabric.print(f"Loading model with {config.__dict__}") 185 | t0 = time.perf_counter() 186 | with fabric.init_module(empty_init=False): 187 | model = GPT(config) 188 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 189 | 190 | total_params = num_parameters(model) 191 | #wandb_logger.experiment.config.update({"total_params": total_params}) 192 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 193 | fabric.print(f"Total parameters {num_parameters(model):,}") 194 | 195 | model = fabric.setup(model) 196 | optimizer = torch.optim.AdamW( 197 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 198 | ) 199 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 200 | optimizer = fabric.setup_optimizers(optimizer) 201 | 202 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 203 | 204 | if resume is True: 205 | resume = sorted(out_dir.glob("*.pth"))[-1] 206 | if resume : 207 | fabric.print(f"Resuming training from {resume}") 208 | fabric.load(resume, state) 209 | 210 | train_time = time.perf_counter() 211 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 212 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 213 | if fabric.device.type == "cuda": 214 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 215 | 216 | def get_cycled_seed(curr_iter): 217 | cycle_range = 51 218 | cycled_value = 50 + (curr_iter % cycle_range) 219 | return cycled_value 220 | 221 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 222 | model = state["model"] 223 | optimizer = state["optimizer"] 224 | 225 | if val_dataloader is not None: 226 | validate(fabric, model, val_dataloader) # sanity check 227 | 228 | with torch.device("meta"): 229 | meta_model = GPT(model.config) 230 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 231 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 232 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 233 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 234 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 235 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 236 | # measured_flos run in meta. Will trigger fusedRMSNorm error 237 | #measured_flops = measure_flops(meta_model, x) 238 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 239 | del meta_model, x 240 | 241 | total_lengths = 0 242 | total_t0 = time.perf_counter() 243 | 244 | if fabric.device.type == "xla": 245 | import torch_xla.core.xla_model as xm 246 | 247 | xm.mark_step() 248 | 249 | 250 | initial_iter = state["iter_num"] 251 | curr_iter = 0 252 | 253 | 254 | loss_func = FusedCrossEntropyLoss() 255 | for train_data in train_dataloader: 256 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 257 | if resume: 258 | if curr_iter < initial_iter: 259 | curr_iter += 1 260 | continue 261 | else: 262 | resume = False 263 | curr_iter = -1 264 | fabric.barrier() 265 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 266 | if state["iter_num"] >= max_iters: 267 | break 268 | 269 | # determine and set the learning rate for this iteration 270 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 271 | for param_group in optimizer.param_groups: 272 | param_group["lr"] = lr 273 | 274 | iter_t0 = time.perf_counter() 275 | input_ids = train_data[:, 0 : model.config.block_size].clone().contiguous() 276 | if use_mask: 277 | random.seed(get_cycled_seed(state["step_count"])) 278 | num_masks = max(1, int(model.config.block_size * mask_ratio)) 279 | mask_positions = random.sample(range(0, model.config.block_size-1), num_masks) 280 | 281 | bs, seq_len = input_ids.shape 282 | input_ids[:,mask_positions] = tokenizer.convert_tokens_to_ids("[MASK]") 283 | 284 | 285 | 286 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 287 | #print(tokenizer.convert_tokens_to_ids("[MASK]") in targets) 288 | #print(f"targets shape: {targets.shape}") 289 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 290 | with fabric.no_backward_sync(model, enabled=is_accumulating): 291 | logits = model(input_ids) 292 | loss = loss_func(logits, targets) 293 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 294 | fabric.backward(loss / gradient_accumulation_steps) 295 | 296 | if not is_accumulating: 297 | wandb_logger.log_metrics({ 298 | "train/loss": loss.item(), 299 | "train/learning_rate": lr, 300 | "train/step": state["step_count"], 301 | "system/gpu_memory": torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 302 | "system/gpu_memory_reserved": torch.cuda.max_memory_reserved() / 1e9 if torch.cuda.is_available() else 0, 303 | "performance/iter_time": (t1 - iter_t0) * 1000, 304 | "performance/tokens_per_sec": (micro_batch_size * model.config.block_size) / (t1 - iter_t0), 305 | "system/estimated_tflops": (estimated_flops * fabric.world_size / 1e12) * (1 / (t1 - iter_t0)) 306 | }, step=state["step_count"]) 307 | 308 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 309 | optimizer.step() 310 | optimizer.zero_grad() 311 | state["step_count"] += 1 312 | elif fabric.device.type == "xla": 313 | xm.mark_step() 314 | state["iter_num"] += 1 315 | 316 | # input_id: B L 317 | total_lengths += input_ids.size(1) 318 | t1 = time.perf_counter() 319 | fabric.print( 320 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 321 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 322 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 323 | # print days as well 324 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 325 | ) 326 | 327 | monitor.on_train_batch_end( 328 | state["iter_num"] * micro_batch_size, 329 | t1 - total_t0, 330 | # this assumes that device FLOPs are the same and that all devices have the same batch size 331 | fabric.world_size, 332 | state["step_count"], 333 | flops_per_batch=estimated_flops, 334 | lengths=total_lengths, 335 | train_loss = loss.item() 336 | ) 337 | 338 | 339 | 340 | 341 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 342 | 343 | t0 = time.perf_counter() 344 | val_loss = validate(fabric, model, val_dataloader) 345 | t1 = time.perf_counter() - t0 346 | monitor.eval_end(t1) 347 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 348 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 349 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 350 | fabric.barrier() 351 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 352 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 353 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 354 | fabric.save(checkpoint_path, state) 355 | 356 | 357 | @torch.no_grad() 358 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 359 | fabric.print("Validating ...") 360 | model.eval() 361 | 362 | losses = torch.zeros(eval_iters, device=fabric.device) 363 | for k, val_data in enumerate(val_dataloader): 364 | if k >= eval_iters: 365 | break 366 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 367 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 368 | logits = model(input_ids) 369 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 370 | 371 | # loss_func = FusedCrossEntropyLoss() 372 | # loss = loss_func(logits, targets) 373 | losses[k] = loss.item() 374 | 375 | out = losses.mean() 376 | 377 | model.train() 378 | return out 379 | 380 | 381 | def create_dataloader( 382 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 383 | ) -> DataLoader: 384 | datasets = [] 385 | data_config = train_data_config if split == "train" else val_data_config 386 | 387 | for prefix, _ in data_config: 388 | #print(f"data_dir: {data_dir}") 389 | #print(f"prefix: {prefix}") 390 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 391 | #print(f"filenames: {filenames}") 392 | random.seed(seed) 393 | random.shuffle(filenames) 394 | #print(f"filenames after shuffle: {filenames}") 395 | dataset = PackedDataset( 396 | filenames, 397 | # n_chunks control the buffer size. 398 | # Note that the buffer size also impacts the random shuffle 399 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 400 | n_chunks=1, 401 | block_size=block_size, 402 | shuffle=shuffle, 403 | seed=seed+fabric.global_rank, 404 | num_processes=fabric.world_size, 405 | process_rank=fabric.global_rank, 406 | ) 407 | datasets.append(dataset) 408 | 409 | if not datasets: 410 | raise RuntimeError( 411 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 412 | ) 413 | 414 | weights = [weight for _, weight in data_config] 415 | 416 | sum_weights = sum(weights) 417 | 418 | weights = [el / sum_weights for el in weights] 419 | 420 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 421 | 422 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 423 | 424 | 425 | def create_dataloaders( 426 | batch_size: int, 427 | block_size: int, 428 | fabric, 429 | train_data_dir: Path = Path("data/redpajama_sample"), 430 | val_data_dir: Optional[Path] = None, 431 | seed: int = 12345, 432 | ) -> Tuple[DataLoader, DataLoader]: 433 | # Increase by one because we need the next word as well 434 | effective_block_size = block_size + 1 435 | train_dataloader = create_dataloader( 436 | batch_size=batch_size, 437 | block_size=effective_block_size, 438 | fabric=fabric, 439 | data_dir=train_data_dir, 440 | shuffle=True, 441 | seed=seed, 442 | split="train" 443 | ) 444 | val_dataloader = ( 445 | create_dataloader( 446 | batch_size=batch_size, 447 | block_size=effective_block_size, 448 | fabric=fabric, 449 | data_dir=val_data_dir, 450 | shuffle=False, 451 | seed=seed, 452 | split="validation" 453 | ) 454 | if val_data_dir 455 | else None 456 | ) 457 | return train_dataloader, val_dataloader 458 | 459 | 460 | # learning rate decay scheduler (cosine with warmup) 461 | def get_lr(it): 462 | # 1) linear warmup for warmup_iters steps 463 | if it < warmup_iters: 464 | return learning_rate * it / warmup_iters 465 | # 2) if it > lr_decay_iters, return min learning rate 466 | if it > lr_decay_iters: 467 | return min_lr 468 | # 3) in between, use cosine decay down to min learning rate 469 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 470 | assert 0 <= decay_ratio <= 1 471 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 472 | return min_lr + coeff * (learning_rate - min_lr) 473 | 474 | 475 | if __name__ == "__main__": 476 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 477 | # torch.backends.cuda.enable_flash_sdp(False) 478 | torch.set_float32_matmul_precision("high") 479 | 480 | from jsonargparse import CLI 481 | 482 | CLI(setup) 483 | 484 | -------------------------------------------------------------------------------- /MEAP-Pretain/pretrained/meap_0.5b_10b_0.1.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import math 4 | import sys 5 | import time 6 | from pathlib import Path 7 | from typing import Optional, Tuple, Union 8 | import math 9 | import lightning as L 10 | import torch 11 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 12 | from torch.utils.data import DataLoader 13 | from functools import partial 14 | 15 | from transformers import AutoTokenizer 16 | # support running without installing as a package 17 | wd = Path(__file__).parent.parent.resolve() 18 | sys.path.append(str(wd)) 19 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 20 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 21 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 22 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 23 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 24 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 25 | from pytorch_lightning.loggers import WandbLogger 26 | from lit_gpt import FusedCrossEntropyLoss 27 | import random 28 | 29 | model_name = "tiny_LLaMA_0.5b_mask" 30 | name = "meap_0.5b_mask" 31 | out_dir = Path("out_meap_0.5b_mask_0.1mask_ratio") / name 32 | #num_nodes=1 33 | # Hyperparameters 34 | num_of_devices = 8 35 | global_batch_size = 256 36 | learning_rate = 4e-4 37 | micro_batch_size = 8 38 | max_step = 10000 39 | warmup_steps = 1000 40 | log_step_interval = 10 41 | eval_iters = 100 42 | save_step_interval = 5000 43 | eval_step_interval = 500 44 | 45 | 46 | weight_decay = 5e-2 47 | beta1 = 0.9 48 | beta2 = 0.95 49 | grad_clip = 1.0 50 | decay_lr = True 51 | min_lr = 4e-5 52 | 53 | batch_size = global_batch_size // num_of_devices 54 | gradient_accumulation_steps = batch_size // micro_batch_size 55 | assert gradient_accumulation_steps > 0 56 | warmup_iters = warmup_steps * gradient_accumulation_steps 57 | 58 | checkpoint_dir = Path("../tokenizer") 59 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False,trust_remote_code=True) 60 | special_tokens = { 61 | "additional_special_tokens": ["[MASK]"], 62 | "pad_token": tokenizer.pad_token or tokenizer.eos_token, 63 | "eos_token": tokenizer.eos_token, 64 | "bos_token": tokenizer.bos_token or tokenizer.eos_token, 65 | } 66 | num_added_tokens = tokenizer.add_special_tokens(special_tokens) 67 | use_mask = True 68 | mask_ratio = 0.1 69 | max_iters = max_step * gradient_accumulation_steps 70 | lr_decay_iters = max_iters 71 | log_iter_interval = log_step_interval * gradient_accumulation_steps 72 | 73 | 74 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 75 | train_data_config = [ 76 | ("c4", 1.0), 77 | ] 78 | 79 | val_data_config = [ 80 | ("validation", 1.0), 81 | ] 82 | 83 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 84 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 85 | wandb_logger = WandbLogger(project="NTP-MASK", 86 | name="meap_0.5b_mask_10b_0.1mask_ratio_1219", 87 | log_model=True, 88 | save_dir=str(out_dir), 89 | config={ 90 | 91 | "model_name": model_name, 92 | "total_params": None, 93 | 94 | 95 | "global_batch_size": global_batch_size, 96 | "micro_batch_size": micro_batch_size, 97 | "gradient_accumulation_steps": gradient_accumulation_steps, 98 | "learning_rate": learning_rate, 99 | "min_lr": min_lr, 100 | "weight_decay": weight_decay, 101 | "warmup_steps": warmup_steps, 102 | "max_step": max_step, 103 | "beta1": beta1, 104 | "beta2": beta2, 105 | "grad_clip": grad_clip, 106 | 107 | 108 | "train_data_config": train_data_config, 109 | "val_data_config": val_data_config, 110 | 111 | 112 | "num_devices": num_of_devices, 113 | "precision": None, 114 | } 115 | 116 | ) 117 | 118 | 119 | 120 | def setup( 121 | devices: int = 8, 122 | train_data_dir: Path = Path("../c4_bin"), 123 | val_data_dir: Optional[Path] = None, 124 | precision: Optional[str] = None, 125 | tpu: bool = False, 126 | resume: Union[bool, Path] = False, 127 | ) -> None: 128 | print("devices: ", devices) 129 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 130 | 131 | if devices > 1: 132 | if tpu: 133 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 134 | devices = "auto" 135 | strategy = XLAStrategy(sync_module_states=False) 136 | else: 137 | strategy = FSDPStrategy( 138 | auto_wrap_policy={Block}, 139 | activation_checkpointing_policy=None, 140 | state_dict_type="full", 141 | limit_all_gathers=True, 142 | cpu_offload=False, 143 | ) 144 | else: 145 | strategy = "auto" 146 | 147 | fabric = L.Fabric(devices=devices,strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 148 | fabric.print(hparams) 149 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 150 | main(fabric, train_data_dir, val_data_dir, resume) 151 | 152 | 153 | def main(fabric, train_data_dir, val_data_dir, resume): 154 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 155 | 156 | if fabric.global_rank == 0: 157 | out_dir.mkdir(parents=True, exist_ok=True) 158 | 159 | config = Config.from_name(model_name) 160 | fabric.print(f"train_data_dir: {train_data_dir}") 161 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 162 | fabric.print(f"val_data_dir: {val_data_dir}") 163 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 164 | #train_data_dir=Path("./output_data/slim_star_combined ") 165 | #val_data_dir=Path("./output_data/slim_star_combined ") 166 | fabric.print(f"train_data_dir: {train_data_dir}") 167 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 168 | fabric.print(f"val_data_dir: {val_data_dir}") 169 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 170 | train_dataloader, val_dataloader = create_dataloaders( 171 | batch_size=micro_batch_size, 172 | block_size=config.block_size, 173 | fabric=fabric, 174 | train_data_dir=train_data_dir, 175 | val_data_dir=val_data_dir, 176 | seed=3407, 177 | ) 178 | if val_dataloader is None: 179 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 180 | else: 181 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 182 | 183 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 184 | 185 | fabric.print(f"Loading model with {config.__dict__}") 186 | t0 = time.perf_counter() 187 | with fabric.init_module(empty_init=False): 188 | model = GPT(config) 189 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 190 | 191 | total_params = num_parameters(model) 192 | #wandb_logger.experiment.config.update({"total_params": total_params}) 193 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 194 | fabric.print(f"Total parameters {num_parameters(model):,}") 195 | 196 | model = fabric.setup(model) 197 | optimizer = torch.optim.AdamW( 198 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 199 | ) 200 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 201 | optimizer = fabric.setup_optimizers(optimizer) 202 | 203 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 204 | 205 | if resume is True: 206 | resume = sorted(out_dir.glob("*.pth"))[-1] 207 | if resume : 208 | fabric.print(f"Resuming training from {resume}") 209 | fabric.load(resume, state) 210 | 211 | train_time = time.perf_counter() 212 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 213 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 214 | if fabric.device.type == "cuda": 215 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 216 | 217 | def get_cycled_seed(curr_iter): 218 | cycle_range = 51 219 | cycled_value = 50 + (curr_iter % cycle_range) 220 | return cycled_value 221 | 222 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 223 | model = state["model"] 224 | optimizer = state["optimizer"] 225 | 226 | if val_dataloader is not None: 227 | validate(fabric, model, val_dataloader) # sanity check 228 | 229 | with torch.device("meta"): 230 | meta_model = GPT(model.config) 231 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 232 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 233 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 234 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 235 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 236 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 237 | # measured_flos run in meta. Will trigger fusedRMSNorm error 238 | #measured_flops = measure_flops(meta_model, x) 239 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 240 | del meta_model, x 241 | 242 | total_lengths = 0 243 | total_t0 = time.perf_counter() 244 | 245 | if fabric.device.type == "xla": 246 | import torch_xla.core.xla_model as xm 247 | 248 | xm.mark_step() 249 | 250 | 251 | initial_iter = state["iter_num"] 252 | curr_iter = 0 253 | 254 | 255 | loss_func = FusedCrossEntropyLoss() 256 | for train_data in train_dataloader: 257 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 258 | if resume: 259 | if curr_iter < initial_iter: 260 | curr_iter += 1 261 | continue 262 | else: 263 | resume = False 264 | curr_iter = -1 265 | fabric.barrier() 266 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 267 | if state["iter_num"] >= max_iters: 268 | break 269 | 270 | # determine and set the learning rate for this iteration 271 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 272 | for param_group in optimizer.param_groups: 273 | param_group["lr"] = lr 274 | 275 | iter_t0 = time.perf_counter() 276 | input_ids = train_data[:, 0 : model.config.block_size].clone().contiguous() 277 | if use_mask: 278 | random.seed(get_cycled_seed(state["step_count"])) 279 | num_masks = max(1, int(model.config.block_size * mask_ratio)) 280 | mask_positions = random.sample(range(0, model.config.block_size-1), num_masks) 281 | 282 | bs, seq_len = input_ids.shape 283 | input_ids[:,mask_positions] = tokenizer.convert_tokens_to_ids("[MASK]") 284 | 285 | 286 | 287 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 288 | #print(tokenizer.convert_tokens_to_ids("[MASK]") in targets) 289 | #print(f"targets shape: {targets.shape}") 290 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 291 | with fabric.no_backward_sync(model, enabled=is_accumulating): 292 | logits = model(input_ids) 293 | loss = loss_func(logits, targets) 294 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 295 | fabric.backward(loss / gradient_accumulation_steps) 296 | 297 | if not is_accumulating: 298 | wandb_logger.log_metrics({ 299 | "train/loss": loss.item(), 300 | "train/learning_rate": lr, 301 | "train/step": state["step_count"], 302 | "system/gpu_memory": torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 303 | "system/gpu_memory_reserved": torch.cuda.max_memory_reserved() / 1e9 if torch.cuda.is_available() else 0, 304 | "performance/iter_time": (t1 - iter_t0) * 1000, 305 | "performance/tokens_per_sec": (micro_batch_size * model.config.block_size) / (t1 - iter_t0), 306 | "system/estimated_tflops": (estimated_flops * fabric.world_size / 1e12) * (1 / (t1 - iter_t0)) 307 | }, step=state["step_count"]) 308 | 309 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 310 | optimizer.step() 311 | optimizer.zero_grad() 312 | state["step_count"] += 1 313 | elif fabric.device.type == "xla": 314 | xm.mark_step() 315 | state["iter_num"] += 1 316 | 317 | # input_id: B L 318 | total_lengths += input_ids.size(1) 319 | t1 = time.perf_counter() 320 | fabric.print( 321 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 322 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 323 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 324 | # print days as well 325 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 326 | ) 327 | 328 | monitor.on_train_batch_end( 329 | state["iter_num"] * micro_batch_size, 330 | t1 - total_t0, 331 | # this assumes that device FLOPs are the same and that all devices have the same batch size 332 | fabric.world_size, 333 | state["step_count"], 334 | flops_per_batch=estimated_flops, 335 | lengths=total_lengths, 336 | train_loss = loss.item() 337 | ) 338 | 339 | 340 | 341 | 342 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 343 | 344 | t0 = time.perf_counter() 345 | val_loss = validate(fabric, model, val_dataloader) 346 | t1 = time.perf_counter() - t0 347 | monitor.eval_end(t1) 348 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 349 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 350 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 351 | fabric.barrier() 352 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 353 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 354 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 355 | fabric.save(checkpoint_path, state) 356 | 357 | 358 | @torch.no_grad() 359 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 360 | fabric.print("Validating ...") 361 | model.eval() 362 | 363 | losses = torch.zeros(eval_iters, device=fabric.device) 364 | for k, val_data in enumerate(val_dataloader): 365 | if k >= eval_iters: 366 | break 367 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 368 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 369 | logits = model(input_ids) 370 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 371 | 372 | # loss_func = FusedCrossEntropyLoss() 373 | # loss = loss_func(logits, targets) 374 | losses[k] = loss.item() 375 | 376 | out = losses.mean() 377 | 378 | model.train() 379 | return out 380 | 381 | 382 | def create_dataloader( 383 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 384 | ) -> DataLoader: 385 | datasets = [] 386 | data_config = train_data_config if split == "train" else val_data_config 387 | 388 | for prefix, _ in data_config: 389 | #print(f"data_dir: {data_dir}") 390 | #print(f"prefix: {prefix}") 391 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 392 | #print(f"filenames: {filenames}") 393 | random.seed(seed) 394 | random.shuffle(filenames) 395 | #print(f"filenames after shuffle: {filenames}") 396 | dataset = PackedDataset( 397 | filenames, 398 | # n_chunks control the buffer size. 399 | # Note that the buffer size also impacts the random shuffle 400 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 401 | n_chunks=8, 402 | block_size=block_size, 403 | shuffle=shuffle, 404 | seed=seed+fabric.global_rank, 405 | num_processes=fabric.world_size, 406 | process_rank=fabric.global_rank, 407 | ) 408 | datasets.append(dataset) 409 | 410 | if not datasets: 411 | raise RuntimeError( 412 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 413 | ) 414 | 415 | weights = [weight for _, weight in data_config] 416 | 417 | sum_weights = sum(weights) 418 | 419 | weights = [el / sum_weights for el in weights] 420 | 421 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 422 | 423 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 424 | 425 | 426 | def create_dataloaders( 427 | batch_size: int, 428 | block_size: int, 429 | fabric, 430 | train_data_dir: Path = Path("data/redpajama_sample"), 431 | val_data_dir: Optional[Path] = None, 432 | seed: int = 12345, 433 | ) -> Tuple[DataLoader, DataLoader]: 434 | # Increase by one because we need the next word as well 435 | effective_block_size = block_size + 1 436 | train_dataloader = create_dataloader( 437 | batch_size=batch_size, 438 | block_size=effective_block_size, 439 | fabric=fabric, 440 | data_dir=train_data_dir, 441 | shuffle=True, 442 | seed=seed, 443 | split="train" 444 | ) 445 | val_dataloader = ( 446 | create_dataloader( 447 | batch_size=batch_size, 448 | block_size=effective_block_size, 449 | fabric=fabric, 450 | data_dir=val_data_dir, 451 | shuffle=False, 452 | seed=seed, 453 | split="validation" 454 | ) 455 | if val_data_dir 456 | else None 457 | ) 458 | return train_dataloader, val_dataloader 459 | 460 | 461 | # learning rate decay scheduler (cosine with warmup) 462 | def get_lr(it): 463 | # 1) linear warmup for warmup_iters steps 464 | if it < warmup_iters: 465 | return learning_rate * it / warmup_iters 466 | # 2) if it > lr_decay_iters, return min learning rate 467 | if it > lr_decay_iters: 468 | return min_lr 469 | # 3) in between, use cosine decay down to min learning rate 470 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 471 | assert 0 <= decay_ratio <= 1 472 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 473 | return min_lr + coeff * (learning_rate - min_lr) 474 | 475 | 476 | if __name__ == "__main__": 477 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 478 | # torch.backends.cuda.enable_flash_sdp(False) 479 | torch.set_float32_matmul_precision("high") 480 | 481 | from jsonargparse import CLI 482 | 483 | CLI(setup) 484 | -------------------------------------------------------------------------------- /MEAP-Pretain/pretrained/meap_0.5b_10b_0.2.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import math 4 | import sys 5 | import time 6 | from pathlib import Path 7 | from typing import Optional, Tuple, Union 8 | import math 9 | import lightning as L 10 | import torch 11 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 12 | from torch.utils.data import DataLoader 13 | from functools import partial 14 | 15 | from transformers import AutoTokenizer 16 | # support running without installing as a package 17 | wd = Path(__file__).parent.parent.resolve() 18 | sys.path.append(str(wd)) 19 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 20 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 21 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 22 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 23 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 24 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 25 | from pytorch_lightning.loggers import WandbLogger 26 | from lit_gpt import FusedCrossEntropyLoss 27 | import random 28 | 29 | model_name = "tiny_LLaMA_0.5b_mask" 30 | name = "meap_0.5b_mask" 31 | out_dir = Path("out_meap_0.5b_mask_0.2mask_ratio") / name 32 | #num_nodes=1 33 | # Hyperparameters 34 | num_of_devices = 8 35 | global_batch_size = 256 36 | learning_rate = 4e-4 37 | micro_batch_size = 8 38 | max_step = 10000 39 | warmup_steps = 1000 40 | log_step_interval = 10 41 | eval_iters = 100 42 | save_step_interval = 5000 43 | eval_step_interval = 500 44 | 45 | 46 | weight_decay = 5e-2 47 | beta1 = 0.9 48 | beta2 = 0.95 49 | grad_clip = 1.0 50 | decay_lr = True 51 | min_lr = 4e-5 52 | 53 | batch_size = global_batch_size // num_of_devices 54 | gradient_accumulation_steps = batch_size // micro_batch_size 55 | assert gradient_accumulation_steps > 0 56 | warmup_iters = warmup_steps * gradient_accumulation_steps 57 | 58 | checkpoint_dir = Path("../tokenizer") 59 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False,trust_remote_code=True) 60 | special_tokens = { 61 | "additional_special_tokens": ["[MASK]"], 62 | "pad_token": tokenizer.pad_token or tokenizer.eos_token, 63 | "eos_token": tokenizer.eos_token, 64 | "bos_token": tokenizer.bos_token or tokenizer.eos_token, 65 | } 66 | num_added_tokens = tokenizer.add_special_tokens(special_tokens) 67 | use_mask = True 68 | mask_ratio = 0.2 69 | max_iters = max_step * gradient_accumulation_steps 70 | lr_decay_iters = max_iters 71 | log_iter_interval = log_step_interval * gradient_accumulation_steps 72 | 73 | 74 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 75 | train_data_config = [ 76 | ("c4", 1.0), 77 | ] 78 | 79 | val_data_config = [ 80 | ("validation", 1.0), 81 | ] 82 | 83 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 84 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 85 | wandb_logger = WandbLogger(project="NTP-MASK", 86 | name="meap_0.5b_mask_10b_0.2mask_ratio_0116", 87 | log_model=True, 88 | save_dir=str(out_dir), 89 | config={ 90 | 91 | "model_name": model_name, 92 | "total_params": None, 93 | 94 | 95 | "global_batch_size": global_batch_size, 96 | "micro_batch_size": micro_batch_size, 97 | "gradient_accumulation_steps": gradient_accumulation_steps, 98 | "learning_rate": learning_rate, 99 | "min_lr": min_lr, 100 | "weight_decay": weight_decay, 101 | "warmup_steps": warmup_steps, 102 | "max_step": max_step, 103 | "beta1": beta1, 104 | "beta2": beta2, 105 | "grad_clip": grad_clip, 106 | 107 | 108 | "train_data_config": train_data_config, 109 | "val_data_config": val_data_config, 110 | 111 | 112 | "num_devices": num_of_devices, 113 | "precision": None, 114 | } 115 | 116 | ) 117 | 118 | 119 | 120 | def setup( 121 | devices: int = 8, 122 | train_data_dir: Path = Path("../c4_bin"), 123 | val_data_dir: Optional[Path] = None, 124 | precision: Optional[str] = None, 125 | tpu: bool = False, 126 | resume: Union[bool, Path] = False, 127 | ) -> None: 128 | print("devices: ", devices) 129 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 130 | 131 | if devices > 1: 132 | if tpu: 133 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 134 | devices = "auto" 135 | strategy = XLAStrategy(sync_module_states=False) 136 | else: 137 | strategy = FSDPStrategy( 138 | auto_wrap_policy={Block}, 139 | activation_checkpointing_policy=None, 140 | state_dict_type="full", 141 | limit_all_gathers=True, 142 | cpu_offload=False, 143 | ) 144 | else: 145 | strategy = "auto" 146 | 147 | fabric = L.Fabric(devices=devices,strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 148 | fabric.print(hparams) 149 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 150 | main(fabric, train_data_dir, val_data_dir, resume) 151 | 152 | 153 | def main(fabric, train_data_dir, val_data_dir, resume): 154 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 155 | 156 | if fabric.global_rank == 0: 157 | out_dir.mkdir(parents=True, exist_ok=True) 158 | 159 | config = Config.from_name(model_name) 160 | fabric.print(f"train_data_dir: {train_data_dir}") 161 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 162 | fabric.print(f"val_data_dir: {val_data_dir}") 163 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 164 | #train_data_dir=Path("./output_data/slim_star_combined ") 165 | #val_data_dir=Path("./output_data/slim_star_combined ") 166 | fabric.print(f"train_data_dir: {train_data_dir}") 167 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 168 | fabric.print(f"val_data_dir: {val_data_dir}") 169 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 170 | train_dataloader, val_dataloader = create_dataloaders( 171 | batch_size=micro_batch_size, 172 | block_size=config.block_size, 173 | fabric=fabric, 174 | train_data_dir=train_data_dir, 175 | val_data_dir=val_data_dir, 176 | seed=3407, 177 | ) 178 | if val_dataloader is None: 179 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 180 | else: 181 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 182 | 183 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 184 | 185 | fabric.print(f"Loading model with {config.__dict__}") 186 | t0 = time.perf_counter() 187 | with fabric.init_module(empty_init=False): 188 | model = GPT(config) 189 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 190 | 191 | total_params = num_parameters(model) 192 | #wandb_logger.experiment.config.update({"total_params": total_params}) 193 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 194 | fabric.print(f"Total parameters {num_parameters(model):,}") 195 | 196 | model = fabric.setup(model) 197 | optimizer = torch.optim.AdamW( 198 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 199 | ) 200 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 201 | optimizer = fabric.setup_optimizers(optimizer) 202 | 203 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 204 | 205 | if resume is True: 206 | resume = sorted(out_dir.glob("*.pth"))[-1] 207 | if resume : 208 | fabric.print(f"Resuming training from {resume}") 209 | fabric.load(resume, state) 210 | 211 | train_time = time.perf_counter() 212 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 213 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 214 | if fabric.device.type == "cuda": 215 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 216 | 217 | def get_cycled_seed(curr_iter): 218 | cycle_range = 51 219 | cycled_value = 50 + (curr_iter % cycle_range) 220 | return cycled_value 221 | 222 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 223 | model = state["model"] 224 | optimizer = state["optimizer"] 225 | 226 | if val_dataloader is not None: 227 | validate(fabric, model, val_dataloader) # sanity check 228 | 229 | with torch.device("meta"): 230 | meta_model = GPT(model.config) 231 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 232 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 233 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 234 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 235 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 236 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 237 | # measured_flos run in meta. Will trigger fusedRMSNorm error 238 | #measured_flops = measure_flops(meta_model, x) 239 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 240 | del meta_model, x 241 | 242 | total_lengths = 0 243 | total_t0 = time.perf_counter() 244 | 245 | if fabric.device.type == "xla": 246 | import torch_xla.core.xla_model as xm 247 | 248 | xm.mark_step() 249 | 250 | 251 | initial_iter = state["iter_num"] 252 | curr_iter = 0 253 | 254 | 255 | loss_func = FusedCrossEntropyLoss() 256 | for train_data in train_dataloader: 257 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 258 | if resume: 259 | if curr_iter < initial_iter: 260 | curr_iter += 1 261 | continue 262 | else: 263 | resume = False 264 | curr_iter = -1 265 | fabric.barrier() 266 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 267 | if state["iter_num"] >= max_iters: 268 | break 269 | 270 | # determine and set the learning rate for this iteration 271 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 272 | for param_group in optimizer.param_groups: 273 | param_group["lr"] = lr 274 | 275 | iter_t0 = time.perf_counter() 276 | input_ids = train_data[:, 0 : model.config.block_size].clone().contiguous() 277 | if use_mask: 278 | random.seed(get_cycled_seed(state["step_count"])) 279 | num_masks = max(1, int(model.config.block_size * mask_ratio)) 280 | mask_positions = random.sample(range(0, model.config.block_size-1), num_masks) 281 | 282 | bs, seq_len = input_ids.shape 283 | input_ids[:,mask_positions] = tokenizer.convert_tokens_to_ids("[MASK]") 284 | 285 | 286 | 287 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 288 | #print(tokenizer.convert_tokens_to_ids("[MASK]") in targets) 289 | #print(f"targets shape: {targets.shape}") 290 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 291 | with fabric.no_backward_sync(model, enabled=is_accumulating): 292 | logits = model(input_ids) 293 | loss = loss_func(logits, targets) 294 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 295 | fabric.backward(loss / gradient_accumulation_steps) 296 | 297 | if not is_accumulating: 298 | wandb_logger.log_metrics({ 299 | "train/loss": loss.item(), 300 | "train/learning_rate": lr, 301 | "train/step": state["step_count"], 302 | "system/gpu_memory": torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 303 | "system/gpu_memory_reserved": torch.cuda.max_memory_reserved() / 1e9 if torch.cuda.is_available() else 0, 304 | "performance/iter_time": (t1 - iter_t0) * 1000, 305 | "performance/tokens_per_sec": (micro_batch_size * model.config.block_size) / (t1 - iter_t0), 306 | "system/estimated_tflops": (estimated_flops * fabric.world_size / 1e12) * (1 / (t1 - iter_t0)) 307 | }, step=state["step_count"]) 308 | 309 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 310 | optimizer.step() 311 | optimizer.zero_grad() 312 | state["step_count"] += 1 313 | elif fabric.device.type == "xla": 314 | xm.mark_step() 315 | state["iter_num"] += 1 316 | 317 | # input_id: B L 318 | total_lengths += input_ids.size(1) 319 | t1 = time.perf_counter() 320 | fabric.print( 321 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 322 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 323 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 324 | # print days as well 325 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 326 | ) 327 | 328 | monitor.on_train_batch_end( 329 | state["iter_num"] * micro_batch_size, 330 | t1 - total_t0, 331 | # this assumes that device FLOPs are the same and that all devices have the same batch size 332 | fabric.world_size, 333 | state["step_count"], 334 | flops_per_batch=estimated_flops, 335 | lengths=total_lengths, 336 | train_loss = loss.item() 337 | ) 338 | 339 | 340 | 341 | 342 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 343 | 344 | t0 = time.perf_counter() 345 | val_loss = validate(fabric, model, val_dataloader) 346 | t1 = time.perf_counter() - t0 347 | monitor.eval_end(t1) 348 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 349 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 350 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 351 | fabric.barrier() 352 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 353 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 354 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 355 | fabric.save(checkpoint_path, state) 356 | 357 | 358 | @torch.no_grad() 359 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 360 | fabric.print("Validating ...") 361 | model.eval() 362 | 363 | losses = torch.zeros(eval_iters, device=fabric.device) 364 | for k, val_data in enumerate(val_dataloader): 365 | if k >= eval_iters: 366 | break 367 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 368 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 369 | logits = model(input_ids) 370 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 371 | 372 | # loss_func = FusedCrossEntropyLoss() 373 | # loss = loss_func(logits, targets) 374 | losses[k] = loss.item() 375 | 376 | out = losses.mean() 377 | 378 | model.train() 379 | return out 380 | 381 | 382 | def create_dataloader( 383 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 384 | ) -> DataLoader: 385 | datasets = [] 386 | data_config = train_data_config if split == "train" else val_data_config 387 | 388 | for prefix, _ in data_config: 389 | #print(f"data_dir: {data_dir}") 390 | #print(f"prefix: {prefix}") 391 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 392 | #print(f"filenames: {filenames}") 393 | random.seed(seed) 394 | random.shuffle(filenames) 395 | #print(f"filenames after shuffle: {filenames}") 396 | dataset = PackedDataset( 397 | filenames, 398 | # n_chunks control the buffer size. 399 | # Note that the buffer size also impacts the random shuffle 400 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 401 | n_chunks=1, 402 | block_size=block_size, 403 | shuffle=shuffle, 404 | seed=seed+fabric.global_rank, 405 | num_processes=fabric.world_size, 406 | process_rank=fabric.global_rank, 407 | ) 408 | datasets.append(dataset) 409 | 410 | if not datasets: 411 | raise RuntimeError( 412 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 413 | ) 414 | 415 | weights = [weight for _, weight in data_config] 416 | 417 | sum_weights = sum(weights) 418 | 419 | weights = [el / sum_weights for el in weights] 420 | 421 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 422 | 423 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 424 | 425 | 426 | def create_dataloaders( 427 | batch_size: int, 428 | block_size: int, 429 | fabric, 430 | train_data_dir: Path = Path("data/redpajama_sample"), 431 | val_data_dir: Optional[Path] = None, 432 | seed: int = 12345, 433 | ) -> Tuple[DataLoader, DataLoader]: 434 | # Increase by one because we need the next word as well 435 | effective_block_size = block_size + 1 436 | train_dataloader = create_dataloader( 437 | batch_size=batch_size, 438 | block_size=effective_block_size, 439 | fabric=fabric, 440 | data_dir=train_data_dir, 441 | shuffle=True, 442 | seed=seed, 443 | split="train" 444 | ) 445 | val_dataloader = ( 446 | create_dataloader( 447 | batch_size=batch_size, 448 | block_size=effective_block_size, 449 | fabric=fabric, 450 | data_dir=val_data_dir, 451 | shuffle=False, 452 | seed=seed, 453 | split="validation" 454 | ) 455 | if val_data_dir 456 | else None 457 | ) 458 | return train_dataloader, val_dataloader 459 | 460 | 461 | # learning rate decay scheduler (cosine with warmup) 462 | def get_lr(it): 463 | # 1) linear warmup for warmup_iters steps 464 | if it < warmup_iters: 465 | return learning_rate * it / warmup_iters 466 | # 2) if it > lr_decay_iters, return min learning rate 467 | if it > lr_decay_iters: 468 | return min_lr 469 | # 3) in between, use cosine decay down to min learning rate 470 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 471 | assert 0 <= decay_ratio <= 1 472 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 473 | return min_lr + coeff * (learning_rate - min_lr) 474 | 475 | 476 | if __name__ == "__main__": 477 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 478 | # torch.backends.cuda.enable_flash_sdp(False) 479 | torch.set_float32_matmul_precision("high") 480 | 481 | from jsonargparse import CLI 482 | 483 | CLI(setup) 484 | -------------------------------------------------------------------------------- /MEAP-Pretain/pretrained/meap_1b.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | import sys 4 | import time 5 | from pathlib import Path 6 | from typing import Optional, Tuple, Union 7 | import math 8 | import lightning as L 9 | import torch 10 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 11 | from torch.utils.data import DataLoader 12 | from functools import partial 13 | 14 | from transformers import AutoTokenizer 15 | # support running without installing as a package 16 | wd = Path(__file__).parent.parent.resolve() 17 | sys.path.append(str(wd)) 18 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 19 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 20 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 21 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 22 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 23 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 24 | from pytorch_lightning.loggers import WandbLogger 25 | from lit_gpt import FusedCrossEntropyLoss 26 | import random 27 | 28 | model_name = "tiny_LLaMA_1b_mask" 29 | name = "meap_1b_mask" 30 | out_dir = Path("out_mask_1b_mask0.15") / name 31 | #num_nodes=1 32 | # Hyperparameters 33 | num_of_devices = 1 34 | global_batch_size = 256 35 | learning_rate = 4e-4 36 | micro_batch_size = 8 37 | max_step = 190000 38 | warmup_steps = 19000 39 | log_step_interval = 10 40 | eval_iters = 100 41 | save_step_interval = 20000 42 | eval_step_interval = 4766 43 | 44 | weight_decay = 5e-2 45 | beta1 = 0.9 46 | beta2 = 0.95 47 | grad_clip = 1.0 48 | decay_lr = True 49 | min_lr = 4e-5 50 | 51 | batch_size = global_batch_size // num_of_devices 52 | gradient_accumulation_steps = batch_size // micro_batch_size 53 | assert gradient_accumulation_steps > 0 54 | warmup_iters = warmup_steps * gradient_accumulation_steps 55 | 56 | checkpoint_dir = Path("../tokenizer") 57 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir,legacy=False,trust_remote_code=True) 58 | special_tokens = { 59 | "additional_special_tokens": ["[MASK]"], 60 | "pad_token": tokenizer.pad_token or tokenizer.eos_token, 61 | "eos_token": tokenizer.eos_token, 62 | "bos_token": tokenizer.bos_token or tokenizer.eos_token, 63 | } 64 | num_added_tokens = tokenizer.add_special_tokens(special_tokens) 65 | use_mask = True 66 | mask_ratio = 0.15 67 | max_iters = max_step * gradient_accumulation_steps 68 | lr_decay_iters = max_iters 69 | log_iter_interval = log_step_interval * gradient_accumulation_steps 70 | 71 | 72 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 73 | train_data_config = [ 74 | ("c4", 1.0), 75 | ] 76 | 77 | val_data_config = [ 78 | ("validation", 1.0), 79 | ] 80 | 81 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 82 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 83 | wandb_logger = WandbLogger(project="NTP-MASK", 84 | name="mask_1b_1230_mask0.15", 85 | log_model=True, 86 | save_dir=str(out_dir), 87 | config={ 88 | 89 | "model_name": model_name, 90 | "total_params": None, 91 | 92 | 93 | "global_batch_size": global_batch_size, 94 | "micro_batch_size": micro_batch_size, 95 | "gradient_accumulation_steps": gradient_accumulation_steps, 96 | "learning_rate": learning_rate, 97 | "min_lr": min_lr, 98 | "weight_decay": weight_decay, 99 | "warmup_steps": warmup_steps, 100 | "max_step": max_step, 101 | "beta1": beta1, 102 | "beta2": beta2, 103 | "grad_clip": grad_clip, 104 | 105 | 106 | "train_data_config": train_data_config, 107 | "val_data_config": val_data_config, 108 | 109 | 110 | "num_devices": num_of_devices, 111 | "precision": None, 112 | } 113 | 114 | ) 115 | 116 | 117 | 118 | def setup( 119 | devices: int = 1, 120 | train_data_dir: Path = Path("../c4_bin"), 121 | val_data_dir: Optional[Path] = None, 122 | precision: Optional[str] = None, 123 | tpu: bool = False, 124 | resume: Union[bool, Path] = False, 125 | ) -> None: 126 | print("devices: ", devices) 127 | #resume=True 128 | 129 | #print("num_nodes: ", num_nodes) 130 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 131 | 132 | if devices > 1: 133 | if tpu: 134 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 135 | devices = "auto" 136 | strategy = XLAStrategy(sync_module_states=False) 137 | else: 138 | strategy = FSDPStrategy( 139 | auto_wrap_policy={Block}, 140 | activation_checkpointing_policy=None, 141 | state_dict_type="full", 142 | limit_all_gathers=True, 143 | cpu_offload=False, 144 | ) 145 | else: 146 | strategy = "auto" 147 | 148 | fabric = L.Fabric(devices=devices,strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 149 | fabric.print(hparams) 150 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 151 | main(fabric, train_data_dir, val_data_dir, resume) 152 | 153 | 154 | def main(fabric, train_data_dir, val_data_dir, resume): 155 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 156 | 157 | if fabric.global_rank == 0: 158 | out_dir.mkdir(parents=True, exist_ok=True) 159 | 160 | config = Config.from_name(model_name) 161 | fabric.print(f"train_data_dir: {train_data_dir}") 162 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 163 | fabric.print(f"val_data_dir: {val_data_dir}") 164 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 165 | fabric.print(f"train_data_dir: {train_data_dir}") 166 | fabric.print(f"type of train_data_dir: {type(train_data_dir)}") 167 | fabric.print(f"val_data_dir: {val_data_dir}") 168 | fabric.print(f"type of val_data_dir: {type(val_data_dir)}") 169 | train_dataloader, val_dataloader = create_dataloaders( 170 | batch_size=micro_batch_size, 171 | block_size=config.block_size, 172 | fabric=fabric, 173 | train_data_dir=train_data_dir, 174 | val_data_dir=val_data_dir, 175 | seed=3407, 176 | ) 177 | if val_dataloader is None: 178 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 179 | else: 180 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 181 | 182 | fabric.seed_everything(3407) # same seed for every process to init model (FSDP) 183 | 184 | fabric.print(f"Loading model with {config.__dict__}") 185 | t0 = time.perf_counter() 186 | with fabric.init_module(empty_init=False): 187 | model = GPT(config) 188 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 189 | 190 | total_params = num_parameters(model) 191 | #wandb_logger.experiment.config.update({"total_params": total_params}) 192 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 193 | fabric.print(f"Total parameters {num_parameters(model):,}") 194 | 195 | model = fabric.setup(model) 196 | optimizer = torch.optim.AdamW( 197 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 198 | ) 199 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 200 | optimizer = fabric.setup_optimizers(optimizer) 201 | 202 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 203 | 204 | if resume is True: 205 | resume="" 206 | #resume = sorted(out_dir.glob("*.pth"))[-1] 207 | if resume : 208 | resume="" 209 | fabric.print(f"Resuming training from {resume}") 210 | fabric.load(resume, state) 211 | 212 | train_time = time.perf_counter() 213 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 214 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 215 | if fabric.device.type == "cuda": 216 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 217 | 218 | def get_cycled_seed(curr_iter): 219 | cycle_range = 51 220 | cycled_value = 50 + (curr_iter % cycle_range) 221 | return cycled_value 222 | 223 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 224 | model = state["model"] 225 | optimizer = state["optimizer"] 226 | 227 | if val_dataloader is not None: 228 | validate(fabric, model, val_dataloader) # sanity check 229 | 230 | with torch.device("meta"): 231 | meta_model = GPT(model.config) 232 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 233 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 234 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 235 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 236 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 237 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 238 | # measured_flos run in meta. Will trigger fusedRMSNorm error 239 | #measured_flops = measure_flops(meta_model, x) 240 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 241 | del meta_model, x 242 | 243 | total_lengths = 0 244 | total_t0 = time.perf_counter() 245 | 246 | if fabric.device.type == "xla": 247 | import torch_xla.core.xla_model as xm 248 | 249 | xm.mark_step() 250 | 251 | 252 | initial_iter = state["iter_num"] 253 | curr_iter = 0 254 | 255 | 256 | loss_func = FusedCrossEntropyLoss() 257 | for train_data in train_dataloader: 258 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 259 | if resume: 260 | if curr_iter < initial_iter: 261 | curr_iter += 1 262 | continue 263 | else: 264 | resume = False 265 | curr_iter = -1 266 | fabric.barrier() 267 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 268 | if state["iter_num"] >= max_iters: 269 | break 270 | 271 | # determine and set the learning rate for this iteration 272 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 273 | for param_group in optimizer.param_groups: 274 | param_group["lr"] = lr 275 | 276 | iter_t0 = time.perf_counter() 277 | input_ids = train_data[:, 0 : model.config.block_size].clone().contiguous() 278 | if use_mask: 279 | random.seed(get_cycled_seed(state["step_count"])) 280 | num_masks = max(1, int(model.config.block_size * mask_ratio)) 281 | mask_positions = random.sample(range(0, model.config.block_size-1), num_masks) 282 | #print(f"mask_positions: {mask_positions}") 283 | bs, seq_len = input_ids.shape 284 | #input_ids[0,0] = 32000 285 | input_ids[:,mask_positions] = tokenizer.convert_tokens_to_ids("[MASK]") 286 | 287 | 288 | 289 | 290 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 291 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 292 | with fabric.no_backward_sync(model, enabled=is_accumulating): 293 | logits = model(input_ids) 294 | loss = loss_func(logits, targets) 295 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 296 | fabric.backward(loss / gradient_accumulation_steps) 297 | 298 | if not is_accumulating: 299 | wandb_logger.log_metrics({ 300 | "train/loss": loss.item(), 301 | "train/learning_rate": lr, 302 | "train/step": state["step_count"], 303 | "system/gpu_memory": torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 304 | "system/gpu_memory_reserved": torch.cuda.max_memory_reserved() / 1e9 if torch.cuda.is_available() else 0, 305 | "performance/iter_time": (t1 - iter_t0) * 1000, 306 | "performance/tokens_per_sec": (micro_batch_size * model.config.block_size) / (t1 - iter_t0), 307 | "system/estimated_tflops": (estimated_flops * fabric.world_size / 1e12) * (1 / (t1 - iter_t0)) 308 | }, step=state["step_count"]) 309 | 310 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 311 | optimizer.step() 312 | optimizer.zero_grad() 313 | state["step_count"] += 1 314 | elif fabric.device.type == "xla": 315 | xm.mark_step() 316 | state["iter_num"] += 1 317 | 318 | # input_id: B L 319 | total_lengths += input_ids.size(1) 320 | t1 = time.perf_counter() 321 | fabric.print( 322 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" 323 | f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 324 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 325 | # print days as well 326 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 327 | ) 328 | 329 | monitor.on_train_batch_end( 330 | state["iter_num"] * micro_batch_size, 331 | t1 - total_t0, 332 | # this assumes that device FLOPs are the same and that all devices have the same batch size 333 | fabric.world_size, 334 | state["step_count"], 335 | flops_per_batch=estimated_flops, 336 | lengths=total_lengths, 337 | train_loss = loss.item() 338 | ) 339 | 340 | 341 | 342 | 343 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 344 | 345 | t0 = time.perf_counter() 346 | val_loss = validate(fabric, model, val_dataloader) 347 | t1 = time.perf_counter() - t0 348 | monitor.eval_end(t1) 349 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 350 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 351 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 352 | fabric.barrier() 353 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 354 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 355 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 356 | fabric.save(checkpoint_path, state) 357 | 358 | 359 | @torch.no_grad() 360 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 361 | fabric.print("Validating ...") 362 | model.eval() 363 | 364 | losses = torch.zeros(eval_iters, device=fabric.device) 365 | for k, val_data in enumerate(val_dataloader): 366 | if k >= eval_iters: 367 | break 368 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 369 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 370 | logits = model(input_ids) 371 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 372 | 373 | # loss_func = FusedCrossEntropyLoss() 374 | # loss = loss_func(logits, targets) 375 | losses[k] = loss.item() 376 | 377 | out = losses.mean() 378 | 379 | model.train() 380 | return out 381 | 382 | 383 | def create_dataloader( 384 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 385 | ) -> DataLoader: 386 | datasets = [] 387 | data_config = train_data_config if split == "train" else val_data_config 388 | for prefix, _ in data_config: 389 | #print(f"data_dir: {data_dir}") 390 | #print(f"prefix: {prefix}") 391 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 392 | #print(f"filenames: {filenames}") 393 | random.seed(seed) 394 | random.shuffle(filenames) 395 | #print(f"filenames after shuffle: {filenames}") 396 | dataset = PackedDataset( 397 | filenames, 398 | # n_chunks control the buffer size. 399 | # Note that the buffer size also impacts the random shuffle 400 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 401 | n_chunks=1, 402 | block_size=block_size, 403 | shuffle=shuffle, 404 | seed=seed+fabric.global_rank, 405 | num_processes=fabric.world_size, 406 | process_rank=fabric.global_rank, 407 | ) 408 | datasets.append(dataset) 409 | 410 | if not datasets: 411 | raise RuntimeError( 412 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 413 | ) 414 | 415 | weights = [weight for _, weight in data_config] 416 | 417 | sum_weights = sum(weights) 418 | 419 | weights = [el / sum_weights for el in weights] 420 | 421 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 422 | 423 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 424 | 425 | 426 | def create_dataloaders( 427 | batch_size: int, 428 | block_size: int, 429 | fabric, 430 | train_data_dir: Path = Path("data/redpajama_sample"), 431 | val_data_dir: Optional[Path] = None, 432 | seed: int = 12345, 433 | ) -> Tuple[DataLoader, DataLoader]: 434 | # Increase by one because we need the next word as well 435 | effective_block_size = block_size + 1 436 | train_dataloader = create_dataloader( 437 | batch_size=batch_size, 438 | block_size=effective_block_size, 439 | fabric=fabric, 440 | data_dir=train_data_dir, 441 | shuffle=True, 442 | seed=seed, 443 | split="train" 444 | ) 445 | val_dataloader = ( 446 | create_dataloader( 447 | batch_size=batch_size, 448 | block_size=effective_block_size, 449 | fabric=fabric, 450 | data_dir=val_data_dir, 451 | shuffle=False, 452 | seed=seed, 453 | split="validation" 454 | ) 455 | if val_data_dir 456 | else None 457 | ) 458 | return train_dataloader, val_dataloader 459 | 460 | 461 | # learning rate decay scheduler (cosine with warmup) 462 | def get_lr(it): 463 | # 1) linear warmup for warmup_iters steps 464 | if it < warmup_iters: 465 | return learning_rate * it / warmup_iters 466 | # 2) if it > lr_decay_iters, return min learning rate 467 | if it > lr_decay_iters: 468 | return min_lr 469 | # 3) in between, use cosine decay down to min learning rate 470 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 471 | assert 0 <= decay_ratio <= 1 472 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 473 | return min_lr + coeff * (learning_rate - min_lr) 474 | 475 | 476 | if __name__ == "__main__": 477 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 478 | # torch.backends.cuda.enable_flash_sdp(False) 479 | torch.set_float32_matmul_precision("high") 480 | 481 | from jsonargparse import CLI 482 | 483 | CLI(setup) 484 | -------------------------------------------------------------------------------- /MEAP-Pretain/requirements.txt: -------------------------------------------------------------------------------- 1 | jsonargparse[signatures] # CLI 2 | pandas 3 | pyarrow 4 | tokenizers 5 | sentencepiece 6 | wandb 7 | zstd -------------------------------------------------------------------------------- /MEAP-Pretain/run/run_one_node.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # External parameter for the Python script 4 | PYTHON_SCRIPT=$1 5 | 6 | # Run the model with the provided Python script 7 | lightning run model --devices=1 --accelerator=cuda $PYTHON_SCRIPT 8 | -------------------------------------------------------------------------------- /MEAP-Pretain/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": false, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": false, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "unk_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": false, 20 | "rstrip": false, 21 | "single_word": false 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /MEAP-Pretain/tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scitix/MEAP/0f5e837720d62d66e159db150589e5dfc4496d1c/MEAP-Pretain/tokenizer/tokenizer.model -------------------------------------------------------------------------------- /MEAP-Pretain/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "add_prefix_space": true, 5 | "added_tokens_decoder": { 6 | "0": { 7 | "content": "", 8 | "lstrip": false, 9 | "normalized": false, 10 | "rstrip": false, 11 | "single_word": false, 12 | "special": true 13 | }, 14 | "1": { 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false, 20 | "special": true 21 | }, 22 | "2": { 23 | "content": "", 24 | "lstrip": false, 25 | "normalized": false, 26 | "rstrip": false, 27 | "single_word": false, 28 | "special": true 29 | } 30 | }, 31 | "bos_token": "", 32 | "clean_up_tokenization_spaces": false, 33 | "eos_token": "", 34 | "legacy": true, 35 | "model_max_length": 1000000000000000019884624838656, 36 | "pad_token": null, 37 | "sp_model_kwargs": {}, 38 | "spaces_between_special_tokens": false, 39 | "tokenizer_class": "LlamaTokenizer", 40 | "unk_token": "", 41 | "use_default_system_prompt": false 42 | } 43 | -------------------------------------------------------------------------------- /MEAP-SFT/deepspeed_zero_stage2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false, 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 2, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "allgather_partitions": true, 39 | "allgather_bucket_size": 2e8, 40 | "overlap_comm": true, 41 | "reduce_scatter": true, 42 | "reduce_bucket_size": 2e8, 43 | "contiguous_gradients": true 44 | }, 45 | 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 100, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /MEAP-SFT/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | transformers==4.45.0 3 | accelerate==0.27.2 4 | peft==0.10.0 5 | datasets==3.1.0 6 | bitsandbytes==0.43.3 7 | triton==3.0.0 8 | sentencepiece==0.2.0 9 | trl==0.8.6 10 | pandas==1.5.3 11 | numpy==1.23.5 12 | safetensors==0.4.5 13 | tiktoken==0.8.0 14 | loguru==0.7.3 15 | deepspeed==0.16.3 16 | tensorboard==2.18.0 17 | scipy==1.14.1 18 | -------------------------------------------------------------------------------- /MEAP-SFT/template.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Optional, List, Dict, Sequence 4 | 5 | __all__ = ['Conversation', 'register_conv_template', 'get_conv_template'] 6 | 7 | 8 | @dataclass 9 | class Conversation: 10 | """A class that manages prompt templates and keeps all conversation history.""" 11 | 12 | # The name of this template 13 | name: str 14 | # The system prompt 15 | system_prompt: str 16 | # All messages. format: list of [question, answer] 17 | messages: Optional[List[Sequence[str]]] 18 | # The roles of the speakers 19 | roles: Optional[Sequence[str]] 20 | # Conversation prompt 21 | prompt: str 22 | # Separator 23 | sep: str 24 | # Stop token, default is tokenizer.eos_token 25 | stop_str: Optional[str] = "" 26 | 27 | def get_prompt( 28 | self, 29 | messages: Optional[List[Sequence[str]]] = None, 30 | system_prompt: Optional[str] = "" 31 | ) -> str: 32 | """ 33 | Returns a string containing prompt without response. 34 | """ 35 | return "".join(self._format_example(messages, system_prompt)) 36 | 37 | def get_dialog( 38 | self, 39 | messages: Optional[List[Sequence[str]]] = None, 40 | system_prompt: Optional[str] = "" 41 | ) -> List[str]: 42 | """ 43 | Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. 44 | """ 45 | return self._format_example(messages, system_prompt) 46 | 47 | def _format_example( 48 | self, 49 | messages: Optional[List[Sequence[str]]] = None, 50 | system_prompt: Optional[str] = "" 51 | ) -> List[str]: 52 | system_prompt = system_prompt or self.system_prompt 53 | system_prompt = system_prompt + self.sep if system_prompt else "" # add separator for non-empty system prompt 54 | messages = messages or self.messages 55 | convs = [] 56 | for turn_idx, [user_query, bot_resp] in enumerate(messages): 57 | if turn_idx == 0: 58 | convs.append(system_prompt + self.prompt.format(query=user_query)) 59 | convs.append(bot_resp) 60 | else: 61 | convs.append(self.sep + self.prompt.format(query=user_query)) 62 | convs.append(bot_resp) 63 | return convs 64 | 65 | def append_message(self, query: str, answer: str): 66 | """Append a new message.""" 67 | self.messages.append([query, answer]) 68 | 69 | 70 | # A global registry for all conversation templates 71 | conv_templates: Dict[str, Conversation] = {} 72 | 73 | 74 | def register_conv_template(template: Conversation): 75 | """Register a new conversation template.""" 76 | conv_templates[template.name] = template 77 | 78 | 79 | 80 | 81 | 82 | """llama3 template 83 | source: https://huggingface.co/meta-llama 84 | Supports: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct 85 | chat template: 86 | <|begin_of_text|><|start_header_id|>system<|end_header_id|> 87 | {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> 88 | {{ user_msg_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 89 | {{ model_answer_1 }}<|eot_id|> 90 | """ 91 | register_conv_template( 92 | Conversation( 93 | name="llama3", 94 | system_prompt=( 95 | "<|start_header_id|>system<|end_header_id|>\n\n" 96 | "You are a helpful, excellent and smart assistant." 97 | ), 98 | messages=[], 99 | roles=("user", "assistant"), 100 | prompt=( 101 | "<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|>" 102 | "<|start_header_id|>assistant<|end_header_id|>\n\n" 103 | ), 104 | sep="<|eot_id|>", 105 | stop_str="<|eot_id|>", 106 | ) 107 | ) 108 | 109 | 110 | 111 | def get_conv_template(name: str) -> Conversation: 112 | """Get a conversation template.""" 113 | return conv_templates[name] 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEAP 2 | 3 | This repository contains the official implementation of "[Mask-Enhanced Autoregressive Prediction: Pay Less Attention to Learn More](https://arxiv.org/abs/2502.07490)". 4 | 5 | **2025-05-01: MEAP is accepted to ICML 2025! ([Poster](https://icml.cc/virtual/2025/poster/46344))** 6 | 7 | *Xialie Zhuang, Zhikai Jia, Jianjin Li, Zhenyu Zhang, Li Shen, Zheng Cao, Shiwei Liu* 8 | 9 | ## 📋 Table of Contents 10 | - [MEAP-Pretrain](#MEAP-Pretrain) 11 | - [MEAP-Sft](#MEAP-Sft) 12 | 13 | ## Overview 14 | 15 | MEAP (Mask-Enhanced Autoregressive Prediction) is a novel training paradigm that seamlessly integrates Masked Language Modeling (MLM) into Next-Token Prediction (NTP) using a decoder-only Transformer. By masking a small fraction of input tokens during standard autoregressive training, MEAP enhances model performance on key information retrieval tasks while maintaining strong reasoning capabilities. 16 | 17 | Key Features: 18 | - Seamless integration of MLM into NTP 19 | - No additional computational overhead 20 | - Compatible with decoder-only architectures 21 | - Improved performance on information retrieval tasks 22 | 23 | ## MEAP-Pretrain 24 | 25 | ### Model Architecture 26 | 27 | The MEAP architecture extends standard decoder-only transformers by: 28 | 1. Randomly masking a portion of input tokens 29 | 2. Training the model to predict both masked tokens and next tokens 30 | 3. Maintaining the autoregressive property during inference 31 | 32 | ### Installation 33 | #### Install env 34 | 35 | ```bash 36 | conda create -n meap python=3.10 37 | conda activate meap 38 | ``` 39 | 40 | #### Install Pytorch. 41 | ```bash 42 | pip install torch==2.5.0 --index-url https://download.pytorch.org/whl/cu121 43 | ``` 44 | 45 | #### Install lightning 46 | ```bash 47 | pip install lightning==2.1.2 48 | pip install lightning-app 49 | pip install lightning-cloud==0.5.52 50 | ``` 51 | 52 | #### Install Flash-Attention 2 and other fused operators: 53 | ```bash 54 | git clone https://github.com/Dao-AILab/flash-attention 55 | cd flash-attention 56 | pip install flash-attn 57 | cd csrc/rotary && pip install . 58 | cd ../layer_norm && pip install . 59 | cd ../xentropy && pip install . 60 | cd ../.. && rm -rf flash-attention 61 | ``` 62 | 63 | #### Build XFormers from Source 64 | 65 | ```bash 66 | pip3 install xformers --no-deps 67 | ``` 68 | #### Install Remaining Dependencies 69 | ``` 70 | pip install -r requirements.txt tokenizers sentencepiece transformers 71 | ``` 72 | to install other dependencies. 73 | It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings. 74 | 75 | Then you are ready to go 🎉! 76 | 77 | ### Data Preparation 78 | 79 | #### Download Datasets 80 | Download the c4 dataset to your chosen directory. 81 | ```bash 82 | mkdir original_data 83 | cd original_data 84 | git lfs install 85 | git clone https://huggingface.co/datasets/allenai/c4/tree/main 86 | cd .. 87 | ``` 88 | 89 | Extract the downloaded c4 file and move it to the json_c4 folder. 90 | ```bash 91 | python data_process/gz_unzip_v1.py 92 | mkdir json_c4 93 | mv original_data 94 | mv *.json ../json_c4/ 95 | ``` 96 | 97 | 98 | 99 | #### Tokenize data 100 | Use the provided scripts to tokenize the datasets and divide them into chunks. 101 | 102 | 103 | ```bash 104 | mkdir c4_bin 105 | python3 prepare_c4.py --source_path ../ --destination_path ../c4_bin --checkpoint_dir ../tokenizer 106 | cd .. 107 | ``` 108 | We have placed some sample data in the 'c4_bin' folder. Please note that this is only for testing the program, and these data are not the complete training data. 109 | 110 | ### Train 111 | 112 | 113 | If your setup comprises two nodes, each with 1 GPUs, you can initiate pretraining with the following commands: 114 | 115 | ```bash 116 | cd MEAP-Pretrain 117 | sh run_one_node.sh ../pretrained/meap_1b.py 118 | ``` 119 | If you want to modify the number of GPUs to be used, please simultaneously modify the `--devices` parameter in `run_one_node.sh`, the `num_of_devices` parameter and the default parameter of `devices` in the `setup` function in `meap_1b.py`. 120 | 121 | The default path for saving the model weights is `out_mask_1b_mask0.15`. If you want to modify the save path, please change the `out_dir` parameter in `meap_1b.py`. 122 | 123 | The default value of the `n_chunks` parameter in `meap_1b.py` is 1. Increasing its value can increase the throughput of data reading. 124 | 125 | More training hyperparameters can also be modified in `meap_1b.py`. 126 | 127 | 128 | ### convert to huggingface 129 | 130 | Convert the trained model to the HF format. 131 | 132 | ```bash 133 | cd convert 134 | 135 | python3 convert_lit_checkpoint.py --checkpoint_name xxxx.pth --out_dir your_save_dir --model_name trained_model_name,such as tiny_LLaMA_1b_mask --model_only false 136 | ``` 137 | 138 | After running the script, a bin file will be stored in the 'out_dir' folder. 139 | 140 | Finally, run convert_safetensors.py to convert the bin file to the safetensors format, where checkpoint_path is the path of the bin file and out_dir is the save path for the safetensors file. 141 | 142 | ```bash 143 | python3 convert_safetensors.py 144 | ``` 145 | ## MEAP-SFT 146 | 147 | ### Model Architecture 148 | 149 | The MEAP architecture extends standard decoder-only transformers by: 150 | 151 | 1. **Randomly Mask Target Text**: Randomly select positions in the target text to mask based on the given `mask_ratio`. 152 | 2. **Align Input and Labels**: Ensure input sequences and labels are aligned in length, and truncate sequences that exceed the maximum length. 153 | 3. **Dynamically Generate Masks**: Dynamically select mask positions in each training step to improve the model's generalization ability. 154 | 155 | ### Installation 156 | 157 | ``` 158 | conda create -n MEAP-SFT python=3.10 -y 159 | conda activate MEAP-SFT 160 | pip install -r ./MEAP-SFT/requirements.txt 161 | ``` 162 | 163 | ### Train 164 | 165 | - IF there is no LLAMA3-8B weight, need to download 166 | 167 | ``` 168 | bash ./script/MEAP-SFT.sh 169 | ``` 170 | 171 | ## License 172 | 173 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 174 | 175 | ## Cite as 176 | ``` 177 | @article{zhuang2025mask, 178 | title={Mask-Enhanced Autoregressive Prediction: Pay Less Attention to Learn More}, 179 | author={Zhuang, Xialie and Jia, Zhikai and Li, Jianjin and Zhang, Zhenyu and Shen, Li and Cao, Zheng and Liu, Shiwei}, 180 | journal={arXiv preprint arXiv:2502.07490}, 181 | year={2025} 182 | } 183 | ``` 184 | 185 | ## Acknowledgments 186 | 187 | We would like to acknowledge and thank the following projects and platforms that helped make this work possible: 188 | 189 | - [Siflow](https://scitix.ai/) - The entire development process relies on the Siflow platform, provided by SCITIX (SGP) TECH PTE. LTD. 190 | 191 | - [TinyLlama](https://github.com/jzhang38/TinyLlama) - Our work builds upon insights and implementations from the TinyLlama project. 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | --------------------------------------------------------------------------------