├── FedSpecNet ├── exp │ ├── __init__.py │ └── exp_basic.py ├── utils │ ├── __init__.py │ ├── masking.py │ ├── metrics.py │ ├── tools.py │ └── timefeatures.py ├── layers │ ├── __init__.py │ ├── Transformer_EncDec.py │ ├── Embed.py │ └── SelfAttention_Family.py ├── models │ ├── __init__.py │ └── iTransformer.py ├── data_provider │ ├── __init__.py │ ├── data_factory.py │ ├── m4.py │ └── uea.py ├── _Part │ └── part_evaluate.py ├── _Support │ ├── support_SSA.py │ ├── support_SWT.py │ └── support_VMD.py ├── FFT.py ├── support_SSA.py ├── Center.py ├── run.py └── Subsequence_number_experiment.py ├── FedSpecNet_v2 ├── exp │ ├── __init__.py │ └── exp_basic.py ├── layers │ ├── __init__.py │ ├── Transformer_EncDec.py │ ├── Embed.py │ └── SelfAttention_Family.py ├── models │ ├── __init__.py │ ├── LSTM.py │ └── iTransformer.py ├── utils │ ├── __init__.py │ ├── masking.py │ ├── metrics.py │ ├── tools.py │ └── timefeatures.py ├── data_provider │ ├── __init__.py │ ├── data_factory.py │ ├── m4.py │ └── uea.py ├── model.pth ├── _Support │ ├── support_ASSA.py │ ├── support_SSA.py │ ├── support_SWT.py │ └── support_VMD.py ├── server.crt ├── FFT.py ├── server.key ├── support_SSA.py ├── run.py ├── Center.py └── Subsequence_number_experiment.py └── README.md /FedSpecNet/exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet_v2/exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet_v2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet_v2/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet_v2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /FedSpecNet/data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FedSpecNet_v2/data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FedSpecNet_v2/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fairy-09/FedSpecNet/HEAD/FedSpecNet_v2/model.pth -------------------------------------------------------------------------------- /FedSpecNet_v2/models/LSTM.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class LSTM(nn.Module): 6 | # def __init__(self, input_size=2, hidden_size=4, output_size=1, num_layer=1): 7 | def __init__(self, input_size=1, hidden_size=4, output_size=1, num_layer=1,bidirectional=True): 8 | super(LSTM, self).__init__() ##调用父类的构造函数进行初始化 9 | self.layer1 = nn.LSTM(input_size, hidden_size, num_layer,bidirectional=bidirectional) ##布尔值,指示是否使用双向LSTM 10 | self.layer2 = nn.Linear(hidden_size * 2 if bidirectional else hidden_size, output_size) 11 | ##第二层做了修改的 12 | ##第一层为LSTM,第二层为全连接层(线性),用于从LSTM的输出生成最终输出 13 | 14 | def forward(self, x): 15 | x, _ = self.layer1(x) 16 | x = torch.relu(x) 17 | s, b, h = x.size() 18 | x = x.view(s * b, h) 19 | x = self.layer2(x) 20 | x = x.view(s, b, -1) 21 | return x[:,-1,:] -------------------------------------------------------------------------------- /FedSpecNet_v2/_Support/support_ASSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def SSA(series, level): 4 | windowLen = level 5 | seriesLen = len(series) 6 | K = seriesLen - windowLen + 1 7 | X = np.zeros((windowLen, K)) 8 | for i in range(K): 9 | X[:, i] = series[i:i + windowLen] 10 | U, sigma, VT = np.linalg.svd(X) 11 | 12 | rec = np.zeros((level, seriesLen)) 13 | mask = np.full((seriesLen,), 1 / level) 14 | for j in range(level - 1): 15 | mask[j] *= level / (j + 1) 16 | mask[-j - 1] *= level / (j + 1) 17 | for i in range(level): 18 | A = np.matmul(U[:, i:i + 1], VT[i:i + 1, :]) 19 | A *= sigma[i] 20 | x = np.zeros((seriesLen,)) 21 | for j in range(level): 22 | temp_x = np.pad(A[j, :], (j, level - 1 - j), 'constant', constant_values=0.) 23 | x += temp_x 24 | x *= mask 25 | rec[i, :] = x 26 | return rec 27 | -------------------------------------------------------------------------------- /FedSpecNet/utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /FedSpecNet_v2/utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /FedSpecNet/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import r2_score 3 | 4 | 5 | def RSE(pred, true): 6 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 7 | 8 | 9 | def CORR(pred, true): 10 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 11 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 12 | return (u / d).mean(-1) 13 | 14 | 15 | def MAE(pred, true): 16 | return np.mean(np.abs(pred - true)) 17 | 18 | 19 | def MSE(pred, true): 20 | return np.mean((pred - true) ** 2) 21 | 22 | 23 | def RMSE(pred, true): 24 | return np.sqrt(MSE(pred, true)) 25 | 26 | 27 | def MAPE(pred, true): 28 | return np.mean(np.abs((pred - true) / true)) 29 | 30 | 31 | def MSPE(pred, true): 32 | return np.mean(np.square((pred - true) / true)) 33 | 34 | 35 | def metric(pred, true): 36 | mae = MAE(pred, true) 37 | mse = MSE(pred, true) 38 | rmse = RMSE(pred, true) 39 | mape = MAPE(pred, true) 40 | mspe = MSPE(pred, true) 41 | R2 = r2_score(true,pred) 42 | 43 | 44 | return mae, mse, rmse, mape, mspe,R2 45 | -------------------------------------------------------------------------------- /FedSpecNet_v2/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import r2_score 3 | 4 | 5 | def RSE(pred, true): 6 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 7 | 8 | 9 | def CORR(pred, true): 10 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 11 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 12 | return (u / d).mean(-1) 13 | 14 | 15 | def MAE(pred, true): 16 | return np.mean(np.abs(pred - true)) 17 | 18 | 19 | def MSE(pred, true): 20 | return np.mean((pred - true) ** 2) 21 | 22 | 23 | def RMSE(pred, true): 24 | return np.sqrt(MSE(pred, true)) 25 | 26 | 27 | def MAPE(pred, true): 28 | return np.mean(np.abs((pred - true) / true)) 29 | 30 | 31 | def MSPE(pred, true): 32 | return np.mean(np.square((pred - true) / true)) 33 | 34 | 35 | def metric(pred, true): 36 | mae = MAE(pred, true) 37 | mse = MSE(pred, true) 38 | rmse = RMSE(pred, true) 39 | mape = MAPE(pred, true) 40 | mspe = MSPE(pred, true) 41 | R2 = r2_score(true,pred) 42 | 43 | 44 | return mae, mse, rmse, mape, mspe,R2 45 | -------------------------------------------------------------------------------- /FedSpecNet_v2/exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import iTransformer 4 | 5 | 6 | 7 | class Exp_Basic(object): 8 | def __init__(self, args): 9 | self.args = args 10 | self.model_dict = { 11 | 'iTransformer': iTransformer 12 | } 13 | self.device = self._acquire_device() 14 | self.model = self._build_model().to(self.device) 15 | 16 | 17 | def _build_model(self): 18 | raise NotImplementedError 19 | return None 20 | 21 | def _acquire_device(self): 22 | if self.args.use_gpu: 23 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 24 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 25 | device = torch.device('cuda:0') 26 | # print('Use GPU: cuda:{}'.format(self.args.gpu)) 27 | else: 28 | device = torch.device('cpu') 29 | print('Use CPU') 30 | return device 31 | 32 | def _get_data(self): 33 | pass 34 | 35 | def vali(self): 36 | pass 37 | 38 | def train(self): 39 | pass 40 | 41 | def test(self): 42 | pass 43 | -------------------------------------------------------------------------------- /FedSpecNet/exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import iTransformer 4 | 5 | 6 | 7 | class Exp_Basic(object): 8 | def __init__(self, args): 9 | self.args = args 10 | self.model_dict = { 11 | 'iTransformer': iTransformer 12 | } 13 | self.device = self._acquire_device() 14 | self.model = self._build_model().to(self.device) 15 | 16 | 17 | def _build_model(self): 18 | raise NotImplementedError 19 | return None 20 | 21 | def _acquire_device(self): 22 | if self.args.use_gpu: 23 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 24 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 25 | device = torch.device('cuda:{}'.format(self.args.gpu)) 26 | # print('Use GPU: cuda:{}'.format(self.args.gpu)) 27 | else: 28 | device = torch.device('cpu') 29 | print('Use CPU') 30 | return device 31 | 32 | def _get_data(self): 33 | pass 34 | 35 | def vali(self): 36 | pass 37 | 38 | def train(self): 39 | pass 40 | 41 | def test(self): 42 | pass 43 | -------------------------------------------------------------------------------- /FedSpecNet/_Part/part_evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from sklearn.metrics import mean_squared_error 4 | from sklearn.metrics import mean_absolute_error 5 | from sklearn.metrics import r2_score 6 | 7 | 8 | def MAPE1(true,predict): 9 | 10 | L1 = int(len(true)) 11 | L2 = int(len(predict)) 12 | 13 | if L1 == L2: 14 | 15 | SUM = 0.0 16 | for i in range(L1): 17 | if true[i] == 0: 18 | SUM = abs(predict[i]) + SUM 19 | else: 20 | SUM = abs((true[i] - predict[i]) / true[i]) + SUM 21 | per_SUM = SUM * 100.0 22 | mape = per_SUM / L1 23 | return mape 24 | else: 25 | print("error") 26 | 27 | 28 | def RMSE1(true_data, predict_data): 29 | testY = true_data 30 | testPredict = predict_data 31 | rmse = math.sqrt( mean_squared_error(testY[:], testPredict[:])) 32 | return rmse 33 | 34 | 35 | def MAE1(true_data, predict_data): 36 | testY = true_data 37 | testPredict = predict_data 38 | mae=mean_absolute_error(testY[:], testPredict[:]) 39 | return mae 40 | 41 | 42 | def R2(y_true, y_predict): 43 | score = r2_score(y_true, y_predict) 44 | return score 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /FedSpecNet_v2/server.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDazCCAlOgAwIBAgIUZdpewb8uZnh5UXlqiFrqyqlCdLgwDQYJKoZIhvcNAQEL 3 | BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM 4 | GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTA3MDEwMjE5MjFaFw0zNTA2 5 | MjkwMjE5MjFaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw 6 | HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB 7 | AQUAA4IBDwAwggEKAoIBAQCj7XUa5wYGI8A0VBiq16Pi3yIeMXguQog4txiWV/QH 8 | Q7AFMbrXbniW3IvbrlWhO9rZESHv9Et81/zpSBqPDSM0QQk//XzKvfqBfiQS8ilG 9 | DSbVLTyecfQwW8FbgJ23zWPwDrJfR9D8Q73Ch8MclECKiZWkVkbJ9jvQmRTFKVUO 10 | hjZ4f1tDCzyjHjG+E/Muq+0Hyxz/CZewSumf4RvS1AAWa0RLG2/NXm4v/fSGtbFK 11 | qq8ITL/Ry7ZxoGHm8CtiJvb2FDS4FWvPNm1w/R4uTWb19vNs1/QkgW8CLaGVNjaD 12 | p2fO8WakrjhLee/FucttIIPd0F5/nMoEwM58cB5LXG1fAgMBAAGjUzBRMB0GA1Ud 13 | DgQWBBS6Lbemsc1Ss8/O8C93cuRROoB+tDAfBgNVHSMEGDAWgBS6Lbemsc1Ss8/O 14 | 8C93cuRROoB+tDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAb 15 | OiKJSEcD4ygkLAEywX3pB46N/ET9GRZd6Z5Iu76i4vBIQqhDq9P8Jjaw5e+5n4rz 16 | I3eK5HLPwFv8wUOrZdAqBZYs+qxpAFkuPBKcql4sgDrVhcDdt6H0NR5Yys6lWG+/ 17 | tikjlNQJjfqQi6mgL9/LjRcZFP6HiTR2aOjDIR24LckyOUrNaOKByAzMW0L4lQod 18 | WMe45enw8S0saVYufihBMb5QhmFfRSpB0ybNEBlJgJL9B6OcmHM12QpifoA93d3Q 19 | /FBgHLZFEFriars2S69ajWEGV52akaRlH7QNRNQHqUinFU5cRmy8aJX+BfOSKPv3 20 | 6DaJWXiWyJXWVYKJd57F 21 | -----END CERTIFICATE----- 22 | -------------------------------------------------------------------------------- /FedSpecNet/_Support/support_SSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | 6 | def SSA(series, level): 7 | 8 | series = series - np.mean(series) 9 | 10 | 11 | windowLen = level 12 | seriesLen = len(series) 13 | K = seriesLen - windowLen + 1 14 | X = np.zeros((windowLen, K)) 15 | for i in range(K): 16 | X[:, i] = series[i:i + windowLen] 17 | 18 | 19 | U, sigma, VT = np.linalg.svd(X, full_matrices=False) 20 | 21 | for i in range(VT.shape[0]): 22 | VT[i, :] *= sigma[i] 23 | A = VT 24 | 25 | 26 | rec = np.zeros((windowLen, seriesLen)) 27 | for i in range(windowLen): 28 | for j in range(windowLen - 1): 29 | for m in range(j + 1): 30 | rec[i, j] += A[i, j - m] * U[m, i] 31 | rec[i, j] /= (j + 1) 32 | for j in range(windowLen - 1, seriesLen - windowLen + 1): 33 | for m in range(windowLen): 34 | rec[i, j] += A[i, j - m] * U[m, i] 35 | rec[i, j] /= windowLen 36 | for j in range(seriesLen - windowLen + 1, seriesLen): 37 | for m in range(j - seriesLen + windowLen, windowLen): 38 | rec[i, j] += A[i, j - m] * U[m, i] 39 | rec[i, j] /= (seriesLen - j) 40 | 41 | 42 | res = pd.DataFrame(rec.T, columns=[f'rec_{i + 1}' for i in range(windowLen)]) 43 | 44 | res.to_csv('output.csv', index=False) 45 | 46 | return rec 47 | 48 | -------------------------------------------------------------------------------- /FedSpecNet_v2/_Support/support_SSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | 6 | def SSA(series, level): 7 | 8 | series = series - np.mean(series) 9 | 10 | 11 | windowLen = level 12 | seriesLen = len(series) 13 | K = seriesLen - windowLen + 1 14 | X = np.zeros((windowLen, K)) 15 | for i in range(K): 16 | X[:, i] = series[i:i + windowLen] 17 | 18 | 19 | U, sigma, VT = np.linalg.svd(X, full_matrices=False) 20 | 21 | for i in range(VT.shape[0]): 22 | VT[i, :] *= sigma[i] 23 | A = VT 24 | 25 | 26 | rec = np.zeros((windowLen, seriesLen)) 27 | for i in range(windowLen): 28 | for j in range(windowLen - 1): 29 | for m in range(j + 1): 30 | rec[i, j] += A[i, j - m] * U[m, i] 31 | rec[i, j] /= (j + 1) 32 | for j in range(windowLen - 1, seriesLen - windowLen + 1): 33 | for m in range(windowLen): 34 | rec[i, j] += A[i, j - m] * U[m, i] 35 | rec[i, j] /= windowLen 36 | for j in range(seriesLen - windowLen + 1, seriesLen): 37 | for m in range(j - seriesLen + windowLen, windowLen): 38 | rec[i, j] += A[i, j - m] * U[m, i] 39 | rec[i, j] /= (seriesLen - j) 40 | 41 | 42 | res = pd.DataFrame(rec.T, columns=[f'rec_{i + 1}' for i in range(windowLen)]) 43 | 44 | res.to_csv('output.csv', index=False) 45 | 46 | return rec 47 | 48 | -------------------------------------------------------------------------------- /FedSpecNet/FFT.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | np.random.seed(0) 6 | timestamps = ["6/26/2018 6:00:00 PM", "6/26/2018 6:05:00 PM", "6/26/2018 6:10:00 PM"] 7 | for i in range(997): 8 | ts = datetime.datetime.strptime(timestamps[-1], "%m/%d/%Y %I:%M:%S %p") + datetime.timedelta(minutes=5) 9 | timestamps.append(ts.strftime("%m/%d/%Y %I:%M:%S %p")) 10 | 11 | data = pd.read_csv('/home/liujian/project/Time-Series-Library-main/dataset-new/house1_5min_KWh.csv', parse_dates=['date']) 12 | data.set_index('date', inplace=True) 13 | values = data['value'].values[:8400] 14 | datetime_objects = [datetime.datetime.strptime(ts, "%m/%d/%Y %I:%M:%S %p") for ts in timestamps] 15 | time_deltas = [(dt - datetime_objects[0]).total_seconds() / 60 for dt in datetime_objects] 16 | 17 | Fs = 1 / (time_deltas[1] - time_deltas[0]) 18 | T = 1 / Fs 19 | Y = np.fft.fft(values) 20 | Y[0]=0 21 | f = np.fft.fftfreq(len(values), d=T) 22 | 23 | plt.figure(figsize=(10, 5)) 24 | plt.subplots_adjust(left=0.05, bottom=0.125, right=0.99, top=0.94) 25 | plt.title("House1_frequency_spectrum",fontsize=20) 26 | positive_f = f[f > 0] 27 | positive_Y = Y[:len(positive_f)] 28 | amplitudes = np.abs(positive_Y) 29 | plt.plot(positive_f, amplitudes) 30 | plt.xlabel('Frequency (Hz)',fontsize=12) 31 | plt.ylabel('Amplitude',fontsize=12) 32 | 33 | plt.show() 34 | 35 | 36 | -------------------------------------------------------------------------------- /FedSpecNet_v2/FFT.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | np.random.seed(0) 6 | timestamps = ["6/26/2018 6:00:00 PM", "6/26/2018 6:05:00 PM", "6/26/2018 6:10:00 PM"] 7 | for i in range(997): 8 | ts = datetime.datetime.strptime(timestamps[-1], "%m/%d/%Y %I:%M:%S %p") + datetime.timedelta(minutes=5) 9 | timestamps.append(ts.strftime("%m/%d/%Y %I:%M:%S %p")) 10 | 11 | data = pd.read_csv('/home/liujian/project/Time-Series-Library-main/dataset-new/house1_5min_KWh.csv', parse_dates=['date']) 12 | data.set_index('date', inplace=True) 13 | values = data['value'].values[:8400] 14 | datetime_objects = [datetime.datetime.strptime(ts, "%m/%d/%Y %I:%M:%S %p") for ts in timestamps] 15 | time_deltas = [(dt - datetime_objects[0]).total_seconds() / 60 for dt in datetime_objects] 16 | 17 | Fs = 1 / (time_deltas[1] - time_deltas[0]) 18 | T = 1 / Fs 19 | Y = np.fft.fft(values) 20 | Y[0]=0 21 | f = np.fft.fftfreq(len(values), d=T) 22 | 23 | plt.figure(figsize=(10, 5)) 24 | plt.subplots_adjust(left=0.05, bottom=0.125, right=0.99, top=0.94) 25 | plt.title("House1_frequency_spectrum",fontsize=20) 26 | positive_f = f[f > 0] 27 | positive_Y = Y[:len(positive_f)] 28 | amplitudes = np.abs(positive_Y) 29 | plt.plot(positive_f, amplitudes) 30 | plt.xlabel('Frequency (Hz)',fontsize=12) 31 | plt.ylabel('Amplitude',fontsize=12) 32 | 33 | plt.show() 34 | 35 | 36 | -------------------------------------------------------------------------------- /FedSpecNet_v2/server.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCj7XUa5wYGI8A0 3 | VBiq16Pi3yIeMXguQog4txiWV/QHQ7AFMbrXbniW3IvbrlWhO9rZESHv9Et81/zp 4 | SBqPDSM0QQk//XzKvfqBfiQS8ilGDSbVLTyecfQwW8FbgJ23zWPwDrJfR9D8Q73C 5 | h8MclECKiZWkVkbJ9jvQmRTFKVUOhjZ4f1tDCzyjHjG+E/Muq+0Hyxz/CZewSumf 6 | 4RvS1AAWa0RLG2/NXm4v/fSGtbFKqq8ITL/Ry7ZxoGHm8CtiJvb2FDS4FWvPNm1w 7 | /R4uTWb19vNs1/QkgW8CLaGVNjaDp2fO8WakrjhLee/FucttIIPd0F5/nMoEwM58 8 | cB5LXG1fAgMBAAECggEAAY71H2JBu9U1dmJGScKluh7POPztgUenHu2UBU83Ob5D 9 | kJtczjnLBHTofra7StC8AEYyYASrHPTrYIP/ilq1z1mOLAEvZ3u9W9dSZW085Ivh 10 | 8CbKtO6IaCZuBOvqYyxgIVO8WKsqHddEvZzkzWTAwLT9ACjdPzra590oIWSrZA3Y 11 | 87U2Mll4D2qgxkyL2D5UlzvO4W9vDT9+UC3/khrrtNqZxqVszGk+xE+9KiW08q6/ 12 | AEYCQyXmL9QjIOtjZDVhZWJ+RClDwnx2YkQUQXHRrI4090Vh8L1bpZTq3pPT+4sM 13 | NM1wMfO5tBbr7irzTEmsy+3gZIWmv71lHXQ9SQtL4QKBgQDnXsZvVeVoMawJXmDb 14 | 5oZCH8XSJXBpkz+pBYfIjGHs4Dg97jzdjABqfa5QB26TXlIJMYxI4QcUsYuvTNxE 15 | 7jbqq91Cx55wAuL14XDRv9tXYRDNCpm3MZ9XUx2evaZ2nuosZvPcUQD5GdLDrgPN 16 | oTq4Dt4/rNvz+qf1/iGaBFUpdQKBgQC1YMHsB0TFcav+g2SSnRvG1HKKW9EsWKcF 17 | Qch72fdRXgbSMk8vl96gSdEwNIe+NBVGhoIQaNYVfxs7DPOvNT6XZueSno8fCnRL 18 | FOoeLX7J4/59qpxwb2AsO7R+AGcEoejFckWGMdJAIYI5Q0tDh3yiS7H39Uncqwam 19 | fH960/oNAwKBgQDJxJzH+uBUPP0KLoPJQP9UKuEYog2mBANAItKG0eWT7PUfDOOH 20 | UNAMDg41PEXxvg9MdSkhZRwHr81g0mZEtnitrbMGyY4hoGLMig0Y3XcqfDtqlBP4 21 | 7g1G2fS6uiiwyWTt09pWB04R7bMfcmFesXvDhzPJ07T8z1op67Th22VpAQKBgQCg 22 | PUI1rtpAAUPtT9GLgcdnaoti1uk+X3f3Of3QBWns9b/a9d1lc6uYOn7YMqB2Fndx 23 | XiYML0JrrWa2TaP/9287vQr7Sp+w0cCaEHkhfhoUbRuJlDAvWQZDLeAwrVRWEGCg 24 | B2uKuftA8xmkU2Jr34fpriwlnwvJK0Nt1HGfZyTzVwKBgQDKH9D/BQ8PhSHS6MkD 25 | +m/O5v9HBEZj30jf0KR4GJvCPvrqk0OMCjJepMSnqEW0K7bBmvMLH6LVMuT+5/F+ 26 | 2aR7a8rsKng7lwnxrs/E7L0GzmZxX0h+BRu+UuiX/XzbBeRdik6+dHeEJlKclsDY 27 | dOTkETjPWqwe1kahvSe0Ae1hdQ== 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /FedSpecNet/support_SSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | 4 | #data:(9608,1) 5 | def SSA(series, level): ## 输入时间序列数据,嵌入窗口长度 6 | # series = 0 7 | # series = series - np.mean(series) # 中心化(非必须) 8 | original_mean = np.mean(series) # 保存原始数据的平均值### 9 | series = series - original_mean # 中心化 10 | 11 | # step1 嵌入 12 | windowLen = level # 嵌入窗口长度4 13 | seriesLen = len(series) # 序列长度9608 14 | K = seriesLen - windowLen + 1 #K:(9605,) 嵌入后矩阵列数 15 | series = series.flatten() 16 | X = np.zeros((windowLen, K)) #X:(4,9605) 初始化零矩阵 17 | for i in range(K): 18 | X[:, i] = series[i:i + windowLen] 19 | ##将长度为 `windowLen` 的滑动窗口在原始时间序列`series` 上进行滑动,提取子序列并将其作为矩阵X的列 20 | 21 | # step2: svd奇异值分解, U和sigma已经按升序排序 22 | U, sigma, VT = np.linalg.svd(X, full_matrices=False) #U:(4,4) 23 | 24 | for i in range(VT.shape[0]): 25 | VT[i, :] *= sigma[i] 26 | A = VT #A:(4,9605),VT:(4,9605) ## A=VT*Σ 27 | 28 | # 重构时间序列 ## 这部分看论文 29 | rec = np.zeros((windowLen, seriesLen)) ## 初始化零矩阵,用于存储重构的时间序列 30 | for i in range(windowLen): 31 | for j in range(windowLen - 1): 32 | for m in range(j + 1): 33 | rec[i, j] += A[i, j - m] * U[m, i] 34 | rec[i, j] /= (j + 1) ## 过去值的加权平均 35 | for j in range(windowLen - 1, seriesLen - windowLen + 1): 36 | for m in range(windowLen): 37 | rec[i, j] += A[i, j - m] * U[m, i] 38 | rec[i, j] /= windowLen ## 固定窗口长度内值的加权平均 39 | for j in range(seriesLen - windowLen + 1, seriesLen): 40 | for m in range(j - seriesLen + windowLen, windowLen): 41 | rec[i, j] += A[i, j - m] * U[m, i] 42 | rec[i, j] /= (seriesLen - j) ## 过去值的加权平均 43 | # for i in range(windowLen): 44 | # rec[i, :] += original_mean 45 | return rec 46 | 47 | ''' 48 | 原始时间序列 -> 中心化 -> 轨迹矩阵(嵌入)-> SVD -> 重构 -> 重构时间序列 49 | ''' 50 | # rrr = np.sum(rec, axis=0) # 选择重构的部分,这里选了全部 51 | # 52 | # plt.figure() 53 | # for i in range(10): 54 | # ax = plt.subplot(5, 2, i + 1) 55 | # ax.plot(rec[i, :]) 56 | # 57 | # plt.figure(2) 58 | # plt.plot(series) 59 | # plt.show() -------------------------------------------------------------------------------- /FedSpecNet_v2/support_SSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import matplotlib.pyplot as plt 3 | 4 | #data:(9608,1) 5 | def SSA(series, level): ## 输入时间序列数据,嵌入窗口长度 6 | # series = 0 7 | # series = series - np.mean(series) # 中心化(非必须) 8 | original_mean = np.mean(series) # 保存原始数据的平均值### 9 | series = series - original_mean # 中心化 10 | 11 | # step1 嵌入 12 | windowLen = level # 嵌入窗口长度4 13 | seriesLen = len(series) # 序列长度9608 14 | K = seriesLen - windowLen + 1 #K:(9605,) 嵌入后矩阵列数 15 | series = series.flatten() 16 | X = np.zeros((windowLen, K)) #X:(4,9605) 初始化零矩阵 17 | for i in range(K): 18 | X[:, i] = series[i:i + windowLen] 19 | ##将长度为 `windowLen` 的滑动窗口在原始时间序列`series` 上进行滑动,提取子序列并将其作为矩阵X的列 20 | 21 | # step2: svd奇异值分解, U和sigma已经按升序排序 22 | U, sigma, VT = np.linalg.svd(X, full_matrices=False) #U:(4,4) 23 | 24 | for i in range(VT.shape[0]): 25 | VT[i, :] *= sigma[i] 26 | A = VT #A:(4,9605),VT:(4,9605) ## A=VT*Σ 27 | 28 | # 重构时间序列 ## 这部分看论文 29 | rec = np.zeros((windowLen, seriesLen)) ## 初始化零矩阵,用于存储重构的时间序列 30 | for i in range(windowLen): 31 | for j in range(windowLen - 1): 32 | for m in range(j + 1): 33 | rec[i, j] += A[i, j - m] * U[m, i] 34 | rec[i, j] /= (j + 1) ## 过去值的加权平均 35 | for j in range(windowLen - 1, seriesLen - windowLen + 1): 36 | for m in range(windowLen): 37 | rec[i, j] += A[i, j - m] * U[m, i] 38 | rec[i, j] /= windowLen ## 固定窗口长度内值的加权平均 39 | for j in range(seriesLen - windowLen + 1, seriesLen): 40 | for m in range(j - seriesLen + windowLen, windowLen): 41 | rec[i, j] += A[i, j - m] * U[m, i] 42 | rec[i, j] /= (seriesLen - j) ## 过去值的加权平均 43 | # for i in range(windowLen): 44 | # rec[i, :] += original_mean 45 | return rec 46 | 47 | ''' 48 | 原始时间序列 -> 中心化 -> 轨迹矩阵(嵌入)-> SVD -> 重构 -> 重构时间序列 49 | ''' 50 | # rrr = np.sum(rec, axis=0) # 选择重构的部分,这里选了全部 51 | # 52 | # plt.figure() 53 | # for i in range(10): 54 | # ax = plt.subplot(5, 2, i + 1) 55 | # ax.plot(rec[i, :]) 56 | # 57 | # plt.figure(2) 58 | # plt.plot(series) 59 | # plt.show() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔥 FedSpecNet 2 | An advanced method that combines federated learning, representation learning, and deep learning is proposed for household energy consumption forecasting. 3 | 4 | # 📚 Requirements 5 | Dependent libraries are placed in a file named 'requirements.txt'. 6 | Requirements(Briefly): 7 | python==3.7.8 8 | torch==1.7.1 9 | torchvision==0.8.2 10 | torchaudio==0.7.2 11 | pandas 12 | matplotlib 13 | scikit-learn 14 | 15 | # 📒 Usage 16 | 1. Run 'Center.py' in Server. Run 'iT_SSAx.py' in Client-x. 17 | 2. Modify the corresponding 'host' and 'port' code case by case. 18 | 19 | # 🌟 Citation 20 | 21 | If you find our repo useful for your research, please consider giving a 🌟 and citing our previous work below, which are the foundation of this work. 22 | 23 | ``` 24 | @article{JIN2022101442, 25 | title = {Highly accurate energy consumption forecasting model based on parallel LSTM neural networks}, 26 | journal = {Advanced Engineering Informatics}, 27 | volume = {51}, 28 | pages = {101442}, 29 | year = {2022}, 30 | issn = {1474-0346}, 31 | doi = {https://doi.org/10.1016/j.aei.2021.101442}, 32 | url = {https://www.sciencedirect.com/science/article/pii/S1474034621001944}, 33 | author = {Ning Jin and Fan Yang and Yuchang Mo and Yongkang Zeng and Xiaokang Zhou and Ke Yan and Xiang Ma} 34 | } 35 | 36 | @article{YANG2023101846, 37 | title = {Multiple households energy consumption forecasting using consistent modeling with privacy preservation}, 38 | journal = {Advanced Engineering Informatics}, 39 | volume = {55}, 40 | pages = {101846}, 41 | year = {2023}, 42 | issn = {1474-0346}, 43 | doi = {https://doi.org/10.1016/j.aei.2022.101846}, 44 | url = {https://www.sciencedirect.com/science/article/pii/S1474034622003044}, 45 | author = {Fan Yang and Ke Yan and Ning Jin and Yang Du} 46 | } 47 | 48 | @article{LIU2024114894, 49 | title = {Household energy consumption forecasting based on adaptive signal decomposition enhanced iTransformer network}, 50 | journal = {Energy and Buildings}, 51 | volume = {324}, 52 | pages = {114894}, 53 | year = {2024}, 54 | issn = {0378-7788}, 55 | doi = {https://doi.org/10.1016/j.enbuild.2024.114894}, 56 | url = {https://www.sciencedirect.com/science/article/pii/S0378778824010107}, 57 | author = {Jian Liu and Fan Yang and Ke Yan and Liangliang Jiang} 58 | } 59 | ``` 60 | 61 | # 🔐 License 62 | The source code is free for research and education use only. Any commercial use should get formal permission first. 63 | -------------------------------------------------------------------------------- /FedSpecNet/_Support/support_SWT.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | from matplotlib import pyplot as plt 3 | 4 | def swt_decom(data, wavefunc, lv): 5 | 6 | coeffs_list = [] 7 | data_to_decom = data 8 | A = None 9 | 10 | for i in range(lv): 11 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 12 | data_to_decom = A 13 | coeffs_list.insert(0, D) 14 | coeffs_list.insert(0, A) 15 | 16 | return coeffs_list 17 | 18 | 19 | 20 | def swt_decom_high(data, wavefunc, lv): 21 | coeffs_list = [] 22 | data_to_decom = data 23 | D = None 24 | 25 | for i in range(lv): 26 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 27 | data_to_decom = D 28 | coeffs_list.insert(0, A) 29 | coeffs_list.insert(0, D) 30 | 31 | return coeffs_list 32 | 33 | 34 | 35 | def iswt_decom(data, wavefunc): 36 | 37 | y = data[0] 38 | for i in range(len(data) - 1): 39 | y = pywt.iswt([(y, data[i+1])], wavefunc) 40 | return y 41 | 42 | 43 | 44 | def iswt_decom_high(data, wavefunc): 45 | y = data[0] 46 | for i in range(len(data) - 1): 47 | y = pywt.iswt([(data[i + 1], y)], wavefunc) 48 | return y 49 | 50 | 51 | 52 | def swpt_decom_V1(data, wavefunc, lv): 53 | data_to_decom = data 54 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 55 | coeffs_list_low = swt_decom(A, wavefunc, lv - 1) 56 | coeffs_list_high = swt_decom_high(D, wavefunc, lv - 1) 57 | coeffs_list = coeffs_list_low + coeffs_list_high 58 | return coeffs_list 59 | 60 | 61 | def iswpt_decom_V1(data, wavefunc): 62 | coeffs_list_low = data[:(len(data))//2] 63 | coeffs_list_high = data[:(len(data))//2] 64 | coeffs_low = iswt_decom(coeffs_list_low, wavefunc) 65 | coeffs_high = iswt_decom_high(coeffs_list_high, wavefunc) 66 | original = iswt_decom([coeffs_low, coeffs_high], wavefunc) 67 | return original 68 | 69 | 70 | def swpt_decom_V2(data, wavefunc, lv): 71 | data_to_decom = data 72 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 73 | coeffs_list_low = swt_decom(A, wavefunc, lv - 1) 74 | coeffs_list_high = swt_decom(D, wavefunc, lv - 1) 75 | coeffs_list = coeffs_list_low + coeffs_list_high 76 | return coeffs_list 77 | 78 | 79 | def iswpt_decom_V2(data, wavefunc): 80 | coeffs_list_low = data[:(len(data))//2] 81 | coeffs_list_high = data[:(len(data))//2] 82 | coeffs_low = iswt_decom(coeffs_list_low, wavefunc) 83 | coeffs_high = iswt_decom(coeffs_list_high, wavefunc) 84 | original = iswt_decom([coeffs_low, coeffs_high], wavefunc) 85 | return original 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /FedSpecNet_v2/_Support/support_SWT.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | from matplotlib import pyplot as plt 3 | 4 | def swt_decom(data, wavefunc, lv): 5 | 6 | coeffs_list = [] 7 | data_to_decom = data 8 | A = None 9 | 10 | for i in range(lv): 11 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 12 | data_to_decom = A 13 | coeffs_list.insert(0, D) 14 | coeffs_list.insert(0, A) 15 | 16 | return coeffs_list 17 | 18 | 19 | 20 | def swt_decom_high(data, wavefunc, lv): 21 | coeffs_list = [] 22 | data_to_decom = data 23 | D = None 24 | 25 | for i in range(lv): 26 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 27 | data_to_decom = D 28 | coeffs_list.insert(0, A) 29 | coeffs_list.insert(0, D) 30 | 31 | return coeffs_list 32 | 33 | 34 | 35 | def iswt_decom(data, wavefunc): 36 | 37 | y = data[0] 38 | for i in range(len(data) - 1): 39 | y = pywt.iswt([(y, data[i+1])], wavefunc) 40 | return y 41 | 42 | 43 | 44 | def iswt_decom_high(data, wavefunc): 45 | y = data[0] 46 | for i in range(len(data) - 1): 47 | y = pywt.iswt([(data[i + 1], y)], wavefunc) 48 | return y 49 | 50 | 51 | 52 | def swpt_decom_V1(data, wavefunc, lv): 53 | data_to_decom = data 54 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 55 | coeffs_list_low = swt_decom(A, wavefunc, lv - 1) 56 | coeffs_list_high = swt_decom_high(D, wavefunc, lv - 1) 57 | coeffs_list = coeffs_list_low + coeffs_list_high 58 | return coeffs_list 59 | 60 | 61 | def iswpt_decom_V1(data, wavefunc): 62 | coeffs_list_low = data[:(len(data))//2] 63 | coeffs_list_high = data[:(len(data))//2] 64 | coeffs_low = iswt_decom(coeffs_list_low, wavefunc) 65 | coeffs_high = iswt_decom_high(coeffs_list_high, wavefunc) 66 | original = iswt_decom([coeffs_low, coeffs_high], wavefunc) 67 | return original 68 | 69 | 70 | def swpt_decom_V2(data, wavefunc, lv): 71 | data_to_decom = data 72 | [(A, D)] = pywt.swt(data_to_decom, wavefunc, level=1, axis=0) 73 | coeffs_list_low = swt_decom(A, wavefunc, lv - 1) 74 | coeffs_list_high = swt_decom(D, wavefunc, lv - 1) 75 | coeffs_list = coeffs_list_low + coeffs_list_high 76 | return coeffs_list 77 | 78 | 79 | def iswpt_decom_V2(data, wavefunc): 80 | coeffs_list_low = data[:(len(data))//2] 81 | coeffs_list_high = data[:(len(data))//2] 82 | coeffs_low = iswt_decom(coeffs_list_low, wavefunc) 83 | coeffs_high = iswt_decom(coeffs_list_high, wavefunc) 84 | original = iswt_decom([coeffs_low, coeffs_high], wavefunc) 85 | return original 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /FedSpecNet/data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \ 2 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader,Dataset_UK_DALE_minute 3 | from data_provider.uea import collate_fn 4 | from torch.utils.data import DataLoader 5 | 6 | data_dict = { 7 | 'UK-DALE': Dataset_UK_DALE_minute, 8 | 'ETTh1': Dataset_ETT_hour, 9 | 'ETTh2': Dataset_ETT_hour, 10 | 'ETTm1': Dataset_ETT_minute, 11 | 'ETTm2': Dataset_ETT_minute, 12 | 'custom': Dataset_Custom, 13 | 'm4': Dataset_M4, 14 | 'PSM': PSMSegLoader, 15 | 'MSL': MSLSegLoader, 16 | 'SMAP': SMAPSegLoader, 17 | 'SMD': SMDSegLoader, 18 | 'SWAT': SWATSegLoader, 19 | 'UEA': UEAloader 20 | } 21 | 22 | 23 | def data_provider(args, flag): 24 | Data = data_dict[args.data] 25 | timeenc = 0 if args.embed != 'timeF' else 1 26 | 27 | if flag == 'test': 28 | shuffle_flag = False 29 | drop_last = True 30 | if args.task_name == 'anomaly_detection' or args.task_name == 'classification': 31 | batch_size = args.batch_size 32 | else: 33 | batch_size = 1 # bsz=1 for evaluation 34 | freq = args.freq 35 | else: 36 | # interval = args.interval ####### 37 | shuffle_flag = False #True 38 | drop_last = True 39 | batch_size = args.batch_size # bsz for train and valid 40 | freq = args.freq 41 | 42 | if args.task_name == 'anomaly_detection': 43 | drop_last = False 44 | data_set = Data( 45 | root_path=args.root_path, 46 | win_size=args.seq_len, 47 | flag=flag, 48 | ) 49 | print(flag, len(data_set)) 50 | data_loader = DataLoader( 51 | data_set, 52 | batch_size=batch_size, 53 | shuffle=shuffle_flag, 54 | num_workers=args.num_workers, 55 | drop_last=drop_last) 56 | return data_set, data_loader 57 | elif args.task_name == 'classification': 58 | drop_last = False 59 | data_set = Data( 60 | root_path=args.root_path, 61 | flag=flag, 62 | ) 63 | 64 | data_loader = DataLoader( 65 | data_set, 66 | batch_size=batch_size, 67 | shuffle=shuffle_flag, 68 | num_workers=args.num_workers, 69 | drop_last=drop_last, 70 | collate_fn=lambda x: collate_fn(x, max_len=args.seq_len) 71 | ) 72 | return data_set, data_loader 73 | else: 74 | if args.data == 'm4': 75 | drop_last = False 76 | data_set = Data( 77 | model_id = args.model_id, 78 | interval=args.interval, 79 | subsequence_num = args.subsequence_num, 80 | decomposition_method = args.decomposition_method, 81 | root_path=args.root_path, 82 | data_path=args.data_path, 83 | flag=flag, 84 | size=[args.seq_len, args.label_len, args.pred_len], 85 | features=args.features, 86 | target=args.target, 87 | timeenc=timeenc, 88 | freq=freq, 89 | seasonal_patterns=args.seasonal_patterns 90 | ) 91 | data_loader = DataLoader( 92 | data_set, 93 | batch_size=batch_size, 94 | shuffle=shuffle_flag, 95 | num_workers=args.num_workers, 96 | drop_last=drop_last) 97 | return data_set, data_loader 98 | -------------------------------------------------------------------------------- /FedSpecNet_v2/data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \ 2 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader,Dataset_UK_DALE_minute 3 | from data_provider.uea import collate_fn 4 | from torch.utils.data import DataLoader 5 | 6 | data_dict = { 7 | 'UK-DALE': Dataset_UK_DALE_minute, 8 | 'ETTh1': Dataset_ETT_hour, 9 | 'ETTh2': Dataset_ETT_hour, 10 | 'ETTm1': Dataset_ETT_minute, 11 | 'ETTm2': Dataset_ETT_minute, 12 | 'custom': Dataset_Custom, 13 | 'm4': Dataset_M4, 14 | 'PSM': PSMSegLoader, 15 | 'MSL': MSLSegLoader, 16 | 'SMAP': SMAPSegLoader, 17 | 'SMD': SMDSegLoader, 18 | 'SWAT': SWATSegLoader, 19 | 'UEA': UEAloader 20 | } 21 | 22 | 23 | def data_provider(args, flag): 24 | Data = data_dict[args.data] 25 | timeenc = 0 if args.embed != 'timeF' else 1 26 | 27 | if flag == 'test': 28 | shuffle_flag = False 29 | drop_last = True 30 | if args.task_name == 'anomaly_detection' or args.task_name == 'classification': 31 | batch_size = args.batch_size 32 | else: 33 | batch_size = 1 # bsz=1 for evaluation 34 | freq = args.freq 35 | else: 36 | # interval = args.interval ####### 37 | shuffle_flag = False #True 38 | drop_last = True 39 | batch_size = args.batch_size # bsz for train and valid 40 | freq = args.freq 41 | 42 | if args.task_name == 'anomaly_detection': 43 | drop_last = False 44 | data_set = Data( 45 | root_path=args.root_path, 46 | win_size=args.seq_len, 47 | flag=flag, 48 | ) 49 | print(flag, len(data_set)) 50 | data_loader = DataLoader( 51 | data_set, 52 | batch_size=batch_size, 53 | shuffle=shuffle_flag, 54 | num_workers=args.num_workers, 55 | drop_last=drop_last) 56 | return data_set, data_loader 57 | elif args.task_name == 'classification': 58 | drop_last = False 59 | data_set = Data( 60 | root_path=args.root_path, 61 | flag=flag, 62 | ) 63 | 64 | data_loader = DataLoader( 65 | data_set, 66 | batch_size=batch_size, 67 | shuffle=shuffle_flag, 68 | num_workers=args.num_workers, 69 | drop_last=drop_last, 70 | collate_fn=lambda x: collate_fn(x, max_len=args.seq_len) 71 | ) 72 | return data_set, data_loader 73 | else: 74 | if args.data == 'm4': 75 | drop_last = False 76 | data_set = Data( 77 | model_id = args.model_id, 78 | interval=args.interval, 79 | subsequence_num = args.subsequence_num, 80 | decomposition_method = args.decomposition_method, 81 | root_path=args.root_path, 82 | data_path=args.data_path, 83 | flag=flag, 84 | size=[args.seq_len, args.label_len, args.pred_len], 85 | features=args.features, 86 | target=args.target, 87 | timeenc=timeenc, 88 | freq=freq, 89 | seasonal_patterns=args.seasonal_patterns 90 | ) 91 | data_loader = DataLoader( 92 | data_set, 93 | batch_size=batch_size, 94 | shuffle=shuffle_flag, 95 | num_workers=args.num_workers, 96 | drop_last=drop_last) 97 | return data_set, data_loader 98 | -------------------------------------------------------------------------------- /FedSpecNet/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | plt.switch_backend('agg') 9 | 10 | 11 | def adjust_learning_rate(optimizer, epoch, args): 12 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 13 | if args.lradj == 'type1': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 15 | elif args.lradj == 'type2': 16 | lr_adjust = { 17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 18 | 10: 5e-7, 15: 1e-7, 20: 5e-8 19 | } 20 | if epoch in lr_adjust.keys(): 21 | lr = lr_adjust[epoch] 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | # print('Updating learning rate to {}'.format(lr)) 25 | 26 | 27 | class EarlyStopping: 28 | def __init__(self, patience=7, verbose=False, delta=0): 29 | self.patience = patience 30 | self.verbose = verbose 31 | self.counter = 0 32 | self.best_score = None 33 | self.early_stop = False 34 | self.val_loss_min = np.Inf 35 | self.delta = delta 36 | 37 | def __call__(self, val_loss, model, path): 38 | score = -val_loss 39 | if self.best_score is None: 40 | self.best_score = score 41 | self.save_checkpoint(val_loss, model, path) 42 | elif score < self.best_score + self.delta: 43 | self.counter += 1 44 | # print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 45 | if self.counter >= self.patience: 46 | self.early_stop = True 47 | else: 48 | self.best_score = score 49 | self.save_checkpoint(val_loss, model, path) 50 | self.counter = 0 51 | 52 | def save_checkpoint(self, val_loss, model, path): 53 | # if self.verbose: 54 | # print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 55 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 56 | self.val_loss_min = val_loss 57 | 58 | 59 | class dotdict(dict): 60 | """dot.notation access to dictionary attributes""" 61 | __getattr__ = dict.get 62 | __setattr__ = dict.__setitem__ 63 | __delattr__ = dict.__delitem__ 64 | 65 | 66 | class StandardScaler(): 67 | def __init__(self, mean, std): 68 | self.mean = mean 69 | self.std = std 70 | 71 | def transform(self, data): 72 | return (data - self.mean) / self.std 73 | 74 | def inverse_transform(self, data): 75 | return (data * self.std) + self.mean 76 | 77 | 78 | def visual(true, preds=None, name='./pic/test.pdf'): 79 | """ 80 | Results visualization 81 | """ 82 | plt.figure() 83 | plt.plot(true, label='GroundTruth', linewidth=2) 84 | if preds is not None: 85 | plt.plot(preds, label='Prediction', linewidth=2) 86 | plt.legend() 87 | plt.savefig(name, bbox_inches='tight') 88 | 89 | 90 | def adjustment(gt, pred): 91 | anomaly_state = False 92 | for i in range(len(gt)): 93 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 94 | anomaly_state = True 95 | for j in range(i, 0, -1): 96 | if gt[j] == 0: 97 | break 98 | else: 99 | if pred[j] == 0: 100 | pred[j] = 1 101 | for j in range(i, len(gt)): 102 | if gt[j] == 0: 103 | break 104 | else: 105 | if pred[j] == 0: 106 | pred[j] = 1 107 | elif gt[i] == 0: 108 | anomaly_state = False 109 | if anomaly_state: 110 | pred[i] = 1 111 | return gt, pred 112 | 113 | 114 | def cal_accuracy(y_pred, y_true): 115 | return np.mean(y_pred == y_true) -------------------------------------------------------------------------------- /FedSpecNet_v2/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | plt.switch_backend('agg') 9 | 10 | 11 | def adjust_learning_rate(optimizer, epoch, args): 12 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 13 | if args.lradj == 'type1': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 15 | elif args.lradj == 'type2': 16 | lr_adjust = { 17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 18 | 10: 5e-7, 15: 1e-7, 20: 5e-8 19 | } 20 | if epoch in lr_adjust.keys(): 21 | lr = lr_adjust[epoch] 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | # print('Updating learning rate to {}'.format(lr)) 25 | 26 | 27 | class EarlyStopping: 28 | def __init__(self, patience=7, verbose=False, delta=0): 29 | self.patience = patience 30 | self.verbose = verbose 31 | self.counter = 0 32 | self.best_score = None 33 | self.early_stop = False 34 | self.val_loss_min = np.Inf 35 | self.delta = delta 36 | 37 | def __call__(self, val_loss, model, path): 38 | score = -val_loss 39 | if self.best_score is None: 40 | self.best_score = score 41 | self.save_checkpoint(val_loss, model, path) 42 | elif score < self.best_score + self.delta: 43 | self.counter += 1 44 | # print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 45 | if self.counter >= self.patience: 46 | self.early_stop = True 47 | else: 48 | self.best_score = score 49 | self.save_checkpoint(val_loss, model, path) 50 | self.counter = 0 51 | 52 | def save_checkpoint(self, val_loss, model, path): 53 | # if self.verbose: 54 | # print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 55 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 56 | self.val_loss_min = val_loss 57 | 58 | 59 | class dotdict(dict): 60 | """dot.notation access to dictionary attributes""" 61 | __getattr__ = dict.get 62 | __setattr__ = dict.__setitem__ 63 | __delattr__ = dict.__delitem__ 64 | 65 | 66 | class StandardScaler(): 67 | def __init__(self, mean, std): 68 | self.mean = mean 69 | self.std = std 70 | 71 | def transform(self, data): 72 | return (data - self.mean) / self.std 73 | 74 | def inverse_transform(self, data): 75 | return (data * self.std) + self.mean 76 | 77 | 78 | def visual(true, preds=None, name='./pic/test.pdf'): 79 | """ 80 | Results visualization 81 | """ 82 | plt.figure() 83 | plt.plot(true, label='GroundTruth', linewidth=2) 84 | if preds is not None: 85 | plt.plot(preds, label='Prediction', linewidth=2) 86 | plt.legend() 87 | plt.savefig(name, bbox_inches='tight') 88 | 89 | 90 | def adjustment(gt, pred): 91 | anomaly_state = False 92 | for i in range(len(gt)): 93 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 94 | anomaly_state = True 95 | for j in range(i, 0, -1): 96 | if gt[j] == 0: 97 | break 98 | else: 99 | if pred[j] == 0: 100 | pred[j] = 1 101 | for j in range(i, len(gt)): 102 | if gt[j] == 0: 103 | break 104 | else: 105 | if pred[j] == 0: 106 | pred[j] = 1 107 | elif gt[i] == 0: 108 | anomaly_state = False 109 | if anomaly_state: 110 | pred[i] = 1 111 | return gt, pred 112 | 113 | 114 | def cal_accuracy(y_pred, y_true): 115 | return np.mean(y_pred == y_true) -------------------------------------------------------------------------------- /FedSpecNet/utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /FedSpecNet_v2/utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /FedSpecNet/data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Dataset 17 | """ 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from dataclasses import dataclass 22 | from glob import glob 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import patoolib 27 | from tqdm import tqdm 28 | import logging 29 | import os 30 | import pathlib 31 | import sys 32 | from urllib import request 33 | 34 | 35 | def url_file_name(url: str) -> str: 36 | """ 37 | Extract file name from url. 38 | 39 | :param url: URL to extract file name from. 40 | :return: File name. 41 | """ 42 | return url.split('/')[-1] if len(url) > 0 else '' 43 | 44 | 45 | def download(url: str, file_path: str) -> None: 46 | """ 47 | Download a file to the given path. 48 | 49 | :param url: URL to download 50 | :param file_path: Where to download the content. 51 | """ 52 | 53 | def progress(count, block_size, total_size): 54 | progress_pct = float(count * block_size) / float(total_size) * 100.0 55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct)) 56 | sys.stdout.flush() 57 | 58 | if not os.path.isfile(file_path): 59 | opener = request.build_opener() 60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 61 | request.install_opener(opener) 62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 63 | f, _ = request.urlretrieve(url, file_path, progress) 64 | sys.stdout.write('\n') 65 | sys.stdout.flush() 66 | file_info = os.stat(f) 67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.') 68 | else: 69 | file_info = os.stat(file_path) 70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.') 71 | 72 | 73 | @dataclass() 74 | class M4Dataset: 75 | ids: np.ndarray 76 | groups: np.ndarray 77 | frequencies: np.ndarray 78 | horizons: np.ndarray 79 | values: np.ndarray 80 | 81 | @staticmethod 82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset': 83 | """ 84 | Load cached dataset. 85 | 86 | :param training: Load training part if training is True, test part otherwise. 87 | """ 88 | info_file = os.path.join(dataset_file, 'M4-info.csv') 89 | train_cache_file = os.path.join(dataset_file, 'training.npz') 90 | test_cache_file = os.path.join(dataset_file, 'test.npz') 91 | m4_info = pd.read_csv(info_file) 92 | return M4Dataset(ids=m4_info.M4id.values, 93 | groups=m4_info.SP.values, 94 | frequencies=m4_info.Frequency.values, 95 | horizons=m4_info.Horizon.values, 96 | values=np.load( 97 | train_cache_file if training else test_cache_file, 98 | allow_pickle=True)) 99 | 100 | 101 | @dataclass() 102 | class M4Meta: 103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly'] 104 | horizons = [6, 8, 18, 13, 14, 48] 105 | frequencies = [1, 4, 12, 1, 1, 24] 106 | horizons_map = { 107 | 'Yearly': 6, 108 | 'Quarterly': 8, 109 | 'Monthly': 18, 110 | 'Weekly': 13, 111 | 'Daily': 14, 112 | 'Hourly': 48 113 | } # different predict length 114 | frequency_map = { 115 | 'Yearly': 1, 116 | 'Quarterly': 4, 117 | 'Monthly': 12, 118 | 'Weekly': 1, 119 | 'Daily': 1, 120 | 'Hourly': 24 121 | } 122 | history_size = { 123 | 'Yearly': 1.5, 124 | 'Quarterly': 1.5, 125 | 'Monthly': 1.5, 126 | 'Weekly': 10, 127 | 'Daily': 10, 128 | 'Hourly': 10 129 | } # from interpretable.gin 130 | 131 | 132 | def load_m4_info() -> pd.DataFrame: 133 | """ 134 | Load M4Info file. 135 | 136 | :return: Pandas DataFrame of M4Info. 137 | """ 138 | return pd.read_csv(INFO_FILE_PATH) 139 | -------------------------------------------------------------------------------- /FedSpecNet_v2/data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Dataset 17 | """ 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from dataclasses import dataclass 22 | from glob import glob 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import patoolib 27 | from tqdm import tqdm 28 | import logging 29 | import os 30 | import pathlib 31 | import sys 32 | from urllib import request 33 | 34 | 35 | def url_file_name(url: str) -> str: 36 | """ 37 | Extract file name from url. 38 | 39 | :param url: URL to extract file name from. 40 | :return: File name. 41 | """ 42 | return url.split('/')[-1] if len(url) > 0 else '' 43 | 44 | 45 | def download(url: str, file_path: str) -> None: 46 | """ 47 | Download a file to the given path. 48 | 49 | :param url: URL to download 50 | :param file_path: Where to download the content. 51 | """ 52 | 53 | def progress(count, block_size, total_size): 54 | progress_pct = float(count * block_size) / float(total_size) * 100.0 55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct)) 56 | sys.stdout.flush() 57 | 58 | if not os.path.isfile(file_path): 59 | opener = request.build_opener() 60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 61 | request.install_opener(opener) 62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 63 | f, _ = request.urlretrieve(url, file_path, progress) 64 | sys.stdout.write('\n') 65 | sys.stdout.flush() 66 | file_info = os.stat(f) 67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.') 68 | else: 69 | file_info = os.stat(file_path) 70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.') 71 | 72 | 73 | @dataclass() 74 | class M4Dataset: 75 | ids: np.ndarray 76 | groups: np.ndarray 77 | frequencies: np.ndarray 78 | horizons: np.ndarray 79 | values: np.ndarray 80 | 81 | @staticmethod 82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset': 83 | """ 84 | Load cached dataset. 85 | 86 | :param training: Load training part if training is True, test part otherwise. 87 | """ 88 | info_file = os.path.join(dataset_file, 'M4-info.csv') 89 | train_cache_file = os.path.join(dataset_file, 'training.npz') 90 | test_cache_file = os.path.join(dataset_file, 'test.npz') 91 | m4_info = pd.read_csv(info_file) 92 | return M4Dataset(ids=m4_info.M4id.values, 93 | groups=m4_info.SP.values, 94 | frequencies=m4_info.Frequency.values, 95 | horizons=m4_info.Horizon.values, 96 | values=np.load( 97 | train_cache_file if training else test_cache_file, 98 | allow_pickle=True)) 99 | 100 | 101 | @dataclass() 102 | class M4Meta: 103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly'] 104 | horizons = [6, 8, 18, 13, 14, 48] 105 | frequencies = [1, 4, 12, 1, 1, 24] 106 | horizons_map = { 107 | 'Yearly': 6, 108 | 'Quarterly': 8, 109 | 'Monthly': 18, 110 | 'Weekly': 13, 111 | 'Daily': 14, 112 | 'Hourly': 48 113 | } # different predict length 114 | frequency_map = { 115 | 'Yearly': 1, 116 | 'Quarterly': 4, 117 | 'Monthly': 12, 118 | 'Weekly': 1, 119 | 'Daily': 1, 120 | 'Hourly': 24 121 | } 122 | history_size = { 123 | 'Yearly': 1.5, 124 | 'Quarterly': 1.5, 125 | 'Monthly': 1.5, 126 | 'Weekly': 10, 127 | 'Daily': 10, 128 | 'Hourly': 10 129 | } # from interpretable.gin 130 | 131 | 132 | def load_m4_info() -> pd.DataFrame: 133 | """ 134 | Load M4Info file. 135 | 136 | :return: Pandas DataFrame of M4Info. 137 | """ 138 | return pd.read_csv(INFO_FILE_PATH) 139 | -------------------------------------------------------------------------------- /FedSpecNet/models/iTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Paper link: https://arxiv.org/abs/2310.06625 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | # self.interval = configs.interval ######### 18 | self.task_name = configs.task_name 19 | self.seq_len = configs.seq_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | # Encoder 25 | self.encoder = Encoder( 26 | [ 27 | EncoderLayer( 28 | AttentionLayer( 29 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 30 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | norm_layer=torch.nn.LayerNorm(configs.d_model) 38 | ) 39 | # Decoder 40 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 41 | self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True) 42 | 43 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 44 | means = x_enc.mean(1, keepdim=True).detach() 45 | x_enc = x_enc - means 46 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 47 | x_enc /= stdev 48 | 49 | _, _, N = x_enc.shape 50 | 51 | # Embedding 52 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 53 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 54 | 55 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 56 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 57 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 58 | return dec_out 59 | 60 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 61 | # Normalization from Non-stationary Transformer 62 | means = x_enc.mean(1, keepdim=True).detach() 63 | x_enc = x_enc - means 64 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 65 | x_enc /= stdev 66 | 67 | _, L, N = x_enc.shape 68 | 69 | # Embedding 70 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 71 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 72 | 73 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 74 | # De-Normalization from Non-stationary Transformer 75 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 76 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 77 | return dec_out 78 | 79 | def anomaly_detection(self, x_enc): 80 | # Normalization from Non-stationary Transformer 81 | means = x_enc.mean(1, keepdim=True).detach() 82 | x_enc = x_enc - means 83 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 84 | x_enc /= stdev 85 | 86 | _, L, N = x_enc.shape 87 | 88 | # Embedding 89 | enc_out = self.enc_embedding(x_enc, None) 90 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 91 | 92 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 93 | # De-Normalization from Non-stationary Transformer 94 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 95 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 96 | return dec_out 97 | 98 | def classification(self, x_enc, x_mark_enc): 99 | # Embedding 100 | enc_out = self.enc_embedding(x_enc, None) 101 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 102 | 103 | # Output 104 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 105 | output = self.dropout(output) 106 | output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model) 107 | output = self.projection(output) # (batch_size, num_classes) 108 | return output 109 | 110 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 111 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 112 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 113 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 114 | return None 115 | -------------------------------------------------------------------------------- /FedSpecNet_v2/models/iTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Paper link: https://arxiv.org/abs/2310.06625 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | # self.interval = configs.interval ######### 18 | self.task_name = configs.task_name 19 | self.seq_len = configs.seq_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | # Encoder 25 | self.encoder = Encoder( 26 | [ 27 | EncoderLayer( 28 | AttentionLayer( 29 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 30 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | norm_layer=torch.nn.LayerNorm(configs.d_model) 38 | ) 39 | # Decoder 40 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 41 | self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True) 42 | 43 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 44 | means = x_enc.mean(1, keepdim=True).detach() 45 | x_enc = x_enc - means 46 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 47 | x_enc /= stdev 48 | 49 | _, _, N = x_enc.shape 50 | 51 | # Embedding 52 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 53 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 54 | 55 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 56 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 57 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 58 | return dec_out 59 | 60 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 61 | # Normalization from Non-stationary Transformer 62 | means = x_enc.mean(1, keepdim=True).detach() 63 | x_enc = x_enc - means 64 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 65 | x_enc /= stdev 66 | 67 | _, L, N = x_enc.shape 68 | 69 | # Embedding 70 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 71 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 72 | 73 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 74 | # De-Normalization from Non-stationary Transformer 75 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 76 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 77 | return dec_out 78 | 79 | def anomaly_detection(self, x_enc): 80 | # Normalization from Non-stationary Transformer 81 | means = x_enc.mean(1, keepdim=True).detach() 82 | x_enc = x_enc - means 83 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 84 | x_enc /= stdev 85 | 86 | _, L, N = x_enc.shape 87 | 88 | # Embedding 89 | enc_out = self.enc_embedding(x_enc, None) 90 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 91 | 92 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 93 | # De-Normalization from Non-stationary Transformer 94 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 95 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 96 | return dec_out 97 | 98 | def classification(self, x_enc, x_mark_enc): 99 | # Embedding 100 | enc_out = self.enc_embedding(x_enc, None) 101 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 102 | 103 | # Output 104 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 105 | output = self.dropout(output) 106 | output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model) 107 | output = self.projection(output) # (batch_size, num_classes) 108 | return output 109 | 110 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 111 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 112 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 113 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 114 | return None 115 | -------------------------------------------------------------------------------- /FedSpecNet/data_provider/uea.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | 7 | def collate_fn(data, max_len=None): 8 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create 9 | Args: 10 | data: len(batch_size) list of tuples (X, y). 11 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length. 12 | - y: torch tensor of shape (num_labels,) : class indices or numerical targets 13 | (for classification or regression, respectively). num_labels > 1 for multi-task models 14 | max_len: global fixed sequence length. Used for architectures requiring fixed length input, 15 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s 16 | Returns: 17 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input) 18 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output) 19 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor 20 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values 21 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding 22 | """ 23 | 24 | batch_size = len(data) 25 | features, labels = zip(*data) 26 | 27 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension) 28 | lengths = [X.shape[0] for X in features] # original sequence length for each time series 29 | if max_len is None: 30 | max_len = max(lengths) 31 | 32 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim) 33 | for i in range(batch_size): 34 | end = min(lengths[i], max_len) 35 | X[i, :end, :] = features[i][:end, :] 36 | 37 | targets = torch.stack(labels, dim=0) # (batch_size, num_labels) 38 | 39 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), 40 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep 41 | 42 | return X, targets, padding_masks 43 | 44 | 45 | def padding_mask(lengths, max_len=None): 46 | """ 47 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths, 48 | where 1 means keep element at this position (time step) 49 | """ 50 | batch_size = lengths.numel() 51 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types 52 | return (torch.arange(0, max_len, device=lengths.device) 53 | .type_as(lengths) 54 | .repeat(batch_size, 1) 55 | .lt(lengths.unsqueeze(1))) 56 | 57 | 58 | class Normalizer(object): 59 | """ 60 | Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization. 61 | """ 62 | 63 | def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None): 64 | """ 65 | Args: 66 | norm_type: choose from: 67 | "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps) 68 | "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows) 69 | mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values 70 | """ 71 | 72 | self.norm_type = norm_type 73 | self.mean = mean 74 | self.std = std 75 | self.min_val = min_val 76 | self.max_val = max_val 77 | 78 | def normalize(self, df): 79 | """ 80 | Args: 81 | df: input dataframe 82 | Returns: 83 | df: normalized dataframe 84 | """ 85 | if self.norm_type == "standardization": 86 | if self.mean is None: 87 | self.mean = df.mean() 88 | self.std = df.std() 89 | return (df - self.mean) / (self.std + np.finfo(float).eps) 90 | 91 | elif self.norm_type == "minmax": 92 | if self.max_val is None: 93 | self.max_val = df.max() 94 | self.min_val = df.min() 95 | return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps) 96 | 97 | elif self.norm_type == "per_sample_std": 98 | grouped = df.groupby(by=df.index) 99 | return (df - grouped.transform('mean')) / grouped.transform('std') 100 | 101 | elif self.norm_type == "per_sample_minmax": 102 | grouped = df.groupby(by=df.index) 103 | min_vals = grouped.transform('min') 104 | return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps) 105 | 106 | else: 107 | raise (NameError(f'Normalize method "{self.norm_type}" not implemented')) 108 | 109 | 110 | def interpolate_missing(y): 111 | """ 112 | Replaces NaN values in pd.Series `y` using linear interpolation 113 | """ 114 | if y.isna().any(): 115 | y = y.interpolate(method='linear', limit_direction='both') 116 | return y 117 | 118 | 119 | def subsample(y, limit=256, factor=2): 120 | """ 121 | If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor 122 | """ 123 | if len(y) > limit: 124 | return y[::factor].reset_index(drop=True) 125 | return y 126 | -------------------------------------------------------------------------------- /FedSpecNet_v2/data_provider/uea.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | 7 | def collate_fn(data, max_len=None): 8 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create 9 | Args: 10 | data: len(batch_size) list of tuples (X, y). 11 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length. 12 | - y: torch tensor of shape (num_labels,) : class indices or numerical targets 13 | (for classification or regression, respectively). num_labels > 1 for multi-task models 14 | max_len: global fixed sequence length. Used for architectures requiring fixed length input, 15 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s 16 | Returns: 17 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input) 18 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output) 19 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor 20 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values 21 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding 22 | """ 23 | 24 | batch_size = len(data) 25 | features, labels = zip(*data) 26 | 27 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension) 28 | lengths = [X.shape[0] for X in features] # original sequence length for each time series 29 | if max_len is None: 30 | max_len = max(lengths) 31 | 32 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim) 33 | for i in range(batch_size): 34 | end = min(lengths[i], max_len) 35 | X[i, :end, :] = features[i][:end, :] 36 | 37 | targets = torch.stack(labels, dim=0) # (batch_size, num_labels) 38 | 39 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), 40 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep 41 | 42 | return X, targets, padding_masks 43 | 44 | 45 | def padding_mask(lengths, max_len=None): 46 | """ 47 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths, 48 | where 1 means keep element at this position (time step) 49 | """ 50 | batch_size = lengths.numel() 51 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types 52 | return (torch.arange(0, max_len, device=lengths.device) 53 | .type_as(lengths) 54 | .repeat(batch_size, 1) 55 | .lt(lengths.unsqueeze(1))) 56 | 57 | 58 | class Normalizer(object): 59 | """ 60 | Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization. 61 | """ 62 | 63 | def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None): 64 | """ 65 | Args: 66 | norm_type: choose from: 67 | "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps) 68 | "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows) 69 | mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values 70 | """ 71 | 72 | self.norm_type = norm_type 73 | self.mean = mean 74 | self.std = std 75 | self.min_val = min_val 76 | self.max_val = max_val 77 | 78 | def normalize(self, df): 79 | """ 80 | Args: 81 | df: input dataframe 82 | Returns: 83 | df: normalized dataframe 84 | """ 85 | if self.norm_type == "standardization": 86 | if self.mean is None: 87 | self.mean = df.mean() 88 | self.std = df.std() 89 | return (df - self.mean) / (self.std + np.finfo(float).eps) 90 | 91 | elif self.norm_type == "minmax": 92 | if self.max_val is None: 93 | self.max_val = df.max() 94 | self.min_val = df.min() 95 | return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps) 96 | 97 | elif self.norm_type == "per_sample_std": 98 | grouped = df.groupby(by=df.index) 99 | return (df - grouped.transform('mean')) / grouped.transform('std') 100 | 101 | elif self.norm_type == "per_sample_minmax": 102 | grouped = df.groupby(by=df.index) 103 | min_vals = grouped.transform('min') 104 | return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps) 105 | 106 | else: 107 | raise (NameError(f'Normalize method "{self.norm_type}" not implemented')) 108 | 109 | 110 | def interpolate_missing(y): 111 | """ 112 | Replaces NaN values in pd.Series `y` using linear interpolation 113 | """ 114 | if y.isna().any(): 115 | y = y.interpolate(method='linear', limit_direction='both') 116 | return y 117 | 118 | 119 | def subsample(y, limit=256, factor=2): 120 | """ 121 | If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor 122 | """ 123 | if len(y) > limit: 124 | return y[::factor].reset_index(drop=True) 125 | return y 126 | -------------------------------------------------------------------------------- /FedSpecNet/layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 6 | 7 | class ConvLayer(nn.Module): 8 | def __init__(self, c_in): 9 | super(ConvLayer, self).__init__() 10 | self.downConv = nn.Conv1d(in_channels=c_in, 11 | out_channels=c_in, 12 | kernel_size=3, 13 | padding=2, 14 | padding_mode='circular').to(device) 15 | self.norm = nn.BatchNorm1d(c_in).to(device) 16 | self.activation = nn.ELU() 17 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1).to(device) 18 | 19 | def forward(self, x): 20 | x = self.downConv(x.permute(0, 2, 1)) 21 | x = self.norm(x) 22 | x = self.activation(x) 23 | x = self.maxPool(x) 24 | x = x.transpose(1, 2) 25 | return x 26 | 27 | 28 | class EncoderLayer(nn.Module): 29 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 30 | super(EncoderLayer, self).__init__() 31 | d_ff = d_ff or 4 * d_model 32 | self.attention = attention 33 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1).to(device) 34 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1).to(device) 35 | self.norm1 = nn.LayerNorm(d_model).to(device) 36 | self.norm2 = nn.LayerNorm(d_model).to(device) 37 | self.dropout = nn.Dropout(dropout) 38 | self.activation = F.relu if activation == "relu" else F.gelu 39 | 40 | def forward(self, x, attn_mask=None, tau=None, delta=None): 41 | new_x, attn = self.attention( #x同时作为qkv 42 | x, x, x, 43 | attn_mask=attn_mask, 44 | tau=tau, delta=delta 45 | ) 46 | x = x + self.dropout(new_x) 47 | 48 | y = x = self.norm1(x) 49 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 50 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 51 | 52 | return self.norm2(x + y), attn 53 | 54 | 55 | class Encoder(nn.Module): 56 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 57 | super(Encoder, self).__init__() 58 | self.attn_layers = nn.ModuleList([layer.to(device) for layer in attn_layers]) 59 | self.conv_layers = nn.ModuleList([layer.to(device) for layer in conv_layers]) if conv_layers is not None else None 60 | self.norm = norm_layer.to(device) if norm_layer is not None else None 61 | 62 | def forward(self, x, attn_mask=None, tau=None, delta=None): 63 | x = x.to(device) # Ensure x is on the correct device 64 | attns = [] 65 | if self.conv_layers is not None: 66 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 67 | delta = delta if i == 0 else None 68 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 69 | x = conv_layer(x) 70 | attns.append(attn) 71 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 72 | attns.append(attn) 73 | else: 74 | for attn_layer in self.attn_layers: 75 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 76 | attns.append(attn) 77 | 78 | if self.norm is not None: 79 | x = self.norm(x.to(device)) 80 | 81 | return x, attns 82 | 83 | class DecoderLayer(nn.Module): 84 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 85 | dropout=0.1, activation="relu"): 86 | super(DecoderLayer, self).__init__() 87 | d_ff = d_ff or 4 * d_model 88 | self.self_attention = self_attention.to(device) 89 | self.cross_attention = cross_attention.to(device) 90 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1).to(device) 91 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1).to(device) 92 | self.norm1 = nn.LayerNorm(d_model).to(device) 93 | self.norm2 = nn.LayerNorm(d_model).to(device) 94 | self.norm3 = nn.LayerNorm(d_model).to(device) 95 | self.dropout = nn.Dropout(dropout) 96 | self.activation = F.relu if activation == "relu" else F.gelu 97 | 98 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 99 | x = x.to(device) 100 | cross = cross.to(device) 101 | 102 | x = x + self.dropout(self.self_attention( 103 | x, x, x, 104 | attn_mask=x_mask, 105 | tau=tau, delta=None 106 | )[0]) 107 | x = self.norm1(x) 108 | 109 | x = x + self.dropout(self.cross_attention( 110 | x, cross, cross, 111 | attn_mask=cross_mask, 112 | tau=tau, delta=delta 113 | )[0]) 114 | 115 | y = x = self.norm2(x) 116 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 117 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 118 | 119 | return self.norm3(x + y) 120 | 121 | 122 | class Decoder(nn.Module): 123 | def __init__(self, layers, norm_layer=None, projection=None): 124 | super(Decoder, self).__init__() 125 | self.layers = nn.ModuleList([layer.to(device) for layer in layers]) 126 | self.norm = norm_layer.to(device) if norm_layer is not None else None 127 | self.projection = projection.to(device) if projection is not None else None 128 | 129 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 130 | x = x.to(device) 131 | cross = cross.to(device) 132 | 133 | for layer in self.layers: 134 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 135 | 136 | if self.norm is not None: 137 | x = self.norm(x) 138 | 139 | if self.projection is not None: 140 | x = self.projection(x) 141 | return x 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /FedSpecNet_v2/layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 6 | 7 | class ConvLayer(nn.Module): 8 | def __init__(self, c_in): 9 | super(ConvLayer, self).__init__() 10 | self.downConv = nn.Conv1d(in_channels=c_in, 11 | out_channels=c_in, 12 | kernel_size=3, 13 | padding=2, 14 | padding_mode='circular').to(device) 15 | self.norm = nn.BatchNorm1d(c_in).to(device) 16 | self.activation = nn.ELU() 17 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1).to(device) 18 | 19 | def forward(self, x): 20 | x = self.downConv(x.permute(0, 2, 1)) 21 | x = self.norm(x) 22 | x = self.activation(x) 23 | x = self.maxPool(x) 24 | x = x.transpose(1, 2) 25 | return x 26 | 27 | 28 | class EncoderLayer(nn.Module): 29 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 30 | super(EncoderLayer, self).__init__() 31 | d_ff = d_ff or 4 * d_model 32 | self.attention = attention 33 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1).to(device) 34 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1).to(device) 35 | self.norm1 = nn.LayerNorm(d_model).to(device) 36 | self.norm2 = nn.LayerNorm(d_model).to(device) 37 | self.dropout = nn.Dropout(dropout) 38 | self.activation = F.relu if activation == "relu" else F.gelu 39 | 40 | def forward(self, x, attn_mask=None, tau=None, delta=None): 41 | new_x, attn = self.attention( #x同时作为qkv 42 | x, x, x, 43 | attn_mask=attn_mask, 44 | tau=tau, delta=delta 45 | ) 46 | x = x + self.dropout(new_x) 47 | 48 | y = x = self.norm1(x) 49 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 50 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 51 | 52 | return self.norm2(x + y), attn 53 | 54 | 55 | class Encoder(nn.Module): 56 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 57 | super(Encoder, self).__init__() 58 | self.attn_layers = nn.ModuleList([layer.to(device) for layer in attn_layers]) 59 | self.conv_layers = nn.ModuleList([layer.to(device) for layer in conv_layers]) if conv_layers is not None else None 60 | self.norm = norm_layer.to(device) if norm_layer is not None else None 61 | 62 | def forward(self, x, attn_mask=None, tau=None, delta=None): 63 | x = x.to(device) # Ensure x is on the correct device 64 | attns = [] 65 | if self.conv_layers is not None: 66 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 67 | delta = delta if i == 0 else None 68 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 69 | x = conv_layer(x) 70 | attns.append(attn) 71 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 72 | attns.append(attn) 73 | else: 74 | for attn_layer in self.attn_layers: 75 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 76 | attns.append(attn) 77 | 78 | if self.norm is not None: 79 | x = self.norm(x.to(device)) 80 | 81 | return x, attns 82 | 83 | class DecoderLayer(nn.Module): 84 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 85 | dropout=0.1, activation="relu"): 86 | super(DecoderLayer, self).__init__() 87 | d_ff = d_ff or 4 * d_model 88 | self.self_attention = self_attention.to(device) 89 | self.cross_attention = cross_attention.to(device) 90 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1).to(device) 91 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1).to(device) 92 | self.norm1 = nn.LayerNorm(d_model).to(device) 93 | self.norm2 = nn.LayerNorm(d_model).to(device) 94 | self.norm3 = nn.LayerNorm(d_model).to(device) 95 | self.dropout = nn.Dropout(dropout) 96 | self.activation = F.relu if activation == "relu" else F.gelu 97 | 98 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 99 | x = x.to(device) 100 | cross = cross.to(device) 101 | 102 | x = x + self.dropout(self.self_attention( 103 | x, x, x, 104 | attn_mask=x_mask, 105 | tau=tau, delta=None 106 | )[0]) 107 | x = self.norm1(x) 108 | 109 | x = x + self.dropout(self.cross_attention( 110 | x, cross, cross, 111 | attn_mask=cross_mask, 112 | tau=tau, delta=delta 113 | )[0]) 114 | 115 | y = x = self.norm2(x) 116 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 117 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 118 | 119 | return self.norm3(x + y) 120 | 121 | 122 | class Decoder(nn.Module): 123 | def __init__(self, layers, norm_layer=None, projection=None): 124 | super(Decoder, self).__init__() 125 | self.layers = nn.ModuleList([layer.to(device) for layer in layers]) 126 | self.norm = norm_layer.to(device) if norm_layer is not None else None 127 | self.projection = projection.to(device) if projection is not None else None 128 | 129 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 130 | x = x.to(device) 131 | cross = cross.to(device) 132 | 133 | for layer in self.layers: 134 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 135 | 136 | if self.norm is not None: 137 | x = self.norm(x) 138 | 139 | if self.projection is not None: 140 | x = self.projection(x) 141 | return x 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /FedSpecNet/_Support/support_VMD.py: -------------------------------------------------------------------------------- 1 | def VMD(signal, K): 2 | # --------------------- 3 | # signal - the time domain signal (1D) to be decomposed 4 | # alpha - the balancing parameter of the data-fidelity constraint 5 | # tau - time-step of the dual ascent ( pick 0 for noise-slack ) 6 | # K - the number of modes to be recovered 7 | # DC - true if the first mode is put and kept at DC (0-freq) 8 | # init - 0 = all omegas start at 0 9 | # 1 = all omegas start uniformly distributed 10 | # 2 = all omegas initialized randomly 11 | # tol - tolerance of convergence criterion; typically around 1e-6 12 | # 13 | # Output: 14 | # ------- 15 | # u - the collection of decomposed modes 16 | # u_hat - spectra of the modes 17 | # omega - estimated mode center-frequencies 18 | # 19 | 20 | alpha = 2000.0 21 | tau = 0 22 | DC = 0 23 | init = 1 24 | tol = 1e-7 25 | 26 | import numpy as np 27 | import math 28 | import matplotlib.pyplot as plt 29 | # Period and sampling frequency of input signal 30 | save_T=len(signal) 31 | fs=1/float(save_T) 32 | 33 | # extend the signal by mirroring 34 | T=save_T 35 | # print(T) 36 | f_mirror=np.zeros(2*T) 37 | # print(f_mirror) 38 | f_mirror[0:int(T/2)]=signal[int(-(T/2)-1)::-1] 39 | # print(f_mirror) 40 | f_mirror[int(T/2):int(3*T/2)]= signal 41 | # print(f_mirror) 42 | f_mirror[int(3*T/2):int(2*T)]=signal[-1:int(-T/2-1):-1] 43 | # print(f_mirror) 44 | f=f_mirror 45 | # print('f_mirror') 46 | # print(f_mirror) 47 | # print('-------') 48 | 49 | # Time Domain 0 to T (of mirrored signal) 50 | T=float(len(f)) 51 | # print(T) 52 | t=np.linspace(1/float(T),1,int(T),endpoint=True) 53 | # print(t) 54 | 55 | # Spectral Domain discretization 56 | freqs=t-0.5-1/T 57 | # print(freqs) 58 | # print('-----') 59 | # Maximum number of iterations (if not converged yet, then it won't anyway) 60 | N=500 61 | 62 | # For future generalizations: individual alpha for each mode 63 | Alpha=alpha*np.ones(K,dtype=complex) 64 | # print(Alpha.shape) 65 | # print(Alpha) 66 | # print('-----') 67 | 68 | # Construct and center f_hat 69 | f_hat=np.fft.fftshift(np.fft.fft(f)) 70 | # print('f_hat') 71 | # print(f_hat.shape) 72 | # print(f_hat) 73 | # print('-----') 74 | f_hat_plus=f_hat 75 | f_hat_plus[0:int(T/2)]=0 76 | # print('f_hat_plus') 77 | # print(f_hat_plus.shape) 78 | # print(f_hat_plus) 79 | # print('-----') 80 | # matrix keeping track of every iterant // could be discarded for mem 81 | u_hat_plus=np.zeros((N,len(freqs),K),dtype=complex) 82 | # print('u_hat_plus') 83 | # print(u_hat_plus.shape) 84 | # print(u_hat_plus) 85 | # print('-----') 86 | 87 | 88 | # Initialization of omega_k 89 | omega_plus=np.zeros((N,K),dtype=complex) 90 | # print('omega_plus') 91 | # print(omega_plus.shape) 92 | # print(omega_plus) 93 | 94 | if (init==1): 95 | for i in range(1,K+1): 96 | omega_plus[0,i-1]=(0.5/K)*(i-1) 97 | elif (init==2): 98 | omega_plus[0,:]=np.sort(math.exp(math.log(fs))+(math.log(0.5)-math.log(fs))*np.random.rand(1,K)) 99 | else: 100 | omega_plus[0,:]=0 101 | 102 | if (DC): 103 | omega_plus[0,0]=0 104 | 105 | # print('omega_plus') 106 | # print(omega_plus.shape) 107 | # print(omega_plus) 108 | 109 | # start with empty dual variables 110 | lamda_hat=np.zeros((N,len(freqs)),dtype=complex) 111 | 112 | # other inits 113 | uDiff=tol+2.2204e-16 #updata step 114 | # print('uDiff') 115 | # print(uDiff) 116 | # print('----') 117 | n=1 #loop counter 118 | sum_uk=0 #accumulator 119 | 120 | T=int(T) 121 | 122 | 123 | # ----------- Main loop for iterative updates 124 | 125 | while uDiff > tol and n tol and n= '1.5.0' else 2 36 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 37 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv1d): 40 | nn.init.kaiming_normal_( 41 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 42 | 43 | def forward(self, x): 44 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 45 | return x 46 | 47 | 48 | class FixedEmbedding(nn.Module): 49 | def __init__(self, c_in, d_model): 50 | super(FixedEmbedding, self).__init__() 51 | 52 | w = torch.zeros(c_in, d_model).float() 53 | w.require_grad = False 54 | 55 | position = torch.arange(0, c_in).float().unsqueeze(1) 56 | div_term = (torch.arange(0, d_model, 2).float() 57 | * -(math.log(10000.0) / d_model)).exp() 58 | 59 | w[:, 0::2] = torch.sin(position * div_term) 60 | w[:, 1::2] = torch.cos(position * div_term) 61 | 62 | self.emb = nn.Embedding(c_in, d_model) 63 | self.emb.weight = nn.Parameter(w, requires_grad=False) 64 | 65 | def forward(self, x): 66 | return self.emb(x).detach() 67 | 68 | 69 | class TemporalEmbedding(nn.Module): 70 | def __init__(self, d_model, embed_type='fixed', freq='h'): 71 | super(TemporalEmbedding, self).__init__() 72 | 73 | minute_size = 4 74 | hour_size = 24 75 | weekday_size = 7 76 | day_size = 32 77 | month_size = 13 78 | 79 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 80 | if freq == 't': 81 | self.minute_embed = Embed(minute_size, d_model) 82 | self.hour_embed = Embed(hour_size, d_model) 83 | self.weekday_embed = Embed(weekday_size, d_model) 84 | self.day_embed = Embed(day_size, d_model) 85 | self.month_embed = Embed(month_size, d_model) 86 | 87 | def forward(self, x): 88 | x = x.long() 89 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr( 90 | self, 'minute_embed') else 0. 91 | hour_x = self.hour_embed(x[:, :, 3]) 92 | weekday_x = self.weekday_embed(x[:, :, 2]) 93 | day_x = self.day_embed(x[:, :, 1]) 94 | month_x = self.month_embed(x[:, :, 0]) 95 | 96 | return hour_x + weekday_x + day_x + month_x + minute_x 97 | 98 | 99 | class TimeFeatureEmbedding(nn.Module): 100 | def __init__(self, d_model, embed_type='timeF', freq='h'): 101 | super(TimeFeatureEmbedding, self).__init__() 102 | 103 | freq_map = {'h': 4, 't': 5, 's': 6, 104 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 105 | d_inp = freq_map[freq] 106 | self.embed = nn.Linear(d_inp, d_model, bias=False) 107 | 108 | def forward(self, x): 109 | return self.embed(x) 110 | 111 | 112 | class DataEmbedding(nn.Module): 113 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 114 | super(DataEmbedding, self).__init__() 115 | 116 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 117 | self.position_embedding = PositionalEmbedding(d_model=d_model) 118 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 119 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 120 | d_model=d_model, embed_type=embed_type, freq=freq) 121 | self.dropout = nn.Dropout(p=dropout) 122 | 123 | def forward(self, x, x_mark): 124 | if x_mark is None: 125 | x = self.value_embedding(x) + self.position_embedding(x) 126 | else: 127 | x = self.value_embedding( 128 | x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 129 | return self.dropout(x) 130 | 131 | 132 | class DataEmbedding_inverted(nn.Module): 133 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")): 134 | super(DataEmbedding_inverted, self).__init__() 135 | self.device = device 136 | self.value_embedding = nn.Linear(c_in, d_model) 137 | self.value_embedding = self.value_embedding.to(self.device) 138 | self.dropout = nn.Dropout(p=dropout) 139 | 140 | def forward(self, x, x_mark): 141 | x = x.to(self.device) 142 | x_mark = x_mark.to(self.device) 143 | 144 | x = x.permute(0, 2, 1) 145 | # x: [Batch Variate Time] 146 | if x_mark is None: 147 | x = self.value_embedding(x) 148 | else: 149 | device = x.device # 确保它们在同一设备上 150 | x = x.to(device) 151 | x_mark = x_mark.to(device) 152 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 153 | 154 | # x: [Batch Variate d_model] 155 | return self.dropout(x) 156 | 157 | 158 | class DataEmbedding_wo_pos(nn.Module): 159 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 160 | super(DataEmbedding_wo_pos, self).__init__() 161 | 162 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 163 | self.position_embedding = PositionalEmbedding(d_model=d_model) 164 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 165 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 166 | d_model=d_model, embed_type=embed_type, freq=freq) 167 | self.dropout = nn.Dropout(p=dropout) 168 | 169 | def forward(self, x, x_mark): 170 | if x_mark is None: 171 | x = self.value_embedding(x) 172 | else: 173 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 174 | return self.dropout(x) 175 | 176 | 177 | class PatchEmbedding(nn.Module): 178 | def __init__(self, d_model, patch_len, stride, padding, dropout): 179 | super(PatchEmbedding, self).__init__() 180 | # Patching 181 | self.patch_len = patch_len 182 | self.stride = stride 183 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 184 | 185 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 186 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False) 187 | 188 | # Positional embedding 189 | self.position_embedding = PositionalEmbedding(d_model) 190 | 191 | # Residual dropout 192 | self.dropout = nn.Dropout(dropout) 193 | 194 | def forward(self, x): 195 | # do patching 196 | n_vars = x.shape[1] 197 | x = self.padding_patch_layer(x) 198 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 199 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) 200 | # Input encoding 201 | x = self.value_embedding(x) + self.position_embedding(x) 202 | return self.dropout(x), n_vars 203 | -------------------------------------------------------------------------------- /FedSpecNet_v2/layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | 7 | # 设置默认设备为 cuda:1 8 | torch.cuda.set_device(0) 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | class PositionalEmbedding(nn.Module): 12 | def __init__(self, d_model, max_len=5000): 13 | super(PositionalEmbedding, self).__init__() 14 | # Compute the positional encodings once in log space. 15 | pe = torch.zeros(max_len, d_model).float() 16 | pe.require_grad = False 17 | 18 | position = torch.arange(0, max_len).float().unsqueeze(1) 19 | div_term = (torch.arange(0, d_model, 2).float() 20 | * -(math.log(10000.0) / d_model)).exp() 21 | 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | 25 | pe = pe.unsqueeze(0) 26 | self.register_buffer('pe', pe) 27 | 28 | def forward(self, x): 29 | return self.pe[:, :x.size(1)] 30 | 31 | 32 | class TokenEmbedding(nn.Module): 33 | def __init__(self, c_in, d_model): 34 | super(TokenEmbedding, self).__init__() 35 | padding = 1 if torch.__version__ >= '1.5.0' else 2 36 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 37 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv1d): 40 | nn.init.kaiming_normal_( 41 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 42 | 43 | def forward(self, x): 44 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 45 | return x 46 | 47 | 48 | class FixedEmbedding(nn.Module): 49 | def __init__(self, c_in, d_model): 50 | super(FixedEmbedding, self).__init__() 51 | 52 | w = torch.zeros(c_in, d_model).float() 53 | w.require_grad = False 54 | 55 | position = torch.arange(0, c_in).float().unsqueeze(1) 56 | div_term = (torch.arange(0, d_model, 2).float() 57 | * -(math.log(10000.0) / d_model)).exp() 58 | 59 | w[:, 0::2] = torch.sin(position * div_term) 60 | w[:, 1::2] = torch.cos(position * div_term) 61 | 62 | self.emb = nn.Embedding(c_in, d_model) 63 | self.emb.weight = nn.Parameter(w, requires_grad=False) 64 | 65 | def forward(self, x): 66 | return self.emb(x).detach() 67 | 68 | 69 | class TemporalEmbedding(nn.Module): 70 | def __init__(self, d_model, embed_type='fixed', freq='h'): 71 | super(TemporalEmbedding, self).__init__() 72 | 73 | minute_size = 4 74 | hour_size = 24 75 | weekday_size = 7 76 | day_size = 32 77 | month_size = 13 78 | 79 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 80 | if freq == 't': 81 | self.minute_embed = Embed(minute_size, d_model) 82 | self.hour_embed = Embed(hour_size, d_model) 83 | self.weekday_embed = Embed(weekday_size, d_model) 84 | self.day_embed = Embed(day_size, d_model) 85 | self.month_embed = Embed(month_size, d_model) 86 | 87 | def forward(self, x): 88 | x = x.long() 89 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr( 90 | self, 'minute_embed') else 0. 91 | hour_x = self.hour_embed(x[:, :, 3]) 92 | weekday_x = self.weekday_embed(x[:, :, 2]) 93 | day_x = self.day_embed(x[:, :, 1]) 94 | month_x = self.month_embed(x[:, :, 0]) 95 | 96 | return hour_x + weekday_x + day_x + month_x + minute_x 97 | 98 | 99 | class TimeFeatureEmbedding(nn.Module): 100 | def __init__(self, d_model, embed_type='timeF', freq='h'): 101 | super(TimeFeatureEmbedding, self).__init__() 102 | 103 | freq_map = {'h': 4, 't': 5, 's': 6, 104 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 105 | d_inp = freq_map[freq] 106 | self.embed = nn.Linear(d_inp, d_model, bias=False) 107 | 108 | def forward(self, x): 109 | return self.embed(x) 110 | 111 | 112 | class DataEmbedding(nn.Module): 113 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 114 | super(DataEmbedding, self).__init__() 115 | 116 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 117 | self.position_embedding = PositionalEmbedding(d_model=d_model) 118 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 119 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 120 | d_model=d_model, embed_type=embed_type, freq=freq) 121 | self.dropout = nn.Dropout(p=dropout) 122 | 123 | def forward(self, x, x_mark): 124 | if x_mark is None: 125 | x = self.value_embedding(x) + self.position_embedding(x) 126 | else: 127 | x = self.value_embedding( 128 | x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 129 | return self.dropout(x) 130 | 131 | 132 | class DataEmbedding_inverted(nn.Module): 133 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): 134 | super(DataEmbedding_inverted, self).__init__() 135 | self.device = device 136 | self.value_embedding = nn.Linear(c_in, d_model) 137 | self.value_embedding = self.value_embedding.to(self.device) 138 | self.dropout = nn.Dropout(p=dropout) 139 | 140 | def forward(self, x, x_mark): 141 | x = x.to(self.device) 142 | x_mark = x_mark.to(self.device) 143 | 144 | x = x.permute(0, 2, 1) 145 | # x: [Batch Variate Time] 146 | if x_mark is None: 147 | x = self.value_embedding(x) 148 | else: 149 | device = x.device # 确保它们在同一设备上 150 | x = x.to(device) 151 | x_mark = x_mark.to(device) 152 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 153 | 154 | # x: [Batch Variate d_model] 155 | return self.dropout(x) 156 | 157 | 158 | class DataEmbedding_wo_pos(nn.Module): 159 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 160 | super(DataEmbedding_wo_pos, self).__init__() 161 | 162 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 163 | self.position_embedding = PositionalEmbedding(d_model=d_model) 164 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 165 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 166 | d_model=d_model, embed_type=embed_type, freq=freq) 167 | self.dropout = nn.Dropout(p=dropout) 168 | 169 | def forward(self, x, x_mark): 170 | if x_mark is None: 171 | x = self.value_embedding(x) 172 | else: 173 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 174 | return self.dropout(x) 175 | 176 | 177 | class PatchEmbedding(nn.Module): 178 | def __init__(self, d_model, patch_len, stride, padding, dropout): 179 | super(PatchEmbedding, self).__init__() 180 | # Patching 181 | self.patch_len = patch_len 182 | self.stride = stride 183 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 184 | 185 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 186 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False) 187 | 188 | # Positional embedding 189 | self.position_embedding = PositionalEmbedding(d_model) 190 | 191 | # Residual dropout 192 | self.dropout = nn.Dropout(dropout) 193 | 194 | def forward(self, x): 195 | # do patching 196 | n_vars = x.shape[1] 197 | x = self.padding_patch_layer(x) 198 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 199 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) 200 | # Input encoding 201 | x = self.value_embedding(x) + self.position_embedding(x) 202 | return self.dropout(x), n_vars 203 | -------------------------------------------------------------------------------- /FedSpecNet/Center.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import pickle 4 | import socket 5 | import time 6 | import struct 7 | 8 | def log(info): 9 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' ' + str(info)) 10 | 11 | Epoch = 10 12 | 13 | def m1(*args): 14 | result = args[0].clone() # 使用 clone() 而不是 deepcopy 15 | for i in range(1, len(args)): 16 | result += args[i] 17 | 18 | result /= len(args) 19 | for i in range(len(args)): 20 | args[i].copy_(result) # 原地赋值,避免新建张量 21 | 22 | 23 | def m2(*args): 24 | result = args[0].clone() 25 | for i in range(1, len(args)): 26 | result += args[i] 27 | 28 | result /= len(args) 29 | for i in range(len(args)): 30 | args[i].copy_(result) 31 | 32 | 33 | def socket_udp_server(): 34 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 35 | host = '127.0.0.1' 36 | port = 7002 37 | s.bind((host, port)) 38 | s.listen(5) 39 | print('waiting for connecting') 40 | 41 | for cnt in range(1, Epoch + 1): 42 | log(f"第{cnt}轮开始接收并计时") 43 | connected_socks = [] 44 | res = [] 45 | 46 | while len(connected_socks) < 5: ##number 测试的时候修改这个数量 47 | try: 48 | s.settimeout(100) # 设置较长的超时时间等待连接 49 | sock, addr = s.accept() 50 | log(f'Connection accepted from {addr}') 51 | 52 | data_length_bytes = sock.recv(4) 53 | if not data_length_bytes: 54 | raise ValueError("未接收到数据长度信息") 55 | data_length = int.from_bytes(data_length_bytes, byteorder='big') 56 | 57 | received_data = b'' 58 | while len(received_data) < data_length: 59 | packet = sock.recv(data_length - len(received_data)) 60 | if not packet: 61 | raise ConnectionError("连接中断") 62 | received_data += packet 63 | 64 | tmp = pickle.loads(received_data) 65 | log('Received data: ...') 66 | if tmp['num'] == cnt: 67 | connected_socks.append(sock) 68 | res.append(tmp['model']) 69 | 70 | except socket.timeout: 71 | log("接收数据超时。") 72 | break # 跳出循环,处理已接收的数据 73 | except Exception as e: 74 | log(f"接收数据时发生异常: {e}") 75 | 76 | if res: 77 | # 数据处理逻辑 78 | log(f"第{cnt}轮接收完毕,接收来自{len(res)}个节点的参数") 79 | # 假设m1和m2函数的数据处理逻辑如前所述 80 | log("开始融合处理操作......") 81 | 82 | for m, n in zip(res[0].values(), res[1].values()): 83 | log(f"融合处理参数维度: {m.size()} 和 {n.size()}") 84 | if len(m.size()) == 1: 85 | m1(m, n) 86 | elif len(m.size()) == 2: 87 | m2(m, n) 88 | 89 | data = {} 90 | data['num'] = cnt 91 | data['model'] = res[0] 92 | log('第%d轮融合完毕,下发......' % cnt) 93 | data = pickle.dumps(data) 94 | 95 | 96 | 97 | # 发送ACK确认给所有节点 98 | for sock in connected_socks: 99 | try: 100 | # # 先发送'ACK'确认信号 101 | # ack_message = 'ACK'.encode() 102 | # sock.sendall(ack_message) 103 | 104 | # 发送融合后的数据长度 105 | data_length = len(data) 106 | packed_length = struct.pack('!I', data_length) 107 | sock.sendall(packed_length) 108 | 109 | # 发送实际融合后的数据 110 | sock.sendall(data) 111 | except Exception as e: 112 | log(f"发送数据时出错: {e}") 113 | finally: 114 | # 确保每次发送完毕后关闭连接 115 | sock.close() 116 | log('Data sent and connections closed.') 117 | 118 | # 清理资源,准备下一轮 119 | connected_socks.clear() 120 | 121 | s.close() 122 | log('All epochs completed, server shutdown.') 123 | 124 | def main(): 125 | socket_udp_server() 126 | 127 | if __name__ == '__main__': 128 | main() 129 | 130 | 131 | # import pickle 132 | # import socket 133 | # import time 134 | # import struct 135 | 136 | # def log(info): 137 | # print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' ' + str(info)) 138 | 139 | # Epoch = 10 140 | 141 | # def m1(*args): 142 | # import copy 143 | # result = copy.deepcopy(args[0]) 144 | # for i in range(1, len(args)): 145 | # for j in range(len(result)): 146 | # result[j] += args[i][j] 147 | 148 | # for i in range(len(args)): 149 | # for j in range(len(args[0])): 150 | # args[i][j] = result[j] / len(args) 151 | 152 | # def m2(*args): 153 | # import copy 154 | # result = copy.deepcopy(args[0]) 155 | # for i in range(1, len(args)): 156 | # # print(args[i]) 157 | # for j in range(len(result)): 158 | # for k in range(len(result[0])): 159 | # result[j][k] += args[i][j][k] 160 | 161 | # for i in range(len(args)): 162 | # for j in range(len(result)): 163 | # for k in range(len(result[0])): 164 | # args[i][j][k] = result[j][k] / len(args) 165 | 166 | 167 | # def socket_udp_server(): 168 | # s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 169 | # host = '127.0.0.1' 170 | # port = 7002 171 | # s.bind((host, port)) 172 | # s.listen(5) 173 | # print('waiting for connecting') 174 | 175 | # for cnt in range(1, Epoch + 1): 176 | # log(f"第{cnt}轮开始接收并计时") 177 | # connected_socks = [] 178 | # res = [] 179 | 180 | # while len(connected_socks) < 2: ##number 测试的时候修改这个数量 181 | # try: 182 | # s.settimeout(100) # 设置较长的超时时间等待连接 183 | # sock, addr = s.accept() 184 | # log(f'Connection accepted from {addr}') 185 | 186 | # data_length_bytes = sock.recv(4) 187 | # if not data_length_bytes: 188 | # raise ValueError("未接收到数据长度信息") 189 | # data_length = int.from_bytes(data_length_bytes, byteorder='big') 190 | 191 | # received_data = b'' 192 | # while len(received_data) < data_length: 193 | # packet = sock.recv(data_length - len(received_data)) 194 | # if not packet: 195 | # raise ConnectionError("连接中断") 196 | # received_data += packet 197 | 198 | # tmp = pickle.loads(received_data) 199 | # log('Received data: ...') 200 | # if tmp['num'] == cnt: 201 | # connected_socks.append(sock) 202 | # res.append(tmp['model']) 203 | 204 | # except socket.timeout: 205 | # log("接收数据超时。") 206 | # break # 跳出循环,处理已接收的数据 207 | # except Exception as e: 208 | # log(f"接收数据时发生异常: {e}") 209 | 210 | # if res: 211 | # # 数据处理逻辑 212 | # log(f"第{cnt}轮接收完毕,接收来自{len(res)}个节点的参数") 213 | # # 假设m1和m2函数的数据处理逻辑如前所述 214 | 215 | # log("开始融合处理操作......") 216 | 217 | # for m, n in zip(res[0].values(), res[1].values()): 218 | # if len(m.size()) == 1: 219 | # m1(m, n) 220 | # elif len(m.size()) == 2: 221 | # m2(m, n) 222 | 223 | # data = {} 224 | # data['num'] = cnt 225 | # data['model'] = res[0] 226 | # log('第%d轮融合完毕,下发......' % cnt) 227 | # data = pickle.dumps(data) 228 | 229 | 230 | 231 | # # 发送ACK确认给所有节点 232 | # for sock in connected_socks: 233 | # try: 234 | # # # 先发送'ACK'确认信号 235 | # # ack_message = 'ACK'.encode() 236 | # # sock.sendall(ack_message) 237 | 238 | # # 发送融合后的数据长度 239 | # data_length = len(data) 240 | # packed_length = struct.pack('!I', data_length) 241 | # sock.sendall(packed_length) 242 | 243 | # # 发送实际融合后的数据 244 | # sock.sendall(data) 245 | # except Exception as e: 246 | # log(f"发送数据时出错: {e}") 247 | # finally: 248 | # # 确保每次发送完毕后关闭连接 249 | # sock.close() 250 | # log('Data sent and connections closed.') 251 | 252 | # # 清理资源,准备下一轮 253 | # connected_socks.clear() 254 | 255 | # s.close() 256 | # log('All epochs completed, server shutdown.') 257 | 258 | # def main(): 259 | # socket_udp_server() 260 | 261 | # if __name__ == '__main__': 262 | # main() 263 | -------------------------------------------------------------------------------- /FedSpecNet/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast 4 | import random 5 | import numpy as np 6 | 7 | if __name__ == '__main__': 8 | fix_seed = 2021 9 | random.seed(fix_seed) 10 | torch.manual_seed(fix_seed) 11 | np.random.seed(fix_seed) 12 | 13 | parser = argparse.ArgumentParser(description='TimesNet') 14 | # iTransformer 15 | parser.add_argument('--decomposition_method', type=str, required=True, default="ssa", help='ssa/vmf/swt/ewt/emd') 16 | parser.add_argument('--subsequence_num', type=int, required=True, default=4, help='ssa subsequence num') 17 | parser.add_argument('--interval', type=int, required=True, default=1, help='time step(1,2,4,6,12,means 5min,10min,20min,30min,60min)') 18 | parser.add_argument('--task_name', type=str, required=True, default='long_term_forecast', 19 | help='task name, options:[long_term_forecast, short_term_forecast, imputation, classification, anomaly_detection]') 20 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 21 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 22 | parser.add_argument('--model', type=str, required=True, default='iTransformer', 23 | help='model name, options: [Autoformer, Transformer, TimesNet,iTransformer]') 24 | 25 | # data loader 26 | parser.add_argument('--data', type=str, required=True, default='UK-DALE', help='dataset type') 27 | parser.add_argument('--root_path', type=str, default='./dataset/UK-DALE/', help='root path of the data file') 28 | parser.add_argument('--data_path', type=str, default='house1_5min_KWh.csv', help='data file') 29 | parser.add_argument('--features', type=str, default='M', 30 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 31 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 32 | parser.add_argument('--freq', type=str, default='t', 33 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 34 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 35 | 36 | # forecasting task 37 | parser.add_argument('--seq_len', type=int, default=8, help='input sequence length') 38 | parser.add_argument('--label_len', type=int, default=1, help='start token length') 39 | parser.add_argument('--pred_len', type=int, default=1, help='prediction sequence length') 40 | parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4') 41 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False) 42 | 43 | # inputation task 44 | parser.add_argument('--mask_rate', type=float, default=0.25, help='mask ratio') 45 | 46 | # anomaly detection task 47 | parser.add_argument('--anomaly_ratio', type=float, default=0.25, help='prior anomaly ratio (%)') 48 | 49 | # model define 50 | parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock') 51 | parser.add_argument('--num_kernels', type=int, default=6, help='for Inception') 52 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') #! 53 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')#! 54 | parser.add_argument('--c_out', type=int, default=7, help='output size')#! 55 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model')#! 56 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') #8 57 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') #! 58 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 59 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')#! 60 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 61 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 62 | parser.add_argument('--distil', action='store_false', 63 | help='whether to use distilling in encoder, using this argument means not using distilling', 64 | default=True) 65 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 66 | parser.add_argument('--embed', type=str, default='timeF', 67 | help='time features encoding, options:[timeF, fixed, learned]') #timeF表示使用固定的频率对数据进行处理,这里的频率是h --freq为h 68 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 69 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') 70 | 71 | # optimization 72 | parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers') 73 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 74 | parser.add_argument('--train_epochs', type=int, default=10 , help='train epochs') #10 75 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 76 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 77 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')#0.0001 78 | parser.add_argument('--des', type=str, default='test', help='exp description') 79 | parser.add_argument('--loss', type=str, default='MSE', help='loss function') 80 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 81 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 82 | 83 | # GPUS 84 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 85 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 86 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 87 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') 88 | 89 | # de-stationary projector params 90 | parser.add_argument('--p_hidden_dims', type=int, nargs='+', default=[128, 128], 91 | help='hidden layer dimensions of projector (List)') 92 | parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector') 93 | 94 | args = parser.parse_args() 95 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 96 | if args.use_gpu and args.use_multi_gpu: 97 | args.devices = args.devices.replace(' ', '') 98 | device_ids = args.devices.split(',') 99 | args.device_ids = [int(id_) for id_ in device_ids] 100 | args.gpu = args.device_ids[0] 101 | 102 | print('Args in experiment:') 103 | print(args) 104 | 105 | if args.task_name == 'long_term_forecast': 106 | Exp = Exp_Long_Term_Forecast 107 | if args.is_training: 108 | for ii in range(args.itr): 109 | # setting record of experiments 110 | setting = '{}_{}_{}_{}_ft{}_dm{}_sl{}_ll{}_pl{}_{}_sm{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 111 | args.task_name, 112 | args.model_id, 113 | args.model, 114 | args.data, 115 | args.features, 116 | args.decomposition_method, 117 | args.seq_len, 118 | args.label_len, 119 | args.pred_len, 120 | args.interval, 121 | args.subsequence_num, 122 | args.d_model, 123 | args.n_heads, 124 | args.e_layers, 125 | args.d_layers, 126 | args.d_ff, 127 | args.factor, 128 | args.embed, 129 | args.distil, 130 | args.des, ii) 131 | 132 | exp = Exp(args) # set experiments 133 | # import ipdb; ipdb.set_trace() 134 | exp.train(setting) 135 | exp.test(setting) 136 | print("\n ") 137 | torch.cuda.empty_cache() 138 | print(args.model_id," DONE!!!") 139 | 140 | else: 141 | ii = 0 142 | setting = '{}_{}_{}_{}_ft{}_dm{}_sl{}_ll{}_pl{}_{}_sm{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 143 | args.task_name, 144 | args.model_id, 145 | args.model, 146 | args.data, 147 | args.features, 148 | args.decomposition_method, 149 | args.seq_len, 150 | args.label_len, 151 | args.pred_len, 152 | args.interval, 153 | args.subsequence_num, 154 | args.d_model, 155 | args.n_heads, 156 | args.e_layers, 157 | args.d_layers, 158 | args.d_ff, 159 | args.factor, 160 | args.embed, 161 | args.distil, 162 | args.des, ii) 163 | 164 | exp = Exp(args) # set experiments 165 | exp.test(setting, test=1) 166 | torch.cuda.empty_cache() 167 | -------------------------------------------------------------------------------- /FedSpecNet_v2/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast 4 | import random 5 | import numpy as np 6 | 7 | if __name__ == '__main__': 8 | fix_seed = 2021 9 | random.seed(fix_seed) 10 | torch.manual_seed(fix_seed) 11 | np.random.seed(fix_seed) 12 | 13 | parser = argparse.ArgumentParser(description='TimesNet') 14 | # iTransformer 15 | parser.add_argument('--decomposition_method', type=str, required=True, default="ssa", help='ssa/vmf/swt/ewt/emd') 16 | parser.add_argument('--subsequence_num', type=int, required=True, default=4, help='ssa subsequence num') 17 | parser.add_argument('--interval', type=int, required=True, default=1, help='time step(1,2,4,6,12,means 5min,10min,20min,30min,60min)') 18 | parser.add_argument('--task_name', type=str, required=True, default='long_term_forecast', 19 | help='task name, options:[long_term_forecast, short_term_forecast, imputation, classification, anomaly_detection]') 20 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 21 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 22 | parser.add_argument('--model', type=str, required=True, default='iTransformer', 23 | help='model name, options: [Autoformer, Transformer, TimesNet,iTransformer]') 24 | 25 | # data loader 26 | parser.add_argument('--data', type=str, required=True, default='UK-DALE', help='dataset type') 27 | parser.add_argument('--root_path', type=str, default='./dataset/UK-DALE/', help='root path of the data file') 28 | parser.add_argument('--data_path', type=str, default='house1_5min_KWh.csv', help='data file') 29 | parser.add_argument('--features', type=str, default='M', 30 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 31 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 32 | parser.add_argument('--freq', type=str, default='t', 33 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 34 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 35 | 36 | # forecasting task 37 | parser.add_argument('--seq_len', type=int, default=8, help='input sequence length') 38 | parser.add_argument('--label_len', type=int, default=1, help='start token length') 39 | parser.add_argument('--pred_len', type=int, default=1, help='prediction sequence length') 40 | parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4') 41 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False) 42 | 43 | # inputation task 44 | parser.add_argument('--mask_rate', type=float, default=0.25, help='mask ratio') 45 | 46 | # anomaly detection task 47 | parser.add_argument('--anomaly_ratio', type=float, default=0.25, help='prior anomaly ratio (%)') 48 | 49 | # model define 50 | parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock') 51 | parser.add_argument('--num_kernels', type=int, default=6, help='for Inception') 52 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') #! 53 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')#! 54 | parser.add_argument('--c_out', type=int, default=7, help='output size')#! 55 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model')#! 56 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') #8 57 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') #! 58 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 59 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')#! 60 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 61 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 62 | parser.add_argument('--distil', action='store_false', 63 | help='whether to use distilling in encoder, using this argument means not using distilling', 64 | default=True) 65 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 66 | parser.add_argument('--embed', type=str, default='timeF', 67 | help='time features encoding, options:[timeF, fixed, learned]') #timeF表示使用固定的频率对数据进行处理,这里的频率是h --freq为h 68 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 69 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') 70 | 71 | # optimization 72 | parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers') 73 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 74 | parser.add_argument('--train_epochs', type=int, default=10 , help='train epochs') #10 75 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 76 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 77 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')#0.0001 78 | parser.add_argument('--des', type=str, default='test', help='exp description') 79 | parser.add_argument('--loss', type=str, default='MSE', help='loss function') 80 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 81 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 82 | 83 | # GPUS 84 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 85 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 86 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 87 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') 88 | 89 | # de-stationary projector params 90 | parser.add_argument('--p_hidden_dims', type=int, nargs='+', default=[128, 128], 91 | help='hidden layer dimensions of projector (List)') 92 | parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector') 93 | 94 | args = parser.parse_args() 95 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 96 | if args.use_gpu and args.use_multi_gpu: 97 | args.devices = args.devices.replace(' ', '') 98 | device_ids = args.devices.split(',') 99 | args.device_ids = [int(id_) for id_ in device_ids] 100 | args.gpu = args.device_ids[0] 101 | 102 | print('Args in experiment:') 103 | print(args) 104 | 105 | if args.task_name == 'long_term_forecast': 106 | Exp = Exp_Long_Term_Forecast 107 | if args.is_training: 108 | for ii in range(args.itr): 109 | # setting record of experiments 110 | setting = '{}_{}_{}_{}_ft{}_dm{}_sl{}_ll{}_pl{}_{}_sm{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 111 | args.task_name, 112 | args.model_id, 113 | args.model, 114 | args.data, 115 | args.features, 116 | args.decomposition_method, 117 | args.seq_len, 118 | args.label_len, 119 | args.pred_len, 120 | args.interval, 121 | args.subsequence_num, 122 | args.d_model, 123 | args.n_heads, 124 | args.e_layers, 125 | args.d_layers, 126 | args.d_ff, 127 | args.factor, 128 | args.embed, 129 | args.distil, 130 | args.des, ii) 131 | 132 | exp = Exp(args) # set experiments 133 | # import ipdb; ipdb.set_trace() 134 | exp.train(setting) 135 | exp.test(setting) 136 | print("\n ") 137 | torch.cuda.empty_cache() 138 | print(args.model_id," DONE!!!") 139 | 140 | else: 141 | ii = 0 142 | setting = '{}_{}_{}_{}_ft{}_dm{}_sl{}_ll{}_pl{}_{}_sm{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 143 | args.task_name, 144 | args.model_id, 145 | args.model, 146 | args.data, 147 | args.features, 148 | args.decomposition_method, 149 | args.seq_len, 150 | args.label_len, 151 | args.pred_len, 152 | args.interval, 153 | args.subsequence_num, 154 | args.d_model, 155 | args.n_heads, 156 | args.e_layers, 157 | args.d_layers, 158 | args.d_ff, 159 | args.factor, 160 | args.embed, 161 | args.distil, 162 | args.des, ii) 163 | 164 | exp = Exp(args) # set experiments 165 | exp.test(setting, test=1) 166 | torch.cuda.empty_cache() 167 | -------------------------------------------------------------------------------- /FedSpecNet_v2/Center.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import pickle 4 | import socket 5 | import time 6 | import struct 7 | 8 | import ssl 9 | 10 | def log(info): 11 | print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' ' + str(info)) 12 | 13 | Epoch = 10 14 | 15 | def m1(args): 16 | result = args[0].clone() # 使用 clone() 而不是 deepcopy 17 | for i in range(1, len(args)): 18 | result += args[i] 19 | 20 | result /= len(args) 21 | for i in range(len(args)): 22 | args[i].copy_(result) # 原地赋值,避免新建张量 23 | 24 | 25 | def m2(args): 26 | result = args[0].clone() 27 | for i in range(1, len(args)): 28 | result += args[i] 29 | 30 | result /= len(args) 31 | for i in range(len(args)): 32 | args[i].copy_(result) 33 | return result 34 | 35 | 36 | def socket_udp_server(): 37 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 38 | host = '127.0.0.1' 39 | port = 7002 40 | s.bind((host, port)) 41 | s.listen(5) 42 | print('waiting for connecting') 43 | 44 | # 加载证书和密钥 45 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 46 | context.load_cert_chain(certfile="server.crt", keyfile="server.key") 47 | 48 | for cnt in range(1, Epoch + 1): 49 | log(f"第{cnt}轮开始接收并计时") 50 | connected_socks = [] 51 | res = [] 52 | 53 | while len(connected_socks) < 5: ##number 测试的时候修改这个数量 54 | try: 55 | s.settimeout(100) # 设置较长的超时时间等待连接 56 | sock, addr = s.accept() 57 | log(f'Connection accepted from {addr}') 58 | 59 | # socket包装和通信 60 | try: 61 | ssl_sock = context.wrap_socket(sock, server_side=True) 62 | data_length_bytes = ssl_sock.recv(4) 63 | if not data_length_bytes: 64 | raise ValueError("未接收到数据长度信息") 65 | data_length = int.from_bytes(data_length_bytes, byteorder='big') 66 | received_data = b'' 67 | while len(received_data) < data_length: 68 | packet = ssl_sock.recv(data_length - len(received_data)) 69 | if not packet: 70 | raise ConnectionError("连接中断") 71 | received_data += packet 72 | tmp = pickle.loads(received_data) 73 | log('Received data: ...') 74 | print(tmp['num']) 75 | print(cnt) 76 | if tmp['num'] == cnt: 77 | connected_socks.append(ssl_sock) 78 | print(len(connected_socks)) 79 | res.append(tmp['model']) 80 | except Exception as e: 81 | print(f"error: {e}") 82 | 83 | except socket.timeout: 84 | log("接收数据超时。") 85 | break # 跳出循环,处理已接收的数据 86 | except Exception as e: 87 | log(f"接收数据时发生异常: {e}") 88 | 89 | if res: 90 | # 数据处理逻辑 91 | log(f"第{cnt}轮接收完毕,接收来自{len(res)}个节点的参数") 92 | # 假设m1和m2函数的数据处理逻辑如前所述 93 | log("开始融合处理操作......") 94 | 95 | for item in zip(*[d.values() for d in res]): 96 | logstr = ", ".join(str(obj.size()) for obj in item) 97 | log(f"融合数据尺寸: {logstr}") 98 | if len(item[0].size()) == 1: 99 | m1(item) 100 | elif len(item[0].size()) == 2: 101 | m2(item) 102 | 103 | data = {} 104 | data['num'] = cnt 105 | data['model'] = res[0] 106 | log('第%d轮融合完毕,下发......' % cnt) 107 | data = pickle.dumps(data) 108 | 109 | 110 | 111 | # 发送ACK确认给所有节点 112 | for ssl_sock in connected_socks: 113 | try: 114 | # # 先发送'ACK'确认信号 115 | # ack_message = 'ACK'.encode() 116 | # sock.sendall(ack_message) 117 | 118 | # 发送融合后的数据长度 119 | data_length = len(data) 120 | packed_length = struct.pack('!I', data_length) 121 | ssl_sock.sendall(packed_length) 122 | 123 | # 发送实际融合后的数据 124 | ssl_sock.sendall(data) 125 | except Exception as e: 126 | log(f"发送数据时出错: {e}") 127 | finally: 128 | # 确保每次发送完毕后关闭连接 129 | ssl_sock.close() 130 | log('Data sent and connections closed.') 131 | 132 | # 清理资源,准备下一轮 133 | connected_socks.clear() 134 | 135 | s.close() 136 | log('All epochs completed, server shutdown.') 137 | 138 | def main(): 139 | socket_udp_server() 140 | 141 | if __name__ == '__main__': 142 | main() 143 | 144 | 145 | # import pickle 146 | # import socket 147 | # import time 148 | # import struct 149 | 150 | # def log(info): 151 | # print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' ' + str(info)) 152 | 153 | # Epoch = 10 154 | 155 | # def m1(*args): 156 | # import copy 157 | # result = copy.deepcopy(args[0]) 158 | # for i in range(1, len(args)): 159 | # for j in range(len(result)): 160 | # result[j] += args[i][j] 161 | 162 | # for i in range(len(args)): 163 | # for j in range(len(args[0])): 164 | # args[i][j] = result[j] / len(args) 165 | 166 | # def m2(*args): 167 | # import copy 168 | # result = copy.deepcopy(args[0]) 169 | # for i in range(1, len(args)): 170 | # # print(args[i]) 171 | # for j in range(len(result)): 172 | # for k in range(len(result[0])): 173 | # result[j][k] += args[i][j][k] 174 | 175 | # for i in range(len(args)): 176 | # for j in range(len(result)): 177 | # for k in range(len(result[0])): 178 | # args[i][j][k] = result[j][k] / len(args) 179 | 180 | 181 | # def socket_udp_server(): 182 | # s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 183 | # host = '127.0.0.1' 184 | # port = 7002 185 | # s.bind((host, port)) 186 | # s.listen(5) 187 | # print('waiting for connecting') 188 | 189 | # for cnt in range(1, Epoch + 1): 190 | # log(f"第{cnt}轮开始接收并计时") 191 | # connected_socks = [] 192 | # res = [] 193 | 194 | # while len(connected_socks) < 2: ##number 测试的时候修改这个数量 195 | # try: 196 | # s.settimeout(100) # 设置较长的超时时间等待连接 197 | # sock, addr = s.accept() 198 | # log(f'Connection accepted from {addr}') 199 | 200 | # data_length_bytes = sock.recv(4) 201 | # if not data_length_bytes: 202 | # raise ValueError("未接收到数据长度信息") 203 | # data_length = int.from_bytes(data_length_bytes, byteorder='big') 204 | 205 | # received_data = b'' 206 | # while len(received_data) < data_length: 207 | # packet = sock.recv(data_length - len(received_data)) 208 | # if not packet: 209 | # raise ConnectionError("连接中断") 210 | # received_data += packet 211 | 212 | # tmp = pickle.loads(received_data) 213 | # log('Received data: ...') 214 | # if tmp['num'] == cnt: 215 | # connected_socks.append(sock) 216 | # res.append(tmp['model']) 217 | 218 | # except socket.timeout: 219 | # log("接收数据超时。") 220 | # break # 跳出循环,处理已接收的数据 221 | # except Exception as e: 222 | # log(f"接收数据时发生异常: {e}") 223 | 224 | # if res: 225 | # # 数据处理逻辑 226 | # log(f"第{cnt}轮接收完毕,接收来自{len(res)}个节点的参数") 227 | # # 假设m1和m2函数的数据处理逻辑如前所述 228 | 229 | # log("开始融合处理操作......") 230 | 231 | # for m, n in zip(res[0].values(), res[1].values()): 232 | # if len(m.size()) == 1: 233 | # m1(m, n) 234 | # elif len(m.size()) == 2: 235 | # m2(m, n) 236 | 237 | # data = {} 238 | # data['num'] = cnt 239 | # data['model'] = res[0] 240 | # log('第%d轮融合完毕,下发......' % cnt) 241 | # data = pickle.dumps(data) 242 | 243 | 244 | 245 | # # 发送ACK确认给所有节点 246 | # for sock in connected_socks: 247 | # try: 248 | # # # 先发送'ACK'确认信号 249 | # # ack_message = 'ACK'.encode() 250 | # # sock.sendall(ack_message) 251 | 252 | # # 发送融合后的数据长度 253 | # data_length = len(data) 254 | # packed_length = struct.pack('!I', data_length) 255 | # sock.sendall(packed_length) 256 | 257 | # # 发送实际融合后的数据 258 | # sock.sendall(data) 259 | # except Exception as e: 260 | # log(f"发送数据时出错: {e}") 261 | # finally: 262 | # # 确保每次发送完毕后关闭连接 263 | # sock.close() 264 | # log('Data sent and connections closed.') 265 | 266 | # # 清理资源,准备下一轮 267 | # connected_socks.clear() 268 | 269 | # s.close() 270 | # log('All epochs completed, server shutdown.') 271 | 272 | # def main(): 273 | # socket_udp_server() 274 | 275 | # if __name__ == '__main__': 276 | # main() 277 | -------------------------------------------------------------------------------- /FedSpecNet/layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from math import sqrt 5 | from utils.masking import TriangularCausalMask, ProbMask 6 | from reformer_pytorch import LSHSelfAttention 7 | from einops import rearrange, repeat 8 | 9 | 10 | class DSAttention(nn.Module): 11 | '''De-stationary Attention''' 12 | 13 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 14 | super(DSAttention, self).__init__() 15 | self.scale = scale 16 | self.mask_flag = mask_flag 17 | self.output_attention = output_attention 18 | self.dropout = nn.Dropout(attention_dropout) 19 | 20 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 21 | B, L, H, E = queries.shape 22 | _, S, _, D = values.shape 23 | scale = self.scale or 1. / sqrt(E) 24 | 25 | tau = 1.0 if tau is None else tau.unsqueeze( 26 | 1).unsqueeze(1) # B x 1 x 1 x 1 27 | delta = 0.0 if delta is None else delta.unsqueeze( 28 | 1).unsqueeze(1) # B x 1 x 1 x S 29 | 30 | # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors 31 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta 32 | 33 | if self.mask_flag: 34 | if attn_mask is None: 35 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 36 | 37 | scores.masked_fill_(attn_mask.mask, -np.inf) 38 | 39 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 40 | V = torch.einsum("bhls,bshd->blhd", A, values) 41 | 42 | if self.output_attention: 43 | return V.contiguous(), A 44 | else: 45 | return V.contiguous(), None 46 | 47 | 48 | class FullAttention(nn.Module): 49 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 50 | super(FullAttention, self).__init__() 51 | self.scale = scale 52 | self.mask_flag = mask_flag 53 | self.output_attention = output_attention 54 | self.dropout = nn.Dropout(attention_dropout) 55 | 56 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 57 | B, L, H, E = queries.shape 58 | _, S, _, D = values.shape 59 | scale = self.scale or 1. / sqrt(E) 60 | 61 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 62 | 63 | if self.mask_flag: 64 | if attn_mask is None: 65 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 66 | 67 | scores.masked_fill_(attn_mask.mask, -np.inf) 68 | 69 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 70 | V = torch.einsum("bhls,bshd->blhd", A, values) 71 | if self.output_attention: 72 | return V.contiguous(), A 73 | else: 74 | return V.contiguous(), None 75 | 76 | 77 | class ProbAttention(nn.Module): 78 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 79 | super(ProbAttention, self).__init__() 80 | self.factor = factor 81 | self.scale = scale 82 | self.mask_flag = mask_flag 83 | self.output_attention = output_attention 84 | self.dropout = nn.Dropout(attention_dropout) 85 | 86 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 87 | # Q [B, H, L, D] 88 | B, H, L_K, E = K.shape 89 | _, _, L_Q, _ = Q.shape 90 | 91 | # calculate the sampled Q_K 92 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 93 | # real U = U_part(factor*ln(L_k))*L_q 94 | index_sample = torch.randint(L_K, (L_Q, sample_k)) 95 | K_sample = K_expand[:, :, torch.arange( 96 | L_Q).unsqueeze(1), index_sample, :] 97 | Q_K_sample = torch.matmul( 98 | Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 99 | 100 | # find the Top_k query with sparisty measurement 101 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 102 | M_top = M.topk(n_top, sorted=False)[1] 103 | 104 | # use the reduced Q to calculate Q_K 105 | Q_reduce = Q[torch.arange(B)[:, None, None], 106 | torch.arange(H)[None, :, None], 107 | M_top, :] # factor*ln(L_q) 108 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 109 | 110 | return Q_K, M_top 111 | 112 | def _get_initial_context(self, V, L_Q): 113 | B, H, L_V, D = V.shape 114 | if not self.mask_flag: 115 | # V_sum = V.sum(dim=-2) 116 | V_sum = V.mean(dim=-2) 117 | contex = V_sum.unsqueeze(-2).expand(B, H, 118 | L_Q, V_sum.shape[-1]).clone() 119 | else: # use mask 120 | # requires that L_Q == L_V, i.e. for self-attention only 121 | assert (L_Q == L_V) 122 | contex = V.cumsum(dim=-2) 123 | return contex 124 | 125 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 126 | B, H, L_V, D = V.shape 127 | 128 | if self.mask_flag: 129 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 130 | scores.masked_fill_(attn_mask.mask, -np.inf) 131 | 132 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 133 | 134 | context_in[torch.arange(B)[:, None, None], 135 | torch.arange(H)[None, :, None], 136 | index, :] = torch.matmul(attn, V).type_as(context_in) 137 | if self.output_attention: 138 | attns = (torch.ones([B, H, L_V, L_V]) / 139 | L_V).type_as(attn).to(attn.device) 140 | attns[torch.arange(B)[:, None, None], torch.arange(H)[ 141 | None, :, None], index, :] = attn 142 | return context_in, attns 143 | else: 144 | return context_in, None 145 | 146 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 147 | B, L_Q, H, D = queries.shape 148 | _, L_K, _, _ = keys.shape 149 | 150 | queries = queries.transpose(2, 1) 151 | keys = keys.transpose(2, 1) 152 | values = values.transpose(2, 1) 153 | 154 | U_part = self.factor * \ 155 | np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 156 | u = self.factor * \ 157 | np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 158 | 159 | U_part = U_part if U_part < L_K else L_K 160 | u = u if u < L_Q else L_Q 161 | 162 | scores_top, index = self._prob_QK( 163 | queries, keys, sample_k=U_part, n_top=u) 164 | 165 | # add scale factor 166 | scale = self.scale or 1. / sqrt(D) 167 | if scale is not None: 168 | scores_top = scores_top * scale 169 | # get the context 170 | context = self._get_initial_context(values, L_Q) 171 | # update the context with selected top_k queries 172 | context, attn = self._update_context( 173 | context, values, scores_top, index, L_Q, attn_mask) 174 | 175 | return context.contiguous(), attn 176 | 177 | 178 | class AttentionLayer(nn.Module): 179 | def __init__(self, attention, d_model, n_heads, d_keys=None, 180 | d_values=None, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")): 181 | super(AttentionLayer, self).__init__() 182 | 183 | d_keys = d_keys or (d_model // n_heads) 184 | d_values = d_values or (d_model // n_heads) 185 | 186 | self.inner_attention = attention 187 | self.query_projection = nn.Linear(d_model, d_keys * n_heads).to(device) 188 | self.key_projection = nn.Linear(d_model, d_keys * n_heads).to(device) 189 | self.value_projection = nn.Linear(d_model, d_values * n_heads).to(device) 190 | self.out_projection = nn.Linear(d_values * n_heads, d_model).to(device) 191 | self.n_heads = n_heads 192 | self.device = device 193 | 194 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 195 | queries = queries.to(self.device) 196 | keys = keys.to(self.device) 197 | values = values.to(self.device) 198 | 199 | B, L, _ = queries.shape #B-batch 200 | _, S, _ = keys.shape 201 | H = self.n_heads #8 202 | 203 | queries = self.query_projection(queries).view(B, L, H, -1) 204 | keys = self.key_projection(keys).view(B, S, H, -1) 205 | values = self.value_projection(values).view(B, S, H, -1) 206 | 207 | out, attn = self.inner_attention( 208 | queries, 209 | keys, 210 | values, 211 | attn_mask, 212 | tau=tau, 213 | delta=delta 214 | ) 215 | out = out.view(B, L, -1) 216 | 217 | return self.out_projection(out), attn 218 | 219 | 220 | class ReformerLayer(nn.Module): 221 | def __init__(self, attention, d_model, n_heads, d_keys=None, 222 | d_values=None, causal=False, bucket_size=4, n_hashes=4): 223 | super().__init__() 224 | self.bucket_size = bucket_size 225 | self.attn = LSHSelfAttention( 226 | dim=d_model, 227 | heads=n_heads, 228 | bucket_size=bucket_size, 229 | n_hashes=n_hashes, 230 | causal=causal 231 | ) 232 | 233 | def fit_length(self, queries): 234 | # inside reformer: assert N % (bucket_size * 2) == 0 235 | B, N, C = queries.shape 236 | if N % (self.bucket_size * 2) == 0: 237 | return queries 238 | else: 239 | # fill the time series 240 | fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) 241 | return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) 242 | 243 | def forward(self, queries, keys, values, attn_mask, tau, delta): 244 | # in Reformer: defalut queries=keys 245 | B, N, C = queries.shape 246 | queries = self.attn(self.fit_length(queries))[:, :N, :] 247 | return queries, None 248 | 249 | 250 | class TwoStageAttentionLayer(nn.Module): 251 | ''' 252 | The Two Stage Attention (TSA) Layer 253 | input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model] 254 | ''' 255 | 256 | def __init__(self, configs, 257 | seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1): 258 | super(TwoStageAttentionLayer, self).__init__() 259 | d_ff = d_ff or 4 * d_model 260 | self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 261 | output_attention=configs.output_attention), d_model, n_heads) 262 | self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 263 | output_attention=configs.output_attention), d_model, n_heads) 264 | self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 265 | output_attention=configs.output_attention), d_model, n_heads) 266 | self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) 267 | 268 | self.dropout = nn.Dropout(dropout) 269 | 270 | self.norm1 = nn.LayerNorm(d_model) 271 | self.norm2 = nn.LayerNorm(d_model) 272 | self.norm3 = nn.LayerNorm(d_model) 273 | self.norm4 = nn.LayerNorm(d_model) 274 | 275 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), 276 | nn.GELU(), 277 | nn.Linear(d_ff, d_model)) 278 | self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff), 279 | nn.GELU(), 280 | nn.Linear(d_ff, d_model)) 281 | 282 | def forward(self, x, attn_mask=None, tau=None, delta=None): 283 | # Cross Time Stage: Directly apply MSA to each dimension 284 | batch = x.shape[0] 285 | time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') 286 | time_enc, attn = self.time_attention( 287 | time_in, time_in, time_in, attn_mask=None, tau=None, delta=None 288 | ) 289 | dim_in = time_in + self.dropout(time_enc) 290 | dim_in = self.norm1(dim_in) 291 | dim_in = dim_in + self.dropout(self.MLP1(dim_in)) 292 | dim_in = self.norm2(dim_in) 293 | 294 | # Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection 295 | dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch) 296 | batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch) 297 | dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None) 298 | dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None) 299 | dim_enc = dim_send + self.dropout(dim_receive) 300 | dim_enc = self.norm3(dim_enc) 301 | dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) 302 | dim_enc = self.norm4(dim_enc) 303 | 304 | final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch) 305 | 306 | return final_out 307 | -------------------------------------------------------------------------------- /FedSpecNet_v2/layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from math import sqrt 5 | from utils.masking import TriangularCausalMask, ProbMask 6 | from reformer_pytorch import LSHSelfAttention 7 | from einops import rearrange, repeat 8 | 9 | 10 | class DSAttention(nn.Module): 11 | '''De-stationary Attention''' 12 | 13 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 14 | super(DSAttention, self).__init__() 15 | self.scale = scale 16 | self.mask_flag = mask_flag 17 | self.output_attention = output_attention 18 | self.dropout = nn.Dropout(attention_dropout) 19 | 20 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 21 | B, L, H, E = queries.shape 22 | _, S, _, D = values.shape 23 | scale = self.scale or 1. / sqrt(E) 24 | 25 | tau = 1.0 if tau is None else tau.unsqueeze( 26 | 1).unsqueeze(1) # B x 1 x 1 x 1 27 | delta = 0.0 if delta is None else delta.unsqueeze( 28 | 1).unsqueeze(1) # B x 1 x 1 x S 29 | 30 | # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors 31 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta 32 | 33 | if self.mask_flag: 34 | if attn_mask is None: 35 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 36 | 37 | scores.masked_fill_(attn_mask.mask, -np.inf) 38 | 39 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 40 | V = torch.einsum("bhls,bshd->blhd", A, values) 41 | 42 | if self.output_attention: 43 | return V.contiguous(), A 44 | else: 45 | return V.contiguous(), None 46 | 47 | 48 | class FullAttention(nn.Module): 49 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 50 | super(FullAttention, self).__init__() 51 | self.scale = scale 52 | self.mask_flag = mask_flag 53 | self.output_attention = output_attention 54 | self.dropout = nn.Dropout(attention_dropout) 55 | 56 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 57 | B, L, H, E = queries.shape 58 | _, S, _, D = values.shape 59 | scale = self.scale or 1. / sqrt(E) 60 | 61 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 62 | 63 | if self.mask_flag: 64 | if attn_mask is None: 65 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 66 | 67 | scores.masked_fill_(attn_mask.mask, -np.inf) 68 | 69 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 70 | V = torch.einsum("bhls,bshd->blhd", A, values) 71 | if self.output_attention: 72 | return V.contiguous(), A 73 | else: 74 | return V.contiguous(), None 75 | 76 | 77 | class ProbAttention(nn.Module): 78 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 79 | super(ProbAttention, self).__init__() 80 | self.factor = factor 81 | self.scale = scale 82 | self.mask_flag = mask_flag 83 | self.output_attention = output_attention 84 | self.dropout = nn.Dropout(attention_dropout) 85 | 86 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 87 | # Q [B, H, L, D] 88 | B, H, L_K, E = K.shape 89 | _, _, L_Q, _ = Q.shape 90 | 91 | # calculate the sampled Q_K 92 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 93 | # real U = U_part(factor*ln(L_k))*L_q 94 | index_sample = torch.randint(L_K, (L_Q, sample_k)) 95 | K_sample = K_expand[:, :, torch.arange( 96 | L_Q).unsqueeze(1), index_sample, :] 97 | Q_K_sample = torch.matmul( 98 | Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 99 | 100 | # find the Top_k query with sparisty measurement 101 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 102 | M_top = M.topk(n_top, sorted=False)[1] 103 | 104 | # use the reduced Q to calculate Q_K 105 | Q_reduce = Q[torch.arange(B)[:, None, None], 106 | torch.arange(H)[None, :, None], 107 | M_top, :] # factor*ln(L_q) 108 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 109 | 110 | return Q_K, M_top 111 | 112 | def _get_initial_context(self, V, L_Q): 113 | B, H, L_V, D = V.shape 114 | if not self.mask_flag: 115 | # V_sum = V.sum(dim=-2) 116 | V_sum = V.mean(dim=-2) 117 | contex = V_sum.unsqueeze(-2).expand(B, H, 118 | L_Q, V_sum.shape[-1]).clone() 119 | else: # use mask 120 | # requires that L_Q == L_V, i.e. for self-attention only 121 | assert (L_Q == L_V) 122 | contex = V.cumsum(dim=-2) 123 | return contex 124 | 125 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 126 | B, H, L_V, D = V.shape 127 | 128 | if self.mask_flag: 129 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 130 | scores.masked_fill_(attn_mask.mask, -np.inf) 131 | 132 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 133 | 134 | context_in[torch.arange(B)[:, None, None], 135 | torch.arange(H)[None, :, None], 136 | index, :] = torch.matmul(attn, V).type_as(context_in) 137 | if self.output_attention: 138 | attns = (torch.ones([B, H, L_V, L_V]) / 139 | L_V).type_as(attn).to(attn.device) 140 | attns[torch.arange(B)[:, None, None], torch.arange(H)[ 141 | None, :, None], index, :] = attn 142 | return context_in, attns 143 | else: 144 | return context_in, None 145 | 146 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 147 | B, L_Q, H, D = queries.shape 148 | _, L_K, _, _ = keys.shape 149 | 150 | queries = queries.transpose(2, 1) 151 | keys = keys.transpose(2, 1) 152 | values = values.transpose(2, 1) 153 | 154 | U_part = self.factor * \ 155 | np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 156 | u = self.factor * \ 157 | np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 158 | 159 | U_part = U_part if U_part < L_K else L_K 160 | u = u if u < L_Q else L_Q 161 | 162 | scores_top, index = self._prob_QK( 163 | queries, keys, sample_k=U_part, n_top=u) 164 | 165 | # add scale factor 166 | scale = self.scale or 1. / sqrt(D) 167 | if scale is not None: 168 | scores_top = scores_top * scale 169 | # get the context 170 | context = self._get_initial_context(values, L_Q) 171 | # update the context with selected top_k queries 172 | context, attn = self._update_context( 173 | context, values, scores_top, index, L_Q, attn_mask) 174 | 175 | return context.contiguous(), attn 176 | 177 | 178 | class AttentionLayer(nn.Module): 179 | def __init__(self, attention, d_model, n_heads, d_keys=None, 180 | d_values=None, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")): 181 | super(AttentionLayer, self).__init__() 182 | 183 | d_keys = d_keys or (d_model // n_heads) 184 | d_values = d_values or (d_model // n_heads) 185 | 186 | self.inner_attention = attention 187 | self.query_projection = nn.Linear(d_model, d_keys * n_heads).to(device) 188 | self.key_projection = nn.Linear(d_model, d_keys * n_heads).to(device) 189 | self.value_projection = nn.Linear(d_model, d_values * n_heads).to(device) 190 | self.out_projection = nn.Linear(d_values * n_heads, d_model).to(device) 191 | self.n_heads = n_heads 192 | self.device = device 193 | 194 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 195 | queries = queries.to(self.device) 196 | keys = keys.to(self.device) 197 | values = values.to(self.device) 198 | 199 | B, L, _ = queries.shape #B-batch 200 | _, S, _ = keys.shape 201 | H = self.n_heads #8 202 | 203 | queries = self.query_projection(queries).view(B, L, H, -1) 204 | keys = self.key_projection(keys).view(B, S, H, -1) 205 | values = self.value_projection(values).view(B, S, H, -1) 206 | 207 | out, attn = self.inner_attention( 208 | queries, 209 | keys, 210 | values, 211 | attn_mask, 212 | tau=tau, 213 | delta=delta 214 | ) 215 | out = out.view(B, L, -1) 216 | 217 | return self.out_projection(out), attn 218 | 219 | 220 | class ReformerLayer(nn.Module): 221 | def __init__(self, attention, d_model, n_heads, d_keys=None, 222 | d_values=None, causal=False, bucket_size=4, n_hashes=4): 223 | super().__init__() 224 | self.bucket_size = bucket_size 225 | self.attn = LSHSelfAttention( 226 | dim=d_model, 227 | heads=n_heads, 228 | bucket_size=bucket_size, 229 | n_hashes=n_hashes, 230 | causal=causal 231 | ) 232 | 233 | def fit_length(self, queries): 234 | # inside reformer: assert N % (bucket_size * 2) == 0 235 | B, N, C = queries.shape 236 | if N % (self.bucket_size * 2) == 0: 237 | return queries 238 | else: 239 | # fill the time series 240 | fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) 241 | return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) 242 | 243 | def forward(self, queries, keys, values, attn_mask, tau, delta): 244 | # in Reformer: defalut queries=keys 245 | B, N, C = queries.shape 246 | queries = self.attn(self.fit_length(queries))[:, :N, :] 247 | return queries, None 248 | 249 | 250 | class TwoStageAttentionLayer(nn.Module): 251 | ''' 252 | The Two Stage Attention (TSA) Layer 253 | input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model] 254 | ''' 255 | 256 | def __init__(self, configs, 257 | seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1): 258 | super(TwoStageAttentionLayer, self).__init__() 259 | d_ff = d_ff or 4 * d_model 260 | self.time_attention = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 261 | output_attention=configs.output_attention), d_model, n_heads) 262 | self.dim_sender = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 263 | output_attention=configs.output_attention), d_model, n_heads) 264 | self.dim_receiver = AttentionLayer(FullAttention(False, configs.factor, attention_dropout=configs.dropout, 265 | output_attention=configs.output_attention), d_model, n_heads) 266 | self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) 267 | 268 | self.dropout = nn.Dropout(dropout) 269 | 270 | self.norm1 = nn.LayerNorm(d_model) 271 | self.norm2 = nn.LayerNorm(d_model) 272 | self.norm3 = nn.LayerNorm(d_model) 273 | self.norm4 = nn.LayerNorm(d_model) 274 | 275 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), 276 | nn.GELU(), 277 | nn.Linear(d_ff, d_model)) 278 | self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff), 279 | nn.GELU(), 280 | nn.Linear(d_ff, d_model)) 281 | 282 | def forward(self, x, attn_mask=None, tau=None, delta=None): 283 | # Cross Time Stage: Directly apply MSA to each dimension 284 | batch = x.shape[0] 285 | time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') 286 | time_enc, attn = self.time_attention( 287 | time_in, time_in, time_in, attn_mask=None, tau=None, delta=None 288 | ) 289 | dim_in = time_in + self.dropout(time_enc) 290 | dim_in = self.norm1(dim_in) 291 | dim_in = dim_in + self.dropout(self.MLP1(dim_in)) 292 | dim_in = self.norm2(dim_in) 293 | 294 | # Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection 295 | dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch) 296 | batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch) 297 | dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None) 298 | dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None) 299 | dim_enc = dim_send + self.dropout(dim_receive) 300 | dim_enc = self.norm3(dim_enc) 301 | dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) 302 | dim_enc = self.norm4(dim_enc) 303 | 304 | final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b=batch) 305 | 306 | return final_out 307 | -------------------------------------------------------------------------------- /FedSpecNet/Subsequence_number_experiment.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | from matplotlib.ticker import FormatStrFormatter 4 | 5 | 6 | def plot_subsequence_result(data_path): 7 | data = pd.read_csv(data_path) 8 | x = data['Number-of-Subsequence'][:30] 9 | y_house1_mae = data['MAE'][:30] 10 | y_house2_mae = data['MAE-2'][:30] 11 | y_house3_mae = data['MAE-3'][:30] 12 | y_house4_mae = data['MAE-4'][:30] 13 | y_house5_mae = data['MAE-5'][:30] 14 | 15 | y_house1_mae1 = data['MAPE'][:30] 16 | y_house1_mae2 = data['MAPE-2'][:30] 17 | y_house1_mae3 = data['MAPE-3'][:30] 18 | y_house1_mae4 = data['MAPE-4'][:30] 19 | y_house1_mae5 = data['MAPE-5'][:30] 20 | 21 | y_house2_mae1 = data['RMSE'][:30] 22 | y_house2_mae2 = data['RMSE-2'][:30] 23 | y_house2_mae3 = data['RMSE-3'][:30] 24 | y_house2_mae4 = data['RMSE-4'][:30] 25 | y_house2_mae5 = data['RMSE-5'][:30] 26 | 27 | y_house3_mae1 = data['R2'][:30] 28 | y_house3_mae1 = [max(0, value) for value in y_house3_mae1] 29 | y_house3_mae2 = data['R2-2'][:30] 30 | y_house3_mae2 = [max(0, value) for value in y_house3_mae2] 31 | y_house3_mae3 = data['R2-3'][:30] 32 | y_house3_mae3 = [max(0, value) for value in y_house3_mae3] 33 | y_house3_mae4 = data['R2-4'][:30] 34 | y_house3_mae4 = [max(0, value) for value in y_house3_mae4] 35 | y_house3_mae5 = data['R2-5'][:30] 36 | y_house3_mae5 = [max(0, value) for value in y_house3_mae5] 37 | 38 | plt.figure(figsize=(20, 12)) 39 | 40 | ahead_itf_idx = [14, 8, 28, 20, 18] # 41 | end_itf_idx = [16, 10, 30, 22, 20] 42 | 43 | plt.plot(x, y_house1_mae, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 44 | plt.plot(x, y_house2_mae, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 45 | plt.plot(x, y_house3_mae, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 46 | plt.plot(x, y_house4_mae, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 47 | plt.plot(x, y_house5_mae, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 48 | y_range = plt.ylim() 49 | y = y_range[1] * 0.95 50 | 51 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 52 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 53 | 54 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 55 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 56 | 57 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 58 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 59 | 60 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 61 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 62 | 63 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 64 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 65 | 66 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 67 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 68 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 69 | 70 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 71 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 72 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 73 | 74 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 75 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 76 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 77 | 78 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 79 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 80 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 81 | 82 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 83 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 84 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 85 | plt.xlabel('Number-of-Subsequence', fontsize=18) 86 | plt.legend(loc='upper right',prop={'size': 10}) 87 | plt.grid(False) 88 | plt.ylabel('MAE(W)', fontsize=18) 89 | plt.tight_layout() 90 | plt.show() 91 | 92 | plt.figure(figsize=(20, 12)) 93 | plt.plot(x, y_house1_mae1, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 94 | plt.plot(x, y_house1_mae2, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 95 | plt.plot(x, y_house1_mae3, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 96 | plt.plot(x, y_house1_mae4, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 97 | plt.plot(x, y_house1_mae5, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 98 | y_range = plt.ylim() 99 | y = y_range[1] * 0.95 100 | 101 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 102 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 103 | 104 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 105 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 106 | 107 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 108 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 109 | 110 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 111 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 112 | 113 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 114 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 115 | 116 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 117 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 118 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 119 | 120 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 121 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 122 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 123 | 124 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 125 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 126 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 127 | 128 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 129 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 130 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 131 | 132 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 133 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 134 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 135 | plt.xlabel('Number-of-Subsequence', fontsize=18) 136 | plt.legend(loc='upper right',prop={'size': 10}) 137 | plt.grid(False) 138 | plt.ylabel('MAPE(%)', fontsize=18) 139 | plt.tight_layout() 140 | plt.show() 141 | 142 | plt.figure(figsize=(20, 12)) 143 | plt.plot(x, y_house2_mae1, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 144 | plt.plot(x, y_house2_mae2, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 145 | plt.plot(x, y_house2_mae3, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 146 | plt.plot(x, y_house2_mae4, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 147 | plt.plot(x, y_house2_mae5, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 148 | y_range = plt.ylim() 149 | y = y_range[1] * 0.95 150 | 151 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 152 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 153 | 154 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 155 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 156 | 157 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 158 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 159 | 160 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 161 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 162 | 163 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 164 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 165 | 166 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 167 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 168 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 169 | 170 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 171 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 172 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 173 | 174 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 175 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 176 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 177 | 178 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 179 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 180 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 181 | 182 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 183 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 184 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 185 | plt.xlabel('Number-of-Subsequence', fontsize=18) 186 | plt.legend(loc='upper right',prop={'size': 10}) 187 | plt.grid(False) 188 | plt.ylabel('RMSE(W)', fontsize=18) 189 | plt.tight_layout() 190 | plt.show() 191 | 192 | plt.figure(figsize=(20, 12)) 193 | plt.plot(x, y_house3_mae1, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 194 | plt.plot(x, y_house3_mae2, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 195 | plt.plot(x, y_house3_mae3, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 196 | plt.plot(x, y_house3_mae4, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 197 | plt.plot(x, y_house3_mae5, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 198 | y_range = plt.ylim() 199 | y = y_range[1] * 0.95 200 | 201 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 202 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 203 | 204 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 205 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 206 | 207 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 208 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 209 | 210 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 211 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 212 | 213 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 214 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 215 | 216 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 217 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 218 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 219 | 220 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 221 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 222 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 223 | 224 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 225 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 226 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 227 | 228 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 229 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 230 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 231 | 232 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 233 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 234 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 235 | plt.xlabel('Number-of-Subsequence', fontsize=18) 236 | plt.legend(loc='upper right',prop={'size': 10}) 237 | plt.grid(False) 238 | plt.ylabel('R2', fontsize=18) 239 | plt.tight_layout() 240 | plt.show() 241 | -------------------------------------------------------------------------------- /FedSpecNet_v2/Subsequence_number_experiment.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | from matplotlib.ticker import FormatStrFormatter 4 | 5 | 6 | def plot_subsequence_result(data_path): 7 | data = pd.read_csv(data_path) 8 | x = data['Number-of-Subsequence'][:30] 9 | y_house1_mae = data['MAE'][:30] 10 | y_house2_mae = data['MAE-2'][:30] 11 | y_house3_mae = data['MAE-3'][:30] 12 | y_house4_mae = data['MAE-4'][:30] 13 | y_house5_mae = data['MAE-5'][:30] 14 | 15 | y_house1_mae1 = data['MAPE'][:30] 16 | y_house1_mae2 = data['MAPE-2'][:30] 17 | y_house1_mae3 = data['MAPE-3'][:30] 18 | y_house1_mae4 = data['MAPE-4'][:30] 19 | y_house1_mae5 = data['MAPE-5'][:30] 20 | 21 | y_house2_mae1 = data['RMSE'][:30] 22 | y_house2_mae2 = data['RMSE-2'][:30] 23 | y_house2_mae3 = data['RMSE-3'][:30] 24 | y_house2_mae4 = data['RMSE-4'][:30] 25 | y_house2_mae5 = data['RMSE-5'][:30] 26 | 27 | y_house3_mae1 = data['R2'][:30] 28 | y_house3_mae1 = [max(0, value) for value in y_house3_mae1] 29 | y_house3_mae2 = data['R2-2'][:30] 30 | y_house3_mae2 = [max(0, value) for value in y_house3_mae2] 31 | y_house3_mae3 = data['R2-3'][:30] 32 | y_house3_mae3 = [max(0, value) for value in y_house3_mae3] 33 | y_house3_mae4 = data['R2-4'][:30] 34 | y_house3_mae4 = [max(0, value) for value in y_house3_mae4] 35 | y_house3_mae5 = data['R2-5'][:30] 36 | y_house3_mae5 = [max(0, value) for value in y_house3_mae5] 37 | 38 | plt.figure(figsize=(20, 12)) 39 | 40 | ahead_itf_idx = [14, 8, 28, 20, 18] # 41 | end_itf_idx = [16, 10, 30, 22, 20] 42 | 43 | plt.plot(x, y_house1_mae, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 44 | plt.plot(x, y_house2_mae, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 45 | plt.plot(x, y_house3_mae, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 46 | plt.plot(x, y_house4_mae, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 47 | plt.plot(x, y_house5_mae, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 48 | y_range = plt.ylim() 49 | y = y_range[1] * 0.95 50 | 51 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 52 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 53 | 54 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 55 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 56 | 57 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 58 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 59 | 60 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 61 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 62 | 63 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 64 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 65 | 66 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 67 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 68 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 69 | 70 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 71 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 72 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 73 | 74 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 75 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 76 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 77 | 78 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 79 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 80 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 81 | 82 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 83 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 84 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 85 | plt.xlabel('Number-of-Subsequence', fontsize=18) 86 | plt.legend(loc='upper right',prop={'size': 10}) 87 | plt.grid(False) 88 | plt.ylabel('MAE(W)', fontsize=18) 89 | plt.tight_layout() 90 | plt.show() 91 | 92 | plt.figure(figsize=(20, 12)) 93 | plt.plot(x, y_house1_mae1, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 94 | plt.plot(x, y_house1_mae2, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 95 | plt.plot(x, y_house1_mae3, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 96 | plt.plot(x, y_house1_mae4, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 97 | plt.plot(x, y_house1_mae5, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 98 | y_range = plt.ylim() 99 | y = y_range[1] * 0.95 100 | 101 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 102 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 103 | 104 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 105 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 106 | 107 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 108 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 109 | 110 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 111 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 112 | 113 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 114 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 115 | 116 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 117 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 118 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 119 | 120 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 121 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 122 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 123 | 124 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 125 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 126 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 127 | 128 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 129 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 130 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 131 | 132 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 133 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 134 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 135 | plt.xlabel('Number-of-Subsequence', fontsize=18) 136 | plt.legend(loc='upper right',prop={'size': 10}) 137 | plt.grid(False) 138 | plt.ylabel('MAPE(%)', fontsize=18) 139 | plt.tight_layout() 140 | plt.show() 141 | 142 | plt.figure(figsize=(20, 12)) 143 | plt.plot(x, y_house2_mae1, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 144 | plt.plot(x, y_house2_mae2, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 145 | plt.plot(x, y_house2_mae3, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 146 | plt.plot(x, y_house2_mae4, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 147 | plt.plot(x, y_house2_mae5, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 148 | y_range = plt.ylim() 149 | y = y_range[1] * 0.95 150 | 151 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 152 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 153 | 154 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 155 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 156 | 157 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 158 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 159 | 160 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 161 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 162 | 163 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 164 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 165 | 166 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 167 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 168 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 169 | 170 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 171 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 172 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 173 | 174 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 175 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 176 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 177 | 178 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 179 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 180 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 181 | 182 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 183 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 184 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 185 | plt.xlabel('Number-of-Subsequence', fontsize=18) 186 | plt.legend(loc='upper right',prop={'size': 10}) 187 | plt.grid(False) 188 | plt.ylabel('RMSE(W)', fontsize=18) 189 | plt.tight_layout() 190 | plt.show() 191 | 192 | plt.figure(figsize=(20, 12)) 193 | plt.plot(x, y_house3_mae1, label='house1', color='orange', linestyle='-', marker='o', linewidth=2.5) 194 | plt.plot(x, y_house3_mae2, label='house2', color='red', linestyle='-', marker='s', linewidth=2.5) 195 | plt.plot(x, y_house3_mae3, label='house3', color='brown', linestyle='-', marker='h', linewidth=2.5) 196 | plt.plot(x, y_house3_mae4, label='house4', color='green', linestyle='-', marker='d', linewidth=2.5) 197 | plt.plot(x, y_house3_mae5, label='house5', color='blue', linestyle='-', marker='d', linewidth=2.5) 198 | y_range = plt.ylim() 199 | y = y_range[1] * 0.95 200 | 201 | plt.axvline(x=ahead_itf_idx[0], color='k', linestyle='--') 202 | plt.axvline(x=end_itf_idx[0], color='k', linestyle='--') 203 | 204 | plt.axvline(x=ahead_itf_idx[1], color='k', linestyle='--') 205 | plt.axvline(x=end_itf_idx[1], color='k', linestyle='--') 206 | 207 | plt.axvline(x=ahead_itf_idx[2], color='k', linestyle='--') 208 | plt.axvline(x=end_itf_idx[2], color='k', linestyle='--') 209 | 210 | plt.axvline(x=ahead_itf_idx[3], color='k', linestyle='--') 211 | plt.axvline(x=end_itf_idx[3], color='k', linestyle='--') 212 | 213 | plt.axvline(x=ahead_itf_idx[4], color='k', linestyle='--') 214 | plt.axvline(x=end_itf_idx[4], color='k', linestyle='--') 215 | 216 | plt.axvspan(ahead_itf_idx[0], end_itf_idx[0], color='orange', alpha=0.5) 217 | plt.text((ahead_itf_idx[0] + end_itf_idx[0]) / 2, y, 'house1_range', 218 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 219 | 220 | plt.axvspan(ahead_itf_idx[1], end_itf_idx[1], color='red', alpha=0.5) 221 | plt.text((ahead_itf_idx[1] + end_itf_idx[1]) / 2, y, 'house2_range', 222 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 223 | 224 | plt.axvspan(ahead_itf_idx[2], end_itf_idx[2], color='brown', alpha=0.5) 225 | plt.text((ahead_itf_idx[2] + end_itf_idx[2]) / 2, y, 'house3_range', 226 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 227 | 228 | plt.axvspan(ahead_itf_idx[3], end_itf_idx[3], color='green', alpha=0.5) 229 | plt.text((ahead_itf_idx[3] + end_itf_idx[3]) / 2, y, 'house4_range', 230 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 231 | 232 | plt.axvspan(ahead_itf_idx[4], end_itf_idx[4], color='blue', alpha=0.5) 233 | plt.text((ahead_itf_idx[4] + end_itf_idx[4]) / 2, y, 'house5_range', 234 | horizontalalignment='center', verticalalignment='top', fontsize=10, color='black') 235 | plt.xlabel('Number-of-Subsequence', fontsize=18) 236 | plt.legend(loc='upper right',prop={'size': 10}) 237 | plt.grid(False) 238 | plt.ylabel('R2', fontsize=18) 239 | plt.tight_layout() 240 | plt.show() 241 | --------------------------------------------------------------------------------