├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── benchmark ├── hexgen_documents │ ├── README.md │ ├── _llama_worker.py │ └── scripts │ │ ├── run_cross_node.sh │ │ └── run_llama.sh ├── petals_documents │ ├── README.md │ ├── _petals.py │ ├── config_scripts │ │ ├── cached_ip.sh │ │ ├── config_petals.sh │ │ ├── config_tgi.sh │ │ ├── petals_coordinator.sh │ │ ├── petals_model.sh │ │ ├── start_coo.sh │ │ ├── start_coo_head.sh │ │ └── start_model.sh │ └── scripts │ │ └── run_petals.sh ├── send_request │ ├── README.md │ ├── request.py │ └── single_request.py ├── tgi_documents │ ├── README.md │ ├── _tgi_worker.py │ └── scripts │ │ └── run_tgi.sh └── utils │ ├── _base.py │ ├── _base_rank_based.py │ └── _utils.py ├── hexgen ├── hexgen_core │ ├── README.md │ ├── __init__.py │ ├── gen_comm_groups.py │ ├── gen_hetero_groups.py │ ├── gen_p2p_lists.py │ ├── gen_parallel_groups.py │ ├── generation.py │ ├── heterogeneous_pipeline.py │ ├── init_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── falcon.py │ │ ├── gpt.py │ │ ├── gpt_neox.py │ │ ├── gptj.py │ │ ├── llama.py │ │ ├── opt.py │ │ └── vit.py │ ├── modules │ │ ├── __init__.py │ │ ├── block.py │ │ ├── embedding.py │ │ ├── mha.py │ │ └── mlp.py │ └── utils.py └── llama │ ├── README.md │ ├── arguments.py │ ├── llama-config │ ├── llama-13b │ │ ├── llama-13b │ │ │ └── config.json │ │ └── params.json │ ├── llama-30b │ │ ├── llama-30b │ │ │ └── config.json │ │ └── params.json │ ├── llama-70b │ │ ├── llama-70b │ │ │ └── config.json │ │ └── params.json │ └── llama-7b │ │ ├── llama-7b │ │ └── config.json │ │ └── params.json │ ├── llama_config_utils.py │ ├── llama_inference.py │ ├── load_model_parameters_utils │ ├── README.md │ ├── create_separate_state_dicts_llama_7b.py │ ├── inv_freq.pt │ ├── load_model_parameters.py │ └── remap_state_dict.py │ ├── modules │ ├── Llamamodel_pipeline.py │ ├── Llamamodel_tensor_parallel.py │ └── hybrid_parallel_model_dist.py │ └── scripts │ ├── run_llama_inference.sh │ ├── run_llama_p0.sh │ ├── run_llama_p1.sh │ ├── run_llama_p2.sh │ ├── run_llama_p3.sh │ ├── run_llama_p4.sh │ └── run_llama_p5.sh ├── requirements.txt ├── scripts ├── run_head.sh └── run_worker.sh └── third_party ├── megatron ├── megatron │ ├── __init__.py │ ├── arguments.py │ ├── checkpointing.py │ ├── core │ │ ├── README.md │ │ ├── __init__.py │ │ ├── enums.py │ │ ├── package_info.py │ │ ├── parallel_state.py │ │ ├── pipeline_parallel │ │ │ ├── __init__.py │ │ │ ├── p2p_communication.py │ │ │ └── schedules.py │ │ ├── requirements.txt │ │ ├── tensor_parallel │ │ │ ├── __init__.py │ │ │ ├── cross_entropy.py │ │ │ ├── data.py │ │ │ ├── layers.py │ │ │ ├── mappings.py │ │ │ ├── mappings_group.py │ │ │ ├── random.py │ │ │ └── utils.py │ │ └── utils.py │ ├── data │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── autoaugment.py │ │ ├── bert_dataset.py │ │ ├── biencoder_dataset_utils.py │ │ ├── blendable_dataset.py │ │ ├── data_samplers.py │ │ ├── dataset_utils.py │ │ ├── gpt_dataset.py │ │ ├── helpers.cpp │ │ ├── helpers.cpython-38-x86_64-linux-gnu.so │ │ ├── ict_dataset.py │ │ ├── image_folder.py │ │ ├── indexed_dataset.py │ │ ├── orqa_wiki_dataset.py │ │ ├── realm_dataset_utils.py │ │ ├── realm_index.py │ │ ├── t5_dataset.py │ │ ├── test │ │ │ ├── test_indexed_dataset.py │ │ │ └── test_preprocess_data.sh │ │ └── vit_dataset.py │ ├── dist_signal_handler.py │ ├── fp16_deprecated │ │ └── loss_scaler.py │ ├── fused_kernels │ │ ├── __init__.py │ │ ├── compat.h │ │ ├── scaled_masked_softmax.cpp │ │ ├── scaled_masked_softmax.h │ │ ├── scaled_masked_softmax_cuda.cu │ │ ├── scaled_softmax.cpp │ │ ├── scaled_softmax_cuda.cu │ │ ├── scaled_upper_triang_masked_softmax.cpp │ │ ├── scaled_upper_triang_masked_softmax.h │ │ ├── scaled_upper_triang_masked_softmax_cuda.cu │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_fused_kernels.py │ │ └── type_shim.h │ ├── global_vars.py │ ├── indexer.py │ ├── initialize.py │ ├── memory.py │ ├── microbatches.py │ ├── model │ │ ├── __init__.py │ │ ├── bert_model.py │ │ ├── biencoder_model.py │ │ ├── classification.py │ │ ├── distributed.py │ │ ├── enums.py │ │ ├── fused_bias_gelu.py │ │ ├── fused_layer_norm.py │ │ ├── fused_softmax.py │ │ ├── gpt_model.py │ │ ├── language_model.py │ │ ├── module.py │ │ ├── multiple_choice.py │ │ ├── realm_model.py │ │ ├── retro_transformer.py │ │ ├── rotary_pos_embedding.py │ │ ├── t5_model.py │ │ ├── transformer.py │ │ ├── utils.py │ │ └── vision │ │ │ ├── classification.py │ │ │ ├── dino.py │ │ │ ├── esvit_swin_backbone.py │ │ │ ├── inpainting.py │ │ │ ├── knn_monitor.py │ │ │ ├── mit_backbone.py │ │ │ ├── swin_backbone.py │ │ │ ├── utils.py │ │ │ └── vit_backbone.py │ ├── mpu │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── commons.py │ │ │ ├── test_cross_entropy.py │ │ │ ├── test_data.py │ │ │ ├── test_initialize.py │ │ │ ├── test_layers.py │ │ │ └── test_random.py │ ├── optimizer │ │ ├── __init__.py │ │ ├── clip_grads.py │ │ ├── distrib_optimizer.py │ │ ├── grad_scaler.py │ │ └── optimizer.py │ ├── optimizer_param_scheduler.py │ ├── static │ │ └── index.html │ ├── text_generation │ │ ├── __init__.py │ │ ├── api.py │ │ ├── beam_utils.py │ │ ├── communication.py │ │ ├── forward_step.py │ │ ├── generation.py │ │ ├── sampling.py │ │ └── tokenization.py │ ├── text_generation_server.py │ ├── timers.py │ ├── tokenizer │ │ ├── __init__.py │ │ ├── bert_tokenization.py │ │ ├── gpt2_tokenization.py │ │ └── tokenizer.py │ ├── training.py │ └── utils.py └── megatron_layers │ ├── __init__.py │ └── transformer.py └── ocf ├── LICENSE ├── README.md ├── docs ├── docs │ ├── advanced │ │ └── internals.md │ ├── guide │ │ ├── getting-started.md │ │ └── inference.md │ ├── images │ │ └── overview.png │ └── index.md ├── mkdocs.yml └── requirements.txt ├── examples ├── .gitignore ├── apis │ ├── .gitignore │ ├── conn.py │ ├── inference.py │ ├── inference_devnet.py │ ├── inference_example.py │ └── sd_inference_example.py └── worker │ ├── _base.py │ ├── inference.py │ └── utils.py └── src ├── benchmark └── inference.js ├── ocf-cli ├── .gitignore ├── Makefile ├── ocf_cli │ ├── __init__.py │ ├── bin │ │ ├── __init__.py │ │ └── ocf.py │ └── lib │ │ ├── core │ │ ├── base.py │ │ ├── config.py │ │ ├── host_worker.py │ │ ├── inference_worker.py │ │ └── utils.py │ │ ├── pod │ │ ├── config.py │ │ ├── manager.py │ │ ├── pod.py │ │ └── utils.py │ │ └── pprint │ │ ├── _base.py │ │ ├── nodes.py │ │ └── service.py └── pyproject.toml └── ocf-core ├── .gitignore ├── Dockerfile ├── Makefile ├── bin ├── core │ ├── cmd │ │ ├── cluster.go │ │ ├── config.go │ │ ├── init.go │ │ ├── root.go │ │ ├── start.go │ │ ├── update.go │ │ └── utility.go │ └── main.go └── netctl │ └── main.go ├── config ├── cfg.yaml └── cfg_standalone.yaml ├── go.mod ├── go.sum └── internal ├── cluster ├── baremetal.go ├── cluster.go ├── kubenetes.go ├── network │ └── edgevpn.go └── slurm.go ├── common ├── constants.go ├── logger.go ├── process │ ├── manager.go │ └── process.go ├── requests │ ├── broadcast.go │ ├── client.go │ ├── proxy.go │ └── weaver.go ├── secrets.go ├── structs │ ├── cluster.go │ ├── matching.go │ ├── request.go │ ├── summary.go │ └── workload.go ├── utils.go └── version.go ├── daemon ├── clock.go └── tcmd.go ├── database ├── access │ ├── client.go │ └── node.go ├── client.go ├── config.go ├── context.go ├── ent.go ├── enttest │ └── enttest.go ├── generate.go ├── hook │ └── hook.go ├── migrate │ ├── migrate.go │ └── schema.go ├── mutation.go ├── node.go ├── node │ ├── node.go │ └── where.go ├── node_create.go ├── node_delete.go ├── node_query.go ├── node_update.go ├── predicate │ └── predicate.go ├── runtime.go ├── runtime │ └── runtime.go ├── schema │ └── node.go └── tx.go ├── pkgs └── weaver │ └── README.md ├── profiler └── storage.go ├── protocol ├── README.md ├── p2p │ ├── bootstrap.go │ ├── dht.go │ ├── discovery.go │ ├── handler.go │ ├── host.go │ ├── key.go │ ├── remote.go │ └── structs.go ├── remote │ └── client.go └── rpc │ └── handler.go └── server ├── auth └── authentication.go ├── cors.go ├── forward.go ├── http_handler.go ├── queue └── queue.go ├── request.go ├── rpc.go ├── server.go ├── throttle.go ├── vacuum.go ├── welcome.go └── worker.go /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: subsystem 2 | subsystem: go-install 3 | export PATH=${PATH}:/usr/local/go/bin && cd ./third_party/ocf/src/ocf-core && git init && make build 4 | 5 | .PHONY: go-install 6 | go-install: 7 | ls /usr/local/go/bin || wget -c https://dl.google.com/go/go1.20.linux-amd64.tar.gz -O - | sudo tar -xz -C /usr/local 8 | 9 | .PHONY: requirements 10 | requirements: 11 | pip install -r requirements.txt 12 | 13 | .PHONY: flash-attn 14 | flash-attn: 15 | pip3 install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 16 | pip install flash-attn==2.0.8 17 | -git clone https://github.com/Dao-AILab/flash-attention.git 18 | cd flash-attention && git submodule update --init csrc/cutlass && cd csrc/fused_dense_lib && pip install . \ 19 | && cd ../xentropy && pip install . && cd ../rotary && pip install . && cd ../layer_norm && pip install . 20 | 21 | .PHONY: hexgen 22 | hexgen: subsystem requirements flash-attn 23 | 24 | .PHONY: hexgen-head 25 | hexgen: subsystem requirements 26 | 27 | .PHONY: clean 28 | clean: 29 | -rm edit $(objects) 30 | -------------------------------------------------------------------------------- /benchmark/hexgen_documents/_llama_worker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from loguru import logger 4 | import sys 5 | sys.path.insert(0, "..") 6 | sys.path.insert(0, '../utils') 7 | sys.path.insert(0, '../..') 8 | sys.path.insert(0, '../../hexgen') 9 | sys.path.insert(0, '../../hexgen/hexgen_core') 10 | sys.path.insert(0, '../../hexgen/llama') 11 | sys.path.insert(0, '../../hexgen/llama/modules') 12 | sys.path.insert(0, '../../hexgen/llama/llama-config') 13 | sys.path.insert(0, '../../third_party/megatron') 14 | from _base_rank_based import InferenceWorker 15 | from llama.arguments import add_arguments, clear_kv_cache 16 | from llama.llama_inference import inference, create_model, set_seed 17 | from megatron.initialize import initialize_megatron 18 | from megatron import get_args 19 | from threading import Thread 20 | from multiprocessing import Process 21 | 22 | class LlamaWorker(InferenceWorker): 23 | def __init__(self, model_name, head_node, args, ) -> None: 24 | self.head_node = head_node 25 | 26 | self.args = args 27 | self.rank = args.rank 28 | self.world_size = args.world_size 29 | self.model, self.tokenizer, self.pp_groups = create_model(args) 30 | 31 | super().__init__(model_name, head_node, self.rank, self.rank, args=args) 32 | 33 | async def handle_requests(self, msg): 34 | 35 | model_msg = self.parse_msg(msg) 36 | 37 | if self.rank == 0: 38 | threads = [] 39 | for rank in range(self.rank + 1, self.world_size): 40 | threads.append(Thread(target=self.send_request, args=(msg, rank))) 41 | 42 | for t in threads: 43 | t.start() 44 | 45 | print(f"On {self.rank}, Start inference") 46 | outputs, infer_time = inference(self.model, self.tokenizer, self.pp_groups, model_msg, self.args) 47 | 48 | else: 49 | print(f"On {self.rank}, Start inference") 50 | outputs, infer_time = inference(self.model, self.tokenizer, self.pp_groups, model_msg, self.args) 51 | 52 | clear_kv_cache() 53 | return outputs, infer_time 54 | 55 | def get_rank(self): 56 | return self.rank 57 | 58 | 59 | if __name__=="__main__": 60 | 61 | initialize_megatron(extra_args_provider=add_arguments) 62 | args = get_args() 63 | 64 | model_name = args.model_name 65 | head_node = args.head_node 66 | 67 | set_seed() 68 | 69 | logger.info(f"Creating Decentralized-LLM-inference Worker, {args.rank}, with world size of {args.world_size}") 70 | worker = LlamaWorker(model_name=model_name, head_node=head_node, args=args) 71 | worker.start() 72 | 73 | -------------------------------------------------------------------------------- /benchmark/hexgen_documents/scripts/run_cross_node.sh: -------------------------------------------------------------------------------- 1 | export PORT=9991 2 | export DEVICES=0,1,2,3 3 | 4 | export NUM_NODES=2 5 | export NUM_GPUS_PER_NODE=2 6 | export NCCL_IB_DISABLE=0 7 | export NCCL_IB_HCA=mlx5_2,mlx5_5 8 | # Modify the master IP below before execution 9 | export MASTER_ADDR='xxx.xxx.xxx.xx' 10 | export NODE_RANK=0 11 | # export NODE_RANK=1 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICES python3 -m torch.distributed.launch --nnodes=$NUM_NODES --nproc_per_node=$NUM_GPUS_PER_NODE --master_addr=$MASTER_ADDR --master_port=$PORT --node_rank=$NODE_RANK _llama_worker.py \ 14 | --model_size llama-7b \ 15 | --use-flash-attn \ 16 | --hetero_config 1 2 1 \ 17 | --pp_partition 8 16 8 \ 18 | --model_name "Llama-2-7b-chat-hf" \ 19 | # Modify the IP below before execution 20 | --head_node 'http://xxx.xxx.xx.xxx:xxxx' \ 21 | --group_id 0 22 | -------------------------------------------------------------------------------- /benchmark/hexgen_documents/scripts/run_llama.sh: -------------------------------------------------------------------------------- 1 | export PORT=9991 2 | export DEVICES=0,1,2,3 3 | 4 | CUDA_VISIBLE_DEVICES=$DEVICES python3 -m torch.distributed.launch --nproc_per_node=4 --master_port $PORT _llama_worker.py \ 5 | --model_size llama-7b \ 6 | --use-flash-attn \ 7 | --hetero_config 1 2 1 \ 8 | --pp_partition 8 16 8 \ 9 | --model_name "Llama-2-7b-chat-hf" \ 10 | # Modify the IP below before execution 11 | --head_node 'http://xxx.xxx.xx.xxx:xxxx' \ 12 | --group_id 0 \ 13 | -------------------------------------------------------------------------------- /benchmark/petals_documents/_petals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import argparse 4 | from loguru import logger 5 | from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, LlamaForCausalLM, AutoModelForSequenceClassification 6 | from transformers import TextGenerationPipeline, TextClassificationPipeline 7 | import sys 8 | sys.path.insert(0, "..") 9 | sys.path.insert(0, "../utils") 10 | from utils._base import InferenceWorker 11 | from transformers import AutoTokenizer 12 | from petals import AutoDistributedModelForCausalLM 13 | 14 | pipeline_mapping = { 15 | 'text-generation': TextGenerationPipeline, 16 | 'text-classification': TextClassificationPipeline, 17 | } 18 | 19 | model_mapping = { 20 | 'text-generation': AutoModelForCausalLM, 21 | 'text-classification': AutoModelForSequenceClassification, 22 | } 23 | 24 | dtype_mapping = { 25 | 'float32': torch.float32, 26 | 'float16': torch.float16, 27 | 'bfloat16': torch.bfloat16, 28 | } 29 | 30 | class PetalsWorker(InferenceWorker): 31 | def __init__(self, model_name, init_peers, token, id) -> None: 32 | super().__init__(f"{model_name}_{id}") 33 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) 34 | if init_peers == 'default': 35 | self.model = AutoDistributedModelForCausalLM.from_pretrained(model_name, token=token) 36 | else: 37 | self.model = AutoDistributedModelForCausalLM.from_pretrained(model_name, initial_peers=[init_peers], token=token) 38 | 39 | async def handle_requests(self, msg): 40 | prompts = msg.get('prompt', '') 41 | max_new_tokens = msg.get('max_new_tokens', 128) 42 | temperature = msg.get('temperature', 0.9) 43 | top_k = msg.get('top_k', 50) 44 | top_p = msg.get('top_p', 0.9) 45 | # if prompt is str: 46 | if isinstance(prompts, str): 47 | prompts = [prompts] 48 | 49 | print(prompts) 50 | outputs = [] 51 | for prompt in prompts: 52 | inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] 53 | torch.cuda.synchronize() 54 | start = time.time() 55 | output = self.model.generate(inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p) 56 | output = self.tokenizer.decode(output[0]) 57 | end = time.time() 58 | outputs.append(output) 59 | return outputs[0], end - start 60 | 61 | if __name__=="__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--model-name", type=str, default="openlm-research/open_llama_7b") 64 | parser.add_argument("--init-peers", type=str, default="default") 65 | parser.add_argument("--token", type=str, default="") 66 | parser.add_argument("--id", type=int, default=0) 67 | args = parser.parse_args() 68 | logger.info(f"args: {args}") 69 | worker = PetalsWorker(args.model_name, args.init_peers, args.token, args.id) 70 | worker.start() -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/cached_ip.sh: -------------------------------------------------------------------------------- 1 | # 216.153.54.97 2 | 3 | /ip4/192.168.99.2/tcp/9992/p2p/QmaPtUnaY39LxQ8bXM4h1APeegQJo1BJzrkXWvVMCpShnp 4 | /ip4/192.168.99.6/tcp/8993/p2p/QmaJswA9DvHUqqR7NgKdy3mMGpbvdFEFQQ66i2334U14aT 5 | /ip4/192.168.99.3/tcp/8994/p2p/Qmcf7XHY4NLpu64EBYtZhXk23qbNMJp9y7mHuvbYXpjphk 6 | /ip4/192.168.99.5/tcp/8995/p2p/QmTJKNy3JJvbe38ZALADiPf6vgQrLg1SDk4oCVx5kS7ctB 7 | /ip4/192.168.99.4/tcp/8996/p2p/QmegVKihARNJ4nu3dn2TkqTVRupPvJJNDNGiafv7X9Sd2R 8 | /ip4/192.168.99.9/tcp/8997/p2p/QmbMtDUSUNPZsXCZ8NjGLuM4gfRsXmfUPuPpzHxJ97GpEf 9 | /ip4/192.168.99.8/tcp/8998/p2p/QmRz7PumMXX9y4Z8QcydyJ9hZaMwfuh2o4AD2nCyJ8iKfe 10 | /ip4/192.168.99.7/tcp/8999/p2p/QmbBedokDkvZmoYXx3Ghv4jSq4wLRAv2i948ZnbJKMU2M5 11 | 12 | # 216.153.61.47 13 | 14 | /ip4/192.168.99.3/tcp/8992/p2p/QmUYnCiXkmgnkUXZ7M5sd5W4GFwaMeSttaWxAgiQZEt9J2 15 | /ip4/192.168.99.11/tcp/28993/p2p/QmdbszYkVRQCeP7YRjeiTjkYRuh49mS4ZDA3wv7NmNQ5Po 16 | /ip4/192.168.99.12/tcp/7994/p2p/QmNciSkDBW7Gna3PcCBQrxNGu9tnTTX4Bbfa6mufnq2SaN 17 | /ip4/192.168.99.13/tcp/7995/p2p/QmYdY2RdKdHsAm9gSpZ9p4jo7J5zPPvAaz3PEg7ftQnZEq 18 | /ip4/192.168.99.14/tcp/7996/p2p/QmcSqJTJ7pRoXPkC7FEMffcuxdNQ83HygURBPh1KfbFkhJ 19 | /ip4/192.168.99.15/tcp/7997/p2p/QmZVX5HiUiWLn4M8785e6bbjw5HSVejDsPEhyPyxF8wjr5 20 | /ip4/192.168.99.16/tcp/7998/p2p/QmPqm33TfjDZSBQ7ttWifb9Ni9HGy84gHrMUWzyWiQdeMk 21 | /ip4/192.168.99.17/tcp/7999/p2p/QmaWKeuPRiH6gZR8FSHS2udwDxwEBpLXkyBdgwZUkMzasd 22 | 23 | # 216.153.51.193 24 | 25 | /ip4/192.168.99.2/tcp/9992/p2p/QmVsGuziLyBADJGhxDgJMSDVx6hB9LJ8V48A8CKYwCEWPM 26 | /ip4/192.168.99.3/tcp/8993/p2p/QmetsqAsyw1LDub9yB9UseLLTkRnPUPd8MeAcY4YjRNrQU 27 | /ip4/192.168.99.4/tcp/8994/p2p/QmPp1FN2cw2eNQr1Fe4vAAXCYeGr48DJ5LZvqkKGKcwwoF 28 | /ip4/192.168.99.5/tcp/8995/p2p/QmYjQQBABPH4Pa3qLa6xiKqYwNJMgcYRgSz86fLnWpxXxz 29 | /ip4/192.168.99.6/tcp/8996/p2p/QmbSmDzao2fUngx2RvMtRPa9iFgqTaRDFmfmy5koNpvynr 30 | /ip4/192.168.99.7/tcp/8997/p2p/QmdG1pjLXEtNGnLrmmieNxoVwYBtgzwid9rBHwNTgLcjWf 31 | 32 | 33 | # 216.153.54.93 34 | 35 | /ip4/192.168.99.10/tcp/8992/p2p/QmbzGaWoEDk3v1MdZRMPvQvVidMduGZuLzb8yK6WQmgx4S 36 | /ip4/192.168.99.14/tcp/7993/p2p/QmfDUqypQLw9H2DJzj8LpCWiX9dqJUSMLZWBUV5tPqBmDL 37 | /ip4/192.168.99.16/tcp/7994/p2p/Qmc7S5ThjBEguEUY53rzEDjGxmNipwLzcfsZHaV2cH1yKQ 38 | /ip4/192.168.99.11/tcp/7995/p2p/QmQwUsM8LwidAEsyfw6FjdCHroZnW5A5daNLUxXqiTKLcr 39 | /ip4/192.168.99.15/tcp/7996/p2p/QmWQgKRwVE6cNPvcCjkXNCTj7VrJ6oWG3f8Q29vRwe74Na 40 | /ip4/192.168.99.13/tcp/7997/p2p/QmbHFyZqvK7jPwtMVCFWWSh45GoEkAokPxvgrZZ3gTq2cK 41 | /ip4/192.168.99.12/tcp/7998/p2p/QmZBZ5adiim1jM7yk4dKaWbyh2YuEeW7VGRj4o6Rgis6Pc 42 | /ip4/192.168.99.17/tcp/7999/p2p/QmZtKycUXxKC6LLzbspypM7A1aZe6Bbq2Ja8E6DMDQq6G8 43 | 44 | 45 | i=5 46 | INITIAL_PEERS=/ip4/192.168.99.7/tcp/8997/p2p/QmdG1pjLXEtNGnLrmmieNxoVwYBtgzwid9rBHwNTgLcjWf 47 | BLOCKS=14 48 | SESSION_NAME="model_$i" 49 | tmux new-session -d -s $SESSION_NAME 50 | tmux send-keys -t $SESSION_NAME "cd FastGen-enhance-spec_decoding/" C-m 51 | tmux send-keys -t $SESSION_NAME "export CUDA_DEVICE=$i" C-m 52 | tmux send-keys -t $SESSION_NAME "export INITIAL_PEERS=$INITIAL_PEERS" C-m 53 | tmux send-keys -t $SESSION_NAME "export MODEL_PORT=\`expr \$CUDA_DEVICE + 59551\`" C-m 54 | tmux send-keys -t $SESSION_NAME "bash petals_model.sh \$MODEL_PORT \$CUDA_DEVICE \$INITIAL_PEERS $BLOCKS" C-m 55 | 56 | # 观察一下model是不是正确启动了对应的层 57 | tmux attach -t model_$i 58 | # 观察--initial_peers 59 | tmux attach -t session_1 -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/config_petals.sh: -------------------------------------------------------------------------------- 1 | # nvidia 12.1 install 2 | sudo su 3 | dpkg -l | grep -iE "Cuda|nvidia" | awk {'print $2'} | xargs apt-get -y remove 4 | dpkg -l | grep -iE "Cuda|nvidia" | awk {'print $2'} | xargs apt-get -y purge 5 | exit 6 | 7 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin 8 | sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 9 | wget https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda-repo-ubuntu2004-12-1-local_12.1.1-530.30.02-1_amd64.deb 10 | sudo dpkg -i cuda-repo-ubuntu2004-12-1-local_12.1.1-530.30.02-1_amd64.deb 11 | sudo cp /var/cuda-repo-ubuntu2004-12-1-local/cuda-*-keyring.gpg /usr/share/keyrings/ 12 | sudo apt-get update 13 | sudo apt-get -y install cuda 14 | 15 | sudo reboot 16 | 17 | # update python to 3.11 18 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 19 | sh Miniconda3-latest-Linux-x86_64.sh -b -p ${HOME}/software/miniconda3 20 | rm -f Miniconda3-latest-Linux-x86_64.sh 21 | echo "export PATH=${HOME}/software/miniconda3/bin:\$PATH" >> ~/.bashrc 22 | source ~/.bashrc 23 | conda --version 24 | 25 | 26 | # install petals 27 | pip install git+https://github.com/bigscience-workshop/petals 28 | 29 | # for more RAM 30 | sudo rm -rf ~/.cache/pip 31 | 32 | # install nvidia-docker 33 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \ 34 | sudo apt-key add - 35 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 36 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \ 37 | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 38 | sudo apt-get update 39 | 40 | # for docker gpus not found 41 | sudo apt-get install -y nvidia-docker2 42 | sudo systemctl restart docker 43 | 44 | # go install 45 | wget -c https://dl.google.com/go/go1.20.linux-amd64.tar.gz -O - | sudo tar -xz -C /usr/local 46 | export PATH=$PATH:/usr/local/go/bin 47 | # source ~/.profile 48 | 49 | # ocf start 50 | cd ocf-enhance-bind_wallet/src/ocf-core 51 | git init 52 | make build 53 | build/core start --config config/cfg.yaml 54 | 55 | 56 | # scp ocf and FastGen 57 | 58 | # start coordinator first 59 | export ID=1 60 | export COOR_PORT=`expr $ID + 10005` 61 | cd FastGen-enhance-spec_decoding/ 62 | bash petals_coordinator.sh $ID $COOR_PORT 63 | 64 | # start coordinator after 65 | export ID=3 66 | tmux new -s pt_$ID 67 | export COOR_PORT=`expr $ID + 10005` 68 | cd FastGen-enhance-spec_decoding/ 69 | bash petals_coordinator.sh $ID $COOR_PORT /ip4/192.168.99.2/tcp/9991/p2p/QmY7XSYgdnRy4nJU27SJ7VJx7HAYWNGhSm9ZamepCucZMJ 70 | 71 | 72 | # start model 73 | cd FastGen-enhance-spec_decoding/ 74 | export CUDA_DEVICE=0 75 | export INITIAL_PEERS=/ip4/192.168.99.2/tcp/10006/p2p/QmWGUaPYJrGm44HGWHBfqRqgWnHKoS8mSytJhBbe3MbpCx 76 | export MODEL_PORT=`expr $CUDA_DEVICE + 17551` 77 | bash petals_model.sh $MODEL_PORT $CUDA_DEVICE $INITIAL_PEERS 10 78 | 79 | # in case fail and ports already allocated 80 | sudo docker stop $(sudo docker ps -a -q) 81 | -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/config_tgi.sh: -------------------------------------------------------------------------------- 1 | # nvidia 12.1 install 2 | sudo su 3 | dpkg -l | grep -iE "Cuda|nvidia" | awk {'print $2'} | xargs apt-get -y remove 4 | dpkg -l | grep -iE "Cuda|nvidia" | awk {'print $2'} | xargs apt-get -y purge 5 | exit 6 | 7 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin 8 | sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 9 | wget https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda-repo-ubuntu2004-12-1-local_12.1.1-530.30.02-1_amd64.deb 10 | sudo dpkg -i cuda-repo-ubuntu2004-12-1-local_12.1.1-530.30.02-1_amd64.deb 11 | sudo cp /var/cuda-repo-ubuntu2004-12-1-local/cuda-*-keyring.gpg /usr/share/keyrings/ 12 | sudo apt-get update 13 | sudo apt-get -y install cuda 14 | 15 | sudo reboot 16 | 17 | 18 | # install nvidia-docker 19 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \ 20 | sudo apt-key add - 21 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 22 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \ 23 | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 24 | sudo apt-get update 25 | 26 | # for docker gpus not found 27 | sudo apt-get install -y nvidia-docker2 28 | sudo systemctl restart docker 29 | 30 | # go install 31 | wget -c https://dl.google.com/go/go1.20.linux-amd64.tar.gz -O - | sudo tar -xz -C /usr/local 32 | export PATH=$PATH:/usr/local/go/bin 33 | # source ~/.profile 34 | 35 | # ocf start 36 | cd ocf-enhance-bind_wallet/src/ocf-core 37 | git init 38 | make build 39 | build/core start --config config/cfg.yaml 40 | 41 | 42 | model=meta-llama/Llama-2-70b-chat-hf 43 | volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run 44 | token=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 45 | 46 | sudo docker run --gpus '"device=0,1,2,3"' --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 9090:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model --num-shard 4 --dtype float16 --cuda-memory-fraction 0.45 -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/petals_coordinator.sh: -------------------------------------------------------------------------------- 1 | VAR=$1 2 | PORT_COO=$2 3 | MASTER_PORT=9991 4 | 5 | if [[ $VAR -eq 1 ]] 6 | then 7 | # the first coordinator 8 | sudo docker run -p $PORT_COO:$PORT_COO --ipc host --gpus all --volume petals-cache:/cache \ 9 | --rm learningathome/petals:main python -m petals.cli.run_dht --host_maddrs /ip4/0.0.0.0/tcp/$PORT_COO --identity_path bootstrap1.id 10 | else 11 | # petals coordinator 12 | INITIAL_PEERS=$3 13 | sudo docker run -p $PORT_COO:$PORT_COO --ipc host --gpus all --volume petals-cache:/cache \ 14 | --rm learningathome/petals:main python -m petals.cli.run_dht --host_maddrs /ip4/0.0.0.0/tcp/$PORT_COO --identity_path bootstrap1.id --initial_peers $INITIAL_PEERS 15 | fi 16 | 17 | 18 | -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/petals_model.sh: -------------------------------------------------------------------------------- 1 | # petals model 2 | MODEL_PORT=$1 3 | DEVICE=$2 4 | INITIAL_PEERS=$3 5 | BLOCKS=$4 6 | VISIBLE_DEVICE='"device='$DEVICE',"' 7 | sudo docker run -p $MODEL_PORT:$MODEL_PORT --ipc host --gpus=''$VISIBLE_DEVICE'' --volume petals-cache:/cache \ 8 | --rm learningathome/petals:main python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf \ 9 | --initial_peers $INITIAL_PEERS --token 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 --num_blocks $BLOCKS --quant_type none 10 | -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/start_coo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | INITIAL_PEERS=$1 4 | 5 | 6 | for i in {2..8}; do 7 | SESSION_NAME="session_$i" 8 | tmux new-session -d -s $SESSION_NAME 9 | tmux send-keys -t $SESSION_NAME "export ID=$i" C-m 10 | tmux send-keys -t $SESSION_NAME "export COOR_PORT=\`expr \$ID + 8991\`" C-m 11 | tmux send-keys -t $SESSION_NAME "cd FastGen-enhance-spec_decoding/" C-m 12 | tmux send-keys -t $SESSION_NAME "bash petals_coordinator.sh \$ID \$COOR_PORT $INITIAL_PEERS" C-m 13 | done 14 | 15 | echo "7 tmux sessions have been initialized." 16 | -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/start_coo_head.sh: -------------------------------------------------------------------------------- 1 | # 2 | i=1 3 | SESSION_NAME="session_$i" 4 | tmux new-session -d -s $SESSION_NAME 5 | tmux send-keys -t $SESSION_NAME "export ID=$i" C-m 6 | tmux send-keys -t $SESSION_NAME "export COOR_PORT=\`expr \$ID + 9991\`" C-m 7 | tmux send-keys -t $SESSION_NAME "cd FastGen-enhance-spec_decoding/" C-m 8 | tmux send-keys -t $SESSION_NAME "bash petals_coordinator.sh \$ID \$COOR_PORT" C-m 9 | tmux attach -t session_$i -------------------------------------------------------------------------------- /benchmark/petals_documents/config_scripts/start_model.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # tune i, peers, blocks mannually at each time. 4 | 5 | # selecting which gpu, for example, from 0-7 if your machine has 8 GPUs 6 | i=3 7 | # corresponding coordinator's addr 8 | # change this in real settings 9 | INITIAL_PEERS=/ip4/xxx.xxx.xxx/tcp/8995/p2p/QmbWW8jCpqLsAFjYMVLyhYWoMLet3iV8JZkJUmzK1b1eF4 10 | # how many layers you want to serve on this gpu 11 | BLOCKS=10 12 | SESSION_NAME="model_$i" 13 | tmux new-session -d -s $SESSION_NAME 14 | tmux send-keys -t $SESSION_NAME "cd FastGen-enhance-spec_decoding/" C-m 15 | tmux send-keys -t $SESSION_NAME "export CUDA_DEVICE=$i" C-m 16 | tmux send-keys -t $SESSION_NAME "export INITIAL_PEERS=$INITIAL_PEERS" C-m 17 | tmux send-keys -t $SESSION_NAME "export MODEL_PORT=\`expr \$CUDA_DEVICE + 59551\`" C-m 18 | tmux send-keys -t $SESSION_NAME "bash petals_model.sh \$MODEL_PORT \$CUDA_DEVICE \$INITIAL_PEERS $BLOCKS" C-m 19 | 20 | tmux attach -t model_3 21 | tmux attach -t session_3 -------------------------------------------------------------------------------- /benchmark/petals_documents/scripts/run_petals.sh: -------------------------------------------------------------------------------- 1 | export INITIAL_PEERS=$1 2 | 3 | python3 _petals.py --model-name meta-llama/Llama-2-70b-chat-hf --token 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 --init-peers $INITIAL_PEERS -------------------------------------------------------------------------------- /benchmark/send_request/README.md: -------------------------------------------------------------------------------- 1 | ## Send Request 2 | 3 | Once you have started service on head coordinator and worker coordinators, you could send request to them. 4 | 5 | First modify the input in `single_reques.py`, just make sure you add an suffix `_0` to correctly call the rank-0, an example is 6 | ```python 7 | data = { 8 | 'model_name': 'Llama-2-70b-chat-hf_0', 9 | 'params': { 10 | 'prompt': "Do you like your self? ", 11 | 'max_new_tokens': 128, 12 | 'temperature': 0.2, 13 | 'top_p': 0.9, 14 | 'top_k': 40, 15 | } 16 | } 17 | ``` 18 | 19 | 20 | By running the following command, you will see the answer to prompt, pure inference time and over time, in a python tuple. 21 | 22 | ```python 23 | python3 single_request.py 24 | ``` 25 | 26 | functions in `request.py` provides for retrieving answers and checking nodes' status, respectively. 27 | -------------------------------------------------------------------------------- /benchmark/send_request/request.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import time 3 | 4 | async def request_head_node(data, head_node, task_id=0): 5 | start = time.time() 6 | 7 | # this large timeout is useful when tasks are crowded on coordinato 8 | timeout = aiohttp.ClientTimeout(total=60 * 60) 9 | async with aiohttp.ClientSession(timeout=timeout) as session: 10 | endpoint = f"{head_node}/api/v1/request/inference" 11 | resp = await session.post(endpoint, json=data) 12 | result = await resp.json() 13 | if 'error' in result: 14 | return None, None, None 15 | 16 | prompt_resp, infer_time = eval(result['data']) 17 | print(f"##### task {task_id} has finished inference #####") 18 | 19 | 20 | return prompt_resp, infer_time, time.time() - start 21 | 22 | async def check_status(head_node): 23 | async with aiohttp.ClientSession() as session: 24 | endpoint = f"{head_node}/api/v1/status/peers" 25 | resp = await session.get(endpoint) 26 | result = await resp.json() 27 | return result 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /benchmark/send_request/single_request.py: -------------------------------------------------------------------------------- 1 | from request import * 2 | from datetime import datetime 3 | import asyncio 4 | 5 | # Modify the IP below before execution 6 | head_node = "http://xxx.xxx.xx.xxx:xxxx" 7 | 8 | process_time = [] 9 | res_list = [] 10 | 11 | start = datetime.now() 12 | 13 | data = { 14 | # align with the name specified in worker 15 | 'model_name': 'Llama-2-70b-chat-hf_0', 16 | 'params': { 17 | 'prompt': "Do you like your self? ", 18 | 'max_new_tokens': 128, 19 | 'temperature': 0.2, 20 | 'top_p': 0.9, 21 | 'top_k': 40, 22 | } 23 | } 24 | 25 | res = asyncio.run(request_head_node(data, head_node=head_node)) 26 | 27 | end = datetime.now() 28 | 29 | res_list.append(res) 30 | process_time.append(end - start) # each element is a timedelta 31 | 32 | print(res) 33 | print(process_time) 34 | 35 | print("=" * 40) 36 | 37 | status = asyncio.run(check_status(head_node)) 38 | print(status) 39 | -------------------------------------------------------------------------------- /benchmark/tgi_documents/README.md: -------------------------------------------------------------------------------- 1 | ## Config tgi step by step 2 | 3 | 1. Install nvidia-docker 4 | 5 | ```bash 6 | # install nvidia-docker 7 | curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \ 8 | sudo apt-key add - 9 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 10 | curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \ 11 | sudo tee /etc/apt/sources.list.d/nvidia-docker.list 12 | sudo apt-get update 13 | 14 | # If the above commands doesn't work, run the following for docker gpus not found problem 15 | sudo apt-get install -y nvidia-docker2 16 | sudo systemctl restart docker 17 | ``` 18 | 19 | 2. Run the docker command, tune the quantization, data type, num shards(how many GPUs to be used), cuda memory fraction, port (If planning to start multiple instances within one machine). Refer to `launcher.md` listed in tgi's repo for more details. Following command is an example. 20 | 21 | ```bash 22 | model=meta-llama/Llama-2-70b-chat-hf 23 | volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run 24 | token= 25 | 26 | docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model --num-shard 4 --dtype float16 --cuda-memory-fraction 0.45 27 | ``` 28 | 29 | 3. Run the worker, this step assumes you have started work coordinator. 30 | 31 | ```bash 32 | bash scripts/run_tgi.sh 33 | ``` 34 | -------------------------------------------------------------------------------- /benchmark/tgi_documents/_tgi_worker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import argparse 4 | from loguru import logger 5 | import sys 6 | sys.path.insert(0, "..") 7 | sys.path.insert(0, "../utils") 8 | from utils._base import InferenceWorker 9 | from text_generation import Client 10 | 11 | 12 | class TGIWorker(InferenceWorker): 13 | def __init__(self, model_name, tgi_addr) -> None: 14 | super().__init__(f"{model_name}") 15 | 16 | self.tgi_addr = tgi_addr 17 | 18 | async def handle_requests(self, msg): 19 | prompts = msg.get('prompt', '') 20 | max_new_tokens = msg.get('max_new_tokens', 128) 21 | temperature = msg.get('temperature', 0.9) 22 | top_k = msg.get('top_k', 50) 23 | top_p = msg.get('top_p', 0.9) 24 | 25 | client = Client(self.tgi_addr) 26 | 27 | torch.cuda.synchronize() 28 | start = time.time() 29 | outputs = client.generate(prompts, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p).generated_text 30 | end = time.time() 31 | 32 | return outputs, end - start 33 | 34 | if __name__=="__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--model-name", type=str, default="tgi_0") 37 | parser.add_argument("--tgi-addr", type=str, default="http://127.0.0.1:8080") 38 | args = parser.parse_args() 39 | logger.info(f"args: {args}") 40 | worker = TGIWorker(args.model_name, args.tgi_addr) 41 | worker.start() -------------------------------------------------------------------------------- /benchmark/tgi_documents/scripts/run_tgi.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=$1 python3 _tgi_worker.py --tgi-addr "http://127.0.0.1:8080" -------------------------------------------------------------------------------- /benchmark/utils/_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import signal 3 | import asyncio 4 | from loguru import logger 5 | from nats.aio.client import Client as NATS 6 | 7 | from _utils import get_visible_gpus_specs 8 | 9 | async def shutdown(signal, loop, nc, model_name, connection_notice): 10 | """Cleanup tasks tied to the service's shutdown.""" 11 | logger.info(f"Gracefully shutting down {model_name} worker...") 12 | tasks = [t for t in asyncio.all_tasks() if t is not 13 | asyncio.current_task()] 14 | [task.cancel() for task in tasks] 15 | await asyncio.gather(*tasks) 16 | connection_notice['status'] = 'disconnected' 17 | await nc.publish("worker:status", bytes(f"{json.dumps(connection_notice)}", encoding='utf-8')) 18 | await nc.close() 19 | loop.stop() 20 | 21 | class InferenceWorker(): 22 | def __init__(self, model_name) -> None: 23 | self.model_name = model_name 24 | # todo: get gpu specs from nvml 25 | self.nc = NATS() 26 | self.connection_notice = {} 27 | 28 | async def run(self, loop): 29 | await self.nc.connect("nats://localhost:8094") 30 | await self.nc.subscribe(f"inference:{self.model_name}", "workers", self.process_request) 31 | self.connection_notice = { 32 | 'service': f'inference:{self.model_name}', 33 | 'gpus': get_visible_gpus_specs(), 34 | 'client_id': self.nc.client_id, 35 | 'status': 'connected' 36 | } 37 | await self.nc.publish("worker:status", bytes(f"{json.dumps(self.connection_notice)}", encoding='utf-8')) 38 | 39 | async def process_request(self, msg): 40 | processed_msg = json.loads(msg.data.decode()) 41 | result = await self.handle_requests(processed_msg['params']) 42 | await self.reply(msg, result) 43 | 44 | async def handle_requests(self, msg): 45 | raise NotImplementedError 46 | 47 | async def reply(self, msg, data): 48 | data = json.dumps(data) 49 | await self.nc.publish(msg.reply, bytes(data, encoding='utf-8')) 50 | 51 | def start(self): 52 | # atexit.register(exit_handler, self.nc, self.model_name, self.connection_notice) 53 | logger.info(f"Starting {self.model_name} worker...") 54 | loop = asyncio.get_event_loop() 55 | signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT, signal.SIGABRT, signal.SIGTSTP) 56 | for s in signals: 57 | loop.add_signal_handler( 58 | s, lambda s=s: asyncio.create_task(shutdown(s, loop, self.nc, self.model_name, self.connection_notice))) 59 | loop.run_until_complete(self.run(loop)) 60 | loop.run_forever() 61 | loop.close() -------------------------------------------------------------------------------- /benchmark/utils/_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from loguru import logger 3 | 4 | def get_visible_gpus_specs(): 5 | # https://github.com/gpuopenanalytics/pynvml/issues/28 6 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 7 | gpus = [] 8 | try: 9 | from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetName 10 | nvmlInit() 11 | if "CUDA_VISIBLE_DEVICES" in os.environ: 12 | ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))) 13 | else: 14 | deviceCount = nvmlDeviceGetCount() 15 | ids = range(deviceCount) 16 | for i in ids: 17 | handle = nvmlDeviceGetHandleByIndex(i) 18 | meminfo = nvmlDeviceGetMemoryInfo(handle) 19 | gpus.append({ 20 | 'name': nvmlDeviceGetName(handle), 21 | 'memory': meminfo.total, 22 | 'memory_free': meminfo.free, 23 | 'memory_used': meminfo.used, 24 | }) 25 | except Exception as e: 26 | logger.info(f"No GPU found: {e}") 27 | return gpus -------------------------------------------------------------------------------- /hexgen/hexgen_core/README.md: -------------------------------------------------------------------------------- 1 | ## Features of HexGen 2 | HexGen stands out for its exceptional handling of large-scale transformer models, offering flexibility and efficiency through its innovative features. 3 | 4 | ### Tensor Model Parallelism with Integrated Heterogeneous Communication 5 | HexGen implements asymmetric tensor model parallelism in a heterogeneous environment, efficiently grouping GPUs for optimized computation. It incorporates the generation of heterogeneous communication groups, allowing for coordinated peer-to-peer communication and data management. A leader GPU node is selected within each tensor parallelism group to manage the broadcast operation of activations, ensuring efficient data distribution and reducing computational overhead. 6 | 7 | ### Pipeline Parallelism with Enhanced Communication Dynamics 8 | The system aligns pipeline stages with the corresponding tensor parallelism groups, facilitating concurrent processing and improved throughput. In each pipeline stage, a leader GPU node is chosen to handle peer-to-peer communication between stages, streamlining the data transfer process. This integration of pipeline parallelism with advanced communication dynamics ensures smooth and efficient processing across different stages of the model. 9 | 10 | ### Fast Decoding Using Flash Attention 11 | Incorporating Flash Attention, HexGen significantly enhances its decoding capabilities. This integration brings state-of-the-art efficiency to attention mechanism computations within transformer models, leading to faster and more effective processing. 12 | 13 | ### Core File Overview 14 | 15 | - HexGen integrates the `models`, `modules`, and `generation.py` scripts from [Flash Attention](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn) to utilize FlashAttention-2, enhancing the efficiency of attention mechanism computations. 16 | 17 | - `gen_hetero_groups.py`, `gen_comm_groups.py`, and `gen_parallel_groups.py` files are crucial for automatically configuring various tensor model parallel and pipeline parallel groups. This setup paves the way for advanced asymmetric hybrid parallelism strategies. 18 | 19 | - `gen_p2p_lists.py` and `heterogeneous_pipeline.py` scripts are designed to establish and manage the peer-to-peer communication (pipeline parallel communication) essential for asymmetric pipeline configurations. 20 | 21 | -------------------------------------------------------------------------------- /hexgen/hexgen_core/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_comm_groups import * 2 | from .gen_hetero_groups import * 3 | from .gen_p2p_lists import * 4 | from .gen_parallel_groups import * 5 | from .generation import * 6 | from .heterogeneous_pipeline import * 7 | from .init_utils import * 8 | from .utils import * 9 | -------------------------------------------------------------------------------- /hexgen/hexgen_core/gen_comm_groups.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CommGroup(object): 4 | def __init__(self, ranks): 5 | assert isinstance(ranks, list) or isinstance(ranks, range), 'Rank list or range should be provided to create a CommGroup!' 6 | self.ranks = sorted(list(set(list(ranks)))) 7 | self.size = len(self.ranks) 8 | self.group = torch.distributed.new_group(self.ranks) 9 | def has_rank(self, rank): 10 | if rank in self.ranks: 11 | self.intra_group_id = self.ranks.index(rank) 12 | return True 13 | return False 14 | def allgather(self, input): 15 | return gather_from_tensor_model_parallel_region_group(input, self.group) 16 | def print(self): 17 | print(self.ranks, end = ' ') 18 | -------------------------------------------------------------------------------- /hexgen/hexgen_core/gen_p2p_lists.py: -------------------------------------------------------------------------------- 1 | def generate_send_recv_lists(pipeline_groups, mainline): 2 | # initialize empty send and receive lists for each rank 3 | ranks = set(rank for group in pipeline_groups for rank in group) 4 | SendList = {rank: [] for rank in ranks} 5 | RecvList = {rank: [] for rank in ranks} 6 | SendBoolean = {rank: [] for rank in ranks} 7 | RecvBoolean = {rank: [] for rank in ranks} 8 | 9 | # fill up send and receive lists based on pipeline groups 10 | for group in pipeline_groups: 11 | is_mainline = set(group) == set(mainline) 12 | for i in range(len(group) - 1): 13 | # Avoid appending duplicates 14 | if group[i+1] not in SendList[group[i]]: 15 | SendList[group[i]].append(group[i+1]) 16 | SendBoolean[group[i]].append(not is_mainline) 17 | if group[i] not in RecvList[group[i+1]]: 18 | RecvList[group[i+1]].append(group[i]) 19 | RecvBoolean[group[i+1]].append(not is_mainline) 20 | 21 | return SendList, RecvList, SendBoolean, RecvBoolean 22 | 23 | # pipeline_groups = [[0,2,4], [1,3,5], [1,3,6], [1,3,7]] 24 | # send, recv, send_bool, recv_bool = generate_send_recv_lists(pipeline_groups) 25 | # print("Send List:", send) 26 | # print("Recv List:", recv) 27 | # print("Send Boolean:", send_bool) 28 | # print("Recv Boolean:", recv_bool) 29 | 30 | -------------------------------------------------------------------------------- /hexgen/hexgen_core/gen_parallel_groups.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | def param_init_fn(module): 5 | module.to_empty(device=torch.device("cuda")) 6 | for m in module.modules(): 7 | if callable(getattr(m, 'reset_parameters', None)): 8 | m.reset_parameters() 9 | 10 | def param_init_fn_(module: nn.Module): 11 | for submodule in module.modules(): 12 | # Handle parameters 13 | for param_name, param in submodule.named_parameters(recurse=False): 14 | if param.is_meta: 15 | materialized_param = nn.Parameter( 16 | torch.empty_like(param, device=torch.device("cuda")) 17 | ) 18 | nn.init.uniform_(materialized_param) 19 | setattr(submodule, param_name, materialized_param) 20 | # Handle buffers 21 | for buffer_name, buffer in submodule.named_buffers(recurse=False): 22 | if buffer.is_meta: 23 | materialized_buffer = torch.empty_like(buffer, device=torch.device("cuda")) 24 | # No need to apply nn.init.uniform_ unless you specifically want to for buffers. 25 | setattr(submodule, buffer_name, materialized_buffer) 26 | 27 | def wrap_modules_data_parallel(module_list, dp_types, dp_groups, module_types, pp_devices=None, mixed_precision=torch.bfloat16, default_process_group=None, wrap_block_name=None): 28 | assert len(module_list) == len(dp_types) 29 | assert len(module_list) == len(dp_groups) 30 | 31 | process_group = default_process_group if default_process_group is not None else dp_groups[0] 32 | pp_on = True if process_group.size < torch.distributed.get_world_size() else False 33 | 34 | if pp_devices is not None: 35 | assert len(pp_devices) == len(module_list) 36 | for i in range(len(module_list)): 37 | pp_device = None if pp_devices is None else pp_devices[i] 38 | param_init_fn_(module_list[i]) 39 | module_list[i].process_group = process_group.group 40 | module_list[i] = module_list[i].to(pp_device) 41 | return module_list 42 | -------------------------------------------------------------------------------- /hexgen/hexgen_core/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/hexgen/hexgen_core/models/__init__.py -------------------------------------------------------------------------------- /hexgen/hexgen_core/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/hexgen/hexgen_core/modules/__init__.py -------------------------------------------------------------------------------- /hexgen/llama/README.md: -------------------------------------------------------------------------------- 1 | ## Init Inference Tasks 2 | 3 | To initiate an independent inference process without involving the coordinator, execute the following command: 4 | 5 | ```bash 6 | bash scripts/run_llama_inference.sh 7 | ``` 8 | 9 | HexGen supports multi-process scenarios without relying on `torch.distributed.launch` for initialization. This is achieved by manually starting HexGen on each process across different machines. For instance, in a setup with 6 processes—4 on one machine and 2 on another—specific environment variables are exported for automatic detection by HexGen. The setup can be executed as follows: 10 | 11 | ```bash 12 | # on machine A 13 | bash scripts/run_llama_p0.sh 14 | bash scripts/run_llama_p1.sh 15 | bash scripts/run_llama_p2.sh 16 | bash scripts/run_llama_p3.sh 17 | # on machine B 18 | bash scripts/run_llama_p4.sh 19 | bash scripts/run_llama_p5.sh 20 | ``` 21 | 22 | Exercise caution with the `CUDA_VISIBLE_DEVICES` setting, as handling 6 processes on a single machine differs from managing them across multiple machines. 23 | 24 | You have the flexibility to customize various inputs to tailor your inference task according to your specific requirements. The `model_msg` object can be adjusted with different parameters, as shown in the example below: 25 | 26 | ```python 27 | model_msg = { 28 | 'prompt': "Do you like yourself ?", # Define your own prompt here 29 | 'max_new_tokens': 128, # Set the maximum number of new tokens 30 | 'temperature': 0.2, # Adjust the randomness in response generation 31 | 'top_k': 20, # Specify the number of highest probability vocabulary tokens to keep for top-k sampling 32 | 'top_p': 0.9, # Set the cumulative probability threshold for top-p (nucleus) sampling 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /hexgen/llama/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def add_arguments(parser): 4 | group = parser.add_argument_group(title='hexgen arguments') 5 | 6 | # hetro parallelism arguments 7 | group.add_argument( 8 | "--local-rank", type=int, default=-1, help="Local rank.", 9 | ) 10 | parser.add_argument( 11 | "--model_size", type=str, default='llama-7b', help="Model size.", choices=['llama-7b', 'llama-13b', 'llama-30b', 'llama-70b'] 12 | ) 13 | parser.add_argument( 14 | "--overwrite_config", type=int, default=0, help="Whether to overwrite model config" 15 | ) 16 | group.add_argument( 17 | "--initialize_on_meta", type=int, default=1, help="Whether to initialize parameters on meta device.", choices=[0, 1] 18 | ) 19 | group.add_argument( 20 | "--hidden_size", type=int, default=768, help="Hidden size of transformer model", 21 | ) 22 | group.add_argument( 23 | "--num_hidden_layers", type=int, default=12, help="Number of layers" 24 | ) 25 | group.add_argument( 26 | "-a", 27 | "--num_attention_heads", 28 | type=int, 29 | default=12, 30 | help="Number of attention heads", 31 | ) 32 | group.add_argument( 33 | "--vocab_size", type=int, default=30522, help="Total number of vocab" 34 | ) 35 | group.add_argument( 36 | "--dropout_prob", type=float, default=0.1, help="Dropout rate." 37 | ) 38 | parser.add_argument( 39 | "--mixed_precision", type=str, default='fp16', help="Mixed precision option.", choices=['fp32', 'fp16', 'bf16'], 40 | ) 41 | parser.add_argument( 42 | "--hetero_config", type=int, nargs='+', default=0, help="Give and execute heterogeneous configuration", 43 | ) 44 | parser.add_argument( 45 | "--pp_partition", type=int, nargs='+', default=0, help="Give and execute pipeline configuration", 46 | ) 47 | 48 | # coordinator arguments 49 | parser.add_argument( 50 | "--model_name", type=str, default="Llama-2-7b-chat-hf", help="Assign the desired name for a worker" 51 | ) 52 | # Modify the IP below before execution 53 | parser.add_argument( 54 | "--head_node", type=str, default='http://xxx.xxx.xx.xxx:xxxx', help="Head node of coordinator" 55 | ) 56 | parser.add_argument( 57 | "--priority", type=int, default=0, help="To be implemented", 58 | ) 59 | parser.add_argument( 60 | "--group_id", type=int, default=0, help="To differentiate workers on a single node", 61 | ) 62 | return parser 63 | 64 | 65 | _KV_CACHE_DICT = None 66 | 67 | def get_kv_cache(): 68 | global _KV_CACHE_DICT 69 | return _KV_CACHE_DICT 70 | 71 | def set_kv_cache(kv_cache_dict): 72 | global _KV_CACHE_DICT 73 | _KV_CACHE_DICT = kv_cache_dict 74 | 75 | def clear_kv_cache(): 76 | global _KV_CACHE_DICT 77 | _KV_CACHE_DICT = None 78 | -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-13b/llama-13b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-13b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-30b/llama-30b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-30b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-70b/llama-70b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-70b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} 2 | -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-7b/llama-7b/config.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000} 2 | -------------------------------------------------------------------------------- /hexgen/llama/llama-config/llama-7b/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": -1} -------------------------------------------------------------------------------- /hexgen/llama/load_model_parameters_utils/README.md: -------------------------------------------------------------------------------- 1 | ## Load Model Parameters for LlaMA 7b 2 | 3 | ### Overview 4 | This guide provides instructions on how to load model parameters for Llama-7b. It works very similarly for other version of Llama models. Here, we will focus on creating separate state dictionaries for each component and layer of the model. 5 | 6 | ### Customize Parameters 7 | If you need to specify custom paths, you can manually edit the `create_separate_state_dicts_llama_7b.py` script. Locate the `save_model_components` function call and adjust the paths as needed. For example: 8 | 9 | ```python 10 | save_model_components( 11 | config_path='../llama-config/', 12 | checkpoint_name='llama-7b', 13 | checkpoint_path='/path/to/Llama-2-7b-chat-hf/', 14 | num_layers=32, 15 | save_dir='./separate_state_dicts/' 16 | ) 17 | ``` 18 | 19 | Here, your sole requirement is to specify the `checkpoint_path`, as the other parameters have been pre-defined and supplied for your convenience. You can download the model checkpoints from [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). 20 | 21 | ### Run the Script 22 | To create the separate state dictionaries for the Llama-7b model, run the following command in the terminal: 23 | 24 | ```bash 25 | python3 create_separate_state_dicts_llama_7b.py 26 | ``` 27 | 28 | This script will automatically generate and save the state dictionaries in the appropriate directory. 29 | 30 | ### Verify the Output 31 | After running the script, you should find the separate state dictionaries saved in the designated folder. Verify that all the expected files are present and correctly named. 32 | 33 | ### Modifying the Inference Script 34 | In the `llama_inference.py` file, add the following code snippet to load the parameters for Llama-7b. Adjust the paths as per your setup: 35 | 36 | ```python 37 | # Load model checkpoints with respect to hetero_config 38 | tp_ranks_whole_model = hetero_groups['tp_ranks_whole_model'] 39 | tp_group_list = hetero_groups['tp_rank_groups'] 40 | state_dicts_path = "./load_model_parameters_utils/" 41 | load_model_parameters(model, config, state_dicts_path, tp_ranks_whole_model, tp_group_list, rank) 42 | ``` 43 | -------------------------------------------------------------------------------- /hexgen/llama/load_model_parameters_utils/create_separate_state_dicts_llama_7b.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | sys.path.insert(0, '..') 5 | sys.path.insert(0, '../site-package') 6 | from llama_config_utils import llama_config_to_gpt2_config, config_from_checkpoint, overwrite_configs_and_args 7 | from transformers import LlamaForCausalLM, LlamaTokenizer 8 | from remap_state_dict import remap_state_dict_hf_llama 9 | 10 | def load_remapped_state_dict(config, checkpoint_path): 11 | 12 | """ 13 | Loads and remaps the state dictionary of a pretrained Llama model. 14 | 15 | Parameters: 16 | - config (dict): Configuration dictionary for the Llama model. 17 | 18 | Returns: 19 | - dict: Remapped state dictionary suitable for the specific configuration. 20 | """ 21 | 22 | state_dict = remap_state_dict_hf_llama(LlamaForCausalLM.from_pretrained(f"{checkpoint_path}").state_dict(), config) 23 | return state_dict 24 | 25 | def save_model_components(config_path, checkpoint_name, checkpoint_path, num_layers, save_dir): 26 | 27 | """ 28 | Save specific components and each transformer layer of a model's state dictionary to separate files. 29 | 30 | Args: 31 | config_path (str): Path to the configuration directory. 32 | checkpoint_name (str): Name of the model checkpoint. 33 | num_layers (int): Number of transformer layers in the model. 34 | save_dir (str): Directory path where the state dictionaries will be saved. 35 | 36 | This function performs the following steps: 37 | 1. Load the configuration and state dictionary for the model. 38 | 2. Save specific components of the state dictionary (embeddings, layer normalization, and language model head). 39 | 3. Iterate over each transformer layer and save its state dictionary separately. 40 | """ 41 | 42 | # Configuration and state dictionary loading 43 | llama_config = config_from_checkpoint(config_path, checkpoint_name) 44 | config = llama_config_to_gpt2_config(llama_config) 45 | state_dict = load_remapped_state_dict(config, checkpoint_path) 46 | 47 | # Saving specific components of the state dictionary to separate files 48 | torch.save(state_dict['transformer.embeddings.word_embeddings.weight'], f'{save_dir}/embeddings.pt') 49 | torch.save(state_dict['transformer.ln_f.weight'], f'{save_dir}/ln_f.pt') 50 | torch.save(state_dict['lm_head.weight'], f'{save_dir}/lm_head.pt') 51 | 52 | # Save the state dictionary of each transformer layer separately 53 | for idx in range(num_layers): 54 | layer_key_prefix = f'transformer.layers.{idx}' 55 | layer_state_dict = {key: value for key, value in state_dict.items() if key.startswith(layer_key_prefix)} 56 | torch.save(layer_state_dict, f'{save_dir}/layer_{idx}.pt') 57 | 58 | def main(): 59 | # Generate model separate state_dicts 60 | if not os.path.exists("./separate_state_dicts"): 61 | os.mkdir("./separate_state_dicts") 62 | 63 | save_model_components( 64 | config_path='../llama-config/', 65 | checkpoint_name='llama-7b', 66 | checkpoint_path='../../../../Llama-2-7b-chat-hf/', 67 | num_layers=32, 68 | save_dir='./separate_state_dicts/' 69 | ) 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /hexgen/llama/load_model_parameters_utils/inv_freq.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/hexgen/llama/load_model_parameters_utils/inv_freq.pt -------------------------------------------------------------------------------- /hexgen/llama/modules/Llamamodel_tensor_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import Tensor, device 4 | from typing import Tuple 5 | import sys 6 | sys.path.insert(0, '..') 7 | sys.path.insert(0, '../site-package') 8 | from megatron.model.utils import init_method_normal, scaled_init_method_normal 9 | from megatron import get_args 10 | from megatron.model.enums import AttnMaskType, AttnType 11 | from megatron.model import MegatronModule 12 | from megatron.core import mpu, tensor_parallel 13 | import torch.nn.functional as F 14 | 15 | class LlamaParallelMLP(MegatronModule): 16 | def __init__(self, init_method, 17 | output_layer_init_method, 18 | act_func = 'silu', 19 | bias = False, 20 | dropout_prob = 0.0, 21 | hidden_size=None, 22 | intermediate_size=None, 23 | tp_group = None, 24 | ): 25 | super().__init__() 26 | args = get_args() 27 | self.bias = bias 28 | 29 | hidden_size = args.hidden_size 30 | intermediate_size = int(8 * hidden_size / 3) 31 | intermediate_size = args.multiple_of * ((intermediate_size + args.multiple_of - 1) // args.multiple_of) 32 | 33 | self.w1 = tensor_parallel.ColumnParallelLinear( 34 | hidden_size, 35 | intermediate_size, 36 | gather_output=False, 37 | init_method=init_method, 38 | bias=bias, 39 | tp_group=tp_group 40 | ) 41 | 42 | self.w2 = tensor_parallel.RowParallelLinear( 43 | intermediate_size, 44 | hidden_size, 45 | input_is_parallel=True, 46 | init_method=output_layer_init_method, 47 | bias=bias, 48 | tp_group=tp_group 49 | ) 50 | 51 | self.w3 = tensor_parallel.ColumnParallelLinear( 52 | hidden_size, 53 | intermediate_size, 54 | gather_output=False, 55 | init_method=init_method, 56 | bias=bias, 57 | tp_group=tp_group 58 | ) 59 | 60 | assert act_func == 'silu' 61 | self.activation_func = F.silu 62 | 63 | def forward(self, hidden_states): 64 | return self.w2(self.activation_func(self.w1(hidden_states)[0]) * self.w3(hidden_states)[0])[0] 65 | 66 | class LlamaMLP_tp(nn.Module): 67 | def __init__(self, config, tp_group = None): 68 | super().__init__() 69 | args=get_args() 70 | init_method = init_method_normal(args.init_method_std) 71 | scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) 72 | self.tp_group = tp_group.group if tp_group is not None else None 73 | self.mlp = LlamaParallelMLP(init_method, scaled_init_method, tp_group = self.tp_group) 74 | 75 | def forward(self, hidden_states): 76 | hidden_states = self.mlp(hidden_states) 77 | return hidden_states 78 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 9996 llama_inference.py \ 2 | --model_size llama-7b \ 3 | --use-flash-attn \ 4 | --hetero_config 1 2 1 \ 5 | --pp_partition 8 16 8 \ 6 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_p0.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=0 2 | export NCCL_IB_HCA=mlx5_2,mlx5_5 3 | # Modify the master IP below before execution 4 | export MASTER_ADDR='xxx.xxx.xxx.xx' 5 | export MASTER_PORT=9991 6 | export WORLD_SIZE=6 7 | export RANK=0 8 | 9 | CUDA_VISIBLE_DEVICES=0 python3 llama_inference.py \ 10 | --model_size llama-7b \ 11 | --use-flash-attn \ 12 | --hetero_config 4 2 \ 13 | --pp_partition 20 12 \ 14 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_p1.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=0 2 | export NCCL_IB_HCA=mlx5_2,mlx5_5 3 | # Modify the master IP below before execution 4 | export MASTER_ADDR='xxx.xxx.xxx.xx' 5 | export MASTER_PORT=9991 6 | export WORLD_SIZE=6 7 | export RANK=1 8 | 9 | CUDA_VISIBLE_DEVICES=1 python3 llama_inference.py \ 10 | --model_size llama-7b \ 11 | --use-flash-attn \ 12 | --hetero_config 4 2 \ 13 | --pp_partition 20 12 \ 14 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_p2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=0 2 | export NCCL_IB_HCA=mlx5_2,mlx5_5 3 | # Modify the master IP below before execution 4 | export MASTER_ADDR='xxx.xxx.xxx.xx' 5 | export MASTER_PORT=9991 6 | export WORLD_SIZE=6 7 | export RANK=2 8 | 9 | CUDA_VISIBLE_DEVICES=2 python3 llama_inference.py \ 10 | --model_size llama-7b \ 11 | --use-flash-attn \ 12 | --hetero_config 4 2 \ 13 | --pp_partition 20 12 \ 14 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_p3.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=0 2 | export NCCL_IB_HCA=mlx5_2,mlx5_5 3 | # Modify the master IP below before execution 4 | export MASTER_ADDR='xxx.xxx.xxx.xx' 5 | export MASTER_PORT=9991 6 | export WORLD_SIZE=6 7 | export RANK=3 8 | 9 | CUDA_VISIBLE_DEVICES=3 python3 llama_inference.py \ 10 | --model_size llama-7b \ 11 | --use-flash-attn \ 12 | --hetero_config 4 2 \ 13 | --pp_partition 20 12 \ 14 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_p4.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=0 2 | export NCCL_IB_HCA=mlx5_2,mlx5_5 3 | # Modify the master IP below before execution 4 | export MASTER_ADDR='xxx.xxx.xxx.xx' 5 | export MASTER_PORT=9991 6 | export WORLD_SIZE=6 7 | export RANK=4 8 | 9 | CUDA_VISIBLE_DEVICES=0 python3 llama_inference.py \ 10 | --model_size llama-7b \ 11 | --use-flash-attn \ 12 | --hetero_config 4 2 \ 13 | --pp_partition 20 12 \ 14 | -------------------------------------------------------------------------------- /hexgen/llama/scripts/run_llama_p5.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=0 2 | export NCCL_IB_HCA=mlx5_2,mlx5_5 3 | # Modify the master IP below before execution 4 | export MASTER_ADDR='xxx.xxx.xxx.xx' 5 | export MASTER_PORT=9991 6 | export WORLD_SIZE=6 7 | export RANK=5 8 | 9 | CUDA_VISIBLE_DEVICES=1 python3 llama_inference.py \ 10 | --model_size llama-7b \ 11 | --use-flash-attn \ 12 | --hetero_config 4 2 \ 13 | --pp_partition 20 12 \ 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ninja 2 | packaging 3 | h5py 4 | tqdm 5 | six 6 | regex 7 | transformers 8 | pynvml 9 | nats-py 10 | loguru 11 | aiohttp 12 | pandas 13 | sentencepiece 14 | protobuf 15 | -------------------------------------------------------------------------------- /scripts/run_head.sh: -------------------------------------------------------------------------------- 1 | # run coordinator 2 | SESSION_NAME="coordinator" 3 | tmux new-session -d -s $SESSION_NAME 4 | tmux send-keys -t $SESSION_NAME "cd ./third_party/ocf/src/ocf-core/" C-m 5 | tmux send-keys -t $SESSION_NAME "build/core start --config config/cfg_standalone.yaml" C-m -------------------------------------------------------------------------------- /scripts/run_worker.sh: -------------------------------------------------------------------------------- 1 | # run coordinator 2 | SESSION_NAME="coordinator" 3 | tmux new-session -d -s $SESSION_NAME 4 | tmux send-keys -t $SESSION_NAME "cd ./third_party/ocf/src/ocf-core/" C-m 5 | tmux send-keys -t $SESSION_NAME "build/core start --config config/cfg.yaml" C-m 6 | 7 | 8 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | from .global_vars import get_args, get_retro_args 6 | from .global_vars import get_current_global_batch_size 7 | from .global_vars import get_num_microbatches 8 | from .global_vars import get_signal_handler 9 | from .global_vars import update_num_microbatches 10 | from .global_vars import get_tokenizer 11 | from .global_vars import get_tensorboard_writer 12 | from .global_vars import get_adlr_autoresume 13 | from .global_vars import get_timers 14 | from .initialize import initialize_megatron 15 | 16 | from .utils import (print_rank_0, 17 | is_last_rank, 18 | print_rank_last) 19 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/README.md: -------------------------------------------------------------------------------- 1 | Megatron Core is a library for efficient and scalable training of transformer based models. 2 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/__init__.py: -------------------------------------------------------------------------------- 1 | import megatron.core.parallel_state 2 | import megatron.core.tensor_parallel 3 | import megatron.core.utils 4 | 5 | # Alias parallel_state as mpu, its legacy name 6 | mpu = parallel_state 7 | 8 | __all__ = [ 9 | "parallel_state", 10 | "tensor_parallel", 11 | "utils", 12 | ] 13 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import enum 4 | 5 | class ModelType(enum.Enum): 6 | encoder_or_decoder = 1 7 | encoder_and_decoder = 2 8 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/package_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | MAJOR = 0 5 | MINOR = 1 6 | PATCH = 0 7 | PRE_RELEASE = '' 8 | 9 | # Use the following formatting: (major, minor, patch, pre-release) 10 | VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE) 11 | 12 | __shortversion__ = '.'.join(map(str, VERSION[:3])) 13 | __version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:]) 14 | 15 | __package_name__ = 'megatron_core' 16 | __contact_names__ = 'NVIDIA' 17 | __contact_emails__ = 'nemo-toolkit@nvidia.com' # use NeMo Email 18 | __homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/' # use NeMo homepage 19 | __repository_url__ = 'https://github.com/NVIDIA/Megatron-LM/megatron/core' 20 | __download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases' 21 | __description__ = 'Megatron Core - a library for efficient and scalable training of transformer based models' 22 | __license__ = 'BSD-3' 23 | __keywords__ = 'deep learning, machine learning, gpu, NLP, NLU, language, transformer, nvidia, pytorch, torch' 24 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .schedules import get_forward_backward_func 2 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/requirements.txt: -------------------------------------------------------------------------------- 1 | torch -------------------------------------------------------------------------------- /third_party/megatron/megatron/core/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import vocab_parallel_cross_entropy 2 | from .data import broadcast_data 3 | 4 | from .layers import ( 5 | ColumnParallelLinear, 6 | RowParallelLinear, 7 | VocabParallelEmbedding, 8 | set_tensor_model_parallel_attributes, 9 | set_defaults_if_not_set_tensor_model_parallel_attributes, 10 | copy_tensor_model_parallel_attributes, 11 | param_is_not_tensor_parallel_duplicate, 12 | linear_with_grad_accumulation_and_async_allreduce 13 | 14 | ) 15 | 16 | from .mappings import ( 17 | copy_to_tensor_model_parallel_region, 18 | gather_from_tensor_model_parallel_region, 19 | gather_from_sequence_parallel_region, 20 | scatter_to_tensor_model_parallel_region, 21 | scatter_to_sequence_parallel_region, 22 | ) 23 | 24 | from .mappings_group import ( 25 | get_tensor_model_parallel_world_size_group, 26 | get_tensor_model_parallel_rank_group, 27 | copy_to_tensor_model_parallel_region_group, 28 | gather_from_tensor_model_parallel_region_group, 29 | gather_from_sequence_parallel_region_group, 30 | reduce_from_tensor_model_parallel_region_group, 31 | scatter_to_tensor_model_parallel_region_group, 32 | scatter_to_sequence_parallel_region_group, 33 | reduce_scatter_to_sequence_parallel_region_group, 34 | ) 35 | 36 | from .random import ( 37 | checkpoint, 38 | get_cuda_rng_tracker, 39 | model_parallel_cuda_manual_seed, 40 | ) 41 | 42 | from .utils import ( 43 | split_tensor_along_last_dim, 44 | split_tensor_into_1d_equal_chunks, 45 | gather_split_1d_tensor, 46 | ) 47 | 48 | __all__ = [ 49 | # cross_entropy.py 50 | "vocab_parallel_cross_entropy", 51 | # data.py 52 | "broadcast_data", 53 | #layers.py 54 | "ColumnParallelLinear", 55 | "RowParallelLinear", 56 | "VocabParallelEmbedding", 57 | "set_tensor_model_parallel_attributes", 58 | "set_defaults_if_not_set_tensor_model_parallel_attributes", 59 | "copy_tensor_model_parallel_attributes", 60 | "param_is_not_tensor_parallel_duplicate", 61 | "linear_with_grad_accumulation_and_async_allreduce", 62 | # mappings.py 63 | "copy_to_tensor_model_parallel_region", 64 | "gather_from_tensor_model_parallel_region", 65 | "gather_from_sequence_parallel_region", 66 | # "reduce_from_tensor_model_parallel_region", 67 | "scatter_to_tensor_model_parallel_region", 68 | "scatter_to_sequence_parallel_region", 69 | # random.py 70 | "checkpoint", 71 | "get_cuda_rng_tracker", 72 | "model_parallel_cuda_manual_seed", 73 | # utils.py 74 | "split_tensor_along_last_dim", 75 | "split_tensor_into_1d_equal_chunks", 76 | "gather_split_1d_tensor", 77 | ] 78 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/data/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 2 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 3 | LIBNAME = helpers 4 | LIBEXT = $(shell python3-config --extension-suffix) 5 | 6 | default: $(LIBNAME)$(LIBEXT) 7 | 8 | %$(LIBEXT): %.cpp 9 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ 10 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import indexed_dataset 2 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/data/blendable_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Blendable dataset.""" 4 | 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from megatron import print_rank_0 11 | 12 | class BlendableDataset(torch.utils.data.Dataset): 13 | 14 | 15 | def __init__(self, datasets, weights, size): 16 | 17 | self.datasets = datasets 18 | num_datasets = len(datasets) 19 | assert num_datasets == len(weights) 20 | 21 | self.size = size 22 | 23 | # Normalize weights. 24 | weights = np.array(weights, dtype=np.float64) 25 | sum_weights = np.sum(weights) 26 | assert sum_weights > 0.0 27 | weights /= sum_weights 28 | 29 | # Build indicies. 30 | start_time = time.time() 31 | assert num_datasets < 255 32 | self.dataset_index = np.zeros(self.size, dtype=np.uint8) 33 | self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) 34 | 35 | from megatron.data import helpers 36 | helpers.build_blending_indices(self.dataset_index, 37 | self.dataset_sample_index, 38 | weights, num_datasets, self.size, 39 | torch.distributed.get_rank() == 0) 40 | print_rank_0('> elapsed time for building blendable dataset indices: ' 41 | '{:.2f} (sec)'.format(time.time() - start_time)) 42 | 43 | # Check size 44 | _ = self.__getitem__(self.size - 1) 45 | try: 46 | _ = self.__getitem__(self.size) 47 | raise RuntimeError('BlendedDataset size is improperly bounded') 48 | except IndexError: 49 | pass 50 | print_rank_0('> size of blendable dataset: ' 51 | '{} samples'.format(self.size)) 52 | 53 | 54 | def __len__(self): 55 | return self.size 56 | 57 | 58 | def __getitem__(self, idx): 59 | dataset_idx = self.dataset_index[idx] 60 | sample_idx = self.dataset_sample_index[idx] 61 | return { 62 | "dataset_idx" : dataset_idx, 63 | **self.datasets[dataset_idx][sample_idx], 64 | } 65 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/megatron/megatron/data/helpers.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /third_party/megatron/megatron/data/test/test_preprocess_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMPL=cached 4 | python ../preprocess_data.py \ 5 | --input test_samples.json \ 6 | --vocab vocab.txt \ 7 | --dataset-impl ${IMPL} \ 8 | --output-prefix test_samples_${IMPL} \ 9 | --workers 1 \ 10 | --log-interval 2 11 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/dist_signal_handler.py: -------------------------------------------------------------------------------- 1 | import signal 2 | 3 | import torch 4 | 5 | 6 | def get_world_size(): 7 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 8 | world_size = torch.distributed.get_world_size() 9 | else: 10 | world_size = 1 11 | return world_size 12 | 13 | 14 | def get_device(local_rank=None): 15 | backend = torch.distributed.get_backend() 16 | if backend == 'nccl': 17 | if local_rank is None: 18 | device = torch.device('cuda') 19 | else: 20 | device = torch.device(f'cuda:{local_rank}') 21 | elif backend == 'gloo': 22 | device = torch.device('cpu') 23 | else: 24 | raise RuntimeError 25 | return device 26 | 27 | 28 | def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): 29 | if not torch.distributed.is_available() or \ 30 | not torch.distributed.is_initialized(): 31 | return [item] 32 | 33 | device = get_device(local_rank) 34 | 35 | if group is not None: 36 | group_size = group.size() 37 | else: 38 | group_size = get_world_size() 39 | 40 | tensor = torch.tensor([item], device=device, dtype=dtype) 41 | output_tensors = [ 42 | torch.zeros(1, dtype=tensor.dtype, device=tensor.device) 43 | for _ in range(group_size) 44 | ] 45 | torch.distributed.all_gather(output_tensors, tensor, group, async_op) 46 | output = [elem.item() for elem in output_tensors] 47 | return output 48 | 49 | 50 | class DistributedSignalHandler: 51 | def __init__(self, sig=signal.SIGTERM): 52 | self.sig = sig 53 | 54 | def signals_received(self): 55 | all_received = all_gather_item( 56 | self._signal_received, dtype=torch.int32 57 | ) 58 | return all_received 59 | 60 | def __enter__(self): 61 | self._signal_received = False 62 | self.released = False 63 | self.original_handler = signal.getsignal(self.sig) 64 | 65 | def handler(signum, frame): 66 | self._signal_received = True 67 | 68 | signal.signal(self.sig, handler) 69 | 70 | return self 71 | 72 | def __exit__(self, type, value, tb): 73 | self.release() 74 | 75 | def release(self): 76 | if self.released: 77 | return False 78 | 79 | signal.signal(self.sig, self.original_handler) 80 | self.released = True 81 | return True 82 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fp16_deprecated/loss_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """For backward compatibility, we need the class definitions to deserialize.""" 4 | 5 | class LossScaler: 6 | def __init__(self, scale=1): 7 | self.cur_scale = scale 8 | 9 | class DynamicLossScaler: 10 | def __init__(self, 11 | init_scale=2**32, 12 | scale_factor=2., 13 | scale_window=1000, 14 | min_scale=1, 15 | delayed_shift=1, 16 | consecutive_hysteresis=False): 17 | self.cur_scale = init_scale 18 | self.cur_iter = 0 19 | self.last_overflow_iter = -1 20 | self.scale_factor = scale_factor 21 | self.scale_window = scale_window 22 | self.min_scale = min_scale 23 | self.delayed_shift = delayed_shift 24 | self.cur_hysteresis = delayed_shift 25 | self.consecutive_hysteresis = consecutive_hysteresis 26 | 27 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/compat.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | /*This code is copied fron NVIDIA apex: 4 | * https://github.com/NVIDIA/apex 5 | * with minor changes. */ 6 | 7 | 8 | 9 | #ifndef TORCH_CHECK 10 | #define TORCH_CHECK AT_CHECK 11 | #endif 12 | 13 | #ifdef VERSION_GE_1_3 14 | #define DATA_PTR data_ptr 15 | #else 16 | #define DATA_PTR data 17 | #endif 18 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/scaled_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_masked_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | torch::Tensor const& mask, 14 | float scale_factor); 15 | 16 | torch::Tensor bwd_cuda( 17 | torch::Tensor const& output_grads, 18 | torch::Tensor const& softmax_results, 19 | float scale_factor); 20 | 21 | int get_batch_per_block_cuda( 22 | int query_seq_len, 23 | int key_seq_len, 24 | int batches, 25 | int attn_heads); 26 | 27 | torch::Tensor fwd( 28 | torch::Tensor const& input, 29 | torch::Tensor const& mask, 30 | float scale_factor) { 31 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 32 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 33 | (input.scalar_type() == at::ScalarType::BFloat16), 34 | "Only fp16 and bf16 are supported"); 35 | AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); 36 | 37 | return fwd_cuda(input, mask, scale_factor); 38 | } 39 | 40 | torch::Tensor bwd( 41 | torch::Tensor const& output_grads, 42 | torch::Tensor const& softmax_results, 43 | float scale_factor) { 44 | 45 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 46 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 47 | 48 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 49 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 50 | "Only fp16 and bf16 are supported"); 51 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 52 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 53 | "Only fp16 and bf16 are supported"); 54 | 55 | return bwd_cuda(output_grads, softmax_results, scale_factor); 56 | } 57 | 58 | int get_batch_per_block( 59 | int query_seq_len, 60 | int key_seq_len, 61 | int batches, 62 | int attn_heads) { 63 | return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); 64 | } 65 | 66 | } // end namespace scaled_masked_softmax 67 | } // end namespace fused_softmax 68 | } // end namespace multihead_attn 69 | 70 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 71 | m.def("forward", 72 | &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 73 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 74 | 75 | m.def("backward", 76 | &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, 77 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 78 | 79 | m.def("get_batch_per_block", 80 | &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, 81 | "Return Batch per block size." 82 | ); 83 | } 84 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/scaled_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | float scale_factor); 14 | 15 | torch::Tensor bwd_cuda( 16 | torch::Tensor const& output_grads, 17 | torch::Tensor const& softmax_results, 18 | float scale_factor); 19 | 20 | torch::Tensor fwd( 21 | torch::Tensor const& input, 22 | float scale_factor) { 23 | AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); 24 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 25 | (input.scalar_type() == at::ScalarType::BFloat16), 26 | "Only fp16 and bf16 are supported"); 27 | 28 | return fwd_cuda(input, scale_factor); 29 | } 30 | 31 | torch::Tensor bwd( 32 | torch::Tensor const& output_grads, 33 | torch::Tensor const& softmax_results, 34 | float scale_factor) { 35 | 36 | AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); 37 | AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); 38 | 39 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 40 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 41 | "Only fp16 and bf16 are supported"); 42 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 43 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 44 | "Only fp16 and bf16 are supported"); 45 | 46 | return bwd_cuda(output_grads, softmax_results, scale_factor); 47 | } 48 | 49 | } // end namespace scaled_softmax 50 | } // end namespace fused_softmax 51 | } // end namespace multihead_attn 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", 55 | &multihead_attn::fused_softmax::scaled_softmax::fwd, 56 | "Self Multihead Attention scaled, softmax -- Forward."); 57 | m.def("backward", 58 | &multihead_attn::fused_softmax::scaled_softmax::bwd, 59 | "Self Multihead Attention scaled, softmax -- Backward."); 60 | } 61 | 62 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/scaled_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_softmax { 16 | 17 | torch::Tensor fwd_cuda( 18 | torch::Tensor const& input, 19 | float scale_factor) 20 | { 21 | // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 22 | const int batches = input.size(0); 23 | const int attn_heads = input.size(1); 24 | const int query_seq_len = input.size(2); 25 | const int key_seq_len = input.size(3); 26 | TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); 27 | TORCH_INTERNAL_ASSERT(query_seq_len > 1); 28 | 29 | // Output 30 | auto act_options = input.options().requires_grad(false); 31 | torch::Tensor softmax_results = 32 | torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); 33 | 34 | // Softmax Intermediate Result Ptr 35 | void* input_ptr = static_cast(input.data_ptr()); 36 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 37 | 38 | DISPATCH_HALF_AND_BFLOAT( 39 | input.scalar_type(), 40 | "dispatch_scaled_softmax_forward", 41 | dispatch_scaled_softmax_forward( 42 | reinterpret_cast(softmax_results_ptr), 43 | reinterpret_cast(input_ptr), 44 | scale_factor, 45 | query_seq_len, 46 | key_seq_len, 47 | batches, 48 | attn_heads); 49 | ); 50 | return softmax_results; 51 | } 52 | 53 | torch::Tensor bwd_cuda( 54 | torch::Tensor const& output_grads_, 55 | torch::Tensor const& softmax_results_, 56 | float scale_factor) { 57 | 58 | auto output_grads = output_grads_.contiguous(); 59 | auto softmax_results = softmax_results_.contiguous(); 60 | 61 | //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] 62 | const int batches = output_grads.size(0); 63 | const int attn_heads = output_grads.size(1); 64 | const int query_seq_len = output_grads.size(2); 65 | const int key_seq_len = output_grads.size(3); 66 | 67 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 68 | 69 | //Softmax Grad 70 | DISPATCH_HALF_AND_BFLOAT( 71 | output_grads_.scalar_type(), 72 | "dispatch_scaled_masked_softmax_backward", 73 | dispatch_scaled_masked_softmax_backward( 74 | reinterpret_cast(output_grads_ptr), 75 | reinterpret_cast(output_grads_ptr), 76 | reinterpret_cast(softmax_results.data_ptr()), 77 | scale_factor, 78 | query_seq_len, 79 | key_seq_len, 80 | batches, 81 | attn_heads); 82 | ); 83 | 84 | //backward pass is completely in-place 85 | return output_grads; 86 | } 87 | } 88 | } 89 | } 90 | 91 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace multihead_attn { 8 | namespace fused_softmax { 9 | namespace scaled_upper_triang_masked_softmax { 10 | 11 | torch::Tensor fwd_cuda( 12 | torch::Tensor const& input, 13 | float scale_factor); 14 | 15 | torch::Tensor bwd_cuda( 16 | torch::Tensor const& output_grads, 17 | torch::Tensor const& softmax_results, 18 | float scale_factor); 19 | 20 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 21 | AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); 22 | AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || 23 | (input.scalar_type() == at::ScalarType::BFloat16), 24 | "Only fp16 and bf16 are supported"); 25 | 26 | return fwd_cuda(input, scale_factor); 27 | } 28 | 29 | torch::Tensor bwd( 30 | torch::Tensor const& output_grads, 31 | torch::Tensor const& softmax_results, 32 | float scale_factor) { 33 | 34 | AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); 35 | AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); 36 | 37 | AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || 38 | (output_grads.scalar_type() == at::ScalarType::BFloat16), 39 | "Only fp16 and bf16 are supported"); 40 | AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || 41 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 42 | "Only fp16 and bf16 are supported"); 43 | 44 | return bwd_cuda(output_grads, softmax_results, scale_factor); 45 | } 46 | 47 | } // end namespace scaled_upper_triang_masked_softmax 48 | } // end namespace fused_softmax 49 | } // end namespace multihead_attn 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", 53 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 54 | "Self Multihead Attention scaled, time masked softmax -- Forward."); 55 | m.def("backward", 56 | &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, 57 | "Self Multihead Attention scaled, time masked softmax -- Backward."); 58 | } 59 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "scaled_upper_triang_masked_softmax.h" 11 | #include "type_shim.h" 12 | 13 | namespace multihead_attn { 14 | namespace fused_softmax { 15 | namespace scaled_upper_triang_masked_softmax { 16 | 17 | torch::Tensor fwd_cuda( 18 | torch::Tensor const& input, 19 | float scale_factor) 20 | { 21 | // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 22 | const int attn_batches = input.size(0); 23 | const int seq_len = input.size(1); 24 | TORCH_INTERNAL_ASSERT(seq_len <= 16384); 25 | 26 | // Output 27 | auto act_options = input.options().requires_grad(false); 28 | torch::Tensor softmax_results = 29 | torch::empty({attn_batches, seq_len, seq_len}, act_options); 30 | 31 | // Softmax Intermediate Result Ptr 32 | void* input_ptr = static_cast(input.data_ptr()); 33 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 34 | 35 | DISPATCH_HALF_AND_BFLOAT( 36 | input.scalar_type(), 37 | "dispatch_scaled_upper_triang_masked_softmax_forward", 38 | dispatch_scaled_upper_triang_masked_softmax_forward( 39 | reinterpret_cast(softmax_results_ptr), 40 | reinterpret_cast(input_ptr), 41 | scale_factor, 42 | seq_len, 43 | seq_len, 44 | attn_batches); 45 | ); 46 | return softmax_results; 47 | } 48 | 49 | 50 | torch::Tensor bwd_cuda( 51 | torch::Tensor const& output_grads_, 52 | torch::Tensor const& softmax_results_, 53 | float scale_factor) { 54 | 55 | auto output_grads = output_grads_.contiguous(); 56 | auto softmax_results = softmax_results_.contiguous(); 57 | 58 | //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 59 | const int attn_batches = output_grads.size(0); 60 | const int seq_len = output_grads.size(1); 61 | TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); 62 | 63 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 64 | 65 | //Softmax Grad 66 | DISPATCH_HALF_AND_BFLOAT( 67 | output_grads_.scalar_type(), 68 | "dispatch_scaled_upper_triang_masked_softmax_backward", 69 | dispatch_scaled_upper_triang_masked_softmax_backward( 70 | reinterpret_cast(output_grads_ptr), 71 | reinterpret_cast(output_grads_ptr), 72 | reinterpret_cast(softmax_results.data_ptr()), 73 | scale_factor, 74 | seq_len, 75 | seq_len, 76 | attn_batches); 77 | ); 78 | 79 | //backward pass is completely in-place 80 | return output_grads; 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/megatron/megatron/fused_kernels/tests/__init__.py -------------------------------------------------------------------------------- /third_party/megatron/megatron/fused_kernels/type_shim.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ 2 | 3 | 4 | #include 5 | #include "compat.h" 6 | 7 | 8 | #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 9 | switch(TYPE) \ 10 | { \ 11 | case at::ScalarType::Half: \ 12 | { \ 13 | using scalar_t = at::Half; \ 14 | __VA_ARGS__; \ 15 | break; \ 16 | } \ 17 | case at::ScalarType::BFloat16: \ 18 | { \ 19 | using scalar_t = at::BFloat16; \ 20 | __VA_ARGS__; \ 21 | break; \ 22 | } \ 23 | default: \ 24 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 25 | } 26 | 27 | 28 | #define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ 29 | switch(TYPE) \ 30 | { \ 31 | case at::ScalarType::Half: \ 32 | { \ 33 | using scalar_t = at::Half; \ 34 | __VA_ARGS__; \ 35 | break; \ 36 | } \ 37 | case at::ScalarType::BFloat16: \ 38 | { \ 39 | using scalar_t = at::BFloat16; \ 40 | __VA_ARGS__; \ 41 | break; \ 42 | } \ 43 | case at::ScalarType::Float: \ 44 | { \ 45 | using scalar_t = float; \ 46 | __VA_ARGS__; \ 47 | break; \ 48 | } \ 49 | default: \ 50 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 51 | } 52 | 53 | 54 | 55 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ 56 | switch(TYPEIN) \ 57 | { \ 58 | case at::ScalarType::Float: \ 59 | { \ 60 | using scalar_t_in = float; \ 61 | switch(TYPEOUT) \ 62 | { \ 63 | case at::ScalarType::Float: \ 64 | { \ 65 | using scalar_t_out = float; \ 66 | __VA_ARGS__; \ 67 | break; \ 68 | } \ 69 | case at::ScalarType::Half: \ 70 | { \ 71 | using scalar_t_out = at::Half; \ 72 | __VA_ARGS__; \ 73 | break; \ 74 | } \ 75 | case at::ScalarType::BFloat16: \ 76 | { \ 77 | using scalar_t_out = at::BFloat16; \ 78 | __VA_ARGS__; \ 79 | break; \ 80 | } \ 81 | default: \ 82 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ 83 | } \ 84 | break; \ 85 | } \ 86 | case at::ScalarType::Half: \ 87 | { \ 88 | using scalar_t_in = at::Half; \ 89 | using scalar_t_out = at::Half; \ 90 | __VA_ARGS__; \ 91 | break; \ 92 | } \ 93 | case at::ScalarType::BFloat16: \ 94 | { \ 95 | using scalar_t_in = at::BFloat16; \ 96 | using scalar_t_out = at::BFloat16; \ 97 | __VA_ARGS__; \ 98 | break; \ 99 | } \ 100 | default: \ 101 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ 102 | } 103 | 104 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm 4 | 5 | from .distributed import DistributedDataParallel 6 | from .bert_model import BertModel 7 | from .gpt_model import GPTModel 8 | from .t5_model import T5Model 9 | from .language_model import get_language_model 10 | from .module import Float16Module, MegatronModule 11 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import enum 4 | 5 | class LayerType(enum.Enum): 6 | encoder = 1 7 | decoder = 2 8 | 9 | class AttnType(enum.Enum): 10 | self_attn = 1 11 | cross_attn = 2 12 | 13 | class AttnMaskType(enum.Enum): 14 | padding = 1 15 | causal = 2 16 | 17 | # For backward compatibility with old model checkpoints 18 | from megatron.core.enums import ModelType 19 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/fused_bias_gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import torch 4 | 5 | 6 | ###### BIAS GELU FUSION/ NO AUTOGRAD ################ 7 | # 1/sqrt(2*pi)-> 0.3989423 8 | # 1/sqrt(2) -> 0.70710678 9 | # sqrt(2/pi) -> 0.79788456 10 | # this function is tanh approximation of gelu 11 | # actual gelu is: 12 | # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) 13 | 14 | @torch.jit.script 15 | def bias_gelu(bias, y): 16 | x = bias + y 17 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 18 | 19 | # gradient of tanh approximation of gelu 20 | # gradient of actual gelu is: 21 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) 22 | @torch.jit.script 23 | def bias_gelu_back(g, bias, y): 24 | x = bias + y 25 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 26 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 27 | ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) 28 | return ff*g 29 | 30 | class GeLUFunction(torch.autograd.Function): 31 | @staticmethod 32 | # bias is an optional argument 33 | def forward(ctx, input, bias): 34 | ctx.save_for_backward(input, bias) 35 | return bias_gelu(bias, input) 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | input, bias = ctx.saved_tensors 40 | tmp = bias_gelu_back(grad_output, bias, input) 41 | return tmp, tmp 42 | 43 | bias_gelu_impl = GeLUFunction.apply 44 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/rotary_pos_embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \ 4 | # 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \ 5 | # common/megatron/rotary_pos_embedding.py 6 | 7 | import importlib.util 8 | import torch 9 | 10 | from torch import einsum, nn 11 | 12 | __all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] 13 | 14 | class RotaryEmbedding(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 18 | self.register_buffer('inv_freq', inv_freq) 19 | if importlib.util.find_spec('einops') is None: 20 | raise RuntimeError("einops is required for Rotary Embedding") 21 | 22 | def forward(self, max_seq_len, offset=0): 23 | seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset 24 | freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) 25 | # first part even vector components, second part odd vector components, 26 | # 2 * dim in dimension size 27 | emb = torch.cat((freqs, freqs), dim=-1) 28 | # emb [seq_length, .., dim] 29 | from einops import rearrange 30 | return rearrange(emb, 'n d -> n 1 1 d') 31 | 32 | 33 | def _rotate_half(x): 34 | """ 35 | change sign so the last dimension becomes [-odd, +even] 36 | """ 37 | from einops import rearrange 38 | x = rearrange(x, '... (j d) -> ... j d', j=2) 39 | x1, x2 = x.unbind(dim=-2) 40 | return torch.cat((-x2, x1), dim=-1) 41 | 42 | 43 | def apply_rotary_pos_emb(t, freqs): 44 | """ 45 | input tensor t is of shape [seq_length, ..., dim] 46 | rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] 47 | check https://kexue.fm/archives/8265 for detailed formulas 48 | """ 49 | rot_dim = freqs.shape[-1] 50 | # ideally t_pass is empty so rotary pos embedding is applied to all tensor t 51 | t, t_pass = t[..., :rot_dim], t[..., rot_dim:] 52 | 53 | # first part is cosine component 54 | # second part is sine component, need to change signs with _rotate_half method 55 | t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) 56 | return torch.cat((t, t_pass), dim=-1) 57 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Utilities for models.""" 4 | 5 | import math 6 | 7 | import torch 8 | 9 | from megatron import get_args 10 | 11 | def init_method_normal(sigma): 12 | """Init method based on N(0, sigma).""" 13 | def init_(tensor): 14 | return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) 15 | 16 | return init_ 17 | 18 | 19 | def scaled_init_method_normal(sigma, num_layers): 20 | """Init method based on N(0, sigma/sqrt(2*num_layers).""" 21 | std = sigma / math.sqrt(2.0 * num_layers) 22 | 23 | def init_(tensor): 24 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 25 | 26 | return init_ 27 | 28 | 29 | def attention_mask_func(attention_scores, attention_mask): 30 | attention_scores.masked_fill_(attention_mask.to(torch.bool), -10000.0) 31 | return attention_scores 32 | 33 | 34 | def get_linear_layer(rows, columns, init_method): 35 | """Simple linear layer with weight initialization.""" 36 | layer = torch.nn.Linear(rows, columns) 37 | if get_args().perform_initialization: 38 | init_method(layer.weight) 39 | with torch.no_grad(): 40 | layer.bias.zero_() 41 | return layer 42 | 43 | @torch.jit.script 44 | def gelu_impl(x): 45 | """OpenAI's gelu implementation.""" 46 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 47 | (1.0 + 0.044715 * x * x))) 48 | def openai_gelu(x): 49 | return gelu_impl(x) 50 | 51 | #This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter 52 | @torch.jit.script 53 | def erf_gelu(x): 54 | return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) 55 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/vision/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | """Vision Transformer(VIT) model.""" 4 | 5 | import torch 6 | from torch.nn.init import trunc_normal_ 7 | from megatron import get_args 8 | from megatron.model.utils import get_linear_layer 9 | from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead 10 | from megatron.model.vision.mit_backbone import mit_b3_avg 11 | from megatron.model.module import MegatronModule 12 | 13 | class VitClassificationModel(MegatronModule): 14 | """Vision Transformer Model.""" 15 | 16 | def __init__(self, num_classes, finetune=False, 17 | pre_process=True, post_process=True): 18 | super(VitClassificationModel, self).__init__() 19 | args = get_args() 20 | 21 | self.hidden_size = args.hidden_size 22 | self.num_classes = num_classes 23 | self.finetune = finetune 24 | self.pre_process = pre_process 25 | self.post_process = post_process 26 | self.backbone = VitBackbone( 27 | pre_process=self.pre_process, 28 | post_process=self.post_process, 29 | single_token_output=True 30 | ) 31 | 32 | if self.post_process: 33 | if not self.finetune: 34 | self.head = VitMlpHead(self.hidden_size, self.num_classes) 35 | else: 36 | self.head = get_linear_layer( 37 | self.hidden_size, 38 | self.num_classes, 39 | torch.nn.init.zeros_ 40 | ) 41 | 42 | def set_input_tensor(self, input_tensor): 43 | """See megatron.model.transformer.set_input_tensor()""" 44 | self.backbone.set_input_tensor(input_tensor) 45 | 46 | def forward(self, input): 47 | hidden_states = self.backbone(input) 48 | 49 | if self.post_process: 50 | hidden_states = self.head(hidden_states) 51 | 52 | return hidden_states 53 | 54 | 55 | class MitClassificationModel(MegatronModule): 56 | """Mix vision Transformer Model.""" 57 | 58 | def __init__(self, num_classes, 59 | pre_process=True, post_process=True): 60 | super(MitClassificationModel, self).__init__() 61 | args = get_args() 62 | 63 | self.hidden_size = args.hidden_size 64 | self.num_classes = num_classes 65 | 66 | self.backbone = mit_b3_avg() 67 | self.head = torch.nn.Linear(512, num_classes) 68 | self.apply(self._init_weights) 69 | 70 | def _init_weights(self, m): 71 | if isinstance(m, torch.nn.Linear): 72 | trunc_normal_(m.weight, std=.02) 73 | if isinstance(m, torch.nn.Linear) and m.bias is not None: 74 | torch.nn.init.constant_(m.bias, 0) 75 | 76 | def set_input_tensor(self, input_tensor): 77 | """See megatron.model.transformer.set_input_tensor()""" 78 | pass 79 | 80 | def forward(self, input): 81 | hidden_states = self.backbone(input) 82 | hidden_states = self.head(hidden_states) 83 | 84 | return hidden_states 85 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/model/vision/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def resize(input, 7 | size=None, 8 | scale_factor=None, 9 | mode='nearest', 10 | align_corners=None, 11 | warning=True): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ((output_h > 1 and output_w > 1 and input_h > 1 18 | and input_w > 1) and (output_h - 1) % (input_h - 1) 19 | and (output_w - 1) % (input_w - 1)): 20 | warnings.warn( 21 | f'When align_corners={align_corners}, ' 22 | 'the output would more aligned if ' 23 | f'input size {(input_h, input_w)} is `x+1` and ' 24 | f'out size {(output_h, output_w)} is `nx+1`') 25 | if isinstance(size, torch.Size): 26 | size = tuple(int(x) for x in size) 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/mpu/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/megatron/megatron/mpu/tests/__init__.py -------------------------------------------------------------------------------- /third_party/megatron/megatron/mpu/tests/commons.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | import argparse 4 | import os 5 | import random 6 | import numpy 7 | import torch 8 | 9 | import mpu 10 | 11 | 12 | class IdentityLayer(torch.nn.Module): 13 | def __init__(self, size, scale=1.0): 14 | super(IdentityLayer, self).__init__() 15 | self.weight = torch.nn.Parameter(scale * torch.randn(size)) 16 | 17 | def forward(self): 18 | return self.weight 19 | 20 | 21 | def set_random_seed(seed): 22 | """Set random seed for reproducability.""" 23 | random.seed(seed) 24 | numpy.random.seed(seed) 25 | torch.manual_seed(seed) 26 | mpu.model_parallel_cuda_manual_seed(seed) 27 | 28 | 29 | def initialize_distributed(backend='nccl'): 30 | """Initialize torch.distributed.""" 31 | # Get local rank in case it is provided. 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--local_rank', type=int, default=None, 34 | help='local rank passed from distributed launcher') 35 | args = parser.parse_args() 36 | local_rank = args.local_rank 37 | 38 | # Get rank and world size. 39 | rank = int(os.getenv('RANK', '0')) 40 | world_size = int(os.getenv("WORLD_SIZE", '1')) 41 | 42 | print('> initializing torch.distributed with local rank: {}, ' 43 | 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) 44 | 45 | # Set the device id. 46 | device = rank % torch.cuda.device_count() 47 | if local_rank is not None: 48 | device = local_rank 49 | torch.cuda.set_device(device) 50 | 51 | # Call the init process. 52 | init_method = 'tcp://' 53 | master_ip = os.getenv('MASTER_ADDR', 'localhost') 54 | master_port = os.getenv('MASTER_PORT', '6000') 55 | init_method += master_ip + ':' + master_port 56 | torch.distributed.init_process_group( 57 | backend=backend, 58 | world_size=world_size, 59 | rank=rank, 60 | init_method=init_method) 61 | 62 | 63 | def print_separator(message): 64 | torch.distributed.barrier() 65 | filler_len = (78 - len(message)) // 2 66 | filler = '-' * filler_len 67 | string = '\n' + filler + ' {} '.format(message) + filler 68 | if torch.distributed.get_rank() == 0: 69 | print(string, flush=True) 70 | torch.distributed.barrier() 71 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/mpu/tests/test_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import print_separator 4 | from commons import initialize_distributed 5 | from mpu import data as data_utils 6 | import mpu 7 | import torch 8 | import functools 9 | import operator 10 | import sys 11 | sys.path.append("../..") 12 | 13 | 14 | def test_broadcast_data(tensor_model_parallel_size): 15 | 16 | if torch.distributed.get_rank() == 0: 17 | print('> testing broadcast_data with model parallel size {} ...'. 18 | format(tensor_model_parallel_size)) 19 | 20 | mpu.initialize_model_parallel(tensor_model_parallel_size) 21 | torch.manual_seed(1234 + mpu.get_data_parallel_rank()) 22 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 23 | 24 | key_size_t = {'key1': [7, 11], 25 | 'key2': [8, 2, 1], 26 | 'key3': [13], 27 | 'key4': [5, 1, 2], 28 | 'key5': [5, 12]} 29 | keys = list(key_size_t.keys()) 30 | 31 | data = {} 32 | data_t = {} 33 | for key in key_size_t: 34 | data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) 35 | data_t[key] = data[key].clone() 36 | data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) 37 | data_t['keyX'] = data['keyX'].clone() 38 | if mpu.get_tensor_model_parallel_rank() != 0: 39 | data = None 40 | 41 | data_utils._check_data_types(keys, data_t, torch.int64) 42 | key_size, key_numel, \ 43 | total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) 44 | for key in keys: 45 | assert key_size[key] == key_size_t[key] 46 | total_numel_t = 0 47 | for key in keys: 48 | target_size = functools.reduce(operator.mul, key_size_t[key], 1) 49 | assert key_numel[key] == target_size 50 | total_numel_t += target_size 51 | assert total_numel == total_numel_t 52 | 53 | data_b = data_utils.broadcast_data(keys, data, torch.int64) 54 | for key in keys: 55 | tensor = data_t[key].cuda() 56 | assert data_b[key].sub(tensor).abs().max() == 0 57 | 58 | # Reset groups 59 | mpu.destroy_tensor_model_parallel() 60 | 61 | torch.distributed.barrier() 62 | if torch.distributed.get_rank() == 0: 63 | print('>> passed the test :-)') 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | initialize_distributed() 69 | world_size = torch.distributed.get_world_size() 70 | 71 | tensor_model_parallel_size = 1 72 | while tensor_model_parallel_size <= world_size: 73 | print_separator('test test broadcast data') 74 | test_broadcast_data(tensor_model_parallel_size) 75 | tensor_model_parallel_size *= 2 76 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/mpu/tests/test_initialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | from commons import print_separator 4 | from commons import initialize_distributed 5 | import mpu 6 | import torch 7 | import sys 8 | sys.path.append("../..") 9 | 10 | 11 | def test_initialize_model_parallel(tensor_model_parallel_size): 12 | 13 | if torch.distributed.get_rank() == 0: 14 | print('> testing initialize_model_parallel with size {} ...'.format( 15 | tensor_model_parallel_size)) 16 | tensor_model_parallel_size_ = min(tensor_model_parallel_size, 17 | torch.distributed.get_world_size()) 18 | assert not mpu.model_parallel_is_initialized() 19 | mpu.initialize_model_parallel(tensor_model_parallel_size_) 20 | assert mpu.model_parallel_is_initialized() 21 | 22 | # Checks. 23 | def check(group, world_size, rank): 24 | assert world_size == torch.distributed.get_world_size(group=group) 25 | assert rank == torch.distributed.get_rank(group=group) 26 | 27 | # Model parallel. 28 | world_size = tensor_model_parallel_size_ 29 | rank = torch.distributed.get_rank() % tensor_model_parallel_size_ 30 | assert world_size == mpu.get_tensor_model_parallel_world_size() 31 | assert rank == mpu.get_tensor_model_parallel_rank() 32 | check(mpu.get_tensor_model_parallel_group(), world_size, rank) 33 | 34 | # Data parallel. 35 | world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ 36 | rank = torch.distributed.get_rank() // tensor_model_parallel_size 37 | assert world_size == mpu.get_data_parallel_world_size() 38 | assert rank == mpu.get_data_parallel_rank() 39 | check(mpu.get_data_parallel_group(), world_size, rank) 40 | 41 | # Reset groups 42 | mpu.destroy_model_parallel() 43 | 44 | torch.distributed.barrier() 45 | if torch.distributed.get_rank() == 0: 46 | print('>> passed the test :-)') 47 | 48 | 49 | def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): 50 | 51 | if torch.distributed.get_rank() == 0: 52 | print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( 53 | tensor_model_parallel_size_)) 54 | tensor_model_parallel_size = min(tensor_model_parallel_size_, 55 | torch.distributed.get_world_size()) 56 | assert not mpu.model_parallel_is_initialized() 57 | mpu.initialize_model_parallel(tensor_model_parallel_size) 58 | assert mpu.model_parallel_is_initialized() 59 | 60 | # Checks 61 | src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() 62 | assert mpu.get_tensor_model_parallel_src_rank() == src_rank 63 | 64 | # Reset groups 65 | mpu.destroy_model_parallel() 66 | 67 | torch.distributed.barrier() 68 | if torch.distributed.get_rank() == 0: 69 | print('>> passed the test :-)') 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | initialize_distributed() 75 | world_size = torch.distributed.get_world_size() 76 | tensor_model_parallel_size = 1 77 | while tensor_model_parallel_size <= world_size: 78 | print_separator('test initialize model parallel') 79 | test_initialize_model_parallel(tensor_model_parallel_size) 80 | print_separator('test model parallel source rank') 81 | test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) 82 | tensor_model_parallel_size *= 2 83 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Megatron 9 | 71 | 72 | 73 |
74 |

