├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── all_hetero_configs.json ├── experimental ├── .gitignore ├── arguments.py ├── compute_mfu.py ├── config.py ├── cost_modeling.py ├── dRAM_flops_estimation.py ├── device.py ├── evaluate.py ├── globals.py ├── graph.py ├── initialize.py ├── layer_partition.py ├── machine_amounts.json ├── main.py ├── partitioner.py ├── requirements.txt ├── scripts │ ├── run_flops_compute.sh │ ├── run_mfu.sh │ ├── run_scheduling.sh │ └── run_simulation.sh ├── simulation.py ├── sort_pipelines.py ├── stage_order.py ├── strategy.py └── utils.py ├── hexiscale ├── __init__.py ├── comm_group.py ├── fsdp_comm_utils.py ├── fsdp_hooks_utils.py ├── gen_hetero_groups.py ├── gen_p2p_lists.py ├── heterogeneous_pipeline.py ├── pipe_sequential.py ├── process_args.py ├── utils.py └── wrap_modules.py ├── llama ├── __init__.py ├── arguments.py ├── generate_hetero_scripts.py ├── llama-config │ ├── llama-13b │ │ ├── llama-13b │ │ │ └── config.json │ │ └── params.json │ ├── llama-30b │ │ ├── llama-30b │ │ │ └── config.json │ │ └── params.json │ ├── llama-70b │ │ ├── llama-70b │ │ │ └── config.json │ │ └── params.json │ └── llama-7b │ │ ├── llama-7b │ │ └── config.json │ │ └── params.json ├── llama_config_utils.py ├── load_model_parameters_utils │ ├── README.md │ ├── __init__.py │ ├── create_separate_state_dicts_llama_7b.py │ ├── inv_freq.pt │ ├── load_model_parameters.py │ ├── remap_state_dict.py │ └── shard_combine_utils.py ├── modules │ ├── Llamamodel_pipeline.py │ └── hybrid_parallel_model_dist.py ├── save_checkpoint.py └── sudo_dataset.py ├── llama_train.py ├── scripts ├── batch_run_scripts.sh ├── generate_hetero_scripts.sh ├── run_iperf3.sh └── train.sh └── third_party ├── README.md ├── __init__.py ├── _runtime_utils.py └── megatron ├── __init__.py ├── arguments.py ├── checkpointing.py ├── core ├── README.md ├── __init__.py ├── enums.py ├── package_info.py ├── parallel_state.py ├── pipeline_parallel │ ├── __init__.py │ ├── p2p_communication.py │ └── schedules.py ├── requirements.txt ├── tensor_parallel │ ├── __init__.py │ ├── cross_entropy.py │ ├── data.py │ ├── layers.py │ ├── mappings.py │ ├── mappings_group.py │ ├── random.py │ └── utils.py └── utils.py ├── data ├── Makefile ├── __init__.py ├── autoaugment.py ├── bert_dataset.py ├── biencoder_dataset_utils.py ├── blendable_dataset.py ├── data_samplers.py ├── dataset_utils.py ├── gpt_dataset.py ├── helpers.cpp ├── helpers.cpython-38-x86_64-linux-gnu.so ├── ict_dataset.py ├── image_folder.py ├── indexed_dataset.py ├── orqa_wiki_dataset.py ├── realm_dataset_utils.py ├── realm_index.py ├── t5_dataset.py ├── test │ ├── test_indexed_dataset.py │ └── test_preprocess_data.sh └── vit_dataset.py ├── dist_signal_handler.py ├── fp16_deprecated └── loss_scaler.py ├── fused_kernels ├── __init__.py ├── compat.h ├── scaled_masked_softmax.cpp ├── scaled_masked_softmax.h ├── scaled_masked_softmax_cuda.cu ├── scaled_softmax.cpp ├── scaled_softmax_cuda.cu ├── scaled_upper_triang_masked_softmax.cpp ├── scaled_upper_triang_masked_softmax.h ├── scaled_upper_triang_masked_softmax_cuda.cu ├── tests │ ├── __init__.py │ └── test_fused_kernels.py └── type_shim.h ├── global_vars.py ├── indexer.py ├── initialize.py ├── memory.py ├── microbatches.py ├── model ├── __init__.py ├── bert_model.py ├── biencoder_model.py ├── classification.py ├── distributed.py ├── enums.py ├── fused_bias_gelu.py ├── fused_layer_norm.py ├── fused_softmax.py ├── gpt_model.py ├── language_model.py ├── module.py ├── multiple_choice.py ├── realm_model.py ├── retro_transformer.py ├── rotary_pos_embedding.py ├── t5_model.py ├── transformer.py ├── utils.py └── vision │ ├── classification.py │ ├── dino.py │ ├── esvit_swin_backbone.py │ ├── inpainting.py │ ├── knn_monitor.py │ ├── mit_backbone.py │ ├── swin_backbone.py │ ├── utils.py │ └── vit_backbone.py ├── mpu └── tests │ ├── __init__.py │ ├── commons.py │ ├── test_cross_entropy.py │ ├── test_data.py │ ├── test_initialize.py │ ├── test_layers.py │ └── test_random.py ├── optimizer ├── __init__.py ├── clip_grads.py ├── distrib_optimizer.py ├── grad_scaler.py └── optimizer.py ├── optimizer_param_scheduler.py ├── static └── index.html ├── text_generation ├── __init__.py ├── api.py ├── beam_utils.py ├── communication.py ├── forward_step.py ├── generation.py ├── sampling.py └── tokenization.py ├── text_generation_server.py ├── timers.py ├── tokenizer ├── __init__.py ├── bert_tokenization.py ├── gpt2_tokenization.py └── tokenizer.py ├── training.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | separate_state_dicts/ 4 | *.pickle 5 | *.swp 6 | llama-scripts-logs/ 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Not well tested yet 2 | 3 | FROM nvcr.io/nvidia/pytorch:24.02-py3 as base 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | curl \ 7 | sudo \ 8 | htop \ 9 | git \ 10 | wget \ 11 | tmux \ 12 | net-tools \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | RUN pip install --no-cache-dir transformers sentencepiece 16 | 17 | RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ 18 | cd flash-attention && cd csrc/fused_dense_lib && pip install . && cd ../layer_norm && pip install . -------------------------------------------------------------------------------- /all_hetero_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "hetero_configs": [ 3 | [4, 2 ], 4 | [2] 5 | ], 6 | 7 | "layer_partitions": [ 8 | [2, 2], 9 | [4] 10 | ], 11 | "devices_id": [ 12 | [[0, 0, 0, 0], [0, 0]], 13 | [[0, 0]] 14 | ] 15 | } -------------------------------------------------------------------------------- /experimental/.gitignore: -------------------------------------------------------------------------------- 1 | .venv/* 2 | .vscode/* 3 | metis-5.1.0/ 4 | __pycache__ 5 | .DS_Store 6 | train_scripts/ 7 | single_train_scripts/ 8 | calculate_costs.py 9 | calculate_bubble_overhead.py 10 | -------------------------------------------------------------------------------- /experimental/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('--model-size', type=str, default='llama-7b', 7 | help='Instruct which model to work on') 8 | parser.add_argument('--npipeline', type=int, default=1, 9 | help="Instruct how many pipelines to be created") 10 | parser.add_argument('--inter_bw', type=float, default=5, 11 | help='Assumed inter-machine bandwidth') 12 | parser.add_argument('--global_bsz', type=int, default=1, 13 | help="Per pipeline global-batch size") 14 | parser.add_argument('--micro_bsz', type=int, default=1, 15 | help="Per pipeline micro-batch size") 16 | parser.add_argument('--MB', type=int, nargs="+", default=None, 17 | help="Per pipeline micro-batch size") 18 | parser.add_argument('--kway', type=int, default=3, 19 | help='Instruct greedy path finding') 20 | parser.add_argument('--recompute', type=bool, default=True, 21 | help='If enabled, activation recompute is considered') 22 | parser.add_argument('--estimate_strategy', type=str, default='[]', 23 | help='Pipeline Strategies to be estimated') 24 | parser.add_argument('--strategy_device_ids', type=str, default=None, 25 | help='Pipeline Strategies to be estimated') 26 | parser.add_argument('--estimate_layer_partition', type=str, default='[]', 27 | help='Layer partitions to be estimated') 28 | parser.add_argument('--estimate_total_layers', type=int, default=32, 29 | help='Adjust when running different total layers from standard model') 30 | parser.add_argument('--accum-iter', type=int, default=1, 31 | help='Adjust batch size for more accurate time cost estimation') 32 | parser.add_argument('--verbose', type=bool, default=False, 33 | help="If enabled, outputs will be in detail") 34 | parser.add_argument('--actual_running_time', type=float, default=None, 35 | help='If provided, a MFU will be computed by it') 36 | parser.add_argument('--estimate_all', action='store_true', 37 | help='If enabled, both memory and time cost are estimated, otherwise only time cose is estimated') 38 | parser.add_argument('--machine_config_path', type=str, default=None, 39 | help='If enabled, both memory and time cost are estimated, otherwise only time cose is estimated') 40 | parser.add_argument('--not_use_tp', action='store_true', 41 | help='Whether consider tp') 42 | parser.add_argument('--zero_3', action='store_true', 43 | help='Whether consider zero-3') 44 | parser.add_argument('--log_interval', type=int, 45 | help='Log interval') 46 | parser.add_argument('--niter', type=int, 47 | help='niter') 48 | parser.add_argument('--apply_random_strategy', action='store_true', 49 | help='random strategy instead of graph partition') 50 | 51 | 52 | args = parser.parse_args() 53 | 54 | args.estimate_strategy = eval(args.estimate_strategy) 55 | args.estimate_layer_partition = eval(args.estimate_layer_partition) 56 | 57 | return args 58 | 59 | -------------------------------------------------------------------------------- /experimental/compute_mfu.py: -------------------------------------------------------------------------------- 1 | from cost_modeling import TimeCost, MemoryCost 2 | from globals import configs 3 | from arguments import get_args 4 | 5 | 6 | args = get_args() 7 | 8 | configs.L = args.estimate_total_layers 9 | 10 | strategy_cost = TimeCost(all_pipelines=[], configs=configs) 11 | 12 | if args.actual_running_time: 13 | print(f"Using provided running time, computed MFU: {round(strategy_cost.mfu(args.actual_running_time) * 100, 3)}%", ) 14 | -------------------------------------------------------------------------------- /experimental/config.py: -------------------------------------------------------------------------------- 1 | from initialize import initialize 2 | from pymetis import Options 3 | from dataclasses import dataclass 4 | from arguments import get_args 5 | 6 | 7 | @dataclass 8 | class Config: 9 | args = get_args() 10 | 11 | model_size = args.model_size 12 | 13 | # graph partition config 14 | niter = args.niter 15 | options = Options(contig=True) 16 | npipeline = args.npipeline 17 | param = [2, 0.2] # n, p for binomial 18 | K = initialize(param, (1, npipeline)) 19 | 20 | # pipeline 21 | kway = args.kway 22 | 23 | # network config 24 | inter_bw = args.inter_bw 25 | specs = None 26 | 27 | # utils 28 | device_machine_map = None 29 | 30 | # model config 31 | 32 | GLB_B = args.global_bsz * args.accum_iter 33 | # GLB_B = 5000 34 | GLB_MB = args.micro_bsz * args.accum_iter 35 | assert GLB_B % GLB_MB == 0 36 | N_MB = GLB_B // GLB_MB 37 | assert GLB_MB >= npipeline, "Too many pipelines" 38 | B = GLB_B // npipeline 39 | MB = GLB_MB // npipeline 40 | if args.MB is not None: 41 | MB = args.MB 42 | 43 | 44 | S = 4096 45 | 46 | if model_size == 'llama-30b': 47 | H = 6656 48 | L = 60 49 | N_attn_heads = 52 50 | P = 30 51 | elif model_size == 'llama-7b': 52 | H = 4096 53 | L = 32 54 | N_attn_heads = 32 55 | P = 7 56 | elif model_size == 'llama-13b': 57 | H = 5120 58 | L = 40 59 | N_attn_heads = 40 60 | P = 13 61 | elif model_size == 'llama-70b': 62 | H = 8192 63 | L = 80 64 | N_attn_heads = 64 65 | P = 7 66 | 67 | V = 32000 68 | B_type = 2 69 | 70 | T = GLB_B * S -------------------------------------------------------------------------------- /experimental/dRAM_flops_estimation.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | from device import Device 4 | from arguments import get_args 5 | import json 6 | 7 | """ 8 | Information Links: 9 | 1. https://www.nvidia.com/en-us/data-center/products/a10-gpu/ 10 | 2. https://www.nvidia.com/en-us/data-center/a100/ 11 | 3. https://shop.pegasus.hk/products/217290/NVIDIA-L20-GPU-Accelerator-PCIE-48GB 12 | 4. https://deepbaytech.com/images/nvidia-a800-datasheet-nvidia-a4-2521686-zhCN.pdf 13 | 5. https://www.nvidia.com/en-us/data-center/h100/ 14 | 6. https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf 15 | """ 16 | 17 | args = get_args() 18 | with open(args.machine_config_path, 'r') as machine_config_file: 19 | machine_config = json.load(machine_config_file) 20 | 21 | # Baseline 22 | baseline_machines = machine_config['baseline_machines'] 23 | 24 | machine_specs = machine_config['machine_specs'] 25 | 26 | # sublist is in the same type, sublist has three number, indicating n_same_machine of 2,4,8 gpus 27 | machine_amounts: dict = machine_config['machine_amounts'] 28 | 29 | ngpus = [2, 4, 8] 30 | 31 | hetero_machines = [] 32 | used_gpus = 0 33 | for name, machine_amount in machine_amounts.items(): 34 | spec = machine_specs[name] 35 | for ngpu, n in machine_amount.items(): 36 | if n == 0: 37 | continue 38 | used_gpus += int(ngpu) 39 | hetero_machines.append({"name": name, "tensor_core": spec[0], "memory": spec[1], "intra_bw": spec[2], "ngpus": int(ngpu), "n_same_machine": n}) 40 | 41 | 42 | print(f"hetero machines: {hetero_machines}, ") 43 | print(f"ngpus: {used_gpus}") 44 | # assert used_gpus <= 128 45 | 46 | 47 | baseline_mem = baseline_machines['memory'] * baseline_machines['ngpus'] * baseline_machines['n_same_machine'] 48 | baseline_tensor_core = baseline_machines['tensor_core'] * baseline_machines['ngpus'] * baseline_machines['n_same_machine'] 49 | 50 | 51 | hetero_mem = 0 52 | hetero_tensor_core = 0 53 | 54 | for machines in hetero_machines: 55 | hetero_mem += machines['memory'] * machines['ngpus'] * machines['n_same_machine'] 56 | hetero_tensor_core += machines['tensor_core'] * machines['ngpus'] * machines['n_same_machine'] 57 | 58 | print(f"baseline mem: {baseline_mem}, hetero mem: {hetero_mem}, hetero-to-baseline mem ratio:\ 59 | {hetero_mem / baseline_mem}, hetero mem over baseline: {hetero_mem > baseline_mem}") 60 | print(f"baseline flops: {baseline_tensor_core}, hetero flops: {hetero_tensor_core}, hetero-to-baseline flops ratio:\ 61 | {hetero_tensor_core / baseline_tensor_core}, hetero flops over baseline: {hetero_tensor_core > baseline_tensor_core}") 62 | 63 | -------------------------------------------------------------------------------- /experimental/device.py: -------------------------------------------------------------------------------- 1 | class Device: 2 | def __init__(self, name, machine_id, tensor_core, intra_bw, memory, **kwargs): 3 | self.name = name 4 | self.machine_id = machine_id 5 | self.tensor_core = tensor_core 6 | self.intra_bw = intra_bw 7 | self.memory = memory 8 | self.kwargs = kwargs 9 | 10 | def __str__(self): 11 | return f"{self.name}, {self.kwargs['device_id']} / {self.kwargs['machine_ngpus']}, tensor_core={self.tensor_core}, memory={self.memory}" 12 | -------------------------------------------------------------------------------- /experimental/evaluate.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from cost_modeling import * 3 | from globals import * 4 | 5 | 6 | def throughput(all_pipelines: Tuple): 7 | """given pipelines, calculate the throughput""" 8 | global configs 9 | 10 | if all_pipelines is None: 11 | return 0 12 | 13 | model_time_cost = TimeCost(all_pipelines, configs) 14 | 15 | token_throughput = model_time_cost.token_throughput() 16 | 17 | return token_throughput 18 | -------------------------------------------------------------------------------- /experimental/globals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from initialize import * 3 | from typing import List 4 | from utils import * 5 | from device import Device 6 | from config import Config 7 | 8 | np.random.seed(40404) 9 | 10 | 11 | configs = Config() 12 | 13 | def update_configs(configs): 14 | machines = create_machines_list() 15 | 16 | devices: List[Device] = create_devices(machines) 17 | 18 | device_machine_map = create_device_machine_map(devices) 19 | tensor_cores, comm_bws, comm_bws_dict = create_specs(devices, configs.inter_bw) 20 | 21 | reverse_comm_bws = (1 / np.array(comm_bws) * 1e10).tolist() 22 | 23 | configs.specs = [tensor_cores, comm_bws, comm_bws_dict, reverse_comm_bws] 24 | configs.device_machine_map = device_machine_map 25 | configs.devices = devices 26 | 27 | print("Scheduling Input Log============================================================") 28 | print("Machines:", machines) 29 | print("Total GPUs:", len(devices)) 30 | print(f"Inter Bandwidth: {configs.inter_bw} GB/s", ) 31 | 32 | 33 | print("=" * 80) 34 | 35 | return devices 36 | 37 | update_configs(configs) 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /experimental/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | # from device import Device 4 | 5 | class Graph: 6 | def __init__(self, adj, xadj, eweights, vweights, options=None): 7 | self.adj, self.xadj, self.eweights, self.vweights, self.options = adj, xadj, eweights, vweights, options 8 | 9 | self._check() 10 | 11 | def __len__(self): 12 | return len(self.xadj) - 1 13 | 14 | def _check(self): 15 | assert len(self.adj) == len(self.eweights) 16 | assert len(self.vweights) == len(self.xadj) - 1 17 | 18 | self.adj = np.array(self.adj) 19 | self.xadj = np.array(self.xadj) 20 | self.eweights = np.array(self.eweights) 21 | self.vweights = np.array(self.vweights) 22 | 23 | 24 | def construct_graph(tensor_cores, comm_bws): 25 | """ 26 | Assume the graph is connected. I.e. every value in communication bandwidth matrix is nonzero. 27 | """ 28 | 29 | tensor_cores = np.around(tensor_cores, decimals=0).astype('int') 30 | comm_bws = np.around(comm_bws, decimals=0).astype('int') 31 | 32 | adj = np.array([[i for i in range(len(tensor_cores)) if i != j] for j in range(len(tensor_cores))]).flatten() 33 | xadj = [i for i in range(0, (comm_bws.shape[0] + 1) * comm_bws.shape[1], comm_bws.shape[1])] 34 | eweights = comm_bws.flatten() 35 | 36 | 37 | G = Graph(adj=adj, xadj=xadj, eweights=eweights, vweights=tensor_cores, options=None) 38 | 39 | return G 40 | 41 | 42 | -------------------------------------------------------------------------------- /experimental/initialize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def initialize(param, size): 4 | K = np.random.binomial(param[0], param[1], size=size) + 1 5 | K = K[0] 6 | 7 | return K 8 | 9 | 10 | -------------------------------------------------------------------------------- /experimental/layer_partition.py: -------------------------------------------------------------------------------- 1 | from cost_modeling import MemoryCost 2 | import numpy as np 3 | from globals import configs 4 | 5 | def adjust_by_memory(quota, memories, strategy): 6 | global configs 7 | 8 | memory_check = [] 9 | memory_view = [] 10 | 11 | for i in range(len(quota)): 12 | flags = [1 if i == 0 else 0, 1 if i == len(quota) - 1 else 0 , 1 if i == len(quota) - 1 else 0] 13 | mem_utils = MemoryCost(device_memory=memories[i], layers=quota[i], stage_strategy=strategy[i], configs=configs, flags=flags) 14 | 15 | memory_check.append(mem_utils.if_oom()) 16 | memory_view.append(mem_utils.overall_memory()) 17 | 18 | if sum(memory_check) : 19 | return None, None 20 | 21 | return quota, memory_view 22 | 23 | 24 | def create_layer_partition(strategy): 25 | """ 26 | Assume layers are distributed proportionally, remaining layers are not so many 27 | """ 28 | global configs 29 | 30 | memories = [sum([configs.devices[strategy[i][j]].memory for j in range(len(strategy[i]))]) 31 | for i in range(len(strategy))] 32 | 33 | ngpus = len(memories) 34 | 35 | # machines = [0, 0, 1, 1, 2] 36 | # memories = [16, 16, 8, 8, 40] 37 | 38 | memory_ranking = np.argsort(memories)[::-1] 39 | 40 | quota = [0 for _ in range(ngpus)] 41 | 42 | 43 | i = 0 44 | remainings = configs.L 45 | overload = [False for _ in range(ngpus)] 46 | rank_ptr = 0 47 | while remainings: 48 | overload[i % ngpus] = (quota[i % ngpus] + 1) / configs.L > memories[i % len(memories) ] / sum(memories) 49 | if sum(overload) != ngpus and overload[i % ngpus]: 50 | i += 1 51 | elif sum(overload) == ngpus: 52 | quota[memory_ranking[rank_ptr % len(memory_ranking)] % ngpus] += 1 53 | rank_ptr += 1 54 | remainings -= 1 55 | i += 1 56 | else: 57 | quota[i % ngpus] += 1 58 | remainings -= 1 59 | i += 1 60 | 61 | 62 | quota, memory_view = adjust_by_memory(quota, memories, strategy) 63 | 64 | return quota, memory_view 65 | -------------------------------------------------------------------------------- /experimental/machine_amounts.json: -------------------------------------------------------------------------------- 1 | { 2 | "baseline_machines": {"name": "A800", "tensor_core": 312, "memory": 80, "ngpus": 8, "n_same_machine": 4}, 3 | "machine_specs" : { 4 | "A4000": [76.7, 24, 64], 5 | "A5000": [ 111.1, 24, 64], 6 | "A6000": [154.8, 48, 64], 7 | "L40": [181, 48, 64], 8 | "A100": [312, 80, 200], 9 | "H100": [1513, 80, 128], 10 | "H100-4": [1513, 80, 400], 11 | "A10": [250, 24, 32], 12 | "L4": [242, 24, 32], 13 | "3080Ti": [56.9, 12, 20 ], 14 | "3080Ti-4": [56.9, 12, 20 ], 15 | "3090": [80, 24, 2], 16 | "4090": [165.2, 24, 2], 17 | "4090-4": [165.2, 24, 20], 18 | "T4S": [65, 16, 32], 19 | "T4A": [65, 8, 32] 20 | 21 | }, 22 | "machine_amounts" : { 23 | "A100": {"3":1}, 24 | "4090": {"3":1}, 25 | "3090": {"2":1} 26 | } 27 | } -------------------------------------------------------------------------------- /experimental/main.py: -------------------------------------------------------------------------------- 1 | from partitioner import * 2 | from graph import * 3 | 4 | from evaluate import * 5 | import time 6 | 7 | args = get_args() 8 | start = time.time() 9 | 10 | npipeline = configs.npipeline 11 | 12 | # construct with reverse of bandwidth 13 | G = construct_graph(configs.specs[0], configs.specs[3]) 14 | G.options = configs.options 15 | 16 | if args.apply_random_strategy: 17 | parts = np.random.randint(0, npipeline, size=len(configs.devices)) 18 | else: 19 | parts = partitioner(G, npipeline) 20 | print("Initial partition results:", parts) 21 | 22 | # reconstruct with bandwidth 23 | G = construct_graph(configs.specs[0], configs.specs[1]) 24 | G.options = configs.options 25 | 26 | next_parts = None 27 | 28 | optimal = None 29 | optimal_npipeline = None 30 | 31 | for i in range(configs.niter): 32 | 33 | if i % args.log_interval == 0: 34 | print(npipeline) 35 | print(f"{i}-th iteration",) 36 | all_pipelines, all_sub_recovered_parts = partition_pipeline(G, parts, npipeline, i) 37 | if next_parts is not None: 38 | next_all_pipelines, next_all_sub_recovered_parts = partition_pipeline(G, next_parts, npipeline, i) 39 | print("Throughput", throughput(next_all_pipelines), throughput(all_pipelines)) 40 | if throughput(next_all_pipelines) > throughput(all_pipelines): 41 | parts = next_parts 42 | all_sub_recovered_parts = next_all_sub_recovered_parts 43 | 44 | if throughput(all_pipelines) > throughput(optimal): 45 | 46 | optimal = all_pipelines 47 | optimal_npipeline = npipeline 48 | 49 | up = np.random.randint(0, 2) 50 | 51 | if up: 52 | npipeline = min(len(configs.devices) // 5, npipeline + 1) 53 | else: 54 | npipeline = max(npipeline - 1, len(configs.devices) // configs.L + 1) 55 | 56 | configs.K = initialize(configs.param, (1, npipeline)) 57 | 58 | # construct with reverse of bandwidth 59 | G = construct_graph(configs.specs[0], configs.specs[3]) 60 | G.options = configs.options 61 | 62 | if args.apply_random_strategy: 63 | parts = np.random.randint(0, npipeline, size=len(configs.devices)) 64 | else: 65 | parts = partitioner(G, npipeline) 66 | 67 | try: 68 | assert npipeline == max(parts) + 1 69 | except AssertionError: 70 | print(npipeline) 71 | exit(0) 72 | # reconstruct with bandwidth 73 | G = construct_graph(configs.specs[0], configs.specs[1]) 74 | G.options = configs.options 75 | 76 | 77 | 78 | print("Output Log============================================================") 79 | if optimal is None: 80 | print("Some machines will OOM, failed") 81 | exit(0) 82 | 83 | optimal_simulation = TimeCost(optimal, configs) 84 | MFU = round(optimal_simulation.mfu() * 100, 3) if optimal else 0 85 | 86 | 87 | 88 | print(f"Optimal Throughput: {round(throughput(optimal), 3)}", ) 89 | print(f"Optimal MFU: {MFU }%", ) 90 | print(f'Optimal time: {round(optimal_simulation.overall_cost(), 3)}', ) 91 | print(f"N-pipeline: {optimal_npipeline}") 92 | 93 | 94 | # For detailed logs: 95 | if args.verbose: 96 | print("Optimal Placement:", ) 97 | for id, pipeline in enumerate(optimal): 98 | print(f" {id}-th pipeline: {[len(stage) for stage in pipeline[0]]}", '\n', 99 | f" - devices: {[[configs.devices[device_id].machine_id for device_id in stage] for stage in pipeline[0]]}", '\n', 100 | f" - devices name: {[[configs.devices[device_id].name for device_id in stage] for stage in pipeline[0]]}", '\n', 101 | f" - layer partitions: {pipeline[1]}", '\n', 102 | f" - memory estimation: {pipeline[2]}", '\n', 103 | ) 104 | 105 | end = time.time() 106 | print("=" * 80) 107 | print("Consumed Time(s):", round(end-start, 3)) 108 | -------------------------------------------------------------------------------- /experimental/requirements.txt: -------------------------------------------------------------------------------- 1 | contourpy==1.2.0 2 | cycler==0.12.1 3 | fonttools==4.48.1 4 | kiwisolver==1.4.5 5 | matplotlib==3.8.2 6 | metis==0.2a5 7 | networkx==3.2.1 8 | numpy==1.26.4 9 | packaging==23.2 10 | pillow==10.2.0 11 | PyMetis==2023.1.1 12 | pyparsing==3.1.1 13 | python-dateutil==2.8.2 14 | six==1.16.0 15 | -------------------------------------------------------------------------------- /experimental/scripts/run_flops_compute.sh: -------------------------------------------------------------------------------- 1 | python3 dRAM_flops_estimation.py \ 2 | --machine_config_path machine_amounts.json -------------------------------------------------------------------------------- /experimental/scripts/run_mfu.sh: -------------------------------------------------------------------------------- 1 | python3 compute_mfu.py \ 2 | --model-size llama-70b \ 3 | --niter 1 \ 4 | --accum-iter 3 \ 5 | --global_bsz 16 \ 6 | --estimate_total_layers 80 \ 7 | --actual_running_time 82.96209 \ 8 | --machine_config_path machine_amounts.json \ 9 | -------------------------------------------------------------------------------- /experimental/scripts/run_scheduling.sh: -------------------------------------------------------------------------------- 1 | python3 main.py \ 2 | --model-size llama-7b \ 3 | --npipeline 2 \ 4 | --inter_bw 0.5 \ 5 | --global_bsz 24 \ 6 | --micro_bsz 2 \ 7 | --machine_config_path machine_amounts.json \ 8 | --log_interval 20 \ 9 | --niter 40 \ 10 | --verbose true \ 11 | # --not_use_tp \ 12 | # --apply_random_strategy \ 13 | -------------------------------------------------------------------------------- /experimental/scripts/run_simulation.sh: -------------------------------------------------------------------------------- 1 | python3 simulation.py \ 2 | --model-size llama-13b \ 3 | --machine_config_path machine_amounts.json \ 4 | --niter 1 \ 5 | --npipeline 1 \ 6 | --inter_bw 1 \ 7 | --accum-iter 1 \ 8 | --global_bsz 10 \ 9 | --micro_bsz 1 \ 10 | --estimate_total_layers 40 \ 11 | --estimate_strategy '[[1,1,1,1]]' \ 12 | --estimate_layer_partition '[[25,5,5,5],]' \ 13 | --strategy_device_ids '[[[2], [3], [5], [6]], ]' \ -------------------------------------------------------------------------------- /experimental/simulation.py: -------------------------------------------------------------------------------- 1 | from cost_modeling import TimeCost, MemoryCost 2 | from globals import configs 3 | from arguments import get_args 4 | 5 | 6 | args = get_args() 7 | 8 | configs.L = args.estimate_total_layers 9 | 10 | strategies = args.estimate_strategy 11 | layer_partitions = args.estimate_layer_partition 12 | 13 | if args.strategy_device_ids is not None: 14 | strategy_device_ids = eval(args.strategy_device_ids) 15 | 16 | print(layer_partitions, configs.L) 17 | assert sum(layer_partitions[0]) == configs.L 18 | 19 | 20 | all_pipelines = [] 21 | all_pipelines_mem = [] 22 | if args.strategy_device_ids is None: 23 | rank = 0 24 | for strategy, layer_partition in zip(strategies, layer_partitions): 25 | strategy_device_ids = [] 26 | pipeline_mem = [] 27 | for stage_idx, (stage_length, nlayers) in enumerate(zip(strategy, layer_partition)): 28 | stage_device_ids = [] 29 | for _ in range(stage_length): 30 | stage_device_ids.append(rank) 31 | rank += 1 32 | strategy_device_ids.append(stage_device_ids) 33 | flags = [1 if stage_idx == 0 else 0] + [1 if stage_idx == len(strategy) - 1 else 0] * 2 34 | stage_mem = MemoryCost(device_memory=None, layers=nlayers, stage_strategy=strategy, configs=configs, flags=flags) 35 | pipeline_mem.append(stage_mem) 36 | all_pipelines.append([strategy_device_ids, layer_partition, '']) 37 | all_pipelines_mem.append(pipeline_mem) 38 | else: 39 | for strategy_device_id, layer_partition in zip(strategy_device_ids, layer_partitions): 40 | all_pipelines.append([strategy_device_id, layer_partition, '']) 41 | 42 | 43 | print(all_pipelines) 44 | # exit(0) 45 | strategy_cost = TimeCost(all_pipelines=all_pipelines, configs=configs) 46 | print("Estimation Input Log============================================================") 47 | print("All pipelines (strategy, layer partition, memory estimation):", all_pipelines) 48 | print("Time Cost Estimation", "=" * 59) 49 | print(f"Batch size per pipeline: {configs.B}, Micro batch size per pipeline: {configs.MB}") 50 | print("Time Cost:", strategy_cost.overall_cost()) 51 | print(f"DP time cost: {strategy_cost.dp_cost()}") 52 | print(f"Processed tokens: {configs.T}") 53 | print("Throughput:", strategy_cost.token_throughput()) 54 | print(f"MFU: {round(strategy_cost.mfu() * 100, 3)}%", ) 55 | 56 | if args.actual_running_time: 57 | print(f"Using provided running time, computed MFU: {round(strategy_cost.mfu(args.actual_running_time) * 100, 3)}%", ) 58 | 59 | 60 | if not args.estimate_all: 61 | exit(0) 62 | print("Mem Cost Estimation", "=" * 60) 63 | 64 | for pp_id, (pipeline_mem, strategy) in enumerate(zip(all_pipelines_mem, strategies)): 65 | for stage_id, (mem, stage_length) in enumerate(zip(pipeline_mem, strategy)): 66 | print(f"For {pp_id}-th pipeline, {stage_id}-th stage (pipeline strategy: {strategy}), \n estimeated overall memory cost: {mem.overall_memory() * stage_length },") 67 | if args.verbose: 68 | print(f"parameter memory cost: {mem.param_memory()}, activation memory: {mem.activation_memory(recompute=True)}, per device memory cost: {mem.overall_memory(recompute=True)}") 69 | 70 | -------------------------------------------------------------------------------- /experimental/sort_pipelines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def sort_all_pipelines(all_pipelines): 5 | for pipeline in all_pipelines: 6 | device_ids = [stage[0] for stage in pipeline[0]] 7 | sort_idxs = np.argsort(device_ids) 8 | 9 | for i in range(1, 3): 10 | pipeline[i] = np.take(pipeline[i], sort_idxs, axis=0).tolist() -------------------------------------------------------------------------------- /experimental/stage_order.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from graph import Graph 3 | from typing import List 4 | from copy import deepcopy 5 | from globals import configs 6 | import itertools 7 | 8 | def get_stage_order(G: Graph): 9 | """ 10 | Given a partitioned and merged graph, find the best order among stages 11 | """ 12 | 13 | if G is None: 14 | return [0] 15 | 16 | if len(G) == 1: 17 | return [1] 18 | 19 | def optimal_path_given_start(next, kway=2): 20 | 21 | def search_path(next, distance_selection): 22 | 23 | path_cum_eweights = 0 24 | path = [] 25 | path.append(next) 26 | selection = 0 27 | 28 | while len(path) != len(G.xadj) - 1: 29 | adjacents = G.adj[G.xadj[next]: G.xadj[next + 1]].tolist() 30 | connection_weight : List = deepcopy(G.eweights[G.xadj[next]: G.xadj[next + 1]].tolist()) 31 | 32 | for node in path: 33 | if node in adjacents: 34 | connection_weight[adjacents.index(node)] = 0 35 | 36 | distance_rank = np.argsort(connection_weight)[::-1] 37 | 38 | 39 | selected_adj_node_id = distance_rank[distance_selection[selection]] 40 | offset = 1 41 | while connection_weight[selected_adj_node_id] == 0 and offset <= len(connection_weight): 42 | 43 | selected_adj_node_id = distance_rank[(distance_selection[selection] + offset) % len(distance_rank) ] 44 | offset += 1 45 | 46 | if connection_weight[selected_adj_node_id] == 0: 47 | return None 48 | 49 | selection += 1 50 | next = adjacents[selected_adj_node_id] 51 | path_cum_eweights += connection_weight[selected_adj_node_id] 52 | 53 | path.append(next) 54 | 55 | return path, path_cum_eweights 56 | 57 | kway = min(kway, len(G) - 1) 58 | 59 | possible_distance_selection = [item for item in itertools.product([i for i in range(kway)], repeat=len(G)-1)] 60 | 61 | possible_paths = [search_path(next, distance_selection) for distance_selection in possible_distance_selection] 62 | 63 | cum_weights = [possible_path[0] for possible_path in possible_paths if possible_path is not None] 64 | 65 | 66 | optimal_path = possible_paths[cum_weights.index(max(cum_weights))] 67 | 68 | return optimal_path 69 | 70 | cumulated_eweight = [] 71 | 72 | paths = [] 73 | 74 | for start in range(len(G)): 75 | next = start 76 | 77 | optimal_path = optimal_path_given_start(next, configs.kway) 78 | if optimal_path is not None: 79 | paths.append(optimal_path[0]) 80 | cumulated_eweight.append(optimal_path[1]) 81 | 82 | optimal_path = paths[np.argsort(cumulated_eweight)[0]] 83 | 84 | return optimal_path -------------------------------------------------------------------------------- /experimental/strategy.py: -------------------------------------------------------------------------------- 1 | from globals import configs 2 | import numpy as np 3 | from cost_modeling import MemoryCost, TimeCost 4 | from typing import List 5 | from arguments import get_args 6 | from copy import deepcopy 7 | 8 | def get_considered_stragegies(i): 9 | """ 10 | Do not consider stages over 3 11 | """ 12 | args = get_args() 13 | considered_strategies = {} 14 | if args.not_use_tp: 15 | 16 | considered_strategies[1] = [[1]] 17 | considered_strategies[2] = [[1, 1]] 18 | considered_strategies[3] = [ [1, 1, 1]] 19 | considered_strategies[4] = [[1, 1, 1, 1]] 20 | considered_strategies[5] = [[1, 1, 1, 1, 1]] 21 | considered_strategies[6] = [[1, 1, 1, 1, 1, 1]] 22 | considered_strategies[7] = [[1, 1, 1, 1, 1, 1, 1]] 23 | considered_strategies[8] = [[1, 1, 1, 1, 1, 1, 1, 1] ] 24 | 25 | else: 26 | considered_strategies[1] = [[1]] 27 | considered_strategies[2] = [[2], [1, 1]] 28 | considered_strategies[3] = [[2, 1], [1, 1, 1]] 29 | considered_strategies[4] = [[4], [2, 2], [2, 1, 1]] 30 | considered_strategies[5] = [[4, 1], [2, 2, 1]] 31 | considered_strategies[6] = [[4, 2], [2, 2, 2]] 32 | considered_strategies[7] = [[4, 2, 1], [2, 2, 2, 1]] 33 | considered_strategies[8] = [[8], [4, 4], [4, 2, 2], ] 34 | 35 | return considered_strategies[i] 36 | 37 | def refine_strategies(strategies, pp_devices): 38 | 39 | global configs 40 | 41 | ngpu_strategy = [] 42 | 43 | def tp_first(): 44 | for possible_strategies in strategies: 45 | shortest = [0] * 100 46 | for possible_strategy in possible_strategies: 47 | 48 | shortest = possible_strategy if len(possible_strategy) < len(shortest) else shortest 49 | ngpu_strategy.extend(shortest) 50 | 51 | def pp_first(): 52 | for possible_strategies in strategies: 53 | shortest = [0] * 1 54 | for possible_strategy in possible_strategies: 55 | 56 | shortest = possible_strategy if len(possible_strategy) >= len(shortest) else shortest 57 | ngpu_strategy.extend(shortest) 58 | 59 | 60 | def tp_pp_tradeoff(): 61 | global configs 62 | 63 | offsets: List = np.cumsum([sum(possible_strategy[0]) for possible_strategy in strategies]).tolist() 64 | offsets.insert(0, 0) 65 | 66 | for i in range(len(strategies)): 67 | local_optimal = None 68 | local_minimum_cost = 1e8 69 | possible_strategies = strategies[i] 70 | for stage_strategy in possible_strategies: 71 | strategy_offsets: List = np.cumsum(stage_strategy).tolist() 72 | strategy_offsets.insert(0, 0) 73 | 74 | stage = [pp_devices[offsets[i]: offsets[i + 1]][strategy_offsets[j]: strategy_offsets[j + 1]] for j in range(len(strategy_offsets) - 1)] 75 | 76 | # simulate by fake layer partition 77 | stage_layer_partition = np.round([configs.L * ndevices // sum(stage_strategy) for ndevices in stage_strategy], decimals=0) 78 | stage_layer_partition[-1] -= sum(stage_layer_partition) - configs.L 79 | assert sum(stage_layer_partition) == configs.L 80 | 81 | strategy_cost = TimeCost(all_pipelines=[[stage, stage_layer_partition]], configs=configs) 82 | 83 | # update strategy_cost to 1e8 if oom 84 | if local_minimum_cost > strategy_cost.pipeline_cost(pp_id=0): 85 | local_minimum_cost = strategy_cost.pipeline_cost(pp_id=0) 86 | local_optimal = stage_strategy 87 | 88 | ngpu_strategy.extend(local_optimal) 89 | 90 | tp_pp_tradeoff() 91 | 92 | return ngpu_strategy 93 | 94 | 95 | def gen_strategy(recovered_parts, path): 96 | global configs 97 | 98 | recovered_parts = list(sorted(recovered_parts, key=lambda ele: path[recovered_parts.index(ele)])) 99 | 100 | # recovered_parts = [[1, 3, 6, 8, 9], [2, 4, 11, 12]] 101 | strategies = [] 102 | 103 | parts_machines = [[configs.device_machine_map[gpu] for gpu in gpus] for gpus in recovered_parts] 104 | intra_counts = [[parts_machine.count(i) for i in set(parts_machine)] 105 | for parts_machine in parts_machines] 106 | 107 | for stage_counts in intra_counts: 108 | for ngpus in stage_counts: 109 | strategies.append(get_considered_stragegies(ngpus)) 110 | 111 | 112 | # like [[0, 1], [2], [3]] 113 | flattened_parts = [recovered_parts[i][j] for i in range(len(recovered_parts)) for j in range(len(recovered_parts[i]))] 114 | 115 | # like [[2, 1], [1], [1]] 116 | ngpu_strategy = refine_strategies(strategies, flattened_parts) 117 | 118 | strategy = [] 119 | i = 0 120 | start = 0 121 | while i < len(ngpu_strategy): 122 | 123 | end = start + ngpu_strategy[i] 124 | strategy.append(flattened_parts[start : end]) 125 | 126 | start = end 127 | i += 1 128 | 129 | strategy_machine = [configs.device_machine_map[stage_device[0]] for stage_device in strategy] 130 | 131 | return strategy -------------------------------------------------------------------------------- /experimental/utils.py: -------------------------------------------------------------------------------- 1 | from device import Device 2 | import json 3 | from config import get_args 4 | 5 | 6 | def create_machines_list(): 7 | args = get_args() 8 | # hetero experiment 9 | with open(args.machine_config_path, 'r') as machine_config_file: 10 | machine_config = json.load(machine_config_file) 11 | machine_specs = machine_config['machine_specs'] 12 | 13 | 14 | # sublist is in the same type, sublist has three number, indicating n_same_machine of 2,4,8 gpus 15 | # machine_amounts = np.random.randint(1, 3, size=(len(machine_specs), 3)) 16 | machine_amounts = machine_config['machine_amounts'] 17 | 18 | ngpus = [2, 4, 8] 19 | 20 | machines = [] 21 | for name, machine_amount in machine_amounts.items(): 22 | spec = machine_specs[name] 23 | for ngpu, n in machine_amount.items(): 24 | if n == 0: 25 | continue 26 | machines.append({"name": name, "tensor_core": spec[0], "memory": spec[1], "intra_bw": spec[2], "ngpus": int(ngpu), "n_same_machine": n}) 27 | 28 | return machines 29 | 30 | 31 | def create_specs(devices, inter_bw): 32 | 33 | tensor_cores = [] # (n, ) 34 | for device in devices: 35 | tensor_cores.append(device.tensor_core * 1024*1024*1024*1024) 36 | 37 | comm_bws = [] # (n, n-1) 38 | comm_bws_dict = {} 39 | for i in range(len(devices)): 40 | comm_bw = [] 41 | for j in range(len(devices)): 42 | if i != j: 43 | bw = inter_bw if devices[i].machine_id != devices[j].machine_id else devices[i].intra_bw 44 | bw = bw * 1024*1024*1024 45 | comm_bw.append(bw) 46 | comm_bws_dict[i, j] = bw 47 | comm_bws.append(comm_bw) 48 | 49 | return tensor_cores, comm_bws, comm_bws_dict 50 | 51 | def create_device_machine_map(devices): 52 | machine_ids = [d.machine_id for d in devices] 53 | return machine_ids 54 | 55 | 56 | def create_devices(machines): 57 | devices = [] 58 | 59 | assigned_id = 0 60 | for machine in machines: 61 | for i in range(machine['n_same_machine']): 62 | for j in range(machine['ngpus']): 63 | devices.append(Device(name=machine['name'], machine_id=assigned_id, 64 | tensor_core=machine['tensor_core'], intra_bw=machine['intra_bw'], 65 | memory=machine['memory'], device_id=j + 1, machine_ngpus=machine['ngpus'])) 66 | assigned_id += 1 67 | return devices -------------------------------------------------------------------------------- /hexiscale/__init__.py: -------------------------------------------------------------------------------- 1 | from .comm_group import * 2 | from .gen_hetero_groups import * 3 | from .gen_p2p_lists import * 4 | from .wrap_modules import * 5 | from .heterogeneous_pipeline import * 6 | from .utils import * 7 | from .pipe_sequential import PipeSequential 8 | from .process_args import * -------------------------------------------------------------------------------- /hexiscale/comm_group.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CommGroup(object): 4 | def __init__(self, ranks): 5 | assert isinstance(ranks, list) or isinstance(ranks, range), 'Rank list or range should be provided to create a CommGroup!' 6 | self.ranks = sorted(list(set(list(ranks)))) 7 | self.size = len(self.ranks) 8 | self.group = torch.distributed.new_group(self.ranks) 9 | def has_rank(self, rank): 10 | if rank in self.ranks: 11 | self.intra_group_id = self.ranks.index(rank) 12 | return True 13 | return False 14 | 15 | def print(self): 16 | print(self.ranks, end = ' ') 17 | -------------------------------------------------------------------------------- /hexiscale/fsdp_hooks_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | from torch.distributed.fsdp._common_utils import _FSDPState 5 | from torch.distributed.fsdp._flat_param import FlatParamHandle 6 | from torch.distributed.fsdp._runtime_utils import _post_backward_hook 7 | import functools, weakref 8 | 9 | 10 | def pre_pipeline_forward(num_microbatches, idx, model): 11 | if num_microbatches > 1 and idx == 0: 12 | delete_ddp_backward_hook(model) 13 | 14 | 15 | def post_pipeline_forward(num_microbatches, idx, model, checkpoint_list): 16 | if num_microbatches > 1: 17 | if isinstance(model, FSDP): 18 | model = model._fsdp_wrapped_module 19 | assert(len(model)==len(checkpoint_list)) 20 | for module, checkpoint in zip(model, checkpoint_list): 21 | if idx == num_microbatches - 1: 22 | delete_fsdp_post_backward_hook(module, save_acc_grad=True, release_param=False) 23 | else: 24 | delete_fsdp_post_backward_hook(module) 25 | 26 | 27 | def pre_pipeline_backward(num_microbatches, idx, model, checkpoint_list): 28 | if num_microbatches > 1: 29 | if isinstance(model, FSDP): 30 | model = model._fsdp_wrapped_module 31 | assert(len(model)==len(checkpoint_list)) 32 | if idx == num_microbatches - 1: 33 | register_ddp_backward_hook(model) 34 | for module, checkpoint in zip(model, checkpoint_list): 35 | register_fsdp_post_backward_hook(module) 36 | 37 | 38 | def _register_post_backward_hooks_handle( 39 | state: _FSDPState, 40 | handle: FlatParamHandle, 41 | ) -> None: 42 | if not torch.is_grad_enabled(): 43 | return 44 | flat_param = handle.flat_param 45 | already_registered = hasattr(flat_param, "_post_backward_hook_state") 46 | if already_registered or not flat_param.requires_grad: 47 | return 48 | # Get the `AccumulateGrad` object 49 | acc_grad = handle.acc_grad # type: ignore[union-attr] 50 | assert acc_grad is not None 51 | hook_handle = acc_grad.register_hook( 52 | functools.partial(_post_backward_hook, state, handle) 53 | ) 54 | flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] 55 | 56 | 57 | def delete_fsdp_post_backward_hook(model, save_acc_grad=False, release_param=True): 58 | for m in model.modules(): 59 | if isinstance(m, FSDP): 60 | handles = m._handle if hasattr(m, '_handle') else m._handles 61 | 62 | 63 | if not isinstance(handles, list): 64 | handles = [handles] 65 | for handle in handles: 66 | # for handle in m._handle: 67 | flat_param = handle.flat_param 68 | if flat_param.requires_grad: 69 | if hasattr(flat_param, "_post_backward_hook_state"): 70 | 71 | if save_acc_grad: 72 | handle.acc_grad = flat_param._post_backward_hook_state[0] 73 | flat_param._post_backward_hook_state[1].remove() 74 | delattr(flat_param, "_post_backward_hook_state") # whether to reduce-scatter and release grad 75 | flat_param._post_backward_called = False 76 | if not release_param and m._is_root: 77 | m._post_backward_callback_queued = True # whether to release params, trades off an allgather between param memory 78 | 79 | 80 | def register_fsdp_post_backward_hook(model): 81 | for m in model.modules(): 82 | if isinstance(m, FSDP): 83 | handles = m._handle if hasattr(m, '_handle') else m._handles 84 | 85 | if not isinstance(handles, list): 86 | handles = [handles] 87 | for handle in handles: 88 | 89 | _register_post_backward_hooks_handle(m, handle) 90 | # if m._is_root: 91 | # m.training_state = TrainingState.IDLE 92 | m._post_backward_callback_queued = False # need to wait for post backward 93 | 94 | 95 | def delete_ddp_backward_hook(model): 96 | for m in model.modules(): 97 | # For DDP module, we need to disable gradient sync for accumulation, 98 | # and set sync manually before backward of the last microbatch. 99 | if isinstance(m, DDP): 100 | m.require_backward_grad_sync = False 101 | 102 | 103 | def register_ddp_backward_hook(model): 104 | for m in model.modules(): 105 | # For DDP module, we need to disable gradient sync for accumulation, 106 | # and set sync manually before backward of the last microbatch. 107 | if isinstance(m, DDP): 108 | m.require_forward_param_sync = True 109 | m.reducer.prepare_for_backward([]) 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /hexiscale/gen_p2p_lists.py: -------------------------------------------------------------------------------- 1 | 2 | def generate_send_recv_lists(pipeline_groups, mainline, forward_backward=False): 3 | 4 | # initialize empty send and receive lists for each rank 5 | ranks = set(rank for group in pipeline_groups for rank in group) 6 | 7 | def initialize_lists(): 8 | SendList = {rank: [] for rank in ranks} 9 | RecvList = {rank: [] for rank in ranks} 10 | SendBoolean = {rank: [] for rank in ranks} 11 | RecvBoolean = {rank: [] for rank in ranks} 12 | return SendList, RecvList, SendBoolean, RecvBoolean 13 | 14 | forward_lists = initialize_lists() 15 | backward_lists = initialize_lists() 16 | 17 | def send_append(idx_from, idx_to, p2p_lists): 18 | SendList, SendBoolean = p2p_lists[0], p2p_lists[2] 19 | # Avoid appending duplicates 20 | if group[idx_to] not in SendList[group[idx_from]]: 21 | SendList[group[idx_from]].append(group[idx_to]) 22 | SendBoolean[group[idx_from]].append(not is_mainline) 23 | 24 | def recv_append(idx_from, idx_to, p2p_lists ): 25 | RecvList, RecvBoolean = p2p_lists[1], p2p_lists[3] 26 | # Avoid appending duplicates 27 | if group[idx_from] not in RecvList[group[idx_to]]: 28 | RecvList[group[idx_to]].append(group[idx_from]) 29 | RecvBoolean[group[idx_to]].append(not is_mainline) 30 | 31 | # fill up send and receive lists based on pipeline groups 32 | for group in pipeline_groups: 33 | is_mainline = set(group) == set(mainline) 34 | for i in range(len(group) - 1): 35 | 36 | send_append(i, i + 1, forward_lists) 37 | recv_append(i, i + 1, forward_lists) 38 | 39 | 40 | if forward_backward: 41 | for group in pipeline_groups: 42 | is_mainline = set(group) == set(mainline) 43 | for i in range(len(group) - 1): 44 | send_append(i + 1, i, backward_lists) 45 | recv_append(i + 1, i, backward_lists) 46 | return forward_lists, backward_lists 47 | else: 48 | return forward_lists, None 49 | 50 | -------------------------------------------------------------------------------- /hexiscale/pipe_sequential.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Tuple 3 | 4 | class PipeSequential(nn.Sequential): 5 | """ 6 | Pipe variant of ``nn.Sequential`` which supports multiple inputs. 7 | """ 8 | 9 | def forward(self, *inputs): 10 | for module in self: 11 | if isinstance(inputs, Tuple): # type: ignore[arg-type] 12 | inputs = module(*inputs) 13 | else: 14 | # Don't expand single variables (ex: lists/Tensor) 15 | inputs = module(inputs) 16 | return inputs -------------------------------------------------------------------------------- /hexiscale/process_args.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import json 4 | 5 | def update_args(args, config): 6 | 7 | 8 | assert int(args.all_hetero_configs_path is None) + int(args.hetero_configs is None) == 1, "Choose one way to start the program" 9 | 10 | if args.all_hetero_configs_path is not None: 11 | with open(args.all_hetero_configs_path, 'r') as all_hetero_configs_file: 12 | all_hetero_configs = json.load(all_hetero_configs_file) 13 | args.hetero_configs = all_hetero_configs['hetero_configs'] 14 | args.layer_partitions = all_hetero_configs['layer_partitions'] 15 | 16 | if 'pp_layouts' in all_hetero_configs.keys(): 17 | args.pp_layouts = all_hetero_configs['pp_layouts'] 18 | else: 19 | args.layer_partitions = eval(args.layer_partitions) 20 | args.hetero_configs = eval(args.hetero_configs) 21 | if args.pp_layouts is not None: 22 | args.pp_layouts = eval(args.pp_layouts) 23 | 24 | args.lr = args.learning_rate 25 | 26 | args.seq_length = 4096 27 | config.n_positions = args.seq_length 28 | args.num_hidden_layers = args.total_layer_num 29 | config.n_layer = args.num_hidden_layers 30 | 31 | def validate_args(args, config): 32 | hetero_configs = args.hetero_configs 33 | layer_partitions = args.layer_partitions 34 | total_layer_num = args.total_layer_num 35 | global_bsz_size = args.global_bsz_size 36 | chunks = args.chunks 37 | 38 | world_size = dist.get_world_size() 39 | 40 | num_heads = config.n_head 41 | 42 | for pipeline in hetero_configs: 43 | for tp_size in pipeline: 44 | assert num_heads % tp_size == 0, "Num heads must be divisible by tp size" 45 | 46 | 47 | assert world_size == sum([sum(pipeline_ranks) for pipeline_ranks in hetero_configs]), 'Wrong hetero configs' 48 | for i, layer_partition in enumerate(layer_partitions): 49 | assert total_layer_num == sum(layer_partition), f'Wrong layer partition for pipeline {i}' 50 | 51 | length_gap = len(hetero_configs) - len(global_bsz_size) 52 | if length_gap > 0: 53 | padding_global_bsz = args.global_bsz_size[-1] 54 | padding_chunks = args.chunks[-1] 55 | 56 | for _ in range(length_gap): 57 | args.global_bsz_size.append(padding_global_bsz) 58 | args.chunks.append(padding_chunks) 59 | 60 | assert len(global_bsz_size) == len(chunks) == len(hetero_configs), 'Wrong length of globl batch size or chunks, should be the same as number of pipelines' 61 | 62 | assert args.run_iter >= args.accum_iter, 'Not enough iterations for one gradient accumulation cycle' 63 | 64 | # if args.checkpoint_layers: 65 | # assert args.checkpoint_all, "Currently only support activation recompute on all layers" 66 | 67 | 68 | for hetero_config, layer_partition in zip(hetero_configs, layer_partitions): 69 | assert len(hetero_config) == len(layer_partition), "Hetero config should have the same length as layer partition" 70 | 71 | 72 | if args.pp_layouts is not None: 73 | sorted_pp_layouts = [] 74 | for pp_layout, hetero_config in zip(args.pp_layouts, hetero_configs): 75 | sorted_pp_layout = [] 76 | for stage_layout, stage_length in zip(pp_layout, hetero_config): 77 | assert len(stage_layout) == stage_length, "Wrong pp_layouts, each stage should contain correct number of devices based on hetero_configs" 78 | sorted_pp_layout.append(list(sorted(stage_layout))) 79 | sorted_pp_layouts.append(sorted_pp_layout) 80 | 81 | args.pp_layouts = sorted_pp_layouts 82 | 83 | -------------------------------------------------------------------------------- /hexiscale/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import torch 3 | from torch.nn.parallel import DistributedDataParallel 4 | 5 | 6 | def listify_model(model: Union[torch.nn.Module, List[torch.nn.Module]]) -> List[torch.nn.Module]: 7 | if isinstance(model, list): 8 | return model 9 | return [model] 10 | 11 | 12 | def chunk_batch(inputs, chunks): 13 | 14 | if inputs is None: 15 | return inputs 16 | 17 | batches = [[] for _ in range(chunks)] 18 | # Actual number of chunks produced 19 | num_chunks = -1 20 | for input in inputs: 21 | if torch.is_tensor(input): 22 | # Chunk only tensors. 23 | tensors = input.chunk(chunks) 24 | 25 | # Validate number of chunks equal across all inputs. 26 | if num_chunks != -1 and num_chunks != len(tensors): 27 | raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') 28 | num_chunks = len(tensors) 29 | 30 | for i, tensor in enumerate(tensors): 31 | batches[i].append(tensor) 32 | else: 33 | # Replicate non-tensors or tensors wrapped with 'NoChunk'. 34 | for i in range(chunks): 35 | batches[i].append(input) 36 | 37 | # Truncate to actual number of chunks 38 | batches = batches[:num_chunks] 39 | 40 | return batches 41 | 42 | 43 | def unwrap_model(model, module_instances=(DistributedDataParallel,)): 44 | return_list = True 45 | if not isinstance(model, list): 46 | model = [model] 47 | return_list = False 48 | unwrapped_model = [] 49 | for model_module in model: 50 | while isinstance(model_module, module_instances): 51 | model_module = model_module.module 52 | unwrapped_model.append(model_module) 53 | if not return_list: 54 | return unwrapped_model[0] 55 | return unwrapped_model -------------------------------------------------------------------------------- /llama/__init__.py: -------------------------------------------------------------------------------- 1 | from .arguments import add_arguments, get_hetero_groups, set_hetero_groups 2 | from .modules.hybrid_parallel_model_dist import get_hybrid_parallel_configs, construct_hybrid_parallel_model, overwrite_megatron_args 3 | from .llama_config_utils import llama_config_to_gpt2_config, config_from_checkpoint, overwrite_configs_and_args 4 | from .load_model_parameters_utils.load_model_parameters import load_model_parameters 5 | from .sudo_dataset import DatasetForLlama 6 | from .save_checkpoint import save_checkpoint 7 | -------------------------------------------------------------------------------- /llama/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def add_arguments(parser: argparse.ArgumentParser): 4 | group = parser.add_argument_group(title='hexgen arguments') 5 | 6 | # hetro parallelism arguments 7 | group.add_argument( 8 | "--local-rank", type=int, default=-1, help="Local rank.", 9 | ) 10 | parser.add_argument( 11 | "--model_size", type=str, default='llama-7b', help="Model size.", choices=['llama-7b', 'llama-13b', 'llama-30b', 'llama-70b'] 12 | ) 13 | parser.add_argument( 14 | "--overwrite_config", type=int, default=0, help="Whether to overwrite model config" 15 | ) 16 | group.add_argument( 17 | "--initialize_on_meta", type=int, default=1, help="Whether to initialize parameters on meta device.", choices=[0, 1] 18 | ) 19 | group.add_argument( 20 | "--hidden_size", type=int, default=768, help="Hidden size of transformer model", 21 | ) 22 | group.add_argument( 23 | "--num_hidden_layers", type=int, default=12, help="Number of layers" 24 | ) 25 | group.add_argument( 26 | "-a", 27 | "--num_attention_heads", 28 | type=int, 29 | default=12, 30 | help="Number of attention heads", 31 | ) 32 | group.add_argument( 33 | "--vocab_size", type=int, default=30522, help="Total number of vocab" 34 | ) 35 | group.add_argument( 36 | "--dropout_prob", type=float, default=0.1, help="Dropout rate." 37 | ) 38 | parser.add_argument( 39 | "--mixed_precision", type=str, default='fp16', help="Mixed precision option.", choices=['fp32', 'fp16', 'bf16'], 40 | ) 41 | parser.add_argument( 42 | "--hetero_config", type=int, nargs='+', default=0, help="Give and execute heterogeneous configuration", 43 | ) 44 | 45 | parser.add_argument( 46 | "--pp_layouts", type=str, default=None, help="Give rank layouts", 47 | ) 48 | parser.add_argument( 49 | "--hetero_configs", type=str, default=None, help="Give pipeline layouts in general", 50 | ) 51 | parser.add_argument( 52 | "--layer_partitions", type=str, default=None, help="Give layer partition layouts", 53 | ) 54 | parser.add_argument( 55 | "--default_dp_type", type=str, default=None, help="Default data parallel type", choices=["ddp","zero2","zero3"], 56 | ) 57 | parser.add_argument( 58 | "--pp_partition", type=int, nargs='+', default=0, help="Give and execute pipeline configuration", 59 | ) 60 | parser.add_argument( 61 | "--total-layer-num", type=int, default=0, help='Total transformer layers to be trained', 62 | ) 63 | parser.add_argument( 64 | "--checkpoint-layers", action='store_true', help='Whether apply activation recompute on each transformer layer' 65 | ) 66 | parser.add_argument( 67 | "--checkpoint-all", action='store_true', help='Whether apply activation recompute on embedding, prenorm, cls layers' 68 | ) 69 | parser.add_argument( 70 | "--seq-parallel", action='store_true', help='Whether apply sequence parallel' 71 | ) 72 | parser.add_argument( 73 | "--accum-iter", type=int, default=1, help='Gradient accumulation cycles', 74 | ) 75 | 76 | # utils arguments 77 | parser.add_argument( 78 | "--token", type=str, default='', help="Access token to gated models", 79 | ) 80 | 81 | # training arguments 82 | parser.add_argument( 83 | "--global_bsz_size", type=int, nargs='+', default=2, help="global_bsz_size", 84 | ) 85 | parser.add_argument( 86 | "--chunks", type=int, nargs="+", default=0, help="Each pipeline chunk num", 87 | ) 88 | parser.add_argument( 89 | "--epochs", type=int, default=1, help="Training epochs", 90 | ) 91 | parser.add_argument( 92 | "--learning_rate", type=float, default=1e-4, help="Learning rate", 93 | ) 94 | group.add_argument( 95 | "-s", "--seq_length", type=int, default=128, help="Maximum sequence len" 96 | ) 97 | parser.add_argument( 98 | "--fsdp", type=int, default=1, help="Apply FSDP", choices=[0, 1], 99 | ) 100 | parser.add_argument( 101 | "--apply_strategy", type=int, default=0, help="Apply searched strategy.", choices=[0, 1], 102 | ) 103 | 104 | parser.add_argument('--profile', action='store_true', help='Enable time profiling') 105 | parser.add_argument('--profile-mem', action='store_true', help='Enable memory profiling') 106 | parser.add_argument('--run-iter', type=int, default=20) 107 | parser.add_argument('--load-params', action='store_true') 108 | parser.add_argument('--optimizer-type', type=str, default='adam') 109 | parser.add_argument('--display_one_pipeline', action='store_true') 110 | parser.add_argument('--all_hetero_configs_path', type=str, default=None, help="JSON file path to hetero configs") 111 | parser.add_argument('--pipeline-type', type=str, default="Gpipe", help="JSON file path to hetero configs") 112 | parser.add_argument('--recompute-stage-output', action='store_true', help="Whether to store stage output for backward") 113 | 114 | return parser 115 | 116 | 117 | _HETERO_GROUPS = None 118 | def get_hetero_groups(): 119 | global _HETERO_GROUPS 120 | return _HETERO_GROUPS 121 | 122 | def set_hetero_groups(hetero_groups): 123 | global _HETERO_GROUPS 124 | _HETERO_GROUPS = hetero_groups 125 | -------------------------------------------------------------------------------- /llama/llama-config/llama-13b/llama-13b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /llama/llama-config/llama-13b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /llama/llama-config/llama-30b/llama-30b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /llama/llama-config/llama-30b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /llama/llama-config/llama-70b/llama-70b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /llama/llama-config/llama-70b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} 2 | -------------------------------------------------------------------------------- /llama/llama-config/llama-7b/llama-7b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /llama/llama-config/llama-7b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /llama/llama_config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import json 4 | import os 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | from transformers import GPT2Config, LlamaConfig 9 | 10 | def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig: 11 | """Load a LlamaConfig from a checkpoint path.""" 12 | 13 | 14 | with open(Path(checkpoint_path) / model_name / 'params.json') as f: 15 | params = json.load(f) 16 | config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None, 17 | num_attention_heads=params['n_heads'], 18 | num_hidden_layers=params['n_layers'], 19 | rms_norm_eps=params['norm_eps']) 20 | return config 21 | 22 | 23 | def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig: 24 | return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json") 25 | 26 | 27 | def config_from_checkpoint( 28 | checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta" 29 | ) -> LlamaConfig: 30 | if checkpoint_format == "meta": 31 | return config_from_meta_checkpoint(checkpoint_path, model_name) 32 | else: 33 | return config_from_hf_checkpoint(checkpoint_path, model_name) 34 | 35 | 36 | def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: 37 | return GPT2Config( 38 | vocab_size=llama_config.vocab_size, 39 | n_positions=llama_config.max_position_embeddings, 40 | # n_positions=0, 41 | n_embd=llama_config.hidden_size, 42 | n_layer=llama_config.num_hidden_layers, 43 | n_head=llama_config.num_attention_heads, 44 | n_inner=llama_config.intermediate_size, 45 | activation_function='swiglu', # Hardcode since HF calls it 'silu' 46 | # Llama doesn't have dropout, idk if it's because they only release the inference code 47 | resid_pdrop=0.0, 48 | embd_pdrop=0.0, 49 | attn_pdrop=0.0, 50 | layer_norm_epsilon=llama_config.rms_norm_eps, 51 | initializer_range=llama_config.initializer_range, 52 | bos_token_id=llama_config.bos_token_id, 53 | eos_token_id=llama_config.eos_token_id, 54 | # These are new arguments not in the original GPT2Config 55 | pad_token_id=llama_config.pad_token_id, # Idk if this does anything 56 | rms_norm=True, 57 | rotary_emb_fraction=1.0, 58 | rotary_emb_interleaved=True, 59 | tie_word_embeddings=False, 60 | qkv_proj_bias=False, 61 | out_proj_bias=False, 62 | mlp_fc1_bias=False, 63 | mlp_fc2_bias=False, 64 | ) 65 | 66 | def overwrite_configs_and_args(config, args): 67 | overwrite_config = {'use_cache': False, 68 | 'use_flash_attn': args.use_flash_attn, 69 | 'fused_bias_fc': True, 70 | 'sequence_parallel': args.seq_parallel} 71 | for key, val in overwrite_config.items(): 72 | setattr(config, key, val) 73 | 74 | if args.overwrite_config: 75 | overwrite_config = {'hidden_size': args.hidden_size, 76 | 'max_position_embeddings': args.seq_length, 77 | 'num_hidden_layers': args.num_hidden_layers, 78 | 'vocab_size': args.vocab_size} 79 | for key, val in overwrite_config.items(): 80 | setattr(config, key, val) 81 | else: 82 | args.hidden_size = config.hidden_size 83 | args.seq_length = config.max_position_embeddings 84 | args.max_position_embeddings = config.max_position_embeddings 85 | args.num_hidden_layers = config.num_hidden_layers 86 | args.vocab_size = config.vocab_size -------------------------------------------------------------------------------- /llama/load_model_parameters_utils/README.md: -------------------------------------------------------------------------------- 1 | ## Load Model Parameters for LlaMA 7b 2 | 3 | ### Overview 4 | This guide provides instructions on how to load model parameters for Llama-7b. Specifically, we will focus on creating separate state dictionaries for each component and layer of the model. 5 | 6 | ### Customize Parameters 7 | If you need to specify custom paths, you can manually edit the `create_separate_state_dicts_llama_7b.py` script. Locate the `save_model_components` function call and adjust the paths as needed. For example: 8 | 9 | ```python 10 | save_model_components( 11 | config_path='../llama-config/', 12 | checkpoint_name='llama-7b', 13 | checkpoint_path='/path/to/Llama-2-7b-chat-hf/', 14 | num_layers=32, 15 | save_dir='./separate_state_dicts/' 16 | ) 17 | ``` 18 | 19 | Here, your sole requirement is to specify the `checkpoint_path`, as the other parameters have been pre-defined and supplied for your convenience. You can download the model checkpoints from [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). 20 | 21 | ### Run the Script 22 | To create the separate state dictionaries for the Llama-7b model, run the following command in the terminal: 23 | 24 | ```bash 25 | python3 create_separate_state_dicts_llama_7b.py 26 | ``` 27 | 28 | This script will automatically generate and save the state dictionaries in the appropriate directory. 29 | 30 | ### Verify the Output 31 | After running the script, you should find the separate state dictionaries saved in the designated folder. Verify that all the expected files are present and correctly named. 32 | 33 | ### Modifying the Inference Script 34 | In the `llama_inference.py` file, add the following code snippet to load the parameters for Llama-7b. Adjust the paths as per your setup: 35 | 36 | ```python 37 | # Load model checkpoints with respect to hetero_config 38 | tp_ranks_whole_model = hetero_groups['tp_ranks_whole_model'] 39 | tp_group_list = hetero_groups['tp_rank_groups'] 40 | state_dicts_path = "./load_model_parameters_utils/" 41 | load_model_parameters(model, config, state_dicts_path, tp_ranks_whole_model, tp_group_list, rank) 42 | ``` 43 | -------------------------------------------------------------------------------- /llama/load_model_parameters_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexiScale/5ee1726bd761b6ba57d3315a35c2cf84ce6ca57e/llama/load_model_parameters_utils/__init__.py -------------------------------------------------------------------------------- /llama/load_model_parameters_utils/create_separate_state_dicts_llama_7b.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | sys.path.insert(0, '..') 5 | sys.path.insert(0, '../site-package') 6 | from llama_config_utils import llama_config_to_gpt2_config, config_from_checkpoint, overwrite_configs_and_args 7 | from transformers import LlamaForCausalLM, LlamaTokenizer 8 | from remap_state_dict import remap_state_dict_hf_llama 9 | 10 | def load_remapped_state_dict(config, checkpoint_path): 11 | 12 | """ 13 | Loads and remaps the state dictionary of a pretrained Llama model. 14 | 15 | Parameters: 16 | - config (dict): Configuration dictionary for the Llama model. 17 | 18 | Returns: 19 | - dict: Remapped state dictionary suitable for the specific configuration. 20 | """ 21 | 22 | state_dict = remap_state_dict_hf_llama(LlamaForCausalLM.from_pretrained(f"{checkpoint_path}").state_dict(), config) 23 | return state_dict 24 | 25 | def save_model_components(config_path, checkpoint_name, checkpoint_path, num_layers, save_dir): 26 | 27 | """ 28 | Save specific components and each transformer layer of a model's state dictionary to separate files. 29 | 30 | Args: 31 | config_path (str): Path to the configuration directory. 32 | checkpoint_name (str): Name of the model checkpoint. 33 | num_layers (int): Number of transformer layers in the model. 34 | save_dir (str): Directory path where the state dictionaries will be saved. 35 | 36 | This function performs the following steps: 37 | 1. Load the configuration and state dictionary for the model. 38 | 2. Save specific components of the state dictionary (embeddings, layer normalization, and language model head). 39 | 3. Iterate over each transformer layer and save its state dictionary separately. 40 | """ 41 | 42 | # Configuration and state dictionary loading 43 | llama_config = config_from_checkpoint(config_path, checkpoint_name) 44 | config = llama_config_to_gpt2_config(llama_config) 45 | state_dict = load_remapped_state_dict(config, checkpoint_path) 46 | 47 | # Saving specific components of the state dictionary to separate files 48 | torch.save(state_dict['transformer.embeddings.word_embeddings.weight'], f'{save_dir}/embeddings.pt') 49 | torch.save(state_dict['transformer.ln_f.weight'], f'{save_dir}/ln_f.pt') 50 | torch.save(state_dict['lm_head.weight'], f'{save_dir}/lm_head.pt') 51 | 52 | # Save the state dictionary of each transformer layer separately 53 | for idx in range(num_layers): 54 | layer_key_prefix = f'transformer.layers.{idx}' 55 | layer_state_dict = {key: value for key, value in state_dict.items() if key.startswith(layer_key_prefix + ".")} 56 | # print(layer_key_prefix, layer_state_dict.keys()) 57 | torch.save(layer_state_dict, f'{save_dir}/layer_{idx}.pt') 58 | 59 | def main(): 60 | # Generate model separate state_dicts 61 | if not os.path.exists("./separate_state_dicts"): 62 | os.mkdir("./separate_state_dicts") 63 | 64 | save_model_components( 65 | config_path='../llama-config/', 66 | checkpoint_name='llama-7b', 67 | checkpoint_path='../../../../Llama-2-7b-chat-hf/', 68 | num_layers=32, 69 | save_dir='./separate_state_dicts/' 70 | ) 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /llama/load_model_parameters_utils/inv_freq.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexiScale/5ee1726bd761b6ba57d3315a35c2cf84ce6ca57e/llama/load_model_parameters_utils/inv_freq.pt -------------------------------------------------------------------------------- /llama/load_model_parameters_utils/remap_state_dict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import re 5 | from collections import OrderedDict 6 | from pathlib import Path 7 | from typing import Dict, List, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from transformers import GPT2Config, LlamaConfig 12 | from einops import rearrange 13 | 14 | def remap_state_dict_hf_llama( 15 | state_dict: Dict[str, torch.Tensor], config: GPT2Config 16 | ) -> Dict[str, torch.Tensor]: 17 | """Convert the state_dict in Hugging Face format to standard GPT format. 18 | 19 | This function modifies state_dict in place. 20 | """ 21 | 22 | # Embedding 23 | def key_mapping_emb(key): 24 | return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) 25 | 26 | state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) 27 | word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") 28 | # It's possible that vocab_size is padded to be a multiple of 8, for example. 29 | pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) 30 | vocab_size = ( 31 | math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple 32 | ) 33 | state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( 34 | word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) 35 | ) 36 | 37 | # LM head 38 | if getattr(config, "tie_word_embeddings"): 39 | state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] 40 | else: 41 | output_embeddings = state_dict.pop("lm_head.weight") 42 | # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings 43 | # differently. 44 | vocab_size = ( 45 | math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) 46 | * pad_vocab_size_multiple 47 | ) 48 | # It's possible that vocab_size is padded to be a multiple of 8, for example. 49 | state_dict["lm_head.weight"] = F.pad( 50 | output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) 51 | ) 52 | 53 | # MLP 54 | for l in range(config.n_layer): 55 | # Fusing weights this way based on difference in the following: 56 | # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220 57 | # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115 58 | w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight") 59 | w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight") 60 | state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) 61 | 62 | def key_mapping_mlp(key): 63 | return re.sub( 64 | r"^model.layers.(\d+).mlp.down_proj.", 65 | r"transformer.layers.\1.mlp.fc2.", 66 | key, 67 | ) 68 | 69 | state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) 70 | 71 | # LayerNorm 72 | def key_mapping_ln(key): 73 | key = re.sub(r"^model.norm.", r"transformer.ln_f.", key) 74 | key = re.sub( 75 | r"^model.layers.(\d+).input_layernorm.", 76 | r"transformer.layers.\1.norm1.", 77 | key, 78 | ) 79 | key = re.sub( 80 | r"^model.layers.(\d+).post_attention_layernorm.", 81 | r"transformer.layers.\1.norm2.", 82 | key, 83 | ) 84 | return key 85 | 86 | state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) 87 | 88 | def inv_permute(w): 89 | # Inverse of permute implemented in: 90 | # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 91 | return rearrange( 92 | w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2 93 | ) 94 | 95 | # Attention 96 | for l in range(config.n_layer): 97 | Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight") 98 | Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight") 99 | Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") 100 | 101 | state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( 102 | [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 103 | ) 104 | # We don't store these 105 | state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) 106 | 107 | def key_mapping_attn(key): 108 | return re.sub( 109 | r"^model.layers.(\d+).self_attn.o_proj.", 110 | r"transformer.layers.\1.mixer.out_proj.", 111 | key, 112 | ) 113 | 114 | state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) 115 | return state_dict 116 | -------------------------------------------------------------------------------- /llama/modules/Llamamodel_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from flash_attn.ops.fused_dense import ColumnParallelLinear 5 | except ImportError: 6 | ColumnParallelLinear = None 7 | 8 | try: 9 | from flash_attn.ops.layer_norm import dropout_add_layer_norm 10 | except ImportError: 11 | dropout_add_layer_norm = None 12 | 13 | try: 14 | from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual 15 | except ImportError: 16 | dropout_add_layer_norm_parallel_residual = None 17 | 18 | try: 19 | from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm 20 | except ImportError: 21 | RMSNorm, dropout_add_rms_norm = None, None 22 | 23 | try: 24 | from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual 25 | except ImportError: 26 | dropout_add_rms_norm_parallel_residual = None 27 | 28 | 29 | class LlamaEmbeddings_(nn.Module): 30 | def __init__(self, model): 31 | super().__init__() 32 | model = model.transformer 33 | attrs = ['embeddings', 'process_group', 'sequence_parallel'] 34 | for key in attrs: 35 | setattr(self, key, getattr(model, key)) 36 | 37 | def label(self): 38 | return [0,0] 39 | 40 | def forward(self, input_ids, position_ids=None): 41 | 42 | embedding_kwargs = ({'combine_batch_seqlen_dim': True} 43 | if self.process_group is not None and self.sequence_parallel else {}) 44 | hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) 45 | 46 | return hidden_states 47 | 48 | class LlamaLayers_(nn.Module): 49 | def __init__(self, model, layer_idx_start, layer_idx_end): 50 | super().__init__() 51 | model = model.transformer 52 | self.layer_idx = layer_idx_start 53 | self.layers = model.layers[layer_idx_start:layer_idx_end] 54 | attrs = ['prenorm', 'parallel_block', 'process_group'] 55 | for key in attrs: 56 | setattr(self, key, getattr(model, key)) 57 | 58 | def label(self): 59 | return [1,self.layer_idx] 60 | 61 | def forward(self, hidden_states, residual=None): 62 | 63 | mixer_kwargs = ({'seqlen': hidden_states.shape[1]} 64 | if self.process_group is not None and self.sequence_parallel else {}) 65 | 66 | for layer in self.layers: 67 | if self.prenorm: 68 | if not self.parallel_block: 69 | 70 | hidden_states, residual = layer(hidden_states, residual, 71 | mixer_kwargs=mixer_kwargs) 72 | 73 | else: 74 | hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) 75 | 76 | 77 | return hidden_states 78 | 79 | class LlamaPreNorm_(nn.Module): 80 | def __init__(self, model): 81 | super().__init__() 82 | model = model.transformer 83 | self.drop_f = model.drop_f 84 | self.ln_f = model.ln_f 85 | attrs = ['fused_dropout_add_ln', 'drop_f', 'parallel_block', 'ln_f', 'prenorm', 'residual_in_fp32'] 86 | for key in attrs: 87 | setattr(self, key, getattr(model, key)) 88 | 89 | def label(self): 90 | return [2,0] 91 | 92 | def forward(self, hidden_states, residual=None): 93 | 94 | assert(residual is None) 95 | residual = None 96 | if self.prenorm: 97 | if not self.fused_dropout_add_ln: 98 | dropped = self.drop_f(hidden_states) 99 | if not self.parallel_block: 100 | residual = (dropped + residual) if residual is not None else dropped 101 | hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) 102 | else: 103 | # Set prenorm=False here since we don't need the residual 104 | if not self.parallel_block: 105 | fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm) 106 | else dropout_add_layer_norm) 107 | hidden_states = fused_add_norm_fn( 108 | hidden_states, residual, self.ln_f.weight, self.ln_f.bias, 109 | self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, 110 | residual_in_fp32=self.residual_in_fp32 111 | ) 112 | 113 | return hidden_states 114 | 115 | class LlamaCls_(nn.Module): 116 | def __init__(self, model): 117 | super().__init__() 118 | attrs = ['lm_head', 'config', 'project_out'] 119 | for key in attrs: 120 | setattr(self, key, getattr(model, key)) 121 | 122 | def label(self): 123 | return [3,0] 124 | 125 | def forward(self, hidden_states): 126 | 127 | if self.project_out is not None: 128 | hidden_states = self.project_out(hidden_states) 129 | 130 | lm_logits = self.lm_head(hidden_states) 131 | 132 | return lm_logits 133 | -------------------------------------------------------------------------------- /llama/save_checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from third_party.megatron import get_args 4 | 5 | 6 | def save_checkpoint(iter, model, optimizer): 7 | """Save a model checkpoint.""" 8 | args = get_args() 9 | 10 | # # Only rank zero of the data parallel writes to the disk. 11 | # model = unwrap_model(model) 12 | 13 | # # Collect rng state across data parallel ranks. 14 | # rng_state = get_rng_state() 15 | 16 | # Checkpoint name. 17 | checkpoint_name = f"{args.hetero_configs}-{args.layer_partitions}-{torch.distributed.get_rank()}" 18 | exit(0) 19 | 20 | # Collect args, model, RNG. 21 | if not torch.distributed.is_initialized() \ 22 | or mpu.get_data_modulo_expert_parallel_rank() == 0: 23 | 24 | # Arguments, iteration, and model. 25 | state_dict = {} 26 | state_dict['args'] = args 27 | state_dict['iteration'] = iteration 28 | if len(model) == 1: 29 | state_dict['model'] = model[0].state_dict_for_save_checkpoint() 30 | else: 31 | for i in range(len(model)): 32 | mpu.set_virtual_pipeline_model_parallel_rank(i) 33 | state_dict['model%d' % i] = \ 34 | model[i].state_dict_for_save_checkpoint() 35 | 36 | # Optimizer stuff. 37 | if not args.no_save_optim: 38 | if optimizer is not None: 39 | state_dict['optimizer'] = optimizer.state_dict() 40 | if opt_param_scheduler is not None: 41 | state_dict['opt_param_scheduler'] = \ 42 | opt_param_scheduler.state_dict() 43 | 44 | # RNG states. 45 | if not args.no_save_rng: 46 | state_dict["rng_state"] = rng_state 47 | 48 | # Save. 49 | ensure_directory_exists(checkpoint_name) 50 | torch.save(state_dict, checkpoint_name) 51 | 52 | # Wait so everyone is done (necessary) 53 | if torch.distributed.is_initialized(): 54 | torch.distributed.barrier() -------------------------------------------------------------------------------- /llama/sudo_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | 5 | class DatasetForLlama(Dataset): 6 | def __init__(self, args): 7 | self.vocab_size = args.vocab_size 8 | self.sentence_length = args.seq_length 9 | self.dataset_size = 2560 * 16 10 | 11 | self.data_length = np.random.randint(1,self.sentence_length+1,(self.dataset_size,)) 12 | self.device = args.local_rank 13 | 14 | self.input_ids = [] 15 | for i in range(self.dataset_size): 16 | sentence = np.random.randint(0,self.vocab_size,(self.sentence_length,)) 17 | sentence[self.data_length[i]:] = 0 18 | mask = np.ones((self.sentence_length,)) 19 | mask[self.data_length[i]:] = 0 20 | self.input_ids.append(sentence) 21 | 22 | self.input_ids = np.array(self.input_ids) 23 | 24 | def __len__(self): 25 | return self.dataset_size 26 | 27 | def __getitem__(self, idx): 28 | if idx >= self.dataset_size: 29 | raise IndexError 30 | input_ids = torch.LongTensor(self.input_ids[idx]).to(self.device) 31 | return input_ids -------------------------------------------------------------------------------- /scripts/batch_run_scripts.sh: -------------------------------------------------------------------------------- 1 | model_size="llama-7b" 2 | 3 | rank=0 4 | save_logs=$1 5 | pkill python 6 | bash scripts/generate_hetero_scripts.sh 7 | 8 | 9 | for script in ./llama-scripts-logs/${model_size}-scripts/*.sh; do 10 | echo "Running "$script 11 | if [[ $save_logs == "save" ]]; 12 | then 13 | bash $script > ./llama-scripts-logs/${model_size}-scripts/${rank}.txt & 14 | rank=$((1 + $rank)) 15 | else 16 | bash $script & 17 | fi 18 | done 19 | -------------------------------------------------------------------------------- /scripts/generate_hetero_scripts.sh: -------------------------------------------------------------------------------- 1 | python3 llama/generate_hetero_scripts.py \ 2 | --retain-run-file \ 3 | --model-size llama-7b \ 4 | --current_device 5 \ 5 | --master_addr 10.60.40.17 \ 6 | --master_port 9997 \ 7 | --layer-num 32 \ 8 | --micro_batch_num 120 \ 9 | --global_batch_size 120 \ 10 | --hetero_configs "[[1] * 10] * 4" \ 11 | --layer_partitions "[[4, 4, 2, 2, 4, 4, 4, 4, 2, 2] ] * 4" \ 12 | --devices_id "[[[0], [0], [1], [1], [2], [2], [3], [3], [4], [5]], [[0], [0], [1], [1], [2], [2], [3], [3], [4], [5]], [[0], [0], [1], [1], [2], [2], [3], [3], [4], [5]], [[0], [0], [1], [1], [2], [2], [3], [3], [4], [5]]]" \ 13 | --accum-iter 1 \ 14 | --run-iter 1 \ 15 | -------------------------------------------------------------------------------- /scripts/run_iperf3.sh: -------------------------------------------------------------------------------- 1 | ip=$1 2 | role=$2 3 | 4 | if [[ $role == "send" ]]; 5 | then 6 | iperf3 -c ${ip} -t 8 -P 30 -p 9992 7 | else 8 | iperf3 -s -p 9992 9 | fi 10 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | # export NCCL_P2P_DISABLE=1 2 | # export CUDA_LAUNCH_BLOCKING=1 3 | # export NCCL_SOCKET_IFNAME=tailscale0 4 | export NCCL_ALGO=Ring 5 | GPUS_PER_NODE=8 6 | # Change for multinode config 7 | MASTER_ADDR=localhost 8 | MASTER_PORT=9996 9 | NNODES=1 10 | NODE_RANK=0 11 | 12 | DISTRIBUTED_ARGS=" 13 | --nproc_per_node $GPUS_PER_NODE \ 14 | --nnodes $NNODES \ 15 | --node_rank $NODE_RANK \ 16 | --master_addr $MASTER_ADDR \ 17 | --master_port $MASTER_PORT 18 | " 19 | # PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True CUDA_VISIBLE_DEVICES=0,1 torchrun $DISTRIBUTED_ARGS llama_train.py \ 20 | # PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun $DISTRIBUTED_ARGS llama_train.py \ 21 | torchrun $DISTRIBUTED_ARGS llama_train.py \ 22 | --model_size llama-7b \ 23 | --mixed_precision fp16 \ 24 | --use-flash-attn \ 25 | --total-layer-num 16 \ 26 | --hetero_configs "[[1] * 2] * 4" \ 27 | --layer_partitions "[[8, 8] ] * 4" \ 28 | --chunks 4 \ 29 | --global_bsz_size 4 \ 30 | --accum-iter 1 \ 31 | --run-iter 2 \ 32 | --fp16 \ 33 | --display_one_pipeline \ 34 | --checkpoint-layers \ 35 | --checkpoint-all \ 36 | # --recompute-stage-output \ 37 | # --default_dp_type ddp \ 38 | # --pp_layouts "[[[0, 1, 4, 6], [2, 5, 3, 7]]]" \ -------------------------------------------------------------------------------- /third_party/README.md: -------------------------------------------------------------------------------- 1 | ### Modification on torch 2 | `_pre_forward` is modified, the original path is `software/miniconda3/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py` -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- 1 | from .megatron import * -------------------------------------------------------------------------------- /third_party/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | from .global_vars import get_args, get_retro_args 6 | from .global_vars import get_current_global_batch_size 7 | from .global_vars import get_num_microbatches 8 | from .global_vars import get_signal_handler 9 | from .global_vars import update_num_microbatches 10 | from .global_vars import get_tokenizer 11 | from .global_vars import get_tensorboard_writer 12 | from .global_vars import get_adlr_autoresume 13 | from .global_vars import get_timers 14 | from .initialize import initialize_megatron 15 | 16 | from .utils import (print_rank_0, 17 | is_last_rank, 18 | print_rank_last) 19 | -------------------------------------------------------------------------------- /third_party/megatron/core/README.md: -------------------------------------------------------------------------------- 1 | Megatron Core is a library for efficient and scalable training of transformer based models. 2 | -------------------------------------------------------------------------------- /third_party/megatron/core/__init__.py: -------------------------------------------------------------------------------- 1 | import third_party.megatron.core.parallel_state 2 | import third_party.megatron.core.tensor_parallel 3 | import third_party.megatron.core.utils 4 | 5 | # Alias parallel_state as mpu, its legacy name 6 | mpu = parallel_state 7 | 8 | __all__ = [ 9 | "parallel_state", 10 | "tensor_parallel", 11 | "utils", 12 | ] 13 | -------------------------------------------------------------------------------- /third_party/megatron/core/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import enum 4 | 5 | class ModelType(enum.Enum): 6 | encoder_or_decoder = 1 7 | encoder_and_decoder = 2 8 | -------------------------------------------------------------------------------- /third_party/megatron/core/package_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | MAJOR = 0 5 | MINOR = 1 6 | PATCH = 0 7 | PRE_RELEASE = '' 8 | 9 | # Use the following formatting: (major, minor, patch, pre-release) 10 | VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) 11 | 12 | __shortversion__ = '.'.join(map(str, VERSION[:3])) 13 | __version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) 14 | 15 | __package_name__ = 'megatron_core' 16 | __contact_names__ = 'NVIDIA' 17 | __contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email 18 | __homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage 19 | __repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core' 20 | __download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases' 21 | __description__ = 'Megatron Core - a library for efficient and scalable training of transformer based models' 22 | __license__ = 'BSD-3' 23 | __keywords__ = 'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch' 24 | -------------------------------------------------------------------------------- /third_party/megatron/core/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .schedules import get_forward_backward_func 2 | -------------------------------------------------------------------------------- /third_party/megatron/core/requirements.txt: -------------------------------------------------------------------------------- 1 | torch -------------------------------------------------------------------------------- /third_party/megatron/core/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import vocab_parallel_cross_entropy 2 | from .data import broadcast_data 3 | 4 | from .layers import ( 5 | ColumnParallelLinear, 6 | RowParallelLinear, 7 | VocabParallelEmbedding, 8 | set_tensor_model_parallel_attributes, 9 | set_defaults_if_not_set_tensor_model_parallel_attributes, 10 | copy_tensor_model_parallel_attributes, 11 | param_is_not_tensor_parallel_duplicate, 12 | linear_with_grad_accumulation_and_async_allreduce 13 | 14 | ) 15 | 16 | from .mappings import ( 17 | copy_to_tensor_model_parallel_region, 18 | gather_from_tensor_model_parallel_region, 19 | gather_from_sequence_parallel_region, 20 | scatter_to_tensor_model_parallel_region, 21 | scatter_to_sequence_parallel_region, 22 | ) 23 | 24 | from .mappings_group import ( 25 | get_tensor_model_parallel_world_size_group, 26 | get_tensor_model_parallel_rank_group, 27 | copy_to_tensor_model_parallel_region_group, 28 | gather_from_tensor_model_parallel_region_group, 29 | gather_from_sequence_parallel_region_group, 30 | reduce_from_tensor_model_parallel_region_group, 31 | scatter_to_tensor_model_parallel_region_group, 32 | scatter_to_sequence_parallel_region_group, 33 | reduce_scatter_to_sequence_parallel_region_group, 34 | ) 35 | 36 | from .random import ( 37 | checkpoint, 38 | get_cuda_rng_tracker, 39 | model_parallel_cuda_manual_seed, 40 | ) 41 | 42 | from .utils import ( 43 | split_tensor_along_last_dim, 44 | split_tensor_into_1d_equal_chunks, 45 | gather_split_1d_tensor, 46 | ) 47 | 48 | __all__ = [ 49 | # cross_entropy.py 50 | "vocab_parallel_cross_entropy", 51 | # data.py 52 | "broadcast_data", 53 | #layers.py 54 | "ColumnParallelLinear", 55 | "RowParallelLinear", 56 | "VocabParallelEmbedding", 57 | "set_tensor_model_parallel_attributes", 58 | "set_defaults_if_not_set_tensor_model_parallel_attributes", 59 | "copy_tensor_model_parallel_attributes", 60 | "param_is_not_tensor_parallel_duplicate", 61 | "linear_with_grad_accumulation_and_async_allreduce", 62 | # mappings.py 63 | "copy_to_tensor_model_parallel_region", 64 | "gather_from_tensor_model_parallel_region", 65 | "gather_from_sequence_parallel_region", 66 | # "reduce_from_tensor_model_parallel_region", 67 | "scatter_to_tensor_model_parallel_region", 68 | "scatter_to_sequence_parallel_region", 69 | # random.py 70 | "checkpoint", 71 | "get_cuda_rng_tracker", 72 | "model_parallel_cuda_manual_seed", 73 | # utils.py 74 | "split_tensor_along_last_dim", 75 | "split_tensor_into_1d_equal_chunks", 76 | "gather_split_1d_tensor", 77 | ] 78 | -------------------------------------------------------------------------------- /third_party/megatron/core/tensor_parallel/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | from third_party.megatron.core.parallel_state import ( 6 | get_tensor_model_parallel_group, 7 | get_tensor_model_parallel_rank, 8 | get_tensor_model_parallel_src_rank, 9 | ) 10 | 11 | 12 | _MAX_DATA_DIM = 5 13 | 14 | 15 | def _check_data_types(keys, data, target_dtype): 16 | """Check that all the keys have the same target data type.""" 17 | for key in keys: 18 | assert data[key].dtype == target_dtype, '{} has data type {} which '\ 19 | 'is different than {}'.format(key, data[key].dtype, target_dtype) 20 | 21 | 22 | def _build_key_size_numel_dictionaries(keys, data): 23 | """Build the size on rank 0 and broadcast.""" 24 | max_dim = _MAX_DATA_DIM 25 | sizes = [0 for _ in range(max_dim) for _ in keys] 26 | 27 | # Pack the sizes on rank zero. 28 | if get_tensor_model_parallel_rank() == 0: 29 | offset = 0 30 | for key in keys: 31 | assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' 32 | size = data[key].size() 33 | for i, s in enumerate(size): 34 | sizes[i + offset] = s 35 | offset += max_dim 36 | 37 | # Move to GPU and broadcast. 38 | sizes_cuda = torch.cuda.LongTensor(sizes) 39 | torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(), 40 | group=get_tensor_model_parallel_group()) 41 | 42 | # Move back to cpu and unpack. 43 | sizes_cpu = sizes_cuda.cpu() 44 | key_size = {} 45 | key_numel = {} 46 | total_numel = 0 47 | offset = 0 48 | for key in keys: 49 | i = 0 50 | size = [] 51 | numel = 1 52 | while sizes_cpu[offset + i] > 0: 53 | this_size = sizes_cpu[offset + i] 54 | size.append(this_size) 55 | numel *= this_size 56 | i += 1 57 | key_size[key] = size 58 | key_numel[key] = numel 59 | total_numel += numel 60 | offset += max_dim 61 | 62 | return key_size, key_numel, total_numel 63 | 64 | 65 | def broadcast_data(keys, data, datatype): 66 | """Broadcast data from rank zero of each model parallel group to the 67 | members of the same model parallel group. 68 | 69 | Arguments: 70 | keys: list of keys in the data disctionary to be broadcasted 71 | data: data dictionary of string keys and cpu tensor values. 72 | datatype: torch data type of all tensors in data associated 73 | with keys. 74 | """ 75 | # Build (key, size) and (key, number of elements) dictionaries along 76 | # with the total number of elements on all ranks. 77 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, 78 | data) 79 | 80 | # Pack on rank zero. 81 | if get_tensor_model_parallel_rank() == 0: 82 | # Check that all keys have the same data type. 83 | _check_data_types(keys, data, datatype) 84 | # Flatten the data associated with the keys 85 | flatten_data = torch.cat( 86 | [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() 87 | else: 88 | flatten_data = torch.empty(total_numel, 89 | device=torch.cuda.current_device(), 90 | dtype=datatype) 91 | 92 | # Broadcast 93 | torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(), 94 | group=get_tensor_model_parallel_group()) 95 | 96 | # Unpack 97 | output = {} 98 | offset = 0 99 | for key in keys: 100 | size = key_size[key] 101 | numel = key_numel[key] 102 | output[key] = flatten_data.narrow(0, offset, numel).view(size) 103 | offset += numel 104 | 105 | return output 106 | -------------------------------------------------------------------------------- /third_party/megatron/core/tensor_parallel/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | from typing import List, Sequence 5 | 6 | from third_party.megatron.core.utils import divide 7 | from third_party.megatron.core import parallel_state 8 | 9 | def split_tensor_along_last_dim( 10 | tensor: torch.Tensor, 11 | num_partitions: int, 12 | contiguous_split_chunks: bool = False, 13 | ) -> List[torch.Tensor]: 14 | """ Split a tensor along its last dimension. 15 | 16 | Arguments: 17 | tensor: input tensor. 18 | num_partitions: number of partitions to split the tensor 19 | contiguous_split_chunks: If True, make each chunk contiguous 20 | in memory. 21 | 22 | Returns: 23 | A list of Tensors 24 | """ 25 | # Get the size and dimension. 26 | last_dim = tensor.dim() - 1 27 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 28 | # Split. 29 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 30 | # Note: torch.split does not create contiguous tensors by default. 31 | if contiguous_split_chunks: 32 | return tuple(chunk.contiguous() for chunk in tensor_list) 33 | 34 | return tensor_list 35 | 36 | def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): 37 | """ Break a tensor into equal 1D chunks across tensor parallel ranks. 38 | 39 | Returns a Tensor or View with this rank's portion of the data. 40 | 41 | Arguments: 42 | tensor: The tensor to split 43 | 44 | Keyword Arguments: 45 | new_buffer (bool): If True, returns a new Tensor. 46 | If False, returns a view into the existing Tensor. 47 | Default is False 48 | 49 | """ 50 | partition_size = torch.numel(tensor) // \ 51 | parallel_state.get_tensor_model_parallel_world_size() 52 | start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() 53 | end_index = start_index + partition_size 54 | if new_buffer: 55 | data = torch.empty(partition_size, dtype=tensor.dtype, 56 | device=torch.cuda.current_device(), 57 | requires_grad=False) 58 | data.copy_(tensor.view(-1)[start_index:end_index]) 59 | else: 60 | data = tensor.view(-1)[start_index:end_index] 61 | return data 62 | 63 | 64 | def gather_split_1d_tensor(tensor): 65 | """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor 66 | model parallel ranks. 67 | 68 | Returns a new Tensor with the gathered data. 69 | 70 | Arguments: 71 | tensor: A Tensor or view of this rank's portion of the data. 72 | """ 73 | numel_gathered = torch.numel(tensor) * \ 74 | parallel_state.get_tensor_model_parallel_world_size() 75 | gathered = torch.empty(numel_gathered, dtype=tensor.dtype, 76 | device=torch.cuda.current_device(), 77 | requires_grad=False) 78 | # TODO: This API is experimental in pytorch (as of Feb 2022) and 79 | # this might break in future pytorch releases. We chose this API 80 | # as opposed to torch.distributed.all_gather for efficiency reasons. 81 | # This API calls directly NCCL all-gather versus the former does 82 | # internal copies and can potentially cause slow down. 83 | torch.distributed._all_gather_base(gathered, tensor, 84 | group=parallel_state.get_tensor_model_parallel_group()) 85 | return gathered 86 | 87 | 88 | class VocabUtility: 89 | """ Split the vocabulary into `world_size` chunks and return the first 90 | and last index of the vocabulary belonging to the `rank` 91 | partition: Note that indices in [fist, last) 92 | 93 | """ 94 | 95 | @staticmethod 96 | def vocab_range_from_per_partition_vocab_size( 97 | per_partition_vocab_size: int, rank, world_size: int 98 | ) -> Sequence[int]: 99 | index_f = rank * per_partition_vocab_size 100 | index_l = index_f + per_partition_vocab_size 101 | return index_f, index_l 102 | 103 | @staticmethod 104 | def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: 105 | per_partition_vocab_size = divide(global_vocab_size, world_size) 106 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 107 | per_partition_vocab_size, rank, world_size 108 | ) 109 | -------------------------------------------------------------------------------- /third_party/megatron/core/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Utility functions used throughout Megatron core""" 4 | from functools import reduce 5 | import operator 6 | 7 | import torch 8 | 9 | from third_party.megatron.core import parallel_state 10 | 11 | 12 | def ensure_divisibility(numerator, denominator): 13 | """Ensure that numerator is divisible by the denominator.""" 14 | assert numerator % denominator == 0, "{} is not divisible by {}".format( 15 | numerator, denominator 16 | ) 17 | 18 | 19 | def divide(numerator, denominator): 20 | """Ensure that numerator is divisible by the denominator and return 21 | the division value.""" 22 | ensure_divisibility(numerator, denominator) 23 | return numerator // denominator 24 | 25 | def get_attr_wrapped_model(model, attr): 26 | """Get an attribute from a wrapped model""" 27 | if isinstance(model, list): 28 | raise RuntimeError("_get_attr_wrapped_model given a list of models") 29 | 30 | while not hasattr(model, attr): 31 | if not hasattr(model, "module"): 32 | raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}") 33 | 34 | model = model.module 35 | return getattr(model, attr) 36 | 37 | def get_model_type(model): 38 | return get_attr_wrapped_model(model, 'model_type') 39 | 40 | 41 | class GlobalMemoryBuffer: 42 | """Global buffer to avoid dynamic memory allocations. 43 | Caller should ensure that buffers of the same name 44 | are not used concurrently.""" 45 | 46 | def __init__(self): 47 | self.buffer = {} 48 | 49 | def get_tensor(self, tensor_shape, dtype, name): 50 | required_len = reduce(operator.mul, tensor_shape, 1) 51 | if self.buffer.get((name, dtype), None) is None or \ 52 | self.buffer[(name, dtype)].numel() < required_len: 53 | self.buffer[(name, dtype)] = \ 54 | torch.empty(required_len, 55 | dtype=dtype, 56 | device=torch.cuda.current_device(), 57 | requires_grad=False) 58 | 59 | return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) 60 | 61 | def _kernel_make_viewless_tensor(inp, requires_grad): 62 | '''Make a viewless tensor. 63 | 64 | View tensors have the undesirable side-affect of retaining a reference 65 | to the originally-viewed tensor, even after manually setting the '.data' 66 | field. This method creates a new tensor that links to the old tensor's 67 | data, without linking the viewed tensor, referenced via the '._base' 68 | field. 69 | ''' 70 | out = torch.empty( 71 | (1,), 72 | dtype = inp.dtype, 73 | device = inp.device, 74 | requires_grad = requires_grad, 75 | ) 76 | out.data = inp.data 77 | return out 78 | 79 | class MakeViewlessTensor(torch.autograd.Function): 80 | ''' 81 | Autograd function to make a viewless tensor. 82 | 83 | This function should be used in cases where the computation graph needs 84 | to be propagated, but we only want a viewless tensor (e.g., 85 | ParallelTransformer's hidden_states). Call this function by passing 86 | 'keep_graph = True' to 'make_viewless_tensor()'. 87 | ''' 88 | @staticmethod 89 | def forward(ctx, inp, requires_grad): 90 | return _kernel_make_viewless_tensor(inp, requires_grad) 91 | @staticmethod 92 | def backward(ctx, grad_output): 93 | return grad_output, None 94 | 95 | def make_viewless_tensor(inp, requires_grad, keep_graph): 96 | ''' 97 | Entry-point for creating viewless tensors. 98 | 99 | This method should be used, rather than calling 'MakeViewlessTensor' 100 | or '_kernel_make_viewless_tensor' directly. This method acts as a 101 | switch for determining if an autograd function or a regular method 102 | should be used to create the tensor. 103 | ''' 104 | 105 | # return tensor as-is, if not a 'view' 106 | if inp._base is None: 107 | return inp 108 | 109 | # create viewless tensor 110 | if keep_graph: 111 | return MakeViewlessTensor.apply(inp, requires_grad) 112 | else: 113 | return _kernel_make_viewless_tensor(inp, requires_grad) 114 | 115 | def assert_viewless_tensor(tensor, extra_msg = None): 116 | '''Assert that a tensor is not a view (i.e., its '._base' field is 117 | not set).''' 118 | if isinstance(tensor, list): 119 | [ assert_viewless_tensor(t) for t in tensor ] 120 | return tensor 121 | if not isinstance(tensor, torch.Tensor): 122 | return tensor 123 | assert tensor._base is None, ( 124 | "Ensure tensor._base is None before setting tensor.data or storing " 125 | "tensor to memory buffer. Otherwise, a memory leak will occur (and " 126 | "likely accumulate over iterations). %s" 127 | ) % extra_msg 128 | return tensor 129 | 130 | def safely_set_viewless_tensor_data(tensor, new_data_tensor): 131 | '''Safely set tensor's '.data' field. 132 | 133 | Check first that the tensor is viewless (i.e., '._base' not set). If not, 134 | raise an exception. 135 | ''' 136 | assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) 137 | tensor.data = new_data_tensor 138 | -------------------------------------------------------------------------------- /third_party/megatron/data/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 2 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 3 | LIBNAME = helpers 4 | LIBEXT = $(shell python3-config --extension-suffix) 5 | 6 | default: $(LIBNAME)$(LIBEXT) 7 | 8 | %$(LIBEXT): %.cpp 9 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ 10 | -------------------------------------------------------------------------------- /third_party/megatron/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import indexed_dataset 2 | -------------------------------------------------------------------------------- /third_party/megatron/data/blendable_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Blendable dataset.""" 4 | 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from third_party.megatron import print_rank_0 11 | 12 | class BlendableDataset(torch.utils.data.Dataset): 13 | 14 | 15 | def __init__(self, datasets, weights, size): 16 | 17 | self.datasets = datasets 18 | num_datasets = len(datasets) 19 | assert num_datasets == len(weights) 20 | 21 | self.size = size 22 | 23 | # Normalize weights. 24 | weights = np.array(weights, dtype=np.float64) 25 | sum_weights = np.sum(weights) 26 | assert sum_weights > 0.0 27 | weights /= sum_weights 28 | 29 | # Build indicies. 30 | start_time = time.time() 31 | assert num_datasets < 255 32 | self.dataset_index = np.zeros(self.size, dtype=np.uint8) 33 | self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) 34 | 35 | from third_party.megatron.data import helpers 36 | helpers.build_blending_indices(self.dataset_index, 37 | self.dataset_sample_index, 38 | weights, num_datasets, self.size, 39 | torch.distributed.get_rank() == 0) 40 | print_rank_0('> elapsed time for building blendable dataset indices: ' 41 | '{:.2f} (sec)'.format(time.time() - start_time)) 42 | 43 | # Check size 44 | _ = self.__getitem__(self.size - 1) 45 | try: 46 | _ = self.__getitem__(self.size) 47 | raise RuntimeError('BlendedDataset size is improperly bounded') 48 | except IndexError: 49 | pass 50 | print_rank_0('> size of blendable dataset: ' 51 | '{} samples'.format(self.size)) 52 | 53 | 54 | def __len__(self): 55 | return self.size 56 | 57 | 58 | def __getitem__(self, idx): 59 | dataset_idx = self.dataset_index[idx] 60 | sample_idx = self.dataset_sample_index[idx] 61 | return { 62 | "dataset_idx" : dataset_idx, 63 | **self.datasets[dataset_idx][sample_idx], 64 | } 65 | -------------------------------------------------------------------------------- /third_party/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexiScale/5ee1726bd761b6ba57d3315a35c2cf84ce6ca57e/third_party/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /third_party/megatron/data/test/test_indexed_dataset.py: -------------------------------------------------------------------------------- 1 | # This file isn't really a formal automated test, it's just a place to 2 | # put some code used during development and manual testing of 3 | # indexed_dataset. 4 | 5 | from third_party.megatron.data import indexed_dataset 6 | from third_party.megatron.tokenizer import build_tokenizer 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import torch 12 | 13 | script_dir = os.path.dirname(os.path.realpath(__file__)) 14 | sys.path.append(os.path.join(script_dir, "../../../")) 15 | 16 | 17 | def test_indexed_dataset(args): 18 | ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) 19 | tokenizer = build_tokenizer(args) 20 | print(len(ds.doc_idx)) 21 | print(len(ds)) 22 | print(ds.doc_idx[-1]) 23 | if ds.supports_prefetch: 24 | # just prefetch the whole thing in test (so assume it is small) 25 | ds.prefetch(range(len(ds))) 26 | if args.count > len(ds.doc_idx) - 1: 27 | args.count = len(ds.doc_idx) - 1 28 | 29 | for i in range(args.count): 30 | start = ds.doc_idx[i] 31 | end = ds.doc_idx[i + 1] 32 | ids = ds[start:end] 33 | print(f"Document {i}:") 34 | print("--------------") 35 | for s in ids: 36 | assert len(s) > 0 37 | l = s.data.tolist() 38 | text = tokenizer.detokenize(l) 39 | print(text) 40 | print("---") 41 | 42 | 43 | def test_indexed_dataset_get(args): 44 | ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) 45 | tokenizer = build_tokenizer(args) 46 | size = ds.sizes[0] 47 | print(f"size: {size}") 48 | full = ds.get(0) 49 | print(full) 50 | # print(tokenizer.detokenize(full.data.tolist())) 51 | print("---") 52 | end = ds.get(0, offset=size - 10) 53 | print(end) 54 | # print(tokenizer.detokenize(end.data.tolist())) 55 | 56 | start = ds.get(0, length=10) 57 | print(start) 58 | # print(tokenizer.detokenize(start.data.tolist())) 59 | 60 | part = ds.get(0, offset=2, length=8) 61 | print(part) 62 | # print(tokenizer.detokenize(part.data.tolist())) 63 | 64 | # def test_albert_dataset(args): 65 | # # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) 66 | # # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) 67 | # # ds = AlbertDataset(idataset, tokenizer) 68 | # ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, 69 | # args.epochs, args.max_num_samples, 70 | # args.masked_lm_prob, args.seq_length, 71 | # args.short_seq_prob, args.seed) 72 | # truncated = 0 73 | # total = 0 74 | # for i, s in enumerate(ds): 75 | # ids = s['text'] 76 | # tokens = ds.tokenizer.convert_ids_to_tokens(ids) 77 | # print(tokens) 78 | # if i >= args.count-1: 79 | # exit() 80 | 81 | 82 | def main(): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--data', type=str, help='prefix to data files') 85 | parser.add_argument('--dataset-impl', type=str, default='infer', 86 | choices=['lazy', 'cached', 'mmap', 'infer']) 87 | parser.add_argument('--count', type=int, default=10, 88 | help='Number of samples/documents to print') 89 | 90 | group = parser.add_argument_group(title='tokenizer') 91 | group.add_argument('--tokenizer-type', type=str, required=True, 92 | choices=['BertWordPieceLowerCase', 93 | 'GPT2BPETokenizer'], 94 | help='What type of tokenizer to use.') 95 | group.add_argument('--vocab-file', type=str, default=None, 96 | help='Path to the vocab file') 97 | group.add_argument('--merge-file', type=str, default=None, 98 | help='Path to the BPE merge file (if necessary).') 99 | 100 | parser.add_argument('--epochs', type=int, default=5, 101 | help='Number of epochs to plan for') 102 | parser.add_argument('--max-num-samples', type=int, default=None, 103 | help='Maximum number of samples to plan for') 104 | parser.add_argument('--masked-lm-prob', type=float, default=0.15, 105 | help='probability of masking tokens') 106 | parser.add_argument('--seq-length', type=int, default=512, 107 | help='maximum sequence length') 108 | parser.add_argument('--short-seq-prob', type=float, default=0.1, 109 | help='probability of creating a short sequence') 110 | parser.add_argument('--seed', type=int, default=1234, 111 | help='random seed') 112 | args = parser.parse_args() 113 | args.rank = 0 114 | args.make_vocab_size_divisible_by = 128 115 | args.tensor_model_parallel_size = 1 116 | 117 | if args.dataset_impl == "infer": 118 | args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) 119 | 120 | # test_albert_dataset(args) 121 | test_indexed_dataset_get(args) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /third_party/megatron/data/test/test_preprocess_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMPL=cached 4 | python ../preprocess_data.py \ 5 | --input test_samples.json \ 6 | --vocab vocab.txt \ 7 | --dataset-impl ${IMPL} \ 8 | --output-prefix test_samples_${IMPL} \ 9 | --workers 1 \ 10 | --log-interval 2 11 | -------------------------------------------------------------------------------- /third_party/megatron/dist_signal_handler.py: -------------------------------------------------------------------------------- 1 | import signal 2 | 3 | import torch 4 | 5 | 6 | def get_world_size(): 7 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 8 | world_size = torch.distributed.get_world_size() 9 | else: 10 | world_size = 1 11 | return world_size 12 | 13 | 14 | def get_device(local_rank=None): 15 | backend = torch.distributed.get_backend() 16 | if backend == 'nccl': 17 | if local_rank is None: 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device(f'cuda:{local_rank}') 21 | elif backend == 'gloo': 22 | device = torch.device('cpu') 23 | else: 24 | raise RuntimeError 25 | return device 26 | 27 | 28 | def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): 29 | if not torch.distributed.is_available() or \ 30 | not torch.distributed.is_initialized(): 31 | return [item] 32 | 33 | device = get_device(local_rank) 34 | 35 | if group is not None: 36 | group_size = group.size() 37 | else: 38 | group_size = get_world_size() 39 | 40 | tensor = torch.tensor([item], device=device, dtype=dtype) 41 | output_tensors = [ 42 | torch.zeros(1, dtype=tensor.dtype, device=tensor.device) 43 | for _ in range(group_size) 44 | ] 45 | torch.distributed.all_gather(output_tensors, tensor, group, async_op) 46 | output = [elem.item() for elem in output_tensors] 47 | return output 48 | 49 | 50 | class DistributedSignalHandler: 51 | def __init__(self, sig=signal.SIGTERM): 52 | self.sig = sig 53 | 54 | def signals_received(self): 55 | all_received = all_gather_item( 56 | self._signal_received, dtype=torch.int32 57 | ) 58 | return all_received 59 | 60 | def __enter__(self): 61 | self._signal_received = False 62 | self.released = False 63 | self.original_handler = signal.getsignal(self.sig) 64 | 65 | def handler(signum, frame): 66 | self._signal_received = True 67 | 68 | signal.signal(self.sig, handler) 69 | 70 | return self 71 | 72 | def __exit__(self, type, value, tb): 73 | self.release() 74 | 75 | def release(self): 76 | if self.released: 77 | return False 78 | 79 | signal.signal(self.sig, self.original_handler) 80 | self.released = True 81 | return True 82 | -------------------------------------------------------------------------------- /third_party/megatron/fp16_deprecated/loss_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """For backward compatibility, we need the class definitions to deserialize.""" 4 | 5 | class LossScaler: 6 | def __init__(self, scale=1): 7 | self.cur_scale = scale 8 | 9 | class DynamicLossScaler: 10 | def __init__(self, 11 | init_scale=2**32, 12 | scale_factor=2., 13 | scale_window=1000, 14 | min_scale=1, 15 | delayed_shift=1, 16 | consecutive_hysteresis=False): 17 | self.cur_scale = init_scale 18 | self.cur_iter = 0 19 | self.last_overflow_iter = -1 20 | self.scale_factor = scale_factor 21 | self.scale_window = scale_window 22 | self.min_scale = min_scale 23 | self.delayed_shift = delayed_shift 24 | self.cur_hysteresis = delayed_shift 25 | self.consecutive_hysteresis = consecutive_hysteresis 26 | 27 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import os 4 | import pathlib 5 | import subprocess 6 | 7 | from torch.utils import cpp_extension 8 | 9 | # Setting this param to a list has a problem of generating different 10 | # compilation commands (with diferent order of architectures) and 11 | # leading to recompilation of fused kernels. Set it to empty string 12 | # to avoid recompilation and assign arch flags explicity in 13 | # extra_cuda_cflags below 14 | os.environ["TORCH_CUDA_ARCH_LIST"] = "" 15 | 16 | 17 | def load(args): 18 | 19 | # Check if cuda 11 is installed for compute capability 8.0 20 | cc_flag = [] 21 | _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( 22 | cpp_extension.CUDA_HOME) 23 | if int(bare_metal_major) >= 11: 24 | cc_flag.append('-gencode') 25 | cc_flag.append('arch=compute_80,code=sm_80') 26 | if int(bare_metal_minor) >= 7: 27 | cc_flag.append('-gencode') 28 | cc_flag.append('arch=compute_90,code=sm_90') 29 | 30 | # Build path 31 | srcpath = pathlib.Path(__file__).parent.absolute() 32 | buildpath = srcpath / 'build' 33 | _create_build_dir(buildpath) 34 | 35 | # Helper function to build the kernels. 36 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags): 37 | return cpp_extension.load( 38 | name=name, 39 | sources=sources, 40 | build_directory=buildpath, 41 | extra_cflags=['-O3',], 42 | extra_cuda_cflags=['-O3', 43 | '-gencode', 'arch=compute_70,code=sm_70', 44 | '--use_fast_math'] + extra_cuda_flags + cc_flag, 45 | verbose=(args.rank == 0) 46 | ) 47 | 48 | # ============== 49 | # Fused softmax. 50 | # ============== 51 | 52 | if args.masked_softmax_fusion: 53 | extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', 54 | '-U__CUDA_NO_HALF_CONVERSIONS__', 55 | '--expt-relaxed-constexpr', 56 | '--expt-extended-lambda'] 57 | 58 | # Upper triangular softmax. 59 | sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', 60 | srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] 61 | scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( 62 | "scaled_upper_triang_masked_softmax_cuda", 63 | sources, extra_cuda_flags) 64 | 65 | # Masked softmax. 66 | sources=[srcpath / 'scaled_masked_softmax.cpp', 67 | srcpath / 'scaled_masked_softmax_cuda.cu'] 68 | scaled_masked_softmax_cuda = _cpp_extention_load_helper( 69 | "scaled_masked_softmax_cuda", sources, extra_cuda_flags) 70 | 71 | # Softmax 72 | sources=[srcpath / 'scaled_softmax.cpp', 73 | srcpath / 'scaled_softmax_cuda.cu'] 74 | scaled_softmax_cuda = _cpp_extention_load_helper( 75 | "scaled_softmax_cuda", sources, extra_cuda_flags) 76 | 77 | 78 | def _get_cuda_bare_metal_version(cuda_dir): 79 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], 80 | universal_newlines=True) 81 | output = raw_output.split() 82 | release_idx = output.index("release") + 1 83 | release = output[release_idx].split(".") 84 | bare_metal_major = release[0] 85 | bare_metal_minor = release[1][0] 86 | 87 | return raw_output, bare_metal_major, bare_metal_minor 88 | 89 | 90 | def _create_build_dir(buildpath): 91 | try: 92 | os.mkdir(buildpath) 93 | except OSError: 94 | if not os.path.isdir(buildpath): 95 | print(f"Creation of the build directory {buildpath} failed") 96 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/compat.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | /*This code is copied fron NVIDIA apex: 4 | * https://github.com/NVIDIA/apex 5 | * with minor changes. */ 6 | 7 | 8 | 9 | #ifndef TORCH_CHECK 10 | #define TORCH_CHECK AT_CHECK 11 | #endif 12 | 13 | #ifdef VERSION_GE_1_3 14 | #define DATA_PTR data_ptr 15 | #else 16 | #define DATA_PTR data 17 | #endif 18 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/scaled_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_masked_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | torch::Tensor const& mask, 14 | float scale_factor); 15 | 16 | torch::Tensor bwd_cuda( 17 | torch::Tensor const& output_grads, 18 | torch::Tensor const& softmax_results, 19 | float scale_factor); 20 | 21 | int get_batch_per_block_cuda( 22 | int query_seq_len, 23 | int key_seq_len, 24 | int batches, 25 | int attn_heads); 26 | 27 | torch::Tensor fwd( 28 | torch::Tensor const& input, 29 | torch::Tensor const& mask, 30 | float scale_factor) { 31 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 32 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 33 | (input.scalar_type() == at::ScalarType::BFloat16), 34 | "Only fp16 and bf16 are supported"); 35 | AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); 36 | 37 | return fwd_cuda(input, mask, scale_factor); 38 | } 39 | 40 | torch::Tensor bwd( 41 | torch::Tensor const& output_grads, 42 | torch::Tensor const& softmax_results, 43 | float scale_factor) { 44 | 45 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 46 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 47 | 48 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 49 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 50 | "Only fp16 and bf16 are supported"); 51 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 52 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 53 | "Only fp16 and bf16 are supported"); 54 | 55 | return bwd_cuda(output_grads, softmax_results, scale_factor); 56 | } 57 | 58 | int get_batch_per_block( 59 | int query_seq_len, 60 | int key_seq_len, 61 | int batches, 62 | int attn_heads) { 63 | return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); 64 | } 65 | 66 | } // end namespace scaled_masked_softmax 67 | } // end namespace fused_softmax 68 | } // end namespace multihead_attn 69 | 70 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 71 | m.def("forward", 72 | &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 73 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 74 | 75 | m.def("backward", 76 | &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, 77 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 78 | 79 | m.def("get_batch_per_block", 80 | &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, 81 | "Return Batch per block size." 82 | ); 83 | } 84 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/scaled_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_masked_softmax { 16 | 17 | int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ 18 | return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); 19 | } 20 | 21 | 22 | torch::Tensor fwd_cuda( 23 | torch::Tensor const& input, 24 | torch::Tensor const& mask, 25 | float scale_factor) 26 | { 27 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 28 | const int batches = input.size(0); 29 | const int pad_batches = mask.size(0); 30 | const int attn_heads = input.size(1); 31 | const int query_seq_len = input.size(2); 32 | const int key_seq_len = input.size(3); 33 | TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); 34 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 35 | TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); 36 | TORCH_INTERNAL_ASSERT(mask.size(1) == 1); 37 | TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); 38 | TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); 39 | 40 | // Output 41 | auto act_options = input.options().requires_grad(false); 42 | torch::Tensor softmax_results = 43 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 44 | 45 | // Softmax Intermediate Result Ptr 46 | void* input_ptr = static_cast(input.data_ptr()); 47 | void* mask_ptr = static_cast(mask.data_ptr()); 48 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 49 | 50 | DISPATCH_HALF_AND_BFLOAT( 51 | input.scalar_type(), 52 | "dispatch_scaled_masked_softmax_forward", 53 | dispatch_scaled_masked_softmax_forward( 54 | reinterpret_cast(softmax_results_ptr), 55 | reinterpret_cast(input_ptr), 56 | reinterpret_cast(mask_ptr), 57 | scale_factor, 58 | query_seq_len, 59 | key_seq_len, 60 | batches, 61 | attn_heads, 62 | pad_batches); 63 | ); 64 | return softmax_results; 65 | } 66 | 67 | torch::Tensor bwd_cuda( 68 | torch::Tensor const& output_grads_, 69 | torch::Tensor const& softmax_results_, 70 | float scale_factor) { 71 | 72 | auto output_grads = output_grads_.contiguous(); 73 | auto softmax_results = softmax_results_.contiguous(); 74 | 75 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 76 | const int batches = output_grads.size(0); 77 | const int attn_heads = output_grads.size(1); 78 | const int query_seq_len = output_grads.size(2); 79 | const int key_seq_len = output_grads.size(3); 80 | 81 | auto act_options = output_grads.options().requires_grad(false); 82 | torch::Tensor input_grads = 83 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 84 | 85 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 86 | void* input_grads_ptr = static_cast(input_grads.data_ptr()); 87 | 88 | //Softmax Grad 89 | DISPATCH_HALF_AND_BFLOAT( 90 | output_grads_.scalar_type(), 91 | "dispatch_scaled_masked_softmax_backward", 92 | dispatch_scaled_masked_softmax_backward( 93 | reinterpret_cast(input_grads_ptr), 94 | reinterpret_cast(output_grads_ptr), 95 | reinterpret_cast(softmax_results.data_ptr()), 96 | scale_factor, 97 | query_seq_len, 98 | key_seq_len, 99 | batches, 100 | attn_heads); 101 | ); 102 | 103 | return input_grads; 104 | } 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/scaled_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | float scale_factor); 14 | 15 | torch::Tensor bwd_cuda( 16 | torch::Tensor const& output_grads, 17 | torch::Tensor const& softmax_results, 18 | float scale_factor); 19 | 20 | torch::Tensor fwd( 21 | torch::Tensor const& input, 22 | float scale_factor) { 23 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 24 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 25 | (input.scalar_type() == at::ScalarType::BFloat16), 26 | "Only fp16 and bf16 are supported"); 27 | 28 | return fwd_cuda(input, scale_factor); 29 | } 30 | 31 | torch::Tensor bwd( 32 | torch::Tensor const& output_grads, 33 | torch::Tensor const& softmax_results, 34 | float scale_factor) { 35 | 36 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 37 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 38 | 39 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 40 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 41 | "Only fp16 and bf16 are supported"); 42 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 43 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 44 | "Only fp16 and bf16 are supported"); 45 | 46 | return bwd_cuda(output_grads, softmax_results, scale_factor); 47 | } 48 | 49 | } // end namespace scaled_softmax 50 | } // end namespace fused_softmax 51 | } // end namespace multihead_attn 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", 55 | &multihead_attn::fused_softmax::scaled_softmax::fwd, 56 | "Self Multihead Attention scaled, softmax -- Forward."); 57 | m.def("backward", 58 | &multihead_attn::fused_softmax::scaled_softmax::bwd, 59 | "Self Multihead Attention scaled, softmax -- Backward."); 60 | } 61 | 62 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/scaled_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_softmax { 16 | 17 | torch::Tensor fwd_cuda( 18 | torch::Tensor const& input, 19 | float scale_factor) 20 | { 21 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 22 | const int batches = input.size(0); 23 | const int attn_heads = input.size(1); 24 | const int query_seq_len = input.size(2); 25 | const int key_seq_len = input.size(3); 26 | TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); 27 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 28 | 29 | // Output 30 | auto act_options = input.options().requires_grad(false); 31 | torch::Tensor softmax_results = 32 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 33 | 34 | // Softmax Intermediate Result Ptr 35 | void* input_ptr = static_cast(input.data_ptr()); 36 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 37 | 38 | DISPATCH_HALF_AND_BFLOAT( 39 | input.scalar_type(), 40 | "dispatch_scaled_softmax_forward", 41 | dispatch_scaled_softmax_forward( 42 | reinterpret_cast(softmax_results_ptr), 43 | reinterpret_cast(input_ptr), 44 | scale_factor, 45 | query_seq_len, 46 | key_seq_len, 47 | batches, 48 | attn_heads); 49 | ); 50 | return softmax_results; 51 | } 52 | 53 | torch::Tensor bwd_cuda( 54 | torch::Tensor const& output_grads_, 55 | torch::Tensor const& softmax_results_, 56 | float scale_factor) { 57 | 58 | auto output_grads = output_grads_.contiguous(); 59 | auto softmax_results = softmax_results_.contiguous(); 60 | 61 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 62 | const int batches = output_grads.size(0); 63 | const int attn_heads = output_grads.size(1); 64 | const int query_seq_len = output_grads.size(2); 65 | const int key_seq_len = output_grads.size(3); 66 | 67 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 68 | 69 | //Softmax Grad 70 | DISPATCH_HALF_AND_BFLOAT( 71 | output_grads_.scalar_type(), 72 | "dispatch_scaled_masked_softmax_backward", 73 | dispatch_scaled_masked_softmax_backward( 74 | reinterpret_cast(output_grads_ptr), 75 | reinterpret_cast(output_grads_ptr), 76 | reinterpret_cast(softmax_results.data_ptr()), 77 | scale_factor, 78 | query_seq_len, 79 | key_seq_len, 80 | batches, 81 | attn_heads); 82 | ); 83 | 84 | //backward pass is completely in-place 85 | return output_grads; 86 | } 87 | } 88 | } 89 | } 90 | 91 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_upper_triang_masked_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | float scale_factor); 14 | 15 | torch::Tensor bwd_cuda( 16 | torch::Tensor const& output_grads, 17 | torch::Tensor const& softmax_results, 18 | float scale_factor); 19 | 20 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 21 | AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); 22 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 23 | (input.scalar_type() == at::ScalarType::BFloat16), 24 | "Only fp16 and bf16 are supported"); 25 | 26 | return fwd_cuda(input, scale_factor); 27 | } 28 | 29 | torch::Tensor bwd( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor) { 33 | 34 | AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); 35 | AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); 36 | 37 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 38 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 39 | "Only fp16 and bf16 are supported"); 40 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 41 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 42 | "Only fp16 and bf16 are supported"); 43 | 44 | return bwd_cuda(output_grads, softmax_results, scale_factor); 45 | } 46 | 47 | } // end namespace scaled_upper_triang_masked_softmax 48 | } // end namespace fused_softmax 49 | } // end namespace multihead_attn 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", 53 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 54 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 55 | m.def("backward", 56 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, 57 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 58 | } 59 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_upper_triang_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_upper_triang_masked_softmax { 16 | 17 | torch::Tensor fwd_cuda( 18 | torch::Tensor const& input, 19 | float scale_factor) 20 | { 21 | // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 22 | const int attn_batches = input.size(0); 23 | const int seq_len = input.size(1); 24 | TORCH_INTERNAL_ASSERT(seq_len <= 16384); 25 | 26 | // Output 27 | auto act_options = input.options().requires_grad(false); 28 | torch::Tensor softmax_results = 29 | torch::empty({attn_batches, seq_len, seq_len}, act_options); 30 | 31 | // Softmax Intermediate Result Ptr 32 | void* input_ptr = static_cast(input.data_ptr()); 33 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 34 | 35 | DISPATCH_HALF_AND_BFLOAT( 36 | input.scalar_type(), 37 | "dispatch_scaled_upper_triang_masked_softmax_forward", 38 | dispatch_scaled_upper_triang_masked_softmax_forward( 39 | reinterpret_cast(softmax_results_ptr), 40 | reinterpret_cast(input_ptr), 41 | scale_factor, 42 | seq_len, 43 | seq_len, 44 | attn_batches); 45 | ); 46 | return softmax_results; 47 | } 48 | 49 | 50 | torch::Tensor bwd_cuda( 51 | torch::Tensor const& output_grads_, 52 | torch::Tensor const& softmax_results_, 53 | float scale_factor) { 54 | 55 | auto output_grads = output_grads_.contiguous(); 56 | auto softmax_results = softmax_results_.contiguous(); 57 | 58 | //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 59 | const int attn_batches = output_grads.size(0); 60 | const int seq_len = output_grads.size(1); 61 | TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); 62 | 63 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 64 | 65 | //Softmax Grad 66 | DISPATCH_HALF_AND_BFLOAT( 67 | output_grads_.scalar_type(), 68 | "dispatch_scaled_upper_triang_masked_softmax_backward", 69 | dispatch_scaled_upper_triang_masked_softmax_backward( 70 | reinterpret_cast(output_grads_ptr), 71 | reinterpret_cast(output_grads_ptr), 72 | reinterpret_cast(softmax_results.data_ptr()), 73 | scale_factor, 74 | seq_len, 75 | seq_len, 76 | attn_batches); 77 | ); 78 | 79 | //backward pass is completely in-place 80 | return output_grads; 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexiScale/5ee1726bd761b6ba57d3315a35c2cf84ce6ca57e/third_party/megatron/fused_kernels/tests/__init__.py -------------------------------------------------------------------------------- /third_party/megatron/fused_kernels/type_shim.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | 4 | #include 5 | #include "compat.h" 6 | 7 | 8 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 9 | switch(TYPE) \ 10 | { \ 11 | case at::ScalarType::Half: \ 12 | { \ 13 | using scalar_t = at::Half; \ 14 | __VA_ARGS__; \ 15 | break; \ 16 | } \ 17 | case at::ScalarType::BFloat16: \ 18 | { \ 19 | using scalar_t = at::BFloat16; \ 20 | __VA_ARGS__; \ 21 | break; \ 22 | } \ 23 | default: \ 24 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 25 | } 26 | 27 | 28 | #define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ 29 | switch(TYPE) \ 30 | { \ 31 | case at::ScalarType::Half: \ 32 | { \ 33 | using scalar_t = at::Half; \ 34 | __VA_ARGS__; \ 35 | break; \ 36 | } \ 37 | case at::ScalarType::BFloat16: \ 38 | { \ 39 | using scalar_t = at::BFloat16; \ 40 | __VA_ARGS__; \ 41 | break; \ 42 | } \ 43 | case at::ScalarType::Float: \ 44 | { \ 45 | using scalar_t = float; \ 46 | __VA_ARGS__; \ 47 | break; \ 48 | } \ 49 | default: \ 50 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 51 | } 52 | 53 | 54 | 55 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ 56 | switch(TYPEIN) \ 57 | { \ 58 | case at::ScalarType::Float: \ 59 | { \ 60 | using scalar_t_in = float; \ 61 | switch(TYPEOUT) \ 62 | { \ 63 | case at::ScalarType::Float: \ 64 | { \ 65 | using scalar_t_out = float; \ 66 | __VA_ARGS__; \ 67 | break; \ 68 | } \ 69 | case at::ScalarType::Half: \ 70 | { \ 71 | using scalar_t_out = at::Half; \ 72 | __VA_ARGS__; \ 73 | break; \ 74 | } \ 75 | case at::ScalarType::BFloat16: \ 76 | { \ 77 | using scalar_t_out = at::BFloat16; \ 78 | __VA_ARGS__; \ 79 | break; \ 80 | } \ 81 | default: \ 82 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ 83 | } \ 84 | break; \ 85 | } \ 86 | case at::ScalarType::Half: \ 87 | { \ 88 | using scalar_t_in = at::Half; \ 89 | using scalar_t_out = at::Half; \ 90 | __VA_ARGS__; \ 91 | break; \ 92 | } \ 93 | case at::ScalarType::BFloat16: \ 94 | { \ 95 | using scalar_t_in = at::BFloat16; \ 96 | using scalar_t_out = at::BFloat16; \ 97 | __VA_ARGS__; \ 98 | break; \ 99 | } \ 100 | default: \ 101 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ 102 | } 103 | 104 | -------------------------------------------------------------------------------- /third_party/megatron/indexer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from third_party.megatron import get_args, print_rank_0 7 | from third_party.megatron.core import mpu 8 | from third_party.megatron.checkpointing import load_biencoder_checkpoint 9 | from third_party.megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset 10 | from third_party.megatron.data.orqa_wiki_dataset import get_open_retrieval_batch 11 | from third_party.megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader 12 | from third_party.megatron.data.realm_index import detach, OpenRetreivalDataStore 13 | from third_party.megatron.model.biencoder_model import get_model_provider 14 | from third_party.megatron.training import get_model 15 | 16 | 17 | class IndexBuilder(object): 18 | """ 19 | Object for taking one pass over a dataset and creating a BlockData of its 20 | embeddings 21 | """ 22 | def __init__(self): 23 | args = get_args() 24 | self.model = None 25 | self.dataloader = None 26 | self.evidence_embedder_obj = None 27 | self.biencoder_shared_query_context_model = \ 28 | args.biencoder_shared_query_context_model 29 | 30 | # need to know whether we're using a REALM checkpoint (args.load) 31 | # or ICT checkpoint 32 | assert not (args.load and args.ict_load) 33 | 34 | self.log_interval = args.indexer_log_interval 35 | self.batch_size = args.indexer_batch_size 36 | 37 | self.load_attributes() 38 | self.is_main_builder = mpu.get_data_parallel_rank() == 0 39 | self.num_total_builders = mpu.get_data_parallel_world_size() 40 | self.iteration = self.total_processed = 0 41 | 42 | def load_attributes(self): 43 | """ 44 | Load the necessary attributes: model, dataloader and empty BlockData 45 | """ 46 | only_context_model = True 47 | if self.biencoder_shared_query_context_model: 48 | only_context_model = False 49 | 50 | model = get_model(get_model_provider(only_context_model=\ 51 | only_context_model, biencoder_shared_query_context_model=\ 52 | self.biencoder_shared_query_context_model)) 53 | 54 | self.model = load_biencoder_checkpoint(model, 55 | only_context_model=only_context_model) 56 | 57 | assert len(self.model) == 1 58 | self.model[0].eval() 59 | 60 | self.dataset = get_open_retrieval_wiki_dataset() 61 | self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ 62 | self.batch_size)) 63 | 64 | self.evidence_embedder_obj = OpenRetreivalDataStore( \ 65 | load_from_path=False) 66 | 67 | def track_and_report_progress(self, batch_size): 68 | """ 69 | Utility function for tracking progress 70 | """ 71 | self.iteration += 1 72 | self.total_processed += batch_size * self.num_total_builders 73 | if self.is_main_builder and self.iteration % self.log_interval == 0: 74 | print('Batch {:10d} | Total {:10d}'.format(self.iteration, 75 | self.total_processed), flush=True) 76 | 77 | def build_and_save_index(self): 78 | """ 79 | Goes through one epoch of the dataloader and adds all data to this 80 | instance's BlockData. 81 | 82 | The copy of BlockData is saved as a shard, which when run in a 83 | distributed setting will be consolidated by the rank 0 process 84 | and saved as a final pickled BlockData. 85 | """ 86 | assert len(self.model) == 1 87 | unwrapped_model = self.model[0] 88 | 89 | while not hasattr(unwrapped_model, 'embed_text'): 90 | unwrapped_model = unwrapped_model.module 91 | 92 | while True: 93 | try: 94 | # batch also has query_tokens and query_pad_data 95 | row_id, context_tokens, context_mask, context_types, \ 96 | context_pad_mask = get_open_retrieval_batch( \ 97 | self.dataloader) 98 | except (StopIteration, IndexError): 99 | break 100 | 101 | # TODO: can we add with torch.no_grad() to reduce memory usage 102 | # detach, separate fields and add to BlockData 103 | assert context_mask.dtype == torch.bool 104 | context_logits = unwrapped_model.embed_text( 105 | unwrapped_model.context_model, context_tokens, context_mask, 106 | context_types) 107 | 108 | context_logits = detach(context_logits) 109 | row_id = detach(row_id) 110 | 111 | self.evidence_embedder_obj.add_block_data(row_id, context_logits) 112 | self.track_and_report_progress(batch_size=len(row_id)) 113 | 114 | # This process signals to finalize its shard and then synchronize with 115 | # the other processes 116 | self.evidence_embedder_obj.save_shard() 117 | torch.distributed.barrier() 118 | del self.model 119 | 120 | # rank 0 process builds the final copy 121 | if self.is_main_builder: 122 | self.evidence_embedder_obj.merge_shards_and_save() 123 | # make sure that every single piece of data was embedded 124 | assert len(self.evidence_embedder_obj.embed_data) == \ 125 | len(self.dataset) 126 | self.evidence_embedder_obj.clear() 127 | 128 | # complete building the final copy 129 | torch.distributed.barrier() 130 | -------------------------------------------------------------------------------- /third_party/megatron/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | import torch 5 | 6 | 7 | # A dictionary of all the memory buffers allocated. 8 | _MEM_BUFFS = dict() 9 | 10 | 11 | def allocate_mem_buff(name, numel, dtype, track_usage): 12 | """Allocate a memory buffer.""" 13 | assert name not in _MEM_BUFFS, \ 14 | 'memory buffer {} already allocated.'.format(name) 15 | _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) 16 | return _MEM_BUFFS[name] 17 | 18 | 19 | def get_mem_buff(name): 20 | """Get the memory buffer.""" 21 | return _MEM_BUFFS[name] 22 | 23 | 24 | class MemoryBuffer: 25 | """Contiguous memory buffer. 26 | Allocate a contiguous memory of type `dtype` and size `numel`. It is 27 | used to reduce memory fragmentation. 28 | 29 | Usage: After the allocation, the `_start` index is set tot the first 30 | index of the memory. A memory chunk starting from `_start` index 31 | can be `allocated` for an input tensor, with the elements of the 32 | tensor being coppied. The buffer can be reused by resetting the 33 | `_start` index. 34 | 35 | """ 36 | def __init__(self, name, numel, dtype, track_usage): 37 | if torch.distributed.get_rank() == 0: 38 | element_size = torch.tensor([], dtype=dtype).element_size() 39 | print('> building the {} memory buffer with {} num elements ' 40 | 'and {} dtype ({:.1f} MB)...'.format( 41 | name, numel, dtype, numel*element_size/1024/1024), 42 | flush=True) 43 | self.name = name 44 | self.numel = numel 45 | self.dtype = dtype 46 | self.data = torch.empty(self.numel, 47 | dtype=self.dtype, 48 | device=torch.cuda.current_device(), 49 | requires_grad=False) 50 | 51 | # Index tracking the start of the free memory. 52 | self._start = 0 53 | 54 | # Values used for tracking usage. 55 | self.track_usage = track_usage 56 | if self.track_usage: 57 | self.in_use_value = 0.0 58 | self.total_value = 0.0 59 | 60 | 61 | def reset(self): 62 | """Reset the buffer start index to the beginning of the buffer.""" 63 | self._start = 0 64 | 65 | 66 | def is_in_use(self): 67 | """Whether the current buffer hold on to any memory.""" 68 | return self._start > 0 69 | 70 | 71 | def numel_in_use(self): 72 | """Return number of elements in use.""" 73 | return self._start 74 | 75 | 76 | def add(self, tensor): 77 | """Allocate a chunk of memory from the buffer to tensor and copy 78 | the values.""" 79 | assert tensor.dtype == self.dtype, \ 80 | 'Input tensor type {} different from buffer type {}'.format( 81 | tensor.dtype, self.dtype) 82 | # Number of elements of the input tensor. 83 | tensor_numel = torch.numel(tensor) 84 | new_start = self._start + tensor_numel 85 | assert new_start <= self.numel, \ 86 | 'Not enough memory left in the buffer ({} > {})'.format( 87 | tensor_numel, self.numel - self._start) 88 | # New tensor is a view into the memory. 89 | new_tensor = self.data[self._start:new_start] 90 | self._start = new_start 91 | new_tensor = new_tensor.view(tensor.shape) 92 | new_tensor.copy_(tensor) 93 | # Return a pointer to the new tensor. 94 | return new_tensor 95 | 96 | 97 | def get_data(self): 98 | """Return the data currently in use.""" 99 | if self.track_usage: 100 | self.in_use_value += float(self._start) 101 | self.total_value += float(self.numel) 102 | return self.data[:self._start] 103 | 104 | 105 | def print_average_usage(self): 106 | """Print memory usage average over time. We would like this value 107 | to be as high as possible.""" 108 | assert self.track_usage, 'You need to enable track usage.' 109 | if torch.distributed.get_rank() == 0: 110 | print(' > usage of {} memory buffer: {:.2f} %'.format( 111 | self.name, self.in_use_value * 100.0 / self.total_value), 112 | flush=True) 113 | 114 | 115 | 116 | class RingMemBuffer: 117 | """A ring of memory buffers.""" 118 | 119 | def __init__(self, name, num_buffers, numel, dtype, track_usage): 120 | self.num_buffers = num_buffers 121 | self.buffers = [ 122 | allocate_mem_buff(name+' {}'.format(i), numel, dtype, track_usage) 123 | for i in range(num_buffers)] 124 | self._index = -1 125 | 126 | 127 | def get_next_buffer(self): 128 | self._index += 1 129 | self._index = self._index % self.num_buffers 130 | buff = self.buffers[self._index] 131 | assert not buff.is_in_use(), 'buffer is already in use.' 132 | return buff 133 | -------------------------------------------------------------------------------- /third_party/megatron/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm 4 | 5 | from .distributed import DistributedDataParallel 6 | from .bert_model import BertModel 7 | from .gpt_model import GPTModel 8 | from .t5_model import T5Model 9 | from .language_model import get_language_model 10 | from .module import Float16Module, MegatronModule 11 | -------------------------------------------------------------------------------- /third_party/megatron/model/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Classification model.""" 4 | 5 | import torch 6 | 7 | from third_party.megatron import get_args, print_rank_last 8 | from third_party.megatron.model.enums import AttnMaskType 9 | from third_party.megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids 10 | from third_party.megatron.model.language_model import get_language_model 11 | from third_party.megatron.model.utils import get_linear_layer 12 | from third_party.megatron.model.utils import init_method_normal 13 | from third_party.megatron.model.utils import scaled_init_method_normal 14 | from .module import MegatronModule 15 | 16 | 17 | class Classification(MegatronModule): 18 | 19 | def __init__(self, 20 | num_classes, 21 | num_tokentypes=2, 22 | pre_process=True, 23 | post_process=True): 24 | super(Classification, self).__init__(share_word_embeddings=False) 25 | args = get_args() 26 | 27 | self.num_classes = num_classes 28 | self.pre_process = pre_process 29 | self.post_process = post_process 30 | init_method = init_method_normal(args.init_method_std) 31 | 32 | self.language_model, self._language_model_key = get_language_model( 33 | num_tokentypes=num_tokentypes, 34 | add_pooler=True, 35 | encoder_attn_mask_type=AttnMaskType.padding, 36 | init_method=init_method, 37 | scaled_init_method=scaled_init_method_normal(args.init_method_std, 38 | args.num_layers), 39 | pre_process=self.pre_process, 40 | post_process=self.post_process) 41 | 42 | # Multi-choice head. 43 | if self.post_process: 44 | self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) 45 | self.classification_head = get_linear_layer(args.hidden_size, 46 | self.num_classes, 47 | init_method) 48 | self._classification_head_key = 'classification_head' 49 | 50 | def set_input_tensor(self, input_tensor): 51 | """See megatron.model.transformer.set_input_tensor()""" 52 | self.language_model.set_input_tensor(input_tensor) 53 | 54 | def forward(self, model_input, attention_mask, tokentype_ids=None): 55 | 56 | extended_attention_mask = bert_extended_attention_mask(attention_mask) 57 | input_ids = model_input 58 | position_ids = bert_position_ids(input_ids) 59 | 60 | lm_output = self.language_model( 61 | input_ids, 62 | position_ids, 63 | extended_attention_mask, 64 | tokentype_ids=tokentype_ids 65 | ) 66 | 67 | if self.post_process: 68 | _, pooled_output = lm_output 69 | classification_output = self.classification_dropout(pooled_output) 70 | classification_logits = self.classification_head(classification_output) 71 | 72 | # Reshape back to separate choices. 73 | classification_logits = classification_logits.view(-1, self.num_classes) 74 | 75 | return classification_logits 76 | return lm_output 77 | 78 | def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): 79 | """For easy load when model is combined with other heads, 80 | add an extra key.""" 81 | 82 | state_dict_ = {} 83 | state_dict_[self._language_model_key] \ 84 | = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, 85 | keep_vars=keep_vars) 86 | if self.post_process: 87 | state_dict_[self._classification_head_key] \ 88 | = self.classification_head.state_dict(prefix=prefix, keep_vars=keep_vars) 89 | return state_dict_ 90 | 91 | def load_state_dict(self, state_dict, strict=True): 92 | """Customized load.""" 93 | 94 | self.language_model.load_state_dict( 95 | state_dict[self._language_model_key], strict=strict) 96 | if self.post_process: 97 | if self._classification_head_key in state_dict: 98 | self.classification_head.load_state_dict( 99 | state_dict[self._classification_head_key], strict=strict) 100 | else: 101 | print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 102 | 'initializing to random'.format( 103 | self._classification_head_key)) 104 | -------------------------------------------------------------------------------- /third_party/megatron/model/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import enum 4 | 5 | class LayerType(enum.Enum): 6 | encoder = 1 7 | decoder = 2 8 | 9 | class AttnType(enum.Enum): 10 | self_attn = 1 11 | cross_attn = 2 12 | 13 | class AttnMaskType(enum.Enum): 14 | padding = 1 15 | causal = 2 16 | 17 | # For backward compatibility with old model checkpoints 18 | from third_party.megatron.core.enums import ModelType 19 | -------------------------------------------------------------------------------- /third_party/megatron/model/fused_bias_gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | 6 | ###### BIAS GELU FUSION/ NO AUTOGRAD ################ 7 | # 1/sqrt(2*pi)-> 0.3989423 8 | # 1/sqrt(2) -> 0.70710678 9 | # sqrt(2/pi) -> 0.79788456 10 | # this function is tanh approximation of gelu 11 | # actual gelu is: 12 | # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) 13 | 14 | @torch.jit.script 15 | def bias_gelu(bias, y): 16 | x = bias + y 17 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 18 | 19 | # gradient of tanh approximation of gelu 20 | # gradient of actual gelu is: 21 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) 22 | @torch.jit.script 23 | def bias_gelu_back(g, bias, y): 24 | x = bias + y 25 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 26 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 27 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) 28 | return ff*g 29 | 30 | class GeLUFunction(torch.autograd.Function): 31 | @staticmethod 32 | # bias is an optional argument 33 | def forward(ctx, input, bias): 34 | ctx.save_for_backward(input, bias) 35 | return bias_gelu(bias, input) 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | input, bias = ctx.saved_tensors 40 | tmp = bias_gelu_back(grad_output, bias, input) 41 | return tmp, tmp 42 | 43 | bias_gelu_impl = GeLUFunction.apply 44 | -------------------------------------------------------------------------------- /third_party/megatron/model/fused_layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """This code is copied fron NVIDIA apex: 4 | https://github.com/NVIDIA/apex 5 | with some changes. """ 6 | 7 | import numbers 8 | import torch 9 | from torch.nn.parameter import Parameter 10 | from torch.nn import init 11 | import importlib 12 | 13 | from third_party.megatron.core.utils import make_viewless_tensor 14 | 15 | try: 16 | from apex.contrib.layer_norm.layer_norm import FastLayerNormFN 17 | HAVE_PERSIST_LAYER_NORM = True 18 | except: 19 | HAVE_PERSIST_LAYER_NORM = False 20 | 21 | # from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction 22 | 23 | 24 | global fused_layer_norm_cuda 25 | fused_layer_norm_cuda = None 26 | 27 | 28 | class MixedFusedLayerNorm(torch.nn.Module): 29 | 30 | def __init__(self, normalized_shape, eps=1e-5, 31 | no_persist_layer_norm=True, 32 | sequence_parallel=False, 33 | apply_layernorm_1p=False): 34 | super(MixedFusedLayerNorm, self).__init__() 35 | 36 | self.apply_layernorm_1p = apply_layernorm_1p 37 | 38 | global fused_layer_norm_cuda 39 | fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") 40 | 41 | # List of hiddens sizes supported in the persistent layer norm kernel 42 | # If the hidden size is not supported, fall back to the non-persistent 43 | # kernel. 44 | persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096, 45 | 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 46 | 24576, 25600, 30720, 32768, 40960, 49152, 65536] 47 | if normalized_shape not in persist_ln_hidden_sizes or \ 48 | not HAVE_PERSIST_LAYER_NORM: 49 | no_persist_layer_norm = True 50 | 51 | if isinstance(normalized_shape, numbers.Integral): 52 | normalized_shape = (normalized_shape,) 53 | self.normalized_shape = torch.Size(normalized_shape) 54 | self.eps = eps 55 | self.weight = Parameter(torch.Tensor(*normalized_shape)) 56 | self.bias = Parameter(torch.Tensor(*normalized_shape)) 57 | self.reset_parameters() 58 | self.no_persist_layer_norm = no_persist_layer_norm 59 | self.sequence_parallel = sequence_parallel 60 | 61 | # set sequence parallelism flag on weight and bias parameters 62 | setattr(self.weight, 'sequence_parallel', self.sequence_parallel) 63 | setattr(self.bias, 'sequence_parallel', self.sequence_parallel) 64 | 65 | 66 | def reset_parameters(self): 67 | 68 | if self.apply_layernorm_1p: 69 | init.zeros_(self.weight) 70 | init.zeros_(self.bias) 71 | else: 72 | init.ones_(self.weight) 73 | init.zeros_(self.bias) 74 | 75 | def forward(self, input): 76 | 77 | weight = self.weight + 1 if self.apply_layernorm_1p else self.weight 78 | 79 | if self.no_persist_layer_norm: 80 | return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps) 81 | else: 82 | output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) 83 | 84 | # Apex's fast layer norm function outputs a 'view' tensor (i.e., has 85 | # a populated '_base' field). This will result in schedule.py's 86 | # deallocate_output_tensor() throwing an error, so a viewless tensor is 87 | # created to prevent this. 88 | output = make_viewless_tensor(inp = output, 89 | requires_grad = input.requires_grad, 90 | keep_graph = True) 91 | 92 | return output 93 | -------------------------------------------------------------------------------- /third_party/megatron/model/gpt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """GPT-2 model.""" 4 | 5 | import torch 6 | 7 | from third_party.megatron import get_args 8 | from third_party.megatron.core import tensor_parallel 9 | from .module import MegatronModule 10 | 11 | from .enums import AttnMaskType 12 | from .language_model import parallel_lm_logits 13 | from .language_model import get_language_model 14 | from .utils import init_method_normal 15 | from .utils import scaled_init_method_normal 16 | 17 | 18 | def post_language_model_processing(lm_output, labels, logit_weights, 19 | parallel_output, 20 | fp16_lm_cross_entropy): 21 | 22 | # Output. Format [s b h] 23 | output = parallel_lm_logits( 24 | lm_output, 25 | logit_weights, 26 | parallel_output) 27 | 28 | if labels is None: 29 | # [s b h] => [b s h] 30 | return output.transpose(0,1).contiguous() 31 | else: 32 | # [b s] => [s b] 33 | labels = labels.transpose(0,1).contiguous() 34 | if fp16_lm_cross_entropy: 35 | assert output.dtype == torch.half 36 | loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) 37 | else: 38 | loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) 39 | 40 | # [s b] => [b, s] 41 | loss = loss.transpose(0,1).contiguous() 42 | return loss 43 | 44 | 45 | class GPTModel(MegatronModule): 46 | """GPT-2 Language model.""" 47 | 48 | def __init__(self, 49 | num_tokentypes=0, 50 | parallel_output=True, 51 | pre_process=True, 52 | post_process=True): 53 | args = get_args() 54 | super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights) 55 | 56 | self.parallel_output = parallel_output 57 | self.pre_process = pre_process 58 | self.post_process = post_process 59 | self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy 60 | self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights 61 | 62 | self.language_model, self._language_model_key = get_language_model( 63 | num_tokentypes=num_tokentypes, 64 | add_pooler=False, 65 | encoder_attn_mask_type=AttnMaskType.causal, 66 | init_method=init_method_normal(args.init_method_std), 67 | scaled_init_method=scaled_init_method_normal(args.init_method_std, 68 | args.num_layers), 69 | pre_process=self.pre_process, 70 | post_process=self.post_process) 71 | 72 | if not args.untie_embeddings_and_output_weights: 73 | self.initialize_word_embeddings(init_method_normal) 74 | 75 | def set_input_tensor(self, input_tensor): 76 | """See megatron.model.transformer.set_input_tensor()""" 77 | self.language_model.set_input_tensor(input_tensor) 78 | 79 | def forward(self, input_ids, position_ids, attention_mask, 80 | ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None, 81 | labels=None, tokentype_ids=None, inference_params=None): 82 | 83 | lm_output = self.language_model( 84 | input_ids, 85 | position_ids, 86 | attention_mask, 87 | ret_input_ids=ret_input_ids, 88 | ret_position_ids=ret_position_ids, 89 | ret_attn_mask=ret_attn_mask, 90 | inference_params=inference_params) 91 | 92 | if self.post_process: 93 | return post_language_model_processing( 94 | lm_output, labels, 95 | self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(), 96 | self.parallel_output, 97 | self.fp16_lm_cross_entropy) 98 | else: 99 | return lm_output 100 | 101 | def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): 102 | 103 | state_dict_ = {} 104 | state_dict_[self._language_model_key] \ 105 | = self.language_model.state_dict_for_save_checkpoint( 106 | prefix=prefix, keep_vars=keep_vars) 107 | # Save word_embeddings. 108 | if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: 109 | state_dict_[self._word_embeddings_for_head_key] \ 110 | = self.word_embeddings.state_dict(prefix=prefix, 111 | keep_vars=keep_vars) 112 | return state_dict_ 113 | 114 | def load_state_dict(self, state_dict, strict=True): 115 | """Customized load.""" 116 | 117 | # Load word_embeddings. 118 | if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: 119 | self.word_embeddings.load_state_dict( 120 | state_dict[self._word_embeddings_for_head_key], strict=strict) 121 | if self._language_model_key in state_dict: 122 | state_dict = state_dict[self._language_model_key] 123 | self.language_model.load_state_dict(state_dict, strict=strict) 124 | -------------------------------------------------------------------------------- /third_party/megatron/model/multiple_choice.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Multiple choice model.""" 4 | 5 | import torch 6 | 7 | from third_party.megatron import get_args, print_rank_last 8 | from third_party.megatron.model.enums import AttnMaskType 9 | from third_party.megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids 10 | from third_party.megatron.model.language_model import get_language_model 11 | from third_party.megatron.model.utils import get_linear_layer 12 | from third_party.megatron.model.utils import init_method_normal 13 | from third_party.megatron.model.utils import scaled_init_method_normal 14 | from .module import MegatronModule 15 | 16 | 17 | class MultipleChoice(MegatronModule): 18 | 19 | def __init__(self, 20 | num_tokentypes=2, 21 | pre_process=True, 22 | post_process=True): 23 | super(MultipleChoice, self).__init__(share_word_embeddings=False) 24 | args = get_args() 25 | 26 | init_method = init_method_normal(args.init_method_std) 27 | self.pre_process = pre_process 28 | self.post_process = post_process 29 | 30 | self.language_model, self._language_model_key = get_language_model( 31 | num_tokentypes=num_tokentypes, 32 | add_pooler=True, 33 | encoder_attn_mask_type=AttnMaskType.padding, 34 | init_method=init_method, 35 | scaled_init_method=scaled_init_method_normal(args.init_method_std, 36 | args.num_layers), 37 | pre_process=self.pre_process, 38 | post_process=self.post_process) 39 | 40 | # Multi-choice head. 41 | if self.post_process: 42 | self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) 43 | self.multichoice_head = get_linear_layer(args.hidden_size, 1, 44 | init_method) 45 | self._multichoice_head_key = 'multichoice_head' 46 | 47 | def set_input_tensor(self, input_tensor): 48 | """See megatron.model.transformer.set_input_tensor()""" 49 | self.language_model.set_input_tensor(input_tensor) 50 | 51 | def forward(self, model_input, attention_mask, tokentype_ids=None): 52 | 53 | # [batch, choices, sequence] --> [batch * choices, sequence] --> 54 | # transformer --> [batch, choices] --> softmax 55 | 56 | # Ensure the shape is [batch-size, choices, sequence] 57 | assert len(attention_mask.shape) == 3 58 | num_choices = attention_mask.shape[1] 59 | 60 | # Reshape and treat choice dimension the same as batch. 61 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 62 | extended_attention_mask = bert_extended_attention_mask(attention_mask) 63 | 64 | input_ids = model_input 65 | # Do the same as attention_mask for input_ids, tokentype_ids 66 | assert len(input_ids.shape) == 3 67 | assert len(tokentype_ids.shape) == 3 68 | input_ids = input_ids.view(-1, input_ids.size(-1)) 69 | tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) 70 | position_ids = bert_position_ids(input_ids) 71 | 72 | lm_output = self.language_model( 73 | input_ids, 74 | position_ids, 75 | extended_attention_mask, 76 | tokentype_ids=tokentype_ids 77 | ) 78 | if self.post_process: 79 | _, pooled_output = lm_output 80 | multichoice_output = self.multichoice_dropout(pooled_output) 81 | multichoice_logits = self.multichoice_head(multichoice_output) 82 | 83 | # Reshape back to separate choices. 84 | multichoice_logits = multichoice_logits.view(-1, num_choices) 85 | 86 | return multichoice_logits 87 | return lm_output 88 | 89 | def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): 90 | """For easy load when model is combined with other heads, 91 | add an extra key.""" 92 | 93 | state_dict_ = {} 94 | state_dict_[self._language_model_key] \ 95 | = self.language_model.state_dict_for_save_checkpoint(prefix=prefix, 96 | keep_vars=keep_vars) 97 | if self.post_process: 98 | state_dict_[self._multichoice_head_key] \ 99 | = self.multichoice_head.state_dict(prefix=prefix, keep_vars=keep_vars) 100 | return state_dict_ 101 | 102 | def load_state_dict(self, state_dict, strict=True): 103 | """Customized load.""" 104 | 105 | self.language_model.load_state_dict( 106 | state_dict[self._language_model_key], strict=strict) 107 | if self.post_process: 108 | if self._multichoice_head_key in state_dict: 109 | self.multichoice_head.load_state_dict( 110 | state_dict[self._multichoice_head_key], strict=strict) 111 | else: 112 | print_rank_last('***WARNING*** could not find {} in the checkpoint, ' 113 | 'initializing to random'.format( 114 | self._multichoice_head_key)) 115 | -------------------------------------------------------------------------------- /third_party/megatron/model/rotary_pos_embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \ 4 | # 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \ 5 | # common/megatron/rotary_pos_embedding.py 6 | 7 | import importlib.util 8 | import torch 9 | 10 | from torch import einsum, nn 11 | 12 | __all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] 13 | 14 | class RotaryEmbedding(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 18 | self.register_buffer('inv_freq', inv_freq) 19 | if importlib.util.find_spec('einops') is None: 20 | raise RuntimeError("einops is required for Rotary Embedding") 21 | 22 | def forward(self, max_seq_len, offset=0): 23 | seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset 24 | freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) 25 | # first part even vector components, second part odd vector components, 26 | # 2 * dim in dimension size 27 | emb = torch.cat((freqs, freqs), dim=-1) 28 | # emb [seq_length, .., dim] 29 | from einops import rearrange 30 | return rearrange(emb, 'n d -> n 1 1 d') 31 | 32 | 33 | def _rotate_half(x): 34 | """ 35 | change sign so the last dimension becomes [-odd, +even] 36 | """ 37 | from einops import rearrange 38 | x = rearrange(x, '... (j d) -> ... j d', j=2) 39 | x1, x2 = x.unbind(dim=-2) 40 | return torch.cat((-x2, x1), dim=-1) 41 | 42 | 43 | def apply_rotary_pos_emb(t, freqs): 44 | """ 45 | input tensor t is of shape [seq_length, ..., dim] 46 | rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] 47 | check https://kexue.fm/archives/8265 for detailed formulas 48 | """ 49 | rot_dim = freqs.shape[-1] 50 | # ideally t_pass is empty so rotary pos embedding is applied to all tensor t 51 | t, t_pass = t[..., :rot_dim], t[..., rot_dim:] 52 | 53 | # first part is cosine component 54 | # second part is sine component, need to change signs with _rotate_half method 55 | t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) 56 | return torch.cat((t, t_pass), dim=-1) 57 | -------------------------------------------------------------------------------- /third_party/megatron/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Utilities for models.""" 4 | 5 | import math 6 | 7 | import torch 8 | 9 | from third_party.megatron import get_args 10 | 11 | def init_method_normal(sigma): 12 | """Init method based on N(0, sigma).""" 13 | def init_(tensor): 14 | return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) 15 | 16 | return init_ 17 | 18 | 19 | def scaled_init_method_normal(sigma, num_layers): 20 | """Init method based on N(0, sigma/sqrt(2*num_layers).""" 21 | std = sigma / math.sqrt(2.0 * num_layers) 22 | 23 | def init_(tensor): 24 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 25 | 26 | return init_ 27 | 28 | 29 | def attention_mask_func(attention_scores, attention_mask): 30 | attention_scores.masked_fill_(attention_mask.to(torch.bool), -10000.0) 31 | return attention_scores 32 | 33 | 34 | def get_linear_layer(rows, columns, init_method): 35 | """Simple linear layer with weight initialization.""" 36 | layer = torch.nn.Linear(rows, columns) 37 | if get_args().perform_initialization: 38 | init_method(layer.weight) 39 | with torch.no_grad(): 40 | layer.bias.zero_() 41 | return layer 42 | 43 | @torch.jit.script 44 | def gelu_impl(x): 45 | """OpenAI's gelu implementation.""" 46 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 47 | (1.0 + 0.044715 * x * x))) 48 | def openai_gelu(x): 49 | return gelu_impl(x) 50 | 51 | #This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter 52 | @torch.jit.script 53 | def erf_gelu(x): 54 | return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) 55 | -------------------------------------------------------------------------------- /third_party/megatron/model/vision/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Vision Transformer(VIT) model.""" 4 | 5 | import torch 6 | from torch.nn.init import trunc_normal_ 7 | from third_party.megatron import get_args 8 | from third_party.megatron.model.utils import get_linear_layer 9 | from third_party.megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead 10 | from third_party.megatron.model.vision.mit_backbone import mit_b3_avg 11 | from third_party.megatron.model.module import MegatronModule 12 | 13 | class VitClassificationModel(MegatronModule): 14 | """Vision Transformer Model.""" 15 | 16 | def __init__(self, num_classes, finetune=False, 17 | pre_process=True, post_process=True): 18 | super(VitClassificationModel, self).__init__() 19 | args = get_args() 20 | 21 | self.hidden_size = args.hidden_size 22 | self.num_classes = num_classes 23 | self.finetune = finetune 24 | self.pre_process = pre_process 25 | self.post_process = post_process 26 | self.backbone = VitBackbone( 27 | pre_process=self.pre_process, 28 | post_process=self.post_process, 29 | single_token_output=True 30 | ) 31 | 32 | if self.post_process: 33 | if not self.finetune: 34 | self.head = VitMlpHead(self.hidden_size, self.num_classes) 35 | else: 36 | self.head = get_linear_layer( 37 | self.hidden_size, 38 | self.num_classes, 39 | torch.nn.init.zeros_ 40 | ) 41 | 42 | def set_input_tensor(self, input_tensor): 43 | """See megatron.model.transformer.set_input_tensor()""" 44 | self.backbone.set_input_tensor(input_tensor) 45 | 46 | def forward(self, input): 47 | hidden_states = self.backbone(input) 48 | 49 | if self.post_process: 50 | hidden_states = self.head(hidden_states) 51 | 52 | return hidden_states 53 | 54 | 55 | class MitClassificationModel(MegatronModule): 56 | """Mix vision Transformer Model.""" 57 | 58 | def __init__(self, num_classes, 59 | pre_process=True, post_process=True): 60 | super(MitClassificationModel, self).__init__() 61 | args = get_args() 62 | 63 | self.hidden_size = args.hidden_size 64 | self.num_classes = num_classes 65 | 66 | self.backbone = mit_b3_avg() 67 | self.head = torch.nn.Linear(512, num_classes) 68 | self.apply(self._init_weights) 69 | 70 | def _init_weights(self, m): 71 | if isinstance(m, torch.nn.Linear): 72 | trunc_normal_(m.weight, std=.02) 73 | if isinstance(m, torch.nn.Linear) and m.bias is not None: 74 | torch.nn.init.constant_(m.bias, 0) 75 | 76 | def set_input_tensor(self, input_tensor): 77 | """See megatron.model.transformer.set_input_tensor()""" 78 | pass 79 | 80 | def forward(self, input): 81 | hidden_states = self.backbone(input) 82 | hidden_states = self.head(hidden_states) 83 | 84 | return hidden_states 85 | -------------------------------------------------------------------------------- /third_party/megatron/model/vision/knn_monitor.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from third_party.megatron import print_rank_0, get_args 4 | from third_party.megatron.core import mpu 5 | from third_party.megatron.data.vit_dataset import ClassificationTransform 6 | from third_party.megatron.data.image_folder import ImageFolder 7 | 8 | _FEATURE_BANK = None 9 | 10 | 11 | def build_data_loader(dataset, drop_last=True, shuffle=False): 12 | """Data loader. Note that batch-size is the local (per GPU) batch-size.""" 13 | # Sampler. 14 | args = get_args() 15 | micro_batch_size = 16 16 | num_workers = args.num_workers 17 | world_size = mpu.get_data_parallel_world_size() 18 | rank = mpu.get_data_parallel_rank() 19 | sampler = torch.utils.data.distributed.DistributedSampler( 20 | dataset, num_replicas=world_size, rank=rank, 21 | drop_last=drop_last, shuffle=shuffle 22 | ) 23 | 24 | # Data loader. Note that batch size is the per GPU batch size. 25 | data_loader = torch.utils.data.DataLoader( 26 | dataset, 27 | batch_size=micro_batch_size, 28 | sampler=sampler, 29 | shuffle=False, 30 | num_workers=num_workers, 31 | drop_last=not drop_last, 32 | pin_memory=True, 33 | ) 34 | return data_loader 35 | 36 | 37 | def compute_feature_bank(model): 38 | args = get_args() 39 | global _FEATURE_BANK 40 | feature_bank = [] 41 | feature_label = [] 42 | 43 | train_ds = ImageFolder( 44 | root=args.data_path[0], 45 | transform=ClassificationTransform((args.img_h, args.img_w), train=False), 46 | data_per_class_fraction=1.0 47 | ) 48 | classes = len(train_ds.classes) 49 | dataloader = build_data_loader(train_ds) 50 | 51 | for m in model: 52 | m.eval() 53 | 54 | with torch.no_grad(): 55 | for i, batch in enumerate(dataloader): 56 | images = batch[0].cuda().contiguous() 57 | labels = batch[1].cuda().contiguous() 58 | student_feature, teacher_feature = model[0](images) 59 | feature = F.normalize(teacher_feature.float(), dim=1) 60 | feature_bank.append(feature) 61 | feature_label.append(labels) 62 | 63 | for m in model: 64 | m.train() 65 | 66 | # [N', D] 67 | feature_bank = torch.cat(feature_bank, dim=0).contiguous() 68 | feature_label = torch.cat(feature_label, dim=0).contiguous() 69 | 70 | feature_banks = [torch.zeros_like(feature_bank) 71 | for i in range(mpu.get_data_parallel_world_size())] 72 | torch.distributed.all_gather(feature_banks, 73 | feature_bank, 74 | group=mpu.get_data_parallel_group()) 75 | 76 | assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()], 77 | feature_bank)) 78 | 79 | feature_labels = [torch.zeros_like(feature_label) 80 | for i in range(mpu.get_data_parallel_world_size())] 81 | torch.distributed.all_gather(feature_labels, 82 | feature_label, 83 | group=mpu.get_data_parallel_group()) 84 | 85 | # [D, N] 86 | feature_banks = torch.cat(feature_banks, dim=0).t().contiguous() 87 | # [N] 88 | feature_labels = torch.cat(feature_labels, dim=0).contiguous() 89 | print_rank_0("feature_banks size is {}".format(feature_banks.size())) 90 | print_rank_0("feature labels size is {}".format(feature_labels.size())) 91 | 92 | _FEATURE_BANK = (feature_banks, feature_labels, classes) 93 | 94 | 95 | def get_feature_bank(): 96 | global _FEATURE_BANK 97 | assert _FEATURE_BANK is not None 98 | return _FEATURE_BANK 99 | 100 | 101 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 102 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and 103 | # https://github.com/leftthomas/SimCLR 104 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 105 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 106 | sim_matrix = torch.mm(feature, feature_bank) 107 | # [B, K] 108 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 109 | # [B, K] 110 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), 111 | dim=-1, 112 | index=sim_indices) 113 | sim_weight = (sim_weight / knn_t).exp() 114 | 115 | # counts for each class 116 | one_hot_label = torch.zeros(feature.size(0) * knn_k, 117 | classes, 118 | device=sim_labels.device) 119 | # [B*K, C] 120 | one_hot_label = one_hot_label.scatter(dim=-1, 121 | index=sim_labels.view(-1, 1), 122 | value=1.0) 123 | # weighted score ---> [B, C] 124 | pred_scores = torch.sum( 125 | one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), 126 | dim=1) 127 | 128 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 129 | return pred_labels 130 | -------------------------------------------------------------------------------- /third_party/megatron/model/vision/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def resize(input, 7 | size=None, 8 | scale_factor=None, 9 | mode='nearest', 10 | align_corners=None, 11 | warning=True): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ((output_h > 1 and output_w > 1 and input_h > 1 18 | and input_w > 1) and (output_h - 1) % (input_h - 1) 19 | and (output_w - 1) % (input_w - 1)): 20 | warnings.warn( 21 | f'When align_corners={align_corners}, ' 22 | 'the output would more aligned if ' 23 | f'input size {(input_h, input_w)} is `x+1` and ' 24 | f'out size {(output_h, output_w)} is `nx+1`') 25 | if isinstance(size, torch.Size): 26 | size = tuple(int(x) for x in size) 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | -------------------------------------------------------------------------------- /third_party/megatron/mpu/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexiScale/5ee1726bd761b6ba57d3315a35c2cf84ce6ca57e/third_party/megatron/mpu/tests/__init__.py -------------------------------------------------------------------------------- /third_party/megatron/mpu/tests/commons.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import argparse 4 | import os 5 | import random 6 | import numpy 7 | import torch 8 | 9 | import mpu 10 | 11 | 12 | class IdentityLayer(torch.nn.Module): 13 | def __init__(self, size, scale=1.0): 14 | super(IdentityLayer, self).__init__() 15 | self.weight = torch.nn.Parameter(scale * torch.randn(size)) 16 | 17 | def forward(self): 18 | return self.weight 19 | 20 | 21 | def set_random_seed(seed): 22 | """Set random seed for reproducability.""" 23 | random.seed(seed) 24 | numpy.random.seed(seed) 25 | torch.manual_seed(seed) 26 | mpu.model_parallel_cuda_manual_seed(seed) 27 | 28 | 29 | def initialize_distributed(backend='nccl'): 30 | """Initialize torch.distributed.""" 31 | # Get local rank in case it is provided. 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--local_rank', type=int, default=None, 34 | help='local rank passed from distributed launcher') 35 | args = parser.parse_args() 36 | local_rank = args.local_rank 37 | 38 | # Get rank and world size. 39 | rank = int(os.getenv('RANK', '0')) 40 | world_size = int(os.getenv("WORLD_SIZE", '1')) 41 | 42 | print('> initializing torch.distributed with local rank: {}, ' 43 | 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) 44 | 45 | # Set the device id. 46 | device = rank % torch.cuda.device_count() 47 | if local_rank is not None: 48 | device = local_rank 49 | torch.cuda.set_device(device) 50 | 51 | # Call the init process. 52 | init_method = 'tcp://' 53 | master_ip = os.getenv('MASTER_ADDR', 'localhost') 54 | master_port = os.getenv('MASTER_PORT', '6000') 55 | init_method += master_ip + ':' + master_port 56 | torch.distributed.init_process_group( 57 | backend=backend, 58 | world_size=world_size, 59 | rank=rank, 60 | init_method=init_method) 61 | 62 | 63 | def print_separator(message): 64 | torch.distributed.barrier() 65 | filler_len = (78 - len(message)) // 2 66 | filler = '-' * filler_len 67 | string = '\n' + filler + ' {} '.format(message) + filler 68 | if torch.distributed.get_rank() == 0: 69 | print(string, flush=True) 70 | torch.distributed.barrier() 71 | -------------------------------------------------------------------------------- /third_party/megatron/mpu/tests/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import set_random_seed 4 | from commons import IdentityLayer 5 | from commons import print_separator 6 | from commons import initialize_distributed 7 | from mpu.cross_entropy import vocab_parallel_cross_entropy 8 | import mpu 9 | import torch.nn.functional as F 10 | import torch 11 | import random 12 | import sys 13 | sys.path.append("../..") 14 | 15 | 16 | def torch_cross_entropy(batch_size, seq_length, vocab_size, 17 | logits_scale, seed): 18 | set_random_seed(seed) 19 | identity = IdentityLayer((batch_size, seq_length, vocab_size), 20 | scale=logits_scale).cuda() 21 | logits = identity() 22 | target = torch.cuda.LongTensor( 23 | size=(batch_size, seq_length)).random_(0, vocab_size) 24 | loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), 25 | target.view(-1), 26 | reduction='none').view_as(target).mean() 27 | loss.backward() 28 | return loss, identity.weight.grad 29 | 30 | 31 | def mpu_cross_entropy(batch_size, seq_length, vocab_size, 32 | logits_scale, seed): 33 | set_random_seed(seed) 34 | identity = IdentityLayer((batch_size, seq_length, vocab_size), 35 | scale=logits_scale).cuda() 36 | logits = identity() 37 | logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) 38 | target = torch.cuda.LongTensor( 39 | size=(batch_size, seq_length)).random_(0, vocab_size) 40 | loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() 41 | loss.backward() 42 | return loss, identity.weight.grad 43 | 44 | 45 | def test_cross_entropy(tensor_model_parallel_size): 46 | 47 | if torch.distributed.get_rank() == 0: 48 | print('> testing cross entropy with model parallel size {} ...'. 49 | format(tensor_model_parallel_size)) 50 | 51 | mpu.initialize_model_parallel(tensor_model_parallel_size) 52 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 53 | 54 | batch_size = 13 55 | seq_length = 17 56 | vocab_size_per_partition = 11 57 | logits_scale = 1000.0 58 | vocab_size = vocab_size_per_partition * tensor_model_parallel_size 59 | seed = 1234 60 | 61 | loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, 62 | vocab_size, logits_scale, 63 | seed) 64 | loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, 65 | vocab_size, logits_scale, 66 | seed) 67 | 68 | error = loss_torch.sub_(loss_mpu).abs().max() 69 | print(' max error in loss on global rank {}: {}'.format( 70 | torch.distributed.get_rank(), error)) 71 | assert error < 1.0e-6 72 | 73 | error = grad_torch.sub_(grad_mpu).abs().max() 74 | print(' max error in grad on global rank {}: {}'.format( 75 | torch.distributed.get_rank(), error)) 76 | assert error < 1.0e-6 77 | 78 | # Reset groups 79 | mpu.destroy_tensor_model_parallel() 80 | 81 | torch.distributed.barrier() 82 | if torch.distributed.get_rank() == 0: 83 | print('>> passed the test :-)') 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | initialize_distributed() 89 | world_size = torch.distributed.get_world_size() 90 | 91 | tensor_model_parallel_size = 1 92 | while tensor_model_parallel_size <= world_size: 93 | print_separator('test cross entropy') 94 | test_cross_entropy(tensor_model_parallel_size) 95 | tensor_model_parallel_size *= 2 96 | -------------------------------------------------------------------------------- /third_party/megatron/mpu/tests/test_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import print_separator 4 | from commons import initialize_distributed 5 | from mpu import data as data_utils 6 | import mpu 7 | import torch 8 | import functools 9 | import operator 10 | import sys 11 | sys.path.append("../..") 12 | 13 | 14 | def test_broadcast_data(tensor_model_parallel_size): 15 | 16 | if torch.distributed.get_rank() == 0: 17 | print('> testing broadcast_data with model parallel size {} ...'. 18 | format(tensor_model_parallel_size)) 19 | 20 | mpu.initialize_model_parallel(tensor_model_parallel_size) 21 | torch.manual_seed(1234 + mpu.get_data_parallel_rank()) 22 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 23 | 24 | key_size_t = {'key1': [7, 11], 25 | 'key2': [8, 2, 1], 26 | 'key3': [13], 27 | 'key4': [5, 1, 2], 28 | 'key5': [5, 12]} 29 | keys = list(key_size_t.keys()) 30 | 31 | data = {} 32 | data_t = {} 33 | for key in key_size_t: 34 | data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) 35 | data_t[key] = data[key].clone() 36 | data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) 37 | data_t['keyX'] = data['keyX'].clone() 38 | if mpu.get_tensor_model_parallel_rank() != 0: 39 | data = None 40 | 41 | data_utils._check_data_types(keys, data_t, torch.int64) 42 | key_size, key_numel, \ 43 | total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) 44 | for key in keys: 45 | assert key_size[key] == key_size_t[key] 46 | total_numel_t = 0 47 | for key in keys: 48 | target_size = functools.reduce(operator.mul, key_size_t[key], 1) 49 | assert key_numel[key] == target_size 50 | total_numel_t += target_size 51 | assert total_numel == total_numel_t 52 | 53 | data_b = data_utils.broadcast_data(keys, data, torch.int64) 54 | for key in keys: 55 | tensor = data_t[key].cuda() 56 | assert data_b[key].sub(tensor).abs().max() == 0 57 | 58 | # Reset groups 59 | mpu.destroy_tensor_model_parallel() 60 | 61 | torch.distributed.barrier() 62 | if torch.distributed.get_rank() == 0: 63 | print('>> passed the test :-)') 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | initialize_distributed() 69 | world_size = torch.distributed.get_world_size() 70 | 71 | tensor_model_parallel_size = 1 72 | while tensor_model_parallel_size <= world_size: 73 | print_separator('test test broadcast data') 74 | test_broadcast_data(tensor_model_parallel_size) 75 | tensor_model_parallel_size *= 2 76 | -------------------------------------------------------------------------------- /third_party/megatron/mpu/tests/test_initialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import print_separator 4 | from commons import initialize_distributed 5 | import mpu 6 | import torch 7 | import sys 8 | sys.path.append("../..") 9 | 10 | 11 | def test_initialize_model_parallel(tensor_model_parallel_size): 12 | 13 | if torch.distributed.get_rank() == 0: 14 | print('> testing initialize_model_parallel with size {} ...'.format( 15 | tensor_model_parallel_size)) 16 | tensor_model_parallel_size_ = min(tensor_model_parallel_size, 17 | torch.distributed.get_world_size()) 18 | assert not mpu.model_parallel_is_initialized() 19 | mpu.initialize_model_parallel(tensor_model_parallel_size_) 20 | assert mpu.model_parallel_is_initialized() 21 | 22 | # Checks. 23 | def check(group, world_size, rank): 24 | assert world_size == torch.distributed.get_world_size(group=group) 25 | assert rank == torch.distributed.get_rank(group=group) 26 | 27 | # Model parallel. 28 | world_size = tensor_model_parallel_size_ 29 | rank = torch.distributed.get_rank() % tensor_model_parallel_size_ 30 | assert world_size == mpu.get_tensor_model_parallel_world_size() 31 | assert rank == mpu.get_tensor_model_parallel_rank() 32 | check(mpu.get_tensor_model_parallel_group(), world_size, rank) 33 | 34 | # Data parallel. 35 | world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ 36 | rank = torch.distributed.get_rank() // tensor_model_parallel_size 37 | assert world_size == mpu.get_data_parallel_world_size() 38 | assert rank == mpu.get_data_parallel_rank() 39 | check(mpu.get_data_parallel_group(), world_size, rank) 40 | 41 | # Reset groups 42 | mpu.destroy_model_parallel() 43 | 44 | torch.distributed.barrier() 45 | if torch.distributed.get_rank() == 0: 46 | print('>> passed the test :-)') 47 | 48 | 49 | def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): 50 | 51 | if torch.distributed.get_rank() == 0: 52 | print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( 53 | tensor_model_parallel_size_)) 54 | tensor_model_parallel_size = min(tensor_model_parallel_size_, 55 | torch.distributed.get_world_size()) 56 | assert not mpu.model_parallel_is_initialized() 57 | mpu.initialize_model_parallel(tensor_model_parallel_size) 58 | assert mpu.model_parallel_is_initialized() 59 | 60 | # Checks 61 | src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() 62 | assert mpu.get_tensor_model_parallel_src_rank() == src_rank 63 | 64 | # Reset groups 65 | mpu.destroy_model_parallel() 66 | 67 | torch.distributed.barrier() 68 | if torch.distributed.get_rank() == 0: 69 | print('>> passed the test :-)') 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | initialize_distributed() 75 | world_size = torch.distributed.get_world_size() 76 | tensor_model_parallel_size = 1 77 | while tensor_model_parallel_size <= world_size: 78 | print_separator('test initialize model parallel') 79 | test_initialize_model_parallel(tensor_model_parallel_size) 80 | print_separator('test model parallel source rank') 81 | test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) 82 | tensor_model_parallel_size *= 2 83 | -------------------------------------------------------------------------------- /third_party/megatron/optimizer/grad_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Megatron grad scaler.""" 4 | 5 | from abc import ABC 6 | from abc import abstractmethod 7 | 8 | import torch 9 | 10 | 11 | class MegatronGradScaler(ABC): 12 | 13 | def __init__(self, initial_scale): 14 | """Initialize scale value with the input initial scale.""" 15 | assert initial_scale > 0.0 16 | self._scale = torch.cuda.FloatTensor([initial_scale]) 17 | 18 | @property 19 | def scale(self): 20 | return self._scale 21 | 22 | @property 23 | def inv_scale(self): 24 | return self._scale.double().reciprocal().float() 25 | 26 | @abstractmethod 27 | def update(self, found_inf): 28 | pass 29 | 30 | @abstractmethod 31 | def state_dict(self): 32 | pass 33 | 34 | @abstractmethod 35 | def load_state_dict(self, state_dict): 36 | pass 37 | 38 | 39 | 40 | class ConstantGradScaler(MegatronGradScaler): 41 | 42 | def update(self, found_inf): 43 | pass 44 | 45 | def state_dict(self): 46 | return dict() 47 | 48 | def load_state_dict(self, state_dict): 49 | pass 50 | 51 | 52 | 53 | class DynamicGradScaler(MegatronGradScaler): 54 | 55 | def __init__(self, initial_scale, min_scale, 56 | growth_factor, backoff_factor, 57 | growth_interval, hysteresis): 58 | """"Grad scaler with dynamic scale that gets adjusted 59 | during training.""" 60 | super(DynamicGradScaler, self).__init__(initial_scale) 61 | 62 | # Lower bound on the scale. 63 | assert min_scale > 0.0 64 | assert min_scale <= initial_scale 65 | self.min_scale = torch.cuda.FloatTensor([min_scale]) 66 | # Growth and backoff factors for the scale. 67 | assert growth_factor > 1.0 68 | self.growth_factor = torch.cuda.FloatTensor([growth_factor]) 69 | assert backoff_factor < 1.0 70 | assert backoff_factor > 0.0 71 | self.backoff_factor = torch.cuda.FloatTensor([backoff_factor]) 72 | # Interval over which if we don't see any inf/nan, 73 | # we will scale the grad scale by the growth factor. 74 | assert growth_interval > 0 75 | self.growth_interval = growth_interval 76 | # Number of inf/nans we should see before scaling down 77 | # the grad scale by the backoff factor. 78 | assert hysteresis > 0 79 | self.hysteresis = hysteresis 80 | 81 | # Trackers. 82 | self._growth_tracker = 0 83 | self._hysteresis_tracker = self.hysteresis 84 | 85 | 86 | def update(self, found_inf): 87 | 88 | # If we have an inf/nan, growth tracker is set to 0 89 | # and hysterisis tracker is reduced by 1. 90 | if found_inf: 91 | self._growth_tracker = 0 92 | self._hysteresis_tracker -= 1 93 | # Now if we are out of hysteresis count, scale down the loss. 94 | if self._hysteresis_tracker <= 0: 95 | self._scale = torch.max(self._scale * self.backoff_factor, 96 | self.min_scale) 97 | else: 98 | # If there is no nan/inf, increment the growth tracker. 99 | self._growth_tracker += 1 100 | # If we have had enough consequitive intervals with no nan/inf: 101 | if self._growth_tracker == self.growth_interval: 102 | # Reset the tracker and hysteresis trackers, 103 | self._growth_tracker = 0 104 | self._hysteresis_tracker = self.hysteresis 105 | # and scale up the loss scale. 106 | self._scale = self._scale * self.growth_factor 107 | 108 | 109 | def state_dict(self): 110 | state_dict = {} 111 | state_dict['scale'] = self._scale 112 | state_dict['growth_tracker'] = self._growth_tracker 113 | state_dict['hysteresis_tracker'] = self._hysteresis_tracker 114 | return state_dict 115 | 116 | 117 | def load_state_dict(self, state_dict): 118 | self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) 119 | self._growth_tracker = state_dict['growth_tracker'] 120 | self._hysteresis_tracker = state_dict['hysteresis_tracker'] 121 | -------------------------------------------------------------------------------- /third_party/megatron/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Megatron 9 | 71 | 72 | 73 |
74 |

