├── README.md ├── conf └── llama_65b_merit_v1_pv91_v91_v5_0_full.yaml ├── convert2ckpt.py ├── data ├── data_utils.py ├── flan.py └── test.py ├── general_util └── tokenization_utils.py ├── models └── llama_ds_mp_wrap.py ├── requirements.txt └── trainer_base_ds_mp.py /README.md: -------------------------------------------------------------------------------- 1 | # llama-pipeline-parallel 2 | 3 | This is a experimental repo to explore how to implement LLaMA with Deepspeed Pipeline Parallelism since the document is incomplete and very few projects are working on this. The repo hopes to provide a minimal prototype and training loop to implemement PP training for LLaMA and keep a note of possible bugs and the corresponding solutions. 4 | 5 | We have provided a minimal template to launch hybrid training of PP and DP, and the config can be found in `conf/llama_65b_metir_v1_pv91_v91_v5_0.yaml`. 6 | It should be noted that the template cannot be directly run since this is extracted from another project and some parts are omitted. 7 | But you can still quickly adapt it to your own usage by removing the relevant parts of dataset and collator initialization. 8 | 9 | ## Updates 10 | 11 | 2023/07/02: Successfully enabling hybrid training of LLaMA-65B on two nodes with 16 * 80G A100. 12 | 13 | 2023/06/25: Repo established. Add some notes first and the code will soon be released when the clear is ready. 14 | 15 | ## Core Code Snippets 16 | 17 | ### Model initialization 18 | 19 | There are two main approaches to enable model initialization and loading pre-trained weights. One is first initializing the model using the `from_pretrained` function of HuggingFace's `transformers` repo. 20 | In this case, you may refer to `models.llama_ds_mp_wrap.get_model` for details. 21 | The drawback of this method is that it will load the whole model for each worker. This will cause out-of-CPU-memory-usage when the model is large. 22 | Another method is first initializing the sharded models with DeepSpeed's `LayerSpec` class to implement post-initialization after pipeline parallelism partition. Then each rank only need to load the pre-trained weights for each own partition: 23 | 24 | ```python 25 | model_or_config = transformers.AutoConfig.from_pretrained(cfg.model_name_or_path) 26 | layers = models.llama_ds_mp_wrap.get_layers_from_config(model_or_config) 27 | model_pipe = PipelineModule(layers=layers, 28 | num_stages=cfg.num_stages, 29 | loss_fn=models.llama_ds_mp_wrap.loss_fn, 30 | activation_checkpoint_interval=getattr(cfg, "activation_checkpoint_interval", 0) 31 | ) 32 | ... 33 | model.load_checkpoint(cfg.model_name_or_path, load_module_only=True, load_optimizer_states=False, load_lr_scheduler_states=False) 34 | ``` 35 | 36 | Note that the pre-trained weights should be converted from HF format by using `convert2ckpt.py`. 37 | 38 | 39 | ### Hybrid Training of Pipeline Parallelism (PP) and Distributed Data Parallel (DP) 40 | 41 | When `dist.world_size` > `num_stages`, hybrid training is automatically enabled. The number of stages of pipeline parallel (PP) is `num_stages` 42 | while the degree of data-parallel (DP) is `dist.world_size // num_stages`. 43 | 44 | ### No Weight Typing of Word Embedding 45 | 46 | Different from traditional pre-trained language models, LLaMA do not need weight typing. So do not use `TiedLayerSpec` to wrap `embed_tokens` and `lm_head` modules. 47 | 48 | ### Distributed Sampler Setting 49 | 50 | When hybrid training of PP and DP is enabled, `DistributedSampler` should be carefully set for each rank w.r.t. its state (PP stage and DP group). 51 | 52 | The core code snippet is as follows: 53 | 54 | ```python 55 | dp_degree = dist.get_world_size() // cfg.num_stages 56 | 57 | if dp_degree > 1: 58 | dp_id = model.grid.get_data_parallel_id() 59 | sub_train_sampler = DistributedSampler(sub_train_dataset, num_replicas=dp_degree, rank=dp_id) 60 | else: 61 | sub_train_sampler = RandomSampler(sub_train_dataset) 62 | ``` 63 | 64 | ### Data Fetch Design of DeepSpeed and CPU Memory Reduction 65 | 66 | In DeepSpeed design, among specific PP group, only the first and the last rank, i.e., `stage=0 or stage=num_stages - 1`, 67 | will fetch minibatch from dataloader, and the other ranks never fetch data. 68 | 69 | Based on this, for the ranks where the dataloader will never be used, we can use placeholders to allocate the memory usage. This could be especially useful when training large models. 70 | For example, when training LLaMA-65B with `offload_optimizer=True` and `num_stages=8`, the CPU memory usage is already nearly 800GB, 71 | which will cause CPU memory OOM when you are using large dataset. 72 | 73 | The code of dataset placeholder is as follows: 74 | 75 | ```python 76 | def load_empty_dataset_and_collator(cfg: DictConfig): 77 | from data.test import TestDataset 78 | from data.flan import FlanCollatorOverCollator 79 | 80 | dataset = TestDataset(None, None, getattr(cfg, "total_dataset_len", -1)) 81 | collator = FlanCollatorOverCollator(collator=None, 82 | tokenizer=cfg.model_name_or_path, 83 | max_seq_length=128, 84 | decoder_only=True, 85 | return_standard_inputs=True, 86 | ) 87 | 88 | # Keep consistent with `load_and_cache_examples`. 89 | if getattr(cfg, "dist_load_data_barrier", True): 90 | dist.barrier() 91 | 92 | if dist.is_initialized(): 93 | dist.barrier() 94 | 95 | return dataset, collator 96 | 97 | 98 | if model.is_first_stage() or model.is_last_stage(): 99 | sub_train_dataset = load_and_cache_examples(cfg, tokenizer, _split="train", _file=_file) 100 | 101 | if dp_degree > 1: 102 | dp_id = model.grid.get_data_parallel_id() 103 | sub_train_sampler = DistributedSampler(sub_train_dataset, num_replicas=dp_degree, rank=dp_id) 104 | else: 105 | sub_train_sampler = RandomSampler(sub_train_dataset) 106 | sub_train_collator = hydra.utils.instantiate(cfg.collator) if "collator" in cfg and cfg.collator else None 107 | 108 | sub_train_dataloader = DataLoader(dataset=sub_train_dataset, 109 | sampler=sub_train_sampler, 110 | batch_size=cfg.train_batch_size, 111 | collate_fn=sub_train_collator, 112 | num_workers=cfg.num_workers, 113 | pin_memory=True, 114 | prefetch_factor=cfg.prefetch_factor, 115 | drop_last=True, 116 | ) 117 | else: 118 | sub_train_dataset, sub_train_collator = load_empty_dataset_and_collator(cfg) 119 | sub_train_sampler = None 120 | 121 | sub_train_dataloader = DataLoader(dataset=sub_train_dataset, 122 | batch_size=cfg.train_batch_size, 123 | collate_fn=sub_train_collator, 124 | drop_last=True, 125 | shuffle=False) 126 | 127 | ``` 128 | 129 | where `TestDataset` is an empty dataset and the collator is arbitrary one meeting the input format. 130 | 131 | ## Know Problems and Possible Solutions 132 | 133 | ### BF16 Support 134 | Bfloat16 can be used by setting the following in deepspeed config: 135 | ``` 136 | data_types: 137 | grad_accum_dtype: "fp32" 138 | ``` 139 | However, bfloat16 cannot be used with optimizer offload. Note that pipeline parallelism is designed not to support optimizer offload (see issue [\#3866](https://github.com/microsoft/DeepSpeed/issues/3866)). Nevertheless, it can still be enabled under fp16 training. 140 | 141 | ### Flash Attention 142 | 143 | I cannot enable flash attention using both the original implementation or `torch.nn.functional.scaled_dot_product_attention` from pytorch 2.0. See issue [here](https://github.com/HuangLK/llama-deepspeed/issues/36) and [here](https://github.com/microsoft/DeepSpeed/issues/3868). 144 | 145 | ### Torch Compile 146 | 147 | Torch compilation is not supported in the template, which perhaps becuase my writing is incorrect. 148 | 149 | ## Reference & Acknowledgement 150 | 151 | 1. [llama-deepspeed](https://github.com/HuangLK/llama-deepspeed/tree/main) 152 | 2. [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning) 153 | 3. [DeepSpeed Pipeline Parallelism Tutorial](https://www.deepspeed.ai/tutorials/pipeline/) 154 | 155 | [//]: # (### Quick Notes) 156 | 157 | [//]: # () 158 | [//]: # (#### Data fetech) 159 | 160 | [//]: # () 161 | [//]: # (1. Currently most implementations uses `shuffle=True` instead of `DistributedSampler` or `RandomSampler` of pytorch in data loader. I find that for `wordld_size=4` scenario, only the first rank and the last one fetech data from data loader. This can be verified by adding print information in `__getitem__` method of specific dataset. However, when really training, I find that only the batch feteched from the first rank will be really send to model. This is consistent with what I thought about pipeline parallelism that only one rank feteches data and the other ranks only take the outputs from the previous rank as iputs.) 162 | 163 | [//]: # (2. There is a bug in Deepspeed hybrid engine loading model checkpoint that there mush be optimizer states in the specific dir, check it [here](https://github.com/HuangLK/llama-deepspeed/issues/28).) 164 | -------------------------------------------------------------------------------- /conf/llama_65b_merit_v1_pv91_v91_v5_0_full.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./ 4 | 5 | aws_output_bucket: 6 | 7 | train_file: /opt/ml/input/data/train/distant_path_v9.1_fix_no_shuffle.train.0.pkl 8 | test_file: 9 | dist_load_data_barrier: False 10 | 11 | # Model 12 | model: 13 | _target_: transformers.AutoConfig.from_pretrained 14 | pad_token_id: 0 15 | 16 | 17 | get_layers: 18 | _target_: models.llama_ds_mp_wrap.get_layers_from_config 19 | activation_checkpointing: True 20 | 21 | enable_flash_attention: False 22 | 23 | # Pipeline parallelism specific 24 | num_stages: 8 25 | #activation_checkpoint_interval: 1 26 | 27 | # Data loading 28 | read_tensor_train: 29 | _target_: data.wiki_entity_path_v9_1_2.convert_examples_into_features_seq2seq 30 | max_neg_num: 3 31 | aug_num: 3 32 | max_seq_length: 512 33 | shuffle_context: True 34 | min_rep_num: 5 35 | geo_p: 0.4 36 | deduct_ratio: 1.0 37 | context_ratio: 1.0 38 | noise_sent_ratio: 0.0 39 | num_workers: 128 40 | 41 | 42 | extended_vocab: 43 | 44 | # Data collator 45 | collator: 46 | _target_: data.collators.wiki_seq2seq_collator.WikiSeq2SeqCollatorWithCausalLMCombine 47 | max_seq_length: 512 48 | tokenizer: ${model_name_or_path} 49 | causal_lm: True 50 | causal_lm_add_eos: False 51 | generative_mode: True 52 | return_standard_inputs: True 53 | use_fast: False 54 | 55 | # Dataloader 56 | num_workers: 4 57 | prefetch_factor: 2 58 | 59 | do_preprocess: False 60 | 61 | model_name_or_path: /tmp/llama-65b-mp8 62 | pretrain: 63 | 64 | exp_name: llama.65b.merit_v91_v91.seq2seq.v5.0.3aug.mp8.dp2.adamw.500steps.NA100.0702.aws 65 | exp_notes: 66 | output_dir: /tmp/${exp_name} # Fix 67 | 68 | do_train: True 69 | evaluate_during_training: False 70 | 71 | do_eval: True 72 | eval_sub_path: checkpoint-* 73 | 74 | # Training hyper-parameters 75 | per_gpu_train_batch_size: 8 76 | per_gpu_eval_batch_size: 1 77 | learning_rate: 1e-6 78 | gradient_accumulation_steps: 256 79 | weight_decay: 0.001 80 | adam_epsilon: 1e-6 81 | adam_betas: "(0.9, 0.99)" 82 | max_grad_norm: 5.0 83 | num_train_epochs: 1 84 | max_steps: -1 85 | warmup_proportion: 0 86 | warmup_steps: 50 87 | total_dataset_len: 2122936 88 | 89 | # Optimizer 90 | optimizer: 91 | use_nvlamb: 92 | bit_training: 93 | 94 | 95 | logging_steps: 1 96 | save_best: False 97 | save_steps: 50 98 | eval_steps: -1 99 | ddp_eval: True 100 | no_cuda: False 101 | seed: 42 102 | local_rank: -1 103 | fp16: True 104 | fp16_opt_level: O1 105 | fp16_bfloat16: True 106 | 107 | # Prediction config 108 | prediction_cfg: 109 | metric: "acc" 110 | measure: 1 111 | best_checkpoint: 112 | best_result: 113 | eval_forward_fn: 114 | _target_: general_util.evaluator.DiscriminatorForwardFn 115 | post_process: 116 | 117 | 118 | # Deepspeed config 119 | ds_cfg: 120 | train_micro_batch_size_per_gpu: ${per_gpu_train_batch_size} 121 | gradient_accumulation_steps: ${gradient_accumulation_steps} 122 | optimizer: 123 | type: AdamW 124 | params: 125 | lr: ${learning_rate} 126 | betas: [0.9, 0.99] 127 | eps: ${adam_epsilon} 128 | weight_decay: ${weight_decay} 129 | scheduler: 130 | type: WarmupDecayLR 131 | params: 132 | total_num_steps: 133 | warmup_max_lr: ${learning_rate} 134 | warmup_num_steps: 135 | warmup_type: linear 136 | gradient_clipping: ${max_grad_norm} 137 | fp16: 138 | enabled: true 139 | loss_scale: 0 140 | loss_scale_window: 1000 141 | initial_scale_power: 12 142 | hysteresis: 2 143 | min_loss_scale: 1 144 | # bf16: 145 | # enabled: ${fp16} 146 | # autotuning: 147 | # enabled: true 148 | # arg_mappings: 149 | # train_micro_batch_size_per_gpu: "per_gpu_train_batch_size" 150 | # gradient_accumulation_steps: "gradient_accumulation_steps" 151 | # zero_optimization: "ds_cfg.zero_optimization" 152 | zero_optimization: 153 | stage: 1 154 | contiguous_gradients: True 155 | overlap_comm: True 156 | reduce_scatter: True 157 | reduce_bucket_size: 5e7 158 | allgather_partitions: True 159 | allgather_bucket_size: 5e7 160 | offload_optimizer: 161 | device: cpu 162 | pin_memory: True 163 | # offload_param: 164 | # device: cpu 165 | # pin_memory: True 166 | # activation_checkpointing: 167 | # partition_activations: True 168 | # cpu_checkpointing: True 169 | # contiguous_memory_optimization: False 170 | # number_checkpoints: False 171 | # synchronize_checkpoint_boundary: False 172 | # profile: False 173 | steps_per_print: 1 174 | 175 | # Lightseq config 176 | with_lightseq: False 177 | 178 | 179 | summary_helper: 180 | _target_: general_util.tensorboard_helper.WandbWriter 181 | batch_index_or_keys: 182 | outputs_index_or_keys: 183 | 184 | # Temporary variables 185 | n_gpu: 186 | device: 187 | train_batch_size: 188 | eval_batch_size: 189 | world_size: 190 | topology: 191 | -------------------------------------------------------------------------------- /convert2ckpt.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Literal 3 | from dataclasses import dataclass, field 4 | 5 | import torch 6 | import transformers 7 | from transformers.models.llama.modeling_llama import LlamaConfig 8 | 9 | from general_util.tokenization_utils import expand_special_tokenizer, PreTrainedTokenizer 10 | 11 | 12 | @dataclass 13 | class Arguments: 14 | model_name_or_path: Optional[str] = field(default="/path/to/llama-7b-hf") 15 | output_dir: str = field(default="./llama-7B-init-ckpt") 16 | mp_world_size: int = field(default=1) 17 | 18 | 19 | def write_ckpt(outpath: Path, model: torch.nn.Module, model_config: LlamaConfig, mp: int): 20 | loaded = model.state_dict() 21 | 22 | n_layers = model_config.num_hidden_layers 23 | # embedding 24 | sd = {"weight": loaded['model.embed_tokens.weight']} 25 | torch.save(sd, outpath / "layer_00-model_00-model_states.pt") 26 | # norm 27 | sd = {f"weight": loaded['model.norm.weight']} 28 | torch.save(sd, outpath / f"layer_{n_layers + 1}-model_00-model_states.pt") 29 | # lm head 30 | sd = {f"weight": loaded['lm_head.weight']} 31 | torch.save(sd, outpath / f"layer_{n_layers + 2}-model_00-model_states.pt") 32 | # decoder layers 33 | for layer_i in range(n_layers): 34 | sd = {nm.replace(f"model.layers.{layer_i}.", f""): weight for nm, weight in loaded.items() if 35 | nm.startswith(f"model.layers.{layer_i}.")} 36 | torch.save(sd, outpath / f"layer_{layer_i + 1:02d}-model_00-model_states.pt") 37 | 38 | model_state = { 39 | "dp_world_size": 1, 40 | "mp_world_size": mp, 41 | "module": None, 42 | "optimizer": None, 43 | "global_steps": 1, 44 | "skipped_steps": 1, 45 | "iteration": 1, 46 | } 47 | for rank in range(mp): 48 | torch.save(model_state, outpath / f"mp_rank_{rank:02d}_model_states.pt") 49 | 50 | 51 | def main(): 52 | parser = transformers.HfArgumentParser((Arguments,)) 53 | args, = parser.parse_args_into_dataclasses() 54 | 55 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path) 56 | model_config = transformers.AutoConfig.from_pretrained(args.model_name_or_path) 57 | model = transformers.AutoModelForCausalLM.from_pretrained(args.model_name_or_path) 58 | 59 | original_vocab_size = model_config.vocab_size 60 | expand_special_tokenizer(tokenizer) 61 | if len(tokenizer) > original_vocab_size: 62 | print(f"expand vocab size from {original_vocab_size} to {len(tokenizer)}") 63 | model.resize_token_embeddings(len(tokenizer)) 64 | 65 | outpath = Path(args.output_dir) 66 | if outpath.exists(): 67 | print(f"{outpath} exists. Do nothing.") 68 | exit(0) 69 | 70 | print(f"create {outpath}") 71 | outpath.mkdir() 72 | steppath = outpath / "global_step001" 73 | steppath.mkdir() 74 | 75 | write_ckpt(steppath, model, model_config, args.mp_world_size) 76 | with open(outpath / "latest", "w") as fout: 77 | fout.write("global_step001") 78 | 79 | tokenizer.save_pretrained(outpath) 80 | model_config.save_pretrained(outpath) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Set, Union, Dict, Tuple 3 | 4 | from transformers import PreTrainedTokenizer 5 | from transformers import RobertaTokenizer, RobertaTokenizerFast, AlbertTokenizer, AlbertTokenizerFast, DebertaTokenizer, \ 6 | DebertaTokenizerFast, DebertaV2Tokenizer 7 | from transformers.models.bert.tokenization_bert import whitespace_tokenize 8 | 9 | from general_util.logger import get_child_logger 10 | 11 | try: 12 | from nltk import word_tokenize 13 | except: 14 | pass 15 | 16 | logger = get_child_logger(__name__) 17 | 18 | 19 | def tokenizer_get_name(_tokenizer: PreTrainedTokenizer): 20 | tokenizer_name = _tokenizer.__class__.__name__ 21 | tokenizer_name = tokenizer_name.replace('TokenizerFast', '') 22 | tokenizer_name = tokenizer_name.replace('Tokenizer', '').lower() 23 | return tokenizer_name 24 | 25 | 26 | def get_sep_tokens(_tokenizer: PreTrainedTokenizer): 27 | if _tokenizer.sep_token: 28 | return [_tokenizer.sep_token] * (_tokenizer.max_len_single_sentence - _tokenizer.max_len_sentences_pair) 29 | return [] 30 | 31 | 32 | # FIXED: This method may find a span within a single word. 33 | # def find_span(text: str, span: str, start: int = 0): 34 | # pos = text.find(span, start) 35 | # if pos == -1: 36 | # return [] 37 | # _e = pos + len(span) 38 | # return [(pos, _e)] + find_span(text, span, start=_e) 39 | 40 | 41 | def is_alphabet(char): 42 | res = ord('a') <= ord(char) <= ord('z') or ord('A') <= ord(char) <= ord('Z') 43 | res = res or (char in ['-', '\'']) # Fix the problem shown in the bad case of the method `span_chunk`. 44 | return res 45 | 46 | 47 | def whitespace_tokenize_w_punctuation_ends(text): 48 | """ 49 | Bad case: 50 | >>> whitespace_tokenize_w_punctuation_ends("\" My name is Fangkai Jiao.\"") 51 | >>> ['"', 'My', 'name', 'is', 'Fangkai', 'Jiao.', '"'] 52 | 53 | >>> word_tokenize("\" My name is Fangkai Jiao.\"") 54 | >>> ['``', 'My', 'name', 'is', 'Fangkai', 'Jiao', '.', "''"] 55 | """ 56 | words = whitespace_tokenize(text) 57 | new_words = [] 58 | for word in words: 59 | if len(word) == 1: 60 | new_words.append(word) 61 | continue 62 | 63 | if not is_alphabet(word[0]): 64 | new_words.append(word[0]) 65 | word = word[1:] 66 | 67 | if len(word) == 1: 68 | new_words.append(word) 69 | continue 70 | 71 | if not is_alphabet(word[-1]): 72 | new_words.append(word[:-1]) 73 | new_words.append(word[-1]) 74 | else: 75 | new_words.append(word) 76 | 77 | return new_words 78 | 79 | 80 | def find_span(sentence: str, span: str, start: int = 0): 81 | span = span.strip() 82 | 83 | s = sentence.find(span, start) 84 | if s == -1: 85 | return [] 86 | 87 | e = s + len(span) 88 | 89 | a = not is_alphabet(sentence[s - 1]) if s > 0 else True 90 | b = not is_alphabet(sentence[e]) if e < len(sentence) else True 91 | if a and b: 92 | return [(s, e)] + find_span(sentence, span, start=e) 93 | else: 94 | return find_span(sentence, span, start=e) 95 | 96 | 97 | def span_chunk(text: str, span_ls: List[str], space_tokenize: bool = False) -> Tuple[List[str], List[int]]: 98 | """ 99 | Word based span indicating. 100 | The method is based on whitespace tokenization, which may lead to inconsistent with BPE or Wordpiece. 101 | 102 | FIXME: 103 | 1. The warnings are to be fixed. There is some consistency can be address through proper text normalization. 104 | 2. The `whitespace_tokenize` aims to not split the words such as "don't", 105 | but may cause the punctuations not split correctly. 106 | """ 107 | pos_ls = [] 108 | for span in span_ls: 109 | span_pos_ls = find_span(text, span) 110 | pos_ls.extend(span_pos_ls) 111 | pos_ls = sorted(pos_ls, key=lambda x: x[0]) 112 | 113 | # Unified span 114 | to_be_dropped = set() 115 | for i, pos_i in enumerate(pos_ls): 116 | for j, pos_j in enumerate(pos_ls): 117 | if i == j: 118 | continue 119 | if pos_j[0] <= pos_i[0] and pos_i[1] <= pos_j[1]: 120 | to_be_dropped.add(i) 121 | 122 | new_pos_ls = [] 123 | for pos_id, pos in enumerate(pos_ls): 124 | if pos_id not in to_be_dropped: 125 | new_pos_ls.append(pos) 126 | pos_ls = new_pos_ls 127 | 128 | # No within word span check. 129 | for pos_id, pos in enumerate(pos_ls): 130 | if pos_id == 0: 131 | continue 132 | # assert pos[0] >= pos_ls[pos_id - 1][1], (span_ls, text[pos[0]: pos[1]], text[pos_ls[pos_id - 1][0]: pos_ls[pos_id - 1][1]]) 133 | # TODO: Think about how to fix this: 134 | # some bad cases: 135 | # - AssertionError: (['Goethe-Universität', 'Johann Wolfgang Goethe'], 'Goethe-Universität', 'Johann Wolfgang Goethe') 136 | if pos[0] < pos_ls[pos_id - 1][1]: 137 | # pos[0] = pos_ls[pos_id - 1][1] 138 | pos_ls[pos_id] = (pos_ls[pos_id - 1][1], pos[1]) 139 | 140 | text_spans = [] 141 | indicate_mask = [] 142 | last_e = 0 143 | for s, e in pos_ls: 144 | if last_e > s: 145 | logger.warning(f"Overlapped span: {text_spans[-1]}\t{text[s: e]}\t{text}") 146 | print(f"Overlapped span: {text_spans[-1]}\t{text[s: e]}\t{text}") 147 | continue 148 | if s > last_e: 149 | if space_tokenize: 150 | # text_spans.extend(whitespace_tokenize(text[last_e: s])) 151 | # text_spans.extend(whitespace_tokenize_w_punctuation_ends(text[last_e: s])) 152 | text_spans.extend(word_tokenize(text[last_e: s])) 153 | else: 154 | tmp = text[last_e: s].strip() 155 | if tmp: 156 | text_spans.append(tmp) 157 | indicate_mask = indicate_mask + [0] * (len(text_spans) - len(indicate_mask)) 158 | 159 | text_spans.append(text[s: e].strip()) 160 | indicate_mask = indicate_mask + [1] * (len(text_spans) - len(indicate_mask)) 161 | last_e = e 162 | 163 | rest = text[last_e:].strip() 164 | if rest: 165 | if space_tokenize: 166 | # text_spans.extend(whitespace_tokenize(rest)) 167 | # text_spans.extend(whitespace_tokenize_w_punctuation_ends(rest)) 168 | text_spans.extend(word_tokenize(rest)) 169 | else: 170 | text_spans.append(rest) 171 | indicate_mask = indicate_mask + [0] * (len(text_spans) - len(indicate_mask)) 172 | 173 | # recovered_text = " ".join(text_spans) 174 | # if recovered_text != text: 175 | # logger.warning(f"In consistent text during chunk:\n{recovered_text}\n{text}") 176 | # print(f"In consistent text during chunk:\n{recovered_text}\n{text}") 177 | # print(span_ls) 178 | # print("======================") 179 | 180 | return text_spans, indicate_mask 181 | 182 | 183 | def span_chunk_subword(text: str, span_ls: List[str]) -> Tuple[List[str], List[int]]: 184 | """ 185 | Using the subword tokenization algorithm, e.g., BPR or wordpiece, to tokenize the sentence first, 186 | and find the span through recovery, which may have high time complexity. 187 | """ 188 | pass 189 | 190 | 191 | def span_chunk_simple(text: str, span_ls: List[str], tokenizer: PreTrainedTokenizer): 192 | """ 193 | This version only process the entities spans and using pre-trained tokenizer to tokenize the text first 194 | to annotate the position of each span. 195 | """ 196 | pos_ls = [] 197 | for span in span_ls: 198 | span_pos_ls = find_span(text, span) 199 | pos_ls.extend(span_pos_ls) 200 | pos_ls = sorted(pos_ls, key=lambda x: x[0]) 201 | 202 | for pos_id, pos in enumerate(pos_ls): 203 | if pos_id == 0: 204 | continue 205 | # assert pos[0] >= pos_ls[pos_id - 1][1], (span_ls, text[pos[0]: pos[1]], text[pos_ls[pos_id - 1][0]: pos_ls[pos_id - 1][1]]) 206 | # There maybe bad case where a entity in a substring of another entity. 207 | # A bad case: 208 | # AssertionError: (['Netherlands', 'history of eindhoven', 'Koninkrijk der Nederlanden', 'Constituent country of the Kingdom of the Netherlands', 'Robert van der Horst', 'Eindhoven'], 209 | # 'Netherlands', 'Constituent country of the Kingdom of the Netherlands') 210 | if pos[0] < pos_ls[pos_id - 1][1]: 211 | return None, None 212 | 213 | tokens = [] 214 | token_spans = [] 215 | last_e = 0 216 | for s, e in pos_ls: 217 | if last_e > s: 218 | print(f"Overlapped span: {text[last_e: s]}\t{text[s: e]}\t{text}") 219 | continue 220 | 221 | sub_tokens = tokenizer.tokenize(text[last_e: s]) 222 | find = False 223 | for a in range(len(sub_tokens)): 224 | if tokenizer.convert_tokens_to_string(sub_tokens[a:]).strip() == text[s: e]: 225 | find = True 226 | if a > 0: 227 | tokens.extend(sub_tokens[:a]) 228 | tk_s = len(tokens) 229 | tokens.extend(sub_tokens[a:]) 230 | tk_e = len(tokens) 231 | token_spans.append((tk_s, tk_e)) 232 | break 233 | 234 | if not find: 235 | while s - 1 >= last_e and text[s - 1] == ' ': 236 | s = s - 1 # To tokenize the space with the entity together. 237 | if s > last_e: 238 | tokens.extend(tokenizer.tokenize(text[last_e: s])) 239 | 240 | tk_s = len(tokens) 241 | tokens.extend(tokenizer.tokenize(text[s: e])) 242 | tk_e = len(tokens) 243 | token_spans.append((tk_s, tk_e)) 244 | 245 | last_e = e 246 | 247 | if last_e < len(text): 248 | tokens.extend(tokenizer.tokenize(text[last_e:])) 249 | 250 | normalized_text = tokenizer.convert_tokens_to_string(tokens) 251 | 252 | # consistency check 253 | for s, e in token_spans: 254 | ent = tokenizer.convert_tokens_to_string(tokens[s: e]).strip() 255 | if ent not in span_ls: 256 | # print(f"Warning: {ent}\t{span_ls}") 257 | print(f"Warning: missed entity span after tokenization") 258 | return None, None 259 | 260 | _re_tokens = tokenizer.tokenize(normalized_text) 261 | if tokens != _re_tokens: 262 | # print(f"Warning: \n{tokens}\n{_re_tokens}\n{text}\n{normalized_text}") 263 | # print() 264 | # print(f"Warning: inconsistent tokens") 265 | return None, None 266 | if normalized_text != text: 267 | print(f"Warning, inconsistent text: {normalized_text}\t{text}") 268 | # return None, None 269 | 270 | return normalized_text, token_spans 271 | 272 | 273 | def get_unused_tokens(_tokenizer: PreTrainedTokenizer, token_num: int = 4): 274 | if isinstance(_tokenizer, RobertaTokenizer) or isinstance(_tokenizer, RobertaTokenizerFast): 275 | _unused_token = "" 276 | _unused_tokens = [] 277 | for i in range(token_num): 278 | _unused_tokens.append(_unused_token.format(str(i))) 279 | _tokenizer.add_tokens(_unused_tokens) 280 | return _unused_tokens 281 | elif isinstance(_tokenizer, AlbertTokenizer) or isinstance(_tokenizer, AlbertTokenizerFast): 282 | _unused_token = "[unused{}]" 283 | _unused_tokens = [] 284 | for i in range(token_num): 285 | _unused_tokens.append(_unused_token.format(str(i))) 286 | _tokenizer.add_tokens(_unused_tokens) 287 | return _unused_tokens 288 | elif any([isinstance(_tokenizer, x) for x in [DebertaTokenizer, DebertaTokenizerFast, DebertaV2Tokenizer]]): 289 | _unused_token = "[unused{}]" 290 | _unused_tokens = [] 291 | for i in range(token_num): 292 | _unused_tokens.append(_unused_token.format(str(i))) 293 | _tokenizer.add_tokens(_unused_tokens) 294 | return _unused_tokens 295 | 296 | 297 | def dfs(src: List[int], vis: Set, state: List[int], ans: List[List[int]]): 298 | if len(state) == len(src): 299 | if not all(a == b for a, b in zip(src, state)): 300 | ans.append(state) 301 | 302 | for x in src: 303 | if x not in vis: 304 | new_vis = copy.deepcopy(vis) 305 | new_vis.add(x) 306 | new_state = copy.deepcopy(state) 307 | new_state.append(x) 308 | dfs(src, new_vis, new_state, ans) 309 | 310 | 311 | def get_all_permutation(array: List[int]): 312 | res = [] 313 | dfs(array, set(), list(), res) 314 | for state in res: 315 | assert not all(a == b for a, b in zip(state, array)) 316 | return res 317 | 318 | 319 | def recursive_find_path(node: Union[List, Dict, str], outputs: List[List[str]], res: List[str]): 320 | if isinstance(node, str): 321 | outputs.append(res + [node]) 322 | return 323 | 324 | if isinstance(node, list): 325 | for x in node: 326 | recursive_find_path(x, outputs, res) 327 | elif isinstance(node, dict): 328 | for key, value in node.items(): 329 | recursive_find_path(value, outputs, res + [key]) 330 | else: 331 | raise ValueError('Unknown type: {}'.format(type(node))) 332 | 333 | 334 | def recursive_bfs(deduction: Union[List, Dict]): 335 | res = '' 336 | 337 | queue = [deduction] 338 | while queue: 339 | node = queue.pop(0) 340 | if isinstance(node, str): 341 | res = res + ' ' + node 342 | elif isinstance(node, list): 343 | queue.extend(node) 344 | elif isinstance(node, dict): 345 | for key, value in node.items(): 346 | queue.append(value) 347 | res = res + ' ' + key 348 | else: 349 | raise ValueError('Unknown type: {}'.format(type(node))) 350 | 351 | return res.strip() 352 | 353 | 354 | def dfs_enumerate_all_assign(keys: List[str], values: List[str], relation: str, res: List[str], assign: str, 355 | key_vis: Set): 356 | if len(key_vis) == 0: 357 | res.append(assign) 358 | 359 | for key_id in key_vis: 360 | new_key_vis = copy.deepcopy(key_vis) 361 | new_key_vis.remove(key_id) 362 | for value in values: 363 | if value in keys[key_id]: 364 | continue 365 | new_assign = assign + ' ' + keys[key_id] + ' ' + relation + ' ' + value + '.' 366 | dfs_enumerate_all_assign(keys, values, relation, res, new_assign, new_key_vis) 367 | 368 | 369 | def dfs_load_assignment(assignment_list, res: List[Tuple[str, str]], cur_assign: str): 370 | for assignment in assignment_list: 371 | if assignment['flag'] is False: 372 | continue 373 | if assignment['flag'] is None: 374 | res.append((cur_assign + ' ' + assignment['deduction'], assignment['id'])) 375 | elif assignment['flag'] is True: 376 | dfs_load_assignment(assignment['assignment'], res, cur_assign + ' ' + assignment['deduction']) 377 | else: 378 | raise ValueError('Unknown flag: {}'.format(assignment['flag'])) 379 | 380 | 381 | def word_seq_to_word_char_starts(words: List[str]): 382 | """ 383 | Args: 384 | words: The input word sequence (not subwords). 385 | """ 386 | word2char_starts = [] 387 | text = "" 388 | for word in words: 389 | if len(text) > 0: 390 | text = " " + text 391 | word2char_starts.append(len(text)) 392 | text += word 393 | return word2char_starts, text 394 | 395 | 396 | def char_to_subword_ids(text, tokenizer: PreTrainedTokenizer): 397 | subwords = tokenizer.tokenize(text) 398 | 399 | char2subword_ids = [] 400 | char_lens = 0 401 | subword_idx = 0 402 | subwords_max_num = len(subwords) 403 | while subword_idx < subwords_max_num: 404 | subword_list = [] 405 | prev_subword_idx = subword_idx 406 | subword_len = 0 407 | subword = "" 408 | while subword_idx < subwords_max_num: 409 | subword_list.append(subwords[subword_idx]) 410 | subword_idx += 1 411 | subword = tokenizer.convert_tokens_to_string(subword_list) 412 | subword_len = len(subword) 413 | if subword == tokenizer.sep_token: 414 | char_lens += 1 415 | if text[char_lens: char_lens + subword_len] == subword: 416 | break 417 | assert text[char_lens: char_lens + subword_len] == subword 418 | if subword == "": 419 | char2subword_ids.extend([prev_subword_idx] * (subword_len + 1)) 420 | else: 421 | char2subword_ids.extend([prev_subword_idx] * subword_len) 422 | 423 | char_lens += len(subword) 424 | 425 | if len(text) != len(char2subword_ids): 426 | flag = False 427 | else: 428 | flag = True 429 | 430 | return char2subword_ids, subwords, flag 431 | -------------------------------------------------------------------------------- /data/flan.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union, Tuple, Dict, Optional 3 | 4 | import hydra 5 | import torch 6 | from omegaconf import DictConfig 7 | from torch.utils.data import Dataset 8 | from transformers import AutoTokenizer, PreTrainedTokenizer 9 | 10 | from general_util.tokenization_utils import expand_special_tokenizer 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def load_flan_data_w_filter(file_path: str): 16 | logger.info(f"Loading FLAN data from {file_path}...") 17 | data = torch.load(file_path, map_location="cpu") 18 | new_data = [] 19 | cnt = 0 20 | for item in data: 21 | if item["inputs"].strip() == "": 22 | continue 23 | if item["targets"].strip() == "": 24 | cnt += 1 25 | continue 26 | new_data.append(item) 27 | logger.info(f"Removed {cnt} empty examples.") 28 | logger.info(f"Loaded {len(new_data)} examples.") 29 | return new_data 30 | 31 | 32 | # def load_gpt4all_data(): 33 | # return load_dataset("nomic-ai/gpt4all-j-prompt-generations", revision='v1.2-jazzy')["train"] 34 | 35 | 36 | class PromptDataset(Dataset): 37 | def __init__(self, file_path, tokenizer: PreTrainedTokenizer, cfg: DictConfig): 38 | self.data = hydra.utils.instantiate(cfg, file_path) 39 | self.tokenizer = tokenizer 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def __getitem__(self, idx): 45 | return { 46 | "flan": { 47 | "inputs": self.data[idx]["prompt"], 48 | "targets": self.data[idx]["response"], 49 | } 50 | } 51 | 52 | 53 | class FLANDataset(Dataset): 54 | def __init__(self, file_path: str, tokenizer: PreTrainedTokenizer): 55 | self.data = load_flan_data_w_filter(file_path) 56 | self.tokenizer = tokenizer 57 | 58 | def __len__(self): 59 | return len(self.data) 60 | 61 | def __getitem__(self, idx): 62 | return self.data[idx] 63 | 64 | 65 | class WikiPathDatasetV5WFlan(Dataset): 66 | def __init__(self, raw_data: Union[Tuple, DictConfig], flan_file: str, file_path: str, tokenizer: PreTrainedTokenizer): 67 | # print(type(raw_data)) 68 | if isinstance(raw_data, DictConfig): 69 | raw_data = hydra.utils.instantiate(raw_data, file_path=file_path, tokenizer=tokenizer) 70 | 71 | self.examples = raw_data[0] 72 | self.flan_data = load_flan_data_w_filter(flan_file) 73 | 74 | def __len__(self): 75 | return max(len(self.examples), len(self.flan_data)) 76 | 77 | def __getitem__(self, index): 78 | example = self.examples[index % len(self.examples)] 79 | flan = self.flan_data[index % len(self.flan_data)] 80 | # example = self.examples[index] 81 | # if index >= len(self.flan_data): 82 | # flan = random.choice(self.flan_data) 83 | # else: 84 | # flan = self.flan_data[index] 85 | return { 86 | "example": example, 87 | "flan": flan, 88 | "index": index, 89 | } 90 | 91 | 92 | class WikiPathDatasetV5WithDataset(Dataset): 93 | def __init__(self, raw_data: Union[Tuple, DictConfig], extra_data: Union[PromptDataset, DictConfig], 94 | file_path: str, tokenizer: PreTrainedTokenizer, add_wiki_text: bool = False): 95 | if isinstance(raw_data, DictConfig): 96 | raw_data = hydra.utils.instantiate(raw_data, file_path=file_path, tokenizer=tokenizer) 97 | 98 | if isinstance(extra_data, DictConfig): 99 | extra_data = hydra.utils.instantiate(extra_data, tokenizer=tokenizer) 100 | 101 | self.examples = raw_data[0] 102 | self.extra_data = extra_data 103 | 104 | self.add_wiki_text = add_wiki_text 105 | if self.add_wiki_text: 106 | self.wiki_texts = raw_data[1] 107 | 108 | def __len__(self): 109 | return max(len(self.examples), len(self.extra_data)) 110 | 111 | def __getitem__(self, index): 112 | example = self.examples[index % len(self.examples)] 113 | flan = self.extra_data[index % len(self.extra_data)] 114 | res = { 115 | "example": example, 116 | "index": index, 117 | } 118 | res.update(flan) 119 | if self.add_wiki_text: 120 | res["text"] = self.wiki_texts[index % len(self.wiki_texts)] 121 | return res 122 | 123 | 124 | class FlanCollectionGroupDataset(Dataset): 125 | def __init__(self, file_path: str, tokenizer=None): 126 | super().__init__() 127 | logger.info(f"Loading FLAN data from {file_path}...") 128 | data = torch.load(file_path, map_location="cpu") 129 | self.data = [] 130 | cnt = 0 131 | for item in data: 132 | if item["inputs"].strip() == "": 133 | continue 134 | if item["targets"].strip() == "": 135 | cnt += 1 136 | continue 137 | self.data.append(item) 138 | logger.info(f"Removed {cnt} empty examples.") 139 | 140 | def __len__(self): 141 | return len(self.data) 142 | 143 | def __getitem__(self, index): 144 | return { 145 | "flan": self.data[index], 146 | } 147 | 148 | 149 | def vanilla_seq2seq_convertor(examples, tokenizer: PreTrainedTokenizer, max_seq_length, decoder_only: bool = False): 150 | inputs = [] 151 | outputs = [] 152 | for exp in examples: 153 | inputs.append(exp["inputs"]) 154 | if decoder_only: 155 | outputs.append(exp["inputs"] + " " + exp["targets"] + tokenizer.eos_token) 156 | else: 157 | outputs.append(exp["targets"]) 158 | 159 | model_inputs = tokenizer(inputs, text_target=outputs, max_length=max_seq_length, padding="longest", 160 | truncation=True, return_tensors="pt") 161 | if decoder_only: 162 | input_lens = model_inputs["input_ids"].ne(tokenizer.pad_token_id).sum(dim=1) 163 | model_inputs = tokenizer(outputs, max_length=max_seq_length, padding="longest", 164 | truncation=True, return_tensors="pt") 165 | new_input_lens = model_inputs["input_ids"].ne(tokenizer.pad_token_id).sum(dim=1) 166 | input_lens = input_lens - input_lens.eq(new_input_lens).to(input_lens.dtype) * (input_lens // 2) 167 | input_lens = input_lens.to(torch.long) 168 | model_inputs["input_lens"] = input_lens 169 | 170 | return model_inputs 171 | 172 | 173 | def combine_tensor_on_length(a: torch.Tensor, b: torch.Tensor, pad_id: int): 174 | max_len = max(a.size(1), b.size(1)) 175 | new_tensor = torch.zeros(a.size(0) + b.size(0), max_len, dtype=a.dtype, device=a.device).fill_(pad_id) 176 | new_tensor[:a.size(0), :a.size(1)] = a 177 | new_tensor[a.size(0):, :b.size(1)] = b 178 | return new_tensor 179 | 180 | 181 | def get_lm_labels(input_lens, input_ids, pad_token_id, ignore_index=-100): 182 | labels = input_ids.clone() 183 | 184 | label_mask = labels.ne(pad_token_id) 185 | lens_mask = torch.arange(labels.size(1))[None, :] >= input_lens[:, None] 186 | label_mask = label_mask & lens_mask 187 | 188 | labels = labels.masked_fill(~label_mask, ignore_index).contiguous() 189 | 190 | return labels 191 | 192 | 193 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 194 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 195 | """ 196 | Make causal mask used for bi-directional self-attention. 197 | """ 198 | bsz, tgt_len = input_ids_shape 199 | mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) 200 | mask_cond = torch.arange(mask.size(-1)) 201 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 202 | mask = mask.to(dtype) 203 | 204 | if past_key_values_length > 0: 205 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 206 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 207 | 208 | 209 | # Copied from transformers.models.bart.modeling_bart._expand_mask 210 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 211 | """ 212 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 213 | """ 214 | bsz, src_len = mask.size() 215 | tgt_len = tgt_len if tgt_len is not None else src_len 216 | 217 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 218 | 219 | inverted_mask = 1.0 - expanded_mask 220 | 221 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 222 | 223 | 224 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 225 | def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length): 226 | # create causal mask 227 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 228 | combined_attention_mask = None 229 | if input_shape[-1] > 1: 230 | combined_attention_mask = _make_causal_mask( 231 | input_shape, 232 | torch.float16, 233 | past_key_values_length=past_key_values_length, 234 | ) 235 | 236 | if attention_mask is not None: 237 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 238 | expanded_attn_mask = _expand_mask(attention_mask, torch.float16, tgt_len=input_shape[-1]) 239 | combined_attention_mask = ( 240 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 241 | ) 242 | 243 | return combined_attention_mask 244 | 245 | 246 | def convert_to_standard_inputs(model_inputs: Dict, tokenizer: PreTrainedTokenizer, ignored_index: int = -100): 247 | input_ids = model_inputs["input_ids"] 248 | attention_mask = model_inputs["attention_mask"] 249 | # input_lens = getattr(model_inputs, "input_lens", None) 250 | input_lens = model_inputs["input_lens"] 251 | 252 | labels = get_lm_labels(input_lens, input_ids, tokenizer.pad_token_id, ignored_index) 253 | 254 | seq_length = input_ids.size(1) 255 | position_ids = torch.arange(seq_length, dtype=torch.long) 256 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 257 | 258 | attention_mask = _prepare_decoder_attention_mask(attention_mask, input_ids.shape, 0) 259 | 260 | return input_ids, attention_mask, position_ids, labels 261 | 262 | 263 | class FlanCollatorOverCollator: 264 | def __init__(self, collator, tokenizer: str, max_seq_length: int, decoder_only: bool = False, return_standard_inputs: bool = False): 265 | self.collator = collator 266 | self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=False) 267 | expand_special_tokenizer(self.tokenizer) 268 | self.max_seq_length = max_seq_length 269 | self.decoder_only = decoder_only 270 | self.convert_to_standard_inputs = return_standard_inputs 271 | 272 | def __call__(self, batch): 273 | flan_batch = [] 274 | for item in batch: 275 | flan_batch.append(item.pop("flan")) 276 | 277 | index = torch.tensor([b["index"] for b in batch], dtype=torch.long) 278 | 279 | if self.collator is not None: 280 | model_inputs = self.collator(batch) 281 | orig_batch_size = model_inputs["input_ids"].size(0) 282 | flan_inputs = vanilla_seq2seq_convertor(flan_batch, self.tokenizer, self.max_seq_length, self.decoder_only) 283 | for k, v in flan_inputs.items(): 284 | if k == "input_lens": 285 | if "flan_input_lens" in model_inputs: 286 | model_inputs["flan_input_lens"] = torch.cat([model_inputs["flan_input_lens"], v], dim=0) 287 | else: 288 | empty_input_lens = torch.zeros(orig_batch_size, dtype=torch.long, device=v.device) 289 | model_inputs[f"flan_input_lens"] = torch.cat([empty_input_lens, v], dim=0) 290 | continue 291 | 292 | if f"flan_{k}" in model_inputs: 293 | model_inputs[f"flan_{k}"] = combine_tensor_on_length(model_inputs[f"flan_{k}"], v, self.tokenizer.pad_token_id) 294 | else: 295 | model_inputs[f"flan_{k}"] = v 296 | else: 297 | model_inputs = vanilla_seq2seq_convertor(flan_batch, self.tokenizer, self.max_seq_length, self.decoder_only) 298 | 299 | if self.convert_to_standard_inputs: 300 | input_ids, attention_mask, position_ids, labels = convert_to_standard_inputs(model_inputs, self.tokenizer) 301 | 302 | labels = torch.cat([labels, index.unsqueeze(1)], dim=1) 303 | 304 | return ( 305 | (input_ids, attention_mask, position_ids, index), 306 | labels, 307 | ) 308 | 309 | return model_inputs 310 | -------------------------------------------------------------------------------- /data/test.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class TestDataset(Dataset): 5 | def __init__(self, file_path, tokenizer, pseudo_dataset_len: int = -1): 6 | super().__init__() 7 | self.data = ["My name is Jiao Fangkai."] 8 | self.pseudo_dataset_len = pseudo_dataset_len 9 | 10 | def __len__(self): 11 | if self.pseudo_dataset_len > 0: 12 | return self.pseudo_dataset_len 13 | return 100000000 14 | 15 | def __getitem__(self, index): 16 | return { 17 | "flan": { 18 | "inputs": self.data[0], 19 | "targets": self.data[0], 20 | }, 21 | "index": index, 22 | } 23 | -------------------------------------------------------------------------------- /general_util/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from transformers import PreTrainedTokenizer 4 | import os 5 | from data.data_utils import tokenizer_get_name 6 | 7 | DEFAULT_PAD_TOKEN = "[PAD]" 8 | DEFAULT_EOS_TOKEN = "" 9 | DEFAULT_BOS_TOKEN = "" 10 | DEFAULT_UNK_TOKEN = "" 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def expand_special_tokenizer(tokenizer: PreTrainedTokenizer): 16 | tokenizer_name = tokenizer_get_name(tokenizer) 17 | if "llama" in tokenizer_name: 18 | special_tokens_map = {} 19 | eos_token = os.environ.get("EOS_TOKEN", None) 20 | if eos_token or (not tokenizer.eos_token): 21 | special_tokens_map["eos_token"] = eos_token if eos_token else DEFAULT_EOS_TOKEN 22 | 23 | bos_token = os.environ.get("BOS_TOKEN", None) 24 | if bos_token or (not tokenizer.bos_token): 25 | special_tokens_map["bos_token"] = bos_token if bos_token else DEFAULT_BOS_TOKEN 26 | 27 | unk_token = os.environ.get("UNK_TOKEN", None) 28 | if not tokenizer.unk_token: 29 | special_tokens_map["unk_token"] = unk_token if unk_token else DEFAULT_UNK_TOKEN 30 | 31 | pad_token = os.environ.get("PAD_TOKEN", None) 32 | if not tokenizer.pad_token: 33 | special_tokens_map["pad_token"] = pad_token if pad_token else DEFAULT_PAD_TOKEN 34 | 35 | new_tokens = tokenizer.add_special_tokens( 36 | special_tokens_dict=special_tokens_map 37 | ) 38 | # new_tokens = tokenizer.add_special_tokens(special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN)) 39 | # tokenizer.pad_token = tokenizer.eos_token 40 | # tokenizer.pad_token_id = tokenizer.eos_token_id 41 | # assert new_tokens == 1 42 | elif "gptneox" in tokenizer_name: 43 | special_tokens_map = {} 44 | eos_token = os.environ.get("EOS_TOKEN", None) 45 | if eos_token: 46 | special_tokens_map["eos_token"] = eos_token if eos_token else DEFAULT_EOS_TOKEN 47 | 48 | new_tokens = tokenizer.add_special_tokens( 49 | special_tokens_dict=special_tokens_map 50 | ) 51 | 52 | if not tokenizer.pad_token: 53 | tokenizer.pad_token = tokenizer.eos_token 54 | tokenizer.pad_token_id = tokenizer.eos_token_id 55 | 56 | logger.info(tokenizer) 57 | 58 | 59 | def is_seq2seq_tokenizer(tokenizer: PreTrainedTokenizer): 60 | tokenizer_name = tokenizer_get_name(tokenizer) 61 | return any([x in tokenizer_name for x in ["t5", "bart", "pegasus", "mbart", "marian", "blenderbot"]]) 62 | -------------------------------------------------------------------------------- /models/llama_ds_mp_wrap.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import deepspeed 4 | import torch 5 | import logging 6 | from deepspeed.pipe import TiedLayerSpec, LayerSpec 7 | from torch.nn import CrossEntropyLoss 8 | from transformers.models.llama.modeling_llama import ( 9 | LlamaForCausalLM, 10 | LlamaConfig, 11 | LlamaDecoderLayer, 12 | LlamaRMSNorm, 13 | ) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class EmbeddingPipeLayer(torch.nn.Module): 19 | def __init__(self, model: LlamaForCausalLM): 20 | super().__init__() 21 | self.embed_tokens = model.model.embed_tokens 22 | self.weight = self.embed_tokens.weight 23 | 24 | def forward(self, ipt): 25 | input_ids, attention_mask, position_ids = ipt 26 | inputs_embeds = self.embed_tokens(input_ids) 27 | return inputs_embeds, attention_mask, position_ids 28 | 29 | 30 | class LlamaPipeLayer(torch.nn.Module): 31 | def __init__(self, model: LlamaForCausalLM, layer_idx): 32 | super().__init__() 33 | self.layer = model.model.layers[layer_idx] 34 | self.gradient_checkpointing = model.model.gradient_checkpointing 35 | 36 | def forward(self, ipt): 37 | hidden_states, attention_mask, position_ids = ipt 38 | 39 | if self.gradient_checkpointing and self.training: 40 | output_attentions = False 41 | 42 | def create_custom_forward(module): 43 | def custom_forward(*inputs): 44 | # None for past_key_value 45 | return module(*inputs, output_attentions, None) 46 | 47 | return custom_forward 48 | 49 | # layer_outputs = torch.utils.checkpoint.checkpoint( 50 | # create_custom_forward(self.layer), 51 | # hidden_states, 52 | # attention_mask, 53 | # position_ids, 54 | # None, 55 | # ) 56 | # deepspeed checkpoint auto use outputs[0] if len(outputs) == 1 57 | outputs = deepspeed.checkpointing.checkpoint( 58 | create_custom_forward(self.layer), 59 | hidden_states, 60 | attention_mask, 61 | position_ids, 62 | None, 63 | ) 64 | layer_outputs = [outputs] 65 | else: 66 | layer_outputs = self.layer( 67 | hidden_states, 68 | attention_mask=attention_mask, 69 | position_ids=position_ids, 70 | # past_key_value=past_key_value, 71 | # output_attentions=output_attentions, 72 | # use_cache=use_cache, 73 | ) 74 | 75 | hidden_states = layer_outputs[0] 76 | return hidden_states, attention_mask, position_ids 77 | 78 | 79 | class FLNPipeLayer(torch.nn.Module): 80 | def __init__(self, model: LlamaForCausalLM): 81 | super().__init__() 82 | self.norm = model.model.norm 83 | 84 | def forward(self, ipt): 85 | hidden_states, attention_mask, position_ids = ipt 86 | hidden_states = self.norm(hidden_states) 87 | 88 | return hidden_states 89 | 90 | 91 | class LMPipeLayer(torch.nn.Module): 92 | def __init__(self, model: LlamaForCausalLM): 93 | super().__init__() 94 | self.lm_head = model.lm_head 95 | self.weight = self.lm_head.weight 96 | self.config = model.config 97 | 98 | def forward(self, ipt): 99 | hidden_states = ipt 100 | logits = torch.nn.functional.linear(hidden_states, self.lm_head.weight) 101 | 102 | return logits 103 | 104 | 105 | def loss_fn(outputs, labels): 106 | logits = outputs 107 | # last_rank_index = labels[:, -1] 108 | # labels = labels[:, :-1] 109 | shift_logits = logits[..., :-1, :].contiguous() 110 | shift_labels = labels[..., 1:].contiguous() 111 | 112 | loss_fct = CrossEntropyLoss() 113 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 114 | # print("loss", loss, loss.requires_grad) 115 | # print(os.environ["LOCAL_RANK"], index, last_rank_index) 116 | return loss 117 | 118 | 119 | def get_model(model): 120 | layers = [TiedLayerSpec("weight", EmbeddingPipeLayer, model=model, tied_weight_attr="weight"), 121 | *[LayerSpec(LlamaPipeLayer, model=model, layer_idx=idx) for idx in range(model.config.num_hidden_layers)], 122 | LayerSpec(FLNPipeLayer, model=model), 123 | TiedLayerSpec("weight", LMPipeLayer, model=model, tied_weight_attr="weight"), 124 | ] 125 | return layers 126 | 127 | 128 | class EmbeddingPipe(torch.nn.Embedding): 129 | def forward(self, args): 130 | input_ids, attention_mask, position_ids = args 131 | inputs_embeds = super().forward(input_ids) 132 | return inputs_embeds, attention_mask, position_ids 133 | 134 | 135 | class ParallelTransformerLayerPipe(LlamaDecoderLayer): 136 | def __init__(self, config: LlamaConfig, activation_checkpointing: bool = False): 137 | super().__init__(config) 138 | self.activation_checkpointing = activation_checkpointing 139 | # for name, param in self.named_parameters(): 140 | # if "norm" in name: 141 | # continue 142 | # param.data = param.data.to(dtype) 143 | 144 | def forward(self, args): 145 | if self.activation_checkpointing: 146 | return self._ckpt_forward(args) 147 | 148 | hidden_states, attention_mask, position_ids = args 149 | outputs = LlamaDecoderLayer.forward(self, 150 | hidden_states, 151 | attention_mask, 152 | position_ids, 153 | ) 154 | return outputs[0], attention_mask, position_ids 155 | 156 | def _ckpt_forward(self, args): 157 | hidden_states, attention_mask, position_ids = args 158 | 159 | def create_custom_forward(module): 160 | def custom_forward(*inputs): 161 | return LlamaDecoderLayer.forward(module, *inputs) 162 | 163 | return custom_forward 164 | 165 | # deepspeed checkpoint auto use outputs[0] if len(outputs) == 1 166 | outputs = deepspeed.checkpointing.checkpoint( 167 | create_custom_forward(self), 168 | hidden_states, 169 | attention_mask, 170 | position_ids, 171 | None, 172 | ) 173 | # layer_outputs = torch.utils.checkpoint.checkpoint( 174 | # create_custom_forward(self), 175 | # hidden_states, 176 | # attention_mask, 177 | # position_ids, 178 | # None, 179 | # ) 180 | 181 | return outputs, attention_mask, position_ids 182 | 183 | 184 | class LayerNormPipe(LlamaRMSNorm): 185 | def forward(self, args): 186 | hidden_states, attention_mask, position_ids = args 187 | last_hidden_states = super().forward(hidden_states) 188 | return last_hidden_states 189 | 190 | 191 | class LMLayerPipe(torch.nn.Linear): 192 | def forward(self, args): 193 | hidden_states = args 194 | logits = super().forward(hidden_states) 195 | return logits 196 | 197 | 198 | class LossLayer(torch.nn.Module): 199 | def forward(self, args): 200 | logits, labels = args 201 | shift_logits = logits[..., :-1, :].contiguous() 202 | shift_labels = labels[..., 1:].contiguous() 203 | 204 | loss_fct = CrossEntropyLoss() 205 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 206 | return loss 207 | 208 | 209 | def get_layers_from_config(model_config, activation_checkpointing: bool = False): 210 | """ 211 | `tie_word_embeddings` in LLaMA is set to `false`. 212 | """ 213 | layers = [ 214 | LayerSpec(EmbeddingPipe, model_config.vocab_size, model_config.hidden_size), 215 | # TiedLayerSpec("weight", EmbeddingPipe, model_config.vocab_size, model_config.hidden_size, tied_weight_attr="weight"), 216 | *[LayerSpec(ParallelTransformerLayerPipe, model_config, activation_checkpointing) 217 | for _ in range(model_config.num_hidden_layers)], 218 | LayerSpec(LayerNormPipe, model_config.hidden_size, model_config.rms_norm_eps), 219 | LayerSpec(LMLayerPipe, model_config.hidden_size, model_config.vocab_size, bias=False), 220 | # TiedLayerSpec("weight", LMLayerPipe, model_config.hidden_size, model_config.vocab_size, bias=False, 221 | # tied_weight_attr="weight"), 222 | # LayerSpec(LossLayer), 223 | ] 224 | return layers 225 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | nltk 3 | tensorboard 4 | sentencepiece 5 | transformers 6 | peft 7 | torch==2.0.0 8 | hydra-core 9 | fairscale 10 | deepspeed 11 | datasets -------------------------------------------------------------------------------- /trainer_base_ds_mp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Copyright 2023 Nanyang Technological University Fangkai Jiao 4 | # 5 | # Part of this code is based on the source code of Transformers 6 | # (arXiv:1910.03771) 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import datetime 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | import sys 26 | from typing import Dict, Union 27 | 28 | import deepspeed 29 | import hydra 30 | import numpy as np 31 | import torch 32 | import wandb 33 | from deepspeed.pipe import PipelineModule 34 | from deepspeed.runtime.engine import DeepSpeedEngine 35 | from omegaconf import DictConfig, OmegaConf 36 | from torch import distributed as dist 37 | from torch.utils.data import (DataLoader, RandomSampler, DistributedSampler, ConcatDataset) 38 | from tqdm import tqdm, trange 39 | from transformers import (AutoTokenizer, PreTrainedTokenizer) 40 | 41 | import models.llama_ds_mp_wrap 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | torch.backends.cuda.matmul.allow_tf32 = True 46 | 47 | 48 | # Hack here to process the loading checkpoint bug. 49 | def load_checkpoint(self, 50 | load_dir, 51 | tag=None, 52 | load_module_strict=True, 53 | load_optimizer_states=True, 54 | load_lr_scheduler_states=True, 55 | load_module_only=False, 56 | custom_load_fn=None): 57 | """ 58 | Load training checkpoint 59 | 60 | Arguments: 61 | load_dir: Required. Directory to load the checkpoint from 62 | tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file 63 | load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. 64 | load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance 65 | load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. 66 | load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting. 67 | custom_load_fn: Optional. Custom model load function. 68 | 69 | Returns: 70 | A tuple of ``load_path`` and ``client_state``. 71 | *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. 72 | *``client_state``: State dictionary used for loading required training states in the client code. 73 | 74 | Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right 75 | after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and 76 | ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine 77 | before ``load_checkpoint()``. 78 | 79 | """ 80 | 81 | if tag is None: 82 | latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest" 83 | latest_path = os.path.join(load_dir, latest_tag) 84 | if os.path.isfile(latest_path): 85 | with open(latest_path, "r") as fd: 86 | tag = fd.read().strip() 87 | else: 88 | if self.load_universal_checkpoint(): 89 | raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist') 90 | else: 91 | logger.warning( 92 | f"Unable to find latest file at {latest_path}, if trying to load latest " 93 | "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." 94 | ) 95 | return None, None 96 | 97 | if self.zero_optimization_partition_weights(): 98 | # Prepare for checkpoint load by ensuring all parameters are partitioned 99 | self.optimizer.checkpoint_event_prologue() 100 | 101 | load_path, client_states = self._load_checkpoint(load_dir, 102 | tag, 103 | load_module_strict=load_module_strict, 104 | load_optimizer_states=load_optimizer_states, 105 | load_lr_scheduler_states=load_lr_scheduler_states, 106 | load_module_only=load_module_only, 107 | custom_load_fn=custom_load_fn) 108 | 109 | load_zero_checkpoint = load_optimizer_states and (self.zero_optimization() or self.bfloat16_enabled()) 110 | if load_zero_checkpoint and load_path is not None: 111 | success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) 112 | if not success: 113 | self.optimizer._restore_from_bit16_weights() 114 | 115 | if self.zero_optimization_partition_weights(): 116 | self.optimizer.checkpoint_event_epilogue() 117 | 118 | return load_path, client_states 119 | 120 | 121 | DeepSpeedEngine.load_checkpoint = load_checkpoint 122 | 123 | 124 | def set_seed(args): 125 | random.seed(args.seed) 126 | np.random.seed(args.seed) 127 | torch.manual_seed(args.seed) 128 | if args.n_gpu > 0: 129 | torch.cuda.manual_seed_all(args.seed) 130 | 131 | 132 | def initialize_dataset(cfg: DictConfig, file_path: str, tokenizer: PreTrainedTokenizer): 133 | if "_target_" in cfg: 134 | return hydra.utils.call(cfg, file_path=file_path, tokenizer=tokenizer) 135 | else: 136 | datasets = [initialize_dataset(cfg[key], file_path, tokenizer) for key in cfg.keys()] 137 | assert len(datasets) 138 | datasets = ConcatDataset(datasets) 139 | return datasets 140 | 141 | 142 | def load_and_cache_examples(cfg, tokenizer: PreTrainedTokenizer, _split="train", _file: str = None): 143 | if_barrier = False 144 | 145 | if _file is not None: 146 | input_file = _file 147 | if_barrier = True 148 | else: 149 | if _split == "train": 150 | input_file = cfg.train_file 151 | if_barrier = True 152 | elif _split == "dev": 153 | input_file = cfg.dev_file 154 | if cfg.ddp_eval and cfg.local_rank != -1: 155 | if_barrier = True 156 | elif _split == "test": 157 | input_file = cfg.test_file 158 | if cfg.ddp_eval and cfg.local_rank != -1: 159 | if_barrier = True 160 | else: 161 | raise RuntimeError(_split) 162 | 163 | if getattr(cfg, "dist_load_data_barrier", True) and if_barrier and cfg.local_rank not in [-1, 0]: 164 | dist.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 165 | 166 | sub_config = f"read_tensor_{_split}" 167 | if sub_config in cfg: 168 | dataset = initialize_dataset(cfg[sub_config], file_path=input_file, tokenizer=tokenizer) 169 | else: 170 | dataset = initialize_dataset(cfg.read_tensor, file_path=input_file, tokenizer=tokenizer) 171 | 172 | if getattr(cfg, "dist_load_data_barrier", True) and if_barrier and cfg.local_rank == 0: 173 | dist.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 174 | 175 | if dist.is_initialized(): 176 | dist.barrier() 177 | 178 | return dataset 179 | 180 | 181 | def load_empty_dataset_and_collator(cfg: DictConfig): 182 | from data.test import TestDataset 183 | from data.flan import FlanCollatorOverCollator 184 | 185 | dataset = TestDataset(None, None, getattr(cfg, "total_dataset_len", -1)) 186 | collator = FlanCollatorOverCollator(collator=None, 187 | tokenizer=cfg.model_name_or_path, 188 | max_seq_length=128, 189 | decoder_only=True, 190 | return_standard_inputs=True, 191 | ) 192 | 193 | # Keep consistent with `load_and_cache_examples`. 194 | if getattr(cfg, "dist_load_data_barrier", True): 195 | dist.barrier() 196 | 197 | if dist.is_initialized(): 198 | dist.barrier() 199 | 200 | return dataset, collator 201 | 202 | 203 | def save_model(model: Union[deepspeed.DeepSpeedEngine, deepspeed.PipelineEngine], 204 | cfg: DictConfig, output_dir: str, tokenizer: PreTrainedTokenizer = None, state_dict: Dict = None): 205 | model.save_checkpoint(output_dir) 206 | 207 | if cfg.local_rank not in [-1, 0]: 208 | dist.barrier() 209 | 210 | if cfg.local_rank in [-1, 0]: 211 | 212 | if tokenizer is not None: 213 | tokenizer.save_pretrained(output_dir) 214 | 215 | OmegaConf.save(cfg, os.path.join(output_dir, "training_config.yaml")) 216 | logger.info("Saving model checkpoint to %s", output_dir) 217 | 218 | end_dir = output_dir.split("/")[-1] 219 | 220 | os.system(f"./s5cmd sync {output_dir}/ {cfg.aws_output_bucket}/{end_dir}/") 221 | 222 | if cfg.local_rank == 0: 223 | dist.barrier() 224 | 225 | 226 | def train(cfg, model, tokenizer, continue_from_global_step=0): 227 | """ Train the model """ 228 | if cfg.local_rank in [-1, 0]: 229 | tb_helper = hydra.utils.instantiate(cfg.summary_helper) if "summary_helper" in cfg and cfg.summary_helper else None 230 | else: 231 | tb_helper = None 232 | 233 | cfg.train_batch_size = cfg.per_gpu_train_batch_size 234 | 235 | if "_target_" in cfg.train_file: 236 | files = hydra.utils.instantiate(cfg.train_file) 237 | elif cfg.train_file.startswith("hf:"): 238 | files = [cfg.train_file[3:]] 239 | elif os.path.exists(cfg.train_file): 240 | files = [cfg.train_file] 241 | else: 242 | files = list(glob.glob(cfg.train_file)) 243 | logger.info(files) 244 | 245 | dp_degree = dist.get_world_size() // cfg.num_stages 246 | 247 | if getattr(cfg, "total_dataset_len", -1) > 0: 248 | total_dataset_len = cfg.total_dataset_len 249 | else: 250 | total_dataset_len = 0 251 | for _file in tqdm(files, total=len(files)): 252 | sub_train_dataset = load_and_cache_examples(cfg, tokenizer, _split="train", _file=_file) 253 | total_dataset_len += len(sub_train_dataset) 254 | del sub_train_dataset 255 | 256 | if getattr(cfg, "do_preprocess", False): 257 | return 258 | 259 | if "extended_vocab" in cfg and cfg.extended_vocab: 260 | logger.info(f"Extended extra vocab size: {cfg.extended_vocab}") 261 | model.resize_token_embeddings(model.config.vocab_size + cfg.extended_vocab) 262 | 263 | _actual_train_batch_size = cfg.train_batch_size * cfg.gradient_accumulation_steps * dp_degree 264 | if cfg.max_steps > 0: 265 | t_total = cfg.max_steps 266 | cfg.num_train_epochs = cfg.max_steps // (total_dataset_len // _actual_train_batch_size) + 1 267 | else: 268 | t_total = total_dataset_len // _actual_train_batch_size * cfg.num_train_epochs 269 | 270 | num_warmup_steps = int(t_total * cfg.warmup_proportion) if cfg.warmup_proportion else cfg.warmup_steps 271 | 272 | ds_config = cfg.ds_cfg 273 | if "total_num_steps" in ds_config.scheduler.params: 274 | ds_config.scheduler.params.total_num_steps = t_total 275 | ds_config.scheduler.params.warmup_num_steps = num_warmup_steps 276 | ds_config = OmegaConf.to_container(ds_config, resolve=True) 277 | 278 | if torch.__version__ >= "2" and (getattr(os.environ, "TORCH_COMPILE", False) or getattr(cfg, "compile", False)): 279 | model = torch.compile(model, mode="max-autotune") 280 | model, optimizer, _, scheduler = deepspeed.initialize(model=model, 281 | model_parameters=[p for p in model.parameters() if p.requires_grad], 282 | config=ds_config) 283 | 284 | model.load_checkpoint(cfg.model_name_or_path, load_module_only=True, load_optimizer_states=False, load_lr_scheduler_states=False) 285 | logger.info(optimizer.optimizer) 286 | 287 | # Train! 288 | logger.info("***** Running training *****") 289 | logger.info(" Num examples = %d", total_dataset_len) 290 | logger.info(" Num Epochs = %d", cfg.num_train_epochs) 291 | logger.info(" Instantaneous batch size per GPU = %d", cfg.per_gpu_train_batch_size) 292 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", _actual_train_batch_size) 293 | logger.info(" Gradient Accumulation steps = %d", cfg.gradient_accumulation_steps) 294 | logger.info(" Total optimization steps = %d", t_total) 295 | logger.info(" Warmup steps = %d", num_warmup_steps) 296 | 297 | if continue_from_global_step > 0: 298 | logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step) 299 | model.load_checkpoint(cfg.resume) 300 | 301 | global_step = 0 302 | tr_loss, logging_loss = 0.0, 0.0 303 | # model.zero_grad() 304 | train_iterator = trange(int(cfg.num_train_epochs), desc="Epoch", disable=cfg.local_rank not in [-1, 0]) 305 | set_seed(cfg) # Added here for reproducibility (even between python 2 and 3) 306 | 307 | for epoch in train_iterator: 308 | for _file in files: 309 | if model.is_first_stage() or model.is_last_stage(): 310 | sub_train_dataset = load_and_cache_examples(cfg, tokenizer, _split="train", _file=_file) 311 | 312 | if dp_degree > 1: 313 | dp_id = model.grid.get_data_parallel_id() 314 | sub_train_sampler = DistributedSampler(sub_train_dataset, num_replicas=dp_degree, rank=dp_id) 315 | else: 316 | sub_train_sampler = RandomSampler(sub_train_dataset) 317 | sub_train_collator = hydra.utils.instantiate(cfg.collator) if "collator" in cfg and cfg.collator else None 318 | 319 | sub_train_dataloader = DataLoader(dataset=sub_train_dataset, 320 | sampler=sub_train_sampler, 321 | batch_size=cfg.train_batch_size, 322 | collate_fn=sub_train_collator, 323 | num_workers=cfg.num_workers, 324 | pin_memory=True, 325 | prefetch_factor=cfg.prefetch_factor, 326 | drop_last=True, 327 | ) 328 | else: 329 | sub_train_dataset, sub_train_collator = load_empty_dataset_and_collator(cfg) 330 | sub_train_sampler = None 331 | 332 | sub_train_dataloader = DataLoader(dataset=sub_train_dataset, 333 | batch_size=cfg.train_batch_size * dp_degree, 334 | collate_fn=sub_train_collator, 335 | drop_last=True, 336 | shuffle=False) 337 | 338 | epoch_update_steps = len(sub_train_dataloader) // cfg.gradient_accumulation_steps 339 | sub_train_dataloader = iter(deepspeed.utils.RepeatingLoader(sub_train_dataloader)) 340 | 341 | if sub_train_sampler is not None and isinstance(sub_train_sampler, DistributedSampler): 342 | sub_train_sampler.set_epoch(epoch) 343 | 344 | for _ in tqdm(range(epoch_update_steps), desc="Iteration", disable=cfg.local_rank not in [-1, 0], dynamic_ncols=True): 345 | # If training is continued from a checkpoint, fast forward 346 | # to the state of that checkpoint. 347 | if global_step < continue_from_global_step: 348 | for _ in range(cfg.gradient_accumulation_steps): 349 | next(sub_train_dataloader) 350 | global_step += 1 351 | continue 352 | 353 | model.train() 354 | loss = model.train_batch(data_iter=sub_train_dataloader) 355 | global_step += 1 356 | 357 | tr_loss += loss.item() 358 | 359 | # Log metrics 360 | log_metrics = {} 361 | if cfg.local_rank in [-1, 0] and cfg.logging_steps > 0 and global_step % cfg.logging_steps == 0: 362 | log_metrics['lr'] = scheduler.get_lr()[0] 363 | log_metrics['loss'] = (tr_loss - logging_loss) / cfg.logging_steps 364 | logging_loss = tr_loss 365 | 366 | # Save model checkpoint 367 | if cfg.save_steps > 0 and global_step % cfg.save_steps == 0: 368 | output_dir = os.path.join(cfg.output_dir, 'checkpoint-{}'.format(global_step)) 369 | if cfg.local_rank in [-1, 0] and not os.path.exists(output_dir): 370 | os.makedirs(output_dir, exist_ok=True) 371 | save_model(model, cfg, output_dir, tokenizer) 372 | 373 | if len(log_metrics) > 0 and cfg.local_rank in [-1, 0]: 374 | wandb.log(log_metrics) 375 | 376 | del log_metrics 377 | 378 | if 0 < cfg.max_steps < global_step: 379 | train_iterator.close() 380 | break 381 | 382 | if 0 < cfg.max_steps < global_step: 383 | break 384 | 385 | return global_step, tr_loss / global_step 386 | 387 | 388 | @hydra.main(config_path="conf", config_name="config", version_base="1.2") 389 | def main(cfg: DictConfig): 390 | if "LOCAL_RANK" in os.environ and os.environ["LOCAL_RANK"] not in [-1, "-1"]: 391 | cfg.local_rank = int(os.environ["LOCAL_RANK"]) 392 | 393 | if cfg.local_rank == -1 or cfg.no_cuda: 394 | device = str(torch.device("cuda" if torch.cuda.is_available() and not cfg.no_cuda else "cpu")) 395 | cfg.n_gpu = torch.cuda.device_count() 396 | else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs 397 | torch.cuda.set_device(cfg.local_rank) 398 | device = str(torch.device("cuda", cfg.local_rank)) 399 | deepspeed.init_distributed(dist_backend="nccl", timeout=datetime.timedelta(seconds=7200)) 400 | cfg.n_gpu = 1 401 | cfg.world_size = dist.get_world_size() 402 | cfg.device = device 403 | 404 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 405 | cfg.local_rank, cfg.device, cfg.n_gpu, bool(cfg.local_rank != -1), cfg.fp16) 406 | logger.warning(f"CPU cores: {os.cpu_count()}") 407 | 408 | # Set seed 409 | set_seed(cfg) 410 | 411 | use_barrier = not os.path.exists(cfg.model_name_or_path) 412 | # Load pre-trained model and tokenizer 413 | if use_barrier and cfg.local_rank not in [-1, 0]: 414 | dist.barrier() # Make sure only the first process in distributed training will download model & vocab 415 | 416 | tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path) 417 | 418 | from general_util.tokenization_utils import expand_special_tokenizer 419 | 420 | expand_special_tokenizer(tokenizer) 421 | 422 | model_or_config = hydra.utils.call(cfg.model, cfg.model_name_or_path) 423 | 424 | layers = hydra.utils.call(cfg.get_layers, model_or_config) 425 | model_pipe = PipelineModule(layers=layers, 426 | num_stages=cfg.num_stages, 427 | loss_fn=models.llama_ds_mp_wrap.loss_fn, 428 | activation_checkpoint_interval=getattr(cfg, "activation_checkpoint_interval", 0) 429 | ) 430 | logger.info(f"{model_pipe.topology}") 431 | cfg.topology = str(model_pipe.topology) 432 | 433 | if use_barrier and cfg.local_rank == 0: 434 | dist.barrier() # Make sure only the first process in distributed training will download model & vocab 435 | 436 | if cfg.local_rank in [-1, 0] and cfg.do_train: 437 | if not os.path.exists(cfg.output_dir): 438 | os.makedirs(cfg.output_dir) 439 | OmegaConf.save(cfg, os.path.join(cfg.output_dir, "training_config.yaml")) 440 | 441 | wandb.init( 442 | project="LLaMA-BiFLAN", 443 | name=cfg.exp_name, 444 | notes=cfg.exp_notes, 445 | config=OmegaConf.to_container(cfg, resolve=True), 446 | ) 447 | wandb.define_metric(cfg.prediction_cfg.metric, summary=("max" if cfg.prediction_cfg.measure > 0 else "min")) 448 | 449 | # Training 450 | if cfg.do_train: 451 | continue_from_global_step = 0 # If set to 0, start training from the beginning 452 | if os.path.exists(cfg.output_dir) and getattr(cfg, "resume", None): 453 | checkpoint = cfg.resume 454 | logger.info("Resuming training from the latest checkpoint: %s", checkpoint) 455 | continue_from_global_step = int(checkpoint.split('-')[-1]) 456 | 457 | global_step, tr_loss = train(cfg, model_pipe, tokenizer, continue_from_global_step) 458 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 459 | 460 | 461 | if __name__ == "__main__": 462 | os.environ["HYDRA_FULL_ERROR"] = "1" 463 | 464 | hydra_formatted_args = [] 465 | # convert the cli params added by torch.distributed.launch into Hydra format 466 | for arg in sys.argv: 467 | if arg.startswith("--"): 468 | hydra_formatted_args.append(arg[len("--"):]) 469 | else: 470 | hydra_formatted_args.append(arg) 471 | sys.argv = hydra_formatted_args 472 | print(sys.argv) 473 | main() 474 | --------------------------------------------------------------------------------