├── GlobalModel_generated.py ├── README.md ├── client_data_allocation.py ├── data_wiz ├── 10 │ ├── local_training_0.json │ ├── local_training_1.json │ ├── local_training_2.json │ ├── local_training_3.json │ ├── local_training_4.json │ ├── local_training_5.json │ ├── local_training_6.json │ ├── local_training_7.json │ ├── local_training_8.json │ ├── local_training_9.json │ └── x └── ss ├── download.py ├── fed_utils ├── __init__.py ├── client.py ├── client_participation_scheduling.py ├── evaluation.py ├── hehe ├── model_aggregation.py └── other.py ├── load.py ├── main.py ├── mmlu_test_14042.jsonl ├── mmlu_test_1444.jsonl ├── new.sh ├── requirements.txt ├── run_wiz.sh ├── templates ├── README.md ├── alpaca.json ├── alpaca_legacy.json ├── alpaca_short.json ├── vigogne.json └── xx └── utils ├── README.md ├── __init__.py ├── callbacks.py ├── prompter.py └── ss /GlobalModel_generated.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | 8 | from peft import ( 9 | PeftModel, 10 | LoraConfig, 11 | get_peft_model, 12 | get_peft_model_state_dict, 13 | prepare_model_for_int8_training, 14 | set_peft_model_state_dict, 15 | ) 16 | 17 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer,AutoTokenizer 18 | from utils.callbacks import Iteratorize, Stream 19 | from utils.prompter import Prompter 20 | if torch.cuda.is_available(): 21 | device = "cuda" 22 | else: 23 | device = "cpu" 24 | 25 | try: 26 | if torch.backends.mps.is_available(): 27 | device = "mps" 28 | except: 29 | pass 30 | 31 | 32 | def main( 33 | load_8bit: bool = False, 34 | base_model: str = "", 35 | lora_weights_path: str = "", 36 | lora_config_path: str= "", # provide only the file path, excluding the file name 'adapter_config.json' 37 | prompt_template: str = "", # The prompt template to use, will default to alpaca. 38 | server_name: str = "127.0.0.1", 39 | share_gradio: bool = False, 40 | ): 41 | base_model = base_model or os.environ.get("BASE_MODEL", "") 42 | assert ( 43 | base_model 44 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" 45 | 46 | prompter = Prompter(prompt_template) 47 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 48 | if not lora_weights_path.endswith(".bin"): 49 | if device == "cuda": 50 | model = LlamaForCausalLM.from_pretrained( 51 | base_model, 52 | load_in_8bit=load_8bit, 53 | torch_dtype=torch.float16, 54 | device_map="auto", 55 | ) 56 | model = PeftModel.from_pretrained( 57 | model, 58 | lora_weights_path, 59 | torch_dtype=torch.float16, 60 | ) 61 | elif device == "mps": 62 | model = LlamaForCausalLM.from_pretrained( 63 | base_model, 64 | device_map={"": device}, 65 | torch_dtype=torch.float16, 66 | ) 67 | model = PeftModel.from_pretrained( 68 | model, 69 | lora_weights_path, 70 | device_map={"": device}, 71 | torch_dtype=torch.float16, 72 | ) 73 | else: 74 | model = LlamaForCausalLM.from_pretrained( 75 | base_model, device_map={"": device}, low_cpu_mem_usage=True 76 | ) 77 | model = PeftModel.from_pretrained( 78 | model, 79 | lora_weights_path, 80 | device_map={"": device}, 81 | ) 82 | else: 83 | model = LlamaForCausalLM.from_pretrained( 84 | base_model, 85 | load_in_8bit=True, 86 | torch_dtype=torch.float16, 87 | device_map="auto", 88 | ) 89 | model = prepare_model_for_int8_training(model) 90 | config = LoraConfig.from_pretrained(lora_config_path) 91 | lora_weights = torch.load(lora_weights_path) 92 | model = PeftModel(model, config) 93 | set_peft_model_state_dict(model,lora_weights,"default") 94 | del lora_weights 95 | 96 | 97 | # unwind broken decapoda-research config 98 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 99 | model.config.bos_token_id = 1 100 | model.config.eos_token_id = 2 101 | 102 | if not load_8bit: 103 | model.half() # seems to fix bugs for some users. 104 | 105 | model.eval() 106 | 107 | 108 | def evaluate( 109 | instruction, 110 | input=None, 111 | temperature=0.1, 112 | top_p=0.75, 113 | top_k=40, 114 | num_beams=4, 115 | max_new_tokens=128, 116 | stream_output=True, 117 | **kwargs, 118 | ): 119 | prompt = prompter.generate_prompt(instruction, input) 120 | inputs = tokenizer(prompt, return_tensors="pt") 121 | input_ids = inputs["input_ids"].to(device) 122 | generation_config = GenerationConfig( 123 | temperature=temperature, 124 | top_p=top_p, 125 | top_k=top_k, 126 | num_beams=num_beams, 127 | **kwargs, 128 | ) 129 | 130 | generate_params = { 131 | "input_ids": input_ids, 132 | "generation_config": generation_config, 133 | "return_dict_in_generate": True, 134 | "output_scores": True, 135 | "max_new_tokens": max_new_tokens, 136 | } 137 | 138 | if stream_output: 139 | # Stream the reply 1 token at a time. 140 | # This is based on the trick of using 'stopping_criteria' to create an iterator, 141 | # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. 142 | 143 | def generate_with_callback(callback=None, **kwargs): 144 | kwargs.setdefault( 145 | "stopping_criteria", transformers.StoppingCriteriaList() 146 | ) 147 | kwargs["stopping_criteria"].append( 148 | Stream(callback_func=callback) 149 | ) 150 | with torch.no_grad(): 151 | model.generate(**kwargs) 152 | 153 | def generate_with_streaming(**kwargs): 154 | return Iteratorize( 155 | generate_with_callback, kwargs, callback=None 156 | ) 157 | 158 | with generate_with_streaming(**generate_params) as generator: 159 | for output in generator: 160 | # new_tokens = len(output) - len(input_ids[0]) 161 | decoded_output = tokenizer.decode(output) 162 | 163 | if output[-1] in [tokenizer.eos_token_id]: 164 | break 165 | 166 | yield prompter.get_response(decoded_output) 167 | return # early return for stream_output 168 | 169 | # Without streaming 170 | with torch.no_grad(): 171 | generation_output = model.generate( 172 | input_ids=input_ids, 173 | generation_config=generation_config, 174 | return_dict_in_generate=True, 175 | output_scores=True, 176 | max_new_tokens=max_new_tokens, 177 | ) 178 | s = generation_output.sequences[0] 179 | output = tokenizer.decode(s) 180 | yield prompter.get_response(output) 181 | 182 | sherpherd_UI=gr.Interface( 183 | fn=evaluate, 184 | inputs=[ 185 | gr.components.Textbox( 186 | lines=2, 187 | label="Instruction", 188 | placeholder="Tell me about alpacas.", 189 | ), 190 | gr.components.Textbox(lines=2, label="Input", placeholder="none"), 191 | gr.components.Slider( 192 | minimum=0, maximum=1, value=0.1, label="Temperature" 193 | ), 194 | gr.components.Slider( 195 | minimum=0, maximum=1, value=0.75, label="Top p" 196 | ), 197 | gr.components.Slider( 198 | minimum=0, maximum=100, step=1, value=40, label="Top k" 199 | ), 200 | gr.components.Slider( 201 | minimum=1, maximum=4, step=1, value=4, label="Beams" 202 | ), 203 | gr.components.Slider( 204 | minimum=1, maximum=2000, step=1, value=128, label="Max tokens" 205 | ), 206 | gr.components.Checkbox(label="Stream output"), 207 | ], 208 | outputs=[ 209 | gr.inputs.Textbox( 210 | lines=5, 211 | label="Output", 212 | ) 213 | ], 214 | title="FederatedGPT-shepherd", 215 | description="Shepherd is a LLM that has been fine-tuned in a federated manner ", 216 | ).queue() 217 | 218 | sherpherd_UI.launch(share=True) 219 | 220 | 221 | 222 | 223 | 224 | if __name__ == "__main__": 225 | fire.Fire(main) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLoRA: Federated Fine-Tuning Large Language Models with Heterogeneous Low-Rank Adaptations 2 | Code of paper : [FLoRA: Federated Fine-Tuning Large Language Models with Heterogeneous Low-Rank Adaptations](https://arxiv.org/pdf/2409.05976). 3 | 4 | You can use this code to fine-tune LLMs with LoRA by WizardLLM dataset or other datasets. 5 | The LoRA fine-tuning method includes FLoRA, FedIT, and Zero-Padding. You can also use heterogeneous LoRA rank settings in FLoRA and Zero-Padding. 6 | 7 | ## Requirments 8 | Install all the packages from requirments.txt 9 | * pip install -r requirements.txt 10 | * git clone https://github.com/EleutherAI/lm-evaluation-harness 11 | * cd lm-evaluation-harness 12 | * pip install -e . 13 | 14 | ## Data 15 | * The training dataset of WizardLLM has already been downloaded and split in ./data_wiz/ fold. 16 | * If you want to use your dataset, use the same format as ./data_wiz/. 17 | 18 | ## Running the experiments 19 | * To run the FLoRA algorithm (--stacking: True) and FedIT (--stacking False) in a homogeneous LoRA setting: 20 | ``` 21 | python main.py --global_model 'huggyllama/llama-7b' --data_path "./data_wiz" --output_dir './FloRA-llama7b-wiz-homo/' --num_communication_rounds 3 --local_num_epochs 1 --stacking True 22 | python main.py --global_model 'huggyllama/llama-7b' --data_path "./data_wiz" --output_dir './FedIT-llama7b-wiz-homo/' --num_communication_rounds 3 --local_num_epochs 1 --stacking False 23 | ``` 24 | * To run the FLoRA algorithm (--stacking: True) and Zero-Padding (--stacking False --zero_padding True) in a heterogeneous LoRA setting: 25 | ``` 26 | python main.py --global_model 'huggyllama/llama-7b' --data_path "./data_wiz" --output_dir './FloRA-llama7b-wiz-heter/' --num_communication_rounds 3 --local_num_epochs 1 --stacking True --heter True 27 | python main.py --global_model 'huggyllama/llama-7b' --data_path "./data_wiz" --output_dir './FedIT-llama7b-wiz-heter/' --num_communication_rounds 3 --local_num_epochs 1 --stacking False --heter True --zero_padding True 28 | ``` 29 | 30 | * To evaluate on LLM harness, try: 31 | ``` 32 | lm_eval --model_args pretrained=./FloRA-llama7b-wiz-homo/,parallelize=True,load_in_4bit=False, --tasks mmlu --num_fewshot 5 --batch_size 16 --output_path ../FloRA-llama7b-wiz-homo/ 33 | ``` 34 | * To evaluate on MT-Bench, please follow the instructions on their websites: https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge 35 | ----- 36 | -------------------------------------------------------------------------------- /client_data_allocation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import random 5 | import os 6 | import json 7 | import pdb 8 | 9 | num_clients = int(sys.argv[1]) 10 | diff_quantity = int(sys.argv[2]) 11 | 12 | np.random.seed(42) 13 | random.seed(42) 14 | 15 | # Divide the entire dataset into a training set and a test set. 16 | 17 | df = pd.read_json("new-databricks-dolly-15k.json", orient='records') 18 | sorted_df = df.sort_values(by=['category']) 19 | grouped = sorted_df.groupby('category') 20 | sampled_df = grouped.apply(lambda x: x.sample(n=10)) 21 | sampled_df = sampled_df.reset_index(level=0, drop=True) 22 | remaining_df = sorted_df.drop(index=sampled_df.index) 23 | 24 | sampled_df = sampled_df.reset_index().drop('index', axis=1) 25 | remaining_df = remaining_df.reset_index().drop('index', axis=1) 26 | data_path = os.path.join("data", str(num_clients)) 27 | 28 | os.makedirs(data_path,exist_ok=True) 29 | 30 | remaining_df_dic = remaining_df.to_dict(orient='records') 31 | with open(os.path.join(data_path, "global_training.json"), 'w') as outfile: 32 | json.dump(remaining_df_dic, outfile) 33 | 34 | sampled_df_dic = sampled_df.to_dict(orient='records') 35 | with open(os.path.join(data_path, "global_test.json"), 'w') as outfile: 36 | json.dump(sampled_df_dic, outfile) 37 | 38 | # Partition the global training data into smaller subsets for each client's local training dataset 39 | 40 | if diff_quantity: 41 | min_size = 0 42 | min_require_size = 40 43 | alpha = 0.5 44 | 45 | N = len(remaining_df) 46 | net_dataidx_map = {} 47 | category_uniques = remaining_df['category'].unique().tolist() 48 | while min_size < min_require_size: 49 | 50 | idx_partition = [[] for _ in range(num_clients)] 51 | for k in range(len(category_uniques)): 52 | category_rows_k = remaining_df.loc[remaining_df['category'] == category_uniques[k]] 53 | category_rows_k_index = category_rows_k.index.values 54 | np.random.shuffle(category_rows_k_index) 55 | proportions = np.random.dirichlet(np.repeat(alpha, num_clients)) 56 | proportions = np.array([p * (len(idx_j) < N / num_clients) for p, idx_j in zip(proportions, idx_partition)]) 57 | proportions = proportions / proportions.sum() 58 | proportions = (np.cumsum(proportions) * len(category_rows_k_index)).astype(int)[:-1] 59 | idx_partition = [idx_j + idx.tolist() for idx_j, idx in 60 | zip(idx_partition, np.split(category_rows_k_index, proportions))] 61 | min_size = min([len(idx_j) for idx_j in idx_partition]) 62 | 63 | print(min_size) 64 | 65 | 66 | else: 67 | num_shards_per_clients = 2 68 | remaining_df_index = remaining_df.index.values 69 | shards = np.array_split(remaining_df_index, int(num_shards_per_clients * num_clients)) 70 | random.shuffle(shards) 71 | 72 | shards = [shards[i:i + num_shards_per_clients] for i in range(0, len(shards), num_shards_per_clients)] 73 | idx_partition = [np.concatenate(shards[n]).tolist() for n in range(num_clients)] 74 | 75 | 76 | for client_id, idx in enumerate(idx_partition): 77 | print( 78 | "\n Generating the local training dataset of Client_{}".format(client_id) 79 | ) 80 | sub_remaining_df = remaining_df.loc[idx] 81 | sub_remaining_df = sub_remaining_df.reset_index().drop('index', axis=1) 82 | sub_remaining_df_dic = sub_remaining_df.to_dict(orient='records') 83 | 84 | with open(os.path.join(data_path, "local_training_{}.json".format(client_id)), 'w') as outfile: 85 | json.dump(sub_remaining_df_dic, outfile) 86 | -------------------------------------------------------------------------------- /data_wiz/10/x: -------------------------------------------------------------------------------- 1 | xx 2 | -------------------------------------------------------------------------------- /data_wiz/ss: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | 3 | snapshot_download(repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", local_dir = 'tinyllama') 4 | #snapshot_download(repo_id="huggyllama/llama-13b", local_dir = 'llama-13b') 5 | -------------------------------------------------------------------------------- /fed_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_aggregation import FedAvg 2 | from .client_participation_scheduling import client_selection 3 | from .client import GeneralClient 4 | from .evaluation import global_evaluation 5 | from .other import other_function -------------------------------------------------------------------------------- /fed_utils/client.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import os 3 | from datasets import load_dataset 4 | import copy 5 | from collections import OrderedDict 6 | import torch 7 | from peft import ( 8 | get_peft_model_state_dict, 9 | set_peft_model_state_dict, 10 | ) 11 | 12 | class GeneralClient: 13 | def __init__(self, client_id, model, data_path, output_dir): 14 | self.client_id = client_id 15 | self.model = model 16 | self.local_data_path = os.path.join(data_path, "local_training_{}.json".format(self.client_id)) 17 | self.local_data = load_dataset("json", data_files=self.local_data_path) 18 | self.output_dir = output_dir 19 | self.local_output_dir = os.path.join(self.output_dir, "trainer_saved", "local_output_{}".format(self.client_id)) 20 | 21 | def preprare_local_dataset(self, generate_and_tokenize_prompt, local_val_set_size): 22 | if local_val_set_size > 0: 23 | local_train_val = self.local_data["train"].train_test_split( 24 | test_size=local_val_set_size, shuffle=True, seed=42 25 | ) 26 | self.local_train_dataset = ( 27 | local_train_val["train"].shuffle().map(generate_and_tokenize_prompt) 28 | ) 29 | self.local_eval_dataset = ( 30 | local_train_val["test"].shuffle().map(generate_and_tokenize_prompt) 31 | ) 32 | else: 33 | self.local_train_dataset = self.local_data["train"].shuffle().map(generate_and_tokenize_prompt) 34 | self.local_eval_dataset = None 35 | self.local_val_set_size = local_val_set_size 36 | 37 | def build_local_trainer(self, 38 | tokenizer, 39 | local_micro_batch_size, 40 | gradient_accumulation_steps, 41 | local_num_epochs, 42 | local_learning_rate, 43 | group_by_length, 44 | ddp): 45 | self.train_args = transformers.TrainingArguments( 46 | per_device_train_batch_size=local_micro_batch_size, 47 | gradient_accumulation_steps=gradient_accumulation_steps, 48 | warmup_steps=0, 49 | num_train_epochs=local_num_epochs, 50 | learning_rate=local_learning_rate, 51 | fp16=True, 52 | logging_steps=1, 53 | optim="adamw_torch", 54 | evaluation_strategy="steps" if self.local_val_set_size > 0 else "no", 55 | save_strategy="steps", 56 | eval_steps=200 if self.local_val_set_size > 0 else None, 57 | save_steps=5000000, 58 | output_dir=self.local_output_dir, 59 | save_total_limit=1, 60 | load_best_model_at_end=True if self.local_val_set_size > 0 else False, 61 | ddp_find_unused_parameters=False if ddp else None, 62 | group_by_length=group_by_length, 63 | dataloader_drop_last=False 64 | ) 65 | self.local_trainer = transformers.Trainer(model=self.model, 66 | train_dataset=self.local_train_dataset, 67 | eval_dataset=self.local_eval_dataset, 68 | args=self.train_args, 69 | data_collator=transformers.DataCollatorForSeq2Seq( 70 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 71 | ), 72 | ) 73 | 74 | def initiate_local_training(self): 75 | self.model.config.use_cache = False 76 | self.params_dict_old = copy.deepcopy( 77 | OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if 78 | "default" in name)) 79 | self.params_dict_new = OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if 80 | "default" in name) 81 | self.model.state_dict = ( 82 | lambda instance, *_, **__: get_peft_model_state_dict( 83 | instance, self.params_dict_new, "default" 84 | ) 85 | ).__get__(self.model, type(self.model)) 86 | 87 | def train(self): 88 | self.local_trainer.train() 89 | 90 | def terminate_local_training(self, epoch, local_dataset_len_dict, previously_selected_clients_set): 91 | 92 | local_dataset_len_dict[self.client_id] = len(self.local_train_dataset) 93 | new_adapter_weight = self.model.state_dict() 94 | single_output_dir = os.path.join(self.output_dir, str(epoch), "local_output_{}".format(self.client_id)) 95 | os.makedirs(single_output_dir, exist_ok=True) 96 | torch.save(new_adapter_weight, single_output_dir + "/pytorch_model.bin") 97 | 98 | older_adapter_weight = get_peft_model_state_dict(self.model, self.params_dict_old, "default") 99 | set_peft_model_state_dict(self.model, older_adapter_weight, "default") 100 | previously_selected_clients_set = previously_selected_clients_set | set({self.client_id}) 101 | last_client_id = self.client_id 102 | 103 | return self.model, local_dataset_len_dict, previously_selected_clients_set, last_client_id 104 | -------------------------------------------------------------------------------- /fed_utils/client_participation_scheduling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def client_selection(num_clients, client_selection_frac, client_selection_strategy, other_info=None): 5 | np.random.seed(other_info) 6 | if client_selection_strategy == "random": 7 | num_selected = max(int(client_selection_frac * num_clients), 1) 8 | selected_clients_set = set(np.random.choice(np.arange(num_clients), num_selected, replace=False)) 9 | 10 | return selected_clients_set 11 | -------------------------------------------------------------------------------- /fed_utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from tqdm import tqdm 4 | import fire 5 | import torch 6 | import datasets 7 | from transformers import GenerationConfig 8 | import json 9 | import csv 10 | from peft import set_peft_model_state_dict 11 | import numpy as np 12 | import random 13 | 14 | model_type = 'llama' 15 | datasets.utils.logging.set_verbosity_error() 16 | device_map = "auto" 17 | max_new_token: int = 32 18 | verbose: bool = False 19 | 20 | # 设置随机数种子 21 | def setup_seed(seed): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | 28 | setup_seed(1) 29 | 30 | def global_evaluation(model, tokenizer, prompter, dev_data_path): 31 | data_class = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions'] 32 | right_count_dict = dict.fromkeys(data_class, 0) 33 | total_count_dict = dict.fromkeys(data_class, 0) 34 | acc_count_dict = dict.fromkeys(data_class, 0) 35 | with open(dev_data_path, 'r') as f: 36 | test_set = json.load(f) 37 | count=0 38 | 39 | if model_type == 'llama': 40 | sampling = GenerationConfig( 41 | do_sample=True, 42 | temperature=0.2, 43 | top_p=0.6, 44 | top_k=30, 45 | num_beams=1, 46 | max_new_tokens=max_new_token, 47 | early_stopping=True, 48 | ) 49 | 50 | if model_type == 'gpt2': 51 | sampling = GenerationConfig( 52 | bos_token_id = 50256, 53 | eos_token_id = 50256, 54 | _from_model_config = True, 55 | ) 56 | 57 | for data_point in tqdm(test_set): 58 | count +=1 59 | target = data_point["output"] 60 | class_test_set = data_point["class"] 61 | 62 | tgt_ans_idx = target.replace('The answer is: ','').split('. ')[0] 63 | tgt_ans = target.replace('The answer is: ','').split('. ')[1] 64 | 65 | test_prompt = prompter.generate_prompt( 66 | data_point["instruction"], 67 | data_point["input"], 68 | 'The answer is: ', 69 | ) 70 | 71 | with torch.autocast("cuda"): 72 | inputs = tokenizer(test_prompt, return_tensors="pt") 73 | input =inputs["input_ids"].to('cuda') 74 | with torch.no_grad(): 75 | #print(tokenizer.eos_token_id, tokenizer.pad_token_id) 76 | generation_output = model.generate( 77 | input_ids=input, 78 | generation_config=sampling, 79 | return_dict_in_generate=True, 80 | output_scores=True, 81 | max_new_tokens=max_new_token, 82 | pad_token_id=tokenizer.eos_token_id 83 | ) 84 | generation_output_decoded = tokenizer.decode(generation_output.sequences[0]) 85 | # print(generation_output_decoded) 86 | split = prompter.template["response_split"] 87 | ans = generation_output_decoded.split(split)[-1].strip() 88 | if verbose: 89 | print('-------------------') 90 | print(test_prompt) 91 | print(tgt_ans) 92 | print(tgt_ans_idx) 93 | print(ans) 94 | if tgt_ans_idx+'.' in ans or tgt_ans in ans: 95 | # if tgt_ans_idx in ans or tgt_ans in ans: 96 | right_count_dict[class_test_set] += 1 97 | total_count_dict[class_test_set] += 1 98 | 99 | mean_acc = 0. 100 | 101 | for key in acc_count_dict.keys(): 102 | tmp = right_count_dict[key]/total_count_dict[key] 103 | mean_acc += tmp 104 | acc_count_dict[key] = tmp 105 | mean_acc /= len(acc_count_dict.keys()) 106 | csv_data = [right_count_dict, total_count_dict, acc_count_dict] 107 | 108 | '''with open(os.path.join('/ai4bio-store/junbo.li/data_selection/alpaca-lora/raw_dict_mmlu',data_path.split('/')[-1].replace('.json','') + '.csv'), 'w', newline='') as file: 109 | writer = csv.DictWriter(file, fieldnames=right_count_dict.keys()) 110 | writer.writeheader() 111 | for row in csv_data: 112 | writer.writerow(row)''' 113 | if verbose: 114 | print(right_count_dict) 115 | #print(total_count_dict) 116 | print('Acc: ', acc_count_dict) 117 | print() 118 | #score = eval_usmle(model, dev_data_path, tokenizer, verbose=False) 119 | print('========== Accuracy ==========') 120 | print(mean_acc) 121 | 122 | return mean_acc 123 | 124 | #model = LlamaForCausalLM.from_pretrained( 125 | #model = AutoModelForCausalLM.from_pretrained( 126 | #tokenizer = LlamaTokenizer.from_pretrained('linhvu/decapoda-research-llama-7b-hf') 127 | #tokenizer = AutoTokenizer.from_pretrained('gpt2') 128 | #tokenizer.pad_token_id = tokenizer.eos_token_id 129 | #print(tokenizer.pad_token_id, tokenizer.eos_token_id) 130 | '''tokenizer.pad_token_id = ( 131 | 0 132 | ) 133 | tokenizer.padding_side = "left"''' 134 | # = Prompter("alpaca") 135 | 136 | '''for id in range(1, 10): 137 | single_weights = torch.load('./lora-shepherd-7b-autolora-1-4/10/0/local_output_{}/'.format(id)) 138 | set_peft_model_state_dict(model_c, single_weights, "default") 139 | for param_tensor in model_c.state_dict(): 140 | model.state_dict[param_tensor] += model_c.state_dict[param_tensor] 141 | 142 | for param_tensor in model.state_dict(): 143 | model.state_dict[param_tensor] = model.state_dict[param_tensor]/10.0''' 144 | 145 | '''with open(count_fine_path, "a") as file: 146 | file.write(str({"dataset_name": data_path.split('/')[-1], "accuracy": score})+'\n')''' -------------------------------------------------------------------------------- /fed_utils/hehe: -------------------------------------------------------------------------------- 1 | x 2 | -------------------------------------------------------------------------------- /fed_utils/model_aggregation.py: -------------------------------------------------------------------------------- 1 | from peft import ( 2 | set_peft_model_state_dict, 3 | ) 4 | import torch 5 | import os 6 | from torch.nn.functional import normalize 7 | from torch.nn import ZeroPad2d 8 | 9 | def FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch, stacking, lora_r, heter, local_ranks, zero_padding, full): 10 | weights_array = normalize( 11 | torch.tensor([local_dataset_len_dict[client_id] for client_id in selected_clients_set], 12 | dtype=torch.float32), 13 | p=1, dim=0) 14 | 15 | print("Weights:", weights_array) 16 | for k, client_id in enumerate(selected_clients_set): 17 | single_output_dir = os.path.join(output_dir, str(epoch), "local_output_{}".format(client_id), 18 | "pytorch_model.bin") 19 | single_weights = torch.load(single_output_dir, map_location = 'cpu') 20 | #print(single_weights) 21 | #print("y") 22 | x = 0 23 | if full: 24 | if k == 0: 25 | weighted_single_weights = single_weights 26 | for key in weighted_single_weights.keys(): 27 | weighted_single_weights[key] = weighted_single_weights[key] * (weights_array[k]) 28 | else: 29 | for key in single_weights.keys(): 30 | weighted_single_weights[key] += single_weights[key] * (weights_array[k]) 31 | 32 | else: 33 | if stacking: 34 | if zero_padding: 35 | max_lora = max(local_ranks) 36 | if k == 0: 37 | weighted_single_weights = single_weights 38 | for key in weighted_single_weights.keys(): 39 | if single_weights[key].shape[0] == local_ranks[client_id]: 40 | pad = ZeroPad2d(padding=(0, 0, 0, max_lora-local_ranks[client_id])) 41 | weighted_single_weights[key] = pad(weighted_single_weights[key]) * (weights_array[k]) 42 | elif single_weights[key].shape[1] == local_ranks[client_id]: 43 | pad = ZeroPad2d(padding=(0, max_lora-local_ranks[client_id], 0, 0)) 44 | weighted_single_weights[key] = pad(weighted_single_weights[key]) * (weights_array[k]) 45 | else: 46 | for key in single_weights.keys(): 47 | #print(single_weights[key].shape) 48 | if single_weights[key].shape[0] == local_ranks[client_id]: 49 | pad = ZeroPad2d(padding=(0, 0, 0, max_lora-local_ranks[client_id])) 50 | single_weights[key] = pad(single_weights[key]) * (weights_array[k]) 51 | weighted_single_weights[key] += single_weights[key] 52 | elif single_weights[key].shape[1] == local_ranks[client_id]: 53 | pad = ZeroPad2d(padding=(0, max_lora-local_ranks[client_id], 0, 0)) 54 | single_weights[key] = pad(single_weights[key]) * (weights_array[k]) 55 | #print(single_weights[key][255,32]) 56 | weighted_single_weights[key] += single_weights[key] 57 | 58 | else: 59 | if k == 0: 60 | weighted_single_weights = single_weights 61 | for key in weighted_single_weights.keys(): 62 | #weighted_single_weights[key] = weighted_single_weights[key] * (weights_array[k]) 63 | #print(weighted_single_weights[key].shape) 64 | if heter: 65 | x += 1 66 | if weighted_single_weights[key].shape[0] == local_ranks[client_id]: 67 | weighted_single_weights[key] = weighted_single_weights[key] * (weights_array[k] * 1) 68 | else: 69 | if weighted_single_weights[key].shape[0] == lora_r: 70 | weighted_single_weights[key] = weighted_single_weights[key] * (weights_array[k] * 1) 71 | 72 | else: 73 | for key in single_weights.keys(): 74 | if heter: 75 | x += 1 76 | if single_weights[key].shape[0] == local_ranks[client_id]: 77 | new = [weighted_single_weights[key], single_weights[key] * (weights_array[k]) * 1] 78 | weighted_single_weights[key] = torch.cat(new, dim=0) 79 | else: 80 | if single_weights[key].shape[0] == lora_r: 81 | new = [weighted_single_weights[key], single_weights[key] * (weights_array[k]) * 1] 82 | weighted_single_weights[key] = torch.cat(new, dim=0) 83 | 84 | if heter: 85 | if single_weights[key].shape[1] == local_ranks[client_id]: 86 | new = [weighted_single_weights[key], single_weights[key]]# * (weights_array[k])] 87 | weighted_single_weights[key] = torch.cat(new, dim=1) 88 | else: 89 | if single_weights[key].shape[1] == lora_r: 90 | new = [weighted_single_weights[key], single_weights[key]]# * (weights_array[k])] 91 | weighted_single_weights[key] = torch.cat(new, dim=1) 92 | 93 | else: 94 | if zero_padding: 95 | max_lora = max(local_ranks) 96 | if k == 0: 97 | weighted_single_weights = single_weights 98 | for key in weighted_single_weights.keys(): 99 | if single_weights[key].shape[0] == local_ranks[client_id]: 100 | pad = ZeroPad2d(padding=(0, 0, 0, max_lora-local_ranks[client_id])) 101 | weighted_single_weights[key] = pad(weighted_single_weights[key]) * (weights_array[k]) 102 | elif single_weights[key].shape[1] == local_ranks[client_id]: 103 | pad = ZeroPad2d(padding=(0, max_lora-local_ranks[client_id], 0, 0)) 104 | weighted_single_weights[key] = pad(weighted_single_weights[key]) * (weights_array[k]) 105 | else: 106 | for key in single_weights.keys(): 107 | #print(single_weights[key].shape) 108 | if single_weights[key].shape[0] == local_ranks[client_id]: 109 | pad = ZeroPad2d(padding=(0, 0, 0, max_lora-local_ranks[client_id])) 110 | single_weights[key] = pad(single_weights[key]) * (weights_array[k]) 111 | weighted_single_weights[key] += single_weights[key] 112 | elif single_weights[key].shape[1] == local_ranks[client_id]: 113 | pad = ZeroPad2d(padding=(0, max_lora-local_ranks[client_id], 0, 0)) 114 | single_weights[key] = pad(single_weights[key]) * (weights_array[k]) 115 | #print(single_weights[key][255,32]) 116 | weighted_single_weights[key] += single_weights[key] 117 | else: 118 | if k == 0: 119 | weighted_single_weights = {key: single_weights[key] * (weights_array[k]) for key in 120 | single_weights.keys()} 121 | else: 122 | weighted_sindgle_weights = {key: weighted_single_weights[key] + single_weights[key] * (weights_array[k]) 123 | for key in 124 | single_weights.keys()} 125 | 126 | 127 | if stacking: 128 | torch.save(weighted_single_weights, os.path.join(output_dir, str(epoch), "adapter_model.bin")) 129 | return model 130 | elif full: 131 | torch.save(weighted_single_weights, os.path.join(output_dir, str(epoch), "pytorch_model.bin")) 132 | model.load_state_dict(weighted_single_weights) 133 | return model 134 | else: 135 | set_peft_model_state_dict(model, weighted_single_weights, "default") 136 | return model 137 | -------------------------------------------------------------------------------- /fed_utils/other.py: -------------------------------------------------------------------------------- 1 | def other_function(): 2 | 3 | return print("design the other functions you need") 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /load.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, AutoConfig 2 | import torch 3 | 4 | model = LlamaForCausalLM.from_pretrained( 5 | 'tinyllama', 6 | load_in_8bit=False, 7 | torch_dtype=torch.float32, 8 | token="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", 9 | ) 10 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" 3 | from typing import List 4 | from tqdm import tqdm 5 | import fire 6 | import torch 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, AutoConfig 8 | from peft import ( 9 | LoraConfig, 10 | get_peft_model, 11 | prepare_model_for_int8_training, 12 | PeftModel, 13 | AdaLoraConfig, 14 | AdaLoraModel, 15 | ) 16 | from fed_utils import FedAvg, client_selection, global_evaluation, GeneralClient 17 | import datasets 18 | from utils.prompter import Prompter 19 | import numpy as np 20 | import random 21 | import copy 22 | 23 | def fl_finetune( 24 | # model/data params 25 | global_model: str = 'huggyllama/llama-7b', 26 | data_path: str = './data', 27 | output_dir: str = './fedgpt-llama7b-5-2/', 28 | # FL hyperparamas 29 | client_selection_strategy: str = 'random', 30 | client_selection_frac: float = 1, 31 | num_communication_rounds: int = 5, 32 | num_clients: int = 10, 33 | # Local training hyperparams 34 | local_batch_size: int = 128, # 64, 35 | local_micro_batch_size: int = 16, 36 | local_num_epochs: int = 3, 37 | local_learning_rate: float = 3e-4, 38 | local_val_set_size: int = 0, 39 | local_save_steps: int = 3, 40 | cutoff_len: int = 512, 41 | # LoRA hyperparams 42 | lora_r: int = 16, 43 | lora_alpha: int = 32, 44 | lora_dropout: float = 0.05, 45 | lora_target_modules: List[str] = [ 46 | "q_proj", 47 | "v_proj", 48 | ], 49 | # llm hyperparams 50 | train_on_inputs: bool = True, 51 | group_by_length: bool = False, 52 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 53 | prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. 54 | # aggregation mode 55 | stacking: bool = False, 56 | # evaluation 57 | dev_data_path: str = './mmlu_test_1444.jsonl', 58 | # heterogeneous 59 | heter: bool = False, 60 | local_ranks: List[int] = [64, 32, 16, 16, 8, 8, 4, 4, 4, 4], 61 | zero_padding: bool = False, 62 | Adalora: bool = False, 63 | full: bool = False 64 | ): 65 | if int(os.environ.get("LOCAL_RANK", 0)) == 0: 66 | print( 67 | f"Federated Finetuning LLM-LoRA with params:\n" 68 | f"global_model: {global_model}\n" 69 | f"data_path: {data_path}\n" 70 | f"output_dir: {output_dir}\n" 71 | f"client_selection_strategy: {client_selection_strategy}\n" 72 | f"client_selection_frac: {client_selection_frac}\n" 73 | f"num_communication_rounds: {num_communication_rounds}\n" 74 | f"num_clients: {num_clients}\n" 75 | f"local_batch_size: {local_batch_size}\n" 76 | f"local_micro_batch_size: {local_micro_batch_size}\n" 77 | f"local_num_epochs: {local_num_epochs}\n" 78 | f"local_learning_rate: {local_learning_rate}\n" 79 | f"local_val_set_size: {local_val_set_size}\n" 80 | f"local_save_steps: {local_save_steps}\n" 81 | f"cutoff_len: {cutoff_len}\n" 82 | f"lora_r: {lora_r}\n" 83 | f"lora_alpha: {lora_alpha}\n" 84 | f"lora_dropout: {lora_dropout}\n" 85 | f"lora_target_modules: {lora_target_modules}\n" 86 | f"train_on_inputs: {train_on_inputs}\n" 87 | f"group_by_length: {group_by_length}\n" 88 | f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" 89 | f"prompt template: {prompt_template_name}\n" 90 | ) 91 | assert ( 92 | global_model 93 | ), "Please specify a --global_model, e.g. --global_modell='decapoda-research/llama-7b-hf'" 94 | 95 | data_path = os.path.join(data_path, str(num_clients)) 96 | assert (os.path.exists(data_path), "Please generate the data files for each client") 97 | 98 | # set up the global model & toknizer 99 | gradient_accumulation_steps = local_batch_size // local_micro_batch_size 100 | prompter = Prompter(prompt_template_name) 101 | device_map = "auto" 102 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 103 | ddp = world_size != 1 104 | if ddp: 105 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 106 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 107 | 108 | if global_model == 'gpt2': 109 | model = GPT2LMHeadModel.from_pretrained( 110 | global_model, 111 | load_in_8bit=False, 112 | torch_dtype=torch.float32, 113 | device_map=device_map, 114 | ) 115 | elif global_model == 'google/gemma-2b' or global_model == 'google/gemma-7b': 116 | model = AutoModelForCausalLM.from_pretrained( 117 | global_model, 118 | load_in_8bit=False, 119 | torch_dtype=torch.float32, 120 | device_map=device_map, 121 | token='your token', 122 | ) 123 | else: 124 | model = LlamaForCausalLM.from_pretrained( 125 | global_model, 126 | load_in_8bit=False, 127 | torch_dtype=torch.float32, 128 | device_map=device_map, 129 | token="your token", 130 | ) 131 | 132 | if global_model == 'gpt2': 133 | tokenizer = GPT2Tokenizer.from_pretrained(global_model) 134 | elif global_model == 'google/gemma-2b' or global_model == 'google/gemma-7b': 135 | tokenizer = AutoTokenizer.from_pretrained(global_model, token='your_token',) 136 | else: 137 | tokenizer = LlamaTokenizer.from_pretrained(global_model, token="your_token",) 138 | 139 | tokenizer.pad_token_id = ( 140 | 0 141 | ) 142 | tokenizer.padding_side = "left" 143 | 144 | def tokenize(prompt, add_eos_token=True): 145 | result = tokenizer( 146 | prompt, 147 | truncation=True, 148 | max_length=cutoff_len, 149 | padding=False, 150 | return_tensors=None, 151 | ) 152 | if ( 153 | result["input_ids"][-1] != tokenizer.eos_token_id 154 | and len(result["input_ids"]) < cutoff_len 155 | and add_eos_token 156 | ): 157 | result["input_ids"].append(tokenizer.eos_token_id) 158 | result["attention_mask"].append(1) 159 | 160 | result["labels"] = result["input_ids"].copy() 161 | 162 | return result 163 | 164 | def generate_and_tokenize_prompt(data_point): 165 | if data_path == './data/10': 166 | full_prompt = prompter.generate_prompt( 167 | data_point["instruction"], 168 | data_point["context"], 169 | data_point["response"], 170 | ) 171 | elif data_path == './data_wiz/10' or data_path == './data_mix/20': 172 | full_prompt = prompter.generate_prompt( 173 | data_point["instruction"], 174 | None, 175 | data_point["output"], 176 | ) 177 | else: 178 | full_prompt = prompter.generate_prompt( 179 | data_point["instruction"], 180 | data_point["input"], 181 | data_point["output"], 182 | ) 183 | 184 | tokenized_full_prompt = tokenize(full_prompt) 185 | if not train_on_inputs: 186 | user_prompt = prompter.generate_prompt( 187 | data_point["instruction"], data_point["context"] 188 | ) 189 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 190 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 191 | 192 | tokenized_full_prompt["labels"] = [ 193 | -100 194 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 195 | user_prompt_len: 196 | ] # could be sped up, probably 197 | return tokenized_full_prompt 198 | 199 | #model = prepare_model_for_int8_training(model) 200 | if full == False: 201 | if stacking == False: 202 | if zero_padding: 203 | config_ori = LoraConfig( 204 | base_model_name_or_path=global_model, 205 | r = max(local_ranks), 206 | lora_alpha = lora_alpha * max(local_ranks), 207 | target_modules = lora_target_modules, 208 | lora_dropout = lora_dropout, 209 | bias = "none", 210 | task_type = "CAUSAL_LM", 211 | ) 212 | else: 213 | config = LoraConfig( 214 | base_model_name_or_path=global_model, 215 | r = lora_r, 216 | lora_alpha = lora_alpha, 217 | target_modules = lora_target_modules, 218 | lora_dropout = lora_dropout, 219 | bias = "none", 220 | task_type = "CAUSAL_LM", 221 | ) 222 | model = get_peft_model(model, config) 223 | 224 | else: 225 | config_ori = LoraConfig( 226 | base_model_name_or_path=global_model, 227 | r = lora_r * num_clients, 228 | lora_alpha = lora_alpha * num_clients, 229 | target_modules = lora_target_modules, 230 | lora_dropout = lora_dropout, 231 | bias = "none", 232 | task_type = "CAUSAL_LM", 233 | ) 234 | 235 | if not ddp and torch.cuda.device_count() > 1: 236 | model.is_parallelizable = True 237 | model.model_parallel = True 238 | 239 | print("The process of federated instruction-tuning has started..") 240 | previously_selected_clients_set = set() 241 | last_client_id = None 242 | local_dataset_len_dict = dict() 243 | output_dir = os.path.join(output_dir, str(num_clients)) 244 | 245 | acc_list = [] 246 | 247 | for epoch in tqdm(range(num_communication_rounds)): 248 | 249 | print("\nConducting the client selection") 250 | selected_clients_set = client_selection(num_clients, client_selection_frac, client_selection_strategy, 251 | other_info=epoch) 252 | 253 | for client_id in selected_clients_set: 254 | if full == False: 255 | if Adalora: 256 | config = AdaLoraConfig( 257 | r=local_ranks[client_id], 258 | lora_alpha=2*local_ranks[client_id], 259 | target_modules=lora_target_modules, 260 | lora_dropout=lora_dropout, 261 | bias="none", 262 | task_type="CAUSAL_LM", 263 | base_model_name_or_path=global_model, 264 | ) 265 | model_client = copy.deepcopy(model) 266 | model_client = get_peft_model(model_client, config) 267 | else: 268 | if stacking: 269 | if heter: 270 | config = LoraConfig( 271 | r=local_ranks[client_id], 272 | lora_alpha=2*local_ranks[client_id], 273 | target_modules=lora_target_modules, 274 | lora_dropout=lora_dropout, 275 | bias="none", 276 | task_type="CAUSAL_LM", 277 | base_model_name_or_path=global_model, 278 | ) 279 | model_client = copy.deepcopy(model) 280 | model_client = get_peft_model(model_client, config) 281 | else: 282 | config = LoraConfig( 283 | r=lora_r, 284 | lora_alpha=lora_alpha, 285 | target_modules=lora_target_modules, 286 | lora_dropout=lora_dropout, 287 | bias="none", 288 | task_type="CAUSAL_LM", 289 | base_model_name_or_path=global_model, 290 | ) 291 | model_client = copy.deepcopy(model) 292 | model_client = get_peft_model(model_client, config) 293 | else: 294 | if heter: 295 | config = LoraConfig( 296 | r=local_ranks[client_id], 297 | lora_alpha=2*local_ranks[client_id], 298 | target_modules=lora_target_modules, 299 | lora_dropout=lora_dropout, 300 | bias="none", 301 | task_type="CAUSAL_LM", 302 | base_model_name_or_path=global_model, 303 | ) 304 | model_client = copy.deepcopy(model) 305 | model_client = get_peft_model(model_client, config) 306 | else: 307 | model_client = model 308 | 309 | else: 310 | model_client = model 311 | 312 | client = GeneralClient(client_id, model_client, data_path, output_dir) 313 | 314 | print("\nPreparing the local dataset and trainer for Client_{}".format(client_id)) 315 | client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size) 316 | client.build_local_trainer(tokenizer, 317 | local_micro_batch_size, 318 | gradient_accumulation_steps, 319 | local_num_epochs, 320 | local_learning_rate, 321 | group_by_length, 322 | ddp) 323 | 324 | print("Initiating the local training of Client_{}".format(client_id)) 325 | client.initiate_local_training() 326 | 327 | print("Local training starts ... ") 328 | client.train() 329 | 330 | print("\nTerminating the local training of Client_{}".format(client_id)) 331 | model_client, local_dataset_len_dict, previously_selected_clients_set, last_client_id = client.terminate_local_training( 332 | epoch, local_dataset_len_dict, previously_selected_clients_set) 333 | del client 334 | 335 | print("Collecting the weights of clients and performing aggregation") 336 | #local_dataset_len_dict = [1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00] 337 | 338 | model = FedAvg(model, 339 | selected_clients_set, 340 | output_dir, 341 | local_dataset_len_dict, 342 | epoch, 343 | stacking, 344 | lora_r, 345 | heter, 346 | local_ranks, 347 | zero_padding, 348 | full 349 | ) 350 | 351 | if full == False: 352 | if stacking: 353 | config_ori.save_pretrained( 354 | os.path.join(output_dir, str(epoch)), 355 | load_in_8bit=False, 356 | torch_dtype=torch.float16, 357 | device_map=device_map, 358 | ) 359 | model = PeftModel.from_pretrained(model, os.path.join(output_dir, str(epoch))) 360 | else: 361 | torch.save(model.state_dict(), os.path.join(output_dir, str(epoch), "adapter_model.bin")) 362 | config.save_pretrained( 363 | os.path.join(output_dir, str(epoch)), 364 | load_in_8bit=False, 365 | torch_dtype=torch.float16, 366 | device_map=device_map, 367 | ) 368 | else: 369 | config = AutoConfig.from_pretrained(global_model) 370 | tokenizer.save_pretrained(os.path.join(output_dir, str(epoch)), 371 | load_in_8bit=False, 372 | torch_dtype=torch.float32, 373 | device_map=device_map,) 374 | config.save_pretrained(os.path.join(output_dir, str(epoch)), 375 | load_in_8bit=False, 376 | torch_dtype=torch.float32, 377 | device_map=device_map,) 378 | 379 | print('save model') 380 | 381 | acc = global_evaluation(model, tokenizer, prompter, dev_data_path) 382 | print('Acc of Epoch', str(epoch), 'is:', acc) 383 | acc_list.append(acc) 384 | '''x_dir = os.path.join(output_dir, str(epoch)) 385 | current_dir = x_dir # + "/temp/" 386 | print(current_dir)''' 387 | #arc_easy,hellaswag,mmlu,truthfulqa 388 | #os.system("lm_eval --model_args pretrained=huggyllama/llama-7b,parallelize=True,load_in_4bit=False,peft={current_dir} --tasks arc_easy,hellaswag,mmlu,truthfulqa --device cuda --output_path {current_dir}".format(current_dir = current_dir)) 389 | #os.system("lm_eval --model_args pretrained={current_dir},parallelize=True,load_in_4bit=False --tasks arc_easy,hellaswag,mmlu,truthfulqa --device cuda --output_path {current_dir}".format(current_dir = os.path.join(output_dir, str(epoch)))) 390 | if stacking: 391 | model = model.merge_and_unload() 392 | model.save_pretrained(os.path.join(output_dir, str(epoch) + '/final'), 393 | load_in_8bit=False, 394 | torch_dtype=torch.float32, 395 | device_map=device_map,) 396 | 397 | if epoch < (num_communication_rounds - 1): 398 | rm_dir = os.path.join(output_dir, str(epoch)) 399 | os.system("rm -rf {xxxxx}".format(xxxxx = rm_dir)) 400 | 401 | print(acc_list) 402 | #os.system("lm_eval --model_args pretrained=huggyllama/llama-7b,parallelize=True,load_in_4bit=False,peft={current_dir} --tasks arc_challenge,mmlu --device cuda --output_path {current_dir}".format(current_dir = os.path.join(output_dir, str(epoch)))) 403 | filename = output_dir + 'log.txt' 404 | file = open(filename,'a') 405 | for i in range(len(acc_list)): 406 | s = str(acc_list[i]).replace('[','').replace(']','') 407 | s = s.replace("'",'').replace(',','') +'\n' 408 | file.write(s) 409 | file.close() 410 | print("Log Saved") 411 | 412 | if __name__ == "__main__": 413 | fire.Fire(fl_finetune) 414 | -------------------------------------------------------------------------------- /new.sh: -------------------------------------------------------------------------------- 1 | python main.py --global_model 'google/gemma-7b' --data_path "./data_wiz" --output_dir './nips-gemma7b-full-wiz-1-3-10/' --num_communication_rounds 1 --local_num_epochs 3 --full True 2 | python main.py --global_model 'google/gemma-7b' --data_path "./data_mmlu" --output_dir './nips-gemma7b-full-mmlu-1-3-10/' --num_communication_rounds 1 --local_num_epochs 3 --full True 3 | python main.py --global_model 'google/gemma-7b' --data_path "./data_mix" --output_dir './nips-gemma7b-full-mix-1-3-10/' --num_communication_rounds 1 --local_num_epochs 3 --full True --num_clients 20 4 | 5 | python main.py --global_model 'google/gemma-7b' --data_path "./data_wiz" --output_dir './nips-gemma7b-full-wiz-3-1-10/' --num_communication_rounds 3 --local_num_epochs 1 --full True 6 | python main.py --global_model 'google/gemma-7b' --data_path "./data_mmlu" --output_dir './nips-gemma7b-full-mmlu-3-1-10/' --num_communication_rounds 3 --local_num_epochs 1 --full True 7 | python main.py --global_model 'google/gemma-7b' --data_path "./data_mix" --output_dir './nips-gemma7b-full-mix-3-1-20/' --num_communication_rounds 3 --local_num_epochs 1 --full True --num_clients 20 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.27.2 3 | aiofiles==23.2.1 4 | aiohttp==3.9.3 5 | aiosignal==1.3.1 6 | altair==5.2.0 7 | anthropic==0.20.0 8 | anyio==4.3.0 9 | asttokens==2.0.5 10 | astunparse==1.6.3 11 | async-timeout==4.0.3 12 | attrs==23.1.0 13 | backcall==0.2.0 14 | beautifulsoup4==4.12.2 15 | blessed==1.20.0 16 | boltons==23.0.0 17 | Brotli==1.0.9 18 | certifi==2023.7.22 19 | cffi==1.15.1 20 | chardet==4.0.0 21 | charset-normalizer==2.0.4 22 | click==8.1.7 23 | colorama==0.4.6 24 | conda==23.9.0 25 | conda-build==3.27.0 26 | conda-content-trust==0.2.0 27 | conda_index==0.3.0 28 | conda-libmamba-solver==23.7.0 29 | conda-package-handling==2.2.0 30 | conda_package_streaming==0.9.0 31 | contourpy==1.2.0 32 | cryptography==41.0.3 33 | cycler==0.12.1 34 | DataProperty==1.0.1 35 | datasets==2.16.1 36 | decorator==5.1.1 37 | dill==0.3.7 38 | distro==1.9.0 39 | dnspython==2.4.2 40 | evaluate==0.4.1 41 | exceptiongroup==1.0.4 42 | executing==0.8.3 43 | expecttest==0.1.6 44 | fastapi==0.110.0 45 | ffmpy==0.3.2 46 | filelock==3.9.0 47 | fire==0.5.0 48 | fonttools==4.50.0 49 | frozenlist==1.4.1 50 | fschat==0.2.35 51 | fsspec==2023.10.0 52 | gmpy2==2.1.2 53 | gpustat==1.1.1 54 | gradio==3.50.2 55 | gradio_client==0.6.1 56 | h11==0.14.0 57 | httpcore==1.0.4 58 | httpx==0.27.0 59 | huggingface-hub==0.21.4 60 | hypothesis==6.88.4 61 | idna==3.4 62 | importlib_resources==6.3.2 63 | ipython==8.15.0 64 | jedi==0.18.1 65 | Jinja2==3.1.2 66 | joblib==1.3.2 67 | jsonlines==4.0.0 68 | jsonpatch==1.32 69 | jsonpointer==2.1 70 | jsonschema==4.21.1 71 | jsonschema-specifications==2023.12.1 72 | kiwisolver==1.4.5 73 | libarchive-c==2.9 74 | lm_eval==0.4.0 75 | lxml==5.1.0 76 | markdown-it-py==3.0.0 77 | markdown2==2.4.13 78 | MarkupSafe==2.1.1 79 | matplotlib==3.8.3 80 | matplotlib-inline==0.1.6 81 | mbstrdecoder==1.1.3 82 | mdurl==0.1.2 83 | mkl-fft==1.3.8 84 | mkl-random==1.2.4 85 | mkl-service==2.4.0 86 | more-itertools==8.12.0 87 | mpmath==1.3.0 88 | msgpack==1.0.8 89 | multidict==6.0.5 90 | multiprocess==0.70.15 91 | networkx==3.1 92 | nh3==0.2.15 93 | nltk==3.8.1 94 | numexpr==2.9.0 95 | numpy 96 | nvidia-ml-py==12.535.133 97 | openai==0.28.1 98 | orjson==3.9.15 99 | packaging==23.1 100 | pandas==2.2.1 101 | parso==0.8.3 102 | pathvalidate==3.2.0 103 | peft==0.3.0 104 | pexpect==4.8.0 105 | pickleshare==0.7.5 106 | Pillow==10.0.1 107 | pip==24.0 108 | pkginfo==1.9.6 109 | pluggy==1.0.0 110 | portalocker==2.8.2 111 | prompt-toolkit==3.0.36 112 | protobuf==4.25.3 113 | psutil==5.9.0 114 | ptyprocess==0.7.0 115 | pure-eval==0.2.2 116 | pyarrow==15.0.1 117 | pyarrow-hotfix==0.6 118 | pybind11==2.11.1 119 | pycosat==0.6.6 120 | pycparser==2.21 121 | pydantic==1.10.14 122 | pydub==0.25.1 123 | Pygments==2.15.1 124 | pyOpenSSL==23.2.0 125 | pyparsing==3.1.2 126 | PySocks==1.7.1 127 | pytablewriter==1.2.0 128 | python-dateutil==2.9.0.post0 129 | python-etcd==0.4.5 130 | python-multipart==0.0.9 131 | pytz==2023.3.post1 132 | PyYAML==6.0.1 133 | ray==2.9.3 134 | referencing==0.34.0 135 | regex==2023.12.25 136 | requests==2.31.0 137 | responses==0.18.0 138 | rich==13.7.1 139 | rouge-score==0.1.2 140 | rpds-py==0.18.0 141 | ruamel.yaml==0.17.21 142 | ruamel.yaml.clib==0.2.6 143 | sacrebleu==2.4.0 144 | safetensors==0.4.2 145 | scikit-learn==1.4.1.post1 146 | scipy==1.12.0 147 | semantic-version==2.10.0 148 | sentencepiece==0.1.99 149 | setuptools==68.0.0 150 | shortuuid==1.0.13 151 | six==1.16.0 152 | sniffio==1.3.1 153 | sortedcontainers==2.4.0 154 | soupsieve==2.5 155 | sqlitedict==2.1.0 156 | stack-data==0.2.0 157 | starlette==0.36.3 158 | svgwrite==1.4.3 159 | sympy==1.11.1 160 | tabledata==1.3.3 161 | tabulate==0.9.0 162 | tcolorpy==0.1.4 163 | termcolor==2.4.0 164 | threadpoolctl==3.3.0 165 | tiktoken==0.6.0 166 | tokenizers==0.19.1 167 | tomli==2.0.1 168 | toolz==0.12.0 169 | torch==2.1.1 170 | torchaudio==2.1.1 171 | torchelastic==0.2.2 172 | torchvision==0.16.1 173 | tqdm==4.65.0 174 | tqdm-multiprocess==0.0.11 175 | traitlets==5.7.1 176 | transformers==4.40.0 177 | triton==2.1.0 178 | truststore==0.8.0 179 | typepy==1.3.2 180 | types-dataclasses==0.6.6 181 | typing_extensions==4.10.0 182 | tzdata==2024.1 183 | urllib3==1.26.18 184 | uvicorn==0.28.1 185 | wavedrom==2.0.3.post3 186 | wcwidth==0.2.5 187 | websockets==11.0.3 188 | wheel==0.41.2 189 | word2number==1.1 190 | xxhash==3.4.1 191 | yarl==1.9.4 192 | zstandard==0.19.0 193 | 194 | -------------------------------------------------------------------------------- /run_wiz.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | git clone https://github.com/EleutherAI/lm-evaluation-harness 3 | cd lm-evaluation-harness 4 | pip install -e . 5 | cd ../ 6 | pip install huggingface_hub 7 | python download.py 8 | cd tinyllama 9 | wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors?download=true 10 | cd ../ 11 | python main.py --global_model 'tinyllama' --data_path "./data_wiz" --output_dir './nips-tinyllama-full-wiz-1-1-10/' --num_communication_rounds 1 --local_num_epochs 1 --full True 12 | python main.py --global_model 'llama-13b' --data_path "./data_wiz" --output_dir './nips-llama13b-full-wiz-1-3-10/' --num_communication_rounds 1 --local_num_epochs 3 --full True 13 | python main.py --global_model 'llama-13b' --data_path "./data_wiz" --output_dir './nips-llama13b-full-wiz-3-1-10/' --num_communication_rounds 3 --local_num_epochs 1 --full True 14 | lm_eval --model_args pretrained=./nips-llama13b-full-wiz-1-3-10/10/0/,parallelize=True,load_in_4bit=False, --tasks arc_challenge --num_fewshot 25 --batch_size 16 --output_path ./nips-llama13b-full-wiz-1-3-10/10/0/ 15 | lm_eval --model_args pretrained=./nips-llama13b-full-wiz-3-1-10/10/2/,parallelize=True,load_in_4bit=False, --tasks arc_challenge --num_fewshot 25 --batch_size 16 --output_path ./nips-llama13b-full-wiz-3-1-10/10/2/ 16 | -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | # Prompt templates 2 | 3 | This directory contains template styles for the prompts used to finetune LoRA models. 4 | 5 | ## Format 6 | 7 | A template is described via a JSON file with the following keys: 8 | 9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. 10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. 11 | - `description`: A short description of the template, with possible use cases. 12 | - `response_split`: The text to use as separator when cutting real response from the model output. 13 | 14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. 15 | 16 | ## Example template 17 | 18 | The default template, used unless otherwise specified, is `alpaca.json` 19 | 20 | ```json 21 | { 22 | "description": "Template used by Alpaca-LoRA.", 23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 25 | "response_split": "### Response:" 26 | } 27 | 28 | ``` 29 | 30 | ## Current templates 31 | 32 | ### alpaca 33 | 34 | Default template used for generic LoRA fine tunes so far. 35 | 36 | ### alpaca_legacy 37 | 38 | Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments. 39 | 40 | ### alpaca_short 41 | 42 | A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome. 43 | 44 | ### vigogne 45 | 46 | The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning. 47 | -------------------------------------------------------------------------------- /templates/alpaca.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/alpaca_legacy.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Legacy template, used by Original Alpaca repository.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/alpaca_short.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "A shorter template to experiment with.", 3 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/vigogne.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "French template, used by Vigogne for finetuning.", 3 | "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n", 4 | "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n", 5 | "response_split": "### Réponse:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/xx: -------------------------------------------------------------------------------- 1 | c 2 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Directory for helpers modules 2 | 3 | ## prompter.py 4 | 5 | Prompter class, a template manager. 6 | 7 | `from utils.prompter import Prompter` 8 | 9 | ## callbacks.py 10 | 11 | Helpers to support streaming generate output. 12 | 13 | `from utils.callbacks import Iteratorize, Stream` 14 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ATP-1010/FederatedLLM/7bea15826bc0de37da35e44bc34a39274e3cc09f/utils/__init__.py -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to support streaming generate output. 3 | Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py 4 | """ 5 | 6 | import gc 7 | import traceback 8 | from queue import Queue 9 | from threading import Thread 10 | 11 | import torch 12 | import transformers 13 | 14 | 15 | class Stream(transformers.StoppingCriteria): 16 | def __init__(self, callback_func=None): 17 | self.callback_func = callback_func 18 | 19 | def __call__(self, input_ids, scores) -> bool: 20 | if self.callback_func is not None: 21 | self.callback_func(input_ids[0]) 22 | return False 23 | 24 | 25 | class Iteratorize: 26 | 27 | """ 28 | Transforms a function that takes a callback 29 | into a lazy iterator (generator). 30 | """ 31 | 32 | def __init__(self, func, kwargs={}, callback=None): 33 | self.mfunc = func 34 | self.c_callback = callback 35 | self.q = Queue() 36 | self.sentinel = object() 37 | self.kwargs = kwargs 38 | self.stop_now = False 39 | 40 | def _callback(val): 41 | if self.stop_now: 42 | raise ValueError 43 | self.q.put(val) 44 | 45 | def gentask(): 46 | try: 47 | ret = self.mfunc(callback=_callback, **self.kwargs) 48 | except ValueError: 49 | pass 50 | except: 51 | traceback.print_exc() 52 | pass 53 | 54 | self.q.put(self.sentinel) 55 | if self.c_callback: 56 | self.c_callback(ret) 57 | 58 | self.thread = Thread(target=gentask) 59 | self.thread.start() 60 | 61 | def __iter__(self): 62 | return self 63 | 64 | def __next__(self): 65 | obj = self.q.get(True, None) 66 | if obj is self.sentinel: 67 | raise StopIteration 68 | else: 69 | return obj 70 | 71 | def __enter__(self): 72 | return self 73 | 74 | def __exit__(self, exc_type, exc_val, exc_tb): 75 | self.stop_now = True 76 | -------------------------------------------------------------------------------- /utils/prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "alpaca" 18 | file_name = osp.join("templates", f"{template_name}.json") 19 | if not osp.exists(file_name): 20 | raise ValueError(f"Can't read {file_name}") 21 | with open(file_name) as fp: 22 | self.template = json.load(fp) 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | ) -> str: 34 | # returns the full prompt from instruction and optional input 35 | # if a label (=response, =output) is provided, it's also appended. 36 | if input: 37 | res = self.template["prompt_input"].format( 38 | instruction=instruction, input=input 39 | ) 40 | else: 41 | res = self.template["prompt_no_input"].format( 42 | instruction=instruction 43 | ) 44 | if label: 45 | res = f"{res}{label}" 46 | if self._verbose: 47 | print(res) 48 | return res 49 | 50 | def get_response(self, output: str) -> str: 51 | return output.split(self.template["response_split"])[1].strip() 52 | -------------------------------------------------------------------------------- /utils/ss: -------------------------------------------------------------------------------- 1 | c 2 | --------------------------------------------------------------------------------