├── store └── test.txt ├── data └── FASHION_MNIST │ └── test.txt ├── matlab ├── training_result │ └── test.txt ├── DATA │ ├── channel_model_trial_50_K_50_N_4_PL_3.mat │ ├── channel_model_trial_50_K_50_N_64_PL_3.mat │ ├── channel_model_trial_50_K_100_N_1_PL_3_loc.mat │ └── channel_model_trial_50_K_100_N_1_PL_3_single.mat ├── Setup_Init.m ├── Single.m ├── single_relay_channel.m ├── single_relay_channel_loc.m ├── plot_figure.m ├── cell_channel_model.m ├── main_cmp.m ├── plot_Pr.m ├── Xu.m └── AM.m ├── LICENSE ├── Nets.py ├── plot_result.py ├── AirComp.py ├── plot_Pr.py ├── learning_flow.py ├── train_script.py ├── README.md └── main.py /store/test.txt: -------------------------------------------------------------------------------- 1 | test -------------------------------------------------------------------------------- /data/FASHION_MNIST/test.txt: -------------------------------------------------------------------------------- 1 | test -------------------------------------------------------------------------------- /matlab/training_result/test.txt: -------------------------------------------------------------------------------- 1 | test -------------------------------------------------------------------------------- /matlab/DATA/channel_model_trial_50_K_50_N_4_PL_3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhlinup/Relay-FL/HEAD/matlab/DATA/channel_model_trial_50_K_50_N_4_PL_3.mat -------------------------------------------------------------------------------- /matlab/DATA/channel_model_trial_50_K_50_N_64_PL_3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhlinup/Relay-FL/HEAD/matlab/DATA/channel_model_trial_50_K_50_N_64_PL_3.mat -------------------------------------------------------------------------------- /matlab/DATA/channel_model_trial_50_K_100_N_1_PL_3_loc.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhlinup/Relay-FL/HEAD/matlab/DATA/channel_model_trial_50_K_100_N_1_PL_3_loc.mat -------------------------------------------------------------------------------- /matlab/DATA/channel_model_trial_50_K_100_N_1_PL_3_single.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhlinup/Relay-FL/HEAD/matlab/DATA/channel_model_trial_50_K_100_N_1_PL_3_single.mat -------------------------------------------------------------------------------- /matlab/Setup_Init.m: -------------------------------------------------------------------------------- 1 | function setup = Setup_Init(K, N, h_k, f_n, g_kn, P_r) 2 | 3 | %-------------------------------------------------------------------------- 4 | %System model Parameters 5 | setup.K = K; 6 | setup.N = N; 7 | 8 | setup.D = ones(setup.K, 1) / setup.K; 9 | setup.rho = ones(setup.K, 1) / setup.K; 10 | 11 | setup.P_K = ones(setup.K, 1) * P_r; 12 | setup.P_N = ones(setup.N, 1) * 0.1; 13 | 14 | setup.P_0 = 0.1; 15 | setup.P_r = P_r; 16 | 17 | setup.SNR = 100; 18 | setup.sigma_0 = power(10, -setup.SNR / 10); 19 | setup.noise_N = ones(setup.N, 1) * setup.sigma_0; 20 | 21 | setup.sigma = power(10, -setup.SNR / 10); 22 | 23 | setup.J_max = 100; 24 | setup.threshold = 1e-4; 25 | 26 | setup.h_k = h_k; 27 | setup.f_n = f_n; 28 | setup.g_kn = g_kn; 29 | 30 | -------------------------------------------------------------------------------- /matlab/Single.m: -------------------------------------------------------------------------------- 1 | function [w, true_w, ave_mse, mse, MMSE, tx_scaling, rx_scaling_opt] = Single(setup, d, signal) 2 | 3 | g = signal; 4 | g_mean = mean(signal, 2); 5 | global_g_mean = setup.rho.' * g_mean; 6 | 7 | g_var = var(signal, 0, 2); 8 | global_g_var = setup.rho.' * g_var; 9 | 10 | var_mean_sqrt = sqrt(global_g_var); 11 | 12 | rx_scaling = 1 / sqrt(setup.P_0) * setup.rho ./ abs(setup.h_k); 13 | 14 | rx_scaling_opt = max(rx_scaling); 15 | 16 | tx_scaling = setup.rho ./ setup.h_k / rx_scaling_opt; 17 | 18 | noise_1 = (randn(1, d) + 1j * randn(1, d)) / sqrt(2) * sqrt(setup.sigma); 19 | 20 | x_signal = repmat(tx_scaling, 1, d) .* ((signal - global_g_mean) / var_mean_sqrt); 21 | y = setup.h_k.' * x_signal + noise_1; 22 | 23 | w = real(y * rx_scaling_opt * var_mean_sqrt + global_g_mean); 24 | true_w = setup.rho.' * signal; 25 | ave_mse = norm(true_w - w)^2 / norm(true_w)^2; 26 | 27 | mse = norm(true_w - w)^2 / d; 28 | 29 | rho_hat = tx_scaling .* setup.h_k * rx_scaling_opt; 30 | MMSE = norm(rho_hat - setup.rho)^2 + setup.sigma * rx_scaling_opt^2; 31 | 32 | aa = 1; -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zehong Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /matlab/single_relay_channel.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | 4 | K = 100; 5 | N = 1; 6 | 7 | f_c = 915 * 10^6; % carrier bandwidth 8 | G_A = 4.11; % antenna gain 9 | PL = 3; % path loss value 10 | light = 3 * 10^8; % speed of light 11 | 12 | trial = 50; 13 | 14 | d_AP = [0, 0]; 15 | 16 | channel_U = zeros(K, trial); 17 | channel_R = zeros(N, trial); 18 | channel_UR = zeros(K, N, trial); 19 | 20 | for i = 1 : trial 21 | 22 | dx_U = unifrnd(80, 120, K, 1); 23 | dy_U = unifrnd(-60, 60, K, 1); 24 | 25 | dx_R = 50; 26 | dy_R = 0; 27 | 28 | dis_U = sqrt((dx_U - d_AP(1)).^2 + (dy_U - d_AP(2)).^2); 29 | dis_R = sqrt((dx_R - d_AP(1)).^2 + (dy_R - d_AP(2)).^2); 30 | dis_UR = zeros(K, N); 31 | 32 | for k = 1 : K 33 | for n = 1 : N 34 | dis_UR(k, n) = sqrt((dx_U(k) - dx_R(n))^2 + (dy_U(k) - dy_R(n))^2); 35 | end 36 | end 37 | 38 | PL_U = G_A * (light ./ (4 * pi * f_c .* dis_U)).^PL; 39 | PL_R = G_A * (light ./ (4 * pi * f_c .* dis_R)).^PL; 40 | PL_UR = G_A * (light ./ (4 * pi * f_c .* dis_UR)).^PL; 41 | 42 | g_rayl_U = (randn(K, 1) + 1j * randn(K, 1)) / sqrt(2); %Rayleigh fading component for K 43 | g_rayl_R = (randn(N, 1) + 1j * randn(N, 1)) / sqrt(2); 44 | g_rayl_UR = (randn(K, N) + 1j * randn(K, N)) / sqrt(2); 45 | 46 | channel_U(:, i) = g_rayl_U .* sqrt(PL_U); 47 | channel_R(:, i) = g_rayl_R .* sqrt(PL_R); 48 | channel_UR(:, :, i) = g_rayl_UR .* sqrt(PL_UR); 49 | end 50 | 51 | filename=['DATA/channel_model_trial_' num2str(trial) '_K_' num2str(K) '_N_' num2str(N) '_PL_' num2str(PL) '_single.mat']; 52 | save(filename) -------------------------------------------------------------------------------- /Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | 6 | # import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, dim_in, dim_hidden, dim_out): 13 | super(MLP, self).__init__() 14 | self.layer_input = nn.Linear(dim_in, dim_hidden) 15 | self.bn = nn.BatchNorm1d(dim_hidden) 16 | self.dropout = nn.Dropout() 17 | self.relu = nn.ReLU() 18 | self.layer_hidden = nn.Linear(dim_hidden, dim_out) 19 | 20 | def forward(self, x): 21 | x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1]) 22 | x = self.layer_input(x) 23 | x = self.bn(x) 24 | x = self.relu(x) 25 | x = self.layer_hidden(x) 26 | return x 27 | 28 | 29 | class CNNMnist(nn.Module): 30 | def __init__(self, num_channels, num_classes, batch_norm=False): 31 | super(CNNMnist, self).__init__() 32 | self.conv1 = nn.Conv2d(num_channels, 10, kernel_size=5) 33 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 34 | if batch_norm: 35 | self.conv2_norm = nn.BatchNorm2d(20) 36 | else: 37 | self.conv2_norm = nn.Dropout2d() 38 | self.fc1 = nn.Linear(320, 50) 39 | self.fc2 = nn.Linear(50, num_classes) 40 | 41 | def forward(self, x): 42 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 43 | x = F.relu(F.max_pool2d(self.conv2_norm(self.conv2(x)), 2)) 44 | x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3]) 45 | x = F.relu(self.fc1(x)) 46 | x = F.dropout(x, training=self.training) 47 | x = self.fc2(x) 48 | return x 49 | -------------------------------------------------------------------------------- /matlab/single_relay_channel_loc.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | 4 | K = 100; 5 | N = 8; 6 | 7 | f_c = 915 * 10^6; % carrier bandwidth 8 | G_A = 4.11; % antenna gain 9 | PL = 3; % path loss value 10 | light = 3 * 10^8; % speed of light 11 | 12 | trial = 50; 13 | 14 | d_AP = [0, 0]; 15 | 16 | channel_U = zeros(K, trial); 17 | channel_R = zeros(N, trial); 18 | channel_UR = zeros(K, N, trial); 19 | 20 | for i = 1 : trial 21 | 22 | dx_U = unifrnd(80, 120, K, 1); 23 | dy_U = unifrnd(-60, 60, K, 1); 24 | 25 | dx_R = 10: 10: 80; 26 | dx_R = dx_R'; 27 | dy_R = zeros(N, 1); 28 | 29 | dis_U = sqrt((dx_U - d_AP(1)).^2 + (dy_U - d_AP(2)).^2); 30 | dis_R = sqrt((dx_R - d_AP(1)).^2 + (dy_R - d_AP(2)).^2); 31 | dis_UR = zeros(K, N); 32 | 33 | for k = 1 : K 34 | for n = 1 : N 35 | dis_UR(k, n) = sqrt((dx_U(k) - dx_R(n))^2 + (dy_U(k) - dy_R(n))^2); 36 | end 37 | end 38 | 39 | PL_U = G_A * (light ./ (4 * pi * f_c .* dis_U)).^PL; 40 | PL_R = G_A * (light ./ (4 * pi * f_c .* dis_R)).^PL; 41 | PL_UR = G_A * (light ./ (4 * pi * f_c .* dis_UR)).^PL; 42 | 43 | g_rayl_U = (randn(K, 1) + 1j * randn(K, 1)) / sqrt(2); %Rayleigh fading component for K 44 | g_rayl_R = (randn(N, 1) + 1j * randn(N, 1)) / sqrt(2); 45 | g_rayl_UR = (randn(K, N) + 1j * randn(K, N)) / sqrt(2); 46 | 47 | channel_U(:, i) = g_rayl_U .* sqrt(PL_U); 48 | channel_R(:, i) = g_rayl_R .* sqrt(PL_R); 49 | channel_UR(:, :, i) = g_rayl_UR .* sqrt(PL_UR); 50 | end 51 | 52 | filename=['DATA/channel_model_trial_' num2str(trial) '_K_' num2str(K) '_N_1_PL_' num2str(PL) '_loc.mat']; 53 | save(filename) -------------------------------------------------------------------------------- /matlab/plot_figure.m: -------------------------------------------------------------------------------- 1 | clear 2 | 3 | trial = 30; 4 | K = 20; 5 | N = 1; 6 | B = 0; 7 | E = 1; 8 | lr = 0.05; 9 | 10 | filename=['training_result/cmp_time_trial_' num2str(trial) '_K_' num2str(K) '_N_' num2str(N) '_B_' num2str(B) '_E_' num2str(E) '.mat']; 11 | 12 | load(filename); 13 | 14 | index1 = 0 : length(test_accuracy1) - 1; 15 | index2 = 0 : 2: length(test_accuracy1) - 1; 16 | 17 | linesize=1.5; 18 | MarkerSize=8; 19 | LineWidth=1.5; 20 | 21 | figure 22 | 23 | hold on 24 | plot(index1, test_accuracy1, 'k--', 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerIndices', 1: 10: length(index1)); 25 | plot(index2, test_accuracy2, 'r-o', 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerFaceColor', 'r', 'MarkerIndices', 1: 100: length(index2)); 26 | plot(index1, test_accuracy3, '-^', 'Color', [0.4940 0.1840 0.5560], 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerFaceColor', [0.4940 0.1840 0.5560], 'MarkerIndices', 1: 100: length(index1)); 27 | plot(index2, test_accuracy5, '-p', 'Color', [0.4660 0.6740 0.1880], 'LineWidth', LineWidth, 'MarkerSize', 2 + MarkerSize, 'MarkerFaceColor', [0.4660 0.6740 0.1880], 'MarkerIndices', 1: 100: length(index2)); 28 | 29 | set(get(gca, 'Children'), 'linewidth', 1.5) 30 | set(gca, 'XTick', 0: 200: length(index1)) 31 | % set(gca, 'XLim', [K_set(1), K_set(end)]) 32 | set(gca, 'YTick', 0: 0.1: 0.9) 33 | axis([index1(1) index1(end) 0 0.9]) 34 | 35 | grid on 36 | box on 37 | hl = legend('Error-free channel', 'Proposed scheme', 'FL without relays [29]', 'Relay-assisted scheme in [22]'); 38 | set(hl,'Interpreter', 'latex', 'fontsize', 12, 'location', 'southeast') 39 | xlabel('Number of Transmission Blocks', 'Interpreter', 'latex', 'fontsize', 14); 40 | ylabel('Test Accuracy','Interpreter', 'latex', 'fontsize', 14); -------------------------------------------------------------------------------- /matlab/cell_channel_model.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | 4 | K = 50; 5 | N = 64; 6 | 7 | f_c = 915 * 10^6; % carrier bandwidth 8 | G_A = 4.11; % antenna gain 9 | PL = 3; % path loss value 10 | light = 3 * 10^8; % speed of light 11 | 12 | trial = 50; 13 | 14 | d_AP = [0, 0]; 15 | 16 | channel_U = zeros(K, trial); 17 | channel_R = zeros(N, trial); 18 | channel_UR = zeros(K, N, trial); 19 | 20 | r_min = 0; 21 | r_max = 120; 22 | 23 | A = 2 / (r_max * r_max - r_min * r_min); 24 | 25 | relay_r = ones(N, 1) * 50; 26 | relay_theta = zeros(N, 1); 27 | for n = 1 : N 28 | relay_theta(n) = (n - 1) * 2 * pi / N; 29 | end 30 | 31 | ini = 1; 32 | relay_theta2 = []; 33 | for i = 1 : log2(64) + 1 34 | relay_theta2 = [relay_theta2, relay_theta(1 : 64 / ini : 64)]; 35 | ini = ini * 2; 36 | end 37 | 38 | dx_R = relay_r .* cos(relay_theta); 39 | dy_R = relay_r .* sin(relay_theta); 40 | 41 | rng(1) 42 | user_theta_set = unifrnd(0, 2 * pi, K, 1, trial); 43 | user_r_set = sqrt(2 .* unifrnd(0, 1, K, 1, trial) / A + r_min * r_min); 44 | 45 | dx_U = user_r_set .* cos(user_theta_set); 46 | dy_U = user_r_set .* sin(user_theta_set); 47 | 48 | for i = 1 : trial 49 | 50 | user_theta = user_theta_set(:, :, 1); 51 | user_r = user_r_set(:, :, 1); 52 | 53 | dx_U = user_r .* cos(user_theta); 54 | dy_U = user_r .* sin(user_theta); 55 | 56 | dis_U = sqrt((dx_U - d_AP(1)).^2 + (dy_U - d_AP(2)).^2); 57 | dis_R = sqrt((dx_R - d_AP(1)).^2 + (dy_R - d_AP(2)).^2); 58 | dis_UR = zeros(K, N); 59 | 60 | for k = 1 : K 61 | for n = 1 : N 62 | dis_UR(k, n) = sqrt((dx_U(k) - dx_R(n))^2 + (dy_U(k) - dy_R(n))^2); 63 | end 64 | end 65 | 66 | PL_U = G_A * (light ./ (4 * pi * f_c .* dis_U)).^PL; 67 | PL_R = G_A * (light ./ (4 * pi * f_c .* dis_R)).^PL; 68 | PL_UR = G_A * (light ./ (4 * pi * f_c .* dis_UR)).^PL; 69 | 70 | g_rayl_U = (randn(K, 1) + 1j * randn(K, 1)) / sqrt(2); %Rayleigh fading component for K 71 | g_rayl_R = (randn(64, 1) + 1j * randn(64, 1)) / sqrt(2); 72 | g_rayl_UR = (randn(K, 64) + 1j * randn(K, 64)) / sqrt(2); 73 | 74 | channel_U(:, i) = g_rayl_U .* sqrt(PL_U); 75 | channel_R(:, i) = g_rayl_R(1 : 64 / N : 64) .* sqrt(PL_R); 76 | channel_UR(:, :, i) = g_rayl_UR(:, 1 : 64 / N : 64) .* sqrt(PL_UR); 77 | end 78 | 79 | filename=['DATA/channel_model_trial_' num2str(trial) '_K_' num2str(K) '_N_' num2str(N) '_PL_' num2str(PL) '.mat']; 80 | save(filename) 81 | -------------------------------------------------------------------------------- /matlab/main_cmp.m: -------------------------------------------------------------------------------- 1 | clear 2 | clc 3 | tic 4 | 5 | load('DATA/channel_model_trial_50_K_100_N_1_PL_3_single'); % single-relay 6 | % load('DATA/channel_model_trial_50_K_100_N_1_PL_3_loc'); % single relay location 7 | % load('DATA/channel_model_trial_50_K_50_N_4_PL_3'); % single-cell network 8 | % load('DATA/channel_model_trial_50_K_50_N_64_PL_3'); % single-cell network with 64 relays 9 | 10 | K = 20; 11 | N = 1; 12 | trial = 50; 13 | PL = 3; 14 | 15 | V_set = [0.01 0.1 0.3 0.5 1.0]; 16 | 17 | V_length = length(V_set); 18 | 19 | Proposed_nmse = zeros(V_length, trial); 20 | Single_nmse = zeros(V_length, trial); 21 | Xu_nmse = zeros(V_length, trial); 22 | 23 | Proposed_mse = zeros(V_length, trial); 24 | Single_mse = zeros(V_length, trial); 25 | Xu_mse = zeros(V_length, trial); 26 | 27 | Proposed_mmse = zeros(V_length, trial); 28 | Single_mmse = zeros(V_length, trial); 29 | Xu_mmse = zeros(V_length, trial); 30 | 31 | Proposed_a_k1 = zeros(V_length, K, trial); 32 | Proposed_a_k2 = zeros(V_length, K, trial); 33 | Proposed_b_n = zeros(V_length, N, trial); 34 | Proposed_c_1 = zeros(V_length, trial); 35 | Proposed_c_2 = zeros(V_length, trial); 36 | 37 | Proposed_ite = zeros(V_length, trial); 38 | 39 | Single_a_k1 = zeros(V_length, K, trial); 40 | Single_c_1 = zeros(V_length, trial); 41 | 42 | Xu_a_k1 = zeros(V_length, K, trial); 43 | Xu_b_n = zeros(V_length, N, trial); 44 | Xu_eta = zeros(V_length, trial); 45 | 46 | d = 100000; 47 | signal = normrnd(0, 1, [K, d]); 48 | grad = mean(signal, 1); 49 | 50 | filename=['DATA/trial_' num2str(trial) '_K_' num2str(K) '_N_' num2str(N) '_PL_' num2str(PL) '_Pr_source.mat']; 51 | 52 | for V_idx = 1 : V_length 53 | 54 | P_r = V_set(V_idx); 55 | 56 | parfor iter = 1 : trial % parallel computing 57 | % for iter = 1 : trial 58 | fprintf('%d-th trial\n', iter); 59 | 60 | h_k = channel_U(1 : K, iter); 61 | f_n = channel_R(1 : N, iter); 62 | g_kn = channel_UR(1 : K, 1 : N, iter); 63 | 64 | setup = Setup_Init(K, N, h_k, f_n, g_kn, P_r); 65 | 66 | tic; 67 | t_start = cputime; 68 | 69 | [w1, true_w1, Single_nmse(V_idx, iter), Single_mse(V_idx, iter), Single_mmse(V_idx, iter), Single_a_k1(V_idx, :, iter), Single_c_1(V_idx, iter)] = Single(setup, d, signal); 70 | [w3, true_w3, Xu_nmse(V_idx, iter), Xu_mse(V_idx, iter), Xu_mmse(V_idx, iter), Xu_a_k1(V_idx, :, iter), Xu_b_n(V_idx, :, iter), Xu_eta(V_idx, iter)] = Xu(setup, d, signal); 71 | [w4, true_w4, Proposed_ite(V_idx, iter), Proposed_nmse(V_idx, iter), Proposed_mse(V_idx, iter), Proposed_mmse(V_idx, iter), Proposed_a_k1(V_idx, :, iter), Proposed_a_k2(V_idx, :, iter), Proposed_b_n(V_idx, :, iter), Proposed_c_1(V_idx, iter), Proposed_c_2(V_idx, iter)] = AM(setup, d, signal); 72 | 73 | t_end = cputime; 74 | toc_end = toc; 75 | end 76 | save(filename); 77 | end -------------------------------------------------------------------------------- /plot_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import copy 5 | import scipy.io as sio 6 | 7 | if __name__ == '__main__': 8 | trial = 50 9 | K = 20 10 | N = 1 11 | SNR = 100 12 | B = 0 13 | E = 1 14 | lr = 0.05 15 | PL = 3.0 16 | P_r = 0.1 17 | iid = 1 18 | noniid_level = 2 19 | loc = 50 20 | 21 | kappa = 0.4 22 | 23 | filename = 'store/trial_{}_K_{}_N_{}_B_{}_E_{}_lr_{}_SNR_{}_PL_{}_Pr_{}.npz'.format(trial, K, N, B, E, lr, SNR, PL, 24 | P_r) 25 | 26 | print(filename) 27 | 28 | nmse = np.zeros(5) 29 | 30 | a = np.load(filename, allow_pickle=1) 31 | 32 | result_CNN_set = a['arr_1'] 33 | result_MSE_set = a['arr_2'] 34 | result_NMSE_set = a['arr_2'] 35 | 36 | nmse1 = a['arr_3'] 37 | nmse2 = a['arr_4'] 38 | nmse4 = a['arr_6'] 39 | 40 | for i in range(trial): 41 | if i == 0: 42 | res_CNN = copy.deepcopy(result_CNN_set[0]) 43 | else: 44 | for item in res_CNN.keys(): 45 | res_CNN[item] += copy.deepcopy(result_CNN_set[i][item]) 46 | 47 | for item in res_CNN.keys(): 48 | res_CNN[item] = copy.deepcopy(res_CNN[item] / trial) 49 | 50 | test_accuracy1 = res_CNN['accuracy_test1'] 51 | test_accuracy2 = res_CNN['accuracy_test2'] 52 | test_accuracy3 = res_CNN['accuracy_test3'] 53 | test_accuracy5 = res_CNN['accuracy_test5'] 54 | 55 | nmse[1] = 10 * np.log10(np.mean(nmse1[~np.isnan(nmse1)])) 56 | nmse[2] = 10 * np.log10(np.mean(nmse2[~np.isnan(nmse2)])) 57 | nmse[4] = 10 * np.log10(np.mean(nmse4[~np.isnan(nmse4)])) 58 | 59 | matfile = 'matlab/training_result/cmp_time_trial_{}_K_{}_N_{}_B_{}_E_{}.mat'.format(trial, K, N, B, E) 60 | sio.savemat(matfile, mdict={'test_accuracy1': test_accuracy1[0: 1001], 'test_accuracy2': test_accuracy2[0: 501], 61 | 'test_accuracy3': test_accuracy3[0: 1001], 'test_accuracy5': test_accuracy5[0: 501]}) 62 | matfile2 = 'matlab/training_result/cmp_time_trial_{}_K_{}_N_{}_B_{}_E_{}_NMSE.mat'.format(trial, K, N, B, E) 63 | sio.savemat(matfile2, mdict={'nmse': nmse}) 64 | 65 | plt.plot(np.arange(0, len(test_accuracy1)), test_accuracy1, 'k--', label=r'Error-Free Channel') 66 | plt.plot(np.arange(0, 2 * len(test_accuracy2), 2), test_accuracy2, '-o', markersize=6, markevery=100, 67 | label=r'Proposed Scheme') 68 | plt.plot(np.arange(0, len(test_accuracy3)), test_accuracy3, '-*', markersize=8, markevery=100, 69 | label=r'Conventional') 70 | plt.plot(np.arange(0, 2 * len(test_accuracy5), 2), test_accuracy5, '->', markersize=6, markevery=100, 71 | label=r'Existing Scheme') 72 | 73 | plt.legend() 74 | plt.xlim([0, 1000]) 75 | plt.ylim([0, 0.9]) 76 | plt.xlabel('Transmission Time Slot') 77 | plt.ylabel('Test Accuracy') 78 | plt.grid() 79 | plt.show() 80 | -------------------------------------------------------------------------------- /matlab/plot_Pr.m: -------------------------------------------------------------------------------- 1 | clear 2 | 3 | trial = 28; 4 | K = 20; 5 | N = 1; 6 | B = 0; 7 | E = 1; 8 | 9 | Pr_set = [0.01, 0.1, 0.3, 0.5, 1]; 10 | 11 | 12 | filename=['training_result/cmp_Pr_trial_' num2str(trial) '_K_' num2str(K) '_N_' num2str(N) '_B_' num2str(B) '_E_' num2str(E) '.mat']; 13 | 14 | load(filename); 15 | 16 | filename=['training_result/cmp_Pr_trial_' num2str(trial) '_K_' num2str(K) '_N_' num2str(N) '_B_' num2str(B) '_E_' num2str(E) '_NMSE.mat']; 17 | 18 | load(filename); 19 | 20 | linesize=1.5; 21 | MarkerSize=8; 22 | LineWidth=1.5; 23 | 24 | figure 25 | 26 | hold on 27 | plot(Pr_set, test_accuracy(1, 1 : end), 'k--', 'LineWidth', LineWidth, 'MarkerSize', MarkerSize); 28 | plot(Pr_set, test_accuracy(2, 1 : end), 'r-o', 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerFaceColor', 'r'); 29 | plot(Pr_set, ones(length(Pr_set), 1) * mean(test_accuracy(3, 1: end)), '-^', 'Color', [0.4940 0.1840 0.5560], 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerFaceColor', [0.4940 0.1840 0.5560]); 30 | plot(Pr_set, test_accuracy(5, 1 : end), '-p', 'Color', [0.4660 0.6740 0.1880], 'LineWidth', LineWidth, 'MarkerSize', 2 + MarkerSize, 'MarkerFaceColor', [0.4660 0.6740 0.1880]); 31 | 32 | set(get(gca, 'Children'), 'linewidth', 1.5) 33 | set(gca, 'XTick', [0.01, 0.1, 0.3, 0.5, 1]) 34 | % xticklabels({'0.01', '0.1', '0.3', '0.5', '1'}) 35 | % set(gca, 'XLim', [Pr_set(1), Pr_set(end)]) 36 | set(gca, 'YTick', 0: 0.1: 0.9) 37 | axis([0 Pr_set(end) 0 0.9]) 38 | axis([0.01 1 0.4 0.9]) 39 | 40 | grid on 41 | box on 42 | hl = legend('Error-free channel', 'Proposed scheme', 'Conventional scheme', 'Existing scheme [26]'); 43 | set(hl,'Interpreter', 'latex', 'fontsize', 12, 'location', 'southeast') 44 | xlabel('Maximum Relay Transmit Power $P_r$ (W)', 'Interpreter', 'latex', 'fontsize', 14); 45 | ylabel('Test Accuracy','Interpreter', 'latex', 'fontsize', 14); 46 | 47 | figure 48 | 49 | hold on 50 | plot(Pr_set, nmse(2, 1 : end), 'r-o', 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerFaceColor', 'r'); 51 | plot(Pr_set, ones(length(Pr_set), 1) * mean(nmse(3, 1: end)), '-^', 'Color', [0.4940 0.1840 0.5560], 'LineWidth', LineWidth, 'MarkerSize', MarkerSize, 'MarkerFaceColor', [0.4940 0.1840 0.5560]); 52 | plot(Pr_set, nmse(5, 1 : end), '-p', 'Color', [0.4660 0.6740 0.1880], 'LineWidth', LineWidth, 'MarkerSize', 2 + MarkerSize, 'MarkerFaceColor', [0.4660 0.6740 0.1880]); 53 | 54 | set(get(gca, 'Children'), 'linewidth', 1.5) 55 | set(gca, 'XTick', [0.01, 0.1, 0.3, 0.5, 1]) 56 | % xticklabels({'0.01', '0.1', '0.3', '0.5', '1'}) 57 | set(gca, 'XLim', [Pr_set(1), Pr_set(end)]) 58 | % set(gca, 'YTick', -10: 5: 10) 59 | % axis([0 Pr_set(end) 0 0.9]) 60 | axis([0.01 1 -10 10]) 61 | 62 | grid on 63 | box on 64 | hl = legend('Proposed scheme', 'Conventional scheme', 'Existing scheme [26]'); 65 | set(hl,'Interpreter', 'latex', 'fontsize', 12, 'location', 'southeast') 66 | xlabel('Maximum Relay Transmit Power $P_r$ (W)', 'Interpreter', 'latex', 'fontsize', 14); 67 | ylabel('Average NMSE (dB)','Interpreter', 'latex', 'fontsize', 14); -------------------------------------------------------------------------------- /AirComp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def AM(setup, d, signal, a_k1, a_k2, b_n, c_1, c_2): 5 | rho = setup.rho 6 | 7 | g_mean = np.mean(signal, axis=1) 8 | global_g_mean = rho.T @ g_mean 9 | 10 | g_var = np.var(signal, axis=1) 11 | global_g_var = rho.T @ g_var 12 | 13 | var_mean_sqrt = global_g_var ** 0.5 14 | 15 | noise_1 = (np.random.randn(d) + 1j * np.random.randn(d)) / np.sqrt(2) * np.sqrt(setup.sigma) 16 | noise_2 = (np.random.randn(d) + 1j * np.random.randn(d)) / np.sqrt(2) * np.sqrt(setup.sigma) 17 | noise_N = (np.random.randn(setup.N, d) + 1j * np.random.randn(setup.N, d)) / np.sqrt(2) * np.sqrt(setup.sigma) 18 | 19 | x_signal_1 = np.tile(a_k1, (d, 1)).T * (signal - np.tile(global_g_mean, (d, 1)).T) / var_mean_sqrt 20 | x_signal_2 = np.tile(a_k2, (d, 1)).T * (signal - np.tile(global_g_mean, (d, 1)).T) / var_mean_sqrt 21 | 22 | r_n = setup.g_kn.T @ x_signal_1 + noise_N 23 | 24 | y_1 = setup.h_k.T @ x_signal_1 + noise_1 25 | y_2 = setup.h_k.T @ x_signal_2 + setup.f_n.T @ (np.tile(b_n, (d, 1)).T * r_n) + noise_2 26 | 27 | w = np.real((y_1 * c_1 + y_2 * c_2) * var_mean_sqrt + global_g_mean) 28 | true_w = rho.T @ signal 29 | avg_mse = np.linalg.norm((true_w - w)) ** 2 / np.linalg.norm(true_w) ** 2 30 | mse2 = np.linalg.norm((true_w - w)) ** 2 / d 31 | return w, true_w, avg_mse, mse2 32 | 33 | 34 | def Single(setup, d, signal, a_k1, c_1): 35 | rho = setup.rho 36 | 37 | g_mean = np.mean(signal, axis=1) 38 | global_g_mean = rho.T @ g_mean 39 | 40 | g_var = np.var(signal, axis=1) 41 | global_g_var = rho.T @ g_var 42 | 43 | var_mean_sqrt = global_g_var ** 0.5 44 | 45 | noise_1 = (np.random.randn(d) + 1j * np.random.randn(d)) / np.sqrt(2) * np.sqrt(setup.sigma) 46 | 47 | x_signal = np.tile(a_k1, (d, 1)).T * (signal - np.tile(global_g_mean, (d, 1)).T) / var_mean_sqrt 48 | 49 | y = setup.h_k.T @ x_signal + noise_1 50 | 51 | w = np.real(y * c_1 * var_mean_sqrt + global_g_mean) 52 | true_w = rho.T @ signal 53 | avg_mse = np.linalg.norm((true_w - w)) ** 2 / np.linalg.norm(true_w) ** 2 54 | mse2 = np.linalg.norm((true_w - w)) ** 2 / d 55 | return w, true_w, avg_mse, mse2 56 | 57 | 58 | def Xu(setup, d, signal, a_k1, b_n, c_2): 59 | rho = setup.rho 60 | 61 | g_mean = np.mean(signal, axis=1) 62 | global_g_mean = rho.T @ g_mean 63 | 64 | g_var = np.var(signal, axis=1) 65 | global_g_var = rho.T @ g_var 66 | 67 | var_mean_sqrt = global_g_var ** 0.5 68 | 69 | # noise_1 = (np.random.randn(d) + 1j * np.random.randn(d)) / np.sqrt(2) * np.sqrt(setup.sigma) 70 | noise_2 = (np.random.randn(d) + 1j * np.random.randn(d)) / np.sqrt(2) * np.sqrt(setup.sigma) 71 | noise_N = (np.random.randn(setup.N, d) + 1j * np.random.randn(setup.N, d)) / np.sqrt(2) * np.sqrt(setup.sigma) 72 | 73 | x_signal_1 = np.tile(a_k1, (d, 1)).T * (signal - np.tile(global_g_mean, (d, 1)).T) / var_mean_sqrt 74 | 75 | r_n = setup.g_kn.T @ x_signal_1 + noise_N 76 | 77 | y = setup.f_n.T @ (np.tile(b_n, (d, 1)).T * r_n) + noise_2 78 | 79 | w = np.real(y / c_2 * var_mean_sqrt + global_g_mean) 80 | true_w = rho.T @ signal 81 | avg_mse = np.linalg.norm((true_w - w)) ** 2 / np.linalg.norm(true_w) ** 2 82 | mse2 = np.linalg.norm((true_w - w)) ** 2 / d 83 | return w, true_w, avg_mse, mse2 84 | -------------------------------------------------------------------------------- /plot_Pr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import copy 6 | import scipy.io as sio 7 | 8 | if __name__ == '__main__': 9 | d = 21921 10 | trial = 50 11 | K = 20 12 | N = 1 13 | SNR = 100 14 | B = 0 15 | E = 1 16 | lr = 0.05 17 | PL = 3.0 18 | P_r = 0.1 19 | iid = 1 20 | noniid_level = 2 21 | Pr_set = [0.01, 0.1, 0.3, 0.5, 1] 22 | 23 | test_accuracy = np.zeros([5, len(Pr_set)]) 24 | training_loss = np.zeros([5, len(Pr_set)]) 25 | nmse = np.zeros([5, len(Pr_set)]) 26 | 27 | for i in range(len(Pr_set)): 28 | P_r = Pr_set[i] 29 | filename = 'store/trial_{}_K_{}_N_{}_B_{}_E_{}_lr_{}_SNR_{}_PL_{}_Pr_{}.npz'.format(trial, K, N, B, E, lr, SNR, 30 | PL, P_r) 31 | 32 | a = np.load(filename, allow_pickle=True) 33 | res = a['arr_1'] 34 | nmse1 = a['arr_3'] 35 | nmse2 = a['arr_4'] 36 | nmse4 = a['arr_6'] 37 | 38 | nmse[1, i] = 10 * np.log10(np.mean(nmse1[~np.isnan(nmse1)])) 39 | nmse[2, i] = 10 * np.log10(np.mean(nmse2[~np.isnan(nmse2)])) 40 | nmse[4, i] = 10 * np.log10(np.mean(nmse4[~np.isnan(nmse4)])) 41 | 42 | res_CNN = {} 43 | for iter in range(trial): 44 | if iter == 0: 45 | res_CNN = copy.deepcopy(res[0]) 46 | else: 47 | for item in res_CNN.keys(): 48 | res_CNN[item] += copy.deepcopy(res[iter][item]) 49 | 50 | for item in res_CNN.keys(): 51 | res_CNN[item] = copy.deepcopy(res_CNN[item] / trial) 52 | 53 | test_accuracy[0, i] = res_CNN['accuracy_test1'][1000] 54 | test_accuracy[1, i] = res_CNN['accuracy_test2'][500] 55 | test_accuracy[2, i] = res_CNN['accuracy_test3'][1000] 56 | test_accuracy[4, i] = res_CNN['accuracy_test5'][500] 57 | 58 | print(test_accuracy) 59 | print(training_loss) 60 | 61 | matfile = 'matlab/training_result/cmp_Pr_trial_{}_K_{}_N_{}_B_{}_E_{}.mat'.format(trial, K, N, B, E) 62 | sio.savemat(matfile, mdict={'test_accuracy': test_accuracy}) 63 | 64 | matfile2 = 'matlab/training_result/cmp_Pr_trial_{}_K_{}_N_{}_B_{}_E_{}_NMSE.mat'.format(trial, K, N, B, E) 65 | sio.savemat(matfile2, mdict={'nmse': nmse}) 66 | 67 | plt.plot(Pr_set, test_accuracy[0], 'r-') 68 | plt.plot(Pr_set, test_accuracy[1], 'b-') 69 | plt.plot(Pr_set, test_accuracy[2], 'g-') 70 | plt.plot(Pr_set, test_accuracy[4], 'y-') 71 | plt.legend(labels=['Error-Free', 'Proposed', 'Conventional', 'Existing Scheme'], loc='lower center', 72 | fontsize='x-large') 73 | plt.xlim([0.01, 1]) 74 | plt.xticks(Pr_set) 75 | # plt.ylim([0.2, 0.9]) 76 | # plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9]) 77 | plt.xlabel('Maximum relay transmit power $P_r$') 78 | plt.ylabel('Test Accuracy') 79 | plt.grid() 80 | 81 | plt.figure() 82 | 83 | plt.plot(Pr_set, nmse[1], 'b-') 84 | plt.plot(Pr_set, nmse[2], 'g-') 85 | plt.plot(Pr_set, nmse[4], 'y-') 86 | plt.legend(labels=['Proposed', 'Conventional', 'Existing Scheme'], loc='lower center', 87 | fontsize='x-large') 88 | plt.xlim([0.01, 1]) 89 | plt.xticks(Pr_set) 90 | # plt.ylim([0.2, 0.9]) 91 | # plt.yticks([0.5, 0.6, 0.7, 0.8, 0.9]) 92 | plt.xlabel('Maximum relay transmit power $P_r$') 93 | plt.ylabel('Average NMSE') 94 | 95 | plt.show() 96 | -------------------------------------------------------------------------------- /learning_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import torch 4 | import train_script 5 | import AirComp 6 | 7 | 8 | def FedAvg_grad(w_glob, grad, device): 9 | ind = 0 10 | w_return = copy.deepcopy(w_glob) 11 | 12 | for item in w_return.keys(): 13 | a = np.array(w_return[item].size()) 14 | if len(a): 15 | b = np.prod(a) 16 | w_return[item] = copy.deepcopy(w_return[item]) + torch.from_numpy( 17 | np.reshape(grad[ind: ind + b], a)).float().to(device) 18 | ind = ind + b 19 | return w_return 20 | 21 | 22 | def learning_iter(setup, d, net_glob, w_glob, train_images, train_labels, test_images, test_labels, 23 | trans_mode, a_k1, a_k2, b_n, c_1, c_2): 24 | loss_train = [] 25 | mse_train = [] 26 | mse2_train = [] 27 | accuracy_test = [] 28 | loss_test_set = [] 29 | 30 | net_glob.eval() 31 | acc_test, loss_test = train_script.test_model(net_glob, setup, test_images, test_labels) 32 | accuracy_test.append(acc_test) 33 | 34 | net_glob.train() 35 | 36 | if trans_mode == 1 or trans_mode == 3: 37 | epochs = setup.epochs * 2 38 | else: 39 | epochs = setup.epochs 40 | 41 | setup.lr = setup.init_lr 42 | for iter in range(epochs): 43 | if iter > 1 and iter % setup.step == 0: 44 | setup.lr = max(setup.lr * setup.gamma, setup.low_lr) 45 | 46 | gradient_store_per_iter = np.zeros([setup.K, d]) 47 | 48 | loss_locals = [] 49 | ind = 0 50 | for idx in range(setup.K): 51 | if setup.local_bs == 0: 52 | size = int(setup.size[idx]) 53 | else: 54 | size = min(int(setup.size[idx]), setup.local_bs) 55 | 56 | w, loss, gradient = train_script.local_update(setup, d, copy.deepcopy(net_glob).to(setup.device), 57 | train_images, train_labels, idx, size) 58 | 59 | loss_locals.append(copy.deepcopy(loss)) 60 | 61 | copy_g = copy.deepcopy(w) 62 | copy_g[np.isnan(copy_g)] = 0 63 | 64 | gradient_store_per_iter[ind, :] = copy_g 65 | ind = ind + 1 66 | 67 | if trans_mode == 1: 68 | grad = np.average(copy.deepcopy(gradient_store_per_iter), axis=0, weights=setup.rho) 69 | mse = 0 70 | mse2 = 0 71 | 72 | elif trans_mode == 2: 73 | grad, _, mse, mse2 = AirComp.AM(setup, d, copy.deepcopy(gradient_store_per_iter), 74 | a_k1, a_k2, b_n, c_1, c_2) 75 | 76 | elif trans_mode == 3: 77 | grad, _, mse, mse2 = AirComp.Single(setup, d, copy.deepcopy(gradient_store_per_iter), a_k1, c_1) 78 | 79 | elif trans_mode == 5: 80 | grad, _, mse, mse2 = AirComp.Xu(setup, d, copy.deepcopy(gradient_store_per_iter), a_k1, b_n, c_2) 81 | 82 | # if setup.verbose: 83 | # print(10 * np.log10(mse)) 84 | # print(10 * np.log10(mse2)) 85 | # print(np.mean(np.abs(gradient_store_per_iter) ** 2)) 86 | # print(np.mean(np.abs(grad) ** 2)) 87 | 88 | w_glob = copy.deepcopy(FedAvg_grad(w_glob, grad, setup.device)) 89 | net_glob.load_state_dict(w_glob) 90 | # loss 91 | loss_avg = sum(loss_locals) / len(loss_locals) 92 | if setup.verbose: 93 | print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg)) 94 | 95 | loss_train.append(loss_avg) 96 | mse_train.append(mse) 97 | mse2_train.append(mse2) 98 | 99 | acc_test, loss_test = train_script.test_model(net_glob, setup, test_images, test_labels) 100 | accuracy_test.append(acc_test) 101 | loss_test_set.append(loss_test) 102 | net_glob.train() 103 | 104 | return loss_train, accuracy_test, loss_test_set, mse_train, mse2_train 105 | -------------------------------------------------------------------------------- /matlab/Xu.m: -------------------------------------------------------------------------------- 1 | function [w, true_w, ave_mse, mse, MMSE, a_k1, b_n, eta] = Xu(setup, d, signal) 2 | 3 | g_mean = mean(signal, 2); 4 | global_g_mean = setup.rho.' * g_mean; 5 | 6 | g_var = var(signal, 0, 2); 7 | global_g_var = setup.rho.' * g_var; 8 | 9 | var_mean_sqrt = sqrt(global_g_var); 10 | 11 | h_k = setup.h_k; 12 | f_n = setup.f_n; 13 | g_kn = setup.g_kn; 14 | 15 | rho = setup.rho; 16 | P_0 = setup.P_0; 17 | P_r = setup.P_r; 18 | sigma = setup.sigma; 19 | 20 | K = setup.K; 21 | N = setup.N; 22 | 23 | rx_scaling = sqrt(P_0) * abs(h_k) ./ rho; 24 | 25 | eta = min(rx_scaling); 26 | 27 | a_k1 = eta * rho ./ h_k; 28 | 29 | b_n = zeros(N, 1); 30 | for n = 1 : N 31 | b_n(n) = sqrt(P_r / (transpose(abs(g_kn(:, n)).^2) * abs(a_k1).^2 + sigma)); 32 | end 33 | 34 | obj = norm(g_kn * (f_n .* b_n) .* a_k1 / eta - rho)^2 ... 35 | + (1 + transpose(abs(f_n).^2) * abs(b_n).^2) * sigma / abs(eta)^2; 36 | 37 | obj_pre = 1e6; 38 | threshold = setup.threshold; 39 | 40 | scaling_factor = 1e2; 41 | scaling_factor2 = 1e-2; 42 | 43 | obj_vec = [obj]; 44 | 45 | while (obj_pre - obj) / obj_pre > threshold 46 | 47 | obj_pre = obj; 48 | 49 | a_angle = - angle(g_kn * (f_n .* b_n)); 50 | 51 | hat_Pr = P_r ./ (abs(b_n).^2) - sigma; 52 | 53 | cvx_begin quiet 54 | variable a_vec1(K); 55 | minimize(square_pos(norm(scaling_factor * abs(g_kn * (f_n .* b_n)) .* a_vec1 / eta - scaling_factor * rho))); 56 | subject to 57 | for k = 1 : K 58 | a_vec1(k) <= sqrt(P_0); 59 | end 60 | for n = 1 : N 61 | transpose(scaling_factor * abs(g_kn(:, n)).^2) * power(a_vec1, 2) <= scaling_factor * hat_Pr(n); 62 | end 63 | cvx_end 64 | cvx_status; 65 | 66 | if strcmp(cvx_status, 'Failed') ~= 1 67 | a_k1 = a_vec1 .* exp(1j * a_angle); 68 | else 69 | cvx_begin quiet 70 | variable a_vec1(K); 71 | minimize(square_pos(norm(abs(g_kn * (f_n .* b_n)) .* a_vec1 / eta - rho))); 72 | subject to 73 | for k = 1 : K 74 | a_vec1(k) <= sqrt(P_0); 75 | end 76 | for n = 1 : N 77 | transpose(abs(g_kn(:, n)).^2) * power(a_vec1, 2) <= hat_Pr(n); 78 | end 79 | cvx_end 80 | cvx_status 81 | if strcmp(cvx_status, 'Failed') ~= 1 82 | a_k1 = a_vec1 .* exp(1j * a_angle); 83 | else 84 | cvx_begin quiet 85 | variable a_vec1(K); 86 | minimize(square_pos(norm(scaling_factor2 * abs(g_kn * (f_n .* b_n)) .* a_vec1 / eta - scaling_factor2 * rho))); 87 | subject to 88 | for k = 1 : K 89 | a_vec1(k) <= sqrt(P_0); 90 | end 91 | for n = 1 : N 92 | transpose(scaling_factor2 * abs(g_kn(:, n)).^2) * power(a_vec1, 2) <= scaling_factor2 * hat_Pr(n); 93 | end 94 | cvx_end 95 | cvx_status 96 | if strcmp(cvx_status, 'Failed') ~= 1 97 | a_k1 = a_vec1 .* exp(1j * a_angle); 98 | end 99 | end 100 | end 101 | 102 | bar_Pr = zeros(N, 1); 103 | a1_g = zeros(K, N); 104 | for k = 1 : K 105 | a1_g(k, :) = a_k1(k) * transpose(g_kn(k, :)); 106 | end 107 | tmp2 = 0; 108 | 109 | 110 | tmp1 = (rho .* a_k1).' * g_kn; 111 | for k = 1 : K 112 | tmp2 = tmp2 + abs(a_k1(k))^2 * g_kn(k, :)' * g_kn(k, :); 113 | end 114 | 115 | b_vec = eta * ((scaling_factor^4 * (tmp2 + sigma * eye(N)) * diag(f_n)) \ (scaling_factor^4 * tmp1)'); 116 | 117 | for n = 1 : N 118 | tmp3 = 0; 119 | 120 | for k = 1 : K 121 | tmp3 = tmp3 + abs(g_kn(k, n))^2 * abs(a_k1(k))^2; 122 | end 123 | hat_bn = b_vec(n); 124 | bar_Pr(n) = P_r / (tmp3 + sigma); 125 | 126 | if abs(hat_bn) >= sqrt(bar_Pr(n)) 127 | b_n(n) = sqrt(bar_Pr(n)) * hat_bn / norm(hat_bn); 128 | else 129 | b_n(n) = hat_bn; 130 | end 131 | end 132 | 133 | eta = (b_n' * diag(f_n)' * (tmp2 + sigma * eye(N)) * diag(f_n) * b_n + sigma) / (transpose(rho .* a_k1) * g_kn * diag(f_n) * b_n); 134 | 135 | obj = norm(g_kn * (f_n .* b_n) .* a_k1 / eta - rho)^2 ... 136 | + (1 + transpose(abs(f_n).^2) * abs(b_n).^2) * sigma / abs(eta)^2; 137 | 138 | obj_vec = [obj_vec, obj]; 139 | end 140 | 141 | noise_2 = (randn(1, d) + 1j * randn(1, d)) / sqrt(2) * sqrt(sigma); 142 | noise_N = (randn(N, d) + 1j * randn(N, d)) / sqrt(2) * sqrt(sigma); 143 | 144 | x_signal_1 = repmat(a_k1, 1, d) .* ((signal - global_g_mean) / var_mean_sqrt); 145 | 146 | r_n= transpose(g_kn) * x_signal_1 + noise_N; 147 | 148 | y = setup.f_n.' * (b_n .* r_n) + noise_2; 149 | 150 | w = real(y / eta * var_mean_sqrt + global_g_mean); 151 | true_w = rho.' * signal; 152 | ave_mse = norm(true_w - w)^2 / norm(true_w)^2; 153 | 154 | mse = norm(true_w - w)^2 / d; 155 | 156 | rho_hat = g_kn * (f_n .* b_n) .* a_k1 / eta; 157 | MMSE = norm(g_kn * (f_n .* b_n) .* a_k1 / eta - rho)^2 ... 158 | + (1 + transpose(abs(f_n).^2) * abs(b_n).^2) * sigma / abs(eta)^2; -------------------------------------------------------------------------------- /matlab/AM.m: -------------------------------------------------------------------------------- 1 | function [w, true_w, ite, ave_mse, mse, MMSE, a_k1, a_k2, b_n, c_1, c_2] = AM(setup, d, signal) 2 | 3 | g_mean = mean(signal, 2); 4 | global_g_mean = setup.rho.' * g_mean; 5 | 6 | g_var = var(signal, 0, 2); 7 | global_g_var = setup.rho.' * g_var; 8 | 9 | var_mean_sqrt = sqrt(global_g_var); 10 | 11 | h_k = setup.h_k; 12 | f_n = setup.f_n; 13 | g_kn = setup.g_kn; 14 | 15 | rho = setup.rho; 16 | P_0 = setup.P_0; 17 | P_r = setup.P_r; 18 | sigma = setup.sigma; 19 | 20 | K = setup.K; 21 | N = setup.N; 22 | 23 | rx_scaling = sqrt(2) / sqrt(P_0) * rho ./ abs(h_k) / 2; 24 | 25 | c_1 = max(rx_scaling); 26 | c_2 = c_1; 27 | 28 | a_k1 = rho ./ h_k / c_1 / 2; 29 | a_k2 = a_k1; 30 | 31 | b_n = zeros(N, 1); 32 | for n = 1 : N 33 | b_n(n) = sqrt(P_r / (transpose(abs(g_kn(:, n)).^2) * abs(a_k1).^2 + sigma)); 34 | end 35 | 36 | ga_m = abs(g_kn).^2; 37 | 38 | theta = c_1 * h_k + c_2 * g_kn * (f_n .* b_n); 39 | phi = c_2 * h_k; 40 | 41 | rho_hat = theta .* a_k1 + phi .* a_k2; 42 | obj = norm(theta .* a_k1 + phi .* a_k2 - rho)^2 ... 43 | + (abs(c_1)^2 + abs(c_2)^2 * transpose(abs(f_n).^2) * abs(b_n).^2 ... 44 | + abs(c_2)^2) * sigma; 45 | 46 | J_max = setup.J_max; 47 | threshold = setup.threshold; 48 | obj_vec = [obj]; 49 | 50 | scaling_factor = 1e2; 51 | scaling_factor2 = 1e-2; 52 | 53 | for j = 1 : J_max 54 | obj_pre = obj; 55 | 56 | hat_Pr = P_r ./ abs(b_n).^2 - sigma; 57 | 58 | theta = c_1 * h_k + c_2 * g_kn * (f_n .* b_n); 59 | phi = c_2 * h_k; 60 | 61 | cvx_begin quiet 62 | variable a_vec1(K) complex; 63 | variable a_vec2(K) complex; 64 | minimize(square_pos(norm(scaling_factor * theta .* a_vec1 + scaling_factor * phi .* a_vec2 - scaling_factor * rho))); 65 | subject to 66 | for k = 1 : K 67 | pow_abs(a_vec1(k), 2) <= P_0 / 2; 68 | pow_abs(a_vec2(k), 2) <= P_0 / 2; 69 | end 70 | for n = 1 : N 71 | transpose(scaling_factor * abs(g_kn(:, n)).^2) * pow_abs(a_vec1, 2) <= scaling_factor * hat_Pr(n); 72 | end 73 | cvx_end 74 | cvx_status; 75 | 76 | if strcmp(cvx_status, 'Failed') ~= 1 77 | a_k1 = a_vec1; 78 | a_k2 = a_vec2; 79 | else 80 | cvx_begin quiet 81 | variable a_vec1(K) complex; 82 | variable a_vec2(K) complex; 83 | minimize(square_pos(norm(theta .* a_vec1 + phi .* a_vec2 - rho))); 84 | subject to 85 | for k = 1 : K 86 | pow_abs(a_vec1(k), 2) <= P_0 / 2; 87 | pow_abs(a_vec2(k), 2) <= P_0 / 2; 88 | end 89 | for n = 1 : N 90 | transpose(abs(g_kn(:, n)).^2) * pow_abs(a_vec1, 2) <= hat_Pr(n); 91 | end 92 | cvx_end 93 | cvx_status 94 | if strcmp(cvx_status, 'Failed') ~= 1 95 | a_k1 = a_vec1; 96 | a_k2 = a_vec2; 97 | else 98 | cvx_begin quiet 99 | variable a_vec1(K) complex; 100 | variable a_vec2(K) complex; 101 | minimize(square_pos(norm(scaling_factor2 * theta .* a_vec1 + scaling_factor2 * phi .* a_vec2 - scaling_factor2 * rho))); 102 | subject to 103 | for k = 1 : K 104 | pow_abs(a_vec1(k), 2) <= P_0 / 2; 105 | pow_abs(a_vec2(k), 2) <= P_0 / 2; 106 | end 107 | for n = 1 : N 108 | transpose(scaling_factor2 * abs(g_kn(:, n)).^2) * pow_abs(a_vec1, 2) <= scaling_factor2 * hat_Pr(n); 109 | end 110 | cvx_end 111 | cvx_status 112 | if strcmp(cvx_status, 'Failed') ~= 1 113 | a_k1 = a_vec1; 114 | a_k2 = a_vec2; 115 | end 116 | end 117 | end 118 | 119 | cons_cm = rho - c_1 * h_k .* a_k1 - c_2 * h_k .* a_k2; 120 | bar_Pr = zeros(N, 1); 121 | a1_g = zeros(K, N); 122 | for k = 1 : K 123 | a1_g(k, :) = a_k1(k) * transpose(g_kn(k, :)); 124 | end 125 | tmp1 = 0; 126 | tmp2 = 0; 127 | for k = 1 : K 128 | tmp1 = tmp1 + cons_cm(k) * ctranspose(a1_g(k, :)); 129 | tmp2 = tmp2 + abs(a_k1(k))^2 * (g_kn(k, :)' * g_kn(k, :)); 130 | end 131 | 132 | b_vec = ((tmp2 + sigma * eye(N)) * diag(f_n)* c_2) \ (tmp1); 133 | 134 | for n = 1 : N 135 | tmp3 = 0; 136 | 137 | for k = 1 : K 138 | tmp3 = tmp3 + abs(g_kn(k, n))^2 * abs(a_k1(k))^2; 139 | end 140 | hat_bn = b_vec(n); 141 | bar_Pr(n) = P_r / (tmp3 + sigma); 142 | 143 | if abs(hat_bn) > sqrt(bar_Pr(n)) 144 | b_n(n) = sqrt(bar_Pr(n)) * hat_bn / norm(hat_bn); 145 | else 146 | b_n(n) = hat_bn; 147 | end 148 | end 149 | 150 | 151 | tmp1 = 0; 152 | tmp2 = 0; 153 | tmp3 = g_kn * (f_n .* b_n) .* a_k1; 154 | 155 | for k = 1 : K 156 | tmp1 = tmp1 + (rho(k) - c_2 * (tmp3(k) + h_k(k) * a_k2(k))) * conj(h_k(k) * a_k1(k)); 157 | tmp2 = tmp2 + abs(h_k(k))^2 * abs(a_k1(k))^2; 158 | end 159 | 160 | c_1 = tmp1 / (tmp2 + sigma); 161 | 162 | 163 | tmp1 = 0; 164 | tmp2 = 0; 165 | tmp4 = sigma * transpose(abs(f_n).^2) * abs(b_n).^2; 166 | 167 | for k = 1 : K 168 | tmp1 = tmp1 + conj(tmp3(k) + h_k(k) * a_k2(k)) * (rho(k) - c_1 * h_k(k) * a_k1(k)); 169 | tmp2 = tmp2 + abs(tmp3(k) + h_k(k) * a_k2(k))^2; 170 | end 171 | 172 | c_2 = tmp1 / (tmp2 + tmp4 + sigma); 173 | 174 | 175 | theta = c_1 * h_k + c_2 * g_kn * (f_n .* b_n); 176 | phi = c_2 * h_k; 177 | 178 | obj = norm(theta .* a_k1 + phi .* a_k2 - rho)^2 ... 179 | + (abs(c_1)^2 + abs(c_2)^2 * transpose(abs(f_n).^2) * abs(b_n).^2 ... 180 | + abs(c_2)^2) * sigma; 181 | 182 | 183 | if abs(obj - obj_pre) / abs(obj) <= threshold 184 | break 185 | end 186 | 187 | obj_vec = [obj_vec, obj]; 188 | end 189 | 190 | ite = j; 191 | 192 | noise_1 = (randn(1, d) + 1j * randn(1, d)) / sqrt(2) * sqrt(sigma); 193 | noise_2 = (randn(1, d) + 1j * randn(1, d)) / sqrt(2) * sqrt(sigma); 194 | 195 | noise_N = (randn(N, d) + 1j * randn(N, d)) / sqrt(2) * sqrt(sigma); 196 | 197 | x_signal_1 = repmat(a_k1, 1, d) .* ((signal - global_g_mean) / var_mean_sqrt); 198 | x_signal_2 = repmat(a_k2, 1, d) .* ((signal - global_g_mean) / var_mean_sqrt); 199 | 200 | r_n= transpose(g_kn) * x_signal_1 + noise_N; 201 | 202 | y_1 = setup.h_k.' * x_signal_1 + noise_1; 203 | y_2 = setup.h_k.' * x_signal_2 + setup.f_n.' * (b_n .* r_n) + noise_2; 204 | 205 | w = real((y_1 * c_1 + y_2 * c_2) * var_mean_sqrt + global_g_mean); 206 | true_w = rho.' * signal; 207 | ave_mse = norm(true_w - w)^2 / norm(true_w)^2; 208 | 209 | mse = norm(true_w - w)^2 / d; 210 | 211 | theta = c_1 * h_k + c_2 * g_kn * (f_n .* b_n); 212 | phi = c_2 * h_k; 213 | 214 | rho_hat = theta .* a_k1 + phi .* a_k2; 215 | MMSE = norm(rho_hat - rho)^2 ... 216 | + (abs(c_1)^2 + abs(c_2)^2 * transpose(abs(f_n).^2) * abs(b_n).^2 ... 217 | + abs(c_2)^2) * sigma; 218 | 219 | 220 | -------------------------------------------------------------------------------- /train_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | np.set_printoptions(precision=6, threshold=1e3) 6 | import torch 7 | 8 | from torchvision import datasets, transforms 9 | import copy 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader 12 | 13 | 14 | def mnist_iid(dataset, K, M): 15 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 16 | 17 | for i in range(M): 18 | dict_users[i] = set(np.random.choice(all_idxs, int(K[i]), replace=False)) 19 | all_idxs = list(set(all_idxs) - dict_users[i]) 20 | return dict_users 21 | 22 | 23 | def load_fmnist_iid(K): 24 | transform = transforms.Compose([transforms.ToTensor(), 25 | transforms.Normalize((0.1307,), (0.3081,)) 26 | ]) 27 | dataset_train = datasets.FashionMNIST('./data/FASHION_MNIST/', download=True, train=True, transform=transform) 28 | dataset_test = datasets.FashionMNIST('./data/FASHION_MNIST/', download=True, train=False, transform=transform) 29 | 30 | loader = DataLoader(dataset_train, batch_size=len(dataset_train), shuffle=False) 31 | images, labels = next(enumerate(loader))[1] 32 | images, labels = images.numpy(), labels.numpy() 33 | D_k = int(len(labels) / K) 34 | 35 | train_images = [] 36 | train_labels = [] 37 | dict_users = {i: np.array([], dtype='int64') for i in range(K)} 38 | all_idxs = np.arange(len(labels)) 39 | 40 | D = np.zeros(K) 41 | for i in range(K): 42 | dict_users[i] = set(np.random.choice(all_idxs, int(D_k), replace=False)) 43 | all_idxs = list(set(all_idxs) - dict_users[i]) 44 | train_images.append(images[list(dict_users[i])]) 45 | train_labels.append(labels[list(dict_users[i])]) 46 | D[i] = len(dict_users[i]) 47 | 48 | test_loader = DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=True) 49 | test_images, test_labels = next(enumerate(test_loader))[1] 50 | 51 | return train_images, train_labels, test_images.numpy(), test_labels.numpy(), D 52 | 53 | 54 | def load_fmnist_noniid(K, NUM_SHARDS): 55 | transform = transforms.Compose([transforms.ToTensor(), 56 | transforms.Normalize((0.1307,), (0.3081,)) 57 | ]) 58 | dataset_train = datasets.FashionMNIST('./data/FASHION_MNIST/', download=True, train=True, transform=transform) 59 | dataset_test = datasets.FashionMNIST('./data/FASHION_MNIST/', download=True, train=False, transform=transform) 60 | 61 | loader = DataLoader(dataset_train, batch_size=len(dataset_train), shuffle=False) 62 | images, labels = next(enumerate(loader))[1] 63 | images, labels = images.numpy(), labels.numpy() 64 | 65 | train_images = [] 66 | train_labels = [] 67 | 68 | # PART = 10 69 | PART = 1 70 | 71 | num_shards = K * NUM_SHARDS * PART 72 | num_imgs = int(len(images) / num_shards) 73 | idx_shard = [i for i in range(num_shards)] 74 | dict_users = {i: np.array([], dtype='int64') for i in range(K)} 75 | all_idxs = np.arange(len(labels)) 76 | 77 | # sort labels 78 | idxs_labels = np.vstack((all_idxs, labels)) 79 | idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()] 80 | all_idxs = idxs_labels[0, :] 81 | 82 | idx_shard = idx_shard[::PART] 83 | 84 | D = np.zeros(K) 85 | for i in range(K): 86 | rand_set = set(np.random.choice(idx_shard, NUM_SHARDS, replace=False)) 87 | idx_shard = list(set(idx_shard) - rand_set) 88 | for rand in rand_set: 89 | dict_users[i] = np.concatenate((dict_users[i], all_idxs[rand * num_imgs:(rand + 1) * num_imgs]), axis=0) 90 | train_images.append(images[dict_users[i]]) 91 | train_labels.append(labels[dict_users[i]]) 92 | D[i] = len(dict_users[i]) 93 | 94 | test_loader = DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=True) 95 | test_images, test_labels = next(enumerate(test_loader))[1] 96 | 97 | return train_images, train_labels, test_images.numpy(), test_labels.numpy(), D 98 | 99 | 100 | def local_update(setup, d, model1, train_images, train_labels, idx, batch_size): 101 | initital_weight = copy.deepcopy(model1.state_dict()) 102 | 103 | model = copy.deepcopy(model1) 104 | model.train() 105 | 106 | loss_function = nn.CrossEntropyLoss() 107 | 108 | optimizer = torch.optim.SGD(model.parameters(), lr=setup.lr, momentum=setup.momentum) 109 | 110 | # optimizer = torch.optim.Adam(model.parameters(), lr=setup.lr) 111 | 112 | epoch_loss = [] 113 | images = np.array_split(train_images[idx], len(train_images[idx]) // batch_size) 114 | labels = np.array_split(train_labels[idx], len(train_labels[idx]) // batch_size) 115 | 116 | for epoch in range(setup.local_ep): 117 | batch_loss = [] 118 | for b_idx in range(len(images)): 119 | model.zero_grad() 120 | 121 | log_probs = model(torch.tensor(images[b_idx].copy(), device=setup.device)) 122 | local_loss = loss_function(log_probs, torch.tensor(labels[b_idx].copy(), device=setup.device)) 123 | 124 | local_loss.backward() 125 | optimizer.step() 126 | if setup.verbose == 2: 127 | print('User: {}, Epoch: {}, Batch No: {}/{} Loss: {:.6f}'.format(idx, 128 | epoch, b_idx + 1, len(images), 129 | local_loss.item())) 130 | batch_loss.append(local_loss.item()) 131 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 132 | 133 | copyw = copy.deepcopy(model.state_dict()) 134 | gradient2 = np.array([[]]) 135 | w2 = np.array([[]]) 136 | for item in copyw.keys(): 137 | gradient2 = np.hstack((gradient2, np.reshape((initital_weight[item] - copyw[item]).cpu().numpy(), 138 | [1, -1]) / setup.lr)) 139 | 140 | w2 = np.hstack((w2, np.reshape((copyw[item] - initital_weight[item]).cpu().numpy(), 141 | [1, -1]))) 142 | 143 | return w2, sum(epoch_loss) / len(epoch_loss), gradient2 144 | 145 | 146 | def test_model(model, setup, test_images, test_labels): 147 | model.eval() 148 | loss, total, correct = 0.0, 0.0, 0.0 149 | 150 | images = torch.tensor(test_images).to(setup.device) 151 | labels = torch.tensor(test_labels).to(setup.device) 152 | outputs = model(images).to(setup.device) 153 | loss_function = nn.CrossEntropyLoss() 154 | batch_loss = loss_function(outputs, labels) 155 | loss += batch_loss.item() 156 | _, pred_labels = torch.max(outputs, 1) 157 | pred_labels = pred_labels.view(-1) 158 | 159 | correct += torch.sum(torch.eq(pred_labels, labels)).item() 160 | total += len(labels) 161 | accuracy = correct / total 162 | 163 | if setup.verbose: 164 | print('Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( 165 | loss, int(correct), int(total), 100.0 * accuracy)) 166 | return accuracy, loss 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Relay-FL 2 | This is the simulation code package for the following paper: 3 | 4 | Zehong Lin, Hang Liu, and Ying-Jun Angela Zhang, “Relay-Assisted Cooperative Federated Learning,” IEEE Transactions on Wireless Communications, DOI: 10.1109/TWC.2022.3155596. [[ArXiv Version](https://arxiv.org/abs/2107.09518)] 5 | 6 | The package, written on Python 3 and Matlab, reproduces the numerical results of the proposed algorithm in the above paper. 7 | 8 | 9 | ## Abstract of Article: 10 | 11 | > Federated learning (FL) has recently emerged as a promising technology to enable artificial intelligence (AI) at the network edge, where distributed mobile devices collaboratively train a shared AI model under the coordination of an edge server. To significantly improve the communication efficiency of FL, over-the-air computation allows a large number of mobile devices to concurrently upload their local models by exploiting the superposition property of wireless multi-access channels. Due to wireless channel fading, the model aggregation error at the edge server is dominated by the weakest channel among all devices, causing severe straggler issues. In this paper, we propose a relay-assisted cooperative FL scheme to effectively address the straggler issue. In particular, we deploy multiple half-duplex relays to cooperatively assist the devices in uploading the local model updates to the edge server. The nature of the over-the-air computation poses system objectives and constraints that are distinct from those in traditional relay communication systems. Moreover, the strong coupling between the design variables renders the optimization of such a system challenging. To tackle the issue, we propose an alternating-optimization-based algorithm to optimize the transceiver and relay operation with low complexity. Then, we analyze the model aggregation error in a single-relay case and show that our relay-assisted scheme achieves a smaller error than the one without relays provided that the relay transmit power and the relay channel gains are sufficiently large. The analysis provides critical insights on relay deployment in the implementation of cooperative FL. Extensive numerical results show that our design achieves faster convergence compared with state-of-the-art schemes. 12 | 13 | ## Referencing 14 | 15 | If you in any way use this code for research that results in publications, please cite our original article listed above. 16 | 17 | ## Dependencies 18 | This package is written on Matlab and Python 3. It requires the following libraries: 19 | * Matlab and CVX 20 | * Python >= 3.5 21 | * torch 22 | * torchvision 23 | * scipy 24 | * CUDA (if GPU is used) 25 | 26 | ## Documentations (Please also see each file for more details): 27 | 28 | * __data/__: Store the Fashion-MNIST dataset. When running at the first time, it automatically downloads the dataset from the Interenet. 29 | * __store/__: Store output files (\*.npz) 30 | * __matlab/__: Documents for data and codes to be used in Matlab 31 | * __DATA/__: Store files (\*.mat) for channel models and optimization results in Matlab 32 | * __training_result/__: Store files for training results (\*.mat) to be plotted for presentation 33 | * __main_cmp.m__: Initialize the simulation system, optimizing the variables 34 | * __Setup_Init.m__: Specify and initialize the system parameters 35 | * __AM.m__: Alternating minization algorithm proposed in the paper 36 | * __Single.m__: Conventional over-the-air model aggregation scheme 37 | * __Xu.m__: Existing relay-assisted scheme in Ref. [23] 38 | * __single_relay_channel.m__: Construct the channel model for the single-relay case 39 | * __single_relay_channel_loc.m__: Construct the channel model for the single-relay case with varying relay location 40 | * __cell_channel_model.m__: Construct the channel model for the multi-relay case in a single-cell 41 | * __plot_figure.m__: plot the figure with varying transmission blocks from the training results stored in training_result/ 42 | * __plot_Pr.m__: plot the figure with varying P_r from the training results stored in training_result/ 43 | * __main.py__: Initialize the simulation system, training the learning model, and storing the result to store/ as a npz file 44 | * __initial()__: Initialize the parser function to read the user-input parameters 45 | * __learning_flow.py__: Read the optimization result, initial the learning model, and perform training and testing 46 | * __Learning_iter()__: Given learning model, compute the graidents, update the training models, and perform testing on top of train_script.py 47 | * __FedAvg_grad()__: Given the aggregated model changes and the current model, update the global model by eq.(5) 48 | * __Nets.py__: 49 | * __CNNMnist()__: Specify the convolutional neural network structure used for learning 50 | * __MLP()__: Specify the multiple layer perceptron structure used for learning 51 | * __AirComp.py__: 52 | * __AM()__: Given the local model changes, perform relay-assisted over-the-air model aggregation; see Section II-C 53 | * __Single()__: Given the local model changes, perform conventional over-the-air model aggregation; see Section II-B 54 | * __Xu()__: Given the local model changes, perform relay-assisted over-the-air model aggregation scheme proposed in Ref. [23] 55 | * __train_script.py__: 56 | * __Load_fmnist_iid()__: Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices 57 | * __Load_fmnist_noniid()__: Download (if needed) and load the Fashion-MNIST data, and distribute them to the local devices by following a non-iid distribution 58 | * __local_update()__: Given a learning model and the distributed training data, compute the local gradients/model changes 59 | * __test_model()__: Given a learning model, test the accuracy/loss based on certain test images 60 | * __plot_result.py__: plot the figure with varying transmission blocks from the output files in store/, process and store the training results in matlab/training_result/ 61 | * __plot_Pr.py__: plot the figure with varying P_r from the output files in store/, process and store the training results in matlab/training_result/ 62 | 63 | 64 | ## How to Use 65 | 1. Use the codes for channel models in **matlab/** to obtain the channel coefficients. 66 | 67 | 2. The main file for optimization in Matlab is **matlab/main_cmp.m**, which optimizes the variables of the proposed relay-assisted scheme and benchmark schemes. 68 | 69 | Run **matlab/main_cmp.m**, the obtained optimization results are then used for FL. 70 | 71 | 3. The main file for FL is **main.py**. It can take the following user-input parameters by a parser (also see the function **initial()** in main.py): 72 | 73 | | Parameter Name | Meaning| Default Value| Type/Range | 74 | | ---------- | -----------|-----------|-----------| 75 | | K | total number of devices |20 |int | 76 | | N | total number of relays |1 |int | 77 | | PL | path loss exponent |3.0 |float | 78 | | trial | total number of Monte Carlo trials |50 |int | 79 | | SNR | -noise variance in dB |100 |float | 80 | | P_r | relay transmit power budget |0.1 |float | 81 | | verbose | output no/importatnt/detailed messages in running the scripts |1 |0, 1 | 82 | | seed | random seed |1 |int | 83 | | gpu | GPU index used for learning (if possible) |0 |int | 84 | | local_ep | number of local epochs, E |1 |int | 85 | | local_bs | local batch size, B, 0 for full batch |0 |int | 86 | | lr | learning rate, lambda |0.05 |float | 87 | | low_lr | learning rate lower bound, bar_lambda |1e-5 |float | 88 | | gamma | learning rate decrease ratio, gamma |0.9 |float | 89 | | step | learning rate decrease step, bar_T |50 |int | 90 | | momentum | SGD momentum, used only for multiple local updates |0.99 |float | 91 | | epochs | number of training rounds, T |500 |int | 92 | | iid | 1 for iid, 0 for non-iid |1 |0, 1 | 93 | | noniid_level | number of classes at each device for non-iid |2 |2, 4, 6, 8, 10 | 94 | | V_idx | Variable index |0 |int | 95 | 96 | 97 | Here is an example for executing the scripts in a Linux terminal: 98 | > python main.py --gpu=0 --trial=50 --V_idx 0 99 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Callable 2 | 3 | import numpy as np 4 | import argparse 5 | import math 6 | import time 7 | import torch 8 | import copy 9 | import learning_flow 10 | import train_script 11 | from Nets import CNNMnist, MLP 12 | import scipy.io as sio 13 | 14 | 15 | def initial(): 16 | # network parameters 17 | setup = argparse.ArgumentParser() 18 | setup.add_argument('--K', type=int, default=20, help='total # of devices') 19 | setup.add_argument('--N', type=int, default=1, help='# of relays') 20 | setup.add_argument('--PL', type=float, default=3.0, help='path loss exponent') 21 | 22 | # simulation parameters 23 | setup.add_argument('--trial', type=int, default=50, help='# of Monte Carlo Trials') 24 | setup.add_argument('--SNR', type=float, default=100, help='-noise variance in dB') 25 | setup.add_argument('--P_r', type=float, default=0.1, help='relay transmit power budget 0.1W') 26 | setup.add_argument('--verbose', type=int, default=1, help=r'whether output or not') 27 | setup.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 28 | 29 | # learning parameters 30 | setup.add_argument('--gpu', type=int, default=0, help=r'Use which gpu') 31 | setup.add_argument('--local_ep', type=int, default=1, help="the number of local epochs, E") 32 | setup.add_argument('--local_bs', type=int, default=0, help="0 for no effect, local bath size, B") 33 | setup.add_argument('--lr', type=float, default=0.05, help="learning rate, lambda") 34 | setup.add_argument('--low_lr', type=float, default=1e-5, help="learning rate lower bound, bar_lambda") 35 | setup.add_argument('--gamma', type=float, default=0.9, help="learning rate decrease ratio, gamma") 36 | setup.add_argument('--step', type=int, default=50, help="learning rate decrease step, bar_T") 37 | setup.add_argument('--momentum', type=float, default=0.99, 38 | help="SGD momentum, used only for multiple local updates") 39 | setup.add_argument('--epochs', type=int, default=500, help="rounds of training, T") 40 | setup.add_argument('--iid', type=int, default=1, help="1 for iid, 0 for non-iid") 41 | setup.add_argument('--noniid_level', type=int, default=2, help="number of classes at each device for non-iid") 42 | setup.add_argument('--V_idx', type=int, default=0, help="Variable index") 43 | args = setup.parse_args() 44 | return args 45 | 46 | 47 | if __name__ == '__main__': 48 | setup = initial() 49 | np.random.seed(setup.seed) 50 | torch.manual_seed(setup.seed) 51 | 52 | setup.init_lr = copy.deepcopy(setup.lr) 53 | 54 | print(setup) 55 | 56 | data = sio.loadmat('matlab/DATA/trial_50_K_20_N_1_PL_3_Pr.mat') 57 | Pr_set = [0.01, 0.1, 0.3, 0.5, 1] 58 | V_idx = setup.V_idx 59 | setup.P_r = Pr_set[V_idx] 60 | 61 | store_filename = 'store/trial_{}_K_{}_N_{}_B_{}_E_{}_lr_{}_SNR_{}_PL_{}_Pr_{}.npz'.format(setup.trial, setup.K, 62 | setup.N, setup.local_bs, 63 | setup.local_ep, setup.lr, 64 | setup.SNR, setup.PL, 65 | setup.P_r) 66 | print(store_filename) 67 | 68 | setup.sigma = np.power(10, -setup.SNR / 10) 69 | 70 | channel_U = data['channel_U'] 71 | channel_R = data['channel_R'] 72 | channel_UR = data['channel_UR'] 73 | 74 | Proposed_a_k1 = data['Proposed_a_k1'] 75 | Proposed_a_k2 = data['Proposed_a_k2'] 76 | Proposed_b_n = data['Proposed_b_n'] 77 | Proposed_c_1 = data['Proposed_c_1'] 78 | Proposed_c_2 = data['Proposed_c_2'] 79 | 80 | Single_a_k1 = data['Single_a_k1'] 81 | Single_c_1 = data['Single_c_1'] 82 | 83 | Xu_a_k1 = data['Xu_a_k1'] 84 | Xu_b_n = data['Xu_b_n'] 85 | Xu_eta = data['Xu_eta'] 86 | 87 | MSE_1 = np.zeros([setup.trial, 2 * setup.epochs]) 88 | MSE_2 = np.zeros([setup.trial, setup.epochs]) 89 | MSE_3 = np.zeros([setup.trial, 2 * setup.epochs]) 90 | MSE_4 = np.zeros([setup.trial, setup.epochs]) 91 | MSE_5 = np.zeros([setup.trial, setup.epochs]) 92 | 93 | MSE2_1 = np.zeros([setup.trial, 2 * setup.epochs]) 94 | MSE2_2 = np.zeros([setup.trial, setup.epochs]) 95 | MSE2_3 = np.zeros([setup.trial, 2 * setup.epochs]) 96 | MSE2_4 = np.zeros([setup.trial, setup.epochs]) 97 | MSE2_5 = np.zeros([setup.trial, setup.epochs]) 98 | 99 | result_store = [] 100 | 101 | result_set = [] 102 | result_CNN_set = [] 103 | result_CNN_MB_set = [] 104 | 105 | print(torch.__version__) 106 | 107 | print(torch.version.cuda) 108 | print(torch.backends.cudnn.version()) 109 | setup.device = torch.device( 110 | 'cuda:{}'.format(setup.gpu) if torch.cuda.is_available() and setup.gpu != -1 else 'cpu') 111 | print(setup.device) 112 | 113 | for i in range(setup.trial): 114 | print('This is the {0}-th trial'.format(i)) 115 | 116 | setup.h_k = channel_U[: setup.K, i] 117 | setup.f_n = channel_R[: setup.N, i] 118 | setup.g_kn = channel_UR[: setup.K, : setup.N, i] 119 | 120 | p_a_k1 = Proposed_a_k1[V_idx, : setup.K, i] 121 | p_a_k2 = Proposed_a_k2[V_idx, : setup.K, i] 122 | p_b_n = Proposed_b_n[V_idx, : setup.N, i] 123 | p_c_1 = Proposed_c_1[V_idx, i] 124 | p_c_2 = Proposed_c_2[V_idx, i] 125 | 126 | s_a_k1 = Single_a_k1[V_idx, : setup.K, i] 127 | s_c_1 = Single_c_1[V_idx, i] 128 | 129 | x_a_k1 = Xu_a_k1[V_idx, : setup.K, i] 130 | x_b_n = Xu_b_n[V_idx, : setup.N, i] 131 | x_eta = Xu_eta[V_idx, i] 132 | 133 | Error_free = 1 134 | Proposed = 1 135 | Single_slot = 1 136 | Xu_scheme = 1 137 | 138 | result = {} 139 | 140 | if setup.iid: 141 | train_images, train_labels, test_images, test_labels, size = train_script.load_fmnist_iid(setup.K) 142 | else: 143 | train_images, train_labels, test_images, test_labels, size = train_script.load_fmnist_noniid(setup.K, 144 | setup.non_iid_level) 145 | net_glob = CNNMnist(num_classes=10, num_channels=1, batch_norm=True).to(setup.device) 146 | # net_glob = MLP(784, 64, 10).to(setup.device) 147 | 148 | setup.size = size 149 | setup.rho = np.ones(setup.K, dtype=float) * (setup.size / np.sum(setup.size)) 150 | if setup.verbose: 151 | print(net_glob) 152 | w_glob = net_glob.state_dict() 153 | w_0 = copy.deepcopy(w_glob) 154 | d = 0 155 | for item in w_glob.keys(): 156 | d = d + int(np.prod(w_glob[item].shape)) 157 | print('Total Number of Parameters={}'.format(d)) 158 | 159 | net_glob.load_state_dict(w_glob) 160 | idxs_users = np.asarray(range(setup.N)) 161 | 162 | if Error_free: 163 | print('Error_Free Channel is running') 164 | loss_train1, accuracy_test1, loss_test1, mse_1, mse2_1 = learning_flow.learning_iter(setup, d, net_glob, 165 | w_glob, train_images, 166 | train_labels, 167 | test_images, 168 | test_labels, 1, None, 169 | None, None, 170 | None, None) 171 | result['loss_train1'] = np.asarray(loss_train1) 172 | result['accuracy_test1'] = np.asarray(accuracy_test1) 173 | result['loss_test1'] = np.asarray(loss_test1) 174 | print('result {}'.format(result['accuracy_test1'][len(result['accuracy_test1']) - 1])) 175 | MSE_1[i, :] = mse_1 176 | MSE2_1[i, :] = mse2_1 177 | 178 | if Proposed: 179 | print('Proposed Scheme is running') 180 | 181 | w_glob = copy.deepcopy(w_0) 182 | net_glob.load_state_dict(w_glob) 183 | 184 | loss_train2, accuracy_test2, loss_test2, mse_2, mse2_2 = learning_flow.learning_iter(setup, d, net_glob, 185 | w_glob, train_images, 186 | train_labels, 187 | test_images, 188 | test_labels, 2, p_a_k1, 189 | p_a_k2, p_b_n, 190 | p_c_1, p_c_2) 191 | 192 | result['loss_train2'] = np.asarray(loss_train2) 193 | result['accuracy_test2'] = np.asarray(accuracy_test2) 194 | result['loss_test2'] = np.asarray(loss_test2) 195 | print('result {}'.format(result['accuracy_test2'][len(result['accuracy_test2']) - 1])) 196 | MSE_2[i, :] = mse_2 197 | MSE2_2[i, :] = mse2_2 198 | 199 | if Single_slot: 200 | print('Conventional Scheme is running') 201 | 202 | w_glob = copy.deepcopy(w_0) 203 | net_glob.load_state_dict(w_glob) 204 | 205 | loss_train3, accuracy_test3, loss_test3, mse_3, mse2_3 = learning_flow.learning_iter(setup, d, net_glob, 206 | w_glob, train_images, 207 | train_labels, 208 | test_images, 209 | test_labels, 3, 210 | s_a_k1, None, 211 | None, s_c_1, None) 212 | result['loss_train3'] = np.asarray(loss_train3) 213 | result['accuracy_test3'] = np.asarray(accuracy_test3) 214 | result['loss_test3'] = np.asarray(loss_test3) 215 | print('result {}'.format(result['accuracy_test3'][len(result['accuracy_test3']) - 1])) 216 | MSE_3[i, :] = mse_3 217 | MSE2_3[i, :] = mse2_3 218 | 219 | if Xu_scheme: 220 | print('Existing Scheme is running') 221 | 222 | w_glob = copy.deepcopy(w_0) 223 | net_glob.load_state_dict(w_glob) 224 | 225 | loss_train5, accuracy_test5, loss_test5, mse_5, mse2_5 = learning_flow.learning_iter(setup, d, net_glob, 226 | w_glob, train_images, 227 | train_labels, 228 | test_images, 229 | test_labels, 5, 230 | x_a_k1, 231 | None, x_b_n, None, 232 | x_eta) 233 | result['loss_train5'] = np.asarray(loss_train5) 234 | result['accuracy_test5'] = np.asarray(accuracy_test5) 235 | result['loss_test5'] = np.asarray(loss_test5) 236 | print('result {}'.format(result['accuracy_test5'][len(result['accuracy_test5']) - 1])) 237 | MSE_5[i, :] = mse_5 238 | MSE2_5[i, :] = mse2_5 239 | 240 | result_store.append(result) 241 | np.savez(store_filename, vars(setup), result_store, MSE_1, MSE_2, MSE_3, MSE_4, MSE_5, MSE2_1, MSE2_2, MSE2_3, 242 | MSE2_4, MSE2_5) 243 | 244 | np.savez(store_filename, vars(setup), result_store, MSE_1, MSE_2, MSE_3, MSE_4, MSE_5, MSE2_1, MSE2_2, MSE2_3, 245 | MSE2_4, MSE2_5) 246 | --------------------------------------------------------------------------------