├── __init__.py
├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ └── generate_weight_matrix.cpython-310.pyc
├── view_split.sh
├── gsm8k_split.sh
├── glue_split.sh
├── generate_weight_matrix.py
├── gsm8k_split.py
├── view_split.py
└── glue_split.py
├── finetune
├── __init__.py
├── bert
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── client.cpython-310.pyc
│ │ ├── server.cpython-310.pyc
│ │ ├── train.cpython-310.pyc
│ │ └── __init__.cpython-310.pyc
│ ├── client.sh
│ ├── train.sh
│ ├── train.py
│ ├── server.py
│ └── client.py
├── llama
│ ├── __init__.py
│ ├── client.py
│ ├── gsm8k_eval.py
│ ├── gsm8k_train.py
│ └── server.py
└── __pycache__
│ └── __init__.cpython-310.pyc
├── __pycache__
└── __init__.cpython-312.pyc
├── result_processing
├── errorNoniid.m
├── errorBar.m
├── heatmap.py
├── topology.py
├── ConfidenceInterval.m
└── avgResult.py
├── requirements.txt
├── README.md
└── LICENSE
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/finetune/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/finetune/bert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/finetune/llama/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/finetune/llama/client.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/finetune/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/finetune/bert/__pycache__/client.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/client.cpython-310.pyc
--------------------------------------------------------------------------------
/finetune/bert/__pycache__/server.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/server.cpython-310.pyc
--------------------------------------------------------------------------------
/finetune/bert/__pycache__/train.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/train.cpython-310.pyc
--------------------------------------------------------------------------------
/finetune/bert/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/finetune/bert/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/generate_weight_matrix.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LuYuPromax/DFL-LORA/HEAD/utils/__pycache__/generate_weight_matrix.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/view_split.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 设置默认参数
4 | BASE_PATH="/home/ubuntu/smyin/ConLoRA/datasets/decentrilized_dataset/gsm8k"
5 | NUM_CLIENTS=7
6 | DATASET_TYPE="gsm8k" # 或 "gsm8k"
7 |
8 | # 运行 Python 脚本
9 | python3 view_split.py \
10 | --base_path "$BASE_PATH" \
11 | --num_clients $NUM_CLIENTS \
12 | --dataset_type "$DATASET_TYPE"
--------------------------------------------------------------------------------
/utils/gsm8k_split.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 设置默认参数
4 | DATASET_PATH="/home/ubuntu/smyin/ConLoRA/datasets/gsm8k"
5 | OUTPUT_PATH="/home/ubuntu/smyin/ConLoRA/datasets/decentralized_dataset/gsm8k22222"
6 | NUM_CLIENTS=7
7 | SPLIT_TYPE="dirichlet"
8 | ALPHA=0.25
9 |
10 | # 运行 Python 脚本
11 | python gsm8k_split.py \
12 | --dataset_path "$DATASET_PATH" \
13 | --output_path "$OUTPUT_PATH" \
14 | --num_clients $NUM_CLIENTS \
15 | --split_type "$SPLIT_TYPE" \
16 | --alpha $ALPHA
17 |
--------------------------------------------------------------------------------
/utils/glue_split.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 设置默认参数
4 | DATASET_PATH="/home/ubuntu/smyin/ConLoRA/datasets/glue/qnli"
5 | OUTPUT_PATH="/home/ubuntu/smyin/ConLoRA/datasets/decentrilized_dataset/qnli22222"
6 | NUM_CLIENTS=4
7 | SPLIT_TYPE="dirichlet"
8 | ALPHA=0.25
9 | MIN_SIZE=1500
10 |
11 | # 运行 Python 脚本
12 | python3 glue_split.py \
13 | --dataset_path "$DATASET_PATH" \
14 | --output_path "$OUTPUT_PATH" \
15 | --num_clients $NUM_CLIENTS \
16 | --split_type "$SPLIT_TYPE" \
17 | --alpha $ALPHA \
18 | --min_size $MIN_SIZE
19 |
--------------------------------------------------------------------------------
/result_processing/errorNoniid.m:
--------------------------------------------------------------------------------
1 | % 定义数据
2 | x = 1:256; % x 轴数据,例如 1 到 10
3 | % 数据
4 | DLoRA_iid = [];
5 | LoRA_DFL_iid = [];
6 | DLoRA_noniid = [];
7 | LoRA_DFL_noniid = [];
8 | % 绘制折线图
9 | figure;
10 | plot(x, DLoRA_iid, '-', 'LineWidth', 1.5); % 第一组数据,实线
11 | hold on;
12 | plot(x, LoRA_DFL_iid, '-', 'LineWidth', 1.5); % 第二组数据,虚线
13 | plot(x, DLoRA_noniid, '-', 'LineWidth', 1.5); % 第三组数据,点线
14 | plot(x, LoRA_DFL_noniid, '-', 'LineWidth', 1.5); % 第四组数据,点划线
15 | hold off;
16 |
17 | % 设置图例
18 | legend('ConLoRA-IID', 'LoRA-DFL-IID', 'ConLoRA-nonIID', 'LoRA-DFL-nonIID',Location='northwest');
19 |
20 | % 设置标题和轴标签
21 | %title('Line Plot of Four Data Groups');
22 | xlabel('Epoch');
23 | ylabel('Consensus error');
24 |
25 | % 添加网格
26 | grid on;
27 |
28 | % 调整轴范围
29 | xlim([1 260]);
30 | ylim([0 22]);
31 |
--------------------------------------------------------------------------------
/result_processing/errorBar.m:
--------------------------------------------------------------------------------
1 | % 定义数据
2 | network_density = [3, 4, 5, 6];
3 | error_values_group1 = [12.42, 9.84, 5.37, 4.29];
4 | error_values_group2 = [1.70, 1.38, 0.82, 0.53];
5 |
6 | % 将两组数据组合成一个矩阵
7 | error_values = [error_values_group1; error_values_group2]';
8 |
9 | % 绘制分组柱状图
10 | figure;
11 | b = bar(network_density, error_values, 'grouped'); % 返回柱状图对象
12 |
13 | % 设置轴标签
14 | xlabel('Average node connectivity');
15 | ylabel('Consensus error');
16 |
17 | % 添加图例
18 | legend('LoRA-DFL', 'ConLoRA');
19 |
20 | % 显示网格
21 | grid on;
22 |
23 | % 在每个柱子上显示数值
24 | % 获取柱子的x坐标和高度
25 | for i = 1:length(b)
26 | xData = b(i).XEndPoints; % 获取每组柱子的x坐标
27 | yData = b(i).YData; % 获取柱子的y值
28 | for j = 1:length(xData)
29 | text(xData(j), yData(j), sprintf('%.2f', yData(j)), 'HorizontalAlignment', 'center', 'VerticalAlignment', 'bottom');
30 | end
31 | end
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.32.1
2 | aiohttp==3.9.5
3 | aiosignal==1.3.1
4 | async-timeout==4.0.3
5 | attrs==23.2.0
6 | boltons
7 | brotlipy==0.7.0
8 | certifi
9 | cffi
10 | charset-normalizer
11 | cmake==3.25.0
12 | contourpy==1.1.0
13 | cryptography
14 | cycler==0.11.0
15 | datasets==2.20.0
16 | dill==0.3.8
17 | filelock==3.15.4
18 | fonttools==4.40.0
19 | frozenlist==1.4.1
20 | fsspec==2024.5.0
21 | huggingface-hub==0.23.4
22 | idna
23 | Jinja2==3.1.2
24 | joblib==1.4.2
25 | jsonpatch
26 | jsonpointer==2.1
27 | kiwisolver==1.4.4
28 | lit==15.0.7
29 | MarkupSafe==2.1.2
30 | matplotlib==3.7.1
31 | mpmath==1.2.1
32 | multidict==6.0.5
33 | multiprocess==0.70.16
34 | networkx==3.0
35 | numpy==1.24.1
36 | packaging
37 | pandas==2.0.2
38 | peft==0.11.1
39 | Pillow==9.3.0
40 | pluggy
41 | psutil==6.0.0
42 | pyarrow==16.1.0
43 | pyarrow-hotfix==0.6
44 | pycosat
45 | pycparser
46 | pyOpenSSL
47 | pyparsing==3.1.0
48 | PySocks
49 | python-dateutil==2.8.2
50 | pytz==2023.3
51 | PyYAML==6.0.1
52 | regex==2024.5.15
53 | requests==2.32.3
54 | ruamel.yaml
55 | ruamel.yaml.clib
56 | safetensors==0.4.3
57 | scikit-learn==1.5.1
58 | scipy==1.14.0
59 | seaborn==0.12.2
60 | six
61 | sympy==1.11.1
62 | threadpoolctl==3.5.0
63 | tokenizers==0.19.1
64 | toolz
65 | torch==2.0.1
66 | torchaudio==2.0.2
67 | torchsummary==1.5.1
68 | torchvision==0.15.2
69 | tqdm==4.66.4
70 | transformers==4.42.3
71 | triton==2.0.0
72 | typing_extensions==4.4.0
73 | tzdata==2023.3
74 | urllib3
75 | xxhash==3.4.1
76 | yarl==1.9.4
77 | zstandard
78 |
--------------------------------------------------------------------------------
/result_processing/heatmap.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import numpy as np
4 |
5 | # Example data: Replace with actual data
6 | ConLoRA_result = np.array([
7 | [0, 0, 0, 0],
8 | [0, 0, 0, 0],
9 | [0, 0, 0, 0],
10 | [0, 0, 0, 0]
11 | ])
12 |
13 | LoRA_result = np.array([
14 | [0, 0, 0, 0],
15 | [0, 0, 0, 0],
16 | [0, 0, 0, 0],
17 | [0, 0, 0, 0]
18 | ])
19 |
20 | # Calculate the min and max values for the common color scale
21 | min_value = min(ConLoRA_result.min(), LoRA_result.min())
22 | max_value = max(ConLoRA_result.max(), LoRA_result.max())
23 |
24 | # Create the figure
25 | plt.figure(figsize=(8, 12))
26 |
27 | # Plot the first heatmap
28 | plt.subplot(2, 1, 1)
29 | ax1 = sns.heatmap(
30 | ConLoRA_result,
31 | annot=True,
32 | fmt=".2f",
33 | cmap="Blues",
34 | cbar_kws={'label': 'Accuracy'},
35 | vmin=min_value,
36 | vmax=max_value
37 | )
38 | ax1.set_xlabel('α (Dirichlet Parameters)')
39 | ax1.set_ylabel('Average Connectivity')
40 | ax1.set_xticklabels(['0.1', '0.15', '0.2', '0.25'])
41 | ax1.set_yticklabels(['3', '4', '5', '6'])
42 |
43 | # Plot the second heatmap
44 | plt.subplot(2, 1, 2)
45 | ax2 = sns.heatmap(
46 | LoRA_result,
47 | annot=True,
48 | fmt=".2f",
49 | cmap="Blues",
50 | cbar_kws={'label': 'Accuracy'},
51 | vmin=min_value,
52 | vmax=max_value
53 | )
54 | ax2.set_xlabel('α (Dirichlet Parameters)')
55 | ax2.set_ylabel('Average Connectivity')
56 | ax2.set_xticklabels(['0.1', '0.15', '0.2', '0.25'])
57 | ax2.set_yticklabels(['3', '4', '5', '6'])
58 |
59 | # Adjust layout to prevent overlap
60 | plt.tight_layout()
61 |
62 | # Display the plot
63 | plt.show()
64 |
65 | # Save the plot as a PNG file
66 | plt.savefig("123.png")
67 |
--------------------------------------------------------------------------------
/finetune/bert/client.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # ========================================================================
4 | # Set default parameters
5 | # ========================================================================
6 |
7 | # Path to the pre-trained model
8 | MODEL_CHECKPOINT="/home/ubuntu/smyin/models/distilbert-base-uncased"
9 |
10 | # Path to the training dataset
11 | DATASET_PATH="/home/ubuntu/smyin/dataset/decentrilized_dataset/sst2_005/client_1"
12 |
13 | # Path to the validation dataset
14 | VAL_DATASET_PATH="/home/ubuntu/smyin/dataset/glue/sst2"
15 |
16 | # Rank of LoRA layers
17 | LORA_R=4
18 |
19 | # Alpha value for LoRA
20 | LORA_ALPHA=32
21 |
22 | # Training type: "LoRA" or "ConLoRA"
23 | TRAINING_TYPE="ConLoRA"
24 |
25 | # List of target modules for LoRA
26 | TARGET_MODULES="q_lin v_lin pre_classifier classifier"
27 |
28 | # Device selection: "cuda" or "cpu"
29 | DEVICE="cuda"
30 |
31 | # Number of epochs for training
32 | NUM_EPOCHS=10
33 |
34 | # Dataset type (e.g., "sst2", "mnli", "qnli")
35 | DATASET_TYPE="sst2"
36 |
37 | # Number of epochs for training
38 | BATCH_SIZE=64
39 |
40 | # ========================================================================
41 | # Validate input parameters
42 | # ========================================================================
43 |
44 | # Check if the model checkpoint path exists
45 | if [ ! -d "$MODEL_CHECKPOINT" ]; then
46 | echo "Error: Model path $MODEL_CHECKPOINT does not exist!"
47 | exit 1
48 | fi
49 |
50 | # Check if the training dataset path exists
51 | if [ ! -d "$DATASET_PATH" ]; then
52 | echo "Error: Dataset path $DATASET_PATH does not exist!"
53 | exit 1
54 | fi
55 |
56 | # Check if the validation dataset path exists
57 | if [ ! -d "$VAL_DATASET_PATH" ]; then
58 | echo "Error: Validation dataset path $VAL_DATASET_PATH does not exist!"
59 | exit 1
60 | fi
61 |
62 | # ========================================================================
63 | # Run the training script
64 | # ========================================================================
65 |
66 | echo "Training is starting..."
67 |
68 | python3 client.py \
69 | --model_checkpoint "$MODEL_CHECKPOINT" \
70 | --dataset_path "$DATASET_PATH" \
71 | --val_dataset_path "$VAL_DATASET_PATH" \
72 | --lora_r "$LORA_R" \
73 | --lora_alpha "$LORA_ALPHA" \
74 | --training_type "$TRAINING_TYPE" \
75 | --target_modules "$TARGET_MODULES" \
76 | --device "$DEVICE" \
77 | --num_epochs "$NUM_EPOCHS" \
78 | --dataset_type "$DATASET_TYPE" \
79 | --batch_size "$BATCH_SIZE"
80 |
81 | echo "Training has completed!"
82 |
--------------------------------------------------------------------------------
/result_processing/topology.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import networkx as nx
3 | import matplotlib.pyplot as plt
4 |
5 | # Define adjacency matrices for the topologies
6 | topologies = {
7 | "T1": np.array([
8 | [0, 1, 0, 0, 0, 0, 1],
9 | [1, 0, 1, 0, 0, 0, 0],
10 | [0, 1, 0, 1, 0, 0, 0],
11 | [0, 0, 1, 0, 1, 0, 0],
12 | [0, 0, 0, 1, 0, 1, 0],
13 | [0, 0, 0, 0, 1, 0, 1],
14 | [1, 0, 0, 0, 0, 1, 0]
15 | ]),
16 | "T2": np.array([
17 | [0, 1, 1, 0, 0, 0, 1],
18 | [1, 0, 1, 1, 0, 0, 0],
19 | [1, 1, 0, 1, 0, 0, 0],
20 | [0, 1, 1, 0, 1, 0, 0],
21 | [0, 0, 0, 1, 0, 1, 1],
22 | [0, 0, 0, 0, 1, 0, 1],
23 | [1, 0, 0, 0, 1, 1, 0]
24 | ]),
25 | "T3": np.array([
26 | [0, 1, 1, 0, 0, 1, 1],
27 | [1, 0, 1, 1, 0, 0, 1],
28 | [1, 1, 0, 1, 1, 0, 0],
29 | [0, 1, 1, 0, 1, 1, 0],
30 | [0, 0, 1, 1, 0, 1, 1],
31 | [1, 0, 0, 1, 1, 0, 1],
32 | [1, 1, 0, 0, 1, 1, 0]
33 | ]),
34 | "T4": np.array([
35 | [0, 1, 1, 1, 0, 1, 1],
36 | [1, 0, 1, 1, 1, 1, 0],
37 | [1, 1, 0, 1, 1, 1, 0],
38 | [1, 1, 1, 0, 1, 1, 0],
39 | [0, 1, 1, 1, 0, 1, 1],
40 | [1, 1, 1, 1, 1, 0, 1],
41 | [1, 0, 0, 0, 1, 1, 0]
42 | ])
43 | }
44 |
45 | # Plot the topologies
46 | # First, create the base ring structure from T1
47 | num_nodes = topologies["T1"].shape[0]
48 | base_graph = nx.Graph()
49 | for i in range(num_nodes):
50 | base_graph.add_edge(i, (i + 1) % num_nodes)
51 |
52 | # Add additional edges based on T1 adjacency matrix
53 | for i in range(num_nodes):
54 | for j in range(num_nodes):
55 | if topologies["T1"][i, j] == 1 and not base_graph.has_edge(i, j):
56 | base_graph.add_edge(i, j)
57 |
58 | # Define a fixed layout for consistent node positioning
59 | fixed_pos = nx.spring_layout(base_graph, seed=42)
60 |
61 | # Draw the base graph T1
62 | plt.figure(figsize=(8, 6))
63 | nx.draw(base_graph, pos=fixed_pos, with_labels=True, node_color='skyblue', node_size=3000, font_size=30, font_weight='bold', edge_color='gray', width=6)
64 | plt.savefig("T1.png", format='png', dpi=300, bbox_inches='tight')
65 | plt.close()
66 |
67 | # Plot the other topologies based on T1
68 | for name, adjacency_matrix in topologies.items():
69 | if name == "T1":
70 | continue
71 |
72 | G = base_graph.copy()
73 |
74 | # Add additional edges based on the adjacency matrix
75 | for i in range(num_nodes):
76 | for j in range(num_nodes):
77 | if adjacency_matrix[i, j] == 1 and not G.has_edge(i, j):
78 | G.add_edge(i, j)
79 |
80 | # Draw the graph
81 | plt.figure(figsize=(8, 6))
82 | nx.draw(G, pos=fixed_pos, with_labels=True, node_color='skyblue', node_size=3000, font_size=30, font_weight='bold', edge_color='gray', width=6)
83 | plt.savefig(f"{name}.png", format='png', dpi=300, bbox_inches='tight')
84 | plt.close()
85 |
86 |
87 |
--------------------------------------------------------------------------------
/result_processing/ConfidenceInterval.m:
--------------------------------------------------------------------------------
1 | function [mean_accuracies, confidence_intervals] = calculate_confidence_intervals(accuracy_data, confidence_level)
2 | if nargin < 2
3 | confidence_level = 0.95; % 默认置信水平
4 | end
5 |
6 | [epochs, clients] = size(accuracy_data); % epochs代表行,clients代表列
7 | mean_accuracies = zeros(1, clients); % 每一列的平均值
8 | confidence_intervals = zeros(clients, 2); % 每列的置信区间
9 |
10 | for client = 1:clients
11 | accuracies = accuracy_data(:, client); % 取出每一列的数据
12 | mean_accuracy = mean(accuracies); % 计算每一列的平均值
13 | standard_error = std(accuracies) / sqrt(epochs); % 计算标准误差
14 | t_value = tinv(1 - (1 - confidence_level) / 2, epochs - 1); % 查找t分布临界值
15 |
16 | % 计算置信区间
17 | margin_of_error = t_value * standard_error;
18 | confidence_interval = [mean_accuracy - margin_of_error, mean_accuracy + margin_of_error];
19 |
20 | mean_accuracies(client) = mean_accuracy;
21 | confidence_intervals(client, :) = confidence_interval;
22 | end
23 | end
24 |
25 | %此处需要输入自己的结果数据
26 | acc_data_B1=[];
27 | acc_data_B2=[];
28 | acc_data_B3=[];
29 |
30 | acc_data_BA1=[];
31 | acc_data_BA2=[];
32 | acc_data_BA3=[];
33 | [mean_accuracies1, confidence_intervals1] = calculate_confidence_intervals(acc_data_B2);
34 | [mean_accuracies2, confidence_intervals2] = calculate_confidence_intervals(acc_data_BA2);
35 |
36 | function plot_accuracies_with_confidence_intervals(mean_accuracies_1, confidence_intervals_1,mean_accuracies_2, confidence_intervals_2)
37 | epochs = 1:length(mean_accuracies_1);
38 |
39 | % 提取置信区间的上下界
40 | lower_bounds_1 = confidence_intervals_1(:, 1);
41 | upper_bounds_1 = confidence_intervals_1(:, 2);
42 |
43 | lower_bounds_2 = confidence_intervals_2(:, 1);
44 | upper_bounds_2 = confidence_intervals_2(:, 2);
45 |
46 | figure('Position', [100, 100, 750, 500]); % 设置图的大小
47 |
48 | % 绘制 DLora 的均值和置信区间
49 | plot(epochs, mean_accuracies_1, 'b', 'DisplayName', 'ConLoRA', 'LineWidth', 1.5);
50 | hold on;
51 | fill([epochs, fliplr(epochs)], [upper_bounds_1', fliplr(lower_bounds_1')], 'b', 'FaceAlpha', 0.2, 'EdgeColor', 'none','DisplayName','95% CI ConLoRA');
52 |
53 | % 绘制 Lora 的均值和置信区间
54 | plot(epochs, mean_accuracies_2, 'r', 'DisplayName', 'LoRA-DFL', 'LineWidth', 1.5);
55 | fill([epochs, fliplr(epochs)], [upper_bounds_2', fliplr(lower_bounds_2')], 'r', 'FaceAlpha', 0.2, 'EdgeColor', 'none','DisplayName','95% CI LoRA-DFL');
56 |
57 | % 设置标签、标题和图例
58 | xlabel('Epoch', 'FontSize', 12);
59 | ylabel('Accuracy', 'FontSize', 12);
60 | %title('Mean Accuracy with 95% Confidence Intervals for DLora & Lora', 'FontSize', 14);
61 | % 设置图例,并允许自定义位置
62 | legend('show', 'Location', 'southeast');
63 |
64 | grid on;
65 | xlim([1 260]);
66 | ylim([0.4 0.95])
67 |
68 | end
69 |
70 |
71 | plot_accuracies_with_confidence_intervals(mean_accuracies1, confidence_intervals1,mean_accuracies2, confidence_intervals2)
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DFL-LoRA
2 |
3 | DFL-LoRA is a platform for distributed fine-tuning of large models in a fully decentralized federated learning scenario. This platform allows you to perform fine-tuning of large models within a completely decentralized network. In particular, DFL-LoRA addresses inherent error issues in decentralized federated learning by providing an option to mitigate the consensus error amplification effect caused by LoRA. You can freeze the A matrix during training to further reduce this error.
4 |
5 | ## How to Use
6 |
7 | ### 1. Download
8 |
9 | Clone the repository to your local folder:
10 |
11 | ```bash
12 | git clone https://github.com/LuYuPromax/DFL-LoRA.git
13 | ```
14 |
15 | ### 2. Create environment
16 |
17 | Create a new Conda environment with Python 3.10.10:
18 |
19 | ```
20 | conda create -n env_name python==3.10.10
21 | ```
22 |
23 | Then install the related libraries
24 |
25 | ```
26 | pip install -r requirement.txt
27 | ```
28 |
29 | ### Datasets
30 |
31 | Currently, DFL-LoRA supports some GLUE datasets and the GSM8k dataset. To use them, follow these steps:
32 |
33 | 1. Go to the `utils` folder.
34 |
35 | 2. Modify the paths in `glue_split.sh` and `gsm8k_split.sh` to match your local dataset storage.
36 |
37 | 3. Run the following commands to split the datasets:
38 |
39 | ##### GLUE
40 |
41 | ```
42 | chmod +x glue_split.sh
43 | ./glue_split.sh
44 | ```
45 |
46 | ##### GSM8K
47 |
48 | ```
49 | chmod +x gsm8k_split.sh
50 | ./gsm8k_split.sh
51 | ```
52 |
53 | To view the data distribution of different clients, run the view_split.sh script.
54 |
55 | ### Finetune
56 |
57 | To start the federated training process, run:
58 |
59 | ```
60 | chmod +x train.sh
61 | ./train.sh
62 | ```
63 |
64 | Parameter Description
65 |
66 | - **`--model_checkpoint` (str)**: Path to the pre-trained model checkpoint.
67 | - **`--dataset_path_template` (str)**: Template for client dataset paths (`{i}` is replaced by client number).
68 | - **`--val_dataset_path_template` (str)**: Template for client validation dataset paths.
69 | - **`--num_clients` (int)**: Number of clients participating in federated learning.
70 | - **`--lora_r` (int)**: LoRA rank for LoRA layers.
71 | - **`--lora_alpha` (int)**: LoRA alpha for LoRA layers.
72 | - **`--target_modules` (str)**: Comma-separated list of target modules for LoRA layers.
73 | - **`--training_type` (str)**: Type of training (`LoRA` or `DFL-LoRA`).
74 | - **`--dataset_type` (str)**: Dataset type (`sst2`, `mnli`, or `qnli`).
75 | - **`--name` (str)**: Name used to generate the weight matrix.
76 | - **`--num_rounds` (int, default: 256)**: Number of federated learning rounds.
77 | - **`--batch_size` (int, default: 128)**: Batch size for each client.
78 | - **`--log_path` (str, default: "federated_training.log")**: Path to save the log file.
79 |
80 | The training of the llama model is similar.
81 |
82 | ### Notes
83 | - Ensure that your datasets are correctly placed and paths are properly configured in the scripts.
84 | - You can modify the default values of parameters according to your experiment setup.
85 |
--------------------------------------------------------------------------------
/utils/generate_weight_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def generate_weight_matrix(name):
4 | """
5 | Generates a weight matrix based on the provided name.
6 |
7 | Args:
8 | name (str): The name for selecting the weight matrix configuration.
9 |
10 | Returns:
11 | np.array: The weight matrix for aggregation.
12 | """
13 | if name == "du14":
14 | A = np.array([
15 | [0, 1, 1, 0, 0, 0, 0],
16 | [1, 0, 1, 0, 0, 0, 0],
17 | [1, 1, 0, 1, 0, 0, 0],
18 | [0, 0, 1, 0, 1, 0, 0],
19 | [0, 0, 0, 1, 0, 1, 0],
20 | [0, 0, 0, 0, 1, 0, 1],
21 | [0, 0, 0, 0, 0, 1, 0]
22 | ])
23 | elif name == "link3":
24 | A = np.array([
25 | [0, 1, 0, 0, 0, 0, 1],
26 | [1, 0, 1, 0, 0, 0, 0],
27 | [0, 1, 0, 1, 0, 0, 0],
28 | [0, 0, 1, 0, 1, 0, 0],
29 | [0, 0, 0, 1, 0, 1, 0],
30 | [0, 0, 0, 0, 1, 0, 1],
31 | [1, 0, 0, 0, 0, 1, 0]
32 | ])
33 | elif name == "link4":
34 | A = np.array([
35 | [0, 1, 1, 0, 0, 0, 1],
36 | [1, 0, 1, 1, 0, 0, 0],
37 | [1, 1, 0, 1, 0, 0, 0],
38 | [0, 1, 1, 0, 1, 0, 0],
39 | [0, 0, 0, 1, 0, 1, 1],
40 | [0, 0, 0, 0, 1, 0, 1],
41 | [1, 0, 0, 0, 1, 1, 0]
42 | ])
43 | elif name == "link5":
44 | A = np.array([
45 | [0, 1, 1, 0, 0, 1, 1],
46 | [1, 0, 1, 1, 0, 0, 1],
47 | [1, 1, 0, 1, 1, 0, 0],
48 | [0, 1, 1, 0, 1, 1, 0],
49 | [0, 0, 1, 1, 0, 1, 1],
50 | [1, 0, 0, 1, 1, 0, 1],
51 | [1, 1, 0, 0, 1, 1, 0]
52 | ])
53 | elif name == "link6":
54 | A = np.array([
55 | [0, 1, 1, 1, 0, 1, 1],
56 | [1, 0, 1, 1, 1, 1, 0],
57 | [1, 1, 0, 1, 1, 1, 0],
58 | [1, 1, 1, 0, 1, 1, 0],
59 | [0, 1, 1, 1, 0, 1, 1],
60 | [1, 1, 1, 1, 1, 0, 1],
61 | [1, 0, 0, 0, 1, 1, 0]
62 | ])
63 | elif name == "link7":
64 | A = np.array([
65 | [0, 1, 1, 1, 1, 1, 1],
66 | [1, 0, 1, 1, 1, 1, 1],
67 | [1, 1, 0, 1, 1, 1, 1],
68 | [1, 1, 1, 0, 1, 1, 1],
69 | [1, 1, 1, 1, 0, 1, 1],
70 | [1, 1, 1, 1, 1, 0, 1],
71 | [1, 1, 1, 1, 1, 1, 0]
72 | ])
73 | else:
74 | raise ValueError(f"Unknown matrix name: {name}")
75 |
76 | # Calculate degree of each node and generate weight matrix
77 | degree = np.sum(A, axis=1)
78 |
79 | # Initialize weight matrix W
80 | W = np.zeros_like(A, dtype=float)
81 |
82 | # Calculate weight matrix
83 | for i in range(A.shape[0]):
84 | for j in range(A.shape[1]):
85 | if A[i, j] == 1:
86 | W[i, j] = 1 / (max(degree[i], degree[j]) + 1)
87 | elif i != j:
88 | W[i, j] = 0
89 |
90 | # Adjust diagonal elements to ensure each row sums to 1
91 | for i in range(A.shape[0]):
92 | W[i, i] = 1 - np.sum(W[i, np.arange(A.shape[0]) != i])
93 |
94 | return W
95 |
--------------------------------------------------------------------------------
/utils/gsm8k_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import load_dataset
3 | import numpy as np
4 | import os
5 |
6 | class DataSplitter:
7 | def __init__(self, dataset_path, num_clients, alpha=None):
8 | """
9 | Initializes the DataSplitter class.
10 |
11 | Args:
12 | dataset_path (str): Path to the dataset.
13 | num_clients (int): Number of clients.
14 | alpha (float): Dirichlet distribution parameter for non-IID split.
15 | """
16 | self.dataset = load_dataset(dataset_path, "main")
17 | self.train_dataset = self.dataset['train']
18 | self.num_clients = num_clients
19 | self.alpha = alpha
20 | self.clients_datasets = []
21 |
22 | def iid_split(self):
23 | """Splits the dataset into IID partitions across clients."""
24 | shuffled_dataset = self.train_dataset.shuffle(seed=42)
25 | client_size = len(shuffled_dataset) // self.num_clients
26 | self.clients_datasets = []
27 |
28 | for i in range(self.num_clients):
29 | start_idx = i * client_size
30 | end_idx = (i + 1) * client_size if i != self.num_clients - 1 else len(shuffled_dataset)
31 | client_dataset = shuffled_dataset.select(range(start_idx, end_idx))
32 | self.clients_datasets.append(client_dataset)
33 |
34 | def dirichlet_split(self):
35 | """Splits the dataset into non-IID partitions using Dirichlet distribution."""
36 | total_size = len(self.train_dataset)
37 | proportions = np.random.dirichlet([self.alpha] * self.num_clients)
38 | client_sizes = (proportions * total_size).astype(int)
39 | client_sizes[-1] = total_size - client_sizes[:-1].sum()
40 |
41 | indices = np.random.permutation(total_size)
42 | current_index = 0
43 | self.clients_datasets = []
44 |
45 | for size in client_sizes:
46 | client_indices = indices[current_index: current_index + size]
47 | self.clients_datasets.append(self.train_dataset.select(client_indices.tolist()))
48 | current_index += size
49 |
50 | def compute_data_distribution(self):
51 | """Computes the distribution of data among clients."""
52 | return [len(client_dataset) for client_dataset in self.clients_datasets]
53 |
54 | def save_datasets(self, base_path):
55 | """Saves client datasets to disk."""
56 | os.makedirs(base_path, exist_ok=True)
57 | for i, client_dataset in enumerate(self.clients_datasets):
58 | client_dataset.save_to_disk(os.path.join(base_path, f"client_{i+1}"))
59 |
60 |
61 | def main():
62 | parser = argparse.ArgumentParser(description="Dataset splitter for federated learning.")
63 | parser.add_argument('--dataset_path', type=str, required=True, help="Path to the dataset.")
64 | parser.add_argument('--output_path', type=str, required=True, help="Output path for saving client datasets.")
65 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients.")
66 | parser.add_argument('--split_type', type=str, choices=['iid', 'dirichlet'], required=True, help="Type of split: 'iid' or 'dirichlet'.")
67 | parser.add_argument('--alpha', type=float, default=None, help="Alpha value for Dirichlet distribution.")
68 |
69 | args = parser.parse_args()
70 |
71 | # Initialize DataSplitter
72 | splitter = DataSplitter(args.dataset_path, args.num_clients, args.alpha)
73 |
74 | # Perform split
75 | if args.split_type == 'iid':
76 | splitter.iid_split()
77 | elif args.split_type == 'dirichlet':
78 | if args.alpha is None:
79 | parser.error("--alpha must be specified for Dirichlet split.")
80 | splitter.dirichlet_split()
81 |
82 | # Save datasets
83 | splitter.save_datasets(args.output_path)
84 |
85 | # Print data distribution
86 | distribution = splitter.compute_data_distribution()
87 | print("Client data distribution:", distribution)
88 |
89 |
90 | if __name__ == "__main__":
91 | main()
92 |
--------------------------------------------------------------------------------
/finetune/bert/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # ---------------------------------------------------
4 | # Train Script for Federated Learning with LoRA
5 | # ---------------------------------------------------
6 |
7 | # Default parameters (can be overwritten by command-line arguments)
8 | MODEL_CHECKPOINT="/home/ubuntu/smyin/models/distilbert-base-uncased"
9 | DATASET_PATH_TEMPLATE="/home/ubuntu/smyin/dataset/decentrilized_dataset/sst2_020/client_{}"
10 | VAL_DATASET_PATH_TEMPLATE="/home/ubuntu/smyin/dataset/glue/sst2"
11 | NUM_CLIENTS=7
12 | LORA_R=4
13 | LORA_ALPHA=32
14 | TARGET_MODULES="q_lin,v_lin,pre_classifier,classifer" # Comma-separated list of target modules for LoRA layers
15 | TRAINING_TYPE="ConLoRA" # Options: LoRA, ConLoRA
16 | DATASET_TYPE="sst2" # Options: sst2, mnli, qnli
17 | NAME="link3" # The name used to generate the weight matrix
18 | NUM_ROUNDS=5 # Default number of federated learning rounds
19 | BATCH_SIZE=128 # Default batch size for training
20 | LOG_PATH="federated_training.log" # Default path for logging
21 |
22 | # ---------------------------------------------------
23 | # Parsing command-line arguments to allow overrides
24 | # ---------------------------------------------------
25 | while [[ $# -gt 0 ]]; do
26 | case $1 in
27 | --model_checkpoint)
28 | MODEL_CHECKPOINT="$2"
29 | shift 2
30 | ;;
31 | --dataset_path_template)
32 | DATASET_PATH_TEMPLATE="$2"
33 | shift 2
34 | ;;
35 | --val_dataset_path_template)
36 | VAL_DATASET_PATH_TEMPLATE="$2"
37 | shift 2
38 | ;;
39 | --num_clients)
40 | NUM_CLIENTS="$2"
41 | shift 2
42 | ;;
43 | --lora_r)
44 | LORA_R="$2"
45 | shift 2
46 | ;;
47 | --lora_alpha)
48 | LORA_ALPHA="$2"
49 | shift 2
50 | ;;
51 | --target_modules)
52 | TARGET_MODULES="$2"
53 | shift 2
54 | ;;
55 | --training_type)
56 | TRAINING_TYPE="$2"
57 | shift 2
58 | ;;
59 | --dataset_type)
60 | DATASET_TYPE="$2"
61 | shift 2
62 | ;;
63 | --name)
64 | NAME="$2"
65 | shift 2
66 | ;;
67 | --num_rounds)
68 | NUM_ROUNDS="$2"
69 | shift 2
70 | ;;
71 | --batch_size)
72 | BATCH_SIZE="$2"
73 | shift 2
74 | ;;
75 | --log_path)
76 | LOG_PATH="$2"
77 | shift 2
78 | ;;
79 | *)
80 | echo "Unknown argument: $1"
81 | exit 1
82 | ;;
83 | esac
84 | done
85 |
86 | # ---------------------------------------------------
87 | # Print configuration
88 | # ---------------------------------------------------
89 | echo "Training Configuration:"
90 | echo "------------------------"
91 | echo "Model checkpoint: $MODEL_CHECKPOINT"
92 | echo "Dataset path template: $DATASET_PATH_TEMPLATE"
93 | echo "Validation dataset path: $VAL_DATASET_PATH_TEMPLATE"
94 | echo "Number of clients: $NUM_CLIENTS"
95 | echo "LoRA Rank: $LORA_R"
96 | echo "LoRA Alpha: $LORA_ALPHA"
97 | echo "Target modules: $TARGET_MODULES"
98 | echo "Training type: $TRAINING_TYPE"
99 | echo "Dataset type: $DATASET_TYPE"
100 | echo "Weight matrix name: $NAME"
101 | echo "Number of rounds: $NUM_ROUNDS"
102 | echo "Batch size: $BATCH_SIZE"
103 | echo "Log path: $LOG_PATH"
104 | echo "------------------------"
105 |
106 | # ---------------------------------------------------
107 | # Run the training script with the specified parameters
108 | # ---------------------------------------------------
109 | python train.py \
110 | --model_checkpoint "$MODEL_CHECKPOINT" \
111 | --dataset_path_template "$DATASET_PATH_TEMPLATE" \
112 | --val_dataset_path_template "$VAL_DATASET_PATH_TEMPLATE" \
113 | --num_clients "$NUM_CLIENTS" \
114 | --lora_r "$LORA_R" \
115 | --lora_alpha "$LORA_ALPHA" \
116 | --target_modules "$TARGET_MODULES" \
117 | --training_type "$TRAINING_TYPE" \
118 | --dataset_type "$DATASET_TYPE" \
119 | --name "$NAME" \
120 | --num_rounds "$NUM_ROUNDS" \
121 | --batch_size "$BATCH_SIZE" \
122 | --log_path "$LOG_PATH"
123 |
124 |
--------------------------------------------------------------------------------
/utils/view_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datasets import load_from_disk
3 | from collections import Counter
4 | import os
5 |
6 | class DataLoader:
7 | def __init__(self, base_path, num_clients, dataset_type):
8 | """
9 | Initializes the DataLoader class.
10 |
11 | Args:
12 | base_path (str): Path where the client datasets are stored.
13 | num_clients (int): Number of clients.
14 | dataset_type (str): Type of dataset ('glue' or 'gsm8k').
15 | """
16 | self.base_path = base_path
17 | self.num_clients = num_clients
18 | self.dataset_type = dataset_type.lower()
19 | self.clients_datasets = self.load_datasets()
20 |
21 | def load_datasets(self):
22 | """
23 | Loads the client datasets from disk.
24 |
25 | Returns:
26 | list: A list of datasets for each client.
27 | """
28 | clients_datasets = []
29 | for i in range(self.num_clients):
30 | client_path = os.path.join(self.base_path, f"client_{i+1}")
31 | if os.path.exists(client_path):
32 | try:
33 | client_dataset = load_from_disk(client_path)
34 | clients_datasets.append(client_dataset)
35 | except Exception as e:
36 | print(f"Error loading dataset for Client {i+1}: {e}")
37 | clients_datasets.append(None)
38 | else:
39 | print(f"Warning: Client {i+1} dataset not found at {client_path}.")
40 | clients_datasets.append(None)
41 | return clients_datasets
42 |
43 | def compute_label_distribution(self, dataset):
44 | """
45 | Computes the label distribution for the given dataset.
46 |
47 | Args:
48 | dataset (Dataset): A Hugging Face dataset.
49 |
50 | Returns:
51 | Counter: A counter of label occurrences.
52 | """
53 | try:
54 | labels = [example['label'] for example in dataset]
55 | return Counter(labels)
56 | except KeyError:
57 | raise ValueError("The dataset does not contain a 'label' field.")
58 |
59 | def view_data_distribution(self):
60 | """
61 | Views data distribution based on dataset type ('glue' or 'gsm8k').
62 | """
63 | for i in range(self.num_clients):
64 | if self.clients_datasets[i]:
65 | if self.dataset_type == 'glue':
66 | try:
67 | print(f"Label distribution for Client {i+1}:")
68 | dist = self.compute_label_distribution(self.clients_datasets[i])
69 | print(dist)
70 | except ValueError as e:
71 | print(f"Error processing Client {i+1}: {e}")
72 | elif self.dataset_type == 'gsm8k':
73 | print(f"Client {i+1} has {len(self.clients_datasets[i])} examples.")
74 | else:
75 | print(f"No dataset found for Client {i+1}.")
76 |
77 | def main():
78 | parser = argparse.ArgumentParser(description="Load and view the distribution of split datasets.")
79 | parser.add_argument('--base_path', type=str, required=True, help="Path to the directory containing client datasets.")
80 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients.")
81 | parser.add_argument('--dataset_type', type=str, choices=['glue', 'gsm8k'], required=True, help="Type of dataset ('glue' or 'gsm8k').")
82 |
83 | args = parser.parse_args()
84 |
85 | # Validate base path
86 | if not os.path.exists(args.base_path):
87 | raise FileNotFoundError(f"The specified base path does not exist: {args.base_path}")
88 |
89 | # Validate number of clients
90 | available_clients = len([d for d in os.listdir(args.base_path) if os.path.isdir(os.path.join(args.base_path, d))])
91 | if args.num_clients > available_clients:
92 | raise ValueError(f"The specified number of clients ({args.num_clients}) exceeds the available clients ({available_clients}).")
93 |
94 | # Load and view data distribution
95 | loader = DataLoader(args.base_path, args.num_clients, args.dataset_type)
96 | loader.view_data_distribution()
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/finetune/llama/gsm8k_eval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import re
4 | import os
5 | os.environ['NCCL_P2P_DISABLE'] = '1'
6 | os.environ['NCCL_IB_DISABLE'] = '1'
7 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
8 | from dataclasses import dataclass, field
9 | from typing import Dict, Optional
10 |
11 | import torch
12 | import transformers
13 |
14 |
15 | from peft import PeftModel
16 | from datasets import load_dataset
17 | from accelerate.utils import set_seed
18 |
19 |
20 | IGNORE_INDEX = -100
21 | DEFAULT_PAD_TOKEN = "[PAD]"
22 | DEFAULT_EOS_TOKEN = ""
23 | DEFAULT_BOS_TOKEN = ""
24 | DEFAULT_UNK_TOKEN = ""
25 | ANSWER_PROMPT = "The final answer is: "
26 | QUESTION_PROMPT = "\nAnswer the above question. First think step by step and then answer the final number.\n"
27 |
28 |
29 | def smart_tokenizer_and_embedding_resize(
30 | special_tokens_dict: Dict,
31 | tokenizer: transformers.PreTrainedTokenizer,
32 | model: transformers.PreTrainedModel,
33 | ):
34 | """Resize tokenizer and embedding.
35 |
36 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
37 | """
38 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
39 | model.resize_token_embeddings(len(tokenizer))
40 |
41 | if num_new_tokens > 0:
42 | input_embeddings = model.get_input_embeddings().weight.data
43 | output_embeddings = model.get_output_embeddings().weight.data
44 |
45 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
46 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
47 |
48 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
49 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
50 |
51 |
52 | def evaluation(batch_size):
53 | output_dir = "./llama_lora_gsm8k"
54 |
55 | Modelpath='/home/ubuntu/smyin/models/Llama-3.2-1B'
56 | model = transformers.AutoModelForCausalLM.from_pretrained(Modelpath)
57 | tokenizer = transformers.AutoTokenizer.from_pretrained(Modelpath)
58 |
59 |
60 | special_tokens_dict = dict()
61 | if tokenizer.pad_token is None:
62 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
63 | if tokenizer.eos_token is None:
64 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
65 | if tokenizer.bos_token is None:
66 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
67 | if tokenizer.unk_token is None:
68 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
69 |
70 | smart_tokenizer_and_embedding_resize(
71 | special_tokens_dict=special_tokens_dict,
72 | tokenizer=tokenizer,
73 | model=model,
74 | )
75 |
76 | model = model.to('cuda')
77 |
78 | dataset = load_dataset('/home/ubuntu/smyin/dataset/gsm8k', "main")
79 | test_set = dataset['test']
80 |
81 | question = [f"{example['question']}{QUESTION_PROMPT}" for example in test_set]
82 | answer = []
83 |
84 | # get numerical answer
85 | for example in test_set['answer']:
86 | ans = example.split('####')[-1]
87 | ans = ans.replace(',', '') # handle numbers like 2,000
88 | try:
89 | ans = float(ans)
90 | except ValueError:
91 | ans = float("inf")
92 | answer.append(ans)
93 |
94 | logging.warning("Tokenizing inputs...")
95 | eval_step = math.ceil(len(question)/batch_size)
96 | logging.warning(f"Total example: {len(question)} | eval batch size: {batch_size}"
97 | f"eval steps: {eval_step}")
98 | question_data = []
99 | for i in range(eval_step):
100 | if i < eval_step - 1:
101 | batch = tokenizer(
102 | question[i*batch_size: (i+1)*batch_size],
103 | return_tensors="pt",
104 | padding="longest",
105 | )
106 | else:
107 | batch = tokenizer(
108 | question[i*batch_size:],
109 | return_tensors="pt",
110 | padding="longest",
111 | )
112 | batch['input_len'] = len(batch['input_ids'][0])
113 | question_data.append(batch)
114 |
115 | model.eval()
116 | gen_kwargs = {
117 | "max_new_tokens": 256,
118 | "temperature": 0.1,
119 | "top_k": 40,
120 | "top_p": 0.95,
121 | "do_sample": True,
122 | }
123 | ans_pred_list = []
124 | set_seed(42)
125 | for step, batch in enumerate(question_data):
126 | with torch.no_grad():
127 | gen_kwargs["input_ids"] = batch["input_ids"].to('cuda')
128 | gen_kwargs["attention_mask"] = batch["attention_mask"].to('cuda')
129 | generated_tokens = model.generate(**gen_kwargs)
130 |
131 | pred_tokens = generated_tokens[:, batch['input_len']:]
132 | decoded_pred = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True)
133 |
134 | # Extract the numbers in sentences
135 | print(decoded_pred)
136 | ans_pred_list += [extract_answer_number(sentence_pred) for sentence_pred in decoded_pred]
137 |
138 | print("prediction", ans_pred_list)
139 | print("ground truth", answer)
140 |
141 | accuracy = compute_accuracy(answer, ans_pred_list)
142 |
143 | print(f"GSM8K test accuracy: {100*accuracy:.2f}% | ")
144 |
145 |
146 | def extract_answer_number(sentence: str) -> float:
147 | sentence = sentence.replace(',', '')
148 | pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
149 | if not pred:
150 | return float('inf')
151 | segment = sentence.split(ANSWER_PROMPT)
152 | if len(segment) > 1:
153 | pred_answer = segment[1]
154 | pred_answer = [s for s in re.findall(r'-?\d+\.?\d*', pred_answer)]
155 | if len(pred_answer) > 0:
156 | pred_answer = pred_answer[0]
157 | else:
158 | pred_answer = float(pred[-1])
159 | else:
160 | # use the last number as the answer
161 | pred_answer = float(pred[-1])
162 |
163 | if isinstance(pred_answer, str):
164 | try:
165 | pred_answer = float(pred_answer)
166 | except ValueError as e:
167 | pred_answer = float('inf')
168 | return pred_answer
169 |
170 |
171 | def compute_accuracy(pred: list, gold: list):
172 | acc = 0.0
173 | for p, g in zip(pred, gold):
174 | if p == g:
175 | acc += 1
176 |
177 | return acc / len(pred)
178 |
179 |
180 | if __name__ == "__main__":
181 | batch_size=128
182 | evaluation(batch_size)
183 |
184 |
185 |
--------------------------------------------------------------------------------
/utils/glue_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from datasets import load_dataset
4 | from collections import Counter
5 | import os
6 |
7 | class DataSplitter:
8 | def __init__(self, dataset_path, num_clients, alpha=None, min_size=None):
9 | """
10 | Initializes the DataSplitter class.
11 |
12 | Args:
13 | dataset_path (str): Path to the dataset.
14 | num_clients (int): Number of clients.
15 | alpha (float): Dirichlet distribution parameter for non-IID split.
16 | min_size (int): Minimum size of data for a client (optional).
17 | """
18 | self.dataset = load_dataset(dataset_path)
19 | self.train_dataset = self.dataset['train']
20 | self.num_clients = num_clients
21 | self.alpha = alpha
22 | self.min_size = min_size
23 | self.clients_datasets = []
24 |
25 | def iid_split(self):
26 | """
27 | Splits the dataset into IID partitions across clients.
28 | """
29 | shuffled_dataset = self.train_dataset.shuffle(seed=42)
30 | client_size = len(shuffled_dataset) // self.num_clients
31 | self.clients_datasets = []
32 |
33 | for i in range(self.num_clients):
34 | start_idx = i * client_size
35 | end_idx = (i + 1) * client_size if i != self.num_clients - 1 else len(shuffled_dataset)
36 | client_dataset = shuffled_dataset.select(range(start_idx, end_idx))
37 | self.clients_datasets.append(client_dataset)
38 |
39 | def dirichlet_split(self):
40 | """
41 | Splits the dataset into non-IID partitions using Dirichlet distribution.
42 | """
43 | self.clients_datasets = [[] for _ in range(self.num_clients)]
44 | labels = np.array([example['label'] for example in self.train_dataset])
45 | num_classes = len(np.unique(labels))
46 |
47 | class_indices = [np.where(labels == i)[0] for i in range(num_classes)]
48 |
49 | for c in range(num_classes):
50 | class_size = len(class_indices[c])
51 | proportions = np.random.dirichlet(np.repeat(self.alpha, self.num_clients))
52 | client_sizes = (proportions * class_size).astype(int)
53 | client_sizes[-1] = class_size - np.sum(client_sizes[:-1])
54 |
55 | current_index = 0
56 | for i, size in enumerate(client_sizes):
57 | self.clients_datasets[i].extend(class_indices[c][current_index:current_index + size].tolist())
58 | current_index += size
59 |
60 | # Ensure each client has enough data
61 | for attempt in range(1000):
62 | total_sizes = [len(client_dataset) for client_dataset in self.clients_datasets]
63 | if all(size >= self.min_size for size in total_sizes):
64 | break
65 | if any(size < self.min_size for size in total_sizes):
66 | self.clients_datasets = [[] for _ in range(self.num_clients)]
67 | for c in range(num_classes):
68 | class_size = len(class_indices[c])
69 | proportions = np.random.dirichlet(np.repeat(self.alpha, self.num_clients))
70 | client_sizes = (proportions * class_size).astype(int)
71 | client_sizes[-1] = class_size - np.sum(client_sizes[:-1])
72 | current_index = 0
73 | for i, size in enumerate(client_sizes):
74 | self.clients_datasets[i].extend(class_indices[c][current_index:current_index + size].tolist())
75 | current_index += size
76 | else:
77 | raise ValueError(f"Unable to adjust client sizes after {attempt + 1} attempts.")
78 | else:
79 | raise ValueError("Unable to generate a valid split after 1000 attempts.")
80 |
81 | # Convert indices to actual data and shuffle
82 | for i in range(self.num_clients):
83 | if len(self.clients_datasets[i]) == 0:
84 | print(f"Client {i} dataset is empty!")
85 | else:
86 | np.random.shuffle(self.clients_datasets[i])
87 | self.clients_datasets[i] = self.train_dataset.select(self.clients_datasets[i])
88 |
89 | def compute_label_distribution(self, dataset):
90 | """
91 | Computes the label distribution for the given dataset.
92 | """
93 | labels = [example['label'] for example in dataset]
94 | return Counter(labels)
95 |
96 | def label_distribution(self):
97 | """
98 | Computes the label distribution for each client dataset.
99 | """
100 | self.pct = []
101 | for i in range(self.num_clients):
102 | self.pct.append(self.compute_label_distribution(self.clients_datasets[i]))
103 |
104 | def save_datasets(self, base_path):
105 | """
106 | Saves the client datasets to disk.
107 | """
108 | os.makedirs(base_path, exist_ok=True)
109 | for i, client_dataset in enumerate(self.clients_datasets):
110 | client_dataset.save_to_disk(os.path.join(base_path, f"client_{i+1}"))
111 |
112 | def main():
113 | parser = argparse.ArgumentParser(description="Dataset splitter for federated learning.")
114 | parser.add_argument('--dataset_path', type=str, required=True, help="Path to the dataset.")
115 | parser.add_argument('--output_path', type=str, required=True, help="Output path for saving client datasets.")
116 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients.")
117 | parser.add_argument('--split_type', type=str, choices=['iid', 'dirichlet'], required=True, help="Type of split: 'iid' or 'dirichlet'.")
118 | parser.add_argument('--alpha', type=float, default=None, help="Alpha value for Dirichlet distribution.")
119 | parser.add_argument('--min_size', type=int, default=0, help="Minimum size for each client dataset.")
120 |
121 | args = parser.parse_args()
122 |
123 | splitter = DataSplitter(args.dataset_path, args.num_clients, args.alpha, args.min_size)
124 |
125 | # Perform split
126 | if args.split_type == 'iid':
127 | splitter.iid_split()
128 | elif args.split_type == 'dirichlet':
129 | if args.alpha is None:
130 | parser.error("--alpha must be specified for Dirichlet split.")
131 | splitter.dirichlet_split()
132 |
133 | # Save datasets
134 | splitter.save_datasets(args.output_path)
135 |
136 | # Print label distribution
137 | splitter.label_distribution()
138 | print("Label distribution per client:")
139 | for i, dist in enumerate(splitter.pct):
140 | print(f"Client {i+1}: {dist}")
141 |
142 | if __name__ == "__main__":
143 | main()
144 |
--------------------------------------------------------------------------------
/finetune/llama/gsm8k_train.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os
4 | os.environ['NCCL_P2P_DISABLE'] = '1'
5 | os.environ['NCCL_IB_DISABLE'] = '1'
6 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
7 | from dataclasses import dataclass, field
8 | from typing import Dict, Optional, Sequence
9 |
10 | import torch
11 | import transformers
12 | from torch.utils.data import Dataset
13 | from transformers import Trainer,TrainingArguments
14 | from peft import LoraConfig, TaskType, get_peft_model
15 | from datasets import load_dataset
16 |
17 |
18 | IGNORE_INDEX = -100
19 | DEFAULT_PAD_TOKEN = "[PAD]"
20 | DEFAULT_EOS_TOKEN = ""
21 | DEFAULT_BOS_TOKEN = ""
22 | DEFAULT_UNK_TOKEN = ""
23 | ANSWER_PROMPT = "The final answer is: "
24 | QUESTION_PROMPT = "\nAnswer the above question. First think step by step and then answer the final number.\n"
25 |
26 |
27 |
28 | def smart_tokenizer_and_embedding_resize(
29 | special_tokens_dict: Dict,
30 | tokenizer: transformers.PreTrainedTokenizer,
31 | model: transformers.PreTrainedModel,
32 | ):
33 | """Resize tokenizer and embedding.
34 |
35 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
36 | """
37 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
38 | model.resize_token_embeddings(len(tokenizer))
39 |
40 | if num_new_tokens > 0:
41 | input_embeddings = model.get_input_embeddings().weight.data
42 | output_embeddings = model.get_output_embeddings().weight.data
43 |
44 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
45 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
46 |
47 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
48 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
49 |
50 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
51 | """Tokenize a list of strings."""
52 | tokenized_list = [
53 | tokenizer(
54 | text,
55 | return_tensors="pt",
56 | padding="max_length",
57 | max_length=512,
58 | truncation=True,
59 | )
60 | for text in strings
61 | ]
62 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
63 | input_ids_lens = labels_lens = [
64 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
65 | ]
66 | return dict(
67 | input_ids=input_ids,
68 | labels=labels,
69 | input_ids_lens=input_ids_lens,
70 | labels_lens=labels_lens,
71 | )
72 |
73 | class SupervisedDataset(Dataset):
74 | def __init__(self, raw_data, tokenizer):
75 | super(SupervisedDataset, self).__init__()
76 |
77 | sources = [f"{example['question']}{QUESTION_PROMPT}" for example in raw_data]
78 | targets = [f"{example['answer']}{tokenizer.eos_token}".replace("####", ANSWER_PROMPT) for example in raw_data]
79 |
80 | examples = [s + t for s, t in zip(sources, targets)]
81 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
82 | input_ids = examples_tokenized["input_ids"]
83 | labels = copy.deepcopy(input_ids)
84 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
85 | label[:source_len] = IGNORE_INDEX
86 | data_dict = dict(input_ids=input_ids, labels=labels)
87 | self.input_ids = data_dict["input_ids"]
88 | self.labels = data_dict["labels"]
89 |
90 | def __len__(self):
91 | return len(self.input_ids)
92 |
93 | def __getitem__(self, idx):
94 | return {"input_ids": self.input_ids[idx], "labels": self.labels[idx]}
95 |
96 | @dataclass
97 | class DataCollatorForSupervisedDataset(object):
98 | """Collate examples for supervised fine-tuning."""
99 |
100 | tokenizer: transformers.PreTrainedTokenizer
101 |
102 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
103 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
104 | input_ids = torch.nn.utils.rnn.pad_sequence(
105 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
106 | )
107 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
108 | return dict(
109 | input_ids=input_ids,
110 | labels=labels,
111 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
112 | )
113 |
114 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer) -> Dict:
115 | dataset = load_dataset('/home/ubuntu/smyin/datasets/gsm8k', "main")
116 | #train_set = dataset['train'].select(range(500))
117 | train_set = dataset['train']
118 | train_dataset = SupervisedDataset(raw_data=train_set, tokenizer=tokenizer)
119 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
120 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
121 |
122 |
123 | def train():
124 | output_dir = "./llama_lora_gsm8k3"
125 |
126 | model = transformers.AutoModelForCausalLM.from_pretrained('/home/ubuntu/smyin/models/Llama-3.2-1B-Instruct')
127 |
128 | peft_config = LoraConfig(
129 | task_type=TaskType.CAUSAL_LM,
130 | r=8, # 矩阵低秩近似的秩
131 | lora_alpha=32,
132 | lora_dropout=0.1, # LoRA 的 dropout 概率
133 | )
134 | model = get_peft_model(model, peft_config)
135 |
136 | tokenizer = transformers.AutoTokenizer.from_pretrained('/home/ubuntu/smyin/models/Llama-3.2-1B-Instruct')
137 |
138 | special_tokens_dict = dict()
139 | if tokenizer.pad_token is None:
140 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
141 | if tokenizer.eos_token is None:
142 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
143 | if tokenizer.bos_token is None:
144 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
145 | if tokenizer.unk_token is None:
146 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
147 |
148 | smart_tokenizer_and_embedding_resize(special_tokens_dict=special_tokens_dict,tokenizer=tokenizer,model=model,)
149 |
150 | data_module = make_supervised_data_module(tokenizer=tokenizer)
151 |
152 | training_args = TrainingArguments(
153 | output_dir=output_dir,
154 | per_device_train_batch_size=5,
155 | gradient_accumulation_steps=4,
156 | num_train_epochs=5,
157 | save_steps=500,
158 | logging_steps=200,
159 | evaluation_strategy="no",
160 | )
161 |
162 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
163 | trainer.train()
164 | trainer.save_state()
165 | trainer.save_model(output_dir)
166 |
167 |
168 | if __name__ == "__main__":
169 | train()
--------------------------------------------------------------------------------
/finetune/bert/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | import torch
5 | import argparse
6 |
7 | # 动态添加项目根目录到 sys.path
8 | current_dir = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在目录
9 | parent_dir = os.path.dirname(os.path.dirname(current_dir)) # code 根目录
10 | sys.path.append(parent_dir)
11 |
12 | # 导入模块
13 | from utils.generate_weight_matrix import generate_weight_matrix
14 | from server import FederatedServer
15 | from client import Client
16 |
17 | def train_federated_model(model_checkpoint, dataset_path_template, val_dataset_path_template, num_clients, lora_r, lora_alpha, target_modules, training_type, dataset_type, name, num_rounds=256, batch_size=16, log_path="federated_training.log"):
18 | """
19 | Runs the federated learning training process for the given number of rounds.
20 |
21 | Args:
22 | model_checkpoint (str): Path to the pre-trained model.
23 | dataset_path_template (str): Template for the dataset paths for each client.
24 | val_dataset_path_template (str): Path template for validation dataset.
25 | num_clients (int): Number of clients participating in federated learning.
26 | lora_r (int): LoRA rank for LoRA layers.
27 | lora_alpha (int): LoRA alpha for LoRA layers.
28 | target_modules (list): List of target modules for LoRA layers.
29 | training_type (str): Type of training ('LoRA', 'ConLoRA').
30 | dataset_type (str): Type of dataset ('sst2', 'mnli', or 'qnli').
31 | name (str): The name used to generate the weight matrix.
32 | num_rounds (int): Number of federated learning rounds (default is 256).
33 | log_path (str): Path to save the log file.
34 | """
35 | # Initialize clients
36 | clients = []
37 | for i in range(num_clients):
38 | dataset_path = dataset_path_template.format(i+1)
39 | val_dataset_path = val_dataset_path_template
40 | client = Client(model_checkpoint=model_checkpoint,
41 | dataset_path=dataset_path,
42 | val_dataset_path=val_dataset_path,
43 | lora_r=lora_r,
44 | lora_alpha=lora_alpha,
45 | target_modules=target_modules,
46 | training_type=training_type,
47 | dataset_type=dataset_type,
48 | batch_size=batch_size,
49 | device='cuda' if torch.cuda.is_available() else 'cpu')
50 | clients.append(client)
51 |
52 | # Initialize the federated server with clients
53 | server = FederatedServer(clients)
54 |
55 | # Initial aggregation of parameters
56 | server.aggregate_last_2_layers_params()
57 | server.aggregate_lora_A()
58 |
59 | # Initialize accuracy and parameter difference tracking
60 | acc = [[] for _ in range(num_clients)]
61 | diff = []
62 |
63 | # Get the weight matrix A for aggregation
64 | A = generate_weight_matrix(name)
65 |
66 | # Initial evaluation of all clients
67 | for i in range(server.num_clients):
68 | accuracy = server.clients[i].evaluate()
69 | acc[i].append(accuracy)
70 | logging.info(f'Initial accuracy of client {i+1}: {accuracy}')
71 |
72 | # Aggregate LoRA parameters with DFL
73 | server.aggregate_dfl(A)
74 |
75 | # Calculate and log the initial difference in LoRA products
76 | a = server.calculate_lora_products_and_avg_diff(server.new_params)
77 | diff.append(a)
78 |
79 | # Training loop for the given number of rounds
80 | for round in range(num_rounds):
81 | logging.info(f"Round {round+1}/{num_rounds}")
82 |
83 | # Train each client for one epoch
84 | for client in server.clients:
85 | loss = client.train_one_epoch()
86 | accuracy = client.evaluate()
87 | logging.info(f"Client {server.clients.index(client) + 1} - Loss: {loss}, Accuracy: {accuracy}")
88 |
89 | # Aggregate parameters with DFL
90 | server.aggregate_dfl(A)
91 |
92 | # Calculate and log the difference in LoRA products after aggregation
93 | a = server.calculate_lora_products_and_avg_diff(server.new_params)
94 | diff.append(a)
95 | logging.info(f"Round {round+1} - Parameter Difference: {a}")
96 |
97 | # Evaluate all clients after the round
98 | for i in range(server.num_clients):
99 | accuracy = server.clients[i].evaluate()
100 | acc[i].append(accuracy)
101 | logging.info(f'Accuracy of client {i+1} after round {round+1}: {accuracy}')
102 |
103 | # Log the parameter difference and accuracies after the round
104 | logging.info(f"Parameter Difference record: {diff}")
105 | logging.info(f'Accuracies after round {round+1}: {acc}')
106 |
107 |
108 | if __name__ == "__main__":
109 | # Define argument parser
110 | parser = argparse.ArgumentParser(description="Federated Learning Training Script")
111 | parser.add_argument('--model_checkpoint', type=str, required=True, help="Path to the model checkpoint.")
112 | parser.add_argument('--dataset_path_template', type=str, required=True, help="Template for the dataset paths for each client.")
113 | parser.add_argument('--val_dataset_path_template', type=str, required=True, help="Path template for validation dataset.")
114 | parser.add_argument('--num_clients', type=int, required=True, help="Number of clients participating in federated learning.")
115 | parser.add_argument('--lora_r', type=int, required=True, help="LoRA rank for LoRA layers.")
116 | parser.add_argument('--lora_alpha', type=int, required=True, help="LoRA alpha for LoRA layers.")
117 | parser.add_argument('--target_modules', type=str, required=True, help="Comma-separated list of target modules for LoRA layers.")
118 | parser.add_argument('--training_type', type=str, choices=['LoRA', 'ConLoRA'], required=True, help="Training type ('LoRA' or 'ConLoRA').")
119 | parser.add_argument('--dataset_type', type=str, choices=['sst2', 'mnli', 'qnli'], required=True, help="Dataset type.")
120 | parser.add_argument('--name', type=str, required=True, help="The name used to generate the weight matrix.")
121 | parser.add_argument('--num_rounds', type=int, default=256, help="Number of federated learning rounds.")
122 | parser.add_argument('--batch_size', type=int, default=128, help="Batch size for each client.")
123 | parser.add_argument('--log_path', type=str, default="federated_training.log", help="Path to save the log file.")
124 |
125 | args = parser.parse_args()
126 |
127 | # Parse target_modules into list
128 | target_modules = args.target_modules.split(',')
129 |
130 | # Configure logging to the specified log file path
131 | logging.basicConfig(filename=args.log_path, level=logging.INFO,
132 | format='%(asctime)s - %(levelname)s - %(message)s')
133 |
134 | # Run federated training process
135 | train_federated_model(
136 | model_checkpoint=args.model_checkpoint,
137 | dataset_path_template=args.dataset_path_template,
138 | val_dataset_path_template=args.val_dataset_path_template,
139 | num_clients=args.num_clients,
140 | lora_r=args.lora_r,
141 | lora_alpha=args.lora_alpha,
142 | target_modules=target_modules,
143 | training_type=args.training_type,
144 | dataset_type=args.dataset_type,
145 | name=args.name,
146 | num_rounds=args.num_rounds,
147 | batch_size=args.batch_size,
148 | log_path=args.log_path # Pass log_path parameter
149 | )
150 |
--------------------------------------------------------------------------------
/finetune/llama/server.py:
--------------------------------------------------------------------------------
1 | import os
2 | #os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3 | import torch
4 | from itertools import combinations
5 | import numpy as np
6 | import logging
7 | from client import Client
8 |
9 |
10 |
11 | # 获取当前文件的绝对路径
12 | current_dir = os.path.dirname(os.path.abspath(__file__))
13 |
14 | # 确保日志目录存在
15 | log_dir = os.path.join(current_dir, 'logs')
16 | os.makedirs(log_dir, exist_ok=True)
17 |
18 | # 设置日志记录
19 | logging.basicConfig(filename=os.path.join(log_dir, 'LoRA.log'),
20 | level=logging.INFO,
21 | format='%(asctime)s - %(levelname)s - %(message)s')
22 |
23 |
24 | class FederatedServer:
25 |
26 | def __init__(self, clients):
27 | self.clients = clients
28 | self.num_clients=len(clients)
29 |
30 | #初始时为了使所有节点的LoRA参数一致
31 | def aggregate_lora_A(self):
32 | self.Avg_lora_A_params=self.clients[0].get_lora_A()
33 | for i in range(self.num_clients):
34 | self.clients[i].set_trainable_parameters(self.Avg_lora_A_params)
35 |
36 | def aggregate_dfl(self, A):
37 | self.new_params = []
38 |
39 | # 初始化每个客户端的参数为零
40 | for i in range(self.num_clients):
41 | client_params = self.clients[i].get_lora_parameters()
42 | zero_params = {name: (torch.zeros_like(param[0]), param[1]) for name, param in client_params.items()}
43 | self.new_params.append(zero_params)
44 |
45 | # 聚合操作
46 | for i in range(self.num_clients):
47 | for j in range(self.num_clients):
48 | client_params = self.clients[j].get_lora_parameters()
49 | for name, (param, requires_grad) in client_params.items():
50 | self.new_params[i][name] = (self.new_params[i][name][0] + param * A[i][j], requires_grad)
51 |
52 | for i in range(self.num_clients):
53 | self.clients[i].set_trainable_parameters(self.new_params[i])
54 |
55 |
56 | # 提取并相乘lora参数
57 | def extract_and_multiply_lora_params(self,param_group):
58 | result = {}
59 | for param_name, (param, _) in param_group.items():
60 | if 'lora_B.default.weight' in param_name:
61 | prefix = param_name.split('lora_B.default.weight')[0]
62 | lora_A_name = prefix + 'lora_A.default.weight'
63 | lora_B_name = prefix + 'lora_B.default.weight'
64 |
65 | if lora_A_name in param_group and lora_B_name in param_group:
66 | lora_A = param_group[lora_A_name][0]
67 | lora_B = param_group[lora_B_name][0]
68 |
69 | product = torch.matmul(lora_B, lora_A)
70 | result[prefix + 'product'] = product
71 |
72 | return result
73 |
74 |
75 | # 计算所有参数组两两之间的差异并求平均值
76 | def calculate_lora_products_and_avg_diff(self,param_groups):
77 | if len(param_groups) < 2:
78 | raise ValueError("There should be at least two sets of parameters to calculate differences.")
79 |
80 | total_diff_sum = 0.0
81 | num_pairs = 0
82 |
83 | # 生成所有参数组的两两组合
84 | for i, j in combinations(range(len(param_groups)), 2):
85 | product_1 = self.extract_and_multiply_lora_params(param_groups[i])
86 | product_2 = self.extract_and_multiply_lora_params(param_groups[j])
87 | pair_diff_sum=0.0
88 | # 计算 lora_A 和 lora_B 的乘积差异
89 | for key in product_1.keys():
90 | diff = product_1[key] - product_2[key]
91 | #pair_diff_sum += torch.sum(diff).item()
92 | pair_diff_sum += torch.norm(diff).item()
93 |
94 | # 计算普通参数的差异
95 | for param_name in param_groups[i].keys():
96 | if 'lora_A.default.weight' in param_name or 'lora_B.default.weight' in param_name:
97 | continue # 跳过 lora_A 和 lora_B,因为已经处理过
98 | diff = param_groups[i][param_name][0] - param_groups[j][param_name][0]
99 | #pair_diff_sum += torch.sum(diff).item()
100 | pair_diff_sum += torch.norm(diff).item()
101 |
102 |
103 | total_diff_sum += pair_diff_sum
104 |
105 | num_pairs += 1
106 |
107 | # 计算平均差异
108 | average_diff = total_diff_sum / num_pairs if num_pairs > 0 else 0.0
109 | return average_diff
110 |
111 | def get_A(name):
112 | if name=="du14":
113 | A = np.array([
114 | [0, 1, 1, 0, 0, 0, 0],
115 | [1, 0, 1, 0, 0, 0, 0],
116 | [1, 1, 0, 1, 0, 0, 0],
117 | [0, 0, 1, 0, 1, 0, 0],
118 | [0, 0, 0, 1, 0, 1, 0],
119 | [0, 0, 0, 0, 1, 0, 1],
120 | [0, 0, 0, 0, 0, 1, 0]
121 | ])
122 | if name=="link3":
123 | A = np.array([
124 | [0, 1, 0, 0, 0, 0, 1],
125 | [1, 0, 1, 0, 0, 0, 0],
126 | [0, 1, 0, 1, 0, 0, 0],
127 | [0, 0, 1, 0, 1, 0, 0],
128 | [0, 0, 0, 1, 0, 1, 0],
129 | [0, 0, 0, 0, 1, 0, 1],
130 | [1, 0, 0, 0, 0, 1, 0]
131 | ])
132 | elif name=="link4":
133 | A= np.array([
134 | [0, 1, 1, 0, 0, 0, 1],
135 | [1, 0, 1, 1, 0, 0, 0],
136 | [1, 1, 0, 1, 0, 0, 0],
137 | [0, 1, 1, 0, 1, 0, 0],
138 | [0, 0, 0, 1, 0, 1, 1],
139 | [0, 0, 0, 0, 1, 0, 1],
140 | [1, 0, 0, 0, 1, 1, 0]
141 | ])
142 | elif name=="link5":
143 | A= np.array([
144 | [0, 1, 1, 0, 0, 1, 1],
145 | [1, 0, 1, 1, 0, 0, 1],
146 | [1, 1, 0, 1, 1, 0, 0],
147 | [0, 1, 1, 0, 1, 1, 0],
148 | [0, 0, 1, 1, 0, 1, 1],
149 | [1, 0, 0, 1, 1, 0, 1],
150 | [1, 1, 0, 0, 1, 1, 0]
151 | ])
152 | elif name=="link6":
153 | A= np.array([
154 | [0, 1, 1, 1, 0, 1, 1],
155 | [1, 0, 1, 1, 1, 1, 0],
156 | [1, 1, 0, 1, 1, 1, 0],
157 | [1, 1, 1, 0, 1, 1, 0],
158 | [0, 1, 1, 1, 0, 1, 1],
159 | [1, 1, 1, 1, 1, 0, 1],
160 | [1, 0, 0, 0, 1, 1, 0]
161 | ])
162 | elif name=="link7":
163 | A= np.array([
164 | [0, 1, 1, 1, 1, 1, 1],
165 | [1, 0, 1, 1, 1, 1, 1],
166 | [1, 1, 0, 1, 1, 1, 1],
167 | [1, 1, 1, 0, 1, 1, 1],
168 | [1, 1, 1, 1, 0, 1, 1],
169 | [1, 1, 1, 1, 1, 0, 1],
170 | [1, 1, 1, 1, 1, 1, 0]
171 | ])
172 |
173 | degree = np.sum(A, axis=1)
174 |
175 | # 初始化权重矩阵W
176 | W = np.zeros_like(A, dtype=float)
177 |
178 | # 计算权重矩阵
179 | for i in range(A.shape[0]):
180 | for j in range(A.shape[1]):
181 | if A[i, j] == 1:
182 | W[i, j] = 1 / (max(degree[i], degree[j]) + 1)
183 | elif i != j:
184 | W[i, j] = 0
185 |
186 | # 处理对角线元素
187 | for i in range(A.shape[0]):
188 | W[i, i] = 1 - np.sum(W[i, np.arange(A.shape[0]) != i])
189 |
190 |
191 | return W
192 |
193 |
194 | if __name__=="__main__":
195 | model_checkpoint = '/home/ubuntu/smyin/models/Llama-3.2-1B'
196 | dataset_path_template = "/home/ubuntu/smyin/dataset/decentrilized_dataset/gsm8k/client_{}"
197 |
198 | clients = []
199 | num_clients = 7
200 | type = "LoRA"
201 | name="link3"
202 | batch_size=2
203 |
204 |
205 |
206 | for i in range(num_clients):
207 | dataset_path = dataset_path_template.format(i+1)
208 | #device = f"cuda:{i % 4}" # 循环分配到 4 张 GPU
209 | client = Client(model_path=model_checkpoint, data_path=dataset_path, type=type,batch_size=batch_size)
210 | clients.append(client)
211 |
212 | server = FederatedServer(clients)
213 |
214 | server.aggregate_lora_A()
215 |
216 | num_rounds = 3
217 | diff = []
218 |
219 | A=get_A(name)
220 |
221 | server.aggregate_dfl(A)
222 | a = server.calculate_lora_products_and_avg_diff(server.new_params)
223 | diff.append(a)
224 |
225 | for round in range(num_rounds):
226 | logging.info(f"Round {round+1}/{num_rounds}")
227 | for client in server.clients:
228 | loss = client.train_one_epoch()
229 | logging.info(f"Client {server.clients.index(client) + 1} - Loss: {loss}")
230 |
231 | server.aggregate_dfl(A)
232 | a = server.calculate_lora_products_and_avg_diff(server.new_params)
233 | diff.append(a)
234 | logging.info(f"Round {round+1} - Parameter Difference: {a}")
235 |
236 |
237 | for i in range(server.num_clients):
238 | server.clients[i].trainable_parameters_nums()
239 | logging.info(f"Parameter Difference record: {diff}")
240 |
--------------------------------------------------------------------------------
/finetune/bert/server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | import numpy as np
5 | from client import Client
6 | from typing import List, Dict
7 |
8 |
9 | class FederatedServer:
10 | def __init__(self, clients: List[Client]) -> None:
11 | """
12 | Initializes the server with a list of clients.
13 |
14 | Args:
15 | clients (list): List of Client objects.
16 | """
17 | self.clients = clients
18 | self.num_clients = len(clients)
19 |
20 | # Set up logging
21 | current_dir = os.path.dirname(os.path.abspath(__file__))
22 | log_dir = os.path.join(current_dir, 'logs')
23 | os.makedirs(log_dir, exist_ok=True)
24 | logging.basicConfig(
25 | filename=os.path.join(log_dir, 'server_init.log'),
26 | level=logging.INFO,
27 | format='%(asctime)s - %(levelname)s - %(message)s'
28 | )
29 |
30 | def clients_initialize_info(self) -> None:
31 | """
32 | Initializes each client by loading the model and dataset.
33 | """
34 | logging.info("Initializing clients...")
35 | for i, client in enumerate(self.clients):
36 | logging.info(f"Initializing client {i+1}...")
37 | client.print_trainable_parameters() # Just to verify the model parameters
38 | accuracy = client.evaluate() # Initial evaluation
39 | logging.info(f"Initial accuracy of client {i+1}: {accuracy}")
40 |
41 | def aggregate_lora_A(self) -> None:
42 | """
43 | Aggregate LoRA parameters from the first client and set them as trainable parameters
44 | for all other clients to ensure consistency.
45 | """
46 | logging.info("Aggregating LoRA parameters...")
47 | avg_lora_A_params = self.clients[0].get_lora_A()
48 |
49 | # Set the LoRA A parameters for all clients
50 | for client in self.clients:
51 | client.set_trainable_parameters(avg_lora_A_params)
52 |
53 | def aggregate_last_2_layers_params(self) -> None:
54 | """
55 | Aggregates the last 2 layers' parameters across all clients.
56 | """
57 | logging.info("Aggregating last 2 layers' parameters...")
58 | avg_params = self.clients[0].get_last_2_layers()
59 |
60 | # Set the last 2 layers' parameters for all clients
61 | for client in self.clients:
62 | client.set_trainable_parameters(avg_params)
63 |
64 | def aggregate_dfl(self, A: np.array) -> None:
65 | """
66 | Aggregates LoRA parameters using a weighted average, based on a matrix A.
67 |
68 | Args:
69 | A (np.array): A weight matrix used for aggregation.
70 | """
71 | logging.info("Aggregating LoRA parameters with DFL method...")
72 | self.new_params = []
73 |
74 | # Initialize each client's parameters as zero
75 | for i in range(self.num_clients):
76 | client_params = self.clients[i].get_lora_parameters()
77 | zero_params = {name: (torch.zeros_like(param[0]), param[1]) for name, param in client_params.items()}
78 | self.new_params.append(zero_params)
79 |
80 | # Perform aggregation based on weight matrix A
81 | for i in range(self.num_clients):
82 | for j in range(self.num_clients):
83 | client_params = self.clients[j].get_lora_parameters()
84 | for name, (param, requires_grad) in client_params.items():
85 | self.new_params[i][name] = (self.new_params[i][name][0] + param * A[i][j], requires_grad)
86 |
87 | # Update each client's parameters
88 | for i in range(self.num_clients):
89 | self.clients[i].set_trainable_parameters(self.new_params[i])
90 |
91 | logging.info("LoRA parameters aggregated with DFL.")
92 |
93 | def extract_and_multiply_lora_params(self, param_group: Dict[str, tuple]) -> Dict[str, torch.Tensor]:
94 | """
95 | Extracts LoRA A and LoRA B parameters and computes their product.
96 |
97 | Args:
98 | param_group (dict): Dictionary of model parameters.
99 |
100 | Returns:
101 | dict: A dictionary containing the product of LoRA A and LoRA B.
102 | """
103 | result = {}
104 | for param_name, (param, _) in param_group.items():
105 | if 'lora_B.default.weight' in param_name:
106 | prefix = param_name.split('lora_B.default.weight')[0]
107 | lora_A_name = prefix + 'lora_A.default.weight'
108 | lora_B_name = prefix + 'lora_B.default.weight'
109 |
110 | if lora_A_name in param_group and lora_B_name in param_group:
111 | lora_A = param_group[lora_A_name][0]
112 | lora_B = param_group[lora_B_name][0]
113 |
114 | product = torch.matmul(lora_B, lora_A)
115 | result[prefix + 'product'] = product
116 |
117 | return result
118 |
119 | def calculate_lora_products_and_avg_diff(self, param_groups: List[Dict[str, tuple]]) -> float:
120 | """
121 | Calculates the average difference between the LoRA parameter products for all clients.
122 |
123 | Args:
124 | param_groups (list): A list of parameter groups from the clients.
125 |
126 | Returns:
127 | float: The average difference between LoRA parameter products.
128 | """
129 | if len(param_groups) < 2:
130 | raise ValueError("There should be at least two sets of parameters to calculate differences.")
131 |
132 | # Calculate average LoRA parameters across all clients
133 | avg_params = self.clients[0].get_lora_parameters()
134 | avg_params = {name: (torch.zeros_like(param[0]), param[1]) for name, param in avg_params.items()}
135 | for i in range(self.num_clients):
136 | client_params = self.clients[i].get_lora_parameters()
137 | for name, (param, _) in client_params.items():
138 | avg_params[name] = (avg_params[name][0] + param / self.num_clients, _)
139 |
140 | total_diff_sum = 0.0
141 | num_pairs = 0
142 |
143 | # Compute product differences for each pair of clients
144 | for i in range(self.num_clients):
145 | product_1 = self.extract_and_multiply_lora_params(param_groups[i])
146 | product_2 = self.extract_and_multiply_lora_params(avg_params)
147 | pair_diff_sum = 0.0
148 |
149 | # Calculate the difference between the products
150 | for key in product_1.keys():
151 | diff = product_1[key] - product_2[key]
152 | pair_diff_sum += torch.norm(diff).item()
153 |
154 | total_diff_sum += pair_diff_sum
155 | num_pairs += 1
156 |
157 | # Return the average difference
158 | average_diff = total_diff_sum / num_pairs if num_pairs > 0 else 0.0
159 | return average_diff
160 |
161 |
162 | def main() -> None:
163 | """
164 | Main function to initialize clients and federated server, and perform client initialization.
165 | """
166 | # Configuration parameters
167 | model_checkpoint = '/home/ubuntu/smyin/models/distilbert-base-uncased'
168 | dataset_path_template = "/home/ubuntu/smyin/dataset/decentrilized_dataset/sst2_020/client_{}"
169 | val_dataset_path_template = "/home/ubuntu/smyin/dataset/glue/sst2"
170 | num_clients = 7
171 | lora_r = 4
172 | lora_alpha = 32
173 | target_modules = ["q_lin", "v_lin", "pre_classifier", "pre_classifier"] # Example modules
174 | training_type = "LoRA"
175 | dataset_type = "sst2"
176 | batch_size=128
177 |
178 | # Initialize clients
179 | clients = []
180 | for i in range(num_clients):
181 | dataset_path = dataset_path_template.format(i + 1)
182 | val_dataset_path = val_dataset_path_template
183 | client = Client(
184 | model_checkpoint=model_checkpoint,
185 | dataset_path=dataset_path,
186 | val_dataset_path=val_dataset_path,
187 | lora_r=lora_r,
188 | lora_alpha=lora_alpha,
189 | target_modules=target_modules,
190 | training_type=training_type,
191 | dataset_type=dataset_type,
192 | device='cuda' if torch.cuda.is_available() else 'cpu',
193 | batch_size=batch_size
194 | )
195 | clients.append(client)
196 |
197 | # Initialize federated server
198 | server = FederatedServer(clients)
199 |
200 | # Initialize and evaluate the clients
201 | server.clients_initialize_info()
202 |
203 |
204 | if __name__ == "__main__":
205 | main()
206 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/finetune/bert/client.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
6 | from datasets import load_dataset, load_from_disk
7 | from peft import get_peft_model, LoraConfig
8 | from tqdm.auto import tqdm
9 | from transformers import get_scheduler
10 | from sklearn.metrics import accuracy_score
11 |
12 |
13 | class Client:
14 | def __init__(self, model_checkpoint, dataset_path, val_dataset_path, lora_r, lora_alpha, target_modules,
15 | training_type, dataset_type, device='cpu', batch_size=128):
16 | """
17 | Initializes the Client for model training with specified configurations.
18 |
19 | Args:
20 | model_checkpoint (str): Pre-trained model checkpoint path.
21 | dataset_path (str): Path to the training dataset.
22 | val_dataset_path (str): Path to the validation dataset.
23 | lora_r (int): Rank of LoRA layers.
24 | lora_alpha (int): Alpha value for LoRA layers.
25 | target_modules (list): List of target modules for LoRA.
26 | training_type (str): Type of training ('LoRA' or 'ConLoRA').
27 | dataset_type (str): Type of dataset ('sst2', 'mnli', or 'qnli').
28 | device (str): Device for training ('cpu' or 'cuda').
29 | batch_size (int): Batch size for the DataLoader.
30 | """
31 | self.device = device
32 | self.lora_r = lora_r
33 | self.lora_alpha = lora_alpha
34 | self.training_type = training_type
35 | self.dataset_type = dataset_type
36 | self.dataset = load_from_disk(dataset_path)
37 | self.raw_val_dataset = load_dataset(val_dataset_path)
38 | self.batch_size = batch_size # Save batch size as an instance variable
39 |
40 | # Determine number of labels based on dataset type
41 | self.num_labels = self._get_num_labels()
42 |
43 | # Load tokenizer and model
44 | self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
45 | self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=self.num_labels).to(self.device)
46 |
47 | # Configure model for LoRA training
48 | self.configure_model(target_modules)
49 |
50 | # Prepare data loaders for training and validation datasets
51 | self.data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
52 | self.train_data = self.prepare_data(self.dataset)
53 | self.val_data = self.prepare_data(self.raw_val_dataset['validation'])
54 |
55 | # Optimizer setup
56 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
57 |
58 | def _get_num_labels(self):
59 | """
60 | Returns the number of labels based on the dataset type.
61 | """
62 | if self.dataset_type == "sst2":
63 | return 2 # SST-2 is a binary classification problem
64 | elif self.dataset_type == "mnli":
65 | return 3 # MNLI is a 3-class classification problem
66 | elif self.dataset_type == "qnli":
67 | return 2 # QNLI is a binary classification problem
68 | else:
69 | raise ValueError(f"Unsupported dataset type: {self.dataset_type}")
70 |
71 | def tokenize_function(self, examples):
72 | """
73 | Tokenizes the input text based on dataset type (sst2, mnli, qnli).
74 | Args:
75 | examples (dict): A dictionary of input examples.
76 | """
77 | if self.dataset_type == "sst2":
78 | text = examples["sentence"]
79 | self.tokenizer.truncation_side = "left"
80 | tokenized_inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
81 | tokenized_inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long)
82 |
83 | elif self.dataset_type == "mnli":
84 | texts = (examples["premise"], examples["hypothesis"])
85 | self.tokenizer.truncation_side = "left"
86 | tokenized_inputs = self.tokenizer(*texts, return_tensors="pt", truncation=True, max_length=512, padding=True)
87 | tokenized_inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long)
88 |
89 | elif self.dataset_type == "qnli":
90 | texts = (examples["question"], examples["sentence"])
91 | self.tokenizer.truncation_side = "left"
92 | tokenized_inputs = self.tokenizer(*texts, return_tensors="pt", truncation=True, max_length=512, padding=True)
93 | tokenized_inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long)
94 |
95 | return tokenized_inputs
96 |
97 | def prepare_data(self, dataset):
98 | """
99 | Prepares the dataset for training or evaluation by tokenizing and formatting it.
100 |
101 | Args:
102 | dataset (Dataset): The dataset to be processed.
103 | """
104 | tokenized_dataset = dataset.map(self.tokenize_function, batched=True)
105 | tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
106 | return DataLoader(tokenized_dataset, batch_size=self.batch_size, collate_fn=self.data_collator)
107 |
108 | def configure_model(self, target_modules):
109 | """
110 | Configures the model for LoRA (Low-Rank Adaptation) training.
111 |
112 | Args:
113 | target_modules (list): List of target modules for LoRA.
114 | """
115 | if self.tokenizer.pad_token is None:
116 | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
117 | self.model.resize_token_embeddings(len(self.tokenizer))
118 |
119 | if self.training_type == "LoRA":
120 | peft_config = LoraConfig(task_type="SEQ_CLS", r=self.lora_r, lora_alpha=self.lora_alpha, lora_dropout=0.01, target_modules=target_modules)
121 | self.model = get_peft_model(self.model, peft_config)
122 | self._freeze_parameters()
123 |
124 | elif self.training_type == "ConLoRA":
125 | peft_config = LoraConfig(task_type="SEQ_CLS", r=self.lora_r, lora_alpha=self.lora_alpha, lora_dropout=0.01, target_modules=target_modules)
126 | self.model = get_peft_model(self.model, peft_config)
127 | self._freeze_parameters(freeze_lora_A=True)
128 |
129 | def _freeze_parameters(self, freeze_lora_A=False):
130 | """
131 | Freezes the parameters of the model based on the training type.
132 |
133 | Args:
134 | freeze_lora_A (bool): Whether to freeze 'lora_A' parameters.
135 | """
136 | for name, param in self.model.named_parameters():
137 | if "pre_classifier.modules_to_save.default.base_layer" in name or "classifier.modules_to_save.default.base_layer" in name:
138 | param.requires_grad = False
139 | if freeze_lora_A and "lora_A" in name:
140 | param.requires_grad = False
141 | print(self.model.print_trainable_parameters())
142 |
143 | def train_one_epoch(self):
144 | """
145 | Trains the model for one epoch.
146 |
147 | Returns:
148 | loss (float): The loss value for the epoch.
149 | """
150 | self.model.train()
151 | progress_bar = tqdm(range(len(self.train_data)))
152 | lr_scheduler = get_scheduler("linear", optimizer=self.optimizer, num_warmup_steps=0, num_training_steps=len(self.train_data))
153 |
154 | for batch in self.train_data:
155 | batch = {k: v.to(self.device) for k, v in batch.items()}
156 | outputs = self.model(**batch)
157 | loss = outputs.loss
158 | loss.backward()
159 | self.optimizer.step()
160 | #lr_scheduler.step()
161 | self.optimizer.zero_grad()
162 | progress_bar.update(1)
163 |
164 | return loss.item()
165 |
166 | def evaluate(self):
167 | """
168 | Evaluates the model on the validation dataset.
169 |
170 | Returns:
171 | accuracy (float): The accuracy on the validation dataset.
172 | """
173 | self.model.eval()
174 | all_predictions = []
175 | all_labels = []
176 |
177 | for batch in self.val_data:
178 | batch = {k: v.to(self.device) for k, v in batch.items()}
179 | with torch.no_grad():
180 | outputs = self.model(**batch)
181 | logits = outputs.logits
182 | predictions = torch.argmax(logits, dim=-1)
183 | all_predictions.extend(predictions.cpu().numpy())
184 | all_labels.extend(batch["labels"].cpu().numpy())
185 |
186 | accuracy = accuracy_score(all_labels, all_predictions)
187 | return accuracy
188 |
189 | def print_trainable_parameters(self):
190 | """
191 | Prints the number of trainable and total parameters in the model.
192 | """
193 | total_params = sum(p.numel() for p in self.model.parameters())
194 | trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
195 |
196 | print(f"Total parameters: {total_params}")
197 | print(f"Trainable parameters: {trainable_params}")
198 |
199 | def num_get_trainable_parameters(self):
200 | """
201 | Prints detailed information about each trainable parameter.
202 | """
203 | trainable_params = self.get_trainable_parameters()
204 | total_trainable_params = sum(param.numel() for param, _ in trainable_params.values())
205 | print(f"Total number of trainable parameters: {total_trainable_params}")
206 |
207 | print("Details of trainable parameters:")
208 | for name, (param, requires_grad) in trainable_params.items():
209 | print(f"Layer: {name}, Parameter count: {param.numel()}, Requires Grad: {requires_grad}")
210 |
211 | def get_trainable_parameters(self):
212 | """
213 | Returns the trainable parameters of the model.
214 | """
215 | trainable_params = {}
216 | for name, param in self.model.named_parameters():
217 | if param.requires_grad:
218 | trainable_params[name] = (param.clone().detach(), param.requires_grad)
219 | return trainable_params
220 |
221 | def get_lora_parameters(self):
222 | """
223 | Returns the parameters associated with LoRA layers.
224 |
225 | LoRA parameters typically include layers with names containing 'lora_A' or 'lora_B'.
226 |
227 | Returns:
228 | dict: A dictionary with parameter names as keys and their respective values (cloned and detached) as values.
229 | """
230 | lora_params = {}
231 | for name, param in self.model.named_parameters():
232 | if "lora_A" in name or "lora_B" in name:
233 | # Clone the parameter and preserve the 'requires_grad' property
234 | lora_params[name] = (param.clone().detach(), param.requires_grad)
235 | return lora_params
236 |
237 | def get_all_parameters(self):
238 | """
239 | Returns all parameters of the model, including their names and whether they require gradients.
240 |
241 | Returns:
242 | dict: A dictionary with parameter names as keys and tuples of (cloned parameter, requires_grad) as values.
243 | """
244 | all_params = {}
245 | for name, param in self.model.named_parameters():
246 | all_params[name] = (param.clone().detach(), param.requires_grad)
247 | return all_params
248 |
249 | def get_last_2_layers(self):
250 | """
251 | Returns the parameters for the last two layers of the model, typically the classification layers.
252 |
253 | These layers are often named 'pre_classifier' and 'classifier'.
254 |
255 | Returns:
256 | dict: A dictionary with parameter names as keys and their respective values (cloned and detached) as values.
257 | """
258 | last_2_layers_params = {}
259 | for name, param in self.model.named_parameters():
260 | if "pre_classifier" in name or "classifier" in name:
261 | last_2_layers_params[name] = (param.clone().detach(), param.requires_grad)
262 | return last_2_layers_params
263 |
264 | def get_lora_A(self):
265 | """
266 | Returns the parameters for the 'lora_A' layers in the model.
267 |
268 | Returns:
269 | dict: A dictionary with parameter names as keys and their respective values (cloned and detached) as values.
270 | """
271 | lora_A_params = {}
272 | for name, param in self.model.named_parameters():
273 | if "lora_A" in name:
274 | lora_A_params[name] = (param.clone().detach(), param.requires_grad)
275 | return lora_A_params
276 |
277 | def set_trainable_parameters(self, trainable_params):
278 | """
279 | Sets the trainable parameters for the model based on the provided parameters.
280 |
281 | Args:
282 | trainable_params (dict): A dictionary with parameter names as keys and tuples of (parameter, requires_grad) as values.
283 | """
284 | for name, (param, requires_grad) in trainable_params.items():
285 | if name in dict(self.model.named_parameters()):
286 | current_param = dict(self.model.named_parameters())[name]
287 | current_param.data = param.clone().to(self.device)
288 | current_param.requires_grad = requires_grad
289 |
290 |
291 | def main():
292 | parser = argparse.ArgumentParser(description="Client model training script.")
293 | parser.add_argument('--target_modules', type=str, required=True, help="List of target modules for LoRA.")
294 | parser.add_argument('--model_checkpoint', type=str, required=True, help="Path to the model checkpoint.")
295 | parser.add_argument('--dataset_path', type=str, required=True, help="Path to the dataset.")
296 | parser.add_argument('--lora_r', type=int, required=True, help="Rank of LoRA layers.")
297 | parser.add_argument('--lora_alpha', type=int, required=True, help="Alpha value for LoRA layers.")
298 | parser.add_argument('--training_type', type=str, choices=['LoRA', 'ConLoRA'], required=True, help="Training type.")
299 | parser.add_argument('--device', type=str, default='cpu', help="Device for training (default: cpu).")
300 | parser.add_argument('--num_epochs', type=int, default=5, help="Number of epochs to train (default: 5).")
301 | parser.add_argument('--val_dataset_path', type=str, required=True, help="Path to the validation dataset.")
302 | parser.add_argument('--dataset_type', type=str, choices=['sst2', 'mnli', 'qnli'], required=True, help="Type of dataset.")
303 | parser.add_argument('--batch_size', type=int, default=128, help="Batch size for training.")
304 |
305 | args = parser.parse_args()
306 |
307 | # Convert target_modules string to list
308 | target_modules = args.target_modules.split()
309 |
310 | client = Client(
311 | model_checkpoint=args.model_checkpoint,
312 | dataset_path=args.dataset_path,
313 | val_dataset_path=args.val_dataset_path,
314 | lora_r=args.lora_r,
315 | lora_alpha=args.lora_alpha,
316 | training_type=args.training_type,
317 | target_modules=target_modules,
318 | device=args.device,
319 | dataset_type=args.dataset_type,
320 | batch_size=args.batch_size # Pass batch_size from arguments
321 | )
322 |
323 | # Training loop
324 | for epoch in range(args.num_epochs):
325 | loss = client.train_one_epoch()
326 | accuracy = client.evaluate()
327 | print(f"Epoch {epoch + 1}/{args.num_epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
328 |
329 | client.print_trainable_parameters()
330 |
331 |
332 | if __name__ == "__main__":
333 | main()
334 |
--------------------------------------------------------------------------------
/result_processing/avgResult.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | example_data=[[0.324095771777891, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.337238920020377, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.32460519612837496, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.32745797249108505, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3272542027508915, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713], [0.324095771777891, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3271523178807947, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3272542027508915, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883], [0.324095771777891, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.32592969943963324, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3174732552215996, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.32745797249108505, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3272542027508915, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3271523178807947, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3183902190524707, 0.3272542027508915, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3182883341823739, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883], [0.324095771777891, 0.3273560876209883, 0.31818644931227713, 0.3195109526235354, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3182883341823739, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3182883341823739, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3552725420275089, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3272542027508915, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3183902190524707, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3543555781966378, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3182883341823739, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.3543555781966378, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346], [0.324095771777891, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3182883341823739, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.35394803871625063, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31808456444218036, 0.31818644931227713, 0.31818644931227713, 0.3182883341823739, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346, 0.3544574630667346], [0.324095771777891, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.32674477840040755, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3268466632705043, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.31818644931227713, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.31808456444218036, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3182883341823739, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3272542027508915, 0.3544574630667346, 0.35455934793683136, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.35455934793683136, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3543555781966378, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3544574630667346, 0.3273560876209883, 0.3544574630667346], [0.324095771777891, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.327967396841569, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3268466632705043, 0.3273560876209883, 0.32766174223127864, 0.3272542027508915, 0.31818644931227713, 0.3182883341823739, 0.32755985736118187, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3544574630667346, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.31818644931227713, 0.3273560876209883, 0.3273560876209883]]
4 |
5 | # 转换为 NumPy 数组
6 | data = np.array(example_data)
7 |
8 | # 计算每个 epoch 的平均值
9 | epoch_means = np.mean(data, axis=0)
10 | #print("每个 epoch 的平均值:", epoch_means)
11 |
12 | # 取最后 50 轮的平均值
13 | # 示例中将取最后 10 轮,因为示例数据只有 10 轮,实际使用中请替换为你的数据的最后 50 轮+
14 | last_50_epoch_means = epoch_means[-10:] # 请替换为 -50 以匹配你的数据长度
15 |
16 | # 计算平均值
17 | mean_last_50 = np.mean(last_50_epoch_means)
18 | print("最后 50 轮的平均值:", mean_last_50)
19 |
20 | # 计算置信区间(95% 置信水平)
21 | #confidence_level = 0.95
22 | #degrees_freedom = len(last_50_epoch_means) - 1
23 | #sample_mean = np.mean(last_50_epoch_means)
24 | #sample_standard_error = stats.sem(last_50_epoch_means)
25 | #confidence_interval = stats.t.interval(confidence_level, degrees_freedom, sample_mean, sample_standard_error)
26 |
27 | #print("最后 50 轮的平均值的95%置信区间:", confidence_interval)
--------------------------------------------------------------------------------