├── InternLM ├── __init__.py ├── configs │ ├── kd_1b_to_300m.py │ └── pretrain_300m.py ├── internlm │ ├── __init__.py │ ├── apis │ │ ├── __init__.py │ │ └── inference.py │ ├── core │ │ ├── __init__.py │ │ ├── communication │ │ │ ├── __init__.py │ │ │ ├── p2p.py │ │ │ └── utils.py │ │ ├── context │ │ │ ├── __init__.py │ │ │ ├── parallel_context.py │ │ │ ├── process_group_initializer.py │ │ │ └── random.py │ │ ├── engine.py │ │ ├── gradient_handler.py │ │ ├── naive_amp.py │ │ ├── scheduler │ │ │ ├── __init__.py │ │ │ ├── base_scheduler.py │ │ │ ├── no_pipeline_scheduler.py │ │ │ └── pipeline_scheduler.py │ │ └── trainer.py │ ├── data │ │ ├── __init__.py │ │ ├── batch_sampler.py │ │ ├── collaters.py │ │ ├── dataset.py │ │ ├── dummy_dataset.py │ │ ├── packed_dataset.py │ │ ├── single_dataset.py │ │ └── utils.py │ ├── initialize │ │ ├── __init__.py │ │ ├── initialize_tensor.py │ │ ├── initialize_trainer.py │ │ ├── launch.py │ │ └── legacy │ │ │ ├── __init__.py │ │ │ └── launch.py │ ├── model │ │ ├── __init__.py │ │ ├── embedding.py │ │ ├── linear.py │ │ ├── loss.py │ │ ├── metrics.py │ │ ├── modeling_internlm.py │ │ ├── modeling_vit.py │ │ ├── multi_head_attention.py │ │ ├── muse │ │ │ ├── __init__.py │ │ │ ├── modeling_taming_vqgan.py │ │ │ └── modeling_utils.py │ │ ├── norm.py │ │ └── utils.py │ ├── monitor │ │ ├── __init__.py │ │ ├── alert.py │ │ ├── monitor.py │ │ └── utils.py │ ├── solver │ │ ├── __init__.py │ │ ├── beta2_scheduler.py │ │ ├── lr_scheduler.py │ │ ├── optimizer │ │ │ ├── __init__.py │ │ │ ├── hybrid_zero_optim.py │ │ │ ├── store.py │ │ │ └── utils.py │ │ └── pipeline_utils.py │ ├── train │ │ ├── __init__.py │ │ └── training_internlm.py │ └── utils │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── common.py │ │ ├── evaluation.py │ │ ├── gputest.py │ │ ├── logger.py │ │ ├── megatron_timers.py │ │ ├── model_checkpoint.py │ │ ├── parallel.py │ │ ├── registry.py │ │ ├── simple_memory_profiler.py │ │ ├── storage_manager.py │ │ ├── timeout.py │ │ └── writer.py ├── requirements │ ├── runtime.txt │ └── torch.txt ├── tools │ ├── convert2hf.py │ ├── convert2hf_vit.py │ ├── data │ │ ├── derain_prompt │ │ │ ├── 000000_img.png │ │ │ ├── 000000_label.png │ │ │ ├── 000001_img.png │ │ │ ├── 000001_label.png │ │ │ ├── 000002_img.png │ │ │ └── 000002_label.png │ │ ├── examples │ │ │ ├── derain_1.png │ │ │ ├── derain_2.png │ │ │ ├── pose_1.png │ │ │ ├── pose_2.png │ │ │ ├── seg_1.png │ │ │ └── seg_2.png │ │ ├── pose_prompt │ │ │ ├── 000000_img.png │ │ │ ├── 000000_label.png │ │ │ ├── 000001_img.png │ │ │ ├── 000001_label.png │ │ │ ├── 000002_img.png │ │ │ └── 000002_label.png │ │ └── seg_prompt │ │ │ ├── 000000_img.png │ │ │ ├── 000000_label.png │ │ │ ├── 000001_img.png │ │ │ ├── 000001_label.png │ │ │ ├── 000002_img.png │ │ │ └── 000002_label.png │ ├── demo.ipynb │ ├── model_hf │ │ ├── __init__.py │ │ ├── modeling_internlm.py │ │ ├── modeling_vit.py │ │ └── muse │ │ │ ├── __init__.py │ │ │ ├── logging.py │ │ │ ├── modeling_taming_vqgan.py │ │ │ └── modeling_utils.py │ └── utils.py └── train.py ├── README.md ├── data_generation ├── README.md ├── generate │ ├── generate_GoPro.py │ ├── generate_Rain13K.py │ ├── generate_SA-1B.py │ ├── generate_coco-keypoint.py │ ├── generate_hdvila_100m.py │ ├── generate_laion.py │ ├── img_to_token.py │ └── token_concat.py ├── requirements.txt └── vqgan │ ├── __init__.py │ ├── laion_convert.py │ ├── load.py │ ├── muse │ ├── __init__.py │ ├── logging.py │ ├── modeling_taming_vqgan.py │ └── modeling_utils.py │ └── utils.py └── figs └── DeLVM.PNG /InternLM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/__init__.py -------------------------------------------------------------------------------- /InternLM/configs/kd_1b_to_300m.py: -------------------------------------------------------------------------------- 1 | kd_config = dict(gt_weight=1., kd_weight=1., temperature=1) 2 | teacher_type = "INTERNLM" 3 | 4 | teacher_ckpt_folder = '/path/to/teacher' 5 | 6 | VQGAN_FOLDER = '/path/to/vqgan' 7 | T_SEQ_LEN = 2048 8 | T_HIDDEN_SIZE = 2048 9 | T_NUM_ATTENTION_HEAD = 16 10 | T_MLP_RATIO = 8 / 3 11 | T_NUM_LAYER = 22 12 | T_VOCAB_SIZE = 8192 13 | 14 | teacher = dict( 15 | checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] 16 | num_attention_heads=T_NUM_ATTENTION_HEAD, 17 | embed_split_hidden=True, 18 | vocab_size=T_VOCAB_SIZE, 19 | embed_grad_scale=1, 20 | parallel_output=True, 21 | hidden_size=T_HIDDEN_SIZE, 22 | num_layers=T_NUM_LAYER, 23 | mlp_ratio=T_MLP_RATIO, 24 | apply_post_layer_norm=False, 25 | dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" 26 | norm_type="rmsnorm", 27 | layer_norm_epsilon=1e-5, 28 | use_flash_attn=True, 29 | num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. 30 | lvm_config=dict( 31 | enable=True, 32 | embedding_cfg=dict( 33 | vq_model_path=VQGAN_FOLDER, 34 | embedding_dim=T_HIDDEN_SIZE, 35 | freeze_vq_model=True, 36 | ), 37 | ) 38 | ) 39 | 40 | ######################################################## 41 | JOB_NAME = "lvm_llama_kd" 42 | DO_ALERT = False 43 | model_type = "INTERNLM" 44 | 45 | SEQ_LEN = 2048 46 | HIDDEN_SIZE = 1024 47 | NUM_ATTENTION_HEAD = 8 48 | MLP_RATIO = 8 / 3 49 | NUM_LAYER = 22 50 | VOCAB_SIZE = 8192 51 | 52 | MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" 53 | SAVE_CKPT_FOLDER = "local:/path_to_save/" 54 | LOAD_CKPT_FOLDER = "local:/path_to_load/" 55 | 56 | CHECKPOINT_EVERY = 10000 57 | ckpt = dict( 58 | enable_save_ckpt=True, # set True to enable ckpt save. 59 | save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. 60 | # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["all"], ckpt_type="normal"), 61 | # load_ckpt_folder="local:llm_ckpts/", 62 | # 'load_ckpt_info' setting guide: 63 | # 1. the 'path' indicate ckpt path, 64 | # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" 65 | # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. 66 | # load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), 67 | checkpoint_every=CHECKPOINT_EVERY, 68 | async_upload=True, # async ckpt upload. (only work for boto3 ckpt) 69 | async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. 70 | # oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. 71 | oss_snapshot_freq=0, 72 | ) 73 | 74 | TRAIN_FOLDER = "/path/to/dataset" 75 | VALID_FOLDER = "/path/to/dataset" 76 | data = dict( 77 | seq_len=SEQ_LEN, 78 | # micro_num means the number of micro_batch contained in one gradient update 79 | micro_num=1, 80 | # packed_length = micro_bsz * SEQ_LEN 81 | micro_bsz=16, 82 | # defaults to the value of micro_num 83 | valid_micro_num=1, 84 | # defaults to 0, means disable evaluate 85 | valid_every=0, 86 | pack_sample_into_one=False, 87 | train_one_epoch=False, 88 | total_steps=40000, 89 | skip_batches="", 90 | rampup_batch_size="", 91 | # Datasets with less than 50 rows will be discarded 92 | min_length=50, 93 | train_folder=TRAIN_FOLDER, 94 | valid_folder=None, 95 | empty_cache_and_diag_interval=10000, 96 | diag_outlier_ratio=1.1, 97 | ) 98 | 99 | grad_scaler = dict( 100 | fp16=dict( 101 | # the initial loss scale, defaults to 2**16 102 | initial_scale=2**16, 103 | # the minimum loss scale, defaults to None 104 | min_scale=1, 105 | # the number of steps to increase loss scale when no overflow occurs 106 | growth_interval=1000, 107 | ), 108 | # the multiplication factor for increasing loss scale, defaults to 2 109 | growth_factor=2, 110 | # the multiplication factor for decreasing loss scale, defaults to 0.5 111 | backoff_factor=0.5, 112 | # the maximum loss scale, defaults to None 113 | max_scale=2**24, 114 | # the number of overflows before decreasing loss scale, defaults to 2 115 | hysteresis=2, 116 | ) 117 | 118 | hybrid_zero_optimizer = dict( 119 | # Enable low_level_optimzer overlap_communication 120 | overlap_sync_grad=True, 121 | overlap_sync_param=True, 122 | # bucket size for nccl communication params 123 | reduce_bucket_size=512 * 1024 * 1024, 124 | # grad clipping 125 | clip_grad_norm=1.0, 126 | ) 127 | 128 | loss = dict( 129 | label_smoothing=0, 130 | ) 131 | 132 | adam = dict( 133 | lr=1.5e-4, 134 | adam_beta1=0.9, 135 | adam_beta2=0.95, 136 | adam_beta2_c=0, 137 | adam_eps=1e-8, 138 | weight_decay=0.1, 139 | ) 140 | 141 | lr_scheduler = dict( 142 | total_steps=data["total_steps"], 143 | init_steps=0, # optimizer_warmup_step 144 | warmup_ratio=0.0056, 145 | eta_min=1.5e-5, 146 | last_epoch=-1, 147 | ) 148 | 149 | beta2_scheduler = dict( 150 | init_beta2=adam["adam_beta2"], 151 | c=adam["adam_beta2_c"], 152 | cur_iter=-1, 153 | ) 154 | 155 | model = dict( 156 | checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] 157 | num_attention_heads=NUM_ATTENTION_HEAD, 158 | embed_split_hidden=True, 159 | vocab_size=VOCAB_SIZE, 160 | embed_grad_scale=1, 161 | parallel_output=True, 162 | hidden_size=HIDDEN_SIZE, 163 | num_layers=NUM_LAYER, 164 | mlp_ratio=MLP_RATIO, 165 | apply_post_layer_norm=False, 166 | dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" 167 | norm_type="rmsnorm", 168 | layer_norm_epsilon=1e-5, 169 | use_flash_attn=True, 170 | num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. 171 | lvm_config=dict( 172 | enable=True, 173 | embedding_cfg=dict( 174 | vq_model_path='/cache/ckpt/vqgan-f16-8192-laion/', 175 | embedding_dim=HIDDEN_SIZE, 176 | freeze_vq_model=True, 177 | ), 178 | ) 179 | ) 180 | """ 181 | zero1 parallel: 182 | 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, 183 | so parameters will be divided within the range of dp. 184 | 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. 185 | 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. 186 | For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. 187 | pipeline parallel (dict): 188 | 1. size: int, the size of pipeline parallel. 189 | 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. 190 | tensor parallel: tensor parallel size, usually the number of GPUs per node. 191 | """ 192 | parallel = dict( 193 | zero1=8, 194 | pipeline=dict(size=1, interleaved_overlap=True), 195 | sequence_parallel=False, 196 | ) 197 | 198 | cudnn_deterministic = False 199 | cudnn_benchmark = False 200 | 201 | monitor = dict( 202 | # feishu alert configs 203 | alert=dict( 204 | enable_feishu_alert=DO_ALERT, 205 | feishu_alert_address=None, # feishu webhook to send alert message 206 | light_monitor_address=None, # light_monitor address to send heartbeat 207 | ), 208 | ) 209 | -------------------------------------------------------------------------------- /InternLM/configs/pretrain_300m.py: -------------------------------------------------------------------------------- 1 | JOB_NAME = "lvm_llama" 2 | DO_ALERT = False 3 | model_type = "INTERNLM" 4 | 5 | SEQ_LEN = 2048 6 | HIDDEN_SIZE = 1024 7 | NUM_ATTENTION_HEAD = 8 8 | MLP_RATIO = 8 / 3 9 | NUM_LAYER = 22 10 | VOCAB_SIZE = 8192 11 | 12 | MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" 13 | SAVE_CKPT_FOLDER = "local:/path_to_save/" 14 | LOAD_CKPT_FOLDER = "local:/path_to_load/" 15 | 16 | CHECKPOINT_EVERY = 10000 17 | ckpt = dict( 18 | enable_save_ckpt=True, # set True to enable ckpt save. 19 | save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. 20 | # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["all"], ckpt_type="normal"), 21 | # load_ckpt_folder="local:llm_ckpts/", 22 | # 'load_ckpt_info' setting guide: 23 | # 1. the 'path' indicate ckpt path, 24 | # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" 25 | # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. 26 | # load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), 27 | checkpoint_every=CHECKPOINT_EVERY, 28 | async_upload=True, # async ckpt upload. (only work for boto3 ckpt) 29 | async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. 30 | # oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. 31 | oss_snapshot_freq=0, 32 | ) 33 | 34 | TRAIN_FOLDER = "/path/to/dataset" 35 | VALID_FOLDER = "/path/to/dataset" 36 | data = dict( 37 | seq_len=SEQ_LEN, 38 | # micro_num means the number of micro_batch contained in one gradient update 39 | micro_num=1, 40 | # packed_length = micro_bsz * SEQ_LEN 41 | micro_bsz=16, 42 | # defaults to the value of micro_num 43 | valid_micro_num=1, 44 | # defaults to 0, means disable evaluate 45 | valid_every=0, 46 | pack_sample_into_one=False, 47 | train_one_epoch=False, 48 | total_steps=40000, 49 | skip_batches="", 50 | rampup_batch_size="", 51 | # Datasets with less than 50 rows will be discarded 52 | min_length=50, 53 | train_folder=TRAIN_FOLDER, 54 | valid_folder=None, 55 | empty_cache_and_diag_interval=10000, 56 | diag_outlier_ratio=1.1, 57 | ) 58 | 59 | grad_scaler = dict( 60 | fp16=dict( 61 | # the initial loss scale, defaults to 2**16 62 | initial_scale=2**16, 63 | # the minimum loss scale, defaults to None 64 | min_scale=1, 65 | # the number of steps to increase loss scale when no overflow occurs 66 | growth_interval=1000, 67 | ), 68 | # the multiplication factor for increasing loss scale, defaults to 2 69 | growth_factor=2, 70 | # the multiplication factor for decreasing loss scale, defaults to 0.5 71 | backoff_factor=0.5, 72 | # the maximum loss scale, defaults to None 73 | max_scale=2**24, 74 | # the number of overflows before decreasing loss scale, defaults to 2 75 | hysteresis=2, 76 | ) 77 | 78 | hybrid_zero_optimizer = dict( 79 | # Enable low_level_optimzer overlap_communication 80 | overlap_sync_grad=True, 81 | overlap_sync_param=True, 82 | # bucket size for nccl communication params 83 | reduce_bucket_size=512 * 1024 * 1024, 84 | # grad clipping 85 | clip_grad_norm=1.0, 86 | ) 87 | 88 | loss = dict( 89 | label_smoothing=0, 90 | ) 91 | 92 | adam = dict( 93 | lr=1.5e-4, 94 | adam_beta1=0.9, 95 | adam_beta2=0.95, 96 | adam_beta2_c=0, 97 | adam_eps=1e-8, 98 | weight_decay=0.1, 99 | ) 100 | 101 | lr_scheduler = dict( 102 | total_steps=data["total_steps"], 103 | init_steps=0, # optimizer_warmup_step 104 | warmup_ratio=0.0056, 105 | eta_min=1.5e-5, 106 | last_epoch=-1, 107 | ) 108 | 109 | beta2_scheduler = dict( 110 | init_beta2=adam["adam_beta2"], 111 | c=adam["adam_beta2_c"], 112 | cur_iter=-1, 113 | ) 114 | 115 | model = dict( 116 | checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] 117 | num_attention_heads=NUM_ATTENTION_HEAD, 118 | embed_split_hidden=True, 119 | vocab_size=VOCAB_SIZE, 120 | embed_grad_scale=1, 121 | parallel_output=True, 122 | hidden_size=HIDDEN_SIZE, 123 | num_layers=NUM_LAYER, 124 | mlp_ratio=MLP_RATIO, 125 | apply_post_layer_norm=False, 126 | dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" 127 | norm_type="rmsnorm", 128 | layer_norm_epsilon=1e-5, 129 | use_flash_attn=True, 130 | num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. 131 | lvm_config=dict( 132 | enable=True, 133 | embedding_cfg=dict( 134 | vq_model_path='/cache/ckpt/vqgan-f16-8192-laion/', 135 | embedding_dim=HIDDEN_SIZE, 136 | freeze_vq_model=True, 137 | ), 138 | ) 139 | ) 140 | """ 141 | zero1 parallel: 142 | 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, 143 | so parameters will be divided within the range of dp. 144 | 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. 145 | 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. 146 | For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. 147 | pipeline parallel (dict): 148 | 1. size: int, the size of pipeline parallel. 149 | 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. 150 | tensor parallel: tensor parallel size, usually the number of GPUs per node. 151 | """ 152 | parallel = dict( 153 | zero1=8, 154 | pipeline=dict(size=1, interleaved_overlap=True), 155 | sequence_parallel=False, 156 | ) 157 | 158 | cudnn_deterministic = False 159 | cudnn_benchmark = False 160 | 161 | monitor = dict( 162 | # feishu alert configs 163 | alert=dict( 164 | enable_feishu_alert=DO_ALERT, 165 | feishu_alert_address=None, # feishu webhook to send alert message 166 | light_monitor_address=None, # light_monitor address to send heartbeat 167 | ), 168 | ) 169 | -------------------------------------------------------------------------------- /InternLM/internlm/__init__.py: -------------------------------------------------------------------------------- 1 | from .initialize.initialize_trainer import initialize_trainer, initialize_kd_trainer 2 | from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch 3 | 4 | __all__ = [ 5 | "get_default_parser", 6 | "initialize_kd_trainer", 7 | "initialize_trainer", 8 | "launch_from_slurm", 9 | "launch_from_torch", 10 | ] 11 | -------------------------------------------------------------------------------- /InternLM/internlm/apis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/internlm/apis/__init__.py -------------------------------------------------------------------------------- /InternLM/internlm/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import Engine 2 | from .naive_amp import NaiveAMPModel 3 | from .trainer import Trainer 4 | 5 | __all__ = [ 6 | "NaiveAMPModel", 7 | "Engine", 8 | "Trainer", 9 | ] 10 | -------------------------------------------------------------------------------- /InternLM/internlm/core/communication/__init__.py: -------------------------------------------------------------------------------- 1 | from .p2p import ( 2 | AsynCommunicator, 3 | recv_backward, 4 | recv_forward, 5 | send_backward, 6 | send_backward_and_recv_next_backward_async, 7 | send_backward_recv_backward, 8 | send_backward_recv_forward, 9 | send_forward, 10 | send_forward_and_recv_next_forward_async, 11 | send_forward_backward_recv_forward_backward, 12 | send_forward_recv_backward, 13 | send_forward_recv_forward, 14 | ) 15 | from .utils import recv_obj_meta, send_obj_meta 16 | 17 | __all__ = [ 18 | "send_forward", 19 | "send_forward_recv_forward", 20 | "send_forward_backward_recv_forward_backward", 21 | "send_backward", 22 | "send_backward_recv_backward", 23 | "send_backward_recv_forward", 24 | "send_forward_recv_backward", 25 | "recv_backward", 26 | "recv_forward", 27 | "send_obj_meta", 28 | "recv_obj_meta", 29 | "send_backward_and_recv_next_backward_async", 30 | "send_forward_and_recv_next_forward_async", 31 | "AsynCommunicator", 32 | ] 33 | -------------------------------------------------------------------------------- /InternLM/internlm/core/communication/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication 2 | 3 | from typing import List, Tuple, Union 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from internlm.core.context import ParallelMode 9 | from internlm.core.context import global_context as gpc 10 | from internlm.utils.common import get_current_device 11 | 12 | TensorShape = Union[torch.Size, List[int], Tuple[int]] 13 | 14 | 15 | def send_meta_helper(obj, next_rank, tensor_kwargs): 16 | send_shape = torch.tensor(obj.size(), **tensor_kwargs) 17 | send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) 18 | dist.send(send_ndims, next_rank) 19 | dist.send(send_shape, next_rank) 20 | 21 | 22 | def send_obj_meta(obj, next_rank=None): 23 | """Sends obj meta information before sending a specific obj. 24 | Since the recipient must know the shape of the obj in p2p communications, 25 | meta information of the obj should be sent before communications. This function 26 | synchronizes with :func:`recv_obj_meta`. 27 | 28 | Args: 29 | obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. 30 | need_meta (bool, optional): If False, meta information won't be sent. 31 | next_rank (int): The rank of the next member in pipeline parallel group. 32 | 33 | Returns: 34 | bool: False 35 | """ 36 | if next_rank is None: 37 | next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) 38 | 39 | tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} 40 | if isinstance(obj, torch.Tensor): 41 | send_obj_nums = torch.tensor(1, **tensor_kwargs) 42 | dist.send(send_obj_nums, next_rank) 43 | send_meta_helper(obj, next_rank, tensor_kwargs) 44 | else: 45 | send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) 46 | dist.send(send_obj_nums, next_rank) 47 | for tensor_to_send in obj: 48 | send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) 49 | 50 | 51 | def recv_meta_helper(prev_rank, tensor_kwargs): 52 | recv_ndims = torch.empty((), **tensor_kwargs) 53 | dist.recv(recv_ndims, prev_rank) 54 | recv_shape = torch.empty(recv_ndims, **tensor_kwargs) 55 | dist.recv(recv_shape, prev_rank) 56 | return recv_shape 57 | 58 | 59 | def recv_obj_meta(prev_rank=None) -> torch.Size: 60 | """Receives obj meta information before receiving a specific obj. 61 | Since the recipient must know the shape of the obj in p2p communications, 62 | meta information of the obj should be received before communications. This function 63 | synchronizes with :func:`send_obj_meta`. 64 | 65 | Args: 66 | obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. 67 | prev_rank (int): The rank of the source of the obj. 68 | 69 | Returns: 70 | Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. 71 | """ 72 | if prev_rank is None: 73 | prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) 74 | 75 | tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} 76 | recv_obj_nums = torch.empty((), **tensor_kwargs) 77 | dist.recv(recv_obj_nums, prev_rank) 78 | if recv_obj_nums.item() == 1: 79 | recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) 80 | obj_shape = torch.Size(recv_shape) 81 | else: 82 | obj_shape = [] 83 | for _ in range(recv_obj_nums.item()): 84 | recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) 85 | obj_shape.append(torch.Size(recv_shape)) 86 | 87 | return obj_shape 88 | 89 | 90 | def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: 91 | """Break a tensor into equal 1D chunks. 92 | 93 | Args: 94 | tensor (:class:`torch.Tensor`): Tensor to be split before communication. 95 | new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. 96 | 97 | Returns: 98 | :class:`torch.Tensor`: The split tensor 99 | """ 100 | partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR) 101 | start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR) 102 | end_index = start_index + partition_size 103 | if new_buffer: 104 | data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) 105 | data.copy_(tensor.view(-1)[start_index:end_index]) 106 | else: 107 | data = tensor.view(-1)[start_index:end_index] 108 | return data 109 | 110 | 111 | def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: 112 | """Opposite of above function, gather values from model parallel ranks. 113 | 114 | Args: 115 | tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. 116 | Returns: 117 | :class:`torch.Tensor`: The gathered tensor. 118 | """ 119 | world_size = gpc.get_world_size(ParallelMode.TENSOR) 120 | numel = torch.numel(tensor) 121 | numel_gathered = world_size * numel 122 | gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) 123 | chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] 124 | dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR)) 125 | return gathered 126 | -------------------------------------------------------------------------------- /InternLM/internlm/core/context/__init__.py: -------------------------------------------------------------------------------- 1 | from .parallel_context import ( 2 | IS_TENSOR_PARALLEL, 3 | Config, 4 | ParallelContext, 5 | global_context, 6 | ) 7 | from .process_group_initializer import ( 8 | Initializer_Data, 9 | Initializer_Model, 10 | Initializer_Nettest, 11 | Initializer_Pipeline, 12 | Initializer_Tensor, 13 | Initializer_Zero1, 14 | ParallelMode, 15 | ProcessGroupInitializer, 16 | ) 17 | from .random import ( 18 | add_seed, 19 | get_current_mode, 20 | get_seeds, 21 | get_states, 22 | seed, 23 | set_mode, 24 | set_seed_states, 25 | sync_states, 26 | ) 27 | 28 | __all__ = [ 29 | "Config", 30 | "IS_TENSOR_PARALLEL", 31 | "global_context", 32 | "ParallelContext", 33 | "ParallelMode", 34 | "Initializer_Tensor", 35 | "Initializer_Pipeline", 36 | "Initializer_Data", 37 | "Initializer_Zero1", 38 | "Initializer_Nettest", 39 | "ProcessGroupInitializer", 40 | "Initializer_Model", 41 | "seed", 42 | "set_mode", 43 | "add_seed", 44 | "get_seeds", 45 | "get_states", 46 | "get_current_mode", 47 | "set_seed_states", 48 | "sync_states", 49 | ] 50 | -------------------------------------------------------------------------------- /InternLM/internlm/core/context/random.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context 4 | 5 | from contextlib import contextmanager 6 | 7 | import torch 8 | import torch.cuda 9 | from torch import Tensor 10 | 11 | from .process_group_initializer import ParallelMode 12 | 13 | 14 | class SeedManager: 15 | """This class is a manager of all random seeds involved in the system.""" 16 | 17 | def __init__(self): 18 | self._current_mode = None 19 | self._seeds = {} 20 | self._seed_states = {} 21 | 22 | @property 23 | def current_mode(self): 24 | return self._current_mode 25 | 26 | @property 27 | def seeds(self): 28 | return self._seeds 29 | 30 | @property 31 | def seed_states(self): 32 | return self._seed_states 33 | 34 | def set_state(self, parallel_mode: ParallelMode, state: Tensor): 35 | """Sets the state of the seed manager for `parallel_mode`.""" 36 | assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager" 37 | self._seed_states[parallel_mode] = state 38 | 39 | def set_mode(self, parallel_mode: ParallelMode): 40 | """Sets the current mode of the seed manager.""" 41 | if self.current_mode: 42 | # save state for current mode 43 | self._seed_states[self._current_mode] = torch.cuda.get_rng_state() 44 | 45 | # set new state for new mode 46 | self._current_mode = parallel_mode 47 | torch.cuda.set_rng_state(self._seed_states[parallel_mode]) 48 | 49 | def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False): 50 | """Adds a seed to the seed manager for `parallel_mode`.""" 51 | assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode" 52 | if not overwrite: 53 | assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists" 54 | elif parallel_mode in self._seed_states: 55 | print(f"Warning: {parallel_mode} seed overwritten.", flush=True) 56 | 57 | current_state = torch.cuda.get_rng_state() 58 | torch.cuda.manual_seed(seed) 59 | self._seed_states[parallel_mode] = torch.cuda.get_rng_state() 60 | self._seeds[parallel_mode] = seed 61 | torch.cuda.set_rng_state(current_state) 62 | 63 | def reset(self): 64 | self._current_mode = None 65 | self._seeds = {} 66 | self._seed_states = {} 67 | 68 | 69 | _SEED_MANAGER = SeedManager() 70 | 71 | 72 | def get_seeds(): 73 | """Returns the seeds of the seed manager. 74 | Returns: 75 | dict: The seeds of the seed manager. 76 | """ 77 | return _SEED_MANAGER.seeds 78 | 79 | 80 | def get_states(copy=False): 81 | """Returns the seed states of the seed manager. 82 | Returns: 83 | dict: The seed states of the seed manager. 84 | """ 85 | states = _SEED_MANAGER.seed_states 86 | if copy: 87 | new_states = dict() 88 | for parallel_mode, state in states.items(): 89 | new_states[parallel_mode] = state.clone() 90 | return new_states 91 | else: 92 | return _SEED_MANAGER.seed_states 93 | 94 | 95 | def get_current_mode(): 96 | """Returns the current mode of the seed manager. 97 | Returns: 98 | :class:`torch.ByteTensor`: The current mode of the seed manager. 99 | """ 100 | return _SEED_MANAGER.current_mode 101 | 102 | 103 | def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): 104 | """Adds a seed to the seed manager for `parallel_mode`.""" 105 | _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite) 106 | 107 | 108 | def set_mode(parallel_mode: ParallelMode): 109 | """Sets the current mode of the seed manager.""" 110 | _SEED_MANAGER.set_mode(parallel_mode) 111 | 112 | 113 | def set_seed_states(parallel_mode: ParallelMode, state: Tensor): 114 | """Sets the state of the seed manager for `parallel_mode`.""" 115 | _SEED_MANAGER.set_state(parallel_mode, state) 116 | 117 | 118 | def sync_states(): 119 | current_mode = get_current_mode() 120 | current_states = torch.cuda.get_rng_state() 121 | set_seed_states(current_mode, current_states) 122 | 123 | 124 | @contextmanager 125 | def seed(parallel_mode: ParallelMode): 126 | """A context for seed switch""" 127 | current_mode = _SEED_MANAGER.current_mode 128 | try: 129 | yield _SEED_MANAGER.set_mode(parallel_mode) 130 | finally: 131 | _SEED_MANAGER.set_mode(current_mode) 132 | -------------------------------------------------------------------------------- /InternLM/internlm/core/gradient_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from abc import ABC, abstractmethod 5 | from collections import defaultdict 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 10 | 11 | from internlm.core.context import global_context as gpc 12 | 13 | 14 | class BaseGradientHandler(ABC): 15 | """A basic helper class to handle all-reduce operations of gradients across different parallel groups 16 | before optimization. 17 | 18 | Args: 19 | model (Module): Model where the gradients accumulate. 20 | optimizer (Optimizer): Optimizer for updating the parameters. 21 | """ 22 | 23 | def __init__(self, model, optimizer): 24 | self._model = model 25 | self._optimizer = optimizer 26 | 27 | @abstractmethod 28 | def handle_gradient(self): 29 | """A method to accumulate gradients across different parallel groups. Users should 30 | write their own functions or just use the functions in pre-defined subclasses. 31 | """ 32 | pass 33 | 34 | 35 | class PipelineSharedModuleGradientHandler(BaseGradientHandler): 36 | """A helper class to handle all-reduce operations in sub parallel groups. 37 | A all-reduce collective communication will be operated in 38 | :func:`handle_gradient` among all sub pipeline parallel groups. 39 | For better performance, it bucketizes the gradients of all parameters that are 40 | the same type to improve the efficiency of communication. 41 | 42 | Args: 43 | model (Module): Model where the gradients accumulate. 44 | optimizer (Optimizer): Optimizer for updating the parameters. 45 | """ 46 | 47 | def handle_gradient(self): 48 | """A method running a all-reduce operation in sub pipeline parallel groups.""" 49 | if gpc.pipeline_parallel_size > 1: 50 | # bucketize and all-reduce 51 | buckets = defaultdict(lambda: defaultdict(list)) 52 | # Pack the buckets. 53 | for param in self._model.parameters(): 54 | group = getattr(param, "pipeline_shared_module_pg", None) 55 | if ( 56 | param.requires_grad 57 | and group is not None 58 | and ( 59 | (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null()) 60 | or param.grad is not None 61 | ) 62 | ): 63 | tp = param.data.type() 64 | buckets[group][tp].append(param) 65 | 66 | # For each bucket, all-reduce and copy all-reduced grads. 67 | for group, group_buckets in buckets.items(): 68 | for tp, bucket in group_buckets.items(): 69 | grads = [ 70 | param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data 71 | for param in bucket 72 | ] 73 | coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) 74 | dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) 75 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 76 | buf.copy_(synced) 77 | -------------------------------------------------------------------------------- /InternLM/internlm/core/naive_amp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp 5 | 6 | from typing import Any 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch import Tensor, nn 11 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 12 | from torch.distributed import ReduceOp 13 | 14 | from internlm.core.context import ParallelMode 15 | from internlm.core.context.parallel_context import global_context as gpc 16 | 17 | 18 | class NaiveAMPModel(nn.Module): 19 | """ 20 | This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. 21 | It also provides options to cast the output back to fp32 and to synchronize buffers. 22 | 23 | Args: 24 | model (torch.nn.Module): The model to be wrapped and cast into fp16. 25 | output_to_fp32 (bool, optional): If True, the output of this module is cast into fp32. Defaults to True. 26 | parallel_mode (:class:`internlm.core.context.ParallelMode`): The parallel group mode used in this module. 27 | Defaults to ``ParallelMode.DATA``. 28 | sync_buffer (bool, optional): If True, the buffers are synchronized. Defaults to True. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | model: nn.Module, 34 | output_to_fp32: bool = True, 35 | parallel_mode: ParallelMode = ParallelMode.DATA, 36 | sync_buffer: bool = True, 37 | dtype=torch.float16, 38 | ): 39 | super().__init__() 40 | self.model = model.to(dtype) 41 | self._output_to_fp32 = output_to_fp32 42 | self._sync_buf = sync_buffer 43 | self.dtype = dtype 44 | 45 | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: 46 | self._process_group = gpc.get_group(parallel_mode) 47 | self._world_size = gpc.get_world_size(parallel_mode) 48 | else: 49 | self._process_group = None 50 | self._world_size = 1 51 | self._sync_buf = False 52 | self._first_eval_run = False 53 | 54 | @property 55 | def sync_buffer(self): 56 | """Returns the current state of the buffer synchronization.""" 57 | return self._sync_buf 58 | 59 | @sync_buffer.setter 60 | def sync_buffer(self, state: bool): 61 | """Sets the state of the buffer synchronization.""" 62 | self._sync_buf = state 63 | 64 | def _convert_to_fp16(self, input_: Any): 65 | """Converts the input to fp16 if it is a Tensor of dtype float32.""" 66 | if isinstance(input_, Tensor) and input_.dtype == torch.float32: 67 | input_ = input_.to(self.dtype) 68 | return input_ 69 | 70 | def _convert_to_fp32(self, input_: Any): 71 | """Converts the input to fp32 if it is a Tensor of dtype float16.""" 72 | if isinstance(input_, Tensor) and input_.dtype == torch.float16: 73 | input_ = input_.float() 74 | return input_ 75 | 76 | def convert_to_fp32(self, out): 77 | """Converts the output to fp32""" 78 | if isinstance(out, Tensor): 79 | out = self._convert_to_fp32(out) 80 | elif isinstance(out, (tuple, list)): 81 | out = [self._convert_to_fp32(val) for val in out] 82 | elif isinstance(out, dict): 83 | out = {key: self._convert_to_fp32(val) for key, val in out.items()} 84 | 85 | return out 86 | 87 | def _reduce_module_buffer(self): 88 | """ 89 | All-reduces the buffers (e.g., running stats of batch normalization) across 90 | data parallel ranks so that all the ranks will produce consistent results 91 | when given the same input. 92 | """ 93 | buf_list = [] 94 | 95 | # find valid buffers 96 | for buf in self.model.buffers(): 97 | if buf is not None: 98 | buf_list.append(buf) 99 | 100 | # reduce buffers across data parallel ranks 101 | if buf_list: 102 | coalesced_buf = _flatten_dense_tensors(buf_list) 103 | coalesced_buf.div_(self._world_size) 104 | dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group) 105 | unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list) 106 | for old, new in zip(buf_list, unflattened_buf_list): 107 | old.copy_(new) 108 | 109 | def eval(self): 110 | """Sets the model to evaluation mode. Buffers are only synchronized in the first eval iteration.""" 111 | self.model.eval() 112 | self._first_eval_run = True 113 | 114 | def forward(self, *args, **kwargs): 115 | """ 116 | Performs a forward pass on the model. Buffers are synchronized before the forward pass. 117 | The inputs are converted to fp16 and the outputs are optionally converted back to fp32. 118 | """ 119 | if (self.training or self._first_eval_run) and self._sync_buf: 120 | with torch.no_grad(): 121 | self._reduce_module_buffer() 122 | 123 | if self._first_eval_run: 124 | self._first_eval_run = False 125 | 126 | if args: 127 | args = [self._convert_to_fp16(arg) for arg in args] 128 | if kwargs: 129 | for k, v in kwargs.items(): 130 | kwargs[k] = self._convert_to_fp16(v) 131 | 132 | out = self.model(*args, **kwargs) 133 | 134 | if self._output_to_fp32: 135 | out = self.convert_to_fp32(out) 136 | return out 137 | -------------------------------------------------------------------------------- /InternLM/internlm/core/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook 2 | from .no_pipeline_scheduler import NonPipelineScheduler, KDNonPipelineScheduler 3 | from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler, KDPipelineScheduler 4 | 5 | __all__ = [ 6 | "BaseScheduler", 7 | "NonPipelineScheduler", 8 | "KDNonPipelineScheduler", 9 | "InterleavedPipelineScheduler", 10 | "PipelineScheduler", 11 | "KDPipelineScheduler", 12 | "SchedulerHook", 13 | "SchedulerMetricHook", 14 | ] 15 | -------------------------------------------------------------------------------- /InternLM/internlm/core/scheduler/base_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine 5 | 6 | from abc import ABC, abstractmethod 7 | from typing import Any, Callable, Iterable, Optional 8 | 9 | import torch 10 | 11 | from internlm.core.engine import Engine 12 | from internlm.utils.megatron_timers import megatron_timer as timer 13 | 14 | 15 | class BaseScheduler(ABC): 16 | """A basic helper class to control the process of training or evaluation. 17 | It mainly composes of forward_backward_step for gradient backward and 18 | optimizer_step for parameters update. 19 | For the convenience to enable FP16, we aggregate all codes that contain the 20 | control of FP16 in class schedule. 21 | 22 | Args: 23 | data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges 24 | them into data and label. 25 | """ 26 | 27 | def __init__(self, data_process_func: Callable = None): 28 | self.data_process_func = data_process_func 29 | 30 | @abstractmethod 31 | def pre_processing(self, engine: Engine): 32 | """To perform actions before running the schedule. 33 | 34 | Args: 35 | engine (internlm.core.Engine): InternLM engine for training and inference. 36 | """ 37 | pass 38 | 39 | def _load_micro_batch(self, data, label, offset, micro_bsz): 40 | assert isinstance(data, dict) and isinstance(label, torch.Tensor) 41 | micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()} 42 | micro_batch_label = label[offset : offset + micro_bsz] 43 | 44 | return micro_batch_data, micro_batch_label 45 | 46 | @abstractmethod 47 | def forward_backward_step( 48 | self, 49 | engine: Engine, 50 | data_iter: Iterable, 51 | forward_only: bool, 52 | return_loss: bool = True, 53 | return_output_label: bool = True, 54 | ): 55 | """The process function over a batch of dataset for training or evaluation. 56 | 57 | Args: 58 | engine (internlm.core.Engine): InternLM engine for training and inference. 59 | data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader). 60 | forward_only (bool): If True, the process won't include backward. 61 | return_loss (bool, optional): If False, the loss won't be returned. 62 | return_output_label (bool, optional): If False, the output and label won't be returned. 63 | """ 64 | pass 65 | 66 | @staticmethod 67 | def _call_engine(engine: Engine, inputs: Any): 68 | """Calls the engine with the given inputs. 69 | 70 | Args: 71 | engine (internlm.core.Engine): InternLM engine for training and inference. 72 | inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict. 73 | """ 74 | if isinstance(inputs, torch.Tensor): 75 | return engine(inputs) 76 | elif isinstance(inputs, (list, tuple)): 77 | return engine(*inputs) 78 | elif isinstance(inputs, dict): 79 | return engine(**inputs) 80 | else: 81 | raise TypeError( 82 | f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}" 83 | ) 84 | 85 | @staticmethod 86 | def _call_engine_criterion(criterion, outputs: Any, labels: Any): 87 | """Calls the engine's criterion with the given outputs and labels. 88 | 89 | Args: 90 | engine (internlm.core.Engine): InternLM engine for training and inference. 91 | outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict. 92 | labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict. 93 | """ 94 | assert isinstance( 95 | outputs, (torch.Tensor, list, tuple, dict) 96 | ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}" 97 | if isinstance(outputs, torch.Tensor): 98 | outputs = (outputs,) 99 | if isinstance(labels, torch.Tensor): 100 | labels = (labels,) 101 | 102 | if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)): 103 | return criterion(*outputs, *labels) 104 | elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict): 105 | return criterion(*outputs, **labels) 106 | elif isinstance(outputs, dict) and isinstance(labels, dict): 107 | return criterion(**outputs, **labels) 108 | elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): 109 | raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") 110 | else: 111 | raise TypeError( 112 | f"Expected model outputs and labels to be of type torch.Tensor ' \ 113 | '(which is auto-converted to tuple), list, tuple, or dict, ' \ 114 | 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" 115 | ) 116 | 117 | 118 | class SchedulerHook(ABC): 119 | """ 120 | Scheduler Hook. 121 | """ 122 | 123 | @abstractmethod 124 | def before_forward(self, scheduler, inputs) -> None: 125 | """Actions before forward""" 126 | 127 | @abstractmethod 128 | def after_forward(self, scheduler, outputs) -> None: 129 | """Actions after forward""" 130 | 131 | @abstractmethod 132 | def before_criterion(self, scheduler, outputs, label) -> None: 133 | """Actions before criterion""" 134 | 135 | @abstractmethod 136 | def after_criterion(self, scheduler, loss) -> None: 137 | """Actions after criterion""" 138 | 139 | @abstractmethod 140 | def before_backward(self, scheduler, outputs, outputs_grad) -> None: 141 | """Actions before backward""" 142 | 143 | @abstractmethod 144 | def after_backward(self, scheduler, inputs_grad) -> None: 145 | """Actions after backward""" 146 | 147 | @abstractmethod 148 | def post_helper_func(self, scheduler, outputs, label) -> None: 149 | """A post helper function""" 150 | 151 | 152 | class SchedulerMetricHook(SchedulerHook): 153 | """ 154 | Scheduler Metric Hook. 155 | """ 156 | 157 | def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None: 158 | self._post_func = metric 159 | self._skip = skip 160 | 161 | def before_forward(self, scheduler, inputs) -> None: 162 | if not self._skip: 163 | timer("fwd").start() 164 | 165 | def after_forward(self, scheduler, outputs) -> None: 166 | if not self._skip: 167 | timer("fwd").stop() 168 | 169 | def before_criterion(self, scheduler, outputs, label) -> None: 170 | if not self._skip: 171 | timer("cal_loss").start() 172 | 173 | def after_criterion(self, scheduler, loss) -> None: 174 | if not self._skip: 175 | timer("cal_loss").stop() 176 | 177 | def before_backward(self, scheduler, outputs, outputs_grad) -> None: 178 | if not self._skip: 179 | timer("bwd").start() 180 | 181 | def after_backward(self, scheduler, inputs_grad) -> None: 182 | if not self._skip: 183 | timer("bwd").stop() 184 | 185 | def post_helper_func(self, scheduler, outputs, label) -> None: 186 | if self._post_func is not None: 187 | self._post_func(outputs, label) 188 | -------------------------------------------------------------------------------- /InternLM/internlm/core/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine 5 | 6 | import json 7 | from typing import Iterable, Optional 8 | 9 | from internlm.core.engine import Engine 10 | from internlm.core.scheduler import ( 11 | BaseScheduler, 12 | InterleavedPipelineScheduler, 13 | NonPipelineScheduler, 14 | PipelineScheduler, 15 | ) 16 | 17 | 18 | class TrainState: 19 | """ 20 | The TrainState class is used to record the current state of training. 21 | 22 | Args: 23 | train_dl (DataLoader): The DataLoader object used for training. 24 | """ 25 | 26 | def __init__(self, config, batch_sampler) -> None: 27 | """ 28 | Args: 29 | config (Config): internlm config 30 | batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is 31 | asynchronous and prefetched, the batch_sampler state maintained inside the 32 | dataloader are faster then the actual training progress, so we copy the 33 | batch_sampler as the anchor point of ckpt reload. 34 | """ 35 | # The number of batches produced by the data iterator 36 | self.batch_count: int = 0 37 | # Used to store the number of samples consumed in the current epoch 38 | self.num_consumed_samples_in_epoch: int = 0 39 | # Total number of tokens consumed 40 | self.num_consumed_tokens: int = 0 41 | # Number of batches skipped due to inf or nan values 42 | self.inf_nan_skip_batches: int = 0 43 | # Records the number of updates, skipped batches and inf batches are not counted 44 | self.step_count: int = 0 45 | 46 | # Total step count 47 | self.total_steps: int = config.data.total_steps 48 | 49 | # resume tensorboard folder, need load from checkpoint or set manually. 50 | self.resume_tb_folder = config.resume_tb_folder 51 | 52 | self.tensorboard_folder = config.tensorboard_folder 53 | 54 | # learning rate 55 | self.lr = config.adam.lr 56 | 57 | # smapler state 58 | if batch_sampler: 59 | self.init_batch_sampler(batch_sampler) 60 | 61 | def init_batch_sampler(self, batch_sampler): 62 | """ 63 | Args: 64 | batch_sampler (torch.utils.data.Sampler): sampler. 65 | """ 66 | # make a copy of batch_sampler. 67 | self.batch_sampler = batch_sampler.copy() 68 | # Iterator for the batch sampler 69 | self.batch_sampler_iter = iter(self.batch_sampler) 70 | 71 | def __str__(self) -> str: 72 | """Returns a string representation of the training state in JSON format.""" 73 | info = { 74 | "batch_count": self.batch_count, 75 | "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, 76 | "num_consumed_tokens": self.num_consumed_tokens, 77 | "inf_nan_skip_batches": self.inf_nan_skip_batches, 78 | "step_count": self.step_count, 79 | } 80 | 81 | return json.dumps(info, indent=4, sort_keys=True) 82 | 83 | def load_state_dict(self, other_stuffs): 84 | """ 85 | Resumes training from a checkpoint. 86 | 87 | Args: 88 | other_stuffs (dict): Other information needed to resume training. 89 | """ 90 | self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"] 91 | self.num_consumed_tokens = other_stuffs["num_consumed_tokens"] 92 | self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"] 93 | 94 | # Because the ckpt save occurs after updating 'step_count', 95 | # there is no need to increment 'step_count' here (Does our step count start from 0 ?), 96 | # However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume. 97 | self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward 98 | self.step_count = other_stuffs.get("step_count", self.batch_count) 99 | 100 | # resume tensorboard from older tensorboard_folder 101 | self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None) 102 | 103 | def state_dict(self): 104 | return { 105 | "batch_count": self.batch_count, 106 | "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, 107 | "num_consumed_tokens": self.num_consumed_tokens, 108 | "inf_nan_skip_batches": self.inf_nan_skip_batches, 109 | "step_count": self.step_count, 110 | "tensorboard_folder": self.tensorboard_folder, 111 | } 112 | 113 | 114 | class Trainer: 115 | """This is a class tending for easy deployments of users' training and evaluation instead of 116 | writing their own scripts. 117 | 118 | Args: 119 | engine (:class:`Engine`): Engine responsible for the process function. 120 | schedule (:class:`BaseScheduler`, optional): Runtime schedule. Defaults to None. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | engine: Engine, 126 | schedule: Optional[BaseScheduler] = None, 127 | ): 128 | """Initializes the Trainer class. 129 | 130 | Args: 131 | engine (Engine): The engine responsible for the process function. 132 | schedule (Optional[BaseScheduler], optional): The runtime schedule. Defaults to None. 133 | """ 134 | self._engine = engine 135 | 136 | # build schedule 137 | if schedule is None: 138 | self._schedule = NonPipelineScheduler() 139 | else: 140 | assert isinstance( 141 | schedule, BaseScheduler 142 | ), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}" 143 | self._schedule = schedule 144 | 145 | self._schedule.pre_processing(self._engine) 146 | 147 | @property 148 | def engine(self): 149 | """Returns the engine that responsible for managing the training and evaluation process.""" 150 | return self._engine 151 | 152 | @property 153 | def schedule(self): 154 | """Returns the runtime scheduler.""" 155 | return self._schedule 156 | 157 | @property 158 | def uses_pipeline(self): 159 | """Returns whether the pipeline parallel is used or not.""" 160 | return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler)) 161 | 162 | def train(self): 163 | """Sets the model to training mode.""" 164 | self._engine.train() 165 | 166 | def eval(self): 167 | """Sets the model to evaluation mode.""" 168 | self._engine.eval() 169 | 170 | def zero_grad(self): 171 | """Sets the gradient of all parameters in the model to zero.""" 172 | self._engine.zero_grad() 173 | 174 | def step(self): 175 | """Executes the parameter update step.""" 176 | return self._engine.step() 177 | 178 | def execute_schedule(self, data_iter: Iterable, **kwargs): 179 | """Runs the forward, loss computation, and backward for the model. 180 | Returns a tuple of (output, label, loss). 181 | 182 | Args: 183 | data_iter (Iterable): The data iterator. 184 | **kwargs: Additional keyword arguments. 185 | 186 | Returns: 187 | Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). 188 | """ 189 | output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) 190 | return output, label, loss 191 | -------------------------------------------------------------------------------- /InternLM/internlm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_sampler import get_dpsampler_dataloader 2 | from .collaters import jsonl_ds_collate_fn, packed_collate_fn 3 | from .dummy_dataset import RandomDataset 4 | from .packed_dataset import PackedDataset, PackedDatasetWithoutCuSeqlen 5 | 6 | __all__ = [ 7 | "jsonl_ds_collate_fn", 8 | "packed_collate_fn", 9 | "RandomDataset", 10 | "PackedDataset", 11 | "PackedDatasetWithoutCuSeqlen", 12 | "get_dpsampler_dataloader", 13 | ] 14 | -------------------------------------------------------------------------------- /InternLM/internlm/data/collaters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | 7 | def packed_collate_fn(batch, packed_length): 8 | 9 | """ 10 | Collate function for packed input sequences. 11 | 12 | Args: 13 | batch (List[Dict]): List of dictionaries representing each sample in batch. 14 | Each dictionary contains "tokens", "labels", "type_ids", "cu_seqlens", and "indexes" keys. 15 | packed_length (int): The length of packed sequence. 16 | 17 | Returns: 18 | Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", 19 | "cu_seqlens", "indexes", and "type_ids" keys, and the tensor of padded "labels". 20 | 21 | Raises: 22 | AssertionError: If the length of a sample is not equal to packed_length. 23 | AssertionError: If the shape of the padded "input_ids" tensor does not have the correct shape. 24 | """ 25 | 26 | xs, ys, cu_seqlens, indexes, ts = [], [], [], [], [] 27 | for b in batch: 28 | assert ( 29 | len(b["tokens"]) == packed_length 30 | ), f"length of a sample should be equal to packed_length, but got {len(b['tokens'])} and {packed_length})" 31 | assert ( 32 | len(b["labels"]) == packed_length 33 | ), f"length of a sample should be equal to packed_length, but got {len(b['labels'])} and {packed_length})" 34 | assert ( 35 | len(b["type_ids"]) == packed_length 36 | ), f"length of a sample should be equal to packed_length, but got {len(b['type_ids'])} and {packed_length})" 37 | 38 | tokens = [abs(w) for w in b["tokens"]] 39 | labels = [w if w > 0 else -100 for w in b["labels"]] 40 | 41 | xs.append(torch.LongTensor(tokens)) 42 | # The labels have been shifted here, so they are aligned with the output corresponding to the token 43 | ys.append(torch.LongTensor(labels)) 44 | ts.append(torch.LongTensor(b["type_ids"])) 45 | cu_seqlens.append(torch.IntTensor(b["cu_seqlens"])) 46 | indexes.append(torch.LongTensor(b["indexes"])) 47 | 48 | xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True) 49 | ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100) 50 | ts = torch.nn.utils.rnn.pad_sequence(ts, batch_first=True, padding_value=0) 51 | indexes = torch.stack(indexes, dim=0) 52 | if len(set(map(len, cu_seqlens))) == 1: # if has uniform length, then stack to save device transfer time 53 | cu_seqlens = torch.stack(cu_seqlens, dim=0) 54 | 55 | assert xs.shape[1] == packed_length, (xs.shape[1], packed_length) 56 | 57 | return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts}, ys 58 | 59 | 60 | def jsonl_ds_collate_fn(batch, max_length_per_sample): 61 | """ 62 | Collate function for json dataset. 63 | 64 | Args: 65 | batch (List[Dict]): List of dictionaries representing each sample in batch. 66 | Each dictionary contains "tokens". 67 | max_length_per_sample (int): The length of output sequence. 68 | 69 | Returns: 70 | Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", 71 | and the tensor of padded "labels". 72 | 73 | """ 74 | xs, ys = [], [] 75 | for x in batch: 76 | x["tokens"] = x["tokens"][:max_length_per_sample] 77 | tokens = [abs(w) for w in x["tokens"]] 78 | labels = [w if w > 0 else -100 for w in x["tokens"]] 79 | labels = labels[1:] + [-100] 80 | xs.append(torch.as_tensor(tokens)) 81 | ys.append(torch.as_tensor(labels)) # y has been shifted 82 | xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True) 83 | ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100) 84 | 85 | xs = torch.cat([xs, xs.new_zeros(len(xs), max_length_per_sample - len(xs[0]))], dim=-1) 86 | ys = torch.cat([ys, ys.new_full((len(ys), max_length_per_sample - len(ys[0])), fill_value=-100)], dim=-1) 87 | 88 | return {"input_ids": xs}, ys 89 | -------------------------------------------------------------------------------- /InternLM/internlm/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | from torch.utils.data import ConcatDataset 5 | 6 | from internlm.data.single_dataset import JsonlDataset 7 | 8 | 9 | def get_dataset_dict(folder, split="valid") -> Dict: 10 | """ 11 | Return a dictionary of Datasets from a folder containing data files for validation. 12 | 13 | Args: 14 | folder (str): The path to the folder containing data files. 15 | split (str): The split of the data files to be used, default is "valid". 16 | 17 | Returns: 18 | A dictionary containing Datasets for each folder in the given path 19 | that contains data files with the specified split. 20 | 21 | Raises: 22 | AssertionError: If the given folder does not exist. 23 | 24 | Example: 25 | If the given folder is as follows, 26 | - data 27 | - zhihu 28 | - xxx.bin 29 | - valid.bin 30 | - baike 31 | - xxx.bin 32 | - valid.bin 33 | 34 | The returned dictionary will be, 35 | { 36 | 'zhihu': Dataset, 37 | 'baike': Dataset 38 | } 39 | """ 40 | 41 | assert os.path.exists(folder), f"folder `{folder}` not exists" 42 | data_dict = {} 43 | 44 | for root, dirs, files in os.walk(folder, followlinks=True): 45 | dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind 46 | datasets = [] 47 | for fn in sorted(files): # Need sorted to ensure that the order is consistent 48 | if fn.endswith(".bin") and split in fn: 49 | fp = os.path.join(root, fn) 50 | ds = JsonlDataset(fp) 51 | datasets.append(ds) 52 | if datasets: 53 | ds = ConcatDataset(datasets=datasets) 54 | data_dict[os.path.basename(root)] = ds 55 | 56 | return data_dict 57 | -------------------------------------------------------------------------------- /InternLM/internlm/data/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class RandomDataset(Dataset): 9 | """ 10 | RandomDataset for generating random dataset. 11 | 12 | Args: 13 | num_samples (int): The number of samples to generate. 14 | max_len (int): The maximum length of each sample. 15 | 16 | """ 17 | 18 | def __init__(self, num_samples=10000, max_len=1024) -> None: 19 | super().__init__() 20 | rng = np.random.RandomState(1999) 21 | max_num = rng.randint(1, 30, size=(num_samples,)) 22 | rep_num = rng.randint(10, 200, size=(num_samples,)) 23 | data = [] 24 | lengths = [] 25 | for n, r in zip(max_num, rep_num): 26 | d = list(range(n)) * r 27 | d = [n, r] + d 28 | d = d[:max_len] 29 | data.append(d) 30 | lengths.append(len(d)) 31 | self.data = data 32 | self.max_len = max_len 33 | self.lengths = np.array(lengths, dtype=int) 34 | 35 | def __getitem__(self, index): 36 | d = self.data[index] 37 | input_ids = np.array(d, dtype=int) 38 | return {"tokens": list(input_ids), "type_id": 0} 39 | 40 | def get_dataset_name(self): 41 | return "dummy_path/dummy_lang/dummy_ds/train.bin" 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | -------------------------------------------------------------------------------- /InternLM/internlm/data/single_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | """ 5 | A .bin file corresponds to a Dataset instance here. 6 | """ 7 | 8 | import json 9 | import mmap 10 | import os 11 | import threading 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | import torch 16 | 17 | 18 | class JsonlDataset(torch.utils.data.Dataset): 19 | """ 20 | 21 | JSONL format is expected to roughly follow that of The Pile. 22 | One-line-per-document of the form: 23 | ``` 24 | { 25 | "tokens": List[int], 26 | } 27 | ``` 28 | 29 | Note that only the "tokens" key is used. 30 | """ 31 | 32 | def __init__(self, path: str, dataset_type_id: int = 0, min_length=50): 33 | self.path = path 34 | self.threadlocal = threading.local() 35 | resolved_path = Path(path).resolve() 36 | self.resolved_path = resolved_path 37 | self.meta = Path(f"{resolved_path}.meta") 38 | self.type_id = dataset_type_id 39 | 40 | # only build the cache in on the primary worker to prevent overloading nfs 41 | assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}" 42 | try: 43 | with open(self.meta, "rb") as f: 44 | meta = np.load(f) 45 | except Exception as e: 46 | print(f"Cannot load file {self.meta}...") 47 | raise e 48 | self.offsets = meta[:, 0] 49 | self.lengths = meta[:, -1] 50 | 51 | if min_length > 0: 52 | mask = self.lengths >= min_length 53 | self.old_lengths = self.lengths.copy() 54 | self.old_length = len(self.offsets) 55 | self.offsets = self.offsets[mask] 56 | self.lengths = self.lengths[mask] 57 | 58 | def __getitem__(self, idx): 59 | f = self._get_mmap() 60 | position = self.offsets[idx] 61 | f.seek(position) 62 | item = f.readline().decode("utf-8") 63 | try: 64 | item = json.loads(item) 65 | item["length"] = len(item["tokens"]) # add a length info 66 | item["type_id"] = self.type_id 67 | except Exception as err: 68 | raise json.decoder.JSONDecodeError( 69 | doc=self.path, 70 | pos=position, 71 | msg=( 72 | f"Error while loading JSONL line in file {self.path} at byte " 73 | f"{position}. Contents of line:\n{item}\n{err}" 74 | ), 75 | ) 76 | return item 77 | 78 | def get_dataset_name(self): 79 | return str(self.resolved_path) 80 | 81 | def _get_mmap(self): 82 | if not hasattr(self.threadlocal, "handles"): 83 | with open(self.path, "rb") as f: 84 | mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) 85 | self.threadlocal.handles = [f, mm] 86 | if self.path.endswith(".gz") or self.path.endswith(".bz") or self.path.endswith(".bz2"): 87 | raise NotImplementedError( 88 | "Compressed files are not supported because .seek() would require " 89 | "rereading the entire file, making performance too slow." 90 | ) 91 | return self.threadlocal.handles[-1] 92 | 93 | def __setstate__(self, state): 94 | self.__dict__ = state 95 | self.threadlocal = threading.local() 96 | 97 | def __getstate__(self): 98 | d = {} 99 | for i, v in self.__dict__.items(): 100 | if i != "threadlocal": 101 | d[i] = v 102 | return d 103 | 104 | def __del__(self): 105 | if hasattr(self.threadlocal, "handles"): 106 | # cleanup files we opened on initialization 107 | while self.threadlocal.handles: 108 | self.threadlocal.handles.pop().close() 109 | 110 | @staticmethod 111 | def exists(path): 112 | return os.path.exists(path) 113 | 114 | def __len__(self): 115 | # Virtual length of the dataset depends on the epoch number if the number of documents 116 | # is not perfectly divisible by the data_subshard_count 117 | return len(self.offsets) 118 | -------------------------------------------------------------------------------- /InternLM/internlm/data/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | from internlm.core.context import global_context as gpc 7 | 8 | DATASET_TYPE_IDS_MAP = {"vision": 0} 9 | 10 | 11 | def get_dataset_type_id(path): 12 | import re 13 | 14 | match_idxes = [] 15 | for key, idx in DATASET_TYPE_IDS_MAP.items(): 16 | if re.search(rf"/[z_]*{key}/", path): 17 | match_idxes.append(idx) 18 | assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" 19 | return match_idxes[0] 20 | 21 | 22 | def unpack_data(input_ids, cu_seqlens): 23 | """ 24 | input_ids: (n, packed_length) 25 | Return: 26 | output: (batch_size, max_length) 27 | """ 28 | 29 | bsz = input_ids.shape[0] 30 | 31 | num_sequence = gpc.config.data["micro_bsz"] 32 | 33 | outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) 34 | 35 | for i in range(bsz): 36 | output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) 37 | cu_seqlens_slice = cu_seqlens[i] 38 | for j in range(num_sequence): 39 | seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j] 40 | output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] 41 | outputs[i] = output 42 | 43 | if bsz == 1: 44 | outputs = outputs.squeeze(0) 45 | 46 | return outputs 47 | -------------------------------------------------------------------------------- /InternLM/internlm/initialize/__init__.py: -------------------------------------------------------------------------------- 1 | from .initialize_trainer import initialize_trainer, initialize_kd_trainer 2 | from .launch import ( 3 | get_default_parser, 4 | initialize_distributed_env, 5 | launch_from_slurm, 6 | launch_from_torch, 7 | ) 8 | 9 | __all__ = [ 10 | "get_default_parser", 11 | "initialize_trainer", 12 | "initialize_kd_trainer", 13 | "launch_from_slurm", 14 | "launch_from_torch", 15 | "initialize_distributed_env", 16 | ] 17 | -------------------------------------------------------------------------------- /InternLM/internlm/initialize/initialize_tensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import math 5 | 6 | from torch import Tensor, nn 7 | 8 | 9 | def scaled_init_method_normal(sigma: float = 1.0, num_layers: int = 1): 10 | """Init method based on N(0, sigma/sqrt(2*num_layers).""" 11 | std = sigma / math.sqrt(2.0 * num_layers) 12 | 13 | def init_(tensor): 14 | return nn.init.normal_(tensor, mean=0.0, std=std) 15 | 16 | return init_ 17 | 18 | 19 | def normal_(mean: float = 0.0, std: float = 1.0): 20 | r"""Return the initializer filling the input Tensor with values drawn from the normal distribution 21 | 22 | .. math:: 23 | \mathcal{N}(\text{mean}, \text{std}^2) 24 | 25 | Args: 26 | mean (float): the mean of the normal distribution. Defaults 0.0. 27 | std (float): the standard deviation of the normal distribution. Defaults 1.0. 28 | """ 29 | 30 | def initializer(tensor: Tensor): 31 | return nn.init.normal_(tensor, mean, std) 32 | 33 | return initializer 34 | 35 | 36 | def scaled_init_method_uniform(sigma: float = 1.0, num_layers: int = 1): 37 | """Init method based on p(x)=Uniform(-a, a) where std(x)=sigma/sqrt(2*num_layers).""" 38 | std = sigma / math.sqrt(2.0 * num_layers) 39 | a = math.sqrt(3.0 * std) 40 | 41 | def init_(tensor): 42 | return nn.init.uniform_(tensor, -a, a) 43 | 44 | return init_ 45 | 46 | 47 | def uniform_(mean: float = 0.0, std: float = 1.0): 48 | r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution 49 | 50 | .. math:: 51 | \mathcal{U}(mean-a, mean+a), where a satisfies \mathcal{U}_{std}=std. 52 | 53 | Args: 54 | mean (float): the mean of the uniform distribution. Defaults 0.0. 55 | std (float): the standard deviation of the uniform distribution. Defaults 1.0. 56 | """ 57 | 58 | a = math.sqrt(3.0 * std) 59 | 60 | def initializer(tensor: Tensor): 61 | return nn.init.uniform_(tensor, mean - a, mean + a) 62 | 63 | return initializer 64 | -------------------------------------------------------------------------------- /InternLM/internlm/initialize/legacy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/internlm/initialize/legacy/__init__.py -------------------------------------------------------------------------------- /InternLM/internlm/initialize/legacy/launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from internlm.initialize.launch import get_config_value 5 | from internlm.utils.logger import get_logger 6 | 7 | logger = get_logger(__file__) 8 | 9 | 10 | def auto_resume_sanity_check(ckpt_config): 11 | load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None) 12 | if load_given_ckpt is None: 13 | return True # default value is True 14 | else: 15 | return not load_given_ckpt 16 | 17 | 18 | def ckpt_info_sanity_check(ckpt_config): 19 | load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None) 20 | 21 | load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None) 22 | 23 | if load_model_only_folder is not None: 24 | assert ( 25 | load_ckpt_folder is None 26 | ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ 27 | # and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" 28 | return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm") 29 | else: 30 | load_optimizer = get_config_value(ckpt_config, "load_optimizer", True) 31 | 32 | if isinstance(load_ckpt_folder, str): 33 | if load_optimizer: 34 | return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm") 35 | else: 36 | return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm") 37 | elif load_ckpt_folder is None: 38 | return None 39 | else: 40 | assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'" 41 | -------------------------------------------------------------------------------- /InternLM/internlm/model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .embedding import Embedding1D, RotaryEmbedding 5 | from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear 6 | from .metrics import AccPerplex 7 | from .modeling_internlm import build_model_with_cfg 8 | from .modeling_vit import build_vit_model_with_cfg 9 | from .multi_head_attention import MHA 10 | from .utils import gather_forward_split_backward 11 | 12 | __all__ = [ 13 | "Embedding1D", 14 | "FeedForward", 15 | "RotaryEmbedding", 16 | "RewardModelLinear", 17 | "ScaleColumnParallelLinear", 18 | "AccPerplex", 19 | "MHA", 20 | "gather_forward_split_backward", 21 | "build_model_with_cfg", 22 | "build_vit_model_with_cfg" 23 | ] 24 | -------------------------------------------------------------------------------- /InternLM/internlm/model/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch.nn.functional as F 5 | from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss 6 | from torch import nn 7 | 8 | from internlm.core.context import ParallelMode 9 | from internlm.core.context import global_context as gpc 10 | 11 | 12 | class FlashGPTLMLoss(nn.Module): 13 | """ 14 | Loss function for flash GPT Language Model. 15 | """ 16 | 17 | def __init__(self, parallel_output=True, label_smoothing=0): 18 | super().__init__() 19 | 20 | if label_smoothing is not None: 21 | if label_smoothing != 0: 22 | if gpc.is_rank_for_log(): 23 | print(f"use label_smoothing: {label_smoothing}") 24 | else: 25 | label_smoothing = 0 26 | self.label_smoothing = label_smoothing 27 | 28 | if parallel_output: 29 | self.loss_fn = FlashCrossEntropyLoss( 30 | reduction="mean", 31 | inplace_backward=True, 32 | process_group=gpc.get_group(ParallelMode.TENSOR), 33 | label_smoothing=label_smoothing, 34 | ) # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D 35 | else: 36 | # Here, the output will gather output is set in the model, so use ordinary loss 37 | self.loss_fn = nn.CrossEntropyLoss(reduction="mean", label_smoothing=label_smoothing) 38 | 39 | def forward(self, *args): 40 | if len(args) == 3: 41 | # residual is to match prenorm 42 | logits, _, labels = args 43 | elif len(args) == 2: 44 | # When using postnorm 45 | logits, labels = args 46 | else: 47 | raise RuntimeError(f"The number of criterion inputs are:{len(args)}") 48 | shift_logits = logits.contiguous().view(-1, logits.size(-1)) 49 | shift_labels = labels.contiguous().view(-1) 50 | loss = self.loss_fn( 51 | shift_logits, shift_labels 52 | ) # There is no need to consider the ignore_index problem here, because the loss calculation will be 53 | # calculated through the calculation range, and -100 must be outside this range, so there is no problem 54 | 55 | return loss 56 | 57 | 58 | class KLDivLoss(nn.Module): 59 | def __init__(self): 60 | super().__init__() 61 | self.temperature = gpc.config.kd_config.get('temperature', 1) 62 | self.inverse = gpc.config.kd_config.get('inverse', False) 63 | 64 | def forward(self, *args): 65 | if len(args) == 3: 66 | if self.inverse: 67 | logits_teacher, logits_student, _ = args 68 | else: 69 | logits_student, logits_teacher, _ = args 70 | else: 71 | raise RuntimeError(f"The number of criterion inputs are:{len(args)}") 72 | 73 | logits_teacher = logits_teacher.contiguous().view(-1, logits_teacher.size(-1)) 74 | logits_student = logits_student.contiguous().view(-1, logits_student.size(-1)) 75 | 76 | log_pred_student = F.log_softmax(logits_student / self.temperature, dim=1) 77 | pred_teacher = F.softmax(logits_teacher / self.temperature, dim=1) 78 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean') 79 | loss_kd *= self.temperature ** 2 80 | 81 | return loss_kd 82 | -------------------------------------------------------------------------------- /InternLM/internlm/model/muse/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | from .modeling_taming_vqgan import VQGANModel 19 | -------------------------------------------------------------------------------- /InternLM/internlm/model/norm.py: -------------------------------------------------------------------------------- 1 | # adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm 2 | 3 | import numbers 4 | 5 | import torch 6 | from torch.nn import init 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | def manual_rms_norm(my_input, normalized_shape, weight, eps): 11 | # layer norm should always be calculated in float32 12 | dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) 13 | variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) 14 | my_input = my_input * torch.rsqrt(variance + eps) 15 | 16 | if weight is None: 17 | return my_input 18 | 19 | # model_hf into half-precision if necessary 20 | if weight.dtype in [torch.float16, torch.bfloat16]: 21 | my_input = my_input.to(weight.dtype) 22 | 23 | return weight * my_input 24 | 25 | 26 | class RMSNormTorch(torch.nn.Module): 27 | """A custom PyTorch module for RMS normalization.""" 28 | 29 | def __init__(self, normalized_shape, eps=1e-5): 30 | super().__init__() 31 | 32 | if isinstance(normalized_shape, numbers.Integral): 33 | normalized_shape = (normalized_shape,) 34 | self.normalized_shape = torch.Size(normalized_shape) 35 | self.eps = eps 36 | self.weight = Parameter(torch.empty(*normalized_shape)) 37 | self.reset_parameters() 38 | 39 | def forward(self, _input: torch.Tensor): 40 | return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps) 41 | 42 | def reset_parameters(self): 43 | init.ones_(self.weight) 44 | 45 | def extra_repr(self): 46 | return "{normalized_shape}, eps={eps}, ".format(**self.__dict__) 47 | -------------------------------------------------------------------------------- /InternLM/internlm/model/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from flash_attn.ops.fused_dense import FusedDenseFunc 9 | from flash_attn.utils.distributed import ( 10 | all_gather_raw, 11 | all_reduce_raw, 12 | reduce_scatter_raw, 13 | ) 14 | from torch import Tensor 15 | from torch.cuda.amp import custom_bwd 16 | from torch.distributed import ProcessGroup 17 | 18 | from internlm.core.context import global_context as gpc 19 | from internlm.utils.logger import get_logger 20 | 21 | logger = get_logger(__file__) 22 | 23 | 24 | def _split(input_, parallel_mode, dim=-1): 25 | # skip if only one rank involved 26 | world_size = gpc.get_world_size(parallel_mode) 27 | if world_size == 1: 28 | return input_ 29 | 30 | # Split along last dimension. 31 | dim_size = input_.size(dim) 32 | assert dim_size % world_size == 0, ( 33 | f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " 34 | f"cannot split tensor evenly" 35 | ) 36 | 37 | tensor_list = torch.split(input_, dim_size // world_size, dim=dim) 38 | rank = gpc.get_local_rank(parallel_mode) 39 | output = tensor_list[rank].contiguous() 40 | 41 | return output 42 | 43 | 44 | def _gather(input_, parallel_mode, dim=-1): 45 | # skip if only one rank involved 46 | world_size = gpc.get_world_size(parallel_mode) 47 | if world_size == 1: 48 | return input_ 49 | 50 | # all gather 51 | rank = gpc.get_local_rank(parallel_mode) 52 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 53 | tensor_list[rank] = input_ 54 | group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) 55 | torch.distributed.all_gather(tensor_list, input_, group=group) 56 | 57 | # concat 58 | output = torch.cat(tensor_list, dim=dim).contiguous() 59 | 60 | return output 61 | 62 | 63 | class _GatherForwardSplitBackward(torch.autograd.Function): 64 | """Gather the input from model parallel region and concatenate. 65 | 66 | Args: 67 | input_: input matrix. 68 | parallel_mode: parallel mode. 69 | dim: dimension 70 | """ 71 | 72 | @staticmethod 73 | def symbolic(input_): 74 | return _gather(input_, parallel_mode=None) 75 | 76 | @staticmethod 77 | def forward(ctx, input_, parallel_mode, dim): 78 | ctx.mode = parallel_mode 79 | ctx.dim = dim 80 | return _gather(input_, parallel_mode, dim) 81 | 82 | @staticmethod 83 | def backward(ctx, grad_output): 84 | return _split(grad_output, ctx.mode, ctx.dim), None, None 85 | 86 | 87 | def gather_forward_split_backward(input_, parallel_mode, dim): 88 | return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) 89 | 90 | 91 | def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): 92 | assert my_input.dtype == grad_output.dtype 93 | grad_weight = torch.matmul(grad_output.t(), my_input) 94 | grad_bias = grad_output.sum(dim=0) if has_d_bias else None 95 | return grad_weight, grad_bias 96 | 97 | 98 | # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py 99 | class FusedDenseFuncTorch(FusedDenseFunc): 100 | """A custom PyTorch module extending FusedDenseFunc.""" 101 | 102 | @staticmethod 103 | @custom_bwd 104 | def backward(ctx, grad_output, *args): 105 | grad_output = grad_output.contiguous() 106 | if ctx.return_residual: 107 | (grad_input,) = args 108 | grad_input = grad_input.contiguous() 109 | process_group = ctx.process_group 110 | sequence_parallel = ctx.sequence_parallel 111 | if ctx.compute_weight_gradient: 112 | x, weight = ctx.saved_tensors 113 | if process_group is not None and sequence_parallel: 114 | total_x, handle_x = all_gather_raw(x, process_group, async_op=True) 115 | else: 116 | total_x = x 117 | else: 118 | (weight,) = ctx.saved_tensors 119 | total_x = None 120 | batch_shape = grad_output.shape[:-1] 121 | batch_dim = batch_shape.numel() 122 | grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) 123 | if ctx.needs_input_grad[0]: 124 | if not ctx.return_residual: 125 | grad_input = F.linear(grad_output, weight.t()) 126 | else: 127 | grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight) 128 | grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) 129 | if process_group is not None: 130 | reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw 131 | grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) 132 | else: 133 | grad_input = None 134 | if ctx.needs_input_grad[1]: 135 | assert ctx.compute_weight_gradient 136 | if process_group is not None and sequence_parallel: 137 | handle_x.wait() 138 | # we remove the cuda independence, which is different from flash_attn. 139 | grad_weight, grad_bias = linear_bias_wgrad_torch( 140 | total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] 141 | ) 142 | else: 143 | grad_weight = None 144 | grad_bias = grad_output if ctx.needs_input_grad[2] else None 145 | if process_group is not None and ctx.needs_input_grad[0]: 146 | handle_grad_input.wait() 147 | return grad_input, grad_weight, grad_bias, None, None, None 148 | 149 | 150 | def fused_dense_func_torch( 151 | x: Tensor, 152 | weight: Tensor, 153 | bias: Optional[Tensor] = None, 154 | return_residual: bool = False, 155 | process_group: Optional[ProcessGroup] = None, 156 | sequence_parallel: bool = True, 157 | ): 158 | dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( 159 | x.dtype == torch.float32 and torch.is_autocast_enabled() 160 | ) 161 | if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: 162 | return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel) 163 | else: 164 | return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) 165 | 166 | 167 | class _SplitForwardGatherBackward(torch.autograd.Function): 168 | """ 169 | Split the input and keep only the corresponding chuck to the rank. 170 | 171 | Args: 172 | input_: input matrix. 173 | parallel_mode: parallel mode. 174 | dim: dimension 175 | """ 176 | 177 | @staticmethod 178 | def symbolic(input_): 179 | return _split(input_, parallel_mode=None) 180 | 181 | @staticmethod 182 | def forward(ctx, input_, parallel_mode, dim): 183 | ctx.mode = parallel_mode 184 | ctx.dim = dim 185 | return _split(input_, parallel_mode, dim) 186 | 187 | @staticmethod 188 | def backward(ctx, grad_output): 189 | return _gather(grad_output, ctx.mode, ctx.dim), None, None 190 | 191 | 192 | def split_forward_gather_backward(input_, parallel_mode, dim): 193 | return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) 194 | 195 | 196 | def try_import_RMSNorm(): 197 | """ 198 | Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm 199 | 200 | """ 201 | try: 202 | from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm 203 | 204 | return RMSNorm 205 | except ModuleNotFoundError: 206 | logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") 207 | from internlm.model.norm import RMSNormTorch as RMSNorm 208 | 209 | return RMSNorm 210 | 211 | 212 | def try_import_LayerNorm(): 213 | """ 214 | Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm 215 | 216 | """ 217 | try: 218 | from apex.normalization.fused_layer_norm import MixedFusedLayerNorm as LayerNorm 219 | 220 | return LayerNorm 221 | except ModuleNotFoundError: 222 | import torch.nn as nn 223 | 224 | return nn.LayerNorm -------------------------------------------------------------------------------- /InternLM/internlm/monitor/__init__.py: -------------------------------------------------------------------------------- 1 | from .alert import initialize_light_monitor, send_heartbeat 2 | from .monitor import initialize_monitor_manager, send_alert_message 3 | from .utils import set_env_var 4 | 5 | __all__ = [ 6 | "send_alert_message", 7 | "initialize_monitor_manager", 8 | "set_env_var", 9 | "initialize_light_monitor", 10 | "send_heartbeat", 11 | ] 12 | -------------------------------------------------------------------------------- /InternLM/internlm/monitor/alert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import re 5 | import time 6 | from typing import Dict 7 | 8 | import requests 9 | 10 | from internlm.utils.logger import get_logger 11 | 12 | logger = get_logger(__file__) 13 | 14 | 15 | def initialize_light_monitor(monitor_address: str = None): 16 | try: 17 | from uniscale_monitoring import init_monitor 18 | 19 | init_monitor(monitor_address) 20 | except Exception as e: 21 | logger.warning(f"init monitor meet error: {e}") 22 | 23 | 24 | def send_heartbeat(msg_type: str, msg: Dict): 25 | def nan2none(v): 26 | if isinstance(v, float) and math.isnan(v): 27 | return None 28 | return v 29 | 30 | try: 31 | from uniscale_monitoring import send_meta 32 | 33 | data = {} 34 | for k, v in msg.items(): 35 | if isinstance(v, Dict): 36 | for k1, v1 in v.items(): 37 | new_k = f"{k}_{k1}".split(" ")[0] 38 | new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k) 39 | data[new_k] = nan2none(v1) 40 | else: 41 | new_k = k.split(" ")[0] 42 | new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k) 43 | data[new_k] = nan2none(v) 44 | 45 | if os.getenv("CLUSTER_NAME"): 46 | data.update({"cluster": os.getenv("CLUSTER_NAME")}) 47 | if msg_type == "train_metrics": 48 | data.update({"msg_type": "train_metrics"}) 49 | elif msg_type == "init_time": 50 | data.update({"msg_type": "init_time"}) 51 | elif msg_type == "stage_time": 52 | data.update({"msg_type": "stage_time"}) 53 | send_meta(data, timeout=0.1) 54 | except Exception as e: 55 | logger.warning(f"send heartbeat meet error: {e}") 56 | 57 | 58 | def send_feishu_msg_with_webhook(webhook: str, title: str, message: str): 59 | """ 60 | Use Feishu robot to send messages with the given webhook. 61 | 62 | Args: 63 | webhook (str): The webhook to be used to send message. 64 | title (str): The message title. 65 | message (str): The message body. 66 | 67 | Returns: 68 | The response from the request. Or catch the exception and return None. 69 | 70 | Raises: 71 | Exception: An exception rasied by the HTTP post request. 72 | 73 | """ 74 | 75 | headers = {"Content-Type": "application/json;charset=utf-8"} 76 | msg_body = { 77 | "timestamp": int(time.time()), 78 | "msg_type": "post", 79 | "content": { 80 | "post": { 81 | "zh_cn": { 82 | "title": title, 83 | "content": [ 84 | [ 85 | { 86 | "tag": "text", 87 | "text": message, 88 | }, 89 | ], 90 | ], 91 | }, 92 | }, 93 | }, 94 | } 95 | 96 | try: 97 | res = requests.post(webhook, data=json.dumps(msg_body), headers=headers, timeout=30) 98 | res = res.json() 99 | print(f"Feishu webhook response: {res}") 100 | except Exception as err: # pylint: disable=W0703 101 | print(f"HTTP Post error: {err}") 102 | res = None 103 | 104 | return res 105 | -------------------------------------------------------------------------------- /InternLM/internlm/monitor/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | 5 | def now_time(): 6 | return datetime.now().strftime("%b%d_%H-%M-%S") 7 | 8 | 9 | def set_env_var(key, value): 10 | os.environ[str(key)] = str(value) 11 | 12 | 13 | def get_job_id(): 14 | job_id = "none" 15 | if os.getenv("SLURM_JOB_ID") is not None: 16 | job_id = os.getenv("SLURM_JOB_ID") 17 | elif os.getenv("K8S_WORKSPACE_ID") is not None: 18 | job_id = os.getenv("K8S_WORKSPACE_ID") 19 | 20 | return job_id 21 | 22 | 23 | def get_job_name(): 24 | job_name = f"unknown-{now_time()}" 25 | if os.getenv("JOB_NAME") is not None: 26 | job_name = os.getenv("JOB_NAME") 27 | 28 | return job_name 29 | 30 | 31 | def get_job_key(): 32 | return f"{get_job_id()}_{get_job_name()}" 33 | -------------------------------------------------------------------------------- /InternLM/internlm/solver/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .beta2_scheduler import Beta2Scheduler 5 | from .lr_scheduler import FineTuneCosineAnnealingWarmupLR 6 | from .optimizer import HybridZeroOptimizer 7 | 8 | __all__ = ["Beta2Scheduler", "FineTuneCosineAnnealingWarmupLR", "HybridZeroOptimizer"] 9 | -------------------------------------------------------------------------------- /InternLM/internlm/solver/beta2_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | 7 | class Beta2Scheduler: 8 | """ 9 | Beta2Scheduler 10 | """ 11 | 12 | def __init__(self, optimizer: torch.optim.Adam, init_beta2, c=0.8, cur_iter=-1): 13 | self.cur_iter = 0 if cur_iter == -1 else cur_iter 14 | self.init_beta2 = init_beta2 15 | self.c = c 16 | self.optimizer = optimizer 17 | assert isinstance( 18 | optimizer, (torch.optim.Adam, torch.optim.AdamW) 19 | ), "should use Adam optimzier, which has beta2" 20 | 21 | def step(self, cur_iter=None): 22 | if cur_iter is None: 23 | self.cur_iter += 1 24 | else: 25 | self.cur_iter = cur_iter 26 | 27 | new_beta2 = self.get_beta2() 28 | for pg in self.optimizer.param_groups: 29 | beta1, _ = pg["betas"] 30 | pg["betas"] = (beta1, new_beta2) 31 | 32 | def get_beta2(self): 33 | if self.c <= 0: 34 | return self.init_beta2 35 | scale = 1 - (1 / self.cur_iter**self.c) 36 | return max(self.init_beta2, scale) 37 | -------------------------------------------------------------------------------- /InternLM/internlm/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import json 5 | 6 | from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | 9 | 10 | class WarmupScheduler(_LRScheduler): 11 | """Starts with a linear warmup lr schedule until it reaches N epochs then applies 12 | the specific scheduler (For tools: ReduceLROnPlateau). 13 | 14 | Args: 15 | optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. 16 | warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. 17 | after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. 18 | last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, 19 | the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. 20 | """ 21 | 22 | def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): 23 | self.warmup_epochs = int(warmup_epochs) 24 | self.after_scheduler = after_scheduler 25 | self.finished = False 26 | super().__init__(optimizer, last_epoch) 27 | 28 | def state_dict(self): 29 | state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} 30 | if isinstance(state_dict["after_scheduler"], (_LRScheduler, _CosineAnnealingLR)): 31 | state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ 32 | state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() 33 | del state_dict["after_scheduler"] 34 | else: 35 | raise NotImplementedError() 36 | return state_dict 37 | 38 | def load_state_dict(self, state_dict): 39 | # state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} 40 | for key in list(self.__dict__.keys()): 41 | if key in state_dict: 42 | self.__dict__[key] = state_dict[key] 43 | if isinstance(self.after_scheduler, (_LRScheduler, _CosineAnnealingLR)): 44 | assert type(self.after_scheduler).__name__ == state_dict["after_scheduler_type"] 45 | # state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() 46 | self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"]) 47 | # del state_dict['after_scheduler'] 48 | else: 49 | raise NotImplementedError() 50 | return state_dict 51 | 52 | def get_lr(self): 53 | if self.last_epoch >= self.warmup_epochs: 54 | if not self.finished: 55 | self.after_scheduler.base_lrs = self.base_lrs 56 | self.finished = True 57 | return self.after_scheduler.get_lr() 58 | 59 | return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] 60 | 61 | def step(self, epoch=None): 62 | if self.finished: 63 | if epoch is None: 64 | self.after_scheduler.step(None) 65 | self._last_lr = self.after_scheduler.get_last_lr() 66 | else: 67 | self.after_scheduler.step(epoch - self.warmup_epochs) 68 | self._last_lr = self.after_scheduler.get_last_lr() 69 | else: 70 | return super().step(epoch) 71 | 72 | 73 | class CosineAnnealingWarmupLR(WarmupScheduler): 74 | """Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied. 75 | 76 | Args: 77 | optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. 78 | total_steps (int): Number of total training steps. 79 | warmup_steps (int, optional): Number of warmup steps, defaults to 0. 80 | eta_min (int, optional): Minimum learning rate, defaults to 0. 81 | last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, 82 | the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. 83 | """ 84 | 85 | def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0.0, last_epoch: int = -1): 86 | base_scheduler = _CosineAnnealingLR( 87 | optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch 88 | ) 89 | super().__init__(optimizer, warmup_steps, base_scheduler) 90 | 91 | 92 | class FineTuneCosineAnnealingWarmupLR(CosineAnnealingWarmupLR): 93 | """ 94 | FineTune Cosine Annealing Warmup LR. 95 | 96 | Args: 97 | optimizer: The optimizer object. 98 | total_steps (int): The number of total steps. 99 | init_steps (int): The number of init steps, default is 0. 100 | warmup_steps (int): The number of warm up steps, default is 0. 101 | eta_min (float): The minimum learning rate, default is 0.0. 102 | last_epoch: Last epoch, default is -1. 103 | 104 | """ 105 | 106 | def __init__( 107 | self, 108 | optimizer, 109 | total_steps: int, 110 | init_steps: int = 0, 111 | warmup_ratio: float = 0.0, 112 | eta_min: float = 0.0, 113 | last_epoch: int = -1, 114 | ): 115 | self._init_steps = init_steps 116 | self._warmup_steps = int(total_steps * warmup_ratio) 117 | # Use this value to calculate the lr of warmup, because warmup_epochs = init_steps + warmup_steps 118 | super().__init__(optimizer, total_steps, self._warmup_steps + init_steps, eta_min, last_epoch) 119 | 120 | def get_lr(self): 121 | if self.last_epoch >= self.warmup_epochs: 122 | if not self.finished: # pylint: disable=E0203 123 | # This True switch is to avoid warning when the warmup reaches the preset value switch 124 | self.after_scheduler._get_lr_called_within_step = True 125 | self.after_scheduler.base_lrs = self.base_lrs 126 | self.finished = True 127 | return self.after_scheduler.get_lr() 128 | 129 | elif self.last_epoch >= self._init_steps: 130 | return [(self.last_epoch + 1 - self._init_steps) / self._warmup_steps * lr for lr in self.base_lrs] 131 | else: 132 | return [0 for lr in self.base_lrs] 133 | 134 | def __str__(self): 135 | return json.dumps(self.state_dict(), indent=4, sort_keys=True) 136 | -------------------------------------------------------------------------------- /InternLM/internlm/solver/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff 5 | 6 | __all__ = ["HybridZeroOptimizer", "reload_zero_fp32_buff"] 7 | -------------------------------------------------------------------------------- /InternLM/internlm/solver/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from internlm.utils.logger import get_logger 5 | 6 | logger = get_logger(__file__) 7 | 8 | 9 | def partition_uniform(num_items, pipeline_parallel_size, num_chunks): 10 | assert ( 11 | num_items % num_chunks == 0 12 | ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" 13 | 14 | parts = [[] for _ in range(pipeline_parallel_size)] 15 | partition_items = num_items // num_chunks 16 | for idx in range(num_chunks): 17 | base_idx = idx * partition_items 18 | chunk_size = partition_items // pipeline_parallel_size 19 | left = pipeline_parallel_size - partition_items % pipeline_parallel_size 20 | if chunk_size == 0: 21 | raise ValueError("Some nodes in Pipeline have no requests") 22 | 23 | for p in range(pipeline_parallel_size): 24 | st = base_idx 25 | base_idx += chunk_size + (p >= left) 26 | parts[p].append((st, base_idx)) 27 | 28 | indexes = [] 29 | for _parts in parts: 30 | for s, e in _parts: 31 | indexes.extend(list(range(s, e))) 32 | assert len(indexes) == len(set(indexes)), indexes # should have no duplicates 33 | assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected 34 | return parts 35 | -------------------------------------------------------------------------------- /InternLM/internlm/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .training_internlm import ( 2 | get_train_data_loader, 3 | get_validation_data_loader, 4 | initialize_llm_profile, 5 | initialize_model, 6 | initialize_teacher, 7 | initialize_optimizer, 8 | load_new_batch, 9 | load_new_batch_stop, 10 | record_current_batch_training_metrics, 11 | ) 12 | 13 | __all__ = [ 14 | "get_train_data_loader", 15 | "get_validation_data_loader", 16 | "initialize_llm_profile", 17 | "initialize_model", 18 | "initialize_teacher", 19 | "initialize_optimizer", 20 | "load_new_batch", 21 | "load_new_batch_stop", 22 | "record_current_batch_training_metrics", 23 | ] 24 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/internlm/utils/__init__.py -------------------------------------------------------------------------------- /InternLM/internlm/utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import bisect 5 | import inspect 6 | import os 7 | import random 8 | from contextlib import contextmanager 9 | from datetime import datetime 10 | from typing import Union 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import internlm 16 | 17 | CURRENT_TIME = None 18 | 19 | 20 | def parse_args(): 21 | parser = internlm.get_default_parser() 22 | args = parser.parse_args() 23 | 24 | return args 25 | 26 | 27 | def get_master_node(): 28 | import subprocess 29 | 30 | if os.getenv("SLURM_JOB_ID") is None: 31 | raise RuntimeError("get_master_node can only used in Slurm launch!") 32 | result = subprocess.check_output('scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1', shell=True) 33 | result = result.decode("utf8").strip() 34 | return result 35 | 36 | 37 | def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: 38 | if torch.is_tensor(norm) and norm.device.type != "cuda": 39 | norm = norm.to(torch.cuda.current_device()) 40 | return norm 41 | 42 | 43 | def _move_tensor(element): 44 | if not torch.is_tensor(element): 45 | # we expecte the data type if a list of dictionaries 46 | for item in element: 47 | if isinstance(item, dict): 48 | for key, value in item.items(): 49 | assert not value.is_cuda, "elements are already on devices." 50 | item[key] = value.to(get_current_device()).detach() 51 | elif isinstance(item, list): 52 | for index, value in enumerate(item): 53 | assert not value.is_cuda, "elements are already on devices." 54 | item[index] = value.to(get_current_device()).detach() 55 | elif torch.is_tensor(item): 56 | if not item.is_cuda: 57 | item = item.to(get_current_device()).detach() 58 | else: 59 | assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}" 60 | if not element.is_cuda: 61 | element = element.to(get_current_device()).detach() 62 | return element 63 | 64 | 65 | def move_to_device(data): 66 | if isinstance(data, torch.Tensor): 67 | data = data.to(get_current_device()) 68 | elif isinstance(data, (list, tuple)): 69 | data_to_return = [] 70 | for element in data: 71 | if isinstance(element, dict): 72 | data_to_return.append({k: _move_tensor(v) for k, v in element.items()}) 73 | else: 74 | data_to_return.append(_move_tensor(element)) 75 | data = data_to_return 76 | elif isinstance(data, dict): 77 | data = {k: _move_tensor(v) for k, v in data.items()} 78 | else: 79 | raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") 80 | return data 81 | 82 | 83 | def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: 84 | if isinstance(norm, float): 85 | norm = torch.Tensor([norm]) 86 | if move_to_cuda: 87 | norm = norm.to(torch.cuda.current_device()) 88 | return norm 89 | 90 | 91 | def get_current_device() -> torch.device: 92 | """ 93 | Returns currently selected device (gpu/cpu). 94 | If cuda available, return gpu, otherwise return cpu. 95 | """ 96 | if torch.cuda.is_available(): 97 | return torch.device(f"cuda:{torch.cuda.current_device()}") 98 | else: 99 | return torch.device("cpu") 100 | 101 | 102 | def get_batch_size(data): 103 | if isinstance(data, torch.Tensor): 104 | return data.size(0) 105 | elif isinstance(data, (list, tuple)): 106 | if isinstance(data[0], dict): 107 | return data[0][list(data[0].keys())[0]].size(0) 108 | return data[0].size(0) 109 | elif isinstance(data, dict): 110 | return data[list(data.keys())[0]].size(0) 111 | 112 | 113 | def filter_kwargs(func, kwargs): 114 | sig = inspect.signature(func) 115 | return {k: v for k, v in kwargs.items() if k in sig.parameters} 116 | 117 | 118 | def launch_time(): 119 | global CURRENT_TIME 120 | if not CURRENT_TIME: 121 | CURRENT_TIME = datetime.now().strftime("%b%d_%H-%M-%S") 122 | return CURRENT_TIME 123 | 124 | 125 | def set_random_seed(seed): 126 | """Set random seed for reproducability.""" 127 | # It is recommended to use this only when inference. 128 | if seed is not None: 129 | assert seed > 0 130 | random.seed(seed) 131 | np.random.seed(seed) 132 | torch.manual_seed(seed) 133 | torch.cuda.manual_seed(seed) 134 | # if you are using multi-GPU. 135 | torch.cuda.manual_seed_all(seed) 136 | 137 | 138 | @contextmanager 139 | def conditional_context(context_manager, enable=True): 140 | if enable: 141 | with context_manager: 142 | yield 143 | else: 144 | yield 145 | 146 | 147 | class BatchSkipper: 148 | """ 149 | BatchSkipper is used to determine whether to skip the current batch_idx. 150 | """ 151 | 152 | def __init__(self, skip_batches): 153 | if skip_batches == "": 154 | pass 155 | intervals = skip_batches.split(",") 156 | spans = [] 157 | if skip_batches != "": 158 | for interval in intervals: 159 | if "-" in interval: 160 | start, end = map(int, interval.split("-")) 161 | else: 162 | start, end = int(interval), int(interval) 163 | if spans: 164 | assert spans[-1] <= start 165 | spans.extend((start, end + 1)) 166 | self.spans = spans 167 | 168 | def __call__(self, batch_count): 169 | index = bisect.bisect_right(self.spans, batch_count) 170 | return index % 2 == 1 171 | 172 | 173 | class SingletonMeta(type): 174 | """ 175 | Singleton Meta. 176 | """ 177 | 178 | _instances = {} 179 | 180 | def __call__(cls, *args, **kwargs): 181 | if cls not in cls._instances: 182 | cls._instances[cls] = super().__call__(*args, **kwargs) 183 | else: 184 | assert ( 185 | len(args) == 0 and len(kwargs) == 0 186 | ), f"{cls.__name__} is a singleton class and a instance has been created." 187 | return cls._instances[cls] 188 | 189 | 190 | def get_megatron_flops( 191 | elapsed_time_per_iter, 192 | checkpoint=False, 193 | seq_len=2048, 194 | hidden_size=12, 195 | num_layers=32, 196 | vocab_size=12, 197 | global_batch_size=4, 198 | global_world_size=1, 199 | mlp_ratio=4, 200 | use_swiglu=True, 201 | ): 202 | """ 203 | Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf 204 | """ 205 | 206 | checkpoint_activations_factor = 4 if checkpoint else 3 207 | 208 | if use_swiglu: 209 | mlp_ratio = mlp_ratio * 3 / 2 210 | 211 | flops_per_iteration = ( 212 | checkpoint_activations_factor 213 | * ( 214 | (8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2 215 | + 4 * global_batch_size * seq_len**2 * hidden_size 216 | ) 217 | ) * num_layers + 6 * global_batch_size * seq_len * hidden_size * vocab_size 218 | 219 | tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12)) 220 | return tflops 221 | 222 | 223 | class DummyProfile: 224 | """ 225 | Dummy Profile. 226 | """ 227 | 228 | def __init__(self, *args, **kwargs) -> None: 229 | pass 230 | 231 | def __enter__(self): 232 | return self 233 | 234 | def __exit__(self, a, b, c): 235 | pass 236 | 237 | def step(self): 238 | pass 239 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from tqdm import tqdm 6 | 7 | from internlm.core.context import ParallelMode 8 | from internlm.core.context import global_context as gpc 9 | from internlm.core.scheduler import SchedulerMetricHook 10 | from internlm.model.metrics import AccPerplex 11 | 12 | 13 | @contextmanager 14 | def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list): 15 | if not gpc.is_using_pp(): 16 | prev_data_process_func = trainer.schedule.data_process_func 17 | prev_grad_accum_size = trainer.schedule._grad_accum_size 18 | prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size 19 | prev_metric_hooks = trainer.schedule._hooks 20 | try: 21 | trainer.schedule.data_process_func = None 22 | trainer.schedule._grad_accum_size = grad_accum_size 23 | trainer.schedule._grad_accum_batch_size = grad_accum_batch_size 24 | trainer.schedule._hooks = metric_hook_list 25 | yield 26 | finally: 27 | trainer.schedule.data_process_func = prev_data_process_func 28 | trainer.schedule._grad_accum_size = prev_grad_accum_size 29 | trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size 30 | trainer.schedule._hooks = prev_metric_hooks 31 | 32 | 33 | @contextmanager 34 | def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list): 35 | if gpc.is_using_pp(): 36 | pre_data_process_func = trainer.schedule.data_process_func 37 | prev_num_microbatches = trainer.schedule.num_microbatches 38 | prev_tensor_shape = trainer.schedule.tensor_shape 39 | prev_metric_hooks = trainer.schedule._hooks 40 | try: 41 | trainer.schedule.data_process_func = None 42 | trainer.schedule.num_microbatches = num_microbatches 43 | trainer.schedule.tensor_shape = tensor_shape 44 | trainer.schedule._hooks = metric_hook_list 45 | yield 46 | finally: 47 | trainer.schedule.data_process_func = pre_data_process_func 48 | trainer.schedule.num_microbatches = prev_num_microbatches 49 | trainer.schedule.tensor_shape = prev_tensor_shape 50 | trainer.schedule._hooks = prev_metric_hooks 51 | 52 | 53 | @contextmanager 54 | def switch_sequence_parallel_mode(): 55 | prev_mode = gpc.config.parallel.sequence_parallel 56 | try: 57 | gpc.config.parallel.sequence_parallel = False 58 | yield 59 | finally: 60 | gpc.config.parallel.sequence_parallel = prev_mode 61 | 62 | 63 | def evaluate_on_val_dls( 64 | trainer, 65 | val_dls, 66 | writer, 67 | logger, 68 | step_count, 69 | update_panel: bool = False, 70 | streaming: bool = False, 71 | ): 72 | with switch_sequence_parallel_mode(): 73 | torch.cuda.empty_cache() 74 | trainer.eval() 75 | verbose = gpc.is_rank_for_log() 76 | data_cfg = gpc.config.data 77 | 78 | for val_name, val_dl in val_dls.items(): 79 | if not streaming and len(val_dl) == 0 and verbose: 80 | logger.info(f"Validation dataset: {val_name} is empty") 81 | continue 82 | 83 | val_metric = AccPerplex( 84 | device=torch.cuda.current_device(), 85 | tp_pg=gpc.get_group(ParallelMode.TENSOR), 86 | dp_pg=gpc.get_group(ParallelMode.DATA), 87 | ) 88 | val_sche_metric_hook = SchedulerMetricHook(metric=val_metric) 89 | 90 | val_loss = 0 91 | val_idx = -1 92 | for val_idx, batch in tqdm( 93 | enumerate(val_dl), 94 | desc="Val.", 95 | total=len(val_dl) if not streaming else None, 96 | position=1, 97 | disable=not verbose, 98 | leave=False, 99 | ): 100 | with torch.inference_mode(): 101 | if gpc.is_using_pp(): 102 | total_val_bsz = len(batch[1]) 103 | assert total_val_bsz % data_cfg.micro_bsz == 0 104 | num_microbatches = total_val_bsz // data_cfg.micro_bsz 105 | tensor_shape = torch.Size( 106 | [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] 107 | ) 108 | 109 | with switch_evaluation_pipeline_scheduler( 110 | trainer=trainer, 111 | num_microbatches=num_microbatches, 112 | tensor_shape=tensor_shape, 113 | metric_hook_list=[val_sche_metric_hook], 114 | ): 115 | _, _, loss = trainer.execute_schedule( 116 | batch, forward_only=True, return_loss=True, return_output_label=False 117 | ) 118 | else: 119 | total_val_bsz = len(batch[1]) 120 | assert total_val_bsz % data_cfg.micro_bsz == 0 121 | grad_accum_size = total_val_bsz // data_cfg.micro_bsz 122 | grad_accum_batch_size = data_cfg.micro_bsz 123 | with switch_evaluation_no_pipeline_scheduler( 124 | trainer=trainer, 125 | grad_accum_size=grad_accum_size, 126 | grad_accum_batch_size=grad_accum_batch_size, 127 | metric_hook_list=[val_sche_metric_hook], 128 | ): 129 | _, _, loss = trainer.execute_schedule( 130 | batch, forward_only=True, return_loss=True, return_output_label=False 131 | ) 132 | if verbose: 133 | if isinstance(loss, dict): 134 | loss = sum(loss.values()) 135 | val_loss += loss.item() 136 | 137 | assert val_idx != -1 138 | dist.barrier() 139 | 140 | val_res = val_metric.get_metric() 141 | if verbose and (streaming or len(val_dl) != 0): 142 | val_loss = val_loss / (val_idx + 1 + 1e-6) 143 | infos = { 144 | "step": step_count, 145 | f"val/{val_name}_loss": val_loss, 146 | f"val/{val_name}_acc": val_res["acc"], 147 | f"val/{val_name}_plex": val_res["perplexity"], 148 | } 149 | 150 | for key, value in infos.items(): 151 | writer.add_scalar(key=key, value=value, step=step_count) 152 | 153 | if update_panel: 154 | logger.info( 155 | f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]), 156 | extra={ 157 | "step": step_count, 158 | "val_loss": val_loss, 159 | "val_acc": val_res["acc"], 160 | "val_perplexity": val_res["perplexity"], 161 | }, 162 | ) 163 | else: 164 | logger.info( 165 | f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]) 166 | ) 167 | 168 | trainer.train() 169 | torch.cuda.empty_cache() 170 | dist.barrier() 171 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import logging 5 | import os 6 | 7 | LOGGER_NAME = "internlm" 8 | LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s" 9 | LOGGER_LEVEL = "info" 10 | LOGGER_LEVEL_CHOICES = ["debug", "info", "warning", "error", "critical"] 11 | LOGGER_LEVEL_HELP = ( 12 | "The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'" 13 | ) 14 | 15 | uniscale_logger = None 16 | 17 | 18 | def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger: 19 | """Configure the logger that is used for uniscale framework. 20 | 21 | Args: 22 | logger_name (str): used to create or get the correspoding logger in 23 | getLogger call. It will be "internlm" by default. 24 | logging_level (str, optional): Logging level in string or logging enum. 25 | 26 | Returns: 27 | logger (logging.Logger): the created or modified logger. 28 | 29 | """ 30 | 31 | if uniscale_logger is not None: 32 | return uniscale_logger 33 | 34 | logger = logging.getLogger(logger_name) 35 | 36 | if logging_level not in LOGGER_LEVEL_CHOICES: 37 | logging_level = LOGGER_LEVEL 38 | print(LOGGER_LEVEL_HELP) 39 | 40 | logging_level = logging.getLevelName(logging_level.upper()) 41 | 42 | handler = logging.StreamHandler() 43 | handler.setLevel(logging_level) 44 | logger.setLevel(logging_level) 45 | handler.setFormatter(logging.Formatter(LOGGER_FORMAT)) 46 | logger.addHandler(handler) 47 | 48 | return logger 49 | 50 | 51 | def initialize_uniscale_logger( 52 | job_name: str = None, 53 | launch_time: str = None, 54 | file_name: str = None, 55 | name: str = LOGGER_NAME, 56 | level: str = LOGGER_LEVEL, 57 | file_path: str = None, 58 | is_std: bool = True, 59 | ): 60 | """ 61 | Initialize uniscale logger. 62 | 63 | Args: 64 | job_name (str): The name of training job, defaults to None. 65 | launch_time (str): The launch time of training job, defaults to None. 66 | file_name (str): The log file name, defaults to None. 67 | name (str): The logger name, defaults to "internlm". 68 | level (str): The log level, defaults to "info". 69 | file_path (str): The log file path, defaults to None. 70 | is_std (bool): Whether to output to console, defaults to True. 71 | 72 | Returns: 73 | Uniscale logger instance. 74 | """ 75 | 76 | try: 77 | from uniscale_monitoring import get_logger as get_uniscale_logger 78 | except ImportError: 79 | print("Failed to import module uniscale_monitoring. Use default python logger.") 80 | return None 81 | 82 | if not file_path: 83 | assert ( 84 | job_name and launch_time and file_name 85 | ), "If file_path is None, job_name, launch_time and file_name must be setted." 86 | log_file_name = file_name 87 | log_folder = os.path.join("RUN", job_name, launch_time, "logs") 88 | log_dir = os.path.join(log_folder, log_file_name) 89 | file_path = log_dir 90 | 91 | logger = get_uniscale_logger(name=name, level=level, filename=file_path, is_std=is_std) 92 | if isinstance(logger, (list, tuple)): 93 | logger = list(logger)[0] 94 | 95 | global uniscale_logger 96 | uniscale_logger = logger 97 | 98 | return logger 99 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/megatron_timers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import time 5 | 6 | import torch 7 | 8 | 9 | class _Timer: 10 | """Timer.""" 11 | 12 | def __init__(self, name): 13 | self.name_ = name 14 | self.elapsed_ = 0.0 15 | self.started_ = False 16 | self.start_time = time.time() 17 | self.stream = torch.cuda.current_stream() 18 | 19 | def start(self, reset_all=True): 20 | """Start the timer.""" 21 | # need to reset all timers in a new batch 22 | if self.name_ == "one-batch" and reset_all is True: 23 | megatron_timer.reset() 24 | 25 | assert not self.started_, "timer has already been started" 26 | self.stream.synchronize() 27 | self.start_time = time.time() 28 | self.started_ = True 29 | 30 | def stop(self): 31 | """Stop the timer.""" 32 | assert self.started_, "timer is not started" 33 | self.stream.synchronize() 34 | self.elapsed_ += time.time() - self.start_time 35 | self.started_ = False 36 | 37 | def reset(self): 38 | """Reset timer.""" 39 | self.elapsed_ = 0.0 40 | self.started_ = False 41 | 42 | def elapsed(self, reset=True): 43 | """Calculate the elapsed time.""" 44 | started_ = self.started_ 45 | # If the timing in progress, end it first. 46 | if self.started_: 47 | self.stop() 48 | # Get the elapsed time. 49 | elapsed_ = self.elapsed_ 50 | # Reset the elapsed time 51 | if reset: 52 | self.reset() 53 | # If timing was in progress, set it back. 54 | if started_: 55 | self.start(reset_all=False) 56 | return elapsed_ 57 | 58 | 59 | class Timers: 60 | """Group of timers.""" 61 | 62 | def __init__(self): 63 | self.timers = {} 64 | self.hist = {} 65 | self.names = [] 66 | self.times = [] 67 | 68 | def __call__(self, name): 69 | if name not in self.timers: 70 | self.timers[name] = _Timer(name) 71 | return self.timers[name] 72 | 73 | def store_last_timers(self): 74 | """Store timers to two list""" 75 | self.names = [] 76 | self.times = [] 77 | for key, value in self.timers.items(): 78 | senconds = round(float(value.elapsed(reset=False)), 4) 79 | self.names.append(key) 80 | self.times.append(senconds) 81 | if key not in self.hist: 82 | self.hist[key] = [] 83 | self.hist[key].append(senconds) 84 | if len(self.hist[key]) > 10: 85 | self.hist[key].pop(0) 86 | 87 | def write(self, names, writer, iteration, normalizer=1.0, reset=False): 88 | """Write timers to a tensorboard writer""" 89 | # currently when using add_scalars, 90 | # torch.utils.add_scalars makes each timer its own run, which 91 | # polutes the runs list, so we just add each as a scalar 92 | assert normalizer > 0.0 93 | for name in names: 94 | if name in self.timers: 95 | value = self.timers[name].elapsed(reset=reset) / normalizer 96 | writer.add_scalar(f"time/{name}-time", value, iteration) 97 | 98 | def log(self, names, logger, normalizer=1.0, reset=True): 99 | """Log a group of timers.""" 100 | assert normalizer > 0.0 101 | string = "" 102 | for name in names: 103 | if name in self.timers: 104 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer 105 | string += " | {}: {:.2f}".format(name, elapsed_time) 106 | if not len(string): # pylint: disable=C1802 107 | return 108 | string = "time (ms)" + string 109 | 110 | logger.info(string) 111 | return string 112 | 113 | def debug(self, names, logger, normalizer=1.0, reset=True): 114 | """Log a group of timers.""" 115 | assert normalizer > 0.0 116 | string = "" 117 | for name in names: 118 | if name in self.timers: 119 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer 120 | string += " | {}: {:.2f}".format(name, elapsed_time) 121 | if not len(string): # pylint: disable=C1802 122 | return 123 | string = "time (ms)" + string 124 | 125 | logger.debug(string) 126 | return string 127 | 128 | def reset(self): 129 | for _, t in self.timers.items(): 130 | t.reset() 131 | 132 | 133 | megatron_timer = Timers() 134 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch.distributed as dist 5 | 6 | from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode 7 | from internlm.core.context import global_context as gpc 8 | 9 | 10 | def is_model_parallel_parameter(p): 11 | return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) 12 | 13 | 14 | def sync_model_param(model, parallel_mode): 15 | r"""Make sure data parameters are consistent during Data Parallel Mode. 16 | 17 | Args: 18 | model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. 19 | parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked. 20 | """ 21 | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: 22 | for param in model.parameters(): 23 | ranks = gpc.get_ranks_in_group(parallel_mode) 24 | dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) 25 | 26 | 27 | def sync_model_param_within_tp(model): 28 | r"""This function is changed from colossalai, which is ``sync_model_param``. 29 | 30 | We modified this function to make sure it only sync parameters within tensor parallelism 31 | but they are not splitted by tensor parallelism. 32 | This function is used to make sure parameters that are not splitted by tensor parallelism 33 | are the same across each tensor parallelism. 34 | For tools, parameters like RMSNorm, LayerNorm... 35 | 36 | Args: 37 | model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. 38 | """ 39 | parallel_mode = ParallelMode.TENSOR 40 | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: 41 | for param in model.parameters(): 42 | if not is_model_parallel_parameter(param): 43 | ranks = gpc.get_ranks_in_group(parallel_mode) 44 | dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) 45 | 46 | 47 | def is_no_pp_or_last_stage(): 48 | return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) 49 | 50 | 51 | def get_parallel_log_file_name(): 52 | if gpc.is_rank_for_log(): 53 | fn_prefix = "main_" # Indicates a rank with more output information 54 | else: 55 | fn_prefix = "" 56 | 57 | log_file_name = ( 58 | f"{fn_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_" 59 | f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}" 60 | ) 61 | return log_file_name 62 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | class Registry: 6 | """This is a registry class used to register classes and modules so that a universal 7 | object builder can be enabled. 8 | 9 | Args: 10 | name (str): The name of the registry. 11 | """ 12 | 13 | def __init__(self, name: str): 14 | self._name = name 15 | self._registry = dict() 16 | 17 | @property 18 | def name(self): 19 | return self._name 20 | 21 | def register_module(self, module_name: str): 22 | """Registers a module represented in `module_class`. 23 | 24 | Args: 25 | module_name (str): The name of module to be registered. 26 | Returns: 27 | function: The module to be registered, so as to use it normally if via importing. 28 | Raises: 29 | AssertionError: Raises an AssertionError if the module has already been registered before. 30 | """ 31 | 32 | assert module_name not in self._registry, f"{module_name} not found in {self.name}" 33 | 34 | def decorator_wrapper(original_func): 35 | self._registry[module_name] = original_func 36 | return original_func 37 | 38 | return decorator_wrapper 39 | 40 | def get_module(self, module_name: str): 41 | """Retrieves a module with name `module_name` and returns the module if it has 42 | already been registered before. 43 | 44 | Args: 45 | module_name (str): The name of the module to be retrieved. 46 | Returns: 47 | :class:`object`: The retrieved module or None. 48 | Raises: 49 | NameError: Raises a NameError if the module to be retrieved has neither been 50 | registered directly nor as third party modules before. 51 | """ 52 | if module_name in self._registry: 53 | return self._registry[module_name] 54 | raise NameError(f"Module {module_name} not found in the registry {self.name}") 55 | 56 | def has(self, module_name: str): 57 | """Searches for a module with name `module_name` and returns a boolean value indicating 58 | whether the module has been registered directly or as third party modules before. 59 | 60 | Args: 61 | module_name (str): The name of the module to be searched for. 62 | Returns: 63 | bool: A boolean value indicating whether the module has been registered directly or 64 | as third party modules before. 65 | """ 66 | found_flag = module_name in self._registry 67 | 68 | return found_flag 69 | 70 | 71 | MODEL_INITIALIZER = Registry("model_initializer") 72 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/timeout.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import signal 4 | import socket 5 | import traceback 6 | from functools import wraps 7 | 8 | from internlm.utils.logger import get_logger 9 | 10 | logger = get_logger(__file__) 11 | 12 | 13 | class Timeout: 14 | """Timer to execute code 15 | 16 | Adapted from https://github.com/reasoning-machines/pal 17 | 18 | Args: 19 | seconds (float): The maximum seconds to execute code 20 | error_message (str) 21 | """ 22 | 23 | def __init__(self, seconds=1, error_message="Timeout"): 24 | self.seconds = seconds 25 | self.error_message = error_message 26 | 27 | def timeout_handler(self, signum, frame): 28 | raise TimeoutError(self.error_message) 29 | 30 | def __enter__(self): 31 | signal.signal(signal.SIGALRM, self.timeout_handler) 32 | signal.alarm(self.seconds) 33 | 34 | def __exit__(self, error_type, value, traceback): 35 | signal.alarm(0) 36 | 37 | 38 | ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None) 39 | 40 | 41 | timeout_threshold_dict = { 42 | "initialize_distributed_env": 120, 43 | "nopp_forward_backward_step": 360, 44 | "initialize_model": 10, 45 | "initialize_optimizer": 20, 46 | "optim_step": 30, 47 | "get_train_data_loader": 600, 48 | "get_validation_data_loader": 60, 49 | "load_new_batch": 10, 50 | "record_current_batch_training_metrics": 10, 51 | "save_checkpoint": 1200, 52 | "interleaved_forward_backward_step": 600, 53 | "nointerleaved_forward_backward_step": 600, 54 | } 55 | 56 | if ENABLE_TIMEOUT is not None: 57 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" 58 | LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60)))) 59 | else: 60 | timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0) 61 | LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800) 62 | 63 | 64 | def try_get_gpc_rank(): 65 | try: 66 | from internlm.core.context import global_context as gpc 67 | 68 | rank = gpc.get_global_rank() 69 | except: # noqa # pylint: disable=bare-except 70 | rank = "unknown" 71 | 72 | return f"host-{socket.gethostname()}-rank-{rank}" 73 | 74 | 75 | def llm_timeout(seconds=0, func_name=None): 76 | """timeout decorator, Note that this decorator cannot be reentrant, 77 | otherwise the signal will be reset. 78 | 79 | Args: 80 | seconds (int, optional): timeout threshold. Defaults to 300. 81 | func_name (str, optional): the func who is been waited to timeout. 82 | """ 83 | 84 | def decorator(func): 85 | nonlocal func_name 86 | if func_name is None: 87 | func_name = func.__name__ 88 | 89 | @wraps(func) 90 | def wrapper(*args, **kwargs): 91 | def _handle_timeout(signum, frame): 92 | raise TimeoutError 93 | 94 | nonlocal seconds 95 | seconds = timeout_threshold_dict.get(func_name, seconds) 96 | 97 | if seconds > 0: 98 | signal.signal(signal.SIGALRM, _handle_timeout) 99 | signal.alarm(seconds) 100 | 101 | try: 102 | result = func(*args, **kwargs) 103 | except TimeoutError as e: 104 | logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}") 105 | raise e 106 | finally: 107 | signal.alarm(0) 108 | 109 | return result 110 | 111 | return wrapper 112 | 113 | return decorator 114 | -------------------------------------------------------------------------------- /InternLM/internlm/utils/writer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socket 4 | import sys 5 | import traceback 6 | from functools import partial 7 | 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from internlm.core.context import global_context as gpc 12 | 13 | 14 | def tb_save_run_info(writer, config_lines, global_step=0): 15 | writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step) 16 | lines = [] 17 | for line in config_lines: 18 | if line.strip().startswith("#"): 19 | continue 20 | lines.append(line) 21 | writer.add_text(tag="config", text_string="\n".join(lines), global_step=global_step) 22 | 23 | 24 | def init_tb_writer( 25 | job_name: str, 26 | launch_time: str, 27 | file_name: str, 28 | tensorboard_folder: str, 29 | resume_tb_folder: str, 30 | step_count: int, 31 | config: str, 32 | logger: logging.Logger, 33 | ): 34 | tb_log_file_name = file_name 35 | if not tensorboard_folder: 36 | tb_folder = os.path.join(job_name, launch_time, "tensorboards") 37 | else: 38 | tb_folder = tensorboard_folder 39 | 40 | if gpc.get_global_rank() == 0: 41 | # If we don't load ckpt, 'resume_tb_folder' is set as the tensorboard 42 | # dir of the last task by 'make_launch_script.sh'. 43 | # If we load ckpt, 'resume_tb_folder' will be overwritten as the 44 | # reloaded 'train_state.resume_tb_folder'.s 45 | if resume_tb_folder is not None: 46 | assert len(resume_tb_folder) > 0 and resume_tb_folder != "/" 47 | if not os.path.exists(resume_tb_folder): 48 | logger.error( 49 | f"Can't found resume_tb_folder{resume_tb_folder}, \ 50 | please make sure this folder is located at local file system." 51 | ) 52 | else: 53 | logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}... ") 54 | os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/") 55 | os.system(f"chmod -R +w {tb_folder}/") 56 | else: 57 | logger.info(f"Login tensorboard logs to: {tb_folder}") 58 | 59 | tb_logdir = os.path.join(tb_folder, tb_log_file_name) 60 | writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3) 61 | writer.add_text(tag="job_name", text_string=job_name, global_step=step_count) 62 | writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count) 63 | 64 | torch.distributed.broadcast_object_list([tb_folder], src=0) 65 | else: 66 | objects = [None] 67 | torch.distributed.broadcast_object_list(objects, src=0) 68 | tb_folder = objects[0] 69 | tb_logdir = os.path.join(tb_folder, tb_log_file_name) 70 | writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3) 71 | 72 | if gpc.is_rank_for_log(): 73 | tb_save_run_info( 74 | writer=writer, 75 | config_lines=config, 76 | global_step=step_count, 77 | ) 78 | 79 | writer.add_text( 80 | tag=f"mapping_{tb_log_file_name}", 81 | text_string=f"file_path={tb_logdir} hostname={socket.gethostname()} device={torch.cuda.current_device()}", 82 | global_step=step_count, 83 | ) 84 | writer.add_scaler = partial(writer.add_scalar, new_style=True) 85 | 86 | return writer, tb_logdir 87 | 88 | 89 | class Writer: 90 | """ 91 | Customed writer based on tensorboard for recording training metrics. 92 | 93 | Args: 94 | job_name (str): The name of training job, defaults to None. 95 | launch_time (str): A string representing the launch time of the training. 96 | file_name (str): The log file name, defaults to None. 97 | tensorboard_folder (str): A string representing the folder for saving tensorboard logs. 98 | resume_tb_folder (str): A string representing the folder for resuming tensorboard logs. 99 | step_count (int): An integer representing the step count of the training. 100 | config (str): A string representing the configuration of the training. 101 | logger (logging.Logger): A logging.Logger object for logging information during training. 102 | enable_tb (bool): A boolean indicating whether to enable the tensorboard writer. 103 | 104 | """ 105 | 106 | def __init__( 107 | self, 108 | job_name: str = None, 109 | launch_time: str = None, 110 | file_name: str = None, 111 | tensorboard_folder: str = None, 112 | resume_tb_folder: str = None, 113 | step_count: int = 0, 114 | config: str = None, 115 | logger: logging.Logger = None, 116 | enable_tb: bool = True, 117 | ) -> None: 118 | self.enable_tb = enable_tb 119 | self.tb_writer, self.tb_logdir = init_tb_writer( 120 | job_name=job_name, 121 | launch_time=launch_time, 122 | file_name=file_name, 123 | tensorboard_folder=tensorboard_folder, 124 | resume_tb_folder=resume_tb_folder, 125 | step_count=step_count, 126 | config=config, 127 | logger=logger, 128 | ) 129 | 130 | def add_scalar(self, key, value, step): 131 | try: 132 | if self.enable_tb and self.tb_writer is not None: 133 | self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step) 134 | except Exception: 135 | traceback.print_exc() 136 | 137 | def add_scalars(self, key, value, step): 138 | try: 139 | assert isinstance(value, dict) 140 | if self.enable_tb and self.tb_writer is not None: 141 | self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step) 142 | except Exception: 143 | traceback.print_exc() 144 | 145 | def add_text(self, key, value, step): 146 | try: 147 | if self.enable_tb and self.tb_writer is not None: 148 | self.tb_writer.add_text(tag=key, text_string=value, global_step=step) 149 | except Exception: 150 | traceback.print_exc() 151 | -------------------------------------------------------------------------------- /InternLM/requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | transformers<4.30.0 2 | sentencepiece 3 | numpy 4 | tqdm 5 | psutil 6 | packaging 7 | pre-commit 8 | ninja 9 | gputil 10 | pytest 11 | packaging 12 | boto3 13 | botocore 14 | torch-scatter 15 | pyecharts 16 | -f https://data.pyg.org/whl/torch-1.13.1+cu117.html -------------------------------------------------------------------------------- /InternLM/requirements/torch.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==1.13.1+cu117 3 | torchvision==0.14.1+cu117 4 | torchaudio==0.13.1 -------------------------------------------------------------------------------- /InternLM/tools/convert2hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import re 6 | import tempfile 7 | import sys 8 | 9 | import torch 10 | 11 | from model_hf.modeling_internlm import InternLMConfig, InternLMForCausalLM 12 | sys.path.append('../') 13 | 14 | def convert2hf(model_config, states_tp_pps): 15 | with tempfile.TemporaryDirectory() as folder: 16 | states = merge_pp(states_tp_pps)[0] 17 | 18 | dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"] 19 | base = 10000.0 20 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 21 | 22 | current_states = {} 23 | 24 | vq_model_embed_weight = states.pop('embedding.vq_model.quantize.embedding.weight') 25 | embed_proj_weight = states.pop('embedding.embed_proj.weight') 26 | current_states["model.embed_tokens.weight"] = vq_model_embed_weight.mm(embed_proj_weight.T) 27 | current_states["model.norm.weight"] = states.pop("norm.weight") 28 | current_states["lm_head.weight"] = states.pop("head.weight") 29 | 30 | for i in range(model_config["num_layers"]): 31 | states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None) 32 | 33 | wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( 34 | 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] 35 | ) 36 | bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1) 37 | 38 | current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape( 39 | -1, model_config["hidden_size"] 40 | ) 41 | current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1) 42 | current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape( 43 | -1, model_config["hidden_size"] 44 | ) 45 | current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1) 46 | current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape( 47 | -1, model_config["hidden_size"] 48 | ) 49 | current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1) 50 | 51 | current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop( 52 | f"blocks.{i}.mixer.out_proj.weight" 53 | ) 54 | current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias") 55 | 56 | current_states[f"model.layers.{i}.mlp.gate_proj.weight"] = states.pop(f"blocks.{i}.mlp.w1.weight") 57 | current_states[f"model.layers.{i}.mlp.down_proj.weight"] = states.pop(f"blocks.{i}.mlp.w3.weight") 58 | current_states[f"model.layers.{i}.mlp.up_proj.weight"] = states.pop(f"blocks.{i}.mlp.w2.weight") 59 | 60 | current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight") 61 | current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight") 62 | current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq 63 | 64 | config = InternLMConfig( 65 | hidden_size=model_config["hidden_size"], 66 | intermediate_size=compute_intermediate_size(model_config["hidden_size"]), 67 | num_attention_heads=model_config["num_attention_heads"], 68 | num_hidden_layers=model_config["num_layers"], 69 | rms_norm_eps=1e-06, 70 | bias=True, 71 | ) 72 | 73 | if model_config["vocab_size"] != -1: 74 | config.vocab_size = model_config["vocab_size"] 75 | 76 | config.save_pretrained(folder) 77 | torch.save(current_states, os.path.join(folder, "pytorch_model.bin")) 78 | 79 | model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16) 80 | del model.config._name_or_path 81 | 82 | return config, model 83 | 84 | 85 | def compute_intermediate_size(n): 86 | return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 87 | 88 | 89 | def merge_pp(states_tp_pp): 90 | max_tp = len(states_tp_pp) 91 | max_pp = len(states_tp_pp[0]) 92 | 93 | full_states = [] 94 | for tp in range(max_tp): 95 | layer_shift = 0 96 | 97 | tp_states = {} 98 | for pp in range(max_pp): 99 | _layer_shift = 0 100 | states = states_tp_pp[tp][pp] 101 | keys = list(states.keys()) 102 | for key in keys: 103 | match = re.search("\.\d+\.", key) 104 | if match is not None: 105 | s, e = match.span() 106 | layer_idx = int(key[s + 1: e - 1]) + layer_shift 107 | _layer_shift = max(_layer_shift, int(key[s + 1: e - 1])) 108 | name = key[:s] + f".{layer_idx}." + key[e:] 109 | tp_states[name] = states[key] 110 | else: 111 | tp_states[key] = states[key] 112 | layer_shift += _layer_shift + 1 113 | full_states.append({(key[6:] if key.startswith("model.") else key): value for key, value in tp_states.items()}) 114 | return full_states 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--src_folder', type=str, default='/path/to/intermlm_model/') # internlm model folder 120 | parser.add_argument('--tgt_folder', type=str, default='/path/to/hf_model/') # hf model folder 121 | args = parser.parse_args() 122 | 123 | 124 | def load(fp): 125 | with open(fp, "rb") as f: 126 | pt_data = torch.load(f, map_location="cpu") 127 | return pt_data 128 | 129 | 130 | folder = args.src_folder 131 | target_folder = args.tgt_folder 132 | model_config = load(os.path.join(folder, "model_config.pt")) 133 | 134 | fns = list(os.listdir(folder)) 135 | 136 | model_fns = [] 137 | for fn in fns: 138 | if fn.startswith("model_t") and not fn.endswith("md5"): 139 | model_fns.append(fn) 140 | 141 | max_tp, max_pp = -1, -1 142 | for fn in model_fns: 143 | _, tp, pp = os.path.splitext(fn)[0].split("_") 144 | max_pp = max(max_pp, int(pp[2:]) + 1) 145 | max_tp = max(max_tp, int(tp[2:]) + 1) 146 | 147 | states_tp_pps = [[]] 148 | 149 | for pp in range(max_pp): 150 | model_name = f"model_tp0_pp{pp}.pt" 151 | states = load(os.path.join(folder, model_name)) 152 | states_tp_pps[0].append(states) 153 | 154 | config, model = convert2hf(model_config, states_tp_pps) 155 | 156 | os.makedirs(target_folder, exist_ok=True) 157 | model.save_pretrained(target_folder, max_shard_size="20GB") 158 | # TODO There should be a better way to add this. 159 | with open(os.path.join(target_folder, "config.json")) as fp: 160 | config_dict = json.load(fp) 161 | config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMForCausalLM" 162 | with open(os.path.join(target_folder, "config.json"), "w") as fp: 163 | json.dump(config_dict, fp, indent=2) 164 | -------------------------------------------------------------------------------- /InternLM/tools/convert2hf_vit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import tempfile 6 | import sys 7 | 8 | import torch 9 | 10 | from model_hf.modeling_vit import InternLMConfig, InternLMForCausalLM 11 | 12 | sys.path.append('../') 13 | 14 | def convert2hf(model_config, states_tp_pps): 15 | with tempfile.TemporaryDirectory() as folder: 16 | states = merge_pp(states_tp_pps)[0] 17 | 18 | dims_per_head = model_config["hidden_size"] // model_config["num_attention_heads"] 19 | base = 10000.0 20 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 21 | 22 | current_states = {} 23 | 24 | vq_model_embed_weight = states.pop('embedding.vq_model.quantize.embedding.weight') 25 | embed_proj_weight = states.pop('embedding.embed_proj.weight') 26 | current_states["model.embed_tokens.weight"] = vq_model_embed_weight.mm(embed_proj_weight.T) 27 | current_states["model.norm.weight"] = states.pop("norm.weight") 28 | current_states["model.norm.bias"] = states.pop("norm.bias") 29 | current_states["lm_head.weight"] = states.pop("head.weight") 30 | 31 | mlp_bias = False 32 | for i in range(model_config["num_layers"]): 33 | states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None) 34 | 35 | wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( 36 | 3, model_config["num_attention_heads"], -1, model_config["hidden_size"] 37 | ) 38 | bqkv = states.pop(f"blocks.{i}.mixer.Wqkv.bias").reshape(3, model_config["num_attention_heads"], -1) 39 | 40 | current_states[f"model.layers.{i}.self_attn.q_proj.weight"] = wqkv[0].reshape( 41 | -1, model_config["hidden_size"] 42 | ) 43 | current_states[f"model.layers.{i}.self_attn.q_proj.bias"] = bqkv[0].reshape(-1) 44 | current_states[f"model.layers.{i}.self_attn.k_proj.weight"] = wqkv[1].reshape( 45 | -1, model_config["hidden_size"] 46 | ) 47 | current_states[f"model.layers.{i}.self_attn.k_proj.bias"] = bqkv[1].reshape(-1) 48 | current_states[f"model.layers.{i}.self_attn.v_proj.weight"] = wqkv[2].reshape( 49 | -1, model_config["hidden_size"] 50 | ) 51 | current_states[f"model.layers.{i}.self_attn.v_proj.bias"] = bqkv[2].reshape(-1) 52 | 53 | current_states[f"model.layers.{i}.self_attn.o_proj.weight"] = states.pop( 54 | f"blocks.{i}.mixer.out_proj.weight" 55 | ) 56 | current_states[f"model.layers.{i}.self_attn.o_proj.bias"] = states.pop(f"blocks.{i}.mixer.out_proj.bias") 57 | 58 | current_states[f"model.layers.{i}.mlp.fc1.weight"] = states.pop(f"blocks.{i}.mlp.fc1.weight") 59 | current_states[f"model.layers.{i}.mlp.fc2.weight"] = states.pop(f"blocks.{i}.mlp.fc2.weight") 60 | 61 | if f'blocks.{i}.mlp.fc1.bias' in states: 62 | mlp_bias = True 63 | current_states[f"model.layers.{i}.mlp.fc1.bias"] = states.pop(f"blocks.{i}.mlp.fc1.bias") 64 | current_states[f"model.layers.{i}.mlp.fc2.bias"] = states.pop(f"blocks.{i}.mlp.fc2.bias") 65 | 66 | current_states[f"model.layers.{i}.input_layernorm.weight"] = states.pop(f"blocks.{i}.norm1.weight") 67 | current_states[f"model.layers.{i}.input_layernorm.bias"] = states.pop(f"blocks.{i}.norm1.bias") 68 | current_states[f"model.layers.{i}.post_attention_layernorm.weight"] = states.pop(f"blocks.{i}.norm2.weight") 69 | current_states[f"model.layers.{i}.post_attention_layernorm.bias"] = states.pop(f"blocks.{i}.norm2.bias") 70 | current_states[f"model.layers.{i}.self_attn.rotary_emb.inv_freq"] = inv_freq 71 | 72 | config = InternLMConfig( 73 | hidden_size=model_config["hidden_size"], 74 | intermediate_size=int(model_config["hidden_size"] * model_config["mlp_ratio"]), 75 | num_attention_heads=model_config["num_attention_heads"], 76 | num_hidden_layers=model_config["num_layers"], 77 | norm_eps=1e-06, 78 | bias=True, 79 | mlp_bias=mlp_bias, 80 | ) 81 | 82 | if model_config["vocab_size"] != -1: 83 | config.vocab_size = model_config["vocab_size"] 84 | 85 | config.save_pretrained(folder) 86 | torch.save(current_states, os.path.join(folder, "pytorch_model.bin")) 87 | 88 | model = InternLMForCausalLM.from_pretrained(folder, torch_dtype=torch.float16) 89 | del model.config._name_or_path 90 | 91 | return config, model 92 | 93 | 94 | def merge_pp(states_tp_pp): 95 | max_tp = len(states_tp_pp) 96 | max_pp = len(states_tp_pp[0]) 97 | 98 | full_states = [] 99 | for tp in range(max_tp): 100 | layer_shift = 0 101 | 102 | tp_states = {} 103 | for pp in range(max_pp): 104 | _layer_shift = 0 105 | states = states_tp_pp[tp][pp] 106 | keys = list(states.keys()) 107 | for key in keys: 108 | match = re.search("\.\d+\.", key) 109 | if match is not None: 110 | s, e = match.span() 111 | layer_idx = int(key[s + 1: e - 1]) + layer_shift 112 | _layer_shift = max(_layer_shift, int(key[s + 1: e - 1])) 113 | name = key[:s] + f".{layer_idx}." + key[e:] 114 | tp_states[name] = states[key] 115 | else: 116 | tp_states[key] = states[key] 117 | layer_shift += _layer_shift + 1 118 | full_states.append({(key[6:] if key.startswith("model.") else key): value for key, value in tp_states.items()}) 119 | return full_states 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--src_folder', type=str, default='/path/to/intermlm_model/') # internlm model folder 125 | parser.add_argument('--tgt_folder', type=str, default='/path/to/hf_model/') # hf model folder 126 | args = parser.parse_args() 127 | 128 | 129 | def load(fp): 130 | with open(fp, "rb") as f: 131 | pt_data = torch.load(f, map_location="cpu") 132 | return pt_data 133 | 134 | 135 | folder = args.src_folder 136 | target_folder = args.tgt_folder 137 | model_config = load(os.path.join(folder, "model_config.pt")) 138 | 139 | fns = list(os.listdir(folder)) 140 | 141 | model_fns = [] 142 | for fn in fns: 143 | if fn.startswith("model_t") and not fn.endswith("md5"): 144 | model_fns.append(fn) 145 | 146 | max_tp, max_pp = -1, -1 147 | for fn in model_fns: 148 | _, tp, pp = os.path.splitext(fn)[0].split("_") 149 | max_pp = max(max_pp, int(pp[2:]) + 1) 150 | max_tp = max(max_tp, int(tp[2:]) + 1) 151 | 152 | states_tp_pps = [[]] 153 | 154 | for pp in range(max_pp): 155 | model_name = f"model_tp0_pp{pp}.pt" 156 | states = load(os.path.join(folder, model_name)) 157 | states_tp_pps[0].append(states) 158 | 159 | config, model = convert2hf(model_config, states_tp_pps) 160 | 161 | os.makedirs(target_folder, exist_ok=True) 162 | model.save_pretrained(target_folder, max_shard_size="20GB") 163 | # TODO There should be a better way to add this. 164 | with open(os.path.join(target_folder, "config.json")) as fp: 165 | config_dict = json.load(fp) 166 | config_dict["auto_map"]["AutoModel"] = "modeling_vit.InternLMForCausalLM" 167 | with open(os.path.join(target_folder, "config.json"), "w") as fp: 168 | json.dump(config_dict, fp, indent=2) 169 | -------------------------------------------------------------------------------- /InternLM/tools/data/derain_prompt/000000_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/derain_prompt/000000_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/derain_prompt/000000_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/derain_prompt/000000_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/derain_prompt/000001_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/derain_prompt/000001_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/derain_prompt/000001_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/derain_prompt/000001_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/derain_prompt/000002_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/derain_prompt/000002_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/derain_prompt/000002_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/derain_prompt/000002_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/examples/derain_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/examples/derain_1.png -------------------------------------------------------------------------------- /InternLM/tools/data/examples/derain_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/examples/derain_2.png -------------------------------------------------------------------------------- /InternLM/tools/data/examples/pose_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/examples/pose_1.png -------------------------------------------------------------------------------- /InternLM/tools/data/examples/pose_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/examples/pose_2.png -------------------------------------------------------------------------------- /InternLM/tools/data/examples/seg_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/examples/seg_1.png -------------------------------------------------------------------------------- /InternLM/tools/data/examples/seg_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/examples/seg_2.png -------------------------------------------------------------------------------- /InternLM/tools/data/pose_prompt/000000_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/pose_prompt/000000_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/pose_prompt/000000_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/pose_prompt/000000_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/pose_prompt/000001_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/pose_prompt/000001_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/pose_prompt/000001_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/pose_prompt/000001_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/pose_prompt/000002_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/pose_prompt/000002_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/pose_prompt/000002_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/pose_prompt/000002_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/seg_prompt/000000_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/seg_prompt/000000_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/seg_prompt/000000_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/seg_prompt/000000_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/seg_prompt/000001_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/seg_prompt/000001_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/seg_prompt/000001_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/seg_prompt/000001_label.png -------------------------------------------------------------------------------- /InternLM/tools/data/seg_prompt/000002_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/seg_prompt/000002_img.png -------------------------------------------------------------------------------- /InternLM/tools/data/seg_prompt/000002_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/data/seg_prompt/000002_label.png -------------------------------------------------------------------------------- /InternLM/tools/model_hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/InternLM/tools/model_hf/__init__.py -------------------------------------------------------------------------------- /InternLM/tools/model_hf/muse/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | from .modeling_taming_vqgan import VQGANModel 19 | -------------------------------------------------------------------------------- /InternLM/tools/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from torchvision import transforms 5 | 6 | encode_transform = transforms.Compose( 7 | [ 8 | transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), 9 | transforms.CenterCrop(256), 10 | transforms.ToTensor(), 11 | ] 12 | ) 13 | 14 | 15 | def convert_decode_to_pil(rec_image): 16 | rec_image = 2.0 * rec_image - 1.0 17 | rec_image = torch.clamp(rec_image, -1.0, 1.0) 18 | rec_image = (rec_image + 1.0) / 2.0 19 | rec_image *= 255.0 20 | rec_image = rec_image.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) 21 | pil_images = [Image.fromarray(image) for image in rec_image] 22 | return pil_images 23 | 24 | 25 | def patchify(imgs, p): 26 | """ 27 | imgs: (N, C, H, W) 28 | x: (N, L, patch_size**2 * C) 29 | """ 30 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 31 | 32 | in_chans = imgs.shape[1] 33 | h = w = imgs.shape[2] // p 34 | x = imgs.reshape(shape=(imgs.shape[0], in_chans, h, p, w, p)) 35 | x = torch.einsum('nchpwq->nhwpqc', x) 36 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * in_chans)) 37 | return x 38 | 39 | 40 | def unpatchify(x, p): 41 | """ 42 | x: (N, L, patch_size**2 * C) 43 | imgs: (N, C, H, W) 44 | """ 45 | # p = self.patch_embed.patch_size[0] 46 | h = w = int(x.shape[1] ** .5) 47 | assert h * w == x.shape[1] 48 | 49 | x = x.reshape(shape=(x.shape[0], h, w, p, p, -1)) 50 | x = torch.einsum('nhwpqc->nchpwq', x) 51 | imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p)) 52 | return imgs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Implementation of "[Data-efficient Large Vision Models through Sequential Autoregression](https://arxiv.org/pdf/2402.04841.pdf)" [ICML 2024]. 2 | 3 | 4 |
5 |
6 |
8 |
9 | Training general-purpose vision models on purely sequential visual data, eschewing linguistic inputs, has heralded a new frontier in visual understanding. These models are intended to not only comprehend but also seamlessly transit to out-of-domain tasks. 10 | However, current endeavors are hamstrung by an over-reliance on colossal models, exemplified by models with upwards of 3B parameters, and the necessity for an extensive corpus of visual data, often comprising a staggering 400B tokens. 11 | In this paper, we delve into the development of an efficient, autoregression-based vision model, innovatively architected to operate on a limited dataset. We meticulously demonstrate how this model achieves proficiency in a spectrum of visual tasks spanning both high-level and low-level semantic understanding during the testing phase. Our empirical evaluations underscore the model's agility in adapting to various tasks, heralding a significant reduction in the parameter footprint, and a marked decrease in training data requirements, thereby paving the way for more sustainable and accessible advancements in the field of generalist vision models. 12 | 13 | #### TODO List 14 | - [X] Code about training models. 15 | - [X] Code about inferencing models. 16 | - [X] Huggingface & InternLM ckpts. 17 | - [X] Code about data generation. 18 | 19 | 20 | #### Set up 21 | ``` 22 | based on InternLM-v0.2.1dev20231121 23 | ``` 24 | 25 | 26 | Install: `https://github.com/InternLM/InternLM/blob/v0.2.1dev20231121/doc/en/install.md` 27 | 28 | Put your training data to `/path/to/data/vision`. 29 | 30 | Training command: 31 | `torchrun --nproc_per_node 8 train.py --config ./configs/pretrain_300m.py --launcher torch` 32 | 33 | Training via KD command: 34 | `torchrun --nproc_per_node 8 train.py --config ./configs/kd_1b_to_300m.py --launcher torch` 35 | 36 | Convert model and inference example: `./tools` 37 | 38 | The corresponding huggingface ckpt can be downloaded at [LLaMA-1b-hf Onedrive](https://unisyd-my.sharepoint.com/:u:/g/personal/han_wu_sydney_edu_au/EQx8q3DvqP1CqOddm0aYN4wBBywVAOSvyB1P12ItzuNDmw?e=uOkUnP) / [LLaMA-1b-hf Baidu Disk](https://pan.baidu.com/s/12oI_TOVHtbhriM1Bu1TXmw?pwd=1234) and [LLaMA-300m-hf](https://github.com/ggjy/DeLVM/releases/download/hf-ckpt/llama_300m_hf.zip). 39 | 40 | #### Data generation 41 | Please refer to [data_generation/README.md](data_generation/README.md). 42 | 43 | 44 | ### Citation 45 | 46 | If you find this project useful in your research, please consider cite: 47 | 48 | ```bibtex 49 | @article{guo2024dataefficient, 50 | title={Data-efficient Large Vision Models through Sequential Autoregression}, 51 | author={Guo, Jianyuan and Hao, Zhiwei and Wang, Chengcheng and Tang, Yehui and Wu, Han and Hu, Han and Han, Kai and Xu, Chang}, 52 | journal={arXiv preprint arXiv:2402.04841}, 53 | year={2024} 54 | } 55 | ``` 56 | 57 | ### Acknowledgement 58 | 59 | We maily follow the directon of project [LVM](https://github.com/ytongbai/LVM). And this repo is based on [InternLM](https://github.com/InternLM/InternLM), [huggingface.co/transformers](https://github.com/huggingface/transformers), and [huggingface.co/openMUSE](https://github.com/huggingface/open-muse). 60 | 61 | ### License 62 | 63 | [](https://opensource.org/licenses/MIT) 64 | -------------------------------------------------------------------------------- /data_generation/README.md: -------------------------------------------------------------------------------- 1 | # Data generation 2 | 3 | ## Preliminary 4 | 5 | 1. `pip install -r data_generation/requirements.txt` 6 | 2. Download the vqgan checkpoint from [CowTransfer](https://cowtransfer.com/s/d771c6d3d8344d) or [Google Drive](https://drive.google.com/drive/folders/1CyucT_QOArUH_Au8dfzRSwseyiCGserF?usp=share_link), and move it to `./weight/vqgan-f16-8192-laion`. 7 | 8 | ## Human keypoint 9 | 10 | 1. You can generate the keypoint image refer to [mmpose](https://mmpose.readthedocs.io/en/dev-1.x/demos.html#d-human-pose-estimation-with-inferencer) , and 11 | change the inference cmd like this 12 | 13 | ```shell 14 | python inferencer_demo.py data/path \ 15 | coco/train2017/images \ 16 | --pose2d configs/body_2d_keypoint/rtmo/coco/rtmo-l_16xb16-600e_coco-640x640.py \ 17 | --pose2d-weights ./pth/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth \ 18 | --det-model demo/mmdetection_cfg/rtmdet_m_640-8xb32_coco-person.py \ 19 | --black-background \ 20 | --vis-out-dir coco/train2017/keypoints \ 21 | --skeleton-style openpose \ 22 | --disable-rebase-keypoint \ 23 | --radius 8 \ 24 | --thickness 4 \ 25 | ``` 26 | 27 | 2. Generate vq codebook by VQ-GAN 28 | 29 | ```shell 30 | python generate/generate_coco-keypoint.py \ 31 | --input_data coco/train2017/images \ 32 | --target_data coco/train2017/keypoints \ 33 | --output_path vq_token/coco-keypoints/train2017 34 | ``` 35 | 36 | ## Deblur 37 | 38 | ```shell 39 | python generate/generate_GoPro.py \ 40 | --input_data GoPro_train/input \ 41 | --target_data GoPro_train/target \ 42 | --output_path vq_token/GoPro_train 43 | ``` 44 | 45 | ## Derain 46 | 47 | Here we use Rain13K data in lmdb fromat. 48 | 49 | ```shell 50 | python generate/generate_Rain13K.py \ 51 | --input_data Rain13K_lmdb/input.lmdb \ 52 | --target_data Rain13K_lmdb/target.lmdb \ 53 | --output_path vq_token/Rain13K 54 | ``` 55 | 56 | ## Video dataset 57 | 58 | Here we use the HD-VILA-100M dataset. 59 | 60 | 1. You should download the dataset refer [hd-vila-100m](https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m), 61 | and use [src/cut_videos.py](https://github.com/microsoft/XPretrain/blob/main/hd-vila-100m/src/cut_videos.py) to cut 62 | the videos to clips. 63 | 64 | 2. Generate vq codebook by VQ-GAN 65 | 66 | ```shell 67 | python generate/generate_hdvila_100m.py \ 68 | --video_info_json hdvila_100m/cut_video_results/cut_part0.jsonl \ 69 | --data_root hdvila_100m/video_clips_imgs \ 70 | --output_root vq_token/hdvila_100m 71 | ``` 72 | 73 | ## Segment mask 74 | 75 | Here we use the SA-1B dataset. 76 | 77 | 1. Download the SA-1B dataset. 78 | 79 | 2. Generate vq codebook by VQ-GAN. 80 | 81 | ```shell 82 | python generate/generate_SA-1B.py \ 83 | --tar_root SA-1B/tar \ 84 | --img_json_root SA-1B/tmp/img_json \ 85 | --mask_root SA-1B/tmp/mask \ 86 | --output_path vq_token/SA-1B/token \ 87 | --dp_mode 88 | ``` -------------------------------------------------------------------------------- /data_generation/generate/generate_GoPro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | parent_path = pathlib.Path(__file__).absolute().parent.parent 6 | parent_path = os.path.abspath(parent_path) 7 | sys.path.append(parent_path) 8 | os.chdir(parent_path) 9 | print(f'>-------------> parent path {parent_path}') 10 | print(f'>-------------> current work dir {os.getcwd()}') 11 | 12 | import argparse 13 | import glob 14 | import multiprocessing 15 | from PIL import Image 16 | from os.path import join 17 | 18 | import torch 19 | from torch.utils.data import DataLoader 20 | from torchvision.datasets import VisionDataset 21 | 22 | from vqgan.load import encode_transform 23 | from generate.img_to_token import img_to_token 24 | 25 | CPU_COUNT = multiprocessing.cpu_count() 26 | 27 | 28 | class GoProDataset(VisionDataset): 29 | def __init__( 30 | self, 31 | root: str, 32 | target_root, 33 | transform=None, 34 | target_transform=None, 35 | transforms=None, 36 | transform_name=None 37 | ) -> None: 38 | super().__init__(root, transforms, transform, target_transform) 39 | self.target_root = target_root 40 | 41 | file_list = glob.glob(join(root, '*.png')) 42 | ids = [os.path.basename(i).split('.')[0] for i in file_list] 43 | self.ids = list(sorted(ids)) 44 | 45 | self.transform_name = transform_name 46 | 47 | def _load_image(self, id: int): 48 | path = join(self.root, f'{id}.png') 49 | return Image.open(path).convert("RGB") 50 | 51 | def _load_target(self, id: int): 52 | path = join(self.target_root, f'{id}.png') 53 | return Image.open(path).convert("RGB") 54 | 55 | def __getitem__(self, index: int): 56 | id = self.ids[index] 57 | image = self._load_image(id) 58 | target_img = self._load_target(id) 59 | 60 | images = self.transform(image) 61 | target_imgs = self.transform(target_img) 62 | 63 | data_list = [] 64 | if self.transform_name == 'six_crop_encode_transform': 65 | for _img, _target_img in zip(images, target_imgs): 66 | _data = torch.stack([_img, _target_img], dim=0) 67 | data_list.append(_data) 68 | else: 69 | _data = torch.stack([images, target_imgs], dim=0) 70 | data_list.append(_data) 71 | 72 | data = torch.cat(data_list, dim=0) 73 | 74 | return data 75 | 76 | def __len__(self) -> int: 77 | return len(self.ids) 78 | 79 | 80 | def convert_img_to_token(args, device=None): 81 | dataset = GoProDataset(args.input_data, args.target_data, transform=encode_transform, 82 | transform_name='encode_transform') 83 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work) 84 | img_to_token(args, data_loader, args.output_path, device=device) 85 | 86 | 87 | def get_args(): 88 | parser = argparse.ArgumentParser() 89 | 90 | parser.add_argument("--input_data", type=str, default="Rain13K_lmdb/input.lmdb") 91 | parser.add_argument("--target_data", type=str, default="Rain13K_lmdb/target.lmdb") 92 | parser.add_argument("--output_path", type=str, default="vq_token/Rain13K") 93 | 94 | parser.add_argument("--num_work", type=int, default=64) 95 | parser.add_argument("--batch_size", type=int, default=16) 96 | parser.add_argument("--dp_mode", action='store_true', default=False) 97 | parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") 98 | args = parser.parse_args() 99 | 100 | return args 101 | 102 | 103 | if __name__ == '__main__': 104 | args = get_args() 105 | 106 | device = f'cuda:{0}' 107 | convert_img_to_token(args, device=device) 108 | -------------------------------------------------------------------------------- /data_generation/generate/generate_Rain13K.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | import pathlib 5 | import sys 6 | 7 | import cv2 8 | import lmdb 9 | from torch.utils.data import DataLoader 10 | 11 | from generate.img_to_token import img_to_token 12 | 13 | parent_path = pathlib.Path(__file__).absolute().parent.parent 14 | parent_path = os.path.abspath(parent_path) 15 | sys.path.append(parent_path) 16 | os.chdir(parent_path) 17 | print(f'>-------------> parent path {parent_path}') 18 | print(f'>-------------> current work dir {os.getcwd()}') 19 | 20 | import numpy as np 21 | 22 | from PIL import Image 23 | from os.path import join 24 | 25 | import torch 26 | from torchvision.datasets import VisionDataset 27 | 28 | from vqgan.load import encode_transform 29 | 30 | CPU_COUNT = multiprocessing.cpu_count() 31 | 32 | 33 | class LMDBDataset(VisionDataset): 34 | def __init__( 35 | self, 36 | root: str, 37 | target_root, 38 | transform=None, 39 | target_transform=None, 40 | transforms=None, 41 | transform_name=None 42 | ) -> None: 43 | super().__init__(root, transforms, transform, target_transform) 44 | self.target_root = target_root 45 | 46 | self.img_db = lmdb.open(root).begin() 47 | self.target_db = lmdb.open(target_root).begin() 48 | 49 | with open(join(root, 'meta_info.txt'), 'rb') as f: 50 | file_list = f.readlines() 51 | 52 | ids = [i.decode().split(' ')[0].split('.')[0] for i in file_list] 53 | self.ids = list(sorted(ids)) 54 | 55 | self.transform_name = transform_name 56 | 57 | def _load_image(self, id): 58 | img_byte = self.img_db.get(id.encode()) 59 | image_buf = np.frombuffer(img_byte, dtype=np.uint8) 60 | img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) 61 | image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 62 | return image.convert("RGB") 63 | 64 | def _load_target(self, id): 65 | img_byte = self.target_db.get(id.encode()) 66 | image_buf = np.frombuffer(img_byte, dtype=np.uint8) 67 | img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) 68 | image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 69 | return image.convert("RGB") 70 | 71 | def __getitem__(self, index: int): 72 | id = self.ids[index] 73 | image = self._load_image(id) 74 | target_img = self._load_target(id) 75 | 76 | images = self.transform(image) 77 | target_imgs = self.transform(target_img) 78 | 79 | data_list = [] 80 | if self.transform_name == 'six_crop_encode_transform': 81 | for _img, _target_img in zip(images, target_imgs): 82 | _data = torch.stack([_img, _target_img], dim=0) 83 | data_list.append(_data) 84 | else: 85 | _data = torch.stack([images, target_imgs], dim=0) 86 | data_list.append(_data) 87 | 88 | data = torch.cat(data_list, dim=0) 89 | 90 | return data 91 | 92 | def __len__(self) -> int: 93 | return len(self.ids) 94 | 95 | 96 | def convert_img_to_token(args, device=None): 97 | dataset = LMDBDataset(args.input_data, args.target_data, transform=encode_transform, 98 | transform_name='encode_transform') 99 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work) 100 | img_to_token(args, data_loader, args.output_path, device=device) 101 | 102 | 103 | def get_args(): 104 | parser = argparse.ArgumentParser() 105 | 106 | parser.add_argument("--input_data", type=str, default="Rain13K_lmdb/input.lmdb") 107 | parser.add_argument("--target_data", type=str, default="Rain13K_lmdb/target.lmdb") 108 | parser.add_argument("--output_path", type=str, default="vq_token/Rain13K") 109 | 110 | parser.add_argument("--num_work", type=int, default=64) 111 | parser.add_argument("--batch_size", type=int, default=16) 112 | parser.add_argument("--dp_mode", action='store_true', default=False) 113 | parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") 114 | args = parser.parse_args() 115 | 116 | return args 117 | 118 | 119 | if __name__ == '__main__': 120 | args = get_args() 121 | 122 | # input_root = '/home/ma-user/work/data/tmp_data/Rain13K_lmdb/input.lmdb' 123 | # target_root = '/home/ma-user/work/data/tmp_data/Rain13K_lmdb/target.lmdb' 124 | # out_root = '/home/ma-user/work/data/vq_token/Rain13K' 125 | 126 | device = f'cuda:{0}' 127 | convert_img_to_token(args, device=device) 128 | -------------------------------------------------------------------------------- /data_generation/generate/generate_SA-1B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | from mmcv import DataLoader 6 | 7 | parent_path = pathlib.Path(__file__).absolute().parent.parent 8 | parent_path = os.path.abspath(parent_path) 9 | sys.path.append(parent_path) 10 | os.chdir(parent_path) 11 | print(f'>-------------> parent path {parent_path}') 12 | print(f'>-------------> current work dir {os.getcwd()}') 13 | 14 | import glob 15 | import json 16 | import argparse 17 | import subprocess 18 | import multiprocessing 19 | import numpy as np 20 | 21 | from tqdm import tqdm 22 | from PIL import Image 23 | from os.path import join 24 | from joblib import delayed, Parallel 25 | from pycocotools import mask as mask_utils 26 | 27 | import torch 28 | from torchvision.datasets import VisionDataset 29 | 30 | from generate.img_to_token import img_to_token 31 | from vqgan.load import six_crop_encode_transform 32 | 33 | CPU_COUNT = multiprocessing.cpu_count() 34 | 35 | 36 | def convert_anns_to_mask(sam_label): 37 | # device = f'cuda:{0}' 38 | device = f'cpu' 39 | 40 | image_info = sam_label['image'] 41 | anns = sam_label['annotations'] 42 | width, height, file_name = image_info['width'], image_info['height'], image_info['file_name'] 43 | 44 | if len(anns) == 0: 45 | return 46 | 47 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 48 | mask_img = torch.zeros((height, width, 3), device=device) 49 | one_img = torch.ones((height, width, 3), device=device) 50 | 51 | for ann in sorted_anns: 52 | mask = mask_utils.decode(ann['segmentation']) 53 | mask = torch.tensor(mask, device=device) 54 | mask = torch.repeat_interleave(mask.unsqueeze(dim=2), repeats=3, dim=2) 55 | 56 | color_mask = torch.rand(3, device=device) * one_img 57 | mask_img += color_mask * mask 58 | 59 | del mask, color_mask 60 | torch.cuda.empty_cache() 61 | 62 | mask_img_npy = mask_img.cpu().numpy() 63 | mask_img_npy = (255 * mask_img_npy).astype(np.uint8) 64 | 65 | del mask_img, one_img 66 | torch.cuda.empty_cache() 67 | 68 | return mask_img_npy 69 | 70 | 71 | def convert_sam_label(json_dir, out_dir, tar_name): 72 | print(f'>----------------------: convert sam label: {tar_name} ...\n') 73 | 74 | os.makedirs(out_dir, exist_ok=True) 75 | 76 | def _convert(_json_path, _out_dir): 77 | data_name = os.path.basename(_json_path) 78 | out_path = join(_out_dir, data_name.replace('json', 'png')) 79 | with open(_json_path) as f: 80 | sam_label = json.load(f) 81 | 82 | mask_img = convert_anns_to_mask(sam_label) 83 | mask_img = Image.fromarray(mask_img) 84 | mask_img.save(out_path) 85 | 86 | json_list = glob.glob(join(json_dir, '*.json')) 87 | 88 | Parallel(n_jobs=CPU_COUNT)( 89 | delayed(_convert)(index, json_path, out_dir) for json_path in tqdm(json_list)) 90 | 91 | 92 | class SamDataset(VisionDataset): 93 | def __init__( 94 | self, 95 | root: str, 96 | mask_root, 97 | transform=None, 98 | target_transform=None, 99 | transforms=None 100 | ) -> None: 101 | super().__init__(root, transforms, transform, target_transform) 102 | self.mask_root = mask_root 103 | 104 | file_list = glob.glob(join(root, '*.jpg')) 105 | ids = [os.path.basename(i).split('.')[0] for i in file_list] 106 | self.ids = list(sorted(ids)) 107 | 108 | def _load_image(self, id: int): 109 | path = join(self.root, f'{id}.jpg') 110 | return Image.open(path).convert("RGB") 111 | 112 | def _load_mask(self, id: int): 113 | path = join(self.mask_root, f'{id}.png') 114 | return Image.open(path).convert("RGB") 115 | 116 | def __getitem__(self, index: int): 117 | id = self.ids[index] 118 | image = self._load_image(id) 119 | mask_img = self._load_mask(id) 120 | 121 | images = self.transform(image) 122 | mask_imgs = self.transform(mask_img) 123 | 124 | data_list = [] 125 | for _img, _mask_img in zip(images, mask_imgs): 126 | _data = torch.stack([_img, _mask_img], dim=0) 127 | data_list.append(_data) 128 | 129 | data = torch.cat(data_list, dim=0) 130 | 131 | return data 132 | 133 | def __len__(self) -> int: 134 | return len(self.ids) 135 | 136 | 137 | def convert_img_to_token(args, img_data_dir, mask_data_dir, out_dir, tar_name, device=None): 138 | print(f'>----------------------: Convert img to token: {tar_name} ...') 139 | 140 | dataset = SamDataset(img_data_dir, mask_data_dir, transform=six_crop_encode_transform([800, 800])) 141 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work) 142 | img_to_token(args, data_loader, args.output_path, device=device) 143 | 144 | 145 | def unzip_data(tar_root, img_json_dir, tar_name): 146 | print(f'>----------------------: Unzip data: {tar_name} ...') 147 | 148 | local_tar_path = join(tar_root, f'{tar_name}.tar') 149 | os.makedirs(img_json_dir, exist_ok=True) 150 | 151 | cmd = f'tar -xf {local_tar_path} -C {img_json_dir}' 152 | subprocess.check_call(args=cmd, shell=True) 153 | 154 | 155 | def remove_tmpfile(img_json_dir, mask_dir, tar_name): 156 | print(f'>----------------------: Remove tmpfile: {tar_name} ...') 157 | 158 | tmp_files = [ 159 | img_json_dir, 160 | mask_dir 161 | ] 162 | 163 | for tmp_file in tmp_files: 164 | cmd = f'rm -rf {tmp_file}' 165 | subprocess.check_call(args=cmd, shell=True) 166 | 167 | 168 | def get_args(): 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--tar_root", type=str, default="data/SA-1B/tar") 171 | parser.add_argument("--img_json_root", type=str, default="data/SA-1B/tmp/img_json") 172 | parser.add_argument("--mask_root", type=str, default="data/SA-1B/tmp/mask") 173 | parser.add_argument("--output_path", type=str, default="vq_token/SA-1B") 174 | 175 | parser.add_argument("--num_work", type=int, default=64) 176 | parser.add_argument("--batch_size", type=int, default=32) 177 | parser.add_argument("--dp_mode", action='store_true', default=False) 178 | parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") 179 | args = parser.parse_args() 180 | 181 | return args 182 | 183 | 184 | if __name__ == '__main__': 185 | args = get_args() 186 | 187 | exclusion_data = [] 188 | 189 | tar_name_list = os.listdir(args.tar_root) 190 | tar_name_list = [i.split('.')[0] for i in tar_name_list if i[-3:] == 'tar'] 191 | 192 | if os.path.exists(args.output_path): 193 | exist_token_name_list = os.listdir(args.output_path) 194 | else: 195 | exist_token_name_list = [] 196 | 197 | tar_name_list = list(set(tar_name_list) - set(exist_token_name_list)) 198 | tar_name_list = sorted(tar_name_list) 199 | 200 | for index, tar_name in enumerate(tar_name_list): 201 | 202 | if tar_name in exclusion_data: 203 | continue 204 | 205 | print(f'\n\nProcessing sam data: {tar_name} {index + 1}/{len(tar_name_list)} ...') 206 | img_json_dir = join(args.img_json_root, tar_name) 207 | mask_dir = join(args.mask_root, tar_name) 208 | out_dir = join(args.output_path, tar_name) 209 | 210 | unzip_data(args.tar_root, img_json_dir, tar_name) 211 | convert_sam_label(img_json_dir, mask_dir, tar_name) 212 | 213 | device = f'cuda:{0}' 214 | convert_img_to_token(args, img_json_dir, mask_dir, out_dir, tar_name, device=device) 215 | 216 | remove_tmpfile(img_json_dir, mask_dir, tar_name) 217 | -------------------------------------------------------------------------------- /data_generation/generate/generate_coco-keypoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | parent_path = pathlib.Path(__file__).absolute().parent.parent 6 | parent_path = os.path.abspath(parent_path) 7 | sys.path.append(parent_path) 8 | os.chdir(parent_path) 9 | print(f'>-------------> parent path {parent_path}') 10 | print(f'>-------------> current work dir {os.getcwd()}') 11 | 12 | import argparse 13 | import glob 14 | import multiprocessing 15 | 16 | from PIL import Image 17 | from os.path import join 18 | 19 | import torch 20 | from torch.utils.data import DataLoader 21 | from torchvision.datasets import VisionDataset 22 | 23 | from vqgan.load import encode_transform 24 | from generate.img_to_token import img_to_token 25 | 26 | CPU_COUNT = multiprocessing.cpu_count() 27 | 28 | 29 | class KeyPointDataset(VisionDataset): 30 | def __init__( 31 | self, 32 | root: str, 33 | target_root, 34 | transform=None, 35 | target_transform=None, 36 | transforms=None, 37 | transform_name='encode_transform' 38 | ) -> None: 39 | super().__init__(root, transforms, transform, target_transform) 40 | self.target_root = target_root 41 | 42 | file_list = glob.glob(join(root, '*.jpg')) 43 | ids = [os.path.basename(i).split('.')[0] for i in file_list] 44 | self.ids = list(sorted(ids)) 45 | self.transform_name = transform_name 46 | 47 | def _load_image(self, id: int): 48 | path = join(self.root, f'{id}.jpg') 49 | return Image.open(path).convert("RGB") 50 | 51 | def _load_target(self, id: int): 52 | path = join(self.target_root, f'{id}.jpg') 53 | return Image.open(path).convert("RGB") 54 | 55 | def __getitem__(self, index: int): 56 | id = self.ids[index] 57 | image = self._load_image(id) 58 | target_img = self._load_target(id) 59 | 60 | images = self.transform(image) 61 | target_imgs = self.transform(target_img) 62 | 63 | data_list = [] 64 | if self.transform_name == 'six_crop_encode_transform': 65 | for _img, _target_img in zip(images, target_imgs): 66 | _data = torch.stack([_img, _target_img], dim=0) 67 | data_list.append(_data) 68 | else: 69 | _data = torch.stack([images, target_imgs], dim=0) 70 | data_list.append(_data) 71 | 72 | data = torch.cat(data_list, dim=0) 73 | 74 | return data 75 | 76 | def __len__(self) -> int: 77 | return len(self.ids) 78 | 79 | 80 | def convert_img_to_token(args, device=None): 81 | dataset = KeyPointDataset(args.input_data, args.target_data, transform=encode_transform) 82 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work) 83 | img_to_token(args, data_loader, args.output_path, device=device) 84 | 85 | 86 | def get_args(): 87 | parser = argparse.ArgumentParser() 88 | 89 | parser.add_argument("--input_data", type=str, default="coco-pose/GT/val2017/visual-crop/images") 90 | parser.add_argument("--target_data", type=str, default="coco-pose/GT/val2017/visual-crop/keypoints") 91 | parser.add_argument("--output_path", type=str, default="vq_token/coco-crop/val2017") 92 | 93 | parser.add_argument("--num_work", type=int, default=64) 94 | parser.add_argument("--batch_size", type=int, default=16) 95 | parser.add_argument("--dp_mode", action='store_true', default=False) 96 | parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") 97 | args = parser.parse_args() 98 | 99 | return args 100 | 101 | 102 | if __name__ == '__main__': 103 | args = get_args() 104 | 105 | device = f'cuda:{0}' 106 | convert_img_to_token(args, device=device) 107 | -------------------------------------------------------------------------------- /data_generation/generate/generate_hdvila_100m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | parent_path = pathlib.Path(__file__).absolute().parent.parent 6 | parent_path = os.path.abspath(parent_path) 7 | sys.path.append(parent_path) 8 | os.chdir(parent_path) 9 | # print(f'>-------------> parent path {parent_path}') 10 | # print(f'>-------------> current work dir {os.getcwd()}') 11 | 12 | import argparse 13 | import glob 14 | import jsonlines 15 | 16 | from os.path import join 17 | from tqdm import tqdm 18 | from PIL import Image 19 | # from joblib import Parallel, delayed 20 | 21 | import torch 22 | from timm.data import ImageDataset 23 | from torch.utils.data import DataLoader 24 | 25 | from vqgan.load import encode_transform 26 | from vqgan.utils import init_vqgan_encoder, get_multiprocess 27 | from generate.img_to_token import data_loader_to_token, save_bin_and_meta_file 28 | 29 | 30 | class ImageDatasetNoLabel(ImageDataset): 31 | def __getitem__(self, index): 32 | img, target = self.parser[index] 33 | img = img.read() if self.load_bytes else Image.open(img).convert('RGB') 34 | self._consecutive_errors = 0 35 | if self.transform is not None: 36 | img = self.transform(img) 37 | return img 38 | 39 | 40 | def convert_img_to_token(args, data_dir, out_dir, encoder, device=None): 41 | all_data_bin_list, all_cu_seq_len_list = [], [] 42 | 43 | for sub_dir_name in os.listdir(data_dir): 44 | input_dir = os.path.join(data_dir, sub_dir_name) 45 | 46 | if not os.path.exists(input_dir) or len(glob.glob(join(input_dir, '*'))) == 0: 47 | # print('Path not exist: ', input_dir) 48 | continue 49 | 50 | dataset = ImageDatasetNoLabel(input_dir, transform=encode_transform) 51 | new_multiprocess_ctx = get_multiprocess() 52 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, 53 | multiprocessing_context=new_multiprocess_ctx) 54 | 55 | data_bin_list, cu_seq_len_list = data_loader_to_token(encoder, data_loader, device) 56 | all_data_bin_list.extend(data_bin_list) 57 | all_cu_seq_len_list.extend(cu_seq_len_list) 58 | 59 | save_bin_and_meta_file(out_dir, all_data_bin_list, all_cu_seq_len_list) 60 | 61 | 62 | def convert_single_gpu(video_name_list, num_work, index): 63 | work_len = len(video_name_list) // num_work 64 | start, end = index * work_len, (index + 1) * work_len 65 | if index == num_work - 1: 66 | work_video_name_list = video_name_list[start:] 67 | else: 68 | work_video_name_list = video_name_list[start: end] 69 | 70 | device = f'cuda:{index}' 71 | encoder = init_vqgan_encoder(args.model_name_or_path, device) 72 | 73 | for video_name in tqdm(work_video_name_list): 74 | data_dir = join(args.data_root, video_name) 75 | out_dir = join(args.output_root, video_name) 76 | 77 | convert_img_to_token(args, data_dir, out_dir, encoder, device) 78 | 79 | 80 | def get_args(): 81 | parser = argparse.ArgumentParser() 82 | 83 | parser.add_argument("--video_info_json", type=str, 84 | default="hdvila_100m/cut_video_results/cut_part0.jsonl") 85 | parser.add_argument("--data_root", type=str, default='hdvila_100m/video_clips_imgs') 86 | parser.add_argument("--output_root", type=str, default="vq_token/hdvila_100m_2") 87 | 88 | parser.add_argument("--num_gpu", type=int, default=-1) 89 | parser.add_argument("--batch_size", type=str, default=128) 90 | parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") 91 | args = parser.parse_args() 92 | 93 | if args.num_gpu == -1: 94 | args.num_gpu = torch.cuda.device_count() 95 | 96 | return args 97 | 98 | 99 | if __name__ == '__main__': 100 | args = get_args() 101 | num_gpu = args.num_gpu 102 | 103 | print('Convert video info json: ', args.video_info_json) 104 | 105 | with jsonlines.open(args.video_info_json, 'r') as f: 106 | video_name_list = [l.split('.')[0] for l in f] 107 | 108 | if os.path.exists(args.output_root): 109 | exist_token_name = set(os.listdir(args.output_root)) 110 | else: 111 | exist_token_name = [] 112 | 113 | video_name_list = list(set(video_name_list) - set(exist_token_name)) 114 | 115 | for i in range(num_gpu): 116 | convert_single_gpu(video_name_list, num_gpu, i) 117 | 118 | # Parallel(n_jobs=num_gpu)(delayed(convert_single_gpu)(video_name_list, num_gpu, i) for i in range(num_gpu)) 119 | -------------------------------------------------------------------------------- /data_generation/generate/generate_laion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | parent_path = pathlib.Path(__file__).absolute().parent.parent 6 | parent_path = os.path.abspath(parent_path) 7 | sys.path.append(parent_path) 8 | os.chdir(parent_path) 9 | print(f'>-------------> parent path {parent_path}') 10 | print(f'>-------------> current work dir {os.getcwd()}') 11 | 12 | 13 | import argparse 14 | import subprocess 15 | import tarfile 16 | import time 17 | from multiprocessing import Pool 18 | 19 | from timm.data import ImageDataset 20 | from torch.utils.data import DataLoader 21 | from tqdm import tqdm 22 | 23 | from generate.img_to_token import img_to_token 24 | from vqgan.load import encode_transform 25 | 26 | 27 | def extract_tar(tar_info): 28 | tar_file, folder_path, extract_path, prefix = tar_info 29 | tar_file_path = os.path.join(folder_path, tar_file) 30 | target_folder = os.path.join(extract_path, f"{prefix}_{tar_file[:-4]}") 31 | 32 | with tarfile.open(tar_file_path, 'r') as tar: 33 | tar.extractall(target_folder) 34 | 35 | print(f"Extracted {tar_file} to {target_folder}") 36 | 37 | cmd_txt = 'yes | rm -r ' + target_folder + '/*.txt' 38 | subprocess.run(cmd_txt, shell=True) 39 | cmd_json = 'yes | rm -r ' + target_folder + '/*.json' 40 | subprocess.run(cmd_json, shell=True) 41 | cmd_tar_file = 'yes | rm -r ' + tar_file_path 42 | subprocess.run(cmd_tar_file, shell=True) 43 | 44 | 45 | def extract_all_tarfiles_parallel(folder_path, extract_path, prefix, num_processes=4): 46 | file_list = [file for file in os.listdir(folder_path) if file.endswith('.tar')] 47 | tar_info_list = [(tar_file, folder_path, extract_path, prefix) for tar_file in file_list] 48 | 49 | with Pool(num_processes) as pool: 50 | pool.map(extract_tar, tar_info_list) 51 | 52 | 53 | def list_subdir(folder_path): 54 | subdir = [f.name for f in os.scandir(folder_path) if f.is_dir()] 55 | return subdir 56 | 57 | 58 | def get_args(): 59 | parser = argparse.ArgumentParser() 60 | 61 | # unzip 62 | parser.add_argument("--folder_path", type=str, default='/cache/data/laion400m-images/part0/') 63 | parser.add_argument("--extract_path", type=str, default='/home/ma-user/work/laion400m-images/part0_jpg/') 64 | parser.add_argument("--prefix", type=str, default='laion_part0') 65 | parser.add_argument("--num_processes", type=int, default=10) 66 | # vqgan convert 67 | parser.add_argument("--data", type=str, default='/cache/laion_jpg/part0/') # folder of unziped imgs 68 | parser.add_argument("--batch_size", type=str, default=256) # folder of imgs 69 | parser.add_argument("--output", type=str, default="/cache/laion_train_convert/part0/") 70 | parser.add_argument("--model_name_or_path", type=str, default="weight/vqgan-f16-8192-laion") 71 | args = parser.parse_args() 72 | args.data = args.extract_path 73 | 74 | return args 75 | 76 | 77 | if __name__ == '__main__': 78 | args = get_args() 79 | 80 | unzip_start_time = time.time() 81 | extract_all_tarfiles_parallel(args.folder_path, args.extract_path, args.prefix, args.num_processes) 82 | print('########### unzip time: ', time.time() - unzip_start_time) 83 | 84 | dir_names = list_subdir(args.data) 85 | 86 | device = 'cuda:0' 87 | 88 | for idx, sub_dir_name in enumerate(tqdm(dir_names)): 89 | output_dir = os.path.join(args.output, sub_dir_name) 90 | 91 | dataset = ImageDataset(args.input_data, args.target_data, transform=encode_transform) 92 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_work) 93 | img_to_token(args, data_loader, args.output_path, device=device) 94 | 95 | print('Finish convert...') 96 | -------------------------------------------------------------------------------- /data_generation/generate/img_to_token.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from itertools import chain 4 | from os.path import join 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from tqdm import tqdm 10 | 11 | from vqgan.utils import init_vqgan_encoder 12 | 13 | 14 | def data_loader_to_token(encoder, data_loader, device): 15 | cu = 0 16 | data_bin_list = [] 17 | cu_seq_len_list = [] 18 | for _data in tqdm(data_loader): 19 | _data = _data.to(device) 20 | 21 | if _data.dim() == 5: 22 | data_list = list(torch.split(_data, 1, dim=0)) 23 | data_list = [i.squeeze(dim=0) for i in data_list] 24 | data = torch.cat(data_list, dim=0) 25 | else: 26 | data = _data 27 | 28 | _, out_tokens = encoder(data) 29 | 30 | indices_list = list(torch.split(out_tokens, 2, dim=0)) 31 | 32 | for indices in indices_list: 33 | tokens = list(chain(*indices.tolist())) 34 | seq_len = len(tokens) 35 | saved_bin = str.encode(json.dumps(dict(tokens=tokens)) + "\n") 36 | 37 | data_bin_list.append(saved_bin) 38 | cu_seq_len_list.append((cu, seq_len)) 39 | cu += len(saved_bin) 40 | return data_bin_list, cu_seq_len_list 41 | 42 | 43 | def save_bin_and_meta_file(out_dir, data_bin_list, cu_seq_len_list): 44 | os.makedirs(out_dir, exist_ok=True) 45 | out_bin = join(out_dir, "train.bin") 46 | out_meta = join(out_dir, "train.bin.meta") 47 | 48 | with open(out_bin, "wb+") as bin_file: 49 | bin_file.writelines(data_bin_list) 50 | 51 | cu_seq_len_list = np.array(cu_seq_len_list, dtype=np.int64) 52 | with open(out_meta, "wb+") as meta_file: 53 | np.save(meta_file, cu_seq_len_list) 54 | 55 | 56 | def img_to_token(args, data_loader, out_dir, device=None): 57 | encoder = init_vqgan_encoder(args.model_name_or_path, device) 58 | 59 | if args.dp_mode: 60 | encoder = nn.DataParallel(encoder) 61 | 62 | data_bin_list, cu_seq_len_list = data_loader_to_token(encoder, data_loader, device) 63 | save_bin_and_meta_file(out_dir, data_bin_list, cu_seq_len_list) 64 | -------------------------------------------------------------------------------- /data_generation/generate/token_concat.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from itertools import chain 4 | from os.path import join 5 | 6 | import numpy as np 7 | from joblib import Parallel, delayed 8 | from tqdm import tqdm 9 | 10 | 11 | def img_concat(data_root, output_root, data_name, token_num_per_sentence): 12 | data_dir = join(data_root, data_name) 13 | out_dir = join(output_root, data_name) 14 | 15 | os.makedirs(out_dir, exist_ok=True) 16 | 17 | data_bin_path = os.path.join(data_dir, "train.bin") 18 | out_data_bin_path = os.path.join(out_dir, "train.bin") 19 | out_data_meta_path = os.path.join(out_dir, "train.bin.meta") 20 | 21 | with open(data_bin_path, "r") as bin_file: 22 | data_bin = bin_file.readlines() 23 | 24 | cu = 0 25 | new_data_bin = [] 26 | cu_seq_len_list = [] 27 | 28 | sentence = [] 29 | for index, data in enumerate(data_bin): 30 | data = json.loads(data)['tokens'] 31 | if index > 0 and index % token_num_per_sentence == 0: 32 | tokens = list(chain(*sentence)) 33 | seq_len = len(tokens) 34 | saved_bin = str.encode(json.dumps(dict(tokens=tokens)) + "\n") 35 | 36 | new_data_bin.append(saved_bin) 37 | cu_seq_len_list.append((cu, seq_len)) 38 | cu += len(saved_bin) 39 | sentence = [] 40 | 41 | sentence.append(data) 42 | 43 | tokens = list(chain(*sentence)) 44 | seq_len = len(tokens) 45 | saved_bin = str.encode(json.dumps(dict(tokens=tokens)) + "\n") 46 | 47 | new_data_bin.append(saved_bin) 48 | cu_seq_len_list.append((cu, seq_len)) 49 | cu += len(saved_bin) 50 | 51 | with open(out_data_bin_path, "wb+") as out_bin_file: 52 | out_bin_file.writelines(new_data_bin) 53 | np.save(out_data_meta_path, cu_seq_len_list) 54 | os.rename(f'{out_data_meta_path}.npy', out_data_meta_path) 55 | 56 | 57 | if __name__ == '__main__': 58 | token_num_per_sentence = 6 59 | file_name = 'Rain13K' 60 | data_root = '/home/ma-user/work/data/vq_token' 61 | 62 | data_dir = join(data_root, file_name) 63 | output_dir = join(data_root, f'{file_name}-sentence_{token_num_per_sentence}') 64 | 65 | # for data_name in tqdm(os.listdir(data_dir)): 66 | # img_concat(data_dir, output_dir, data_name, token_num_per_sentence) 67 | 68 | Parallel(n_jobs=64)(delayed(img_concat)(data_dir, output_dir, data_name, token_num_per_sentence) 69 | for data_name in tqdm(os.listdir(data_dir))) 70 | -------------------------------------------------------------------------------- /data_generation/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.6.12 2 | accelerate 3 | jsonlines 4 | nvitop 5 | multiprocess -------------------------------------------------------------------------------- /data_generation/vqgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/data_generation/vqgan/__init__.py -------------------------------------------------------------------------------- /data_generation/vqgan/laion_convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import pathlib 5 | import random 6 | import subprocess 7 | import tarfile 8 | import time 9 | from multiprocessing import Pool 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from timm.data import ImageDataset 15 | from torch.utils.data import DataLoader 16 | 17 | torch.set_grad_enabled(False) 18 | random.seed(42) 19 | 20 | from load import encode_transform, load_model 21 | 22 | # set seed 23 | random_seed = 1 24 | torch.manual_seed(random_seed) 25 | torch.cuda.manual_seed(random_seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | np.random.seed(random_seed) 29 | 30 | # args 31 | parser = argparse.ArgumentParser() 32 | # unzip 33 | parser.add_argument("--folder_path", type=str, default='/cache/data/laion400m-images/part0/') 34 | parser.add_argument("--extract_path", type=str, default='/home/ma-user/work/laion400m-images/part0_jpg/') 35 | parser.add_argument("--prefix", type=str, default='laion_part0') 36 | parser.add_argument("--num_processes", type=int, default=10) 37 | # vqgan convert 38 | parser.add_argument("--data", type=str, default='/cache/laion_jpg/part0/') # folder of unziped imgs 39 | parser.add_argument("--batch_size", type=str, default=256) # folder of imgs 40 | parser.add_argument("--type", type=str, default="internlm") 41 | parser.add_argument("--output", type=str, default="/cache/laion_train_convert/part0/") 42 | parser.add_argument("--model_name_or_path", type=str, default="/cache/ckpt/vqgan-f16-8192-laion") 43 | args = parser.parse_args() 44 | args.data = args.extract_path 45 | 46 | # unzip part 47 | unzip_start_time = time.time() 48 | 49 | 50 | def extract_tar(tar_info): 51 | tar_file, folder_path, extract_path, prefix = tar_info 52 | tar_file_path = os.path.join(folder_path, tar_file) 53 | target_folder = os.path.join(extract_path, f"{prefix}_{tar_file[:-4]}") 54 | 55 | with tarfile.open(tar_file_path, 'r') as tar: 56 | tar.extractall(target_folder) 57 | 58 | print(f"Extracted {tar_file} to {target_folder}") 59 | 60 | cmd_txt = 'yes | rm -r ' + target_folder + '/*.txt' 61 | subprocess.run(cmd_txt, shell=True) 62 | cmd_json = 'yes | rm -r ' + target_folder + '/*.json' 63 | subprocess.run(cmd_json, shell=True) 64 | cmd_tar_file = 'yes | rm -r ' + tar_file_path 65 | subprocess.run(cmd_tar_file, shell=True) 66 | 67 | 68 | def extract_all_tarfiles_parallel(folder_path, extract_path, prefix, num_processes=4): 69 | file_list = [file for file in os.listdir(folder_path) if file.endswith('.tar')] 70 | tar_info_list = [(tar_file, folder_path, extract_path, prefix) for tar_file in file_list] 71 | 72 | with Pool(num_processes) as pool: 73 | pool.map(extract_tar, tar_info_list) 74 | 75 | 76 | extract_all_tarfiles_parallel(args.folder_path, args.extract_path, args.prefix, args.num_processes) 77 | print('########### unzip time: ', time.time() - unzip_start_time) 78 | 79 | convert_start_time = time.time() 80 | 81 | 82 | # convert part 83 | def list_subdir(folder_path): 84 | subdir = [f.name for f in os.scandir(folder_path) if f.is_dir()] 85 | return subdir 86 | 87 | 88 | dir_names = list_subdir(args.data) 89 | 90 | print('Strating convert via vqgan...') 91 | print(args) 92 | print(len(dir_names)) 93 | 94 | vq_model = load_model(args.model_name_or_path) 95 | vq_model = vq_model.cuda().eval() 96 | 97 | 98 | class ParallelWrapper(nn.Module): 99 | def __init__(self, vq_model, func='encode'): 100 | super(ParallelWrapper, self).__init__() 101 | self.vq_model = vq_model 102 | self.func = func 103 | 104 | def forward(self, x): 105 | return getattr(self.vq_model, self.func)(x) 106 | 107 | 108 | encoder = ParallelWrapper(vq_model) 109 | encoder = nn.DataParallel(encoder) 110 | 111 | 112 | def dumps(data): 113 | seqlen = len(data) 114 | saved_bin = str.encode(json.dumps(dict(tokens=data)) + "\n") 115 | return {"bin": saved_bin, "length": seqlen} 116 | 117 | 118 | for idx, sub_dir_name in enumerate(dir_names): 119 | if idx % 10 == 0: 120 | print(idx) 121 | 122 | output_dir = os.path.join(args.output, sub_dir_name) 123 | 124 | pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) 125 | 126 | out_bin = os.path.join(output_dir, "train.bin") 127 | out_meta = os.path.join(output_dir, "train.bin.meta") 128 | 129 | pathlib.Path(out_bin).touch(exist_ok=True) 130 | pathlib.Path(out_meta).touch(exist_ok=True) 131 | 132 | dataset = ImageDataset(os.path.join(args.data, sub_dir_name), transform=encode_transform) 133 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) 134 | 135 | from tqdm import tqdm 136 | 137 | cu = 0 138 | cu_seqlens = [] 139 | with open(out_bin, "wb+") as bin_file: 140 | for i, (imgs, _) in enumerate(tqdm(loader)): 141 | imgs = imgs.cuda() 142 | quantized_states, indices = encoder(imgs) 143 | 144 | for indices_i in indices.tolist(): 145 | token = dumps(indices_i) 146 | seqlen = token["length"] # 256 147 | token_data = token["bin"] 148 | bin_file.write(token_data) 149 | # print((cu, seqlen)) 150 | cu_seqlens.append((cu, seqlen)) 151 | cu += len(token_data) 152 | cu_seqlens = np.array(cu_seqlens, dtype=np.int64) 153 | with open(out_meta, "wb+") as meta_file: 154 | np.save(meta_file, cu_seqlens) 155 | 156 | print('########### unzip time: ', time.time() - convert_start_time) 157 | print('Finish convert...') 158 | -------------------------------------------------------------------------------- /data_generation/vqgan/load.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | from PIL import Image 7 | from torch import Tensor 8 | from torchvision import transforms 9 | from torchvision.transforms import Lambda 10 | 11 | from .muse import VQGANModel 12 | 13 | 14 | class ForwardWrapper(nn.Module): 15 | def __init__(self, vq_model, func='encode'): 16 | super(ForwardWrapper, self).__init__() 17 | self.vq_model = vq_model 18 | self.func = func 19 | 20 | def forward(self, x): 21 | return getattr(self.vq_model, self.func)(x) 22 | 23 | 24 | def load_model(path): 25 | # Load the pre-trained vq model from the hub 26 | vq_model = VQGANModel.from_pretrained(path) 27 | return vq_model 28 | 29 | 30 | def load_encoder(path): 31 | vq_model = load_model(path) 32 | encoder = ForwardWrapper(vq_model) 33 | return encoder 34 | 35 | 36 | def load_decoder(path): 37 | vq_model = load_model(path) 38 | decoder = ForwardWrapper(vq_model, func='decode') 39 | return decoder 40 | 41 | 42 | def load_decoder_code(path): 43 | vq_model = load_model(path) 44 | decoder = ForwardWrapper(vq_model, func='decode_code') 45 | return decoder 46 | 47 | 48 | def convert_decode_to_pil(rec_image): 49 | rec_image = 2.0 * rec_image - 1.0 50 | rec_image = torch.clamp(rec_image, -1.0, 1.0) 51 | rec_image = (rec_image + 1.0) / 2.0 52 | rec_image *= 255.0 53 | rec_image = rec_image.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) 54 | pil_images = [Image.fromarray(image) for image in rec_image] 55 | return pil_images 56 | 57 | 58 | class SixCrop(torch.nn.Module): 59 | def __init__(self, crop_size): 60 | super().__init__() 61 | self.crop_size = crop_size 62 | 63 | # def get_dimensions(self, img): 64 | # """Returns the dimensions of an image as [channels, height, width]. 65 | # 66 | # Args: 67 | # img (PIL Image or Tensor): The image to be checked. 68 | # 69 | # Returns: 70 | # List[int]: The image dimensions. 71 | # """ 72 | # if isinstance(img, torch.Tensor): 73 | # return F_t.get_dimensions(img) 74 | # 75 | # return F_pil.get_dimensions(img) 76 | 77 | def get_dimensions(self, img) -> List[int]: 78 | if hasattr(img, "getbands"): 79 | channels = len(img.getbands()) 80 | else: 81 | channels = img.channels 82 | width, height = img.size 83 | return [channels, height, width] 84 | 85 | def six_crop(self, img: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: 86 | """Crop the given image into four corners and the central crop. 87 | If the image is torch Tensor, it is expected 88 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 89 | 90 | .. Note:: 91 | This transform returns a tuple of images and there may be a 92 | mismatch in the number of inputs and targets your ``Dataset`` returns. 93 | 94 | Args: 95 | img (PIL Image or Tensor): Image to be cropped. 96 | size (sequence or int): Desired output size of the crop. If size is an 97 | int instead of sequence like (h, w), a square crop (size, size) is 98 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 99 | 100 | Returns: 101 | tuple: tuple (tl, tr, bl, br, center) 102 | Corresponding top left, top right, bottom left, bottom right and center crop. 103 | """ 104 | # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 105 | # _log_api_usage_once(five_crop) 106 | 107 | crop_height, crop_width = self.crop_size 108 | _, image_height, image_width = self.get_dimensions(img) 109 | 110 | # if crop_width > image_width or crop_height > image_height: 111 | # msg = "Requested crop size {} is bigger than input size {}" 112 | # raise ValueError(msg.format(self.crop_size, (image_height, image_width))) 113 | 114 | if crop_width > image_width: 115 | crop_width = image_width 116 | crop_height = image_width 117 | 118 | if crop_height > image_height: 119 | crop_width = image_height 120 | crop_height = image_height 121 | 122 | tl = F.crop(img, 0, 0, crop_height, crop_width) 123 | tr = F.crop(img, 0, image_width - crop_width, crop_height, crop_width) 124 | bl = F.crop(img, image_height - crop_height, 0, crop_height, crop_width) 125 | br = F.crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) 126 | 127 | if image_height > image_width: 128 | center_top = int(round((image_height - crop_height) / 2.0)) 129 | cl = F.crop(img, center_top, 0, crop_height, crop_width) 130 | cr = F.crop(img, center_top, image_width - crop_width, crop_height, crop_width) 131 | return tl, tr, cl, cr, bl, br 132 | else: 133 | center_left = int(round((image_width - crop_width) / 2.0)) 134 | ct = F.crop(img, 0, center_left, crop_height, crop_width) 135 | cb = F.crop(img, image_height - crop_height, center_left, crop_height, crop_width) 136 | return tl, tr, ct, bl, br, cb 137 | 138 | # center = center_crop(img, [crop_height, crop_width]) 139 | 140 | def forward(self, img): 141 | """ 142 | Args: 143 | img (PIL Image or Tensor): Image to be scaled. 144 | 145 | Returns: 146 | PIL Image or Tensor: Rescaled image. 147 | """ 148 | return self.six_crop(img) 149 | 150 | def __repr__(self) -> str: 151 | return f"{self.__class__.__name__}(size={self.crop_size})" 152 | 153 | 154 | def six_crop_encode_transform(crop_size): 155 | t = transforms.Compose( 156 | [ 157 | SixCrop(crop_size), 158 | # transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), 159 | Lambda(lambda crops: 160 | [transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR)(crop) for crop 161 | in crops]), 162 | Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops]), 163 | ] 164 | ) 165 | return t 166 | 167 | 168 | encode_transform = transforms.Compose( 169 | [ 170 | transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), 171 | transforms.CenterCrop(256), 172 | transforms.ToTensor(), 173 | ] 174 | ) 175 | 176 | encode_transform_no_crop = transforms.Compose( 177 | [ 178 | transforms.Resize([256, 256], interpolation=transforms.InterpolationMode.BILINEAR), 179 | transforms.ToTensor(), 180 | ] 181 | ) 182 | 183 | encode_transform_2 = transforms.Compose( 184 | [ 185 | transforms.RandomHorizontalFlip(), 186 | transforms.RandomVerticalFlip(), 187 | transforms.RandomRotation(180), 188 | transforms.RandomResizedCrop(256, interpolation=transforms.InterpolationMode.BILINEAR), 189 | transforms.ToTensor(), 190 | ] 191 | ) 192 | 193 | encode_transform_rain_random = transforms.Compose( 194 | [ 195 | transforms.RandomHorizontalFlip(), 196 | transforms.RandomVerticalFlip(), 197 | transforms.RandomResizedCrop(256, interpolation=transforms.InterpolationMode.BILINEAR), 198 | transforms.ToTensor(), 199 | ] 200 | ) 201 | 202 | encode_transform_rain_random_2 = transforms.Compose( 203 | [ 204 | transforms.RandomHorizontalFlip(), 205 | transforms.RandomVerticalFlip(), 206 | transforms.RandomCrop(400), 207 | transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), 208 | transforms.ToTensor(), 209 | ] 210 | ) 211 | 212 | if __name__ == '__main__': 213 | import numpy as np 214 | 215 | vq_model = load_model('/cache/ckpt/vqgan-f16-8192-laion') 216 | 217 | image = Image.open("ILSVRC2012_val_00040846.JPEG") 218 | pixel_values = encode_transform(image).unsqueeze(0) 219 | quantized_states, indices = vq_model.encode(pixel_values) 220 | rec_image = vq_model.decode(quantized_states) 221 | pil_images = convert_decode_to_pil(rec_image) 222 | -------------------------------------------------------------------------------- /data_generation/vqgan/muse/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | __version__ = "0.0.1" 17 | 18 | from .modeling_taming_vqgan import VQGANModel 19 | -------------------------------------------------------------------------------- /data_generation/vqgan/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import multiprocess 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | 9 | from .load import load_model 10 | 11 | 12 | def init_seed(): 13 | # set seed 14 | import torch 15 | random_seed = 1 16 | random.seed(42) 17 | torch.set_grad_enabled(False) 18 | torch.manual_seed(random_seed) 19 | torch.cuda.manual_seed(random_seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | np.random.seed(random_seed) 23 | 24 | 25 | class ParallelWrapper(nn.Module): 26 | def __init__(self, vq_model, func='encode'): 27 | super().__init__() 28 | self.vq_model = vq_model 29 | self.func = func 30 | 31 | def forward(self, x): 32 | return getattr(self.vq_model, self.func)(x) 33 | 34 | 35 | def init_vqgan_encoder(model_name_or_path, device): 36 | init_seed() 37 | vq_model = load_model(model_name_or_path) 38 | vq_model = vq_model.to(device).eval() 39 | 40 | print('vq_model device:', vq_model.device) 41 | 42 | encoder = ParallelWrapper(vq_model) 43 | 44 | return encoder 45 | 46 | 47 | def get_multiprocess(): 48 | multiprocess.set_start_method('spawn', force=True) 49 | torch.utils.data.dataloader.python_multiprocessing = multiprocess 50 | new_multiprocess_ctx = multiprocess.get_context() 51 | return new_multiprocess_ctx 52 | 53 | 54 | def dumps(data): 55 | seqlen = len(data) 56 | saved_bin = str.encode(json.dumps(dict(tokens=data)) + "\n") 57 | return {"bin": saved_bin, "length": seqlen} 58 | -------------------------------------------------------------------------------- /figs/DeLVM.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ggjy/DeLVM/940788c839667d50828e991fdb3234f44f67c441/figs/DeLVM.PNG --------------------------------------------------------------------------------