├── .gitignore ├── DATA_LICENSE.txt ├── LICENSE ├── LICENSE.txt ├── MODEL_LICENSE.txt ├── OpenBA ├── __init__.py ├── configuration_openba.py ├── modeling_openba.py └── tokenization_openba.py ├── README.md ├── README_ZH.md ├── assets ├── bachelor.png ├── data.png ├── downstream.png └── training_process.png ├── convert_megatron_to_hf_ckp.py ├── convert_megatron_to_hf_ckp.sh ├── evaluation ├── CMMLU │ ├── data │ │ └── 5shot │ │ │ └── agronomy.json │ ├── logs │ │ ├── OpenBT5-0shot │ │ └── OpenBT5-5shot │ ├── main.py │ ├── make_data.py │ ├── readme.md │ ├── scripts │ │ ├── eval_fewshot.sh │ │ └── eval_zeroshot.sh │ └── template.py └── MMLU │ ├── data │ └── 5shot │ │ └── abstract_algebra_test.json │ ├── logs │ ├── OpenBT5-0shot │ └── OpenBT5-5shot │ ├── main.py │ ├── make_data.py │ ├── readme.md │ ├── scripts │ ├── eval_fewshot.sh │ └── eval_zeroshot.sh │ └── template.py ├── gradio_chat_demo.py ├── gradio_code_demo.py └── training ├── .coveragerc ├── .gitignore ├── megatron ├── __init__.py ├── arguments.py ├── checkpointing.py ├── core │ ├── __init__.py │ ├── enums.py │ ├── parallel_state.py │ ├── pipeline_parallel │ │ ├── __init__.py │ │ ├── p2p_communication.py │ │ └── schedules.py │ ├── tensor_parallel │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ ├── data.py │ │ ├── layers.py │ │ ├── mappings.py │ │ ├── random.py │ │ └── utils.py │ └── utils.py ├── data │ ├── Makefile │ ├── __init__.py │ ├── autoaugment.py │ ├── bert_dataset.py │ ├── biencoder_dataset_utils.py │ ├── blendable_dataset.py │ ├── data_samplers.py │ ├── dataset_utils.py │ ├── gpt_dataset.py │ ├── helpers.cpp │ ├── ict_dataset.py │ ├── image_folder.py │ ├── indexed_dataset.py │ ├── orqa_wiki_dataset.py │ ├── realm_dataset_utils.py │ ├── realm_index.py │ ├── t5_dataset.py │ ├── test │ │ ├── test_indexed_dataset.py │ │ └── test_preprocess_data.sh │ └── vit_dataset.py ├── dist_signal_handler.py ├── fp16_deprecated │ └── loss_scaler.py ├── fused_kernels │ ├── __init__.py │ ├── compat.h │ ├── fused_weight_gradient_dense.cpp │ ├── fused_weight_gradient_dense.cu │ ├── layer_norm_cuda.cpp │ ├── layer_norm_cuda_kernel.cu │ ├── scaled_masked_softmax.cpp │ ├── scaled_masked_softmax.h │ ├── scaled_masked_softmax_cuda.cu │ ├── scaled_softmax.cpp │ ├── scaled_softmax_cuda.cu │ ├── scaled_upper_triang_masked_softmax.cpp │ ├── scaled_upper_triang_masked_softmax.h │ ├── scaled_upper_triang_masked_softmax_cuda.cu │ ├── tests │ │ ├── __init__.py │ │ └── test_fused_kernels.py │ └── type_shim.h ├── global_vars.py ├── indexer.py ├── initialize.py ├── memory.py ├── microbatches.py ├── model │ ├── __init__.py │ ├── bert_model.py │ ├── biencoder_model.py │ ├── classification.py │ ├── distributed.py │ ├── enums.py │ ├── fused_bias_gelu.py │ ├── fused_layer_norm.py │ ├── fused_softmax.py │ ├── gpt_model.py │ ├── language_model.py │ ├── module.py │ ├── multiple_choice.py │ ├── realm_model.py │ ├── retro_transformer.py │ ├── rotary_embedding_torch.py │ ├── t5_model.py │ ├── transformer.py │ ├── utils.py │ └── vision │ │ ├── classification.py │ │ ├── dino.py │ │ ├── esvit_swin_backbone.py │ │ ├── inpainting.py │ │ ├── knn_monitor.py │ │ ├── mit_backbone.py │ │ ├── swin_backbone.py │ │ ├── utils.py │ │ └── vit_backbone.py ├── mpu │ └── tests │ │ ├── __init__.py │ │ ├── commons.py │ │ ├── test_cross_entropy.py │ │ ├── test_data.py │ │ ├── test_initialize.py │ │ ├── test_layers.py │ │ └── test_random.py ├── optimizer │ ├── __init__.py │ ├── clip_grads.py │ ├── distrib_optimizer.py │ ├── grad_scaler.py │ └── optimizer.py ├── optimizer_param_scheduler.py ├── static │ └── index.html ├── text_generation │ ├── __init__.py │ ├── api.py │ ├── beam_utils.py │ ├── communication.py │ ├── forward_step.py │ ├── generation.py │ ├── sampling.py │ └── tokenization.py ├── text_generation_server.py ├── timers.py ├── tokenizer │ ├── __init__.py │ ├── bert_tokenization.py │ ├── gpt2_tokenization.py │ └── tokenizer.py ├── training.py └── utils.py ├── pretrain_t5.py ├── scripts ├── data_process_flan.sh ├── data_process_span_corr.sh ├── run_flan.sh ├── run_pretrain.sh └── run_stretch.sh └── tools ├── checkpoint_split_megatron.py ├── linter.py ├── merge_datasets.py ├── preprocess_data_chat.py ├── preprocess_data_finetune.py └── preprocess_data_pretrain.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | .DS_Store 155 | assets/.DS_Store 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /OpenBA/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_openba import OpenBAForConditionalGeneration 2 | from .configuration_openba import OpenBAConfig 3 | from .tokenization_openba import OpenBATokenizer 4 | 5 | 6 | __all__ = [ 7 | "OpenBAForConditionalGeneration", 8 | "OpenBAConfig", 9 | "OpenBATokenizer", 10 | ] -------------------------------------------------------------------------------- /OpenBA/configuration_openba.py: -------------------------------------------------------------------------------- 1 | from transformers.utils import logging 2 | from transformers.configuration_utils import PretrainedConfig 3 | 4 | 5 | logger = logging.get_logger(__name__) 6 | 7 | 8 | class OpenBAConfig(PretrainedConfig): 9 | model_type = "openba" 10 | keys_to_ignore_at_inference = ["past_key_values"] 11 | attribute_map = { 12 | "hidden_size": "hidden_size", 13 | "num_attention_heads": "num_heads", 14 | "num_hidden_layers": "num_layers" 15 | } 16 | 17 | def __init__( 18 | self, 19 | vocab_size=32128, 20 | hidden_size=512, 21 | kv_channels=64, 22 | ffn_hidden_size=2048, 23 | num_layers=12, 24 | num_decoder_layers=None, 25 | hidden_dropout=0.1, 26 | attention_dropout=0.1, 27 | num_heads=8, 28 | is_encoder_decoder=True, 29 | use_cache=True, 30 | initializer_factor=1.0, 31 | pad_token_id=0, 32 | eos_token_id=1, 33 | decoder_start_token_id=0, 34 | add_qkv_bias=False, 35 | add_ffn_bias=False, 36 | add_lm_head_bias=False, 37 | max_seq_length=1024, 38 | decoder_max_seq_length=256, 39 | **kwargs, 40 | ): 41 | self.vocab_size = vocab_size 42 | self.hidden_size = hidden_size 43 | self.kv_channels = kv_channels 44 | self.ffn_hidden_size = ffn_hidden_size 45 | self.num_layers = num_layers 46 | self.num_decoder_layers = ( 47 | num_decoder_layers if num_decoder_layers is not None else self.num_layers 48 | ) # default = symmetry 49 | self.hidden_dropout = hidden_dropout 50 | self.attention_dropout = attention_dropout 51 | self.initializer_factor = initializer_factor 52 | self.num_heads = num_heads 53 | self.add_qkv_bias = add_qkv_bias 54 | self.add_ffn_bias = add_ffn_bias 55 | self.add_lm_head_bias = add_lm_head_bias 56 | self.max_seq_length = max_seq_length 57 | self.decoder_max_seq_length = decoder_max_seq_length 58 | self.use_cache = use_cache 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | eos_token_id=eos_token_id, 63 | decoder_start_token_id=decoder_start_token_id, 64 | is_encoder_decoder=is_encoder_decoder, 65 | **kwargs, 66 | ) -------------------------------------------------------------------------------- /assets/bachelor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLG/OpenBA/ef4716b7e588f17096043eef773557be41f2d7ed/assets/bachelor.png -------------------------------------------------------------------------------- /assets/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLG/OpenBA/ef4716b7e588f17096043eef773557be41f2d7ed/assets/data.png -------------------------------------------------------------------------------- /assets/downstream.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLG/OpenBA/ef4716b7e588f17096043eef773557be41f2d7ed/assets/downstream.png -------------------------------------------------------------------------------- /assets/training_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLG/OpenBA/ef4716b7e588f17096043eef773557be41f2d7ed/assets/training_process.png -------------------------------------------------------------------------------- /convert_megatron_to_hf_ckp.sh: -------------------------------------------------------------------------------- 1 | python convert_megatron_to_hf_ckp.py \ 2 | --convert_checkpoint_from_megatron_to_transformers \ 3 | --load_path /data/checkpoint/14b_main_long_final/iter_0020000/ \ 4 | --save_path /opt/dyy/hf_model_stretch \ 5 | --tokenizer_name "OpenBA/OpenBA-LM" \ 6 | --print-checkpoint-structure \ 7 | -------------------------------------------------------------------------------- /evaluation/CMMLU/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | import os 5 | import glob 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | import template 9 | import numpy as np 10 | import os 11 | 12 | 13 | def post_process_ABCD(out_list, test_inputs): 14 | golds = [i['data']['ans'] for i in test_inputs] 15 | encs = [i[1] for i in out_list] 16 | decs = [i[2] for i in out_list] 17 | preds = [i[0] for i in out_list] 18 | assert len(preds) == len(decs) == len(golds) == len(decs), str(len(preds), len(decs), len(golds), len(decs)) 19 | data2write = [json.dumps({'enc':enc, 'dec':dec, 'gold':gold, 'pred':pred}) + '\n' for enc, dec, gold, pred in zip(encs, decs, golds, preds)] 20 | right = sum([1 for i, j in zip(golds, preds) if i == j]) 21 | cnt = len(golds) 22 | ABCD_rate = [sum([1 for i in preds if i == j]) / cnt for j in ["A", "B", "C", "D"]] 23 | return ABCD_rate, cnt, right, data2write 24 | 25 | def solve_ABCD(model, tokenizer, input_text, decoder_input_text, args): 26 | input_ids = tokenizer(input_text, return_tensors='pt', max_length = args.max_length, truncation=True).input_ids 27 | decoder_input_ids= tokenizer(decoder_input_text, return_tensors='pt', max_length = args.decoder_max_length - 1, truncation=True).input_ids 28 | decoder_input_ids = model._shift_right(decoder_input_ids) 29 | with torch.no_grad(): 30 | logits = model(input_ids=input_ids.cuda(), \ 31 | decoder_input_ids=decoder_input_ids.cuda()).logits[:,-1,:].contiguous()[0] 32 | 33 | probs = ( 34 | torch.tensor( 35 | [ 36 | logits[tokenizer.convert_tokens_to_ids("A")], 37 | logits[tokenizer.convert_tokens_to_ids("B")], 38 | logits[tokenizer.convert_tokens_to_ids("C")], 39 | logits[tokenizer.convert_tokens_to_ids("D")], 40 | ] 41 | ).detach().cpu().numpy() 42 | ) 43 | 44 | pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] 45 | 46 | return pred 47 | 48 | def main(args): 49 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 50 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path, trust_remote_code=True).half() 51 | model.cuda().eval() 52 | 53 | if not os.path.exists(args.output_folder): os.mkdir(args.output_folder) 54 | 55 | json_input_paths = glob.glob(os.path.join(args.input_folder, '*.json')) 56 | json_file_names = [os.path.basename(i) for i in json_input_paths] 57 | json_output_paths = [os.path.join(args.output_folder, i) for i in json_file_names] 58 | 59 | make_input = getattr(template, args.template_type) 60 | solve = solve_ABCD 61 | post_process = post_process_ABCD 62 | 63 | all_cnt, all_right = 0, 0 64 | for file_name, input_path, output_path in zip(json_file_names, json_input_paths, json_output_paths): 65 | out_list = [] 66 | with open(input_path, 'r') as fr, open(output_path, 'w') as fw: 67 | test_inputs = json.load(fr) 68 | for idx in range(0, len(test_inputs)): 69 | input_data = test_inputs[idx] 70 | candidates, decoder_input_text = make_input(file_name, input_data) 71 | input_text = template.choose_longest_input(candidates, args.max_length, tokenizer, args.add_prefix) 72 | if args.add_prefix: 73 | input_text = f"<{args.ptoken}> " + input_text + " " 74 | decoder_input_text = f" " + decoder_input_text 75 | 76 | pred = solve(model, tokenizer, input_text, decoder_input_text, args) 77 | out_list.append((pred, input_text, decoder_input_text)) 78 | ABCD_rate, cnt, right, data2write = post_process(out_list, test_inputs) 79 | print(file_name.replace('_test.json', '').replace('_', ' '), ': \n', f"acc: {(right / cnt)*100:.2f}%") 80 | for label, rate in zip(['A', 'B', 'C', 'D'], ABCD_rate): print(f"{label}: {rate*100:.2f}%", end = '|') 81 | print('\n' + '-' * 30) 82 | all_cnt, all_right = all_cnt + cnt, all_right + right 83 | for line in data2write:fw.write(line) 84 | 85 | 86 | print(f"all acc: {(all_right / all_cnt)*100:.2f}%", ) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--input-folder", type=str, default="/public/home/ljt/LLM/wpz/LEO_mmlu/4shot",) 91 | parser.add_argument("--output-folder", type=str, default="tmp",) 92 | parser.add_argument("--model-path", type=str, default="/public/home/ljt/LEO/checkpoint/14b_flan_new/iter_0003000_hf",) 93 | parser.add_argument("--max-length", type=int, default=512,) 94 | parser.add_argument("--decoder-max-length", type=int, default=128,) 95 | parser.add_argument("--padding", type=str, default="longest",) 96 | parser.add_argument("--add-prefix", action='store_true') 97 | parser.add_argument("--template-type", type = str, default="make_ABCD_input_0_shot") 98 | parser.add_argument("--ptoken", type = str, default='S') 99 | args = parser.parse_args() 100 | 101 | for arg in vars(args): 102 | print(f"{arg}: {getattr(args, arg)}") 103 | 104 | main(args) 105 | -------------------------------------------------------------------------------- /evaluation/CMMLU/make_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import random 5 | random.seed(42) 6 | 7 | def csv_to_list(filepath): 8 | """ 9 | Convert a CSV file to a list of questions. 10 | """ 11 | questions = [] 12 | with open(filepath, 'r', encoding='utf-8') as file: 13 | reader = csv.reader(file) 14 | for row in reader: 15 | questions.append({ 16 | 'question': row[1], 17 | 'res1': row[2], 18 | 'res2': row[3], 19 | 'res3': row[4], 20 | 'res4': row[5], 21 | 'ans': row[6] 22 | }) 23 | return questions[1:] 24 | 25 | def generate_json(val_dir, test_dir, output_dir): 26 | 27 | 28 | # Process test directory 29 | for filename in os.listdir(test_dir): 30 | test_filepath = os.path.join(test_dir, filename) 31 | test_data = csv_to_list(test_filepath) 32 | val_filepath = os.path.join(val_dir, filename.replace('test', 'val')) 33 | demo_list = csv_to_list(val_filepath) 34 | demo_list = random.sample(demo_list, 5) 35 | print(val_filepath, len(demo_list)) 36 | final_data = [] 37 | for item in test_data: 38 | entry = { 39 | 'demo': demo_list, 40 | 'data': item 41 | } 42 | final_data.append(entry) 43 | 44 | # Save to JSON 45 | json_filename = filename.replace('.csv', '.json') 46 | json_filepath = os.path.join(output_dir, json_filename) 47 | with open(json_filepath, 'w', encoding='utf-8') as file: 48 | json.dump(final_data, file, ensure_ascii=False, indent=2) 49 | 50 | if __name__ == "__main__": 51 | val_dir = 'path_to_cmmlu_dev_folder' 52 | test_dir = 'path_to_cmmlu_test_folder' 53 | output_dir = './data/5shot' # 替换为你的目标输出文件夹 54 | 55 | # Ensure output directory exists 56 | if not os.path.exists(output_dir): 57 | os.makedirs(output_dir) 58 | 59 | generate_json(val_dir, test_dir, output_dir) 60 | -------------------------------------------------------------------------------- /evaluation/CMMLU/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # CMMLU evaluation 3 | 4 | Here we provide scripts for inference of CMMLU with OpenBA. 5 | `make_data.py` is the script to construct `./data/5shot`, which is the cmmlu dataset in json format. 6 | 7 | ```bash 8 | mkdir ./output 9 | bash scripts/eval_fewshot.sh # for few shot 10 | bash scripts/eval_zeroshot.sh # for zero shot 11 | ``` 12 | -------------------------------------------------------------------------------- /evaluation/CMMLU/scripts/eval_fewshot.sh: -------------------------------------------------------------------------------- 1 | name=OpenBT5-5shot 2 | input_folder="./data/5shot" 3 | current_model_path="/public/home/ljt/LLM/wpz/hf_models/OpenBT5-Flan" 4 | current_template="make_ABCD_input_5_shot" 5 | current_output_folder="./output/${name}" 6 | log_name="logs/${name}" 7 | max_length=1024 8 | decoder_max_length=256 9 | 10 | export CUDA_VISIBLE_DEVICES=1 11 | nohup python -u main.py \ 12 | --model-path $current_model_path \ 13 | --max-length $max_length \ 14 | --input-folder $input_folder \ 15 | --output-folder $current_output_folder \ 16 | --template-type $current_template \ 17 | --decoder-max-length $decoder_max_length \ 18 | --add-prefix \ 19 | --ptoken S > $log_name 2>&1 & 20 | -------------------------------------------------------------------------------- /evaluation/CMMLU/scripts/eval_zeroshot.sh: -------------------------------------------------------------------------------- 1 | name=OpenBT5-0shot 2 | input_folder="./data/5shot" 3 | current_model_path="/public/home/ljt/LLM/wpz/hf_models/OpenBT5-Flan" 4 | current_template="make_ABCD_input_0_shot" 5 | current_output_folder="./output/${name}" 6 | log_name="logs/${name}" 7 | max_length=1024 8 | decoder_max_length=256 9 | 10 | export CUDA_VISIBLE_DEVICES=0 11 | nohup python -u main.py \ 12 | --model-path $current_model_path \ 13 | --max-length $max_length \ 14 | --input-folder $input_folder \ 15 | --output-folder $current_output_folder \ 16 | --template-type $current_template \ 17 | --decoder-max-length $decoder_max_length \ 18 | --add-prefix \ 19 | --ptoken S > $log_name 2>&1 & 20 | -------------------------------------------------------------------------------- /evaluation/CMMLU/template.py: -------------------------------------------------------------------------------- 1 | name_en2zh = { 2 | "agronomy": "农学", 3 | "anatomy": "解剖学", 4 | "ancient_chinese": "古汉语", 5 | "arts": "艺术学", 6 | "astronomy": "天文学", 7 | "business_ethics": "商业伦理", 8 | "chinese_civil_service_exam": "中国公务员考试", 9 | "chinese_driving_rule": "中国驾驶规则", 10 | "chinese_food_culture": "中国饮食文化", 11 | "chinese_foreign_policy": "中国外交政策", 12 | "chinese_history":"中国历史", 13 | "chinese_literature": "中国文学", 14 | "chinese_teacher_qualification": "中国教师资格", 15 | "clinical_knowledge": "临床知识", 16 | "college_actuarial_science":"大学精算学", 17 | "college_education":"大学教育学", 18 | "college_engineering_hydrology": "大学工程水文学", 19 | "college_law": "大学法律", 20 | "college_mathematics": "大学数学", 21 | "college_medical_statistics":"大学医学统计", 22 | "college_medicine": "大学医学", 23 | "computer_science": "计算机科学", 24 | "computer_security": "计算机安全", 25 | "conceptual_physics": "概念物理学", 26 | "construction_project_management": "建设工程管理", 27 | "economics": "经济学", 28 | "education": "教育学", 29 | "electrical_engineering": "电气工程", 30 | "elementary_chinese":"小学语文", 31 | "elementary_commonsense":"小学常识", 32 | "elementary_information_and_technology": "小学信息技术", 33 | "elementary_mathematics": "初等数学", 34 | "ethnology": "民族学", 35 | "food_science": "食品科学", 36 | "genetics": "遗传学", 37 | "global_facts": "全球事实", 38 | "high_school_biology": "高中生物", 39 | "high_school_chemistry": "高中化学", 40 | "high_school_geography": "高中地理", 41 | "high_school_mathematics": "高中数学", 42 | "high_school_physics": "高中物理学", 43 | "high_school_politics": "高中政治", 44 | "human_sexuality": "人类性行为", 45 | "international_law": "国际法学", 46 | "journalism": "新闻学", 47 | "jurisprudence": "法理学", 48 | "legal_and_moral_basis": "法律与道德基础", 49 | "logical": "逻辑学", 50 | "machine_learning": "机器学习", 51 | "management": "管理学", 52 | "marketing": "市场营销", 53 | "marxist_theory": "马克思主义理论", 54 | "modern_chinese": "现代汉语", 55 | "nutrition": "营养学", 56 | "philosophy": "哲学", 57 | "professional_accounting": "专业会计", 58 | "professional_law": "专业法学", 59 | "professional_medicine": "专业医学", 60 | "professional_psychology": "专业心理学", 61 | "public_relations": "公共关系", 62 | "security_study":"安全研究", 63 | "sociology": "社会学", 64 | "sports_science": "体育学", 65 | "traditional_chinese_medicine": "中医中药", 66 | "virology": "病毒学", 67 | "world_history":"世界历史", 68 | "world_religions": "世界宗教", 69 | } 70 | 71 | 72 | def make_ABCD_input_0_shot(subject, data): 73 | demo = data['data'] 74 | ASK_TEMPLATE = "以下是关于({:})的单项选择题,请直接给出正确答案的选项。题目:{:} A. {:} B. {:} C. {:} D. {:} " 75 | ANS_TEMPLATE = "答案是:{:}" 76 | input_text = ASK_TEMPLATE.format(name_en2zh[subject.split('.')[0]], demo["question"], demo["res1"], demo["res2"], demo["res3"], demo["res4"]) 77 | decoder_input_text = ANS_TEMPLATE.format('') 78 | return [input_text], decoder_input_text 79 | 80 | def make_ABCD_input_5_shot(subject, data): 81 | demo = data['data'] 82 | ASK_TEMPLATE = "以下是关于({:})的单项选择题,请直接给出正确答案的选项。题目:{:} A. {:} B. {:} C. {:} D. {:} " 83 | ANS_TEMPLATE = "答案是:{:}" 84 | input_text = ASK_TEMPLATE.format(name_en2zh[subject.split('.')[0]], demo["question"], demo["res1"], demo["res2"], demo["res3"], demo["res4"]) 85 | decoder_input_text = ANS_TEMPLATE.format('') 86 | demos = data["demo"] 87 | fs_input_text = "" 88 | input_texts = [input_text] 89 | for demo in demos: 90 | fs_input_text += ASK_TEMPLATE.format(name_en2zh[subject.split('.')[0]], demo["question"], demo["res1"], demo["res2"], demo["res3"], demo["res4"]) + \ 91 | ANS_TEMPLATE.format(demo[f"ans"]) + '\n ' 92 | input_texts.append(fs_input_text + input_text) 93 | return input_texts, decoder_input_text 94 | 95 | def choose_longest_input(cand, max_length, tokenizer, add_s): 96 | idx = len(cand) - 1 97 | while idx >= 0: 98 | length = len(tokenizer(cand[idx])["input_ids"]) 99 | if add_s: length += 2 100 | if length <= max_length: 101 | return cand[idx] 102 | idx -= 1 103 | return cand[0] -------------------------------------------------------------------------------- /evaluation/MMLU/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | import os 5 | import glob 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | import template 9 | import numpy as np 10 | import os 11 | 12 | 13 | def post_process_ABCD(out_list, test_inputs): 14 | golds = [i['data']['ans'] for i in test_inputs] 15 | encs = [i[1] for i in out_list] 16 | decs = [i[2] for i in out_list] 17 | preds = [i[0] for i in out_list] 18 | assert len(preds) == len(decs) == len(golds) == len(decs), str(len(preds), len(decs), len(golds), len(decs)) 19 | data2write = [json.dumps({'enc':enc, 'dec':dec, 'gold':gold, 'pred':pred}) + '\n' for enc, dec, gold, pred in zip(encs, decs, golds, preds)] 20 | right = sum([1 for i, j in zip(golds, preds) if i == j]) 21 | cnt = len(golds) 22 | ABCD_rate = [sum([1 for i in preds if i == j]) / cnt for j in ["A", "B", "C", "D"]] 23 | return ABCD_rate, cnt, right, data2write 24 | 25 | def solve_ABCD(model, tokenizer, input_text, decoder_input_text, args): 26 | input_ids = tokenizer(input_text, return_tensors='pt', max_length = args.max_length, truncation=True).input_ids 27 | decoder_input_ids= tokenizer(decoder_input_text, return_tensors='pt', max_length = args.decoder_max_length - 1, truncation=True).input_ids 28 | decoder_input_ids = model._shift_right(decoder_input_ids) 29 | with torch.no_grad(): 30 | logits = model(input_ids=input_ids.cuda(), \ 31 | decoder_input_ids=decoder_input_ids.cuda()).logits[:,-1,:].contiguous()[0] 32 | 33 | probs = ( 34 | torch.tensor( 35 | [ 36 | logits[tokenizer.convert_tokens_to_ids("A")], 37 | logits[tokenizer.convert_tokens_to_ids("B")], 38 | logits[tokenizer.convert_tokens_to_ids("C")], 39 | logits[tokenizer.convert_tokens_to_ids("D")], 40 | ] 41 | ).detach().cpu().numpy() 42 | ) 43 | 44 | pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] 45 | 46 | return pred 47 | 48 | def main(args): 49 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 50 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path, trust_remote_code=True).half() 51 | model.cuda().eval() 52 | 53 | if not os.path.exists(args.output_folder): os.mkdir(args.output_folder) 54 | 55 | json_input_paths = glob.glob(os.path.join(args.input_folder, '*.json')) 56 | json_file_names = [os.path.basename(i) for i in json_input_paths] 57 | json_output_paths = [os.path.join(args.output_folder, i) for i in json_file_names] 58 | 59 | make_input = getattr(template, args.template_type) 60 | solve = solve_ABCD 61 | post_process = post_process_ABCD 62 | 63 | all_cnt, all_right = 0, 0 64 | for file_name, input_path, output_path in zip(json_file_names, json_input_paths, json_output_paths): 65 | out_list = [] 66 | with open(input_path, 'r') as fr, open(output_path, 'w') as fw: 67 | test_inputs = json.load(fr) 68 | for idx in range(0, len(test_inputs)): 69 | input_data = test_inputs[idx] 70 | candidates, decoder_input_text = make_input(file_name, input_data) 71 | input_text = template.choose_longest_input(candidates, args.max_length, tokenizer, args.add_prefix) 72 | if args.add_prefix: 73 | input_text = f"<{args.ptoken}> " + input_text + " " 74 | decoder_input_text = f" " + decoder_input_text 75 | 76 | pred = solve(model, tokenizer, input_text, decoder_input_text, args) 77 | out_list.append((pred, input_text, decoder_input_text)) 78 | ABCD_rate, cnt, right, data2write = post_process(out_list, test_inputs) 79 | print(file_name.replace('_test.json', '').replace('_', ' '), ': \n', f"acc: {(right / cnt)*100:.2f}%") 80 | for label, rate in zip(['A', 'B', 'C', 'D'], ABCD_rate): print(f"{label}: {rate*100:.2f}%", end = '|') 81 | print('\n' + '-' * 30) 82 | all_cnt, all_right = all_cnt + cnt, all_right + right 83 | for line in data2write:fw.write(line) 84 | 85 | 86 | print(f"all acc: {(all_right / all_cnt)*100:.2f}%", ) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--input-folder", type=str, default="/public/home/ljt/LLM/wpz/LEO_mmlu/4shot",) 91 | parser.add_argument("--output-folder", type=str, default="tmp",) 92 | parser.add_argument("--model-path", type=str, default="/public/home/ljt/LEO/checkpoint/14b_flan_new/iter_0003000_hf",) 93 | parser.add_argument("--max-length", type=int, default=512,) 94 | parser.add_argument("--decoder-max-length", type=int, default=128,) 95 | parser.add_argument("--padding", type=str, default="longest",) 96 | parser.add_argument("--add-prefix", action='store_true') 97 | parser.add_argument("--template-type", type = str, default="make_ABCD_input_0_shot") 98 | parser.add_argument("--ptoken", type = str, default='S') 99 | args = parser.parse_args() 100 | 101 | for arg in vars(args): 102 | print(f"{arg}: {getattr(args, arg)}") 103 | 104 | main(args) 105 | -------------------------------------------------------------------------------- /evaluation/MMLU/make_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import random 5 | random.seed(42) 6 | 7 | def csv_to_list(filepath): 8 | """ 9 | Convert a CSV file to a list of questions. 10 | """ 11 | questions = [] 12 | with open(filepath, 'r', encoding='utf-8') as file: 13 | reader = csv.reader(file) 14 | for row in reader: 15 | questions.append({ 16 | 'question': row[0], 17 | 'res1': row[1], 18 | 'res2': row[2], 19 | 'res3': row[3], 20 | 'res4': row[4], 21 | 'ans': row[5] 22 | }) 23 | return questions 24 | 25 | def generate_json(val_dir, test_dir, output_dir): 26 | 27 | for filename in os.listdir(test_dir): 28 | test_filepath = os.path.join(test_dir, filename) 29 | test_data = csv_to_list(test_filepath) 30 | val_filepath = os.path.join(val_dir, filename.replace('test', 'val')) 31 | demo_list = csv_to_list(val_filepath) 32 | demo_list = random.sample(demo_list, 5) 33 | final_data = [] 34 | for item in test_data: 35 | entry = { 36 | 'demo': demo_list, 37 | 'data': item 38 | } 39 | final_data.append(entry) 40 | 41 | json_filename = filename.replace('.csv', '.json') 42 | json_filepath = os.path.join(output_dir, json_filename) 43 | with open(json_filepath, 'w', encoding='utf-8') as file: 44 | json.dump(final_data, file, ensure_ascii=False, indent=2) 45 | 46 | if __name__ == "__main__": 47 | val_dir = 'path_to_cmmlu_dev_folder' 48 | test_dir = 'path_to_cmmlu_test_folder' 49 | output_dir = './data/5shot' # 替换为你的目标输出文件夹 50 | 51 | # Ensure output directory exists 52 | if not os.path.exists(output_dir): 53 | os.makedirs(output_dir) 54 | 55 | generate_json(val_dir, test_dir, output_dir) 56 | -------------------------------------------------------------------------------- /evaluation/MMLU/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # MMLU evaluation 3 | 4 | Here we provide scripts for inference of MMLU with OpenBA. 5 | 6 | `make_data.py` is the script to construct `./data/5shot`, which is the MMLU dataset in json format. 7 | 8 | ```bash 9 | mkdir ./output 10 | bash scripts/eval_fewshot.sh # for few shot 11 | bash scripts/eval_zeroshot.sh # for zero shot 12 | ``` 13 | -------------------------------------------------------------------------------- /evaluation/MMLU/scripts/eval_fewshot.sh: -------------------------------------------------------------------------------- 1 | name=OpenBT5-5shot 2 | input_folder="./data/5shot" 3 | current_model_path="/public/home/ljt/LLM/wpz/hf_models/OpenBT5-Flan" 4 | current_template="make_ABCD_input_5_shot" 5 | current_output_folder="./output/${name}" 6 | log_name="logs/${name}" 7 | max_length=1024 8 | decoder_max_length=256 9 | 10 | export CUDA_VISIBLE_DEVICES=2 11 | nohup python -u main.py \ 12 | --model-path $current_model_path \ 13 | --max-length $max_length \ 14 | --input-folder $input_folder \ 15 | --output-folder $current_output_folder \ 16 | --template-type $current_template \ 17 | --decoder-max-length $decoder_max_length \ 18 | --add-prefix \ 19 | --ptoken S > $log_name 2>&1 & 20 | -------------------------------------------------------------------------------- /evaluation/MMLU/scripts/eval_zeroshot.sh: -------------------------------------------------------------------------------- 1 | name=OpenBT5-0shot 2 | input_folder="./data/5shot" 3 | current_model_path="/public/home/ljt/LLM/wpz/hf_models/OpenBT5-Flan" 4 | current_template="make_ABCD_input_0_shot" 5 | current_output_folder="./output/${name}" 6 | log_name="logs/${name}" 7 | max_length=1024 8 | decoder_max_length=256 9 | 10 | export CUDA_VISIBLE_DEVICES=3 11 | nohup python -u main.py \ 12 | --model-path $current_model_path \ 13 | --max-length $max_length \ 14 | --input-folder $input_folder \ 15 | --output-folder $current_output_folder \ 16 | --template-type $current_template \ 17 | --decoder-max-length $decoder_max_length \ 18 | --add-prefix \ 19 | --ptoken S > $log_name 2>&1 & 20 | -------------------------------------------------------------------------------- /evaluation/MMLU/template.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def make_ABCD_input_0_shot(subject, data): 4 | demo = data['data'] 5 | ASK_TEMPLATE = "Question: {:} Options: A. {:} B. {:} C. {:} D. {:} Answer:" 6 | ANS_TEMPLATE = "" 7 | input_text = ASK_TEMPLATE.format(demo["question"], demo["res1"], demo["res2"], demo["res3"], demo["res4"]) 8 | decoder_input_text = ANS_TEMPLATE 9 | return [input_text], decoder_input_text 10 | 11 | def make_ABCD_input_5_shot(subject, data): 12 | demo = data['data'] 13 | ASK_TEMPLATE = "Question: {:} Options: A. {:} B. {:} C. {:} D. {:} Answer:" 14 | ANS_TEMPLATE = "{:}" 15 | input_text = ASK_TEMPLATE.format(demo["question"], demo["res1"], demo["res2"], demo["res3"], demo["res4"]) # origin 16 | decoder_input_text = ANS_TEMPLATE 17 | demos = data["demo"] 18 | fs_input_text = "" 19 | input_texts = [input_text] 20 | for demo in demos: 21 | fs_input_text += ASK_TEMPLATE.format(demo["question"], demo["res1"], demo["res2"], demo["res3"], demo["res4"]) + \ 22 | ANS_TEMPLATE.format(demo[f"ans"]) + '\n ' # origin 23 | input_texts.append(fs_input_text + input_text) 24 | return input_texts, decoder_input_text 25 | 26 | 27 | def choose_longest_input(cand, max_length, tokenizer, add_s): 28 | idx = len(cand) - 1 29 | while idx >= 0: 30 | length = len(tokenizer(cand[idx])["input_ids"]) 31 | if add_s: length += 2 32 | if length <= max_length: 33 | return cand[idx] 34 | idx -= 1 35 | return cand[0] 36 | 37 | 38 | -------------------------------------------------------------------------------- /gradio_chat_demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import random 3 | import time 4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 5 | import re 6 | import torch 7 | import os 8 | 9 | tokenizer = AutoTokenizer.from_pretrained('OpenBA/OpenBA-Flan', trust_remote_code=True) 10 | model = AutoModelForSeq2SeqLM.from_pretrained('OpenBA/OpenBA-Flan', trust_remote_code=True).half().cuda() 11 | model.eval() 12 | 13 | def case_insensitive_replace(input_str, from_str, to_str): 14 | pattern = re.compile(re.escape(from_str), re.IGNORECASE) 15 | return pattern.sub(to_str, input_str) 16 | 17 | 18 | def history2input(chat_history, message): 19 | input_text = "" 20 | for i, j in chat_history: 21 | input_text += f"Human: {i} Assistant: {j} " 22 | return input_text + f"Human: {message} Assistant: " 23 | 24 | def gpu_respond(message, top_p, temp, chat_history): 25 | input_text = history2input(chat_history, message) 26 | print("input:", input_text) 27 | bot_message = generate(input_text, top_p, temp) 28 | print("message:", bot_message) 29 | print('-' * 30) 30 | chat_history.append((message, bot_message)) 31 | return "", chat_history 32 | 33 | def generate(input_text, top_p = 0.7, temp = 0.95): 34 | inputs = tokenizer(" " + input_text + " ", return_tensors='pt') 35 | for k in inputs: 36 | inputs[k] = inputs[k].cuda() 37 | 38 | outputs = model.generate( 39 | **inputs, 40 | do_sample=True, 41 | max_new_tokens=512, 42 | temperature = temp, 43 | top_p = top_p, 44 | ) 45 | 46 | response = tokenizer.decode(outputs[0], skip_special_tokens=True) 47 | return response 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | with gr.Blocks() as demo: 53 | chatbot = gr.Chatbot() 54 | msg = gr.Textbox() 55 | clear = gr.ClearButton([msg, chatbot]) 56 | top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top P") 57 | temp = gr.Slider(minimum=0.01, maximum=1.0, value=0.95, label="Temperature") 58 | 59 | msg.submit(gpu_respond, [msg, top_p, temp, chatbot], [msg, chatbot]) 60 | 61 | demo.queue(concurrency_count=3) 62 | demo.launch(share=True) 63 | -------------------------------------------------------------------------------- /gradio_code_demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import random 3 | import time 4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 5 | import re 6 | import torch 7 | import os 8 | 9 | 10 | tokenizer = AutoTokenizer.from_pretrained("OpenBA/OpenBA-Code", trust_remote_code=True) 11 | model = AutoModelForSeq2SeqLM.from_pretrained("OpenBA/OpenBA-Code", trust_remote_code=True).half().cuda() 12 | model.eval() 13 | 14 | def case_insensitive_replace(input_str, from_str, to_str): 15 | pattern = re.compile(re.escape(from_str), re.IGNORECASE) 16 | return pattern.sub(to_str, input_str) 17 | 18 | 19 | def history2input(chat_history, message): 20 | return message 21 | 22 | def gpu_respond(message, top_p, temp, chat_history): 23 | input_text = history2input(chat_history, message) 24 | print("input:", input_text) 25 | bot_message = generate(input_text, top_p, temp) 26 | print("message:", bot_message) 27 | print('-' * 30) 28 | chat_history.append((message, bot_message)) 29 | return "", chat_history 30 | 31 | def generate(input_text, top_p=0.7, temp=0.95): 32 | inputs = tokenizer(" " + input_text + " ", return_tensors='pt') 33 | for k in inputs: 34 | inputs[k] = inputs[k].cuda() 35 | 36 | outputs = model.generate( 37 | **inputs, 38 | do_sample=True, 39 | max_new_tokens=1024, 40 | temperature = temp, 41 | top_p = top_p, 42 | ) 43 | response = tokenizer.decode(outputs[0][1:], spaces_between_special_tokens=False) + '\n' 44 | return response 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | with gr.Blocks() as demo: 50 | chatbot = gr.Chatbot() 51 | msg = gr.Textbox() 52 | clear = gr.ClearButton([msg, chatbot]) 53 | top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top P") 54 | temp = gr.Slider(minimum=0.01, maximum=1.0, value=0.95, label="Temperature") 55 | 56 | msg.submit(gpu_respond, [msg, top_p, temp, chatbot], [msg, chatbot]) 57 | 58 | demo.queue(concurrency_count=3) 59 | demo.launch(share=True) 60 | -------------------------------------------------------------------------------- /training/.coveragerc: -------------------------------------------------------------------------------- 1 | [html] 2 | directory = coverage 3 | 4 | [run] 5 | data_file = .coverage_$LOCAL_RANK 6 | -------------------------------------------------------------------------------- /training/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.so 3 | build 4 | .coverage_* 5 | *.egg-info 6 | *~ 7 | 8 | 9 | demo.json 10 | checkpoint/ 11 | job-output/ 12 | -------------------------------------------------------------------------------- /training/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | from .global_vars import get_args, get_retro_args 6 | from .global_vars import get_current_global_batch_size 7 | from .global_vars import get_num_microbatches 8 | from .global_vars import get_signal_handler 9 | from .global_vars import update_num_microbatches 10 | from .global_vars import get_tokenizer 11 | from .global_vars import get_tensorboard_writer 12 | from .global_vars import get_adlr_autoresume 13 | from .global_vars import get_timers 14 | from .initialize import initialize_megatron 15 | 16 | from .utils import (print_rank_0, 17 | is_last_rank, 18 | print_rank_last) 19 | -------------------------------------------------------------------------------- /training/megatron/core/__init__.py: -------------------------------------------------------------------------------- 1 | import megatron.core.parallel_state 2 | import megatron.core.tensor_parallel 3 | import megatron.core.utils 4 | 5 | # Alias parallel_state as mpu, its legacy name 6 | mpu = parallel_state 7 | 8 | __all__ = [ 9 | "parallel_state", 10 | "tensor_parallel", 11 | "utils", 12 | ] 13 | -------------------------------------------------------------------------------- /training/megatron/core/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import enum 4 | 5 | class ModelType(enum.Enum): 6 | encoder_or_decoder = 1 7 | encoder_and_decoder = 2 8 | -------------------------------------------------------------------------------- /training/megatron/core/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .schedules import get_forward_backward_func 2 | -------------------------------------------------------------------------------- /training/megatron/core/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import vocab_parallel_cross_entropy 2 | from .data import broadcast_data 3 | 4 | from .layers import ( 5 | ColumnParallelLinear, 6 | RowParallelLinear, 7 | VocabParallelEmbedding, 8 | set_tensor_model_parallel_attributes, 9 | set_defaults_if_not_set_tensor_model_parallel_attributes, 10 | copy_tensor_model_parallel_attributes, 11 | param_is_not_tensor_parallel_duplicate, 12 | linear_with_grad_accumulation_and_async_allreduce 13 | 14 | ) 15 | 16 | from .mappings import ( 17 | copy_to_tensor_model_parallel_region, 18 | gather_from_tensor_model_parallel_region, 19 | gather_from_sequence_parallel_region, 20 | scatter_to_tensor_model_parallel_region, 21 | scatter_to_sequence_parallel_region, 22 | ) 23 | 24 | from .random import ( 25 | checkpoint, 26 | get_cuda_rng_tracker, 27 | model_parallel_cuda_manual_seed, 28 | ) 29 | 30 | from .utils import ( 31 | split_tensor_along_last_dim, 32 | split_tensor_into_1d_equal_chunks, 33 | gather_split_1d_tensor, 34 | ) 35 | 36 | __all__ = [ 37 | # cross_entropy.py 38 | "vocab_parallel_cross_entropy", 39 | # data.py 40 | "broadcast_data", 41 | #layers.py 42 | "ColumnParallelLinear", 43 | "RowParallelLinear", 44 | "VocabParallelEmbedding", 45 | "set_tensor_model_parallel_attributes", 46 | "set_defaults_if_not_set_tensor_model_parallel_attributes", 47 | "copy_tensor_model_parallel_attributes", 48 | "param_is_not_tensor_parallel_duplicate", 49 | "linear_with_grad_accumulation_and_async_allreduce", 50 | # mappings.py 51 | "copy_to_tensor_model_parallel_region", 52 | "gather_from_tensor_model_parallel_region", 53 | "gather_from_sequence_parallel_region", 54 | # "reduce_from_tensor_model_parallel_region", 55 | "scatter_to_tensor_model_parallel_region", 56 | "scatter_to_sequence_parallel_region", 57 | # random.py 58 | "checkpoint", 59 | "get_cuda_rng_tracker", 60 | "model_parallel_cuda_manual_seed", 61 | # utils.py 62 | "split_tensor_along_last_dim", 63 | "split_tensor_into_1d_equal_chunks", 64 | "gather_split_1d_tensor", 65 | ] 66 | -------------------------------------------------------------------------------- /training/megatron/core/tensor_parallel/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | from megatron.core.parallel_state import ( 6 | get_tensor_model_parallel_group, 7 | get_tensor_model_parallel_rank, 8 | get_tensor_model_parallel_src_rank, 9 | ) 10 | 11 | 12 | _MAX_DATA_DIM = 5 13 | 14 | 15 | def _check_data_types(keys, data, target_dtype): 16 | """Check that all the keys have the same target data type.""" 17 | for key in keys: 18 | assert data[key].dtype == target_dtype, '{} has data type {} which '\ 19 | 'is different than {}'.format(key, data[key].dtype, target_dtype) 20 | 21 | 22 | def _build_key_size_numel_dictionaries(keys, data): 23 | """Build the size on rank 0 and broadcast.""" 24 | max_dim = _MAX_DATA_DIM 25 | sizes = [0 for _ in range(max_dim) for _ in keys] 26 | 27 | # Pack the sizes on rank zero. 28 | if get_tensor_model_parallel_rank() == 0: 29 | offset = 0 30 | for key in keys: 31 | assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' 32 | size = data[key].size() 33 | for i, s in enumerate(size): 34 | sizes[i + offset] = s 35 | offset += max_dim 36 | 37 | # Move to GPU and broadcast. 38 | sizes_cuda = torch.cuda.LongTensor(sizes) 39 | torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(), 40 | group=get_tensor_model_parallel_group()) 41 | 42 | # Move back to cpu and unpack. 43 | sizes_cpu = sizes_cuda.cpu() 44 | key_size = {} 45 | key_numel = {} 46 | total_numel = 0 47 | offset = 0 48 | for key in keys: 49 | i = 0 50 | size = [] 51 | numel = 1 52 | while sizes_cpu[offset + i] > 0: 53 | this_size = sizes_cpu[offset + i] 54 | size.append(this_size) 55 | numel *= this_size 56 | i += 1 57 | key_size[key] = size 58 | key_numel[key] = numel 59 | total_numel += numel 60 | offset += max_dim 61 | 62 | return key_size, key_numel, total_numel 63 | 64 | 65 | def broadcast_data(keys, data, datatype): 66 | """Broadcast data from rank zero of each model parallel group to the 67 | members of the same model parallel group. 68 | 69 | Arguments: 70 | keys: list of keys in the data disctionary to be broadcasted 71 | data: data dictionary of string keys and cpu tensor values. 72 | datatype: torch data type of all tensors in data associated 73 | with keys. 74 | """ 75 | # Build (key, size) and (key, number of elements) dictionaries along 76 | # with the total number of elements on all ranks. 77 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, 78 | data) 79 | 80 | # Pack on rank zero. 81 | if get_tensor_model_parallel_rank() == 0: 82 | # Check that all keys have the same data type. 83 | _check_data_types(keys, data, datatype) 84 | # Flatten the data associated with the keys 85 | flatten_data = torch.cat( 86 | [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() 87 | else: 88 | flatten_data = torch.empty(total_numel, 89 | device=torch.cuda.current_device(), 90 | dtype=datatype) 91 | 92 | # Broadcast 93 | torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(), 94 | group=get_tensor_model_parallel_group()) 95 | 96 | # Unpack 97 | output = {} 98 | offset = 0 99 | for key in keys: 100 | size = key_size[key] 101 | numel = key_numel[key] 102 | output[key] = flatten_data.narrow(0, offset, numel).view(size) 103 | offset += numel 104 | 105 | return output 106 | -------------------------------------------------------------------------------- /training/megatron/core/tensor_parallel/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | from typing import List, Sequence 5 | 6 | from megatron.core.utils import divide 7 | from megatron.core import parallel_state 8 | 9 | def split_tensor_along_last_dim( 10 | tensor: torch.Tensor, 11 | num_partitions: int, 12 | contiguous_split_chunks: bool = False, 13 | ) -> List[torch.Tensor]: 14 | """ Split a tensor along its last dimension. 15 | 16 | Arguments: 17 | tensor: input tensor. 18 | num_partitions: number of partitions to split the tensor 19 | contiguous_split_chunks: If True, make each chunk contiguous 20 | in memory. 21 | 22 | Returns: 23 | A list of Tensors 24 | """ 25 | # Get the size and dimension. 26 | last_dim = tensor.dim() - 1 27 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 28 | # Split. 29 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 30 | # Note: torch.split does not create contiguous tensors by default. 31 | if contiguous_split_chunks: 32 | return tuple(chunk.contiguous() for chunk in tensor_list) 33 | 34 | return tensor_list 35 | 36 | def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): 37 | """ Break a tensor into equal 1D chunks across tensor parallel ranks. 38 | 39 | Returns a Tensor or View with this rank's portion of the data. 40 | 41 | Arguments: 42 | tensor: The tensor to split 43 | 44 | Keyword Arguments: 45 | new_buffer (bool): If True, returns a new Tensor. 46 | If False, returns a view into the existing Tensor. 47 | Default is False 48 | 49 | """ 50 | partition_size = torch.numel(tensor) // \ 51 | parallel_state.get_tensor_model_parallel_world_size() 52 | start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() 53 | end_index = start_index + partition_size 54 | if new_buffer: 55 | data = torch.empty(partition_size, dtype=tensor.dtype, 56 | device=torch.cuda.current_device(), 57 | requires_grad=False) 58 | data.copy_(tensor.view(-1)[start_index:end_index]) 59 | else: 60 | data = tensor.view(-1)[start_index:end_index] 61 | return data 62 | 63 | 64 | def gather_split_1d_tensor(tensor): 65 | """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor 66 | model parallel ranks. 67 | 68 | Returns a new Tensor with the gathered data. 69 | 70 | Arguments: 71 | tensor: A Tensor or view of this rank's portion of the data. 72 | """ 73 | numel_gathered = torch.numel(tensor) * \ 74 | parallel_state.get_tensor_model_parallel_world_size() 75 | gathered = torch.empty(numel_gathered, dtype=tensor.dtype, 76 | device=torch.cuda.current_device(), 77 | requires_grad=False) 78 | # TODO: This API is experimental in pytorch (as of Feb 2022) and 79 | # this might break in future pytorch releases. We chose this API 80 | # as opposed to torch.distributed.all_gather for efficiency reasons. 81 | # This API calls directly NCCL all-gather versus the former does 82 | # internal copies and can potentially cause slow down. 83 | torch.distributed._all_gather_base(gathered, tensor, 84 | group=parallel_state.get_tensor_model_parallel_group()) 85 | return gathered 86 | 87 | 88 | class VocabUtility: 89 | """ Split the vocabulary into `world_size` chunks and return the first 90 | and last index of the vocabulary belonging to the `rank` 91 | partition: Note that indices in [fist, last) 92 | 93 | """ 94 | 95 | @staticmethod 96 | def vocab_range_from_per_partition_vocab_size( 97 | per_partition_vocab_size: int, rank, world_size: int 98 | ) -> Sequence[int]: 99 | index_f = rank * per_partition_vocab_size 100 | index_l = index_f + per_partition_vocab_size 101 | return index_f, index_l 102 | 103 | @staticmethod 104 | def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: 105 | per_partition_vocab_size = divide(global_vocab_size, world_size) 106 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 107 | per_partition_vocab_size, rank, world_size 108 | ) 109 | -------------------------------------------------------------------------------- /training/megatron/core/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used throughout Megatron core""" 2 | from functools import reduce 3 | import operator 4 | 5 | import torch 6 | 7 | from megatron.core import parallel_state 8 | 9 | 10 | def ensure_divisibility(numerator, denominator): 11 | """Ensure that numerator is divisible by the denominator.""" 12 | assert numerator % denominator == 0, "{} is not divisible by {}".format( 13 | numerator, denominator 14 | ) 15 | 16 | 17 | def divide(numerator, denominator): 18 | """Ensure that numerator is divisible by the denominator and return 19 | the division value.""" 20 | ensure_divisibility(numerator, denominator) 21 | return numerator // denominator 22 | 23 | def get_attr_wrapped_model(model, attr): 24 | """Get an attribute from a wrapped model""" 25 | if isinstance(model, list): 26 | raise RuntimeError("_get_attr_wrapped_model given a list of models") 27 | 28 | while not hasattr(model, attr): 29 | if not hasattr(model, "module"): 30 | raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}") 31 | 32 | model = model.module 33 | return getattr(model, attr) 34 | 35 | def get_model_type(model): 36 | return get_attr_wrapped_model(model, 'model_type') 37 | 38 | 39 | class GlobalMemoryBuffer: 40 | """Global buffer to avoid dynamic memory allocations. 41 | Caller should ensure that buffers of the same name 42 | are not used concurrently.""" 43 | 44 | def __init__(self): 45 | self.buffer = {} 46 | 47 | def get_tensor(self, tensor_shape, dtype, name): 48 | required_len = reduce(operator.mul, tensor_shape, 1) 49 | if self.buffer.get((name, dtype), None) is None or \ 50 | self.buffer[(name, dtype)].numel() < required_len: 51 | self.buffer[(name, dtype)] = \ 52 | torch.empty(required_len, 53 | dtype=dtype, 54 | device=torch.cuda.current_device(), 55 | requires_grad=False) 56 | 57 | return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) 58 | 59 | def _kernel_make_viewless_tensor(inp, requires_grad): 60 | '''Make a viewless tensor. 61 | 62 | View tensors have the undesirable side-affect of retaining a reference 63 | to the originally-viewed tensor, even after manually setting the '.data' 64 | field. This method creates a new tensor that links to the old tensor's 65 | data, without linking the viewed tensor, referenced via the '._base' 66 | field. 67 | ''' 68 | out = torch.empty( 69 | (1,), 70 | dtype = inp.dtype, 71 | device = inp.device, 72 | requires_grad = requires_grad, 73 | ) 74 | out.data = inp.data 75 | return out 76 | 77 | class MakeViewlessTensor(torch.autograd.Function): 78 | ''' 79 | Autograd function to make a viewless tensor. 80 | 81 | This function should be used in cases where the computation graph needs 82 | to be propagated, but we only want a viewless tensor (e.g., 83 | ParallelTransformer's hidden_states). Call this function by passing 84 | 'keep_graph = True' to 'make_viewless_tensor()'. 85 | ''' 86 | @staticmethod 87 | def forward(ctx, inp, requires_grad): 88 | return _kernel_make_viewless_tensor(inp, requires_grad) 89 | @staticmethod 90 | def backward(ctx, grad_output): 91 | return grad_output, None 92 | 93 | def make_viewless_tensor(inp, requires_grad, keep_graph): 94 | ''' 95 | Entry-point for creating viewless tensors. 96 | 97 | This method should be used, rather than calling 'MakeViewlessTensor' 98 | or '_kernel_make_viewless_tensor' directly. This method acts as a 99 | switch for determining if an autograd function or a regular method 100 | should be used to create the tensor. 101 | ''' 102 | 103 | # return tensor as-is, if not a 'view' 104 | if inp._base is None: 105 | return inp 106 | 107 | # create viewless tensor 108 | if keep_graph: 109 | return MakeViewlessTensor.apply(inp, requires_grad) 110 | else: 111 | return _kernel_make_viewless_tensor(inp, requires_grad) 112 | 113 | def assert_viewless_tensor(tensor, extra_msg = None): 114 | '''Assert that a tensor is not a view (i.e., its '._base' field is 115 | not set).''' 116 | if isinstance(tensor, list): 117 | [ assert_viewless_tensor(t) for t in tensor ] 118 | return tensor 119 | if not isinstance(tensor, torch.Tensor): 120 | return tensor 121 | assert tensor._base is None, ( 122 | "Ensure tensor._base is None before setting tensor.data or storing " 123 | "tensor to memory buffer. Otherwise, a memory leak will occur (and " 124 | "likely accumulate over iterations). %s" 125 | ) % extra_msg 126 | return tensor 127 | 128 | def safely_set_viewless_tensor_data(tensor, new_data_tensor): 129 | '''Safely set tensor's '.data' field. 130 | 131 | Check first that the tensor is viewless (i.e., '._base' not set). If not, 132 | raise an exception. 133 | ''' 134 | assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) 135 | tensor.data = new_data_tensor 136 | -------------------------------------------------------------------------------- /training/megatron/data/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 2 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 3 | LIBNAME = helpers 4 | LIBEXT = $(shell python3-config --extension-suffix) 5 | 6 | default: $(LIBNAME)$(LIBEXT) 7 | 8 | %$(LIBEXT): %.cpp 9 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ 10 | -------------------------------------------------------------------------------- /training/megatron/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import indexed_dataset 2 | -------------------------------------------------------------------------------- /training/megatron/data/blendable_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Blendable dataset.""" 4 | 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from megatron import print_rank_0 11 | 12 | class BlendableDataset(torch.utils.data.Dataset): 13 | 14 | 15 | def __init__(self, datasets, weights): 16 | 17 | self.datasets = datasets 18 | num_datasets = len(datasets) 19 | assert num_datasets == len(weights) 20 | 21 | self.size = 0 22 | for dataset in self.datasets: 23 | self.size += len(dataset) 24 | 25 | # Normalize weights. 26 | weights = np.array(weights, dtype=np.float64) 27 | sum_weights = np.sum(weights) 28 | assert sum_weights > 0.0 29 | weights /= sum_weights 30 | 31 | # Build indecies. 32 | start_time = time.time() 33 | assert num_datasets < 255 34 | self.dataset_index = np.zeros(self.size, dtype=np.uint8) 35 | self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) 36 | 37 | from megatron.data import helpers 38 | helpers.build_blending_indices(self.dataset_index, 39 | self.dataset_sample_index, 40 | weights, num_datasets, self.size, 41 | torch.distributed.get_rank() == 0) 42 | print_rank_0('> elapsed time for building blendable dataset indices: ' 43 | '{:.2f} (sec)'.format(time.time() - start_time)) 44 | 45 | 46 | def __len__(self): 47 | return self.size 48 | 49 | 50 | def __getitem__(self, idx): 51 | dataset_idx = self.dataset_index[idx] 52 | sample_idx = self.dataset_sample_index[idx] 53 | return { 54 | "dataset_idx" : dataset_idx, 55 | **self.datasets[dataset_idx][sample_idx], 56 | } 57 | -------------------------------------------------------------------------------- /training/megatron/data/test/test_indexed_dataset.py: -------------------------------------------------------------------------------- 1 | # This file isn't really a formal automated test, it's just a place to 2 | # put some code used during development and manual testing of 3 | # indexed_dataset. 4 | 5 | from megatron.data import indexed_dataset 6 | from megatron.tokenizer import build_tokenizer 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import torch 12 | 13 | script_dir = os.path.dirname(os.path.realpath(__file__)) 14 | sys.path.append(os.path.join(script_dir, "../../../")) 15 | 16 | 17 | def test_indexed_dataset(args): 18 | ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) 19 | tokenizer = build_tokenizer(args) 20 | print(len(ds.doc_idx)) 21 | print(len(ds)) 22 | print(ds.doc_idx[-1]) 23 | if ds.supports_prefetch: 24 | # just prefetch the whole thing in test (so assume it is small) 25 | ds.prefetch(range(len(ds))) 26 | if args.count > len(ds.doc_idx) - 1: 27 | args.count = len(ds.doc_idx) - 1 28 | 29 | for i in range(args.count): 30 | start = ds.doc_idx[i] 31 | end = ds.doc_idx[i + 1] 32 | ids = ds[start:end] 33 | print(f"Document {i}:") 34 | print("--------------") 35 | for s in ids: 36 | assert len(s) > 0 37 | l = s.data.tolist() 38 | text = tokenizer.detokenize(l) 39 | print(text) 40 | print("---") 41 | 42 | 43 | def test_indexed_dataset_get(args): 44 | ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) 45 | tokenizer = build_tokenizer(args) 46 | size = ds.sizes[0] 47 | print(f"size: {size}") 48 | full = ds.get(0) 49 | print(full) 50 | # print(tokenizer.detokenize(full.data.tolist())) 51 | print("---") 52 | end = ds.get(0, offset=size - 10) 53 | print(end) 54 | # print(tokenizer.detokenize(end.data.tolist())) 55 | 56 | start = ds.get(0, length=10) 57 | print(start) 58 | # print(tokenizer.detokenize(start.data.tolist())) 59 | 60 | part = ds.get(0, offset=2, length=8) 61 | print(part) 62 | # print(tokenizer.detokenize(part.data.tolist())) 63 | 64 | # def test_albert_dataset(args): 65 | # # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) 66 | # # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) 67 | # # ds = AlbertDataset(idataset, tokenizer) 68 | # ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, 69 | # args.epochs, args.max_num_samples, 70 | # args.masked_lm_prob, args.seq_length, 71 | # args.short_seq_prob, args.seed) 72 | # truncated = 0 73 | # total = 0 74 | # for i, s in enumerate(ds): 75 | # ids = s['text'] 76 | # tokens = ds.tokenizer.convert_ids_to_tokens(ids) 77 | # print(tokens) 78 | # if i >= args.count-1: 79 | # exit() 80 | 81 | 82 | def main(): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--data', type=str, help='prefix to data files') 85 | parser.add_argument('--dataset-impl', type=str, default='infer', 86 | choices=['lazy', 'cached', 'mmap', 'infer']) 87 | parser.add_argument('--count', type=int, default=10, 88 | help='Number of samples/documents to print') 89 | 90 | group = parser.add_argument_group(title='tokenizer') 91 | group.add_argument('--tokenizer-type', type=str, required=True, 92 | choices=['BertWordPieceLowerCase', 93 | 'GPT2BPETokenizer'], 94 | help='What type of tokenizer to use.') 95 | group.add_argument('--vocab-file', type=str, default=None, 96 | help='Path to the vocab file') 97 | group.add_argument('--merge-file', type=str, default=None, 98 | help='Path to the BPE merge file (if necessary).') 99 | 100 | parser.add_argument('--epochs', type=int, default=5, 101 | help='Number of epochs to plan for') 102 | parser.add_argument('--max-num-samples', type=int, default=None, 103 | help='Maximum number of samples to plan for') 104 | parser.add_argument('--masked-lm-prob', type=float, default=0.15, 105 | help='probability of masking tokens') 106 | parser.add_argument('--seq-length', type=int, default=512, 107 | help='maximum sequence length') 108 | parser.add_argument('--short-seq-prob', type=float, default=0.1, 109 | help='probability of creating a short sequence') 110 | parser.add_argument('--seed', type=int, default=1234, 111 | help='random seed') 112 | args = parser.parse_args() 113 | args.rank = 0 114 | args.make_vocab_size_divisible_by = 128 115 | args.tensor_model_parallel_size = 1 116 | 117 | if args.dataset_impl == "infer": 118 | args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) 119 | 120 | # test_albert_dataset(args) 121 | test_indexed_dataset_get(args) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /training/megatron/data/test/test_preprocess_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMPL=cached 4 | python ../preprocess_data.py \ 5 | --input test_samples.json \ 6 | --vocab vocab.txt \ 7 | --dataset-impl ${IMPL} \ 8 | --output-prefix test_samples_${IMPL} \ 9 | --workers 1 \ 10 | --log-interval 2 11 | -------------------------------------------------------------------------------- /training/megatron/dist_signal_handler.py: -------------------------------------------------------------------------------- 1 | import signal 2 | 3 | import torch 4 | 5 | 6 | def get_world_size(): 7 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 8 | world_size = torch.distributed.get_world_size() 9 | else: 10 | world_size = 1 11 | return world_size 12 | 13 | 14 | def get_device(local_rank=None): 15 | backend = torch.distributed.get_backend() 16 | if backend == 'nccl': 17 | if local_rank is None: 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device(f'cuda:{local_rank}') 21 | elif backend == 'gloo': 22 | device = torch.device('cpu') 23 | else: 24 | raise RuntimeError 25 | return device 26 | 27 | 28 | def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): 29 | if not torch.distributed.is_available() or \ 30 | not torch.distributed.is_initialized(): 31 | return [item] 32 | 33 | device = get_device(local_rank) 34 | 35 | if group is not None: 36 | group_size = group.size() 37 | else: 38 | group_size = get_world_size() 39 | 40 | tensor = torch.tensor([item], device=device, dtype=dtype) 41 | output_tensors = [ 42 | torch.zeros(1, dtype=tensor.dtype, device=tensor.device) 43 | for _ in range(group_size) 44 | ] 45 | torch.distributed.all_gather(output_tensors, tensor, group, async_op) 46 | output = [elem.item() for elem in output_tensors] 47 | return output 48 | 49 | 50 | class DistributedSignalHandler: 51 | def __init__(self, sig=signal.SIGTERM): 52 | self.sig = sig 53 | 54 | def signals_received(self): 55 | all_received = all_gather_item( 56 | self._signal_received, dtype=torch.int32 57 | ) 58 | return all_received 59 | 60 | def __enter__(self): 61 | self._signal_received = False 62 | self.released = False 63 | self.original_handler = signal.getsignal(self.sig) 64 | 65 | def handler(signum, frame): 66 | self._signal_received = True 67 | 68 | signal.signal(self.sig, handler) 69 | 70 | return self 71 | 72 | def __exit__(self, type, value, tb): 73 | self.release() 74 | 75 | def release(self): 76 | if self.released: 77 | return False 78 | 79 | signal.signal(self.sig, self.original_handler) 80 | self.released = True 81 | return True 82 | -------------------------------------------------------------------------------- /training/megatron/fp16_deprecated/loss_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """For backward compatibility, we need the class definitions to deserialize.""" 4 | 5 | class LossScaler: 6 | def __init__(self, scale=1): 7 | self.cur_scale = scale 8 | 9 | class DynamicLossScaler: 10 | def __init__(self, 11 | init_scale=2**32, 12 | scale_factor=2., 13 | scale_window=1000, 14 | min_scale=1, 15 | delayed_shift=1, 16 | consecutive_hysteresis=False): 17 | self.cur_scale = init_scale 18 | self.cur_iter = 0 19 | self.last_overflow_iter = -1 20 | self.scale_factor = scale_factor 21 | self.scale_window = scale_window 22 | self.min_scale = min_scale 23 | self.delayed_shift = delayed_shift 24 | self.cur_hysteresis = delayed_shift 25 | self.consecutive_hysteresis = consecutive_hysteresis 26 | 27 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import os 4 | import pathlib 5 | import subprocess 6 | 7 | from torch.utils import cpp_extension 8 | 9 | # Setting this param to a list has a problem of generating different 10 | # compilation commands (with diferent order of architectures) and 11 | # leading to recompilation of fused kernels. Set it to empty string 12 | # to avoid recompilation and assign arch flags explicity in 13 | # extra_cuda_cflags below 14 | os.environ["TORCH_CUDA_ARCH_LIST"] = "" 15 | 16 | 17 | def load(args): 18 | 19 | # Check if cuda 11 is installed for compute capability 8.0 20 | cc_flag = [] 21 | _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( 22 | cpp_extension.CUDA_HOME) 23 | if int(bare_metal_major) >= 11: 24 | cc_flag.append('-gencode') 25 | cc_flag.append('arch=compute_80,code=sm_80') 26 | # if int(bare_metal_minor) >= 7: 27 | # cc_flag.append('-gencode') 28 | # cc_flag.append('arch=compute_90,code=sm_90') 29 | 30 | # Build path 31 | srcpath = pathlib.Path(__file__).parent.absolute() 32 | buildpath = srcpath / 'build' 33 | _create_build_dir(buildpath) 34 | 35 | # Helper function to build the kernels. 36 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags): 37 | return cpp_extension.load( 38 | name=name, 39 | sources=sources, 40 | build_directory=buildpath, 41 | extra_cflags=['-O3',], 42 | extra_cuda_cflags=['-O3', 43 | '-gencode', 'arch=compute_70,code=sm_70', 44 | '--use_fast_math'] + extra_cuda_flags + cc_flag, 45 | verbose=(args.rank == 0) 46 | ) 47 | 48 | # ============== 49 | # Fused softmax. 50 | # ============== 51 | 52 | if args.masked_softmax_fusion: 53 | extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', 54 | '-U__CUDA_NO_HALF_CONVERSIONS__', 55 | '--expt-relaxed-constexpr', 56 | '--expt-extended-lambda'] 57 | 58 | # Upper triangular softmax. 59 | sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', 60 | srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] 61 | scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( 62 | "scaled_upper_triang_masked_softmax_cuda", 63 | sources, extra_cuda_flags) 64 | 65 | # Masked softmax. 66 | sources=[srcpath / 'scaled_masked_softmax.cpp', 67 | srcpath / 'scaled_masked_softmax_cuda.cu'] 68 | scaled_masked_softmax_cuda = _cpp_extention_load_helper( 69 | "scaled_masked_softmax_cuda", sources, extra_cuda_flags) 70 | 71 | # Softmax 72 | sources=[srcpath / 'scaled_softmax.cpp', 73 | srcpath / 'scaled_softmax_cuda.cu'] 74 | scaled_softmax_cuda = _cpp_extention_load_helper( 75 | "scaled_softmax_cuda", sources, extra_cuda_flags) 76 | 77 | # ================================= 78 | # Mixed precision fused layer norm. 79 | # ================================= 80 | 81 | extra_hopper_flags = ['-U__CUDA_NO_HALF_OPERATORS__', 82 | '-U__CUDA_NO_HALF_CONVERSIONS__'] 83 | 84 | extra_cuda_flags = ['-maxrregcount=50'] 85 | sources=[srcpath / 'layer_norm_cuda.cpp', 86 | srcpath / 'layer_norm_cuda_kernel.cu'] 87 | fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( 88 | "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + extra_hopper_flags) 89 | 90 | # ================================= 91 | # Fused gradient accumulation to weight gradient computation of linear layer 92 | # ================================= 93 | 94 | if args.gradient_accumulation_fusion: 95 | sources=[srcpath / 'fused_weight_gradient_dense.cpp', 96 | srcpath / 'fused_weight_gradient_dense.cu'] 97 | fused_dense_cuda = _cpp_extention_load_helper( 98 | "fused_dense_cuda", sources, extra_hopper_flags) 99 | 100 | 101 | def _get_cuda_bare_metal_version(cuda_dir): 102 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 103 | universal_newlines=True) 104 | output = raw_output.split() 105 | release_idx = output.index("release") + 1 106 | release = output[release_idx].split(".") 107 | bare_metal_major = release[0] 108 | bare_metal_minor = release[1][0] 109 | 110 | return raw_output, bare_metal_major, bare_metal_minor 111 | 112 | 113 | def _create_build_dir(buildpath): 114 | try: 115 | os.mkdir(buildpath) 116 | except OSError: 117 | if not os.path.isdir(buildpath): 118 | print(f"Creation of the build directory {buildpath} failed") 119 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/compat.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | /*This code is copied fron NVIDIA apex: 4 | * https://github.com/NVIDIA/apex 5 | * with minor changes. */ 6 | 7 | 8 | 9 | #ifndef TORCH_CHECK 10 | #define TORCH_CHECK AT_CHECK 11 | #endif 12 | 13 | #ifdef VERSION_GE_1_3 14 | #define DATA_PTR data_ptr 15 | #else 16 | #define DATA_PTR data 17 | #endif 18 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/fused_weight_gradient_dense.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "type_shim.h" 8 | 9 | 10 | template 11 | int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 12 | 13 | void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) { 14 | at::Tensor input_2d, d_output_2d; 15 | // input tensor: collapse to the first dim 16 | auto in_sizes = input.sizes(); 17 | if (input.dim() > 2) { 18 | input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); 19 | } else { 20 | input_2d = input; 21 | } 22 | // d_output tensor: collapse to the first dim 23 | auto d_out_sizes = d_output.sizes(); 24 | if (d_output.dim() > 2) { 25 | d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); 26 | } else { 27 | d_output_2d = d_output; 28 | } 29 | 30 | int hidden_dim = input_2d.size(0); 31 | int in_dim = input_2d.size(1); 32 | int out_dim = d_weight.size(0); 33 | 34 | DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32", 35 | int result = wgrad_gemm_accum_fp32_cuda( 36 | input_2d.data_ptr(), 37 | d_output_2d.data_ptr(), 38 | d_weight.data_ptr(), 39 | in_dim, 40 | hidden_dim, 41 | out_dim); 42 | ); 43 | } 44 | 45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 46 | m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32"); 47 | } 48 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/fused_weight_gradient_dense.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | /* Includes, cuda */ 10 | #include 11 | #include 12 | 13 | 14 | // BF16 Tensor core wrapper around cublas GEMMEx 15 | cublasStatus_t gemmex_wrapper( 16 | cublasHandle_t handle, 17 | cublasOperation_t transa, 18 | cublasOperation_t transb, 19 | int m, 20 | int n, 21 | int k, 22 | const float* alpha, 23 | at::BFloat16* A, 24 | int lda, 25 | at::BFloat16* B, 26 | int ldb, 27 | const float* beta, 28 | float* C, 29 | int ldc) { 30 | return cublasGemmEx( 31 | handle, 32 | transa, 33 | transb, 34 | m, 35 | n, 36 | k, 37 | alpha, 38 | A, 39 | CUDA_R_16BF, 40 | lda, 41 | B, 42 | CUDA_R_16BF, 43 | ldb, 44 | beta, 45 | C, 46 | CUDA_R_32F, 47 | ldc, 48 | CUDA_R_32F, 49 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 50 | } 51 | 52 | // FP16 Tensor core wrapper around cublas GEMMEx 53 | cublasStatus_t gemmex_wrapper( 54 | cublasHandle_t handle, 55 | cublasOperation_t transa, 56 | cublasOperation_t transb, 57 | int m, 58 | int n, 59 | int k, 60 | const float* alpha, 61 | at::Half* A, 62 | int lda, 63 | at::Half* B, 64 | int ldb, 65 | const float* beta, 66 | float* C, 67 | int ldc) { 68 | return cublasGemmEx( 69 | handle, 70 | transa, 71 | transb, 72 | m, 73 | n, 74 | k, 75 | alpha, 76 | A, 77 | CUDA_R_16F, 78 | lda, 79 | B, 80 | CUDA_R_16F, 81 | ldb, 82 | beta, 83 | C, 84 | CUDA_R_32F, 85 | ldc, 86 | CUDA_R_32F, 87 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 88 | } 89 | 90 | // FP32 Tensor core wrapper around cublas GEMMEx 91 | cublasStatus_t gemmex_wrapper( 92 | cublasHandle_t handle, 93 | cublasOperation_t transa, 94 | cublasOperation_t transb, 95 | int m, 96 | int n, 97 | int k, 98 | const float* alpha, 99 | float* A, 100 | int lda, 101 | float* B, 102 | int ldb, 103 | const float* beta, 104 | float* C, 105 | int ldc) { 106 | return cublasGemmEx( 107 | handle, 108 | transa, 109 | transb, 110 | m, 111 | n, 112 | k, 113 | alpha, 114 | A, 115 | CUDA_R_32F, 116 | lda, 117 | B, 118 | CUDA_R_32F, 119 | ldb, 120 | beta, 121 | C, 122 | CUDA_R_32F, 123 | ldc, 124 | CUDA_R_32F, 125 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 126 | } 127 | 128 | template 129 | int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { 130 | cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); 131 | cudaStream_t stream; 132 | cublasGetStream(handle, &stream); 133 | const float alpha = 1.0; 134 | const float beta = 1.0; 135 | int status = 1; 136 | 137 | status = gemmex_wrapper( 138 | handle, 139 | CUBLAS_OP_N, 140 | CUBLAS_OP_T, 141 | in_dim, 142 | out_dim, 143 | hidden_dim, 144 | &alpha, 145 | input, 146 | in_dim, 147 | d_output, 148 | out_dim, 149 | &beta, 150 | d_weight, 151 | in_dim); 152 | return status; 153 | } 154 | 155 | template int wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 156 | template int wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 157 | template int wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); 158 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/layer_norm_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | /*This code is copied fron NVIDIA apex: 4 | * https://github.com/NVIDIA/apex 5 | * with minor changes. */ 6 | 7 | #include 8 | #include 9 | #include 10 | #include "compat.h" 11 | 12 | namespace { 13 | 14 | void compute_n1_n2( 15 | at::Tensor input, 16 | at::IntArrayRef normalized_shape, 17 | int& n1, 18 | int& n2) { 19 | int idiff = input.ndimension() - normalized_shape.size(); 20 | n2 = 1; 21 | for (int i = 0; i < (int)normalized_shape.size(); ++i) { 22 | assert( input.sizes()[i+idiff] == normalized_shape[i] ); 23 | n2 *= normalized_shape[i]; 24 | } 25 | n1 = 1; 26 | for (int i = 0; i < idiff; ++i) { 27 | n1 *= input.sizes()[i]; 28 | } 29 | } 30 | 31 | void check_args( 32 | at::IntArrayRef normalized_shape, 33 | at::Tensor gamma, 34 | at::Tensor beta 35 | ) 36 | { 37 | TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); 38 | TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); 39 | } 40 | 41 | void check_args( 42 | at::Tensor input, 43 | at::IntArrayRef normalized_shape, 44 | int& n1, 45 | int& n2 46 | ) 47 | { 48 | int64_t normalized_ndim = normalized_shape.size(); 49 | 50 | if (normalized_ndim < 1) { 51 | std::stringstream ss; 52 | ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " 53 | << "containing at least one element, but got normalized_shape=" 54 | << normalized_shape; 55 | throw std::runtime_error(ss.str()); 56 | } 57 | 58 | auto input_shape = input.sizes(); 59 | auto input_ndim = input.dim(); 60 | 61 | if (input_ndim < normalized_ndim || 62 | !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { 63 | std::stringstream ss; 64 | ss << "Given normalized_shape=" << normalized_shape 65 | << ", expected input with shape [*"; 66 | for (auto size : normalized_shape) { 67 | ss << ", " << size; 68 | } 69 | ss << "], but got input of size" << input_shape; 70 | throw std::runtime_error(ss.str()); 71 | } 72 | 73 | compute_n1_n2(input,normalized_shape,n1,n2); 74 | } 75 | 76 | 77 | void check_args( 78 | at::Tensor input, 79 | at::IntArrayRef normalized_shape, 80 | at::Tensor gamma, 81 | at::Tensor beta, 82 | int& n1, 83 | int& n2 84 | ) 85 | { 86 | check_args(input,normalized_shape,n1,n2); 87 | check_args(normalized_shape,gamma,beta); 88 | } 89 | } 90 | 91 | void cuda_layer_norm( 92 | at::Tensor* output, 93 | at::Tensor* mean, 94 | at::Tensor* invvar, 95 | at::Tensor* input, 96 | int n1, 97 | int n2, 98 | at::IntArrayRef normalized_shape, 99 | at::Tensor* gamma, 100 | at::Tensor* beta, 101 | double epsilon); 102 | 103 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 104 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 105 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 106 | 107 | std::vector layer_norm_affine( 108 | at::Tensor input, 109 | at::IntArrayRef normalized_shape, 110 | at::Tensor gamma, 111 | at::Tensor beta, 112 | double epsilon) { 113 | 114 | CHECK_INPUT(input); 115 | CHECK_INPUT(gamma); 116 | CHECK_INPUT(beta); 117 | int n1, n2; 118 | check_args(input, normalized_shape, gamma, beta, n1, n2); 119 | 120 | at::Tensor output = at::empty_like( 121 | input, gamma.options().dtype(gamma.scalar_type())); 122 | at::Tensor mean = at::empty( 123 | {n1}, input.options().dtype(at::ScalarType::Float)); 124 | at::Tensor invvar = at::empty_like(mean); 125 | 126 | cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, 127 | normalized_shape, &gamma, &beta, epsilon); 128 | 129 | return {output, mean, invvar}; 130 | 131 | } 132 | 133 | 134 | void cuda_layer_norm_gradient( 135 | at::Tensor* dout, 136 | at::Tensor* mean, 137 | at::Tensor* invvar, 138 | at::Tensor* input, 139 | int n1, 140 | int n2, 141 | at::IntArrayRef normalized_shape, 142 | at::Tensor* gamma, 143 | at::Tensor* beta, 144 | double epsilon, 145 | at::Tensor* grad_input, 146 | at::Tensor* grad_gamma, 147 | at::Tensor* grad_beta 148 | ); 149 | 150 | std::vector layer_norm_gradient_affine( 151 | at::Tensor dout, 152 | at::Tensor mean, 153 | at::Tensor invvar, 154 | at::Tensor input, 155 | at::IntArrayRef normalized_shape, 156 | at::Tensor gamma, 157 | at::Tensor beta, 158 | double epsilon) { 159 | 160 | CHECK_INPUT(dout); 161 | CHECK_INPUT(mean); 162 | CHECK_INPUT(invvar); 163 | CHECK_INPUT(input); 164 | CHECK_INPUT(gamma); 165 | CHECK_INPUT(beta); 166 | int n1, n2; 167 | check_args(input, normalized_shape, gamma, beta, n1, n2); 168 | 169 | at::Tensor grad_input = at::empty_like(input); 170 | at::Tensor grad_gamma = at::empty_like(gamma); 171 | at::Tensor grad_beta = at::empty_like(beta); 172 | 173 | cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, 174 | normalized_shape, &gamma, &beta, epsilon, 175 | &grad_input, &grad_gamma, &grad_beta); 176 | 177 | return {grad_input, grad_gamma, grad_beta}; 178 | 179 | } 180 | 181 | 182 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 183 | m.def("forward_affine", &layer_norm_affine, 184 | "LayerNorm forward (CUDA)"); 185 | m.def("backward_affine", &layer_norm_gradient_affine, 186 | "LayerNorm backward (CUDA)"); 187 | } 188 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/scaled_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_masked_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | torch::Tensor const& mask, 14 | float scale_factor); 15 | 16 | torch::Tensor bwd_cuda( 17 | torch::Tensor const& output_grads, 18 | torch::Tensor const& softmax_results, 19 | float scale_factor); 20 | 21 | int get_batch_per_block_cuda( 22 | int query_seq_len, 23 | int key_seq_len, 24 | int batches, 25 | int attn_heads); 26 | 27 | torch::Tensor fwd( 28 | torch::Tensor const& input, 29 | torch::Tensor const& mask, 30 | float scale_factor) { 31 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 32 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 33 | (input.scalar_type() == at::ScalarType::BFloat16), 34 | "Only fp16 and bf16 are supported"); 35 | AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); 36 | 37 | return fwd_cuda(input, mask, scale_factor); 38 | } 39 | 40 | torch::Tensor bwd( 41 | torch::Tensor const& output_grads, 42 | torch::Tensor const& softmax_results, 43 | float scale_factor) { 44 | 45 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 46 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 47 | 48 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 49 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 50 | "Only fp16 and bf16 are supported"); 51 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 52 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 53 | "Only fp16 and bf16 are supported"); 54 | 55 | return bwd_cuda(output_grads, softmax_results, scale_factor); 56 | } 57 | 58 | int get_batch_per_block( 59 | int query_seq_len, 60 | int key_seq_len, 61 | int batches, 62 | int attn_heads) { 63 | return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); 64 | } 65 | 66 | } // end namespace scaled_masked_softmax 67 | } // end namespace fused_softmax 68 | } // end namespace multihead_attn 69 | 70 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 71 | m.def("forward", 72 | &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 73 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 74 | 75 | m.def("backward", 76 | &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, 77 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 78 | 79 | m.def("get_batch_per_block", 80 | &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, 81 | "Return Batch per block size." 82 | ); 83 | } 84 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/scaled_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_masked_softmax { 16 | 17 | int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ 18 | return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); 19 | } 20 | 21 | 22 | torch::Tensor fwd_cuda( 23 | torch::Tensor const& input, 24 | torch::Tensor const& mask, 25 | float scale_factor) 26 | { 27 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 28 | const int batches = input.size(0); 29 | const int pad_batches = mask.size(0); 30 | const int attn_heads = input.size(1); 31 | const int query_seq_len = input.size(2); 32 | const int key_seq_len = input.size(3); 33 | TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); 34 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 35 | TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); 36 | TORCH_INTERNAL_ASSERT(mask.size(1) == 1); 37 | TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); 38 | TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); 39 | 40 | // Output 41 | auto act_options = input.options().requires_grad(false); 42 | torch::Tensor softmax_results = 43 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 44 | 45 | // Softmax Intermediate Result Ptr 46 | void* input_ptr = static_cast(input.data_ptr()); 47 | void* mask_ptr = static_cast(mask.data_ptr()); 48 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 49 | 50 | DISPATCH_HALF_AND_BFLOAT( 51 | input.scalar_type(), 52 | "dispatch_scaled_masked_softmax_forward", 53 | dispatch_scaled_masked_softmax_forward( 54 | reinterpret_cast(softmax_results_ptr), 55 | reinterpret_cast(input_ptr), 56 | reinterpret_cast(mask_ptr), 57 | scale_factor, 58 | query_seq_len, 59 | key_seq_len, 60 | batches, 61 | attn_heads, 62 | pad_batches); 63 | ); 64 | return softmax_results; 65 | } 66 | 67 | torch::Tensor bwd_cuda( 68 | torch::Tensor const& output_grads_, 69 | torch::Tensor const& softmax_results_, 70 | float scale_factor) { 71 | 72 | auto output_grads = output_grads_.contiguous(); 73 | auto softmax_results = softmax_results_.contiguous(); 74 | 75 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 76 | const int batches = output_grads.size(0); 77 | const int attn_heads = output_grads.size(1); 78 | const int query_seq_len = output_grads.size(2); 79 | const int key_seq_len = output_grads.size(3); 80 | 81 | auto act_options = output_grads.options().requires_grad(false); 82 | torch::Tensor input_grads = 83 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 84 | 85 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 86 | void* input_grads_ptr = static_cast(input_grads.data_ptr()); 87 | 88 | //Softmax Grad 89 | DISPATCH_HALF_AND_BFLOAT( 90 | output_grads_.scalar_type(), 91 | "dispatch_scaled_masked_softmax_backward", 92 | dispatch_scaled_masked_softmax_backward( 93 | reinterpret_cast(input_grads_ptr), 94 | reinterpret_cast(output_grads_ptr), 95 | reinterpret_cast(softmax_results.data_ptr()), 96 | scale_factor, 97 | query_seq_len, 98 | key_seq_len, 99 | batches, 100 | attn_heads); 101 | ); 102 | 103 | return input_grads; 104 | } 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/scaled_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | float scale_factor); 14 | 15 | torch::Tensor bwd_cuda( 16 | torch::Tensor const& output_grads, 17 | torch::Tensor const& softmax_results, 18 | float scale_factor); 19 | 20 | torch::Tensor fwd( 21 | torch::Tensor const& input, 22 | float scale_factor) { 23 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 24 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 25 | (input.scalar_type() == at::ScalarType::BFloat16), 26 | "Only fp16 and bf16 are supported"); 27 | 28 | return fwd_cuda(input, scale_factor); 29 | } 30 | 31 | torch::Tensor bwd( 32 | torch::Tensor const& output_grads, 33 | torch::Tensor const& softmax_results, 34 | float scale_factor) { 35 | 36 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 37 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 38 | 39 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 40 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 41 | "Only fp16 and bf16 are supported"); 42 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 43 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 44 | "Only fp16 and bf16 are supported"); 45 | 46 | return bwd_cuda(output_grads, softmax_results, scale_factor); 47 | } 48 | 49 | } // end namespace scaled_softmax 50 | } // end namespace fused_softmax 51 | } // end namespace multihead_attn 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", 55 | &multihead_attn::fused_softmax::scaled_softmax::fwd, 56 | "Self Multihead Attention scaled, softmax -- Forward."); 57 | m.def("backward", 58 | &multihead_attn::fused_softmax::scaled_softmax::bwd, 59 | "Self Multihead Attention scaled, softmax -- Backward."); 60 | } 61 | 62 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/scaled_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_softmax { 16 | 17 | torch::Tensor fwd_cuda( 18 | torch::Tensor const& input, 19 | float scale_factor) 20 | { 21 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 22 | const int batches = input.size(0); 23 | const int attn_heads = input.size(1); 24 | const int query_seq_len = input.size(2); 25 | const int key_seq_len = input.size(3); 26 | TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); 27 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 28 | 29 | // Output 30 | auto act_options = input.options().requires_grad(false); 31 | torch::Tensor softmax_results = 32 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 33 | 34 | // Softmax Intermediate Result Ptr 35 | void* input_ptr = static_cast(input.data_ptr()); 36 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 37 | 38 | DISPATCH_HALF_AND_BFLOAT( 39 | input.scalar_type(), 40 | "dispatch_scaled_softmax_forward", 41 | dispatch_scaled_softmax_forward( 42 | reinterpret_cast(softmax_results_ptr), 43 | reinterpret_cast(input_ptr), 44 | scale_factor, 45 | query_seq_len, 46 | key_seq_len, 47 | batches, 48 | attn_heads); 49 | ); 50 | return softmax_results; 51 | } 52 | 53 | torch::Tensor bwd_cuda( 54 | torch::Tensor const& output_grads_, 55 | torch::Tensor const& softmax_results_, 56 | float scale_factor) { 57 | 58 | auto output_grads = output_grads_.contiguous(); 59 | auto softmax_results = softmax_results_.contiguous(); 60 | 61 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 62 | const int batches = output_grads.size(0); 63 | const int attn_heads = output_grads.size(1); 64 | const int query_seq_len = output_grads.size(2); 65 | const int key_seq_len = output_grads.size(3); 66 | 67 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 68 | 69 | //Softmax Grad 70 | DISPATCH_HALF_AND_BFLOAT( 71 | output_grads_.scalar_type(), 72 | "dispatch_scaled_masked_softmax_backward", 73 | dispatch_scaled_masked_softmax_backward( 74 | reinterpret_cast(output_grads_ptr), 75 | reinterpret_cast(output_grads_ptr), 76 | reinterpret_cast(softmax_results.data_ptr()), 77 | scale_factor, 78 | query_seq_len, 79 | key_seq_len, 80 | batches, 81 | attn_heads); 82 | ); 83 | 84 | //backward pass is completely in-place 85 | return output_grads; 86 | } 87 | } 88 | } 89 | } 90 | 91 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_upper_triang_masked_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | float scale_factor); 14 | 15 | torch::Tensor bwd_cuda( 16 | torch::Tensor const& output_grads, 17 | torch::Tensor const& softmax_results, 18 | float scale_factor); 19 | 20 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 21 | AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); 22 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 23 | (input.scalar_type() == at::ScalarType::BFloat16), 24 | "Only fp16 and bf16 are supported"); 25 | 26 | return fwd_cuda(input, scale_factor); 27 | } 28 | 29 | torch::Tensor bwd( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor) { 33 | 34 | AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); 35 | AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); 36 | 37 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 38 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 39 | "Only fp16 and bf16 are supported"); 40 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 41 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 42 | "Only fp16 and bf16 are supported"); 43 | 44 | return bwd_cuda(output_grads, softmax_results, scale_factor); 45 | } 46 | 47 | } // end namespace scaled_upper_triang_masked_softmax 48 | } // end namespace fused_softmax 49 | } // end namespace multihead_attn 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", 53 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 54 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 55 | m.def("backward", 56 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, 57 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 58 | } 59 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_upper_triang_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_upper_triang_masked_softmax { 16 | 17 | torch::Tensor fwd_cuda( 18 | torch::Tensor const& input, 19 | float scale_factor) 20 | { 21 | // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 22 | const int attn_batches = input.size(0); 23 | const int seq_len = input.size(1); 24 | TORCH_INTERNAL_ASSERT(seq_len <= 2048); 25 | 26 | // Output 27 | auto act_options = input.options().requires_grad(false); 28 | torch::Tensor softmax_results = 29 | torch::empty({attn_batches, seq_len, seq_len}, act_options); 30 | 31 | // Softmax Intermediate Result Ptr 32 | void* input_ptr = static_cast(input.data_ptr()); 33 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 34 | 35 | DISPATCH_HALF_AND_BFLOAT( 36 | input.scalar_type(), 37 | "dispatch_scaled_upper_triang_masked_softmax_forward", 38 | dispatch_scaled_upper_triang_masked_softmax_forward( 39 | reinterpret_cast(softmax_results_ptr), 40 | reinterpret_cast(input_ptr), 41 | scale_factor, 42 | seq_len, 43 | seq_len, 44 | attn_batches); 45 | ); 46 | return softmax_results; 47 | } 48 | 49 | 50 | torch::Tensor bwd_cuda( 51 | torch::Tensor const& output_grads_, 52 | torch::Tensor const& softmax_results_, 53 | float scale_factor) { 54 | 55 | auto output_grads = output_grads_.contiguous(); 56 | auto softmax_results = softmax_results_.contiguous(); 57 | 58 | //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 59 | const int attn_batches = output_grads.size(0); 60 | const int seq_len = output_grads.size(1); 61 | TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); 62 | 63 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 64 | 65 | //Softmax Grad 66 | DISPATCH_HALF_AND_BFLOAT( 67 | output_grads_.scalar_type(), 68 | "dispatch_scaled_upper_triang_masked_softmax_backward", 69 | dispatch_scaled_upper_triang_masked_softmax_backward( 70 | reinterpret_cast(output_grads_ptr), 71 | reinterpret_cast(output_grads_ptr), 72 | reinterpret_cast(softmax_results.data_ptr()), 73 | scale_factor, 74 | seq_len, 75 | seq_len, 76 | attn_batches); 77 | ); 78 | 79 | //backward pass is completely in-place 80 | return output_grads; 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /training/megatron/fused_kernels/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLG/OpenBA/ef4716b7e588f17096043eef773557be41f2d7ed/training/megatron/fused_kernels/tests/__init__.py -------------------------------------------------------------------------------- /training/megatron/fused_kernels/type_shim.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | 4 | #include 5 | #include "compat.h" 6 | 7 | 8 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 9 | switch(TYPE) \ 10 | { \ 11 | case at::ScalarType::Half: \ 12 | { \ 13 | using scalar_t = at::Half; \ 14 | __VA_ARGS__; \ 15 | break; \ 16 | } \ 17 | case at::ScalarType::BFloat16: \ 18 | { \ 19 | using scalar_t = at::BFloat16; \ 20 | __VA_ARGS__; \ 21 | break; \ 22 | } \ 23 | default: \ 24 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 25 | } 26 | 27 | 28 | #define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ 29 | switch(TYPE) \ 30 | { \ 31 | case at::ScalarType::Half: \ 32 | { \ 33 | using scalar_t = at::Half; \ 34 | __VA_ARGS__; \ 35 | break; \ 36 | } \ 37 | case at::ScalarType::BFloat16: \ 38 | { \ 39 | using scalar_t = at::BFloat16; \ 40 | __VA_ARGS__; \ 41 | break; \ 42 | } \ 43 | case at::ScalarType::Float: \ 44 | { \ 45 | using scalar_t = float; \ 46 | __VA_ARGS__; \ 47 | break; \ 48 | } \ 49 | default: \ 50 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 51 | } 52 | 53 | 54 | 55 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ 56 | switch(TYPEIN) \ 57 | { \ 58 | case at::ScalarType::Float: \ 59 | { \ 60 | using scalar_t_in = float; \ 61 | switch(TYPEOUT) \ 62 | { \ 63 | case at::ScalarType::Float: \ 64 | { \ 65 | using scalar_t_out = float; \ 66 | __VA_ARGS__; \ 67 | break; \ 68 | } \ 69 | case at::ScalarType::Half: \ 70 | { \ 71 | using scalar_t_out = at::Half; \ 72 | __VA_ARGS__; \ 73 | break; \ 74 | } \ 75 | case at::ScalarType::BFloat16: \ 76 | { \ 77 | using scalar_t_out = at::BFloat16; \ 78 | __VA_ARGS__; \ 79 | break; \ 80 | } \ 81 | default: \ 82 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ 83 | } \ 84 | break; \ 85 | } \ 86 | case at::ScalarType::Half: \ 87 | { \ 88 | using scalar_t_in = at::Half; \ 89 | using scalar_t_out = at::Half; \ 90 | __VA_ARGS__; \ 91 | break; \ 92 | } \ 93 | case at::ScalarType::BFloat16: \ 94 | { \ 95 | using scalar_t_in = at::BFloat16; \ 96 | using scalar_t_out = at::BFloat16; \ 97 | __VA_ARGS__; \ 98 | break; \ 99 | } \ 100 | default: \ 101 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ 102 | } 103 | 104 | -------------------------------------------------------------------------------- /training/megatron/indexer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from megatron import get_args, print_rank_0 7 | from megatron.core import mpu 8 | from megatron.checkpointing import load_biencoder_checkpoint 9 | from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset 10 | from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch 11 | from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader 12 | from megatron.data.realm_index import detach, OpenRetreivalDataStore 13 | from megatron.model.biencoder_model import get_model_provider 14 | from megatron.training import get_model 15 | 16 | 17 | class IndexBuilder(object): 18 | """ 19 | Object for taking one pass over a dataset and creating a BlockData of its 20 | embeddings 21 | """ 22 | def __init__(self): 23 | args = get_args() 24 | self.model = None 25 | self.dataloader = None 26 | self.evidence_embedder_obj = None 27 | self.biencoder_shared_query_context_model = \ 28 | args.biencoder_shared_query_context_model 29 | 30 | # need to know whether we're using a REALM checkpoint (args.load) 31 | # or ICT checkpoint 32 | assert not (args.load and args.ict_load) 33 | 34 | self.log_interval = args.indexer_log_interval 35 | self.batch_size = args.indexer_batch_size 36 | 37 | self.load_attributes() 38 | self.is_main_builder = mpu.get_data_parallel_rank() == 0 39 | self.num_total_builders = mpu.get_data_parallel_world_size() 40 | self.iteration = self.total_processed = 0 41 | 42 | def load_attributes(self): 43 | """ 44 | Load the necessary attributes: model, dataloader and empty BlockData 45 | """ 46 | only_context_model = True 47 | if self.biencoder_shared_query_context_model: 48 | only_context_model = False 49 | 50 | model = get_model(get_model_provider(only_context_model=\ 51 | only_context_model, biencoder_shared_query_context_model=\ 52 | self.biencoder_shared_query_context_model)) 53 | 54 | self.model = load_biencoder_checkpoint(model, 55 | only_context_model=only_context_model) 56 | 57 | assert len(self.model) == 1 58 | self.model[0].eval() 59 | 60 | self.dataset = get_open_retrieval_wiki_dataset() 61 | self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ 62 | self.batch_size)) 63 | 64 | self.evidence_embedder_obj = OpenRetreivalDataStore( \ 65 | load_from_path=False) 66 | 67 | def track_and_report_progress(self, batch_size): 68 | """ 69 | Utility function for tracking progress 70 | """ 71 | self.iteration += 1 72 | self.total_processed += batch_size * self.num_total_builders 73 | if self.is_main_builder and self.iteration % self.log_interval == 0: 74 | print('Batch {:10d} | Total {:10d}'.format(self.iteration, 75 | self.total_processed), flush=True) 76 | 77 | def build_and_save_index(self): 78 | """ 79 | Goes through one epoch of the dataloader and adds all data to this 80 | instance's BlockData. 81 | 82 | The copy of BlockData is saved as a shard, which when run in a 83 | distributed setting will be consolidated by the rank 0 process 84 | and saved as a final pickled BlockData. 85 | """ 86 | assert len(self.model) == 1 87 | unwrapped_model = self.model[0] 88 | 89 | while not hasattr(unwrapped_model, 'embed_text'): 90 | unwrapped_model = unwrapped_model.module 91 | 92 | while True: 93 | try: 94 | # batch also has query_tokens and query_pad_data 95 | row_id, context_tokens, context_mask, context_types, \ 96 | context_pad_mask = get_open_retrieval_batch( \ 97 | self.dataloader) 98 | except (StopIteration, IndexError): 99 | break 100 | 101 | # TODO: can we add with torch.no_grad() to reduce memory usage 102 | # detach, separate fields and add to BlockData 103 | assert context_mask.dtype == torch.bool 104 | context_logits = unwrapped_model.embed_text( 105 | unwrapped_model.context_model, context_tokens, context_mask, 106 | context_types) 107 | 108 | context_logits = detach(context_logits) 109 | row_id = detach(row_id) 110 | 111 | self.evidence_embedder_obj.add_block_data(row_id, context_logits) 112 | self.track_and_report_progress(batch_size=len(row_id)) 113 | 114 | # This process signals to finalize its shard and then synchronize with 115 | # the other processes 116 | self.evidence_embedder_obj.save_shard() 117 | torch.distributed.barrier() 118 | del self.model 119 | 120 | # rank 0 process builds the final copy 121 | if self.is_main_builder: 122 | self.evidence_embedder_obj.merge_shards_and_save() 123 | # make sure that every single piece of data was embedded 124 | assert len(self.evidence_embedder_obj.embed_data) == \ 125 | len(self.dataset) 126 | self.evidence_embedder_obj.clear() 127 | 128 | # complete building the final copy 129 | torch.distributed.barrier() 130 | -------------------------------------------------------------------------------- /training/megatron/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | import torch 5 | 6 | 7 | # A dictionary of all the memory buffers allocated. 8 | _MEM_BUFFS = dict() 9 | 10 | 11 | def allocate_mem_buff(name, numel, dtype, track_usage): 12 | """Allocate a memory buffer.""" 13 | assert name not in _MEM_BUFFS, \ 14 | 'memory buffer {} already allocated.'.format(name) 15 | _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) 16 | return _MEM_BUFFS[name] 17 | 18 | 19 | def get_mem_buff(name): 20 | """Get the memory buffer.""" 21 | return _MEM_BUFFS[name] 22 | 23 | 24 | class MemoryBuffer: 25 | """Contiguous memory buffer. 26 | Allocate a contiguous memory of type `dtype` and size `numel`. It is 27 | used to reduce memory fragmentation. 28 | 29 | Usage: After the allocation, the `_start` index is set tot the first 30 | index of the memory. A memory chunk starting from `_start` index 31 | can be `allocated` for an input tensor, with the elements of the 32 | tensor being coppied. The buffer can be reused by resetting the 33 | `_start` index. 34 | 35 | """ 36 | def __init__(self, name, numel, dtype, track_usage): 37 | if torch.distributed.get_rank() == 0: 38 | element_size = torch.tensor([], dtype=dtype).element_size() 39 | print('> building the {} memory buffer with {} num elements ' 40 | 'and {} dtype ({:.1f} MB)...'.format( 41 | name, numel, dtype, numel*element_size/1024/1024), 42 | flush=True) 43 | self.name = name 44 | self.numel = numel 45 | self.dtype = dtype 46 | self.data = torch.empty(self.numel, 47 | dtype=self.dtype, 48 | device=torch.cuda.current_device(), 49 | requires_grad=False) 50 | 51 | # Index tracking the start of the free memory. 52 | self._start = 0 53 | 54 | # Values used for tracking usage. 55 | self.track_usage = track_usage 56 | if self.track_usage: 57 | self.in_use_value = 0.0 58 | self.total_value = 0.0 59 | 60 | 61 | def reset(self): 62 | """Reset the buffer start index to the beginning of the buffer.""" 63 | self._start = 0 64 | 65 | 66 | def is_in_use(self): 67 | """Whether the current buffer hold on to any memory.""" 68 | return self._start > 0 69 | 70 | 71 | def numel_in_use(self): 72 | """Return number of elements in use.""" 73 | return self._start 74 | 75 | 76 | def add(self, tensor): 77 | """Allocate a chunk of memory from the buffer to tensor and copy 78 | the values.""" 79 | assert tensor.dtype == self.dtype, \ 80 | 'Input tensor type {} different from buffer type {}'.format( 81 | tensor.dtype, self.dtype) 82 | # Number of elements of the input tensor. 83 | tensor_numel = torch.numel(tensor) 84 | new_start = self._start + tensor_numel 85 | assert new_start <= self.numel, \ 86 | 'Not enough memory left in the buffer ({} > {})'.format( 87 | tensor_numel, self.numel - self._start) 88 | # New tensor is a view into the memory. 89 | new_tensor = self.data[self._start:new_start] 90 | self._start = new_start 91 | new_tensor = new_tensor.view(tensor.shape) 92 | new_tensor.copy_(tensor) 93 | # Return a pointer to the new tensor. 94 | return new_tensor 95 | 96 | 97 | def get_data(self): 98 | """Return the data currently in use.""" 99 | if self.track_usage: 100 | self.in_use_value += float(self._start) 101 | self.total_value += float(self.numel) 102 | return self.data[:self._start] 103 | 104 | 105 | def print_average_usage(self): 106 | """Print memory usage average over time. We would like this value 107 | to be as high as possible.""" 108 | assert self.track_usage, 'You need to enable track usage.' 109 | if torch.distributed.get_rank() == 0: 110 | print(' > usage of {} memory buffer: {:.2f} %'.format( 111 | self.name, self.in_use_value * 100.0 / self.total_value), 112 | flush=True) 113 | 114 | 115 | 116 | class RingMemBuffer: 117 | """A ring of memory buffers.""" 118 | 119 | def __init__(self, name, num_buffers, numel, dtype, track_usage): 120 | self.num_buffers = num_buffers 121 | self.buffers = [ 122 | allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) 123 | for i in range(num_buffers)] 124 | self._index = -1 125 | 126 | 127 | def get_next_buffer(self): 128 | self._index += 1 129 | self._index = self._index % self.num_buffers 130 | buff = self.buffers[self._index] 131 | assert not buff.is_in_use(), 'buffer is already in use.' 132 | return buff 133 | -------------------------------------------------------------------------------- /training/megatron/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm 4 | 5 | from .distributed import DistributedDataParallel 6 | from .bert_model import BertModel 7 | from .gpt_model import GPTModel 8 | from .t5_model import T5Model 9 | from .language_model import get_language_model 10 | from .module import Float16Module 11 | -------------------------------------------------------------------------------- /training/megatron/model/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Classification model.""" 4 | 5 | import torch 6 | 7 | from megatron import get_args, print_rank_last 8 | from megatron.model.enums import AttnMaskType 9 | from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids 10 | from megatron.model.language_model import get_language_model 11 | from megatron.model.utils import get_linear_layer 12 | from megatron.model.utils import init_method_normal 13 | from megatron.model.utils import scaled_init_method_normal 14 | from .module import MegatronModule 15 | 16 | 17 | class Classification(MegatronModule): 18 | 19 | def __init__(self, 20 | num_classes, 21 | num_tokentypes=2, 22 | pre_process=True, 23 | post_process=True): 24 | super(Classification, self).__init__(share_word_embeddings=False) 25 | args = get_args() 26 | 27 | self.num_classes = num_classes 28 | self.pre_process = pre_process 29 | self.post_process = post_process 30 | init_method = init_method_normal(args.init_method_std) 31 | 32 | self.language_model, self._language_model_key = get_language_model( 33 | num_tokentypes=num_tokentypes, 34 | add_pooler=True, 35 | encoder_attn_mask_type=AttnMaskType.padding, 36 | init_method=init_method, 37 | scaled_init_method=scaled_init_method_normal(args.init_method_std, 38 | args.num_layers), 39 | pre_process=self.pre_process, 40 | post_process=self.post_process) 41 | 42 | # Multi-choice head. 43 | if self.post_process: 44 | self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) 45 | self.classification_head = get_linear_layer(args.hidden_size, 46 | self.num_classes, 47 | init_method) 48 | self._classification_head_key = 'classification_head' 49 | 50 | def set_input_tensor(self, input_tensor): 51 | """See megatron.model.transformer.set_input_tensor()""" 52 | self.language_model.set_input_tensor(input_tensor) 53 | 54 | def forward(self, model_input, attention_mask, tokentype_ids=None): 55 | 56 | extended_attention_mask = bert_extended_attention_mask(attention_mask) 57 | input_ids = model_input 58 | position_ids = bert_position_ids(input_ids) 59 | 60 | lm_output = self.language_model( 61 | input_ids, 62 | position_ids, 63 | extended_attention_mask, 64 | tokentype_ids=tokentype_ids 65 | ) 66 | 67 | if self.post_process: 68 | _, pooled_output = lm_output 69 | classification_output = self.classification_dropout(pooled_output) 70 | classification_logits = self.classification_head(classification_output) 71 | 72 | # Reshape back to separate choices. 73 | classification_logits = classification_logits.view(-1, self.num_classes) 74 | 75 | return classification_logits 76 | return lm_output 77 | 78 | def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): 79 | """For easy load when model is combined with other heads, 80 | add an extra key.""" 81 | 82 | state_dict_ = {} 83 | state_dict_[self._language_model_key] \ 84 | = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, 85 | keep_vars=keep_vars) 86 | if self.post_process: 87 | state_dict_[self._classification_head_key] \ 88 | = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars) 89 | return state_dict_ 90 | 91 | def load_state_dict(self, state_dict, strict=True): 92 | """Customized load.""" 93 | 94 | self.language_model.load_state_dict( 95 | state_dict[self._language_model_key], strict=strict) 96 | if self.post_process: 97 | if self._classification_head_key in state_dict: 98 | self.classification_head.load_state_dict( 99 | state_dict[self._classification_head_key], strict=strict) 100 | else: 101 | print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 102 | 'initializing to random'.format( 103 | self._classification_head_key)) 104 | -------------------------------------------------------------------------------- /training/megatron/model/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import enum 4 | 5 | class LayerType(enum.Enum): 6 | encoder = 1 7 | decoder = 2 8 | 9 | class AttnType(enum.Enum): 10 | self_attn = 1 11 | cross_attn = 2 12 | 13 | class AttnMaskType(enum.Enum): 14 | padding = 1 15 | causal = 2 16 | -------------------------------------------------------------------------------- /training/megatron/model/fused_bias_gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | 6 | ###### BIAS GELU FUSION/ NO AUTOGRAD ################ 7 | # 1/sqrt(2*pi)-> 0.3989423 8 | # 1/sqrt(2) -> 0.70710678 9 | # sqrt(2/pi) -> 0.79788456 10 | # this function is tanh approximation of gelu 11 | # actual gelu is: 12 | # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) 13 | 14 | @torch.jit.script 15 | def bias_gelu(bias, y): 16 | x = bias + y 17 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 18 | 19 | # gradient of tanh approximation of gelu 20 | # gradient of actual gelu is: 21 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) 22 | @torch.jit.script 23 | def bias_gelu_back(g, bias, y): 24 | x = bias + y 25 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 26 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 27 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) 28 | return ff*g 29 | 30 | class GeLUFunction(torch.autograd.Function): 31 | @staticmethod 32 | # bias is an optional argument 33 | def forward(ctx, input, bias): 34 | ctx.save_for_backward(input, bias) 35 | return bias_gelu(bias, input) 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | input, bias = ctx.saved_tensors 40 | tmp = bias_gelu_back(grad_output, bias, input) 41 | return tmp, tmp 42 | 43 | bias_gelu_impl = GeLUFunction.apply 44 | -------------------------------------------------------------------------------- /training/megatron/model/fused_layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """This code is copied fron NVIDIA apex: 4 | https://github.com/NVIDIA/apex 5 | with some changes. """ 6 | 7 | import numbers 8 | import torch 9 | from torch.nn.parameter import Parameter 10 | from torch.nn import init 11 | import importlib 12 | 13 | from megatron.core.utils import make_viewless_tensor 14 | 15 | try: 16 | from apex.contrib.layer_norm.layer_norm import FastLayerNormFN 17 | HAVE_PERSIST_LAYER_NORM = True 18 | except: 19 | HAVE_PERSIST_LAYER_NORM = False 20 | 21 | global fused_mix_prec_layer_norm_cuda 22 | fused_mix_prec_layer_norm_cuda = None 23 | 24 | 25 | class FusedLayerNormAffineFunction(torch.autograd.Function): 26 | 27 | @staticmethod 28 | def forward(ctx, input, weight, bias, normalized_shape, eps): 29 | 30 | ctx.normalized_shape = normalized_shape 31 | ctx.eps = eps 32 | input_ = input.contiguous() 33 | weight_ = weight.contiguous() 34 | bias_ = bias.contiguous() 35 | output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( 36 | input_, ctx.normalized_shape, weight_, bias_, ctx.eps) 37 | ctx.save_for_backward(input_, weight_, bias_, mean, invvar) 38 | 39 | return output 40 | 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | 45 | input_, weight_, bias_, mean, invvar = ctx.saved_tensors 46 | grad_input = grad_weight = grad_bias = None 47 | grad_input, grad_weight, grad_bias \ 48 | = fused_mix_prec_layer_norm_cuda.backward_affine( 49 | grad_output.contiguous(), mean, invvar, 50 | input_, ctx.normalized_shape, 51 | weight_, bias_, ctx.eps) 52 | 53 | return grad_input, grad_weight, grad_bias, None, None 54 | 55 | 56 | 57 | class MixedFusedLayerNorm(torch.nn.Module): 58 | 59 | def __init__(self, normalized_shape, eps=1e-5, 60 | no_persist_layer_norm=True, 61 | sequence_parallel=False): 62 | super(MixedFusedLayerNorm, self).__init__() 63 | 64 | global fused_mix_prec_layer_norm_cuda 65 | fused_mix_prec_layer_norm_cuda = importlib.import_module( 66 | "fused_mix_prec_layer_norm_cuda") 67 | 68 | # List of hiddens sizes supported in the persistent layer norm kernel 69 | # If the hidden size is not supported, fall back to the non-persistent 70 | # kernel. 71 | persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, 72 | 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 73 | 24576, 25600, 30720, 32768, 40960, 49152, 65536] 74 | if normalized_shape not in persist_ln_hidden_sizes or \ 75 | not HAVE_PERSIST_LAYER_NORM: 76 | no_persist_layer_norm = True 77 | 78 | if isinstance(normalized_shape, numbers.Integral): 79 | normalized_shape = (normalized_shape,) 80 | self.normalized_shape = torch.Size(normalized_shape) 81 | self.eps = eps 82 | self.weight = Parameter(torch.Tensor(*normalized_shape)) 83 | self.bias = Parameter(torch.Tensor(*normalized_shape)) 84 | self.reset_parameters() 85 | self.no_persist_layer_norm = no_persist_layer_norm 86 | self.sequence_parallel = sequence_parallel 87 | 88 | # set sequence parallelism flag on weight and bias parameters 89 | setattr(self.weight, 'sequence_parallel', self.sequence_parallel) 90 | setattr(self.bias, 'sequence_parallel', self.sequence_parallel) 91 | 92 | 93 | def reset_parameters(self): 94 | 95 | init.ones_(self.weight) 96 | init.zeros_(self.bias) 97 | 98 | 99 | def forward(self, input): 100 | 101 | if self.no_persist_layer_norm: 102 | return FusedLayerNormAffineFunction.apply( 103 | input, self.weight, self.bias, self.normalized_shape, self.eps) 104 | else: 105 | output = FastLayerNormFN.apply( 106 | input, self.weight, self.bias, self.eps) 107 | 108 | # Apex's fast layer norm function outputs a 'view' tensor (i.e., has 109 | # a populated '_base' field). This will result in schedule.py's 110 | # deallocate_output_tensor() throwing an error, so a viewless tensor is 111 | # created to prevent this. 112 | output = make_viewless_tensor(inp = output, 113 | requires_grad = input.requires_grad, 114 | keep_graph = True) 115 | 116 | return output 117 | -------------------------------------------------------------------------------- /training/megatron/model/gpt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """GPT-2 model.""" 4 | 5 | import torch 6 | 7 | from megatron import get_args 8 | from megatron.core import tensor_parallel 9 | from .module import MegatronModule 10 | 11 | from .enums import AttnMaskType 12 | from .language_model import parallel_lm_logits 13 | from .language_model import get_language_model 14 | from .utils import init_method_normal 15 | from .utils import scaled_init_method_normal 16 | 17 | 18 | def post_language_model_processing(lm_output, labels, logit_weights, 19 | parallel_output, 20 | fp16_lm_cross_entropy): 21 | 22 | # Output. Format [s b h] 23 | output = parallel_lm_logits( 24 | lm_output, 25 | logit_weights, 26 | parallel_output) 27 | 28 | if labels is None: 29 | # [s b h] => [b s h] 30 | return output.transpose(0,1).contiguous() 31 | else: 32 | # [b s] => [s b] 33 | labels = labels.transpose(0,1).contiguous() 34 | if fp16_lm_cross_entropy: 35 | assert output.dtype == torch.half 36 | loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) 37 | else: 38 | loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) 39 | 40 | # [s b] => [b, s] 41 | loss = loss.transpose(0,1).contiguous() 42 | return loss 43 | 44 | 45 | class GPTModel(MegatronModule): 46 | """GPT-2 Language model.""" 47 | 48 | def __init__(self, 49 | num_tokentypes=0, 50 | parallel_output=True, 51 | pre_process=True, 52 | post_process=True): 53 | super(GPTModel, self).__init__() 54 | args = get_args() 55 | 56 | self.parallel_output = parallel_output 57 | self.pre_process = pre_process 58 | self.post_process = post_process 59 | self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy 60 | 61 | self.language_model, self._language_model_key = get_language_model( 62 | num_tokentypes=num_tokentypes, 63 | add_pooler=False, 64 | encoder_attn_mask_type=AttnMaskType.causal, 65 | init_method=init_method_normal(args.init_method_std), 66 | scaled_init_method=scaled_init_method_normal(args.init_method_std, 67 | args.num_layers), 68 | pre_process=self.pre_process, 69 | post_process=self.post_process) 70 | 71 | self.initialize_word_embeddings(init_method_normal) 72 | 73 | def set_input_tensor(self, input_tensor): 74 | """See megatron.model.transformer.set_input_tensor()""" 75 | self.language_model.set_input_tensor(input_tensor) 76 | 77 | def forward(self, input_ids, position_ids, attention_mask, 78 | ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None, 79 | labels=None, tokentype_ids=None, inference_params=None): 80 | 81 | lm_output = self.language_model( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | ret_input_ids=ret_input_ids, 86 | ret_position_ids=ret_position_ids, 87 | ret_attn_mask=ret_attn_mask, 88 | inference_params=inference_params) 89 | 90 | if self.post_process: 91 | return post_language_model_processing( 92 | lm_output, labels, 93 | self.word_embeddings_weight(), 94 | self.parallel_output, 95 | self.fp16_lm_cross_entropy) 96 | else: 97 | return lm_output 98 | 99 | def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): 100 | 101 | state_dict_ = {} 102 | state_dict_[self._language_model_key] \ 103 | = self.language_model.state_dict_for_save_checkpoint( 104 | prefix=prefix, keep_vars=keep_vars) 105 | # Save word_embeddings. 106 | if self.post_process and not self.pre_process: 107 | state_dict_[self._word_embeddings_for_head_key] \ 108 | = self.word_embeddings.state_dict(prefix=prefix, 109 | keep_vars=keep_vars) 110 | return state_dict_ 111 | 112 | def load_state_dict(self, state_dict, strict=True): 113 | """Customized load.""" 114 | 115 | # Load word_embeddings. 116 | if self.post_process and not self.pre_process: 117 | self.word_embeddings.load_state_dict( 118 | state_dict[self._word_embeddings_for_head_key], strict=strict) 119 | if self._language_model_key in state_dict: 120 | state_dict = state_dict[self._language_model_key] 121 | self.language_model.load_state_dict(state_dict, strict=strict) 122 | -------------------------------------------------------------------------------- /training/megatron/model/multiple_choice.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Multiple choice model.""" 4 | 5 | import torch 6 | 7 | from megatron import get_args, print_rank_last 8 | from megatron.model.enums import AttnMaskType 9 | from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids 10 | from megatron.model.language_model import get_language_model 11 | from megatron.model.utils import get_linear_layer 12 | from megatron.model.utils import init_method_normal 13 | from megatron.model.utils import scaled_init_method_normal 14 | from .module import MegatronModule 15 | 16 | 17 | class MultipleChoice(MegatronModule): 18 | 19 | def __init__(self, 20 | num_tokentypes=2, 21 | pre_process=True, 22 | post_process=True): 23 | super(MultipleChoice, self).__init__(share_word_embeddings=False) 24 | args = get_args() 25 | 26 | init_method = init_method_normal(args.init_method_std) 27 | self.pre_process = pre_process 28 | self.post_process = post_process 29 | 30 | self.language_model, self._language_model_key = get_language_model( 31 | num_tokentypes=num_tokentypes, 32 | add_pooler=True, 33 | encoder_attn_mask_type=AttnMaskType.padding, 34 | init_method=init_method, 35 | scaled_init_method=scaled_init_method_normal(args.init_method_std, 36 | args.num_layers), 37 | pre_process=self.pre_process, 38 | post_process=self.post_process) 39 | 40 | # Multi-choice head. 41 | if self.post_process: 42 | self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) 43 | self.multichoice_head = get_linear_layer(args.hidden_size, 1, 44 | init_method) 45 | self._multichoice_head_key = 'multichoice_head' 46 | 47 | def set_input_tensor(self, input_tensor): 48 | """See megatron.model.transformer.set_input_tensor()""" 49 | self.language_model.set_input_tensor(input_tensor) 50 | 51 | def forward(self, model_input, attention_mask, tokentype_ids=None): 52 | 53 | # [batch, choices, sequence] --> [batch * choices, sequence] --> 54 | # transformer --> [batch, choices] --> softmax 55 | 56 | # Ensure the shape is [batch-size, choices, sequence] 57 | assert len(attention_mask.shape) == 3 58 | num_choices = attention_mask.shape[1] 59 | 60 | # Reshape and treat choice dimension the same as batch. 61 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 62 | extended_attention_mask = bert_extended_attention_mask(attention_mask) 63 | 64 | input_ids = model_input 65 | # Do the same as attention_mask for input_ids, tokentype_ids 66 | assert len(input_ids.shape) == 3 67 | assert len(tokentype_ids.shape) == 3 68 | input_ids = input_ids.view(-1, input_ids.size(-1)) 69 | tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) 70 | position_ids = bert_position_ids(input_ids) 71 | 72 | lm_output = self.language_model( 73 | input_ids, 74 | position_ids, 75 | extended_attention_mask, 76 | tokentype_ids=tokentype_ids 77 | ) 78 | if self.post_process: 79 | _, pooled_output = lm_output 80 | multichoice_output = self.multichoice_dropout(pooled_output) 81 | multichoice_logits = self.multichoice_head(multichoice_output) 82 | 83 | # Reshape back to separate choices. 84 | multichoice_logits = multichoice_logits.view(-1, num_choices) 85 | 86 | return multichoice_logits 87 | return lm_output 88 | 89 | def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): 90 | """For easy load when model is combined with other heads, 91 | add an extra key.""" 92 | 93 | state_dict_ = {} 94 | state_dict_[self._language_model_key] \ 95 | = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, 96 | keep_vars=keep_vars) 97 | if self.post_process: 98 | state_dict_[self._multichoice_head_key] \ 99 | = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars) 100 | return state_dict_ 101 | 102 | def load_state_dict(self, state_dict, strict=True): 103 | """Customized load.""" 104 | 105 | self.language_model.load_state_dict( 106 | state_dict[self._language_model_key], strict=strict) 107 | if self.post_process: 108 | if self._multichoice_head_key in state_dict: 109 | self.multichoice_head.load_state_dict( 110 | state_dict[self._multichoice_head_key], strict=strict) 111 | else: 112 | print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 113 | 'initializing to random'.format( 114 | self._multichoice_head_key)) 115 | -------------------------------------------------------------------------------- /training/megatron/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Utilities for models.""" 4 | 5 | import math 6 | 7 | import torch 8 | 9 | from megatron import get_args 10 | 11 | def init_method_normal(sigma): 12 | """Init method based on N(0, sigma).""" 13 | def init_(tensor): 14 | return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) 15 | 16 | return init_ 17 | 18 | 19 | def scaled_init_method_normal(sigma, num_layers): 20 | """Init method based on N(0, sigma/sqrt(2*num_layers).""" 21 | std = sigma / math.sqrt(2.0 * num_layers) 22 | 23 | def init_(tensor): 24 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 25 | 26 | return init_ 27 | 28 | 29 | def attention_mask_func(attention_scores, attention_mask): 30 | attention_scores.masked_fill_(attention_mask, -10000.0) 31 | return attention_scores 32 | 33 | 34 | def get_linear_layer(rows, columns, init_method): 35 | """Simple linear layer with weight initialization.""" 36 | layer = torch.nn.Linear(rows, columns) 37 | if get_args().perform_initialization: 38 | init_method(layer.weight) 39 | with torch.no_grad(): 40 | layer.bias.zero_() 41 | return layer 42 | 43 | @torch.jit.script 44 | def gelu_impl(x): 45 | """OpenAI's gelu implementation.""" 46 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 47 | (1.0 + 0.044715 * x * x))) 48 | def openai_gelu(x): 49 | return gelu_impl(x) 50 | 51 | #This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter 52 | @torch.jit.script 53 | def erf_gelu(x): 54 | return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) 55 | -------------------------------------------------------------------------------- /training/megatron/model/vision/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Vision Transformer(VIT) model.""" 4 | 5 | import torch 6 | from torch.nn.init import trunc_normal_ 7 | from megatron import get_args 8 | from megatron.model.utils import get_linear_layer 9 | from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead 10 | from megatron.model.vision.mit_backbone import mit_b3_avg 11 | from megatron.model.module import MegatronModule 12 | 13 | class VitClassificationModel(MegatronModule): 14 | """Vision Transformer Model.""" 15 | 16 | def __init__(self, num_classes, finetune=False, 17 | pre_process=True, post_process=True): 18 | super(VitClassificationModel, self).__init__() 19 | args = get_args() 20 | 21 | self.hidden_size = args.hidden_size 22 | self.num_classes = num_classes 23 | self.finetune = finetune 24 | self.pre_process = pre_process 25 | self.post_process = post_process 26 | self.backbone = VitBackbone( 27 | pre_process=self.pre_process, 28 | post_process=self.post_process, 29 | single_token_output=True 30 | ) 31 | 32 | if self.post_process: 33 | if not self.finetune: 34 | self.head = VitMlpHead(self.hidden_size, self.num_classes) 35 | else: 36 | self.head = get_linear_layer( 37 | self.hidden_size, 38 | self.num_classes, 39 | torch.nn.init.zeros_ 40 | ) 41 | 42 | def set_input_tensor(self, input_tensor): 43 | """See megatron.model.transformer.set_input_tensor()""" 44 | self.backbone.set_input_tensor(input_tensor) 45 | 46 | def forward(self, input): 47 | hidden_states = self.backbone(input) 48 | 49 | if self.post_process: 50 | hidden_states = self.head(hidden_states) 51 | 52 | return hidden_states 53 | 54 | 55 | class MitClassificationModel(MegatronModule): 56 | """Mix vision Transformer Model.""" 57 | 58 | def __init__(self, num_classes, 59 | pre_process=True, post_process=True): 60 | super(MitClassificationModel, self).__init__() 61 | args = get_args() 62 | 63 | self.hidden_size = args.hidden_size 64 | self.num_classes = num_classes 65 | 66 | self.backbone = mit_b3_avg() 67 | self.head = torch.nn.Linear(512, num_classes) 68 | self.apply(self._init_weights) 69 | 70 | def _init_weights(self, m): 71 | if isinstance(m, torch.nn.Linear): 72 | trunc_normal_(m.weight, std=.02) 73 | if isinstance(m, torch.nn.Linear) and m.bias is not None: 74 | torch.nn.init.constant_(m.bias, 0) 75 | 76 | def set_input_tensor(self, input_tensor): 77 | """See megatron.model.transformer.set_input_tensor()""" 78 | pass 79 | 80 | def forward(self, input): 81 | hidden_states = self.backbone(input) 82 | hidden_states = self.head(hidden_states) 83 | 84 | return hidden_states 85 | -------------------------------------------------------------------------------- /training/megatron/model/vision/inpainting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | i 6 | import math 7 | import apex 8 | import einops 9 | import torch 10 | import torch.nn.functional as F 11 | from megatron import get_args, print_rank_0 12 | from megatron.model.utils import get_linear_layer 13 | from megatron.model.vision.vit_backbone import VitBackbone 14 | from megatron.model.module import MegatronModule 15 | from megatron.model.vision.mit_backbone import mit_b3 16 | from megatron.model.vision.utils import resize_ 17 | 18 | 19 | class VitInpaintingModel(MegatronModule): 20 | 21 | def __init__(self, pre_process=True, post_process=True): 22 | super(VitInpaintingModel, self).__init__() 23 | args = get_args() 24 | 25 | self.pre_process = pre_process 26 | self.post_process = post_process 27 | self.hidden_size = args.hidden_size 28 | self.backbone = VitBackbone( 29 | pre_process=self.pre_process, 30 | post_process=self.post_process, 31 | class_token=False, 32 | ) 33 | self.patch_dim = args.patch_dim 34 | self.img_h = args.img_h 35 | self.img_w = args.img_w 36 | self.seq_length = args.seq_length 37 | # full mask 38 | 39 | if self.post_process: 40 | self.linear_decoder = get_linear_layer( 41 | self.hidden_size, 42 | self.backbone.flatten_dim, 43 | torch.nn.init.zeros_ 44 | ) 45 | 46 | def set_input_tensor(self, input_tensor): 47 | self.backbone.set_input_tensor(input_tensor) 48 | 49 | def forward(self, input): 50 | 51 | hidden_states = self.backbone(input) 52 | 53 | if not self.post_process: 54 | return hidden_states 55 | decoded_output = self.linear_decoder(hidden_states) 56 | output = einops.rearrange( 57 | decoded_output, 58 | "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", 59 | p1=self.patch_dim, 60 | p2=self.patch_dim, 61 | h=self.img_h//self.patch_dim, 62 | w=self.img_w//self.patch_dim, 63 | ) 64 | 65 | return output 66 | 67 | 68 | class MLP(torch.nn.Module): 69 | """ 70 | Linear Embedding 71 | """ 72 | def __init__(self, input_dim=2048, embed_dim=768): 73 | super().__init__() 74 | self.proj = torch.nn.Linear(input_dim, embed_dim) 75 | 76 | def forward(self, x): 77 | x = x.flatten(2).transpose(1, 2) 78 | x = self.proj(x) 79 | return x 80 | 81 | 82 | class MitInpaintingModel(MegatronModule): 83 | """Mix vision Transformer Model.""" 84 | 85 | def __init__(self, pre_process=True, post_process=True): 86 | super(MitInpaintingModel, self).__init__() 87 | self.pre_process = pre_process 88 | self.post_process = post_process 89 | 90 | args = get_args() 91 | self.patch_dim = args.patch_dim 92 | self.img_h = args.img_h 93 | self.img_w = args.img_w 94 | self.flatten_dim = self.patch_dim * self.patch_dim * 3 95 | self.backbone = mit_b3() 96 | 97 | self.in_channels = [64, 128, 320, 512] 98 | self.embedding_dim = 768 99 | 100 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 101 | 102 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) 103 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) 104 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) 105 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) 106 | 107 | self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False) 108 | self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim) 109 | self.dropout = torch.nn.Dropout2d(0.1) 110 | 111 | self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1) 112 | 113 | def set_input_tensor(self, input_tensor): 114 | """See megatron.model.transformer.set_input_tensor()""" 115 | pass 116 | 117 | def forward(self, input): 118 | c1, c2, c3, c4 = self.backbone(input) 119 | 120 | n, _, h, w = c4.shape 121 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 122 | _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False) 123 | 124 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 125 | _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False) 126 | 127 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 128 | _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False) 129 | 130 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 131 | 132 | _c = torch.cat([_c4, _c3, _c2, _c1], dim=1) 133 | _c = self.conv_fuse(_c) 134 | 135 | x = self.norm(_c) 136 | x = F.relu(x, inplace=True) 137 | x = self.dropout(x) 138 | 139 | x = self.linear_pred(x) 140 | 141 | output = einops.rearrange( 142 | x, 143 | "b (c p1 p2) h w -> b c (h p1) (w p2)", 144 | p1=self.patch_dim, 145 | p2=self.patch_dim, 146 | h=self.img_h//self.patch_dim, 147 | w=self.img_w//self.patch_dim, 148 | ) 149 | 150 | return output 151 | -------------------------------------------------------------------------------- /training/megatron/model/vision/knn_monitor.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from megatron import print_rank_0, get_args 4 | from megatron.core import mpu 5 | from megatron.data.vit_dataset import ClassificationTransform 6 | from megatron.data.image_folder import ImageFolder 7 | 8 | _FEATURE_BANK = None 9 | 10 | 11 | def build_data_loader(dataset, drop_last=True, shuffle=False): 12 | """Data loader. Note that batch-size is the local (per GPU) batch-size.""" 13 | # Sampler. 14 | args = get_args() 15 | micro_batch_size = 16 16 | num_workers = args.num_workers 17 | world_size = mpu.get_data_parallel_world_size() 18 | rank = mpu.get_data_parallel_rank() 19 | sampler = torch.utils.data.distributed.DistributedSampler( 20 | dataset, num_replicas=world_size, rank=rank, 21 | drop_last=drop_last, shuffle=shuffle 22 | ) 23 | 24 | # Data loader. Note that batch size is the per GPU batch size. 25 | data_loader = torch.utils.data.DataLoader( 26 | dataset, 27 | batch_size=micro_batch_size, 28 | sampler=sampler, 29 | shuffle=False, 30 | num_workers=num_workers, 31 | drop_last=not drop_last, 32 | pin_memory=True, 33 | ) 34 | return data_loader 35 | 36 | 37 | def compute_feature_bank(model): 38 | args = get_args() 39 | global _FEATURE_BANK 40 | feature_bank = [] 41 | feature_label = [] 42 | 43 | train_ds = ImageFolder( 44 | root=args.data_path[0], 45 | transform=ClassificationTransform((args.img_h, args.img_w), train=False), 46 | data_per_class_fraction=1.0 47 | ) 48 | classes = len(train_ds.classes) 49 | dataloader = build_data_loader(train_ds) 50 | 51 | for m in model: 52 | m.eval() 53 | 54 | with torch.no_grad(): 55 | for i, batch in enumerate(dataloader): 56 | images = batch[0].cuda().contiguous() 57 | labels = batch[1].cuda().contiguous() 58 | student_feature, teacher_feature = model[0](images) 59 | feature = F.normalize(teacher_feature.float(), dim=1) 60 | feature_bank.append(feature) 61 | feature_label.append(labels) 62 | 63 | for m in model: 64 | m.train() 65 | 66 | # [N', D] 67 | feature_bank = torch.cat(feature_bank, dim=0).contiguous() 68 | feature_label = torch.cat(feature_label, dim=0).contiguous() 69 | 70 | feature_banks = [torch.zeros_like(feature_bank) 71 | for i in range(mpu.get_data_parallel_world_size())] 72 | torch.distributed.all_gather(feature_banks, 73 | feature_bank, 74 | group=mpu.get_data_parallel_group()) 75 | 76 | assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()], 77 | feature_bank)) 78 | 79 | feature_labels = [torch.zeros_like(feature_label) 80 | for i in range(mpu.get_data_parallel_world_size())] 81 | torch.distributed.all_gather(feature_labels, 82 | feature_label, 83 | group=mpu.get_data_parallel_group()) 84 | 85 | # [D, N] 86 | feature_banks = torch.cat(feature_banks, dim=0).t().contiguous() 87 | # [N] 88 | feature_labels = torch.cat(feature_labels, dim=0).contiguous() 89 | print_rank_0("feature_banks size is {}".format(feature_banks.size())) 90 | print_rank_0("feature labels size is {}".format(feature_labels.size())) 91 | 92 | _FEATURE_BANK = (feature_banks, feature_labels, classes) 93 | 94 | 95 | def get_feature_bank(): 96 | global _FEATURE_BANK 97 | assert _FEATURE_BANK is not None 98 | return _FEATURE_BANK 99 | 100 | 101 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 102 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and 103 | # https://github.com/leftthomas/SimCLR 104 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 105 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 106 | sim_matrix = torch.mm(feature, feature_bank) 107 | # [B, K] 108 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 109 | # [B, K] 110 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), 111 | dim=-1, 112 | index=sim_indices) 113 | sim_weight = (sim_weight / knn_t).exp() 114 | 115 | # counts for each class 116 | one_hot_label = torch.zeros(feature.size(0) * knn_k, 117 | classes, 118 | device=sim_labels.device) 119 | # [B*K, C] 120 | one_hot_label = one_hot_label.scatter(dim=-1, 121 | index=sim_labels.view(-1, 1), 122 | value=1.0) 123 | # weighted score ---> [B, C] 124 | pred_scores = torch.sum( 125 | one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), 126 | dim=1) 127 | 128 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 129 | return pred_labels 130 | -------------------------------------------------------------------------------- /training/megatron/model/vision/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def resize(input, 7 | size=None, 8 | scale_factor=None, 9 | mode='nearest', 10 | align_corners=None, 11 | warning=True): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ((output_h > 1 and output_w > 1 and input_h > 1 18 | and input_w > 1) and (output_h - 1) % (input_h - 1) 19 | and (output_w - 1) % (input_w - 1)): 20 | warnings.warn( 21 | f'When align_corners={align_corners}, ' 22 | 'the output would more aligned if ' 23 | f'input size {(input_h, input_w)} is `x+1` and ' 24 | f'out size {(output_h, output_w)} is `nx+1`') 25 | if isinstance(size, torch.Size): 26 | size = tuple(int(x) for x in size) 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | -------------------------------------------------------------------------------- /training/megatron/mpu/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLG/OpenBA/ef4716b7e588f17096043eef773557be41f2d7ed/training/megatron/mpu/tests/__init__.py -------------------------------------------------------------------------------- /training/megatron/mpu/tests/commons.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import argparse 4 | import os 5 | import random 6 | import numpy 7 | import torch 8 | 9 | import mpu 10 | 11 | 12 | class IdentityLayer(torch.nn.Module): 13 | def __init__(self, size, scale=1.0): 14 | super(IdentityLayer, self).__init__() 15 | self.weight = torch.nn.Parameter(scale * torch.randn(size)) 16 | 17 | def forward(self): 18 | return self.weight 19 | 20 | 21 | def set_random_seed(seed): 22 | """Set random seed for reproducability.""" 23 | random.seed(seed) 24 | numpy.random.seed(seed) 25 | torch.manual_seed(seed) 26 | mpu.model_parallel_cuda_manual_seed(seed) 27 | 28 | 29 | def initialize_distributed(backend='nccl'): 30 | """Initialize torch.distributed.""" 31 | # Get local rank in case it is provided. 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--local_rank', type=int, default=None, 34 | help='local rank passed from distributed launcher') 35 | args = parser.parse_args() 36 | local_rank = args.local_rank 37 | 38 | # Get rank and world size. 39 | rank = int(os.getenv('RANK', '0')) 40 | world_size = int(os.getenv("WORLD_SIZE", '1')) 41 | 42 | print('> initializing torch.distributed with local rank: {}, ' 43 | 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) 44 | 45 | # Set the device id. 46 | device = rank % torch.cuda.device_count() 47 | if local_rank is not None: 48 | device = local_rank 49 | torch.cuda.set_device(device) 50 | 51 | # Call the init process. 52 | init_method = 'tcp://' 53 | master_ip = os.getenv('MASTER_ADDR', 'localhost') 54 | master_port = os.getenv('MASTER_PORT', '6000') 55 | init_method += master_ip + ':' + master_port 56 | torch.distributed.init_process_group( 57 | backend=backend, 58 | world_size=world_size, 59 | rank=rank, 60 | init_method=init_method) 61 | 62 | 63 | def print_separator(message): 64 | torch.distributed.barrier() 65 | filler_len = (78 - len(message)) // 2 66 | filler = '-' * filler_len 67 | string = '\n' + filler + ' {} '.format(message) + filler 68 | if torch.distributed.get_rank() == 0: 69 | print(string, flush=True) 70 | torch.distributed.barrier() 71 | -------------------------------------------------------------------------------- /training/megatron/mpu/tests/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import set_random_seed 4 | from commons import IdentityLayer 5 | from commons import print_separator 6 | from commons import initialize_distributed 7 | from mpu.cross_entropy import vocab_parallel_cross_entropy 8 | import mpu 9 | import torch.nn.functional as F 10 | import torch 11 | import random 12 | import sys 13 | sys.path.append("../..") 14 | 15 | 16 | def torch_cross_entropy(batch_size, seq_length, vocab_size, 17 | logits_scale, seed): 18 | set_random_seed(seed) 19 | identity = IdentityLayer((batch_size, seq_length, vocab_size), 20 | scale=logits_scale).cuda() 21 | logits = identity() 22 | target = torch.cuda.LongTensor( 23 | size=(batch_size, seq_length)).random_(0, vocab_size) 24 | loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), 25 | target.view(-1), 26 | reduction='none').view_as(target).mean() 27 | loss.backward() 28 | return loss, identity.weight.grad 29 | 30 | 31 | def mpu_cross_entropy(batch_size, seq_length, vocab_size, 32 | logits_scale, seed): 33 | set_random_seed(seed) 34 | identity = IdentityLayer((batch_size, seq_length, vocab_size), 35 | scale=logits_scale).cuda() 36 | logits = identity() 37 | logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) 38 | target = torch.cuda.LongTensor( 39 | size=(batch_size, seq_length)).random_(0, vocab_size) 40 | loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() 41 | loss.backward() 42 | return loss, identity.weight.grad 43 | 44 | 45 | def test_cross_entropy(tensor_model_parallel_size): 46 | 47 | if torch.distributed.get_rank() == 0: 48 | print('> testing cross entropy with model parallel size {} ...'. 49 | format(tensor_model_parallel_size)) 50 | 51 | mpu.initialize_model_parallel(tensor_model_parallel_size) 52 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 53 | 54 | batch_size = 13 55 | seq_length = 17 56 | vocab_size_per_partition = 11 57 | logits_scale = 1000.0 58 | vocab_size = vocab_size_per_partition * tensor_model_parallel_size 59 | seed = 1234 60 | 61 | loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, 62 | vocab_size, logits_scale, 63 | seed) 64 | loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, 65 | vocab_size, logits_scale, 66 | seed) 67 | 68 | error = loss_torch.sub_(loss_mpu).abs().max() 69 | print(' max error in loss on global rank {}: {}'.format( 70 | torch.distributed.get_rank(), error)) 71 | assert error < 1.0e-6 72 | 73 | error = grad_torch.sub_(grad_mpu).abs().max() 74 | print(' max error in grad on global rank {}: {}'.format( 75 | torch.distributed.get_rank(), error)) 76 | assert error < 1.0e-6 77 | 78 | # Reset groups 79 | mpu.destroy_tensor_model_parallel() 80 | 81 | torch.distributed.barrier() 82 | if torch.distributed.get_rank() == 0: 83 | print('>> passed the test :-)') 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | initialize_distributed() 89 | world_size = torch.distributed.get_world_size() 90 | 91 | tensor_model_parallel_size = 1 92 | while tensor_model_parallel_size <= world_size: 93 | print_separator('test cross entropy') 94 | test_cross_entropy(tensor_model_parallel_size) 95 | tensor_model_parallel_size *= 2 96 | -------------------------------------------------------------------------------- /training/megatron/mpu/tests/test_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import print_separator 4 | from commons import initialize_distributed 5 | from mpu import data as data_utils 6 | import mpu 7 | import torch 8 | import functools 9 | import operator 10 | import sys 11 | sys.path.append("../..") 12 | 13 | 14 | def test_broadcast_data(tensor_model_parallel_size): 15 | 16 | if torch.distributed.get_rank() == 0: 17 | print('> testing broadcast_data with model parallel size {} ...'. 18 | format(tensor_model_parallel_size)) 19 | 20 | mpu.initialize_model_parallel(tensor_model_parallel_size) 21 | torch.manual_seed(1234 + mpu.get_data_parallel_rank()) 22 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 23 | 24 | key_size_t = {'key1': [7, 11], 25 | 'key2': [8, 2, 1], 26 | 'key3': [13], 27 | 'key4': [5, 1, 2], 28 | 'key5': [5, 12]} 29 | keys = list(key_size_t.keys()) 30 | 31 | data = {} 32 | data_t = {} 33 | for key in key_size_t: 34 | data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) 35 | data_t[key] = data[key].clone() 36 | data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) 37 | data_t['keyX'] = data['keyX'].clone() 38 | if mpu.get_tensor_model_parallel_rank() != 0: 39 | data = None 40 | 41 | data_utils._check_data_types(keys, data_t, torch.int64) 42 | key_size, key_numel, \ 43 | total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) 44 | for key in keys: 45 | assert key_size[key] == key_size_t[key] 46 | total_numel_t = 0 47 | for key in keys: 48 | target_size = functools.reduce(operator.mul, key_size_t[key], 1) 49 | assert key_numel[key] == target_size 50 | total_numel_t += target_size 51 | assert total_numel == total_numel_t 52 | 53 | data_b = data_utils.broadcast_data(keys, data, torch.int64) 54 | for key in keys: 55 | tensor = data_t[key].cuda() 56 | assert data_b[key].sub(tensor).abs().max() == 0 57 | 58 | # Reset groups 59 | mpu.destroy_tensor_model_parallel() 60 | 61 | torch.distributed.barrier() 62 | if torch.distributed.get_rank() == 0: 63 | print('>> passed the test :-)') 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | initialize_distributed() 69 | world_size = torch.distributed.get_world_size() 70 | 71 | tensor_model_parallel_size = 1 72 | while tensor_model_parallel_size <= world_size: 73 | print_separator('test test broadcast data') 74 | test_broadcast_data(tensor_model_parallel_size) 75 | tensor_model_parallel_size *= 2 76 | -------------------------------------------------------------------------------- /training/megatron/mpu/tests/test_initialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import print_separator 4 | from commons import initialize_distributed 5 | import mpu 6 | import torch 7 | import sys 8 | sys.path.append("../..") 9 | 10 | 11 | def test_initialize_model_parallel(tensor_model_parallel_size): 12 | 13 | if torch.distributed.get_rank() == 0: 14 | print('> testing initialize_model_parallel with size {} ...'.format( 15 | tensor_model_parallel_size)) 16 | tensor_model_parallel_size_ = min(tensor_model_parallel_size, 17 | torch.distributed.get_world_size()) 18 | assert not mpu.model_parallel_is_initialized() 19 | mpu.initialize_model_parallel(tensor_model_parallel_size_) 20 | assert mpu.model_parallel_is_initialized() 21 | 22 | # Checks. 23 | def check(group, world_size, rank): 24 | assert world_size == torch.distributed.get_world_size(group=group) 25 | assert rank == torch.distributed.get_rank(group=group) 26 | 27 | # Model parallel. 28 | world_size = tensor_model_parallel_size_ 29 | rank = torch.distributed.get_rank() % tensor_model_parallel_size_ 30 | assert world_size == mpu.get_tensor_model_parallel_world_size() 31 | assert rank == mpu.get_tensor_model_parallel_rank() 32 | check(mpu.get_tensor_model_parallel_group(), world_size, rank) 33 | 34 | # Data parallel. 35 | world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ 36 | rank = torch.distributed.get_rank() // tensor_model_parallel_size 37 | assert world_size == mpu.get_data_parallel_world_size() 38 | assert rank == mpu.get_data_parallel_rank() 39 | check(mpu.get_data_parallel_group(), world_size, rank) 40 | 41 | # Reset groups 42 | mpu.destroy_model_parallel() 43 | 44 | torch.distributed.barrier() 45 | if torch.distributed.get_rank() == 0: 46 | print('>> passed the test :-)') 47 | 48 | 49 | def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): 50 | 51 | if torch.distributed.get_rank() == 0: 52 | print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( 53 | tensor_model_parallel_size_)) 54 | tensor_model_parallel_size = min(tensor_model_parallel_size_, 55 | torch.distributed.get_world_size()) 56 | assert not mpu.model_parallel_is_initialized() 57 | mpu.initialize_model_parallel(tensor_model_parallel_size) 58 | assert mpu.model_parallel_is_initialized() 59 | 60 | # Checks 61 | src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() 62 | assert mpu.get_tensor_model_parallel_src_rank() == src_rank 63 | 64 | # Reset groups 65 | mpu.destroy_model_parallel() 66 | 67 | torch.distributed.barrier() 68 | if torch.distributed.get_rank() == 0: 69 | print('>> passed the test :-)') 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | initialize_distributed() 75 | world_size = torch.distributed.get_world_size() 76 | tensor_model_parallel_size = 1 77 | while tensor_model_parallel_size <= world_size: 78 | print_separator('test initialize model parallel') 79 | test_initialize_model_parallel(tensor_model_parallel_size) 80 | print_separator('test model parallel source rank') 81 | test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) 82 | tensor_model_parallel_size *= 2 83 | -------------------------------------------------------------------------------- /training/megatron/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from apex.optimizers import FusedAdam as Adam 4 | from apex.optimizers import FusedSGD as SGD 5 | 6 | from megatron import get_args 7 | 8 | from .distrib_optimizer import DistributedOptimizer 9 | from .grad_scaler import ConstantGradScaler, DynamicGradScaler 10 | from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer 11 | 12 | 13 | def get_param_groups(modules, 14 | no_weight_decay_cond, 15 | scale_lr_cond, 16 | lr_mult): 17 | """creates param groups based on weight decay condition (regularized vs non regularized) 18 | and learning rate scale condition (args.lr vs lr_mult * args.lr) 19 | scale_lr_cond is used during finetuning where head of the network requires a scaled 20 | version of the base learning rate. 21 | """ 22 | wd_no_scale_lr = [] 23 | wd_scale_lr = [] 24 | no_wd_no_scale_lr = [] 25 | no_wd_scale_lr = [] 26 | for module in modules: 27 | for name, param in module.named_parameters(): 28 | if not param.requires_grad: 29 | continue 30 | 31 | if no_weight_decay_cond is not None: 32 | no_wd = no_weight_decay_cond(name, param) 33 | else: 34 | # do not regularize biases nor Norm parameters 35 | no_wd = name.endswith(".bias") or len(param.shape) == 1 36 | 37 | if scale_lr_cond is not None: 38 | scale_lr = scale_lr_cond(name, param) 39 | else: 40 | scale_lr = False 41 | 42 | if not no_wd and not scale_lr: 43 | wd_no_scale_lr.append(param) 44 | elif not no_wd and scale_lr: 45 | wd_scale_lr.append(param) 46 | elif no_wd and not scale_lr: 47 | no_wd_no_scale_lr.append(param) 48 | else: 49 | no_wd_scale_lr.append(param) 50 | 51 | param_groups = [] 52 | if len(wd_no_scale_lr): 53 | param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0}) 54 | if len(wd_scale_lr): 55 | param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult}) 56 | if len(no_wd_no_scale_lr): 57 | param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0}) 58 | if len(no_wd_scale_lr): 59 | param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult}) 60 | 61 | return param_groups 62 | 63 | def get_megatron_optimizer(model, 64 | no_weight_decay_cond=None, 65 | scale_lr_cond=None, 66 | lr_mult=1.0): 67 | args = get_args() 68 | 69 | # Base optimizer. 70 | param_groups = get_param_groups(model, 71 | no_weight_decay_cond, 72 | scale_lr_cond, 73 | lr_mult) 74 | 75 | if args.optimizer == 'adam': 76 | optimizer = Adam(param_groups, 77 | lr=args.lr, 78 | weight_decay=args.weight_decay, 79 | betas=(args.adam_beta1, args.adam_beta2), 80 | eps=args.adam_eps) 81 | elif args.optimizer == 'sgd': 82 | optimizer = SGD(param_groups, 83 | lr=args.lr, 84 | weight_decay=args.weight_decay, 85 | momentum=args.sgd_momentum) 86 | else: 87 | raise Exception('{} optimizer is not supported.'.format( 88 | args.optimizer)) 89 | 90 | # Determine whether the params have main-grad field. 91 | params_have_main_grad = False 92 | if args.DDP_impl == 'local': 93 | params_have_main_grad = True 94 | 95 | # Mixed precision optimizer. 96 | # - Note: both the Float16Optimizer and the DistributedOptimizer inherit 97 | # from the MixedPrecisionOptimizer, which manages any optimizer where 98 | # the model params and main params are distinct. 99 | if args.fp16 or args.bf16 or args.use_distributed_optimizer: 100 | 101 | # Grad scaler: 102 | # if loss-scale is provided, instantiate the constant scaler. 103 | # if we are using fp16 and loss-scale is not present, use a 104 | # dynamic scaler. 105 | # otherwise we are running in bf16 with no loss-scale so 106 | # leave it as None. 107 | grad_scaler = None 108 | 109 | # Constant loss scale. 110 | if args.loss_scale: 111 | grad_scaler = ConstantGradScaler(args.loss_scale) 112 | 113 | # Dynamic loss scale. 114 | else: 115 | if args.fp16: 116 | grad_scaler = DynamicGradScaler( 117 | initial_scale=args.initial_loss_scale, 118 | min_scale=args.min_loss_scale, 119 | growth_factor=2.0, 120 | backoff_factor=0.5, 121 | growth_interval=args.loss_scale_window, 122 | hysteresis=args.hysteresis) 123 | 124 | # Megatron optimizer. 125 | opt_ty = DistributedOptimizer \ 126 | if args.use_distributed_optimizer else \ 127 | Float16OptimizerWithFloat16Params 128 | return opt_ty(optimizer, 129 | args.clip_grad, 130 | args.log_num_zeros_in_grad, 131 | params_have_main_grad, 132 | args.use_contiguous_buffers_in_local_ddp, 133 | args.fp16, 134 | args.bf16, 135 | args.params_dtype, 136 | grad_scaler, 137 | model) 138 | 139 | # FP32. 140 | return FP32Optimizer(optimizer, args.clip_grad, 141 | args.log_num_zeros_in_grad, 142 | params_have_main_grad, 143 | args.use_contiguous_buffers_in_local_ddp, 144 | model) 145 | -------------------------------------------------------------------------------- /training/megatron/optimizer/clip_grads.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Gradient clipping.""" 4 | 5 | import torch 6 | try: 7 | from torch._six import inf 8 | except: 9 | from torch import inf 10 | 11 | from apex.multi_tensor_apply import multi_tensor_applier 12 | import amp_C 13 | 14 | from megatron.model.module import param_is_not_shared 15 | from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate 16 | 17 | 18 | def clip_grad_norm_fp32(parameters, grads_for_norm, 19 | max_norm, norm_type=2, 20 | model_parallel_group=None): 21 | """Clips gradient norm of an iterable of parameters whose gradients 22 | are in fp32. 23 | 24 | This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and 25 | added functionality to handle model parallel parameters. Note that 26 | the gradients are modified in place. 27 | 28 | Arguments: 29 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 30 | single Tensor that will have gradients normalized 31 | grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single 32 | Tensor that will be used for calculating the grad norm. 33 | max_norm (float or int): max norm of the gradients 34 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 35 | infinity norm. 36 | model_parallel_group (group): given the nature of the distributed 37 | optimizer, this is passed as an argument. 38 | 39 | Returns: 40 | Total norm of the parameters (viewed as a single vector). 41 | """ 42 | 43 | if isinstance(parameters, torch.Tensor): 44 | parameters = [parameters] 45 | if isinstance(grads_for_norm, torch.Tensor): 46 | grads_for_norm = [grads_for_norm] 47 | 48 | # Grads. 49 | grads = [] 50 | for param in parameters: 51 | if param.grad is not None: 52 | assert param.grad.type() == 'torch.cuda.FloatTensor' 53 | grads.append(param.grad.detach()) 54 | 55 | # Norm parameters. 56 | max_norm = float(max_norm) 57 | norm_type = float(norm_type) 58 | total_norm = 0.0 59 | 60 | # Calculate norm. 61 | if norm_type == inf: 62 | total_norm = max(grad.abs().max() for grad in grads_for_norm) 63 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 64 | # Take max across all model-parallel GPUs. 65 | torch.distributed.all_reduce(total_norm_cuda, 66 | op=torch.distributed.ReduceOp.MAX, 67 | group=model_parallel_group) 68 | total_norm = total_norm_cuda[0].item() 69 | 70 | else: 71 | if norm_type == 2.0: 72 | dummy_overflow_buf = torch.cuda.IntTensor([0]) 73 | # Use apex's multi-tensor applier for efficiency reasons. 74 | # Multi-tensor applier takes a function and a list of list 75 | # and performs the operation on that list all in one kernel. 76 | if grads_for_norm: 77 | grad_norm, _ = multi_tensor_applier( 78 | amp_C.multi_tensor_l2norm, 79 | dummy_overflow_buf, 80 | [grads_for_norm], 81 | False # no per-parameter norm 82 | ) 83 | else: 84 | grad_norm = torch.cuda.FloatTensor([0]) 85 | # Since we will be summing across data parallel groups, 86 | # we need the pow(norm-type). 87 | total_norm = grad_norm ** norm_type 88 | 89 | else: 90 | for grad in grads_for_norm: 91 | grad_norm = torch.norm(grad, norm_type) 92 | total_norm += grad_norm ** norm_type 93 | 94 | # Sum across all model-parallel GPUs. 95 | torch.distributed.all_reduce(total_norm, 96 | op=torch.distributed.ReduceOp.SUM, 97 | group=model_parallel_group) 98 | total_norm = total_norm.item() ** (1.0 / norm_type) 99 | 100 | # Scale. 101 | clip_coeff = max_norm / (total_norm + 1.0e-6) 102 | if clip_coeff < 1.0: 103 | dummy_overflow_buf = torch.cuda.IntTensor([0]) 104 | multi_tensor_applier(amp_C.multi_tensor_scale, 105 | dummy_overflow_buf, 106 | [grads, grads], 107 | clip_coeff) 108 | 109 | return total_norm 110 | 111 | 112 | def count_zeros_fp32(parameters, model_parallel_group): 113 | 114 | if isinstance(parameters, torch.Tensor): 115 | parameters = [parameters] 116 | 117 | # Filter parameters based on: 118 | # - grad should not be none 119 | # - parameter should not be shared 120 | # - should not be a replica due to tensor model parallelism 121 | total_num_zeros = torch.cuda.FloatTensor([0.0]) 122 | for param in parameters: 123 | grad_not_none = param.grad is not None 124 | is_not_shared = param_is_not_shared(param) 125 | is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) 126 | if grad_not_none and is_not_shared and is_not_tp_duplicate: 127 | grad = param.grad.detach() 128 | num_zeros = grad.numel() - torch.count_nonzero(grad) 129 | total_num_zeros = num_zeros + total_num_zeros 130 | 131 | # Sum across all model-parallel GPUs. 132 | torch.distributed.all_reduce(total_num_zeros, 133 | op=torch.distributed.ReduceOp.SUM, 134 | group=model_parallel_group) 135 | 136 | total_num_zeros = total_num_zeros.item() 137 | 138 | return total_num_zeros 139 | -------------------------------------------------------------------------------- /training/megatron/optimizer/grad_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Megatron grad scaler.""" 4 | 5 | from abc import ABC 6 | from abc import abstractmethod 7 | 8 | import torch 9 | 10 | 11 | class MegatronGradScaler(ABC): 12 | 13 | def __init__(self, initial_scale): 14 | """Initialize scale value with the input initial scale.""" 15 | assert initial_scale > 0.0 16 | self._scale = torch.cuda.FloatTensor([initial_scale]) 17 | 18 | @property 19 | def scale(self): 20 | return self._scale 21 | 22 | @property 23 | def inv_scale(self): 24 | return self._scale.double().reciprocal().float() 25 | 26 | @abstractmethod 27 | def update(self, found_inf): 28 | pass 29 | 30 | @abstractmethod 31 | def state_dict(self): 32 | pass 33 | 34 | @abstractmethod 35 | def load_state_dict(self, state_dict): 36 | pass 37 | 38 | 39 | 40 | class ConstantGradScaler(MegatronGradScaler): 41 | 42 | def update(self, found_inf): 43 | pass 44 | 45 | def state_dict(self): 46 | return dict() 47 | 48 | def load_state_dict(self, state_dict): 49 | pass 50 | 51 | 52 | 53 | class DynamicGradScaler(MegatronGradScaler): 54 | 55 | def __init__(self, initial_scale, min_scale, 56 | growth_factor, backoff_factor, 57 | growth_interval, hysteresis): 58 | """"Grad scaler with dynamic scale that gets adjusted 59 | during training.""" 60 | super(DynamicGradScaler, self).__init__(initial_scale) 61 | 62 | # Lower bound on the scale. 63 | assert min_scale > 0.0 64 | assert min_scale <= initial_scale 65 | self.min_scale = torch.cuda.FloatTensor([min_scale]) 66 | # Growth and backoff factors for the scale. 67 | assert growth_factor > 1.0 68 | self.growth_factor = torch.cuda.FloatTensor([growth_factor]) 69 | assert backoff_factor < 1.0 70 | assert backoff_factor > 0.0 71 | self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) 72 | # Interval over which if we don't see any inf/nan, 73 | # we will scale the grad scale by the growth factor. 74 | assert growth_interval > 0 75 | self.growth_interval = growth_interval 76 | # Number of inf/nans we should see before scaling down 77 | # the grad scale by the backoff factor. 78 | assert hysteresis > 0 79 | self.hysteresis = hysteresis 80 | 81 | # Trackers. 82 | self._growth_tracker = 0 83 | self._hysteresis_tracker = self.hysteresis 84 | 85 | 86 | def update(self, found_inf): 87 | 88 | # If we have an inf/nan, growth tracker is set to 0 89 | # and hysterisis tracker is reduced by 1. 90 | if found_inf: 91 | self._growth_tracker = 0 92 | self._hysteresis_tracker -= 1 93 | # Now if we are out of hysteresis count, scale down the loss. 94 | if self._hysteresis_tracker <= 0: 95 | self._scale = torch.max(self._scale * self.backoff_factor, 96 | self.min_scale) 97 | else: 98 | # If there is no nan/inf, increment the growth tracker. 99 | self._growth_tracker += 1 100 | # If we have had enough consequitive intervals with no nan/inf: 101 | if self._growth_tracker == self.growth_interval: 102 | # Reset the tracker and hysteresis trackers, 103 | self._growth_tracker = 0 104 | self._hysteresis_tracker = self.hysteresis 105 | # and scale up the loss scale. 106 | self._scale = self._scale * self.growth_factor 107 | 108 | 109 | def state_dict(self): 110 | state_dict = {} 111 | state_dict['scale'] = self._scale 112 | state_dict['growth_tracker'] = self._growth_tracker 113 | state_dict['hysteresis_tracker'] = self._hysteresis_tracker 114 | return state_dict 115 | 116 | 117 | def load_state_dict(self, state_dict): 118 | self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) 119 | self._growth_tracker = state_dict['growth_tracker'] 120 | self._hysteresis_tracker = state_dict['hysteresis_tracker'] 121 | -------------------------------------------------------------------------------- /training/megatron/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Megatron 9 | 71 | 72 | 73 |
74 |

