├── __init__.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── generate_weight_matrix.cpython-310.pyc ├── view_split.sh ├── gsm8k_split.sh ├── glue_split.sh ├── generate_weight_matrix.py ├── gsm8k_split.py ├── view_split.py └── glue_split.py ├── finetune ├── __init__.py ├── bert │ ├── __init__.py │ ├── __pycache__ │ │ ├── client.cpython-310.pyc │ │ ├── server.cpython-310.pyc │ │ ├── train.cpython-310.pyc │ │ └── __init__.cpython-310.pyc │ ├── client.sh │ ├── train.sh │ ├── train.py │ ├── server.py │ └── client.py ├── llama │ ├── __init__.py │ ├── client.py │ ├── gsm8k_eval.py │ ├── gsm8k_train.py │ └── server.py └── __pycache__ │ └── __init__.cpython-310.pyc ├── __pycache__ └── __init__.cpython-312.pyc ├── result_processing ├── errorNoniid.m ├── errorBar.m ├── heatmap.py ├── topology.py ├── ConfidenceInterval.m └── avgResult.py ├── requirements.txt ├── README.md └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /finetune/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /finetune/bert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /finetune/llama/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /finetune/llama/client.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/bert/__pycache__/client.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/client.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/bert/__pycache__/server.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/server.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/bert/__pycache__/train.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/train.cpython-310.pyc -------------------------------------------------------------------------------- /finetune/bert/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/generate_weight_matrix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/utils/__pycache__/generate_weight_matrix.cpython-310.pyc -------------------------------------------------------------------------------- /utils/view_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置默认参数 4 | BASE_PATH="/home/ubuntu/smyin/ConLoRA/datasets/decentrilized_dataset/gsm8k" 5 | NUM_CLIENTS=7 6 | DATASET_TYPE="gsm8k" # 或 "gsm8k" 7 | 8 | # 运行 Python 脚本 9 | python3 view_split.py \ 10 | --base_path "$BASE_PATH" \ 11 | --num_clients $NUM_CLIENTS \ 12 | --dataset_type "$DATASET_TYPE" -------------------------------------------------------------------------------- /utils/gsm8k_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置默认参数 4 | DATASET_PATH="/home/ubuntu/smyin/ConLoRA/datasets/gsm8k" 5 | OUTPUT_PATH="/home/ubuntu/smyin/ConLoRA/datasets/decentralized_dataset/gsm8k22222" 6 | NUM_CLIENTS=7 7 | SPLIT_TYPE="dirichlet" 8 | ALPHA=0.25 9 | 10 | # 运行 Python 脚本 11 | python gsm8k_split.py \ 12 | --dataset_path "$DATASET_PATH" \ 13 | --output_path "$OUTPUT_PATH" \ 14 | --num_clients $NUM_CLIENTS \ 15 | --split_type "$SPLIT_TYPE" \ 16 | --alpha $ALPHA 17 | -------------------------------------------------------------------------------- /utils/glue_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置默认参数 4 | DATASET_PATH="/home/ubuntu/smyin/ConLoRA/datasets/glue/qnli" 5 | OUTPUT_PATH="/home/ubuntu/smyin/ConLoRA/datasets/decentrilized_dataset/qnli22222" 6 | NUM_CLIENTS=4 7 | SPLIT_TYPE="dirichlet" 8 | ALPHA=0.25 9 | MIN_SIZE=1500 10 | 11 | # 运行 Python 脚本 12 | python3 glue_split.py \ 13 | --dataset_path "$DATASET_PATH" \ 14 | --output_path "$OUTPUT_PATH" \ 15 | --num_clients $NUM_CLIENTS \ 16 | --split_type "$SPLIT_TYPE" \ 17 | --alpha $ALPHA \ 18 | --min_size $MIN_SIZE 19 | -------------------------------------------------------------------------------- /result_processing/errorNoniid.m: -------------------------------------------------------------------------------- 1 | % 定义数据 2 | x = 1:256; % x 轴数据,例如 1 到 10 3 | % 数据 4 | DLoRA_iid = []; 5 | LoRA_DFL_iid = []; 6 | DLoRA_noniid = []; 7 | LoRA_DFL_noniid = []; 8 | % 绘制折线图 9 | figure; 10 | plot(x, DLoRA_iid, '-', 'LineWidth', 1.5); % 第一组数据,实线 11 | hold on; 12 | plot(x, LoRA_DFL_iid, '-', 'LineWidth', 1.5); % 第二组数据,虚线 13 | plot(x, DLoRA_noniid, '-', 'LineWidth', 1.5); % 第三组数据,点线 14 | plot(x, LoRA_DFL_noniid, '-', 'LineWidth', 1.5); % 第四组数据,点划线 15 | hold off; 16 | 17 | % 设置图例 18 | legend('ConLoRA-IID', 'LoRA-DFL-IID', 'ConLoRA-nonIID', 'LoRA-DFL-nonIID',Location='northwest'); 19 | 20 | % 设置标题和轴标签 21 | %title('Line Plot of Four Data Groups'); 22 | xlabel('Epoch'); 23 | ylabel('Consensus error'); 24 | 25 | % 添加网格 26 | grid on; 27 | 28 | % 调整轴范围 29 | xlim([1 260]); 30 | ylim([0 22]); 31 | -------------------------------------------------------------------------------- /result_processing/errorBar.m: -------------------------------------------------------------------------------- 1 | % 定义数据 2 | network_density = [3, 4, 5, 6]; 3 | error_values_group1 = [12.42, 9.84, 5.37, 4.29]; 4 | error_values_group2 = [1.70, 1.38, 0.82, 0.53]; 5 | 6 | % 将两组数据组合成一个矩阵 7 | error_values = [error_values_group1; error_values_group2]'; 8 | 9 | % 绘制分组柱状图 10 | figure; 11 | b = bar(network_density, error_values, 'grouped'); % 返回柱状图对象 12 | 13 | % 设置轴标签 14 | xlabel('Average node connectivity'); 15 | ylabel('Consensus error'); 16 | 17 | % 添加图例 18 | legend('LoRA-DFL', 'ConLoRA'); 19 | 20 | % 显示网格 21 | grid on; 22 | 23 | % 在每个柱子上显示数值 24 | % 获取柱子的x坐标和高度 25 | for i = 1:length(b) 26 | xData = b(i).XEndPoints; % 获取每组柱子的x坐标 27 | yData = b(i).YData; % 获取柱子的y值 28 | for j = 1:length(xData) 29 | text(xData(j), yData(j), sprintf('%.2f', yData(j)), 'HorizontalAlignment', 'center', 'VerticalAlignment', 'bottom'); 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.32.1 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | boltons 7 | brotlipy==0.7.0 8 | certifi 9 | cffi 10 | charset-normalizer 11 | cmake==3.25.0 12 | contourpy==1.1.0 13 | cryptography 14 | cycler==0.11.0 15 | datasets==2.20.0 16 | dill==0.3.8 17 | filelock==3.15.4 18 | fonttools==4.40.0 19 | frozenlist==1.4.1 20 | fsspec==2024.5.0 21 | huggingface-hub==0.23.4 22 | idna 23 | Jinja2==3.1.2 24 | joblib==1.4.2 25 | jsonpatch 26 | jsonpointer==2.1 27 | kiwisolver==1.4.4 28 | lit==15.0.7 29 | MarkupSafe==2.1.2 30 | matplotlib==3.7.1 31 | mpmath==1.2.1 32 | multidict==6.0.5 33 | multiprocess==0.70.16 34 | networkx==3.0 35 | numpy==1.24.1 36 | packaging 37 | pandas==2.0.2 38 | peft==0.11.1 39 | Pillow==9.3.0 40 | pluggy 41 | psutil==6.0.0 42 | pyarrow==16.1.0 43 | pyarrow-hotfix==0.6 44 | pycosat 45 | pycparser 46 | pyOpenSSL 47 | pyparsing==3.1.0 48 | PySocks 49 | python-dateutil==2.8.2 50 | pytz==2023.3 51 | PyYAML==6.0.1 52 | regex==2024.5.15 53 | requests==2.32.3 54 | ruamel.yaml 55 | ruamel.yaml.clib 56 | safetensors==0.4.3 57 | scikit-learn==1.5.1 58 | scipy==1.14.0 59 | seaborn==0.12.2 60 | six 61 | sympy==1.11.1 62 | threadpoolctl==3.5.0 63 | tokenizers==0.19.1 64 | toolz 65 | torch==2.0.1 66 | torchaudio==2.0.2 67 | torchsummary==1.5.1 68 | torchvision==0.15.2 69 | tqdm==4.66.4 70 | transformers==4.42.3 71 | triton==2.0.0 72 | typing_extensions==4.4.0 73 | tzdata==2023.3 74 | urllib3 75 | xxhash==3.4.1 76 | yarl==1.9.4 77 | zstandard 78 | -------------------------------------------------------------------------------- /result_processing/heatmap.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | 5 | # Example data: Replace with actual data 6 | ConLoRA_result = np.array([ 7 | [0, 0, 0, 0], 8 | [0, 0, 0, 0], 9 | [0, 0, 0, 0], 10 | [0, 0, 0, 0] 11 | ]) 12 | 13 | LoRA_result = np.array([ 14 | [0, 0, 0, 0], 15 | [0, 0, 0, 0], 16 | [0, 0, 0, 0], 17 | [0, 0, 0, 0] 18 | ]) 19 | 20 | # Calculate the min and max values for the common color scale 21 | min_value = min(ConLoRA_result.min(), LoRA_result.min()) 22 | max_value = max(ConLoRA_result.max(), LoRA_result.max()) 23 | 24 | # Create the figure 25 | plt.figure(figsize=(8, 12)) 26 | 27 | # Plot the first heatmap 28 | plt.subplot(2, 1, 1) 29 | ax1 = sns.heatmap( 30 | ConLoRA_result, 31 | annot=True, 32 | fmt=".2f", 33 | cmap="Blues", 34 | cbar_kws={'label': 'Accuracy'}, 35 | vmin=min_value, 36 | vmax=max_value 37 | ) 38 | ax1.set_xlabel('α (Dirichlet Parameters)') 39 | ax1.set_ylabel('Average Connectivity') 40 | ax1.set_xticklabels(['0.1', '0.15', '0.2', '0.25']) 41 | ax1.set_yticklabels(['3', '4', '5', '6']) 42 | 43 | # Plot the second heatmap 44 | plt.subplot(2, 1, 2) 45 | ax2 = sns.heatmap( 46 | LoRA_result, 47 | annot=True, 48 | fmt=".2f", 49 | cmap="Blues", 50 | cbar_kws={'label': 'Accuracy'}, 51 | vmin=min_value, 52 | vmax=max_value 53 | ) 54 | ax2.set_xlabel('α (Dirichlet Parameters)') 55 | ax2.set_ylabel('Average Connectivity') 56 | ax2.set_xticklabels(['0.1', '0.15', '0.2', '0.25']) 57 | ax2.set_yticklabels(['3', '4', '5', '6']) 58 | 59 | # Adjust layout to prevent overlap 60 | plt.tight_layout() 61 | 62 | # Display the plot 63 | plt.show() 64 | 65 | # Save the plot as a PNG file 66 | plt.savefig("123.png") 67 | -------------------------------------------------------------------------------- /finetune/bert/client.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ======================================================================== 4 | # Set default parameters 5 | # ======================================================================== 6 | 7 | # Path to the pre-trained model 8 | MODEL_CHECKPOINT="/home/ubuntu/smyin/models/distilbert-base-uncased" 9 | 10 | # Path to the training dataset 11 | DATASET_PATH="/home/ubuntu/smyin/dataset/decentrilized_dataset/sst2_005/client_1" 12 | 13 | # Path to the validation dataset 14 | VAL_DATASET_PATH="/home/ubuntu/smyin/dataset/glue/sst2" 15 | 16 | # Rank of LoRA layers 17 | LORA_R=4 18 | 19 | # Alpha value for LoRA 20 | LORA_ALPHA=32 21 | 22 | # Training type: "LoRA" or "ConLoRA" 23 | TRAINING_TYPE="ConLoRA" 24 | 25 | # List of target modules for LoRA 26 | TARGET_MODULES="q_lin v_lin pre_classifier classifier" 27 | 28 | # Device selection: "cuda" or "cpu" 29 | DEVICE="cuda" 30 | 31 | # Number of epochs for training 32 | NUM_EPOCHS=10 33 | 34 | # Dataset type (e.g., "sst2", "mnli", "qnli") 35 | DATASET_TYPE="sst2" 36 | 37 | # Number of epochs for training 38 | BATCH_SIZE=64 39 | 40 | # ======================================================================== 41 | # Validate input parameters 42 | # ======================================================================== 43 | 44 | # Check if the model checkpoint path exists 45 | if [ ! -d "$MODEL_CHECKPOINT" ]; then 46 | echo "Error: Model path $MODEL_CHECKPOINT does not exist!" 47 | exit 1 48 | fi 49 | 50 | # Check if the training dataset path exists 51 | if [ ! -d "$DATASET_PATH" ]; then 52 | echo "Error: Dataset path $DATASET_PATH does not exist!" 53 | exit 1 54 | fi 55 | 56 | # Check if the validation dataset path exists 57 | if [ ! -d "$VAL_DATASET_PATH" ]; then 58 | echo "Error: Validation dataset path $VAL_DATASET_PATH does not exist!" 59 | exit 1 60 | fi 61 | 62 | # ======================================================================== 63 | # Run the training script 64 | # ======================================================================== 65 | 66 | echo "Training is starting..." 67 | 68 | python3 client.py \ 69 | --model_checkpoint "$MODEL_CHECKPOINT" \ 70 | --dataset_path "$DATASET_PATH" \ 71 | --val_dataset_path "$VAL_DATASET_PATH" \ 72 | --lora_r "$LORA_R" \ 73 | --lora_alpha "$LORA_ALPHA" \ 74 | --training_type "$TRAINING_TYPE" \ 75 | --target_modules "$TARGET_MODULES" \ 76 | --device "$DEVICE" \ 77 | --num_epochs "$NUM_EPOCHS" \ 78 | --dataset_type "$DATASET_TYPE" \ 79 | --batch_size "$BATCH_SIZE" 80 | 81 | echo "Training has completed!" 82 | -------------------------------------------------------------------------------- /result_processing/topology.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | import matplotlib.pyplot as plt 4 | 5 | # Define adjacency matrices for the topologies 6 | topologies = { 7 | "T1": np.array([ 8 | [0, 1, 0, 0, 0, 0, 1], 9 | [1, 0, 1, 0, 0, 0, 0], 10 | [0, 1, 0, 1, 0, 0, 0], 11 | [0, 0, 1, 0, 1, 0, 0], 12 | [0, 0, 0, 1, 0, 1, 0], 13 | [0, 0, 0, 0, 1, 0, 1], 14 | [1, 0, 0, 0, 0, 1, 0] 15 | ]), 16 | "T2": np.array([ 17 | [0, 1, 1, 0, 0, 0, 1], 18 | [1, 0, 1, 1, 0, 0, 0], 19 | [1, 1, 0, 1, 0, 0, 0], 20 | [0, 1, 1, 0, 1, 0, 0], 21 | [0, 0, 0, 1, 0, 1, 1], 22 | [0, 0, 0, 0, 1, 0, 1], 23 | [1, 0, 0, 0, 1, 1, 0] 24 | ]), 25 | "T3": np.array([ 26 | [0, 1, 1, 0, 0, 1, 1], 27 | [1, 0, 1, 1, 0, 0, 1], 28 | [1, 1, 0, 1, 1, 0, 0], 29 | [0, 1, 1, 0, 1, 1, 0], 30 | [0, 0, 1, 1, 0, 1, 1], 31 | [1, 0, 0, 1, 1, 0, 1], 32 | [1, 1, 0, 0, 1, 1, 0] 33 | ]), 34 | "T4": np.array([ 35 | [0, 1, 1, 1, 0, 1, 1], 36 | [1, 0, 1, 1, 1, 1, 0], 37 | [1, 1, 0, 1, 1, 1, 0], 38 | [1, 1, 1, 0, 1, 1, 0], 39 | [0, 1, 1, 1, 0, 1, 1], 40 | [1, 1, 1, 1, 1, 0, 1], 41 | [1, 0, 0, 0, 1, 1, 0] 42 | ]) 43 | } 44 | 45 | # Plot the topologies 46 | # First, create the base ring structure from T1 47 | num_nodes = topologies["T1"].shape[0] 48 | base_graph = nx.Graph() 49 | for i in range(num_nodes): 50 | base_graph.add_edge(i, (i + 1) % num_nodes) 51 | 52 | # Add additional edges based on T1 adjacency matrix 53 | for i in range(num_nodes): 54 | for j in range(num_nodes): 55 | if topologies["T1"][i, j] == 1 and not base_graph.has_edge(i, j): 56 | base_graph.add_edge(i, j) 57 | 58 | # Define a fixed layout for consistent node positioning 59 | fixed_pos = nx.spring_layout(base_graph, seed=42) 60 | 61 | # Draw the base graph T1 62 | plt.figure(figsize=(8, 6)) 63 | nx.draw(base_graph, pos=fixed_pos, with_labels=True, node_color='skyblue', node_size=3000, font_size=30, font_weight='bold', edge_color='gray', width=6) 64 | plt.savefig("T1.png", format='png', dpi=300, bbox_inches='tight') 65 | plt.close() 66 | 67 | # Plot the other topologies based on T1 68 | for name, adjacency_matrix in topologies.items(): 69 | if name == "T1": 70 | continue 71 | 72 | G = base_graph.copy() 73 | 74 | # Add additional edges based on the adjacency matrix 75 | for i in range(num_nodes): 76 | for j in range(num_nodes): 77 | if adjacency_matrix[i, j] == 1 and not G.has_edge(i, j): 78 | G.add_edge(i, j) 79 | 80 | # Draw the graph 81 | plt.figure(figsize=(8, 6)) 82 | nx.draw(G, pos=fixed_pos, with_labels=True, node_color='skyblue', node_size=3000, font_size=30, font_weight='bold', edge_color='gray', width=6) 83 | plt.savefig(f"{name}.png", format='png', dpi=300, bbox_inches='tight') 84 | plt.close() 85 | 86 | 87 | -------------------------------------------------------------------------------- /result_processing/ConfidenceInterval.m: -------------------------------------------------------------------------------- 1 | function [mean_accuracies, confidence_intervals] = calculate_confidence_intervals(accuracy_data, confidence_level) 2 | if nargin < 2 3 | confidence_level = 0.95; % 默认置信水平 4 | end 5 | 6 | [epochs, clients] = size(accuracy_data); % epochs代表行,clients代表列 7 | mean_accuracies = zeros(1, clients); % 每一列的平均值 8 | confidence_intervals = zeros(clients, 2); % 每列的置信区间 9 | 10 | for client = 1:clients 11 | accuracies = accuracy_data(:, client); % 取出每一列的数据 12 | mean_accuracy = mean(accuracies); % 计算每一列的平均值 13 | standard_error = std(accuracies) / sqrt(epochs); % 计算标准误差 14 | t_value = tinv(1 - (1 - confidence_level) / 2, epochs - 1); % 查找t分布临界值 15 | 16 | % 计算置信区间 17 | margin_of_error = t_value * standard_error; 18 | confidence_interval = [mean_accuracy - margin_of_error, mean_accuracy + margin_of_error]; 19 | 20 | mean_accuracies(client) = mean_accuracy; 21 | confidence_intervals(client, :) = confidence_interval; 22 | end 23 | end 24 | 25 | %此处需要输入自己的结果数据 26 | acc_data_B1=[]; 27 | acc_data_B2=[]; 28 | acc_data_B3=[]; 29 | 30 | acc_data_BA1=[]; 31 | acc_data_BA2=[]; 32 | acc_data_BA3=[]; 33 | [mean_accuracies1, confidence_intervals1] = calculate_confidence_intervals(acc_data_B2); 34 | [mean_accuracies2, confidence_intervals2] = calculate_confidence_intervals(acc_data_BA2); 35 | 36 | function plot_accuracies_with_confidence_intervals(mean_accuracies_1, confidence_intervals_1,mean_accuracies_2, confidence_intervals_2) 37 | epochs = 1:length(mean_accuracies_1); 38 | 39 | % 提取置信区间的上下界 40 | lower_bounds_1 = confidence_intervals_1(:, 1); 41 | upper_bounds_1 = confidence_intervals_1(:, 2); 42 | 43 | lower_bounds_2 = confidence_intervals_2(:, 1); 44 | upper_bounds_2 = confidence_intervals_2(:, 2); 45 | 46 | figure('Position', [100, 100, 750, 500]); % 设置图的大小 47 | 48 | % 绘制 DLora 的均值和置信区间 49 | plot(epochs, mean_accuracies_1, 'b', 'DisplayName', 'ConLoRA', 'LineWidth', 1.5); 50 | hold on; 51 | fill([epochs, fliplr(epochs)], [upper_bounds_1', fliplr(lower_bounds_1')], 'b', 'FaceAlpha', 0.2, 'EdgeColor', 'none','DisplayName','95% CI ConLoRA'); 52 | 53 | % 绘制 Lora 的均值和置信区间 54 | plot(epochs, mean_accuracies_2, 'r', 'DisplayName', 'LoRA-DFL', 'LineWidth', 1.5); 55 | fill([epochs, fliplr(epochs)], [upper_bounds_2', fliplr(lower_bounds_2')], 'r', 'FaceAlpha', 0.2, 'EdgeColor', 'none','DisplayName','95% CI LoRA-DFL'); 56 | 57 | % 设置标签、标题和图例 58 | xlabel('Epoch', 'FontSize', 12); 59 | ylabel('Accuracy', 'FontSize', 12); 60 | %title('Mean Accuracy with 95% Confidence Intervals for DLora & Lora', 'FontSize', 14); 61 | % 设置图例,并允许自定义位置 62 | legend('show', 'Location', 'southeast'); 63 | 64 | grid on; 65 | xlim([1 260]); 66 | ylim([0.4 0.95]) 67 | 68 | end 69 | 70 | 71 | plot_accuracies_with_confidence_intervals(mean_accuracies1, confidence_intervals1,mean_accuracies2, confidence_intervals2) 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFL-LoRA 2 | 3 | DFL-LoRA is a platform for distributed fine-tuning of large models in a fully decentralized federated learning scenario. This platform allows you to perform fine-tuning of large models within a completely decentralized network. In particular, DFL-LoRA addresses inherent error issues in decentralized federated learning by providing an option to mitigate the consensus error amplification effect caused by LoRA. You can freeze the A matrix during training to further reduce this error. 4 | 5 | ## How to Use 6 | 7 | ### 1. Download 8 | 9 | Clone the repository to your local folder: 10 | 11 | ```bash 12 | git clone https://github.com/LuYuPromax/DFL-LoRA.git 13 | ``` 14 | 15 | ### 2. Create environment 16 | 17 | Create a new Conda environment with Python 3.10.10: 18 | 19 | ``` 20 | conda create -n env_name python==3.10.10 21 | ``` 22 | 23 | Then install the related libraries 24 | 25 | ``` 26 | pip install -r requirement.txt 27 | ``` 28 | 29 | ### Datasets 30 | 31 | Currently, DFL-LoRA supports some GLUE datasets and the GSM8k dataset. To use them, follow these steps: 32 | 33 | 1. Go to the `utils` folder. 34 | 35 | 2. Modify the paths in `glue_split.sh` and `gsm8k_split.sh` to match your local dataset storage. 36 | 37 | 3. Run the following commands to split the datasets: 38 | 39 | ##### GLUE 40 | 41 | ``` 42 | chmod +x glue_split.sh 43 | ./glue_split.sh 44 | ``` 45 | 46 | ##### GSM8K 47 | 48 | ``` 49 | chmod +x gsm8k_split.sh 50 | ./gsm8k_split.sh 51 | ``` 52 | 53 | To view the data distribution of different clients, run the view_split.sh script. 54 | 55 | ### Finetune 56 | 57 | To start the federated training process, run: 58 | 59 | ``` 60 | chmod +x train.sh 61 | ./train.sh 62 | ``` 63 | 64 | Parameter Description 65 | 66 | - **`--model_checkpoint` (str)**: Path to the pre-trained model checkpoint. 67 | - **`--dataset_path_template` (str)**: Template for client dataset paths (`{i}` is replaced by client number). 68 | - **`--val_dataset_path_template` (str)**: Template for client validation dataset paths. 69 | - **`--num_clients` (int)**: Number of clients participating in federated learning. 70 | - **`--lora_r` (int)**: LoRA rank for LoRA layers. 71 | - **`--lora_alpha` (int)**: LoRA alpha for LoRA layers. 72 | - **`--target_modules` (str)**: Comma-separated list of target modules for LoRA layers. 73 | - **`--training_type` (str)**: Type of training (`LoRA` or `DFL-LoRA`). 74 | - **`--dataset_type` (str)**: Dataset type (`sst2`, `mnli`, or `qnli`). 75 | - **`--name` (str)**: Name used to generate the weight matrix. 76 | - **`--num_rounds` (int, default: 256)**: Number of federated learning rounds. 77 | - **`--batch_size` (int, default: 128)**: Batch size for each client. 78 | - **`--log_path` (str, default: "federated_training.log")**: Path to save the log file. 79 | 80 | The training of the llama model is similar. 81 | 82 | ### Notes 83 | - Ensure that your datasets are correctly placed and paths are properly configured in the scripts. 84 | - You can modify the default values of parameters according to your experiment setup. 85 | -------------------------------------------------------------------------------- /utils/generate_weight_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def generate_weight_matrix(name): 4 | """ 5 | Generates a weight matrix based on the provided name. 6 | 7 | Args: 8 | name (str): The name for selecting the weight matrix configuration. 9 | 10 | Returns: 11 | np.array: The weight matrix for aggregation. 12 | """ 13 | if name == "du14": 14 | A = np.array([ 15 | [0, 1, 1, 0, 0, 0, 0], 16 | [1, 0, 1, 0, 0, 0, 0], 17 | [1, 1, 0, 1, 0, 0, 0], 18 | [0, 0, 1, 0, 1, 0, 0], 19 | [0, 0, 0, 1, 0, 1, 0], 20 | [0, 0, 0, 0, 1, 0, 1], 21 | [0, 0, 0, 0, 0, 1, 0] 22 | ]) 23 | elif name == "link3": 24 | A = np.array([ 25 | [0, 1, 0, 0, 0, 0, 1], 26 | [1, 0, 1, 0, 0, 0, 0], 27 | [0, 1, 0, 1, 0, 0, 0], 28 | [0, 0, 1, 0, 1, 0, 0], 29 | [0, 0, 0, 1, 0, 1, 0], 30 | [0, 0, 0, 0, 1, 0, 1], 31 | [1, 0, 0, 0, 0, 1, 0] 32 | ]) 33 | elif name == "link4": 34 | A = np.array([ 35 | [0, 1, 1, 0, 0, 0, 1], 36 | [1, 0, 1, 1, 0, 0, 0], 37 | [1, 1, 0, 1, 0, 0, 0], 38 | [0, 1, 1, 0, 1, 0, 0], 39 | [0, 0, 0, 1, 0, 1, 1], 40 | [0, 0, 0, 0, 1, 0, 1], 41 | [1, 0, 0, 0, 1, 1, 0] 42 | ]) 43 | elif name == "link5": 44 | A = np.array([ 45 | [0, 1, 1, 0, 0, 1, 1], 46 | [1, 0, 1, 1, 0, 0, 1], 47 | [1, 1, 0, 1, 1, 0, 0], 48 | [0, 1, 1, 0, 1, 1, 0], 49 | [0, 0, 1, 1, 0, 1, 1], 50 | [1, 0, 0, 1, 1, 0, 1], 51 | [1, 1, 0, 0, 1, 1, 0] 52 | ]) 53 | elif name == "link6": 54 | A = np.array([ 55 | [0, 1, 1, 1, 0, 1, 1], 56 | [1, 0, 1, 1, 1, 1, 0], 57 | [1, 1, 0, 1, 1, 1, 0], 58 | [1, 1, 1, 0, 1, 1, 0], 59 | [0, 1, 1, 1, 0, 1, 1], 60 | [1, 1, 1, 1, 1, 0, 1], 61 | [1, 0, 0, 0, 1, 1, 0] 62 | ]) 63 | elif name == "link7": 64 | A = np.array([ 65 | [0, 1, 1, 1, 1, 1, 1], 66 | [1, 0, 1, 1, 1, 1, 1], 67 | [1, 1, 0, 1, 1, 1, 1], 68 | [1, 1, 1, 0, 1, 1, 1], 69 | [1, 1, 1, 1, 0, 1, 1], 70 | [1, 1, 1, 1, 1, 0, 1], 71 | [1, 1, 1, 1, 1, 1, 0] 72 | ]) 73 | else: 74 | raise ValueError(f"Unknown matrix name: {name}") 75 | 76 | # Calculate degree of each node and generate weight matrix 77 | degree = np.sum(A, axis=1) 78 | 79 | # Initialize weight matrix W 80 | W = np.zeros_like(A, dtype=float) 81 | 82 | # Calculate weight matrix 83 | for i in range(A.shape[0]): 84 | for j in range(A.shape[1]): 85 | if A[i, j] == 1: 86 | W[i, j] = 1 / (max(degree[i], degree[j]) + 1) 87 | elif i != j: 88 | W[i, j] = 0 89 | 90 | # Adjust diagonal elements to ensure each row sums to 1 91 | for i in range(A.shape[0]): 92 | W[i, i] = 1 - np.sum(W[i, np.arange(A.shape[0]) != i]) 93 | 94 | return W 95 | -------------------------------------------------------------------------------- /utils/gsm8k_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | import numpy as np 4 | import os 5 | 6 | class DataSplitter: 7 | def __init__(self, dataset_path, num_clients, alpha=None): 8 | """ 9 | Initializes the DataSplitter class. 10 | 11 | Args: 12 | dataset_path (str): Path to the dataset. 13 | num_clients (int): Number of clients. 14 | alpha (float): Dirichlet distribution parameter for non-IID split. 15 | """ 16 | self.dataset = load_dataset(dataset_path, "main") 17 | self.train_dataset = self.dataset['train'] 18 | self.num_clients = num_clients 19 | self.alpha = alpha 20 | self.clients_datasets = [] 21 | 22 | def iid_split(self): 23 | """Splits the dataset into IID partitions across clients.""" 24 | shuffled_dataset = self.train_dataset.shuffle(seed=42) 25 | client_size = len(shuffled_dataset) // self.num_clients 26 | self.clients_datasets = [] 27 | 28 | for i in range(self.num_clients): 29 | start_idx = i * client_size 30 | end_idx = (i + 1) * client_size if i != self.num_clients - 1 else len(shuffled_dataset) 31 | client_dataset = shuffled_dataset.select(range(start_idx, end_idx)) 32 | self.clients_datasets.append(client_dataset) 33 | 34 | def dirichlet_split(self): 35 | """Splits the dataset into non-IID partitions using Dirichlet distribution.""" 36 | total_size = len(self.train_dataset) 37 | proportions = np.random.dirichlet([self.alpha] * self.num_clients) 38 | client_sizes = (proportions * total_size).astype(int) 39 | client_sizes[-1] = total_size - client_sizes[:-1].sum() 40 | 41 | indices = np.random.permutation(total_size) 42 | current_index = 0 43 | self.clients_datasets = [] 44 | 45 | for size in client_sizes: 46 | client_indices = indices[current_index: current_index + size] 47 | self.clients_datasets.append(self.train_dataset.select(client_indices.tolist())) 48 | current_index += size 49 | 50 | def compute_data_distribution(self): 51 | """Computes the distribution of data among clients.""" 52 | return [len(client_dataset) for client_dataset in self.clients_datasets] 53 | 54 | def save_datasets(self, base_path): 55 | """Saves client datasets to disk.""" 56 | os.makedirs(base_path, exist_ok=True) 57 | for i, client_dataset in enumerate(self.clients_datasets): 58 | client_dataset.save_to_disk(os.path.join(base_path, f"client_{i+1}")) 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser(description="Dataset splitter for federated learning.") 63 | parser.add_argument('--dataset_path', type=str, required=True, help="Path to the dataset.") 64 | parser.add_argument('--output_path', type=str, required=True, help="Output path for saving client datasets.") 65 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients.") 66 | parser.add_argument('--split_type', type=str, choices=['iid', 'dirichlet'], required=True, help="Type of split: 'iid' or 'dirichlet'.") 67 | parser.add_argument('--alpha', type=float, default=None, help="Alpha value for Dirichlet distribution.") 68 | 69 | args = parser.parse_args() 70 | 71 | # Initialize DataSplitter 72 | splitter = DataSplitter(args.dataset_path, args.num_clients, args.alpha) 73 | 74 | # Perform split 75 | if args.split_type == 'iid': 76 | splitter.iid_split() 77 | elif args.split_type == 'dirichlet': 78 | if args.alpha is None: 79 | parser.error("--alpha must be specified for Dirichlet split.") 80 | splitter.dirichlet_split() 81 | 82 | # Save datasets 83 | splitter.save_datasets(args.output_path) 84 | 85 | # Print data distribution 86 | distribution = splitter.compute_data_distribution() 87 | print("Client data distribution:", distribution) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /finetune/bert/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # --------------------------------------------------- 4 | # Train Script for Federated Learning with LoRA 5 | # --------------------------------------------------- 6 | 7 | # Default parameters (can be overwritten by command-line arguments) 8 | MODEL_CHECKPOINT="/home/ubuntu/smyin/models/distilbert-base-uncased" 9 | DATASET_PATH_TEMPLATE="/home/ubuntu/smyin/dataset/decentrilized_dataset/sst2_020/client_{}" 10 | VAL_DATASET_PATH_TEMPLATE="/home/ubuntu/smyin/dataset/glue/sst2" 11 | NUM_CLIENTS=7 12 | LORA_R=4 13 | LORA_ALPHA=32 14 | TARGET_MODULES="q_lin,v_lin,pre_classifier,classifer" # Comma-separated list of target modules for LoRA layers 15 | TRAINING_TYPE="ConLoRA" # Options: LoRA, ConLoRA 16 | DATASET_TYPE="sst2" # Options: sst2, mnli, qnli 17 | NAME="link3" # The name used to generate the weight matrix 18 | NUM_ROUNDS=5 # Default number of federated learning rounds 19 | BATCH_SIZE=128 # Default batch size for training 20 | LOG_PATH="federated_training.log" # Default path for logging 21 | 22 | # --------------------------------------------------- 23 | # Parsing command-line arguments to allow overrides 24 | # --------------------------------------------------- 25 | while [[ $# -gt 0 ]]; do 26 | case $1 in 27 | --model_checkpoint) 28 | MODEL_CHECKPOINT="$2" 29 | shift 2 30 | ;; 31 | --dataset_path_template) 32 | DATASET_PATH_TEMPLATE="$2" 33 | shift 2 34 | ;; 35 | --val_dataset_path_template) 36 | VAL_DATASET_PATH_TEMPLATE="$2" 37 | shift 2 38 | ;; 39 | --num_clients) 40 | NUM_CLIENTS="$2" 41 | shift 2 42 | ;; 43 | --lora_r) 44 | LORA_R="$2" 45 | shift 2 46 | ;; 47 | --lora_alpha) 48 | LORA_ALPHA="$2" 49 | shift 2 50 | ;; 51 | --target_modules) 52 | TARGET_MODULES="$2" 53 | shift 2 54 | ;; 55 | --training_type) 56 | TRAINING_TYPE="$2" 57 | shift 2 58 | ;; 59 | --dataset_type) 60 | DATASET_TYPE="$2" 61 | shift 2 62 | ;; 63 | --name) 64 | NAME="$2" 65 | shift 2 66 | ;; 67 | --num_rounds) 68 | NUM_ROUNDS="$2" 69 | shift 2 70 | ;; 71 | --batch_size) 72 | BATCH_SIZE="$2" 73 | shift 2 74 | ;; 75 | --log_path) 76 | LOG_PATH="$2" 77 | shift 2 78 | ;; 79 | *) 80 | echo "Unknown argument: $1" 81 | exit 1 82 | ;; 83 | esac 84 | done 85 | 86 | # --------------------------------------------------- 87 | # Print configuration 88 | # --------------------------------------------------- 89 | echo "Training Configuration:" 90 | echo "------------------------" 91 | echo "Model checkpoint: $MODEL_CHECKPOINT" 92 | echo "Dataset path template: $DATASET_PATH_TEMPLATE" 93 | echo "Validation dataset path: $VAL_DATASET_PATH_TEMPLATE" 94 | echo "Number of clients: $NUM_CLIENTS" 95 | echo "LoRA Rank: $LORA_R" 96 | echo "LoRA Alpha: $LORA_ALPHA" 97 | echo "Target modules: $TARGET_MODULES" 98 | echo "Training type: $TRAINING_TYPE" 99 | echo "Dataset type: $DATASET_TYPE" 100 | echo "Weight matrix name: $NAME" 101 | echo "Number of rounds: $NUM_ROUNDS" 102 | echo "Batch size: $BATCH_SIZE" 103 | echo "Log path: $LOG_PATH" 104 | echo "------------------------" 105 | 106 | # --------------------------------------------------- 107 | # Run the training script with the specified parameters 108 | # --------------------------------------------------- 109 | python train.py \ 110 | --model_checkpoint "$MODEL_CHECKPOINT" \ 111 | --dataset_path_template "$DATASET_PATH_TEMPLATE" \ 112 | --val_dataset_path_template "$VAL_DATASET_PATH_TEMPLATE" \ 113 | --num_clients "$NUM_CLIENTS" \ 114 | --lora_r "$LORA_R" \ 115 | --lora_alpha "$LORA_ALPHA" \ 116 | --target_modules "$TARGET_MODULES" \ 117 | --training_type "$TRAINING_TYPE" \ 118 | --dataset_type "$DATASET_TYPE" \ 119 | --name "$NAME" \ 120 | --num_rounds "$NUM_ROUNDS" \ 121 | --batch_size "$BATCH_SIZE" \ 122 | --log_path "$LOG_PATH" 123 | 124 | -------------------------------------------------------------------------------- /utils/view_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_from_disk 3 | from collections import Counter 4 | import os 5 | 6 | class DataLoader: 7 | def __init__(self, base_path, num_clients, dataset_type): 8 | """ 9 | Initializes the DataLoader class. 10 | 11 | Args: 12 | base_path (str): Path where the client datasets are stored. 13 | num_clients (int): Number of clients. 14 | dataset_type (str): Type of dataset ('glue' or 'gsm8k'). 15 | """ 16 | self.base_path = base_path 17 | self.num_clients = num_clients 18 | self.dataset_type = dataset_type.lower() 19 | self.clients_datasets = self.load_datasets() 20 | 21 | def load_datasets(self): 22 | """ 23 | Loads the client datasets from disk. 24 | 25 | Returns: 26 | list: A list of datasets for each client. 27 | """ 28 | clients_datasets = [] 29 | for i in range(self.num_clients): 30 | client_path = os.path.join(self.base_path, f"client_{i+1}") 31 | if os.path.exists(client_path): 32 | try: 33 | client_dataset = load_from_disk(client_path) 34 | clients_datasets.append(client_dataset) 35 | except Exception as e: 36 | print(f"Error loading dataset for Client {i+1}: {e}") 37 | clients_datasets.append(None) 38 | else: 39 | print(f"Warning: Client {i+1} dataset not found at {client_path}.") 40 | clients_datasets.append(None) 41 | return clients_datasets 42 | 43 | def compute_label_distribution(self, dataset): 44 | """ 45 | Computes the label distribution for the given dataset. 46 | 47 | Args: 48 | dataset (Dataset): A Hugging Face dataset. 49 | 50 | Returns: 51 | Counter: A counter of label occurrences. 52 | """ 53 | try: 54 | labels = [example['label'] for example in dataset] 55 | return Counter(labels) 56 | except KeyError: 57 | raise ValueError("The dataset does not contain a 'label' field.") 58 | 59 | def view_data_distribution(self): 60 | """ 61 | Views data distribution based on dataset type ('glue' or 'gsm8k'). 62 | """ 63 | for i in range(self.num_clients): 64 | if self.clients_datasets[i]: 65 | if self.dataset_type == 'glue': 66 | try: 67 | print(f"Label distribution for Client {i+1}:") 68 | dist = self.compute_label_distribution(self.clients_datasets[i]) 69 | print(dist) 70 | except ValueError as e: 71 | print(f"Error processing Client {i+1}: {e}") 72 | elif self.dataset_type == 'gsm8k': 73 | print(f"Client {i+1} has {len(self.clients_datasets[i])} examples.") 74 | else: 75 | print(f"No dataset found for Client {i+1}.") 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser(description="Load and view the distribution of split datasets.") 79 | parser.add_argument('--base_path', type=str, required=True, help="Path to the directory containing client datasets.") 80 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients.") 81 | parser.add_argument('--dataset_type', type=str, choices=['glue', 'gsm8k'], required=True, help="Type of dataset ('glue' or 'gsm8k').") 82 | 83 | args = parser.parse_args() 84 | 85 | # Validate base path 86 | if not os.path.exists(args.base_path): 87 | raise FileNotFoundError(f"The specified base path does not exist: {args.base_path}") 88 | 89 | # Validate number of clients 90 | available_clients = len([d for d in os.listdir(args.base_path) if os.path.isdir(os.path.join(args.base_path, d))]) 91 | if args.num_clients > available_clients: 92 | raise ValueError(f"The specified number of clients ({args.num_clients}) exceeds the available clients ({available_clients}).") 93 | 94 | # Load and view data distribution 95 | loader = DataLoader(args.base_path, args.num_clients, args.dataset_type) 96 | loader.view_data_distribution() 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /finetune/llama/gsm8k_eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import re 4 | import os 5 | os.environ['NCCL_P2P_DISABLE'] = '1' 6 | os.environ['NCCL_IB_DISABLE'] = '1' 7 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 8 | from dataclasses import dataclass, field 9 | from typing import Dict, Optional 10 | 11 | import torch 12 | import transformers 13 | 14 | 15 | from peft import PeftModel 16 | from datasets import load_dataset 17 | from accelerate.utils import set_seed 18 | 19 | 20 | IGNORE_INDEX = -100 21 | DEFAULT_PAD_TOKEN = "[PAD]" 22 | DEFAULT_EOS_TOKEN = "" 23 | DEFAULT_BOS_TOKEN = "" 24 | DEFAULT_UNK_TOKEN = "" 25 | ANSWER_PROMPT = "The final answer is: " 26 | QUESTION_PROMPT = "\nAnswer the above question. First think step by step and then answer the final number.\n" 27 | 28 | 29 | def smart_tokenizer_and_embedding_resize( 30 | special_tokens_dict: Dict, 31 | tokenizer: transformers.PreTrainedTokenizer, 32 | model: transformers.PreTrainedModel, 33 | ): 34 | """Resize tokenizer and embedding. 35 | 36 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 37 | """ 38 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 39 | model.resize_token_embeddings(len(tokenizer)) 40 | 41 | if num_new_tokens > 0: 42 | input_embeddings = model.get_input_embeddings().weight.data 43 | output_embeddings = model.get_output_embeddings().weight.data 44 | 45 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 46 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 47 | 48 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 49 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 50 | 51 | 52 | def evaluation(batch_size): 53 | output_dir = "./llama_lora_gsm8k" 54 | 55 | Modelpath='/home/ubuntu/smyin/models/Llama-3.2-1B' 56 | model = transformers.AutoModelForCausalLM.from_pretrained(Modelpath) 57 | tokenizer = transformers.AutoTokenizer.from_pretrained(Modelpath) 58 | 59 | 60 | special_tokens_dict = dict() 61 | if tokenizer.pad_token is None: 62 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 63 | if tokenizer.eos_token is None: 64 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 65 | if tokenizer.bos_token is None: 66 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 67 | if tokenizer.unk_token is None: 68 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 69 | 70 | smart_tokenizer_and_embedding_resize( 71 | special_tokens_dict=special_tokens_dict, 72 | tokenizer=tokenizer, 73 | model=model, 74 | ) 75 | 76 | model = model.to('cuda') 77 | 78 | dataset = load_dataset('/home/ubuntu/smyin/dataset/gsm8k', "main") 79 | test_set = dataset['test'] 80 | 81 | question = [f"{example['question']}{QUESTION_PROMPT}" for example in test_set] 82 | answer = [] 83 | 84 | # get numerical answer 85 | for example in test_set['answer']: 86 | ans = example.split('####')[-1] 87 | ans = ans.replace(',', '') # handle numbers like 2,000 88 | try: 89 | ans = float(ans) 90 | except ValueError: 91 | ans = float("inf") 92 | answer.append(ans) 93 | 94 | logging.warning("Tokenizing inputs...") 95 | eval_step = math.ceil(len(question)/batch_size) 96 | logging.warning(f"Total example: {len(question)} | eval batch size: {batch_size}" 97 | f"eval steps: {eval_step}") 98 | question_data = [] 99 | for i in range(eval_step): 100 | if i < eval_step - 1: 101 | batch = tokenizer( 102 | question[i*batch_size: (i+1)*batch_size], 103 | return_tensors="pt", 104 | padding="longest", 105 | ) 106 | else: 107 | batch = tokenizer( 108 | question[i*batch_size:], 109 | return_tensors="pt", 110 | padding="longest", 111 | ) 112 | batch['input_len'] = len(batch['input_ids'][0]) 113 | question_data.append(batch) 114 | 115 | model.eval() 116 | gen_kwargs = { 117 | "max_new_tokens": 256, 118 | "temperature": 0.1, 119 | "top_k": 40, 120 | "top_p": 0.95, 121 | "do_sample": True, 122 | } 123 | ans_pred_list = [] 124 | set_seed(42) 125 | for step, batch in enumerate(question_data): 126 | with torch.no_grad(): 127 | gen_kwargs["input_ids"] = batch["input_ids"].to('cuda') 128 | gen_kwargs["attention_mask"] = batch["attention_mask"].to('cuda') 129 | generated_tokens = model.generate(**gen_kwargs) 130 | 131 | pred_tokens = generated_tokens[:, batch['input_len']:] 132 | decoded_pred = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True) 133 | 134 | # Extract the numbers in sentences 135 | print(decoded_pred) 136 | ans_pred_list += [extract_answer_number(sentence_pred) for sentence_pred in decoded_pred] 137 | 138 | print("prediction", ans_pred_list) 139 | print("ground truth", answer) 140 | 141 | accuracy = compute_accuracy(answer, ans_pred_list) 142 | 143 | print(f"GSM8K test accuracy: {100*accuracy:.2f}% | ") 144 | 145 | 146 | def extract_answer_number(sentence: str) -> float: 147 | sentence = sentence.replace(',', '') 148 | pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)] 149 | if not pred: 150 | return float('inf') 151 | segment = sentence.split(ANSWER_PROMPT) 152 | if len(segment) > 1: 153 | pred_answer = segment[1] 154 | pred_answer = [s for s in re.findall(r'-?\d+\.?\d*', pred_answer)] 155 | if len(pred_answer) > 0: 156 | pred_answer = pred_answer[0] 157 | else: 158 | pred_answer = float(pred[-1]) 159 | else: 160 | # use the last number as the answer 161 | pred_answer = float(pred[-1]) 162 | 163 | if isinstance(pred_answer, str): 164 | try: 165 | pred_answer = float(pred_answer) 166 | except ValueError as e: 167 | pred_answer = float('inf') 168 | return pred_answer 169 | 170 | 171 | def compute_accuracy(pred: list, gold: list): 172 | acc = 0.0 173 | for p, g in zip(pred, gold): 174 | if p == g: 175 | acc += 1 176 | 177 | return acc / len(pred) 178 | 179 | 180 | if __name__ == "__main__": 181 | batch_size=128 182 | evaluation(batch_size) 183 | 184 | 185 | -------------------------------------------------------------------------------- /utils/glue_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from datasets import load_dataset 4 | from collections import Counter 5 | import os 6 | 7 | class DataSplitter: 8 | def __init__(self, dataset_path, num_clients, alpha=None, min_size=None): 9 | """ 10 | Initializes the DataSplitter class. 11 | 12 | Args: 13 | dataset_path (str): Path to the dataset. 14 | num_clients (int): Number of clients. 15 | alpha (float): Dirichlet distribution parameter for non-IID split. 16 | min_size (int): Minimum size of data for a client (optional). 17 | """ 18 | self.dataset = load_dataset(dataset_path) 19 | self.train_dataset = self.dataset['train'] 20 | self.num_clients = num_clients 21 | self.alpha = alpha 22 | self.min_size = min_size 23 | self.clients_datasets = [] 24 | 25 | def iid_split(self): 26 | """ 27 | Splits the dataset into IID partitions across clients. 28 | """ 29 | shuffled_dataset = self.train_dataset.shuffle(seed=42) 30 | client_size = len(shuffled_dataset) // self.num_clients 31 | self.clients_datasets = [] 32 | 33 | for i in range(self.num_clients): 34 | start_idx = i * client_size 35 | end_idx = (i + 1) * client_size if i != self.num_clients - 1 else len(shuffled_dataset) 36 | client_dataset = shuffled_dataset.select(range(start_idx, end_idx)) 37 | self.clients_datasets.append(client_dataset) 38 | 39 | def dirichlet_split(self): 40 | """ 41 | Splits the dataset into non-IID partitions using Dirichlet distribution. 42 | """ 43 | self.clients_datasets = [[] for _ in range(self.num_clients)] 44 | labels = np.array([example['label'] for example in self.train_dataset]) 45 | num_classes = len(np.unique(labels)) 46 | 47 | class_indices = [np.where(labels == i)[0] for i in range(num_classes)] 48 | 49 | for c in range(num_classes): 50 | class_size = len(class_indices[c]) 51 | proportions = np.random.dirichlet(np.repeat(self.alpha, self.num_clients)) 52 | client_sizes = (proportions * class_size).astype(int) 53 | client_sizes[-1] = class_size - np.sum(client_sizes[:-1]) 54 | 55 | current_index = 0 56 | for i, size in enumerate(client_sizes): 57 | self.clients_datasets[i].extend(class_indices[c][current_index:current_index + size].tolist()) 58 | current_index += size 59 | 60 | # Ensure each client has enough data 61 | for attempt in range(1000): 62 | total_sizes = [len(client_dataset) for client_dataset in self.clients_datasets] 63 | if all(size >= self.min_size for size in total_sizes): 64 | break 65 | if any(size < self.min_size for size in total_sizes): 66 | self.clients_datasets = [[] for _ in range(self.num_clients)] 67 | for c in range(num_classes): 68 | class_size = len(class_indices[c]) 69 | proportions = np.random.dirichlet(np.repeat(self.alpha, self.num_clients)) 70 | client_sizes = (proportions * class_size).astype(int) 71 | client_sizes[-1] = class_size - np.sum(client_sizes[:-1]) 72 | current_index = 0 73 | for i, size in enumerate(client_sizes): 74 | self.clients_datasets[i].extend(class_indices[c][current_index:current_index + size].tolist()) 75 | current_index += size 76 | else: 77 | raise ValueError(f"Unable to adjust client sizes after {attempt + 1} attempts.") 78 | else: 79 | raise ValueError("Unable to generate a valid split after 1000 attempts.") 80 | 81 | # Convert indices to actual data and shuffle 82 | for i in range(self.num_clients): 83 | if len(self.clients_datasets[i]) == 0: 84 | print(f"Client {i} dataset is empty!") 85 | else: 86 | np.random.shuffle(self.clients_datasets[i]) 87 | self.clients_datasets[i] = self.train_dataset.select(self.clients_datasets[i]) 88 | 89 | def compute_label_distribution(self, dataset): 90 | """ 91 | Computes the label distribution for the given dataset. 92 | """ 93 | labels = [example['label'] for example in dataset] 94 | return Counter(labels) 95 | 96 | def label_distribution(self): 97 | """ 98 | Computes the label distribution for each client dataset. 99 | """ 100 | self.pct = [] 101 | for i in range(self.num_clients): 102 | self.pct.append(self.compute_label_distribution(self.clients_datasets[i])) 103 | 104 | def save_datasets(self, base_path): 105 | """ 106 | Saves the client datasets to disk. 107 | """ 108 | os.makedirs(base_path, exist_ok=True) 109 | for i, client_dataset in enumerate(self.clients_datasets): 110 | client_dataset.save_to_disk(os.path.join(base_path, f"client_{i+1}")) 111 | 112 | def main(): 113 | parser = argparse.ArgumentParser(description="Dataset splitter for federated learning.") 114 | parser.add_argument('--dataset_path', type=str, required=True, help="Path to the dataset.") 115 | parser.add_argument('--output_path', type=str, required=True, help="Output path for saving client datasets.") 116 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients.") 117 | parser.add_argument('--split_type', type=str, choices=['iid', 'dirichlet'], required=True, help="Type of split: 'iid' or 'dirichlet'.") 118 | parser.add_argument('--alpha', type=float, default=None, help="Alpha value for Dirichlet distribution.") 119 | parser.add_argument('--min_size', type=int, default=0, help="Minimum size for each client dataset.") 120 | 121 | args = parser.parse_args() 122 | 123 | splitter = DataSplitter(args.dataset_path, args.num_clients, args.alpha, args.min_size) 124 | 125 | # Perform split 126 | if args.split_type == 'iid': 127 | splitter.iid_split() 128 | elif args.split_type == 'dirichlet': 129 | if args.alpha is None: 130 | parser.error("--alpha must be specified for Dirichlet split.") 131 | splitter.dirichlet_split() 132 | 133 | # Save datasets 134 | splitter.save_datasets(args.output_path) 135 | 136 | # Print label distribution 137 | splitter.label_distribution() 138 | print("Label distribution per client:") 139 | for i, dist in enumerate(splitter.pct): 140 | print(f"Client {i+1}: {dist}") 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /finetune/llama/gsm8k_train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | os.environ['NCCL_P2P_DISABLE'] = '1' 5 | os.environ['NCCL_IB_DISABLE'] = '1' 6 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 7 | from dataclasses import dataclass, field 8 | from typing import Dict, Optional, Sequence 9 | 10 | import torch 11 | import transformers 12 | from torch.utils.data import Dataset 13 | from transformers import Trainer,TrainingArguments 14 | from peft import LoraConfig, TaskType, get_peft_model 15 | from datasets import load_dataset 16 | 17 | 18 | IGNORE_INDEX = -100 19 | DEFAULT_PAD_TOKEN = "[PAD]" 20 | DEFAULT_EOS_TOKEN = "" 21 | DEFAULT_BOS_TOKEN = "" 22 | DEFAULT_UNK_TOKEN = "" 23 | ANSWER_PROMPT = "The final answer is: " 24 | QUESTION_PROMPT = "\nAnswer the above question. First think step by step and then answer the final number.\n" 25 | 26 | 27 | 28 | def smart_tokenizer_and_embedding_resize( 29 | special_tokens_dict: Dict, 30 | tokenizer: transformers.PreTrainedTokenizer, 31 | model: transformers.PreTrainedModel, 32 | ): 33 | """Resize tokenizer and embedding. 34 | 35 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 36 | """ 37 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 38 | model.resize_token_embeddings(len(tokenizer)) 39 | 40 | if num_new_tokens > 0: 41 | input_embeddings = model.get_input_embeddings().weight.data 42 | output_embeddings = model.get_output_embeddings().weight.data 43 | 44 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 45 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 46 | 47 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 48 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 49 | 50 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 51 | """Tokenize a list of strings.""" 52 | tokenized_list = [ 53 | tokenizer( 54 | text, 55 | return_tensors="pt", 56 | padding="max_length", 57 | max_length=512, 58 | truncation=True, 59 | ) 60 | for text in strings 61 | ] 62 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 63 | input_ids_lens = labels_lens = [ 64 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 65 | ] 66 | return dict( 67 | input_ids=input_ids, 68 | labels=labels, 69 | input_ids_lens=input_ids_lens, 70 | labels_lens=labels_lens, 71 | ) 72 | 73 | class SupervisedDataset(Dataset): 74 | def __init__(self, raw_data, tokenizer): 75 | super(SupervisedDataset, self).__init__() 76 | 77 | sources = [f"{example['question']}{QUESTION_PROMPT}" for example in raw_data] 78 | targets = [f"{example['answer']}{tokenizer.eos_token}".replace("####", ANSWER_PROMPT) for example in raw_data] 79 | 80 | examples = [s + t for s, t in zip(sources, targets)] 81 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 82 | input_ids = examples_tokenized["input_ids"] 83 | labels = copy.deepcopy(input_ids) 84 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 85 | label[:source_len] = IGNORE_INDEX 86 | data_dict = dict(input_ids=input_ids, labels=labels) 87 | self.input_ids = data_dict["input_ids"] 88 | self.labels = data_dict["labels"] 89 | 90 | def __len__(self): 91 | return len(self.input_ids) 92 | 93 | def __getitem__(self, idx): 94 | return {"input_ids": self.input_ids[idx], "labels": self.labels[idx]} 95 | 96 | @dataclass 97 | class DataCollatorForSupervisedDataset(object): 98 | """Collate examples for supervised fine-tuning.""" 99 | 100 | tokenizer: transformers.PreTrainedTokenizer 101 | 102 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 103 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 104 | input_ids = torch.nn.utils.rnn.pad_sequence( 105 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 106 | ) 107 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 108 | return dict( 109 | input_ids=input_ids, 110 | labels=labels, 111 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 112 | ) 113 | 114 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer) -> Dict: 115 | dataset = load_dataset('/home/ubuntu/smyin/datasets/gsm8k', "main") 116 | #train_set = dataset['train'].select(range(500)) 117 | train_set = dataset['train'] 118 | train_dataset = SupervisedDataset(raw_data=train_set, tokenizer=tokenizer) 119 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 120 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 121 | 122 | 123 | def train(): 124 | output_dir = "./llama_lora_gsm8k3" 125 | 126 | model = transformers.AutoModelForCausalLM.from_pretrained('/home/ubuntu/smyin/models/Llama-3.2-1B-Instruct') 127 | 128 | peft_config = LoraConfig( 129 | task_type=TaskType.CAUSAL_LM, 130 | r=8, # 矩阵低秩近似的秩 131 | lora_alpha=32, 132 | lora_dropout=0.1, # LoRA 的 dropout 概率 133 | ) 134 | model = get_peft_model(model, peft_config) 135 | 136 | tokenizer = transformers.AutoTokenizer.from_pretrained('/home/ubuntu/smyin/models/Llama-3.2-1B-Instruct') 137 | 138 | special_tokens_dict = dict() 139 | if tokenizer.pad_token is None: 140 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 141 | if tokenizer.eos_token is None: 142 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 143 | if tokenizer.bos_token is None: 144 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 145 | if tokenizer.unk_token is None: 146 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 147 | 148 | smart_tokenizer_and_embedding_resize(special_tokens_dict=special_tokens_dict,tokenizer=tokenizer,model=model,) 149 | 150 | data_module = make_supervised_data_module(tokenizer=tokenizer) 151 | 152 | training_args = TrainingArguments( 153 | output_dir=output_dir, 154 | per_device_train_batch_size=5, 155 | gradient_accumulation_steps=4, 156 | num_train_epochs=5, 157 | save_steps=500, 158 | logging_steps=200, 159 | evaluation_strategy="no", 160 | ) 161 | 162 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 163 | trainer.train() 164 | trainer.save_state() 165 | trainer.save_model(output_dir) 166 | 167 | 168 | if __name__ == "__main__": 169 | train() -------------------------------------------------------------------------------- /finetune/bert/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import torch 5 | import argparse 6 | 7 | # 动态添加项目根目录到 sys.path 8 | current_dir = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在目录 9 | parent_dir = os.path.dirname(os.path.dirname(current_dir)) # code 根目录 10 | sys.path.append(parent_dir) 11 | 12 | # 导入模块 13 | from utils.generate_weight_matrix import generate_weight_matrix 14 | from server import FederatedServer 15 | from client import Client 16 | 17 | def train_federated_model(model_checkpoint, dataset_path_template, val_dataset_path_template, num_clients, lora_r, lora_alpha, target_modules, training_type, dataset_type, name, num_rounds=256, batch_size=16, log_path="federated_training.log"): 18 | """ 19 | Runs the federated learning training process for the given number of rounds. 20 | 21 | Args: 22 | model_checkpoint (str): Path to the pre-trained model. 23 | dataset_path_template (str): Template for the dataset paths for each client. 24 | val_dataset_path_template (str): Path template for validation dataset. 25 | num_clients (int): Number of clients participating in federated learning. 26 | lora_r (int): LoRA rank for LoRA layers. 27 | lora_alpha (int): LoRA alpha for LoRA layers. 28 | target_modules (list): List of target modules for LoRA layers. 29 | training_type (str): Type of training ('LoRA', 'ConLoRA'). 30 | dataset_type (str): Type of dataset ('sst2', 'mnli', or 'qnli'). 31 | name (str): The name used to generate the weight matrix. 32 | num_rounds (int): Number of federated learning rounds (default is 256). 33 | log_path (str): Path to save the log file. 34 | """ 35 | # Initialize clients 36 | clients = [] 37 | for i in range(num_clients): 38 | dataset_path = dataset_path_template.format(i+1) 39 | val_dataset_path = val_dataset_path_template 40 | client = Client(model_checkpoint=model_checkpoint, 41 | dataset_path=dataset_path, 42 | val_dataset_path=val_dataset_path, 43 | lora_r=lora_r, 44 | lora_alpha=lora_alpha, 45 | target_modules=target_modules, 46 | training_type=training_type, 47 | dataset_type=dataset_type, 48 | batch_size=batch_size, 49 | device='cuda' if torch.cuda.is_available() else 'cpu') 50 | clients.append(client) 51 | 52 | # Initialize the federated server with clients 53 | server = FederatedServer(clients) 54 | 55 | # Initial aggregation of parameters 56 | server.aggregate_last_2_layers_params() 57 | server.aggregate_lora_A() 58 | 59 | # Initialize accuracy and parameter difference tracking 60 | acc = [[] for _ in range(num_clients)] 61 | diff = [] 62 | 63 | # Get the weight matrix A for aggregation 64 | A = generate_weight_matrix(name) 65 | 66 | # Initial evaluation of all clients 67 | for i in range(server.num_clients): 68 | accuracy = server.clients[i].evaluate() 69 | acc[i].append(accuracy) 70 | logging.info(f'Initial accuracy of client {i+1}: {accuracy}') 71 | 72 | # Aggregate LoRA parameters with DFL 73 | server.aggregate_dfl(A) 74 | 75 | # Calculate and log the initial difference in LoRA products 76 | a = server.calculate_lora_products_and_avg_diff(server.new_params) 77 | diff.append(a) 78 | 79 | # Training loop for the given number of rounds 80 | for round in range(num_rounds): 81 | logging.info(f"Round {round+1}/{num_rounds}") 82 | 83 | # Train each client for one epoch 84 | for client in server.clients: 85 | loss = client.train_one_epoch() 86 | accuracy = client.evaluate() 87 | logging.info(f"Client {server.clients.index(client) + 1} - Loss: {loss}, Accuracy: {accuracy}") 88 | 89 | # Aggregate parameters with DFL 90 | server.aggregate_dfl(A) 91 | 92 | # Calculate and log the difference in LoRA products after aggregation 93 | a = server.calculate_lora_products_and_avg_diff(server.new_params) 94 | diff.append(a) 95 | logging.info(f"Round {round+1} - Parameter Difference: {a}") 96 | 97 | # Evaluate all clients after the round 98 | for i in range(server.num_clients): 99 | accuracy = server.clients[i].evaluate() 100 | acc[i].append(accuracy) 101 | logging.info(f'Accuracy of client {i+1} after round {round+1}: {accuracy}') 102 | 103 | # Log the parameter difference and accuracies after the round 104 | logging.info(f"Parameter Difference record: {diff}") 105 | logging.info(f'Accuracies after round {round+1}: {acc}') 106 | 107 | 108 | if __name__ == "__main__": 109 | # Define argument parser 110 | parser = argparse.ArgumentParser(description="Federated Learning Training Script") 111 | parser.add_argument('--model_checkpoint', type=str, required=True, help="Path to the model checkpoint.") 112 | parser.add_argument('--dataset_path_template', type=str, required=True, help="Template for the dataset paths for each client.") 113 | parser.add_argument('--val_dataset_path_template', type=str, required=True, help="Path template for validation dataset.") 114 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients participating in federated learning.") 115 | parser.add_argument('--lora_r', type=int, required=True, help="LoRA rank for LoRA layers.") 116 | parser.add_argument('--lora_alpha', type=int, required=True, help="LoRA alpha for LoRA layers.") 117 | parser.add_argument('--target_modules', type=str, required=True, help="Comma-separated list of target modules for LoRA layers.") 118 | parser.add_argument('--training_type', type=str, choices=['LoRA', 'ConLoRA'], required=True, help="Training type ('LoRA' or 'ConLoRA').") 119 | parser.add_argument('--dataset_type', type=str, choices=['sst2', 'mnli', 'qnli'], required=True, help="Dataset type.") 120 | parser.add_argument('--name', type=str, required=True, help="The name used to generate the weight matrix.") 121 | parser.add_argument('--num_rounds', type=int, default=256, help="Number of federated learning rounds.") 122 | parser.add_argument('--batch_size', type=int, default=128, help="Batch size for each client.") 123 | parser.add_argument('--log_path', type=str, default="federated_training.log", help="Path to save the log file.") 124 | 125 | args = parser.parse_args() 126 | 127 | # Parse target_modules into list 128 | target_modules = args.target_modules.split(',') 129 | 130 | # Configure logging to the specified log file path 131 | logging.basicConfig(filename=args.log_path, level=logging.INFO, 132 | format='%(asctime)s - %(levelname)s - %(message)s') 133 | 134 | # Run federated training process 135 | train_federated_model( 136 | model_checkpoint=args.model_checkpoint, 137 | dataset_path_template=args.dataset_path_template, 138 | val_dataset_path_template=args.val_dataset_path_template, 139 | num_clients=args.num_clients, 140 | lora_r=args.lora_r, 141 | lora_alpha=args.lora_alpha, 142 | target_modules=target_modules, 143 | training_type=args.training_type, 144 | dataset_type=args.dataset_type, 145 | name=args.name, 146 | num_rounds=args.num_rounds, 147 | batch_size=args.batch_size, 148 | log_path=args.log_path # Pass log_path parameter 149 | ) 150 | -------------------------------------------------------------------------------- /finetune/llama/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ['CUDA_VISIBLE_DEVICES'] = '0' 3 | import torch 4 | from itertools import combinations 5 | import numpy as np 6 | import logging 7 | from client import Client 8 | 9 | 10 | 11 | # 获取当前文件的绝对路径 12 | current_dir = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | # 确保日志目录存在 15 | log_dir = os.path.join(current_dir, 'logs') 16 | os.makedirs(log_dir, exist_ok=True) 17 | 18 | # 设置日志记录 19 | logging.basicConfig(filename=os.path.join(log_dir, 'LoRA.log'), 20 | level=logging.INFO, 21 | format='%(asctime)s - %(levelname)s - %(message)s') 22 | 23 | 24 | class FederatedServer: 25 | 26 | def __init__(self, clients): 27 | self.clients = clients 28 | self.num_clients=len(clients) 29 | 30 | #初始时为了使所有节点的LoRA参数一致 31 | def aggregate_lora_A(self): 32 | self.Avg_lora_A_params=self.clients[0].get_lora_A() 33 | for i in range(self.num_clients): 34 | self.clients[i].set_trainable_parameters(self.Avg_lora_A_params) 35 | 36 | def aggregate_dfl(self, A): 37 | self.new_params = [] 38 | 39 | # 初始化每个客户端的参数为零 40 | for i in range(self.num_clients): 41 | client_params = self.clients[i].get_lora_parameters() 42 | zero_params = {name: (torch.zeros_like(param[0]), param[1]) for name, param in client_params.items()} 43 | self.new_params.append(zero_params) 44 | 45 | # 聚合操作 46 | for i in range(self.num_clients): 47 | for j in range(self.num_clients): 48 | client_params = self.clients[j].get_lora_parameters() 49 | for name, (param, requires_grad) in client_params.items(): 50 | self.new_params[i][name] = (self.new_params[i][name][0] + param * A[i][j], requires_grad) 51 | 52 | for i in range(self.num_clients): 53 | self.clients[i].set_trainable_parameters(self.new_params[i]) 54 | 55 | 56 | # 提取并相乘lora参数 57 | def extract_and_multiply_lora_params(self,param_group): 58 | result = {} 59 | for param_name, (param, _) in param_group.items(): 60 | if 'lora_B.default.weight' in param_name: 61 | prefix = param_name.split('lora_B.default.weight')[0] 62 | lora_A_name = prefix + 'lora_A.default.weight' 63 | lora_B_name = prefix + 'lora_B.default.weight' 64 | 65 | if lora_A_name in param_group and lora_B_name in param_group: 66 | lora_A = param_group[lora_A_name][0] 67 | lora_B = param_group[lora_B_name][0] 68 | 69 | product = torch.matmul(lora_B, lora_A) 70 | result[prefix + 'product'] = product 71 | 72 | return result 73 | 74 | 75 | # 计算所有参数组两两之间的差异并求平均值 76 | def calculate_lora_products_and_avg_diff(self,param_groups): 77 | if len(param_groups) < 2: 78 | raise ValueError("There should be at least two sets of parameters to calculate differences.") 79 | 80 | total_diff_sum = 0.0 81 | num_pairs = 0 82 | 83 | # 生成所有参数组的两两组合 84 | for i, j in combinations(range(len(param_groups)), 2): 85 | product_1 = self.extract_and_multiply_lora_params(param_groups[i]) 86 | product_2 = self.extract_and_multiply_lora_params(param_groups[j]) 87 | pair_diff_sum=0.0 88 | # 计算 lora_A 和 lora_B 的乘积差异 89 | for key in product_1.keys(): 90 | diff = product_1[key] - product_2[key] 91 | #pair_diff_sum += torch.sum(diff).item() 92 | pair_diff_sum += torch.norm(diff).item() 93 | 94 | # 计算普通参数的差异 95 | for param_name in param_groups[i].keys(): 96 | if 'lora_A.default.weight' in param_name or 'lora_B.default.weight' in param_name: 97 | continue # 跳过 lora_A 和 lora_B,因为已经处理过 98 | diff = param_groups[i][param_name][0] - param_groups[j][param_name][0] 99 | #pair_diff_sum += torch.sum(diff).item() 100 | pair_diff_sum += torch.norm(diff).item() 101 | 102 | 103 | total_diff_sum += pair_diff_sum 104 | 105 | num_pairs += 1 106 | 107 | # 计算平均差异 108 | average_diff = total_diff_sum / num_pairs if num_pairs > 0 else 0.0 109 | return average_diff 110 | 111 | def get_A(name): 112 | if name=="du14": 113 | A = np.array([ 114 | [0, 1, 1, 0, 0, 0, 0], 115 | [1, 0, 1, 0, 0, 0, 0], 116 | [1, 1, 0, 1, 0, 0, 0], 117 | [0, 0, 1, 0, 1, 0, 0], 118 | [0, 0, 0, 1, 0, 1, 0], 119 | [0, 0, 0, 0, 1, 0, 1], 120 | [0, 0, 0, 0, 0, 1, 0] 121 | ]) 122 | if name=="link3": 123 | A = np.array([ 124 | [0, 1, 0, 0, 0, 0, 1], 125 | [1, 0, 1, 0, 0, 0, 0], 126 | [0, 1, 0, 1, 0, 0, 0], 127 | [0, 0, 1, 0, 1, 0, 0], 128 | [0, 0, 0, 1, 0, 1, 0], 129 | [0, 0, 0, 0, 1, 0, 1], 130 | [1, 0, 0, 0, 0, 1, 0] 131 | ]) 132 | elif name=="link4": 133 | A= np.array([ 134 | [0, 1, 1, 0, 0, 0, 1], 135 | [1, 0, 1, 1, 0, 0, 0], 136 | [1, 1, 0, 1, 0, 0, 0], 137 | [0, 1, 1, 0, 1, 0, 0], 138 | [0, 0, 0, 1, 0, 1, 1], 139 | [0, 0, 0, 0, 1, 0, 1], 140 | [1, 0, 0, 0, 1, 1, 0] 141 | ]) 142 | elif name=="link5": 143 | A= np.array([ 144 | [0, 1, 1, 0, 0, 1, 1], 145 | [1, 0, 1, 1, 0, 0, 1], 146 | [1, 1, 0, 1, 1, 0, 0], 147 | [0, 1, 1, 0, 1, 1, 0], 148 | [0, 0, 1, 1, 0, 1, 1], 149 | [1, 0, 0, 1, 1, 0, 1], 150 | [1, 1, 0, 0, 1, 1, 0] 151 | ]) 152 | elif name=="link6": 153 | A= np.array([ 154 | [0, 1, 1, 1, 0, 1, 1], 155 | [1, 0, 1, 1, 1, 1, 0], 156 | [1, 1, 0, 1, 1, 1, 0], 157 | [1, 1, 1, 0, 1, 1, 0], 158 | [0, 1, 1, 1, 0, 1, 1], 159 | [1, 1, 1, 1, 1, 0, 1], 160 | [1, 0, 0, 0, 1, 1, 0] 161 | ]) 162 | elif name=="link7": 163 | A= np.array([ 164 | [0, 1, 1, 1, 1, 1, 1], 165 | [1, 0, 1, 1, 1, 1, 1], 166 | [1, 1, 0, 1, 1, 1, 1], 167 | [1, 1, 1, 0, 1, 1, 1], 168 | [1, 1, 1, 1, 0, 1, 1], 169 | [1, 1, 1, 1, 1, 0, 1], 170 | [1, 1, 1, 1, 1, 1, 0] 171 | ]) 172 | 173 | degree = np.sum(A, axis=1) 174 | 175 | # 初始化权重矩阵W 176 | W = np.zeros_like(A, dtype=float) 177 | 178 | # 计算权重矩阵 179 | for i in range(A.shape[0]): 180 | for j in range(A.shape[1]): 181 | if A[i, j] == 1: 182 | W[i, j] = 1 / (max(degree[i], degree[j]) + 1) 183 | elif i != j: 184 | W[i, j] = 0 185 | 186 | # 处理对角线元素 187 | for i in range(A.shape[0]): 188 | W[i, i] = 1 - np.sum(W[i, np.arange(A.shape[0]) != i]) 189 | 190 | 191 | return W 192 | 193 | 194 | if __name__=="__main__": 195 | model_checkpoint = '/home/ubuntu/smyin/models/Llama-3.2-1B' 196 | dataset_path_template = "/home/ubuntu/smyin/dataset/decentrilized_dataset/gsm8k/client_{}" 197 | 198 | clients = [] 199 | num_clients = 7 200 | type = "LoRA" 201 | name="link3" 202 | batch_size=2 203 | 204 | 205 | 206 | for i in range(num_clients): 207 | dataset_path = dataset_path_template.format(i+1) 208 | #device = f"cuda:{i % 4}" # 循环分配到 4 张 GPU 209 | client = Client(model_path=model_checkpoint, data_path=dataset_path, type=type,batch_size=batch_size) 210 | clients.append(client) 211 | 212 | server = FederatedServer(clients) 213 | 214 | server.aggregate_lora_A() 215 | 216 | num_rounds = 3 217 | diff = [] 218 | 219 | A=get_A(name) 220 | 221 | server.aggregate_dfl(A) 222 | a = server.calculate_lora_products_and_avg_diff(server.new_params) 223 | diff.append(a) 224 | 225 | for round in range(num_rounds): 226 | logging.info(f"Round {round+1}/{num_rounds}") 227 | for client in server.clients: 228 | loss = client.train_one_epoch() 229 | logging.info(f"Client {server.clients.index(client) + 1} - Loss: {loss}") 230 | 231 | server.aggregate_dfl(A) 232 | a = server.calculate_lora_products_and_avg_diff(server.new_params) 233 | diff.append(a) 234 | logging.info(f"Round {round+1} - Parameter Difference: {a}") 235 | 236 | 237 | for i in range(server.num_clients): 238 | server.clients[i].trainable_parameters_nums() 239 | logging.info(f"Parameter Difference record: {diff}") 240 | -------------------------------------------------------------------------------- /finetune/bert/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import numpy as np 5 | from client import Client 6 | from typing import List, Dict 7 | 8 | 9 | class FederatedServer: 10 | def __init__(self, clients: List[Client]) -> None: 11 | """ 12 | Initializes the server with a list of clients. 13 | 14 | Args: 15 | clients (list): List of Client objects. 16 | """ 17 | self.clients = clients 18 | self.num_clients = len(clients) 19 | 20 | # Set up logging 21 | current_dir = os.path.dirname(os.path.abspath(__file__)) 22 | log_dir = os.path.join(current_dir, 'logs') 23 | os.makedirs(log_dir, exist_ok=True) 24 | logging.basicConfig( 25 | filename=os.path.join(log_dir, 'server_init.log'), 26 | level=logging.INFO, 27 | format='%(asctime)s - %(levelname)s - %(message)s' 28 | ) 29 | 30 | def clients_initialize_info(self) -> None: 31 | """ 32 | Initializes each client by loading the model and dataset. 33 | """ 34 | logging.info("Initializing clients...") 35 | for i, client in enumerate(self.clients): 36 | logging.info(f"Initializing client {i+1}...") 37 | client.print_trainable_parameters() # Just to verify the model parameters 38 | accuracy = client.evaluate() # Initial evaluation 39 | logging.info(f"Initial accuracy of client {i+1}: {accuracy}") 40 | 41 | def aggregate_lora_A(self) -> None: 42 | """ 43 | Aggregate LoRA parameters from the first client and set them as trainable parameters 44 | for all other clients to ensure consistency. 45 | """ 46 | logging.info("Aggregating LoRA parameters...") 47 | avg_lora_A_params = self.clients[0].get_lora_A() 48 | 49 | # Set the LoRA A parameters for all clients 50 | for client in self.clients: 51 | client.set_trainable_parameters(avg_lora_A_params) 52 | 53 | def aggregate_last_2_layers_params(self) -> None: 54 | """ 55 | Aggregates the last 2 layers' parameters across all clients. 56 | """ 57 | logging.info("Aggregating last 2 layers' parameters...") 58 | avg_params = self.clients[0].get_last_2_layers() 59 | 60 | # Set the last 2 layers' parameters for all clients 61 | for client in self.clients: 62 | client.set_trainable_parameters(avg_params) 63 | 64 | def aggregate_dfl(self, A: np.array) -> None: 65 | """ 66 | Aggregates LoRA parameters using a weighted average, based on a matrix A. 67 | 68 | Args: 69 | A (np.array): A weight matrix used for aggregation. 70 | """ 71 | logging.info("Aggregating LoRA parameters with DFL method...") 72 | self.new_params = [] 73 | 74 | # Initialize each client's parameters as zero 75 | for i in range(self.num_clients): 76 | client_params = self.clients[i].get_lora_parameters() 77 | zero_params = {name: (torch.zeros_like(param[0]), param[1]) for name, param in client_params.items()} 78 | self.new_params.append(zero_params) 79 | 80 | # Perform aggregation based on weight matrix A 81 | for i in range(self.num_clients): 82 | for j in range(self.num_clients): 83 | client_params = self.clients[j].get_lora_parameters() 84 | for name, (param, requires_grad) in client_params.items(): 85 | self.new_params[i][name] = (self.new_params[i][name][0] + param * A[i][j], requires_grad) 86 | 87 | # Update each client's parameters 88 | for i in range(self.num_clients): 89 | self.clients[i].set_trainable_parameters(self.new_params[i]) 90 | 91 | logging.info("LoRA parameters aggregated with DFL.") 92 | 93 | def extract_and_multiply_lora_params(self, param_group: Dict[str, tuple]) -> Dict[str, torch.Tensor]: 94 | """ 95 | Extracts LoRA A and LoRA B parameters and computes their product. 96 | 97 | Args: 98 | param_group (dict): Dictionary of model parameters. 99 | 100 | Returns: 101 | dict: A dictionary containing the product of LoRA A and LoRA B. 102 | """ 103 | result = {} 104 | for param_name, (param, _) in param_group.items(): 105 | if 'lora_B.default.weight' in param_name: 106 | prefix = param_name.split('lora_B.default.weight')[0] 107 | lora_A_name = prefix + 'lora_A.default.weight' 108 | lora_B_name = prefix + 'lora_B.default.weight' 109 | 110 | if lora_A_name in param_group and lora_B_name in param_group: 111 | lora_A = param_group[lora_A_name][0] 112 | lora_B = param_group[lora_B_name][0] 113 | 114 | product = torch.matmul(lora_B, lora_A) 115 | result[prefix + 'product'] = product 116 | 117 | return result 118 | 119 | def calculate_lora_products_and_avg_diff(self, param_groups: List[Dict[str, tuple]]) -> float: 120 | """ 121 | Calculates the average difference between the LoRA parameter products for all clients. 122 | 123 | Args: 124 | param_groups (list): A list of parameter groups from the clients. 125 | 126 | Returns: 127 | float: The average difference between LoRA parameter products. 128 | """ 129 | if len(param_groups) < 2: 130 | raise ValueError("There should be at least two sets of parameters to calculate differences.") 131 | 132 | # Calculate average LoRA parameters across all clients 133 | avg_params = self.clients[0].get_lora_parameters() 134 | avg_params = {name: (torch.zeros_like(param[0]), param[1]) for name, param in avg_params.items()} 135 | for i in range(self.num_clients): 136 | client_params = self.clients[i].get_lora_parameters() 137 | for name, (param, _) in client_params.items(): 138 | avg_params[name] = (avg_params[name][0] + param / self.num_clients, _) 139 | 140 | total_diff_sum = 0.0 141 | num_pairs = 0 142 | 143 | # Compute product differences for each pair of clients 144 | for i in range(self.num_clients): 145 | product_1 = self.extract_and_multiply_lora_params(param_groups[i]) 146 | product_2 = self.extract_and_multiply_lora_params(avg_params) 147 | pair_diff_sum = 0.0 148 | 149 | # Calculate the difference between the products 150 | for key in product_1.keys(): 151 | diff = product_1[key] - product_2[key] 152 | pair_diff_sum += torch.norm(diff).item() 153 | 154 | total_diff_sum += pair_diff_sum 155 | num_pairs += 1 156 | 157 | # Return the average difference 158 | average_diff = total_diff_sum / num_pairs if num_pairs > 0 else 0.0 159 | return average_diff 160 | 161 | 162 | def main() -> None: 163 | """ 164 | Main function to initialize clients and federated server, and perform client initialization. 165 | """ 166 | # Configuration parameters 167 | model_checkpoint = '/home/ubuntu/smyin/models/distilbert-base-uncased' 168 | dataset_path_template = "/home/ubuntu/smyin/dataset/decentrilized_dataset/sst2_020/client_{}" 169 | val_dataset_path_template = "/home/ubuntu/smyin/dataset/glue/sst2" 170 | num_clients = 7 171 | lora_r = 4 172 | lora_alpha = 32 173 | target_modules = ["q_lin", "v_lin", "pre_classifier", "pre_classifier"] # Example modules 174 | training_type = "LoRA" 175 | dataset_type = "sst2" 176 | batch_size=128 177 | 178 | # Initialize clients 179 | clients = [] 180 | for i in range(num_clients): 181 | dataset_path = dataset_path_template.format(i + 1) 182 | val_dataset_path = val_dataset_path_template 183 | client = Client( 184 | model_checkpoint=model_checkpoint, 185 | dataset_path=dataset_path, 186 | val_dataset_path=val_dataset_path, 187 | lora_r=lora_r, 188 | lora_alpha=lora_alpha, 189 | target_modules=target_modules, 190 | training_type=training_type, 191 | dataset_type=dataset_type, 192 | device='cuda' if torch.cuda.is_available() else 'cpu', 193 | batch_size=batch_size 194 | ) 195 | clients.append(client) 196 | 197 | # Initialize federated server 198 | server = FederatedServer(clients) 199 | 200 | # Initialize and evaluate the clients 201 | server.clients_initialize_info() 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /finetune/bert/client.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding 6 | from datasets import load_dataset, load_from_disk 7 | from peft import get_peft_model, LoraConfig 8 | from tqdm.auto import tqdm 9 | from transformers import get_scheduler 10 | from sklearn.metrics import accuracy_score 11 | 12 | 13 | class Client: 14 | def __init__(self, model_checkpoint, dataset_path, val_dataset_path, lora_r, lora_alpha, target_modules, 15 | training_type, dataset_type, device='cpu', batch_size=128): 16 | """ 17 | Initializes the Client for model training with specified configurations. 18 | 19 | Args: 20 | model_checkpoint (str): Pre-trained model checkpoint path. 21 | dataset_path (str): Path to the training dataset. 22 | val_dataset_path (str): Path to the validation dataset. 23 | lora_r (int): Rank of LoRA layers. 24 | lora_alpha (int): Alpha value for LoRA layers. 25 | target_modules (list): List of target modules for LoRA. 26 | training_type (str): Type of training ('LoRA' or 'ConLoRA'). 27 | dataset_type (str): Type of dataset ('sst2', 'mnli', or 'qnli'). 28 | device (str): Device for training ('cpu' or 'cuda'). 29 | batch_size (int): Batch size for the DataLoader. 30 | """ 31 | self.device = device 32 | self.lora_r = lora_r 33 | self.lora_alpha = lora_alpha 34 | self.training_type = training_type 35 | self.dataset_type = dataset_type 36 | self.dataset = load_from_disk(dataset_path) 37 | self.raw_val_dataset = load_dataset(val_dataset_path) 38 | self.batch_size = batch_size # Save batch size as an instance variable 39 | 40 | # Determine number of labels based on dataset type 41 | self.num_labels = self._get_num_labels() 42 | 43 | # Load tokenizer and model 44 | self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True) 45 | self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=self.num_labels).to(self.device) 46 | 47 | # Configure model for LoRA training 48 | self.configure_model(target_modules) 49 | 50 | # Prepare data loaders for training and validation datasets 51 | self.data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer) 52 | self.train_data = self.prepare_data(self.dataset) 53 | self.val_data = self.prepare_data(self.raw_val_dataset['validation']) 54 | 55 | # Optimizer setup 56 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 57 | 58 | def _get_num_labels(self): 59 | """ 60 | Returns the number of labels based on the dataset type. 61 | """ 62 | if self.dataset_type == "sst2": 63 | return 2 # SST-2 is a binary classification problem 64 | elif self.dataset_type == "mnli": 65 | return 3 # MNLI is a 3-class classification problem 66 | elif self.dataset_type == "qnli": 67 | return 2 # QNLI is a binary classification problem 68 | else: 69 | raise ValueError(f"Unsupported dataset type: {self.dataset_type}") 70 | 71 | def tokenize_function(self, examples): 72 | """ 73 | Tokenizes the input text based on dataset type (sst2, mnli, qnli). 74 | Args: 75 | examples (dict): A dictionary of input examples. 76 | """ 77 | if self.dataset_type == "sst2": 78 | text = examples["sentence"] 79 | self.tokenizer.truncation_side = "left" 80 | tokenized_inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) 81 | tokenized_inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long) 82 | 83 | elif self.dataset_type == "mnli": 84 | texts = (examples["premise"], examples["hypothesis"]) 85 | self.tokenizer.truncation_side = "left" 86 | tokenized_inputs = self.tokenizer(*texts, return_tensors="pt", truncation=True, max_length=512, padding=True) 87 | tokenized_inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long) 88 | 89 | elif self.dataset_type == "qnli": 90 | texts = (examples["question"], examples["sentence"]) 91 | self.tokenizer.truncation_side = "left" 92 | tokenized_inputs = self.tokenizer(*texts, return_tensors="pt", truncation=True, max_length=512, padding=True) 93 | tokenized_inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long) 94 | 95 | return tokenized_inputs 96 | 97 | def prepare_data(self, dataset): 98 | """ 99 | Prepares the dataset for training or evaluation by tokenizing and formatting it. 100 | 101 | Args: 102 | dataset (Dataset): The dataset to be processed. 103 | """ 104 | tokenized_dataset = dataset.map(self.tokenize_function, batched=True) 105 | tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) 106 | return DataLoader(tokenized_dataset, batch_size=self.batch_size, collate_fn=self.data_collator) 107 | 108 | def configure_model(self, target_modules): 109 | """ 110 | Configures the model for LoRA (Low-Rank Adaptation) training. 111 | 112 | Args: 113 | target_modules (list): List of target modules for LoRA. 114 | """ 115 | if self.tokenizer.pad_token is None: 116 | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 117 | self.model.resize_token_embeddings(len(self.tokenizer)) 118 | 119 | if self.training_type == "LoRA": 120 | peft_config = LoraConfig(task_type="SEQ_CLS", r=self.lora_r, lora_alpha=self.lora_alpha, lora_dropout=0.01, target_modules=target_modules) 121 | self.model = get_peft_model(self.model, peft_config) 122 | self._freeze_parameters() 123 | 124 | elif self.training_type == "ConLoRA": 125 | peft_config = LoraConfig(task_type="SEQ_CLS", r=self.lora_r, lora_alpha=self.lora_alpha, lora_dropout=0.01, target_modules=target_modules) 126 | self.model = get_peft_model(self.model, peft_config) 127 | self._freeze_parameters(freeze_lora_A=True) 128 | 129 | def _freeze_parameters(self, freeze_lora_A=False): 130 | """ 131 | Freezes the parameters of the model based on the training type. 132 | 133 | Args: 134 | freeze_lora_A (bool): Whether to freeze 'lora_A' parameters. 135 | """ 136 | for name, param in self.model.named_parameters(): 137 | if "pre_classifier.modules_to_save.default.base_layer" in name or "classifier.modules_to_save.default.base_layer" in name: 138 | param.requires_grad = False 139 | if freeze_lora_A and "lora_A" in name: 140 | param.requires_grad = False 141 | print(self.model.print_trainable_parameters()) 142 | 143 | def train_one_epoch(self): 144 | """ 145 | Trains the model for one epoch. 146 | 147 | Returns: 148 | loss (float): The loss value for the epoch. 149 | """ 150 | self.model.train() 151 | progress_bar = tqdm(range(len(self.train_data))) 152 | lr_scheduler = get_scheduler("linear", optimizer=self.optimizer, num_warmup_steps=0, num_training_steps=len(self.train_data)) 153 | 154 | for batch in self.train_data: 155 | batch = {k: v.to(self.device) for k, v in batch.items()} 156 | outputs = self.model(**batch) 157 | loss = outputs.loss 158 | loss.backward() 159 | self.optimizer.step() 160 | #lr_scheduler.step() 161 | self.optimizer.zero_grad() 162 | progress_bar.update(1) 163 | 164 | return loss.item() 165 | 166 | def evaluate(self): 167 | """ 168 | Evaluates the model on the validation dataset. 169 | 170 | Returns: 171 | accuracy (float): The accuracy on the validation dataset. 172 | """ 173 | self.model.eval() 174 | all_predictions = [] 175 | all_labels = [] 176 | 177 | for batch in self.val_data: 178 | batch = {k: v.to(self.device) for k, v in batch.items()} 179 | with torch.no_grad(): 180 | outputs = self.model(**batch) 181 | logits = outputs.logits 182 | predictions = torch.argmax(logits, dim=-1) 183 | all_predictions.extend(predictions.cpu().numpy()) 184 | all_labels.extend(batch["labels"].cpu().numpy()) 185 | 186 | accuracy = accuracy_score(all_labels, all_predictions) 187 | return accuracy 188 | 189 | def print_trainable_parameters(self): 190 | """ 191 | Prints the number of trainable and total parameters in the model. 192 | """ 193 | total_params = sum(p.numel() for p in self.model.parameters()) 194 | trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 195 | 196 | print(f"Total parameters: {total_params}") 197 | print(f"Trainable parameters: {trainable_params}") 198 | 199 | def num_get_trainable_parameters(self): 200 | """ 201 | Prints detailed information about each trainable parameter. 202 | """ 203 | trainable_params = self.get_trainable_parameters() 204 | total_trainable_params = sum(param.numel() for param, _ in trainable_params.values()) 205 | print(f"Total number of trainable parameters: {total_trainable_params}") 206 | 207 | print("Details of trainable parameters:") 208 | for name, (param, requires_grad) in trainable_params.items(): 209 | print(f"Layer: {name}, Parameter count: {param.numel()}, Requires Grad: {requires_grad}") 210 | 211 | def get_trainable_parameters(self): 212 | """ 213 | Returns the trainable parameters of the model. 214 | """ 215 | trainable_params = {} 216 | for name, param in self.model.named_parameters(): 217 | if param.requires_grad: 218 | trainable_params[name] = (param.clone().detach(), param.requires_grad) 219 | return trainable_params 220 | 221 | def get_lora_parameters(self): 222 | """ 223 | Returns the parameters associated with LoRA layers. 224 | 225 | LoRA parameters typically include layers with names containing 'lora_A' or 'lora_B'. 226 | 227 | Returns: 228 | dict: A dictionary with parameter names as keys and their respective values (cloned and detached) as values. 229 | """ 230 | lora_params = {} 231 | for name, param in self.model.named_parameters(): 232 | if "lora_A" in name or "lora_B" in name: 233 | # Clone the parameter and preserve the 'requires_grad' property 234 | lora_params[name] = (param.clone().detach(), param.requires_grad) 235 | return lora_params 236 | 237 | def get_all_parameters(self): 238 | """ 239 | Returns all parameters of the model, including their names and whether they require gradients. 240 | 241 | Returns: 242 | dict: A dictionary with parameter names as keys and tuples of (cloned parameter, requires_grad) as values. 243 | """ 244 | all_params = {} 245 | for name, param in self.model.named_parameters(): 246 | all_params[name] = (param.clone().detach(), param.requires_grad) 247 | return all_params 248 | 249 | def get_last_2_layers(self): 250 | """ 251 | Returns the parameters for the last two layers of the model, typically the classification layers. 252 | 253 | These layers are often named 'pre_classifier' and 'classifier'. 254 | 255 | Returns: 256 | dict: A dictionary with parameter names as keys and their respective values (cloned and detached) as values. 257 | """ 258 | last_2_layers_params = {} 259 | for name, param in self.model.named_parameters(): 260 | if "pre_classifier" in name or "classifier" in name: 261 | last_2_layers_params[name] = (param.clone().detach(), param.requires_grad) 262 | return last_2_layers_params 263 | 264 | def get_lora_A(self): 265 | """ 266 | Returns the parameters for the 'lora_A' layers in the model. 267 | 268 | Returns: 269 | dict: A dictionary with parameter names as keys and their respective values (cloned and detached) as values. 270 | """ 271 | lora_A_params = {} 272 | for name, param in self.model.named_parameters(): 273 | if "lora_A" in name: 274 | lora_A_params[name] = (param.clone().detach(), param.requires_grad) 275 | return lora_A_params 276 | 277 | def set_trainable_parameters(self, trainable_params): 278 | """ 279 | Sets the trainable parameters for the model based on the provided parameters. 280 | 281 | Args: 282 | trainable_params (dict): A dictionary with parameter names as keys and tuples of (parameter, requires_grad) as values. 283 | """ 284 | for name, (param, requires_grad) in trainable_params.items(): 285 | if name in dict(self.model.named_parameters()): 286 | current_param = dict(self.model.named_parameters())[name] 287 | current_param.data = param.clone().to(self.device) 288 | current_param.requires_grad = requires_grad 289 | 290 | 291 | def main(): 292 | parser = argparse.ArgumentParser(description="Client model training script.") 293 | parser.add_argument('--target_modules', type=str, required=True, help="List of target modules for LoRA.") 294 | parser.add_argument('--model_checkpoint', type=str, required=True, help="Path to the model checkpoint.") 295 | parser.add_argument('--dataset_path', type=str, required=True, help="Path to the dataset.") 296 | parser.add_argument('--lora_r', type=int, required=True, help="Rank of LoRA layers.") 297 | parser.add_argument('--lora_alpha', type=int, required=True, help="Alpha value for LoRA layers.") 298 | parser.add_argument('--training_type', type=str, choices=['LoRA', 'ConLoRA'], required=True, help="Training type.") 299 | parser.add_argument('--device', type=str, default='cpu', help="Device for training (default: cpu).") 300 | parser.add_argument('--num_epochs', type=int, default=5, help="Number of epochs to train (default: 5).") 301 | parser.add_argument('--val_dataset_path', type=str, required=True, help="Path to the validation dataset.") 302 | parser.add_argument('--dataset_type', type=str, choices=['sst2', 'mnli', 'qnli'], required=True, help="Type of dataset.") 303 | parser.add_argument('--batch_size', type=int, default=128, help="Batch size for training.") 304 | 305 | args = parser.parse_args() 306 | 307 | # Convert target_modules string to list 308 | target_modules = args.target_modules.split() 309 | 310 | client = Client( 311 | model_checkpoint=args.model_checkpoint, 312 | dataset_path=args.dataset_path, 313 | val_dataset_path=args.val_dataset_path, 314 | lora_r=args.lora_r, 315 | lora_alpha=args.lora_alpha, 316 | training_type=args.training_type, 317 | target_modules=target_modules, 318 | device=args.device, 319 | dataset_type=args.dataset_type, 320 | batch_size=args.batch_size # Pass batch_size from arguments 321 | ) 322 | 323 | # Training loop 324 | for epoch in range(args.num_epochs): 325 | loss = client.train_one_epoch() 326 | accuracy = client.evaluate() 327 | print(f"Epoch {epoch + 1}/{args.num_epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}") 328 | 329 | client.print_trainable_parameters() 330 | 331 | 332 | if __name__ == "__main__": 333 | main() 334 | -------------------------------------------------------------------------------- /result_processing/avgResult.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | example_data=[[0.324095771777891, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.337238920020377, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.32460519612837496, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.32745797249108505, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3272542027508915, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713], [0.324095771777891, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3271523178807947, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3272542027508915, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883], [0.324095771777891, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.32592969943963324, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3174732552215996, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.32745797249108505, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3272542027508915, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3271523178807947, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3183902190524707, 0.3272542027508915, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3182883341823739, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883], [0.324095771777891, 0.3273560876209883, 0.31818644931227713, 0.3195109526235354, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3182883341823739, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3182883341823739, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3552725420275089, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3272542027508915, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3183902190524707, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3543555781966378, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3182883341823739, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.3543555781966378, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346], [0.324095771777891, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3182883341823739, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.35394803871625063, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31808456444218036, 0.31818644931227713, 0.31818644931227713, 0.3182883341823739, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346], [0.324095771777891, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.32674477840040755, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3268466632705043, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.31808456444218036, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3182883341823739, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3272542027508915, 0.3544574630667346, 0.35455934793683136, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.35455934793683136, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3543555781966378, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346], [0.324095771777891, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.327967396841569, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3268466632705043, 0.3273560876209883, 0.32766174223127864, 0.3272542027508915, 0.31818644931227713, 0.3182883341823739, 0.32755985736118187, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883]] 4 | 5 | # 转换为 NumPy 数组 6 | data = np.array(example_data) 7 | 8 | # 计算每个 epoch 的平均值 9 | epoch_means = np.mean(data, axis=0) 10 | #print("每个 epoch 的平均值:", epoch_means) 11 | 12 | # 取最后 50 轮的平均值 13 | # 示例中将取最后 10 轮,因为示例数据只有 10 轮,实际使用中请替换为你的数据的最后 50 轮+ 14 | last_50_epoch_means = epoch_means[-10:] # 请替换为 -50 以匹配你的数据长度 15 | 16 | # 计算平均值 17 | mean_last_50 = np.mean(last_50_epoch_means) 18 | print("最后 50 轮的平均值:", mean_last_50) 19 | 20 | # 计算置信区间(95% 置信水平) 21 | #confidence_level = 0.95 22 | #degrees_freedom = len(last_50_epoch_means) - 1 23 | #sample_mean = np.mean(last_50_epoch_means) 24 | #sample_standard_error = stats.sem(last_50_epoch_means) 25 | #confidence_interval = stats.t.interval(confidence_level, degrees_freedom, sample_mean, sample_standard_error) 26 | 27 | #print("最后 50 轮的平均值的95%置信区间:", confidence_interval) --------------------------------------------------------------------------------