Prompt Megatron

75 | 76 | 77 | 78 | 79 | 80 |
81 | 0 82 | / 1000 83 |
84 | 85 |
86 | 87 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/text_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | from .api import ( 5 | generate, 6 | generate_and_post_process, 7 | beam_search_and_post_process) 8 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/text_generation/beam_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | ## from huggingface beam search 19 | class BeamHypotheses(object): 20 | def __init__(self, num_beams, length_penalty=1.0, early_stopping=False): 21 | """ 22 | Initialize n-best list of hypotheses. 23 | """ 24 | self.length_penalty = length_penalty 25 | self.early_stopping = early_stopping 26 | self.num_beams = num_beams 27 | self.beams = [] 28 | self.worst_score = 1e9 29 | 30 | def __len__(self): 31 | """ 32 | Number of hypotheses in the list. 33 | """ 34 | return len(self.beams) 35 | 36 | def add(self, hyp, sum_logprobs, length): 37 | """ 38 | Add a new hypothesis to the list. 39 | """ 40 | score = sum_logprobs / length ** self.length_penalty 41 | if len(self) < self.num_beams or score > self.worst_score: 42 | self.beams.append((score, hyp)) 43 | if len(self) > self.num_beams: 44 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) 45 | del self.beams[sorted_scores[0][1]] 46 | self.worst_score = sorted_scores[1][0] 47 | else: 48 | self.worst_score = min(score, self.worst_score) 49 | 50 | def is_done(self, best_sum_logprobs, cur_len): 51 | """ 52 | If there are enough hypotheses and that none of the hypotheses being generated 53 | can become better than the worst one in the heap, then we are done with this sentence. 54 | """ 55 | 56 | if len(self) < self.num_beams: 57 | return False 58 | elif self.early_stopping: 59 | return True 60 | else: 61 | cur_score = best_sum_logprobs / cur_len ** self.length_penalty 62 | ret = self.worst_score >= cur_score 63 | return ret 64 | 65 | -------------------------------------------------------------------------------- /third_party/megatron/megatron/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | from .tokenizer import build_tokenizer 5 | -------------------------------------------------------------------------------- /third_party/megatron/megatron_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import ParallelMLP, ParallelAttention #, ParallelTransformerLayer, ParallelTransformer 2 | -------------------------------------------------------------------------------- /third_party/ocf/README.md: -------------------------------------------------------------------------------- 1 | -- Open Compute Framework -- 2 | 3 | Documentation: https://ocf.Anonymous.org 4 | 5 | All the links in this repo that contains `Anonymous` should be modified when running in real settings. All the social media links that contains `Anonymous` is not real, either. -------------------------------------------------------------------------------- /third_party/ocf/docs/docs/advanced/internals.md: -------------------------------------------------------------------------------- 1 | # Internals 2 | 3 | ![The overview of OCF Architecture](../images/overview.png) 4 | 5 | Open Compute Framework is built on top of a concept named **Distributed Node Table**. 6 | 7 | ## Distributed Node Table 8 | 9 | In layman's terms, it is a table that stores the information of all the nodes in the network. The table is distributed across all the nodes in the network, and each node stores a full copy of the table. 10 | 11 | Each node in the network manages it's workers, which are the computing resources that the node can leverage. Once a worker is connected/disconnected to the node, the node will broadcast the information, including the underlying hardware of the node, the service it provides, and other metadata, to the network. All other nodes will receive the broadcast and update their local node table. 12 | 13 | Once a node receives a user-issued query, it will look up the node table to find the "best" node to handle the query. The definition of "best" is hard and we are working on it. For now, we use a simple heuristic: **we randomly pick a node that claimed it provides the service that the user requested.** 14 | 15 | With the distributed node table, we can achieve the following goals: 16 | 17 | * **Decentralized**: There is no single point of failure. The node table is distributed across all the nodes in the network. If you like, you can host your own network and invite your friends to join. 18 | * **Dynamic**: The node table is dynamic. It is updated whenever a node joins or leaves the network. 19 | * **Efficient**: The node table is distributed across all the nodes in the network. Each node stores a full copy of the table, so that it can look up the table locally. This avoids the overhead of querying a centralized node table or a cascading node table. 20 | 21 | It also comes with some drawbacks. Since we only broadcast the join/leave of workers, some newly joint node may not have a full view of the network. Hence, the node needs to sync itself with the network and it is something we are working on. 22 | 23 | ## Message Bus between Node and Workers 24 | 25 | Each node manages a set of workers. The messaging between the node and the workers is done through a message bus. The message bus is implemented using [NATS](https://nats.io/), which is a lightweight and high-performance messaging system. 26 | 27 | The message bus is used for managing the following tasks: 28 | 29 | * **Node Status**: When a worker joins the network, it will send a message to the node to announce its presence. The node will then broadcast the message to the network. 30 | * **Query Queue**: When a node receives a query, it will put the query into the message bus. The worker will then pick up the query from the message bus and process it. 31 | 32 | ## CAP Principles -------------------------------------------------------------------------------- /third_party/ocf/docs/docs/guide/getting-started.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | 3 | ## Inspect Network Status 4 | 5 | We provide a simple command line tool to inspect the status of the network. You can use it to check the status of the network, or to check the status of your own node. 6 | 7 | ```bash 8 | $ pip install -U ocf-cli 9 | ``` 10 | 11 | To check the service provided by the network, run: 12 | 13 | ```bash 14 | ocf list service 15 | ``` 16 | 17 | To check the nodes in the network, run: 18 | 19 | ```bash 20 | ocf list node 21 | ``` 22 | 23 | ## Join the Network 24 | 25 | ## Host Your Own Network -------------------------------------------------------------------------------- /third_party/ocf/docs/docs/guide/inference.md: -------------------------------------------------------------------------------- 1 | ## Inference Service 2 | 3 | To use the inference service, you can first check the status of the service by running: 4 | 5 | ```bash 6 | ocf list service 7 | ``` 8 | 9 | Then you can run the following script to test the inference service: 10 | 11 | ```python 12 | import json 13 | import requests 14 | 15 | URL = "https://inference.Anonymous.dev/api/v1/request/inference" 16 | 17 | def inference(): 18 | resp = requests.post(URL, json={ 19 | 'model_name': 'meta-llama/Llama-2-70b-chat-hf', 20 | 'params': { 21 | 'prompt': ": tell me about computer science?\n: ", 22 | 'max_tokens': 32, 23 | 'temperature': 0.7, 24 | 'top_p': 1.0, 25 | 'top_k': 40, 26 | } 27 | }) 28 | resp = json.loads(resp.json()['data']) 29 | print(resp) 30 | return resp 31 | 32 | if __name__ == "__main__": 33 | inference() 34 | ``` 35 | -------------------------------------------------------------------------------- /third_party/ocf/docs/docs/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/ocf/docs/docs/images/overview.png -------------------------------------------------------------------------------- /third_party/ocf/docs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome 2 | 3 | **Open Compute Framework** is a framework for decentralized computing. 4 | 5 | ## Why Decentralized Computing? 6 | 7 | In many cases, a single individual or organization won't have enough resources to run a large-scale computing task. We were facing two main challenges in the past: 8 | 9 | * Running LLM inference at a large scale is prohibitively expensive, especially when we need to run many different models on a large benchmark dataset. 10 | * We were hosting a generic benchmark and inviting participants, which exihibits a bursty workload. We need to pay for the idle time when the benchmark is not running. 11 | 12 | We believe that decentralized computing can help us solve these problems, in the following ways: 13 | 14 | * We can leverage the computing resources from the community, and run the benchmark at a large scale, such that we avoid the cost of running the benchmark on our own. Think about the SETI@home project. 15 | * We avoid single point of failure, as the computing resources are distributed across the globe. 16 | * We avoid the cost of idle time, as we can bring up idle resources to run the benchmark when needed. 17 | 18 | ## How Does It Work? 19 | 20 | The framework is built on top of [LibP2P](https://libp2p.io/), which connects the computing resources in a peer-to-peer network. Each request will be routed to a peer that is able to handle the request. We aim to make the routing as efficient as possible. 21 | 22 | ## Demo 23 | 24 | We run a public, free instance of OCF as the inference API. [Status Page](https://ocfstatus.Anonymous.dev). More details coming soon! 25 | -------------------------------------------------------------------------------- /third_party/ocf/docs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Open Compute Framework 2 | theme: 3 | name: material 4 | features: 5 | - navigation.sections 6 | - navigation.tabs 7 | - navigation.tabs.sticky 8 | 9 | repo_url: https://github.com/ocf 10 | repo_name: ocf 11 | nav: 12 | - Home: index.md 13 | - User Guide: 14 | - Getting Started: guide/getting-started.md 15 | - Inference: guide/inference.md 16 | - Advanced: 17 | - Internals: advanced/internals.md 18 | - Status: https://ocfstatus.Anonymous.dev 19 | extra: 20 | social: 21 | - icon: fontawesome/brands/discord 22 | link: https://discord.gg/3BD3RzK2K2 23 | - icon: fontawesome/brands/twitter 24 | link: https://twitter.com/AnonymousHQ 25 | - icon: fontawesome/brands/github 26 | link: https://github.com/Anonymous-org 27 | - icon: fontawesome/brands/docker 28 | link: https://hub.docker.com/u/Anonymous 29 | -------------------------------------------------------------------------------- /third_party/ocf/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-minify-plugin 3 | mkdocs-material -------------------------------------------------------------------------------- /third_party/ocf/examples/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /third_party/ocf/examples/apis/.gitignore: -------------------------------------------------------------------------------- 1 | constant.py 2 | *.pyc -------------------------------------------------------------------------------- /third_party/ocf/examples/apis/conn.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from constant import RELAY_URL, HOST_ID 3 | 4 | def get_conn(): 5 | url = f"{RELAY_URL}/api/v1/proxy/{HOST_ID}/api/v1/status/connections" 6 | resp = requests.get(url) 7 | return resp.text 8 | 9 | def get_global_view(): 10 | url = f"{RELAY_URL}/api/v1/status/table" 11 | resp = requests.get(url) 12 | return resp.text 13 | 14 | if __name__ == "__main__": 15 | # print(get_conn()) 16 | print(get_global_view()) -------------------------------------------------------------------------------- /third_party/ocf/examples/apis/inference.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from constant import RELAY_URL, HOST_ID 3 | from multiprocessing import Pool 4 | 5 | def inference(i): 6 | url = f"{RELAY_URL}/api/v1/proxy/{HOST_ID}/api/v1/request/inference" 7 | resp = requests.post(url, json={ 8 | 'model_name': 'openlm-research/open_llama_7b', 9 | 'params': { 10 | 'prompt': "Hello!" 11 | } 12 | }) 13 | return resp.text 14 | 15 | def global_inference(i): 16 | url = f"{RELAY_URL}/api/v1/request/inference" 17 | resp = requests.post(url, json={ 18 | 'model_name': 'meta-llama/Llama-2-70b-chat-hf', 19 | 'params': { 20 | 'prompt': "Alan Turing was a " 21 | } 22 | }) 23 | print(resp.json()) 24 | return resp.text 25 | 26 | if __name__ == "__main__": 27 | with Pool(1) as p: 28 | p.map(global_inference, range(1)) 29 | -------------------------------------------------------------------------------- /third_party/ocf/examples/apis/inference_devnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | 4 | URL = "http://140.238.214.47:8092/api/v1/request/inference" 5 | 6 | def inference(): 7 | resp = requests.post(URL, json={ 8 | 'model_name': 'microsoft/deberta-large-mnli', 9 | 'params': { 10 | 'prompt': "tell me about computer science?", 11 | 'max_tokens': 32, 12 | 'temperature': 0.7, 13 | 'top_p': 1.0, 14 | 'top_k': 40, 15 | } 16 | }) 17 | resp = resp.text 18 | print(resp) 19 | return resp 20 | 21 | if __name__ == "__main__": 22 | inference() -------------------------------------------------------------------------------- /third_party/ocf/examples/apis/inference_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | 4 | URL = "https://api.Anonymous.dev/inference" 5 | 6 | def inference(): 7 | resp = requests.post(URL, json={ 8 | 'model': 'microsoft/deberta-large-mnli', 9 | 'params': { 10 | 'prompt': "tell me about computer science?", 11 | 'max_tokens': 32, 12 | 'temperature': 0.7, 13 | 'top_p': 1.0, 14 | 'top_k': 40, 15 | } 16 | }) 17 | resp = resp.json() 18 | print(resp) 19 | return resp 20 | 21 | if __name__ == "__main__": 22 | inference() -------------------------------------------------------------------------------- /third_party/ocf/examples/apis/sd_inference_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | 4 | URL = "https://inference.Anonymous.dev/api/v1/request/inference" 5 | 6 | def inference(): 7 | resp = requests.post(URL, json={ 8 | 'model_name': 'stabilityai/stable-diffusion-xl-base-0.9', 9 | 'params': { 10 | 'prompt': "An astronaut is running on mars", 11 | } 12 | }) 13 | resp = resp.json() 14 | return resp 15 | 16 | if __name__ == "__main__": 17 | inference() -------------------------------------------------------------------------------- /third_party/ocf/examples/worker/_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import asyncio 3 | from loguru import logger 4 | from nats.aio.client import Client as NATS 5 | from utils import get_visible_gpus_specs 6 | 7 | class InferenceWorker(): 8 | def __init__(self, model_name) -> None: 9 | self.model_name = model_name 10 | # todo: get gpu specs from nvml 11 | 12 | async def run(self, loop): 13 | self.nc = NATS() 14 | await self.nc.connect("nats://localhost:8094") 15 | await self.nc.subscribe(f"inference:{self.model_name}", "workers", self.process_request) 16 | connection_notice = { 17 | 'service': f'inference:{self.model_name}', 18 | 'gpus': get_visible_gpus_specs(), 19 | 'client_id': self.nc.client_id, 20 | 'status': 'connected' 21 | } 22 | await self.nc.publish("worker:status", bytes(f"{json.dumps(connection_notice)}", encoding='utf-8')) 23 | 24 | async def process_request(self, msg): 25 | processed_msg = json.loads(msg.data.decode()) 26 | result = await self.handle_requests(processed_msg['params']) 27 | await self.reply(msg, result) 28 | 29 | async def handle_requests(self, msg): 30 | raise NotImplementedError 31 | 32 | async def reply(self, msg, data): 33 | data = json.dumps(data) 34 | await self.nc.publish(msg.reply, bytes(data, encoding='utf-8')) 35 | 36 | def start(self): 37 | logger.info(f"Starting {self.model_name} worker...") 38 | 39 | loop = asyncio.get_event_loop() 40 | loop.run_until_complete(self.run(loop)) 41 | loop.run_forever() 42 | loop.close() -------------------------------------------------------------------------------- /third_party/ocf/examples/worker/inference.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from _base import InferenceWorker 3 | class HFWorker(InferenceWorker): 4 | def __init__(self, model_name) -> None: 5 | super().__init__(model_name) 6 | 7 | async def handle_requests(self, msg): 8 | logger.info(f"Processing request {msg}") 9 | return {"result": "hello world"} 10 | 11 | if __name__=="__main__": 12 | worker = HFWorker("test") 13 | worker.start() -------------------------------------------------------------------------------- /third_party/ocf/examples/worker/utils.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | def get_visible_gpus_specs(): 3 | gpus = [] 4 | try: 5 | from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetName 6 | nvmlInit() 7 | deviceCount = nvmlDeviceGetCount() 8 | for i in range(deviceCount): 9 | handle = nvmlDeviceGetHandleByIndex(i) 10 | meminfo = nvmlDeviceGetMemoryInfo(handle) 11 | gpus.append({ 12 | 'name': nvmlDeviceGetName(handle), 13 | 'memory': meminfo.total, 14 | 'memory_free': meminfo.free, 15 | 'memory_used': meminfo.used, 16 | }) 17 | except Exception as e: 18 | logger.info(f"No GPU found: {e}") 19 | return gpus -------------------------------------------------------------------------------- /third_party/ocf/src/benchmark/inference.js: -------------------------------------------------------------------------------- 1 | import http from 'k6/http'; 2 | 3 | export const options = { 4 | stages: [ 5 | { duration: '30s', target: 20 }, 6 | { duration: '1m30s', target: 10 }, 7 | { duration: '20s', target: 0 }, 8 | ], 9 | }; 10 | 11 | export default function () { 12 | const url = 'https://inference.Anonymous.dev/api/v1/request/inference'; 13 | const payload = JSON.stringify({ 14 | model_name: 'mosaicml/mpt-7b-chat', 15 | params: { 16 | 'prompt': "I'm feeling happy today", 17 | } 18 | }); 19 | 20 | const params = { 21 | headers: { 22 | 'Content-Type': 'application/json', 23 | }, 24 | }; 25 | http.post(url, payload, params); 26 | } -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/.gitignore: -------------------------------------------------------------------------------- 1 | ocf_cli.egg-info 2 | dist/ 3 | build/ 4 | *.pyc -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | python setup.py install 3 | 4 | format: 5 | autoflake -i **/*.py 6 | isort -i ocf_cli/**/*.py 7 | yapf -i **/*.py 8 | 9 | clean: 10 | rm -rf build 11 | rm -rf dist 12 | rm -rf ocf_cli.egg-info 13 | 14 | build: 15 | python -m build --wheel 16 | 17 | test: 18 | PYTHONPATH=./ python3 tests/server.py 19 | 20 | publish-test: 21 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 22 | 23 | publish: 24 | twine upload dist/* 25 | 26 | install-local: 27 | pip install -e . 28 | 29 | install-test: 30 | pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ ocf_cli -U 31 | 32 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/ocf/src/ocf-cli/ocf_cli/__init__.py -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/ocf/src/ocf-cli/ocf_cli/bin/__init__.py -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/bin/ocf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import time 4 | import sched 5 | import typer 6 | from loguru import logger 7 | import netifaces as ni 8 | from typing import Optional 9 | from netifaces import AF_INET 10 | from ocf_cli.lib.core.utils import gpu_measure 11 | from ocf_cli.lib.core.config import read_config 12 | from ocf_cli.lib.pprint.nodes import pprint_nodes 13 | from ocf_cli.lib.pprint.service import pprint_service 14 | 15 | app = typer.Typer() 16 | home_dir = os.path.expanduser("~") 17 | default_ocf_home = os.path.join(home_dir, ".config", "ocf") 18 | config_path = os.path.join(default_ocf_home, "cli.json") 19 | config = read_config(config_path) 20 | 21 | @app.command() 22 | def list(entity: str): 23 | if entity == "node": 24 | pprint_nodes() 25 | elif entity == "service": 26 | pprint_service() 27 | else: 28 | print("[ERROR] Unknown: ", entity) 29 | 30 | @app.command() 31 | def main(): 32 | print("OCF CLI") 33 | 34 | @app.command() 35 | def join(host: str='localhost', 36 | nic_name: str ='access', 37 | report_interval: int=10, 38 | working_dir: str = "." 39 | ): 40 | print("> Joining OCF network") 41 | ip_addr = ni.ifaddresses(nic_name)[AF_INET][0]['addr'] 42 | gpu_stats = gpu_measure() 43 | if gpu_stats is not None and 'gpu' in gpu_stats: 44 | print(">> GPU found, joining OCF network") 45 | total_gpus = len(gpu_stats) 46 | # s = sched.scheduler(time.time, time.sleep) 47 | 48 | # s.enter(report_interval, 1, clock_watch, (s, tom_client, ip_addr, str(idx))) 49 | # s.run(blocking=True) 50 | else: 51 | logger.error("No GPU found, exiting...") 52 | 53 | if __name__ == "__main__": 54 | app() -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/core/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import signal 3 | import asyncio 4 | from loguru import logger 5 | from nats.aio.client import Client as NATS 6 | from utils import get_visible_gpus_specs 7 | import atexit 8 | 9 | async def shutdown(signal, loop, nc, model_name, connection_notice): 10 | """Cleanup tasks tied to the service's shutdown.""" 11 | logger.info(f"Gracefully shutting down {model_name} worker...") 12 | tasks = [t for t in asyncio.all_tasks() if t is not 13 | asyncio.current_task()] 14 | [task.cancel() for task in tasks] 15 | await asyncio.gather(*tasks) 16 | connection_notice['status'] = 'disconnected' 17 | await nc.publish("worker:status", bytes(f"{json.dumps(connection_notice)}", encoding='utf-8')) 18 | await nc.close() 19 | loop.stop() 20 | 21 | class BaseWorker(): 22 | def __init__(self, service_name) -> None: 23 | self.service_name = service_name 24 | self.nc = NATS() 25 | self.connection_notice = {} 26 | 27 | async def run(self, loop): 28 | await self.nc.connect("nats://localhost:8094") 29 | await self.nc.subscribe(self.service_name+f".{self.nc.client_id}", "workers", self.process_request) 30 | self.connection_notice = self.get_connection_notice() 31 | await self.nc.publish("worker:status", bytes(f"{json.dumps(self.connection_notice)}", encoding='utf-8')) 32 | 33 | async def process_request(self, msg): 34 | processed_msg = json.loads(msg.data.decode()) 35 | result = await self.handle_requests(processed_msg['params']) 36 | await self.reply(msg, result) 37 | 38 | def get_connection_notice(self): 39 | return { 40 | 'service': f'inference:{self.model_name}', 41 | 'gpus': get_visible_gpus_specs(), 42 | 'client_id': self.nc.client_id, 43 | 'status': 'connected' 44 | } 45 | 46 | async def handle_requests(self, msgs): 47 | raise NotImplementedError 48 | 49 | async def reply(self, msg, data): 50 | data = json.dumps(data) 51 | await self.nc.publish(msg.reply, bytes(data, encoding='utf-8')) 52 | 53 | def start(self): 54 | logger.info(f"Starting {self.model_name} worker...") 55 | loop = asyncio.get_event_loop() 56 | signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT, signal.SIGABRT, signal.SIGTSTP) 57 | for s in signals: 58 | loop.add_signal_handler( 59 | s, lambda s=s: asyncio.create_task(shutdown(s, loop, self.nc, self.model_name, self.connection_notice))) 60 | loop.run_until_complete(self.run(loop)) 61 | loop.run_forever() 62 | loop.close() -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/core/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Dict 4 | home_dir = os.path.expanduser("~") 5 | 6 | default_ocf_home = os.path.join(home_dir, ".config", "ocf") 7 | 8 | default_config = { 9 | "data_dir": os.path.join(default_ocf_home, "data"), 10 | "home_dir":os.path.join(default_ocf_home, "home"), 11 | "last_used_port": 8092 12 | } 13 | 14 | def write_config(config: Dict, path: str): 15 | with open(path, "w+") as f: 16 | json.dump(config, f) 17 | 18 | def read_config(path: str) -> Dict: 19 | if os.path.exists(path): 20 | with open(path, "r") as f: 21 | config = json.load(f) 22 | return config 23 | else: 24 | return default_config -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/core/host_worker.py: -------------------------------------------------------------------------------- 1 | from ocf_cli.lib.core.base import BaseWorker 2 | from ocf_cli.lib.core.utils import get_visible_gpus_specs 3 | 4 | """ 5 | Host worker is the meta-worker that will always connect to the ocf-node, and manages the start/stop of other workers. 6 | It also periodically report the status of the workers to the ocf-node. 7 | """ 8 | 9 | class HostWorker(BaseWorker): 10 | def __init__(self) -> None: 11 | self.service_name = "host" 12 | super().__init__(self.service_name) 13 | 14 | def get_connection_notice(self): 15 | notice = { 16 | 'service': f"{self.service_name}.{self.nc.client_id}", 17 | 'gpus': get_visible_gpus_specs(), 18 | 'client_id': self.nc.client_id, 19 | 'status': 'connected', 20 | 'offering': [] 21 | } 22 | return notice 23 | 24 | async def handle_requests(self, msgs): 25 | # i.e., only the last message is processed 26 | msgs = msgs[-1] 27 | print(msgs) -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/core/inference_worker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import signal 3 | import asyncio 4 | from loguru import logger 5 | from nats.aio.client import Client as NATS 6 | from utils import get_visible_gpus_specs 7 | import atexit 8 | 9 | async def shutdown(signal, loop, nc, model_name, connection_notice): 10 | """Cleanup tasks tied to the service's shutdown.""" 11 | logger.info(f"Gracefully shutting down {model_name} worker...") 12 | tasks = [t for t in asyncio.all_tasks() if t is not 13 | asyncio.current_task()] 14 | [task.cancel() for task in tasks] 15 | await asyncio.gather(*tasks) 16 | connection_notice['status'] = 'disconnected' 17 | await nc.publish("worker:status", bytes(f"{json.dumps(connection_notice)}", encoding='utf-8')) 18 | await nc.close() 19 | loop.stop() 20 | 21 | class InferenceWorker(): 22 | def __init__(self, model_name) -> None: 23 | self.model_name = model_name 24 | # todo: get gpu specs from nvml 25 | self.nc = NATS() 26 | self.connection_notice = {} 27 | 28 | async def run(self, loop): 29 | await self.nc.connect("nats://localhost:8094") 30 | await self.nc.subscribe(f"inference:{self.model_name}", "workers", self.process_request) 31 | self.connection_notice = { 32 | 'service': f'inference:{self.model_name}', 33 | 'gpus': get_visible_gpus_specs(), 34 | 'client_id': self.nc.client_id, 35 | 'status': 'connected' 36 | } 37 | await self.nc.publish("worker:status", bytes(f"{json.dumps(self.connection_notice)}", encoding='utf-8')) 38 | 39 | async def process_request(self, msg): 40 | processed_msg = json.loads(msg.data.decode()) 41 | result = await self.handle_requests(processed_msg['params']) 42 | await self.reply(msg, result) 43 | 44 | async def handle_requests(self, msg): 45 | raise NotImplementedError 46 | 47 | async def reply(self, msg, data): 48 | data = json.dumps(data) 49 | await self.nc.publish(msg.reply, bytes(data, encoding='utf-8')) 50 | 51 | def start(self): 52 | # atexit.register(exit_handler, self.nc, self.model_name, self.connection_notice) 53 | logger.info(f"Starting {self.model_name} worker...") 54 | loop = asyncio.get_event_loop() 55 | signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT, signal.SIGQUIT, signal.SIGABRT, signal.SIGTSTP) 56 | for s in signals: 57 | loop.add_signal_handler( 58 | s, lambda s=s: asyncio.create_task(shutdown(s, loop, self.nc, self.model_name, self.connection_notice))) 59 | loop.run_until_complete(self.run(loop)) 60 | loop.run_forever() 61 | loop.close() -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/core/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from loguru import logger 3 | 4 | import traceback 5 | from typing import Union 6 | import pynvml 7 | 8 | 9 | def gpu_measure() -> Union[dict, None]: 10 | try: 11 | pynvml.nvmlInit() 12 | metrics = {"gpu": []} 13 | deviceCount = pynvml.nvmlDeviceGetCount() 14 | for i in range(deviceCount): 15 | handle = pynvml.nvmlDeviceGetHandleByIndex(i) 16 | name = pynvml.nvmlDeviceGetName(handle) 17 | mem = pynvml.nvmlDeviceGetMemoryInfo(handle) 18 | power = pynvml.nvmlDeviceGetPowerUsage(handle) 19 | utilitization = pynvml.nvmlDeviceGetUtilizationRates(handle) 20 | try: 21 | name = name.decode("utf-8") 22 | except Exception as e: 23 | pass 24 | metrics["gpu"].append( 25 | { 26 | "product_name": name, 27 | "fb_memory_usage": { 28 | "total": mem.total / 1024 / 1024, 29 | "used": mem.used / 1024 / 1024, 30 | "free": mem.free / 1024 / 1024, 31 | }, 32 | "utilization": utilitization.gpu, 33 | "power_readings": {"power_draw": power / 1000}, 34 | } 35 | ) 36 | except pynvml.NVMLError as error: 37 | traceback.print_exc() 38 | print(error) 39 | metrics = None 40 | finally: 41 | pynvml.nvmlShutdown() 42 | return metrics 43 | 44 | def get_visible_gpus_specs(): 45 | # https://github.com/gpuopenanalytics/pynvml/issues/28 46 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 47 | gpus = [] 48 | try: 49 | from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetName 50 | nvmlInit() 51 | if "CUDA_VISIBLE_DEVICES" in os.environ: 52 | ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))) 53 | else: 54 | deviceCount = nvmlDeviceGetCount() 55 | ids = range(deviceCount) 56 | for i in ids: 57 | handle = nvmlDeviceGetHandleByIndex(i) 58 | meminfo = nvmlDeviceGetMemoryInfo(handle) 59 | gpus.append({ 60 | 'name': nvmlDeviceGetName(handle), 61 | 'memory': meminfo.total, 62 | 'memory_free': meminfo.free, 63 | 'memory_used': meminfo.used, 64 | }) 65 | except Exception as e: 66 | logger.info(f"No GPU found: {e}") 67 | return gpus -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/pod/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DEBUG = bool(os.getenv("TOMCLIENT_DEBUG", default="")) 4 | CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | COLOR_MAIN = "#fdca40" 7 | COLOR_ACCENT = "#3aaed8" 8 | COLOR_ERROR = "#f64740" 9 | STATUS_COLORS = { 10 | "running": "#f79824", 11 | "paused": "#f6efee", 12 | "success": "#4aad52", 13 | "failed": COLOR_ERROR, 14 | } 15 | 16 | ICON_POD = "⚡️" 17 | ICON_INFO = "🧲" 18 | ICON_STATUS = "•" 19 | ICON_KILLED = "💀" 20 | 21 | FAILFAST_DELAY = 2 22 | DATETIME_FORMAT = "%H:%M:%S %Y/%m/%d" 23 | TRUNCATE_LENGTH = 36 24 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/pod/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shlex 4 | import signal 5 | import time 6 | from functools import wraps 7 | from pathlib import Path 8 | 9 | import click 10 | import psutil 11 | from ocf_cli.lib.pod import config 12 | 13 | logging.disable(level=logging.CRITICAL) 14 | 15 | if config.DEBUG: 16 | logging.disable(logging.NOTSET) 17 | logging.basicConfig(level=logging.DEBUG) 18 | 19 | 20 | logger = logging.getLogger(__package__) 21 | 22 | 23 | def allow_missing(func): 24 | @wraps(func) 25 | def wrapper(*args, **kwargs): 26 | try: 27 | return func(*args, **kwargs) 28 | except FileNotFoundError: 29 | pass 30 | 31 | return wrapper 32 | 33 | 34 | def shlex_join_backport(split_command): 35 | """Return a shell-escaped string""" 36 | return " ".join(shlex.quote(arg) for arg in split_command) 37 | 38 | 39 | class timed(object): 40 | def __init__(self): 41 | self._start = time.time() 42 | self._end = None 43 | 44 | @property 45 | def elapsed(self): 46 | if self._end is not None: 47 | return self._end - self._start 48 | 49 | def __enter__(self): 50 | return self 51 | 52 | def __exit__(self, *args): 53 | self._end = time.time() 54 | return self 55 | 56 | 57 | def wait_created( 58 | path: Path, interval: float = 0.1, timeout: float = config.FAILFAST_DELAY 59 | ): 60 | start = time.time() 61 | while not path.exists() and time.time() - start < timeout: 62 | time.sleep(interval) 63 | return path.exists() 64 | 65 | 66 | def validate_signal(ctx, param, value): 67 | try: 68 | signal_code = int(value) 69 | except ValueError: 70 | raise click.BadParameter("Signal should be a valid integer value") 71 | 72 | try: 73 | return signal.Signals(signal_code) 74 | except ValueError: 75 | raise click.BadParameter(f"{signal_code} is not a valid signal code") 76 | 77 | 78 | def kill_proc_tree(pid, sig=signal.SIGKILL, include_parent=True): 79 | if pid == os.getpid(): 80 | raise ValueError("Would not kill myself") 81 | 82 | parent = psutil.Process(pid) 83 | children = parent.children(recursive=True) 84 | if include_parent: 85 | children.append(parent) 86 | for p in children: 87 | try: 88 | p.send_signal(sig) 89 | logger.debug(f"Sent {sig} to {p.pid} process") 90 | except psutil.NoSuchProcess: 91 | pass 92 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/pprint/_base.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | 3 | console = Console() 4 | 5 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/pprint/nodes.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from rich.pretty import pprint 3 | from rich.table import Table 4 | from ocf_cli.lib.pprint._base import console 5 | 6 | RELAY_URL = "https://inference.Anonymous.dev" 7 | 8 | def pprint_nodes(): 9 | url = f"{RELAY_URL}/api/v1/status/table" 10 | resp = requests.get(url) 11 | resp = resp.json() 12 | table = Table(title="Connected Nodes") 13 | table.add_column("Node") 14 | table.add_column("Worker") 15 | table.add_column("Service") 16 | table.add_column("Status") 17 | table.add_column("GPU Device") 18 | table.add_column("GPU Memory (used / total)") 19 | for node in resp['nodes']: 20 | # make it 1x NVIDIA GeForce RTX 3090... etc. 21 | gpus_specs = {} 22 | for gpu in node["gpus"]: 23 | if gpu["name"] not in gpus_specs: 24 | gpus_specs[gpu["name"]] = 0 25 | gpus_specs[gpu["name"]] += 1 26 | gpu_specs_str = "" 27 | for gpu_name, gpu_count in gpus_specs.items(): 28 | gpu_specs_str += f"{gpu_count}x {gpu_name}, " 29 | gpu_specs_str = gpu_specs_str[:-2] 30 | 31 | used_memory = 0 32 | total_memory = 0 33 | for gpu in node["gpus"]: 34 | used_memory += gpu["memory_used"] 35 | total_memory += gpu["memory"] 36 | memory_str = f"{used_memory/1024/1024:.2f} / {total_memory/1024/1024:.2f} MB" 37 | table.add_row( 38 | node["peer_id"], 39 | str(node["client_id"]), 40 | node["service"], 41 | node["status"], 42 | gpu_specs_str, 43 | memory_str, 44 | ) 45 | console.print(table) -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/ocf_cli/lib/pprint/service.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from rich.pretty import pprint 3 | from rich.table import Table 4 | from ocf_cli.lib.pprint._base import console 5 | 6 | RELAY_URL = "https://inference.Anonymous.dev" 7 | 8 | def pprint_service(): 9 | url = f"{RELAY_URL}/api/v1/status/table" 10 | resp = requests.get(url) 11 | resp = resp.json() 12 | table = Table(title="Service") 13 | table.add_column("Service") 14 | table.add_column("Providers") 15 | services = {} 16 | for node in resp['nodes']: 17 | if node['service'] not in services: 18 | services[node['service']] = {'providers': []} 19 | services[node['service']]['providers'].append(node) 20 | for service in services: 21 | gpu_specs = {} 22 | for node in services[service]['providers']: 23 | for gpu in node["gpus"]: 24 | if gpu["name"] not in gpu_specs: 25 | gpu_specs[gpu["name"]] = 0 26 | gpu_specs[gpu["name"]] += 1 27 | gpu_specs_str = "" 28 | for gpu_name, gpu_count in gpu_specs.items(): 29 | gpu_specs_str += f"{gpu_count}x {gpu_name}, " 30 | gpu_specs_str = gpu_specs_str[:-2] 31 | table.add_row( 32 | service, 33 | gpu_specs_str, 34 | ) 35 | console.print(table) -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-cli/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ocf_cli" 7 | version = "0.0.2" 8 | authors = [ 9 | { name="Anonymous author", email="Anonymous email" }, 10 | ] 11 | description = "OCF CLI" 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | dependencies = [ 20 | "typer", 21 | "requests", 22 | "rich", 23 | "loguru", 24 | "huggingface-hub", 25 | "pynvml", 26 | "loguru", 27 | "nats-py", 28 | "click", 29 | "psutil", 30 | "humanize", 31 | "netifaces", 32 | ] 33 | 34 | [project.scripts] 35 | ocf = 'ocf_cli.bin.ocf:app' 36 | 37 | [tool.setuptools] 38 | packages = ["ocf_cli"] 39 | 40 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | build/ -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:alpine 2 | 3 | ENV GIN_MODE=release 4 | ENV PORT=8092 5 | 6 | COPY . /app 7 | 8 | WORKDIR /app 9 | 10 | RUN apk add --no-cache git make bash && make build 11 | 12 | EXPOSE $PORT 13 | 14 | ENTRYPOINT ["/app/build/core", "start", "--config", "/app/config/cfg.yaml"] -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/cmd/cluster.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "ocfcore/internal/cluster" 5 | "ocfcore/internal/common" 6 | "ocfcore/internal/common/structs" 7 | 8 | "github.com/spf13/cobra" 9 | "github.com/spf13/viper" 10 | ) 11 | 12 | var acquireCmd = &cobra.Command{ 13 | Use: "add", 14 | Short: "Add a machine to the cluster effectively", 15 | Run: func(cmd *cobra.Command, args []string) { 16 | slurmClient := cluster.NewSlurmClusterClient() 17 | slurmClient.AcquireMachine(structs.AcquireMachinePayload{ 18 | Script: viper.GetString("acquire_machine.script"), 19 | Params: make(map[string]string, 0), 20 | }) 21 | }, 22 | } 23 | 24 | var clusterCmd = &cobra.Command{ 25 | Use: "cluster", 26 | Short: "Manage the cluster status", 27 | Long: `Manage the cluster status, currently supporting slurm, kubernetes and baremetal`, 28 | Run: func(cmd *cobra.Command, args []string) { 29 | err := cmd.Help() 30 | if err != nil { 31 | common.Logger.Error("Could not print help", "error", err) 32 | } 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/cmd/config.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | type P2PConfig struct { 4 | Port string `json:"port", yaml:"port"` 5 | } 6 | 7 | type VaccumConfig struct { 8 | Interval int `json:"interval", yaml:"interval"` 9 | } 10 | 11 | type QueueConfig struct { 12 | Port string `json:"port", yaml:"port"` 13 | } 14 | 15 | type Config struct { 16 | Path string `json:"path", yaml:"path"` 17 | Port string `json:"port", yaml:"port"` 18 | Name string `json:"name", yaml:"name"` 19 | P2p P2PConfig `json:"p2p", yaml:"p2p"` 20 | Vacuum VaccumConfig `json:"vacuum", yaml:"vacuum"` 21 | Queue QueueConfig `json:"queue", yaml:"queue"` 22 | Account AccountConfig `json:"account", yaml:"account"` 23 | } 24 | 25 | type AccountConfig struct { 26 | Wallet string `json:"wallet", yaml:"wallet"` 27 | } 28 | 29 | var defaultConfig = Config{ 30 | Path: "", 31 | Port: "8092", 32 | Name: "relay", 33 | P2p: P2PConfig{Port: "8093"}, 34 | Vacuum: VaccumConfig{Interval: 10}, 35 | Queue: QueueConfig{Port: "8094"}, 36 | Account: AccountConfig{Wallet: ""}, 37 | } 38 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/cmd/init.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var iniocfcored = &cobra.Command{ 8 | Use: "init", 9 | Short: "Initialize the system, create the database and the config file", 10 | Run: func(cmd *cobra.Command, args []string) { 11 | 12 | }} 13 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/cmd/start.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "ocfcore/internal/daemon" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | var startocfcore = &cobra.Command{ 10 | Use: "start", 11 | Short: "Start listening for incoming connections", 12 | Run: func(cmd *cobra.Command, args []string) { 13 | daemon.Start() 14 | }} 15 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/cmd/update.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "ocfcore/internal/common" 8 | 9 | "github.com/minio/selfupdate" 10 | "github.com/spf13/cobra" 11 | ) 12 | 13 | var updateCmd = &cobra.Command{ 14 | Use: "update", 15 | Short: "Update the ocf binary to the latest version", 16 | Run: func(cmd *cobra.Command, args []string) { 17 | updateURL := "https://cdn.xzyao.dev/ocfcore" 18 | resp, err := http.Get(updateURL) 19 | if err != nil { 20 | common.Logger.Error("Error while checking for updates: ", err) 21 | } 22 | defer resp.Body.Close() 23 | err = selfupdate.Apply(resp.Body, selfupdate.Options{}) 24 | if err != nil { 25 | if rerr := selfupdate.RollbackError(err); rerr != nil { 26 | common.Logger.Info("Failed to rollback from bad update: ", rerr) 27 | } 28 | } 29 | common.Logger.Info("Successfully updated") 30 | fmt.Printf("ocfcore version %s", common.JSONVersion.Version) 31 | fmt.Printf(" (commit: %s)", common.JSONVersion.Commit) 32 | fmt.Printf(" (built at: %s)", common.JSONVersion.Date) 33 | fmt.Println() 34 | }, 35 | } 36 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/cmd/utility.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "ocfcore/internal/common" 6 | 7 | "github.com/spf13/cobra" 8 | ) 9 | 10 | var versionCmd = &cobra.Command{ 11 | Use: "version", 12 | Short: "Print the version of ocfcore", 13 | Run: func(cmd *cobra.Command, args []string) { 14 | fmt.Printf("ocfcore version %s", common.JSONVersion.Version) 15 | fmt.Printf(" (commit: %s)", common.JSONVersion.Commit) 16 | fmt.Printf(" (built at: %s)", common.JSONVersion.Date) 17 | fmt.Println() 18 | }, 19 | } 20 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/core/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "ocfcore/bin/core/cmd" 5 | "ocfcore/internal/common" 6 | ) 7 | 8 | var ( 9 | // Populated during build 10 | version = "dev" 11 | commitHash = "?" 12 | buildDate = "" 13 | authUrl = "" 14 | authClientId = "" 15 | authSecret = "" 16 | sentryDSN = "" 17 | ) 18 | 19 | func main() { 20 | common.JSONVersion.Version = version 21 | common.JSONVersion.Commit = commitHash 22 | common.JSONVersion.Date = buildDate 23 | common.BuildSecret.AuthClientID = authClientId 24 | common.BuildSecret.AuthURL = authUrl 25 | common.BuildSecret.AuthSecret = authSecret 26 | common.BuildSecret.SentryDSN = sentryDSN 27 | cmd.Execute() 28 | } 29 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/bin/netctl/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/config/cfg.yaml: -------------------------------------------------------------------------------- 1 | name: relay 2 | p2p: {} 3 | path: "" 4 | port: "8092" 5 | queue: 6 | port: "8094" 7 | vacuum: 8 | interval: 10 9 | bootstrap: 10 | addrs: 11 | - "/ip4/198.176.96.165/tcp/43905/p2p/QmVZjanvTGSJXST3s6JGuPaNQbzW8JjQqVVfF5m8VkkqSj" 12 | # mode: "standalone" 13 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/config/cfg_standalone.yaml: -------------------------------------------------------------------------------- 1 | name: relay 2 | p2p: {} 3 | path: "" 4 | port: "8092" 5 | queue: 6 | port: "8094" 7 | vacuum: 8 | interval: 10 9 | bootstrap: 10 | addrs: 11 | # - "/ip4/147.189.197.175/tcp/43905/p2p/QmVTMfAUK1qMZGvrV6rAB16naJJhCYMNiyjFQqamS8RZGn" 12 | mode: "standalone" -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/cluster/baremetal.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/cluster/cluster.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import "ocfcore/internal/common/structs" 4 | 5 | type ClusterManager interface { 6 | AcquireMachine(payload structs.AcquireMachinePayload) 7 | Execute(command string) 8 | } 9 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/cluster/kubenetes.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/cluster/network/edgevpn.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | func StartNetworkWeaver() { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/cluster/slurm.go: -------------------------------------------------------------------------------- 1 | package cluster 2 | 3 | import ( 4 | "os/exec" 5 | "ocfcore/internal/common" 6 | "ocfcore/internal/common/structs" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | type SlurmCluster struct { 12 | ConnectedMachine []structs.WarmedMachine 13 | } 14 | 15 | var slurmClusterClient *SlurmCluster 16 | 17 | func NewSlurmClusterClient() *SlurmCluster { 18 | if slurmClusterClient == nil { 19 | slurmClusterClient = &SlurmCluster{} 20 | } 21 | return slurmClusterClient 22 | } 23 | 24 | func (s *SlurmCluster) AcquireMachine(payload structs.AcquireMachinePayload) { 25 | common.Logger.Info("Acquiring machine", "payload", payload) 26 | output, err := s.execute(payload.Script) 27 | if err != nil { 28 | common.Logger.Error("Could not acquire machine", "error", err) 29 | return 30 | } 31 | outputString := strings.Split(output, " ") 32 | machineID := outputString[len(outputString)-1] 33 | common.Logger.Info("Machine ID", "machineID", machineID) 34 | s.ConnectedMachine = append(s.ConnectedMachine, structs.WarmedMachine{ 35 | MachineID: machineID, 36 | Status: "REQUESTING", 37 | StartedAt: time.Now().Unix(), 38 | Life: time.Hour * 4, 39 | }) 40 | } 41 | 42 | func (s *SlurmCluster) execute(command string) (string, error) { 43 | // execute the command 44 | prg := "sbatch" 45 | args := []string{command} 46 | cmd := exec.Command(prg, args...) 47 | stdout, err := cmd.Output() 48 | if err != nil { 49 | return "", err 50 | } 51 | return string(stdout), nil 52 | } 53 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/constants.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "os" 5 | "path" 6 | 7 | "github.com/mitchellh/go-homedir" 8 | "github.com/spf13/viper" 9 | ) 10 | 11 | func GetocfcorePath() string { 12 | home, err := homedir.Dir() 13 | if err != nil { 14 | Logger.Error("Could not get home directory", "error", err) 15 | home = "." 16 | } 17 | ocfcorePath := path.Join(home, ".ocfcore") 18 | if _, err := os.Stat(ocfcorePath); os.IsNotExist(err) { 19 | err := os.MkdirAll(ocfcorePath, 0755) 20 | if err != nil { 21 | Logger.Error("Could not create ocfcore directory", "error", err) 22 | return "." 23 | } 24 | } 25 | return ocfcorePath 26 | } 27 | 28 | func GetDBPath() string { 29 | if viper.Get("database.path") != nil { 30 | return viper.GetString("database.path") 31 | } 32 | home, err := homedir.Dir() 33 | if err != nil { 34 | return "./ocfcore.db" 35 | } 36 | 37 | dbPath := path.Join(home, ".ocfcore", "ocfcore.db") 38 | return dbPath 39 | } 40 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/logger.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "github.com/spf13/viper" 5 | "go.uber.org/zap" 6 | "go.uber.org/zap/zapcore" 7 | ) 8 | 9 | var Logger *zap.SugaredLogger 10 | 11 | func init() { 12 | config := zap.NewDevelopmentConfig() 13 | if viper.Get("loglevel") != nil { 14 | // if it is not set, by default will be 0 - info 15 | config.Level.SetLevel(zapcore.Level(viper.GetInt("log_level"))) 16 | } 17 | config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder 18 | 19 | zapLogger, err := config.Build() 20 | // trunk-ignore(golangci-lint/errcheck) 21 | defer zapLogger.Sync() 22 | if err != nil { 23 | panic(err) 24 | } 25 | Logger = zapLogger.Sugar() 26 | } 27 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/process/manager.go: -------------------------------------------------------------------------------- 1 | package process 2 | 3 | type ProcessManager struct { 4 | processes []*Process 5 | } 6 | 7 | var pm *ProcessManager 8 | 9 | func NewProcessManager() *ProcessManager { 10 | if pm == nil { 11 | pm = &ProcessManager{} 12 | } 13 | return pm 14 | } 15 | 16 | func (pm *ProcessManager) StartProcess(command string, envs string, args []string) { 17 | process := NewProcess(command, envs, args...) 18 | process = process.Start() 19 | pm.processes = append(pm.processes, process) 20 | } 21 | 22 | func (pm *ProcessManager) StopAllProcesses() { 23 | for _, process := range pm.processes { 24 | process.Kill() 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/requests/broadcast.go: -------------------------------------------------------------------------------- 1 | package requests 2 | 3 | import ( 4 | "ocfcore/internal/common/structs" 5 | "ocfcore/internal/protocol/p2p" 6 | ) 7 | 8 | // functions for massively broadcasting messages to all peers 9 | 10 | func BroadcastNodeStatus(nodeStatus structs.NodeStatus) { 11 | node := p2p.GetP2PNode() 12 | dnt := p2p.GetNodeTable() 13 | for _, peer := range dnt.Peers { 14 | if peer.PeerID != node.ID().String() { 15 | // we don't need to update local node table as it is already updated 16 | UpdateRemoteNodeTable(peer.PeerID, nodeStatus) 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/requests/client.go: -------------------------------------------------------------------------------- 1 | package requests 2 | 3 | import ( 4 | "github.com/sethgrid/pester" 5 | ) 6 | 7 | var client *pester.Client 8 | 9 | func NewHTTPClient() *pester.Client { 10 | if client == nil { 11 | client = pester.New() 12 | client.MaxRetries = 1 13 | client.Concurrency = 1 14 | client.Backoff = pester.ExponentialJitterBackoff 15 | } 16 | return client 17 | } 18 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/requests/proxy.go: -------------------------------------------------------------------------------- 1 | package requests 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "ocfcore/internal/common/structs" 8 | "strings" 9 | 10 | "github.com/spf13/viper" 11 | ) 12 | 13 | func ForwardInferenceRequest(peerId string, req structs.InferenceStruct) (string, error) { 14 | remoteAddr := fmt.Sprintf("http://localhost:%s/api/v1/proxy/%s/api/v1/request/_inference", viper.GetString("port"), peerId) 15 | reqString, err := json.Marshal(req) 16 | if err != nil { 17 | return "", err 18 | } 19 | payload := strings.NewReader(string(reqString)) 20 | resp, err := NewHTTPClient().Post(remoteAddr, "application/json", payload) 21 | if err != nil { 22 | return "", err 23 | } 24 | b, err := io.ReadAll(resp.Body) 25 | if err != nil { 26 | return "", err 27 | } 28 | return string(b), nil 29 | } 30 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/requests/weaver.go: -------------------------------------------------------------------------------- 1 | package requests 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "ocfcore/internal/common" 8 | "ocfcore/internal/common/structs" 9 | "strings" 10 | 11 | ns "github.com/nats-io/nats-server/v2/server" 12 | "github.com/spf13/viper" 13 | "golang.org/x/exp/slices" 14 | ) 15 | 16 | var blackListedPeers []string 17 | 18 | func CheckPeerStatus(peerId string) error { 19 | if !slices.Contains(blackListedPeers, peerId) { 20 | peerAddr := fmt.Sprintf("http://localhost:%s/api/v1/proxy/%s/api/v1/status/health?peer=0", viper.GetString("port"), peerId) 21 | resp, err := NewHTTPClient().Get(peerAddr) 22 | if err != nil { 23 | common.Logger.Error("Error while checking peer status", "error", err) 24 | return err 25 | } 26 | b, err := io.ReadAll(resp.Body) 27 | if err != nil { 28 | common.Logger.Error("Error while reading response body", "error", err) 29 | return err 30 | } 31 | if string(b) == "ERROR: protocol not supported" { 32 | blackListedPeers = append(blackListedPeers, peerId) 33 | return fmt.Errorf("peer %s is not ocfcore", peerId) 34 | } else { 35 | fmt.Println(string(b)) 36 | } 37 | return nil 38 | 39 | } 40 | return fmt.Errorf("peer is blacklisted") 41 | } 42 | 43 | func ReadProvidedService(peerId string) ([]string, error) { 44 | remoteAddr := fmt.Sprintf("http://localhost:%s/api/v1/proxy/%s/api/v1/status/connections", viper.GetString("port"), peerId) 45 | resp, err := NewHTTPClient().Get(remoteAddr) 46 | if err != nil { 47 | common.Logger.Debug("Error while reading provided service", "error", err) 48 | return nil, err 49 | } 50 | b, err := io.ReadAll(resp.Body) 51 | if err != nil { 52 | common.Logger.Error("Error while reading response body", "error", err) 53 | return nil, err 54 | } 55 | var conns ns.Connz 56 | json.Unmarshal(b, &conns) 57 | var providedService []string 58 | for _, conn := range conns.Conns { 59 | providedService = append(providedService, conn.Subs...) 60 | } 61 | return providedService, nil 62 | } 63 | 64 | func UpdateRemoteNodeTable(peerId string, nodeStatus structs.NodeStatus) error { 65 | remoteAddr := fmt.Sprintf("http://localhost:%s/api/v1/proxy/%s/api/v1/status/table", viper.GetString("port"), peerId) 66 | reqString, err := json.Marshal(nodeStatus) 67 | if err != nil { 68 | return err 69 | } 70 | payload := strings.NewReader(string(reqString)) 71 | resp, err := NewHTTPClient().Post(remoteAddr, "application/json", payload) 72 | if err != nil { 73 | return err 74 | } 75 | b, err := io.ReadAll(resp.Body) 76 | if err != nil { 77 | return err 78 | } 79 | fmt.Println(string(b)) 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/secrets.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | type buildSecret struct { 4 | AuthClientID string `json:"AUTH_CLIENT_ID"` 5 | AuthURL string `json:"AUTH_URL"` 6 | AuthSecret string `json:"AUTH_CLIENT_SECRET"` 7 | SentryDSN string `json:"SENTRY_DSN"` 8 | } 9 | 10 | var BuildSecret buildSecret 11 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/structs/cluster.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | import ( 4 | "ocfcore/internal/protocol/p2p" 5 | "time" 6 | ) 7 | 8 | // Cluster is a lower-level interface from workers 9 | 10 | type AcquireMachinePayload struct { 11 | Script string `json:"script"` 12 | Params map[string]string `json:"params"` 13 | } 14 | 15 | type WarmedMachine struct { 16 | MachineID string `json:"machine_id"` 17 | Status string `json:"status"` 18 | Life time.Duration `json:"life"` 19 | StartedAt int64 `json:"started_at"` 20 | } 21 | 22 | type NodeStatus struct { 23 | PeerID string `json:"peer_id"` 24 | ClientID int `json:"client_id"` 25 | Status string `json:"status"` 26 | Specs []p2p.GPUSpec `json:"gpus"` 27 | Service string `json:"service"` 28 | } 29 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/structs/matching.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | type MatchingWorkerStatus struct { 4 | Accelerator string `json:"accelerator"` 5 | Status string `json:"status"` 6 | } 7 | 8 | type ExpectedRuntime struct { 9 | WorkerID string `json:"worker_id"` 10 | Runtime float64 `json:"runtime"` 11 | } 12 | 13 | type MatchingModelStatus struct { 14 | Workers []ExpectedRuntime `json:"expectations"` 15 | } 16 | 17 | type MatchingStatus struct { 18 | Workers map[string]MatchingWorkerStatus `json:"workers"` 19 | Models map[string]MatchingModelStatus `json:"models"` 20 | Timestamp int64 `json:"timestamp"` 21 | } 22 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/structs/request.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | type InferenceStruct struct { 4 | UniqueModelName string `json:"model_name"` 5 | Params map[string]interface{} `json:"params"` 6 | } 7 | 8 | type GenericStruct struct { 9 | JobTypeID string `json:"job_type_id"` 10 | Params map[string]interface{} `json:"params"` 11 | } 12 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/structs/summary.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | type CardStatus struct { 4 | CardID string `json:"card_id"` 5 | Status string `json:"status"` 6 | Serving string `json:"serving"` 7 | PowerUsage float64 `json:"power_usage"` 8 | GPUUtilization float64 `json:"gpu_utilization"` 9 | UsedMemory float64 `json:"used_memory"` 10 | AvailableMemory float64 `json:"available_memory"` 11 | LastUpdated int64 `json:"last_updated"` 12 | GPUSpecifier string `json:"gpu_specifier"` 13 | } 14 | 15 | type StatusSummary struct { 16 | Status map[string]CardStatus `json:"status"` 17 | } 18 | 19 | type CardMetrics struct { 20 | PowerUsage float64 `json:"power_usage"` 21 | GPUUtilization float64 `json:"gpu_utilization"` 22 | UsedMemory float64 `json:"used_memory"` 23 | AvailableMemory float64 `json:"available_memory"` 24 | } 25 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/structs/workload.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | type AvailableWorkload struct { 4 | Name string `json:"name"` 5 | Modes []string `json:"modes"` 6 | } 7 | 8 | type NatsConnections struct { 9 | ServerID string `json:"server_id"` 10 | } 11 | 12 | type LoadWorkLoadInstruction struct { 13 | Workload string `json:"workload"` 14 | Mode string `json:"mode"` 15 | BootstrapConfig map[string]string `json:"bootstrap_config"` 16 | } 17 | 18 | type ProvisionModelsPlan struct { 19 | Instructions []LoadWorkLoadInstruction `json:"instructions"` 20 | } 21 | 22 | // WorkloadInstructionsHub maps workerID to LoadWorkLoadInstruction 23 | type WorkloadInstructionsHub struct { 24 | Instructions map[string]ProvisionModelsPlan 25 | } 26 | 27 | type WorkloadTableRow struct { 28 | WorkloadID string `json:"workload_id"` 29 | Providers []string `json:"providers"` 30 | } 31 | 32 | type WorkloadTable struct { 33 | Workloads []WorkloadTableRow `json:"workloads"` 34 | } 35 | 36 | func (wt WorkloadTable) Add(workloadID string, provider string) *WorkloadTable { 37 | for _, workload := range wt.Workloads { 38 | if workload.WorkloadID == workloadID { 39 | workload.Providers = append(workload.Providers, provider) 40 | return &wt 41 | } 42 | } 43 | row := WorkloadTableRow{WorkloadID: workloadID, Providers: []string{provider}} 44 | wt.Workloads = append(wt.Workloads, row) 45 | return &wt 46 | } 47 | 48 | func (wt WorkloadTable) Find(workloadID string) *WorkloadTableRow { 49 | for _, workload := range wt.Workloads { 50 | if workload.WorkloadID == workloadID { 51 | return &workload 52 | } 53 | } 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/utils.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | func ContainsString(slice []string, item string) bool { 4 | for _, v := range slice { 5 | if v == item { 6 | return true 7 | } 8 | } 9 | return false 10 | } 11 | 12 | func RemoveString(slice []string, item string) []string { 13 | for idx, v := range slice { 14 | if v == item { 15 | return append(slice[:idx], slice[idx+1:]...) 16 | } 17 | } 18 | return slice 19 | } 20 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/common/version.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | type jsonVersion struct { 4 | Version string `json:"version"` 5 | Commit string `json:"commit"` 6 | Date string `json:"date"` 7 | } 8 | 9 | var JSONVersion jsonVersion 10 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/daemon/clock.go: -------------------------------------------------------------------------------- 1 | package daemon 2 | 3 | import ( 4 | "ocfcore/internal/common" 5 | "ocfcore/internal/server" 6 | "time" 7 | 8 | "github.com/go-co-op/gocron" 9 | "github.com/spf13/viper" 10 | ) 11 | 12 | var firstRun = true 13 | 14 | func StartTicker() { 15 | s := gocron.NewScheduler(time.UTC) 16 | 17 | _, err := s.Every(viper.GetInt("vacuum.interval")).Seconds().Do(func() { 18 | if firstRun { 19 | // skip the first run to wait until the server is ready 20 | firstRun = false 21 | return 22 | } 23 | server.DisconnectionDetection() 24 | // todo: disable this for now 25 | // todo: in future this will be managed more passively - each node monitors its own worker periodically and broadcast the status to the peers 26 | // server.UpdateGlobalWorkloadTable() 27 | }) 28 | if err != nil { 29 | common.Logger.Error(err) 30 | } 31 | s.StartAsync() 32 | } 33 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/daemon/tcmd.go: -------------------------------------------------------------------------------- 1 | package daemon 2 | 3 | import ( 4 | "ocfcore/internal/common" 5 | "ocfcore/internal/server" 6 | 7 | "github.com/getsentry/sentry-go" 8 | "github.com/spf13/viper" 9 | ) 10 | 11 | func Start() { 12 | err := sentry.Init(sentry.ClientOptions{ 13 | Dsn: common.BuildSecret.SentryDSN, 14 | // Set TracesSampleRate to 1.0 to capture 100% 15 | // of transactions for performance monitoring. 16 | // We recommend adjusting this value in production, 17 | TracesSampleRate: 1.0, 18 | }) 19 | if err != nil { 20 | common.Logger.Error("sentry.Init: %s", err) 21 | } 22 | common.Logger.Info("Wallet: ", viper.Get("wallet.account")) 23 | StartTicker() 24 | server.StartServer() 25 | } 26 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/access/client.go: -------------------------------------------------------------------------------- 1 | package access 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "ocfcore/internal/database" 7 | ) 8 | 9 | type DBClient struct { 10 | database.Client 11 | } 12 | 13 | var dbClient *DBClient 14 | 15 | func NewDBClient() *DBClient { 16 | if dbClient == nil { 17 | dbClient, err := database.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") 18 | if err != nil { 19 | log.Fatalf("failed opening connection to sqlite: %v", err) 20 | } 21 | // Run the auto migration tool. 22 | if err := dbClient.Schema.Create(context.Background()); err != nil { 23 | log.Fatalf("failed creating schema resources: %v", err) 24 | } 25 | } 26 | return dbClient 27 | } 28 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/access/node.go: -------------------------------------------------------------------------------- 1 | package access 2 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/config.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package database 4 | 5 | import ( 6 | "entgo.io/ent" 7 | "entgo.io/ent/dialect" 8 | ) 9 | 10 | // Option function to configure the client. 11 | type Option func(*config) 12 | 13 | // Config is the configuration for the client and its builder. 14 | type config struct { 15 | // driver used for executing database requests. 16 | driver dialect.Driver 17 | // debug enable a debug logging. 18 | debug bool 19 | // log used for logging on debug mode. 20 | log func(...any) 21 | // hooks to execute on mutations. 22 | hooks *hooks 23 | } 24 | 25 | // hooks per client, for fast access. 26 | type hooks struct { 27 | Node []ent.Hook 28 | } 29 | 30 | // Options applies the options on the config object. 31 | func (c *config) options(opts ...Option) { 32 | for _, opt := range opts { 33 | opt(c) 34 | } 35 | if c.debug { 36 | c.driver = dialect.Debug(c.driver, c.log) 37 | } 38 | } 39 | 40 | // Debug enables debug logging on the ent.Driver. 41 | func Debug() Option { 42 | return func(c *config) { 43 | c.debug = true 44 | } 45 | } 46 | 47 | // Log sets the logging function for debug mode. 48 | func Log(fn func(...any)) Option { 49 | return func(c *config) { 50 | c.log = fn 51 | } 52 | } 53 | 54 | // Driver configures the client driver. 55 | func Driver(driver dialect.Driver) Option { 56 | return func(c *config) { 57 | c.driver = driver 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/context.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package database 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | type clientCtxKey struct{} 10 | 11 | // FromContext returns a Client stored inside a context, or nil if there isn't one. 12 | func FromContext(ctx context.Context) *Client { 13 | c, _ := ctx.Value(clientCtxKey{}).(*Client) 14 | return c 15 | } 16 | 17 | // NewContext returns a new context with the given Client attached. 18 | func NewContext(parent context.Context, c *Client) context.Context { 19 | return context.WithValue(parent, clientCtxKey{}, c) 20 | } 21 | 22 | type txCtxKey struct{} 23 | 24 | // TxFromContext returns a Tx stored inside a context, or nil if there isn't one. 25 | func TxFromContext(ctx context.Context) *Tx { 26 | tx, _ := ctx.Value(txCtxKey{}).(*Tx) 27 | return tx 28 | } 29 | 30 | // NewTxContext returns a new context with the given Tx attached. 31 | func NewTxContext(parent context.Context, tx *Tx) context.Context { 32 | return context.WithValue(parent, txCtxKey{}, tx) 33 | } 34 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/enttest/enttest.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package enttest 4 | 5 | import ( 6 | "context" 7 | "ocfcore/internal/database" 8 | // required by schema hooks. 9 | _ "ocfcore/internal/database/runtime" 10 | 11 | "ocfcore/internal/database/migrate" 12 | 13 | "entgo.io/ent/dialect/sql/schema" 14 | ) 15 | 16 | type ( 17 | // TestingT is the interface that is shared between 18 | // testing.T and testing.B and used by enttest. 19 | TestingT interface { 20 | FailNow() 21 | Error(...any) 22 | } 23 | 24 | // Option configures client creation. 25 | Option func(*options) 26 | 27 | options struct { 28 | opts []database.Option 29 | migrateOpts []schema.MigrateOption 30 | } 31 | ) 32 | 33 | // WithOptions forwards options to client creation. 34 | func WithOptions(opts ...database.Option) Option { 35 | return func(o *options) { 36 | o.opts = append(o.opts, opts...) 37 | } 38 | } 39 | 40 | // WithMigrateOptions forwards options to auto migration. 41 | func WithMigrateOptions(opts ...schema.MigrateOption) Option { 42 | return func(o *options) { 43 | o.migrateOpts = append(o.migrateOpts, opts...) 44 | } 45 | } 46 | 47 | func newOptions(opts []Option) *options { 48 | o := &options{} 49 | for _, opt := range opts { 50 | opt(o) 51 | } 52 | return o 53 | } 54 | 55 | // Open calls database.Open and auto-run migration. 56 | func Open(t TestingT, driverName, dataSourceName string, opts ...Option) *database.Client { 57 | o := newOptions(opts) 58 | c, err := database.Open(driverName, dataSourceName, o.opts...) 59 | if err != nil { 60 | t.Error(err) 61 | t.FailNow() 62 | } 63 | migrateSchema(t, c, o) 64 | return c 65 | } 66 | 67 | // NewClient calls database.NewClient and auto-run migration. 68 | func NewClient(t TestingT, opts ...Option) *database.Client { 69 | o := newOptions(opts) 70 | c := database.NewClient(o.opts...) 71 | migrateSchema(t, c, o) 72 | return c 73 | } 74 | func migrateSchema(t TestingT, c *database.Client, o *options) { 75 | tables, err := schema.CopyTables(migrate.Tables) 76 | if err != nil { 77 | t.Error(err) 78 | t.FailNow() 79 | } 80 | if err := migrate.Create(context.Background(), c.Schema, tables, o.migrateOpts...); err != nil { 81 | t.Error(err) 82 | t.FailNow() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/generate.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema 4 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/migrate/migrate.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "io" 9 | 10 | "entgo.io/ent/dialect" 11 | "entgo.io/ent/dialect/sql/schema" 12 | ) 13 | 14 | var ( 15 | // WithGlobalUniqueID sets the universal ids options to the migration. 16 | // If this option is enabled, ent migration will allocate a 1<<32 range 17 | // for the ids of each entity (table). 18 | // Note that this option cannot be applied on tables that already exist. 19 | WithGlobalUniqueID = schema.WithGlobalUniqueID 20 | // WithDropColumn sets the drop column option to the migration. 21 | // If this option is enabled, ent migration will drop old columns 22 | // that were used for both fields and edges. This defaults to false. 23 | WithDropColumn = schema.WithDropColumn 24 | // WithDropIndex sets the drop index option to the migration. 25 | // If this option is enabled, ent migration will drop old indexes 26 | // that were defined in the schema. This defaults to false. 27 | // Note that unique constraints are defined using `UNIQUE INDEX`, 28 | // and therefore, it's recommended to enable this option to get more 29 | // flexibility in the schema changes. 30 | WithDropIndex = schema.WithDropIndex 31 | // WithForeignKeys enables creating foreign-key in schema DDL. This defaults to true. 32 | WithForeignKeys = schema.WithForeignKeys 33 | ) 34 | 35 | // Schema is the API for creating, migrating and dropping a schema. 36 | type Schema struct { 37 | drv dialect.Driver 38 | } 39 | 40 | // NewSchema creates a new schema client. 41 | func NewSchema(drv dialect.Driver) *Schema { return &Schema{drv: drv} } 42 | 43 | // Create creates all schema resources. 44 | func (s *Schema) Create(ctx context.Context, opts ...schema.MigrateOption) error { 45 | return Create(ctx, s, Tables, opts...) 46 | } 47 | 48 | // Create creates all table resources using the given schema driver. 49 | func Create(ctx context.Context, s *Schema, tables []*schema.Table, opts ...schema.MigrateOption) error { 50 | migrate, err := schema.NewMigrate(s.drv, opts...) 51 | if err != nil { 52 | return fmt.Errorf("ent/migrate: %w", err) 53 | } 54 | return migrate.Create(ctx, tables...) 55 | } 56 | 57 | // WriteTo writes the schema changes to w instead of running them against the database. 58 | // 59 | // if err := client.Schema.WriteTo(context.Background(), os.Stdout); err != nil { 60 | // log.Fatal(err) 61 | // } 62 | func (s *Schema) WriteTo(ctx context.Context, w io.Writer, opts ...schema.MigrateOption) error { 63 | return Create(ctx, &Schema{drv: &schema.WriteDriver{Writer: w, Driver: s.drv}}, Tables, opts...) 64 | } 65 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/migrate/schema.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package migrate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql/schema" 7 | "entgo.io/ent/schema/field" 8 | ) 9 | 10 | var ( 11 | // NodesColumns holds the columns for the "nodes" table. 12 | NodesColumns = []*schema.Column{ 13 | {Name: "id", Type: field.TypeInt, Increment: true}, 14 | {Name: "peer_id", Type: field.TypeString, Default: "unknown"}, 15 | {Name: "status", Type: field.TypeString, Default: "unknown"}, 16 | } 17 | // NodesTable holds the schema information for the "nodes" table. 18 | NodesTable = &schema.Table{ 19 | Name: "nodes", 20 | Columns: NodesColumns, 21 | PrimaryKey: []*schema.Column{NodesColumns[0]}, 22 | } 23 | // Tables holds all the tables in the schema. 24 | Tables = []*schema.Table{ 25 | NodesTable, 26 | } 27 | ) 28 | 29 | func init() { 30 | } 31 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/node/node.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package node 4 | 5 | const ( 6 | // Label holds the string label denoting the node type in the database. 7 | Label = "node" 8 | // FieldID holds the string denoting the id field in the database. 9 | FieldID = "id" 10 | // FieldPeerId holds the string denoting the peerid field in the database. 11 | FieldPeerId = "peer_id" 12 | // FieldStatus holds the string denoting the status field in the database. 13 | FieldStatus = "status" 14 | // Table holds the table name of the node in the database. 15 | Table = "nodes" 16 | ) 17 | 18 | // Columns holds all SQL columns for node fields. 19 | var Columns = []string{ 20 | FieldID, 21 | FieldPeerId, 22 | FieldStatus, 23 | } 24 | 25 | // ValidColumn reports if the column name is valid (part of the table columns). 26 | func ValidColumn(column string) bool { 27 | for i := range Columns { 28 | if column == Columns[i] { 29 | return true 30 | } 31 | } 32 | return false 33 | } 34 | 35 | var ( 36 | // DefaultPeerId holds the default value on creation for the "peerId" field. 37 | DefaultPeerId string 38 | // DefaultStatus holds the default value on creation for the "status" field. 39 | DefaultStatus string 40 | ) 41 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/node_delete.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package database 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "ocfcore/internal/database/node" 9 | "ocfcore/internal/database/predicate" 10 | 11 | "entgo.io/ent/dialect/sql" 12 | "entgo.io/ent/dialect/sql/sqlgraph" 13 | "entgo.io/ent/schema/field" 14 | ) 15 | 16 | // NodeDelete is the builder for deleting a Node entity. 17 | type NodeDelete struct { 18 | config 19 | hooks []Hook 20 | mutation *NodeMutation 21 | } 22 | 23 | // Where appends a list predicates to the NodeDelete builder. 24 | func (nd *NodeDelete) Where(ps ...predicate.Node) *NodeDelete { 25 | nd.mutation.Where(ps...) 26 | return nd 27 | } 28 | 29 | // Exec executes the deletion query and returns how many vertices were deleted. 30 | func (nd *NodeDelete) Exec(ctx context.Context) (int, error) { 31 | var ( 32 | err error 33 | affected int 34 | ) 35 | if len(nd.hooks) == 0 { 36 | affected, err = nd.sqlExec(ctx) 37 | } else { 38 | var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { 39 | mutation, ok := m.(*NodeMutation) 40 | if !ok { 41 | return nil, fmt.Errorf("unexpected mutation type %T", m) 42 | } 43 | nd.mutation = mutation 44 | affected, err = nd.sqlExec(ctx) 45 | mutation.done = true 46 | return affected, err 47 | }) 48 | for i := len(nd.hooks) - 1; i >= 0; i-- { 49 | if nd.hooks[i] == nil { 50 | return 0, fmt.Errorf("database: uninitialized hook (forgotten import database/runtime?)") 51 | } 52 | mut = nd.hooks[i](mut) 53 | } 54 | if _, err := mut.Mutate(ctx, nd.mutation); err != nil { 55 | return 0, err 56 | } 57 | } 58 | return affected, err 59 | } 60 | 61 | // ExecX is like Exec, but panics if an error occurs. 62 | func (nd *NodeDelete) ExecX(ctx context.Context) int { 63 | n, err := nd.Exec(ctx) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return n 68 | } 69 | 70 | func (nd *NodeDelete) sqlExec(ctx context.Context) (int, error) { 71 | _spec := &sqlgraph.DeleteSpec{ 72 | Node: &sqlgraph.NodeSpec{ 73 | Table: node.Table, 74 | ID: &sqlgraph.FieldSpec{ 75 | Type: field.TypeInt, 76 | Column: node.FieldID, 77 | }, 78 | }, 79 | } 80 | if ps := nd.mutation.predicates; len(ps) > 0 { 81 | _spec.Predicate = func(selector *sql.Selector) { 82 | for i := range ps { 83 | ps[i](selector) 84 | } 85 | } 86 | } 87 | affected, err := sqlgraph.DeleteNodes(ctx, nd.driver, _spec) 88 | if err != nil && sqlgraph.IsConstraintError(err) { 89 | err = &ConstraintError{msg: err.Error(), wrap: err} 90 | } 91 | return affected, err 92 | } 93 | 94 | // NodeDeleteOne is the builder for deleting a single Node entity. 95 | type NodeDeleteOne struct { 96 | nd *NodeDelete 97 | } 98 | 99 | // Exec executes the deletion query. 100 | func (ndo *NodeDeleteOne) Exec(ctx context.Context) error { 101 | n, err := ndo.nd.Exec(ctx) 102 | switch { 103 | case err != nil: 104 | return err 105 | case n == 0: 106 | return &NotFoundError{node.Label} 107 | default: 108 | return nil 109 | } 110 | } 111 | 112 | // ExecX is like Exec, but panics if an error occurs. 113 | func (ndo *NodeDeleteOne) ExecX(ctx context.Context) { 114 | ndo.nd.ExecX(ctx) 115 | } 116 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/predicate/predicate.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package predicate 4 | 5 | import ( 6 | "entgo.io/ent/dialect/sql" 7 | ) 8 | 9 | // Node is the predicate function for node builders. 10 | type Node func(*sql.Selector) 11 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package database 4 | 5 | import ( 6 | "ocfcore/internal/database/node" 7 | "ocfcore/internal/database/schema" 8 | ) 9 | 10 | // The init function reads all schema descriptors with runtime code 11 | // (default values, validators, hooks and policies) and stitches it 12 | // to their package variables. 13 | func init() { 14 | nodeFields := schema.Node{}.Fields() 15 | _ = nodeFields 16 | // nodeDescPeerId is the schema descriptor for peerId field. 17 | nodeDescPeerId := nodeFields[0].Descriptor() 18 | // node.DefaultPeerId holds the default value on creation for the peerId field. 19 | node.DefaultPeerId = nodeDescPeerId.Default.(string) 20 | // nodeDescStatus is the schema descriptor for status field. 21 | nodeDescStatus := nodeFields[1].Descriptor() 22 | // node.DefaultStatus holds the default value on creation for the status field. 23 | node.DefaultStatus = nodeDescStatus.Default.(string) 24 | } 25 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/runtime/runtime.go: -------------------------------------------------------------------------------- 1 | // Code generated by ent, DO NOT EDIT. 2 | 3 | package runtime 4 | 5 | // The schema-stitching logic is generated in ocfcore/internal/database/runtime.go 6 | 7 | const ( 8 | Version = "v0.11.4" // Version of ent codegen. 9 | Sum = "h1:grwVY0fp31BZ6oEo3YrXenAuv8VJmEw7F/Bi6WqeH3Q=" // Sum of ent codegen. 10 | ) 11 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/database/schema/node.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "entgo.io/ent" 5 | "entgo.io/ent/schema/field" 6 | ) 7 | 8 | // Node holds the schema definition for the Node entity. 9 | type Node struct { 10 | ent.Schema 11 | } 12 | 13 | // Fields of the ocfcore node. 14 | func (Node) Fields() []ent.Field { 15 | return []ent.Field{ 16 | field.String("peerId"). 17 | Default("unknown"), 18 | field.String("status").Default("unknown"), 19 | } 20 | } 21 | 22 | // Edges of the Node. 23 | func (Node) Edges() []ent.Edge { 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/pkgs/weaver/README.md: -------------------------------------------------------------------------------- 1 | # Network Weaver 2 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/profiler/storage.go: -------------------------------------------------------------------------------- 1 | package profiler 2 | 3 | import ( 4 | "fmt" 5 | "ocfcore/internal/common" 6 | "ocfcore/internal/common/structs" 7 | "time" 8 | 9 | "github.com/nakabonne/tstorage" 10 | ) 11 | 12 | type TStorage struct { 13 | Client tstorage.Storage 14 | } 15 | 16 | var StorageClient TStorage 17 | 18 | func NewStorageClient() TStorage { 19 | if StorageClient.Client == nil { 20 | StorageClient.Client, _ = tstorage.NewStorage( 21 | tstorage.WithTimestampPrecision(tstorage.Seconds), 22 | tstorage.WithDataPath("./data"), 23 | ) 24 | } 25 | return StorageClient 26 | } 27 | 28 | func AddPoint(host string, metric string, timestamp int64, value float64) { 29 | labels := []tstorage.Label{ 30 | {Name: "host", Value: host}, 31 | } 32 | err := NewStorageClient().Client.InsertRows([]tstorage.Row{ 33 | { 34 | Metric: metric, 35 | Labels: labels, 36 | DataPoint: tstorage.DataPoint{Timestamp: timestamp, Value: value}, 37 | }, 38 | }) 39 | if err != nil { 40 | common.Logger.Error(err) 41 | } 42 | } 43 | 44 | func QueryPoints(start int64, end int64, metric string, host string) []*tstorage.DataPoint { 45 | labels := []tstorage.Label{ 46 | {Name: "host", Value: host}, 47 | } 48 | rows, _ := NewStorageClient().Client.Select( 49 | metric, 50 | labels, 51 | start, 52 | end, 53 | ) 54 | return rows 55 | } 56 | 57 | func AggregateAverageUtilization(host string, duration time.Duration) float64 { 58 | // current timestamp 59 | end := time.Now().Unix() 60 | // start timestamp = current timestamp - duration 61 | start := end - int64(duration.Seconds()) 62 | labels := []tstorage.Label{ 63 | {Name: "host", Value: host}, 64 | } 65 | rows, _ := NewStorageClient().Client.Select( 66 | "GPU Utilization", 67 | labels, 68 | start, 69 | end, 70 | ) 71 | var sum float64 72 | for _, row := range rows { 73 | sum += row.Value 74 | } 75 | return sum / float64(len(rows)) 76 | } 77 | 78 | func QueryCardSummary(host string) (structs.CardMetrics, error) { 79 | var metrics structs.CardMetrics 80 | end := time.Now().Unix() 81 | start := end - int64(30*time.Second.Seconds()) 82 | labels := []tstorage.Label{ 83 | {Name: "host", Value: host}, 84 | } 85 | var err error 86 | keywords := []string{"GPU Utilization", "Power Usage", "Used Memory", "Available Memory"} 87 | for _, keyword := range keywords { 88 | rows, _ := NewStorageClient().Client.Select( 89 | keyword, 90 | labels, 91 | start, 92 | end, 93 | ) 94 | if len(rows) > 0 { 95 | if keyword == "GPU Utilization" { 96 | metrics.GPUUtilization = rows[len(rows)-1].Value 97 | } 98 | if keyword == "Power Usage" { 99 | metrics.PowerUsage = rows[len(rows)-1].Value 100 | } 101 | if keyword == "Used Memory" { 102 | metrics.UsedMemory = rows[len(rows)-1].Value 103 | } 104 | if keyword == "Available Memory" { 105 | metrics.AvailableMemory = rows[len(rows)-1].Value 106 | } 107 | } else { 108 | err = fmt.Errorf("no data for %s", keyword) 109 | } 110 | } 111 | return metrics, err 112 | } 113 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Relaxed-System-Lab/HexGen/949414d1be18e504752508c0559cb966ddf999f4/third_party/ocf/src/ocf-core/internal/protocol/README.md -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/bootstrap.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "github.com/multiformats/go-multiaddr" 5 | "github.com/spf13/viper" 6 | ) 7 | 8 | const defaultBootstrapPeerAddr = "/ip4/206.189.249.2/tcp/43905/p2p/QmbY2bk4JGkD6yoW9DriYsFqHqqSjZh7AyyuXeYYKFDXba" 9 | 10 | func getDefaultBootstrapPeers() []multiaddr.Multiaddr { 11 | var DefaultBootstrapPeers []multiaddr.Multiaddr 12 | bootstrapAddrs := viper.GetStringSlice("bootstrap.addrs") 13 | if bootstrapAddrs == nil { 14 | bootstrapAddrs = append(bootstrapAddrs, defaultBootstrapPeerAddr) 15 | } 16 | for _, s := range bootstrapAddrs { 17 | ma, err := multiaddr.NewMultiaddr(s) 18 | if err != nil { 19 | panic(err) 20 | } 21 | DefaultBootstrapPeers = append(DefaultBootstrapPeers, ma) 22 | } 23 | return DefaultBootstrapPeers 24 | } 25 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/dht.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "context" 5 | "ocfcore/internal/common" 6 | 7 | dht "github.com/libp2p/go-libp2p-kad-dht" 8 | "github.com/libp2p/go-libp2p/core/host" 9 | "github.com/libp2p/go-libp2p/core/peer" 10 | "github.com/multiformats/go-multiaddr" 11 | ) 12 | 13 | func NewDHT(ctx context.Context, host host.Host, bootstrapPeers []multiaddr.Multiaddr) (*dht.IpfsDHT, error) { 14 | var options []dht.Option 15 | 16 | if len(bootstrapPeers) == 0 { 17 | options = append(options, dht.Mode(dht.ModeServer)) 18 | } 19 | 20 | kdht, err := dht.New(ctx, host, options...) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | if err = kdht.Bootstrap(ctx); err != nil { 26 | return nil, err 27 | } 28 | 29 | for _, peerAddr := range bootstrapPeers { 30 | peerinfo, _ := peer.AddrInfoFromP2pAddr(peerAddr) 31 | go func() { 32 | if err := host.Connect(ctx, *peerinfo); err != nil { 33 | common.Logger.Warn("Error while connecting to node %q: %-v", peerinfo, err) 34 | } 35 | }() 36 | } 37 | 38 | return kdht, nil 39 | } 40 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/discovery.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "context" 5 | "ocfcore/internal/common" 6 | "sync" 7 | "time" 8 | 9 | dht "github.com/libp2p/go-libp2p-kad-dht" 10 | "github.com/libp2p/go-libp2p/core/host" 11 | "github.com/libp2p/go-libp2p/core/network" 12 | "github.com/libp2p/go-libp2p/core/peerstore" 13 | routing "github.com/libp2p/go-libp2p/p2p/discovery/routing" 14 | ) 15 | 16 | var discoverLock sync.Mutex 17 | 18 | // DiscoverNew is a function that keeps updating DNT with the latest information about the network. 19 | func Discover(ctx context.Context, h host.Host, dht *dht.IpfsDHT, rendezvous string) { 20 | GetNodeTable().Update(Peer{ 21 | PeerID: GetP2PNode().ID().String(), 22 | Status: CONNECTED, 23 | }) 24 | var disconnected []string 25 | discoverLock.Lock() 26 | defer discoverLock.Unlock() 27 | var routingDiscovery = routing.NewRoutingDiscovery(dht) 28 | routingDiscovery.Advertise(ctx, rendezvous) 29 | ticker := time.NewTicker(time.Second * 1) 30 | defer ticker.Stop() 31 | for { 32 | select { 33 | case <-ctx.Done(): 34 | return 35 | case <-ticker.C: 36 | // cleaning disconnected peers 37 | dntPeers := GetNodeTable().Peers 38 | for _, p := range dntPeers { 39 | if p.PeerID != h.ID().String() { 40 | storedPeers := h.Peerstore().Peers() 41 | for _, sp := range storedPeers { 42 | if p.PeerID == sp.String() && h.Network().Connectedness(sp) == network.NotConnected { 43 | disconnected = append(disconnected, p.PeerID) 44 | break 45 | } 46 | } 47 | } 48 | } 49 | GetNodeTable().RemoveDisconnectedPeers(disconnected) 50 | disconnected = []string{} 51 | peers, err := routingDiscovery.FindPeers(ctx, rendezvous) 52 | if err != nil { 53 | common.Logger.Error(err) 54 | } 55 | for p := range peers { 56 | if p.ID == h.ID() { 57 | continue 58 | } 59 | if h.Network().Connectedness(p.ID) != network.Connected { 60 | _, err := h.Network().DialPeer(ctx, p.ID) 61 | if err != nil { 62 | continue 63 | } 64 | } 65 | if h.Network().Connectedness(p.ID) == network.Connected { 66 | h.Peerstore().AddAddrs(p.ID, p.Addrs, peerstore.PermanentAddrTTL) 67 | GetNodeTable().Update(Peer{ 68 | PeerID: p.ID.String(), 69 | Status: CONNECTED, 70 | }) 71 | } 72 | } 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/handler.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "ocfcore/internal/common" 7 | 8 | gostream "github.com/libp2p/go-libp2p-gostream" 9 | p2phttp "github.com/libp2p/go-libp2p-http" 10 | dht "github.com/libp2p/go-libp2p-kad-dht" 11 | "github.com/multiformats/go-multiaddr" 12 | "github.com/spf13/viper" 13 | ) 14 | 15 | func P2PListener() net.Listener { 16 | ctx := context.Background() 17 | host := GetP2PNode() 18 | var dhtc *dht.IpfsDHT 19 | var err error 20 | if viper.GetString("bootstrap.mode") == "standalone" { 21 | common.Logger.Info("standalone mode") 22 | dhtc, err = NewDHT(ctx, host, []multiaddr.Multiaddr{}) 23 | } else { 24 | dhtc, err = NewDHT(ctx, host, getDefaultBootstrapPeers()) 25 | } 26 | if err != nil { 27 | panic(err) 28 | } 29 | dhtc.Bootstrap(ctx) 30 | go Discover(ctx, host, dhtc, common.JSONVersion.Version+"/"+viper.GetString("bootstrap.rendezvous")) 31 | common.Logger.Info("ocfcore peer ID: ", host.ID()) 32 | common.Logger.Info("ocfcore peer Addr: ", host.Addrs()) 33 | listener, _ := gostream.Listen(host, p2phttp.DefaultP2PProtocol) 34 | return listener 35 | } 36 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/host.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "fmt" 7 | mrand "math/rand" 8 | "ocfcore/internal/common" 9 | "strconv" 10 | "sync" 11 | "time" 12 | 13 | "github.com/libp2p/go-libp2p" 14 | "github.com/libp2p/go-libp2p/core/crypto" 15 | "github.com/libp2p/go-libp2p/core/host" 16 | "github.com/libp2p/go-libp2p/p2p/net/connmgr" 17 | "github.com/libp2p/go-libp2p/p2p/security/noise" 18 | libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" 19 | "github.com/spf13/viper" 20 | ) 21 | 22 | var P2PNode *host.Host 23 | var once sync.Once 24 | 25 | func GetP2PNode() host.Host { 26 | once.Do(func() { 27 | ctx := context.Background() 28 | var err error 29 | seed := viper.GetString("seed") 30 | // try to parse the seed as int64 31 | seedInt, err := strconv.ParseInt(seed, 10, 64) 32 | if err != nil { 33 | panic(err) 34 | } 35 | host, err := newHost(ctx, seedInt) 36 | P2PNode = &host 37 | if err != nil { 38 | panic(err) 39 | } 40 | }) 41 | return *P2PNode 42 | } 43 | 44 | func newHost(ctx context.Context, seed int64) (host.Host, error) { 45 | connmgr, err := connmgr.NewConnManager( 46 | 100, // Lowwater 47 | 400, // HighWater, 48 | connmgr.WithGracePeriod(time.Minute), 49 | ) 50 | if err != nil { 51 | common.Logger.Error("Error while creating connection manager: %v", err) 52 | } 53 | var priv crypto.PrivKey 54 | fmt.Println("seed: ", seed) 55 | // try to load the private key from file 56 | if seed == 0 { 57 | // try to load from the file 58 | priv = loadKeyFromFile() 59 | if priv == nil { 60 | r := rand.Reader 61 | priv, _, err = crypto.GenerateKeyPairWithReader(crypto.RSA, 2048, r) 62 | if err != nil { 63 | return nil, err 64 | } 65 | } 66 | } else { 67 | r := mrand.New(mrand.NewSource(seed)) 68 | priv, _, err = crypto.GenerateKeyPairWithReader(crypto.RSA, 2048, r) 69 | if err != nil { 70 | return nil, err 71 | } 72 | } 73 | // persist private key 74 | writeKeyToFile(priv) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | return libp2p.New( 80 | libp2p.DefaultTransports, 81 | libp2p.Identity(priv), 82 | libp2p.ConnectionManager(connmgr), 83 | libp2p.NATPortMap(), 84 | libp2p.ListenAddrStrings( 85 | "/ip4/0.0.0.0/tcp/43905", 86 | "/ip4/0.0.0.0/udp/59820/quic", 87 | ), 88 | libp2p.Security(libp2ptls.ID, libp2ptls.New), 89 | libp2p.Security(noise.ID, noise.New), 90 | libp2p.EnableNATService(), 91 | libp2p.EnableRelay(), 92 | libp2p.EnableHolePunching(), 93 | libp2p.ForceReachabilityPublic(), 94 | ) 95 | } 96 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/key.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "ocfcore/internal/common" 8 | 9 | "github.com/libp2p/go-libp2p/core/crypto" 10 | "github.com/mitchellh/go-homedir" 11 | ) 12 | 13 | func writeKeyToFile(priv crypto.PrivKey) { 14 | keyData, err := crypto.MarshalPrivateKey(priv) 15 | if err != nil { 16 | common.Logger.Error("Error while marshalling private key: ", err) 17 | } 18 | home, err := homedir.Dir() 19 | if err != nil { 20 | fmt.Println(err) 21 | os.Exit(1) 22 | } 23 | keyPath := path.Join(home, ".tom", "keys", "id") 24 | err = os.MkdirAll(path.Dir(keyPath), os.ModePerm) 25 | if err != nil { 26 | common.Logger.Error("Could not create keys directory", "error", err) 27 | os.Exit(1) 28 | } 29 | err = os.WriteFile(keyPath, keyData, 0600) 30 | if err != nil { 31 | common.Logger.Error("Could not write key to file", err) 32 | os.Exit(1) 33 | } 34 | } 35 | 36 | func loadKeyFromFile() crypto.PrivKey { 37 | home, err := homedir.Dir() 38 | if err != nil { 39 | fmt.Println(err) 40 | return nil 41 | } 42 | keyPath := path.Join(home, ".tom", "keys", "id") 43 | keyData, err := os.ReadFile(keyPath) 44 | if err != nil { 45 | return nil 46 | } 47 | priv, err := crypto.UnmarshalPrivateKey(keyData) 48 | if err != nil { 49 | common.Logger.Error("Error while unmarshalling private key: ", err) 50 | return nil 51 | } 52 | return priv 53 | } 54 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/p2p/remote.go: -------------------------------------------------------------------------------- 1 | package p2p 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "ocfcore/internal/common" 7 | "ocfcore/internal/protocol/remote" 8 | 9 | "github.com/spf13/viper" 10 | ) 11 | 12 | func BroadcastPeerOffering(peer Peer) { 13 | dnt := GetNodeTable() 14 | common.Logger.Info("Broadcasting peer offering", "peer", peer) 15 | for _, remote := range dnt.Peers { 16 | if peer.PeerID != remote.PeerID { 17 | UpdateRemoteNodeTable(remote.PeerID, peer) 18 | } 19 | } 20 | } 21 | 22 | func UpdateRemoteNodeTable(peerId string, peer Peer) error { 23 | peer.Owner = viper.GetString("wallet.account") 24 | remoteAddr := fmt.Sprintf("http://localhost:%s/api/v1/proxy/%s/api/v1/status/peers", viper.GetString("port"), peerId) 25 | reqString, err := json.Marshal(peer) 26 | if err != nil { 27 | return err 28 | } 29 | _, err = remote.HTTPPost(remoteAddr, reqString) 30 | if err != nil { 31 | common.Logger.Info("Error while updating remote node table", "error", err) 32 | } 33 | return err 34 | } 35 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/remote/client.go: -------------------------------------------------------------------------------- 1 | package remote 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | 7 | "github.com/sethgrid/pester" 8 | ) 9 | 10 | var client *pester.Client 11 | 12 | func NewHTTPClient() *pester.Client { 13 | if client == nil { 14 | client = pester.New() 15 | client.MaxRetries = 1 16 | client.Concurrency = 1 17 | client.Backoff = pester.ExponentialJitterBackoff 18 | } 19 | return client 20 | } 21 | 22 | func HTTPPost(remoteAddr string, req []byte) (string, error) { 23 | payload := strings.NewReader(string(req)) 24 | resp, err := NewHTTPClient().Post(remoteAddr, "application/json", payload) 25 | if err != nil { 26 | return "nil", err 27 | } 28 | b, err := io.ReadAll(resp.Body) 29 | if err != nil { 30 | return string(b), err 31 | } 32 | return string(b), nil 33 | } 34 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/protocol/rpc/handler.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/auth/authentication.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "ocfcore/internal/common" 5 | 6 | "github.com/authorizerdev/authorizer-go" 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func AuthorizeMiddleware() gin.HandlerFunc { 11 | return func(c *gin.Context) { 12 | defaultHeaders := map[string]string{} 13 | authorizerClient, err := authorizer.NewAuthorizerClient(common.BuildSecret.AuthClientID, common.BuildSecret.AuthURL, "", defaultHeaders) 14 | if err != nil { 15 | // unauthorized 16 | c.AbortWithStatusJSON(401, "unauthorized - unable to create authorizer client") 17 | return 18 | } 19 | profile, err := authorizerClient.GetProfile(map[string]string{ 20 | "Authorization": c.Request.Header.Get("Authorization"), 21 | }) 22 | if err != nil { 23 | // unauthorized 24 | c.AbortWithStatusJSON(401, "unauthorized - unable to get profile") 25 | common.Logger.Error(err) 26 | return 27 | } 28 | common.Logger.Info(profile.Roles) 29 | c.Next() 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/cors.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "ocfcore/internal/common" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func beforeResponse() gin.HandlerFunc { 11 | return func(c *gin.Context) { 12 | c.Writer.Header().Set("tom-version", common.JSONVersion.Commit) 13 | // if not set 14 | if c.Writer.Header().Get("Access-Control-Allow-Origin") != "*" { 15 | c.Writer.Header().Set("Access-Control-Allow-Origin", "*") 16 | c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") 17 | c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS") 18 | c.Writer.Header().Set("Access-Control-Allow-Headers", "authorization, origin, content-type, accept") 19 | } 20 | if c.Request.Method == "OPTIONS" { 21 | c.Writer.WriteHeader(http.StatusOK) 22 | } 23 | } 24 | } 25 | 26 | func rewriteHeader() func(r *http.Response) error { 27 | return func(r *http.Response) error { 28 | r.Header.Del("Access-Control-Allow-Origin") 29 | r.Header.Del("Access-Control-Allow-Credentials") 30 | r.Header.Del("Access-Control-Allow-Methods") 31 | r.Header.Del("Access-Control-Allow-Headers") 32 | return nil 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/forward.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | func ErrorHandler(res http.ResponseWriter, req *http.Request, err error) { 9 | res.Write([]byte(fmt.Sprintf("ERROR: %s", err.Error()))) 10 | } 11 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/request.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | "ocfcore/internal/common/requests" 6 | "ocfcore/internal/common/structs" 7 | "ocfcore/internal/protocol/p2p" 8 | "ocfcore/internal/server/queue" 9 | 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | var rrIndex int 14 | 15 | type InferenceResponse struct { 16 | Msg string `json:"msg"` 17 | Data string `json:"data"` 18 | } 19 | 20 | func InferenceRequest(c *gin.Context) { 21 | var request structs.InferenceStruct 22 | err := c.BindJSON(&request) 23 | if err != nil { 24 | c.JSON(400, gin.H{"error": err.Error()}) 25 | return 26 | } 27 | jsonRequest, err := json.Marshal(request) 28 | if err != nil { 29 | c.JSON(500, gin.H{"error": err.Error()}) 30 | return 31 | } 32 | topic := "inference:" + request.UniqueModelName 33 | msg, err := queue.Publish(topic, jsonRequest) 34 | if err != nil { 35 | c.JSON(500, gin.H{"error": err.Error()}) 36 | return 37 | } 38 | // wait until the inference is done 39 | c.JSON(200, gin.H{"message": "ok", "data": string(msg.Data)}) 40 | } 41 | 42 | // AutoInferenceRequest is a function that handles the inference request, but dispatches it to the correct worker 43 | // todo: we should have the ability to "cleverly" dispatch the inference request to the "fastest" worker 44 | func AutoInferenceRequest(c *gin.Context) { 45 | var request structs.InferenceStruct 46 | err := c.BindJSON(&request) 47 | if err != nil { 48 | c.JSON(400, gin.H{"error": err.Error()}) 49 | return 50 | } 51 | // find workers 52 | table := p2p.GetNodeTable() 53 | topic := "inference:" + request.UniqueModelName 54 | providers := table.FindProviders(topic) 55 | if len(providers) == 0 { 56 | c.JSON(500, gin.H{"error": "no worker available"}) 57 | return 58 | } 59 | 60 | // pick a worker by round robin method 61 | var scapegoat p2p.Peer 62 | if rrIndex < len(providers) { 63 | scapegoat = providers[rrIndex] 64 | } else { 65 | // in case someone leave 66 | rrIndex = len(providers) - 1 67 | scapegoat = providers[rrIndex] 68 | } 69 | rrIndex = (rrIndex + 1) % len(providers) 70 | // now forward request to scapegoat 71 | res, err := requests.ForwardInferenceRequest(scapegoat.PeerID, request) 72 | if err != nil { 73 | c.JSON(500, gin.H{"error": err.Error()}) 74 | } 75 | var response InferenceResponse 76 | err = json.Unmarshal([]byte(res), &response) 77 | if err != nil { 78 | c.JSON(500, gin.H{"error": err.Error()}) 79 | } 80 | c.JSON(200, gin.H{"message": "ok", "data": response.Data}) 81 | } 82 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/rpc.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | rpc "github.com/ethereum/go-ethereum/rpc" 5 | ) 6 | 7 | func NewRPCServer() *rpc.Server { 8 | workerService := new(WorkerService) 9 | rpcServer := rpc.NewServer() 10 | err := rpcServer.RegisterName("worker", workerService) 11 | if err != nil { 12 | panic(err) 13 | } 14 | return rpcServer 15 | } 16 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "ocfcore/internal/common" 6 | "ocfcore/internal/protocol/p2p" 7 | "ocfcore/internal/server/auth" 8 | "ocfcore/internal/server/queue" 9 | "sync" 10 | 11 | "github.com/gin-gonic/gin" 12 | "github.com/spf13/viper" 13 | ) 14 | 15 | func StartServer() { 16 | PrintWelcomeMessage() 17 | var wg sync.WaitGroup 18 | gin.SetMode(gin.ReleaseMode) 19 | r := gin.Default() 20 | r.Use(beforeResponse()) 21 | r.Use(gin.Recovery()) 22 | v1 := r.Group("/api/v1") 23 | { 24 | ocfcoreStatus := v1.Group("/status") 25 | { 26 | ocfcoreStatus.GET("/health", healthStatusCheck) 27 | ocfcoreStatus.GET("/worker/:workerId/:metric", GetWorkerStatus) 28 | ocfcoreStatus.GET("/workers", GetWorkerHub) 29 | ocfcoreStatus.GET("/matchmaking", matchmakingStatus) 30 | ocfcoreStatus.GET("/summary", GetSummary) 31 | ocfcoreStatus.GET("/connections", GetConnections) 32 | ocfcoreStatus.GET("/table", GetWorkloadTable) 33 | ocfcoreStatus.POST("/peers", UpdatePeers) 34 | ocfcoreStatus.GET("/peers", GetPeersInfo) 35 | } 36 | ocfcoreWs := v1.Group("/ws") 37 | { 38 | ocfcoreWs.GET("", 39 | gin.WrapH(NewRPCServer().WebsocketHandler([]string{"*"}))) 40 | } 41 | ocfcoreProxy := v1.Group("/proxy") 42 | { 43 | ocfcoreProxy.PATCH("/:peerId/*path", ForwardHandler) 44 | ocfcoreProxy.POST("/:peerId/*path", ForwardHandler) 45 | ocfcoreProxy.GET("/:peerId/*path", ForwardHandler) 46 | } 47 | ocfcoreThrottle := v1.Group("/controller") 48 | { 49 | ocfcoreThrottle.Use(auth.AuthorizeMiddleware()) 50 | ocfcoreThrottle.POST("/instructions/:workerId", LoadWorkload) 51 | ocfcoreThrottle.GET("/instructions/:workerId", GetWorkloadInstructions) 52 | ocfcoreThrottle.POST("/cluster/nodes", AddClusterNode) 53 | } 54 | ocfcoreRequest := v1.Group("/request") 55 | { 56 | ocfcoreRequest.POST("/inference", AutoInferenceRequest) 57 | ocfcoreRequest.POST("/_inference", InferenceRequest) 58 | } 59 | } 60 | p2plistener := p2p.P2PListener() 61 | go func() { 62 | err := http.Serve(p2plistener, r) 63 | if err != nil { 64 | common.Logger.Error("http.Serve: %s", err) 65 | } 66 | }() 67 | queue.StartQueueServer() 68 | wg.Wait() 69 | err := r.Run("0.0.0.0:" + viper.GetString("port")) 70 | if err != nil { 71 | panic(err) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/throttle.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "ocfcore/internal/common/structs" 5 | ) 6 | 7 | var instructionsHub structs.WorkloadInstructionsHub 8 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/vacuum.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | func DisconnectionDetection() { 4 | // List all connections 5 | // queue.RemoveDisconnectedNode() 6 | } 7 | -------------------------------------------------------------------------------- /third_party/ocf/src/ocf-core/internal/server/welcome.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/common-nighthawk/go-figure" 7 | ) 8 | 9 | func PrintWelcomeMessage() { 10 | myFigure := figure.NewFigure("Open Compute", "isometric1", true) 11 | myFigure.Print() 12 | fmt.Println(">> Join Discord for Discussion: https://discord.gg/3BD3RzK2K2") 13 | fmt.Println(">> Documentation: https://ocf.Anonymous.org") 14 | } 15 | --------------------------------------------------------------------------------