Prompt Megatron

75 | 76 | 77 | 78 | 79 | 80 |
81 | 0 82 | / 1000 83 |
84 | 85 |
86 | 87 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /training/megatron/text_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | from .api import ( 5 | generate, 6 | generate_and_post_process, 7 | beam_search_and_post_process) 8 | -------------------------------------------------------------------------------- /training/megatron/text_generation/beam_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | ## from huggingface beam search 19 | class BeamHypotheses(object): 20 | def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): 21 | """ 22 | Initialize n-best list of hypotheses. 23 | """ 24 | self.length_penalty = length_penalty 25 | self.early_stopping = early_stopping 26 | self.num_beams = num_beams 27 | self.beams = [] 28 | self.worst_score = 1e9 29 | 30 | def __len__(self): 31 | """ 32 | Number of hypotheses in the list. 33 | """ 34 | return len(self.beams) 35 | 36 | def add(self, hyp, sum_logprobs, length): 37 | """ 38 | Add a new hypothesis to the list. 39 | """ 40 | score = sum_logprobs / length ** self.length_penalty 41 | if len(self) < self.num_beams or score > self.worst_score: 42 | self.beams.append((score, hyp)) 43 | if len(self) > self.num_beams: 44 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) 45 | del self.beams[sorted_scores[0][1]] 46 | self.worst_score = sorted_scores[1][0] 47 | else: 48 | self.worst_score = min(score, self.worst_score) 49 | 50 | def is_done(self, best_sum_logprobs, cur_len): 51 | """ 52 | If there are enough hypotheses and that none of the hypotheses being generated 53 | can become better than the worst one in the heap, then we are done with this sentence. 54 | """ 55 | 56 | if len(self) < self.num_beams: 57 | return False 58 | elif self.early_stopping: 59 | return True 60 | else: 61 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 62 | ret = self.worst_score >= cur_score 63 | return ret 64 | 65 | -------------------------------------------------------------------------------- /training/megatron/text_generation/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Sampling utilities. 4 | Part of this code is inspired by: 5 | - https://github.com/ari-holtzman/degen/blob/master/gen.py 6 | - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html 7 | """ 8 | 9 | 10 | import torch 11 | 12 | 13 | 14 | def modify_logits_for_top_k_filtering(logits, top_k): 15 | """Set the logits for none top-k values to -inf.""" 16 | 17 | filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] 18 | logits.masked_fill_(filter_, float('-Inf')) 19 | 20 | 21 | 22 | def modify_logits_for_top_p_filtering(logits, top_p): 23 | """Set the logits for none top-p values to -inf.""" 24 | 25 | # First sort and calculate cumulative sum of probabilities. 26 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 27 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 28 | 29 | # Filteration based on the cumulative sum. 30 | filter_ = cumulative_probs > top_p 31 | # This shift by 1 is weird and I cannot justify it. This existed 32 | # in the original implementation: 33 | # https://github.com/ari-holtzman/degen/blob/master/gen.py 34 | # and I guess it is needed so keeping it for now. 35 | filter_[:, 1:] = filter_[:, :-1].clone() 36 | # Make sure we at least have one token to select from. 37 | filter_[..., 0] = 0 38 | 39 | # Fill in the filtered part 40 | filter_ = filter_.scatter(1, sorted_indices, filter_) 41 | logits.masked_fill_(filter_, float('-Inf')) 42 | 43 | 44 | 45 | def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): 46 | """ Sample and generate a token. 47 | Note: logits has the dimension [b, v] where b is the batch size 48 | and v is the vocabulary size. 49 | If vocab_size is provided, we will make sure the sample that is 50 | generated is in [0, vocab-size). This will avoid out of vocabulary 51 | generations due to padding. 52 | """ 53 | 54 | # Check logits for consistency. 55 | assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' 56 | # assert logits.type() == 'torch.cuda.FloatTensor', \ 57 | # 'input logits should be floats.' 58 | 59 | 60 | # Greedy is just simple argmax. 61 | if top_k == 1: 62 | assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' 63 | samples = torch.argmax(logits, dim=-1) 64 | 65 | # Top-k or top-p sampling. 66 | else: 67 | # Clone so we do not modify the inputs, 68 | logits = logits.clone() 69 | # Apply temperature in place. 70 | if temperature != 1.0: 71 | logits.div_(temperature) 72 | 73 | if top_k > 1: 74 | assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' 75 | assert top_k <= logits.size(1), 'top-k is larger than logit size.' 76 | if vocab_size: 77 | assert top_k < vocab_size, 'top-k is larger than vocab size.' 78 | modify_logits_for_top_k_filtering(logits, top_k) 79 | 80 | elif top_p > 0.0: 81 | assert top_p <= 1.0, 'top-p should be in (0, 1].' 82 | modify_logits_for_top_p_filtering(logits, top_p) 83 | 84 | # After filtering, we need to recalculate the distribution. 85 | probs = logits.softmax(dim=-1) 86 | samples = torch.multinomial(probs, num_samples=1).view(-1) 87 | 88 | # If vocab size is provided, make sure the samples are in 89 | # in the range [0, vocab-size). 90 | if vocab_size: 91 | samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) 92 | 93 | return samples 94 | -------------------------------------------------------------------------------- /training/megatron/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | from .tokenizer import build_tokenizer 5 | -------------------------------------------------------------------------------- /training/pretrain_t5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Pretrain T5""" 4 | 5 | from functools import partial 6 | 7 | import torch 8 | 9 | from megatron import ( 10 | get_args, 11 | get_timers, 12 | print_rank_0 13 | ) 14 | from megatron.core import tensor_parallel 15 | from megatron.core.enums import ModelType 16 | from megatron.data.dataset_utils import build_train_valid_test_datasets 17 | from megatron.model import T5Model 18 | from megatron.training import pretrain 19 | from megatron.utils import average_losses_across_data_parallel_group 20 | from megatron import get_tokenizer 21 | 22 | 23 | def model_provider(pre_process=True, post_process=True, 24 | add_encoder=True, add_decoder=True): 25 | """Build the model.""" 26 | 27 | print_rank_0('building T5 model ...') 28 | model = T5Model(num_tokentypes=0, 29 | parallel_output=True, 30 | pre_process=pre_process, 31 | post_process=post_process, 32 | add_encoder=add_encoder, 33 | add_decoder=add_decoder) 34 | return model 35 | 36 | 37 | def get_batch(data_iterator): 38 | """Build the batch.""" 39 | 40 | keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 41 | 'enc_mask', 'dec_mask', 'enc_dec_mask'] 42 | datatype = torch.int64 43 | 44 | # Broadcast data. 45 | if data_iterator is not None: 46 | data = next(data_iterator) 47 | else: 48 | data = None 49 | 50 | data_b = tensor_parallel.broadcast_data(keys, data, datatype) 51 | 52 | # Unpack. 53 | tokens_enc = data_b['text_enc'].long() 54 | tokens_dec = data_b['text_dec'].long() 55 | labels = data_b['labels'].long() 56 | loss_mask = data_b['loss_mask'].float() 57 | 58 | enc_mask = (data_b['enc_mask'] < 0.5) 59 | dec_mask = (data_b['dec_mask'] < 0.5) 60 | enc_dec_mask = (data_b['enc_dec_mask'] < 0.5) 61 | 62 | return tokens_enc, tokens_dec, loss_mask, labels, \ 63 | enc_mask, dec_mask, enc_dec_mask 64 | 65 | 66 | def loss_func(loss_mask, output_tensor): 67 | lm_loss_ = output_tensor.float() 68 | lm_loss = torch.sum( 69 | lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() 70 | 71 | loss = lm_loss 72 | averaged_losses = average_losses_across_data_parallel_group([lm_loss]) 73 | 74 | return loss, {'lm loss': averaged_losses[0]} 75 | 76 | 77 | def forward_step(data_iterator, model): 78 | """Forward step.""" 79 | args = get_args() 80 | timers = get_timers() 81 | 82 | # Get the batch. 83 | timers('batch generator', log_level=2).start() 84 | tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \ 85 | = get_batch(data_iterator) 86 | timers('batch generator').stop() 87 | 88 | # Forward model lm_labels 89 | # print_rank_0('e' + str(tokens_enc.shape)) 90 | # print_rank_0('d' + str(tokens_dec.shape)) 91 | # from megatron import get_tokenizer 92 | # tokenizer = get_tokenizer() 93 | # import pdb; pdb.set_trace() 94 | output_tensor = model(tokens_enc, 95 | tokens_dec, 96 | enc_mask, 97 | dec_mask, 98 | enc_dec_mask, 99 | tokentype_ids=None, 100 | lm_labels=lm_labels) 101 | 102 | return output_tensor, partial(loss_func, loss_mask) 103 | 104 | 105 | def train_valid_test_datasets_provider(train_val_test_num_samples): 106 | """Build train, valid, and test datasets.""" 107 | args = get_args() 108 | 109 | print_rank_0('> building train, validation, and test datasets for T5 ...') 110 | train_ds, valid_ds, test_ds = build_train_valid_test_datasets( 111 | train_data_prefix=args.train_data_path, 112 | valid_data_prefix=args.valid_data_path, 113 | data_impl=args.data_impl, 114 | train_samples=train_val_test_num_samples[0], 115 | valid_samples=train_val_test_num_samples[1], 116 | ul2_type=args.ul2_type, 117 | max_seq_length=args.encoder_seq_length, 118 | max_seq_length_dec=args.decoder_seq_length, 119 | seed=args.seed, 120 | skip_warmup=(not args.mmap_warmup) 121 | ) 122 | print_rank_0("> finished creating T5 datasets ...") 123 | 124 | return train_ds, valid_ds, test_ds 125 | 126 | if __name__ == "__main__": 127 | 128 | pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder, 129 | forward_step, args_defaults={'tokenizer_type': 'SentencePieceTokenizer'}) 130 | -------------------------------------------------------------------------------- /training/scripts/data_process_flan.sh: -------------------------------------------------------------------------------- 1 | python tools/preprocess_data_finetune.py \ 2 | --json-file /data/flan.json \ 3 | --input-column input \ 4 | --target-column target \ 5 | --tokenizer-model /data/tokenizer/multilingual-spiece.model \ 6 | --vocab_extra_ids 100 \ 7 | --output-prefix /data/flan \ 8 | --dataset-impl mmap \ 9 | --workers 32 \ 10 | --log-interval 10 \ 11 | --chunk-size 8 \ 12 | -------------------------------------------------------------------------------- /training/scripts/data_process_span_corr.sh: -------------------------------------------------------------------------------- 1 | python tools/preprocess_data_pretrain.py \ 2 | --json-file /data/pretrain_data.jsonl \ 3 | --json-key text \ 4 | --group-size 568 \ 5 | \ 6 | --tokenizer-model /data/tokenizer/multilingual-spiece.model \ 7 | --vocab_extra_ids 100 \ 8 | \ 9 | --output-prefix /data/pretrain_data \ 10 | --dataset-impl mmap \ 11 | --batch-size 1000 \ 12 | --workers 16 \ 13 | --chunk-size 1 \ 14 | --log-interval 10 \ 15 | -------------------------------------------------------------------------------- /training/scripts/run_flan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | # export OMP_NUM_THREADS=24 4 | 5 | # Change for multinode config 6 | MASTER_ADDR=wxhd00 7 | MASTER_PORT=17099 8 | NNODES=4 9 | NODE_RANK=0 10 | GPUS_PER_NODE=8 11 | 12 | LOAD_PATH="/data/checkpoint/14b_main_stretch" 13 | CHECKPOINT_PATH="/data/checkpoint/14b_flan_final" 14 | TRAIN_DATA_PATH="/data/all_instruct/biflan/biflan_multitask" 15 | VALID_DATA_PATH="/data/all_instruct/biflan/biflan_multitask" 16 | TOKENIZER_PATH="/data/tokenizer/multilingual-spiece.model" 17 | TESNSORBOARD_PATH=$CHECKPOINT_PATH/tensorboard 18 | 19 | mkdir -p ${TESNSORBOARD_PATH} 20 | 21 | DISTRIBUTED_ARGS=" 22 | --nproc_per_node $GPUS_PER_NODE \ 23 | --nnodes $NNODES \ 24 | --node_rank $NODE_RANK \ 25 | --master_addr $MASTER_ADDR \ 26 | --master_port $MASTER_PORT 27 | " 28 | 29 | T5_ARGS=" 30 | --seed 322 31 | --tensor-model-parallel-size 4 \ 32 | --encoder-num-layers 12 \ 33 | --decoder-num-layers 36 \ 34 | --hidden-size 4096 \ 35 | --num-attention-heads 40 \ 36 | --kv-channels 128 \ 37 | --ffn-hidden-size 16384 \ 38 | --encoder-seq-length 1024 \ 39 | --decoder-seq-length 256 \ 40 | --max-position-embeddings 2048 \ 41 | --micro-batch-size 16 \ 42 | --global-batch-size 1024 \ 43 | --lr 0.000007 \ 44 | --train-iters 100000 \ 45 | --lr-decay-iters 50000 \ 46 | --lr-decay-style constant \ 47 | --min-lr 0.000001 \ 48 | --weight-decay 0.01 \ 49 | --lr-warmup-iters 500 \ 50 | --adam-beta1 0.9 \ 51 | --adam-beta2 0.999 \ 52 | --clip-grad 1.0 \ 53 | --fp16 \ 54 | --vocab-extra-ids 100 \ 55 | --ul2-type sample \ 56 | --pos-emb-type rotary \ 57 | --mlp-type SwiGLU \ 58 | --use-distributed-optimizer \ 59 | --no-query-key-layer-scaling \ 60 | --recompute-activations \ 61 | --attention-softmax-in-fp32 \ 62 | --override-opt_param-scheduler \ 63 | " 64 | 65 | DATA_ARGS=" 66 | --train-data-path $TRAIN_DATA_PATH \ 67 | --valid-data-path $VALID_DATA_PATH \ 68 | --tokenizer-model $TOKENIZER_PATH \ 69 | --data-impl mmap \ 70 | --num-workers 32 \ 71 | " 72 | 73 | OUTPUT_ARGS=" 74 | --log-interval 1 \ 75 | --save-interval 1500 \ 76 | --eval-interval 500000 \ 77 | --eval-iters 3 \ 78 | --tensorboard-dir $TESNSORBOARD_PATH \ 79 | " 80 | 81 | torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ 82 | $T5_ARGS \ 83 | $DATA_ARGS \ 84 | $OUTPUT_ARGS \ 85 | --distributed-backend nccl \ 86 | --save $CHECKPOINT_PATH \ 87 | --load $LOAD_PATH | tee -a $CHECKPOINT_PATH/${NODE_RANK}.log 88 | -------------------------------------------------------------------------------- /training/scripts/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | # export OMP_NUM_THREADS=24 4 | 5 | # Change for multinode config 6 | MASTER_ADDR=wxhd00 7 | MASTER_PORT=17099 8 | NNODES=4 9 | NODE_RANK=0 10 | GPUS_PER_NODE=8 11 | 12 | LOAD_PATH="/data/checkpoint/14b_main" 13 | CHECKPOINT_PATH="/data/checkpoint/14b_main" 14 | TRAIN_DATA_PATH="/data/en_zh/all_data" 15 | VALID_DATA_PATH="/data/en_zh/val_spancorr" 16 | TOKENIZER_PATH="/data/tokenizer/multilingual-spiece.model" 17 | TESNSORBOARD_PATH=$CHECKPOINT_PATH/tensorboard 18 | 19 | mkdir -p ${TESNSORBOARD_PATH} 20 | 21 | DISTRIBUTED_ARGS=" 22 | --nproc_per_node $GPUS_PER_NODE \ 23 | --nnodes $NNODES \ 24 | --node_rank $NODE_RANK \ 25 | --master_addr $MASTER_ADDR \ 26 | --master_port $MASTER_PORT 27 | " 28 | 29 | T5_ARGS=" 30 | --tensor-model-parallel-size 4 \ 31 | --encoder-num-layers 12 \ 32 | --decoder-num-layers 36 \ 33 | --hidden-size 4096 \ 34 | --num-attention-heads 40 \ 35 | --kv-channels 128 \ 36 | --ffn-hidden-size 16384 \ 37 | --encoder-seq-length 570 \ 38 | --decoder-seq-length 381 \ 39 | --max-position-embeddings 768 \ 40 | --micro-batch-size 16 \ 41 | --global-batch-size 4096 \ 42 | --lr 0.0001 \ 43 | --train-iters 200000 \ 44 | --lr-decay-iters 100000 \ 45 | --lr-decay-style cosine \ 46 | --min-lr 0.00001 \ 47 | --weight-decay 0.1 \ 48 | --lr-warmup-iters 2000 \ 49 | --adam-beta1 0.9 \ 50 | --adam-beta2 0.95 \ 51 | --adam-eps 1e-8 \ 52 | --clip-grad 1.0 \ 53 | --fp16 \ 54 | --vocab-extra-ids 100 \ 55 | --ul2-type sample \ 56 | --pos-emb-type rotary \ 57 | --mlp-type SwiGLU \ 58 | --use-distributed-optimizer \ 59 | --no-query-key-layer-scaling \ 60 | --recompute-activations \ 61 | --attention-softmax-in-fp32 \ 62 | " 63 | 64 | DATA_ARGS=" 65 | --train-data-path $TRAIN_DATA_PATH \ 66 | --valid-data-path $VALID_DATA_PATH \ 67 | --tokenizer-model $TOKENIZER_PATH \ 68 | --data-impl mmap \ 69 | --num-workers 32 \ 70 | " 71 | 72 | OUTPUT_ARGS=" 73 | --log-interval 1 \ 74 | --save-interval 500 \ 75 | --eval-interval 500 \ 76 | --eval-iters 3 \ 77 | --tensorboard-dir $TESNSORBOARD_PATH \ 78 | " 79 | 80 | torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ 81 | $T5_ARGS \ 82 | $DATA_ARGS \ 83 | $OUTPUT_ARGS \ 84 | --distributed-backend nccl \ 85 | --save $CHECKPOINT_PATH \ 86 | --load $LOAD_PATH | tee -a $CHECKPOINT_PATH/${NODE_RANK}.log 87 | -------------------------------------------------------------------------------- /training/scripts/run_stretch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | # export OMP_NUM_THREADS=24 4 | 5 | # Change for multinode config 6 | MASTER_ADDR=wxhd00 7 | MASTER_PORT=12399 8 | NNODES=4 9 | NODE_RANK=0 10 | GPUS_PER_NODE=8 11 | 12 | LOAD_PATH="/data/checkpoint/14b_main" 13 | CHECKPOINT_PATH="/data/checkpoint/14b_main_stretch" 14 | TRAIN_DATA_PATH="/data/en_zh/all_data_stretch_2048" 15 | VALID_DATA_PATH="/data/en_zh/all_data_stretch_2048" 16 | TOKENIZER_PATH="/data/tokenizer/multilingual-spiece.model" 17 | TESNSORBOARD_PATH=$CHECKPOINT_PATH/tensorboard 18 | 19 | mkdir -p ${TESNSORBOARD_PATH} 20 | 21 | DISTRIBUTED_ARGS=" 22 | --nproc_per_node $GPUS_PER_NODE \ 23 | --nnodes $NNODES \ 24 | --node_rank $NODE_RANK \ 25 | --master_addr $MASTER_ADDR \ 26 | --master_port $MASTER_PORT 27 | " 28 | 29 | T5_ARGS=" 30 | --tensor-model-parallel-size 4 \ 31 | --encoder-num-layers 12 \ 32 | --decoder-num-layers 36 \ 33 | --hidden-size 4096 \ 34 | --num-attention-heads 40 \ 35 | --kv-channels 128 \ 36 | --ffn-hidden-size 16384 \ 37 | --encoder-seq-length 1027 \ 38 | --decoder-seq-length 1025 \ 39 | --max-position-embeddings 2038 \ 40 | --micro-batch-size 4 \ 41 | --global-batch-size 1024 \ 42 | --lr 0.00004 \ 43 | --train-iters 100000 \ 44 | --lr-decay-iters 25000 \ 45 | --lr-decay-style cosine \ 46 | --min-lr 0.00001 \ 47 | --weight-decay 0.1 \ 48 | --lr-warmup-iters 0 \ 49 | --adam-beta1 0.9 \ 50 | --adam-beta2 0.95 \ 51 | --adam-eps 1e-8 \ 52 | --clip-grad 1.0 \ 53 | --fp16 \ 54 | --vocab-extra-ids 100 \ 55 | --ul2-type sample \ 56 | --pos-emb-type rotary \ 57 | --mlp-type SwiGLU \ 58 | --use-distributed-optimizer \ 59 | --no-query-key-layer-scaling \ 60 | --attention-softmax-in-fp32 \ 61 | --finetune \ 62 | " 63 | 64 | DATA_ARGS=" 65 | --train-data-path $TRAIN_DATA_PATH \ 66 | --valid-data-path $VALID_DATA_PATH \ 67 | --tokenizer-model $TOKENIZER_PATH \ 68 | --data-impl mmap \ 69 | --num-workers 32 \ 70 | " 71 | 72 | OUTPUT_ARGS=" 73 | --log-interval 1 \ 74 | --save-interval 1000 \ 75 | --eval-interval 1000 \ 76 | --eval-iters 3 \ 77 | --tensorboard-dir $TESNSORBOARD_PATH \ 78 | " 79 | 80 | torchrun $DISTRIBUTED_ARGS pretrain_t5.py \ 81 | $T5_ARGS \ 82 | $DATA_ARGS \ 83 | $OUTPUT_ARGS \ 84 | --distributed-backend nccl \ 85 | --save $CHECKPOINT_PATH \ 86 | --load $LOAD_PATH | tee -a $CHECKPOINT_PATH/${NODE_RANK}.log 87 | -------------------------------------------------------------------------------- /training/tools/linter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pathlib 4 | import subprocess 5 | 6 | 7 | def recursively_lint_files(): 8 | """Recursively lint all python files in chosen subdirectories of megatron-lm""" 9 | 10 | try: 11 | import autopep8 12 | except ModuleNotFoundError: 13 | print("Please first install autopep8 via `pip install autopep8`") 14 | return 15 | 16 | # get all python file paths from top level directory 17 | file_dir = str(pathlib.Path(__file__).parent.absolute()) 18 | working_dir = osp.join(file_dir, os.pardir) 19 | all_py_paths = set(os.path.join(working_dir, fname) 20 | for fname in os.listdir(working_dir) if ".py" in fname) 21 | 22 | # get all python file paths from chosen subdirectories 23 | check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks'] 24 | for sub_dir in check_dirs: 25 | for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)): 26 | all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname)) 27 | 28 | print("Linting the following: ") 29 | for py_path in all_py_paths: 30 | print(py_path) 31 | command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path) 32 | subprocess.check_call(command) 33 | 34 | 35 | if __name__ == "__main__": 36 | recursively_lint_files() 37 | -------------------------------------------------------------------------------- /training/tools/merge_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 6 | 7 | from megatron.data import indexed_dataset 8 | 9 | 10 | def main(args): 11 | 12 | prefixes = set() 13 | for basename in os.listdir(args.input): 14 | prefix, ext = os.path.splitext(basename) 15 | 16 | if prefix in prefixes: 17 | continue 18 | 19 | if not os.path.isfile(os.path.join(args.input, basename)): 20 | continue 21 | 22 | ext_pair = '.bin' if ext == '.idx' else '.idx' 23 | assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \ 24 | f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}' 25 | 26 | prefixes.add(prefix) 27 | 28 | builder = None 29 | for prefix in sorted(prefixes): 30 | if builder is None: 31 | dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer') 32 | 33 | if isinstance(dataset, indexed_dataset.MMapIndexedDataset): 34 | builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype) 35 | else: 36 | builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin') 37 | 38 | del dataset 39 | 40 | builder.merge_file_(os.path.join(args.input, prefix)) 41 | 42 | builder.finalize(args.output_prefix + '.idx') 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | 48 | group = parser.add_argument_group(title='input data') 49 | group.add_argument('--input', type=str, required=True, 50 | help='Path to directory containing all document files to merge') 51 | 52 | group = parser.add_argument_group(title='output data') 53 | group.add_argument('--output-prefix', type=str, required=True, 54 | help='Path to binary output file without suffix') 55 | 56 | args = parser.parse_args() 57 | 58 | assert os.path.isdir(args.input), \ 59 | f'ERROR: {args.input} is not a directory or does not exist' 60 | 61 | assert os.path.isdir(os.path.dirname(args.output_prefix)), \ 62 | f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist' 63 | 64 | main(args) 65 | 66 | -------------------------------------------------------------------------------- /training/tools/preprocess_data_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Processing data for pretraining.""" 4 | 5 | import argparse 6 | import json 7 | import multiprocessing 8 | import os 9 | import sys 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 11 | import time 12 | 13 | import torch 14 | try: 15 | import nltk 16 | nltk_available = True 17 | except ImportError: 18 | nltk_available = False 19 | 20 | from megatron.tokenizer import build_tokenizer 21 | from megatron.data import indexed_dataset 22 | 23 | 24 | class Encoder(object): 25 | def __init__(self, args): 26 | self.args = args 27 | 28 | def initializer(self): 29 | # Use Encoder class as a container for global data 30 | Encoder.tokenizer = build_tokenizer(self.args) 31 | # Encoder.insturct_column = self.args.insturct_column 32 | Encoder.input_column = self.args.input_column 33 | Encoder.target_column = self.args.target_column 34 | Encoder.task_prefix = self.args.task_prefix 35 | 36 | def encode(self, json_line): 37 | data = json.loads(json_line) 38 | tot_len = len(json_line) 39 | # instruct_text = data[Encoder.insturct_column] 40 | instruct_text = "" 41 | source_text = instruct_text + data[Encoder.input_column] + " summarize:" 42 | target_text = Encoder.task_prefix + data[Encoder.target_column] 43 | 44 | source_sentence_ids = Encoder.tokenizer.tokenize(source_text) 45 | source_sentence_ids.append(Encoder.tokenizer.eos_id) 46 | target_sentence_ids = Encoder.tokenizer.tokenize(target_text) 47 | target_sentence_ids.append(Encoder.tokenizer.eos_id) 48 | 49 | return (source_sentence_ids, target_sentence_ids), tot_len 50 | 51 | def get_args(): 52 | parser = argparse.ArgumentParser() 53 | group = parser.add_argument_group(title='input data') 54 | group.add_argument('--json-file', type=str, required=True, help='Path to input JSON') 55 | group.add_argument('--insturct-column', type=str, default='text') 56 | group.add_argument('--input-column', type=str, default='text') 57 | group.add_argument('--target-column', type=str, default='answer') 58 | group.add_argument('--batch-size', type=int, default=1000) 59 | group.add_argument('--task-prefix', type=str, default="") 60 | 61 | group = parser.add_argument_group(title='tokenizer') 62 | group.add_argument('--tokenizer-model', type=str, required=True) 63 | group.add_argument('--vocab_extra_ids', type=int, default=0) 64 | 65 | group = parser.add_argument_group(title='output data') 66 | group.add_argument('--output-prefix', type=str, required=True, help='Path to binary output file without suffix') 67 | group.add_argument('--dataset-impl', type=str, default='mmap', choices=['lazy', 'cached', 'mmap']) 68 | 69 | group = parser.add_argument_group(title='runtime') 70 | group.add_argument('--workers', type=int, required=True, help='Number of worker processes to launch') 71 | group.add_argument('--chunk-size', type=int, required=True, help='Chunk size assigned to each worker process') 72 | group.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') 73 | args = parser.parse_args() 74 | # args.keep_empty = False 75 | 76 | # some default/dummy values for the tokenizer 77 | args.rank = 0 78 | args.make_vocab_size_divisible_by = 128 79 | args.tensor_model_parallel_size = 1 80 | args.tokenizer_type = "SentencePieceTokenizer" 81 | # args.vocab_extra_ids = 0 82 | 83 | return args 84 | 85 | def main(): 86 | args = get_args() 87 | startup_start = time.time() 88 | 89 | print("Opening", args.json_file) 90 | fin = open(args.json_file, 'r', encoding='utf-8') 91 | 92 | encoder = Encoder(args) 93 | tokenizer = build_tokenizer(args) 94 | pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) 95 | 96 | encoded_docs_batch = pool.imap(encoder.encode, fin, args.chunk_size) 97 | #encoded_docs = map(encoder.encode, fin) 98 | 99 | print(f"Vocab size: {tokenizer.vocab_size}") 100 | print(f"Output prefix: {args.output_prefix}") 101 | output_bin_file = "{}_multitask.bin".format(args.output_prefix) 102 | output_idx_file = "{}_multitask.idx".format(args.output_prefix) 103 | builder = indexed_dataset.make_builder(output_bin_file, 104 | impl=args.dataset_impl, 105 | vocab_size=tokenizer.vocab_size) 106 | 107 | startup_end = time.time() 108 | proc_start = time.time() 109 | total_docs_processed, total_sentence_created, total_bytes_processed = 0, 0, 0 110 | print("Time to startup:", startup_end - startup_start) 111 | 112 | for doc, bytes_processed in encoded_docs_batch: 113 | source, target = doc 114 | builder.add_item( 115 | source_tensor=torch.IntTensor(source), 116 | target_tensor=torch.IntTensor(target), 117 | task="multi-task", 118 | ) 119 | total_docs_processed += 1 120 | total_sentence_created += 1 121 | total_bytes_processed += bytes_processed 122 | 123 | if total_docs_processed % args.log_interval == 0: 124 | current = time.time() 125 | elapsed = current - proc_start 126 | mbs = total_bytes_processed / elapsed / 1024 / 1024 127 | print(f"{total_docs_processed} docs, {total_docs_processed / elapsed:.0f} docs/s | ", 128 | f"{total_sentence_created} sents, {total_sentence_created / elapsed:.0f} sents/s | ", 129 | f"{mbs:.2f} MB/s | ") 130 | print("Done! Now finalizing.") 131 | 132 | builder.finalize(output_idx_file) 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /training/tools/preprocess_data_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Processing data for pretraining.""" 4 | 5 | import argparse 6 | import json 7 | import multiprocessing 8 | import os 9 | import sys 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 11 | import time 12 | 13 | import torch 14 | try: 15 | import nltk 16 | nltk_available = True 17 | except ImportError: 18 | nltk_available = False 19 | 20 | from megatron.tokenizer import build_tokenizer 21 | from megatron.data import indexed_dataset 22 | 23 | 24 | class Encoder(object): 25 | def __init__(self, args): 26 | self.args = args 27 | 28 | def initializer(self): 29 | # Use Encoder class as a container for global data 30 | Encoder.tokenizer = build_tokenizer(self.args) 31 | Encoder.group_size = self.args.group_size 32 | 33 | def encode(self, json_lines): 34 | concat_sentence_ids = [] 35 | sentence_ids = None 36 | tot_len = 0 37 | # concatenate text 38 | for json_line in json_lines: 39 | data = json.loads(json_line) 40 | text_column = self.args.json_key 41 | text = data[text_column] 42 | sentence_ids = Encoder.tokenizer.tokenize(text) 43 | if len(sentence_ids) > 0: 44 | sentence_ids.append(Encoder.tokenizer.eos_id) 45 | concat_sentence_ids.extend(sentence_ids) 46 | tot_len += len(json_line) 47 | # group text 48 | group_size = Encoder.group_size 49 | total_length = len(concat_sentence_ids) // group_size * group_size 50 | group_text = [concat_sentence_ids[i : i + group_size] for i in range(0, total_length, group_size)] 51 | return group_text, tot_len 52 | 53 | def get_args(): 54 | parser = argparse.ArgumentParser() 55 | group = parser.add_argument_group(title='input data') 56 | group.add_argument('--json-file', type=str, required=True, help='Path to input JSON') 57 | group.add_argument('--json-key', type=str, default='text') 58 | group.add_argument("--group-size", type=int, required=True) 59 | 60 | group = parser.add_argument_group(title='tokenizer') 61 | group.add_argument('--tokenizer-model', type=str, required=True) 62 | group.add_argument('--vocab_extra_ids', type=int, default=0) 63 | 64 | group = parser.add_argument_group(title='output data') 65 | group.add_argument('--output-prefix', type=str, required=True, help='Path to binary output file without suffix') 66 | group.add_argument('--dataset-impl', type=str, default='mmap', choices=['lazy', 'cached', 'mmap']) 67 | 68 | group = parser.add_argument_group(title='runtime') 69 | group.add_argument('--workers', type=int, required=True, help='Number of worker processes to launch') 70 | group.add_argument('--batch-size', type=int, required=True) 71 | group.add_argument('--chunk-size', type=int, required=True, help='Chunk size assigned to each worker process') 72 | group.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') 73 | args = parser.parse_args() 74 | # args.keep_empty = False 75 | 76 | # some default/dummy values for the tokenizer 77 | args.rank = 0 78 | args.make_vocab_size_divisible_by = 128 79 | args.tensor_model_parallel_size = 1 80 | args.tokenizer_type = "SentencePieceTokenizer" 81 | # args.vocab_extra_ids = 0 82 | 83 | return args 84 | 85 | def main(): 86 | args = get_args() 87 | startup_start = time.time() 88 | 89 | print("Opening", args.json_file) 90 | fin = open(args.json_file, 'r', encoding='utf-8') 91 | 92 | encoder = Encoder(args) 93 | tokenizer = build_tokenizer(args) 94 | pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) 95 | 96 | def group_iter(line_iter, batch_size): 97 | group_text = [] 98 | for line in line_iter: 99 | group_text.append(line) 100 | if len(group_text) == batch_size: 101 | yield group_text 102 | group_text.clear() 103 | yield group_text 104 | 105 | encoded_docs_batch = pool.imap(encoder.encode, group_iter(fin, args.batch_size), args.chunk_size) 106 | #encoded_docs = map(encoder.encode, fin) 107 | 108 | print(f"Vocab size: {tokenizer.vocab_size}") 109 | print(f"Output prefix: {args.output_prefix}") 110 | output_bin_file = "{}_spancorr.bin".format(args.output_prefix) 111 | output_idx_file = "{}_spancorr.idx".format(args.output_prefix) 112 | builder = indexed_dataset.make_builder(output_bin_file, 113 | impl=args.dataset_impl, 114 | vocab_size=tokenizer.vocab_size) 115 | 116 | startup_end = time.time() 117 | proc_start = time.time() 118 | total_docs_processed, total_sentence_created, total_bytes_processed = 0, 0, 0 119 | print("Time to startup:", startup_end - startup_start) 120 | 121 | for doc, bytes_processed in encoded_docs_batch: 122 | for sentence in doc: 123 | if len(sentence) != args.group_size: 124 | print(f" > warning sentence length = {len(sentence)} != {args.group_size}") 125 | builder.add_item( 126 | source_tensor=torch.IntTensor(sentence), 127 | target_tensor=torch.IntTensor([]), 128 | task="span-corruption" 129 | ) 130 | total_docs_processed += args.batch_size 131 | total_sentence_created += len(doc) 132 | total_bytes_processed += bytes_processed 133 | if total_docs_processed % args.log_interval == 0: 134 | current = time.time() 135 | elapsed = current - proc_start 136 | mbs = total_bytes_processed / elapsed / 1024 / 1024 137 | print(f"{total_docs_processed} docs, {total_docs_processed / elapsed:.0f} docs/s | ", 138 | f"{total_sentence_created} sents, {total_sentence_created / elapsed:.0f} sents/s | ", 139 | f"{mbs:.2f} MB/s | ") 140 | print("Done! Now finalizing.") 141 | 142 | builder.finalize(output_idx_file) 143 | 144 | if __name__ == '__main__': 145 | main() 146 | --------------------------------------------------------------------------------