Prompt Megatron

75 | 76 | 77 | 78 | 79 | 80 |
81 | 0 82 | / 1000 83 |
84 | 85 |
86 | 87 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /third_party/megatron/text_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | from .api import ( 5 | generate, 6 | generate_and_post_process, 7 | beam_search_and_post_process) 8 | -------------------------------------------------------------------------------- /third_party/megatron/text_generation/beam_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | ## from huggingface beam search 19 | class BeamHypotheses(object): 20 | def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): 21 | """ 22 | Initialize n-best list of hypotheses. 23 | """ 24 | self.length_penalty = length_penalty 25 | self.early_stopping = early_stopping 26 | self.num_beams = num_beams 27 | self.beams = [] 28 | self.worst_score = 1e9 29 | 30 | def __len__(self): 31 | """ 32 | Number of hypotheses in the list. 33 | """ 34 | return len(self.beams) 35 | 36 | def add(self, hyp, sum_logprobs, length): 37 | """ 38 | Add a new hypothesis to the list. 39 | """ 40 | score = sum_logprobs / length ** self.length_penalty 41 | if len(self) < self.num_beams or score > self.worst_score: 42 | self.beams.append((score, hyp)) 43 | if len(self) > self.num_beams: 44 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) 45 | del self.beams[sorted_scores[0][1]] 46 | self.worst_score = sorted_scores[1][0] 47 | else: 48 | self.worst_score = min(score, self.worst_score) 49 | 50 | def is_done(self, best_sum_logprobs, cur_len): 51 | """ 52 | If there are enough hypotheses and that none of the hypotheses being generated 53 | can become better than the worst one in the heap, then we are done with this sentence. 54 | """ 55 | 56 | if len(self) < self.num_beams: 57 | return False 58 | elif self.early_stopping: 59 | return True 60 | else: 61 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 62 | ret = self.worst_score >= cur_score 63 | return ret 64 | 65 | -------------------------------------------------------------------------------- /third_party/megatron/text_generation/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Sampling utilities. 4 | Part of this code is inspired by: 5 | - https://github.com/ari-holtzman/degen/blob/master/gen.py 6 | - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html 7 | """ 8 | 9 | 10 | import torch 11 | 12 | 13 | 14 | def modify_logits_for_top_k_filtering(logits, top_k): 15 | """Set the logits for none top-k values to -inf.""" 16 | 17 | filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] 18 | logits.masked_fill_(filter_, float('-Inf')) 19 | 20 | 21 | 22 | def modify_logits_for_top_p_filtering(logits, top_p): 23 | """Set the logits for none top-p values to -inf.""" 24 | 25 | # First sort and calculate cumulative sum of probabilities. 26 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 27 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 28 | 29 | # Filteration based on the cumulative sum. 30 | filter_ = cumulative_probs > top_p 31 | # This shift by 1 is weird and I cannot justify it. This existed 32 | # in the original implementation: 33 | # https://github.com/ari-holtzman/degen/blob/master/gen.py 34 | # and I guess it is needed so keeping it for now. 35 | filter_[:, 1:] = filter_[:, :-1].clone() 36 | # Make sure we at least have one token to select from. 37 | filter_[..., 0] = 0 38 | 39 | # Fill in the filtered part 40 | filter_ = filter_.scatter(1, sorted_indices, filter_) 41 | logits.masked_fill_(filter_, float('-Inf')) 42 | 43 | 44 | 45 | def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): 46 | """ Sample and generate a token. 47 | Note: logits has the dimension [b, v] where b is the batch size 48 | and v is the vocabulary size. 49 | If vocab_size is provided, we will make sure the sample that is 50 | generated is in [0, vocab-size). This will avoid out of vocabulary 51 | generations due to padding. 52 | """ 53 | 54 | # Check logits for consistency. 55 | assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' 56 | assert logits.type() == 'torch.cuda.FloatTensor', \ 57 | 'input logits should be floats.' 58 | 59 | 60 | # Greedy is just simple argmax. 61 | if top_k == 1: 62 | assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' 63 | samples = torch.argmax(logits, dim=-1) 64 | 65 | # Top-k or top-p sampling. 66 | else: 67 | # Clone so we do not modify the inputs, 68 | logits = logits.clone() 69 | # Apply temperature in place. 70 | if temperature != 1.0: 71 | logits.div_(temperature) 72 | 73 | if top_k > 1: 74 | assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' 75 | assert top_k <= logits.size(1), 'top-k is larger than logit size.' 76 | if vocab_size: 77 | assert top_k < vocab_size, 'top-k is larger than vocab size.' 78 | modify_logits_for_top_k_filtering(logits, top_k) 79 | 80 | elif top_p > 0.0: 81 | assert top_p <= 1.0, 'top-p should be in (0, 1].' 82 | modify_logits_for_top_p_filtering(logits, top_p) 83 | 84 | # After filtering, we need to recalculate the distribution. 85 | probs = logits.softmax(dim=-1) 86 | samples = torch.multinomial(probs, num_samples=1).view(-1) 87 | 88 | # If vocab size is provided, make sure the samples are in 89 | # in the range [0, vocab-size). 90 | if vocab_size: 91 | samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) 92 | 93 | return samples 94 | -------------------------------------------------------------------------------- /third_party/megatron/text_generation/tokenization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Tokenization utilities.""" 4 | 5 | 6 | import torch 7 | 8 | 9 | from third_party.megatron import get_tokenizer, get_args 10 | from .communication import broadcast_int_list, broadcast_tensor 11 | 12 | 13 | def detokenize_generations(tokens_gpu_tensor, 14 | lengths_gpu_tensor, 15 | return_segments): 16 | """Detokenize the generated tokens.""" 17 | 18 | tokenizer = get_tokenizer() 19 | args = get_args() 20 | prompts_plus_generations = [] 21 | if return_segments: 22 | prompts_plus_generations_segments = [] 23 | 24 | tokens = tokens_gpu_tensor.cpu().numpy().tolist() 25 | lengths = lengths_gpu_tensor.cpu().numpy().tolist() 26 | for sequence_tokens, length in zip(tokens, lengths): 27 | sequence_tokens = sequence_tokens[:length] 28 | prompts_plus_generations.append( 29 | tokenizer.detokenize(sequence_tokens)) 30 | if return_segments: 31 | words = [] 32 | for token in sequence_tokens: 33 | if args.tokenizer_type in ['SentencePieceTokenizer', 34 | 'GPTSentencePieceTokenizer']: 35 | word = tokenizer.decoder[token] 36 | elif args.tokenizer_type == 'NullTokenizer': 37 | word = str(token) 38 | else: 39 | word = tokenizer.tokenizer.decoder[token] 40 | word = bytearray( 41 | [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( 42 | 'utf-8', errors='replace') 43 | words.append(word) 44 | prompts_plus_generations_segments.append(words) 45 | 46 | if return_segments: 47 | return tokens, prompts_plus_generations, \ 48 | prompts_plus_generations_segments 49 | 50 | return tokens, prompts_plus_generations 51 | 52 | 53 | def tokenize_prompts(prompts=None, tokens_to_generate=None, 54 | add_BOS=None, rank=0): 55 | """Tokenize prompts and make them avaiable on all ranks.""" 56 | 57 | # On all ranks set to None so we can pass them to functions 58 | sizes_list = None 59 | prompts_tokens_cuda_long_tensor = None 60 | prompts_length_cuda_long_tensor = None 61 | 62 | # On the specified rank, build the above. 63 | if torch.distributed.get_rank() == rank: 64 | assert prompts is not None 65 | assert tokens_to_generate is not None 66 | # Tensor of tokens padded and their unpadded length. 67 | prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \ 68 | _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS) 69 | # We need the sizes of these tensors for the boradcast 70 | sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size 71 | prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght 72 | 73 | # First, broadcast the sizes. 74 | sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank) 75 | 76 | # Now that we have the sizes, we can boradcast the tokens 77 | # and length tensors. 78 | sizes = sizes_tensor.tolist() 79 | prompts_tokens_cuda_long_tensor = broadcast_tensor( 80 | sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank) 81 | prompts_length_cuda_long_tensor = broadcast_tensor( 82 | sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor, 83 | rank=rank) 84 | 85 | return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor 86 | 87 | 88 | def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS): 89 | """Given a set of prompts and number of tokens to generate: 90 | - tokenize prompts 91 | - set the sequence length to be the max of length of prompts 92 | plus the number of tokens we would like to generate 93 | - pad all the sequences to this length so we can convert them 94 | into a 2D tensor. 95 | """ 96 | 97 | # Tokenize all the prompts. 98 | tokenizer = get_tokenizer() 99 | if add_BOS: 100 | prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt) 101 | for prompt in prompts] 102 | else: 103 | prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts] 104 | 105 | # Now we have a list of list of tokens which each list has a different 106 | # size. We want to extend this list to: 107 | # - incorporate the tokens that need to be generated 108 | # - make all the sequences equal length. 109 | # Get the prompts length. 110 | prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens] 111 | # Get the max prompts length. 112 | max_prompt_len = max(prompts_length) 113 | # Number of tokens in the each sample of the batch. 114 | samples_length = max_prompt_len + tokens_to_generate 115 | # Now update the list of list to be of the same size: samples_length. 116 | for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length): 117 | padding_size = samples_length - prompt_length 118 | prompt_tokens.extend([tokenizer.eod] * padding_size) 119 | 120 | # Now we are in a structured format, we can convert to tensors. 121 | prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens) 122 | prompts_length_tensor = torch.cuda.LongTensor(prompts_length) 123 | 124 | return prompts_tokens_tensor, prompts_length_tensor 125 | -------------------------------------------------------------------------------- /third_party/megatron/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | from .tokenizer import build_tokenizer 5 | --------------------------------------------------------------------------------