├── run_on_every_node.py ├── huggingface_utils.py ├── deepspeed_inference_actors.py ├── deepspeed_utils.py ├── deepspeed_predictor.py └── README.md /run_on_every_node.py: -------------------------------------------------------------------------------- 1 | # Download the model from our S3 mirror as it's faster 2 | 3 | import argparse 4 | import subprocess 5 | 6 | import ray 7 | import ray.util.scheduling_strategies 8 | 9 | 10 | def force_on_node(node_id: str, remote_func_or_actor_class): 11 | scheduling_strategy = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( 12 | node_id=node_id, soft=False 13 | ) 14 | options = {"scheduling_strategy": scheduling_strategy} 15 | return remote_func_or_actor_class.options(**options) 16 | 17 | 18 | def run_on_every_node(remote_func_or_actor_class, *remote_args, **remote_kwargs): 19 | refs = [] 20 | for node in ray.nodes(): 21 | if node["Alive"] and node["Resources"].get("GPU", None): 22 | refs.append( 23 | force_on_node(node["NodeID"], remote_func_or_actor_class).remote( 24 | *remote_args, **remote_kwargs 25 | ) 26 | ) 27 | return ray.get(refs) 28 | 29 | 30 | @ray.remote(num_gpus=1) 31 | def download_model(bucket_uri, path_to_save_in): 32 | subprocess.run( 33 | [ 34 | "aws", 35 | "s3", 36 | "sync", 37 | "--quiet", 38 | bucket_uri, 39 | path_to_save_in, 40 | ] 41 | ) 42 | 43 | 44 | @ray.remote(num_gpus=1) 45 | def mount_nvme(): 46 | subprocess.run( 47 | 'drive_name="${1:-/dev/nvme1n1}"; mount_path="${2:-/nvme}"; set -x; sudo file -s "$drive_name"; sudo apt install xfsprogs -y; sudo mkfs -t xfs "$drive_name"; sudo mkdir "$mount_path" && sudo mount "$drive_name" "$mount_path" && sudo chown -R ray "$mount_path"', 48 | shell=True, 49 | check=True, 50 | ) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | 56 | parser.add_argument("function", type=str, help="function in this file to run") 57 | parser.add_argument("args", nargs="*", type=str, help="string args to function") 58 | args = parser.parse_args() 59 | 60 | ray.init() 61 | if args.function not in globals(): 62 | raise ValueError(f"{args.function} doesn't exist") 63 | fn = globals()[args.function] 64 | assert callable(fn) or hasattr(fn, "_function") 65 | print(f"Running {args.function}({', '.join(args.args)})") 66 | if hasattr(fn, "_function"): 67 | run_on_every_node(fn, *args.args) 68 | else: 69 | fn(*args.args) 70 | -------------------------------------------------------------------------------- /huggingface_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from collections import defaultdict 4 | from unittest.mock import patch 5 | 6 | import deepspeed 7 | import torch 8 | from filelock import FileLock 9 | from transformers import AutoModelForCausalLM 10 | from transformers.modeling_utils import dtype_byte_size 11 | from transformers.utils.hub import convert_file_size_to_int 12 | 13 | 14 | def shard_checkpoint_contiguous( 15 | state_dict, max_shard_size="10GB", weights_name: str = "pytorch_model.bin" 16 | ): 17 | """ 18 | Same as transformers.modeling_utils.shard_checkpoint, but shards each layer 19 | into its own file to mitigate https://github.com/microsoft/DeepSpeed/issues/3084. 20 | """ 21 | max_shard_size = convert_file_size_to_int(max_shard_size) 22 | 23 | sharded_state_dicts = [] 24 | current_block = {} 25 | current_block_size = 0 26 | total_size = 0 27 | 28 | layers = defaultdict(list) 29 | saved_keys = set() 30 | for key in state_dict: 31 | if key.startswith("model.decoder.layers."): 32 | layer_key = ".".join(key.split(".")[:4]) 33 | layers[layer_key].append(key) 34 | 35 | for keys in layers.values(): 36 | for key in keys: 37 | weight = state_dict[key] 38 | weight_size = weight.numel() * dtype_byte_size(weight.dtype) 39 | 40 | current_block[key] = weight 41 | current_block_size += weight_size 42 | total_size += weight_size 43 | saved_keys.add(key) 44 | sharded_state_dicts.append(current_block) 45 | current_block = {} 46 | current_block_size = 0 47 | 48 | for key, weight in state_dict.items(): 49 | if key in saved_keys: 50 | continue 51 | weight_size = weight.numel() * dtype_byte_size(weight.dtype) 52 | 53 | # If this weight is going to tip up over the maximal size, we split. 54 | if current_block_size + weight_size > max_shard_size: 55 | sharded_state_dicts.append(current_block) 56 | current_block = {} 57 | current_block_size = 0 58 | 59 | current_block[key] = weight 60 | current_block_size += weight_size 61 | total_size += weight_size 62 | 63 | # Add the last block 64 | sharded_state_dicts.append(current_block) 65 | 66 | # If we only have one shard, we return it 67 | if len(sharded_state_dicts) == 1: 68 | return {weights_name: sharded_state_dicts[0]}, None 69 | 70 | # Otherwise, let's build the index 71 | weight_map = {} 72 | shards = {} 73 | for idx, shard in enumerate(sharded_state_dicts): 74 | shard_file = weights_name.replace( 75 | ".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin" 76 | ) 77 | shard_file = shard_file.replace( 78 | ".safetensors", 79 | f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors", 80 | ) 81 | shards[shard_file] = shard 82 | for key in shard.keys(): 83 | weight_map[key] = shard_file 84 | 85 | # Add the metadata 86 | metadata = {"total_size": total_size} 87 | index = {"metadata": metadata, "weight_map": weight_map} 88 | return shards, index 89 | 90 | 91 | def reshard_checkpoint(model_name_or_path, dtype, path_to_save_in): 92 | """ 93 | Loads a transformers model into CPU memory, reshards and saves it to mitigate 94 | https://github.com/microsoft/DeepSpeed/issues/3084. 95 | """ 96 | with FileLock(f"{path_to_save_in}.lock"): 97 | # We use a done marker file so that the other ranks do not 98 | # go through the process again. 99 | done_marker = os.path.join(path_to_save_in, ".done") 100 | if not os.path.exists(done_marker): 101 | dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype 102 | with deepspeed.OnDevice(dtype=dtype, device="cpu"): 103 | model = AutoModelForCausalLM.from_pretrained( 104 | model_name_or_path, 105 | torch_dtype=dtype, 106 | low_cpu_mem_usage=True, 107 | ) 108 | with patch( 109 | "transformers.modeling_utils.shard_checkpoint", 110 | shard_checkpoint_contiguous, 111 | ): 112 | model.save_pretrained(path_to_save_in) 113 | with open(done_marker, "w"): 114 | pass 115 | del model 116 | gc.collect() 117 | return path_to_save_in 118 | -------------------------------------------------------------------------------- /deepspeed_inference_actors.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | import pandas as pd 6 | import ray 7 | import ray.util 8 | from ray.air import Checkpoint, ScalingConfig 9 | from ray.train.batch_predictor import BatchPredictor 10 | 11 | from deepspeed_predictor import DeepSpeedPredictor 12 | 13 | 14 | def get_parser() -> ArgumentParser: 15 | parser = ArgumentParser() 16 | 17 | parser.add_argument("--name", required=True, type=str, help="model_name") 18 | parser.add_argument( 19 | "--num_worker_groups", 20 | required=True, 21 | type=int, 22 | help="Number of prediction worker groups", 23 | ) 24 | parser.add_argument( 25 | "--num_gpus_per_worker_group", 26 | required=True, 27 | type=int, 28 | help="Number of GPUs per prediction worker group", 29 | ) 30 | parser.add_argument( 31 | "--hf_home", 32 | required=False, 33 | default=None, 34 | type=str, 35 | help="path to use as Hugging Face cache. If none, will be left as default.", 36 | ) 37 | parser.add_argument( 38 | "--checkpoint_path", 39 | required=False, 40 | default=None, 41 | type=str, 42 | help="model checkpoint path", 43 | ) 44 | parser.add_argument( 45 | "--save_mp_checkpoint_path", 46 | required=False, 47 | default=None, 48 | type=str, 49 | help="save-path to store the new model checkpoint", 50 | ) 51 | parser.add_argument("--batch_size", default=1, type=int, help="batch size") 52 | parser.add_argument( 53 | "--dtype", 54 | default="float16", 55 | type=str, 56 | choices=["float32", "float16", "int8"], 57 | help="data-type", 58 | ) 59 | parser.add_argument( 60 | "--ds_inference", action="store_true", help="enable ds-inference" 61 | ) 62 | parser.add_argument( 63 | "--use_kernel", action="store_true", help="enable kernel-injection" 64 | ) 65 | parser.add_argument( 66 | "--replace_method", 67 | required=False, 68 | default="auto", 69 | type=str, 70 | help="replace method['', 'auto']", 71 | ) 72 | parser.add_argument( 73 | "--max_tokens", 74 | default=1024, 75 | type=int, 76 | help="maximum tokens used for the text-generation KV-cache", 77 | ) 78 | parser.add_argument( 79 | "--max_new_tokens", default=50, type=int, help="maximum new tokens to generate" 80 | ) 81 | parser.add_argument( 82 | "--use_meta_tensor", 83 | action="store_true", 84 | help="use the meta tensors to initialize model", 85 | ) 86 | parser.add_argument( 87 | "--use_cache", default=True, type=bool, help="use cache for generation" 88 | ) 89 | parser.add_argument( 90 | "--reshard_checkpoint_path", 91 | required=False, 92 | default=None, 93 | type=str, 94 | help="Path to store a resharded HF checkpoint to mitigate microsoft/DeepSpeed/issues/3084. If not provided, will not reshard", 95 | ) 96 | return parser 97 | 98 | 99 | parser = get_parser() 100 | args = parser.parse_args() 101 | 102 | # %% 103 | runtime_env = {"working_dir": os.path.dirname(__file__)} 104 | 105 | if args.hf_home: 106 | os.environ["HF_HOME"] = args.hf_home 107 | runtime_env["env_vars"] = {"HF_HOME": os.environ["HF_HOME"]} 108 | 109 | ray.init(runtime_env=runtime_env) 110 | 111 | 112 | # %% 113 | import pandas as pd 114 | 115 | PREDICT_COLUMN = "predict" 116 | 117 | df = pd.DataFrame( 118 | ["DeepSpeed is", "Test", "Fill me", "How are you"] * 16, columns=[PREDICT_COLUMN] 119 | ) 120 | ds = ( 121 | ray.data.from_pandas(df) 122 | .repartition(args.num_gpus_per_worker_group * 2) 123 | .random_shuffle() 124 | .fully_executed() 125 | ) 126 | 127 | # %% 128 | # This is a scaling config for one worker group. 129 | group_scaling_config = ScalingConfig( 130 | use_gpu=True, 131 | num_workers=args.num_gpus_per_worker_group, 132 | trainer_resources={"CPU": 0}, 133 | ) 134 | batch_predictor = BatchPredictor.from_checkpoint( 135 | Checkpoint.from_dict({"args": args}), 136 | DeepSpeedPredictor, 137 | scaling_config=group_scaling_config, 138 | ) 139 | 140 | # %% 141 | pred = batch_predictor.predict( 142 | ds, 143 | batch_size=1, 144 | num_cpus_per_worker=0, 145 | min_scoring_workers=args.num_worker_groups, 146 | max_scoring_workers=args.num_worker_groups, 147 | # Kwargs passed to mode.generate 148 | do_sample=True, 149 | temperature=0.9, 150 | max_length=100, 151 | ) 152 | 153 | # %% 154 | print(pred.to_pandas()) 155 | -------------------------------------------------------------------------------- /deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation 2 | 3 | import argparse 4 | import gc 5 | import io 6 | import json 7 | import math 8 | import os 9 | from pathlib import Path 10 | from typing import List 11 | 12 | import deepspeed 13 | import torch 14 | from deepspeed.runtime.utils import see_memory_usage 15 | from huggingface_hub import snapshot_download 16 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 17 | 18 | 19 | class DSPipeline: 20 | """ 21 | Example helper class for comprehending DeepSpeed Meta Tensors, meant to mimic HF pipelines. 22 | The DSPipeline can run with and without meta tensors. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name="bigscience/bloom-3b", 28 | dtype=torch.float16, 29 | is_meta=True, 30 | device=-1, 31 | checkpoint_path=None, 32 | ): 33 | self.model_name = model_name 34 | self.dtype = dtype 35 | 36 | if isinstance(device, torch.device): 37 | self.device = device 38 | elif isinstance(device, str): 39 | self.device = torch.device(device) 40 | elif device < 0: 41 | self.device = torch.device("cpu") 42 | else: 43 | self.device = torch.device(f"cuda:{device}") 44 | 45 | # the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time. 46 | self.tp_presharded_models = [ 47 | "microsoft/bloom-deepspeed-inference-int8", 48 | "microsoft/bloom-deepspeed-inference-fp16", 49 | ] 50 | 51 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") 52 | self.tokenizer.pad_token = self.tokenizer.eos_token 53 | 54 | if is_meta: 55 | """When meta tensors enabled, use checkpoints""" 56 | self.config = AutoConfig.from_pretrained(self.model_name) 57 | self.repo_root, self.checkpoints_json = self._generate_json(checkpoint_path) 58 | 59 | with deepspeed.OnDevice(dtype=torch.float16, device="meta"): 60 | self.model = AutoModelForCausalLM.from_config(self.config) 61 | else: 62 | self.model = AutoModelForCausalLM.from_pretrained(self.model_name) 63 | 64 | self.model.eval() 65 | 66 | def __call__(self, inputs=["test"], **kwargs): 67 | if isinstance(inputs, str): 68 | input_list = [inputs] 69 | else: 70 | input_list = inputs 71 | 72 | outputs = self.generate_outputs(input_list, **kwargs) 73 | return outputs 74 | 75 | def _generate_json(self, checkpoint_path=None): 76 | if checkpoint_path is None: 77 | repo_root = snapshot_download( 78 | self.model_name, 79 | allow_patterns=["*"], 80 | ignore_patterns=["*.safetensors", "*.msgpack", "*.h5"], 81 | local_files_only=False, 82 | revision=None, 83 | ) 84 | else: 85 | assert os.path.exists( 86 | checkpoint_path 87 | ), f"Checkpoint path {checkpoint_path} does not exist" 88 | repo_root = checkpoint_path 89 | 90 | if os.path.exists(os.path.join(repo_root, "ds_inference_config.json")): 91 | checkpoints_json = os.path.join(repo_root, "ds_inference_config.json") 92 | elif self.model_name in self.tp_presharded_models: 93 | # tp presharded repos come with their own checkpoints config file 94 | checkpoints_json = os.path.join(repo_root, "ds_inference_config.json") 95 | else: 96 | checkpoints_json = "checkpoints.json" 97 | 98 | with io.open(checkpoints_json, "w", encoding="utf-8") as f: 99 | file_list = [ 100 | str(entry).split("/")[-1] 101 | for entry in Path(repo_root).rglob("*.[bp][it][n]") 102 | if entry.is_file() 103 | ] 104 | data = {"type": "BLOOM", "checkpoints": file_list, "version": 1.0} 105 | json.dump(data, f) 106 | 107 | return repo_root, checkpoints_json 108 | 109 | def generate_outputs(self, inputs=["test"], **generate_kwargs): 110 | input_tokens = self.tokenizer.batch_encode_plus( 111 | inputs, return_tensors="pt", padding=True 112 | ) 113 | for t in input_tokens: 114 | if torch.is_tensor(input_tokens[t]): 115 | input_tokens[t] = input_tokens[t].to(self.device) 116 | 117 | self.model.cuda().to(self.device) 118 | 119 | outputs = self.model.generate(**input_tokens, **generate_kwargs) 120 | outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 121 | 122 | return outputs 123 | 124 | 125 | def init_model( 126 | args: argparse.Namespace, world_size: int, local_rank: int 127 | ) -> DSPipeline: 128 | """Initialize the deepspeed model""" 129 | data_type = getattr(torch, args.dtype) 130 | 131 | if local_rank == 0: 132 | see_memory_usage("before init", True) 133 | 134 | pipe = DSPipeline( 135 | model_name=args.name, 136 | dtype=data_type, 137 | is_meta=args.use_meta_tensor, 138 | device=local_rank, 139 | checkpoint_path=args.checkpoint_path, 140 | ) 141 | if local_rank == 0: 142 | see_memory_usage("after init", True) 143 | if args.use_meta_tensor: 144 | ds_kwargs = dict(base_dir=pipe.repo_root, checkpoint=pipe.checkpoints_json) 145 | else: 146 | ds_kwargs = dict() 147 | 148 | gc.collect() 149 | if args.ds_inference: 150 | pipe.model = deepspeed.init_inference( 151 | pipe.model, 152 | dtype=data_type, 153 | mp_size=world_size, 154 | replace_with_kernel_inject=args.use_kernel, 155 | replace_method=args.replace_method, 156 | max_tokens=args.max_tokens, 157 | save_mp_checkpoint_path=args.save_mp_checkpoint_path, 158 | **ds_kwargs, 159 | ) 160 | if local_rank == 0: 161 | see_memory_usage("after init_inference", True) 162 | return pipe 163 | 164 | 165 | def generate( 166 | input_sentences: List[str], pipe: DSPipeline, batch_size: int, **generate_kwargs 167 | ) -> List[str]: 168 | """Generate predictions using a DSPipeline""" 169 | if batch_size > len(input_sentences): 170 | # dynamically extend to support larger bs by repetition 171 | input_sentences *= math.ceil(batch_size / len(input_sentences)) 172 | 173 | inputs = input_sentences[:batch_size] 174 | outputs = pipe(inputs, **generate_kwargs) 175 | return outputs 176 | -------------------------------------------------------------------------------- /deepspeed_predictor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import socket 4 | from collections import defaultdict 5 | from contextlib import closing 6 | from datetime import timedelta 7 | from typing import List, Tuple 8 | 9 | import pandas as pd 10 | import ray 11 | import ray.util 12 | import torch.distributed as dist 13 | from ray.air import Checkpoint, ScalingConfig 14 | from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME 15 | from ray.train.predictor import Predictor 16 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 17 | 18 | from deepspeed_utils import generate, init_model 19 | from huggingface_utils import reshard_checkpoint 20 | 21 | 22 | def find_free_port() -> int: 23 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 24 | s.bind(("", 0)) 25 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 26 | return s.getsockname()[1] 27 | 28 | 29 | @ray.remote 30 | class PredictionWorker: 31 | def __init__(self, args: argparse.Namespace, rank: int, world_size: int): 32 | self.args = args 33 | self.rank = rank 34 | self.world_size = world_size 35 | 36 | def get_address_and_port(self) -> Tuple[str, int]: 37 | """Returns the IP address and a free port on this node.""" 38 | addr = ray.util.get_node_ip_address() 39 | port = find_free_port() 40 | 41 | return addr, port 42 | 43 | def init_distributed( 44 | self, local_rank: int, local_world_size: int, master_addr: str, master_port: str 45 | ): 46 | """Initialize torch distributed backend""" 47 | os.environ["MASTER_ADDR"] = str(master_addr) 48 | os.environ["MASTER_PORT"] = str(master_port) 49 | # Same as in Ray Train 50 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" 51 | # This is not really robust, as multiple worker groups on 52 | # one node will overlap. 53 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 54 | [str(x) for x in range(local_world_size)] 55 | ) 56 | 57 | if "NCCL_SOCKET_IFNAME" not in os.environ: 58 | os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_NCCL_SOCKET_IFNAME 59 | 60 | dist.init_process_group( 61 | backend="nccl", 62 | init_method="env://", 63 | rank=self.rank, 64 | world_size=self.world_size, 65 | timeout=timedelta(seconds=1800), 66 | ) 67 | 68 | self.local_rank = local_rank 69 | 70 | os.environ["RANK"] = str(self.rank) 71 | os.environ["LOCAL_RANK"] = str(local_rank) 72 | os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size) 73 | os.environ["WORLD_SIZE"] = str(self.world_size) 74 | 75 | def init_model(self): 76 | """Initialize model for inference""" 77 | if self.args.reshard_checkpoint_path: 78 | self.args.checkpoint_path = reshard_checkpoint( 79 | self.args.checkpoint_path or self.args.name, 80 | self.args.dtype, 81 | self.args.reshard_checkpoint_path, 82 | ) 83 | self.generator = init_model(self.args, self.world_size, self.local_rank) 84 | 85 | def generate(self, data: pd.DataFrame, column: str, **kwargs) -> List[str]: 86 | return generate( 87 | list(data[column]), self.generator, self.args.batch_size, **kwargs 88 | ) 89 | 90 | 91 | class DeepSpeedPredictor(Predictor): 92 | def __init__(self, checkpoint: Checkpoint, scaling_config: ScalingConfig) -> None: 93 | self.checkpoint = checkpoint 94 | self.scaling_config = scaling_config 95 | self.init_worker_group(scaling_config) 96 | 97 | def init_worker_group(self, scaling_config: ScalingConfig): 98 | """Create the worker group. 99 | 100 | Each worker in the group communicates with other workers through the 101 | torch distributed backend. The worker group is inelastic (a failure of 102 | one worker will destroy the entire group). Each worker in the group 103 | recieves the same input data and outputs the same generated text. 104 | """ 105 | args = self.checkpoint.to_dict()["args"] 106 | 107 | # Start a placement group for the workers. 108 | self.pg = scaling_config.as_placement_group_factory().to_placement_group() 109 | prediction_worker_cls = PredictionWorker.options( 110 | num_cpus=scaling_config.num_cpus_per_worker, 111 | num_gpus=scaling_config.num_gpus_per_worker, 112 | resources=scaling_config.additional_resources_per_worker, 113 | scheduling_strategy=PlacementGroupSchedulingStrategy( 114 | placement_group=self.pg, placement_group_capture_child_tasks=True 115 | ), 116 | ) 117 | # Create the prediction workers. 118 | self.prediction_workers = [ 119 | prediction_worker_cls.remote(args, i, scaling_config.num_workers) 120 | for i in range(scaling_config.num_workers) 121 | ] 122 | # Get the IPs and ports of the workers. 123 | self.prediction_workers_ips_ports = ray.get( 124 | [ 125 | prediction_worker.get_address_and_port.remote() 126 | for prediction_worker in self.prediction_workers 127 | ] 128 | ) 129 | # Rank 0 worker will be set as the master address for torch distributed. 130 | rank_0_ip, rank_0_port = self.prediction_workers_ips_ports[0] 131 | 132 | # Map from node ip to the workers on it 133 | ip_dict = defaultdict(list) 134 | for i, ip_port in enumerate(self.prediction_workers_ips_ports): 135 | ip_dict[ip_port[0]].append(i) 136 | 137 | # Configure local ranks and start the distributed backend on each worker. 138 | # This assumes that there cannot be a situation where 2 worker groups use the 139 | # same node. 140 | tasks = [] 141 | for rank in range(len(self.prediction_workers)): 142 | worker = self.prediction_workers[rank] 143 | local_world_size = len(ip_dict[self.prediction_workers_ips_ports[rank][0]]) 144 | local_rank = ip_dict[self.prediction_workers_ips_ports[rank][0]].index(rank) 145 | tasks.append( 146 | worker.init_distributed.remote( 147 | local_rank, local_world_size, rank_0_ip, rank_0_port 148 | ) 149 | ) 150 | ray.get(tasks) 151 | 152 | # Initialize the model itself on each worker. 153 | ray.get([worker.init_model.remote() for worker in self.prediction_workers]) 154 | 155 | def _predict_pandas( 156 | self, 157 | data: pd.DataFrame, 158 | input_column: str = "predict", 159 | output_column: str = "output", 160 | **kwargs 161 | ) -> pd.DataFrame: 162 | data_ref = ray.put(data) 163 | prediction = ray.get( 164 | [ 165 | worker.generate.remote(data_ref, column=input_column, **kwargs) 166 | for worker in self.prediction_workers 167 | ] 168 | )[0] 169 | 170 | return pd.DataFrame(prediction, columns=[output_column]) 171 | 172 | @classmethod 173 | def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor": 174 | return cls(checkpoint=checkpoint, **kwargs) 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ray-DeepSpeed-Inference 2 | 3 | *EXPERIMENTAL AND NOT PRODUCTION READY! Many rough edges.* 4 | 5 | Based on https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation 6 | 7 | ## How to run 8 | 9 | Runs OPT-66b inference on a cluster composed of g4dn nodes (in my tests, 3 x g4dn.12xlarge, giving a total of 12 GPUs). You can also run it on 12 x g4dn.4xlarge. 10 | 11 | ```bash 12 | python run_on_every_node.py download_model "s3://large-dl-models-mirror/models--anyscale--opt-66b-resharded/main/" "~/model" 13 | 14 | python deepspeed_inference_actors.py --name "facebook/opt-66b" --checkpoint_path "~/model" --batch_size 1 --ds_inference --use_kernel --use_meta_tensor --num_worker_groups 1 --num_gpus_per_worker_group 12 15 | ``` 16 | 17 | ## How it works 18 | 19 | This repository demonstrates how to use [DeepSpeed Inference](https://www.deepspeed.ai/tutorials/inference-tutorial/) with [Ray](https://ray.io/) for scalable batch inference. The combination of these two tools allows for efficient generation of text with large language models, including models as large as OPT-66b. 20 | 21 | DeepSpeed Inference utilizes automatic model parallelism to distribute the model across multiple GPUs. Ray handles the scheduling and orchestration of the workload. 22 | 23 | There are three key parts to the code: 24 | 1. `deepspeed_inference_actors.py` (the entrypoint) generates a sample Ray Dataset and uses `ray.train.batch_predictor.BatchPredictor` with a custom `DeepSpeedPredictor`. The `BatchPredictor` spawns `num_worker_groups` `DeepSpeedPredictor` actors, each recieving a share of the data. 25 | 2. `deepspeed_predictor.py` contains the code for the `DeepSpeedPredictor`. Each `DeepSpeedPredictor` actor spawns `num_gpus_per_worker_group` worker actors (`PredictionWorker`), connected together via a `torch.distributed` backend, as required by DeepSpeed. Once initialized, the DeepSpeed model is ready for prediction. 26 | 3. `deepspeed_utils.py` contains code based on a DeepSpeed example that is used by `PredictionWorkers`. 27 | 28 | In other words, a `DeepSpeedPredictor` creates a worker group of `PredictionWorker`, which share a single model. A worker group is inelastic (if one worker fails, the entire group fails). This is similar to how Ray Train works (in fact, the logic can be implemented using Ray Train private APIs instead of `PredictionWorker`). 29 | 30 | ## Known issues 31 | 32 | 1. If there are multiple worker groups scheduled on one node, this will result in workers using the same CUDA devices and thus leading to a crash. Therefore, it's best to either use 1 GPU nodes, or make sure that the number of workers in a group divided by the number of nodes is equal to the number of GPUs on the nodes. 33 | 2. Certain models obtained from Hugging Face hub will cause exceptions due to a [bug in DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/3084). The solution is to reshard the checkpoints of those models to ensure that all layers are stored in contiguous files. The relevant code is included in `huggingface_utils.py`. 34 | 35 | ## Environment 36 | 37 | Key packages: 38 | ``` 39 | accelerate==0.17.1 40 | deepspeed==0.8.3 41 | ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl 42 | torch==2.0.0 43 | transformers==4.27.2 44 | ``` 45 | 46 | All packages: 47 | ``` 48 | absl-py==1.4.0 49 | accelerate==0.17.1 50 | adal==1.2.7 51 | aim==3.16.1 52 | aim-ui==3.16.1 53 | aimrecords==0.0.7 54 | aimrocks==0.3.1 55 | aiofiles==22.1.0 56 | aiohttp==3.8.4 57 | aiohttp-cors==0.7.0 58 | aiorwlock==1.3.0 59 | aiosignal==1.3.1 60 | aiosqlite==0.18.0 61 | ale-py==0.8.1 62 | alembic==1.10.2 63 | anyio==3.6.2 64 | anyscale @ file:///home/ray/anyscale-0.0.0.dev0.tar.gz 65 | anyscale-node-provider @ file:///home/ray/anyscale_node_provider-0.0.1.tar.gz 66 | applicationinsights==0.11.10 67 | argcomplete==1.12.3 68 | argon2-cffi==21.3.0 69 | argon2-cffi-bindings==21.2.0 70 | arrow==1.2.3 71 | asttokens==2.2.1 72 | astunparse==1.6.3 73 | async-timeout==4.0.2 74 | attrs==22.2.0 75 | autocfg==0.0.8 76 | autogluon.common==0.7.0 77 | autogluon.core==0.7.0 78 | autograd==1.5 79 | autopage==0.5.1 80 | AutoROM==0.6.0 81 | AutoROM.accept-rom-license==0.6.0 82 | awscli==1.25.6 83 | awscliv2==2.2.0 84 | ax-platform==0.3.1 85 | azure-cli-core==2.40.0 86 | azure-cli-telemetry==1.0.8 87 | azure-common==1.1.28 88 | azure-core==1.26.3 89 | azure-identity==1.10.0 90 | azure-mgmt-compute==23.1.0 91 | azure-mgmt-core==1.3.2 92 | azure-mgmt-network==19.0.0 93 | azure-mgmt-resource==20.0.0 94 | Babel==2.12.1 95 | backcall==0.2.0 96 | backoff==1.10.0 97 | backports.zoneinfo==0.2.1 98 | base58==2.0.1 99 | bayesian-optimization==1.2.0 100 | bcrypt==4.0.1 101 | beautifulsoup4==4.12.0 102 | bitsandbytes==0.37.2 103 | black==23.1.0 104 | bleach==6.0.0 105 | blessed==1.20.0 106 | blobfile==2.0.1 107 | boto3==1.26.95 108 | botocore==1.29.95 109 | botorch==0.8.3 110 | cached-property==1.5.2 111 | cachetools==5.3.0 112 | catboost==1.1.1 113 | certifi==2022.12.7 114 | cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work 115 | chardet==5.1.0 116 | charset-normalizer==3.1.0 117 | chess==1.7.0 118 | chex==0.1.6 119 | click==8.1.3 120 | cliff==4.2.0 121 | cloudpickle==2.2.1 122 | cma==2.7.0 123 | cmaes==0.9.1 124 | cmake==3.26.0 125 | cmd2==2.4.3 126 | colorama==0.4.6 127 | coloredlogs==15.0.1 128 | colorful==0.5.5 129 | colorlog==6.7.0 130 | comet-ml==3.31.9 131 | comm==0.1.2 132 | commonmark==0.9.1 133 | conda==23.1.0 134 | conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work 135 | conda-package-handling @ file:///croot/conda-package-handling_1666940373510/work 136 | configobj==5.0.8 137 | ConfigSpace==0.4.18 138 | contourpy==1.0.7 139 | coolname==2.2.0 140 | cryptography @ file:///croot/cryptography_1673298753778/work 141 | cycler==0.11.0 142 | Cython==0.29.32 143 | databricks-cli==0.17.5 144 | DataProperty==0.55.0 145 | datasets==2.10.1 146 | debugpy==1.6.6 147 | decorator==5.1.1 148 | decord==0.6.0 149 | deepspeed==0.8.3 150 | defusedxml==0.7.1 151 | Deprecated==1.2.13 152 | diffusers @ git+https://github.com/huggingface/diffusers.git@7fe88613fa15d230d59482889c440c7befa17c25 153 | dill==0.3.6 154 | distlib==0.3.6 155 | dm-tree==0.1.8 156 | docker==6.0.1 157 | docker-pycreds==0.4.0 158 | docutils==0.16 159 | dopamine-rl==4.0.5 160 | dragonfly-opt==0.1.6 161 | dulwich==0.21.3 162 | einops==0.3.0 163 | entrypoints==0.4 164 | etils==1.1.1 165 | evaluate==0.4.0 166 | everett==3.1.0 167 | exceptiongroup==1.1.1 168 | executing==1.2.0 169 | executor==23.2 170 | expiringdict==1.2.2 171 | fastapi==0.95.0 172 | fasteners==0.18 173 | fastjsonschema==2.16.3 174 | filelock==3.10.0 175 | FLAML==1.1.1 176 | Flask==2.2.3 177 | flatbuffers==2.0.7 178 | flax==0.6.7 179 | fonttools==4.39.2 180 | fqdn==1.5.1 181 | freezegun==1.1.0 182 | frozenlist==1.3.3 183 | fsspec==2023.3.0 184 | ftfy==6.1.1 185 | future==0.18.3 186 | gast==0.4.0 187 | gin-config==0.5.0 188 | gitdb==4.0.10 189 | GitPython==3.1.31 190 | glfw==2.5.7 191 | gluoncv==0.10.1.post0 192 | google-api-core==2.11.0 193 | google-api-python-client==1.7.8 194 | google-auth==2.16.2 195 | google-auth-httplib2==0.1.0 196 | google-auth-oauthlib==0.4.6 197 | google-cloud-compute==1.10.1 198 | google-cloud-core==2.3.2 199 | google-cloud-resource-manager==1.9.0 200 | google-cloud-secret-manager==2.16.0 201 | google-cloud-storage==2.7.0 202 | google-crc32c==1.5.0 203 | google-oauth==1.0.1 204 | google-pasta==0.2.0 205 | google-resumable-media==2.4.1 206 | googleapis-common-protos==1.58.0 207 | gpustat==1.0.0 208 | GPy==1.10.0 209 | gpytorch==1.9.1 210 | graphviz==0.8.4 211 | greenlet==2.0.2 212 | grpc-google-iam-v1==0.12.6 213 | grpcio==1.51.3 214 | grpcio-status==1.48.2 215 | grpcio-tools==1.51.3 216 | gunicorn==20.1.0 217 | gym==0.26.2 218 | gym-notices==0.0.8 219 | Gymnasium==0.26.3 220 | gymnasium-notices==0.0.1 221 | h11==0.14.0 222 | h5py==3.7.0 223 | halo==0.0.31 224 | HEBO==0.3.2 225 | higher==0.2.1 226 | hjson==3.1.0 227 | hpbandster==0.7.4 228 | httplib2==0.21.0 229 | huggingface-hub==0.13.3 230 | humanfriendly==10.0 231 | humanize==4.6.0 232 | hyperopt==0.2.5 233 | idna==3.4 234 | imageio==2.26.1 235 | imageio-ffmpeg==0.4.5 236 | importlib-metadata==6.1.0 237 | importlib-resources==5.12.0 238 | iniconfig==2.0.0 239 | ipykernel==6.22.0 240 | ipython==8.11.0 241 | ipython-genutils==0.2.0 242 | ipywidgets==8.0.4 243 | isodate==0.6.1 244 | isoduration==20.11.0 245 | isort==5.12.0 246 | itsdangerous==2.1.2 247 | jax==0.4.6 248 | jaxlib==0.4.6 249 | jedi==0.18.2 250 | Jinja2==3.1.2 251 | jmespath==0.10.0 252 | joblib==1.2.0 253 | json5==0.9.11 254 | jsonlines==3.1.0 255 | jsonpatch==1.32 256 | jsonpointer==2.3 257 | jsonschema==4.17.3 258 | jupyter-events==0.6.3 259 | jupyter-ydoc==0.2.3 260 | jupyter_client==8.1.0 261 | jupyter_core==5.3.0 262 | jupyter_server==2.5.0 263 | jupyter_server_fileid==0.8.0 264 | jupyter_server_terminals==0.4.4 265 | jupyter_server_ydoc==0.6.1 266 | jupyterlab==3.6.1 267 | jupyterlab-pygments==0.2.2 268 | jupyterlab-widgets==3.0.5 269 | jupyterlab_server==2.20.0 270 | kaggle-environments==1.7.11 271 | keras==2.11.0 272 | kiwisolver==1.4.4 273 | knack==0.10.1 274 | kubernetes==26.1.0 275 | lazy_loader==0.1 276 | libclang==15.0.6.1 277 | libtorrent==2.0.7 278 | lightgbm==3.3.5 279 | lightgbm-ray==0.1.8 280 | lightning-bolts==0.4.0 281 | lightning-utilities==0.8.0 282 | linear-operator==0.3.0 283 | lit==16.0.0 284 | lm-dataformat @ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836 285 | lm-eval==0.3.0 286 | log-symbols==0.0.14 287 | lxml==4.9.2 288 | lz4==4.3.2 289 | Mako==1.2.4 290 | Markdown==3.4.1 291 | markdown-it-py==2.2.0 292 | MarkupSafe==2.1.2 293 | matplotlib==3.7.1 294 | matplotlib-inline==0.1.6 295 | mbstrdecoder==1.1.2 296 | mdurl==0.1.2 297 | minigrid==2.1.1 298 | mistune==2.0.5 299 | mlagents-envs==0.28.0 300 | mlflow==1.30.0 301 | modin==0.18.1 302 | monotonic==1.6 303 | mosaicml==0.12.1 304 | mpmath==1.3.0 305 | msal==1.18.0b1 306 | msal-extensions==1.0.0 307 | msgpack==1.0.5 308 | msrest==0.7.1 309 | msrestazure==0.6.4 310 | mujoco==2.2.0 311 | mujoco-py==2.1.2.14 312 | multidict==6.0.4 313 | multipledispatch==0.6.0 314 | multiprocess==0.70.14 315 | mxnet==1.8.0.post0 316 | mypy-extensions==1.0.0 317 | nbclassic==0.5.3 318 | nbclient==0.7.2 319 | nbconvert==7.2.10 320 | nbformat==5.8.0 321 | nest-asyncio==1.5.6 322 | netifaces==0.11.0 323 | networkx==3.0 324 | nevergrad==0.4.3.post7 325 | ninja==1.11.1 326 | nltk==3.8.1 327 | notebook==6.5.3 328 | notebook_shim==0.2.2 329 | numexpr==2.8.4 330 | numpy==1.23.5 331 | nvidia-cublas-cu11==11.10.3.66 332 | nvidia-cuda-cupti-cu11==11.7.101 333 | nvidia-cuda-nvrtc-cu11==11.7.99 334 | nvidia-cuda-runtime-cu11==11.7.99 335 | nvidia-cudnn-cu11==8.5.0.96 336 | nvidia-cufft-cu11==10.9.0.58 337 | nvidia-curand-cu11==10.2.10.91 338 | nvidia-cusolver-cu11==11.4.0.1 339 | nvidia-cusparse-cu11==11.7.4.91 340 | nvidia-ml-py==11.495.46 341 | nvidia-nccl-cu11==2.14.3 342 | nvidia-nvtx-cu11==11.7.91 343 | oauth2client==4.1.3 344 | oauthlib==3.2.2 345 | onnx==1.12.0 346 | onnxruntime==1.14.1 347 | open-spiel==1.2 348 | openai==0.27.2 349 | opencensus==0.11.2 350 | opencensus-context==0.1.3 351 | opencv-python==4.7.0.72 352 | opentelemetry-api==1.1.0 353 | opentelemetry-exporter-otlp==1.1.0 354 | opentelemetry-exporter-otlp-proto-grpc==1.1.0 355 | opentelemetry-exporter-otlp-proto-http==1.16.0 356 | opentelemetry-proto==1.1.0 357 | opentelemetry-sdk==1.1.0 358 | opentelemetry-semantic-conventions==0.20b0 359 | opt-einsum==3.3.0 360 | optax==0.1.4 361 | optuna==2.10.0 362 | orbax==0.1.5 363 | packaging==23.0 364 | pandas==1.5.3 365 | pandocfilters==1.5.0 366 | paramiko==2.12.0 367 | paramz==0.9.5 368 | parso==0.8.3 369 | pathspec==0.11.1 370 | pathtools==0.1.2 371 | pathvalidate==2.5.2 372 | patsy==0.5.3 373 | pbr==5.11.1 374 | PettingZoo==1.22.1 375 | pexpect==4.8.0 376 | pickleshare==0.7.5 377 | Pillow==9.4.0 378 | pkginfo==1.9.6 379 | pkgutil_resolve_name==1.3.10 380 | platformdirs==3.1.1 381 | plotly==5.13.1 382 | pluggy @ file:///tmp/build/80754af9/pluggy_1648042571233/work 383 | portalocker==2.7.0 384 | prettytable==3.6.0 385 | prometheus-client==0.13.1 386 | prometheus-flask-exporter==0.22.3 387 | promise==2.3 388 | prompt-toolkit==3.0.38 389 | property-manager==3.0 390 | proto-plus==1.22.2 391 | protobuf==3.20.3 392 | psutil==5.9.4 393 | ptyprocess==0.7.0 394 | pure-eval==0.2.2 395 | py-cpuinfo==9.0.0 396 | py-spy==0.3.14 397 | py3nvml==0.2.7 398 | pyaml==21.10.1 399 | pyarrow==11.0.0 400 | pyasn1==0.4.8 401 | pyasn1-modules==0.2.8 402 | pybind11==2.6.2 403 | pycosat @ file:///croot/pycosat_1666805502580/work 404 | pycountry==22.3.5 405 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 406 | pycryptodomex==3.17 407 | pydantic==1.10.6 408 | pyDeprecate==0.3.2 409 | pygame==2.1.2 410 | pyglet==1.5.15 411 | Pygments==2.14.0 412 | PyJWT==2.6.0 413 | pymoo==0.5.0 414 | pymunk==6.2.1 415 | PyNaCl==1.5.0 416 | PyOpenGL==3.1.6 417 | pyOpenSSL==23.0.0 418 | pyparsing==3.0.9 419 | pyperclip==1.8.2 420 | pypng==0.20220715.0 421 | pyro-api==0.1.2 422 | pyro-ppl==1.8.4 423 | Pyro4==4.82 424 | pyrsistent==0.19.3 425 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 426 | pytablewriter==0.64.2 427 | pytest==7.2.2 428 | pytest-remotedata==0.3.2 429 | python-dateutil==2.8.2 430 | python-json-logger==2.0.7 431 | pytorch-lightning==2.0.0 432 | pytorch-ranger==0.1.1 433 | pytz==2022.7.1 434 | pytz-deprecation-shim==0.1.0.post0 435 | PyWavelets==1.4.1 436 | PyYAML==6.0 437 | pyzmq==25.0.2 438 | querystring-parser==1.2.4 439 | ray @ file:///home/ray/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl 440 | ray-lightning==0.3.0 441 | recsim==0.2.4 442 | redis==3.5.3 443 | regex==2022.10.31 444 | requests==2.28.2 445 | requests-oauthlib==1.3.1 446 | requests-toolbelt==0.10.1 447 | responses==0.18.0 448 | RestrictedPython==6.0 449 | rfc3339-validator==0.1.4 450 | rfc3986-validator==0.1.1 451 | rich==12.0.1 452 | rouge-score==0.1.2 453 | rsa==4.9 454 | ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work 455 | ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work 456 | s3transfer==0.6.0 457 | sacrebleu==1.5.0 458 | scikit-image==0.20.0 459 | scikit-learn==1.2.2 460 | scikit-optimize==0.9.0 461 | scipy==1.10.1 462 | segment-analytics-python==2.2.2 463 | semantic-version==2.10.0 464 | Send2Trash==1.8.0 465 | sentencepiece==0.1.96 466 | sentry-sdk==1.17.0 467 | serpent==1.41 468 | setproctitle==1.3.2 469 | shortuuid==1.0.1 470 | sigopt==7.5.0 471 | six==1.16.0 472 | smart-open==6.3.0 473 | smmap==5.0.0 474 | sniffio==1.3.0 475 | soupsieve==2.4 476 | spinners==0.0.24 477 | SQLAlchemy==1.4.47 478 | sqlitedict==2.1.0 479 | sqlparse==0.4.3 480 | stack-data==0.6.2 481 | starlette==0.26.1 482 | statsmodels==0.13.5 483 | stevedore==5.0.0 484 | SuperSuit==3.7.0 485 | sympy==1.11.1 486 | tabledata==1.3.1 487 | tabulate==0.9.0 488 | tblib==1.7.0 489 | tcolorpy==0.1.2 490 | tenacity==8.2.2 491 | tensorboard==2.12.0 492 | tensorboard-data-server==0.7.0 493 | tensorboard-plugin-wit==1.8.1 494 | tensorboardX==2.4.1 495 | tensorflow-estimator==2.11.0 496 | tensorflow-io-gcs-filesystem==0.31.0 497 | tensorflow-probability==0.19.0 498 | tensorstore==0.1.33 499 | termcolor==2.2.0 500 | terminado==0.10.1 501 | tf-slim==1.1.0 502 | tf2onnx==1.13.0 503 | threadpoolctl==3.1.0 504 | tifffile==2023.3.15 505 | tiktoken==0.1.2 506 | timm==0.4.5 507 | tinycss2==1.2.1 508 | tinyscaler==1.2.5 509 | tokenizers==0.13.2 510 | tomli==2.0.1 511 | toolz @ file:///croot/toolz_1667464077321/work 512 | torch==2.0.0 513 | torch-optimizer==0.3.0 514 | torchaudio==2.0.1 515 | torchmetrics==0.11.4 516 | torchvision==0.15.1 517 | tornado==6.2 518 | tqdm==4.65.0 519 | tqdm-multiprocess==0.0.11 520 | traitlets==5.9.0 521 | transformers==4.27.2 522 | triton==2.0.0 523 | tune-sklearn==0.4.4 524 | typeguard==2.13.3 525 | typepy==1.3.0 526 | typer==0.6.1 527 | typing_extensions==4.5.0 528 | tzdata==2022.7 529 | tzlocal==4.3 530 | ujson==5.7.0 531 | uri-template==1.2.0 532 | uritemplate==3.0.1 533 | urllib3==1.26.15 534 | uvicorn==0.21.1 535 | verboselogs==1.7 536 | virtualenv==20.21.0 537 | wandb==0.13.4 538 | wcwidth==0.2.6 539 | webcolors==1.12 540 | webencodings==0.5.1 541 | websocket-client==1.5.1 542 | Werkzeug==2.2.3 543 | widgetsnbextension==4.0.5 544 | wrapt==1.15.0 545 | wurlitzer==3.0.3 546 | xgboost==1.7.4 547 | xgboost-ray==0.1.15 548 | xmltodict==0.13.0 549 | xxhash==3.2.0 550 | y-py==0.5.9 551 | yacs==0.1.8 552 | yarl==1.8.2 553 | ypy-websocket==0.8.2 554 | zipp==3.15.0 555 | zoopt==0.4.1 556 | zstandard==0.20.0 557 | ``` --------------------------------------------------------------------------------