├── .gitignore ├── assets └── intro.png ├── requirements.txt ├── utils.py ├── models.py ├── README.md ├── fed_agg.py ├── fed_train_glue.py ├── fed_train_e2e.py ├── data_utils.py └── train_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CERT-Lab/fedex-lora/HEAD/assets/intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.20.0 2 | future==0.18.3 3 | matplotlib==3.7.2 4 | nltk==3.9.1 5 | numpy==1.24.3 6 | pandas==2.0.3 7 | peft==0.12.0 8 | rouge_score==0.1.2 9 | scikit_learn 10 | scipy 11 | scikit-image 12 | torch==2.3.1 13 | tqdm==4.66.4 14 | transformers==4.44.2 15 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import json 5 | 6 | 7 | def tensor_to_list(obj): 8 | if isinstance(obj, torch.Tensor): 9 | return obj.detach().cpu().numpy().tolist() 10 | elif isinstance(obj, dict): 11 | return {k: tensor_to_list(v) for k, v in obj.items()} 12 | elif isinstance(obj, list): 13 | return [tensor_to_list(v) for v in obj] 14 | else: 15 | return obj 16 | 17 | 18 | def save_dict_to_json(data_dict, args, base_path): 19 | # Create a timestamp for the filename 20 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 21 | filename = f"compare_dict_rounds_{timestamp}.json" 22 | file_path = os.path.join(base_path, filename) 23 | 24 | # Ensure the directory exists 25 | os.makedirs(base_path, exist_ok=True) 26 | 27 | # Combine data_dict and args 28 | combined_dict = {"args": vars(args), "data": data_dict} 29 | 30 | # Convert tensors to lists 31 | json_serializable_dict = tensor_to_list(combined_dict) 32 | 33 | # Write JSON data to the file 34 | with open(file_path, "w") as json_file: 35 | json.dump(json_serializable_dict, json_file, indent=2) 36 | 37 | print(f"Data and args saved to {file_path}") 38 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW 4 | from transformers import ( 5 | AutoTokenizer, 6 | AutoModelForCausalLM, 7 | TrainingArguments, 8 | BitsAndBytesConfig, 9 | ) 10 | from datasets import load_dataset 11 | from tqdm import tqdm 12 | import numpy as np 13 | from peft import ( 14 | get_peft_model, 15 | AdaLoraModel, 16 | AdaLoraConfig, 17 | TaskType, 18 | LoraConfig, 19 | prepare_model_for_kbit_training, 20 | ) 21 | from data_utils import * 22 | import argparse 23 | from copy import deepcopy 24 | 25 | 26 | def create_peft_model(num_labels, args): 27 | 28 | model = RobertaForSequenceClassification.from_pretrained( 29 | args.model, num_labels=num_labels 30 | ) 31 | 32 | peft_config = LoraConfig( 33 | task_type=TaskType.SEQ_CLS, 34 | r=args.lora_r, 35 | lora_alpha=args.lora_alpha, 36 | lora_dropout=args.lora_dropout, 37 | use_rslora=args.rslora, 38 | target_modules=["query", "value"], 39 | ) 40 | 41 | model = get_peft_model(model, peft_config) 42 | 43 | return model 44 | 45 | 46 | def create_peft_FFA_model(num_labels, args): 47 | 48 | model = RobertaForSequenceClassification.from_pretrained( 49 | args.model, num_labels=num_labels 50 | ) 51 | 52 | peft_config = LoraConfig( 53 | task_type=TaskType.SEQ_CLS, 54 | r=args.lora_r, 55 | lora_alpha=args.lora_alpha, 56 | lora_dropout=args.lora_dropout, 57 | use_rslora=args.rslora, 58 | target_modules=["query", "value"], 59 | ) 60 | model = get_peft_model(model, peft_config) 61 | 62 | # Make LoRA A matrices non-trainable 63 | for name, param in model.named_parameters(): 64 | if "lora_A" in name: 65 | param.requires_grad = False 66 | 67 | return model 68 | 69 | 70 | def create_peft_gpt2_model_e2e(args): 71 | model = GPT2LMHeadModel.from_pretrained("gpt2") 72 | 73 | # Define LoRA configuration for language modeling task 74 | lora_config = LoraConfig( 75 | task_type=TaskType.CAUSAL_LM, # For language modeling 76 | inference_mode=False, 77 | r=args.lora_r, # The dimension of the low-rank update matrices 78 | lora_alpha=args.lora_alpha, # The scaling factor for LoRA layers 79 | lora_dropout=args.lora_dropout, # Dropout to apply to LoRA layers 80 | target_modules=["c_attn", "c_proj"], # Modules to apply LoRA 81 | ) 82 | 83 | # Apply LoRA to the GPT-2 model 84 | model = get_peft_model(model, lora_config) 85 | return model 86 | 87 | 88 | def create_peft_gpt2_model_e2e_ffa(args): 89 | model = GPT2LMHeadModel.from_pretrained("gpt2") 90 | 91 | # Define LoRA configuration for language modeling task 92 | lora_config = LoraConfig( 93 | task_type=TaskType.CAUSAL_LM, # For language modeling 94 | inference_mode=False, 95 | r=args.lora_r, # The dimension of the low-rank update matrices 96 | lora_alpha=args.lora_alpha, # The scaling factor for LoRA layers 97 | lora_dropout=args.lora_dropout, # Dropout to apply to LoRA layers 98 | target_modules=["c_attn", "c_proj"], # Modules to apply LoRA 99 | ) 100 | 101 | for name, param in model.named_parameters(): 102 | if "lora_A" in name: 103 | param.requires_grad = False 104 | 105 | # Apply LoRA to the GPT-2 model 106 | model = get_peft_model(model, lora_config) 107 | return model 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedEx-LoRA: Exact Aggregation for Federated and Efficient Fine-Tuning of Foundation Models 2 | 3 | Code for the paper: [FedEx-LoRA: Exact Aggregation for Federated and Efficient Fine-Tuning of Foundation Models](https://arxiv.org/abs/2410.09432). Accepted to ACL Main (Oral). 4 | 5 | ## Introduction 6 | 7 | Low-Rank Adaptation (LoRA) is a popular technique for efficient fine-tuning of foundation models. However, applying LoRA in federated learning environments, where data is distributed across multiple clients, presents unique challenges. Existing methods rely on traditional federated averaging of LoRA adapters, resulting in inexact updates. To address this, we propose Federated Exact LoRA, or FedEx-LoRA, which adds a residual error term to the pretrained frozen weight matrix. Our approach achieves exact updates with minimal computational and communication overhead, preserving LoRA's efficiency. We evaluate the method on various models across arithmetic reasoning, commonsense reasoning, natural language understanding and natural language generation tasks, showing consistent performance gains over state-of-the-art methods across multiple settings. Through extensive analysis, we quantify that the deviations in updates from the ideal solution are significant, highlighting the need for exact aggregation. Our method's simplicity, efficiency, and broad applicability position it as a promising solution for accurate and effective federated fine-tuning of foundation models. 8 | 9 | ![FedEx-LoRA Arch](assets/intro.png) 10 | 11 | Comparison of federated LoRA methods: (a) FedIT averages the individual client low-rank adapters $A_i$ and $B_i$, resulting in inexact updates. (b) FedEx-LoRA sends the error residual $\Delta W_{res}$ along with the individual adapters $A_i$ and $B_i$, which is added to the pretrained weight matrix $W_0$, ensuring exact aggregation. Clients transmit low-rank adapters $A_i$ and $B_i$ in both methods. 12 | 13 | 14 | ## Environment 15 | We recommend using a Conda environment to run the Python scripts for this project. Follow these commands to set up the environment and install the required libraries: 16 | ``` 17 | conda create -n fedex-lora python=3.10 18 | conda activate fedex-lora 19 | pip install -r requirements.txt 20 | 21 | ``` 22 | 23 | ## Natural Language Understanding 24 | 25 | ``` 26 | CUDA_VISIBLE_DEVICES={device_indices} python3 fed_train_glue.py --model=roberta_base --task=cola --agg_type=ours --num_clients=3 --lora_r=4 --rounds 50 --lr 1e-3 --local_epochs 3 27 | ``` 28 | - Task: `cola`, `mrpc`, `rte`, `stsb`, `sst2`, `qnli` 29 | - Model: `roberta-base`, `roberta-large` 30 | - LoRA rank: Set `lora_r` 31 | 32 | ## Natural Language Generation 33 | 34 | ``` 35 | CUDA_VISIBLE_DEVICES={device_indices} python3 fed_train_e2e_new.py --agg_type=ours --log --lora_r=4 --task=e2e --lr=2e-3 --num_clients=3 --local_epochs=5 36 | ``` 37 | - LoRA rank: Set `lora_r` 38 | 39 | Here is [the code](https://github.com/tuetschek/e2e-metrics) the code for evaluating E2E. 40 | 41 | ## Citation 42 | 43 | If you use our work for your research, please cite our paper: 44 | 45 | ``` 46 | @article{singhal2024fedex, 47 | title={Fedex-lora: Exact aggregation for federated and efficient fine-tuning of foundation models}, 48 | author={Singhal, Raghav and Ponkshe, Kaustubh and Vepakomma, Praneeth}, 49 | journal={arXiv preprint arXiv:2410.09432}, 50 | year={2024} 51 | } 52 | 53 | @article{singhal2025fed, 54 | title={Fed-SB: A silver bullet for extreme communication efficiency and performance in (private) federated lora fine-tuning}, 55 | author={Singhal, Raghav and Ponkshe, Kaustubh and Vartak, Rohit and Varshney, Lav R and Vepakomma, Praneeth}, 56 | journal={arXiv preprint arXiv:2502.15436}, 57 | year={2025} 58 | } 59 | 60 | ``` 61 | -------------------------------------------------------------------------------- /fed_agg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import ( 4 | RobertaTokenizer, 5 | RobertaForSequenceClassification, 6 | AdamW, 7 | get_linear_schedule_with_warmup, 8 | ) 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | import numpy as np 12 | from peft import get_peft_model, LoraConfig, TaskType 13 | from data_utils import * 14 | from models import * 15 | from sklearn.metrics import matthews_corrcoef 16 | import numpy as np 17 | import torch.nn as nn 18 | 19 | 20 | def aggregate_models_normal(global_model, client_models): 21 | 22 | global_dict = global_model.state_dict() 23 | for k in global_dict.keys(): 24 | if "lora" in k: # Only aggregate LoRA parameters 25 | global_dict[k] = torch.stack( 26 | [client_models[i][k].float() for i in range(len(client_models))], 0 27 | ).mean(0) 28 | 29 | if "classifier" in k: 30 | global_dict[k] = torch.stack( 31 | [client_models[i][k].float() for i in range(len(client_models))], 0 32 | ).mean(0) 33 | 34 | global_model.load_state_dict(global_dict) 35 | 36 | return global_model 37 | 38 | 39 | def aggregate_models_ffa(global_model, client_models): 40 | 41 | global_dict = global_model.state_dict() 42 | for k in global_dict.keys(): 43 | if "lora_B" in k: # Only aggregate LoRA B parameters 44 | global_dict[k] = torch.stack( 45 | [client_models[i][k].float() for i in range(len(client_models))], 0 46 | ).mean(0) 47 | 48 | if "classifier" in k: 49 | global_dict[k] = torch.stack( 50 | [client_models[i][k].float() for i in range(len(client_models))], 0 51 | ).mean(0) 52 | 53 | global_model.load_state_dict(global_dict) 54 | 55 | return global_model 56 | 57 | 58 | def aggregate_models_ours(global_model, client_models, args): 59 | 60 | global_model = ( 61 | global_model.to("cuda") if torch.cuda.is_available() else global_model 62 | ) 63 | global_dict = global_model.state_dict() 64 | 65 | for k in global_dict.keys(): 66 | 67 | if "classifier" in k: 68 | global_dict[k] = torch.stack( 69 | [client_models[i][k].float() for i in range(len(client_models))], 0 70 | ).mean(0) 71 | 72 | for client_model in client_models: 73 | 74 | for k in global_dict.keys(): 75 | 76 | if "classifier" in k: 77 | client_model[k] = global_dict[k] 78 | 79 | for name, module in global_model.named_modules(): 80 | 81 | if hasattr(module, "lora_A") and hasattr(module, "lora_B"): 82 | 83 | lora_A_keys = name + ".lora_A.default.weight" 84 | lora_B_keys = name + ".lora_B.default.weight" 85 | base_layer_keys = name + ".base_layer.weight" 86 | 87 | lora_A_weights = torch.stack( 88 | [client_model[lora_A_keys].detach() for client_model in client_models] 89 | ) 90 | lora_B_weights = torch.stack( 91 | [client_model[lora_B_keys].detach() for client_model in client_models] 92 | ) 93 | 94 | # M shape: (d, k) 95 | M = sum( 96 | lora_B_weights[i] @ lora_A_weights[i] for i in range(len(client_models)) 97 | ) / len(client_models) 98 | 99 | lora_A_avg = lora_A_weights.mean(0) 100 | lora_B_avg = lora_B_weights.mean(0) 101 | 102 | scaling_factor = ( 103 | args.lora_alpha / np.sqrt(args.lora_r) 104 | if args.rslora 105 | else args.lora_alpha / args.lora_r 106 | ) 107 | 108 | residue = M - lora_B_avg @ lora_A_avg 109 | 110 | global_dict[name + ".lora_A.default.weight"] = lora_A_avg 111 | global_dict[name + ".lora_B.default.weight"] = lora_B_avg 112 | global_dict[name + ".base_layer.weight"] += torch.transpose( 113 | residue * scaling_factor, 1, 0 114 | ) 115 | 116 | global_model.load_state_dict(global_dict) 117 | 118 | return global_model 119 | -------------------------------------------------------------------------------- /fed_train_glue.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import ( 4 | RobertaTokenizer, 5 | RobertaForSequenceClassification, 6 | AdamW, 7 | get_linear_schedule_with_warmup, 8 | ) 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | import numpy as np 12 | from peft import get_peft_model, LoraConfig, TaskType 13 | from data_utils import * 14 | from models import * 15 | import argparse 16 | import warnings 17 | import os 18 | from datetime import datetime 19 | import numpy as np 20 | import wandb 21 | from train_eval import * 22 | from fed_agg import * 23 | import json 24 | from utils import * 25 | 26 | parser = argparse.ArgumentParser(description="Federated Learning with LoRA") 27 | 28 | parser.add_argument( 29 | "--task", type=str, default="cola", help="GLUE task to fine-tune on" 30 | ) 31 | parser.add_argument("--model", type=str, default="roberta-base", help="Model name") 32 | parser.add_argument("--lora_r", type=int, default=4, help="LoRA R value") 33 | parser.add_argument("--lora_alpha", type=int, default=8, help="LoRA alpha value") 34 | parser.add_argument( 35 | "--lora_dropout", type=float, default=0.1, help="LoRA dropout value" 36 | ) 37 | parser.add_argument("--rslora", action="store_true", help="Use RSLoRA") 38 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size") 39 | parser.add_argument( 40 | "--agg_type", type=str, default="ours", help="Type of aggregation" 41 | ) 42 | parser.add_argument("--num_clients", type=int, default=3, help="Number of clients") 43 | parser.add_argument("--rounds", type=int, default=50, help="Number of rounds") 44 | parser.add_argument( 45 | "--local_epochs", type=int, default=3, help="Number of local epochs" 46 | ) 47 | parser.add_argument("--warmup_ratio", type=float, default=0.06, help="Warmup ratio") 48 | parser.add_argument( 49 | "--max_seq_length", type=int, default=512, help="Maximum sequence length" 50 | ) 51 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") 52 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 53 | 54 | args = parser.parse_args() 55 | 56 | wandb.init(project="project_name", config=args) 57 | 58 | np.random.seed(args.seed) 59 | torch.manual_seed(args.seed) 60 | torch.cuda.manual_seed_all(args.seed) 61 | 62 | 63 | def federated_learning(task): 64 | 65 | train_data, val_data, test_data = load_and_preprocess_data(task) 66 | 67 | num_labels = len(set(train_data["labels"])) 68 | 69 | if args.task == "stsb": 70 | num_labels = 1 71 | 72 | client_dataloaders = create_client_dataloaders(train_data, args) 73 | val_dataloader = create_dataloader(val_data, args) 74 | 75 | max_metric_1 = 0 76 | max_metric_2 = 0 77 | 78 | if args.agg_type == "ffa": 79 | global_model = create_peft_FFA_model(num_labels, args) 80 | else: 81 | global_model = create_peft_model(num_labels, args) 82 | 83 | client_models = [] 84 | 85 | for i in range(args.num_clients): 86 | 87 | if args.agg_type == "ffa": 88 | client_model = create_peft_FFA_model(num_labels, args) 89 | else: 90 | client_model = create_peft_model(num_labels, args) 91 | 92 | client_models.append(client_model) 93 | 94 | for round in range(args.rounds): 95 | print(f"Round {round + 1}/{args.rounds}") 96 | 97 | client_model_state_dicts = [] 98 | for i in range(args.num_clients): 99 | client_model = client_models[i] 100 | client_model.load_state_dict(global_model.state_dict()) 101 | client_model_state_dict = train_client( 102 | client_model, client_dataloaders[i], args 103 | ) 104 | client_model_state_dicts.append(client_model_state_dict) 105 | 106 | if args.agg_type == "normal": 107 | global_model = aggregate_models_normal(global_model, client_models) 108 | elif args.agg_type == "ours": 109 | global_model = aggregate_models_ours(global_model, client_models, args) 110 | elif args.agg_type == "ffa": 111 | global_model = aggregate_models_ffa(global_model, client_models) 112 | 113 | max_metric_1, max_metric_2 = evaluate_global_model( 114 | global_model, val_dataloader, args, max_metric_1, max_metric_2 115 | ) 116 | 117 | 118 | # Main execution 119 | if __name__ == "__main__": 120 | task = args.task 121 | model = federated_learning(task) 122 | -------------------------------------------------------------------------------- /fed_train_e2e.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import ( 4 | RobertaTokenizer, 5 | RobertaForSequenceClassification, 6 | AdamW, 7 | get_linear_schedule_with_warmup, 8 | ) 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | import numpy as np 12 | from peft import get_peft_model, LoraConfig, TaskType 13 | from data_utils import * 14 | from models import * 15 | import argparse 16 | import warnings 17 | from sklearn.metrics import matthews_corrcoef 18 | import numpy as np 19 | import wandb 20 | from train_eval import * 21 | from fed_agg import * 22 | import warnings 23 | import json 24 | 25 | parser = argparse.ArgumentParser(description="Federated Learning with LoRA") 26 | 27 | parser.add_argument("--agg_type", type=str, default="ours", help="Type of aggregation") 28 | parser.add_argument("--rounds", type=int, default=6, help="Number of rounds") 29 | parser.add_argument("--num_clients", type=int, default=3, help="Number of clients") 30 | parser.add_argument( 31 | "--local_epochs", type=int, default=3, help="Number of local epochs" 32 | ) 33 | parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") 34 | parser.add_argument("--lora_r", type=int, default=4, help="LoRA R value") 35 | parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha value") 36 | parser.add_argument( 37 | "--lora_dropout", type=float, default=0.1, help="LoRA dropout value" 38 | ) 39 | parser.add_argument("--rslora", action="store_true", help="Use RSLoRA") 40 | parser.add_argument("--batch_size", type=int, default=8, help="Batch size") 41 | parser.add_argument("--warmup_ratio", type=float, default=0.06, help="Warmup ratio") 42 | parser.add_argument( 43 | "--max_seq_length", type=int, default=128, help="Maximum sequence length" 44 | ) 45 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 46 | parser.add_argument("--device", type=str, default="cuda", help="Device to train on") 47 | parser.add_argument("--idx", type=int, default=0, help="Index of the save folder") 48 | parser.add_argument("--log", action="store_true", help="Log the results") 49 | parser.add_argument("--run_dir", type=str, help="Directory to store logs") 50 | 51 | args = parser.parse_args() 52 | 53 | wandb.init(project="project_name", config=args) 54 | 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed_all(args.seed) 58 | 59 | warnings.filterwarnings("ignore") 60 | 61 | 62 | def get_next_run_number(base_dir): 63 | if not os.path.exists(base_dir): 64 | os.makedirs(base_dir) 65 | return 1 66 | 67 | existing_runs = [int(d) for d in os.listdir(base_dir) if d.isdigit()] 68 | return max(existing_runs, default=0) + 1 69 | 70 | 71 | def save_args(args, directory): 72 | args_file = os.path.join(directory, "args.json") 73 | with open(args_file, "w") as f: 74 | json.dump(vars(args), f, indent=2) 75 | 76 | 77 | def federated_learning(task): 78 | 79 | train_data, val_data, test_data, tokenizer = create_e2e_data() 80 | client_data = create_client_dataloaders_nlg(train_data, args) 81 | 82 | if args.agg_type == "ffa": 83 | global_model = create_peft_gpt2_model_e2e_ffa(args) 84 | else: 85 | global_model = create_peft_gpt2_model_e2e(args) 86 | 87 | for round in range(args.rounds): 88 | print(f"Round {round + 1}/{args.rounds}") 89 | 90 | # Train on selected clients 91 | client_models = [] 92 | for client in range(args.num_clients): 93 | 94 | if args.agg_type == "ffa": 95 | client_model = create_peft_gpt2_model_e2e_ffa(args) 96 | else: 97 | client_model = create_peft_gpt2_model_e2e(args) 98 | 99 | client_model.load_state_dict(global_model.state_dict()) 100 | client_model = train_client_e2e( 101 | client_model, client_data[client], val_data, tokenizer, args 102 | ) 103 | client_models.append(client_model) 104 | 105 | if args.agg_type == "normal": 106 | global_model = aggregate_models_normal(global_model, client_models) 107 | elif args.agg_type == "ours": 108 | global_model = aggregate_models_ours(global_model, client_models, args) 109 | elif args.agg_type == "ffa": 110 | global_model = aggregate_models_ffa(global_model, client_models) 111 | 112 | args.idx = round + 1 113 | 114 | if args.log: 115 | base_dir = "text_store_new/" + args.agg_type 116 | run_number = get_next_run_number(base_dir) 117 | run_dir = os.path.join(base_dir, str(run_number)) 118 | os.makedirs(run_dir) 119 | save_args(args, run_dir) 120 | args.run_dir = run_dir 121 | 122 | evaluate_e2e_save_text(global_model, test_data, tokenizer, args) 123 | 124 | return global_model 125 | 126 | 127 | # Main execution 128 | if __name__ == "__main__": 129 | task = "e2e" 130 | model = federated_learning(task) 131 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW 4 | from datasets import load_dataset 5 | from torch.utils.data import Dataset, DataLoader, Subset 6 | from transformers import ( 7 | GPT2Tokenizer, 8 | GPT2LMHeadModel, 9 | AdamW, 10 | get_linear_schedule_with_warmup, 11 | ) 12 | from tqdm import tqdm 13 | import numpy as np 14 | import pandas as pd 15 | from peft import get_peft_model, LoraConfig, TaskType 16 | 17 | 18 | def load_and_preprocess_data(task): 19 | 20 | if "mnli" in task: 21 | dataset = load_dataset("glue", "mnli") 22 | else: 23 | dataset = load_dataset("glue", task) 24 | 25 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 26 | 27 | def tokenize_function(examples): 28 | 29 | # Handle different input formats 30 | if "premise" in examples and "hypothesis" in examples: 31 | # MNLI and similar tasks 32 | return tokenizer( 33 | examples["premise"], 34 | examples["hypothesis"], 35 | truncation=True, 36 | padding="max_length", 37 | max_length=128, 38 | ) 39 | elif "question" in examples and "sentence" in examples: 40 | # QNLI and similar tasks 41 | return tokenizer( 42 | examples["question"], 43 | examples["sentence"], 44 | truncation=True, 45 | padding="max_length", 46 | max_length=128, 47 | ) 48 | elif "sentence1" in examples and "sentence2" in examples: 49 | # MRPC, STS-B 50 | return tokenizer( 51 | examples["sentence1"], 52 | examples["sentence2"], 53 | truncation=True, 54 | padding="max_length", 55 | max_length=128, 56 | ) 57 | elif "question1" in examples and "question2" in examples: 58 | # QQP 59 | return tokenizer( 60 | examples["question1"], 61 | examples["question2"], 62 | truncation=True, 63 | padding="max_length", 64 | max_length=128, 65 | ) 66 | elif "sentence" in examples: 67 | # CoLA, SST-2 68 | return tokenizer( 69 | examples["sentence"], 70 | truncation=True, 71 | padding="max_length", 72 | max_length=128, 73 | ) 74 | else: 75 | raise ValueError(f"Unexpected format for task {task}") 76 | 77 | tokenized_datasets = dataset.map(tokenize_function, batched=True) 78 | 79 | if task == "cola": 80 | tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"]) 81 | elif task == "sst2": 82 | tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"]) 83 | elif task == "mrpc": 84 | tokenized_datasets = tokenized_datasets.remove_columns( 85 | ["sentence1", "sentence2", "idx"] 86 | ) 87 | elif task == "qqp": 88 | tokenized_datasets = tokenized_datasets.remove_columns( 89 | ["question1", "question2", "idx"] 90 | ) 91 | elif task == "stsb": 92 | tokenized_datasets = tokenized_datasets.remove_columns( 93 | ["sentence1", "sentence2", "idx"] 94 | ) 95 | elif task == "qnli": 96 | tokenized_datasets = tokenized_datasets.remove_columns( 97 | ["question", "sentence", "idx"] 98 | ) 99 | elif task == "rte": 100 | tokenized_datasets = tokenized_datasets.remove_columns( 101 | ["sentence1", "sentence2", "idx"] 102 | ) 103 | elif task == "wnli": 104 | tokenized_datasets = tokenized_datasets.remove_columns( 105 | ["sentence1", "sentence2", "idx"] 106 | ) 107 | elif task == "mnli_matched" or task == "mnli_mismatched" or task == "mnli": 108 | tokenized_datasets = tokenized_datasets.remove_columns( 109 | ["premise", "hypothesis", "idx"] 110 | ) 111 | else: 112 | raise ValueError(f"Unexpected task {task}") 113 | 114 | tokenized_datasets = tokenized_datasets.rename_column("label", "labels") 115 | tokenized_datasets.set_format("torch") 116 | 117 | if ( 118 | task == "cola" 119 | or task == "sst2" 120 | or task == "mrpc" 121 | or task == "qqp" 122 | or task == "stsb" 123 | or task == "qnli" 124 | or task == "rte" 125 | or task == "wnli" 126 | ): 127 | train_dataset = tokenized_datasets["train"] 128 | val_dataset = tokenized_datasets["validation"] 129 | test_dataset = tokenized_datasets["test"] 130 | elif task == "mnli_matched": 131 | train_dataset = tokenized_datasets["train"] 132 | val_dataset = tokenized_datasets["validation_matched"] 133 | test_dataset = tokenized_datasets["test_matched"] 134 | elif task == "mnli_mismatched": 135 | train_dataset = tokenized_datasets["train"] 136 | val_dataset = tokenized_datasets["validation_mismatched"] 137 | test_dataset = tokenized_datasets["test_mismatched"] 138 | 139 | return train_dataset, val_dataset, test_dataset 140 | 141 | 142 | def create_dataloader(dataset, args): 143 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=False) 144 | 145 | 146 | def create_client_dataloaders_nlg(dataset, args): 147 | client_data = [[] for _ in range(args.num_clients)] 148 | for data in dataset: 149 | client_idx = np.random.randint(args.num_clients) 150 | client_data[client_idx].append(data) 151 | return client_data 152 | 153 | 154 | def create_client_dataloaders(dataset, args): 155 | client_data = [[] for _ in range(args.num_clients)] 156 | for data in dataset: 157 | client_idx = np.random.randint(args.num_clients) 158 | client_data[client_idx].append(data) 159 | return [ 160 | DataLoader(cd, batch_size=args.batch_size, shuffle=True) for cd in client_data 161 | ] 162 | 163 | 164 | def create_e2e_data(): 165 | def preprocess_function(examples): 166 | inputs = examples["meaning_representation"] 167 | targets = examples["human_reference"] 168 | 169 | # Combine the input-output pair into a single text 170 | model_inputs = [ 171 | f"{input_} -> {target} <|endoftext|>" 172 | for input_, target in zip(inputs, targets) 173 | ] 174 | only_inputs = [f"{input_} ->" for input_, target in zip(inputs, targets)] 175 | 176 | # Tokenize the combined inputs 177 | tokenized_inputs = tokenizer( 178 | model_inputs, 179 | max_length=512, 180 | padding="max_length", 181 | truncation=True, 182 | return_tensors="pt", 183 | ) 184 | tokenized_only_inputs = tokenizer( 185 | only_inputs, 186 | max_length=512, 187 | padding="max_length", 188 | truncation=True, 189 | return_tensors="pt", 190 | ) 191 | 192 | # Labels are the same as input_ids but shift them for next-token prediction 193 | tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone() 194 | 195 | # Set the labels to -100 where attention mask is 0 (this will ignore padding in loss computation) 196 | tokenized_inputs["labels"][tokenized_inputs["attention_mask"] == 0] = -100 197 | # set the labels to -100 where meaning representation input ids are present 198 | tokenized_inputs["labels"][tokenized_only_inputs["attention_mask"] == 1] = -100 199 | 200 | return tokenized_inputs 201 | 202 | dataset = load_dataset("tuetschek/e2e_nlg") 203 | from transformers import GPT2Tokenizer 204 | 205 | # Load the GPT-2 tokenizer 206 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 207 | tokenizer.pad_token = ( 208 | tokenizer.eos_token 209 | ) # GPT-2 doesn't have a pad token, so we set it to the eos token 210 | tokenized_datasets = dataset.map(preprocess_function, batched=True) 211 | return ( 212 | tokenized_datasets["train"], 213 | tokenized_datasets["validation"], 214 | tokenized_datasets["test"], 215 | tokenizer, 216 | ) 217 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from transformers import ( 4 | RobertaTokenizer, 5 | RobertaForSequenceClassification, 6 | AdamW, 7 | get_linear_schedule_with_warmup, 8 | ) 9 | from datasets import load_dataset 10 | from tqdm import tqdm 11 | import numpy as np 12 | from peft import get_peft_model, LoraConfig, TaskType 13 | from data_utils import * 14 | from models import * 15 | import argparse 16 | import warnings 17 | from sklearn.metrics import matthews_corrcoef 18 | import numpy as np 19 | import wandb 20 | from torch.cuda.amp import GradScaler, autocast 21 | from sklearn.metrics import matthews_corrcoef, f1_score, accuracy_score 22 | from scipy.stats import pearsonr, spearmanr 23 | import numpy as np 24 | from opacus import PrivacyEngine 25 | from opacus.validators.module_validator import ModuleValidator 26 | import torch 27 | from torch.utils.data import DataLoader 28 | from tqdm import tqdm 29 | from nltk.translate.bleu_score import corpus_bleu 30 | from nltk.translate.nist_score import corpus_nist 31 | from nltk.translate.meteor_score import meteor_score 32 | from rouge_score import rouge_scorer 33 | from pycocoevalcap.cider.cider import Cider 34 | import torch 35 | from datasets import load_dataset 36 | from transformers import get_linear_schedule_with_warmup 37 | from transformers import GPT2LMHeadModel 38 | from peft import get_peft_model, LoraConfig, TaskType 39 | from transformers import Trainer, TrainingArguments 40 | from data_utils import * 41 | import os 42 | from copy import deepcopy 43 | 44 | 45 | def train_client(model, dataloader, args): 46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 47 | model.to(device) 48 | 49 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 50 | total_steps = len(dataloader) * args.local_epochs 51 | num_warmup_steps = int(total_steps * args.warmup_ratio) 52 | scheduler = get_linear_schedule_with_warmup( 53 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps 54 | ) 55 | 56 | scaler = GradScaler() 57 | model.train() 58 | for epoch in range(args.local_epochs): 59 | 60 | for step, data in enumerate(tqdm(dataloader)): 61 | data = {k: v.to(device) for k, v in data.items()} 62 | 63 | with autocast(): 64 | outputs = model(**data) 65 | loss = outputs.loss 66 | 67 | wandb.log({"client_loss": loss.detach().cpu().numpy()}) 68 | 69 | scaler.scale(loss).backward() 70 | scaler.step(optimizer) 71 | scaler.update() 72 | scheduler.step() 73 | optimizer.zero_grad() 74 | 75 | return model.state_dict() 76 | 77 | 78 | def calculate_metrics(all_true_labels, all_predictions, task): 79 | if task == "cola": 80 | return accuracy_score(all_true_labels, all_predictions), matthews_corrcoef( 81 | all_true_labels, all_predictions 82 | ) 83 | elif task in ["sst2", "qnli", "rte", "wnli"]: 84 | return accuracy_score(all_true_labels, all_predictions), None 85 | elif task == "mrpc": 86 | return f1_score(all_true_labels, all_predictions), accuracy_score( 87 | all_true_labels, all_predictions 88 | ) 89 | elif task == "stsb": 90 | return ( 91 | pearsonr(all_true_labels, all_predictions)[0], 92 | spearmanr(all_true_labels, all_predictions)[0], 93 | ) 94 | elif task == "qqp": 95 | return accuracy_score(all_true_labels, all_predictions), f1_score( 96 | all_true_labels, all_predictions 97 | ) 98 | elif task in ["mnli_matched", "mnli_mismatched"]: 99 | return accuracy_score(all_true_labels, all_predictions), None 100 | else: 101 | raise ValueError(f"Unknown task: {task}") 102 | 103 | 104 | def evaluate_global_model(global_model, dataloader, args, max_metric1, max_metric2): 105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 106 | global_model.to(device) 107 | 108 | global_model.eval() 109 | eval_loss = 0 110 | all_predictions = [] 111 | all_true_labels = [] 112 | 113 | for batch in dataloader: 114 | batch = {k: v.to(device) for k, v in batch.items()} 115 | with torch.no_grad(): 116 | 117 | outputs = global_model(**batch) 118 | 119 | eval_loss += outputs.loss.detach().cpu().numpy() 120 | 121 | if args.task == "stsb": 122 | predictions = outputs.logits.squeeze().cpu().numpy() 123 | else: 124 | predictions = outputs.logits.argmax(dim=-1).cpu().numpy() 125 | all_predictions.extend(predictions) 126 | all_true_labels.extend(batch["labels"].cpu().numpy()) 127 | 128 | eval_loss /= len(dataloader) 129 | 130 | # Calculate the metrics for the specific task 131 | metric1, metric2 = calculate_metrics(all_true_labels, all_predictions, args.task) 132 | 133 | if metric1 > max_metric1: 134 | max_metric1 = metric1 135 | 136 | if metric2 is not None and metric2 > max_metric2: 137 | max_metric2 = metric2 138 | 139 | print(f"{args.task} - Eval Loss: {eval_loss:.4f}, Metric 1: {metric1:.4f}") 140 | if metric2 is not None: 141 | print(f"{args.task} - Metric 2: {metric2:.4f}") 142 | print(f"{args.task} - Max Metric 1: {max_metric1:.4f}") 143 | if max_metric2 is not None: 144 | print(f"{args.task} - Max Metric 2: {max_metric2:.4f}") 145 | 146 | wandb.log( 147 | { 148 | f"eval_loss": eval_loss, 149 | f"metric1": metric1, 150 | f"metric2": metric2 if metric2 is not None else 0, 151 | f"max_metric1": max_metric1, 152 | f"max_metric2": max_metric2 if max_metric2 is not None else 0, 153 | } 154 | ) 155 | 156 | return max_metric1, max_metric2 157 | 158 | 159 | def get_lr_scheduler(optimizer, num_warmup_steps, num_training_steps): 160 | return get_linear_schedule_with_warmup( 161 | optimizer, 162 | num_warmup_steps=num_warmup_steps, 163 | num_training_steps=num_training_steps, 164 | ) 165 | 166 | 167 | def train_client_e2e(model, train_dataset, val_dataset, tokenizer, args): 168 | num_epochs = args.local_epochs # or whatever number of epochs you want 169 | per_device_train_batch_size = args.batch_size 170 | num_training_steps = len(train_dataset) * num_epochs // per_device_train_batch_size 171 | num_warmup_steps = int(0.1 * num_training_steps) # 10% of total steps for warmup 172 | 173 | optimizer = torch.optim.AdamW(model.parameters()) 174 | 175 | # Define training arguments 176 | training_args = TrainingArguments( 177 | # Directory to save the model 178 | output_dir="./models_trained/gpt4/dump/models/gpt2-e2e-lora_gpt4", 179 | overwrite_output_dir=True, 180 | logging_dir="./models_trained/gpt4/dump/logs/gpt2-e2e-lora_gpt4", # Directory for logs 181 | per_device_train_batch_size=args.batch_size, # Adjust based on your GPU capacity 182 | per_device_eval_batch_size=args.batch_size, 183 | evaluation_strategy="epoch", # Evaluate every epoch 184 | save_strategy="epoch", 185 | num_train_epochs=num_epochs, # Number of training epochs 186 | learning_rate=args.lr, # Learning rate for LoRA parameters 187 | weight_decay=0.01, 188 | label_smoothing_factor=0.1, 189 | report_to="wandb", 190 | run_name="fed-lora", 191 | logging_steps=100, # Log every 100 steps 192 | ) 193 | 194 | # Initialize the trainer 195 | trainer = Trainer( 196 | model=model, 197 | args=training_args, 198 | train_dataset=train_dataset, 199 | eval_dataset=val_dataset, 200 | tokenizer=tokenizer, 201 | optimizers=( 202 | optimizer, 203 | get_lr_scheduler(optimizer, num_warmup_steps, num_training_steps), 204 | ), 205 | ) 206 | 207 | # Train the model 208 | trainer.train() 209 | return model.state_dict() 210 | 211 | 212 | def gen_and_save(model, dataloader, tokenizer, args): 213 | device = args.device 214 | model.to(device) 215 | model.eval() 216 | 217 | all_predictions = [] 218 | 219 | all_inputs = [] 220 | with torch.no_grad(): 221 | for step, batch in enumerate(tqdm(dataloader)): 222 | 223 | inputs = {k: v.to(device) for k, v in batch.items()} 224 | 225 | # Generate predictions (starting from after the MR) 226 | generated = model.generate( 227 | input_ids=inputs["input_ids"], # Input MR as prompt 228 | attention_mask=inputs["attention_mask"], 229 | max_length=inputs["input_ids"].shape[1] 230 | + 50, # Allow space for generation after MR 231 | num_return_sequences=1, 232 | no_repeat_ngram_size=4, 233 | do_sample=True, 234 | num_beams=10, 235 | penalty_alpha=0.9, 236 | pad_token_id=tokenizer.eos_token_id, # Ensure padding works correctly 237 | ) 238 | # Decode the generated predictions, excluding the input MR tokens 239 | # We slice the generated tokens to remove the input MR part 240 | 241 | input_seq = tokenizer.batch_decode( 242 | inputs["input_ids"], skip_special_tokens=True 243 | ) 244 | predictions = [ 245 | tokenizer.decode( 246 | generated[i][len(inputs["input_ids"][i]) :], 247 | skip_special_tokens=True, 248 | ) 249 | for i in range(generated.shape[0]) 250 | ] 251 | # Collect predictions and references 252 | all_inputs.extend(input_seq) 253 | all_predictions.extend(predictions) 254 | # all_references.extend(references) 255 | 256 | return all_predictions, all_inputs 257 | 258 | 259 | def process_lists(input_list, second_list, third_list): 260 | result1 = [] 261 | result2 = [] 262 | result3 = [] 263 | current_group = [] 264 | current_item = None 265 | second_list_index = 0 266 | 267 | for item in input_list: 268 | if item != current_item: 269 | if current_group: 270 | result1.append(current_group) 271 | result2.append(current_item) 272 | result3.append(third_list[second_list_index - 1]) 273 | current_item = item 274 | current_group = [second_list[second_list_index]] 275 | second_list_index += 1 276 | else: 277 | if second_list_index < len(second_list): 278 | current_group.append(second_list[second_list_index]) 279 | second_list_index += 1 280 | 281 | if current_group: 282 | result1.append(current_group) 283 | 284 | return result1, result2, result3 285 | 286 | 287 | def evaluate_e2e_save_text(model, test_data, tokenizer, args): 288 | 289 | def preprocess_function2(examples): 290 | inputs = examples["meaning_representation"] 291 | targets = examples["human_reference"] 292 | 293 | # Combine the input-output pair into a single text 294 | model_inputs = [f"{input_} ->" for input_, target in zip(inputs, targets)] 295 | 296 | # Tokenize the combined inputs 297 | tokenized_inputs = tokenizer( 298 | model_inputs, 299 | max_length=512, 300 | padding="max_length", 301 | truncation=True, 302 | return_tensors="pt", 303 | ) 304 | 305 | # Labels are the same as input_ids but shift them for next-token prediction 306 | tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone() 307 | 308 | # Set the labels to -100 where attention mask is 0 (this will ignore padding in loss computation) 309 | tokenized_inputs["labels"][tokenized_inputs["attention_mask"] == 0] = -100 310 | 311 | return tokenized_inputs 312 | 313 | tokenized_test_dataset = test_data.map(preprocess_function2, batched=True) 314 | tokenized_test_dataset = tokenized_test_dataset.remove_columns( 315 | ["meaning_representation", "human_reference"] 316 | ) 317 | tokenized_test_dataset.set_format( 318 | type="torch", columns=["input_ids", "attention_mask", "labels"] 319 | ) 320 | 321 | test_dataloader = create_dataloader(tokenized_test_dataset, args) 322 | all_predictions, all_inputs = gen_and_save(model, test_dataloader, tokenizer, args) 323 | all_references = test_data[0 : len(all_predictions)]["human_reference"] 324 | 325 | all_references_new, all_inputs_new, all_predictions_new = process_lists( 326 | all_inputs, all_references, all_predictions 327 | ) 328 | 329 | path_pred = args.run_dir + "/predictions.txt" 330 | path_ref = args.run_dir + "/refs_exact.txt" 331 | 332 | if not os.path.exists(args.run_dir): 333 | os.makedirs(args.run_dir) 334 | 335 | with open(path_pred, "w") as file: 336 | for item in all_predictions_new: 337 | file.write(item.strip() + "\n") 338 | 339 | with open(path_ref, "w") as file: 340 | for str_list in all_references_new: 341 | for item in str_list: 342 | file.write(item.strip() + "\n") 343 | 344 | file.write("\n") 345 | --------------------------------------------------------------------------------