├── Merak ├── core │ ├── finetuning │ │ ├── __init__.py │ │ └── lora │ │ │ ├── __init__.py │ │ │ └── mappings.py │ ├── amp │ │ └── __init__.py │ ├── zero │ │ └── __init__.py │ ├── tensor_parallel │ │ ├── __init__.py │ │ └── mp_mapping.py │ ├── fx │ │ ├── graph_shard │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── split_graph.py │ │ ├── __init__.py │ │ └── tracer │ │ │ ├── __init__.py │ │ │ ├── tracers.py │ │ │ └── _dynamo_trace.py │ ├── checkpoint │ │ └── __init__.py │ ├── __init__.py │ ├── recompute │ │ ├── __init__.py │ │ └── utils.py │ ├── printer │ │ ├── __init__.py │ │ ├── logging.py │ │ └── see_memory.py │ ├── pipeline │ │ └── __init__.py │ └── mpu │ │ └── __init__.py ├── inference │ └── __init__.py ├── merak_args │ └── __init__.py ├── utils │ ├── __init__.py │ └── device_to_meta.py ├── __init__.py └── initialize.py ├── examples ├── torch-models │ ├── data │ │ └── __init__.py │ ├── models │ │ ├── __init__.py │ │ └── build.py │ ├── swin_base_patch4_window7_224.yaml │ ├── README.md │ ├── run_torchvision.py │ ├── config.json │ └── run_swin.py ├── deepseek-r1 │ ├── figures │ │ ├── image-20250228171342803.png │ │ ├── image-20250303085648865.png │ │ ├── image-20250303085732272.png │ │ └── image-20250303102744861.png │ ├── lora_config.json │ ├── lora_config_Qwen.json │ ├── lora_config_gen.json │ └── models │ │ ├── DeepSeek-R1-Distill-Qwen-32B │ │ └── config.json │ │ ├── DeepSeek-R1-Distll-Qwen-14B │ │ └── config.json │ │ ├── DeepSeek-R1-Distill-Qwen-7B │ │ └── config.json │ │ └── DeepSeek-R1-Distill-Llama-8B │ │ └── config.json ├── models │ ├── llama │ │ ├── config.py │ │ └── run_llama.py │ ├── distilbert │ │ ├── config.py │ │ └── run_distilbert.py │ ├── albert │ │ ├── config.py │ │ └── run_albert.py │ ├── dinov2 │ │ ├── config.py │ │ └── run_dinov2.py │ ├── layoutlm │ │ ├── config.py │ │ └── run_layoutlm.py │ ├── electra │ │ ├── config.py │ │ └── run_electra.py │ ├── opt │ │ ├── config.py │ │ └── run_opt.py │ ├── mt5 │ │ ├── config.py │ │ └── run_mt5.py │ ├── gpt2 │ │ ├── config.py │ │ └── run_gpt.py │ ├── nezha │ │ ├── config.py │ │ └── run_nezha.py │ ├── gptj │ │ ├── config.py │ │ └── run_gptj.py │ ├── t5 │ │ ├── config.py │ │ └── run_t5.py │ ├── m2m100 │ │ ├── config.py │ │ └── run_m2m100.py │ ├── lxmert │ │ ├── config.py │ │ └── run_lxmert.py │ ├── bart │ │ ├── config.py │ │ └── run_bart.py │ ├── speech2text │ │ ├── config.py │ │ └── run_speech_to_text.py │ ├── marianmt │ │ ├── config.py │ │ └── run_marian.py │ ├── blenderbot │ │ ├── config.py │ │ └── run_blenderbot.py │ ├── bert │ │ ├── run_bert.py │ │ └── config.py │ ├── hubert │ │ └── config.py │ ├── pegasus │ │ └── run_pegasus.py │ ├── convnext │ │ └── run_convnext.py │ ├── trocr │ │ └── run_trocr.py │ ├── mobilebert │ │ └── run_mobilebert.py │ ├── speech2text2 │ │ └── run_speech_to_text2.py │ ├── plbart │ │ └── run_plbart.py │ ├── mbart │ │ └── run_mbart.py │ ├── xglm │ │ └── run_xglm.py │ ├── clip │ │ └── run_clip.py │ └── altclip │ │ └── run_altclip.py ├── bert_pretraining │ ├── config.py │ └── README.md ├── text_generation │ ├── README.md │ ├── config.py │ └── run_gpt_text_generation.py ├── lora │ └── README.md ├── image-classification │ ├── README.md │ └── run_vit.py ├── unet │ └── README.md └── README.md ├── .gitmodules ├── .gitignore ├── docs ├── fx_doc.md └── amp_doc.md └── test └── core ├── fx ├── test_convert.py ├── tracer │ └── test_symbolic_trace.py └── graph_shard │ └── test_split_graph.py ├── pipeline_parallel └── test_module.py └── checkpoint └── test_safetensor_ckpt.py /Merak/core/finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/torch-models/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader -------------------------------------------------------------------------------- /examples/torch-models/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /Merak/core/amp/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['MixedPrecisionConfig'] 2 | 3 | from .amp_engine import MixedPrecisionConfig -------------------------------------------------------------------------------- /Merak/core/zero/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['configure_zero_optimizer'] 2 | 3 | from .stage1 import configure_zero_optimizer -------------------------------------------------------------------------------- /Merak/core/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'ModuleRebuild', 3 | ] 4 | 5 | from .model_parallel import ModuleRebuild -------------------------------------------------------------------------------- /Merak/core/fx/graph_shard/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['_shard_model_transformers'] 2 | 3 | from .split_graph import _shard_model_transformers -------------------------------------------------------------------------------- /Merak/inference/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['text_generation_pipeline'] 2 | 3 | from .text_generation_inference import text_generation_pipeline -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "examples/unet/pytorch_unet"] 2 | path = examples/unet/pytorch_unet 3 | url = https://github.com/cosmic-cortex/pytorch-UNet.git 4 | -------------------------------------------------------------------------------- /Merak/core/fx/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['convert_to_sequential', 'add_inputs_to_shards'] 2 | 3 | from .convert import convert_to_sequential, add_inputs_to_shards 4 | -------------------------------------------------------------------------------- /examples/deepseek-r1/figures/image-20250228171342803.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HPDL-Group/Merak/HEAD/examples/deepseek-r1/figures/image-20250228171342803.png -------------------------------------------------------------------------------- /examples/deepseek-r1/figures/image-20250303085648865.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HPDL-Group/Merak/HEAD/examples/deepseek-r1/figures/image-20250303085648865.png -------------------------------------------------------------------------------- /examples/deepseek-r1/figures/image-20250303085732272.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HPDL-Group/Merak/HEAD/examples/deepseek-r1/figures/image-20250303085732272.png -------------------------------------------------------------------------------- /examples/deepseek-r1/figures/image-20250303102744861.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HPDL-Group/Merak/HEAD/examples/deepseek-r1/figures/image-20250303102744861.png -------------------------------------------------------------------------------- /Merak/core/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'CheckpointSaver', 'CheckpointLoader', 'rotate_checkpoints' 3 | ] 4 | 5 | from .checkpoint import ( 6 | rotate_checkpoints, CheckpointSaver, CheckpointLoader 7 | ) -------------------------------------------------------------------------------- /Merak/merak_args/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['MerakArguments', 'get_args', 'mergeargs', 'manual_set_args'] 2 | 3 | 4 | from .args import ( 5 | MerakArguments, 6 | get_args, 7 | mergeargs, 8 | manual_set_args 9 | ) -------------------------------------------------------------------------------- /examples/deepseek-r1/lora_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "r": 16, 3 | "lora_alpha": 32, 4 | "lora_dropout": 0.05, 5 | "target_modules": ["q_proj", "v_proj"], 6 | "bias": "none", 7 | "task_type": "CAUSAL_LM", 8 | "inference_mode": false 9 | } 10 | -------------------------------------------------------------------------------- /examples/deepseek-r1/lora_config_Qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "r": 32, 3 | "lora_alpha": 64, 4 | "lora_dropout": 0.05, 5 | "target_modules": ["q_proj", "v_proj"], 6 | "bias": "none", 7 | "task_type": "CAUSAL_LM", 8 | "inference_mode": false 9 | } 10 | -------------------------------------------------------------------------------- /examples/deepseek-r1/lora_config_gen.json: -------------------------------------------------------------------------------- 1 | { 2 | "r": 32, 3 | "lora_alpha": 64, 4 | "lora_dropout": 0.05, 5 | "target_modules": ["q_proj", "v_proj"], 6 | "bias": "none", 7 | "task_type": "CAUSAL_LM", 8 | "inference_mode": true 9 | } 10 | -------------------------------------------------------------------------------- /Merak/core/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'mpu', 'PipelineEngine', 3 | 'recompute', 'printer', 4 | 'PipelineModule' 5 | ] 6 | 7 | from . import mpu, recompute, printer 8 | from .merak_engine import PipelineEngine 9 | from .pipeline import PipelineModule 10 | -------------------------------------------------------------------------------- /Merak/core/fx/tracer/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'tf_symbolic_trace', 3 | 'LayerProxyTracer', 4 | 'dynamo_trace' 5 | ] 6 | 7 | from .tracers import LayerProxyTracer 8 | from ._symbolic_trace import tf_symbolic_trace 9 | from ._dynamo_trace import dynamo_trace 10 | -------------------------------------------------------------------------------- /Merak/core/finetuning/lora/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'LoraConfig', '_prepare_lora_config', 3 | 'mark_only_lora_as_trainable', '_find_and_replace' 4 | ] 5 | 6 | from .config import LoraConfig, _prepare_lora_config 7 | from .utils import mark_only_lora_as_trainable, _find_and_replace -------------------------------------------------------------------------------- /examples/torch-models/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | # DEPTHS: [ 2, 2, 18, 2 ] 8 | DEPTHS: [ 2, 2, 2, 2 ] 9 | # NUM_HEADS: [ 4, 8, 16, 32 ] 10 | NUM_HEADS: [ 2, 2, 2, 2 ] 11 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /Merak/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'init_empty_weights', 'WorkerInitObj', 3 | 'MegatronPretrainingRandomSampler', 'BaseParams', 4 | 'RepeatingLoader' 5 | ] 6 | 7 | from .device_to_meta import init_empty_weights 8 | from .trainer_utils import WorkerInitObj, MegatronPretrainingRandomSampler, RepeatingLoader 9 | from .parameters import BaseParams -------------------------------------------------------------------------------- /Merak/core/recompute/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'checkpoint', 3 | 'pre_checkpoint', 4 | 'get_rng_tracker', 5 | 'RNGManager', 6 | 'model_parallel_cuda_manual_seed' 7 | ] 8 | 9 | from .checkpointing import ( 10 | checkpoint, 11 | pre_checkpoint, 12 | get_rng_tracker, 13 | model_parallel_cuda_manual_seed, 14 | RNGManager 15 | ) -------------------------------------------------------------------------------- /Merak/core/printer/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'SynchronizedWallClockTimer', 3 | 'ThroughputTimer', 4 | 'LoggerFactory', 5 | 'log_dist', 6 | 'logger', 7 | 'see_memory_usage', 8 | 'set_timer_log_rank', 9 | 'AccMetric' 10 | ] 11 | 12 | from .timer import ( 13 | SynchronizedWallClockTimer, 14 | set_timer_log_rank 15 | ) 16 | from .logging import ( 17 | LoggerFactory, 18 | AccMetric, 19 | log_dist, 20 | logger 21 | ) 22 | from .see_memory import see_memory_usage -------------------------------------------------------------------------------- /Merak/core/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'InferenceSchedule', 'PipeSchedule', 'TrainSchedule', 3 | 'MergeP2PTrainSchedule', 'LastNoRecomputeTrainSchedule', 4 | 'FullCriticalPathTrainSchedule', 5 | 'PipelineModule', 6 | 'LayerPartition' 7 | ] 8 | 9 | from .module import PipelineModule 10 | from .schedule import ( 11 | InferenceSchedule, 12 | PipeSchedule, 13 | TrainSchedule, 14 | MergeP2PTrainSchedule, 15 | LastNoRecomputeTrainSchedule, 16 | FullCriticalPathTrainSchedule 17 | ) 18 | from .layers_partition import LayerPartition -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## General 2 | # git patch 3 | *.patch 4 | # filesystem 5 | .idea 6 | .DS_Store 7 | .apex 8 | 9 | # compiled python files 10 | *.pyc 11 | 12 | # compiled python folders 13 | **/__pycache__ 14 | **/output 15 | **/imagenet_outputs 16 | **/tmp_trainer 17 | *.egg-info/ 18 | **/*_cache/ 19 | *.so 20 | **/build 21 | 22 | # vscode configurations 23 | **/.vscode 24 | 25 | # model files 26 | #*.json 27 | events.out.tfevents* 28 | model.ckpt* 29 | *.pbtxt 30 | *.pt 31 | eval 32 | weight-decay 33 | events.out.* 34 | tensorboard 35 | 36 | # .tar packages 37 | *.tar 38 | 39 | # log or output files 40 | *.log 41 | *.out 42 | *.txt 43 | *.eps 44 | *.json 45 | *.onnx 46 | 47 | # shell script 48 | *.sh 49 | 50 | -------------------------------------------------------------------------------- /examples/deepseek-r1/models/DeepSeek-R1-Distill-Qwen-32B/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Qwen2ForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 151643, 7 | "eos_token_id": 151643, 8 | "hidden_act": "silu", 9 | "hidden_size": 5120, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 27648, 12 | "max_position_embeddings": 131072, 13 | "max_window_layers": 64, 14 | "model_type": "qwen2", 15 | "num_attention_heads": 40, 16 | "num_hidden_layers": 64, 17 | "num_key_value_heads": 8, 18 | "rms_norm_eps": 1e-05, 19 | "rope_theta": 1000000.0, 20 | "sliding_window": 131072, 21 | "tie_word_embeddings": false, 22 | "torch_dtype": "bfloat16", 23 | "transformers_version": "4.43.1", 24 | "use_cache": true, 25 | "use_sliding_window": false, 26 | "vocab_size": 152064 27 | } 28 | -------------------------------------------------------------------------------- /examples/deepseek-r1/models/DeepSeek-R1-Distll-Qwen-14B/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Qwen2ForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 151643, 7 | "eos_token_id": 151643, 8 | "hidden_act": "silu", 9 | "hidden_size": 5120, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 13824, 12 | "max_position_embeddings": 131072, 13 | "max_window_layers": 48, 14 | "model_type": "qwen2", 15 | "num_attention_heads": 40, 16 | "num_hidden_layers": 48, 17 | "num_key_value_heads": 8, 18 | "rms_norm_eps": 1e-05, 19 | "rope_theta": 1000000.0, 20 | "sliding_window": 131072, 21 | "tie_word_embeddings": false, 22 | "torch_dtype": "bfloat16", 23 | "transformers_version": "4.43.1", 24 | "use_cache": true, 25 | "use_sliding_window": false, 26 | "vocab_size": 152064 27 | } 28 | -------------------------------------------------------------------------------- /examples/deepseek-r1/models/DeepSeek-R1-Distill-Qwen-7B/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Qwen2ForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 151643, 7 | "eos_token_id": 151643, 8 | "hidden_act": "silu", 9 | "hidden_size": 3584, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 18944, 12 | "max_position_embeddings": 131072, 13 | "max_window_layers": 28, 14 | "model_type": "qwen2", 15 | "num_attention_heads": 28, 16 | "num_hidden_layers": 28, 17 | "num_key_value_heads": 4, 18 | "rms_norm_eps": 1e-06, 19 | "rope_theta": 10000, 20 | "sliding_window": 4096, 21 | "tie_word_embeddings": false, 22 | "torch_dtype": "bfloat16", 23 | "transformers_version": "4.44.0", 24 | "use_cache": false, 25 | "use_mrope": false, 26 | "use_sliding_window": false, 27 | "vocab_size": 152064 28 | } 29 | -------------------------------------------------------------------------------- /examples/models/llama/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "llama": 3 | config = {"architectures": ["LlamaForCausalLM"], 4 | "bos_token_id": 0, 5 | "eos_token_id": 1, 6 | "hidden_act": "silu", 7 | "hidden_size": 4096, 8 | "intermediate_size": 11008, 9 | "initializer_range": 0.02, 10 | "max_sequence_length": 2048, 11 | "model_type": "llama", 12 | "num_attention_heads": 32, 13 | "num_hidden_layers": 32, 14 | "pad_token_id": -1, 15 | "rms_norm_eps": 1e-06, 16 | "use_cache": True, 17 | "vocab_size": 32000, 18 | "return_dict": False, 19 | "use_cache": False, 20 | } 21 | else: 22 | raise ValueError(f"No {model_name} config") 23 | return config 24 | -------------------------------------------------------------------------------- /examples/deepseek-r1/models/DeepSeek-R1-Distill-Llama-8B/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 128000, 8 | "eos_token_id": 128001, 9 | "hidden_act": "silu", 10 | "hidden_size": 4096, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 14336, 13 | "max_position_embeddings": 131072, 14 | "mlp_bias": false, 15 | "model_type": "llama", 16 | "num_attention_heads": 3, 17 | "num_hidden_layers": 3, 18 | "num_key_value_heads": 8, 19 | "pretraining_tp": 1, 20 | "rms_norm_eps": 1e-05, 21 | "rope_scaling": { 22 | "factor": 8.0, 23 | "low_freq_factor": 1.0, 24 | "high_freq_factor": 4.0, 25 | "original_max_position_embeddings": 8192, 26 | "rope_type": "llama3" 27 | }, 28 | "rope_theta": 500000.0, 29 | "tie_word_embeddings": false, 30 | "torch_dtype": "bfloat16", 31 | "transformers_version": "4.43.0.dev0", 32 | "use_cache": true, 33 | "vocab_size": 128256 34 | } 35 | -------------------------------------------------------------------------------- /examples/models/distilbert/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "distilbert-base-cased": 3 | config = { 4 | "activation": "gelu", 5 | "architectures": [ 6 | "DistilBertForMaskedLM" 7 | ], 8 | "attention_dropout": 0.1, 9 | "dim": 768, 10 | "dropout": 0.1, 11 | "hidden_dim": 3072, 12 | "initializer_range": 0.02, 13 | "max_position_embeddings": 512, 14 | "model_type": "distilbert", 15 | "n_heads": 12, 16 | "n_layers": 6, 17 | "output_past": True, 18 | "pad_token_id": 0, 19 | "qa_dropout": 0.1, 20 | "seq_classif_dropout": 0.2, 21 | "sinusoidal_pos_embds": False, 22 | "tie_weights_": True, 23 | 'return_dict': False, 24 | "use_cache": False, 25 | "_attn_implementation": 'eager', 26 | "vocab_size": 28996 27 | } 28 | else: 29 | raise ValueError(f"No {model_name} config") 30 | return config -------------------------------------------------------------------------------- /examples/models/albert/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "albert-base-v1": 3 | config = { 4 | "architectures": [ 5 | "AlbertForMaskedLM" 6 | ], 7 | "attention_probs_dropout_prob": 0.1, 8 | "bos_token_id": 2, 9 | "classifier_dropout_prob": 0.1, 10 | "down_scale_factor": 1, 11 | "embedding_size": 128, 12 | "eos_token_id": 3, 13 | "gap_size": 0, 14 | "hidden_act": "gelu", 15 | "hidden_dropout_prob": 0.1, 16 | "hidden_size": 768, 17 | "initializer_range": 0.02, 18 | "inner_group_num": 1, 19 | "intermediate_size": 3072, 20 | "layer_norm_eps": 1e-12, 21 | "max_position_embeddings": 512, 22 | "model_type": "albert", 23 | "net_structure_type": 0, 24 | "num_attention_heads": 12, 25 | "num_hidden_groups": 1, 26 | "num_hidden_layers": 12, 27 | "num_memory_blocks": 0, 28 | "pad_token_id": 0, 29 | "type_vocab_size": 2, 30 | "vocab_size": 30000 31 | } 32 | return config -------------------------------------------------------------------------------- /examples/bert_pretraining/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "bert-large": 3 | config = { 4 | 'architectures': ['BertForPreTraining'], 5 | 'attention_probs_dropout_prob': 0.1, 6 | 'gradient_checkpointing': False, 7 | 'hidden_act': 'gelu', 8 | 'hidden_dropout_prob': 0.1, 9 | 'hidden_size': 1024, 10 | 'initializer_range': 0.02, 11 | 'intermediate_size': 4096, 12 | 'layer_norm_eps': 1e-12, 13 | 'max_position_embeddings': 512, 14 | 'model_type': 'bert', 15 | 'num_attention_heads': 16, 16 | 'num_hidden_layers': 24, 17 | 'pad_token_id': 0, 18 | 'position_embedding_type': 'absolute', 19 | 'type_vocab_size': 2, 20 | '_attn_implementation': 'eager', 21 | 'use_cache': True, 22 | 'return_dict': False, 23 | 'vocab_size': 30522, 24 | } 25 | else: 26 | raise ValueError(f"No {model_name} config") 27 | 28 | return config 29 | -------------------------------------------------------------------------------- /Merak/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | __all__ = [ 19 | 'MerakTrainer', 'init', 'get_grid', 'get_args', 20 | 'get_topo', 'MerakArguments', 'print_rank_0', 21 | 'init_empty_weights' 22 | ] 23 | 24 | from .initialize import init, get_grid, get_topo, print_rank_0 25 | from .merak_trainer import MerakTrainer 26 | from .merak_args import MerakArguments, get_args 27 | from .utils import init_empty_weights -------------------------------------------------------------------------------- /examples/models/dinov2/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "dinov2-base": 3 | config = { 4 | "architectures": [ 5 | "Dinov2Model" 6 | ], 7 | "attention_probs_dropout_prob": 0.0, 8 | "drop_path_rate": 0.0, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.0, 11 | "hidden_size": 768, 12 | "image_size": 518, 13 | "initializer_range": 0.02, 14 | "layer_norm_eps": 1e-06, 15 | "layerscale_value": 1.0, 16 | "mlp_ratio": 4, 17 | "model_type": "dinov2", 18 | "num_attention_heads": 12, 19 | "num_channels": 3, 20 | "num_hidden_layers": 12, 21 | "patch_size": 14, 22 | "qkv_bias": True, 23 | "torch_dtype": "float32", 24 | "transformers_version": "4.31.0.dev0", 25 | 'return_dict': False, 26 | "use_cache": True, 27 | "use_swiglu_ffn": False 28 | } 29 | else: 30 | raise ValueError(f"No {model_name} config") 31 | return config -------------------------------------------------------------------------------- /examples/models/layoutlm/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "layoutlm-base-uncased": 3 | config = { 4 | "_name_or_path": "microsoft/layoutlm-base-uncased", 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_2d_position_embeddings": 1024, 13 | "max_position_embeddings": 512, 14 | "model_type": "layoutlm", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "output_past": True, 18 | "pad_token_id": 0, 19 | "position_embedding_type": "absolute", 20 | "transformers_version": "4.4.0.dev0", 21 | "type_vocab_size": 2, 22 | "use_cache": True, 23 | 'return_dict': False, 24 | "vocab_size": 30522 25 | } 26 | else: 27 | raise ValueError(f"No {model_name} config") 28 | return config -------------------------------------------------------------------------------- /examples/models/electra/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "electra-base": 3 | config = { 4 | "architectures": [ 5 | "ElectraForMaskedLM" 6 | ], 7 | "attention_probs_dropout_prob": 0.1, 8 | "embedding_size": 768, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.1, 11 | "hidden_size": 256, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 1024, 14 | "layer_norm_eps": 1e-12, 15 | "max_position_embeddings": 512, 16 | "model_type": "electra", 17 | "num_attention_heads": 4, 18 | "num_hidden_layers": 12, 19 | "pad_token_id": 0, 20 | "position_embedding_type": "absolute", 21 | "summary_activation": "gelu", 22 | "summary_last_dropout": 0.1, 23 | "summary_type": "first", 24 | "summary_use_proj": True, 25 | "transformers_version": "4.6.0.dev0", 26 | "type_vocab_size": 2, 27 | 'return_dict': False, 28 | "use_cache": False, 29 | "vocab_size": 30522 30 | } 31 | else: 32 | raise ValueError(f"No {model_name} config") 33 | return config -------------------------------------------------------------------------------- /examples/models/opt/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "opt-350m": 3 | config = { 4 | "_name_or_path": "opt-350m", 5 | "activation_dropout": 0.0, 6 | "activation_function": "relu", 7 | "architectures": [ 8 | "OPTForCausalLM" 9 | ], 10 | "attention_dropout": 0.0, 11 | "bos_token_id": 2, 12 | "do_layer_norm_before": False, 13 | "dropout": 0.1, 14 | "eos_token_id": 2, 15 | "ffn_dim": 4096, 16 | "hidden_size": 1024, 17 | "init_std": 0.02, 18 | "layerdrop": 0.0, 19 | "max_position_embeddings": 2048, 20 | "model_type": "opt", 21 | "num_attention_heads": 16, 22 | "num_hidden_layers": 24, 23 | "pad_token_id": 1, 24 | "prefix": "", 25 | "torch_dtype": "float16", 26 | "transformers_version": "4.20.0.dev0", 27 | "use_cache": False, 28 | 'return_dict': False, 29 | "vocab_size": 50272, 30 | "word_embed_proj_dim": 512 31 | } 32 | else: 33 | raise ValueError(f"No {model_name} config") 34 | return config -------------------------------------------------------------------------------- /examples/models/mt5/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "mt5-base": 3 | config = { 4 | "_name_or_path": "/home/patrick/hugging_face/t5/mt5-base", 5 | "architectures": [ 6 | "MT5ForConditionalGeneration" 7 | ], 8 | "d_ff": 2048, 9 | "d_kv": 64, 10 | "d_model": 768, 11 | "decoder_start_token_id": 0, 12 | "dropout_rate": 0.1, 13 | "eos_token_id": 1, 14 | "feed_forward_proj": "gated-gelu", 15 | "initializer_factor": 1.0, 16 | "is_encoder_decoder": True, 17 | "layer_norm_epsilon": 1e-06, 18 | "model_type": "mt5", 19 | "num_decoder_layers": 12, 20 | "num_heads": 12, 21 | "num_layers": 12, 22 | "output_past": True, 23 | "pad_token_id": 0, 24 | "relative_attention_num_buckets": 32, 25 | "tie_word_embeddings": False, 26 | "tokenizer_class": "T5Tokenizer", 27 | "transformers_version": "4.10.0.dev0", 28 | "use_cache": False, 29 | "vocab_size": 250112 30 | } 31 | else: 32 | raise ValueError(f"No {model_name} config") 33 | return config -------------------------------------------------------------------------------- /examples/models/gpt2/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "gpt2": 3 | config = { 4 | 'activation_function': 'gelu', 5 | 'architectures': ['GPT2LMHeadModel'], 6 | 'attn_pdrop': 0.1, 7 | 'bos_token_id': 50256, 8 | 'embd_pdrop': 0.1, 9 | 'eos_token_id': 50256, 10 | 'initializer_range': 0.02, 11 | 'layer_norm_epsilon': 1e-05, 12 | 'model_type': 'gpt2', 13 | "n_ctx": 1024, 14 | 'n_embd': 768, 15 | 'n_head': 12, 16 | 'n_layer': 12, 17 | 'n_positions': 1024, 18 | 'resid_pdrop': 0.1, 19 | 'summary_activation': None, 20 | 'summary_first_dropout': 0.1, 21 | 'summary_proj_to_labels': True, 22 | 'summary_type': 'cls_index', 23 | 'summary_use_proj': True, 24 | 'task_specific_params': {'text-generation': {'do_sample': True, 'max_length': 50}}, 25 | 'vocab_size': 50344, 26 | 'return_dict': False, 27 | 'reorder_and_upcast_attn': False, 28 | '_attn_implementation': 'eager', 29 | 'use_cache': False 30 | } 31 | else: 32 | raise ValueError(f"No {model_name} config") 33 | return config -------------------------------------------------------------------------------- /examples/models/nezha/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "nezha-cn-base": 3 | config = { 4 | "_name_or_path": "nezha-cn-base", 5 | "architectures": [ 6 | "NeZhaForMaskedLM" 7 | ], 8 | "attention_probs_dropout_prob": 0.1, 9 | "bos_token_id": 2, 10 | "classifier_dropout": 0.1, 11 | "embedding_size": 128, 12 | "eos_token_id": 3, 13 | "hidden_act": "gelu", 14 | "hidden_dropout_prob": 0.1, 15 | "hidden_size": 768, 16 | "initializer_range": 0.02, 17 | "inner_group_num": 1, 18 | "intermediate_size": 3072, 19 | "layer_norm_eps": 1e-12, 20 | "max_position_embeddings": 512, 21 | "max_relative_position": 64, 22 | "model_type": "nezha", 23 | "num_attention_heads": 12, 24 | "num_hidden_groups": 1, 25 | "num_hidden_layers": 12, 26 | "pad_token_id": 0, 27 | "torch_dtype": "float32", 28 | "transformers_version": "4.20.0.dev0", 29 | "type_vocab_size": 2, 30 | "use_cache": False, 31 | "use_relative_position": True, 32 | 'return_dict': False, 33 | "vocab_size": 21128 34 | } 35 | else: 36 | raise ValueError(f"No {model_name} config") 37 | return config -------------------------------------------------------------------------------- /examples/text_generation/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | ## Merak text generation examples 20 | 21 | This demonstrates how to perform text generation using 3D parallelism in Merak. 22 | 23 | ```bash 24 | torchrun --nproc_per_node=4 \ 25 | run_gpt_text_generation.py \ 26 | --model-name gpt2 \ 27 | --cache-dir ./cache/gpt2-110M \ 28 | --output_dir ./output \ 29 | --per_device_train_batch_size 1 --gradient_accumulation_steps 1 \ 30 | --resume_from_checkpoint output/transformers_model \ 31 | --activation_checkpointing false --checkpoint_num_layers 0 \ 32 | --no_tie_modules true --seed 42 \ 33 | --split_method layer_split 34 | ``` 35 | 36 | 37 | -------------------------------------------------------------------------------- /docs/fx_doc.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | Merak supports dependency-based automatic graph sharding and config-based manual graph sharding. 20 | 21 | # Graph capture 22 | Merak employs trace tool torch.fx and torch._dynamo in Pytorch to trace model into a fx GraphModule. 23 | 1. Please ensure pytorch is installed. Installation please refer to [pytorch repository](https://github.com/pytorch/pytorch). 24 | 25 | 2. Enable graph capture with setting `trace_method='fx'` or `trace_method='dynamo'` in `Merak.MerakArguments`. 26 | 27 | # Graph sharding 28 | Merak supports dependency-based automatic graph sharding and config-based manual graph sharding. 29 | Enable graph sharding with setting `split_method='farthest_min_deps'` or `split_method='nearest_min_deps'` or `split_method='layer_split'` in `Merak.MerakArguments`. 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /Merak/core/fx/tracer/tracers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Parts of the code here are adapted from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/utils/fx.py 17 | 18 | import torch 19 | 20 | class LayerProxyTracer(torch.fx.Tracer): 21 | """Tracer with an extended set of leaf nn.Modules.""" 22 | 23 | def __init__(self, leaf_modules): 24 | super().__init__() 25 | self.leaf_modules = leaf_modules 26 | 27 | def is_manual_leaf_module(self, m): 28 | for i in self.leaf_modules: 29 | if isinstance(m, i): 30 | return True 31 | return False 32 | 33 | def is_leaf_module(self, m: torch.nn.Module, 34 | model_qualified_name: str) -> bool: 35 | return super().is_leaf_module(m, model_qualified_name) or \ 36 | self.is_manual_leaf_module(m) -------------------------------------------------------------------------------- /examples/lora/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | ## Merak lora examples 20 | 21 | This demonstrates how to perform lora using Merak. 22 | 23 | ```bash 24 | torchrun --nproc-per-node=4 run_vit.py \ 25 | --per_device_train_batch_size 128 --gradient_accumulation_steps 4 \ 26 | --cache-dir ./vit_cache \ 27 | --output_dir ./output --remove_unused_columns False \ 28 | --input_name "pixel_values" \ 29 | --activation_checkpointing false --checkpoint_num_layers 0 \ 30 | --num_train_epochs 5 --learning_rate 5e-3 --dataloader_num_workers 4 \ 31 | --evaluation_strategy='epoch' --return_logits true \ 32 | --save --save_steps 100 --seed 42 \ 33 | --resume_from_checkpoint ./output/ckpt --lora_config ./lora_config.json 34 | ``` 35 | 36 | 37 | -------------------------------------------------------------------------------- /examples/text_generation/config.py: -------------------------------------------------------------------------------- 1 | 2 | def load_config(model_name): 3 | if model_name == "gpt2": 4 | config = { 5 | "activation_function": "gelu", 6 | "architectures": [ 7 | "GPT2LMHeadModel" 8 | ], 9 | "attn_pdrop": 0.1, 10 | "bos_token_id": 50256, 11 | "embd_pdrop": 0.1, 12 | "eos_token_id": 50256, 13 | "initializer_range": 0.02, 14 | "layer_norm_epsilon": 1e-05, 15 | "model_type": "gpt2", 16 | "n_ctx": 1024, 17 | "n_embd": 768, 18 | "n_head": 12, 19 | "n_layer": 12, 20 | "n_positions": 1024, 21 | "resid_pdrop": 0.1, 22 | "summary_activation": None, 23 | "summary_first_dropout": 0.1, 24 | "summary_proj_to_labels": True, 25 | "summary_type": "cls_index", 26 | "summary_use_proj": True, 27 | "task_specific_params": { 28 | "text-generation": { 29 | "do_sample": True, 30 | "max_length": 50 31 | } 32 | }, 33 | 'return_dict': False, 34 | 'reorder_and_upcast_attn': False, 35 | 'use_cache': False, 36 | "vocab_size": 50304 37 | } 38 | else: 39 | raise ValueError(f"No {model_name} config") 40 | 41 | return config 42 | -------------------------------------------------------------------------------- /examples/models/gptj/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "gpt-j-6b": 3 | config = { 4 | "activation_function": "gelu_new", 5 | "architectures": [ 6 | "GPTJForCausalLM" 7 | ], 8 | "attn_pdrop": 0.0, 9 | "bos_token_id": 50256, 10 | "embd_pdrop": 0.0, 11 | "eos_token_id": 50256, 12 | "gradient_checkpointing": False, 13 | "initializer_range": 0.02, 14 | "layer_norm_epsilon": 1e-05, 15 | "model_type": "gptj", 16 | "n_embd": 4096, 17 | "n_head": 16, 18 | "n_inner": None, 19 | "n_layer": 4, #28, 20 | "n_positions": 2048, 21 | "resid_pdrop": 0.0, 22 | "rotary": True, 23 | "rotary_dim": 64, 24 | "scale_attn_weights": True, 25 | "summary_activation": None, 26 | "summary_first_dropout": 0.1, 27 | "summary_proj_to_labels": True, 28 | "summary_type": "cls_index", 29 | "summary_use_proj": True, 30 | "task_specific_params": { 31 | "text-generation": { 32 | "do_sample": True, 33 | "max_length": 50, 34 | "temperature": 1.0 35 | } 36 | }, 37 | "tie_word_embeddings": False, 38 | "tokenizer_class": "GPT2Tokenizer", 39 | "transformers_version": "4.18.0.dev0", 40 | 'return_dict': False, 41 | "use_cache": False, 42 | "vocab_size": 50400 43 | } 44 | else: 45 | raise ValueError(f"No {model_name} config") 46 | return config -------------------------------------------------------------------------------- /examples/image-classification/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | 20 | # Example for running ViT model 21 | 22 | This example show case of ViT model. Model is from `transformers`, but cannot be traced by `transformers.utils.fx`. We fit it in our Merak and can run it with 3D parallelism. 23 | 24 | ## Running ViT model with ImageNet for classification 25 | 26 | --- 27 | 28 | Run it according to following bash: 29 | 30 | ```bash 31 | torchrun --nproc_per_node=4 run_vit.py \ 32 | --per_device_train_batch_size 4 --gradient_accumulation_steps 4 \ 33 | --cache-dir ./vit \ 34 | --data-files /path/to/imagenet \ 35 | --seq_length 1024 --output_dir ./output --remove_unused_columns False \ 36 | --input_name "pixel_values" 37 | ``` 38 | 39 | Code is based on [transformers](https://github.com/huggingface/transformers/tree/master/examples/pytorch/image-classification) repository. -------------------------------------------------------------------------------- /Merak/core/recompute/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | from typing import Any 20 | 21 | def move_to_device(item: Any, device: torch.device): 22 | """ 23 | Move tensor onto device. Works on individual tensors, and tensors contained/ 24 | nested in lists, tuples, and dicts. 25 | Parameters: 26 | item: tensor to move or (possibly nested) container of tensors to move. 27 | device: target device 28 | 29 | Returns: 30 | None 31 | """ 32 | if torch.is_tensor(item): 33 | return item.to(device) 34 | elif isinstance(item, list): 35 | return [move_to_device(v, device) for v in item] 36 | elif isinstance(item, tuple): 37 | return tuple([move_to_device(v, device) for v in item]) 38 | elif isinstance(item, dict): 39 | return {k: move_to_device(v, device) for k, v in item.items()} 40 | else: 41 | return item -------------------------------------------------------------------------------- /examples/models/t5/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "t5-base": 3 | config = { 4 | 'architectures': ['T5WithLMHeadModel'], 5 | 'd_ff': 3072, 6 | 'd_kv': 64, 7 | 'd_model': 512, 8 | 'decoder_start_token_id': 0, 9 | 'dropout_rate': 0.1, 10 | 'eos_token_id': 1, 11 | 'initializer_factor': 1.0, 12 | 'is_encoder_decoder': True, 13 | 'layer_norm_epsilon': 1e-06, 14 | 'model_type': 't5', 15 | 'n_positions': 512, 16 | 'num_heads': 16, 17 | 'num_layers': 8, 18 | 'output_past': True, 19 | 'pad_token_id': 0, 20 | 'relative_attention_num_buckets': 32, 21 | 'task_specific_params': {'summarization': {'early_stopping': True, 'length_penalty': 2.0, 'max_length': 200, 'min_length': 30, 'no_repeat_ngram_size': 3, 'num_beams': 4, 'prefix': 'summarize: '}, 22 | 'translation_en_to_de': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to German: '}, 23 | 'translation_en_to_fr': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 24 | 'prefix': 'translate English to French: '}, 25 | 'translation_en_to_ro': {'early_stopping': True, 'max_length': 300, 26 | 'num_beams': 4, 27 | 'prefix': 'translate English to Romanian: '}}, 28 | 'vocab_size': 32128, 29 | 'use_cache': False, 30 | 'return_dict': False, 31 | } 32 | return config -------------------------------------------------------------------------------- /examples/models/m2m100/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "m2m100_418M": 3 | config = { 4 | "_name_or_path": "hf_models/m2m100_418M", 5 | "activation_dropout": 0.0, 6 | "activation_function": "relu", 7 | "architectures": [ 8 | "M2M100ForConditionalGeneration" 9 | ], 10 | "attention_dropout": 0.1, 11 | "bos_token_id": 0, 12 | "d_model": 1024, 13 | "decoder_attention_heads": 16, 14 | "decoder_ffn_dim": 4096, 15 | "decoder_layerdrop": 0.05, 16 | "decoder_layers": 4, 17 | "decoder_start_token_id": 2, 18 | "dropout": 0.1, 19 | "early_stopping": True, 20 | "encoder_attention_heads": 16, 21 | "encoder_ffn_dim": 4096, 22 | "encoder_layerdrop": 0.05, 23 | "encoder_layers": 4, 24 | "eos_token_id": 2, 25 | "gradient_checkpointing": False, 26 | "init_std": 0.02, 27 | "is_encoder_decoder": True, 28 | "max_length": 200, 29 | "max_position_embeddings": 1024, 30 | "model_type": "m2m_100", 31 | "num_beams": 5, 32 | "num_hidden_layers": 4, 33 | "pad_token_id": 1, 34 | "scale_embedding": True, 35 | "transformers_version": "4.4.0.dev0", 36 | 'return_dict': False, 37 | "use_cache": False, 38 | "vocab_size": 128112 39 | } 40 | else: 41 | raise ValueError(f"No {model_name} config") 42 | return config -------------------------------------------------------------------------------- /examples/models/lxmert/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "lxmert-vqa-uncased": 3 | config = { 4 | "architectures": [ 5 | "LxmertForQuestionAnswering" 6 | ], 7 | "attention_probs_dropout_prob": 0.1, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "l_layers": 9, 14 | "layer_norm_eps": 1e-12, 15 | "max_position_embeddings": 512, 16 | "model_type": "lxmert", 17 | "num_attention_heads": 12, 18 | "num_attr_labels": 400, 19 | "num_hidden_layers": { 20 | "cross_encoder": 5, 21 | "language": 9, 22 | "vision": 5 23 | }, 24 | "num_object_labels": 1600, 25 | "num_qa_labels": 3129, 26 | "r_layers": 5, 27 | "task_mask_lm": True, 28 | "task_matched": True, 29 | "task_obj_predict": True, 30 | "task_qa": True, 31 | "type_vocab_size": 2, 32 | "visual_attr_loss": True, 33 | "visual_feat_dim": 2048, 34 | "visual_feat_loss": True, 35 | "visual_loss_normalizer": 6.67, 36 | "visual_obj_loss": True, 37 | "visual_pos_dim": 4, 38 | "vocab_size": 30522, 39 | 'return_dict': False, 40 | "use_cache": False, 41 | "x_layers": 5 42 | } 43 | else: 44 | raise ValueError(f"No {model_name} config") 45 | return config -------------------------------------------------------------------------------- /examples/torch-models/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | # Running with Swin-transformer and torchvision models 20 | 21 | This script shows that the torch model, which is not from `transformers` library, but can be traced by `torch.fx`, how to run with 3D parallelism in Merak. 22 | 23 | Running with following bash: 24 | 25 | ```bash 26 | torchrun --nproc_per_node=4 run_swin.py \ 27 | --cfg ./swin_base_patch4_window7_224.yaml \ 28 | --output_dir ./output \ 29 | --per_device_train_batch_size 4 --gradient_accumulation_steps 4 \ 30 | --num_layers 16 --wall_clock_breakdown True --logging_steps 10 \ 31 | --data_path /path/to/datasets 32 | 33 | torchrun --nproc-per-node=4 run_torchvision.py \ 34 | --data_path /path/to/datasets \ 35 | --output_dir ./output \ 36 | --cfg ./swin_base_patch4_window7_224.yaml \ 37 | --num_layers 152 38 | ``` 39 | 40 | Code is based on [Swin-transformer](https://github.com/microsoft/Swin-Transformer) repository. 41 | -------------------------------------------------------------------------------- /examples/models/bart/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "bart-large-mnli": 3 | config = { 4 | "_num_labels": 3, 5 | "activation_dropout": 0.0, 6 | "activation_function": "gelu", 7 | "add_final_layer_norm": False, 8 | "architectures": [ 9 | "BartForSequenceClassification" 10 | ], 11 | "_attn_implementation": 'eager', 12 | "attention_dropout": 0.0, 13 | "bos_token_id": 0, 14 | "classif_dropout": 0.0, 15 | "classifier_dropout": 0.0, 16 | "d_model": 1024, 17 | "decoder_attention_heads": 16, 18 | "decoder_ffn_dim": 4096, 19 | "decoder_layerdrop": 0.0, 20 | "decoder_layers": 12, 21 | "decoder_start_token_id": 2, 22 | "dropout": 0.1, 23 | "encoder_attention_heads": 16, 24 | "encoder_ffn_dim": 4096, 25 | "encoder_layerdrop": 0.0, 26 | "encoder_layers": 12, 27 | "eos_token_id": 2, 28 | "forced_eos_token_id": 2, 29 | "gradient_checkpointing": False, 30 | "id2label": { 31 | "0": "contradiction", 32 | "1": "neutral", 33 | "2": "entailment" 34 | }, 35 | "init_std": 0.02, 36 | "is_encoder_decoder": True, 37 | "label2id": { 38 | "contradiction": 0, 39 | "entailment": 2, 40 | "neutral": 1 41 | }, 42 | "max_position_embeddings": 1024, 43 | "model_type": "bart", 44 | "normalize_before": False, 45 | "num_hidden_layers": 12, 46 | "output_past": False, 47 | "pad_token_id": 1, 48 | "scale_embedding": False, 49 | "transformers_version": "4.7.0.dev0", 50 | "use_cache": False, 51 | "vocab_size": 50265 52 | } 53 | return config -------------------------------------------------------------------------------- /examples/unet/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | ## Merak examples 20 | 21 | These examples show that which model can be run with 3D parallelism in Merak. It shows that five popular models of pytorch, including GPT2, ViT, BERT, T5, Swin-transformer, running with 3D parallelism in Merak. These models show three cases of training model: 22 | 23 | 1. Model can be traced by `transformers.utils.fx` , like GPT2, T5 and BERT. 24 | 2. Model is from `transformers`, but cannot be traced by `transformers.utils.fx`, like ViT. 25 | 3. Model is not from `transformers`, but can be traced by `torch.fx`, like Swin-tranfromer. 26 | 27 | User could make sense of Merak's mechanism by these examples, and apply it to another models. Currently, the running bash is on a machine with 4 GPUs. 28 | 29 | ```bash 30 | torchrun --nproc_per_node=4 \ 31 | run_unet.py \ 32 | --train_dataset "/path/to/datasets/" \ 33 | --output_dir ./output \ 34 | --checkpoint_path ./output/ckpt \ 35 | --per_device_train_batch_size 32 --gradient_accumulation_steps 1 \ 36 | --wall_clock_breakdown true --logging_steps 1 \ 37 | --input_name x \ 38 | --trace_method dynamo --crop 128 \ 39 | ``` 40 | -------------------------------------------------------------------------------- /test/core/fx/test_convert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Test command: 19 | # yhrun -p 3090 -N 1 -n 1 torchrun --nproc-per-node=4 test_convert.py --output_dir ./ 20 | 21 | import Merak 22 | 23 | from transformers import ( 24 | GPT2LMHeadModel, 25 | GPT2Config, 26 | HfArgumentParser, 27 | ) 28 | 29 | from Merak import MerakArguments 30 | from Merak.merak_args import mergeargs, manual_set_args 31 | from Merak.core.fx import convert_to_sequential 32 | 33 | def main(): 34 | # init dist 35 | pp = 4 36 | tp = 1 37 | dp = 1 38 | Merak.init(pp, tp, dp) 39 | 40 | # merge args 41 | hfparser = HfArgumentParser(MerakArguments) 42 | training_args = hfparser.parse_args_into_dataclasses()[0] 43 | 44 | config = GPT2Config() 45 | 46 | # create model 47 | model = GPT2LMHeadModel(config) 48 | 49 | # set args 50 | mergeargs(training_args, model.config) 51 | manual_set_args(training_args) 52 | 53 | # convert to sequential 54 | model, model_layers, input_to_shard_dic = convert_to_sequential(model, training_args) 55 | 56 | print('==model==', model) 57 | print('==model_layers==', model_layers) 58 | print('==input_to_shard_dic==', input_to_shard_dic) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() -------------------------------------------------------------------------------- /examples/bert_pretraining/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | 20 | # BERT pretraining 21 | 22 | Here is an example code of [BERT](https://arxiv.org/abs/1810.04805) pretraining with Merak. 23 | 24 | ## Getting the data 25 | 26 | This sample uses hdf5 data files. The preparation of the pre-training dataset is described in 27 | [NVIDIA Examples](https://github.com/NVIDIA/DeepLearningExamples/tree/180382499f791d38eb9e91c105d75764cd2f1cd7/PyTorch/LanguageModeling/BERT#getting-the-data) and its related [scripts](https://github.com/NVIDIA/DeepLearningExamples/tree/180382499f791d38eb9e91c105d75764cd2f1cd7/PyTorch/LanguageModeling/BERT/data). 28 | 29 | ## Training with Merak 30 | 31 | The main modification of training scipts is related to hdf5 data loading and customized loss function. 32 | 33 | --- 34 | 35 | Running according to following bash: 36 | 37 | ```bash 38 | export TOKENIZERS_PARALLELISM=false 39 | torchrun --nproc_per_node=4 run_bert.py \ 40 | --model-name bert-large \ 41 | --data-files "/path/to/hdf5_file/" \ 42 | --output_dir ./output \ 43 | --per_device_train_batch_size 4 --gradient_accumulation_steps 16 \ 44 | --logging_steps 10 \ 45 | --input_name input_ids attention_mask token_type_ids \ 46 | --dataloader_num_workers 2 47 | ``` 48 | -------------------------------------------------------------------------------- /docs/amp_doc.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | # Mixed Precision Training 20 | 21 | Merak supports automatic mixed precision (amp) training and fully fp16 training. 22 | 23 | ## Automatic mixed precision training 24 | 25 | 1. Please ensure apex is installed. Installation please refer to [apex repository](https://github.com/NVIDIA/apex). 26 | 27 | 2. Enable amp training with setting `half_precision_backend='amp'` or `half_precision_backend='apex'` in `Merak.MerakArguments`. Detail usage of this config can be found in [transformers trainer arguments](https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/trainer#transformers.TrainingArguments). 28 | 29 | 3. Amp in Apex training supports O0/O1/O2/O3 level, DP, and PMP for now. 30 | 31 | 32 | ## FP16 training 33 | 34 | 35 | Enable fp16 training with setting `fp16=true` in `Merak.MerakArguments`. Detail usage of this config can be found in [transformers trainer arguments](https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/trainer#transformers.TrainingArguments). 36 | 37 | More configuration of fp16 training, including `loss_scale`, `initial_scale_power`, `loss_scale_window`, `hysteresis`, and `min_loss_scale`, can be setted with `Merak.MerakArguments`. Detail usages of these configs please refer to our api [document](https://github.com/HPDL-Group/Merak/blob/main/docs/api_doc.md). -------------------------------------------------------------------------------- /examples/models/speech2text/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "s2t-medium-librispeech-asr": 3 | config = { 4 | "_name_or_path": "hf_models_fb/s2t-medium-librispeech-asr/", 5 | "activation_dropout": 0.15, 6 | "activation_function": "relu", 7 | "architectures": [ 8 | "Speech2TextForConditionalGeneration" 9 | ], 10 | "attention_dropout": 0.15, 11 | "bos_token_id": 0, 12 | "classifier_dropout": 0.0, 13 | "conv_channels": 1024, 14 | "conv_kernel_sizes": [ 15 | 5, 16 | 5 17 | ], 18 | "d_model": 512, 19 | "decoder_attention_heads": 8, 20 | "decoder_ffn_dim": 2048, 21 | "decoder_layerdrop": 0.0, 22 | "decoder_layers": 6, 23 | "decoder_start_token_id": 2, 24 | "dropout": 0.15, 25 | "early_stopping": True, 26 | "encoder_attention_heads": 8, 27 | "encoder_ffn_dim": 2048, 28 | "encoder_layerdrop": 0.0, 29 | "encoder_layers": 12, 30 | "eos_token_id": 2, 31 | "gradient_checkpointing": False, 32 | "init_std": 0.02, 33 | "input_channels": 1, 34 | "input_feat_per_channel": 80, 35 | "is_encoder_decoder": True, 36 | "max_length": 200, 37 | "max_source_positions": 6000, 38 | "max_target_positions": 1024, 39 | "model_type": "speech_to_text", 40 | "num_beams": 5, 41 | "num_conv_layers": 2, 42 | "num_hidden_layers": 12, 43 | "pad_token_id": 1, 44 | "scale_embedding": True, 45 | "transformers_version": "4.4.0.dev0", 46 | "use_cache": False, 47 | 'return_dict': False, 48 | "vocab_size": 10000 49 | } 50 | else: 51 | raise ValueError(f"No {model_name} config") 52 | return config -------------------------------------------------------------------------------- /examples/torch-models/run_torchvision.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | from Merak import MerakArguments 20 | from Merak import MerakTrainer 21 | 22 | import os 23 | from transformers import ( 24 | HfArgumentParser 25 | ) 26 | 27 | from config import get_config 28 | from data import build_loader 29 | import torchvision 30 | 31 | 32 | def parse_option(parser): 33 | group = parser.add_argument_group('Torchvision model training and evaluation script') 34 | group.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 35 | group.add_argument('--data_path', type=str, default=None, help='path to data folder', ) 36 | 37 | return parser 38 | 39 | 40 | def main(config): 41 | dataset_train, dataset_val, _, _, _ = build_loader(config) 42 | 43 | model = torchvision.models.resnet152() 44 | 45 | trainer = MerakTrainer( 46 | model=model, 47 | args=training_args, 48 | train_dataset=dataset_train, 49 | eval_dataset=dataset_val, 50 | ) 51 | trainer.train() 52 | 53 | 54 | if __name__ == '__main__': 55 | pp = 2 56 | tp = 1 57 | dp = 2 58 | Merak.init(pp, tp, dp) 59 | 60 | hfparser = HfArgumentParser(MerakArguments) 61 | parser = parse_option(hfparser) 62 | training_args, args = parser.parse_args_into_dataclasses() 63 | 64 | # using data config from swin transformer 65 | config = get_config(args) 66 | main(config) 67 | -------------------------------------------------------------------------------- /test/core/fx/tracer/test_symbolic_trace.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Test command: 19 | # yhrun -p 3090 -N 1 -n 1 torchrun --nproc-per-node=4 test_symbolic_trace.py --output_dir ./ 20 | 21 | import Merak 22 | 23 | from transformers import ( 24 | GPT2LMHeadModel, 25 | GPT2Config, 26 | HfArgumentParser, 27 | ) 28 | 29 | from Merak import MerakArguments 30 | from Merak.merak_args import mergeargs, manual_set_args 31 | from Merak.core.fx.tracer import symbolic_trace 32 | 33 | def main(): 34 | # init dist 35 | pp = 4 36 | tp = 1 37 | dp = 1 38 | Merak.init(pp, tp, dp) 39 | 40 | # merge args 41 | hfparser = HfArgumentParser(MerakArguments) 42 | training_args = hfparser.parse_args_into_dataclasses()[0] 43 | 44 | config = GPT2Config() 45 | 46 | # create model 47 | model = GPT2LMHeadModel(config) 48 | 49 | # set args 50 | mergeargs(training_args, model.config) 51 | manual_set_args(training_args) 52 | 53 | # trace model 54 | traced, dummy_inputs = symbolic_trace( 55 | model, 56 | input_names = training_args.input_names, 57 | batch_size = training_args.per_device_train_batch_size, 58 | sequence_length = training_args.seq_length, 59 | ) 60 | 61 | print(['==dummy_inputs==', dummy_inputs]) 62 | print(['==traced==', traced]) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /Merak/core/finetuning/lora/mappings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # The code here are adapted from https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py 19 | 20 | 21 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { 22 | "t5": ["q", "v"], 23 | "mt5": ["q", "v"], 24 | "bart": ["q_proj", "v_proj"], 25 | "gpt2": ["c_attn"], 26 | "bloom": ["query_key_value"], 27 | "blip-2": ["q", "v", "q_proj", "v_proj"], 28 | "opt": ["q_proj", "v_proj"], 29 | "gptj": ["q_proj", "v_proj"], 30 | "gpt_neox": ["query_key_value"], 31 | "gpt_neo": ["q_proj", "v_proj"], 32 | "bert": ["query", "value"], 33 | "roberta": ["query", "value"], 34 | "xlm-roberta": ["query", "value"], 35 | "electra": ["query", "value"], 36 | "deberta-v2": ["query_proj", "value_proj"], 37 | "deberta": ["in_proj"], 38 | "layoutlm": ["query", "value"], 39 | "llama": ["q_proj", "v_proj"], 40 | "chatglm": ["query_key_value"], 41 | "gpt_bigcode": ["c_attn"], 42 | "mpt": ["Wqkv"], 43 | "RefinedWebModel": ["query_key_value"], 44 | "RefinedWeb": ["query_key_value"], 45 | "falcon": ["query_key_value"], 46 | "btlm": ["c_proj", "c_attn"], 47 | "codegen": ["qkv_proj"], 48 | "mistral": ["q_proj", "v_proj"], 49 | "mixtral": ["q_proj", "v_proj"], 50 | "stablelm": ["q_proj", "v_proj"], 51 | "phi": ["q_proj", "v_proj", "fc1", "fc2"], 52 | "gemma": ["q_proj", "v_proj"], 53 | "qwen2": ["q_proj", "v_proj"], 54 | } 55 | 56 | 57 | -------------------------------------------------------------------------------- /examples/torch-models/config.json: -------------------------------------------------------------------------------- 1 | AMP_OPT_LEVEL: '' 2 | AUG: 3 | AUTO_AUGMENT: rand-m9-mstd0.5-inc1 4 | COLOR_JITTER: 0.4 5 | CUTMIX: 1.0 6 | CUTMIX_MINMAX: null 7 | MIXUP: 0.8 8 | MIXUP_MODE: batch 9 | MIXUP_PROB: 1.0 10 | MIXUP_SWITCH_PROB: 0.5 11 | RECOUNT: 1 12 | REMODE: pixel 13 | REPROB: 0.25 14 | BASE: 15 | - '' 16 | DATA: 17 | BATCH_SIZE: 128 18 | CACHE_MODE: part 19 | DATASET: imagenet 20 | DATA_PATH: /ssd/datasets/imagenet/pytorch/ 21 | IMG_SIZE: 224 22 | INTERPOLATION: bicubic 23 | NUM_WORKERS: 2 24 | PIN_MEMORY: true 25 | ZIP_MODE: false 26 | EVAL_MODE: false 27 | LOCAL_RANK: 0 28 | MODEL: 29 | DROP_PATH_RATE: 0.5 30 | DROP_RATE: 0.0 31 | LABEL_SMOOTHING: 0.1 32 | NAME: swin_base_patch4_window7_224 33 | NUM_CLASSES: 1000 34 | PRETRAINED: '' 35 | RESUME: '' 36 | SWIN: 37 | APE: false 38 | DEPTHS: 39 | - 2 40 | - 2 41 | - 2 42 | - 2 43 | EMBED_DIM: 128 44 | IN_CHANS: 3 45 | MLP_RATIO: 4.0 46 | NUM_HEADS: 47 | - 2 48 | - 2 49 | - 2 50 | - 2 51 | PATCH_NORM: true 52 | PATCH_SIZE: 4 53 | QKV_BIAS: true 54 | QK_SCALE: null 55 | WINDOW_SIZE: 7 56 | SWIN_MLP: 57 | APE: false 58 | DEPTHS: 59 | - 2 60 | - 2 61 | - 6 62 | - 2 63 | EMBED_DIM: 96 64 | IN_CHANS: 3 65 | MLP_RATIO: 4.0 66 | NUM_HEADS: 67 | - 3 68 | - 6 69 | - 12 70 | - 24 71 | PATCH_NORM: true 72 | PATCH_SIZE: 4 73 | WINDOW_SIZE: 7 74 | TYPE: swin 75 | OUTPUT: '' 76 | PRINT_FREQ: 10 77 | SAVE_FREQ: 1 78 | SEED: 0 79 | TAG: default 80 | TEST: 81 | CROP: true 82 | SEQUENTIAL: false 83 | THROUGHPUT_MODE: false 84 | TRAIN: 85 | ACCUMULATION_STEPS: 0 86 | AUTO_RESUME: true 87 | BASE_LR: 0.0005 88 | CLIP_GRAD: 5.0 89 | EPOCHS: 300 90 | LR_SCHEDULER: 91 | DECAY_EPOCHS: 30 92 | DECAY_RATE: 0.1 93 | NAME: cosine 94 | MIN_LR: 5.0e-06 95 | OPTIMIZER: 96 | BETAS: 97 | - 0.9 98 | - 0.999 99 | EPS: 1.0e-08 100 | MOMENTUM: 0.9 101 | NAME: adamw 102 | START_EPOCH: 0 103 | USE_CHECKPOINT: false 104 | WARMUP_EPOCHS: 20 105 | WARMUP_LR: 5.0e-07 106 | WEIGHT_DECAY: 0.05 107 | -------------------------------------------------------------------------------- /Merak/core/tensor_parallel/mp_mapping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from transformers.utils.fx import _generate_supported_model_class_names 19 | from typing import Callable, Optional 20 | 21 | SUPPORTED_MODEL_NAMES = [ 22 | "bert", 23 | "gpt2", 24 | "t5", 25 | "vit" 26 | ] 27 | 28 | MP_MODEL_MAPPING = { 29 | 'gpt2': { 30 | 'input_output_mapping':[(1, 3, 'col'), (1, 1,'row'), 31 | (1, 4, 'col'), (4, 1, 'row')], 32 | 'tp_attr_list':['num_heads', 'split_size'] 33 | }, 34 | 't5': { 35 | 'col_para_list':['Attention.q', 'Attention.k', 36 | 'Attention.v', 'DenseReluDense.wi'], 37 | 'row_para_list':['Attention.o', 'DenseReluDense.wo'], 38 | 'weight_change_list':[('relative_attention_bias', 1)], 39 | 'tp_attr_list':['n_heads', 'inner_dim'] 40 | }, 41 | 'bert':{ 42 | 'col_para_list':['query', 'key', 'value', 'intermediate.dense'], 43 | 'row_para_list':['output.dense'], 44 | 'tp_attr_list':['num_attention_heads','all_head_size'] 45 | }, 46 | 'vit':{ 47 | 'col_para_list':['query', 'key', 'value', 'intermediate.dense'], 48 | 'row_para_list':['output.dense'], 49 | 'tp_attr_list':['num_attention_heads','all_head_size'] 50 | }, 51 | } 52 | 53 | def get_mp_layer_lists(model_class: Callable) -> Optional[dict]: 54 | for model_name in SUPPORTED_MODEL_NAMES: 55 | if model_class.__name__ in \ 56 | _generate_supported_model_class_names(model_name): 57 | return MP_MODEL_MAPPING[model_name] 58 | return None -------------------------------------------------------------------------------- /test/core/pipeline_parallel/test_module.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # test command 19 | # yhrun -N 1 -n 1 -p 3090 torchrun --nproc-per-node=4 test_module.py --output_dir ./output 20 | 21 | import torch 22 | import Merak 23 | 24 | from Merak import get_topo, get_grid, MerakArguments 25 | from Merak.core.pipeline import PipelineModule 26 | from Merak.core.recompute import checkpoint as checkpoint_func 27 | from Merak.core.fx import convert_to_sequential 28 | from Merak.merak_args import mergeargs 29 | 30 | from transformers import ( 31 | HfArgumentParser, 32 | GPT2LMHeadModel, 33 | GPT2Config, 34 | ) 35 | 36 | def main(): 37 | # config_kwarg = load_config("layoutlm-base-uncased") 38 | Merak.init(4, 1, 1) 39 | 40 | # merge args 41 | hfparser = HfArgumentParser(MerakArguments) 42 | training_args = hfparser.parse_args_into_dataclasses()[0] 43 | 44 | config = GPT2Config() 45 | 46 | # create model 47 | model = GPT2LMHeadModel(config).cuda() 48 | 49 | mergeargs(training_args, model.config) 50 | 51 | model, layers, input_to_shard = convert_to_sequential(model, training_args) 52 | del model 53 | 54 | pipe_model = PipelineModule( 55 | layers=layers, 56 | args=training_args, 57 | loss_fn=torch.nn.CrossEntropyLoss(), 58 | topology=get_topo(), 59 | communicaiton_grid=get_grid(), 60 | activation_checkpoint_func=checkpoint_func, 61 | tie_dims=set(), 62 | input_to_shard_dic=input_to_shard, 63 | ) 64 | 65 | print(pipe_model.partitions()) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /examples/models/marianmt/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "opus-mt-zh-en": 3 | config = { 4 | "_name_or_path": "/tmp/Helsinki-NLP/opus-mt-zh-en", 5 | "activation_dropout": 0.0, 6 | "activation_function": "swish", 7 | "add_bias_logits": False, 8 | "add_final_layer_norm": False, 9 | "architectures": [ 10 | "MarianMTModel" 11 | ], 12 | "attention_dropout": 0.0, 13 | "bad_words_ids": [ 14 | [ 15 | 65000 16 | ] 17 | ], 18 | "bos_token_id": 0, 19 | "classif_dropout": 0.0, 20 | "classifier_dropout": 0.0, 21 | "d_model": 512, 22 | "decoder_attention_heads": 8, 23 | "decoder_ffn_dim": 2048, 24 | "decoder_layerdrop": 0.0, 25 | "decoder_layers": 6, 26 | "decoder_start_token_id": 65000, 27 | "decoder_vocab_size": 65001, 28 | "dropout": 0.1, 29 | "encoder_attention_heads": 8, 30 | "encoder_ffn_dim": 2048, 31 | "encoder_layerdrop": 0.0, 32 | "encoder_layers": 6, 33 | "eos_token_id": 0, 34 | "extra_pos_embeddings": 65001, 35 | "forced_eos_token_id": 0, 36 | "id2label": { 37 | "0": "LABEL_0", 38 | "1": "LABEL_1", 39 | "2": "LABEL_2" 40 | }, 41 | "init_std": 0.02, 42 | "is_encoder_decoder": True, 43 | "label2id": { 44 | "LABEL_0": 0, 45 | "LABEL_1": 1, 46 | "LABEL_2": 2 47 | }, 48 | "max_length": 512, 49 | "max_position_embeddings": 512, 50 | "model_type": "marian", 51 | "normalize_before": False, 52 | "normalize_embedding": False, 53 | "num_beams": 6, 54 | "num_hidden_layers": 6, 55 | "pad_token_id": 65000, 56 | "scale_embedding": True, 57 | "share_encoder_decoder_embeddings": True, 58 | "static_position_embeddings": True, 59 | "transformers_version": "4.22.0.dev0", 60 | "use_cache": False, 61 | 'return_dict': False, 62 | "vocab_size": 65001 63 | } 64 | else: 65 | raise ValueError(f"No {model_name} config") 66 | return config -------------------------------------------------------------------------------- /Merak/core/mpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Parts of the code here are adapted from https://github.com/NVIDIA/Megatron-LM/blob/806422e5ec35c27b027dbb413b05e27b6590dc56/megatron/mpu/__init__.py 19 | 20 | """Model parallel utility interface.""" 21 | 22 | from .cross_entropy import vocab_parallel_cross_entropy 23 | 24 | from .initialize import destroy_model_parallel 25 | from .initialize import get_data_parallel_group 26 | from .initialize import get_data_parallel_rank 27 | from .initialize import get_data_parallel_world_size 28 | from .initialize import get_model_parallel_group 29 | from .initialize import get_model_parallel_rank, set_model_parallel_rank 30 | from .initialize import get_model_parallel_src_rank 31 | from .initialize import get_model_parallel_world_size, set_model_parallel_world_size 32 | from .initialize import get_pipe_parallel_group 33 | from .initialize import get_pipe_parallel_rank 34 | from .initialize import get_pipe_parallel_world_size 35 | from .initialize import get_pipeline_model_parallel_prev_rank 36 | from .initialize import get_pipeline_model_parallel_next_rank 37 | from .initialize import is_pipeline_first_stage, is_pipeline_last_stage 38 | 39 | from .layers import LayerNorm 40 | from .layers import ColumnParallelLinear 41 | from .layers import RowParallelLinear 42 | from .layers import VocabParallelEmbedding 43 | from .layers import RowSequenceParallel 44 | from .layers import ColumnSequenceParallel 45 | from .layers import ColParallelConv2d 46 | 47 | from .mappings import copy_to_model_parallel_region 48 | from .mappings import gather_from_model_parallel_region 49 | from .mappings import reduce_from_model_parallel_region 50 | from .mappings import scatter_to_model_parallel_region 51 | 52 | from .utils import divide 53 | from .utils import split_tensor_along_last_dim -------------------------------------------------------------------------------- /examples/models/blenderbot/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "blenderbot-400M-distill": 3 | config = { 4 | "_name_or_path": "./", 5 | "activation_dropout": 0.0, 6 | "activation_function": "gelu", 7 | "add_bias_logits": False, 8 | "add_final_layer_norm": True, 9 | "architectures": [ 10 | "BlenderbotForConditionalGeneration" 11 | ], 12 | "attention_dropout": 0.0, 13 | "bos_token_id": 1, 14 | "classif_dropout": 0.0, 15 | "classifier_dropout": 0.0, 16 | "d_model": 1280, 17 | "decoder_attention_heads": 32, 18 | "decoder_ffn_dim": 5120, 19 | "decoder_layerdrop": 0.0, 20 | "decoder_layers": 12, 21 | "decoder_start_token_id": 1, 22 | "do_blenderbot_90_layernorm": True, 23 | "dropout": 0.1, 24 | "encoder_attention_heads": 32, 25 | "encoder_ffn_dim": 5120, 26 | "encoder_layerdrop": 0.0, 27 | "encoder_layers": 2, 28 | "encoder_no_repeat_ngram_size": 3, 29 | "eos_token_id": 2, 30 | "extra_layer_norm": False, 31 | "extra_pos_embeddings": 0, 32 | "force_bos_token_to_be_generated": False, 33 | "forced_eos_token_id": 2, 34 | "gradient_checkpointing": False, 35 | "id2label": { 36 | "0": "LABEL_0", 37 | "1": "LABEL_1", 38 | "2": "LABEL_2" 39 | }, 40 | "init_std": 0.02, 41 | "is_encoder_decoder": True, 42 | "label2id": { 43 | "LABEL_0": 0, 44 | "LABEL_1": 1, 45 | "LABEL_2": 2 46 | }, 47 | "layernorm_variant": "prelayernorm", 48 | "length_penalty": 0.65, 49 | "max_length": 60, 50 | "max_position_embeddings": 128, 51 | "min_length": 20, 52 | "model_type": "blenderbot", 53 | "no_repeat_ngram_size": 3, 54 | "normalize_before": True, 55 | "normalize_embedding": False, 56 | "num_beams": 10, 57 | "num_hidden_layers": 2, 58 | "pad_token_id": 0, 59 | "scale_embedding": True, 60 | "static_position_embeddings": False, 61 | "transformers_version": "4.13.0.dev0", 62 | "unk_token_id": 3, 63 | "use_cache": True, 64 | "vocab_size": 8008 65 | } 66 | return config -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | 20 | # Merak Parallel Training Framework 21 | 22 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 23 | 24 | Merak is a high-performance distributed training framework designed for 3D parallel model training. Supports seamless parallelization across multiple GPUs/nodes. 25 | 26 | ## 🔥 Supported Models 27 | 28 | ### Population Models 29 | | | | | | 30 | |---|---|---|---| 31 | | ✅ **DeepseekR1** | 32 | 33 | ### Natural Language Processing 34 | | | | | | 35 | |---|---|---|---| 36 | | ✅ ALBERT | ✅ Bart | ✅ BERT | ✅ BlenderBot | 37 | | ✅ DistilBERT | ✅ Electra | ✅ GPT-2 | ✅ GPT-J | 38 | | ✅ LLaMA | ✅ MarianMT | ✅ mBART | ✅ mT5 | 39 | | ✅ Nezha | ✅ Pegasus | ✅ PLBART | ✅ T5 | 40 | | ✅ XGLM | ✅ OPT | ✅ m2m100 | ✅ LayoutLM | 41 | 42 | ### Multimodal & Vision-Language 43 | | | | | 44 | |---|---|---| 45 | | ✅ CLIP | ✅ AltCLIP | ✅ TroCR | 46 | 47 | ### Computer Vision 48 | | | | | | 49 | |---|---|---|---| 50 | | ✅ ConvNeXt | ✅ ResNet | ✅ Swin | ✅ DINOv2 | 51 | | ✅ SegFormer | ✅ MobileBERT | ✅ LXMERT | 52 | | ✅ UNet | ✅ ViT | 53 | 54 | ### Speech Processing 55 | | | | | 56 | |---|---|---| 57 | | ✅ Wav2Vec2 | ✅ Speech2Text | ✅ Speech2Text2 | 58 | | ✅ Hubert | 59 | 60 | 61 | ## ✨ Key Features 62 | - ‌**Multi-Strategy Parallelism**‌ 63 | Tensor/Data/Pipeline Parallelism hybrid training 64 | - ‌**Memory Optimization**‌ 65 | Zero Redundancy Optimizer (ZeRO) Stage 1 66 | - ‌**High-Performance Pipeline Parallelism Strategy**‌ 67 | last_no_recompute_1f1b/full_critical_path_1f1b 68 | - ‌**Other Functions**‌ 69 | Lora/Text Generation 70 | 71 | ## 📌 Key Notes 72 | 1. ‌**Compatibility Levels**‌ 73 | - ✅ Full Support: Out-of-the-box parallel strategies 74 | - ⚠️ Partial Support: Requires manual configuration or has functional constraints 75 | -------------------------------------------------------------------------------- /examples/models/bert/run_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | 20 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 21 | from config import load_config 22 | from transformers import ( 23 | set_seed, 24 | BertForMaskedLM, 25 | BertForPreTraining, 26 | BertConfig, 27 | HfArgumentParser, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | def parse_option(parser): 33 | # easy config modification 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. bert)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # init dist 40 | pp = 2 41 | tp = 2 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | hfparser = HfArgumentParser(MerakArguments) 46 | parser = parse_option(hfparser) 47 | training_args, args = parser.parse_args_into_dataclasses() 48 | 49 | # Set seed before initializing model. 50 | set_seed(training_args.seed) 51 | 52 | # set model config 53 | config_kwarg = load_config(args.model_name) 54 | config = BertConfig(**config_kwarg) 55 | 56 | 57 | with init_empty_weights(): 58 | if args.model_name == 'bert-large-uncased': 59 | model = BertForPreTraining(config) 60 | elif args.model_name == 'bert-large' or args.model_name == 'bert-base-uncased': 61 | model = BertForMaskedLM(config) 62 | 63 | # Create a fake dataset for training 64 | train_dataset = DynamicGenDataset( 65 | model.config, mode="text_only", dataset_size=1e6 66 | ) 67 | 68 | 69 | # using our distributed trainer 70 | trainer = MerakTrainer( 71 | model=model, 72 | args=training_args, 73 | train_dataset=train_dataset, 74 | ) 75 | 76 | # Training 77 | trainer.train() 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /examples/models/gpt2/run_gpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | HfArgumentParser, 26 | GPT2LMHeadModel, 27 | GPT2Config, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 4 42 | tp = 1 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | # Parse training and model arguments 47 | hfparser = HfArgumentParser(MerakArguments) 48 | parser = parse_option(hfparser) 49 | training_args, args = parser.parse_args_into_dataclasses() 50 | 51 | # Set random seed for reproducibility 52 | set_seed(training_args.seed) 53 | 54 | # Load model configuration from custom config file 55 | config_kwarg = load_config(args.model_name) 56 | config = GPT2Config(**config_kwarg) 57 | 58 | # Initialize GPT-2 language modeling model 59 | with init_empty_weights(): 60 | model = GPT2LMHeadModel(config) 61 | 62 | # Create a fake dataset for training 63 | train_dataset = DynamicGenDataset( 64 | model.config, mode="text_only", dataset_size=1e6 65 | ) 66 | 67 | # Initialize trainer with model, training arguments and dataset 68 | trainer = MerakTrainer( 69 | model=model, 70 | args=training_args, 71 | train_dataset=train_dataset, 72 | ) 73 | 74 | # Start training 75 | trainer.train() 76 | 77 | 78 | if __name__ == "__main__": 79 | main() -------------------------------------------------------------------------------- /examples/models/hubert/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "hubert-base-ls960": 3 | config = { 4 | "_name_or_path": "facebook/hubert-base-ls960", 5 | "activation_dropout": 0.1, 6 | "apply_spec_augment": True, 7 | "architectures": [ 8 | "HubertModel" 9 | ], 10 | "attention_dropout": 0.1, 11 | "bos_token_id": 1, 12 | "conv_bias": False, 13 | "conv_dim": [ 14 | 512, 15 | 512, 16 | 512, 17 | 512, 18 | 512, 19 | 512, 20 | 512 21 | ], 22 | "conv_kernel": [ 23 | 10, 24 | 3, 25 | 3, 26 | 3, 27 | 3, 28 | 2, 29 | 2 30 | ], 31 | "conv_stride": [ 32 | 5, 33 | 2, 34 | 2, 35 | 2, 36 | 2, 37 | 2, 38 | 2 39 | ], 40 | "ctc_loss_reduction": "sum", 41 | "ctc_zero_infinity": False, 42 | "do_stable_layer_norm": False, 43 | "eos_token_id": 2, 44 | "feat_extract_activation": "gelu", 45 | "feat_extract_dropout": 0.0, 46 | "feat_extract_norm": "group", 47 | "feat_proj_dropout": 0.1, 48 | "final_dropout": 0.1, 49 | "gradient_checkpointing": False, 50 | "hidden_act": "gelu", 51 | "hidden_dropout": 0.1, 52 | "hidden_dropout_prob": 0.1, 53 | "hidden_size": 768, 54 | "initializer_range": 0.02, 55 | "intermediate_size": 3072, 56 | "layer_norm_eps": 1e-05, 57 | "layerdrop": 0.1, 58 | "mask_feature_length": 10, 59 | "mask_feature_prob": 0.0, 60 | "mask_time_length": 10, 61 | "mask_time_prob": 0.05, 62 | "model_type": "hubert", 63 | "num_attention_heads": 12, 64 | "num_conv_pos_embedding_groups": 16, 65 | "num_conv_pos_embeddings": 128, 66 | "num_feat_extract_layers": 7, 67 | "num_hidden_layers": 12, 68 | "pad_token_id": 0, 69 | "transformers_version": "4.10.0.dev0", 70 | "vocab_size": 32, 71 | 'return_dict': False, 72 | "use_cache": False, 73 | "tokenizer_class": "Wav2Vec2CTCTokenizer" 74 | } 75 | else: 76 | raise ValueError(f"No {model_name} config") 77 | return config -------------------------------------------------------------------------------- /examples/models/t5/run_t5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | HfArgumentParser, 26 | T5ForConditionalGeneration, 27 | T5Config, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 2 42 | tp = 2 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | # Parse training and model arguments 47 | hfparser = HfArgumentParser(MerakArguments) 48 | parser = parse_option(hfparser) 49 | training_args, args = parser.parse_args_into_dataclasses() 50 | 51 | # Set random seed for reproducibility 52 | set_seed(training_args.seed) 53 | 54 | # Load model configuration from custom config file 55 | config_kwarg = load_config(args.model_name) 56 | config = T5Config(**config_kwarg) 57 | 58 | # Initialize language modeling model 59 | with init_empty_weights(): 60 | model = T5ForConditionalGeneration(config) 61 | 62 | # Create a fake dataset for training 63 | train_dataset = DynamicGenDataset( 64 | model.config, mode="condition", dataset_size=1e6 65 | ) 66 | 67 | # Initialize trainer with model, training arguments and dataset 68 | trainer = MerakTrainer( 69 | model=model, 70 | args=training_args, 71 | train_dataset=train_dataset, 72 | ) 73 | 74 | # Start training 75 | trainer.train() 76 | 77 | 78 | if __name__ == "__main__": 79 | main() -------------------------------------------------------------------------------- /Merak/core/fx/graph_shard/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Yck (eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | from typing import Any, Callable, Dict, List, Optional, Tuple 19 | 20 | 21 | def _snake_case(s: str): 22 | """ 23 | Transforms the given string ``s`` to a Python-style variable name 24 | 25 | Examples: 26 | ``mod.snake_case`` -> ``mod.snake_case`` 27 | ``mod.pascalCase``-> ``mod.pascal_case`` 28 | ``mod.ALL_CAPS`` -> ``mod.all_caps`` 29 | """ 30 | chars = [] 31 | prev_lower = False 32 | for c in s: 33 | if prev_lower and c.isupper(): 34 | chars.append('_') 35 | chars.append(c.lower()) 36 | prev_lower = c.islower() 37 | return ''.join(chars) 38 | 39 | 40 | def _get_count(param_count: Dict[str, int], node_name: str): 41 | """Identify different mutations of a given node name.""" 42 | 43 | if node_name in param_count: 44 | return param_count[node_name] 45 | elif node_name.replace(".", "_") in param_count: 46 | return param_count[node_name.replace(".", "_")] 47 | else: 48 | raise RuntimeError(f"Unable to find match between \ 49 | param {param_count} and node {node_name}") 50 | 51 | 52 | def _create_shard_to_param_count( 53 | param_count: Dict[str, int], 54 | node_name_to_shard_id: Dict[str, int] 55 | ) -> Dict[str, int]: 56 | """Utility to create a map from shard id to param count using 57 | existing state.""" 58 | 59 | shard_to_param_count: Dict[int, int] = {} 60 | for node in node_name_to_shard_id.keys(): 61 | try: 62 | count = _get_count(param_count, node) 63 | except RuntimeError: 64 | # print_rank_0(f"Unable to find match node {node_name}") 65 | continue 66 | if node_name_to_shard_id[node] in shard_to_param_count: 67 | shard_to_param_count[node_name_to_shard_id[node]] += count 68 | else: 69 | shard_to_param_count[node_name_to_shard_id[node]] = count 70 | return shard_to_param_count -------------------------------------------------------------------------------- /Merak/core/fx/tracer/_dynamo_trace.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from torch.fx._compatibility import compatibility 19 | from torch.fx.graph_module import GraphModule 20 | from typing import Dict, List, Union, Tuple 21 | 22 | def dynamo_trace( 23 | module: torch.nn.Module, 24 | dummy_inputs: Union[Dict[str, torch.Tensor], Tuple[torch.Tensor], List[torch.Tensor]] 25 | ) -> List[GraphModule]: 26 | if isinstance(dummy_inputs, dict): 27 | inputs = tuple(dummy_inputs.values()) 28 | elif isinstance(dummy_inputs, (tuple, list)): 29 | inputs = tuple(dummy_inputs) 30 | else: 31 | raise TypeError("Type of dummy inputs must be list, tuple or dict") 32 | 33 | try: 34 | ep = torch.export.export_for_training( 35 | module, 36 | inputs, 37 | ) 38 | except Exception as e: 39 | raise RuntimeError( 40 | "It seems that we cannot capture your model as a full graph. " 41 | "Typical reasons include graph breaks, data/shape-dependent " 42 | "control flow, or missing meta kernels for custom operators. " 43 | "You can use our manual pipeline interfaces, or try to fix the " 44 | "graph breaks, see https://pytorch.org/docs/stable/export.html" 45 | ) from e 46 | 47 | traced = ep.module() 48 | 49 | # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving 50 | # parameters relies on the invariant that parameter accesses happen once. This is not necessarily 51 | # the case (especially with custom tracers), so fix that up here. 52 | get_attr_nodes: Dict[str, torch.fx.Node] = {} 53 | for node in traced.graph.nodes: # type: ignore[union-attr] 54 | if node.op == "get_attr": 55 | get_attr_nodes.setdefault(node.target, node) 56 | 57 | if get_attr_nodes[node.target] != node: 58 | node.replace_all_uses_with(get_attr_nodes[node.target]) 59 | traced.graph.erase_node(node) # type: ignore[operator, union-attr] 60 | 61 | return traced -------------------------------------------------------------------------------- /examples/torch-models/models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | from .swin_mlp import SwinMLP 10 | 11 | 12 | def build_model(config): 13 | model_type = config.MODEL.TYPE 14 | if model_type == 'swin': 15 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 16 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 17 | in_chans=config.MODEL.SWIN.IN_CHANS, 18 | num_classes=config.MODEL.NUM_CLASSES, 19 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 20 | depths=config.MODEL.SWIN.DEPTHS, 21 | num_heads=config.MODEL.SWIN.NUM_HEADS, 22 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 23 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 24 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 25 | qk_scale=config.MODEL.SWIN.QK_SCALE, 26 | drop_rate=config.MODEL.DROP_RATE, 27 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 28 | ape=config.MODEL.SWIN.APE, 29 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 30 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 31 | elif model_type == 'swin_mlp': 32 | model = SwinMLP(img_size=config.DATA.IMG_SIZE, 33 | patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, 34 | in_chans=config.MODEL.SWIN_MLP.IN_CHANS, 35 | num_classes=config.MODEL.NUM_CLASSES, 36 | embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, 37 | depths=config.MODEL.SWIN_MLP.DEPTHS, 38 | num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, 39 | window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, 40 | mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, 41 | drop_rate=config.MODEL.DROP_RATE, 42 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 43 | ape=config.MODEL.SWIN_MLP.APE, 44 | patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, 45 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 46 | else: 47 | raise NotImplementedError(f"Unkown model: {model_type}") 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /test/core/fx/graph_shard/test_split_graph.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Test command: 19 | # yhrun -p 3090 -N 1 -n 1 torchrun --nproc-per-node=4 test_split_graph.py --output_dir ./ 20 | 21 | import Merak 22 | 23 | from transformers import ( 24 | GPT2LMHeadModel, 25 | GPT2Config, 26 | HfArgumentParser, 27 | ) 28 | 29 | from Merak import MerakArguments 30 | from Merak.merak_args import mergeargs, manual_set_args 31 | from Merak.core.fx.tracer import symbolic_trace 32 | from Merak.core.fx.graph_shard import shard_model_transformers 33 | 34 | def main(): 35 | # init dist 36 | pp = 4 37 | tp = 1 38 | dp = 1 39 | Merak.init(pp, tp, dp) 40 | 41 | # merge args 42 | hfparser = HfArgumentParser(MerakArguments) 43 | training_args = hfparser.parse_args_into_dataclasses()[0] 44 | 45 | config = GPT2Config() 46 | 47 | # create model 48 | model = GPT2LMHeadModel(config) 49 | 50 | # set args 51 | mergeargs(training_args, model.config) 52 | manual_set_args(training_args) 53 | 54 | # trace model 55 | traced, dummy_inputs = symbolic_trace( 56 | model, 57 | input_names = training_args.input_names, 58 | batch_size = training_args.per_device_train_batch_size, 59 | sequence_length = training_args.seq_length, 60 | ) 61 | 62 | # a experience users number threshold, a node has more user than this threshold 63 | # indicate the node is needed in multiple stages and could be transmitted between stages 64 | output_node_threshold = 5 65 | output_nodes_count = {} 66 | for node in traced.graph.nodes: 67 | if len(list(node.users)) > output_node_threshold: 68 | output_nodes_count[node.name] = len(list(node.users)) 69 | 70 | # split graph 71 | result, input_to_shard = shard_model_transformers( 72 | traced, model, training_args.shard_count, output_nodes_count 73 | ) 74 | 75 | print('==result==', result) 76 | print('==input_to_shard==', input_to_shard) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() -------------------------------------------------------------------------------- /examples/text_generation/run_gpt_text_generation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # using our distributed trainer 19 | import Merak 20 | from Merak import MerakArguments, print_rank_0 21 | from Merak.inference import text_generation_pipeline 22 | from utils import create_tokenizer 23 | from config import load_config 24 | 25 | from transformers import ( 26 | set_seed, 27 | HfArgumentParser, 28 | GPT2LMHeadModel, 29 | GPT2Config, 30 | ) 31 | import torch 32 | import enum 33 | 34 | import random 35 | import numpy as np 36 | 37 | def parse_option(parser): 38 | # easy config modification 39 | parser.add_argument('--cache-dir', type=str, help='where to save cache') 40 | parser.add_argument('--model-name', type=str, help='gpt2') 41 | 42 | return parser 43 | 44 | class ReturnType(enum.Enum): 45 | TENSORS = 0 46 | NEW_TEXT = 1 47 | FULL_TEXT = 2 48 | 49 | def main(): 50 | # init dist 51 | pp = 4 52 | tp = 1 53 | dp = 1 54 | Merak.init(pp, tp, dp) 55 | 56 | torch.cuda.set_device("cuda:0") 57 | 58 | # merge args 59 | hfparser = HfArgumentParser(MerakArguments) 60 | parser = parse_option(hfparser) 61 | training_args, args = parser.parse_args_into_dataclasses() 62 | 63 | random.seed(training_args.seed) 64 | np.random.seed(training_args.seed) 65 | torch.manual_seed(training_args.seed) 66 | torch.cuda.manual_seed(training_args.seed) 67 | 68 | # Set seed before initializing model. 69 | set_seed(training_args.seed) 70 | 71 | config_kwarg = load_config(args.model_name) 72 | config = GPT2Config( 73 | **config_kwarg 74 | ) 75 | 76 | # create tokenizer 77 | tokenizer = create_tokenizer(args.cache_dir, "IDEA-CCNL/Wenzhong-GPT2-110M", config) 78 | 79 | # create model 80 | model = GPT2LMHeadModel(config) 81 | tokenizer.eod = model.config.eos_token_id 82 | 83 | # Initialize our Trainer) 84 | pipeline = text_generation_pipeline( 85 | model=model, 86 | args=training_args, 87 | tokenizer=tokenizer, 88 | ) 89 | 90 | # Training 91 | pipeline.generate_samples_interactive() 92 | 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /examples/torch-models/run_swin.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | from Merak import MerakArguments 20 | from Merak import MerakTrainer 21 | 22 | import os 23 | from transformers import ( 24 | HfArgumentParser 25 | ) 26 | 27 | from config import get_config 28 | from models import build_model 29 | from models.swin_transformer import window_reverse 30 | from timm.models.layers import DropPath 31 | from data import build_loader 32 | 33 | 34 | def parse_option(parser): 35 | group = parser.add_argument_group('Swin Transformer training and evaluation script') 36 | group.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 37 | group.add_argument('--data_path', type=str, default=None, help='path to data folder', ) 38 | 39 | return parser 40 | 41 | 42 | def main(config): 43 | dataset_train, dataset_val, _, _, _ = build_loader(config) 44 | 45 | model = build_model(config) 46 | 47 | leaf = ((window_reverse, DropPath)) 48 | 49 | trainer = MerakTrainer( 50 | model=model, 51 | args=training_args, 52 | train_dataset=dataset_train, 53 | eval_dataset=dataset_val, 54 | leaf_modules=leaf, 55 | ) 56 | trainer.train() 57 | 58 | 59 | if __name__ == '__main__': 60 | pp = 2 61 | tp = 1 62 | dp = 2 63 | Merak.init(pp, tp, dp) 64 | 65 | if tp > 1: 66 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 67 | col_para_list = ['qkv', 'fc1'] 68 | row_para_list = ['proj', 'fc2'] 69 | weight_change_list = [('relative_position_bias_table', 1)] 70 | tp_attr_list = ['num_heads'] 71 | 72 | # manully set tp attribute for swin model 73 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 74 | weight_change_list=weight_change_list, tp_attr_list=tp_attr_list) 75 | 76 | hfparser = HfArgumentParser(MerakArguments) 77 | parser = parse_option(hfparser) 78 | training_args, args = parser.parse_args_into_dataclasses() 79 | 80 | config = get_config(args) 81 | 82 | path = os.path.join("./config.json") 83 | with open(path, "w") as f: 84 | f.write(config.dump()) 85 | 86 | main(config) 87 | -------------------------------------------------------------------------------- /examples/models/bert/config.py: -------------------------------------------------------------------------------- 1 | def load_config(model_name): 2 | if model_name == "bert-large-uncased": 3 | config = { 4 | 'architectures': ['BertForMaskedLM'], 5 | 'attention_probs_dropout_prob': 0.1, 6 | 'gradient_checkpointing': False, 7 | 'hidden_act': 'gelu', 8 | 'hidden_dropout_prob': 0.1, 9 | 'hidden_size': 1024, 10 | 'initializer_range': 0.02, 11 | 'intermediate_size': 4096, 12 | 'layer_norm_eps': 1e-12, 13 | 'max_position_embeddings': 8192, 14 | 'model_type': 'bert', 15 | 'num_attention_heads': 16, 16 | 'num_hidden_layers': 4, 17 | 'pad_token_id': 0, 18 | 'position_embedding_type': 'absolute', 19 | 'type_vocab_size': 2, 20 | 'use_cache': False, 21 | 'vocab_size': 30524, 22 | 'return_dict': True, 23 | 'attn_implementation':"eager" 24 | } 25 | elif model_name == "bert-large": 26 | config = { 27 | 'architectures': ['BertForPreTraining'], 28 | 'attention_probs_dropout_prob': 0.1, 29 | 'gradient_checkpointing': False, 30 | 'hidden_act': 'gelu', 31 | 'hidden_dropout_prob': 0.1, 32 | 'hidden_size': 1024, 33 | 'initializer_range': 0.02, 34 | 'intermediate_size': 4096, 35 | 'layer_norm_eps': 1e-12, 36 | 'max_position_embeddings': 512, 37 | 'model_type': 'bert', 38 | 'num_attention_heads': 16, 39 | 'num_hidden_layers': 24, 40 | 'pad_token_id': 0, 41 | 'position_embedding_type': 'absolute', 42 | 'type_vocab_size': 2, 43 | '_attn_implementation': 'eager', 44 | 'use_cache': True, 45 | 'return_dict': False, 46 | 'vocab_size': 30522, 47 | } 48 | elif model_name == "bert-base-uncased": 49 | config = { 50 | "architectures": [ 51 | "BertForMaskedLM" 52 | ], 53 | "attention_probs_dropout_prob": 0.1, 54 | "gradient_checkpointing": False, 55 | "hidden_act": "gelu", 56 | "hidden_dropout_prob": 0.1, 57 | "hidden_size": 768, 58 | "initializer_range": 0.02, 59 | "intermediate_size": 3072, 60 | "layer_norm_eps": 1e-12, 61 | "max_position_embeddings": 512, 62 | "model_type": "bert", 63 | "num_attention_heads": 12, 64 | "num_hidden_layers": 12, 65 | "pad_token_id": 0, 66 | "position_embedding_type": "absolute", 67 | "transformers_version": "4.6.0.dev0", 68 | "type_vocab_size": 2, 69 | "use_cache": False, 70 | "vocab_size": 30522 71 | } 72 | else: 73 | raise ValueError(f"No {model_name} config") 74 | return config -------------------------------------------------------------------------------- /examples/models/distilbert/run_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | set_seed, 26 | DistilBertForMaskedLM, 27 | DistilBertConfig, 28 | HfArgumentParser, 29 | ) 30 | from Merak.utils.datasets import DynamicGenDataset 31 | 32 | 33 | # Add custom command-line arguments 34 | def parse_option(parser): 35 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 36 | return parser 37 | 38 | 39 | def main(): 40 | # Initialize Merak distributed training environment 41 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 42 | pp = 4 43 | tp = 1 44 | dp = 1 45 | Merak.init(pp, tp, dp) 46 | 47 | hfparser = HfArgumentParser(MerakArguments) 48 | parser = parse_option(hfparser) 49 | training_args, args = parser.parse_args_into_dataclasses() 50 | 51 | if tp > 1: 52 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 53 | 54 | col_para_list = ['q_lin', 'k_lin', 'v_lin', 'lin1'] 55 | row_para_list = ['out_lin', 'lin2'] 56 | tp_attr_list = ['n_heads'] 57 | 58 | # manully set tp attribute for swin model 59 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 60 | tp_attr_list=tp_attr_list) 61 | 62 | # Set seed before initializing model. 63 | set_seed(training_args.seed) 64 | 65 | # set model config 66 | config_kwarg = load_config(args.model_name) 67 | config = DistilBertConfig(**config_kwarg) 68 | 69 | with init_empty_weights(): 70 | model = DistilBertForMaskedLM(config) 71 | 72 | # Create a fake dataset for training 73 | train_dataset = DynamicGenDataset( 74 | model.config, mode="text_only", dataset_size=1e6 75 | ) 76 | 77 | # using our distributed trainer 78 | trainer = MerakTrainer( 79 | model=model, 80 | args=training_args, 81 | train_dataset=train_dataset, 82 | ) 83 | 84 | # Training 85 | trainer.train() 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /examples/models/opt/run_opt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | set_seed, 26 | OPTForCausalLM, 27 | OPTConfig, 28 | HfArgumentParser, 29 | ) 30 | from Merak.utils.datasets import DynamicGenDataset 31 | 32 | 33 | # Add custom command-line arguments 34 | def parse_option(parser): 35 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 36 | return parser 37 | 38 | 39 | def main(): 40 | # Initialize Merak distributed training environment 41 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 42 | pp = 2 43 | tp = 2 44 | dp = 1 45 | Merak.init(pp, tp, dp) 46 | 47 | # merge args 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | if tp > 1: 53 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 54 | 55 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 56 | row_para_list = ['out_proj', 'fc2'] 57 | tp_attr_list = ['num_heads'] 58 | 59 | # manully set tp attribute for swin model 60 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 61 | tp_attr_list=tp_attr_list) 62 | 63 | # Set seed before initializing model. 64 | set_seed(training_args.seed) 65 | 66 | 67 | # set model config 68 | config_kwarg = load_config(args.model_name) 69 | config = OPTConfig(**config_kwarg) 70 | 71 | # meta init model 72 | with init_empty_weights(): 73 | model = OPTForCausalLM(config) 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="text_only", dataset_size=1e6 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = MerakTrainer( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | ) 86 | 87 | # Training 88 | trainer.train() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/models/m2m100/run_m2m100.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | M2M100ForConditionalGeneration, 26 | M2M100Config, 27 | HfArgumentParser, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 4 42 | tp = 1 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | # merge args 47 | hfparser = HfArgumentParser(MerakArguments) 48 | parser = parse_option(hfparser) 49 | training_args, args = parser.parse_args_into_dataclasses() 50 | 51 | if tp > 1: 52 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 53 | 54 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 55 | row_para_list = ['out_proj', 'fc2'] 56 | tp_attr_list = ['num_heads'] 57 | 58 | # manully set tp attribute for swin model 59 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 60 | tp_attr_list=tp_attr_list) 61 | 62 | # Set seed before initializing model. 63 | set_seed(training_args.seed) 64 | 65 | # set model config 66 | config_kwarg = load_config(args.model_name) 67 | config = M2M100Config(**config_kwarg) 68 | 69 | # meta init model 70 | with init_empty_weights(): 71 | model = M2M100ForConditionalGeneration(config) 72 | 73 | # Create a fake dataset for training 74 | train_dataset = DynamicGenDataset( 75 | model.config, mode="condition", dataset_size=1e6 76 | ) 77 | 78 | # using our distributed trainer 79 | trainer = MerakTrainer( 80 | model=model, 81 | args=training_args, 82 | train_dataset=train_dataset, 83 | ) 84 | 85 | # Training 86 | trainer.train() 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /examples/models/electra/run_electra.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | DataCollatorForLanguageModeling, 26 | set_seed, 27 | ElectraForMaskedLM, 28 | ElectraConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 2 44 | tp = 2 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | if tp > 1: 53 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 54 | 55 | col_para_list = ['query', 'key', 'value', 'intermediate.dense'] 56 | row_para_list = ['output.dense'] 57 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 58 | 59 | # manully set tp attribute for swin model 60 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 61 | tp_attr_list=tp_attr_list) 62 | 63 | # Set seed before initializing model. 64 | set_seed(training_args.seed) 65 | 66 | # set model config 67 | config_kwarg = load_config(args.model_name) 68 | config = ElectraConfig(**config_kwarg) 69 | 70 | with init_empty_weights(): 71 | model = ElectraForMaskedLM(config) 72 | 73 | # Create a fake dataset for training 74 | train_dataset = DynamicGenDataset( 75 | model.config, mode="text_only", dataset_size=1e6 76 | ) 77 | 78 | # using our distributed trainer 79 | trainer = MerakTrainer( 80 | model=model, 81 | args=training_args, 82 | train_dataset=train_dataset, 83 | ) 84 | 85 | # Training 86 | trainer.train() 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /examples/models/marianmt/run_marian.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | DataCollatorForLanguageModeling, 25 | set_seed, 26 | MarianMTModel, 27 | MarianConfig, 28 | HfArgumentParser, 29 | ) 30 | from Merak.utils.datasets import DynamicGenDataset 31 | 32 | 33 | # Add custom command-line arguments 34 | def parse_option(parser): 35 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 36 | return parser 37 | 38 | 39 | def main(): 40 | # Initialize Merak distributed training environment 41 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 42 | pp = 4 43 | tp = 1 44 | dp = 1 45 | Merak.init(pp, tp, dp) 46 | 47 | # merge args 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | if tp > 1: 53 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 54 | 55 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 56 | row_para_list = ['out_proj', 'fc2'] 57 | tp_attr_list = ['num_heads'] 58 | 59 | # manully set tp attribute for swin model 60 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 61 | tp_attr_list=tp_attr_list) 62 | 63 | # Set seed before initializing model. 64 | set_seed(training_args.seed) 65 | 66 | # set model config 67 | config_kwarg = load_config(args.model_name) 68 | config = MarianConfig(**config_kwarg) 69 | 70 | 71 | # meta init model 72 | with init_empty_weights(): 73 | model = MarianMTModel(config) 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="condition", dataset_size=1e6 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = MerakTrainer( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | ) 86 | 87 | # Training 88 | trainer.train() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/models/nezha/run_nezha.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | DataCollatorForLanguageModeling, 26 | set_seed, 27 | NezhaForMaskedLM, 28 | NezhaConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 2 44 | tp = 2 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | # merge args 49 | hfparser = HfArgumentParser(MerakArguments) 50 | parser = parse_option(hfparser) 51 | training_args, args = parser.parse_args_into_dataclasses() 52 | 53 | if tp > 1: 54 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 55 | 56 | col_para_list = ['query', 'key', 'value', 'intermediate.dense'] 57 | row_para_list = ['output.dense'] 58 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 59 | 60 | # manully set tp attribute for swin model 61 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 62 | tp_attr_list=tp_attr_list) 63 | 64 | # Set seed before initializing model. 65 | set_seed(training_args.seed) 66 | 67 | 68 | # set model config 69 | config_kwarg = load_config(args.model_name) 70 | config = NezhaConfig(**config_kwarg) 71 | 72 | # meta init model 73 | with init_empty_weights(): 74 | model = NezhaForMaskedLM(config) 75 | 76 | # Create a fake dataset for training 77 | train_dataset = DynamicGenDataset( 78 | model.config, mode="text_only", dataset_size=1e6 79 | ) 80 | 81 | # using our distributed trainer 82 | trainer = MerakTrainer( 83 | model=model, 84 | args=training_args, 85 | train_dataset=train_dataset, 86 | ) 87 | 88 | # Training 89 | trainer.train() 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /examples/models/albert/run_albert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | 20 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 21 | from config import load_config 22 | from transformers import ( 23 | set_seed, 24 | HfArgumentParser, 25 | AlbertForMaskedLM, 26 | AlbertConfig, 27 | ) 28 | from Merak.utils.datasets import DynamicGenDataset 29 | 30 | 31 | # Add custom command-line arguments 32 | def parse_option(parser): 33 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 34 | return parser 35 | 36 | 37 | def main(): 38 | # Initialize Merak distributed training environment 39 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 40 | pp = 2 41 | tp = 2 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | if tp > 1: 46 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 47 | 48 | col_para_list = ['query', 'key', 'value'] 49 | row_para_list = ['attention.dense'] 50 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 51 | 52 | # manully set tp attribute for swin model 53 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 54 | tp_attr_list=tp_attr_list) 55 | 56 | 57 | # Parse training and model arguments 58 | hfparser = HfArgumentParser(MerakArguments) 59 | parser = parse_option(hfparser) 60 | training_args, args = parser.parse_args_into_dataclasses() 61 | 62 | # Set random seed for reproducibility 63 | set_seed(training_args.seed) 64 | 65 | # Load model configuration from custom config file 66 | config_kwarg = load_config(args.model_name) 67 | config = AlbertConfig(**config_kwarg) 68 | 69 | # Initialize language modeling model 70 | with init_empty_weights(): 71 | model = AlbertForMaskedLM(config) 72 | 73 | # Create a fake dataset for training 74 | train_dataset = DynamicGenDataset( 75 | model.config, mode="text_only", dataset_size=1e6 76 | ) 77 | 78 | # Initialize trainer with model, training arguments and dataset 79 | trainer = MerakTrainer( 80 | model=model, 81 | args=training_args, 82 | train_dataset=train_dataset, 83 | ) 84 | 85 | # Start training 86 | trainer.train() 87 | 88 | 89 | if __name__ == "__main__": 90 | main() -------------------------------------------------------------------------------- /examples/models/pegasus/run_pegasus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | DataCollatorForLanguageModeling, 26 | set_seed, 27 | PegasusForConditionalGeneration, 28 | PegasusConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 4 44 | tp = 1 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | # merge args 49 | hfparser = HfArgumentParser(MerakArguments) 50 | parser = parse_option(hfparser) 51 | training_args, args = parser.parse_args_into_dataclasses() 52 | 53 | if tp > 1: 54 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 55 | 56 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 57 | row_para_list = ['out_proj', 'fc2'] 58 | tp_attr_list = ['num_heads'] 59 | 60 | # manully set tp attribute for swin model 61 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 62 | tp_attr_list=tp_attr_list) 63 | 64 | # Set seed before initializing model. 65 | set_seed(training_args.seed) 66 | 67 | # set model config 68 | config_kwarg = load_config(args.model_name) 69 | config = PegasusConfig(**config_kwarg) 70 | # config._attn_implementation = 'eager' 71 | 72 | # meta init model 73 | with init_empty_weights(): 74 | model = PegasusForConditionalGeneration(config) 75 | 76 | # Create a fake dataset for training 77 | train_dataset = DynamicGenDataset( 78 | model.config, mode="condition", dataset_size=1e6 79 | ) 80 | 81 | # using our distributed trainer 82 | trainer = MerakTrainer( 83 | model=model, 84 | args=training_args, 85 | train_dataset=train_dataset, 86 | ) 87 | 88 | # Training 89 | trainer.train() 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /test/core/checkpoint/test_safetensor_ckpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Test command: 19 | # yhrun -N 1 -n 1 -p 3090 torchrun --nproc-per-node=4 test_safetensor_ckpt.py --output_dir './output' --logging_steps 1 --wall_clock_breakdown true --split_method 'layer_split' 20 | 21 | # using our distributed trainer 22 | import os 23 | import Merak 24 | import torch 25 | 26 | from Merak import MerakTrainer, MerakArguments, init_empty_weights 27 | from Merak.core.checkpoint.safetensor_plugin import save_3d_parallel_model 28 | 29 | from transformers import ( 30 | default_data_collator, 31 | set_seed, 32 | HfArgumentParser, 33 | BertForMaskedLM, 34 | BertConfig, 35 | ) 36 | 37 | def parse_option(parser): 38 | # easy config modification 39 | parser.add_argument('--cache-dir', type=str, help='where to save cache') 40 | parser.add_argument('--model-name', type=str, help='gpt2') 41 | return parser 42 | 43 | 44 | def main(): 45 | 46 | # init dist 47 | pp = 1 48 | tp = 4 49 | dp = 1 50 | Merak.init(pp, tp, dp) 51 | 52 | # merge args 53 | hfparser = HfArgumentParser(MerakArguments) 54 | parser = parse_option(hfparser) 55 | training_args, args = parser.parse_args_into_dataclasses() 56 | 57 | # Set seed before initializing model. 58 | set_seed(training_args.seed) 59 | 60 | # load data 61 | config = BertConfig(num_hidden_layers=12, 62 | max_position_embeddings=512, 63 | num_attention_heads=12, 64 | vocab_size=50344, 65 | reorder_and_upcast_attn=False, use_cache=False, _attn_implementation="eager") 66 | 67 | # create model 68 | with init_empty_weights(): 69 | model = BertForMaskedLM(config) 70 | 71 | # Preprocessing the datasets. 72 | train_dataset = {k: v for k, v in enumerate(range(1000))} 73 | eval_dataset = {k: v for k, v in enumerate(range(1000))} 74 | 75 | class TestTrainer(MerakTrainer): 76 | 77 | def train(self): 78 | self.init_engine() 79 | save_3d_parallel_model(self.engine.module, self.args, 0) 80 | exit() 81 | 82 | # Initialize our Trainer) 83 | trainer = TestTrainer( 84 | model=model, 85 | args=training_args, 86 | train_dataset=train_dataset, 87 | eval_dataset=eval_dataset 88 | ) 89 | 90 | # Training 91 | trainer.train() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /examples/models/convnext/run_convnext.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | HfArgumentParser, 26 | ConvNextForImageClassification, 27 | ConvNextConfig, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 2 42 | tp = 2 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | if tp > 1: 47 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 48 | 49 | col_para_list = ['query', 'key', 'value'] 50 | row_para_list = ['attention.dense'] 51 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 52 | 53 | # manully set tp attribute for swin model 54 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 55 | tp_attr_list=tp_attr_list) 56 | 57 | 58 | # Parse training and model arguments 59 | hfparser = HfArgumentParser(MerakArguments) 60 | parser = parse_option(hfparser) 61 | training_args, args = parser.parse_args_into_dataclasses() 62 | 63 | # Set random seed for reproducibility 64 | set_seed(training_args.seed) 65 | 66 | # Load model configuration from custom config file 67 | config_kwarg = load_config(args.model_name) 68 | config = ConvNextConfig(**config_kwarg) 69 | 70 | # Initialize model 71 | with init_empty_weights(): 72 | model = ConvNextForImageClassification(config) 73 | 74 | # Create a fake dataset for training 75 | train_dataset = DynamicGenDataset( 76 | model.config, mode="vision_only", dataset_size=1e6 77 | ) 78 | 79 | # Initialize trainer with model, training arguments and dataset 80 | trainer = MerakTrainer( 81 | model=model, 82 | args=training_args, 83 | train_dataset=train_dataset, 84 | ) 85 | 86 | # Start training 87 | trainer.train() 88 | 89 | 90 | if __name__ == "__main__": 91 | main() -------------------------------------------------------------------------------- /examples/models/layoutlm/run_layoutlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | DataCollatorForLanguageModeling, 26 | set_seed, 27 | LayoutLMForMaskedLM, 28 | LayoutLMConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 2 44 | tp = 2 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | # merge args 49 | hfparser = HfArgumentParser(MerakArguments) 50 | parser = parse_option(hfparser) 51 | training_args, args = parser.parse_args_into_dataclasses() 52 | 53 | if tp > 1: 54 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 55 | 56 | col_para_list = ['query', 'key', 'value', 'intermediate.dense'] 57 | row_para_list = ['output.dense'] 58 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 59 | 60 | # manully set tp attribute for swin model 61 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 62 | tp_attr_list=tp_attr_list) 63 | 64 | # Set seed before initializing model. 65 | set_seed(training_args.seed) 66 | 67 | # set model config 68 | config_kwarg = load_config(args.model_name) 69 | config = LayoutLMConfig(**config_kwarg) 70 | 71 | # meta init model 72 | with init_empty_weights(): 73 | model = LayoutLMForMaskedLM(config) 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="text_only", dataset_size=1e6 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = MerakTrainer( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | # eval_dataset=eval_dataset, 86 | ) 87 | 88 | # Training 89 | trainer.train() 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /examples/models/dinov2/run_dinov2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | 20 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 21 | from config import load_config 22 | 23 | from transformers import ( 24 | set_seed, 25 | Dinov2ForImageClassification, 26 | Dinov2Config, 27 | HfArgumentParser, 28 | ) 29 | import torch 30 | from Merak.utils.datasets import DynamicGenDataset 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 2 42 | tp = 2 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | hfparser = HfArgumentParser(MerakArguments) 47 | parser = parse_option(hfparser) 48 | training_args, args = parser.parse_args_into_dataclasses() 49 | 50 | if tp > 1: 51 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 52 | 53 | col_para_list = ['query', 'key', 'value', 'fc1'] 54 | row_para_list = ['output.dense', 'fc2'] 55 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 56 | 57 | # manully set tp attribute for swin model 58 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 59 | tp_attr_list=tp_attr_list) 60 | 61 | # Set seed before initializing model. 62 | set_seed(training_args.seed) 63 | 64 | # set model config 65 | config_kwarg = load_config(args.model_name) 66 | config = Dinov2Config(**config_kwarg) 67 | 68 | # Initialize model 69 | with init_empty_weights(): 70 | model = Dinov2ForImageClassification(config) 71 | 72 | training_args.seq_length = model.config.image_size 73 | 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="vision_only", dataset_size=1e6, image_size=model.config.image_size 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = MerakTrainer( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | ) 86 | 87 | # Training 88 | trainer.train() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/models/trocr/run_trocr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from transformers import ( 23 | DataCollatorForLanguageModeling, 24 | set_seed, 25 | TrOCRForCausalLM, 26 | TrOCRConfig, 27 | HfArgumentParser, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 4 42 | tp = 1 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | # merge args 47 | hfparser = HfArgumentParser(MerakArguments) 48 | parser = parse_option(hfparser) 49 | training_args, args = parser.parse_args_into_dataclasses() 50 | 51 | if tp > 1: 52 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 53 | 54 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 55 | row_para_list = ['out_proj', 'fc2'] 56 | tp_attr_list = ['num_heads'] 57 | 58 | # manully set tp attribute for swin model 59 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 60 | tp_attr_list=tp_attr_list) 61 | 62 | # Set seed before initializing model. 63 | set_seed(training_args.seed) 64 | 65 | # set model config 66 | # config_kwarg = load_config(args.model_name) 67 | config_kwarg = { 68 | 'return_dict': False, 69 | 'use_cache': False 70 | } 71 | config = TrOCRConfig(**config_kwarg) 72 | config._attn_implementation = 'eager' 73 | 74 | # meta init 75 | with init_empty_weights(): 76 | model = TrOCRForCausalLM(config) 77 | 78 | # Create a fake dataset for training 79 | train_dataset = DynamicGenDataset( 80 | model.config, mode="text_only", dataset_size=1e6 81 | ) 82 | 83 | # using our distributed trainer 84 | trainer = MerakTrainer( 85 | model=model, 86 | args=training_args, 87 | train_dataset=train_dataset, 88 | ) 89 | 90 | # Training 91 | trainer.train() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /examples/models/gptj/run_gptj.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | 22 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 23 | from config import load_config 24 | 25 | from transformers import ( 26 | set_seed, 27 | GPTJForCausalLM, 28 | GPTJConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 2 44 | tp = 2 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | # Set seed before initializing model. 53 | set_seed(training_args.seed) 54 | 55 | if tp > 1: 56 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 57 | 58 | # input_output_mapping=[(1, 3, 'col'), (1, 1,'row'), (1, 4, 'col'), (4, 1, 'row')] 59 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc_in'] 60 | row_para_list = ['out_proj', 'fc_out'] 61 | tp_attr_list=['num_attention_heads'] 62 | # manully set tp attribute 63 | # set_tp_layer_lists(input_output_mapping=input_output_mapping, tp_attr_list=tp_attr_list) 64 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 65 | tp_attr_list=tp_attr_list) 66 | 67 | # set model config 68 | config_kwarg = load_config(args.model_name) 69 | config = GPTJConfig(**config_kwarg) 70 | 71 | with init_empty_weights(): 72 | model = GPTJForCausalLM(config) 73 | 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="text_only", dataset_size=1e6 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = DynamicGenDataset( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | ) 86 | 87 | # Training 88 | trainer.train() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/models/lxmert/run_lxmert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | 23 | from config import load_config 24 | from transformers import ( 25 | DataCollatorForLanguageModeling, 26 | set_seed, 27 | LxmertForQuestionAnswering, 28 | LxmertConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 4 44 | tp = 1 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | # merege args 49 | hfparser = HfArgumentParser(MerakArguments) 50 | parser = parse_option(hfparser) 51 | training_args, args = parser.parse_args_into_dataclasses() 52 | 53 | if tp > 1: 54 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 55 | 56 | col_para_list = ['query', 'key', 'value', 'intermediate.dense', 'lang_inter.dense', 'visn_inter.dense'] 57 | row_para_list = ['output.dense', 'lang_output.dense', 'visn_output.dense'] 58 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 59 | 60 | # manully set tp attribute for swin model 61 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 62 | tp_attr_list=tp_attr_list) 63 | 64 | # Set seed before initializing model. 65 | set_seed(training_args.seed) 66 | 67 | # set model config 68 | config_kwarg = load_config(args.model_name) 69 | config = LxmertConfig(**config_kwarg) 70 | 71 | # meta init model 72 | with init_empty_weights(): 73 | model = LxmertForQuestionAnswering(config) 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="for_qa", dataset_size=1e6 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = MerakTrainer( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | ) 86 | 87 | # Training 88 | trainer.train() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /Merak/utils/device_to_meta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Parts of the code here are adapted from https://github.com/huggingface/accelerate/blob/v0.34.2/src/accelerate/big_modeling.py 19 | 20 | from contextlib import contextmanager 21 | 22 | import torch 23 | import torch.nn as nn 24 | from transformers import Conv1D 25 | 26 | @contextmanager 27 | def init_empty_weights(): 28 | """ 29 | A context manager under which models are initialized with all parameters on the meta device, therefore creating an 30 | empty model. Useful when just initializing the model would blow the available RAM. 31 | 32 | Args: 33 | include_buffers (`bool`, *optional*): 34 | Whether or not to also put all buffers on the meta device while initializing. 35 | 36 | Example: 37 | 38 | ```python 39 | import torch.nn as nn 40 | from accelerate import init_empty_weights 41 | 42 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 43 | with init_empty_weights(): 44 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 45 | ``` 46 | 47 | 48 | 49 | Any model created under this context manager has no weights. As such you can't do something like 50 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 51 | Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not 52 | called. 53 | 54 | 55 | """ 56 | with init_on_device(torch.device("meta")) as f: 57 | yield f 58 | 59 | @contextmanager 60 | def init_on_device(device: torch.device): 61 | 62 | old_register_parameter = nn.Module.register_parameter 63 | 64 | def register_empty_parameter( 65 | module: nn.Module, 66 | name: str, 67 | param: torch.Tensor 68 | ): 69 | old_register_parameter(module, name, param) 70 | if isinstance(module, (nn.Embedding, nn.Linear, Conv1D)): 71 | if param is not None: 72 | param_cls = type(module._parameters[name]) 73 | kwargs = module._parameters[name].__dict__ 74 | kwargs["requires_grad"] = param.requires_grad 75 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 76 | 77 | try: 78 | nn.Module.register_parameter = register_empty_parameter 79 | yield 80 | finally: 81 | nn.Module.register_parameter = old_register_parameter -------------------------------------------------------------------------------- /examples/models/mobilebert/run_mobilebert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | 23 | from transformers import ( 24 | DataCollatorForLanguageModeling, 25 | set_seed, 26 | MobileBertForMaskedLM, 27 | MobileBertConfig, 28 | HfArgumentParser, 29 | ) 30 | from Merak.utils.datasets import DynamicGenDataset 31 | 32 | 33 | # Add custom command-line arguments 34 | def parse_option(parser): 35 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 36 | return parser 37 | 38 | 39 | def main(): 40 | # Initialize Merak distributed training environment 41 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 42 | pp = 2 43 | tp = 2 44 | dp = 1 45 | Merak.init(pp, tp, dp) 46 | 47 | # merge args 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | if tp > 1: 53 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 54 | 55 | col_para_list = ['query', 'key', 'value', 'intermediate.dense'] 56 | row_para_list = ['output.dense'] 57 | tp_attr_list = ['num_attention_heads', 'all_head_size'] 58 | 59 | # manully set tp attribute for swin model 60 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 61 | tp_attr_list=tp_attr_list) 62 | 63 | # Set seed before initializing model. 64 | set_seed(training_args.seed) 65 | 66 | # set model config 67 | # config_kwarg = load_config(args.model_name) 68 | config_kwarg = { 69 | 'num_hidden_layers': 8, 70 | 'return_dict': False, 71 | 'use_cache': False 72 | } 73 | config = MobileBertConfig(**config_kwarg) 74 | config._attn_implementation = 'eager' 75 | 76 | # meta init model 77 | with init_empty_weights(): 78 | model = MobileBertForMaskedLM(config) 79 | 80 | # Create a fake dataset for training 81 | train_dataset = DynamicGenDataset( 82 | model.config, mode="text_only", dataset_size=1e6 83 | ) 84 | 85 | # using our distributed trainer 86 | trainer = MerakTrainer( 87 | model=model, 88 | args=training_args, 89 | train_dataset=train_dataset, 90 | ) 91 | 92 | # Training 93 | trainer.train() 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /examples/models/speech2text/run_speech_to_text.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | 24 | from transformers import ( 25 | DataCollatorForLanguageModeling, 26 | set_seed, 27 | Speech2TextForConditionalGeneration, 28 | Speech2TextConfig, 29 | HfArgumentParser, 30 | ) 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | 34 | # Add custom command-line arguments 35 | def parse_option(parser): 36 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 37 | return parser 38 | 39 | 40 | def main(): 41 | # Initialize Merak distributed training environment 42 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 43 | pp = 4 44 | tp = 1 45 | dp = 1 46 | Merak.init(pp, tp, dp) 47 | 48 | # merge args 49 | hfparser = HfArgumentParser(MerakArguments) 50 | parser = parse_option(hfparser) 51 | training_args, args = parser.parse_args_into_dataclasses() 52 | 53 | if tp > 1: 54 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 55 | 56 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 57 | row_para_list = ['out_proj', 'fc2'] 58 | tp_attr_list = ['num_heads'] 59 | 60 | # manully set tp attribute for swin model 61 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 62 | tp_attr_list=tp_attr_list) 63 | 64 | # Set seed before initializing model. 65 | set_seed(training_args.seed) 66 | 67 | # set model config 68 | config_kwarg = load_config(args.model_name) 69 | config = Speech2TextConfig(**config_kwarg) 70 | 71 | with init_empty_weights(): 72 | model = Speech2TextForConditionalGeneration(config) 73 | 74 | # Create a fake dataset for training 75 | train_dataset = DynamicGenDataset( 76 | model.config, mode="speech2text", dataset_size=1e6 77 | ) 78 | 79 | # using our distributed trainer 80 | trainer = MerakTrainer( 81 | model=model, 82 | args=training_args, 83 | train_dataset=train_dataset, 84 | # eval_dataset=eval_dataset, 85 | # tokenizer=tokenizer, 86 | # Data collator will default to DataCollatorWithPadding, so we change it. 87 | # data_collator=data_collator, 88 | ) 89 | 90 | # Training 91 | trainer.train() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /examples/models/speech2text2/run_speech_to_text2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 21 | 22 | from config import load_config 23 | from transformers import ( 24 | DataCollatorForLanguageModeling, 25 | set_seed, 26 | HfArgumentParser, 27 | Speech2Text2ForCausalLM, 28 | Speech2Text2Config, 29 | ) 30 | from Merak.utils.datasets import DynamicGenDataset 31 | 32 | 33 | # Add custom command-line arguments 34 | def parse_option(parser): 35 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 36 | return parser 37 | 38 | 39 | def main(): 40 | # Initialize Merak distributed training environment 41 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 42 | pp = 4 43 | tp = 1 44 | dp = 1 45 | Merak.init(pp, tp, dp) 46 | 47 | # merge args 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | if tp > 1: 53 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 54 | 55 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 56 | row_para_list = ['out_proj', 'fc2'] 57 | tp_attr_list = ['num_heads'] 58 | 59 | # manully set tp attribute for swin model 60 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 61 | tp_attr_list=tp_attr_list) 62 | 63 | # Set seed before initializing model. 64 | set_seed(training_args.seed) 65 | 66 | # set model config 67 | config_kwarg = load_config(args.model_name) 68 | 69 | # meta init 70 | with init_empty_weights(): 71 | model = Speech2Text2ForCausalLM(Speech2Text2Config(**config_kwarg['decoder'])) 72 | 73 | training_args.num_layers = config_kwarg['decoder']['encoder_layers'] 74 | 75 | # Create a fake dataset for training 76 | train_dataset = DynamicGenDataset( 77 | model.config, mode="condition", dataset_size=1e6 78 | ) 79 | 80 | # using our distributed trainer 81 | trainer = MerakTrainer( 82 | model=model, 83 | args=training_args, 84 | train_dataset=train_dataset, 85 | # eval_dataset=eval_dataset, 86 | # tokenizer=tokenizer, 87 | # Data collator will default to DataCollatorWithPadding, so we change it. 88 | # data_collator=data_collator, 89 | ) 90 | 91 | # Training 92 | trainer.train() 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /examples/models/llama/run_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import os 19 | import torch 20 | import Merak 21 | 22 | from Merak import MerakArguments, MerakTrainer, print_rank_0, init_empty_weights 23 | from Merak.core import mpu 24 | from config import load_config 25 | 26 | from transformers.models.llama.modeling_llama import LlamaAttention 27 | from transformers import ( 28 | default_data_collator, 29 | set_seed, 30 | HfArgumentParser, 31 | LlamaForCausalLM, 32 | LlamaConfig 33 | ) 34 | from Merak.utils.datasets import DynamicGenDataset 35 | 36 | 37 | # Add custom command-line arguments 38 | def parse_option(parser): 39 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 40 | return parser 41 | 42 | 43 | def main(): 44 | # Initialize Merak distributed training environment 45 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 46 | pp = 4 47 | tp = 2 48 | dp = 1 49 | Merak.init(pp, tp, dp) 50 | torch.backends.cuda.matmul.allow_tf32 = True 51 | 52 | # merge args 53 | hfparser = HfArgumentParser(MerakArguments) 54 | parser = parse_option(hfparser) 55 | training_args, args = parser.parse_args_into_dataclasses() 56 | 57 | # Set seed before initializing model. 58 | set_seed(training_args.seed) 59 | 60 | if tp > 1: 61 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 62 | 63 | col_para_list = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 64 | 'gate_proj', 'up_proj'] 65 | row_para_list = ['self_attn.o_proj', 'down_proj'] 66 | tp_attr_list=['num_heads', 'num_key_value_heads'] 67 | # manully set tp attribute 68 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, tp_attr_list=tp_attr_list) 69 | 70 | # set model config 71 | config_kwarg = load_config(args.model_name) 72 | config = LlamaConfig( 73 | **config_kwarg 74 | ) 75 | config._attn_implementation="eager" 76 | 77 | # meta init model 78 | with init_empty_weights(): 79 | model = LlamaForCausalLM(config) 80 | model.generation_config.pad_token_id=-100 81 | 82 | # Create a fake dataset for training 83 | train_dataset = DynamicGenDataset( 84 | model.config, mode="text_only", dataset_size=1e6 85 | ) 86 | 87 | # Initialize our Trainer 88 | trainer = MerakTrainer( 89 | model=model, 90 | args=training_args, 91 | train_dataset=train_dataset, 92 | leaf_modules=(LlamaAttention,) 93 | ) 94 | 95 | trainer.train() 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /Merak/initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import torch.distributed as dist 20 | try: 21 | import torch_ft 22 | except: 23 | pass 24 | 25 | topo = None 26 | communication_grid = None 27 | 28 | def print_rank_0(message: str): 29 | """If distributed is initialized print only on rank 0.""" 30 | if dist.is_initialized(): 31 | if dist.get_rank() == 0: 32 | print(message, flush=True) 33 | else: 34 | print(message, flush=True) 35 | 36 | def init(pp: int, tp: int, dp: int, backend: str = 'nccl'): 37 | """ 38 | Initialized the distributed communication groups, include data parallel, 39 | tensor model parallel and pipeline model parallel. Each parallel degree 40 | has it own communication group, we can ge the rank or size through mpu API. 41 | 42 | Parameters: 43 | - dp (int) -- Parallel degree of data parallelism. 44 | - tp (int) -- Parallel degree of tensor model parallelism. 45 | - pp (int) -- Parallel degree of pipeline model parallelism. 46 | """ 47 | compile_config = torch.__config__.show().split(", ") 48 | if 'USE_NCCL=1' in compile_config or 'USE_NCCL=ON' in compile_config: 49 | backend = 'nccl' 50 | elif 'USE_MPI=1' in compile_config or 'USE_MPI=ON' in compile_config: 51 | backend = 'mpi' 52 | else: 53 | raise RuntimeError(f"Distributed package doesn't have NCCL/MPI built in") 54 | if not dist.is_initialized(): 55 | dist.init_process_group(backend) 56 | # we init topology and communication grid here 57 | from .core.mpu.topology import ( 58 | PipeModelDataParallelTopology, 59 | PipelineParallelGrid) 60 | global topo 61 | topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=tp, num_dp=dp) 62 | global communication_grid 63 | communication_grid = PipelineParallelGrid( 64 | topo, 65 | dist.new_group(ranks=range(dist.get_world_size())) 66 | ) 67 | 68 | 69 | # set mpu for transformers model 70 | from .core.mpu.initialize import ( 71 | set_data_parallel_group, 72 | set_model_parallel_group, 73 | set_pipe_parallel_group) 74 | 75 | set_data_parallel_group(communication_grid.get_data_parallel_group()) 76 | set_model_parallel_group(communication_grid.get_slice_parallel_group()) 77 | set_pipe_parallel_group(communication_grid.get_pipe_parallel_group()) 78 | 79 | print_rank_0(f'Pipeline Model Parallel Size: {pp} \ 80 | \nTensor Model Parallel Size: {tp} \ 81 | \nData Parallel Size: {dp} \n') 82 | 83 | def get_topo(): 84 | global topo 85 | return topo 86 | 87 | def get_grid(): 88 | global communication_grid 89 | return communication_grid -------------------------------------------------------------------------------- /examples/image-classification/run_vit.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # using our distributed trainer 19 | import Merak 20 | from Merak import MerakArguments, MerakTrainer, print_rank_0 21 | from utils import collate_fn, prepare_dataset, compute_metrics 22 | 23 | from transformers import ( 24 | HfArgumentParser, 25 | ViTForImageClassification, 26 | ViTConfig 27 | ) 28 | 29 | import torch 30 | from transformers.utils.dummy_vision_objects import ViTFeatureExtractor 31 | 32 | 33 | def parse_option(parser): 34 | 35 | # easy config modification 36 | parser.add_argument('--data-files', type=str, help='path to dataset') 37 | parser.add_argument('--cache-dir', type=str, help='where to save cache') 38 | 39 | return parser 40 | 41 | def main(): 42 | # init dist 43 | pp = 2 44 | tp = 1 45 | dp = 2 46 | Merak.init(pp, tp, dp) 47 | 48 | # merge args 49 | hfparser = HfArgumentParser(MerakArguments) 50 | parser = parse_option(hfparser) 51 | training_args, args = parser.parse_args_into_dataclasses() 52 | 53 | config = ViTConfig(num_labels=1000, return_dict=False) 54 | model = ViTForImageClassification(config) 55 | 56 | ds = prepare_dataset(args.data_files, args.cache_dir) 57 | 58 | class VitTrainer(MerakTrainer): 59 | def create_dataloader(self): 60 | self.train_dataloader = torch.utils.data.DataLoader( 61 | self.train_dataset, 62 | batch_sampler=self.get_train_sampler(), 63 | num_workers=self.args.dataloader_num_workers, 64 | pin_memory=self.args.dataloader_pin_memory, 65 | collate_fn=collate_fn, 66 | ) 67 | 68 | def prepare_data(self, data): 69 | if not isinstance(data, (tuple, list)): 70 | if isinstance(data, dict): 71 | inputs_list = [] 72 | for key, val in self.input_to_stage_dic.items(): 73 | for i in val: 74 | inputs_list.append(data.pop(i)) 75 | inputs_list += list(data.values()) 76 | return tuple(inputs_list) 77 | else: 78 | raise NotImplementedError('only support data in tuple, list or dict') 79 | else: 80 | return data 81 | 82 | # Initalize our trainer 83 | trainer = VitTrainer( 84 | model=model, 85 | args=training_args, 86 | train_dataset=ds["train"], 87 | eval_dataset=ds["validation"], 88 | # Data collator will default to DataCollatorWithPadding, so we change it. 89 | # compute_metrics=compute_metrics, 90 | ) 91 | 92 | # Training 93 | trainer.train() 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /examples/models/plbart/run_plbart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from transformers import ( 23 | set_seed, 24 | PLBartForCausalLM, 25 | PLBartConfig, 26 | HfArgumentParser, 27 | ) 28 | from Merak.utils.datasets import DynamicGenDataset 29 | 30 | 31 | # Add custom command-line arguments 32 | def parse_option(parser): 33 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 34 | return parser 35 | 36 | 37 | def main(): 38 | # Initialize Merak distributed training environment 39 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 40 | pp = 4 41 | tp = 1 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | # merge args 46 | hfparser = HfArgumentParser(MerakArguments) 47 | parser = parse_option(hfparser) 48 | training_args, args = parser.parse_args_into_dataclasses() 49 | 50 | if tp > 1: 51 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 52 | 53 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 54 | row_para_list = ['out_proj', 'fc2'] 55 | tp_attr_list = ['num_heads'] 56 | 57 | # manully set tp attribute for swin model 58 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 59 | tp_attr_list=tp_attr_list) 60 | 61 | # Set seed before initializing model. 62 | set_seed(training_args.seed) 63 | 64 | # set model config 65 | # config_kwarg = load_config(args.model_name) 66 | config_kwarg = { 67 | 'return_dict': False, 68 | 'use_cache': False 69 | } 70 | config = PLBartConfig(**config_kwarg) 71 | 72 | # meta init 73 | with init_empty_weights(): 74 | model = PLBartForCausalLM(config) 75 | 76 | split_point = [] 77 | split_point.append('model.decoder.embed_positions') 78 | for i in range(model.config.decoder_layers): 79 | split_point.append(f'model.decoder.layers.{i}') 80 | split_point.append('lm_head') 81 | training_args.custom_split_points = split_point 82 | 83 | # Create a fake dataset for training 84 | train_dataset = DynamicGenDataset( 85 | model.config, mode="text_only", dataset_size=1e6 86 | ) 87 | 88 | # using our distributed trainer 89 | trainer = MerakTrainer( 90 | model=model, 91 | args=training_args, 92 | train_dataset=train_dataset, 93 | ) 94 | 95 | # Training 96 | trainer.train() 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /examples/models/bart/run_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | 20 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 21 | from config import load_config 22 | from transformers import ( 23 | set_seed, 24 | HfArgumentParser, 25 | BartForCausalLM, 26 | BartConfig, 27 | ) 28 | from Merak.utils.datasets import DynamicGenDataset 29 | 30 | 31 | # Add custom command-line arguments 32 | def parse_option(parser): 33 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 34 | return parser 35 | 36 | 37 | def main(): 38 | # Initialize Merak distributed training environment 39 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 40 | pp = 4 41 | tp = 1 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | if tp > 1: 46 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 47 | 48 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'intermediate.dense', 'fc1'] 49 | row_para_list = ['out_proj', 'fc2'] 50 | tp_attr_list = ['num_heads', 'emb_dim'] 51 | 52 | # manully set tp attribute for swin model 53 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 54 | tp_attr_list=tp_attr_list) 55 | 56 | 57 | # Parse training and model arguments 58 | hfparser = HfArgumentParser(MerakArguments) 59 | parser = parse_option(hfparser) 60 | training_args, args = parser.parse_args_into_dataclasses() 61 | 62 | # Set random seed for reproducibility 63 | set_seed(training_args.seed) 64 | 65 | # Load model configuration from custom config file 66 | config_kwarg = load_config(args.model_name) 67 | config = BartConfig(**config_kwarg) 68 | 69 | # Initialize language modeling model 70 | with init_empty_weights(): 71 | model = BartForCausalLM(config) 72 | 73 | split_point = [] 74 | split_point.append('model.decoder.embed_positions') 75 | for i in range(model.config.decoder_layers): 76 | split_point.append(f'model.decoder.layers.{i}') 77 | split_point.append('lm_head') 78 | training_args.custom_split_points = split_point 79 | 80 | model.config._attn_implementation = 'eager' 81 | 82 | # Create a fake dataset for training 83 | train_dataset = DynamicGenDataset( 84 | model.config, mode="text_only", dataset_size=1e6 85 | ) 86 | 87 | # Initialize trainer with model, training arguments and dataset 88 | trainer = MerakTrainer( 89 | model=model, 90 | args=training_args, 91 | train_dataset=train_dataset, 92 | ) 93 | 94 | # Start training 95 | trainer.train() 96 | 97 | 98 | if __name__ == "__main__": 99 | main() -------------------------------------------------------------------------------- /examples/models/mt5/run_mt5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | HfArgumentParser, 26 | MT5ForConditionalGeneration, 27 | MT5Config, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | # Add custom command-line arguments 33 | def parse_option(parser): 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # Initialize Merak distributed training environment 40 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 41 | pp = 2 42 | tp = 2 43 | dp = 1 44 | Merak.init(pp, tp, dp) 45 | 46 | if tp > 1: 47 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 48 | 49 | col_para_list = ['SelfAttention.q', 'SelfAttention.k', 'SelfAttention.v', 50 | 'EncDecAttention.q', 'EncDecAttention.k', 'EncDecAttention.v', 51 | 'DenseReluDense.wi_0', 'DenseReluDense.wi_1'] 52 | row_para_list = ['SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo'] 53 | weight_change_list = ('relative_attention_bias', 1), 54 | tp_attr_list = ['n_heads', 'inner_dim'] 55 | 56 | # manully set tp attribute for swin model 57 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 58 | weight_change_list=weight_change_list, tp_attr_list=tp_attr_list) 59 | 60 | # Parse training and model arguments 61 | hfparser = HfArgumentParser(MerakArguments) 62 | parser = parse_option(hfparser) 63 | training_args, args = parser.parse_args_into_dataclasses() 64 | 65 | # Set random seed for reproducibility 66 | set_seed(training_args.seed) 67 | 68 | # Load model configuration from custom config file 69 | config_kwarg = load_config(args.model_name) 70 | config = MT5Config(**config_kwarg) 71 | 72 | # Initialize language modeling model 73 | with init_empty_weights(): 74 | model = MT5ForConditionalGeneration(config) 75 | 76 | # Create a fake dataset for training 77 | train_dataset = DynamicGenDataset( 78 | model.config, mode="condition", dataset_size=1e6, seq_length=training_args.seq_length 79 | ) 80 | 81 | # Initialize trainer with model, training arguments and dataset 82 | trainer = MerakTrainer( 83 | model=model, 84 | args=training_args, 85 | train_dataset=train_dataset, 86 | ) 87 | 88 | # Start training 89 | trainer.train() 90 | 91 | 92 | if __name__ == "__main__": 93 | main() -------------------------------------------------------------------------------- /examples/models/mbart/run_mbart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from transformers import ( 23 | set_seed, 24 | MBartForCausalLM, 25 | MBartConfig, 26 | HfArgumentParser, 27 | ) 28 | from Merak.utils.datasets import DynamicGenDataset 29 | 30 | 31 | # Add custom command-line arguments 32 | def parse_option(parser): 33 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 34 | return parser 35 | 36 | 37 | def main(): 38 | # Initialize Merak distributed training environment 39 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 40 | pp = 4 41 | tp = 1 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | # merge args 46 | hfparser = HfArgumentParser(MerakArguments) 47 | parser = parse_option(hfparser) 48 | training_args, args = parser.parse_args_into_dataclasses() 49 | 50 | if tp > 1: 51 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 52 | 53 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 54 | row_para_list = ['out_proj', 'fc2'] 55 | tp_attr_list = ['num_heads'] 56 | 57 | # manully set tp attribute for swin model 58 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 59 | tp_attr_list=tp_attr_list) 60 | 61 | # Set seed before initializing model. 62 | set_seed(training_args.seed) 63 | 64 | # set model config 65 | # config_kwarg = load_config(args.model_name) 66 | config_kwarg = { 67 | 'return_dict': False, 68 | 'use_cache': False 69 | } 70 | config = MBartConfig(**config_kwarg) 71 | config._attn_implementation = 'eager' 72 | 73 | # meta init model 74 | with init_empty_weights(): 75 | model = MBartForCausalLM(config) 76 | 77 | split_point = [] 78 | split_point.append('model.decoder.embed_positions') 79 | for i in range(model.config.decoder_layers): 80 | split_point.append(f'model.decoder.layers.{i}') 81 | split_point.append('lm_head') 82 | training_args.custom_split_points = split_point 83 | 84 | # Create a fake dataset for training 85 | train_dataset = DynamicGenDataset( 86 | model.config, mode="text_only", dataset_size=1e6 87 | ) 88 | 89 | # using our distributed trainer 90 | trainer = MerakTrainer( 91 | model=model, 92 | args=training_args, 93 | train_dataset=train_dataset, 94 | ) 95 | 96 | # Training 97 | trainer.train() 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /examples/models/xglm/run_xglm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import Merak 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from transformers import ( 23 | DataCollatorForLanguageModeling, 24 | set_seed, 25 | XGLMForCausalLM, 26 | XGLMConfig, 27 | HfArgumentParser, 28 | ) 29 | 30 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 31 | from Merak.utils.datasets import DynamicGenDataset 32 | 33 | # Add custom command-line arguments 34 | def parse_option(parser): 35 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 36 | return parser 37 | 38 | 39 | def main(): 40 | # Initialize Merak distributed training environment 41 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 42 | pp = 4 43 | tp = 1 44 | dp = 1 45 | Merak.init(pp, tp, dp) 46 | 47 | # merge args 48 | hfparser = HfArgumentParser(MerakArguments) 49 | parser = parse_option(hfparser) 50 | training_args, args = parser.parse_args_into_dataclasses() 51 | 52 | if tp > 1: 53 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 54 | 55 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 56 | row_para_list = ['out_proj', 'fc2'] 57 | tp_attr_list = ['num_heads'] 58 | 59 | # manully set tp attribute for swin model 60 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 61 | tp_attr_list=tp_attr_list) 62 | 63 | # Set seed before initializing model. 64 | set_seed(training_args.seed) 65 | 66 | 67 | # set model config 68 | # config_kwarg = load_config(args.model_name) 69 | config_kwarg = { 70 | 'max_position_embeddings': 512, 71 | 'num_layers': 8, 72 | 'return_dict': False, 73 | 'use_cache': False 74 | } 75 | config = XGLMConfig(**config_kwarg) 76 | 77 | # meta init 78 | with init_empty_weights(): 79 | model = XGLMForCausalLM(config) 80 | 81 | model.eval() 82 | 83 | # Create a fake dataset for training 84 | eval_dataset = DynamicGenDataset( 85 | model.config, mode="text_only", dataset_size=1e6 86 | ) 87 | 88 | # using our distributed trainer 89 | trainer = MerakTrainer( 90 | model=model, 91 | args=training_args, 92 | eval_dataset=eval_dataset, 93 | leaf_modules=(AttentionMaskConverter,) 94 | # Data collator will default to DataCollatorWithPadding, so we change it. 95 | # data_collator=data_collator, 96 | ) 97 | 98 | # Training 99 | trainer.evaluation() 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /examples/models/blenderbot/run_blenderbot.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | 20 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 21 | from config import load_config 22 | from transformers import ( 23 | set_seed, 24 | HfArgumentParser, 25 | BlenderbotForConditionalGeneration, 26 | BlenderbotConfig, 27 | ) 28 | 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | # Add custom command-line arguments 32 | def parse_option(parser): 33 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. gpt2)') 34 | return parser 35 | 36 | 37 | def main(): 38 | # Initialize Merak distributed training environment 39 | # pp: pipeline parallelism, tp: tensor parallelism, dp: data parallelism 40 | pp = 4 41 | tp = 1 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | 46 | if tp > 1: 47 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 48 | 49 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 50 | row_para_list = ['out_proj', 'fc2'] 51 | tp_attr_list = ['num_heads'] 52 | 53 | # manully set tp attribute for swin model 54 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 55 | tp_attr_list=tp_attr_list) 56 | 57 | # Parse training and model arguments 58 | hfparser = HfArgumentParser(MerakArguments) 59 | parser = parse_option(hfparser) 60 | training_args, args = parser.parse_args_into_dataclasses() 61 | 62 | # Set random seed for reproducibility 63 | set_seed(training_args.seed) 64 | 65 | # Load model configuration from custom config file 66 | config_kwarg = load_config(args.model_name) 67 | config = BlenderbotConfig(**config_kwarg) 68 | 69 | # Initialize language modeling model 70 | with init_empty_weights(): 71 | model = BlenderbotForConditionalGeneration(config) 72 | 73 | split_points = [] 74 | for i in range(config.encoder_layers): 75 | split_points.append(f"model.encoder.layers.{i}") 76 | for i in range(config.decoder_layers): 77 | split_points.append(f"model.decoder.layers.{i}") 78 | split_points.append("lm_head") 79 | training_args.custom_split_points = split_points 80 | 81 | model.config._attn_implementation = 'eager' 82 | 83 | # Create a fake dataset for training 84 | train_dataset = DynamicGenDataset( 85 | model.config, mode="condition", dataset_size=1e6 86 | ) 87 | 88 | # Initialize trainer with model, training arguments and dataset 89 | trainer = MerakTrainer( 90 | model=model, 91 | args=training_args, 92 | train_dataset=train_dataset, 93 | ) 94 | 95 | # Start training 96 | trainer.train() 97 | 98 | 99 | if __name__ == "__main__": 100 | main() -------------------------------------------------------------------------------- /Merak/core/fx/graph_shard/split_graph.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Yck (eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | 20 | from torch.fx.graph_module import GraphModule 21 | from typing import Dict, List, Set, Any, Optional, Tuple 22 | 23 | from Merak.merak_args import get_args 24 | 25 | from .farthest_deps import farthest_deps_split 26 | from .layer_config import layer_config_split 27 | from .nearest_deps import nearest_deps_split 28 | 29 | 30 | def _shard_model_transformers( 31 | traced_graph_module: GraphModule, 32 | model: torch.nn.Module, 33 | shard_count=3, 34 | ) -> Tuple[List[GraphModule], Dict[str, int]]: 35 | """Utility used to shard a model using torch.fx. 36 | 37 | This function traces the model twice in an attempt to identify the 38 | right cutpoints and then shard the model. In the first pass we calculate 39 | the number of parameters as we are tracing the graph and mark nodes at 40 | which we might want to create a new module. In the second pass we 41 | modify the graph by inserting placeholders and output nodes to 42 | essentially shard the graph. 43 | 44 | We don't support skip connections between shards. This means that all 45 | input and output is self contained within a given shard. A node from 46 | shard 1 cannot be an input to a node from shard 3. We expect all inputs 47 | to a given shard to be coming from the last node in the previous shard. 48 | This means that we may not be able to shard models by the specified 49 | `shard_count` mentioned by the user. 50 | 51 | Args: 52 | model (nn.Module): Model to be sharded as specified by the device 53 | count. 54 | 55 | shard_count (int): Number of shards that we want to split the model 56 | into. 57 | 58 | """ 59 | args = get_args() 60 | 61 | # a experience users number threshold, a node has more user than this 62 | # threshold indicate the node is needed in multiple stages and 63 | # could be transmitted between stages 64 | output_node_threshold = 5 65 | output_nodes_count = {} 66 | for node in traced_graph_module.graph.nodes: 67 | if len(list(node.users)) > output_node_threshold: 68 | output_nodes_count[node.name] = len(list(node.users)) 69 | 70 | if args.split_method == 'farthest_min_deps': 71 | module_list, func_inputs = farthest_deps_split(traced_graph_module, 72 | model, shard_count, 73 | output_nodes_count) 74 | elif args.split_method == 'layer_split': 75 | module_list, func_inputs = layer_config_split(traced_graph_module, model) 76 | elif args.split_method == 'nearest_min_deps': 77 | module_list, func_inputs = nearest_deps_split(traced_graph_module, model) 78 | else: 79 | assert args.split_method in ['farthest_min_deps', 'layer_split', 'nearest_min_deps'] 80 | return module_list, func_inputs -------------------------------------------------------------------------------- /Merak/core/printer/logging.py: -------------------------------------------------------------------------------- 1 | # https://github.com/microsoft/DeepSpeed/blob/85ce85dd5f4b18c0019a5121b06900e3a2c3933b/deepspeed/utils/logging.py 2 | 3 | import logging 4 | import sys 5 | import torch 6 | import torch.distributed as dist 7 | from typing import List, Optional, Dict, Union 8 | 9 | __all__ = ['LoggerFactory', 'log_dist', 'logger'] 10 | 11 | log_levels = { 12 | "debug": logging.DEBUG, 13 | "info": logging.INFO, 14 | "warning": logging.WARNING, 15 | "error": logging.ERROR, 16 | "critical": logging.CRITICAL, 17 | } 18 | 19 | 20 | class LoggerFactory: 21 | @staticmethod 22 | def create_logger(name: Optional[str] = None, level: int = logging.INFO) -> logging.Logger: 23 | """create a logger 24 | 25 | Args: 26 | name (str): name of the logger 27 | level: level of logger 28 | 29 | Raises: 30 | ValueError is name is None 31 | """ 32 | 33 | if name is None: 34 | raise ValueError("name for logger cannot be None") 35 | 36 | formatter = logging.Formatter( 37 | "[%(asctime)s] [%(levelname)s] " 38 | "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s") 39 | 40 | logger_ = logging.getLogger(name) 41 | logger_.setLevel(level) 42 | logger_.propagate = False 43 | ch = logging.StreamHandler(stream=sys.stdout) 44 | ch.setLevel(level) 45 | ch.setFormatter(formatter) 46 | logger_.addHandler(ch) 47 | return logger_ 48 | 49 | 50 | logger = LoggerFactory.create_logger(name="Merak", level=logging.INFO) 51 | 52 | 53 | def log_dist(message: str, ranks: Optional[List[int]] = None, level: int = logging.INFO): 54 | """Log message when one of following condition meets 55 | 56 | + not dist.is_initialized() 57 | + dist.get_rank() in ranks if ranks is not None or ranks = [-1] 58 | 59 | Args: 60 | message (str) 61 | ranks (list) 62 | level (int) 63 | 64 | """ 65 | should_log = not dist.is_initialized() 66 | ranks = ranks or [] 67 | my_rank = dist.get_rank() if dist.is_initialized() else -1 68 | if ranks and not should_log: 69 | should_log = ranks[0] == -1 70 | should_log = should_log or (my_rank in set(ranks)) 71 | if should_log: 72 | final_message = "[Rank {}] {}".format(my_rank, message) 73 | logger.log(level, final_message) 74 | 75 | class Metric(object): 76 | def __init__(self, name: str): 77 | self.name = name 78 | self.sum = torch.tensor(0.) 79 | self.n = torch.tensor(0.) 80 | 81 | def update(self, val: torch.Tensor): 82 | self.sum += val 83 | self.n += 1 84 | 85 | def reset(self): 86 | self.sum = torch.tensor(0.) 87 | self.n = torch.tensor(0.) 88 | 89 | @property 90 | def avg(self) -> torch.Tensor: 91 | return self.sum / self.n 92 | 93 | class AccMetric(object): 94 | def __init__(self): 95 | self.metrics = {} 96 | 97 | def update(self, key: str, val: torch.Tensor): 98 | if key in self.metrics: 99 | self.metrics[key].update(val) 100 | else: 101 | self.metrics[key] = Metric(key) 102 | self.metrics[key].update(val) 103 | 104 | def reset(self): 105 | for key in self.metrics: 106 | self.metrics[key].reset() 107 | 108 | @property 109 | def avg(self) -> Dict[str, torch.Tensor]: 110 | avg_dict = {} 111 | for key in self.metrics: 112 | avg_dict[key] = self.metrics[key].avg.item() 113 | self.reset() 114 | 115 | return avg_dict 116 | 117 | -------------------------------------------------------------------------------- /examples/models/clip/run_clip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | HfArgumentParser, 26 | CLIPModel, 27 | CLIPConfig, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | def parse_option(parser): 33 | # easy config modification 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. bert)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # init dist 40 | pp = 4 41 | tp = 1 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | 46 | if tp > 1: 47 | assert tp == 1, 'Error: Model does not support tensor parallelization.' 48 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 49 | 50 | col_para_list = ['q_proj', 'k_proj', 'v_proj', 'fc1'] 51 | row_para_list = ['out_proj', 'fc2'] 52 | tp_attr_list = ['num_heads'] 53 | 54 | # manully set tp attribute for swin model 55 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 56 | tp_attr_list=tp_attr_list) 57 | 58 | hfparser = HfArgumentParser(MerakArguments) 59 | parser = parse_option(hfparser) 60 | training_args, args = parser.parse_args_into_dataclasses() 61 | 62 | # Set seed before initializing model. 63 | set_seed(training_args.seed) 64 | 65 | # set model config 66 | config_kwarg = load_config(args.model_name) 67 | config = CLIPConfig(**config_kwarg) 68 | 69 | 70 | with init_empty_weights(): 71 | model = CLIPModel(config) 72 | 73 | # Create a fake dataset for training 74 | train_dataset = DynamicGenDataset( 75 | model.config, mode="multimodal", dataset_size=1e6 76 | ) 77 | 78 | # define custom loss function 79 | def clip_loss(similarity: torch.Tensor) -> torch.Tensor: 80 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: 81 | return torch.nn.functional.cross_entropy( 82 | logits, torch.arange(len(logits), device=logits.device) 83 | ) 84 | caption_loss = contrastive_loss(similarity) 85 | image_loss = contrastive_loss(similarity.t()) 86 | return (caption_loss + image_loss) / 2.0 87 | 88 | class MyTrainer(MerakTrainer): 89 | 90 | def get_loss_fn(self): 91 | def loss_fn(outputs, labels): 92 | loss = clip_loss(outputs[1]) 93 | return loss 94 | return loss_fn 95 | 96 | 97 | # using our distributed trainer 98 | trainer = MyTrainer( 99 | model=model, 100 | args=training_args, 101 | train_dataset=train_dataset, 102 | # eval_dataset=eval_dataset, 103 | ) 104 | 105 | # Training 106 | trainer.train() 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /Merak/core/printer/see_memory.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import gc 20 | import psutil 21 | from .logging import logger, log_dist 22 | from typing import Optional, List 23 | 24 | try: 25 | import torch_ft 26 | USE_CPU = True 27 | except ModuleNotFoundError: 28 | USE_CPU = False 29 | 30 | if USE_CPU: 31 | if hasattr(torch_ft.utils, "dsp_mem_used"): 32 | torch_memory_reserved = torch_ft.utils.dsp_mem_used 33 | else: 34 | torch_memory_reserved = None 35 | if hasattr(torch_ft.utils, "dsp_mem_used_peak"): 36 | torch_max_memory_reserved = torch_ft.utils.dsp_mem_used_peak 37 | else: 38 | torch_max_memory_reserved = None 39 | else: 40 | if hasattr(torch.cuda, "memory_reserved"): 41 | torch_memory_reserved = torch.cuda.memory_reserved 42 | else: 43 | torch_memory_reserved = torch.cuda.memory_allocated 44 | if hasattr(torch.cuda, "max_memory_reserved"): 45 | torch_max_memory_reserved = torch.cuda.max_memory_reserved 46 | else: 47 | torch_max_memory_reserved = torch.cuda.memory_cached 48 | 49 | peak_memory = 0 50 | 51 | def see_memory_usage( 52 | message: str, 53 | force: bool= False, 54 | ram: bool = False, 55 | ranks: List[int] = [0], 56 | ): 57 | if not force: 58 | return 59 | if torch.distributed.is_initialized() and \ 60 | not torch.distributed.get_rank() in ranks: 61 | # torch.cuda.empty_cache() 62 | # torch.distributed.barrier(group=group) 63 | return 64 | 65 | # python doesn't do real-time garbage collection so do it explicitly to get 66 | # the correct RAM reports 67 | gc.collect() 68 | global peak_memory 69 | if USE_CPU: 70 | max_ma = round(torch_max_memory_reserved() / (1024 * 1024),2) 71 | else: 72 | max_ma = round(torch.cuda.max_memory_allocated() / (1024 * 1024),2) 73 | peak_memory = max(peak_memory, max_ma) 74 | # Print message except when distributed but not rank 0 75 | # log_dist(message, ranks=ranks) 76 | log_dist( 77 | f"{message} MA {round(torch_memory_reserved() / (1024 * 1024),2 )} MB \ 78 | Max_MA {max_ma} MB \ 79 | CA {round(torch_memory_reserved() / (1024 * 1024),2)} MB \ 80 | Max_CA {round(torch_max_memory_reserved() / (1024 * 1024))} MB \ 81 | PEAK_MA {peak_memory} MB ", ranks=ranks) 82 | 83 | if ram: 84 | vm_stats = psutil.virtual_memory() 85 | used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) 86 | logger.info( 87 | f'CPU Virtual Memory: used = {used_GB} GB, \ 88 | percent = {vm_stats.percent}%') 89 | 90 | # torch.cuda.empty_cache() 91 | # torch.distributed.barrier(group=group) 92 | # get the peak memory to report correct data, so reset the counter for the 93 | # next call 94 | if USE_CPU: 95 | if hasattr(torch_ft.utils, "reset_dsp_peak_mem_stats"): 96 | torch_ft.utils.reset_dsp_peak_mem_stats() 97 | else: 98 | if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ 99 | torch.cuda.reset_peak_memory_stats() -------------------------------------------------------------------------------- /examples/models/altclip/run_altclip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022, HPDL group, PDL lab, NUDT. All rights reserved. 3 | # 4 | # Maintainer: Swli (lucasleesw9@gmail.com), TXacs (txacs1993@gmail.com), Yck(eyichenke@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import Merak 19 | import torch 20 | 21 | from Merak import MerakArguments, MerakTrainer, init_empty_weights 22 | from config import load_config 23 | from transformers import ( 24 | set_seed, 25 | HfArgumentParser, 26 | AltCLIPModel, 27 | AltCLIPConfig, 28 | ) 29 | from Merak.utils.datasets import DynamicGenDataset 30 | 31 | 32 | def parse_option(parser): 33 | # easy config modification 34 | parser.add_argument('--model_name', type=str, help='Name of the model to load (e.g. bert)') 35 | return parser 36 | 37 | 38 | def main(): 39 | # init dist 40 | pp = 4 41 | tp = 1 42 | dp = 1 43 | Merak.init(pp, tp, dp) 44 | 45 | 46 | if tp > 1: 47 | from Merak.core.tensor_parallel.mp_attrs import set_tp_layer_lists 48 | 49 | col_para_list = ['query', 'key', 'value', 'q_proj', 'k_proj', 'v_proj', 'intermediate.dense', 'fc1'] 50 | row_para_list = ['output.dense', 'out_proj', 'fc2'] 51 | tp_attr_list = ['num_heads'] 52 | 53 | # manully set tp attribute for swin model 54 | set_tp_layer_lists(col_para_list=col_para_list, row_para_list=row_para_list, 55 | tp_attr_list=tp_attr_list) 56 | # merge args 57 | hfparser = HfArgumentParser(MerakArguments) 58 | parser = parse_option(hfparser) 59 | training_args, args = parser.parse_args_into_dataclasses() 60 | 61 | # Set seed before initializing model. 62 | set_seed(training_args.seed) 63 | 64 | # set model config 65 | config_kwarg = load_config(args.model_name) 66 | config = AltCLIPConfig(**config_kwarg) 67 | 68 | 69 | with init_empty_weights(): 70 | model = AltCLIPModel(config) 71 | 72 | model.config._attn_implementation = 'eager' 73 | 74 | # Create a fake dataset for training 75 | train_dataset = DynamicGenDataset( 76 | model.config, mode="multimodal", dataset_size=1e6 77 | ) 78 | 79 | 80 | # define custom loss function 81 | def clip_loss(similarity: torch.Tensor) -> torch.Tensor: 82 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: 83 | return torch.nn.functional.cross_entropy( 84 | logits, torch.arange(len(logits), device=logits.device) 85 | ) 86 | caption_loss = contrastive_loss(similarity) 87 | image_loss = contrastive_loss(similarity.t()) 88 | return (caption_loss + image_loss) / 2.0 89 | 90 | class MyTrainer(MerakTrainer): 91 | 92 | def get_loss_fn(self): 93 | def loss_fn(outputs, labels): 94 | loss = clip_loss(outputs[1]) 95 | return loss 96 | return loss_fn 97 | 98 | 99 | # using our distributed trainer 100 | trainer = MyTrainer( 101 | model=model, 102 | args=training_args, 103 | train_dataset=train_dataset, 104 | # eval_dataset=eval_dataset, 105 | ) 106 | 107 | # Training 108 | trainer.train() 109 | 110 | if __name__ == "__main__": 111 | main() 112 | --------------------------------------------------------------------------------