├── Filters ├── EKF.py ├── EKF_test.py ├── KalmanFilter_test.py └── Linear_KF.py ├── KNet └── KalmanNet_nn.py ├── Pipelines └── Pipeline_EKF.py ├── Plot.py ├── README.md ├── Simulations ├── Extended_sysmdl.py ├── Linear_CA │ ├── data │ │ └── decimated_dt1e-2_T100_r0_randnInit.pt │ └── parameters.py ├── Linear_canonical │ ├── data │ │ └── 2x2_rq020_T100.pt │ └── parameters.py ├── Linear_sysmdl.py ├── Lorenz_Atractor │ ├── data │ │ └── data_gen.pt │ └── parameters.py ├── config.py └── utils.py ├── main_linear_CA.py ├── main_linear_canonical.py ├── main_lor_DT.py ├── main_lor_DT_NLobs.py ├── main_lor_decimation.py └── requirements.txt /Filters/EKF.py: -------------------------------------------------------------------------------- 1 | """# **Class: Extended Kalman Filter** 2 | Theoretical Non Linear Kalman 3 | """ 4 | import torch 5 | 6 | from Simulations.Lorenz_Atractor.parameters import getJacobian 7 | 8 | class ExtendedKalmanFilter: 9 | 10 | def __init__(self, SystemModel, args): 11 | # Device 12 | if args.use_cuda: 13 | self.device = torch.device('cuda') 14 | else: 15 | self.device = torch.device('cpu') 16 | # process model 17 | self.f = SystemModel.f 18 | self.m = SystemModel.m 19 | self.Q = SystemModel.Q.to(self.device) 20 | # observation model 21 | self.h = SystemModel.h 22 | self.n = SystemModel.n 23 | self.R = SystemModel.R.to(self.device) 24 | # sequence length (use maximum length if random length case) 25 | self.T = SystemModel.T 26 | self.T_test = SystemModel.T_test 27 | 28 | # Predict 29 | def Predict(self): 30 | # Predict the 1-st moment of x 31 | self.m1x_prior = self.f(self.m1x_posterior).to(self.device) 32 | # Compute the Jacobians 33 | self.UpdateJacobians(getJacobian(self.m1x_posterior,self.f), getJacobian(self.m1x_prior, self.h)) 34 | # Predict the 2-nd moment of x 35 | self.m2x_prior = torch.bmm(self.batched_F, self.m2x_posterior) 36 | self.m2x_prior = torch.bmm(self.m2x_prior, self.batched_F_T) + self.Q 37 | 38 | # Predict the 1-st moment of y 39 | self.m1y = self.h(self.m1x_prior) 40 | # Predict the 2-nd moment of y 41 | self.m2y = torch.bmm(self.batched_H, self.m2x_prior) 42 | self.m2y = torch.bmm(self.m2y, self.batched_H_T) + self.R 43 | 44 | # Compute the Kalman Gain 45 | def KGain(self): 46 | self.KG = torch.bmm(self.m2x_prior, self.batched_H_T) 47 | self.KG = torch.bmm(self.KG, torch.inverse(self.m2y)) 48 | 49 | #Save KalmanGain 50 | self.KG_array[:,:,:,self.i] = self.KG 51 | self.i += 1 52 | 53 | # Innovation 54 | def Innovation(self, y): 55 | self.dy = y - self.m1y 56 | 57 | # Compute Posterior 58 | def Correct(self): 59 | # Compute the 1-st posterior moment 60 | self.m1x_posterior = self.m1x_prior + torch.bmm(self.KG, self.dy) 61 | 62 | # Compute the 2-nd posterior moment 63 | self.m2x_posterior = torch.bmm(self.m2y, torch.transpose(self.KG, 1, 2)) 64 | self.m2x_posterior = self.m2x_prior - torch.bmm(self.KG, self.m2x_posterior) 65 | 66 | def Update(self, y): 67 | self.Predict() 68 | self.KGain() 69 | self.Innovation(y) 70 | self.Correct() 71 | 72 | return self.m1x_posterior, self.m2x_posterior 73 | 74 | ######################### 75 | 76 | def UpdateJacobians(self, F, H): 77 | self.batched_F = F.to(self.device) 78 | self.batched_F_T = torch.transpose(F,1,2) 79 | self.batched_H = H.to(self.device) 80 | self.batched_H_T = torch.transpose(H,1,2) 81 | 82 | def Init_batched_sequence(self, m1x_0_batch, m2x_0_batch): 83 | 84 | self.m1x_0_batch = m1x_0_batch # [batch_size, m, 1] 85 | self.m2x_0_batch = m2x_0_batch # [batch_size, m, m] 86 | 87 | ###################### 88 | ### Generate Batch ### 89 | ###################### 90 | def GenerateBatch(self, y): 91 | """ 92 | input y: batch of observations [batch_size, n, T] 93 | """ 94 | y = y.to(self.device) 95 | self.batch_size = y.shape[0] # batch size 96 | T = y.shape[2] # sequence length (maximum length if randomLength=True) 97 | 98 | # Pre allocate KG array 99 | self.KG_array = torch.zeros([self.batch_size,self.m,self.n,T]).to(self.device) 100 | self.i = 0 # Index for KG_array alocation 101 | 102 | # Allocate Array for 1st and 2nd order moments (use zero padding) 103 | self.x = torch.zeros(self.batch_size, self.m, T).to(self.device) 104 | self.sigma = torch.zeros(self.batch_size, self.m, self.m, T).to(self.device) 105 | 106 | # Set 1st and 2nd order moments for t=0 107 | self.m1x_posterior = self.m1x_0_batch.to(self.device) 108 | self.m2x_posterior = self.m2x_0_batch.to(self.device) 109 | 110 | # Generate in a batched manner 111 | for t in range(0, T): 112 | yt = torch.unsqueeze(y[:, :, t],2) 113 | xt,sigmat = self.Update(yt) 114 | self.x[:, :, t] = torch.squeeze(xt,2) 115 | self.sigma[:, :, :, t] = sigmat -------------------------------------------------------------------------------- /Filters/EKF_test.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import time 4 | from Filters.EKF import ExtendedKalmanFilter 5 | 6 | 7 | def EKFTest(args, SysModel, test_input, test_target, allStates=True,\ 8 | randomInit = False,test_init=None, test_lengthMask=None): 9 | # Number of test samples 10 | N_T = test_target.size()[0] 11 | # LOSS 12 | loss_fn = nn.MSELoss(reduction='mean') 13 | # MSE [Linear] 14 | MSE_EKF_linear_arr = torch.zeros(N_T) 15 | # Allocate empty tensor for output 16 | EKF_out = torch.zeros([N_T, SysModel.m, test_input.size()[2]]) # N_T x m x T 17 | KG_array = torch.zeros([N_T, SysModel.m, SysModel.n, test_input.size()[2]]) # N_T x m x n x T 18 | 19 | if not allStates: 20 | loc = torch.tensor([True,False,False]) # for position only 21 | if SysModel.m == 2: 22 | loc = torch.tensor([True,False]) # for position only 23 | 24 | start = time.time() 25 | EKF = ExtendedKalmanFilter(SysModel, args) 26 | # Init and Forward Computation 27 | if(randomInit): 28 | EKF.Init_batched_sequence(test_init, SysModel.m2x_0.view(1,SysModel.m,SysModel.m).expand(N_T,-1,-1)) 29 | else: 30 | EKF.Init_batched_sequence(SysModel.m1x_0.view(1,SysModel.m,1).expand(N_T,-1,-1), SysModel.m2x_0.view(1,SysModel.m,SysModel.m).expand(N_T,-1,-1)) 31 | EKF.GenerateBatch(test_input) 32 | 33 | end = time.time() 34 | t = end - start 35 | 36 | KG_array = EKF.KG_array 37 | EKF_out = EKF.x 38 | 39 | # MSE loss 40 | for j in range(N_T):# cannot use batch due to different length and std computation 41 | if(allStates): 42 | if args.randomLength: 43 | MSE_EKF_linear_arr[j] = loss_fn(EKF.x[j,:,test_lengthMask[j]], test_target[j,:,test_lengthMask[j]]).item() 44 | else: 45 | MSE_EKF_linear_arr[j] = loss_fn(EKF.x[j,:,:], test_target[j,:,:]).item() 46 | else: # mask on state 47 | if args.randomLength: 48 | MSE_EKF_linear_arr[j] = loss_fn(EKF.x[j,loc,test_lengthMask[j]], test_target[j,loc,test_lengthMask[j]]).item() 49 | else: 50 | MSE_EKF_linear_arr[j] = loss_fn(EKF.x[j,loc,:], test_target[j,loc,:]).item() 51 | 52 | MSE_EKF_linear_avg = torch.mean(MSE_EKF_linear_arr) 53 | MSE_EKF_dB_avg = 10 * torch.log10(MSE_EKF_linear_avg) 54 | 55 | # Standard deviation 56 | MSE_EKF_linear_std = torch.std(MSE_EKF_linear_arr, unbiased=True) 57 | 58 | # Confidence interval 59 | EKF_std_dB = 10 * torch.log10(MSE_EKF_linear_std + MSE_EKF_linear_avg) - MSE_EKF_dB_avg 60 | 61 | print("Extended Kalman Filter - MSE LOSS:", MSE_EKF_dB_avg, "[dB]") 62 | print("Extended Kalman Filter - STD:", EKF_std_dB, "[dB]") 63 | # Print Run Time 64 | print("Inference Time:", t) 65 | 66 | return [MSE_EKF_linear_arr, MSE_EKF_linear_avg, MSE_EKF_dB_avg, KG_array, EKF_out] 67 | 68 | 69 | -------------------------------------------------------------------------------- /Filters/KalmanFilter_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | from Filters.Linear_KF import KalmanFilter 5 | 6 | def KFTest(args, SysModel, test_input, test_target, allStates=True,\ 7 | randomInit = False, test_init=None, test_lengthMask=None): 8 | 9 | # LOSS 10 | loss_fn = nn.MSELoss(reduction='mean') 11 | 12 | # MSE [Linear] 13 | MSE_KF_linear_arr = torch.zeros(args.N_T) 14 | # allocate memory for KF output 15 | KF_out = torch.zeros(args.N_T, SysModel.m, args.T_test) 16 | if not allStates: 17 | loc = torch.tensor([True,False,False]) # for position only 18 | if SysModel.m == 2: 19 | loc = torch.tensor([True,False]) # for position only 20 | 21 | start = time.time() 22 | 23 | KF = KalmanFilter(SysModel, args) 24 | # Init and Forward Computation 25 | if(randomInit): 26 | KF.Init_batched_sequence(test_init, SysModel.m2x_0.view(1,SysModel.m,SysModel.m).expand(args.N_T,-1,-1)) 27 | else: 28 | KF.Init_batched_sequence(SysModel.m1x_0.view(1,SysModel.m,1).expand(args.N_T,-1,-1), SysModel.m2x_0.view(1,SysModel.m,SysModel.m).expand(args.N_T,-1,-1)) 29 | KF.GenerateBatch(test_input) 30 | 31 | end = time.time() 32 | t = end - start 33 | KF_out = KF.x 34 | # MSE loss 35 | for j in range(args.N_T):# cannot use batch due to different length and std computation 36 | if(allStates): 37 | if args.randomLength: 38 | MSE_KF_linear_arr[j] = loss_fn(KF.x[j,:,test_lengthMask[j]], test_target[j,:,test_lengthMask[j]]).item() 39 | else: 40 | MSE_KF_linear_arr[j] = loss_fn(KF.x[j,:,:], test_target[j,:,:]).item() 41 | else: # mask on state 42 | if args.randomLength: 43 | MSE_KF_linear_arr[j] = loss_fn(KF.x[j,loc,test_lengthMask[j]], test_target[j,loc,test_lengthMask[j]]).item() 44 | else: 45 | MSE_KF_linear_arr[j] = loss_fn(KF.x[j,loc,:], test_target[j,loc,:]).item() 46 | 47 | MSE_KF_linear_avg = torch.mean(MSE_KF_linear_arr) 48 | MSE_KF_dB_avg = 10 * torch.log10(MSE_KF_linear_avg) 49 | 50 | # Standard deviation 51 | MSE_KF_linear_std = torch.std(MSE_KF_linear_arr, unbiased=True) 52 | 53 | # Confidence interval 54 | KF_std_dB = 10 * torch.log10(MSE_KF_linear_std + MSE_KF_linear_avg) - MSE_KF_dB_avg 55 | 56 | print("Kalman Filter - MSE LOSS:", MSE_KF_dB_avg, "[dB]") 57 | print("Kalman Filter - STD:", KF_std_dB, "[dB]") 58 | # Print Run Time 59 | print("Inference Time:", t) 60 | return [MSE_KF_linear_arr, MSE_KF_linear_avg, MSE_KF_dB_avg, KF_out] 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /Filters/Linear_KF.py: -------------------------------------------------------------------------------- 1 | """# **Class: Kalman Filter** 2 | Theoretical Linear Kalman Filter 3 | batched version 4 | """ 5 | import torch 6 | 7 | class KalmanFilter: 8 | 9 | def __init__(self, SystemModel, args): 10 | # Device 11 | if args.use_cuda: 12 | self.device = torch.device('cuda') 13 | else: 14 | self.device = torch.device('cpu') 15 | self.F = SystemModel.F 16 | self.m = SystemModel.m 17 | self.Q = SystemModel.Q.to(self.device) 18 | 19 | self.H = SystemModel.H 20 | self.n = SystemModel.n 21 | self.R = SystemModel.R.to(self.device) 22 | 23 | self.T = SystemModel.T 24 | self.T_test = SystemModel.T_test 25 | 26 | # Predict 27 | 28 | def Predict(self): 29 | # Predict the 1-st moment of x 30 | self.m1x_prior = torch.bmm(self.batched_F, self.m1x_posterior).to(self.device) 31 | 32 | # Predict the 2-nd moment of x 33 | self.m2x_prior = torch.bmm(self.batched_F, self.m2x_posterior) 34 | self.m2x_prior = torch.bmm(self.m2x_prior, self.batched_F_T) + self.Q 35 | 36 | # Predict the 1-st moment of y 37 | self.m1y = torch.bmm(self.batched_H, self.m1x_prior) 38 | 39 | # Predict the 2-nd moment of y 40 | self.m2y = torch.bmm(self.batched_H, self.m2x_prior) 41 | self.m2y = torch.bmm(self.m2y, self.batched_H_T) + self.R 42 | 43 | # Compute the Kalman Gain 44 | def KGain(self): 45 | self.KG = torch.bmm(self.m2x_prior, self.batched_H_T) 46 | 47 | self.KG = torch.bmm(self.KG, torch.inverse(self.m2y)) 48 | 49 | # Innovation 50 | def Innovation(self, y): 51 | self.dy = y - self.m1y 52 | 53 | # Compute Posterior 54 | def Correct(self): 55 | # Compute the 1-st posterior moment 56 | self.m1x_posterior = self.m1x_prior + torch.bmm(self.KG, self.dy) 57 | 58 | # Compute the 2-nd posterior moment 59 | self.m2x_posterior = torch.bmm(self.m2y, torch.transpose(self.KG, 1, 2)) 60 | self.m2x_posterior = self.m2x_prior - torch.bmm(self.KG, self.m2x_posterior) 61 | 62 | def Update(self, y): 63 | self.Predict() 64 | self.KGain() 65 | self.Innovation(y) 66 | self.Correct() 67 | 68 | return self.m1x_posterior,self.m2x_posterior 69 | 70 | def Init_batched_sequence(self, m1x_0_batch, m2x_0_batch): 71 | 72 | self.m1x_0_batch = m1x_0_batch # [batch_size, m, 1] 73 | self.m2x_0_batch = m2x_0_batch # [batch_size, m, m] 74 | 75 | ###################### 76 | ### Generate Batch ### 77 | ###################### 78 | def GenerateBatch(self, y): 79 | """ 80 | input y: batch of observations [batch_size, n, T] 81 | """ 82 | y = y.to(self.device) 83 | self.batch_size = y.shape[0] # batch size 84 | T = y.shape[2] # sequence length (maximum length if randomLength=True) 85 | 86 | # Batched F and H 87 | self.batched_F = self.F.view(1,self.m,self.m).expand(self.batch_size,-1,-1).to(self.device) 88 | self.batched_F_T = torch.transpose(self.batched_F, 1, 2).to(self.device) 89 | self.batched_H = self.H.view(1,self.n,self.m).expand(self.batch_size,-1,-1).to(self.device) 90 | self.batched_H_T = torch.transpose(self.batched_H, 1, 2).to(self.device) 91 | 92 | # Allocate Array for 1st and 2nd order moments (use zero padding) 93 | self.x = torch.zeros(self.batch_size, self.m, T).to(self.device) 94 | self.sigma = torch.zeros(self.batch_size, self.m, self.m, T).to(self.device) 95 | 96 | # Set 1st and 2nd order moments for t=0 97 | self.m1x_posterior = self.m1x_0_batch.to(self.device) 98 | self.m2x_posterior = self.m2x_0_batch.to(self.device) 99 | 100 | # Generate in a batched manner 101 | for t in range(0, T): 102 | yt = torch.unsqueeze(y[:, :, t],2) 103 | xt,sigmat = self.Update(yt) 104 | self.x[:, :, t] = torch.squeeze(xt,2) 105 | self.sigma[:, :, :, t] = sigmat 106 | -------------------------------------------------------------------------------- /KNet/KalmanNet_nn.py: -------------------------------------------------------------------------------- 1 | """# **Class: KalmanNet**""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as func 6 | 7 | class KalmanNetNN(torch.nn.Module): 8 | 9 | ################### 10 | ### Constructor ### 11 | ################### 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def NNBuild(self, SysModel, args): 16 | 17 | # Device 18 | if args.use_cuda: 19 | self.device = torch.device('cuda') 20 | else: 21 | self.device = torch.device('cpu') 22 | 23 | self.InitSystemDynamics(SysModel.f, SysModel.h, SysModel.m, SysModel.n) 24 | 25 | # Number of neurons in the 1st hidden layer 26 | #H1_KNet = (SysModel.m + SysModel.n) * (10) * 8 27 | 28 | # Number of neurons in the 2nd hidden layer 29 | #H2_KNet = (SysModel.m * SysModel.n) * 1 * (4) 30 | 31 | self.InitKGainNet(SysModel.prior_Q, SysModel.prior_Sigma, SysModel.prior_S, args) 32 | 33 | ###################################### 34 | ### Initialize Kalman Gain Network ### 35 | ###################################### 36 | def InitKGainNet(self, prior_Q, prior_Sigma, prior_S, args): 37 | 38 | self.seq_len_input = 1 # KNet calculates time-step by time-step 39 | self.batch_size = args.n_batch # Batch size 40 | 41 | self.prior_Q = prior_Q.to(self.device) 42 | self.prior_Sigma = prior_Sigma.to(self.device) 43 | self.prior_S = prior_S.to(self.device) 44 | 45 | 46 | 47 | # GRU to track Q 48 | self.d_input_Q = self.m * args.in_mult_KNet 49 | self.d_hidden_Q = self.m ** 2 50 | self.GRU_Q = nn.GRU(self.d_input_Q, self.d_hidden_Q).to(self.device) 51 | 52 | # GRU to track Sigma 53 | self.d_input_Sigma = self.d_hidden_Q + self.m * args.in_mult_KNet 54 | self.d_hidden_Sigma = self.m ** 2 55 | self.GRU_Sigma = nn.GRU(self.d_input_Sigma, self.d_hidden_Sigma).to(self.device) 56 | 57 | # GRU to track S 58 | self.d_input_S = self.n ** 2 + 2 * self.n * args.in_mult_KNet 59 | self.d_hidden_S = self.n ** 2 60 | self.GRU_S = nn.GRU(self.d_input_S, self.d_hidden_S).to(self.device) 61 | 62 | # Fully connected 1 63 | self.d_input_FC1 = self.d_hidden_Sigma 64 | self.d_output_FC1 = self.n ** 2 65 | self.FC1 = nn.Sequential( 66 | nn.Linear(self.d_input_FC1, self.d_output_FC1), 67 | nn.ReLU()).to(self.device) 68 | 69 | # Fully connected 2 70 | self.d_input_FC2 = self.d_hidden_S + self.d_hidden_Sigma 71 | self.d_output_FC2 = self.n * self.m 72 | self.d_hidden_FC2 = self.d_input_FC2 * args.out_mult_KNet 73 | self.FC2 = nn.Sequential( 74 | nn.Linear(self.d_input_FC2, self.d_hidden_FC2), 75 | nn.ReLU(), 76 | nn.Linear(self.d_hidden_FC2, self.d_output_FC2)).to(self.device) 77 | 78 | # Fully connected 3 79 | self.d_input_FC3 = self.d_hidden_S + self.d_output_FC2 80 | self.d_output_FC3 = self.m ** 2 81 | self.FC3 = nn.Sequential( 82 | nn.Linear(self.d_input_FC3, self.d_output_FC3), 83 | nn.ReLU()).to(self.device) 84 | 85 | # Fully connected 4 86 | self.d_input_FC4 = self.d_hidden_Sigma + self.d_output_FC3 87 | self.d_output_FC4 = self.d_hidden_Sigma 88 | self.FC4 = nn.Sequential( 89 | nn.Linear(self.d_input_FC4, self.d_output_FC4), 90 | nn.ReLU()).to(self.device) 91 | 92 | # Fully connected 5 93 | self.d_input_FC5 = self.m 94 | self.d_output_FC5 = self.m * args.in_mult_KNet 95 | self.FC5 = nn.Sequential( 96 | nn.Linear(self.d_input_FC5, self.d_output_FC5), 97 | nn.ReLU()).to(self.device) 98 | 99 | # Fully connected 6 100 | self.d_input_FC6 = self.m 101 | self.d_output_FC6 = self.m * args.in_mult_KNet 102 | self.FC6 = nn.Sequential( 103 | nn.Linear(self.d_input_FC6, self.d_output_FC6), 104 | nn.ReLU()).to(self.device) 105 | 106 | # Fully connected 7 107 | self.d_input_FC7 = 2 * self.n 108 | self.d_output_FC7 = 2 * self.n * args.in_mult_KNet 109 | self.FC7 = nn.Sequential( 110 | nn.Linear(self.d_input_FC7, self.d_output_FC7), 111 | nn.ReLU()).to(self.device) 112 | 113 | ################################## 114 | ### Initialize System Dynamics ### 115 | ################################## 116 | def InitSystemDynamics(self, f, h, m, n): 117 | 118 | # Set State Evolution Function 119 | self.f = f 120 | self.m = m 121 | 122 | # Set Observation Function 123 | self.h = h 124 | self.n = n 125 | 126 | ########################### 127 | ### Initialize Sequence ### 128 | ########################### 129 | def InitSequence(self, M1_0, T): 130 | """ 131 | input M1_0 (torch.tensor): 1st moment of x at time 0 [batch_size, m, 1] 132 | """ 133 | self.T = T 134 | 135 | self.m1x_posterior = M1_0.to(self.device) 136 | self.m1x_posterior_previous = self.m1x_posterior 137 | self.m1x_prior_previous = self.m1x_posterior 138 | self.y_previous = self.h(self.m1x_posterior) 139 | 140 | ###################### 141 | ### Compute Priors ### 142 | ###################### 143 | def step_prior(self): 144 | # Predict the 1-st moment of x 145 | self.m1x_prior = self.f(self.m1x_posterior) 146 | 147 | # Predict the 1-st moment of y 148 | self.m1y = self.h(self.m1x_prior) 149 | 150 | ############################## 151 | ### Kalman Gain Estimation ### 152 | ############################## 153 | def step_KGain_est(self, y): 154 | # both in size [batch_size, n] 155 | obs_diff = torch.squeeze(y,2) - torch.squeeze(self.y_previous,2) 156 | obs_innov_diff = torch.squeeze(y,2) - torch.squeeze(self.m1y,2) 157 | # both in size [batch_size, m] 158 | fw_evol_diff = torch.squeeze(self.m1x_posterior,2) - torch.squeeze(self.m1x_posterior_previous,2) 159 | fw_update_diff = torch.squeeze(self.m1x_posterior,2) - torch.squeeze(self.m1x_prior_previous,2) 160 | 161 | obs_diff = func.normalize(obs_diff, p=2, dim=1, eps=1e-12, out=None) 162 | obs_innov_diff = func.normalize(obs_innov_diff, p=2, dim=1, eps=1e-12, out=None) 163 | fw_evol_diff = func.normalize(fw_evol_diff, p=2, dim=1, eps=1e-12, out=None) 164 | fw_update_diff = func.normalize(fw_update_diff, p=2, dim=1, eps=1e-12, out=None) 165 | 166 | # Kalman Gain Network Step 167 | KG = self.KGain_step(obs_diff, obs_innov_diff, fw_evol_diff, fw_update_diff) 168 | 169 | # Reshape Kalman Gain to a Matrix 170 | self.KGain = torch.reshape(KG, (self.batch_size, self.m, self.n)) 171 | 172 | ####################### 173 | ### Kalman Net Step ### 174 | ####################### 175 | def KNet_step(self, y): 176 | 177 | # Compute Priors 178 | self.step_prior() 179 | 180 | # Compute Kalman Gain 181 | self.step_KGain_est(y) 182 | 183 | # Innovation 184 | dy = y - self.m1y # [batch_size, n, 1] 185 | 186 | # Compute the 1-st posterior moment 187 | INOV = torch.bmm(self.KGain, dy) 188 | self.m1x_posterior_previous = self.m1x_posterior 189 | self.m1x_posterior = self.m1x_prior + INOV 190 | 191 | #self.state_process_posterior_0 = self.state_process_prior_0 192 | self.m1x_prior_previous = self.m1x_prior 193 | 194 | # update y_prev 195 | self.y_previous = y 196 | 197 | # return 198 | return self.m1x_posterior 199 | 200 | ######################## 201 | ### Kalman Gain Step ### 202 | ######################## 203 | def KGain_step(self, obs_diff, obs_innov_diff, fw_evol_diff, fw_update_diff): 204 | 205 | def expand_dim(x): 206 | expanded = torch.empty(self.seq_len_input, self.batch_size, x.shape[-1]).to(self.device) 207 | expanded[0, :, :] = x 208 | return expanded 209 | 210 | obs_diff = expand_dim(obs_diff) 211 | obs_innov_diff = expand_dim(obs_innov_diff) 212 | fw_evol_diff = expand_dim(fw_evol_diff) 213 | fw_update_diff = expand_dim(fw_update_diff) 214 | 215 | #################### 216 | ### Forward Flow ### 217 | #################### 218 | 219 | # FC 5 220 | in_FC5 = fw_update_diff 221 | out_FC5 = self.FC5(in_FC5) 222 | 223 | # Q-GRU 224 | in_Q = out_FC5 225 | out_Q, self.h_Q = self.GRU_Q(in_Q, self.h_Q) 226 | 227 | # FC 6 228 | in_FC6 = fw_evol_diff 229 | out_FC6 = self.FC6(in_FC6) 230 | 231 | # Sigma_GRU 232 | in_Sigma = torch.cat((out_Q, out_FC6), 2) 233 | out_Sigma, self.h_Sigma = self.GRU_Sigma(in_Sigma, self.h_Sigma) 234 | 235 | # FC 1 236 | in_FC1 = out_Sigma 237 | out_FC1 = self.FC1(in_FC1) 238 | 239 | # FC 7 240 | in_FC7 = torch.cat((obs_diff, obs_innov_diff), 2) 241 | out_FC7 = self.FC7(in_FC7) 242 | 243 | 244 | # S-GRU 245 | in_S = torch.cat((out_FC1, out_FC7), 2) 246 | out_S, self.h_S = self.GRU_S(in_S, self.h_S) 247 | 248 | 249 | # FC 2 250 | in_FC2 = torch.cat((out_Sigma, out_S), 2) 251 | out_FC2 = self.FC2(in_FC2) 252 | 253 | ##################### 254 | ### Backward Flow ### 255 | ##################### 256 | 257 | # FC 3 258 | in_FC3 = torch.cat((out_S, out_FC2), 2) 259 | out_FC3 = self.FC3(in_FC3) 260 | 261 | # FC 4 262 | in_FC4 = torch.cat((out_Sigma, out_FC3), 2) 263 | out_FC4 = self.FC4(in_FC4) 264 | 265 | # updating hidden state of the Sigma-GRU 266 | self.h_Sigma = out_FC4 267 | 268 | return out_FC2 269 | ############### 270 | ### Forward ### 271 | ############### 272 | def forward(self, y): 273 | y = y.to(self.device) 274 | return self.KNet_step(y) 275 | 276 | ######################### 277 | ### Init Hidden State ### 278 | ######################### 279 | def init_hidden_KNet(self): 280 | weight = next(self.parameters()).data 281 | hidden = weight.new(self.seq_len_input, self.batch_size, self.d_hidden_S).zero_() 282 | self.h_S = hidden.data 283 | self.h_S = self.prior_S.flatten().reshape(1, 1, -1).repeat(self.seq_len_input,self.batch_size, 1) # batch size expansion 284 | hidden = weight.new(self.seq_len_input, self.batch_size, self.d_hidden_Sigma).zero_() 285 | self.h_Sigma = hidden.data 286 | self.h_Sigma = self.prior_Sigma.flatten().reshape(1,1, -1).repeat(self.seq_len_input,self.batch_size, 1) # batch size expansion 287 | hidden = weight.new(self.seq_len_input, self.batch_size, self.d_hidden_Q).zero_() 288 | self.h_Q = hidden.data 289 | self.h_Q = self.prior_Q.flatten().reshape(1,1, -1).repeat(self.seq_len_input,self.batch_size, 1) # batch size expansion 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /Pipelines/Pipeline_EKF.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the class Pipeline_EKF, 3 | which is used to train and test KalmanNet. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import random 9 | import time 10 | from Plot import Plot_extended 11 | 12 | 13 | class Pipeline_EKF: 14 | 15 | def __init__(self, Time, folderName, modelName): 16 | super().__init__() 17 | self.Time = Time 18 | self.folderName = folderName + '/' 19 | self.modelName = modelName 20 | self.modelFileName = self.folderName + "model_" + self.modelName + ".pt" 21 | self.PipelineName = self.folderName + "pipeline_" + self.modelName + ".pt" 22 | 23 | def save(self): 24 | torch.save(self, self.PipelineName) 25 | 26 | def setssModel(self, ssModel): 27 | self.ssModel = ssModel 28 | 29 | def setModel(self, model): 30 | self.model = model 31 | 32 | def setTrainingParams(self, args): 33 | self.args = args 34 | if args.use_cuda: 35 | self.device = torch.device('cuda') 36 | else: 37 | self.device = torch.device('cpu') 38 | self.N_steps = args.n_steps # Number of Training Steps 39 | self.N_B = args.n_batch # Number of Samples in Batch 40 | self.learningRate = args.lr # Learning Rate 41 | self.weightDecay = args.wd # L2 Weight Regularization - Weight Decay 42 | self.alpha = args.alpha # Composition loss factor 43 | # MSE LOSS Function 44 | self.loss_fn = nn.MSELoss(reduction='mean') 45 | 46 | # Use the optim package to define an Optimizer that will update the weights of 47 | # the model for us. Here we will use Adam; the optim package contains many other 48 | # optimization algoriths. The first argument to the Adam constructor tells the 49 | # optimizer which Tensors it should update. 50 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learningRate, weight_decay=self.weightDecay) 51 | 52 | def NNTrain(self, SysModel, cv_input, cv_target, train_input, train_target, path_results, \ 53 | MaskOnState=False, randomInit=False,cv_init=None,train_init=None,\ 54 | train_lengthMask=None,cv_lengthMask=None): 55 | 56 | self.N_E = len(train_input) 57 | self.N_CV = len(cv_input) 58 | 59 | self.MSE_cv_linear_epoch = torch.zeros([self.N_steps]) 60 | self.MSE_cv_dB_epoch = torch.zeros([self.N_steps]) 61 | 62 | self.MSE_train_linear_epoch = torch.zeros([self.N_steps]) 63 | self.MSE_train_dB_epoch = torch.zeros([self.N_steps]) 64 | 65 | if MaskOnState: 66 | mask = torch.tensor([True,False,False]) 67 | if SysModel.m == 2: 68 | mask = torch.tensor([True,False]) 69 | 70 | ############## 71 | ### Epochs ### 72 | ############## 73 | 74 | self.MSE_cv_dB_opt = 1000 75 | self.MSE_cv_idx_opt = 0 76 | 77 | for ti in range(0, self.N_steps): 78 | 79 | ############################### 80 | ### Training Sequence Batch ### 81 | ############################### 82 | self.optimizer.zero_grad() 83 | # Training Mode 84 | self.model.train() 85 | self.model.batch_size = self.N_B 86 | # Init Hidden State 87 | self.model.init_hidden_KNet() 88 | 89 | # Init Training Batch tensors 90 | y_training_batch = torch.zeros([self.N_B, SysModel.n, SysModel.T]).to(self.device) 91 | train_target_batch = torch.zeros([self.N_B, SysModel.m, SysModel.T]).to(self.device) 92 | x_out_training_batch = torch.zeros([self.N_B, SysModel.m, SysModel.T]).to(self.device) 93 | if self.args.randomLength: 94 | MSE_train_linear_LOSS = torch.zeros([self.N_B]) 95 | MSE_cv_linear_LOSS = torch.zeros([self.N_CV]) 96 | 97 | # Randomly select N_B training sequences 98 | assert self.N_B <= self.N_E # N_B must be smaller than N_E 99 | n_e = random.sample(range(self.N_E), k=self.N_B) 100 | ii = 0 101 | for index in n_e: 102 | if self.args.randomLength: 103 | y_training_batch[ii,:,train_lengthMask[index,:]] = train_input[index,:,train_lengthMask[index,:]] 104 | train_target_batch[ii,:,train_lengthMask[index,:]] = train_target[index,:,train_lengthMask[index,:]] 105 | else: 106 | y_training_batch[ii,:,:] = train_input[index] 107 | train_target_batch[ii,:,:] = train_target[index] 108 | ii += 1 109 | 110 | # Init Sequence 111 | if(randomInit): 112 | train_init_batch = torch.empty([self.N_B, SysModel.m,1]).to(self.device) 113 | ii = 0 114 | for index in n_e: 115 | train_init_batch[ii,:,0] = torch.squeeze(train_init[index]) 116 | ii += 1 117 | self.model.InitSequence(train_init_batch, SysModel.T) 118 | else: 119 | self.model.InitSequence(\ 120 | SysModel.m1x_0.reshape(1,SysModel.m,1).repeat(self.N_B,1,1), SysModel.T) 121 | 122 | # Forward Computation 123 | for t in range(0, SysModel.T): 124 | x_out_training_batch[:, :, t] = torch.squeeze(self.model(torch.unsqueeze(y_training_batch[:, :, t],2))) 125 | 126 | # Compute Training Loss 127 | MSE_trainbatch_linear_LOSS = 0 128 | if (self.args.CompositionLoss): 129 | y_hat = torch.zeros([self.N_B, SysModel.n, SysModel.T]) 130 | for t in range(SysModel.T): 131 | y_hat[:,:,t] = torch.squeeze(SysModel.h(torch.unsqueeze(x_out_training_batch[:,:,t]))) 132 | 133 | if(MaskOnState):### FIXME: composition loss, y_hat may have different mask with x 134 | if self.args.randomLength: 135 | jj = 0 136 | for index in n_e:# mask out the padded part when computing loss 137 | MSE_train_linear_LOSS[jj] = self.alpha * self.loss_fn(x_out_training_batch[jj,mask,train_lengthMask[index]], train_target_batch[jj,mask,train_lengthMask[index]])+(1-self.alpha)*self.loss_fn(y_hat[jj,mask,train_lengthMask[index]], y_training_batch[jj,mask,train_lengthMask[index]]) 138 | jj += 1 139 | MSE_trainbatch_linear_LOSS = torch.mean(MSE_train_linear_LOSS) 140 | else: 141 | MSE_trainbatch_linear_LOSS = self.alpha * self.loss_fn(x_out_training_batch[:,mask,:], train_target_batch[:,mask,:])+(1-self.alpha)*self.loss_fn(y_hat[:,mask,:], y_training_batch[:,mask,:]) 142 | else:# no mask on state 143 | if self.args.randomLength: 144 | jj = 0 145 | for index in n_e:# mask out the padded part when computing loss 146 | MSE_train_linear_LOSS[jj] = self.alpha * self.loss_fn(x_out_training_batch[jj,:,train_lengthMask[index]], train_target_batch[jj,:,train_lengthMask[index]])+(1-self.alpha)*self.loss_fn(y_hat[jj,:,train_lengthMask[index]], y_training_batch[jj,:,train_lengthMask[index]]) 147 | jj += 1 148 | MSE_trainbatch_linear_LOSS = torch.mean(MSE_train_linear_LOSS) 149 | else: 150 | MSE_trainbatch_linear_LOSS = self.alpha * self.loss_fn(x_out_training_batch, train_target_batch)+(1-self.alpha)*self.loss_fn(y_hat, y_training_batch) 151 | 152 | else:# no composition loss 153 | if(MaskOnState): 154 | if self.args.randomLength: 155 | jj = 0 156 | for index in n_e:# mask out the padded part when computing loss 157 | MSE_train_linear_LOSS[jj] = self.loss_fn(x_out_training_batch[jj,mask,train_lengthMask[index]], train_target_batch[jj,mask,train_lengthMask[index]]) 158 | jj += 1 159 | MSE_trainbatch_linear_LOSS = torch.mean(MSE_train_linear_LOSS) 160 | else: 161 | MSE_trainbatch_linear_LOSS = self.loss_fn(x_out_training_batch[:,mask,:], train_target_batch[:,mask,:]) 162 | else: # no mask on state 163 | if self.args.randomLength: 164 | jj = 0 165 | for index in n_e:# mask out the padded part when computing loss 166 | MSE_train_linear_LOSS[jj] = self.loss_fn(x_out_training_batch[jj,:,train_lengthMask[index]], train_target_batch[jj,:,train_lengthMask[index]]) 167 | jj += 1 168 | MSE_trainbatch_linear_LOSS = torch.mean(MSE_train_linear_LOSS) 169 | else: 170 | MSE_trainbatch_linear_LOSS = self.loss_fn(x_out_training_batch, train_target_batch) 171 | 172 | # dB Loss 173 | self.MSE_train_linear_epoch[ti] = MSE_trainbatch_linear_LOSS.item() 174 | self.MSE_train_dB_epoch[ti] = 10 * torch.log10(self.MSE_train_linear_epoch[ti]) 175 | 176 | ################## 177 | ### Optimizing ### 178 | ################## 179 | 180 | # Before the backward pass, use the optimizer object to zero all of the 181 | # gradients for the variables it will update (which are the learnable 182 | # weights of the model). This is because by default, gradients are 183 | # accumulated in buffers( i.e, not overwritten) whenever .backward() 184 | # is called. Checkout docs of torch.autograd.backward for more details. 185 | 186 | # Backward pass: compute gradient of the loss with respect to model 187 | # parameters 188 | MSE_trainbatch_linear_LOSS.backward(retain_graph=True) 189 | 190 | # Calling the step function on an Optimizer makes an update to its 191 | # parameters 192 | self.optimizer.step() 193 | # self.scheduler.step(self.MSE_cv_dB_epoch[ti]) 194 | 195 | ################################# 196 | ### Validation Sequence Batch ### 197 | ################################# 198 | 199 | # Cross Validation Mode 200 | self.model.eval() 201 | self.model.batch_size = self.N_CV 202 | # Init Hidden State 203 | self.model.init_hidden_KNet() 204 | with torch.no_grad(): 205 | 206 | SysModel.T_test = cv_input.size()[-1] # T_test is the maximum length of the CV sequences 207 | 208 | x_out_cv_batch = torch.empty([self.N_CV, SysModel.m, SysModel.T_test]).to(self.device) 209 | 210 | # Init Sequence 211 | if(randomInit): 212 | if(cv_init==None): 213 | self.model.InitSequence(\ 214 | SysModel.m1x_0.reshape(1,SysModel.m,1).repeat(self.N_CV,1,1), SysModel.T_test) 215 | else: 216 | self.model.InitSequence(cv_init, SysModel.T_test) 217 | else: 218 | self.model.InitSequence(\ 219 | SysModel.m1x_0.reshape(1,SysModel.m,1).repeat(self.N_CV,1,1), SysModel.T_test) 220 | 221 | for t in range(0, SysModel.T_test): 222 | x_out_cv_batch[:, :, t] = torch.squeeze(self.model(torch.unsqueeze(cv_input[:, :, t],2))) 223 | 224 | # Compute CV Loss 225 | MSE_cvbatch_linear_LOSS = 0 226 | if(MaskOnState): 227 | if self.args.randomLength: 228 | for index in range(self.N_CV): 229 | MSE_cv_linear_LOSS[index] = self.loss_fn(x_out_cv_batch[index,mask,cv_lengthMask[index]], cv_target[index,mask,cv_lengthMask[index]]) 230 | MSE_cvbatch_linear_LOSS = torch.mean(MSE_cv_linear_LOSS) 231 | else: 232 | MSE_cvbatch_linear_LOSS = self.loss_fn(x_out_cv_batch[:,mask,:], cv_target[:,mask,:]) 233 | else: 234 | if self.args.randomLength: 235 | for index in range(self.N_CV): 236 | MSE_cv_linear_LOSS[index] = self.loss_fn(x_out_cv_batch[index,:,cv_lengthMask[index]], cv_target[index,:,cv_lengthMask[index]]) 237 | MSE_cvbatch_linear_LOSS = torch.mean(MSE_cv_linear_LOSS) 238 | else: 239 | MSE_cvbatch_linear_LOSS = self.loss_fn(x_out_cv_batch, cv_target) 240 | 241 | # dB Loss 242 | self.MSE_cv_linear_epoch[ti] = MSE_cvbatch_linear_LOSS.item() 243 | self.MSE_cv_dB_epoch[ti] = 10 * torch.log10(self.MSE_cv_linear_epoch[ti]) 244 | 245 | if (self.MSE_cv_dB_epoch[ti] < self.MSE_cv_dB_opt): 246 | self.MSE_cv_dB_opt = self.MSE_cv_dB_epoch[ti] 247 | self.MSE_cv_idx_opt = ti 248 | 249 | torch.save(self.model, path_results + 'best-model.pt') 250 | 251 | ######################## 252 | ### Training Summary ### 253 | ######################## 254 | print(ti, "MSE Training :", self.MSE_train_dB_epoch[ti], "[dB]", "MSE Validation :", self.MSE_cv_dB_epoch[ti], 255 | "[dB]") 256 | 257 | if (ti > 1): 258 | d_train = self.MSE_train_dB_epoch[ti] - self.MSE_train_dB_epoch[ti - 1] 259 | d_cv = self.MSE_cv_dB_epoch[ti] - self.MSE_cv_dB_epoch[ti - 1] 260 | print("diff MSE Training :", d_train, "[dB]", "diff MSE Validation :", d_cv, "[dB]") 261 | 262 | print("Optimal idx:", self.MSE_cv_idx_opt, "Optimal :", self.MSE_cv_dB_opt, "[dB]") 263 | 264 | return [self.MSE_cv_linear_epoch, self.MSE_cv_dB_epoch, self.MSE_train_linear_epoch, self.MSE_train_dB_epoch] 265 | 266 | def NNTest(self, SysModel, test_input, test_target, path_results, MaskOnState=False,\ 267 | randomInit=False,test_init=None,load_model=False,load_model_path=None,\ 268 | test_lengthMask=None): 269 | # Load model 270 | if load_model: 271 | self.model = torch.load(load_model_path, map_location=self.device) 272 | else: 273 | self.model = torch.load(path_results+'best-model.pt', map_location=self.device) 274 | 275 | self.N_T = test_input.shape[0] 276 | SysModel.T_test = test_input.size()[-1] 277 | self.MSE_test_linear_arr = torch.zeros([self.N_T]) 278 | x_out_test = torch.zeros([self.N_T, SysModel.m,SysModel.T_test]).to(self.device) 279 | 280 | if MaskOnState: 281 | mask = torch.tensor([True,False,False]) 282 | if SysModel.m == 2: 283 | mask = torch.tensor([True,False]) 284 | 285 | # MSE LOSS Function 286 | loss_fn = nn.MSELoss(reduction='mean') 287 | 288 | # Test mode 289 | self.model.eval() 290 | self.model.batch_size = self.N_T 291 | # Init Hidden State 292 | self.model.init_hidden_KNet() 293 | torch.no_grad() 294 | 295 | start = time.time() 296 | 297 | if (randomInit): 298 | self.model.InitSequence(test_init, SysModel.T_test) 299 | else: 300 | self.model.InitSequence(SysModel.m1x_0.reshape(1,SysModel.m,1).repeat(self.N_T,1,1), SysModel.T_test) 301 | 302 | for t in range(0, SysModel.T_test): 303 | x_out_test[:,:, t] = torch.squeeze(self.model(torch.unsqueeze(test_input[:,:, t],2))) 304 | 305 | end = time.time() 306 | t = end - start 307 | 308 | # MSE loss 309 | for j in range(self.N_T):# cannot use batch due to different length and std computation 310 | if(MaskOnState): 311 | if self.args.randomLength: 312 | self.MSE_test_linear_arr[j] = loss_fn(x_out_test[j,mask,test_lengthMask[j]], test_target[j,mask,test_lengthMask[j]]).item() 313 | else: 314 | self.MSE_test_linear_arr[j] = loss_fn(x_out_test[j,mask,:], test_target[j,mask,:]).item() 315 | else: 316 | if self.args.randomLength: 317 | self.MSE_test_linear_arr[j] = loss_fn(x_out_test[j,:,test_lengthMask[j]], test_target[j,:,test_lengthMask[j]]).item() 318 | else: 319 | self.MSE_test_linear_arr[j] = loss_fn(x_out_test[j,:,:], test_target[j,:,:]).item() 320 | 321 | # Average 322 | self.MSE_test_linear_avg = torch.mean(self.MSE_test_linear_arr) 323 | self.MSE_test_dB_avg = 10 * torch.log10(self.MSE_test_linear_avg) 324 | 325 | # Standard deviation 326 | self.MSE_test_linear_std = torch.std(self.MSE_test_linear_arr, unbiased=True) 327 | 328 | # Confidence interval 329 | self.test_std_dB = 10 * torch.log10(self.MSE_test_linear_std + self.MSE_test_linear_avg) - self.MSE_test_dB_avg 330 | 331 | # Print MSE and std 332 | str = self.modelName + "-" + "MSE Test:" 333 | print(str, self.MSE_test_dB_avg, "[dB]") 334 | str = self.modelName + "-" + "STD Test:" 335 | print(str, self.test_std_dB, "[dB]") 336 | # Print Run Time 337 | print("Inference Time:", t) 338 | 339 | return [self.MSE_test_linear_arr, self.MSE_test_linear_avg, self.MSE_test_dB_avg, x_out_test, t] 340 | 341 | def PlotTrain_KF(self, MSE_KF_linear_arr, MSE_KF_dB_avg): 342 | 343 | self.Plot = Plot_extended(self.folderName, self.modelName) 344 | 345 | self.Plot.NNPlot_epochs(self.N_steps, MSE_KF_dB_avg, 346 | self.MSE_test_dB_avg, self.MSE_cv_dB_epoch, self.MSE_train_dB_epoch) 347 | 348 | self.Plot.NNPlot_Hist(MSE_KF_linear_arr, self.MSE_test_linear_arr) -------------------------------------------------------------------------------- /Plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib as mpl 3 | mpl.rcParams['agg.path.chunksize'] = 1E4 4 | import matplotlib.pyplot as plt 5 | import matplotlib.gridspec as gridspec 6 | from mpl_toolkits.mplot3d import Axes3D 7 | import seaborn as sns 8 | import numpy as np 9 | from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset 10 | from scipy.signal import find_peaks 11 | 12 | # Legend 13 | Klegend = ["Unsupervised KalmanNet - Train", "Unsupervised KalmanNet - Validation", "Unsupervised KalmanNet - Test", "Kalman Filter"] 14 | RTSlegend = ["RTSNet - Train", "RTSNet - Validation", "RTSNet - Test", "RTS Smoother","Kalman Filter"] 15 | ERTSlegend = ["RTSNet - Train","RTSNet - Validation", "RTSNet - Test", "RTS","EKF"] 16 | error_evol = ["KNet Empirical Error","KNet Covariance Trace","KF Empirical Error","KF Covariance Trace","KNet Error Deviation","EKF Error Deviation"] 17 | # Color 18 | KColor = ['-ro','darkorange','k-', 'b-','g-'] 19 | RTSColor = ['red','darkorange','g-', 'b-'] 20 | 21 | class Plot_KF: 22 | 23 | def __init__(self, folderName, modelName): 24 | self.folderName = folderName 25 | self.modelName = modelName 26 | 27 | def NNPlot_epochs(self, N_Epochs_plt, MSE_KF_dB_avg, 28 | MSE_test_dB_avg, MSE_cv_dB_epoch, MSE_train_dB_epoch): 29 | 30 | # File Name 31 | fileName = self.folderName + 'plt_epochs_dB' 32 | 33 | fontSize = 32 34 | 35 | # Figure 36 | plt.figure(figsize = (25, 10)) 37 | 38 | # x_axis 39 | x_plt = range(0, N_Epochs_plt) 40 | 41 | # Train 42 | y_plt1 = MSE_train_dB_epoch[range(0, N_Epochs_plt)] 43 | plt.plot(x_plt, y_plt1, KColor[0], label=Klegend[0]) 44 | 45 | # CV 46 | y_plt2 = MSE_cv_dB_epoch[range(0, N_Epochs_plt)] 47 | plt.plot(x_plt, y_plt2, KColor[1], label=Klegend[1]) 48 | 49 | # Test 50 | y_plt3 = MSE_test_dB_avg * torch.ones(N_Epochs_plt) 51 | plt.plot(x_plt, y_plt3, KColor[2], label=Klegend[2]) 52 | 53 | # KF 54 | y_plt4 = MSE_KF_dB_avg * torch.ones(N_Epochs_plt) 55 | plt.plot(x_plt, y_plt4, KColor[3], label=Klegend[3]) 56 | 57 | plt.xticks(fontsize= fontSize) 58 | plt.yticks(fontsize= fontSize) 59 | plt.legend(fontsize=fontSize) 60 | plt.xlabel('Number of Training Iterations', fontsize=fontSize) 61 | plt.ylabel('MSE Loss Value [dB]', fontsize=fontSize) 62 | plt.grid(True) 63 | # plt.title(self.modelName + ":" + "MSE Loss [dB] - per Epoch", fontsize=fontSize) 64 | plt.savefig(fileName) 65 | 66 | 67 | def KFPlot(res_grid): 68 | 69 | plt.figure(figsize = (50, 20)) 70 | x_plt = [-6, 0, 6] 71 | 72 | plt.plot(x_plt, res_grid[0][:], 'xg', label='minus') 73 | plt.plot(x_plt, res_grid[1][:], 'ob', label='base') 74 | plt.plot(x_plt, res_grid[2][:], '+r', label='plus') 75 | plt.plot(x_plt, res_grid[3][:], 'oy', label='base NN') 76 | 77 | plt.legend() 78 | plt.xlabel('Noise', fontsize=16) 79 | plt.ylabel('MSE Loss Value [dB]', fontsize=16) 80 | plt.title('Change', fontsize=16) 81 | plt.savefig('plt_grid_dB') 82 | 83 | print("\ndistribution 1") 84 | print("Kalman Filter") 85 | print(res_grid[0][0], "[dB]", res_grid[1][0], "[dB]", res_grid[2][0], "[dB]") 86 | print(res_grid[1][0] - res_grid[0][0], "[dB]", res_grid[2][0] - res_grid[1][0], "[dB]") 87 | print("KalmanNet", res_grid[3][0], "[dB]", "KalmanNet Diff", res_grid[3][0] - res_grid[1][0], "[dB]") 88 | 89 | print("\ndistribution 2") 90 | print("Kalman Filter") 91 | print(res_grid[0][1], "[dB]", res_grid[1][1], "[dB]", res_grid[2][1], "[dB]") 92 | print(res_grid[1][1] - res_grid[0][1], "[dB]", res_grid[2][1] - res_grid[1][1], "[dB]") 93 | print("KalmanNet", res_grid[3][1], "[dB]", "KalmanNet Diff", res_grid[3][1] - res_grid[1][1], "[dB]") 94 | 95 | print("\ndistribution 3") 96 | print("Kalman Filter") 97 | print(res_grid[0][2], "[dB]", res_grid[1][2], "[dB]", res_grid[2][2], "[dB]") 98 | print(res_grid[1][2] - res_grid[0][2], "[dB]", res_grid[2][2] - res_grid[1][2], "[dB]") 99 | print("KalmanNet", res_grid[3][2], "[dB]", "KalmanNet Diff", res_grid[3][2] - res_grid[1][2], "[dB]") 100 | 101 | def NNPlot_test(MSE_KF_linear_arr, MSE_KF_linear_avg, MSE_KF_dB_avg, 102 | MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg): 103 | 104 | 105 | N_Epochs_plt = 100 106 | 107 | ############################### 108 | ### Plot per epoch [linear] ### 109 | ############################### 110 | plt.figure(figsize = (50, 20)) 111 | 112 | x_plt = range(0, N_Epochs_plt) 113 | 114 | # KNet - Test 115 | y_plt3 = MSE_test_linear_avg * torch.ones(N_Epochs_plt) 116 | plt.plot(x_plt, y_plt3, KColor[2], label=Klegend[2]) 117 | 118 | # KF 119 | y_plt4 = MSE_KF_linear_avg * torch.ones(N_Epochs_plt) 120 | plt.plot(x_plt, y_plt4, KColor[3], label=Klegend[3]) 121 | 122 | plt.legend() 123 | plt.xlabel('Number of Training Epochs', fontsize=16) 124 | plt.ylabel('MSE Loss Value [linear]', fontsize=16) 125 | plt.title('MSE Loss [linear] - per Epoch', fontsize=16) 126 | plt.savefig('plt_model_test_linear') 127 | 128 | ########################### 129 | ### Plot per epoch [dB] ### 130 | ########################### 131 | plt.figure(figsize = (50, 20)) 132 | 133 | x_plt = range(0, N_Epochs_plt) 134 | 135 | # KNet - Test 136 | y_plt3 = MSE_test_dB_avg * torch.ones(N_Epochs_plt) 137 | plt.plot(x_plt, y_plt3, KColor[2], label=Klegend[2]) 138 | 139 | # KF 140 | y_plt4 = MSE_KF_dB_avg * torch.ones(N_Epochs_plt) 141 | plt.plot(x_plt, y_plt4, KColor[3], label=Klegend[3]) 142 | 143 | plt.legend() 144 | plt.xlabel('Number of Training Epochs', fontsize=16) 145 | plt.ylabel('MSE Loss Value [dB]', fontsize=16) 146 | plt.title('MSE Loss [dB] - per Epoch', fontsize=16) 147 | plt.savefig('plt_model_test_dB') 148 | 149 | ######################## 150 | ### Linear Histogram ### 151 | ######################## 152 | plt.figure(figsize=(50, 20)) 153 | sns.distplot(MSE_test_linear_arr, hist=False, kde=True, kde_kws={'linewidth': 3}, color='g', label = 'KalmanNet') 154 | sns.distplot(MSE_KF_linear_arr, hist=False, kde=True, kde_kws={'linewidth': 3}, color= 'b', label = 'Kalman Filter') 155 | plt.title("Histogram [Linear]") 156 | plt.savefig('plt_hist_linear') 157 | 158 | fig, axes = plt.subplots(2, 1, figsize=(50, 20), sharey=True, dpi=100) 159 | sns.distplot(MSE_test_linear_arr, hist=False, kde=True, kde_kws={'linewidth': 3}, color='g', label='KalmanNet', ax=axes[0]) 160 | sns.distplot(MSE_KF_linear_arr, hist=False, kde=True, kde_kws={'linewidth': 3}, color='b', label='Kalman Filter', ax=axes[1]) 161 | plt.title("Histogram [Linear]") 162 | plt.savefig('plt_hist_linear_1') 163 | 164 | #################### 165 | ### dB Histogram ### 166 | #################### 167 | 168 | plt.figure(figsize=(50, 20)) 169 | sns.distplot(10 * torch.log10(MSE_test_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3}, color='g', label = 'KalmanNet') 170 | sns.distplot(10 * torch.log10(MSE_KF_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3}, color= 'b', label = 'Kalman Filter') 171 | plt.title("Histogram [dB]") 172 | plt.savefig('plt_hist_dB') 173 | 174 | 175 | fig, axes = plt.subplots(2, 1, figsize=(50, 20), sharey=True, dpi=100) 176 | sns.distplot(10 * torch.log10(MSE_test_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3}, color='g', label = 'KalmanNet', ax=axes[0]) 177 | sns.distplot(10 * torch.log10(MSE_KF_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3}, color= 'b', label = 'Kalman Filter', ax=axes[1]) 178 | plt.title("Histogram [dB]") 179 | plt.savefig('plt_hist_dB_1') 180 | 181 | print('End') 182 | 183 | class Plot_RTS(Plot_KF): 184 | 185 | def __init__(self, folderName, modelName): 186 | self.folderName = folderName 187 | self.modelName = modelName 188 | 189 | def NNPlot_epochs(self, N_MiniBatchTrain_plt, BatchSize, MSE_KF_dB_avg, MSE_RTS_dB_avg, 190 | MSE_test_dB_avg, MSE_cv_dB_epoch, MSE_train_dB_epoch): 191 | N_Epochs_plt = np.floor(N_MiniBatchTrain_plt/BatchSize).astype(int) # number of epochs 192 | 193 | # File Name 194 | fileName = self.folderName + 'plt_epochs_dB' 195 | 196 | fontSize = 32 197 | 198 | # Figure 199 | plt.figure(figsize = (25, 10)) 200 | 201 | # x_axis 202 | x_plt = range(0, N_Epochs_plt) 203 | 204 | # Train 205 | y_plt1 = MSE_train_dB_epoch[np.linspace(0,BatchSize*(N_Epochs_plt-1) ,N_Epochs_plt)] 206 | plt.plot(x_plt, y_plt1, KColor[0], label=RTSlegend[0]) 207 | 208 | # CV 209 | y_plt2 = MSE_cv_dB_epoch[np.linspace(0,BatchSize*(N_Epochs_plt-1) ,N_Epochs_plt)] 210 | plt.plot(x_plt, y_plt2, KColor[1], label=RTSlegend[1]) 211 | 212 | # Test 213 | y_plt3 = MSE_test_dB_avg * torch.ones(N_Epochs_plt) 214 | plt.plot(x_plt, y_plt3, KColor[2], label=RTSlegend[2]) 215 | 216 | # RTS 217 | y_plt4 = MSE_RTS_dB_avg * torch.ones(N_Epochs_plt) 218 | plt.plot(x_plt, y_plt4, "g", label=RTSlegend[3]) 219 | 220 | # KF 221 | y_plt5 = MSE_KF_dB_avg * torch.ones(N_Epochs_plt) 222 | plt.plot(x_plt, y_plt5, "orange", label=RTSlegend[4]) 223 | 224 | plt.legend(fontsize=fontSize) 225 | plt.xlabel('Number of Training Epochs', fontsize=fontSize) 226 | plt.ylabel('MSE Loss Value [dB]', fontsize=fontSize) 227 | plt.title(self.modelName + ":" + "MSE Loss [dB] - per Epoch", fontsize=fontSize) 228 | plt.savefig(fileName) 229 | 230 | 231 | def NNPlot_Hist(self, MSE_KF_linear_arr, MSE_RTS_data_linear_arr, MSE_RTSNet_linear_arr): 232 | 233 | fileName = self.folderName + 'plt_hist_dB' 234 | fontSize = 32 235 | #################### 236 | ### dB Histogram ### 237 | #################### 238 | plt.figure(figsize=(10, 25)) 239 | ax = sns.displot( 240 | {self.modelName: 10 * torch.log10(MSE_RTSNet_linear_arr), 241 | 'Kalman Filter': 10 * torch.log10(MSE_KF_linear_arr), 242 | 'RTS Smoother': 10 * torch.log10(MSE_RTS_data_linear_arr)}, # Use a dict to assign labels to each curve 243 | kind="kde", 244 | common_norm=False, # Normalize each distribution independently: the area under each curve equals 1. 245 | palette=["blue", "orange", "g"], # Use palette for multiple colors 246 | linewidth= 1, 247 | ) 248 | plt.title(self.modelName + ":" +"Histogram [dB]") 249 | plt.xlabel('MSE Loss Value [dB]') 250 | plt.ylabel('Percentage') 251 | sns.move_legend(ax, "upper right") 252 | plt.grid(True) 253 | plt.tight_layout() 254 | plt.savefig(fileName) 255 | 256 | def KF_RTS_Plot_Linear(self, r, MSE_KF_RTS_dB,PlotResultName): 257 | fileName = self.folderName + PlotResultName 258 | plt.figure(figsize = (25, 10)) 259 | x_plt = 10 * torch.log10(1/r**2) 260 | 261 | plt.plot(x_plt, MSE_KF_RTS_dB[0,:], '-^',color='orange',linewidth=1, markersize=12, label=r'2x2, KF') 262 | plt.plot(x_plt, MSE_KF_RTS_dB[1,:], '--go',markerfacecolor='none',linewidth=3, markersize=12, label=r'2x2, RTS') 263 | plt.plot(x_plt, MSE_KF_RTS_dB[2,:], '-bo',linewidth=1, markersize=12, label=r'2x2, RTSNet') 264 | 265 | plt.legend(fontsize=32) 266 | plt.xlabel(r'Noise $\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=32) 267 | plt.ylabel('MSE [dB]', fontsize=32) 268 | # plt.title('Comparing Kalman Filter and RTS Smoother', fontsize=32) 269 | plt.xticks(fontsize=20) 270 | plt.yticks(fontsize=20) 271 | plt.grid(True) 272 | plt.savefig(fileName) 273 | 274 | def rotate_RTS_Plot_F(self, r, MSE_RTS_dB,rotateName): 275 | fileName = self.folderName + rotateName 276 | plt.figure(figsize = (25, 10)) 277 | x_plt = 10 * torch.log10(1/r**2) 278 | 279 | plt.plot(x_plt, MSE_RTS_dB[0,:], '-r^', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{F}_{\alpha=0^\circ}$)') 280 | plt.plot(x_plt, MSE_RTS_dB[1,:], '-gx', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{F}_{\alpha=10^\circ}$)') 281 | plt.plot(x_plt, MSE_RTS_dB[2,:], '-bo', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTSNet ($\mathbf{F}_{\alpha=10^\circ}$)') 282 | 283 | plt.legend(fontsize=16) 284 | plt.xlabel(r'Noise $\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=32) 285 | plt.ylabel('MSE [dB]', fontsize=32) 286 | # plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 287 | plt.xticks(fontsize=20) 288 | plt.yticks(fontsize=20) 289 | plt.grid(True) 290 | plt.savefig(fileName) 291 | 292 | def rotate_RTS_Plot_H(self, r, MSE_RTS_dB,rotateName): 293 | fileName = self.folderName + rotateName 294 | magnifying_glass, main_H = plt.subplots(figsize = [25, 10]) 295 | # main_H = plt.figure(figsize = [25, 10]) 296 | x_plt = 10 * torch.log10(1/r**2) 297 | NoiseFloor = -x_plt 298 | main_H.plot(x_plt, NoiseFloor, '--r', linewidth=2, markersize=12, label=r'Noise Floor') 299 | main_H.plot(x_plt, MSE_RTS_dB[0,:], '-g^', linewidth=2, markersize=12, label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB] , 2x2, RTS Smoother ($\mathbf{H}_{\alpha=0^\circ}$)') 300 | main_H.plot(x_plt, MSE_RTS_dB[1,:], '-yx', linewidth=2, markersize=12, label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{H}_{\alpha=10^\circ}$)') 301 | main_H.plot(x_plt, MSE_RTS_dB[2,:], '-bo', linewidth=2, markersize=12, label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTSNet ($\mathbf{H}_{\alpha=10^\circ}$)') 302 | 303 | main_H.set(xlim=(x_plt[0], x_plt[len(x_plt)-1]), ylim=(-20, 15)) 304 | main_H.legend(fontsize=20) 305 | plt.xlabel(r'$\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=20) 306 | plt.ylabel('MSE [dB]', fontsize=20) 307 | plt.xticks(fontsize=20) 308 | plt.yticks(fontsize=20) 309 | # plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 310 | plt.grid(True) 311 | 312 | ax2 = plt.axes([.15, .15, .27, .27]) 313 | x1, x2, y1, y2 = -0.2, 0.2, -5, 8 314 | ax2.set_xlim(x1, x2) 315 | ax2.set_ylim(y1, y2) 316 | ax2.plot(x_plt, NoiseFloor, '--r', linewidth=2, markersize=12, label=r'Noise Floor') 317 | ax2.plot(x_plt, MSE_RTS_dB[0,:], '-g^', linewidth=2, markersize=12, label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB] , 2x2, RTS Smoother ($\mathbf{H}_{\alpha=0^\circ}$)') 318 | ax2.plot(x_plt, MSE_RTS_dB[1,:], '-yx', linewidth=2, markersize=12, label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{H}_{\alpha=10^\circ}$)') 319 | ax2.plot(x_plt, MSE_RTS_dB[2,:], '-bo', linewidth=2, markersize=12, label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTSNet ($\mathbf{H}_{\alpha=10^\circ}$)') 320 | ax2.grid(True) 321 | plt.savefig(fileName) 322 | 323 | def rotate_RTS_Plot_FHCompare(self, r, MSE_RTS_dB_F,MSE_RTS_dB_H,rotateName): 324 | fileName = self.folderName + rotateName 325 | plt.figure(figsize = (25, 10)) 326 | x_plt = 10 * torch.log10(1/r) 327 | 328 | plt.plot(x_plt, MSE_RTS_dB_F[0,:], '-r^', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{F}_{\alpha=0^\circ}$)') 329 | plt.plot(x_plt, MSE_RTS_dB_F[1,:], '-gx', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{F}_{\alpha=10^\circ}$)') 330 | plt.plot(x_plt, MSE_RTS_dB_F[2,:], '-bo', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTSNet ($\mathbf{F}_{\alpha=10^\circ}$)') 331 | plt.plot(x_plt, MSE_RTS_dB_H[0,:], '--r^', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{H}_{\alpha=0^\circ}$)') 332 | plt.plot(x_plt, MSE_RTS_dB_H[1,:], '--gx', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTS Smoother ($\mathbf{H}_{\alpha=10^\circ}$)') 333 | plt.plot(x_plt, MSE_RTS_dB_H[2,:], '--bo', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], 2x2, RTSNet ($\mathbf{H}_{\alpha=10^\circ}$)') 334 | 335 | plt.legend(fontsize=16) 336 | plt.xlabel(r'Noise $\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=32) 337 | plt.ylabel('MSE [dB]', fontsize=32) 338 | plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 339 | plt.grid(True) 340 | plt.savefig(fileName) 341 | 342 | def plotTraj_CA(self,test_target, RTS_out, rtsnet_out, dim, file_name): 343 | legend = ["RTSNet", "Ground Truth", "MB RTS"] 344 | font_size = 14 345 | T_test = rtsnet_out[0].size()[1] 346 | x_plt = range(0, T_test) 347 | if dim==0:#position 348 | plt.plot(x_plt, rtsnet_out[0][0,:].detach().numpy(), label=legend[0]) 349 | plt.plot(x_plt, test_target[0][0,:].detach().numpy(), label=legend[1]) 350 | plt.plot(x_plt, RTS_out[0][0,:], label=legend[2]) 351 | plt.legend(fontsize=font_size) 352 | plt.xlabel('t', fontsize=font_size) 353 | plt.ylabel('position', fontsize=font_size) 354 | plt.savefig(file_name) 355 | plt.clf() 356 | elif dim==1:#velocity 357 | plt.plot(x_plt, rtsnet_out[0][1,:].detach().numpy(), label=legend[0]) 358 | plt.plot(x_plt, test_target[0][1,:].detach().numpy(), label=legend[1]) 359 | plt.plot(x_plt, RTS_out[0][1,:], label=legend[2]) 360 | plt.legend(fontsize=font_size) 361 | plt.xlabel('t', fontsize=font_size) 362 | plt.ylabel('velocity', fontsize=font_size) 363 | plt.savefig(file_name) 364 | plt.clf() 365 | elif dim==2:#acceleration 366 | plt.plot(x_plt, rtsnet_out[0][2,:].detach().numpy(), label=legend[0]) 367 | plt.plot(x_plt, test_target[0][2,:].detach().numpy(), label=legend[1]) 368 | plt.plot(x_plt, RTS_out[0][2,:], label=legend[2]) 369 | plt.legend(fontsize=font_size) 370 | plt.xlabel('t', fontsize=font_size) 371 | plt.ylabel('acceleration', fontsize=font_size) 372 | plt.savefig(file_name) 373 | plt.clf() 374 | else: 375 | print("invalid dimension") 376 | 377 | class Plot_extended(Plot_RTS): 378 | def EKFPlot_Hist(self, MSE_EKF_linear_arr): 379 | fileName = self.folderName + 'plt_hist_dB' 380 | fontSize = 32 381 | #################### 382 | ### dB Histogram ### 383 | #################### 384 | plt.figure(figsize=(25, 10)) 385 | sns.distplot(10 * np.log10(MSE_EKF_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3}, color= 'b', label = 'Extended Kalman Filter') 386 | plt.title(self.modelName + ":" +"Histogram [dB]",fontsize=fontSize) 387 | plt.legend(fontsize=fontSize) 388 | plt.savefig(fileName) 389 | 390 | def KF_RTS_Plot(self, r, MSE_KF_RTS_dB): 391 | fileName = self.folderName + 'Nonlinear_KF_RTS_Compare_dB' 392 | plt.figure(figsize = (25, 10)) 393 | x_plt = 10 * torch.log10(1/r**2) 394 | 395 | plt.plot(x_plt, MSE_KF_RTS_dB[0,:], '-gx', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], Toy Model, EKF') 396 | plt.plot(x_plt, MSE_KF_RTS_dB[1,:], '--bo', label=r'$\mathrm{\frac{q^2}{r^2}}=0$ [dB], Toy Model, Extended RTS') 397 | 398 | plt.legend(fontsize=32) 399 | plt.xlabel(r'Noise $\mathrm{\frac{1}{q^2}}$ [dB]', fontsize=32) 400 | plt.ylabel('MSE [dB]', fontsize=32) 401 | plt.title('Comparing Extended Kalman Filter and Extended RTS Smoother', fontsize=32) 402 | plt.xticks(fontsize=20) 403 | plt.yticks(fontsize=20) 404 | plt.grid(True) 405 | plt.savefig(fileName) 406 | 407 | def NNPlot_trainsteps(self, N_MiniBatchTrain_plt, MSE_EKF_dB_avg, MSE_ERTS_dB_avg, 408 | MSE_test_dB_avg, MSE_cv_dB_epoch, MSE_train_dB_epoch): 409 | N_Epochs_plt = N_MiniBatchTrain_plt 410 | 411 | # File Name 412 | fileName = self.folderName + 'plt_epochs_dB' 413 | 414 | fontSize = 32 415 | 416 | # Figure 417 | plt.figure(figsize = (25, 10)) 418 | 419 | # x_axis 420 | x_plt = range(0, N_Epochs_plt) 421 | 422 | # Train 423 | y_plt1 = MSE_train_dB_epoch[range(0, N_Epochs_plt)] 424 | plt.plot(x_plt, y_plt1, KColor[0], label=ERTSlegend[0]) 425 | 426 | # CV 427 | y_plt2 = MSE_cv_dB_epoch[range(0, N_Epochs_plt)] 428 | plt.plot(x_plt, y_plt2, KColor[1], label=ERTSlegend[1]) 429 | 430 | # Test 431 | y_plt3 = MSE_test_dB_avg * torch.ones(N_Epochs_plt) 432 | plt.plot(x_plt, y_plt3, KColor[2], label=ERTSlegend[2]) 433 | 434 | # RTS 435 | y_plt4 = MSE_ERTS_dB_avg * torch.ones(N_Epochs_plt) 436 | plt.plot(x_plt, y_plt4, KColor[3], label=ERTSlegend[3]) 437 | 438 | # EKF 439 | y_plt5 = MSE_EKF_dB_avg * torch.ones(N_Epochs_plt) 440 | plt.plot(x_plt, y_plt5, KColor[4], label=ERTSlegend[4]) 441 | 442 | plt.legend(fontsize=fontSize) 443 | plt.xlabel('Number of Training Steps', fontsize=fontSize) 444 | plt.ylabel('MSE Loss Value [dB]', fontsize=fontSize) 445 | plt.grid(True) 446 | plt.title(self.modelName + ":" + "MSE Loss [dB] - per Step", fontsize=fontSize) 447 | plt.savefig(fileName) 448 | 449 | 450 | 451 | def NNPlot_epochs(self, N_E,N_MiniBatchTrain_plt, BatchSize, MSE_EKF_dB_avg, MSE_ERTS_dB_avg, 452 | MSE_test_dB_avg, MSE_cv_dB_epoch, MSE_train_dB_epoch): 453 | N_Epochs_plt = np.floor(N_MiniBatchTrain_plt*BatchSize/N_E).astype(int) # number of epochs 454 | print(N_Epochs_plt) 455 | # File Name 456 | fileName = self.folderName + 'plt_epochs_dB' 457 | 458 | fontSize = 32 459 | 460 | # Figure 461 | plt.figure(figsize = (25, 10)) 462 | 463 | # x_axis 464 | x_plt = range(0, N_Epochs_plt) 465 | 466 | # Train 467 | y_plt1 = MSE_train_dB_epoch[np.linspace(0,N_MiniBatchTrain_plt-1,N_Epochs_plt)] 468 | plt.plot(x_plt, y_plt1, KColor[0], label=ERTSlegend[0]) 469 | 470 | # CV 471 | y_plt2 = MSE_cv_dB_epoch[np.linspace(0,N_MiniBatchTrain_plt-1,N_Epochs_plt)] 472 | plt.plot(x_plt, y_plt2, KColor[1], label=ERTSlegend[1]) 473 | 474 | # Test 475 | y_plt3 = MSE_test_dB_avg * torch.ones(N_Epochs_plt) 476 | plt.plot(x_plt, y_plt3, KColor[2], label=ERTSlegend[2]) 477 | 478 | # RTS 479 | y_plt4 = MSE_ERTS_dB_avg * torch.ones(N_Epochs_plt) 480 | plt.plot(x_plt, y_plt4, KColor[3], label=ERTSlegend[3]) 481 | 482 | # EKF 483 | y_plt5 = MSE_EKF_dB_avg * torch.ones(N_Epochs_plt) 484 | plt.plot(x_plt, y_plt5, KColor[4], label=ERTSlegend[4]) 485 | 486 | plt.legend(fontsize=fontSize) 487 | plt.xlabel('Number of Training Epochs', fontsize=fontSize) 488 | plt.ylabel('MSE Loss Value [dB]', fontsize=fontSize) 489 | plt.grid(True) 490 | plt.title(self.modelName + ":" + "MSE Loss [dB] - per Epoch", fontsize=fontSize) 491 | plt.savefig(fileName) 492 | 493 | def NNPlot_Hist(self, MSE_EKF_linear_arr, MSE_ERTS_data_linear_arr, MSE_RTSNet_linear_arr): 494 | 495 | fileName = self.folderName + 'plt_hist_dB' 496 | fontSize = 32 497 | #################### 498 | ### dB Histogram ### 499 | #################### 500 | plt.figure(figsize=(25, 10)) 501 | # sns.distplot(10 * torch.log10(MSE_RTSNet_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 5}, color='b', label = self.modelName) 502 | # sns.distplot(10 * torch.log10(MSE_EKF_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3}, color= 'orange', label = 'EKF') 503 | # sns.distplot(10 * torch.log10(MSE_ERTS_data_linear_arr), hist=False, kde=True, kde_kws={'linewidth': 3.2,"linestyle":'--'},color= 'g', label = 'RTS') 504 | 505 | # plt.title(self.modelName + ":" +"Histogram [dB]",fontsize=fontSize) 506 | # plt.legend(fontsize=fontSize) 507 | # plt.xlabel('MSE Loss Value [dB]', fontsize=fontSize) 508 | # plt.ylabel('Percentage', fontsize=fontSize) 509 | # plt.tick_params(labelsize=fontSize) 510 | # plt.grid(True) 511 | # plt.savefig(fileName) 512 | ax = sns.displot( 513 | {self.modelName: 10 * torch.log10(MSE_RTSNet_linear_arr), 514 | 'Kalman Filter': 10 * torch.log10(MSE_EKF_linear_arr), 515 | 'RTS Smoother': 10 * torch.log10(MSE_ERTS_data_linear_arr)}, # Use a dict to assign labels to each curve 516 | kind="kde", 517 | common_norm=False, # Normalize each distribution independently: the area under each curve equals 1. 518 | palette=["blue", "orange", "g"], # Use palette for multiple colors 519 | linewidth= 1, 520 | ) 521 | plt.title(self.modelName + ":" +"Histogram [dB]") 522 | plt.xlabel('MSE Loss Value [dB]') 523 | plt.ylabel('Percentage') 524 | sns.move_legend(ax, "upper right") 525 | plt.grid(True) 526 | plt.tight_layout() 527 | plt.savefig(fileName) 528 | 529 | def NNPlot_epochs_KF_RTS(self, N_MiniBatchTrain_plt, BatchSize, MSE_EKF_dB_avg, MSE_ERTS_dB_avg, 530 | MSE_KNet_test_dB_avg, MSE_KNet_cv_dB_epoch, MSE_KNet_train_dB_epoch, 531 | MSE_RTSNet_test_dB_avg, MSE_RTSNet_cv_dB_epoch, MSE_RTSNet_train_dB_epoch): 532 | N_Epochs_plt = np.floor(N_MiniBatchTrain_plt/BatchSize).astype(int) # number of epochs 533 | 534 | # File Name 535 | fileName = self.folderName + 'plt_epochs_dB' 536 | 537 | fontSize = 32 538 | 539 | # Figure 540 | plt.figure(figsize = (25, 10)) 541 | 542 | # x_axis 543 | x_plt = range(0, N_Epochs_plt) 544 | 545 | # Train KNet and RTSNet 546 | # y_plt1 = MSE_KNet_train_dB_epoch[np.linspace(0,BatchSize*(N_Epochs_plt-1) ,N_Epochs_plt)] 547 | # plt.plot(x_plt, y_plt1, KColor[0], label=Klegend[0]) 548 | # y_plt2 = MSE_RTSNet_train_dB_epoch[np.linspace(0,BatchSize*(N_Epochs_plt-1) ,N_Epochs_plt)] 549 | # plt.plot(x_plt, y_plt2, color=RTSColor[0],linestyle='-', marker='o', label=ERTSlegend[0]) 550 | 551 | # CV KNet and RTSNet 552 | y_plt3 = MSE_KNet_cv_dB_epoch[np.linspace(0,BatchSize*(N_Epochs_plt-1) ,N_Epochs_plt)] 553 | plt.plot(x_plt, y_plt3, color=RTSColor[0],linestyle='-', marker='o', label=Klegend[1]) 554 | y_plt4 = MSE_RTSNet_cv_dB_epoch[np.linspace(0,BatchSize*(N_Epochs_plt-1) ,N_Epochs_plt)] 555 | plt.plot(x_plt, y_plt4, color=RTSColor[1],linestyle='-', marker='o', label=ERTSlegend[1]) 556 | 557 | # Test KNet and RTSNet 558 | y_plt5 = MSE_KNet_test_dB_avg * torch.ones(N_Epochs_plt) 559 | plt.plot(x_plt, y_plt5, color=RTSColor[0],linestyle='--', label=Klegend[2]) 560 | y_plt6 = MSE_RTSNet_test_dB_avg * torch.ones(N_Epochs_plt) 561 | plt.plot(x_plt, y_plt6,color=RTSColor[1],linestyle='--', label=ERTSlegend[2]) 562 | 563 | # RTS 564 | y_plt7 = MSE_ERTS_dB_avg * torch.ones(N_Epochs_plt) 565 | plt.plot(x_plt, y_plt7, RTSColor[2], label=ERTSlegend[3]) 566 | 567 | # EKF 568 | y_plt8 = MSE_EKF_dB_avg * torch.ones(N_Epochs_plt) 569 | plt.plot(x_plt, y_plt8, RTSColor[3], label=ERTSlegend[4]) 570 | 571 | plt.legend(fontsize=fontSize) 572 | plt.xlabel('Number of Training Epochs', fontsize=fontSize) 573 | plt.ylabel('MSE Loss Value [dB]', fontsize=fontSize) 574 | plt.title(self.modelName + ":" + "MSE Loss [dB] - per Epoch", fontsize=fontSize) 575 | plt.grid(True) 576 | plt.savefig(fileName) 577 | 578 | def plotTrajectories(self,inputs, dim, titles, file_name): 579 | 580 | fig = plt.figure(figsize=(15, 10)) 581 | plt.Axes (fig, [0,0,1,1]) 582 | # plt.subplots_adjust(wspace=-0.2, hspace=-0.2) 583 | matrix_size = int(np.ceil(np.sqrt(len(inputs)))) 584 | #gs1 = gridspec.GridSpec(matrix_size,matrix_size) 585 | gs1 = gridspec.GridSpec(3,2) 586 | gs1.update(wspace=0, hspace=0) 587 | gs2 = gridspec.GridSpec(5,1) 588 | gs2.update(wspace=0, hspace=1) 589 | plt.rcParams["figure.frameon"] = False 590 | plt.rcParams["figure.constrained_layout.use"]= True 591 | i=0 592 | for title in titles: 593 | inputs_numpy = inputs[i][0].detach().numpy() 594 | # gs1.update(wspace=-0.3,hspace=-0.3) 595 | if(dim==3): 596 | plt.rcParams["figure.frameon"] = False 597 | ax = fig.add_subplot(gs1[i],projection='3d') 598 | # if(i<3): 599 | # ax = fig.add_subplot(gs1[i],projection='3d') 600 | # else: 601 | # ax = fig.add_subplot(gs1[i:i+2],projection='3d') 602 | 603 | y_al = 0.73 604 | if(title == "True Trajectory"): 605 | c = 'k' 606 | elif(title == "Observation"): 607 | c = 'r' 608 | elif(title == "Extended RTS"): 609 | c = 'b' 610 | y_al = 0.68 611 | elif(title == "RTSNet"): 612 | c = 'g' 613 | elif(title == "Particle Smoother"): 614 | c = 'c' 615 | elif(title == "Vanilla RNN"): 616 | c = 'm' 617 | elif(title == "KNet"): 618 | c = 'y' 619 | else: 620 | c = 'purple' 621 | y_al = 0.68 622 | 623 | ax.set_axis_off() 624 | ax.set_title(title, y=y_al, fontdict={'fontsize': 15,'fontweight' : 20,'verticalalignment': 'baseline'}) 625 | ax.plot(inputs_numpy[0,:], inputs_numpy[1,:], inputs_numpy[2,:], c, linewidth=0.5) 626 | 627 | ## Plot display 628 | #ax.set_yticklabels([]) 629 | #ax.set_xticklabels([]) 630 | #ax.set_zticklabels([]) 631 | #ax.set_xlabel('x') 632 | #ax.set_ylabel('y') 633 | #ax.set_zlabel('z') 634 | 635 | if(dim==2): 636 | ax = fig.add_subplot(matrix_size, matrix_size,i+1) 637 | ax.plot(inputs_numpy[0,:],inputs_numpy[1,:], 'b', linewidth=0.75) 638 | ax.set_xlabel('x1') 639 | ax.set_ylabel('x2') 640 | ax.set_title(title, pad=10, fontdict={'fontsize': 20,'fontweight' : 20,'verticalalignment': 'baseline'}) 641 | 642 | if(dim==4): 643 | if(title == "True Trajectory"): 644 | target_theta_sample = inputs_numpy[0,0,:] 645 | 646 | # ax = fig.add_subplot(matrix_size, matrix_size,i+1) 647 | ax = fig.add_subplot(gs2[i,:]) 648 | # print(inputs_numpy[0,0,:]) 649 | ax.plot(np.arange(np.size(inputs_numpy[0,:],axis=1)), inputs_numpy[0,0,:], 'b', linewidth=0.75) 650 | if(title != "True Trajectory"): 651 | diff = target_theta_sample - inputs_numpy[0,0,:] 652 | peaks, _ = find_peaks(diff, prominence=0.31) 653 | troughs, _ = find_peaks(-diff, prominence=0.31) 654 | for peak, trough in zip(peaks, troughs): 655 | plt.axvspan(peak, trough, color='red', alpha=.2) 656 | # zoomed in 657 | # ax.plot(np.arange(20), inputs_numpy[0,0,0:20], 'b', linewidth=0.75)inputs_numpy[0,0,:] 658 | ax.set_xlabel('time [s]') 659 | ax.set_ylabel('theta [rad]') 660 | ax.set_title(title, pad=10, fontdict={'fontsize': 20,'fontweight' : 20,'verticalalignment': 'baseline'}) 661 | 662 | i +=1 663 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0, dpi=1000) 664 | 665 | def Partial_Plot_Lor(self, r, MSE_Partial_dB): 666 | fileName = self.folderName + 'Nonlinear_Lor_Partial_J=2' 667 | magnifying_glass, main_partial = plt.subplots(figsize = [20, 15]) 668 | x_plt = 10 * torch.log10(1/r**2) 669 | NoiseFloor = -x_plt 670 | main_partial.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12, label=r'Noise Floor') 671 | main_partial.plot(x_plt, MSE_Partial_dB[0,:], '-yx', linewidth=3, markersize=12, label=r'EKF: $\rm J_{mdl}=5$') 672 | main_partial.plot(x_plt, MSE_Partial_dB[1,:], '--yx', linewidth=3, markersize=12, label=r'EKF: $\rm J_{mdl}=2$') 673 | main_partial.plot(x_plt, MSE_Partial_dB[2,:], '-bo', linewidth=3, markersize=12, label=r'RTS: $\rm J_{mdl}=5$') 674 | main_partial.plot(x_plt, MSE_Partial_dB[3,:], '--bo', linewidth=3, markersize=12, label=r'RTS: $ \rm J_{mdl}=2$') 675 | main_partial.plot(x_plt, MSE_Partial_dB[4,:], '--g^', linewidth=3, markersize=12, label=r'RTSNet: $ \rm J_{mdl}=2$') 676 | 677 | main_partial.set(xlim=(x_plt[0], x_plt[len(x_plt)-1]), ylim=(-60, 10)) 678 | main_partial.legend(fontsize=20) 679 | plt.xlabel(r'$\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=20) 680 | plt.ylabel('MSE [dB]', fontsize=20) 681 | plt.xticks(fontsize=20) 682 | plt.yticks(fontsize=20) 683 | # plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 684 | plt.grid(True) 685 | 686 | ax2 = plt.axes([.15, .15, .25, .25]) 687 | x1, x2, y1, y2 = 19.5, 20.5, -35, -10 688 | ax2.set_xlim(x1, x2) 689 | ax2.set_ylim(y1, y2) 690 | ax2.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12) 691 | ax2.plot(x_plt, MSE_Partial_dB[0,:], '-yx', linewidth=3, markersize=12) 692 | ax2.plot(x_plt, MSE_Partial_dB[1,:], '--yx', linewidth=3, markersize=12) 693 | ax2.plot(x_plt, MSE_Partial_dB[2,:], '-bo', linewidth=3, markersize=12) 694 | ax2.plot(x_plt, MSE_Partial_dB[3,:], '--bo', linewidth=3, markersize=12) 695 | ax2.plot(x_plt, MSE_Partial_dB[4,:], '--g^', linewidth=3, markersize=12) 696 | ax2.grid(True) 697 | plt.savefig(fileName) 698 | 699 | 700 | fileName = self.folderName + 'Nonlinear_Pen_PartialF' 701 | magnifying_glass, main_partial = plt.subplots(figsize = [20, 15]) 702 | x_plt = 10 * torch.log10(1/r**2) 703 | NoiseFloor = -x_plt 704 | main_partial.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12, label=r'Noise Floor') 705 | main_partial.plot(x_plt, MSE_Partial_dB[0,:], '-yx', linewidth=4, markersize=12, label=r'EKF: $\rm L=1$') 706 | main_partial.plot(x_plt, MSE_Partial_dB[1,:], '--yx', linewidth=4, markersize=12, label=r'EKF: $\rm L=1.1$') 707 | main_partial.plot(x_plt, MSE_Partial_dB[2,:], '-bo', linewidth=2, markersize=12, label=r'RTS: $\rm L=1$') 708 | main_partial.plot(x_plt, MSE_Partial_dB[3,:], '--bo', linewidth=3, markersize=12, label=r'RTS: $ \rm L=1.1$') 709 | main_partial.plot(x_plt, MSE_Partial_dB[4,:], '--g^', linewidth=3, markersize=24, label=r'RTSNet: $ \rm L=1.1$') 710 | 711 | main_partial.set(xlim=(x_plt[0], x_plt[len(x_plt)-1]), ylim=(-75, 5)) 712 | main_partial.legend(fontsize=20) 713 | plt.xlabel(r'$\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=20) 714 | plt.ylabel('MSE [dB]', fontsize=20) 715 | plt.xticks(fontsize=20) 716 | plt.yticks(fontsize=20) 717 | # plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 718 | plt.grid(True) 719 | 720 | ax2 = plt.axes([.15, .15, .25, .25]) 721 | x1, x2, y1, y2 = 19.5, 20.5, -55, -15 722 | ax2.set_xlim(x1, x2) 723 | ax2.set_ylim(y1, y2) 724 | ax2.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12) 725 | ax2.plot(x_plt, MSE_Partial_dB[0,:], '-yx', linewidth=3, markersize=12) 726 | ax2.plot(x_plt, MSE_Partial_dB[1,:], '--yx', linewidth=3, markersize=12) 727 | ax2.plot(x_plt, MSE_Partial_dB[2,:], '-bo', linewidth=3, markersize=12) 728 | ax2.plot(x_plt, MSE_Partial_dB[3,:], '--bo', linewidth=3, markersize=12) 729 | ax2.plot(x_plt, MSE_Partial_dB[4,:], '--g^', linewidth=3, markersize=12) 730 | ax2.grid(True) 731 | plt.savefig(fileName) 732 | 733 | def Partial_Plot_H1(self, r, MSE_Partial_dB): 734 | fileName = self.folderName + 'Nonlinear_Lor_Partial_Hrot1' 735 | magnifying_glass, main_partial = plt.subplots(figsize = [20, 15]) 736 | x_plt = 10 * torch.log10(1/r**2) 737 | NoiseFloor = -x_plt 738 | main_partial.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12, label=r'Noise Floor') 739 | main_partial.plot(x_plt, MSE_Partial_dB[0,:], '-yx', linewidth=3, markersize=12, label=r'EKF: $\Delta{\theta}=0$') 740 | main_partial.plot(x_plt, MSE_Partial_dB[1,:], '--yx', linewidth=3, markersize=12, label=r'EKF: $\Delta{\theta}=1$') 741 | main_partial.plot(x_plt, MSE_Partial_dB[2,:], '-bo', linewidth=3, markersize=12, label=r'RTS: $\Delta{\theta}=0$') 742 | main_partial.plot(x_plt, MSE_Partial_dB[3,:], '--bo', linewidth=3, markersize=12, label=r'RTS: $\Delta{\theta}=1$') 743 | main_partial.plot(x_plt, MSE_Partial_dB[4,:], '--g^', linewidth=3, markersize=12, label=r'RTSNet: $\Delta{\theta}=1$') 744 | 745 | main_partial.set(xlim=(x_plt[0], x_plt[len(x_plt)-1]), ylim=(-60, 10)) 746 | main_partial.legend(fontsize=20) 747 | plt.xlabel(r'$\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=20) 748 | plt.ylabel('MSE [dB]', fontsize=20) 749 | plt.xticks(fontsize=20) 750 | plt.yticks(fontsize=20) 751 | # plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 752 | plt.grid(True) 753 | 754 | ax2 = plt.axes([.15, .15, .25, .25]) 755 | x1, x2, y1, y2 = 19.5, 20.5, -35, -10 756 | ax2.set_xlim(x1, x2) 757 | ax2.set_ylim(y1, y2) 758 | ax2.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12) 759 | ax2.plot(x_plt, MSE_Partial_dB[0,:], '-yx', linewidth=3, markersize=12) 760 | ax2.plot(x_plt, MSE_Partial_dB[1,:], '--yx', linewidth=3, markersize=12) 761 | ax2.plot(x_plt, MSE_Partial_dB[2,:], '-bo', linewidth=3, markersize=12) 762 | ax2.plot(x_plt, MSE_Partial_dB[3,:], '--bo', linewidth=3, markersize=12) 763 | ax2.plot(x_plt, MSE_Partial_dB[4,:], '--g^', linewidth=3, markersize=12) 764 | ax2.grid(True) 765 | plt.savefig(fileName) 766 | 767 | def Partial_Plot_KNetRTSNet_Compare(self, r, MSE_Partial_dB): 768 | fileName = self.folderName + 'Nonlinear_Lor_Partial_Hrot1_Compare' 769 | magnifying_glass, main_partial = plt.subplots(figsize = [20, 15]) 770 | x_plt = 10 * torch.log10(1/r**2) 771 | NoiseFloor = -x_plt 772 | main_partial.plot(x_plt, NoiseFloor, '--r', linewidth=3, markersize=12, label=r'Noise Floor') 773 | main_partial.plot(x_plt, MSE_Partial_dB[0,:], '--bo', linewidth=3, markersize=12, label=r'KNet: $\Delta{\theta}=1$') 774 | main_partial.plot(x_plt, MSE_Partial_dB[1,:], '--g^', linewidth=3, markersize=12, label=r'RTSNet: $\Delta{\theta}=1$') 775 | 776 | main_partial.set(xlim=(x_plt[0], x_plt[len(x_plt)-1]), ylim=(-60, 10)) 777 | main_partial.legend(fontsize=20) 778 | plt.xlabel(r'$\mathrm{\frac{1}{r^2}}$ [dB]', fontsize=20) 779 | plt.ylabel('MSE [dB]', fontsize=20) 780 | plt.xticks(fontsize=20) 781 | plt.yticks(fontsize=20) 782 | # plt.title('MSE vs inverse noise variance with inaccurate SS knowledge', fontsize=32) 783 | plt.grid(True) 784 | plt.savefig(fileName) 785 | 786 | def error_evolution(self,MSE_Net, trace_Net,MSE_KF, trace_KF): 787 | fileName = self.folderName + 'error_evolution' 788 | fontSize = 32 789 | # Figure 790 | fig, axs = plt.subplots(2, figsize = (25, 10)) 791 | # x_axis 792 | x_plt = range(0, MSE_Net.size()[0]) 793 | ## Figure 1: Error 794 | # Net 795 | y_plt1 = MSE_Net.detach().numpy() 796 | axs[0].plot(x_plt, y_plt1, '-bo', label=error_evol[0]) 797 | y_plt2 = trace_Net.detach().numpy() 798 | axs[0].plot(x_plt, y_plt2, '--yo', label=error_evol[1]) 799 | # EKF 800 | y_plt3 = MSE_KF.detach().numpy() 801 | axs[0].plot(x_plt, y_plt3, '-ro', label=error_evol[2]) 802 | y_plt4 = trace_KF.detach().numpy() 803 | axs[0].plot(x_plt, y_plt4, '--go', label=error_evol[3]) 804 | axs[0].legend(loc="upper right") 805 | 806 | ## Figure 2: Error Deviation 807 | # Net 808 | y_plt5 = MSE_Net.detach().numpy() - trace_Net.detach().numpy() 809 | axs[1].plot(x_plt, y_plt5, '-bo', label=error_evol[4]) 810 | # EKF 811 | y_plt6 = MSE_KF.detach().numpy() - trace_KF.detach().numpy() 812 | axs[1].plot(x_plt, y_plt6, '-ro', label=error_evol[5]) 813 | axs[1].legend(loc="upper right") 814 | 815 | axs[0].set(xlabel='Timestep', ylabel='Error [dB]') 816 | axs[1].set(xlabel='Timestep', ylabel='Error Deviation[dB]') 817 | axs[0].grid(True) 818 | axs[1].grid(True) 819 | fig.savefig(fileName) 820 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KalmanNet 2 | 3 | ## Feb.13, 2023 Update "batched" 4 | 5 | Support a batch of sequences being processed simultaneously, leading to dramatic efficiency improvement. 6 | 7 | ## Link to paper 8 | 9 | [KalmanNet: Neural Network Aided Kalman Filtering for Partially Known Dynamics](https://arxiv.org/abs/2107.10043) 10 | 11 | ## Running code 12 | 13 | This branch simulates architecture #2 in our paper. There are main files simulating the linear and non-linear cases respectively. 14 | 15 | * Linear case (canonical model or constant acceleration model) 16 | 17 | ``` 18 | python3 main_linear_canonical.py 19 | python3 main_linear_CA.py 20 | ``` 21 | 22 | * Non-linear Lorenz Attractor case (Discrete-Time, decimation, or Non-linear observation function) 23 | 24 | ``` 25 | python3 main_lor_DT.py 26 | python3 main_lor_decimation.py 27 | python3 main_lor_DT_NLobs.py 28 | ``` 29 | 30 | ## Parameter settings 31 | 32 | * Simulations/model_name/parameters.py 33 | 34 | Contain model settings: m, n, f/F, h/H, Q and R. 35 | 36 | * Simulations/config.py 37 | 38 | Contain dataset size, training parameters and network settings. 39 | 40 | * main files 41 | 42 | Set flags, paths, etc. 43 | 44 | 45 | -------------------------------------------------------------------------------- /Simulations/Extended_sysmdl.py: -------------------------------------------------------------------------------- 1 | """# **Class: System Model for Non-linear Cases** 2 | 3 | 1 Store system model parameters: 4 | state transition function f, 5 | observation function h, 6 | process noise Q, 7 | observation noise R, 8 | train&CV dataset sequence length T, 9 | test dataset sequence length T_test, 10 | state dimension m, 11 | observation dimension n, etc. 12 | 13 | 2 Generate datasets for non-linear cases 14 | """ 15 | 16 | import torch 17 | from torch.distributions.multivariate_normal import MultivariateNormal 18 | 19 | class SystemModel: 20 | 21 | def __init__(self, f, Q, h, R, T, T_test, m, n, prior_Q=None, prior_Sigma=None, prior_S=None): 22 | 23 | #################### 24 | ### Motion Model ### 25 | #################### 26 | self.f = f 27 | self.m = m 28 | self.Q = Q 29 | ######################### 30 | ### Observation Model ### 31 | ######################### 32 | self.h = h 33 | self.n = n 34 | self.R = R 35 | ################ 36 | ### Sequence ### 37 | ################ 38 | # Assign T 39 | self.T = T 40 | self.T_test = T_test 41 | 42 | ######################### 43 | ### Covariance Priors ### 44 | ######################### 45 | if prior_Q is None: 46 | self.prior_Q = torch.eye(self.m) 47 | else: 48 | self.prior_Q = prior_Q 49 | 50 | if prior_Sigma is None: 51 | self.prior_Sigma = torch.zeros((self.m, self.m)) 52 | else: 53 | self.prior_Sigma = prior_Sigma 54 | 55 | if prior_S is None: 56 | self.prior_S = torch.eye(self.n) 57 | else: 58 | self.prior_S = prior_S 59 | 60 | ##################### 61 | ### Init Sequence ### 62 | ##################### 63 | def InitSequence(self, m1x_0, m2x_0): 64 | 65 | self.m1x_0 = m1x_0 66 | self.m2x_0 = m2x_0 67 | 68 | def Init_batched_sequence(self, m1x_0_batch, m2x_0_batch): 69 | 70 | self.m1x_0_batch = m1x_0_batch 71 | self.x_prev = m1x_0_batch 72 | self.m2x_0_batch = m2x_0_batch 73 | 74 | ######################### 75 | ### Update Covariance ### 76 | ######################### 77 | def UpdateCovariance_Matrix(self, Q, R): 78 | 79 | self.Q = Q 80 | 81 | self.R = R 82 | 83 | ######################### 84 | ### Generate Sequence ### 85 | ######################### 86 | def GenerateSequence(self, Q_gen, R_gen, T): 87 | # Pre allocate an array for current state 88 | self.x = torch.zeros(size=[self.m, T]) 89 | # Pre allocate an array for current observation 90 | self.y = torch.zeros(size=[self.n, T]) 91 | # Set x0 to be x previous 92 | self.x_prev = self.m1x_0 93 | xt = self.x_prev 94 | 95 | # Generate Sequence Iteratively 96 | for t in range(0, T): 97 | 98 | ######################## 99 | #### State Evolution ### 100 | ######################## 101 | if torch.equal(Q_gen,torch.zeros(self.m,self.m)):# No noise 102 | xt = self.f(self.x_prev) 103 | elif self.m == 1: # 1 dim noise 104 | xt = self.f(self.x_prev) 105 | eq = torch.normal(mean=0, std=Q_gen) 106 | # Additive Process Noise 107 | xt = torch.add(xt,eq) 108 | else: 109 | xt = self.f(self.x_prev) 110 | mean = torch.zeros([self.m]) 111 | distrib = MultivariateNormal(loc=mean, covariance_matrix=Q_gen) 112 | eq = distrib.rsample() 113 | eq = torch.reshape(eq[:], xt.size()) 114 | # Additive Process Noise 115 | xt = torch.add(xt,eq) 116 | 117 | ################ 118 | ### Emission ### 119 | ################ 120 | yt = self.h(xt) 121 | # Observation Noise 122 | if self.n == 1: # 1 dim noise 123 | er = torch.normal(mean=0, std=R_gen) 124 | # Additive Observation Noise 125 | yt = torch.add(yt,er) 126 | else: 127 | mean = torch.zeros([self.n]) 128 | distrib = MultivariateNormal(loc=mean, covariance_matrix=R_gen) 129 | er = distrib.rsample() 130 | er = torch.reshape(er[:], yt.size()) 131 | # Additive Observation Noise 132 | yt = torch.add(yt,er) 133 | 134 | ######################## 135 | ### Squeeze to Array ### 136 | ######################## 137 | 138 | # Save Current State to Trajectory Array 139 | self.x[:, t] = torch.squeeze(xt,1) 140 | 141 | # Save Current Observation to Trajectory Array 142 | self.y[:, t] = torch.squeeze(yt,1) 143 | 144 | ################################ 145 | ### Save Current to Previous ### 146 | ################################ 147 | self.x_prev = xt 148 | 149 | 150 | ###################### 151 | ### Generate Batch ### 152 | ###################### 153 | def GenerateBatch(self, args, size, T, randomInit=False): 154 | if(randomInit): 155 | # Allocate Empty Array for Random Initial Conditions 156 | self.m1x_0_rand = torch.zeros(size, self.m, 1) 157 | if args.distribution == 'uniform': 158 | ### if Uniform Distribution for random init 159 | for i in range(size): 160 | initConditions = torch.rand_like(self.m1x_0) * args.variance 161 | self.m1x_0_rand[i,:,0:1] = initConditions.view(self.m,1) 162 | 163 | elif args.distribution == 'normal': 164 | ### if Normal Distribution for random init 165 | for i in range(size): 166 | distrib = MultivariateNormal(loc=torch.squeeze(self.m1x_0), covariance_matrix=self.m2x_0) 167 | initConditions = distrib.rsample().view(self.m,1) 168 | self.m1x_0_rand[i,:,0:1] = initConditions 169 | else: 170 | raise ValueError('args.distribution not supported!') 171 | 172 | self.Init_batched_sequence(self.m1x_0_rand, self.m2x_0)### for sequence generation 173 | else: # fixed init 174 | initConditions = self.m1x_0.view(1,self.m,1).expand(size,-1,-1) 175 | self.Init_batched_sequence(initConditions, self.m2x_0)### for sequence generation 176 | 177 | if(args.randomLength): 178 | # Allocate Array for Input and Target (use zero padding) 179 | self.Input = torch.zeros(size, self.n, args.T_max) 180 | self.Target = torch.zeros(size, self.m, args.T_max) 181 | self.lengthMask = torch.zeros((size,args.T_max), dtype=torch.bool)# init with all false 182 | # Init Sequence Lengths 183 | T_tensor = torch.round((args.T_max-args.T_min)*torch.rand(size)).int()+args.T_min # Uniform distribution [100,1000] 184 | for i in range(0, size): 185 | # Generate Sequence 186 | self.GenerateSequence(self.Q, self.R, T_tensor[i].item()) 187 | # Training sequence input 188 | self.Input[i, :, 0:T_tensor[i].item()] = self.y 189 | # Training sequence output 190 | self.Target[i, :, 0:T_tensor[i].item()] = self.x 191 | # Mask for sequence length 192 | self.lengthMask[i, 0:T_tensor[i].item()] = True 193 | 194 | else: 195 | # Allocate Empty Array for Input 196 | self.Input = torch.empty(size, self.n, T) 197 | # Allocate Empty Array for Target 198 | self.Target = torch.empty(size, self.m, T) 199 | 200 | # Set x0 to be x previous 201 | self.x_prev = self.m1x_0_batch 202 | xt = self.x_prev 203 | 204 | # Generate in a batched manner 205 | for t in range(0, T): 206 | ######################## 207 | #### State Evolution ### 208 | ######################## 209 | if torch.equal(self.Q,torch.zeros(self.m,self.m)):# No noise 210 | xt = self.f(self.x_prev) 211 | elif self.m == 1: # 1 dim noise 212 | xt = self.f(self.x_prev) 213 | eq = torch.normal(mean=torch.zeros(size), std=self.Q).view(size,1,1) 214 | # Additive Process Noise 215 | xt = torch.add(xt,eq) 216 | else: 217 | xt = self.f(self.x_prev) 218 | mean = torch.zeros([size, self.m]) 219 | distrib = MultivariateNormal(loc=mean, covariance_matrix=self.Q) 220 | eq = distrib.rsample().view(size,self.m,1) 221 | # Additive Process Noise 222 | xt = torch.add(xt,eq) 223 | 224 | ################ 225 | ### Emission ### 226 | ################ 227 | # Observation Noise 228 | if torch.equal(self.R,torch.zeros(self.n,self.n)):# No noise 229 | yt = self.h(xt) 230 | elif self.n == 1: # 1 dim noise 231 | yt = self.h(xt) 232 | er = torch.normal(mean=torch.zeros(size), std=self.R).view(size,1,1) 233 | # Additive Observation Noise 234 | yt = torch.add(yt,er) 235 | else: 236 | yt = self.h(xt) 237 | mean = torch.zeros([size,self.n]) 238 | distrib = MultivariateNormal(loc=mean, covariance_matrix=self.R) 239 | er = distrib.rsample().view(size,self.n,1) 240 | # Additive Observation Noise 241 | yt = torch.add(yt,er) 242 | 243 | ######################## 244 | ### Squeeze to Array ### 245 | ######################## 246 | 247 | # Save Current State to Trajectory Array 248 | self.Target[:, :, t] = torch.squeeze(xt,2) 249 | 250 | # Save Current Observation to Trajectory Array 251 | self.Input[:, :, t] = torch.squeeze(yt,2) 252 | 253 | ################################ 254 | ### Save Current to Previous ### 255 | ################################ 256 | self.x_prev = xt 257 | -------------------------------------------------------------------------------- /Simulations/Linear_CA/data/decimated_dt1e-2_T100_r0_randnInit.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KalmanNet/KalmanNet_TSP/828a2cf529bc84f43b37d543d916fe5858054457/Simulations/Linear_CA/data/decimated_dt1e-2_T100_r0_randnInit.pt -------------------------------------------------------------------------------- /Simulations/Linear_CA/parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the parameters for the simulations with linear kinematic model 3 | * Constant Acceleration Model (CA) 4 | # full state P, V, A 5 | # only postion P 6 | * Constant Velocity Model (CV) 7 | """ 8 | 9 | import torch 10 | 11 | m = 3 # dim of state for CA model 12 | m_cv = 2 # dim of state for CV model 13 | 14 | delta_t_gen = 1e-2 15 | 16 | ######################################################### 17 | ### state evolution matrix F and observation matrix H ### 18 | ######################################################### 19 | F_gen = torch.tensor([[1, delta_t_gen,0.5*delta_t_gen**2], 20 | [0, 1, delta_t_gen], 21 | [0, 0, 1]]).float() 22 | 23 | F_CV = torch.tensor([[1, delta_t_gen], 24 | [0, 1]]).float() 25 | 26 | # Full observation 27 | H_identity = torch.eye(3) 28 | # Observe only the postion 29 | H_onlyPos = torch.tensor([[1, 0, 0]]).float() 30 | 31 | ############################################### 32 | ### process noise Q and observation noise R ### 33 | ############################################### 34 | # Noise Parameters 35 | r2 = torch.tensor([1]).float() 36 | q2 = torch.tensor([1]).float() 37 | 38 | Q_gen = q2 * torch.tensor([[1/20*delta_t_gen**5, 1/8*delta_t_gen**4,1/6*delta_t_gen**3], 39 | [ 1/8*delta_t_gen**4, 1/3*delta_t_gen**3,1/2*delta_t_gen**2], 40 | [ 1/6*delta_t_gen**3, 1/2*delta_t_gen**2, delta_t_gen]]).float() 41 | 42 | Q_CV = q2 * torch.tensor([[1/3*delta_t_gen**3, 1/2*delta_t_gen**2], 43 | [1/2*delta_t_gen**2, delta_t_gen]]).float() 44 | 45 | R_3 = r2 * torch.eye(3) 46 | R_2 = r2 * torch.eye(2) 47 | 48 | R_onlyPos = r2 -------------------------------------------------------------------------------- /Simulations/Linear_canonical/data/2x2_rq020_T100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KalmanNet/KalmanNet_TSP/828a2cf529bc84f43b37d543d916fe5858054457/Simulations/Linear_canonical/data/2x2_rq020_T100.pt -------------------------------------------------------------------------------- /Simulations/Linear_canonical/parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the parameters for the simulations with linear canonical model 3 | * Linear State Space Models with Full Information 4 | # v = 0, -10, -20 dB 5 | # scaling model dim to 5x5, 10x10, 20x20, etc 6 | # scalable trajectory length T 7 | # random initial state 8 | * Linear SS Models with Partial Information 9 | # observation model mismatch 10 | # evolution model mismatch 11 | """ 12 | 13 | import torch 14 | 15 | m = 2 # state dimension = 2, 5, 10, etc. 16 | n = 2 # observation dimension = 2, 5, 10, etc. 17 | 18 | ################################## 19 | ### Initial state and variance ### 20 | ################################## 21 | m1_0 = torch.zeros(m, 1) # initial state mean 22 | 23 | ######################################################### 24 | ### state evolution matrix F and observation matrix H ### 25 | ######################################################### 26 | # F in canonical form 27 | F = torch.eye(m) 28 | F[0] = torch.ones(1,m) 29 | 30 | if m == 2: 31 | # H = I 32 | H = torch.eye(2) 33 | else: 34 | # H in reverse canonical form 35 | H = torch.zeros(n,n) 36 | H[0] = torch.ones(1,n) 37 | for i in range(n): 38 | H[i,n-1-i] = 1 39 | 40 | ####################### 41 | ### Rotated F and H ### 42 | ####################### 43 | F_rotated = torch.zeros_like(F) 44 | H_rotated = torch.zeros_like(H) 45 | if(m==2): 46 | alpha_degree = 10 # rotation angle in degree 47 | rotate_alpha = torch.tensor([alpha_degree/180*torch.pi]) 48 | cos_alpha = torch.cos(rotate_alpha) 49 | sin_alpha = torch.sin(rotate_alpha) 50 | rotate_matrix = torch.tensor([[cos_alpha, -sin_alpha], 51 | [sin_alpha, cos_alpha]]) 52 | 53 | F_rotated = torch.mm(F,rotate_matrix) 54 | H_rotated = torch.mm(H,rotate_matrix) 55 | 56 | ############################################### 57 | ### process noise Q and observation noise R ### 58 | ############################################### 59 | # Noise variance takes the form of a diagonal matrix 60 | Q_structure = torch.eye(m) 61 | R_structure = torch.eye(n) -------------------------------------------------------------------------------- /Simulations/Linear_sysmdl.py: -------------------------------------------------------------------------------- 1 | """# **Class: System Model for Linear Cases** 2 | 3 | 1 Store system model parameters: 4 | state transition matrix F, 5 | observation matrix H, 6 | process noise covariance matrix Q, 7 | observation noise covariance matrix R, 8 | train&CV dataset sequence length T, 9 | test dataset sequence length T_test, etc. 10 | 11 | 2 Generate dataset for linear cases 12 | """ 13 | 14 | import torch 15 | from torch.distributions.multivariate_normal import MultivariateNormal 16 | 17 | class SystemModel: 18 | 19 | def __init__(self, F, Q, H, R, T, T_test, prior_Q=None, prior_Sigma=None, prior_S=None): 20 | 21 | #################### 22 | ### Motion Model ### 23 | #################### 24 | self.F = F 25 | self.m = self.F.size()[0] 26 | self.Q = Q 27 | 28 | ######################### 29 | ### Observation Model ### 30 | ######################### 31 | self.H = H 32 | self.n = self.H.size()[0] 33 | self.R = R 34 | 35 | ################ 36 | ### Sequence ### 37 | ################ 38 | # Assign T 39 | self.T = T 40 | self.T_test = T_test 41 | 42 | ######################### 43 | ### Covariance Priors ### 44 | ######################### 45 | if prior_Q is None: 46 | self.prior_Q = torch.eye(self.m) 47 | else: 48 | self.prior_Q = prior_Q 49 | 50 | if prior_Sigma is None: 51 | self.prior_Sigma = torch.zeros((self.m, self.m)) 52 | else: 53 | self.prior_Sigma = prior_Sigma 54 | 55 | if prior_S is None: 56 | self.prior_S = torch.eye(self.n) 57 | else: 58 | self.prior_S = prior_S 59 | 60 | 61 | def f(self, x): 62 | batched_F = self.F.to(x.device).view(1,self.F.shape[0],self.F.shape[1]).expand(x.shape[0],-1,-1) 63 | return torch.bmm(batched_F, x) 64 | 65 | def h(self, x): 66 | batched_H = self.H.to(x.device).view(1,self.H.shape[0],self.H.shape[1]).expand(x.shape[0],-1,-1) 67 | return torch.bmm(batched_H, x) 68 | 69 | ##################### 70 | ### Init Sequence ### 71 | ##################### 72 | def InitSequence(self, m1x_0, m2x_0): 73 | 74 | self.m1x_0 = m1x_0 75 | self.x_prev = m1x_0 76 | self.m2x_0 = m2x_0 77 | 78 | def Init_batched_sequence(self, m1x_0_batch, m2x_0_batch): 79 | 80 | self.m1x_0_batch = m1x_0_batch 81 | self.x_prev = m1x_0_batch 82 | self.m2x_0_batch = m2x_0_batch 83 | 84 | ######################### 85 | ### Update Covariance ### 86 | ######################### 87 | def UpdateCovariance_Matrix(self, Q, R): 88 | 89 | self.Q = Q 90 | 91 | self.R = R 92 | 93 | ######################### 94 | ### Generate Sequence ### 95 | ######################### 96 | def GenerateSequence(self, Q_gen, R_gen, T): 97 | # Pre allocate an array for current state 98 | self.x = torch.zeros(size=[self.m, T]) 99 | # Pre allocate an array for current observation 100 | self.y = torch.zeros(size=[self.n, T]) 101 | # Set x0 to be x previous 102 | self.x_prev = self.m1x_0 103 | xt = self.x_prev 104 | 105 | # Generate Sequence Iteratively 106 | for t in range(0, T): 107 | 108 | ######################## 109 | #### State Evolution ### 110 | ######################## 111 | if torch.equal(Q_gen,torch.zeros(self.m,self.m)):# No noise 112 | xt = self.F.matmul(self.x_prev) 113 | elif self.m == 1: # 1 dim noise 114 | xt = self.F.matmul(self.x_prev) 115 | eq = torch.normal(mean=0, std=Q_gen) 116 | # Additive Process Noise 117 | xt = torch.add(xt,eq) 118 | else: 119 | xt = self.F.matmul(self.x_prev) 120 | mean = torch.zeros([self.m]) 121 | distrib = MultivariateNormal(loc=mean, covariance_matrix=Q_gen) 122 | eq = distrib.rsample() 123 | # eq = torch.normal(mean, self.q) 124 | eq = torch.reshape(eq[:], xt.size()) 125 | # Additive Process Noise 126 | xt = torch.add(xt,eq) 127 | 128 | ################ 129 | ### Emission ### 130 | ################ 131 | # Observation Noise 132 | if torch.equal(R_gen,torch.zeros(self.n,self.n)):# No noise 133 | yt = self.H.matmul(xt) 134 | elif self.n == 1: # 1 dim noise 135 | yt = self.H.matmul(xt) 136 | er = torch.normal(mean=0, std=R_gen) 137 | # Additive Observation Noise 138 | yt = torch.add(yt,er) 139 | else: 140 | yt = self.H.matmul(xt) 141 | mean = torch.zeros([self.n]) 142 | distrib = MultivariateNormal(loc=mean, covariance_matrix=R_gen) 143 | er = distrib.rsample() 144 | er = torch.reshape(er[:], yt.size()) 145 | # Additive Observation Noise 146 | yt = torch.add(yt,er) 147 | 148 | ######################## 149 | ### Squeeze to Array ### 150 | ######################## 151 | 152 | # Save Current State to Trajectory Array 153 | self.x[:, t] = torch.squeeze(xt,1) 154 | 155 | # Save Current Observation to Trajectory Array 156 | self.y[:, t] = torch.squeeze(yt,1) 157 | 158 | ################################ 159 | ### Save Current to Previous ### 160 | ################################ 161 | self.x_prev = xt 162 | 163 | ###################### 164 | ### Generate Batch ### 165 | ###################### 166 | def GenerateBatch(self, args, size, T, randomInit=False): 167 | if(randomInit): 168 | # Allocate Empty Array for Random Initial Conditions 169 | self.m1x_0_rand = torch.zeros(size, self.m, 1) 170 | if args.distribution == 'uniform': 171 | ### if Uniform Distribution for random init 172 | for i in range(size): 173 | initConditions = torch.rand_like(self.m1x_0) * args.variance 174 | self.m1x_0_rand[i,:,0:1] = initConditions.view(self.m,1) 175 | 176 | elif args.distribution == 'normal': 177 | ### if Normal Distribution for random init 178 | for i in range(size): 179 | distrib = MultivariateNormal(loc=torch.squeeze(self.m1x_0), covariance_matrix=self.m2x_0) 180 | initConditions = distrib.rsample().view(self.m,1) 181 | self.m1x_0_rand[i,:,0:1] = initConditions 182 | else: 183 | raise ValueError('args.distribution not supported!') 184 | 185 | self.Init_batched_sequence(self.m1x_0_rand, self.m2x_0)### for sequence generation 186 | else: # fixed init 187 | initConditions = self.m1x_0.view(1,self.m,1).expand(size,-1,-1) 188 | self.Init_batched_sequence(initConditions, self.m2x_0)### for sequence generation 189 | 190 | if(args.randomLength): 191 | # Allocate Array for Input and Target (use zero padding) 192 | self.Input = torch.zeros(size, self.n, args.T_max) 193 | self.Target = torch.zeros(size, self.m, args.T_max) 194 | self.lengthMask = torch.zeros((size,args.T_max), dtype=torch.bool)# init with all false 195 | # Init Sequence Lengths 196 | T_tensor = torch.round((args.T_max-args.T_min)*torch.rand(size)).int()+args.T_min # Uniform distribution [100,1000] 197 | for i in range(0, size): 198 | # Generate Sequence 199 | self.GenerateSequence(self.Q, self.R, T_tensor[i].item()) 200 | # Training sequence input 201 | self.Input[i, :, 0:T_tensor[i].item()] = self.y 202 | # Training sequence output 203 | self.Target[i, :, 0:T_tensor[i].item()] = self.x 204 | # Mask for sequence length 205 | self.lengthMask[i, 0:T_tensor[i].item()] = True 206 | 207 | else: 208 | # Allocate Empty Array for Input 209 | self.Input = torch.empty(size, self.n, T) 210 | # Allocate Empty Array for Target 211 | self.Target = torch.empty(size, self.m, T) 212 | 213 | # Set x0 to be x previous 214 | self.x_prev = self.m1x_0_batch 215 | xt = self.x_prev 216 | 217 | # Generate in a batched manner 218 | for t in range(0, T): 219 | ######################## 220 | #### State Evolution ### 221 | ######################## 222 | if torch.equal(self.Q,torch.zeros(self.m,self.m)):# No noise 223 | xt = self.f(self.x_prev) 224 | elif self.m == 1: # 1 dim noise 225 | xt = self.f(self.x_prev) 226 | eq = torch.normal(mean=torch.zeros(size), std=self.Q).view(size,1,1) 227 | # Additive Process Noise 228 | xt = torch.add(xt,eq) 229 | else: 230 | xt = self.f(self.x_prev) 231 | mean = torch.zeros([size, self.m]) 232 | distrib = MultivariateNormal(loc=mean, covariance_matrix=self.Q) 233 | eq = distrib.rsample().view(size,self.m,1) 234 | # Additive Process Noise 235 | xt = torch.add(xt,eq) 236 | 237 | ################ 238 | ### Emission ### 239 | ################ 240 | # Observation Noise 241 | if torch.equal(self.R,torch.zeros(self.n,self.n)):# No noise 242 | yt = self.h(xt) 243 | elif self.n == 1: # 1 dim noise 244 | yt = self.h(xt) 245 | er = torch.normal(mean=torch.zeros(size), std=self.R).view(size,1,1) 246 | # Additive Observation Noise 247 | yt = torch.add(yt,er) 248 | else: 249 | yt = self.H.matmul(xt) 250 | mean = torch.zeros([size,self.n]) 251 | distrib = MultivariateNormal(loc=mean, covariance_matrix=self.R) 252 | er = distrib.rsample().view(size,self.n,1) 253 | # Additive Observation Noise 254 | yt = torch.add(yt,er) 255 | 256 | ######################## 257 | ### Squeeze to Array ### 258 | ######################## 259 | 260 | # Save Current State to Trajectory Array 261 | self.Target[:, :, t] = torch.squeeze(xt,2) 262 | 263 | # Save Current Observation to Trajectory Array 264 | self.Input[:, :, t] = torch.squeeze(yt,2) 265 | 266 | ################################ 267 | ### Save Current to Previous ### 268 | ################################ 269 | self.x_prev = xt 270 | 271 | 272 | def sampling(self, q, r, gain): 273 | 274 | if (gain != 0): 275 | gain_q = 0.1 276 | #aq = gain * q * np.random.randn(self.m, self.m) 277 | aq = gain_q * q * torch.eye(self.m) 278 | #aq = gain_q * q * torch.tensor([[1.0, 1.0], [1.0, 1.0]]) 279 | else: 280 | aq = 0 281 | 282 | Aq = q * torch.eye(self.m) + aq 283 | Q_gen = torch.transpose(Aq, 0, 1) * Aq 284 | 285 | if (gain != 0): 286 | gain_r = 0.5 287 | #ar = gain * r * np.random.randn(self.n, self.n) 288 | ar = gain_r * r * torch.eye(self.n) 289 | #ar = gain_r * r * torch.tensor([[1.0, 1.0], [1.0, 1.0]]) 290 | 291 | else: 292 | ar = 0 293 | 294 | Ar = r * torch.eye(self.n) + ar 295 | R_gen = torch.transpose(Ar, 0, 1) * Ar 296 | 297 | return [Q_gen, R_gen] 298 | -------------------------------------------------------------------------------- /Simulations/Lorenz_Atractor/data/data_gen.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KalmanNet/KalmanNet_TSP/828a2cf529bc84f43b37d543d916fe5858054457/Simulations/Lorenz_Atractor/data/data_gen.pt -------------------------------------------------------------------------------- /Simulations/Lorenz_Atractor/parameters.py: -------------------------------------------------------------------------------- 1 | """This file contains the parameters for the Lorenz Atractor simulation. 2 | 3 | Update 2023-02-06: f and h support batch size speed up 4 | 5 | """ 6 | 7 | 8 | import torch 9 | import math 10 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 11 | from torch import autograd 12 | 13 | ######################### 14 | ### Design Parameters ### 15 | ######################### 16 | m = 3 17 | n = 3 18 | variance = 0 19 | m1x_0 = torch.ones(m, 1) 20 | m2x_0 = 0 * 0 * torch.eye(m) 21 | 22 | ### Decimation 23 | delta_t_gen = 1e-5 24 | delta_t = 0.02 25 | ratio = delta_t_gen/delta_t 26 | 27 | ### Taylor expansion order 28 | J = 5 29 | J_mod = 2 30 | 31 | ### Angle of rotation in the 3 axes 32 | roll_deg = yaw_deg = pitch_deg = 1 33 | 34 | roll = roll_deg * (math.pi/180) 35 | yaw = yaw_deg * (math.pi/180) 36 | pitch = pitch_deg * (math.pi/180) 37 | 38 | RX = torch.tensor([ 39 | [1, 0, 0], 40 | [0, math.cos(roll), -math.sin(roll)], 41 | [0, math.sin(roll), math.cos(roll)]]) 42 | RY = torch.tensor([ 43 | [math.cos(pitch), 0, math.sin(pitch)], 44 | [0, 1, 0], 45 | [-math.sin(pitch), 0, math.cos(pitch)]]) 46 | RZ = torch.tensor([ 47 | [math.cos(yaw), -math.sin(yaw), 0], 48 | [math.sin(yaw), math.cos(yaw), 0], 49 | [0, 0, 1]]) 50 | 51 | RotMatrix = torch.mm(torch.mm(RZ, RY), RX) 52 | 53 | ### Auxiliar MultiDimensional Tensor B and C (they make A --> Differential equation matrix) 54 | C = torch.tensor([[-10, 10, 0], 55 | [ 28, -1, 0], 56 | [ 0, 0, -8/3]]).float() 57 | 58 | ###################################################### 59 | ### State evolution function f for Lorenz Atractor ### 60 | ###################################################### 61 | ### f_gen is for dataset generation 62 | def f_gen(x, jacobian=False): 63 | BX = torch.zeros([x.shape[0],m,m]).float().to(x.device) #[batch_size, m, m] 64 | BX[:,1,0] = torch.squeeze(-x[:,2,:]) 65 | BX[:,2,0] = torch.squeeze(x[:,1,:]) 66 | Const = C.to(x.device) 67 | A = torch.add(BX, Const) 68 | # Taylor Expansion for F 69 | F = torch.eye(m).to(x.device) 70 | F = F.reshape((1, m, m)).repeat(x.shape[0], 1, 1) # [batch_size, m, m] identity matrix 71 | for j in range(1,J+1): 72 | F_add = (torch.matrix_power(A*delta_t_gen, j)/math.factorial(j)) 73 | F = torch.add(F, F_add) 74 | if jacobian: 75 | return torch.bmm(F, x), F 76 | else: 77 | return torch.bmm(F, x) 78 | 79 | ### f will be fed to filters and KNet, note that the mismatch comes from delta_t 80 | def f(x, jacobian=False): 81 | BX = torch.zeros([x.shape[0],m,m]).float().to(x.device) #[batch_size, m, m] 82 | BX[:,1,0] = torch.squeeze(-x[:,2,:]) 83 | BX[:,2,0] = torch.squeeze(x[:,1,:]) 84 | Const = C.to(x.device) 85 | A = torch.add(BX, Const) 86 | # Taylor Expansion for F 87 | F = torch.eye(m).to(x.device) 88 | F = F.reshape((1, m, m)).repeat(x.shape[0], 1, 1) # [batch_size, m, m] identity matrix 89 | for j in range(1,J+1): 90 | F_add = (torch.matrix_power(A*delta_t, j)/math.factorial(j)) 91 | F = torch.add(F, F_add) 92 | if jacobian: 93 | return torch.bmm(F, x), F 94 | else: 95 | return torch.bmm(F, x) 96 | 97 | ### fInacc will be fed to filters and KNet, note that the mismatch comes from delta_t and J_mod 98 | def fInacc(x, jacobian=False): 99 | BX = torch.zeros([x.shape[0],m,m]).float().to(x.device) #[batch_size, m, m] 100 | BX[:,1,0] = torch.squeeze(-x[:,2,:]) 101 | BX[:,2,0] = torch.squeeze(x[:,1,:]) 102 | Const = C.to(x.device) 103 | A = torch.add(BX, Const) 104 | # Taylor Expansion for F 105 | F = torch.eye(m).to(x.device) 106 | F = F.reshape((1, m, m)) 107 | F = F.repeat(x.shape[0], 1, 1) # [batch_size, m, m] identity matrix 108 | for j in range(1,J_mod+1): 109 | F_add = (torch.matrix_power(A*delta_t, j)/math.factorial(j)) 110 | F = torch.add(F, F_add) 111 | if jacobian: 112 | return torch.bmm(F, x), F 113 | else: 114 | return torch.bmm(F, x) 115 | 116 | ### fInacc will be fed to filters and KNet, note that the mismatch comes from delta_t and rotation 117 | def fRotate(x, jacobian=False): 118 | BX = torch.zeros([x.shape[0],m,m]).float().to(x.device) #[batch_size, m, m] 119 | BX[:,1,0] = torch.squeeze(-x[:,2,:]) 120 | BX[:,2,0] = torch.squeeze(x[:,1,:]) 121 | Const = C.to(x.device) 122 | A = torch.add(BX, Const) 123 | # Taylor Expansion for F 124 | F = torch.eye(m).to(x.device) 125 | F = F.reshape((1, m, m)) 126 | F = F.repeat(x.shape[0], 1, 1) # [batch_size, m, m] identity matrix 127 | for j in range(1,J+1): 128 | F_add = (torch.matrix_power(A*delta_t, j)/math.factorial(j)) 129 | F = torch.add(F, F_add) 130 | F_rotated = torch.bmm(RotMatrix.reshape(1,m,m).repeat(x.shape[0],1,1),F) 131 | if jacobian: 132 | return torch.bmm(F_rotated, x), F_rotated 133 | else: 134 | return torch.bmm(F_rotated, x) 135 | 136 | ################################################## 137 | ### Observation function h for Lorenz Atractor ### 138 | ################################################## 139 | H_design = torch.eye(n) 140 | H_Rotate = torch.mm(RotMatrix,H_design) 141 | H_Rotate_inv = torch.inverse(H_Rotate) 142 | 143 | def h(x, jacobian=False): 144 | H = H_design.to(x.device).reshape((1, n, n)).repeat(x.shape[0], 1, 1) # [batch_size, n, n] identity matrix 145 | y = torch.bmm(H,x) 146 | if jacobian: 147 | return y, H 148 | else: 149 | return y 150 | 151 | def h_nonlinear(x): 152 | return toSpherical(x) 153 | 154 | def hRotate(x, jacobian=False): 155 | H = H_Rotate.to(x.device).reshape((1, n, n)).repeat(x.shape[0], 1, 1)# [batch_size, n, n] rotated matrix 156 | if jacobian: 157 | return torch.bmm(H,x), H 158 | else: 159 | return torch.bmm(H,x) 160 | 161 | def h_nobatch(x, jacobian=False): 162 | H = H_design.to(x.device) 163 | y = torch.matmul(H,x) 164 | if jacobian: 165 | return y, H 166 | else: 167 | return y 168 | ############################################### 169 | ### process noise Q and observation noise R ### 170 | ############################################### 171 | Q_non_diag = False 172 | R_non_diag = False 173 | 174 | Q_structure = torch.eye(m) 175 | R_structure = torch.eye(n) 176 | 177 | if(Q_non_diag): 178 | q_d = 1 179 | q_nd = 1/2 180 | Q = torch.tensor([[q_d, q_nd, q_nd],[q_nd, q_d, q_nd],[q_nd, q_nd, q_d]]) 181 | 182 | if(R_non_diag): 183 | r_d = 1 184 | r_nd = 1/2 185 | R = torch.tensor([[r_d, r_nd, r_nd],[r_nd, r_d, r_nd],[r_nd, r_nd, r_d]]) 186 | 187 | ################################## 188 | ### Utils for non-linear cases ### 189 | ################################## 190 | def getJacobian(x, g): 191 | """ 192 | Currently, pytorch does not have a built-in function to compute Jacobian matrix 193 | in a batched manner, so we have to iterate over the batch dimension. 194 | 195 | input x (torch.tensor): [batch_size, m/n, 1] 196 | input g (function): function to be differentiated 197 | output Jac (torch.tensor): [batch_size, m, m] for f, [batch_size, n, m] for h 198 | """ 199 | # Method 1: using autograd.functional.jacobian 200 | # batch_size = x.shape[0] 201 | # Jac_x0 = torch.squeeze(autograd.functional.jacobian(g, torch.unsqueeze(x[0,:,:],0))) 202 | # Jac = torch.zeros([batch_size, Jac_x0.shape[0], Jac_x0.shape[1]]) 203 | # Jac[0,:,:] = Jac_x0 204 | # for i in range(1,batch_size): 205 | # Jac[i,:,:] = torch.squeeze(autograd.functional.jacobian(g, torch.unsqueeze(x[i,:,:],0))) 206 | # Method 2: using F, H directly 207 | _,Jac = g(x, jacobian=True) 208 | return Jac 209 | 210 | def toSpherical(cart): 211 | """ 212 | input cart (torch.tensor): [batch_size, m, 1] or [batch_size, m] 213 | output spher (torch.tensor): [batch_size, n, 1] 214 | """ 215 | rho = torch.linalg.norm(cart,dim=1).reshape(cart.shape[0], 1)# [batch_size, 1] 216 | phi = torch.atan2(cart[:, 1, ...], cart[:, 0, ...]).reshape(cart.shape[0], 1) # [batch_size, 1] 217 | phi = phi + (phi < 0).type_as(phi) * (2 * torch.pi) 218 | 219 | theta = torch.div(torch.squeeze(cart[:, 2, ...]), torch.squeeze(rho)) 220 | theta = torch.acos(theta).reshape(cart.shape[0], 1) # [batch_size, 1] 221 | 222 | spher = torch.cat([rho, theta, phi], dim=1).reshape(cart.shape[0],3,1) # [batch_size, n, 1] 223 | 224 | return spher 225 | 226 | def toCartesian(sphe): 227 | """ 228 | input sphe (torch.tensor): [batch_size, n, 1] or [batch_size, n] 229 | output cart (torch.tensor): [batch_size, n] 230 | """ 231 | rho = sphe[:, 0, ...] 232 | theta = sphe[:, 1, ...] 233 | phi = sphe[:, 2, ...] 234 | 235 | x = (rho * torch.sin(theta) * torch.cos(phi)).reshape(sphe.shape[0],1) 236 | y = (rho * torch.sin(theta) * torch.sin(phi)).reshape(sphe.shape[0],1) 237 | z = (rho * torch.cos(theta)).reshape(sphe.shape[0],1) 238 | 239 | cart = torch.cat([x,y,z],dim=1).reshape(cart.shape[0],3,1) # [batch_size, n, 1] 240 | 241 | return cart -------------------------------------------------------------------------------- /Simulations/config.py: -------------------------------------------------------------------------------- 1 | """This file contains the settings for the simulation""" 2 | import argparse 3 | 4 | def general_settings(): 5 | ### Dataset settings 6 | # Sizes 7 | parser = argparse.ArgumentParser(prog = 'KalmanNet',\ 8 | description = 'Dataset, training and network parameters') 9 | parser.add_argument('--N_E', type=int, default=1000, metavar='trainset-size', 10 | help='input training dataset size (# of sequences)') 11 | parser.add_argument('--N_CV', type=int, default=100, metavar='cvset-size', 12 | help='input cross validation dataset size (# of sequences)') 13 | parser.add_argument('--N_T', type=int, default=200, metavar='testset-size', 14 | help='input test dataset size (# of sequences)') 15 | parser.add_argument('--T', type=int, default=100, metavar='length', 16 | help='input sequence length') 17 | parser.add_argument('--T_test', type=int, default=100, metavar='test-length', 18 | help='input test sequence length') 19 | # Random length 20 | parser.add_argument('--randomLength', type=bool, default=False, metavar='rl', 21 | help='if True, random sequence length') 22 | parser.add_argument('--T_max', type=int, default=1000, metavar='maximum-length', 23 | help='if random sequence length, input max sequence length') 24 | parser.add_argument('--T_min', type=int, default=100, metavar='minimum-length', 25 | help='if random sequence length, input min sequence length') 26 | # Random initial state 27 | parser.add_argument('--randomInit_train', type=bool, default=False, metavar='ri_train', 28 | help='if True, random initial state for training set') 29 | parser.add_argument('--randomInit_cv', type=bool, default=False, metavar='ri_cv', 30 | help='if True, random initial state for cross validation set') 31 | parser.add_argument('--randomInit_test', type=bool, default=False, metavar='ri_test', 32 | help='if True, random initial state for test set') 33 | parser.add_argument('--variance', type=float, default=100, metavar='variance', 34 | help='input variance for the random initial state with uniform distribution') 35 | parser.add_argument('--distribution', type=str, default='normal', metavar='distribution', 36 | help='input distribution for the random initial state (uniform/normal)') 37 | 38 | 39 | ### Training settings 40 | parser.add_argument('--use_cuda', type=bool, default=False, metavar='CUDA', 41 | help='if True, use CUDA') 42 | parser.add_argument('--n_steps', type=int, default=1000, metavar='N_steps', 43 | help='number of training steps (default: 1000)') 44 | parser.add_argument('--n_batch', type=int, default=20, metavar='N_B', 45 | help='input batch size for training (default: 20)') 46 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 47 | help='learning rate (default: 1e-3)') 48 | parser.add_argument('--wd', type=float, default=1e-4, metavar='WD', 49 | help='weight decay (default: 1e-4)') 50 | parser.add_argument('--CompositionLoss', type=bool, default=False, metavar='loss', 51 | help='if True, use composition loss') 52 | parser.add_argument('--alpha', type=float, default=0.3, metavar='alpha', 53 | help='input alpha [0,1] for the composition loss') 54 | 55 | 56 | ### KalmanNet settings 57 | parser.add_argument('--in_mult_KNet', type=int, default=5, metavar='in_mult_KNet', 58 | help='input dimension multiplier for KNet') 59 | parser.add_argument('--out_mult_KNet', type=int, default=40, metavar='out_mult_KNet', 60 | help='output dimension multiplier for KNet') 61 | 62 | args = parser.parse_args() 63 | return args 64 | -------------------------------------------------------------------------------- /Simulations/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The file contains utility functions for the simulations. 3 | """ 4 | 5 | import torch 6 | 7 | def DataGen(args, SysModel_data, fileName): 8 | 9 | ################################## 10 | ### Generate Training Sequence ### 11 | ################################## 12 | SysModel_data.GenerateBatch(args, args.N_E, args.T, randomInit=args.randomInit_train) 13 | train_input = SysModel_data.Input 14 | train_target = SysModel_data.Target 15 | ### init conditions ### 16 | train_init = SysModel_data.m1x_0_batch #size: N_E x m x 1 17 | ### length mask ### 18 | if args.randomLength: 19 | train_lengthMask = SysModel_data.lengthMask 20 | 21 | #################################### 22 | ### Generate Validation Sequence ### 23 | #################################### 24 | SysModel_data.GenerateBatch(args, args.N_CV, args.T, randomInit=args.randomInit_cv) 25 | cv_input = SysModel_data.Input 26 | cv_target = SysModel_data.Target 27 | cv_init = SysModel_data.m1x_0_batch #size: N_CV x m x 1 28 | ### length mask ### 29 | if args.randomLength: 30 | cv_lengthMask = SysModel_data.lengthMask 31 | 32 | ############################## 33 | ### Generate Test Sequence ### 34 | ############################## 35 | SysModel_data.GenerateBatch(args, args.N_T, args.T_test, randomInit=args.randomInit_test) 36 | test_input = SysModel_data.Input 37 | test_target = SysModel_data.Target 38 | test_init = SysModel_data.m1x_0_batch #size: N_T x m x 1 39 | ### length mask ### 40 | if args.randomLength: 41 | test_lengthMask = SysModel_data.lengthMask 42 | 43 | ################# 44 | ### Save Data ### 45 | ################# 46 | if(args.randomLength): 47 | torch.save([train_input, train_target, cv_input, cv_target, test_input, test_target,train_init, cv_init, test_init, train_lengthMask,cv_lengthMask,test_lengthMask], fileName) 48 | else: 49 | torch.save([train_input, train_target, cv_input, cv_target, test_input, test_target,train_init, cv_init, test_init], fileName) 50 | 51 | def DecimateData(all_tensors, t_gen,t_mod, offset=0): 52 | 53 | # ratio: defines the relation between the sampling time of the true process and of the model (has to be an integer) 54 | ratio = round(t_mod/t_gen) 55 | 56 | i = 0 57 | all_tensors_out = all_tensors 58 | for tensor in all_tensors: 59 | tensor = tensor[:,(0+offset)::ratio] 60 | if(i==0): 61 | all_tensors_out = torch.cat([tensor], dim=0).view(1,all_tensors.size()[1],-1) 62 | else: 63 | all_tensors_out = torch.cat([all_tensors_out,tensor.view(1,all_tensors.size()[1],-1)], dim=0) 64 | i += 1 65 | 66 | return all_tensors_out 67 | 68 | def Decimate_and_perturbate_Data(true_process, delta_t, delta_t_mod, N_examples, h, lambda_r, offset=0): 69 | 70 | # Decimate high resolution process 71 | decimated_process = DecimateData(true_process, delta_t, delta_t_mod, offset) 72 | 73 | noise_free_obs = getObs(decimated_process,h) 74 | 75 | # Replicate for computation purposes 76 | decimated_process = torch.cat(int(N_examples)*[decimated_process]) 77 | noise_free_obs = torch.cat(int(N_examples)*[noise_free_obs]) 78 | 79 | 80 | # Observations; additive Gaussian Noise 81 | observations = noise_free_obs + torch.randn_like(decimated_process) * lambda_r 82 | 83 | return [decimated_process, observations] 84 | 85 | def getObs(sequences, h): 86 | i = 0 87 | sequences_out = torch.zeros_like(sequences) 88 | # sequences_out = torch.zeros_like(sequences) 89 | for sequence in sequences: 90 | for t in range(sequence.size()[1]): 91 | sequences_out[i,:,t] = h(sequence[:,t]) 92 | i = i+1 93 | 94 | return sequences_out 95 | 96 | def Short_Traj_Split(data_target, data_input, T):### Random Init is automatically incorporated 97 | data_target = list(torch.split(data_target,T+1,2)) # +1 to reserve for init 98 | data_input = list(torch.split(data_input,T+1,2)) # +1 to reserve for init 99 | 100 | data_target.pop()# Remove the last one which may not fullfill length T 101 | data_input.pop()# Remove the last one which may not fullfill length T 102 | 103 | data_target = torch.squeeze(torch.cat(list(data_target), dim=0))#Back to tensor and concat together 104 | data_input = torch.squeeze(torch.cat(list(data_input), dim=0))#Back to tensor and concat together 105 | # Split out init 106 | target = data_target[:,:,1:] 107 | input = data_input[:,:,1:] 108 | init = data_target[:,:,0] 109 | return [target, input, init] 110 | -------------------------------------------------------------------------------- /main_linear_CA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | 4 | from Simulations.Linear_sysmdl import SystemModel 5 | import Simulations.config as config 6 | import Simulations.utils as utils 7 | from Simulations.Linear_CA.parameters import F_gen,F_CV,H_identity,H_onlyPos,\ 8 | Q_gen,Q_CV,R_3,R_2,R_onlyPos,\ 9 | m,m_cv 10 | 11 | from Filters.KalmanFilter_test import KFTest 12 | 13 | from KNet.KalmanNet_nn import KalmanNetNN 14 | 15 | from Pipelines.Pipeline_EKF import Pipeline_EKF as Pipeline 16 | 17 | from Plot import Plot_extended as Plot 18 | 19 | ################ 20 | ### Get Time ### 21 | ################ 22 | today = datetime.today() 23 | now = datetime.now() 24 | strToday = today.strftime("%m.%d.%y") 25 | strNow = now.strftime("%H:%M:%S") 26 | strTime = strToday + "_" + strNow 27 | print("Current Time =", strTime) 28 | path_results = 'KNet/' 29 | 30 | print("Pipeline Start") 31 | #################################### 32 | ### Generative Parameters For CA ### 33 | #################################### 34 | args = config.general_settings() 35 | ### Dataset parameters 36 | args.N_E = 1000 37 | args.N_CV = 100 38 | args.N_T = 200 39 | offset = 0 ### Init condition of dataset 40 | args.randomInit_train = True 41 | args.randomInit_cv = True 42 | args.randomInit_test = True 43 | 44 | args.T = 100 45 | args.T_test = 100 46 | ### training parameters 47 | KnownRandInit_train = True # if true: use known random init for training, else: model is agnostic to random init 48 | KnownRandInit_cv = True 49 | KnownRandInit_test = True 50 | args.use_cuda = True # use GPU or not 51 | args.n_steps = 4000 52 | args.n_batch = 10 53 | args.lr = 1e-4 54 | args.wd = 1e-4 55 | 56 | if args.use_cuda: 57 | if torch.cuda.is_available(): 58 | device = torch.device('cuda') 59 | print("Using GPU") 60 | else: 61 | raise Exception("No GPU found, please set args.use_cuda = False") 62 | else: 63 | device = torch.device('cpu') 64 | print("Using CPU") 65 | 66 | if(args.randomInit_train or args.randomInit_cv or args.args.randomInit_test): 67 | std_gen = 1 68 | else: 69 | std_gen = 0 70 | 71 | if(KnownRandInit_train or KnownRandInit_cv or KnownRandInit_test): 72 | std_feed = 0 73 | else: 74 | std_feed = 1 75 | 76 | m1x_0 = torch.zeros(m) # Initial State 77 | m1x_0_cv = torch.zeros(m_cv) # Initial State for CV 78 | m2x_0 = std_feed * std_feed * torch.eye(m) # Initial Covariance for feeding to filters and KNet 79 | m2x_0_gen = std_gen * std_gen * torch.eye(m) # Initial Covariance for generating dataset 80 | m2x_0_cv = std_feed * std_feed * torch.eye(m_cv) # Initial Covariance for CV 81 | 82 | ############################# 83 | ### Dataset Generation ### 84 | ############################# 85 | ### PVA or P 86 | Loss_On_AllState = False # if false: only calculate loss on position 87 | Train_Loss_On_AllState = True # if false: only calculate training loss on position 88 | CV_model = False # if true: use CV model, else: use CA model 89 | 90 | DatafolderName = 'Simulations/Linear_CA/data/' 91 | DatafileName = 'decimated_dt1e-2_T100_r0_randnInit.pt' 92 | 93 | #################### 94 | ### System Model ### 95 | #################### 96 | # Generation model (CA) 97 | sys_model_gen = SystemModel(F_gen, Q_gen, H_onlyPos, R_onlyPos, args.T, args.T_test) 98 | sys_model_gen.InitSequence(m1x_0, m2x_0_gen)# x0 and P0 99 | 100 | # Feed model (to KF, KalmanNet) 101 | if CV_model: 102 | H_onlyPos = torch.tensor([[1, 0]]).float() 103 | sys_model = SystemModel(F_CV, Q_CV, H_onlyPos, R_onlyPos, args.T, args.T_test) 104 | sys_model.InitSequence(m1x_0_cv, m2x_0_cv)# x0 and P0 105 | else: 106 | sys_model = SystemModel(F_gen, Q_gen, H_onlyPos, R_onlyPos, args.T, args.T_test) 107 | sys_model.InitSequence(m1x_0, m2x_0)# x0 and P0 108 | 109 | print("Start Data Gen") 110 | utils.DataGen(args, sys_model_gen, DatafolderName+DatafileName) 111 | print("Load Original Data") 112 | [train_input, train_target, cv_input, cv_target, test_input, test_target,train_init,cv_init,test_init] = torch.load(DatafolderName+DatafileName, map_location=device) 113 | if CV_model:# set state as (p,v) instead of (p,v,a) 114 | train_target = train_target[:,0:m_cv,:] 115 | train_init = train_init[:,0:m_cv] 116 | cv_target = cv_target[:,0:m_cv,:] 117 | cv_init = cv_init[:,0:m_cv] 118 | test_target = test_target[:,0:m_cv,:] 119 | test_init = test_init[:,0:m_cv] 120 | 121 | print("Data Shape") 122 | print("testset state x size:",test_target.size()) 123 | print("testset observation y size:",test_input.size()) 124 | print("trainset state x size:",train_target.size()) 125 | print("trainset observation y size:",train_input.size()) 126 | print("cvset state x size:",cv_target.size()) 127 | print("cvset observation y size:",cv_input.size()) 128 | 129 | print("Compute Loss on All States (if false, loss on position only):", Loss_On_AllState) 130 | ############################## 131 | ### Evaluate Kalman Filter ### 132 | ############################## 133 | print("Evaluate Kalman Filter") 134 | if args.randomInit_test and KnownRandInit_test: 135 | [MSE_KF_linear_arr, MSE_KF_linear_avg, MSE_KF_dB_avg, KF_out] = KFTest(args, sys_model, test_input, test_target, allStates=Loss_On_AllState, randomInit = True, test_init=test_init) 136 | else: 137 | [MSE_KF_linear_arr, MSE_KF_linear_avg, MSE_KF_dB_avg, KF_out] = KFTest(args, sys_model, test_input, test_target, allStates=Loss_On_AllState) 138 | 139 | ########################## 140 | ### Evaluate KalmanNet ### 141 | ########################## 142 | # Build Neural Network 143 | KNet_model = KalmanNetNN() 144 | KNet_model.NNBuild(sys_model, args) 145 | print("Number of trainable parameters for KNet pass 1:",sum(p.numel() for p in KNet_model.parameters() if p.requires_grad)) 146 | ## Train Neural Network 147 | KNet_Pipeline = Pipeline(strTime, "KNet", "KNet") 148 | KNet_Pipeline.setssModel(sys_model) 149 | KNet_Pipeline.setModel(KNet_model) 150 | KNet_Pipeline.setTrainingParams(args) 151 | if (KnownRandInit_train): 152 | print("Train KNet with Known Random Initial State") 153 | print("Train Loss on All States (if false, loss on position only):", Train_Loss_On_AllState) 154 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results, MaskOnState=not Train_Loss_On_AllState, randomInit = True, cv_init=cv_init,train_init=train_init) 155 | else: 156 | print("Train KNet with Unknown Initial State") 157 | print("Train Loss on All States (if false, loss on position only):", Train_Loss_On_AllState) 158 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results, MaskOnState=not Train_Loss_On_AllState) 159 | 160 | if (KnownRandInit_test): 161 | print("Test KNet with Known Random Initial State") 162 | ## Test Neural Network 163 | print("Compute Loss on All States (if false, loss on position only):", Loss_On_AllState) 164 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,KNet_out,RunTime] = KNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results,MaskOnState=not Loss_On_AllState,randomInit=True,test_init=test_init) 165 | else: 166 | print("Test KNet with Unknown Initial State") 167 | ## Test Neural Network 168 | print("Compute Loss on All States (if false, loss on position only):", Loss_On_AllState) 169 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,KNet_out,RunTime] = KNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results,MaskOnState=not Loss_On_AllState) 170 | 171 | 172 | #################### 173 | ### Plot results ### 174 | #################### 175 | PlotfolderName = "Figures/Linear_CA/" 176 | PlotfileName0 = "TrainPVA_position.png" 177 | PlotfileName1 = "TrainPVA_velocity.png" 178 | PlotfileName2 = "TrainPVA_acceleration.png" 179 | 180 | Plot = Plot(PlotfolderName, PlotfileName0) 181 | print("Plot") 182 | Plot.plotTraj_CA(test_target, KF_out, KNet_out, dim=0, file_name=PlotfolderName+PlotfileName0)#Position 183 | Plot.plotTraj_CA(test_target, KF_out, KNet_out, dim=1, file_name=PlotfolderName+PlotfileName1)#Velocity 184 | Plot.plotTraj_CA(test_target, KF_out, KNet_out, dim=2, file_name=PlotfolderName+PlotfileName2)#Acceleration -------------------------------------------------------------------------------- /main_linear_canonical.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datetime import datetime 4 | 5 | from Simulations.Linear_sysmdl import SystemModel 6 | from Simulations.utils import DataGen 7 | import Simulations.config as config 8 | from Simulations.Linear_canonical.parameters import F, H, Q_structure, R_structure,\ 9 | m, m1_0 10 | 11 | from Filters.KalmanFilter_test import KFTest 12 | 13 | from KNet.KalmanNet_nn import KalmanNetNN 14 | 15 | from Pipelines.Pipeline_EKF import Pipeline_EKF 16 | 17 | print("Pipeline Start") 18 | 19 | ################ 20 | ### Get Time ### 21 | ################ 22 | today = datetime.today() 23 | now = datetime.now() 24 | strToday = today.strftime("%m.%d.%y") 25 | strNow = now.strftime("%H:%M:%S") 26 | strTime = strToday + "_" + strNow 27 | print("Current Time =", strTime) 28 | path_results = 'KNet/' 29 | 30 | #################### 31 | ### Design Model ### 32 | #################### 33 | args = config.general_settings() 34 | 35 | ### dataset parameters ################################################## 36 | args.N_E = 1000 37 | args.N_CV = 100 38 | args.N_T = 200 39 | # init condition 40 | args.randomInit_train = False 41 | args.randomInit_cv = False 42 | args.randomInit_test = False 43 | if args.randomInit_train or args.randomInit_cv or args.randomInit_test: 44 | # you can modify initial variance 45 | args.variance = 1 46 | args.distribution = 'normal' # 'uniform' or 'normal' 47 | m2_0 = args.variance * torch.eye(m) 48 | else: 49 | # deterministic initial condition 50 | m2_0 = 0 * torch.eye(m) 51 | # sequence length 52 | args.T = 100 53 | args.T_test = 100 54 | args.randomLength = False 55 | if args.randomLength:# you can modify T_max and T_min 56 | args.T_max = 1000 57 | args.T_min = 100 58 | # set T and T_test to T_max for convenience of batch calculation 59 | args.T = args.T_max 60 | args.T_test = args.T_max 61 | else: 62 | train_lengthMask = None 63 | cv_lengthMask = None 64 | test_lengthMask = None 65 | # noise 66 | r2 = torch.tensor([1]) 67 | vdB = -20 # ratio v=q2/r2 68 | v = 10**(vdB/10) 69 | q2 = torch.mul(v,r2) 70 | print("1/r2 [dB]: ", 10 * torch.log10(1/r2[0])) 71 | print("1/q2 [dB]: ", 10 * torch.log10(1/q2[0])) 72 | 73 | ### training parameters ################################################## 74 | args.use_cuda = True # use GPU or not 75 | args.n_steps = 4000 76 | args.n_batch = 30 77 | args.lr = 1e-4 78 | args.wd = 1e-3 79 | 80 | if args.use_cuda: 81 | if torch.cuda.is_available(): 82 | device = torch.device('cuda') 83 | print("Using GPU") 84 | else: 85 | raise Exception("No GPU found, please set args.use_cuda = False") 86 | else: 87 | device = torch.device('cpu') 88 | print("Using CPU") 89 | 90 | ### True model ################################################## 91 | Q = q2 * Q_structure 92 | R = r2 * R_structure 93 | sys_model = SystemModel(F, Q, H, R, args.T, args.T_test) 94 | sys_model.InitSequence(m1_0, m2_0) 95 | print("State Evolution Matrix:",F) 96 | print("Observation Matrix:",H) 97 | 98 | ################################### 99 | ### Data Loader (Generate Data) ### 100 | ################################### 101 | dataFolderName = 'Simulations/Linear_canonical/data' + '/' 102 | dataFileName = '2x2_rq020_T100.pt' 103 | print("Start Data Gen") 104 | DataGen(args, sys_model, dataFolderName + dataFileName) 105 | print("Data Load") 106 | if args.randomLength: 107 | [train_input, train_target, cv_input, cv_target, test_input, test_target,train_init, cv_init, test_init, train_lengthMask,cv_lengthMask,test_lengthMask] = torch.load(dataFolderName + dataFileName, map_location=device) 108 | else: 109 | [train_input, train_target, cv_input, cv_target, test_input, test_target,_,_,_] = torch.load(dataFolderName + dataFileName, map_location=device) 110 | 111 | print("trainset size:",train_target.size()) 112 | print("cvset size:",cv_target.size()) 113 | print("testset size:",test_target.size()) 114 | 115 | ######################################## 116 | ### Evaluate Observation Noise Floor ### 117 | ######################################## 118 | loss_obs = nn.MSELoss(reduction='mean') 119 | MSE_obs_linear_arr = torch.empty(args.N_T)# MSE [Linear] 120 | for i in range(args.N_T): 121 | MSE_obs_linear_arr[i] = loss_obs(test_input[i], test_target[i]).item() 122 | MSE_obs_linear_avg = torch.mean(MSE_obs_linear_arr) 123 | MSE_obs_dB_avg = 10 * torch.log10(MSE_obs_linear_avg) 124 | 125 | # Standard deviation 126 | MSE_obs_linear_std = torch.std(MSE_obs_linear_arr, unbiased=True) 127 | 128 | # Confidence interval 129 | obs_std_dB = 10 * torch.log10(MSE_obs_linear_std + MSE_obs_linear_avg) - MSE_obs_dB_avg 130 | 131 | print("Observation Noise Floor - MSE LOSS:", MSE_obs_dB_avg, "[dB]") 132 | print("Observation Noise Floor - STD:", obs_std_dB, "[dB]") 133 | 134 | ############################## 135 | ### Evaluate Kalman Filter ### 136 | ############################## 137 | print("Evaluate Kalman Filter True") 138 | if args.randomInit_test: 139 | [MSE_KF_linear_arr, MSE_KF_linear_avg, MSE_KF_dB_avg, KF_out] = KFTest(args, sys_model, test_input, test_target, randomInit = True, test_init=test_init, test_lengthMask=test_lengthMask) 140 | else: 141 | [MSE_KF_linear_arr, MSE_KF_linear_avg, MSE_KF_dB_avg, KF_out] = KFTest(args, sys_model, test_input, test_target, test_lengthMask=test_lengthMask) 142 | 143 | 144 | ########################## 145 | ### KalmanNet Pipeline ### 146 | ########################## 147 | 148 | ### KalmanNet with full info ########################################################################################## 149 | # Build Neural Network 150 | print("KalmanNet with full model info") 151 | KalmanNet_model = KalmanNetNN() 152 | KalmanNet_model.NNBuild(sys_model, args) 153 | print("Number of trainable parameters for KalmanNet:",sum(p.numel() for p in KalmanNet_model.parameters() if p.requires_grad)) 154 | ## Train Neural Network 155 | KalmanNet_Pipeline = Pipeline_EKF(strTime, "KNet", "KalmanNet") 156 | KalmanNet_Pipeline.setssModel(sys_model) 157 | KalmanNet_Pipeline.setModel(KalmanNet_model) 158 | KalmanNet_Pipeline.setTrainingParams(args) 159 | if (args.randomInit_train or args.randomInit_cv or args.randomInit_test): 160 | if args.randomLength: 161 | ## Train Neural Network 162 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KalmanNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results, randomInit = True, cv_init=cv_init,train_init=train_init,train_lengthMask=train_lengthMask,cv_lengthMask=cv_lengthMask) 163 | ## Test Neural Network 164 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,knet_out,RunTime] = KalmanNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results,randomInit=True,test_init=test_init,test_lengthMask=test_lengthMask) 165 | else: 166 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KalmanNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results, randomInit = True, cv_init=cv_init,train_init=train_init) 167 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,knet_out,RunTime] = KalmanNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results,randomInit=True,test_init=test_init) 168 | else: 169 | if args.randomLength: 170 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KalmanNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results,train_lengthMask=train_lengthMask,cv_lengthMask=cv_lengthMask) 171 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,knet_out,RunTime] = KalmanNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results,test_lengthMask=test_lengthMask) 172 | else: 173 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KalmanNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results) 174 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,knet_out,RunTime] = KalmanNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results) 175 | KalmanNet_Pipeline.save() -------------------------------------------------------------------------------- /main_lor_DT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 3 | import torch.nn as nn 4 | from Filters.EKF_test import EKFTest 5 | 6 | from Simulations.Extended_sysmdl import SystemModel 7 | from Simulations.utils import DataGen,Short_Traj_Split 8 | import Simulations.config as config 9 | 10 | from Pipelines.Pipeline_EKF import Pipeline_EKF 11 | 12 | from datetime import datetime 13 | 14 | from KNet.KalmanNet_nn import KalmanNetNN 15 | 16 | from Simulations.Lorenz_Atractor.parameters import m1x_0, m2x_0, m, n,\ 17 | f, h, hRotate, H_Rotate, H_Rotate_inv, Q_structure, R_structure 18 | 19 | print("Pipeline Start") 20 | ################ 21 | ### Get Time ### 22 | ################ 23 | today = datetime.today() 24 | now = datetime.now() 25 | strToday = today.strftime("%m.%d.%y") 26 | strNow = now.strftime("%H:%M:%S") 27 | strTime = strToday + "_" + strNow 28 | print("Current Time =", strTime) 29 | 30 | ################### 31 | ### Settings ### 32 | ################### 33 | args = config.general_settings() 34 | ### dataset parameters 35 | args.N_E = 1000 36 | args.N_CV = 100 37 | args.N_T = 200 38 | args.T = 100 39 | args.T_test = 100 40 | ### training parameters 41 | args.use_cuda = True # use GPU or not 42 | args.n_steps = 2000 43 | args.n_batch = 30 44 | args.lr = 1e-3 45 | args.wd = 1e-3 46 | 47 | if args.use_cuda: 48 | if torch.cuda.is_available(): 49 | device = torch.device('cuda') 50 | print("Using GPU") 51 | else: 52 | raise Exception("No GPU found, please set args.use_cuda = False") 53 | else: 54 | device = torch.device('cpu') 55 | print("Using CPU") 56 | 57 | offset = 0 # offset for the data 58 | chop = False # whether to chop data sequences into shorter sequences 59 | path_results = 'KNet/' 60 | DatafolderName = 'Simulations/Lorenz_Atractor/data' + '/' 61 | switch = 'partial' # 'full' or 'partial' or 'estH' 62 | 63 | # noise q and r 64 | r2 = torch.tensor([0.1]) # [100, 10, 1, 0.1, 0.01] 65 | vdB = -20 # ratio v=q2/r2 66 | v = 10**(vdB/10) 67 | q2 = torch.mul(v,r2) 68 | 69 | Q = q2[0] * Q_structure 70 | R = r2[0] * R_structure 71 | 72 | print("1/r2 [dB]: ", 10 * torch.log10(1/r2[0])) 73 | print("1/q2 [dB]: ", 10 * torch.log10(1/q2[0])) 74 | 75 | traj_resultName = ['traj_lorDT_rq1030_T100.pt'] 76 | dataFileName = ['data_lor_v20_rq1030_T100.pt'] 77 | 78 | ######################################### 79 | ### Generate and load data DT case ### 80 | ######################################### 81 | 82 | sys_model = SystemModel(f, Q, hRotate, R, args.T, args.T_test, m, n)# parameters for GT 83 | sys_model.InitSequence(m1x_0, m2x_0)# x0 and P0 84 | 85 | print("Start Data Gen") 86 | DataGen(args, sys_model, DatafolderName + dataFileName[0]) 87 | print("Data Load") 88 | print(dataFileName[0]) 89 | [train_input_long,train_target_long, cv_input, cv_target, test_input, test_target,_,_,_] = torch.load(DatafolderName + dataFileName[0], map_location=device) 90 | if chop: 91 | print("chop training data") 92 | [train_target, train_input, train_init] = Short_Traj_Split(train_target_long, train_input_long, args.T) 93 | # [cv_target, cv_input] = Short_Traj_Split(cv_target, cv_input, args.T) 94 | else: 95 | print("no chopping") 96 | train_target = train_target_long[:,:,0:args.T] 97 | train_input = train_input_long[:,:,0:args.T] 98 | # cv_target = cv_target[:,:,0:args.T] 99 | # cv_input = cv_input[:,:,0:args.T] 100 | 101 | print("trainset size:",train_target.size()) 102 | print("cvset size:",cv_target.size()) 103 | print("testset size:",test_target.size()) 104 | 105 | 106 | # Model with partial info 107 | sys_model_partial = SystemModel(f, Q, h, R, args.T, args.T_test, m, n) 108 | sys_model_partial.InitSequence(m1x_0, m2x_0) 109 | # Model for 2nd pass 110 | sys_model_pass2 = SystemModel(f, Q, h, R, args.T, args.T_test, m, n)# parameters for GT 111 | sys_model_pass2.InitSequence(m1x_0, m2x_0)# x0 and P0 112 | 113 | ######################################## 114 | ### Evaluate Observation Noise Floor ### 115 | ######################################## 116 | N_T = len(test_input) 117 | loss_obs = nn.MSELoss(reduction='mean') 118 | MSE_obs_linear_arr = torch.empty(N_T)# MSE [Linear] 119 | 120 | for j in range(0, N_T): 121 | reversed_target = torch.matmul(H_Rotate_inv, test_input[j]) 122 | MSE_obs_linear_arr[j] = loss_obs(reversed_target, test_target[j]).item() 123 | MSE_obs_linear_avg = torch.mean(MSE_obs_linear_arr) 124 | MSE_obs_dB_avg = 10 * torch.log10(MSE_obs_linear_avg) 125 | 126 | # Standard deviation 127 | MSE_obs_linear_std = torch.std(MSE_obs_linear_arr, unbiased=True) 128 | 129 | # Confidence interval 130 | obs_std_dB = 10 * torch.log10(MSE_obs_linear_std + MSE_obs_linear_avg) - MSE_obs_dB_avg 131 | 132 | print("Observation Noise Floor(test dataset) - MSE LOSS:", MSE_obs_dB_avg, "[dB]") 133 | print("Observation Noise Floor(test dataset) - STD:", obs_std_dB, "[dB]") 134 | 135 | 136 | ######################## 137 | ### Evaluate Filters ### 138 | ######################## 139 | ### Evaluate EKF true 140 | # print("Evaluate EKF true") 141 | # [MSE_EKF_linear_arr, MSE_EKF_linear_avg, MSE_EKF_dB_avg, EKF_KG_array, EKF_out] = EKFTest(args, sys_model, test_input, test_target) 142 | # ### Evaluate EKF partial 143 | # print("Evaluate EKF partial") 144 | # [MSE_EKF_linear_arr_partial, MSE_EKF_linear_avg_partial, MSE_EKF_dB_avg_partial, EKF_KG_array_partial, EKF_out_partial] = EKFTest(args, sys_model_partial, test_input, test_target) 145 | 146 | # ### Save trajectories 147 | # trajfolderName = 'Filters' + '/' 148 | # DataResultName = traj_resultName[0] 149 | # EKF_sample = torch.reshape(EKF_out[0],[1,m,args.T_test]) 150 | # target_sample = torch.reshape(test_target[0,:,:],[1,m,args.T_test]) 151 | # input_sample = torch.reshape(test_input[0,:,:],[1,n,args.T_test]) 152 | # torch.save({ 153 | # 'EKF': EKF_sample, 154 | # 'ground_truth': target_sample, 155 | # 'observation': input_sample, 156 | # }, trajfolderName+DataResultName) 157 | 158 | ##################### 159 | ### Evaluate KNet ### 160 | ##################### 161 | if switch == 'full': 162 | ## KNet with full info #################################################################################### 163 | ################ 164 | ## KNet full ### 165 | ################ 166 | ## Build Neural Network 167 | print("KNet with full model info") 168 | KNet_model = KalmanNetNN() 169 | KNet_model.NNBuild(sys_model, args) 170 | # ## Train Neural Network 171 | KNet_Pipeline = Pipeline_EKF(strTime, "KNet", "KNet") 172 | KNet_Pipeline.setssModel(sys_model) 173 | KNet_Pipeline.setModel(KNet_model) 174 | print("Number of trainable parameters for KNet:",sum(p.numel() for p in KNet_model.parameters() if p.requires_grad)) 175 | KNet_Pipeline.setTrainingParams(args) 176 | if(chop): 177 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results,randomInit=True,train_init=train_init) 178 | else: 179 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results) 180 | ## Test Neural Network 181 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,Knet_out,RunTime] = KNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results) 182 | 183 | #################################################################################### 184 | elif switch == 'partial': 185 | ## KNet with model mismatch #################################################################################### 186 | ################### 187 | ## KNet partial ### 188 | #################### 189 | ## Build Neural Network 190 | print("KNet with observation model mismatch") 191 | KNet_model = KalmanNetNN() 192 | KNet_model.NNBuild(sys_model_partial, args) 193 | ## Train Neural Network 194 | KNet_Pipeline = Pipeline_EKF(strTime, "KNet", "KNet") 195 | KNet_Pipeline.setssModel(sys_model_partial) 196 | KNet_Pipeline.setModel(KNet_model) 197 | KNet_Pipeline.setTrainingParams(args) 198 | if(chop): 199 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model_partial, cv_input, cv_target, train_input, train_target, path_results,randomInit=True,train_init=train_init) 200 | else: 201 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model_partial, cv_input, cv_target, train_input, train_target, path_results) 202 | ## Test Neural Network 203 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,Knet_out,RunTime] = KNet_Pipeline.NNTest(sys_model_partial, test_input, test_target, path_results) 204 | 205 | ################################################################################### 206 | elif switch == 'estH': 207 | print("True Observation matrix H:", H_Rotate) 208 | ### Least square estimation of H 209 | X = torch.squeeze(train_target[:,:,0]) 210 | Y = torch.squeeze(train_input[:,:,0]) 211 | for t in range(1,args.T): 212 | X_t = torch.squeeze(train_target[:,:,t]) 213 | Y_t = torch.squeeze(train_input[:,:,t]) 214 | X = torch.cat((X,X_t),0) 215 | Y = torch.cat((Y,Y_t),0) 216 | Y_1 = torch.unsqueeze(Y[:,0],1) 217 | Y_2 = torch.unsqueeze(Y[:,1],1) 218 | Y_3 = torch.unsqueeze(Y[:,2],1) 219 | H_row1 = torch.matmul(torch.matmul(torch.inverse(torch.matmul(X.T,X)),X.T),Y_1) 220 | H_row2 = torch.matmul(torch.matmul(torch.inverse(torch.matmul(X.T,X)),X.T),Y_2) 221 | H_row3 = torch.matmul(torch.matmul(torch.inverse(torch.matmul(X.T,X)),X.T),Y_3) 222 | H_hat = torch.cat((H_row1.T,H_row2.T,H_row3.T),0) 223 | print("Estimated Observation matrix H:", H_hat) 224 | 225 | def h_hat(x, jacobian=False): 226 | H = H_hat.reshape((1, n, m)).repeat(x.shape[0], 1, 1) # [batch_size, n, m] 227 | y = torch.bmm(H,x) 228 | if jacobian: 229 | return y, H 230 | else: 231 | return y 232 | 233 | # Estimated model 234 | sys_model_esth = SystemModel(f, Q, h_hat, R, args.T, args.T_test, m, n) 235 | sys_model_esth.InitSequence(m1x_0, m2x_0) 236 | 237 | ################ 238 | ## KNet estH ### 239 | ################ 240 | print("KNet with estimated H") 241 | KNet_Pipeline = Pipeline_EKF(strTime, "KNet", "KNetEstH_"+ dataFileName[0]) 242 | KNet_Pipeline.setssModel(sys_model_esth) 243 | KNet_model = KalmanNetNN() 244 | KNet_model.NNBuild(sys_model_esth, args) 245 | KNet_Pipeline.setModel(KNet_model) 246 | KNet_Pipeline.setTrainingParams(args) 247 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KNet_Pipeline.NNTrain(sys_model_esth, cv_input, cv_target, train_input, train_target, path_results) 248 | ## Test Neural Network 249 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,Knet_out,RunTime] = KNet_Pipeline.NNTest(sys_model_esth, test_input, test_target, path_results) 250 | 251 | ################################################################################### 252 | else: 253 | print("Error in switch! Please try 'full' or 'partial' or 'estH'.") 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /main_lor_DT_NLobs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | 4 | from Filters.EKF_test import EKFTest 5 | 6 | from Simulations.Extended_sysmdl import SystemModel 7 | from Simulations.utils import DataGen,Short_Traj_Split 8 | import Simulations.config as config 9 | from Simulations.Lorenz_Atractor.parameters import m1x_0, m2x_0, m, n,\ 10 | f, h, h_nonlinear, Q_structure, R_structure 11 | 12 | from Pipelines.Pipeline_EKF import Pipeline_EKF 13 | 14 | from KNet.KalmanNet_nn import KalmanNetNN 15 | 16 | print("Pipeline Start") 17 | ################ 18 | ### Get Time ### 19 | ################ 20 | today = datetime.today() 21 | now = datetime.now() 22 | strToday = today.strftime("%m.%d.%y") 23 | strNow = now.strftime("%H:%M:%S") 24 | strTime = strToday + "_" + strNow 25 | print("Current Time =", strTime) 26 | 27 | ################### 28 | ### Settings ### 29 | ################### 30 | args = config.general_settings() 31 | ### dataset parameters 32 | args.N_E = 1000 33 | args.N_CV = 100 34 | args.N_T = 200 35 | args.T = 20 36 | args.T_test = 20 37 | ### settings for KalmanNet 38 | args.in_mult_KNet = 40 39 | args.out_mult_KNet = 5 40 | 41 | ### training parameters 42 | args.use_cuda = True # use GPU or not 43 | args.n_steps = 2000 44 | args.n_batch = 100 45 | args.lr = 1e-4 46 | args.wd = 1e-4 47 | args.CompositionLoss = True 48 | args.alpha = 0.5 49 | 50 | if args.use_cuda: 51 | if torch.cuda.is_available(): 52 | device = torch.device('cuda') 53 | print("Using GPU") 54 | else: 55 | raise Exception("No GPU found, please set args.use_cuda = False") 56 | else: 57 | device = torch.device('cpu') 58 | print("Using CPU") 59 | 60 | offset = 0 61 | chop = False 62 | sequential_training = False 63 | path_results = 'KNet/' 64 | DatafolderName = 'Simulations/Lorenz_Atractor/data' + '/' 65 | r2 = torch.tensor([1e-3]) # [10, 1, 0.1, 0.01, 1e-3] 66 | vdB = 0 # ratio v=q2/r2 67 | v = 10**(vdB/10) 68 | q2 = torch.mul(v,r2) 69 | 70 | Q = q2[0] * Q_structure 71 | R = r2[0] * R_structure 72 | 73 | print("1/r2 [dB]: ", 10 * torch.log10(1/r2[0])) 74 | print("1/q2 [dB]: ", 10 * torch.log10(1/q2[0])) 75 | 76 | traj_resultName = ['traj_lorDT_NLobs_rq3030_T20.pt'] 77 | dataFileName = ['data_lor_v0_rq3030_T20.pt'] 78 | 79 | ######################################### 80 | ### Generate and load data DT case ### 81 | ######################################### 82 | 83 | sys_model = SystemModel(f, Q, h_nonlinear, R, args.T, args.T_test, m, n)# parameters for GT 84 | sys_model.InitSequence(m1x_0, m2x_0)# x0 and P0 85 | ## Model with H=I 86 | sys_model_H = SystemModel(f, Q, h, R, args.T,args.T_test, m, n) 87 | sys_model_H.InitSequence(m1x_0, m2x_0) 88 | 89 | print("Start Data Gen") 90 | DataGen(args, sys_model, DatafolderName + dataFileName[0]) 91 | print("Data Load") 92 | print(dataFileName[0]) 93 | [train_input_long,train_target_long, cv_input, cv_target, test_input, test_target,_,_,_] = torch.load(DatafolderName + dataFileName[0], map_location=device) 94 | if chop: 95 | print("chop training data") 96 | [train_target, train_input, train_init] = Short_Traj_Split(train_target_long, train_input_long, args.T) 97 | # [cv_target, cv_input] = Short_Traj_Split(cv_target, cv_input, T) 98 | else: 99 | print("no chopping") 100 | train_target = train_target_long[:,:,0:args.T] 101 | train_input = train_input_long[:,:,0:args.T] 102 | # cv_target = cv_target[:,:,0:T] 103 | # cv_input = cv_input[:,:,0:T] 104 | 105 | print("trainset size:",train_target.size()) 106 | print("cvset size:",cv_target.size()) 107 | print("testset size:",test_target.size()) 108 | 109 | ######################## 110 | ### Evaluate Filters ### 111 | ######################## 112 | # ### Evaluate EKF full 113 | print("Evaluate EKF full") 114 | [MSE_EKF_linear_arr, MSE_EKF_linear_avg, MSE_EKF_dB_avg, EKF_KG_array, EKF_out] = EKFTest(args, sys_model, test_input, test_target) 115 | 116 | # ### Save trajectories 117 | # trajfolderName = 'Filters' + '/' 118 | # DataResultName = traj_resultName[0] 119 | # EKF_sample = torch.reshape(EKF_out[0],[1,m,args.T_test]) 120 | # target_sample = torch.reshape(test_target[0,:,:],[1,m,args.T_test]) 121 | # input_sample = torch.reshape(test_input[0,:,:],[1,n,args.T_test]) 122 | # torch.save({ 123 | # 'EKF': EKF_sample, 124 | # 'ground_truth': target_sample, 125 | # 'observation': input_sample, 126 | # }, trajfolderName+DataResultName) 127 | 128 | 129 | ########################## 130 | ### Evaluate KalmanNet ### 131 | ########################## 132 | ## Build Neural Network 133 | print("KalmanNet start") 134 | KalmanNet_model = KalmanNetNN() 135 | KalmanNet_model.NNBuild(sys_model, args) 136 | ## Train Neural Network 137 | KalmanNet_Pipeline = Pipeline_EKF(strTime, "KNet", "KalmanNet") 138 | KalmanNet_Pipeline.setssModel(sys_model) 139 | KalmanNet_Pipeline.setModel(KalmanNet_model) 140 | print("Number of trainable parameters for KNet:",sum(p.numel() for p in KalmanNet_model.parameters() if p.requires_grad)) 141 | KalmanNet_Pipeline.setTrainingParams(args) 142 | if(chop): 143 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KalmanNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results,randomInit=True,train_init=train_init) 144 | else: 145 | print("Composition Loss:",args.CompositionLoss) 146 | [MSE_cv_linear_epoch, MSE_cv_dB_epoch, MSE_train_linear_epoch, MSE_train_dB_epoch] = KalmanNet_Pipeline.NNTrain(sys_model, cv_input, cv_target, train_input, train_target, path_results) 147 | ## Test Neural Network 148 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg,knet_out,RunTime] = KalmanNet_Pipeline.NNTest(sys_model, test_input, test_target, path_results) 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /main_lor_decimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datetime import datetime 4 | 5 | import Filters.EKF_test as EKF_test 6 | 7 | from Simulations.Extended_sysmdl import SystemModel 8 | import Simulations.config as config 9 | from Simulations.utils import Decimate_and_perturbate_Data,Short_Traj_Split 10 | from Simulations.Lorenz_Atractor.parameters import m1x_0, m2x_0, m, n,delta_t_gen,delta_t,\ 11 | f, h, h_nobatch, fInacc, Q_structure, R_structure 12 | 13 | from Pipelines.Pipeline_EKF import Pipeline_EKF 14 | from KNet.KalmanNet_nn import KalmanNetNN 15 | 16 | from Plot import Plot_extended as Plot 17 | 18 | print("Pipeline Start") 19 | 20 | ################ 21 | ### Get Time ### 22 | ################ 23 | today = datetime.today() 24 | now = datetime.now() 25 | strToday = today.strftime("%m.%d.%y") 26 | strNow = now.strftime("%H:%M:%S") 27 | strTime = strToday + "_" + strNow 28 | print("Current Time =", strTime) 29 | 30 | ################### 31 | ### Settings ### 32 | ################### 33 | args = config.general_settings() 34 | ### dataset parameters 35 | args.N_E = 1000 36 | args.N_CV = 10 37 | args.N_T = 10 38 | args.T = 3000 39 | args.T_test = 3000 40 | ### training parameters 41 | args.use_cuda = True # use GPU or not 42 | args.n_steps = 2000 43 | args.n_batch = 8 44 | args.lr = 1e-4 45 | args.wd = 1e-4 46 | 47 | if args.use_cuda: 48 | if torch.cuda.is_available(): 49 | device = torch.device('cuda') 50 | print("Using GPU") 51 | else: 52 | raise Exception("No GPU found, please set args.use_cuda = False") 53 | else: 54 | device = torch.device('cpu') 55 | print("Using CPU") 56 | 57 | offset = 0 # offset for the data 58 | chop = False # whether to chop the dataset sequences into smaller ones 59 | path_results = 'KNet/' 60 | DatafolderName = 'Simulations/Lorenz_Atractor/data/' 61 | DatafileName = 'decimated_r0_Ttest3000_NT10.pt' 62 | data_gen = 'data_gen.pt' 63 | data_gen_file = torch.load(DatafolderName+data_gen) 64 | [true_sequence] = data_gen_file['All Data'] 65 | 66 | r = torch.tensor([1]) 67 | lambda_q = torch.tensor([0.3873]) 68 | 69 | print("1/r2 [dB]: ", 10 * torch.log10(1/r[0]**2)) 70 | print("Search 1/q2 [dB]: ", 10 * torch.log10(1/lambda_q[0]**2)) 71 | Q = (lambda_q[0]**2) * Q_structure 72 | R = (r[0]**2) * R_structure 73 | # True Model 74 | sys_model_true = SystemModel(f, Q, h, R, args.T, args.T_test,m,n) 75 | sys_model_true.InitSequence(m1x_0, m2x_0) 76 | 77 | # Model with partial Info 78 | sys_model = SystemModel(fInacc, Q, h, R, args.T, args.T_test,m,n) 79 | sys_model.InitSequence(m1x_0, m2x_0) 80 | 81 | ############################################## 82 | ### Generate and load data Decimation case ### 83 | ############################################## 84 | ######################## 85 | print("Data Gen") 86 | ######################## 87 | [test_target, test_input] = Decimate_and_perturbate_Data(true_sequence, delta_t_gen, delta_t, args.N_T, h_nobatch, r[0], offset) 88 | [train_target_long, train_input_long] = Decimate_and_perturbate_Data(true_sequence, delta_t_gen, delta_t, args.N_E, h_nobatch, r[0], offset) 89 | [cv_target_long, cv_input_long] = Decimate_and_perturbate_Data(true_sequence, delta_t_gen, delta_t, args.N_CV, h_nobatch, r[0], offset) 90 | if chop: 91 | print("chop training data") 92 | [train_target, train_input, train_init] = Short_Traj_Split(train_target_long, train_input_long, args.T) 93 | args.N_E = train_target.size()[0] 94 | else: 95 | print("no chopping") 96 | train_target = train_target_long 97 | train_input = train_input_long 98 | # Save dataset 99 | if(chop): 100 | torch.save([train_input, train_target, train_init, cv_input_long, cv_target_long, test_input, test_target], DatafolderName+DatafileName) 101 | else: 102 | torch.save([train_input, train_target, cv_input_long, cv_target_long, test_input, test_target], DatafolderName+DatafileName) 103 | 104 | ######################### 105 | print("Data Load") 106 | ######################### 107 | [train_input, train_target, cv_input_long, cv_target_long, test_input, test_target] = torch.load(DatafolderName+DatafileName, map_location=device) 108 | if(chop): 109 | print("chop training data") 110 | [train_target, train_input, train_init] = Short_Traj_Split(train_target, train_input, args.T) 111 | args.N_E = train_target.size()[0] 112 | print("load dataset to device:",train_input.device) 113 | print("testset size:",test_target.size()) 114 | print("trainset size:",train_target.size()) 115 | print("cvset size:",cv_target_long.size()) 116 | 117 | ######################################## 118 | ### Evaluate Observation Noise Floor ### 119 | ######################################## 120 | args.N_T = len(test_input) 121 | loss_obs = nn.MSELoss(reduction='mean') 122 | MSE_obs_linear_arr = torch.empty(args.N_T)# MSE [Linear] 123 | for j in range(0, args.N_T): 124 | MSE_obs_linear_arr[j] = loss_obs(test_input[j], test_target[j]).item() 125 | MSE_obs_linear_avg = torch.mean(MSE_obs_linear_arr) 126 | MSE_obs_dB_avg = 10 * torch.log10(MSE_obs_linear_avg) 127 | 128 | # Standard deviation 129 | MSE_obs_linear_std = torch.std(MSE_obs_linear_arr, unbiased=True) 130 | 131 | # Confidence interval 132 | obs_std_dB = 10 * torch.log10(MSE_obs_linear_std + MSE_obs_linear_avg) - MSE_obs_dB_avg 133 | 134 | print("Observation Noise Floor(test dataset) - MSE LOSS:", MSE_obs_dB_avg, "[dB]") 135 | print("Observation Noise Floor(test dataset) - STD:", obs_std_dB, "[dB]") 136 | ################################################### 137 | args.N_E = len(train_input) 138 | MSE_obs_linear_arr = torch.empty(args.N_E)# MSE [Linear] 139 | for j in range(0, args.N_E): 140 | MSE_obs_linear_arr[j] = loss_obs(train_input[j], train_target[j]).item() 141 | MSE_obs_linear_avg = torch.mean(MSE_obs_linear_arr) 142 | MSE_obs_dB_avg = 10 * torch.log10(MSE_obs_linear_avg) 143 | 144 | # Standard deviation 145 | MSE_obs_linear_std = torch.std(MSE_obs_linear_arr, unbiased=True) 146 | 147 | # Confidence interval 148 | obs_std_dB = 10 * torch.log10(MSE_obs_linear_std + MSE_obs_linear_avg) - MSE_obs_dB_avg 149 | 150 | print("Observation Noise Floor(train dataset) - MSE LOSS:", MSE_obs_dB_avg, "[dB]") 151 | print("Observation Noise Floor(train dataset) - STD:", obs_std_dB, "[dB]") 152 | 153 | ######################## 154 | ### Evaluate Filters ### 155 | ######################## 156 | ### EKF 157 | print("Start EKF test J=5") 158 | [MSE_EKF_linear_arr, MSE_EKF_linear_avg, MSE_EKF_dB_avg, EKF_KG_array, EKF_out] = EKF_test.EKFTest(args, sys_model_true, test_input, test_target) 159 | print("Start EKF test J=2") 160 | [MSE_EKF_linear_arr_partial, MSE_EKF_linear_avg_partial, MSE_EKF_dB_avg_partial, EKF_KG_array_partial, EKF_out_partial] = EKF_test.EKFTest(args, sys_model, test_input, test_target) 161 | 162 | ######################################## 163 | ### KalmanNet with model mismatch ###### 164 | ######################################## 165 | ## Build Neural Network 166 | KNet_model = KalmanNetNN() 167 | KNet_model.NNBuild(sys_model, args) 168 | KNet_Pipeline = Pipeline_EKF(strTime, "KNet", "KalmanNet") 169 | KNet_Pipeline.setModel(KNet_model) 170 | KNet_Pipeline.setssModel(sys_model) 171 | print("Number of trainable parameters for KNet:",sum(p.numel() for p in KNet_model.parameters() if p.requires_grad)) 172 | # Train Neural Network 173 | KNet_Pipeline.setTrainingParams(args) 174 | if(chop): 175 | KNet_Pipeline.NNTrain(sys_model,cv_input_long,cv_target_long,train_input,train_target,path_results,\ 176 | randomInit=True,train_init=train_init) 177 | else: 178 | KNet_Pipeline.NNTrain(sys_model,cv_input_long,cv_target_long,train_input,train_target,path_results) 179 | # Test Neural Network 180 | [MSE_test_linear_arr, MSE_test_linear_avg, MSE_test_dB_avg, knet_out,t] = KNet_Pipeline.NNTest(sys_model,test_input,test_target,path_results) 181 | 182 | 183 | 184 | # Save trajectories 185 | trajfolderName = 'Simulations/Lorenz_Atractor' + '/' 186 | DataResultName = 'traj_lor_dec.pt' 187 | target_sample = torch.reshape(test_target[0,:,:],[1,m,args.T_test]) 188 | input_sample = torch.reshape(test_input[0,:,:],[1,n,args.T_test]) 189 | torch.save({ 190 | 'True':target_sample, 191 | 'Observation':input_sample, 192 | 'EKF J=5':EKF_out, 193 | 'EKF J=2':EKF_out_partial, 194 | 'KNet': knet_out, 195 | }, trajfolderName+DataResultName) 196 | 197 | ############# 198 | ### Plot ### 199 | ############# 200 | titles = ["True Trajectory","Observation","EKF","KNet"] 201 | input = [target_sample,input_sample,EKF_out_partial, knet_out] 202 | Net_Plot = Plot(trajfolderName,DataResultName) 203 | Net_Plot.plotTrajectories(input,3, titles,trajfolderName+"lor_dec_trajs.png") 204 | 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.4 2 | numpy==1.20.1 3 | pyParticleEst==1.1.4 4 | scipy==1.8.1 5 | seaborn==0.11.1 6 | torch==1.10.1 7 | --------------------------------------------------------------------------------