├── Appendix_Invariance_Aware_Loss_Functions_results_for_5-element_MRA.pdf ├── DoA.py ├── ICASSP ├── eval │ ├── perf_n4.sh │ └── perf_n5.sh └── run │ ├── cuda0_n4.sh │ └── cuda0_n5.sh ├── README.md ├── SDP ├── ProxCov.m ├── SDP.py ├── SPA.m ├── SPA_noisevar.m ├── StructCovMLE.m ├── StructCovMLE_noisevar.m └── Wasserstein.m ├── batch_sampler.py ├── data.py ├── eval ├── crlb.py ├── perf_dnn_n4.sh ├── perf_dnn_n4_other_distances.sh ├── perf_dnn_n4_random_power.sh ├── perf_dnn_n4_w_and_wo_crs.sh ├── perf_dnn_n5.sh ├── perf_dnn_n6.sh ├── perf_opt_n4.sh ├── perf_opt_n4_random_power.sh ├── perf_opt_n5.sh └── perf_opt_n6.sh ├── loss.py ├── main.py ├── models.py ├── performance.py ├── predict.py ├── requirements.txt ├── run ├── cuda0_lr_search.sh ├── cuda0_n4.sh ├── cuda0_n4_other_distances.sh ├── cuda0_n4_random_power.sh ├── cuda0_n4_w_and_wo_crs.sh ├── cuda0_n5.sh └── cuda0_n6.sh ├── train.py └── utils.py /Appendix_Invariance_Aware_Loss_Functions_results_for_5-element_MRA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kjason/SubspaceRepresentationLearning/bc3e64f1361ea95f64576f2dc63fb8d0d39f7a03/Appendix_Invariance_Aware_Loss_Functions_results_for_5-element_MRA.pdf -------------------------------------------------------------------------------- /DoA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | """ 6 | import math 7 | import numpy as np 8 | from numpy.polynomial import Polynomial 9 | import torch 10 | from utils import MRA, data_in_preprocess 11 | 12 | def coarray_and_weight_function(cov,sensor_grid): 13 | """ 14 | Create the co-array from the sample covariance matrix and the sensor grid 15 | 16 | :param cov: the N-by-N complex covariance matrix 17 | :param sensor_grid: an N-element array of nonnegative integers representing the location of each sensor in the linear array 18 | :return: Co-array and the corresponding weight function 19 | """ 20 | N = len(sensor_grid) 21 | if N != cov.size(0): 22 | raise ValueError("cov does not match with the sensor_grid") 23 | N_a = sensor_grid[-1] 24 | N_coarray = 2*N_a+1 25 | coarray = torch.zeros(N_coarray,dtype=torch.cfloat) # from -N_a to N_a 26 | weight_fn = torch.zeros(N_coarray,dtype=torch.cfloat) # from -N_a to N_a 27 | for i in range(N): 28 | for j in range(N): 29 | diff = sensor_grid[i]-sensor_grid[j] 30 | coarray[diff+N_a] += cov[i,j] 31 | weight_fn[diff+N_a] += 1 32 | return coarray, weight_fn 33 | 34 | def direct_augmentation(cov: torch.Tensor): 35 | _, MRA_sensor_grid, N_a = MRA(cov.size(0),1) 36 | ULA_M_sensor = N_a+1 37 | ULA_sensor_grid = [(i-N_a/2) for i in range(N_a+1)] 38 | coarray, weight_fn = coarray_and_weight_function(cov,MRA_sensor_grid) 39 | r = coarray / weight_fn 40 | aug_cov = torch.zeros(ULA_M_sensor,ULA_M_sensor,dtype=torch.cfloat) 41 | for i in range(ULA_M_sensor): 42 | aug_cov[:,-i-1] = r[i:i+ULA_M_sensor] 43 | return aug_cov, ULA_sensor_grid 44 | 45 | def spatial_smoothing(cov: torch.Tensor): 46 | aug_cov, ULA_sensor_grid = direct_augmentation(cov) 47 | return (1/len(ULA_sensor_grid))*torch.matmul(aug_cov,aug_cov.conj().transpose(0,1)), ULA_sensor_grid 48 | 49 | def SRP_LA(Y,cov,lam,sensor_locations,N_gridpoints): 50 | grid = [i/(N_gridpoints-1) for i in range(N_gridpoints)] 51 | p = [] 52 | for t in grid: 53 | imag = torch.tensor(sensor_locations).unsqueeze(1)*2*torch.pi*(1/lam)*torch.cos(torch.tensor(t)*torch.pi) 54 | v = (1/math.sqrt(len(sensor_locations)))*torch.exp(torch.complex(torch.zeros_like(imag),imag)) 55 | if cov is None: 56 | a = torch.abs(torch.matmul(v.conj().transpose(0,1),Y))**2 57 | else: 58 | a = torch.abs(torch.matmul(torch.matmul(v.conj().transpose(0,1),cov),v)) 59 | p.append(torch.sum(a)) 60 | p = torch.tensor(p) 61 | return p/torch.max(p), grid 62 | 63 | def MUSIC_LA(Y,cov,lam,sensor_locations,N_gridpoints,num_sources): 64 | eps = 1e-8 65 | if cov is None: 66 | U,_,_ = torch.linalg.svd(Y) 67 | E_n = U[:,num_sources:] 68 | else: 69 | L, Q = torch.linalg.eigh(cov) 70 | E_n = Q[:,:-num_sources] 71 | grid = [i/(N_gridpoints-1) for i in range(N_gridpoints)] 72 | p = [] 73 | for t in grid: 74 | imag = torch.tensor(sensor_locations).unsqueeze(1)*2*torch.pi*(1/lam)*torch.cos(torch.tensor(t)*torch.pi) 75 | v = (1/math.sqrt(len(sensor_locations)))*torch.exp(torch.complex(torch.zeros_like(imag),imag)) 76 | p.append(1/(torch.sum(torch.abs(torch.matmul(v.conj().transpose(0,1),E_n))**2)+eps)) 77 | p = torch.tensor(p) 78 | return p/torch.max(p), grid 79 | 80 | def RootMUSIC_ULA(Y,cov,num_sources,EnEnH): 81 | if cov is None: 82 | U,_,_ = torch.linalg.svd(Y) 83 | E_n = U[:,num_sources:] 84 | elif EnEnH is False: 85 | _, Q = np.linalg.eigh(cov.numpy()) 86 | E_n = Q[:,:-num_sources] 87 | else: 88 | _, Q = np.linalg.eigh(cov.numpy()) 89 | E_n = Q[:,num_sources:] 90 | N = E_n.shape[0] 91 | M = E_n.shape[1] 92 | tmp = torch.zeros(2*N-1,M,dtype=torch.cfloat).numpy() 93 | for i in range(M): 94 | tmp[:,i] = np.convolve(E_n[:,i],np.flip(E_n[:,i].conj())) 95 | coeff = np.sum(tmp,axis=1) 96 | r = Polynomial(coeff[::-1]).roots() 97 | rmin = r[np.abs(r)<=1] 98 | order = np.argsort(-np.abs(rmin)) 99 | signalroot = rmin[order[:num_sources]] 100 | DoAs = np.sort(np.arccos(np.angle(signalroot)/np.pi)) 101 | remaining_num_src = num_sources - DoAs.shape[0] 102 | success = not remaining_num_src > 0 103 | if not success: 104 | #print(f"Number of DoAs found is not equal to num_sources, we will guess the remaining sources are located at pi/2 rad or 90 deg (remaining_num_src={remaining_num_src})") 105 | DoAs = np.sort(np.concatenate((DoAs,np.array([np.pi/2]*remaining_num_src)))) 106 | return DoAs, success 107 | 108 | def RootMUSIC_ULA_2(Y,cov,num_sources,EnEnH): 109 | if cov is None: 110 | U,_,_ = torch.linalg.svd(Y) 111 | E_n = U[:,num_sources:].numpy() 112 | elif EnEnH is False: 113 | _, Q = np.linalg.eigh(cov.numpy()) 114 | E_n = Q[:,:-num_sources] 115 | else: 116 | _, Q = np.linalg.eigh(cov.numpy()) 117 | E_n = Q[:,num_sources:] 118 | # 1 119 | N = E_n.shape[0] 120 | M = E_n.shape[1] 121 | tmp = torch.zeros(2*N-1,M,dtype=torch.cfloat).numpy() 122 | for i in range(M): 123 | tmp[:,i] = np.convolve(E_n[:,i],np.flip(E_n[:,i].conj())) 124 | coeff = np.sum(tmp,axis=1) 125 | # 2 126 | #m = E_n.shape[0] 127 | #C = E_n @ E_n.T.conj() 128 | #coeff = np.zeros((m - 1,), dtype=np.complex_) 129 | #for i in range(1, m): 130 | # coeff[i - 1] = np.sum(np.diag(C, i)) 131 | #coeff = np.hstack((coeff[::-1], np.sum(np.diag(C)), coeff.conj())) 132 | 133 | z = Polynomial(coeff[::-1]).roots() 134 | # the root finding procedure below is borrowed from https://github.com/morriswmz/doatools.py/blob/master/doatools/estimation/music.py 135 | nz = len(z) 136 | mask = np.ones((nz,), dtype=np.bool_) 137 | for i in range(nz): 138 | absz = abs(z[i]) 139 | if absz > 1.0: 140 | mask[i] = False 141 | elif absz == 1.0: 142 | idx = -1 143 | dist = np.inf 144 | for j in range(nz): 145 | if j != i: 146 | cur_dist = abs(z[i] - z[j]) 147 | if cur_dist < dist: 148 | dist = cur_dist 149 | idx = j 150 | if idx < 0: 151 | raise RuntimeError('Unpaired point found on the unit circle, which is impossible.') 152 | if mask[idx] is True and mask[i] is True: 153 | mask[idx] = False 154 | z = z[mask] 155 | sorted_indices = np.argsort(-np.abs(z)) 156 | z = z[sorted_indices[:num_sources]] 157 | DoAs = np.sort(np.arccos(np.angle(z)/np.pi)) 158 | remaining_num_src = num_sources - DoAs.shape[0] 159 | success = not remaining_num_src > 0 160 | if not success: 161 | #print(f"Number of DoAs found is not equal to num_sources, we will guess the remaining sources are located at pi/2 rad or 90 deg (remaining_num_src={remaining_num_src})") 162 | DoAs = np.sort(np.concatenate((DoAs,np.array([np.pi/2]*remaining_num_src)))) 163 | return DoAs, success 164 | 165 | class BasePredictor: 166 | EnEnH = False 167 | need_snapshot = False 168 | use_noise_var = False 169 | def _get_one_ULA_cov(self, cov: torch.Tensor): 170 | return NotImplemented 171 | 172 | def get_ULA_cov(self, data_in: torch.Tensor, is_snapshot: bool, noise_var: torch.Tensor = None): 173 | data_in, batch_size = data_in_preprocess(data_in) 174 | if is_snapshot is False and self.need_snapshot is True: 175 | raise ValueError(f"given covariance matrices but the predictor actually needs snapshots") 176 | if is_snapshot is True and self.need_snapshot is False: 177 | T_snapshots = data_in.shape[-1] 178 | data = (1/T_snapshots)*torch.matmul(data_in,data_in.conj().transpose(-2,-1)) 179 | else: 180 | data = data_in 181 | output_cov = [] 182 | for b in range(batch_size): 183 | if self.use_noise_var is True: 184 | if noise_var is None: 185 | raise ValueError("Please provide the noise_var because self.use_noise_var is True") 186 | out = self._get_one_ULA_cov(data[b,:,:],noise_var[b]) 187 | else: 188 | out = self._get_one_ULA_cov(data[b,:,:]) 189 | if isinstance(out,np.ndarray): 190 | out = torch.from_numpy(out) 191 | out, _ = data_in_preprocess(out) 192 | output_cov.append(out) 193 | output_cov = torch.cat(output_cov,0) 194 | return output_cov 195 | 196 | def get_DoA_by_rootMUSIC(self, data_in: torch.Tensor, num_sources: int, is_snapshot: bool, noise_var: torch.Tensor = None): 197 | if self.use_noise_var is True: 198 | cov = self.get_ULA_cov(data_in,is_snapshot,noise_var) 199 | else: 200 | cov = self.get_ULA_cov(data_in,is_snapshot) 201 | batch_size = cov.size(0) 202 | DoA_list = [] 203 | success_list = [] 204 | for b in range(batch_size): 205 | DoAs, success = RootMUSIC_ULA_2(None,cov[b,:,:],num_sources,self.EnEnH) 206 | DoA_list.append(torch.from_numpy(DoAs).unsqueeze(0)) 207 | success_list.append(success) 208 | DoA = torch.cat(DoA_list,0) 209 | return DoA, success_list 210 | 211 | def get_DoA(self, data_in: torch.Tensor, num_sources: int, is_snapshot: bool, noise_var: torch.Tensor = None): 212 | return self.get_DoA_by_rootMUSIC(data_in, num_sources, is_snapshot, noise_var) 213 | 214 | class CovMRA2ULA_DA(BasePredictor): 215 | def _get_one_ULA_cov(self,cov: torch.Tensor): 216 | return direct_augmentation(cov)[0] 217 | 218 | class CovMRA2ULA_SS(BasePredictor): 219 | def _get_one_ULA_cov(self,cov: torch.Tensor): 220 | return spatial_smoothing(cov)[0] -------------------------------------------------------------------------------- /ICASSP/eval/perf_n4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRGF=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGSISDR=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SISDRFrobeniusNorm_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGSISDR_Signal=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSISDRFrobeniusNorm_mu=02_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 6 | 7 | # SNR vs. MSE 8 | python3 performance.py --results_folder dnn_results --cov_models $DCRGF $DCRGSISDR $DCRGSISDR_Signal --DA 1 --N_sensors 4 --num_sources_list 1 2 3 4 5 6 --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 9 | 10 | # number of snapshots vs. MSE 11 | python3 performance.py --results_folder dnn_results --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --cov_models $DCRGF $DCRGSISDR $DCRGSISDR_Signal --DA 1 --N_sensors 4 --num_sources_list 1 4 6 --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /ICASSP/eval/perf_n5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRGF=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGSISDR=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=SISDRFrobeniusNorm_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGSISDR_Signal=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=SignalSISDRFrobeniusNorm_mu=02_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 6 | 7 | # SNR vs. MSE 8 | python3 performance.py --results_folder dnn_results --cov_models $DCRGF $DCRGSISDR $DCRGSISDR_Signal --DA 1 --N_sensors 5 --num_sources_list 1 2 3 4 5 6 7 8 9 --min_sep 4 4 4 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 9 | 10 | # number of snapshots vs. MSE 11 | python3 performance.py --results_folder dnn_results --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --cov_models $DCRGF $DCRGSISDR $DCRGSISDR_Signal --DA 1 --N_sensors 5 --num_sources_list 1 2 3 4 5 6 7 8 9 --min_sep 4 4 4 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /ICASSP/run/cuda0_n4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 4-element MRA 4 | 5 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.01 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --save_dataset 1 6 | 7 | python3 main.py --train_L 200 --loss SISDRFrobeniusNorm --mu 0.05 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 8 | 9 | python3 main.py --train_L 200 --loss SignalSISDRFrobeniusNorm --mu 0.2 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /ICASSP/run/cuda0_n5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 5-element MRA 4 | 5 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.01 --model N5_M10_WRN_16_8 --N_sensors 5 --min_sep 3 3 3 3 3 3 3 3 3 --n_sources_train 1 2 3 4 5 6 7 8 9 --n_sources_val 1 2 3 4 5 7 8 9 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 --save_dataset 1 6 | 7 | python3 main.py --train_L 200 --loss SISDRFrobeniusNorm --mu 0.05 --model N5_M10_WRN_16_8 --N_sensors 5 --min_sep 3 3 3 3 3 3 3 3 3 --n_sources_train 1 2 3 4 5 6 7 8 9 --n_sources_val 1 2 3 4 5 7 8 9 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 8 | 9 | python3 main.py --train_L 200 --loss SignalSISDRFrobeniusNorm --mu 0.2 --model N5_M10_WRN_16_8 --N_sensors 5 --min_sep 3 3 3 3 3 3 3 3 3 --n_sources_train 1 2 3 4 5 6 7 8 9 --n_sources_val 1 2 3 4 5 7 8 9 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Subspace Representation Learning for Sparse Linear Arrays to Localize More Sources than Sensors: A Deep Learning Methodology 2 | 3 | This repository is the official implementation of the paper accepted by IEEE Transactions on Signal Processing, [Subspace Representation Learning for Sparse Linear Arrays to Localize More Sources than Sensors: A Deep Learning Methodology](https://ieeexplore.ieee.org/document/10899405). This repository also hosts the official implementation of the paper accepted at ICASSP 2025, [A Comparative Study of Invariance-Aware Loss Functions for Deep Learning-based Gridless Direction-of-Arrival Estimation](https://ieeexplore.ieee.org/document/10889620). 4 | 5 | - Download the paper of Subspace Representation Learning from [IEEE Xplore](https://ieeexplore.ieee.org/document/10899405) or [arXiv](https://arxiv.org/abs/2408.16605). 6 | - Download the paper of Invariance-Aware Loss Functions from [IEEE Xplore](https://ieeexplore.ieee.org/document/10889620) or [arXiv](https://www.arxiv.org/abs/2503.12386). 7 | 8 | ## Requirements 9 | 10 | To install requirements: 11 | 12 | ```setup 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | > Please ensure Python is installed before running the above setup command. The code was tested on Python 3.9.13 and 3.10.12. 17 | 18 | If you are interested in running SDP-based baselines such as SPA and StructCovMLE, then you will need to install MATLAB because all SDP problems will be solved by the SDPT3 solver in CVX. For the implementation of all SDP-based baselines, please see the "SDP" folder. 19 | 20 | ## Training DNN models for DoA estimation 21 | 22 | To reproduce the numerical results in the paper, we will need to train DNN models before evaluation. To train all the models in the experiment of the 4-element MRA, run: 23 | 24 | ```train 25 | bash run/cuda0_n4.sh 26 | ``` 27 | 28 | To replicate results for the 5-element and 6-element MRAs, simply run `bash run/cuda0_n5.sh` and `bash run/cuda0_n6.sh`. 29 | 30 | > The best learning rate can be found by a simple grid search using `cuda0_lr_search.sh`. See Appendix C in the paper for more details about learning rates and the empirical risk on the validation set. 31 | 32 | If one is only interested in subspace representation learning for the 5-element MRA, then one can run: 33 | 34 | ```train_subspace 35 | python main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 36 | ``` 37 | 38 | > This will train a model with 200x10,000=2,000,000 data points per source number, leading to a dataset of 18,000,000 data points in total since the 5-element MRA can resolve up to 9 sources. The base of 10,000 can be configured by the option `base_L` of `main.py` and one can also specify the size of the validation set via the option `val_L`. Consistent rank sampling is enabled, the learning rate is 0.1, and the loss function is `SignalSubspaceDist` which will apply subspace representation learning. 39 | 40 | The above command will train a model for a perfect 5-element MRA. To train a model for imperfect arrays, run: 41 | 42 | ```train_subspace_for_imperfect_arrays 43 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --rho 1.0 --mix_rho 1 44 | ``` 45 | 46 | > Note that here we enable the option `mix_rho` to randomly select the degree of imperfections, rho, in the interval [0, 1.0], for each data point. If the option `rho` is set to 0.5, then the interval becomes [0, 0.5]. This `mix_rho` allows us to create a dataset of different imperfect arrays. 47 | 48 | To train a model for the gridless end-to-end approach, run: 49 | 50 | ```train_gridless_end2end_for_imperfect_arrays 51 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.2 --consistent_rank_sampling 1 52 | ``` 53 | 54 | ## Evaluation 55 | 56 | To evaluate performance of all of the SDP-based methods in the case of 5-element MRA, run: 57 | 58 | ```eval_SDP 59 | bash eval/perf_opt_n5.sh 60 | ``` 61 | 62 | Performance evaluation can be customized by using the options of `performance.py`. For example, you can change the number of total random trials by specifying the number of random angles and the number of random trials per random angle. You can also change the minimum separation constraint, the degree of array imperfections, SNRs, numbers of sources, the methods you want to evaluate, etc. For instance, the following command evaluates the direct augmentation approach, spatial smoothing, Wasserstein distance minimization, and SPA, using a total of 10,000 random trials and a minimum separation of 4 degrees on the perfect 5-element MRA. One can even specify different minimum separations for different source numbers. See the options of `performance.py` for more details. 63 | 64 | ```eval_SDP_MRA5 65 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 2 3 4 5 6 7 8 9 --rho 0 --results_folder results --min_sep 4 4 4 4 4 4 4 4 4 66 | ``` 67 | 68 | To evaluate performance of all of the DNN-based methods in the case of 5-element MRA, run: 69 | 70 | ```eval_DNN 71 | bash eval/perf_dnn_n5.sh 72 | ``` 73 | 74 | You can also evaluate any number of DNN-based approaches by specifying a list of paths to the models after the option `cov_models`. For example, the following command will evaluate 4 different models. 75 | 76 | ```eval_DNN_single 77 | python3 performance.py --results_folder dnn_results --cov_models $DCRT $DCRGF $DCRGA $OURS 78 | ``` 79 | 80 | To evaluate methods on other MRAs, simply switch the script to other cases as indicated in the folder `eval`. 81 | 82 | ## Additional numerical results 83 | 84 | Additional numerical results for the 5-element MRA (see Section V(g) of the ICASSP paper) are available in the PDF file [`Appendix_Invariance_Aware_Loss_Functions_results_for_5-element_MRA.pdf`](Appendix_Invariance_Aware_Loss_Functions_results_for_5-element_MRA.pdf). 85 | 86 | ## BibTeX 87 | 88 | If this repository contributes to your research, please cite our papers. 89 | 90 | ``` 91 | @article{chen2025subspace, 92 | title={Subspace Representation Learning for Sparse Linear Arrays to Localize More Sources than Sensors: A Deep Learning Methodology}, 93 | author={Chen, Kuan-Lin and Rao, Bhaskar D.}, 94 | journal={IEEE Transactions on Signal Processing}, 95 | volume={73}, 96 | pages={1293-1308}, 97 | year={2025}, 98 | doi={10.1109/TSP.2025.3544170} 99 | } 100 | ``` 101 | 102 | ``` 103 | @inproceedings{chen2025comparative, 104 | title={A Comparative Study of Invariance-Aware Loss Functions for Deep Learning-based Gridless Direction-of-Arrival Estimation}, 105 | author={Chen, Kuan-Lin and Rao, Bhaskar D.}, 106 | booktitle={International Conference on Acoustics, Speech and Signal Processing}, 107 | year={2025}, 108 | organization={IEEE} 109 | } 110 | ``` -------------------------------------------------------------------------------- /SDP/ProxCov.m: -------------------------------------------------------------------------------- 1 | function T = ProxCov(Y,S,epsilon) 2 | n = size(S,1); 3 | m = size(S,2); 4 | l = size(Y,2); 5 | eI = epsilon*eye(l); 6 | cvx_begin sdp quiet 7 | variable T(m,m) hermitian toeplitz 8 | variable W(l,l) hermitian complex 9 | minimize( norm(Y*W*(Y')-S*T*(S'),'fro') ) 10 | T >= 0; 11 | W >= eI; 12 | cvx_end 13 | end -------------------------------------------------------------------------------- /SDP/SDP.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | SDP-based baselines used in the paper 7 | 8 | https://arxiv.org/abs/2408.16605 9 | """ 10 | import numpy as np 11 | import matlab.engine 12 | import torch 13 | from utils import MRA 14 | from DoA import BasePredictor 15 | import os 16 | 17 | class SDPMRA2ULA(BasePredictor): 18 | def __init__(self, N_sensors: int): 19 | self.N_sensors = N_sensors 20 | _, sensor_grid, N_a = MRA(N_sensors,1) 21 | self.M_sensors = N_a + 1 22 | self.S = np.zeros((self.N_sensors,self.M_sensors),dtype=np.cfloat) 23 | for i in range(self.N_sensors): 24 | self.S[i,sensor_grid[i]] = 1 25 | 26 | class SDPCovMRA2ULA_Wasserstein_SDPT3(SDPMRA2ULA): 27 | def __init__(self, N_sensors: int): 28 | super().__init__(N_sensors) 29 | self.eng = matlab.engine.start_matlab() 30 | self.eng.cd(os.path.dirname(os.path.realpath(__file__)), nargout=0) 31 | 32 | def _get_one_ULA_cov(self,cov: torch.Tensor): 33 | return np.array(self.eng.Wasserstein(matlab.double(cov.tolist(),is_complex=True),matlab.double(self.S.tolist(),is_complex=True))) 34 | 35 | class SDPCovMRA2ULA_SPA_SDPT3(SDPMRA2ULA): 36 | def __init__(self, N_sensors: int, use_noise_var: bool, remove_noise: bool = True): 37 | super().__init__(N_sensors) 38 | self.use_noise_var = use_noise_var 39 | self.remove_noise = remove_noise 40 | self.eng = matlab.engine.start_matlab() 41 | self.eng.cd(os.path.dirname(os.path.realpath(__file__)), nargout=0) 42 | 43 | def _get_one_ULA_cov(self,cov: torch.Tensor, noise_var: torch.Tensor = None): 44 | if noise_var == None and self.use_noise_var is False: 45 | R = np.array(self.eng.SPA(matlab.double(cov.tolist(),is_complex=True),matlab.double(self.S.tolist(),is_complex=True))) 46 | elif noise_var != None and self.use_noise_var is True: 47 | R = np.array(self.eng.SPA_noisevar(matlab.double(cov.tolist(),is_complex=True),matlab.double(self.S.tolist(),is_complex=True),matlab.double(noise_var.tolist()))) 48 | else: 49 | raise ValueError(f"Incorrect mode: self.use_noise_var={self.use_noise_var} and noise_var={noise_var}") 50 | if self.remove_noise is True: 51 | L, _ = np.linalg.eigh(R) 52 | R = R - min(L) * np.eye(L.shape[0]) 53 | return R 54 | 55 | class SDPSnapshotMRA2ULA_ProxCov_SDPT3(SDPMRA2ULA): 56 | def __init__(self, N_sensors: int, epsilon: float): 57 | super().__init__(N_sensors) 58 | self.need_snapshot = True 59 | self.epsilon = epsilon 60 | self.eng = matlab.engine.start_matlab() 61 | self.eng.cd(os.path.dirname(os.path.realpath(__file__)), nargout=0) 62 | 63 | def _get_one_ULA_cov(self,Y: torch.Tensor): 64 | return np.array(self.eng.ProxCov(matlab.double(Y.tolist(),is_complex=True),matlab.double(self.S.tolist(),is_complex=True),matlab.double(self.epsilon))) 65 | 66 | class SDPCovMRA2ULA_StructCovMLE_SDPT3(SDPMRA2ULA): 67 | def __init__(self, N_sensors: int, epsilon: float, max_iter: int, use_noise_var: bool): 68 | super().__init__(N_sensors) 69 | self.epsilon = epsilon 70 | self.max_iter = max_iter 71 | self.use_noise_var = use_noise_var 72 | self.eng = matlab.engine.start_matlab() 73 | self.eng.cd(os.path.dirname(os.path.realpath(__file__)), nargout=0) 74 | 75 | def _get_one_ULA_cov(self, cov: torch.Tensor, noise_var: torch.Tensor = None): 76 | if noise_var == None and self.use_noise_var is False: 77 | return np.array(self.eng.StructCovMLE(matlab.double(cov.tolist(),is_complex=True),matlab.double(self.S.tolist(),is_complex=True),matlab.double(self.epsilon),matlab.int16(self.max_iter))) 78 | elif noise_var != None and self.use_noise_var is True: 79 | return np.array(self.eng.StructCovMLE_noisevar(matlab.double(cov.tolist(),is_complex=True),matlab.double(self.S.tolist(),is_complex=True),matlab.double(self.epsilon),matlab.int16(self.max_iter),matlab.double(noise_var.tolist()))) 80 | else: 81 | raise ValueError(f"Incorrect mode: self.use_noise_var={self.use_noise_var} and noise_var={noise_var}") -------------------------------------------------------------------------------- /SDP/SPA.m: -------------------------------------------------------------------------------- 1 | function T = SPA(R_hat,S) 2 | n = size(S,1); 3 | m = size(S,2); 4 | R_hat_inv = inv(R_hat); 5 | R_hat_inv = 0.5*(R_hat_inv+R_hat_inv'); 6 | R_sqrt = sqrtm(R_hat); 7 | Z = zeros(n,m); 8 | cvx_begin sdp quiet 9 | variable T(m,m) hermitian toeplitz 10 | variable X(n,n) hermitian complex 11 | minimize( real(trace(X) + trace(R_hat_inv*S*T*(S'))) ) 12 | [X, R_sqrt, Z; R_sqrt', S*T*(S'), Z; Z', Z', T] >= 0; 13 | cvx_end 14 | end -------------------------------------------------------------------------------- /SDP/SPA_noisevar.m: -------------------------------------------------------------------------------- 1 | function T = SPA_noisevar(R_hat,S,lambda) 2 | n = size(S,1); 3 | m = size(S,2); 4 | R_hat_inv = inv(R_hat); 5 | R_hat_inv = 0.5*(R_hat_inv+R_hat_inv'); 6 | R_sqrt = sqrtm(R_hat); 7 | Z = zeros(n,m); 8 | cvx_begin sdp quiet 9 | variable T(m,m) hermitian toeplitz 10 | variable X(n,n) hermitian complex 11 | minimize( real(trace(X) + trace(R_hat_inv*S*T*(S'))) ) 12 | [X, R_sqrt, Z; R_sqrt', S*T*(S')+lambda*eye(n), Z; Z', Z', T] >= 0; 13 | cvx_end 14 | end -------------------------------------------------------------------------------- /SDP/StructCovMLE.m: -------------------------------------------------------------------------------- 1 | function V = StructCovMLE(R_hat,S,epsilon,max_iter) 2 | % """ 3 | % Created on Sun Jun 9 2024 4 | % 5 | % @author: Kuan-Lin Chen 6 | % 7 | % Implementation of the StructCovMLE approach in the following paper: 8 | % 9 | % Pote, Rohan R., and Bhaskar D. Rao. 10 | % "Maximum likelihood-based gridless DoA estimation using structured 11 | % covariance matrix recovery and SBL with grid refinement." 12 | % IEEE Transactions on Signal Processing 71 (2023): 802-815. 13 | % 14 | % @param R_hat: the sample spatial covariance matrix received at the sparse linear array 15 | % @param S: the row-selection matrix containing only ones and zeros 16 | % @param epsilon: the threshold of relative change (one of the stopping criterion) 17 | % @param max_iter: the maximum number of iterations (one if the stopping criterion) 18 | % @return V: the spatial covariance matrix estimate of the corresponding ULA 19 | % """ 20 | n = size(S,1); 21 | m = size(S,2); 22 | I = eye(n); 23 | V = eye(m); 24 | for i = 1:max_iter 25 | V_prev = V; 26 | Vs_inv = inv(S*V*(S')); 27 | Vs_inv = 0.5*(Vs_inv+Vs_inv'); 28 | Z = zeros(n,m); 29 | cvx_begin sdp quiet 30 | variable T(m,m) hermitian toeplitz 31 | variable X(n,n) hermitian complex 32 | minimize( real(trace(Vs_inv*S*T*(S')) + trace(X*R_hat)) ) 33 | [X, I, Z; I, S*T*(S'), Z; Z', Z', T] >= 0; 34 | cvx_end 35 | V = T; 36 | relative_change = norm(V-V_prev,'fro') / norm(V_prev,'fro'); 37 | if relative_change < epsilon 38 | break 39 | end 40 | end 41 | end -------------------------------------------------------------------------------- /SDP/StructCovMLE_noisevar.m: -------------------------------------------------------------------------------- 1 | function R=StructCovMLE_noisevar(S,So,epsilon,mmITER,lambda1) 2 | % Modified from https://github.com/rohanpote/GridlessDoA_StructCovMLE/blob/main/lib/StructCovMLE_MUSIC.m 3 | % Originally written by Rohan R. Pote, 2022 4 | % Pote, Rohan R., and Bhaskar D. Rao. 5 | % "Maximum likelihood-based gridless DoA estimation using structured 6 | % covariance matrix recovery and SBL with grid refinement." 7 | % IEEE Transactions on Signal Processing 71 (2023): 802-815. 8 | M=size(So,1); 9 | Mapt=size(So,2); 10 | x_old = [1 zeros(1,Mapt-1)]; 11 | for mmloop = 1:mmITER 12 | B0 = inv(So*toeplitz(x_old)*So'+lambda1*eye(M)); 13 | cvx_begin sdp quiet 14 | % cvx_solver sedumi 15 | variable x(1,Mapt) complex 16 | variable U(M,M) hermitian 17 | minimize(real(trace(B0*(So*toeplitz(x)*So')))+real(trace(U*S))) 18 | subject to 19 | (toeplitz(x)+toeplitz(x)')/2>=0 20 | [U eye(M); eye(M) So*((toeplitz(x)+toeplitz(x)')/2)*So'+lambda1*eye(M)]>=0 21 | cvx_end 22 | relative_change = norm(x-x_old,'fro') / norm(x_old,'fro'); 23 | if relative_change < epsilon 24 | break 25 | end 26 | x_old = x; 27 | end 28 | R = toeplitz(x); 29 | end -------------------------------------------------------------------------------- /SDP/Wasserstein.m: -------------------------------------------------------------------------------- 1 | function R0 = Wasserstein(R_hat,S) 2 | n = size(S,1); 3 | m = size(S,2); 4 | R_hat = 0.5*(R_hat+R_hat'); 5 | cvx_begin sdp quiet 6 | variable R0(m,m) hermitian toeplitz 7 | variable V(n,n) complex 8 | minimize( trace( R_hat + S*R0*(S') - V - V') ) 9 | [S*R0*(S'), V;V', R_hat] >= 0; 10 | R0 >= 0; 11 | cvx_end 12 | end -------------------------------------------------------------------------------- /batch_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sat May 11 2024 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | https://arxiv.org/abs/2408.16605 7 | """ 8 | from typing import Iterator, List 9 | import torch 10 | 11 | # consistent rank sampling, see Section IV-E in the paper 12 | class ConsistentRankBatchSampler(torch.utils.data.Sampler[List[int]]): 13 | def __init__(self,N: int, K: int, batch_size: int, drop_last: bool=False) -> None: 14 | if not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0: 15 | raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") 16 | if not isinstance(drop_last, bool): 17 | raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") 18 | self.N = N 19 | self.K = K 20 | self.total_size = self.N * self.K 21 | self.batch_size = batch_size 22 | self.drop_last = drop_last 23 | self.samples = [(torch.randperm(self.N) + self.N * i).tolist() for i in range(self.K)] 24 | 25 | def __iter__(self) -> Iterator[List[int]]: 26 | while len(self.samples) != 0: 27 | k = torch.randint(len(self.samples),(1,)).item() 28 | if self.drop_last: 29 | yield self.samples[k][:self.batch_size] 30 | del self.samples[k][:self.batch_size] 31 | if len(self.samples[k]) < self.batch_size: 32 | del self.samples[k] 33 | else: 34 | if len(self.samples[k]) <= self.batch_size: 35 | yield self.samples[k] 36 | del self.samples[k] 37 | else: 38 | yield self.samples[k][:self.batch_size] 39 | del self.samples[k][:self.batch_size] 40 | 41 | def __len__(self) -> int: 42 | if self.drop_last: 43 | return (self.N // self.batch_size) * self.K 44 | else: 45 | return ((self.N + self.batch_size - 1) // self.batch_size) * self.K -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | """ 6 | from datetime import datetime 7 | from typing import List 8 | import numpy as np 9 | import scipy.linalg as la 10 | import h5py 11 | import os 12 | import torch 13 | from torch.utils.data.dataset import Dataset 14 | from tqdm import tqdm 15 | from utils import MRA, cov_normalize, dir_path 16 | 17 | def sample(a: float, b: float, min_sep: float): 18 | s = np.random.uniform(a,b) 19 | (l_a, l_b) = (a, s - min_sep) if s - min_sep > a else (None,None) 20 | (r_a, r_b) = (s + min_sep, b) if s + min_sep < b else (None,None) 21 | if l_a is None and r_a is None: 22 | return [s] 23 | elif l_a is None: 24 | return [s] + sample(r_a, r_b, min_sep) 25 | elif r_a is None: 26 | return [s] + sample(l_a, l_b, min_sep) 27 | else: 28 | return [s] + sample(l_a, l_b, min_sep) + sample(r_a, r_b, min_sep) 29 | 30 | def random_source_angles(deg_range: List[float], min_sep: float, num_sources: int): 31 | candidates = sample(deg_range[0],deg_range[1],min_sep) 32 | while len(candidates) < num_sources: 33 | #print(f"candidates ({len(candidates)}) < num_sources ({num_sources}), resample") 34 | candidates = sample(deg_range[0],deg_range[1],min_sep) 35 | return np.random.permutation(np.random.choice(a=candidates, size=num_sources, replace=False).astype(np.float32)) 36 | 37 | class ArrayManifold: 38 | @torch.no_grad() 39 | def __init__(self, d: float, lam: float, N_sensors: int,gain_bias: List[float], phase_bias_deg: List[float], 40 | position_bias: List[float],mc_mag_angle: List[float], device: str): 41 | self.d = d 42 | self.lam = lam 43 | self.N_sensors = N_sensors 44 | self.device = device 45 | # MRA and ULA 46 | MRA_sensor_locations, sensor_grid, N_a = MRA(N_sensors,d) 47 | ULA_sensor_locations = [(i-N_a/2)*d for i in range(N_a+1)] 48 | self.sensor_grid = sensor_grid 49 | self.ULA_M_sensors = len(ULA_sensor_locations) 50 | self.MRA_sensor_locations = torch.tensor(MRA_sensor_locations,device=device) 51 | self.ULA_sensor_locations = torch.tensor(ULA_sensor_locations,device=device) 52 | # imperfections 53 | if len(mc_mag_angle) != 2: 54 | raise ValueError("invalid mc_mag_angle, mc_mag_angle[0] is the magnitude and mc_mag_angle[1] is phase in degrees") 55 | if len(gain_bias) != self.ULA_M_sensors or len(phase_bias_deg) != self.ULA_M_sensors or len(position_bias) != self.ULA_M_sensors: 56 | raise ValueError("invalid gain_bias, phase_bias_deg, or position_bias, their length must be equal to M") 57 | self.gain_bias = torch.tensor(gain_bias,device=device,dtype=torch.complex64) 58 | self.phase_bias = torch.tensor(phase_bias_deg,device=device,dtype=torch.float32) * np.pi/180 59 | self.position_bias = torch.tensor(position_bias,device=device,dtype=torch.float32) * d 60 | gamma = mc_mag_angle[0]*np.exp(1j*mc_mag_angle[1]*np.pi/180) 61 | ula_gamma_vec = gamma ** np.arange(self.ULA_M_sensors) 62 | ula_gamma_vec[0] = 0 63 | self.ula_mcm = torch.from_numpy(la.toeplitz(ula_gamma_vec)).type(torch.complex64).to(device) 64 | mra_gamma_vec = ula_gamma_vec[self.sensor_grid] 65 | self.mra_mcm = torch.from_numpy(la.toeplitz(mra_gamma_vec)).type(torch.complex64).to(device) 66 | 67 | @torch.no_grad() 68 | def get_V(self, rho: float, source_angles: torch.Tensor, mix: bool, mode: str): 69 | # MRA_sensor_locations is of size N 70 | # source_angles is of size L x 1 x # of sources 71 | # V is of size L x N x # of sources 72 | if mode == 'MRA': 73 | if rho == 0: 74 | imag = 2*torch.pi*(1/self.lam)*torch.matmul(self.MRA_sensor_locations.unsqueeze(1).unsqueeze(0),torch.cos(source_angles)) 75 | V = torch.exp(torch.complex(torch.zeros_like(imag),imag)) 76 | else: 77 | if mix is True: 78 | rho = rho * torch.rand(source_angles.shape[0],1,dtype=torch.float32) 79 | else: 80 | rho = rho * torch.ones(source_angles.shape[0],1,dtype=torch.float32) 81 | e_gain = 1.0 + rho.type(torch.complex64) @ self.gain_bias[self.sensor_grid].unsqueeze(0) 82 | e_phase = torch.exp(1j * (rho @ self.phase_bias[self.sensor_grid].unsqueeze(0))) 83 | e_pos = rho @ self.position_bias[self.sensor_grid].unsqueeze(0) 84 | E_mc = torch.eye(self.N_sensors,dtype=torch.complex64,device=self.device).unsqueeze(0) + rho.type(torch.complex64).unsqueeze(2) * self.mra_mcm.unsqueeze(0) 85 | MRA_sensor_locations_e = self.MRA_sensor_locations.unsqueeze(0) + e_pos 86 | imag = 2*torch.pi*(1/self.lam)*torch.matmul(MRA_sensor_locations_e.unsqueeze(2),torch.cos(source_angles)) 87 | temp = e_gain.unsqueeze(2) * e_phase.unsqueeze(2) * torch.exp(torch.complex(torch.zeros_like(imag),imag)) 88 | V = torch.matmul(E_mc,temp) 89 | elif mode == 'ULA': 90 | if rho == 0: 91 | imag = 2*torch.pi*(1/self.lam)*torch.matmul(self.ULA_sensor_locations.unsqueeze(1).unsqueeze(0),torch.cos(source_angles)) 92 | V = torch.exp(torch.complex(torch.zeros_like(imag),imag)) 93 | else: 94 | if mix is True: 95 | rho = rho * torch.rand(source_angles.shape[0],1,dtype=torch.float32) 96 | else: 97 | rho = rho * torch.ones(source_angles.shape[0],1,dtype=torch.float32) 98 | e_gain = 1.0 + rho.type(torch.complex64) @ self.gain_bias.unsqueeze(0) 99 | e_phase = torch.exp(1j * (rho @ self.phase_bias.unsqueeze(0))) 100 | e_pos = rho @ self.position_bias.unsqueeze(0) 101 | E_mc = torch.eye(self.ULA_M_sensors,dtype=torch.complex64,device=self.device).unsqueeze(0) + rho.type(torch.complex64).unsqueeze(2) * self.ula_mcm.unsqueeze(0) 102 | ULA_sensor_locations_e = self.ULA_sensor_locations.unsqueeze(0) + e_pos 103 | imag = 2*torch.pi*(1/self.lam)*torch.matmul(ULA_sensor_locations_e.unsqueeze(2),torch.cos(source_angles)) 104 | temp = e_gain.unsqueeze(2) * e_phase.unsqueeze(2) * torch.exp(torch.complex(torch.zeros_like(imag),imag)) 105 | V = torch.matmul(E_mc,temp) 106 | else: 107 | raise TypeError(f"invalid mode={mode}, must be MRA or ULA") 108 | return V 109 | 110 | @torch.no_grad() 111 | def get_random_source_angles(deg_range: List[float], min_sep: float, num_sources: int, num_datapoints: int, mode: str, seed: int): 112 | filepath = os.path.join('./source_angles/',f"mode={mode}_source_angles_rg={str(deg_range)}_sep={min_sep}_nsrc={num_sources}_ndatapoints={num_datapoints}_seed={seed}.hdf5".replace(' ','')) 113 | if os.path.isfile(filepath): 114 | #print((f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] A dataset of random source angles already exists at {filepath}" 115 | #" (remove the existing dataset if you want to create a new one). Start loading...")) 116 | with h5py.File(filepath,'r') as file: 117 | source_angles = file["source_angles"][:] 118 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Finished loading the dataset") 119 | else: 120 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] No dataset found at {filepath}, generate a new dataset of random source angles") 121 | source_angles = np.zeros((num_datapoints,num_sources),dtype=np.float32) 122 | for i in tqdm(range(num_datapoints),leave=True): 123 | source_angles[i,:] = random_source_angles(deg_range,min_sep,num_sources) * np.pi/180 124 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Saving the dataset to path {filepath}") 125 | dir_path('./source_angles/') 126 | with h5py.File(filepath,'w') as file: 127 | file.create_dataset(name="source_angles",data=source_angles,compression='gzip') 128 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Dataset saved at {filepath}") 129 | return source_angles 130 | 131 | @torch.no_grad() 132 | def get_source_and_noise_random_base(base_L: int, num_sources: int, T_snapshots: int, M: int, seed: int, mode: str): 133 | filepath = os.path.join('./source_noise_random_base/',f"mode={mode}_sn_random_baseL={base_L}_M={M}_nsrc={num_sources}_Tsnapshots={T_snapshots}_seed={seed}.hdf5".replace(' ','')) 134 | if os.path.isfile(filepath): 135 | #print((f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] A dataset of source/noise random base already exists at {filepath}" 136 | #" (remove the existing dataset if you want to create a new one). Start loading...")) 137 | with h5py.File(filepath,'r') as file: 138 | source_base = file["source_base"][:] 139 | noise_base = file["noise_base"][:] 140 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Finished loading the dataset") 141 | else: 142 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] No dataset found at {filepath}, generate a new dataset of random source angles") 143 | source_base = torch.randn(base_L,num_sources,T_snapshots,dtype=torch.cfloat) # random source base 144 | noise_base = torch.randn(base_L,M,T_snapshots,dtype=torch.cfloat) # noise base 145 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Saving the dataset to path {filepath}") 146 | dir_path('./source_noise_random_base/') 147 | with h5py.File(filepath,'w') as file: 148 | file.create_dataset(name="source_base",data=source_base,compression='gzip') 149 | file.create_dataset(name="noise_base",data=noise_base,compression='gzip') 150 | #print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Dataset saved at {filepath}") 151 | return source_base, noise_base 152 | 153 | @torch.no_grad() 154 | def generate_batch_cov_MRA_ULA(array_manifold: ArrayManifold, rho: float, mix_rho: bool, source_base: torch.Tensor, noise_base: torch.Tensor, T_snapshots: int, p: torch.Tensor, 155 | SNR: float, source_angles: List[float], use_variance: bool, noisy_ULA: bool = False, normalization: str = 'disabled',diag_src_cov=True,return_MRA_snapshots=False,total_power_one=False): 156 | # source 157 | if total_power_one is True: 158 | p = p / torch.sum(p,1).unsqueeze(1) # the given p is a 2d tensor (L x # of sources) 159 | # source_angles is a 2d tensor (L x # of sources) and sources is a 3d tensor (L x # of sources x T) 160 | sources = torch.sqrt(p).to(torch.cfloat).unsqueeze(2)*source_base # T_snapshots complex zero-mean circularly-symmetric gaussian random vectors 161 | # noise (L x M x T) 162 | noise = (1/(10**(SNR/20)))*noise_base # T_snapshots complex zero-mean circularly-symmetric gaussian random vectors 163 | # source_angles is converted to a 3d tensor (L x 1 x # of sources) 164 | source_angles = source_angles.unsqueeze(1) 165 | # ULA 166 | V_ULA = array_manifold.get_V(rho=0,source_angles=source_angles,mix=False,mode='ULA') 167 | # ULA sample covariance with or without noise 168 | Y_ULA = torch.matmul(V_ULA,sources) 169 | Y_ULA = Y_ULA + noise if noisy_ULA is True else Y_ULA 170 | if diag_src_cov is True: 171 | # ULA noise-free diagonal covariance matrix or diagonal sample covariance matrix 172 | if use_variance is True: 173 | cov_ULA = torch.matmul(torch.matmul(V_ULA,torch.vmap(torch.diag)(p).to(torch.cfloat)),V_ULA.conj().transpose(-2,-1)) 174 | else: 175 | source_sample_cov = (1/T_snapshots)*torch.matmul(sources,sources.conj().transpose(-2,-1)) 176 | source_sample_cov_diag = torch.vmap(torch.diag)(torch.diagonal(source_sample_cov,dim1=-2,dim2=-1)) 177 | cov_ULA = torch.matmul(torch.matmul(V_ULA,source_sample_cov_diag),V_ULA.conj().transpose(-2,-1)) 178 | else: 179 | cov_ULA = (1/T_snapshots)*torch.matmul(Y_ULA,Y_ULA.conj().transpose(-2,-1)) 180 | # imperfect or perfect (depending on rho) ULA with holes or MRA (no zero padding) 181 | V = array_manifold.get_V(rho=rho,source_angles=source_angles,mix=mix_rho,mode='MRA') 182 | noise = noise[:,array_manifold.sensor_grid,:] 183 | Y_nopad = torch.matmul(V,sources) + noise 184 | if return_MRA_snapshots is True: 185 | return_MRA = Y_nopad 186 | else: 187 | cov_MRA = (1/T_snapshots)*torch.matmul(Y_nopad,Y_nopad.conj().transpose(-2,-1)) 188 | # normalization 189 | return_MRA = cov_normalize(cov_MRA,normalization,array_manifold.N_sensors) 190 | # normalization 191 | cov_ULA = cov_normalize(cov_ULA,normalization,array_manifold.ULA_M_sensors) 192 | return return_MRA, cov_ULA, array_manifold.MRA_sensor_locations, array_manifold.ULA_sensor_locations 193 | 194 | class CovMapDataset(Dataset): 195 | def __init__(self, 196 | mode: str, 197 | L: int, 198 | d: float, 199 | lam: float, 200 | N_sensors: int, 201 | T_snapshots: int, 202 | num_sources: List[int], 203 | snr_range: List[float], 204 | snr_uniform: bool, 205 | snr_list: List[float], 206 | snr_prob: List[float], 207 | seed: int, 208 | deg_range: List[float], 209 | min_sep: List[float], 210 | diag_src_cov: bool, 211 | use_variance: bool, 212 | gain_bias: List[float], 213 | phase_bias_deg: List[float], 214 | position_bias: List[float], 215 | mc_mag_angle: List[float], 216 | rho: float, 217 | mix_rho: bool, 218 | base_L: int = 10000, 219 | dynamic: bool = True, 220 | random_power: bool = False, 221 | power_range: List[float] = [0.1,1.0], 222 | total_power_one: bool = False, 223 | normalization: str = 'disabled', 224 | device: str = 'cpu', 225 | save_dataset: bool = False 226 | ): 227 | np.random.seed(seed) 228 | torch.manual_seed(seed) 229 | self.L = L 230 | self.d = d 231 | self.lam = lam 232 | self.N_sensors = N_sensors 233 | self.T_snapshots = T_snapshots 234 | self.num_sources = num_sources 235 | self.snr_range = snr_range 236 | self.snr_uniform = snr_uniform 237 | self.snr_list = np.array(snr_list) 238 | self.snr_prob = np.array(snr_prob) 239 | self.deg_range = deg_range 240 | self.min_sep = min_sep 241 | self.diag_src_cov = diag_src_cov 242 | self.use_variance = use_variance 243 | self.dynamic = dynamic 244 | self.random_power = random_power 245 | self.power_range = power_range 246 | self.total_power_one = total_power_one 247 | self.normalization = normalization 248 | self.device = device 249 | self.base_L = base_L 250 | self.N_datapoints_per_nsrc = self.base_L * L 251 | self.N_datapoints = self.N_datapoints_per_nsrc * len(num_sources) 252 | self.cov_in = None 253 | self.cov_out = None 254 | self.source_number = None 255 | self.rho = rho 256 | self.mix_rho = mix_rho 257 | self.pid180 = np.pi/180 258 | self.array_manifold = ArrayManifold(d=d,lam=lam,N_sensors=N_sensors,gain_bias=gain_bias,phase_bias_deg=phase_bias_deg, 259 | position_bias=position_bias,mc_mag_angle=mc_mag_angle,device=device) 260 | dataset_folder = f'./covaug_datasets_{mode}/' 261 | 262 | path = os.path.join(dataset_folder,(f"{mode}_d={d}_lam={lam}_L={L}_N={N_sensors}_T={T_snapshots}_nsrc={str(num_sources)}_snr={str(snr_range)}_uni={int(snr_uniform)}" 263 | f"_spr={round(snr_prob[-1]/snr_prob[0],1)}_seed={seed}_rg={str(deg_range)}_sep={str(min_sep)}_rho={rho}_mix={int(mix_rho)}_dg={int(diag_src_cov)}" 264 | f"_uv={int(use_variance)}_baseL={base_L}_rp={int(random_power)}_pr={str(power_range)}_tpo={int(total_power_one)}" 265 | f"_nor={normalization}.hdf5").replace(' ','').replace('.','').replace(',','_').replace('[','').replace(']','')) 266 | 267 | if dynamic is False: 268 | if os.path.exists(path): 269 | print((f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] A dataset already exists at {path}" 270 | " (remove the existing dataset if you want to create a new one). Start loading...")) 271 | with h5py.File(path,'r') as file: 272 | self.cov_out = torch.from_numpy(file["cov_out"][:]) 273 | self.cov_in = torch.from_numpy(file["cov_in"][:]) 274 | self.source_number = torch.from_numpy(file["source_number"][:]) 275 | self.angles = torch.from_numpy(file["angles"][:]) 276 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Finished loading the dataset") 277 | else: 278 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] No dataset found at {path}, generate a new dataset for CovMap") 279 | self.angles = torch.zeros(self.N_datapoints,max(self.num_sources)) 280 | with torch.device(self.device): 281 | with tqdm(total=self.N_datapoints,leave=True) as pbar: 282 | for k in range(len(num_sources)): 283 | source_angles = torch.from_numpy(get_random_source_angles(deg_range=deg_range,min_sep=min_sep[k],num_sources=num_sources[k],num_datapoints=self.N_datapoints_per_nsrc,mode=mode,seed=seed)) 284 | self.angles[k * self.base_L * L:(k+1) * self.base_L * L,:num_sources[k]] = torch.sort(source_angles)[0] 285 | for j in range(self.L): 286 | if self.random_power is True: 287 | p = (power_range[1] - power_range[0]) * torch.rand(self.base_L,num_sources[k]) + power_range[0] 288 | p = p * p.size(1) / torch.sum(p,dim=1,keepdim=True) 289 | else: 290 | p = torch.ones(self.base_L,num_sources[k]) 291 | if snr_uniform is True: 292 | SNR = torch.rand(self.base_L,1,1) * (snr_range[1]-snr_range[0]) + snr_range[0] 293 | else: 294 | SNR = torch.from_numpy(np.random.choice(a=self.snr_list,size=self.base_L,p=self.snr_prob).astype(np.float32)).unsqueeze(1).unsqueeze(2) 295 | source_base = torch.randn(self.base_L,source_angles.shape[1],T_snapshots,dtype=torch.cfloat) # random source base 296 | noise_base = torch.randn(self.base_L,self.array_manifold.ULA_M_sensors,T_snapshots,dtype=torch.cfloat) # noise base 297 | cov_MRA, cov_ULA, _, _ = generate_batch_cov_MRA_ULA(self.array_manifold,self.rho,self.mix_rho,source_base,noise_base,T_snapshots,p,SNR,source_angles[j*self.base_L:(j+1)*self.base_L,:], 298 | use_variance,False,normalization,diag_src_cov,False,total_power_one) 299 | l = k * self.base_L * L + j * self.base_L 300 | if self.cov_out is None: 301 | self.cov_out = torch.zeros(self.N_datapoints,cov_ULA.shape[1],cov_ULA.shape[2],dtype=torch.complex64) 302 | self.cov_out[:self.base_L,:,:] = cov_ULA 303 | else: 304 | self.cov_out[l:l+self.base_L,:,:] = cov_ULA 305 | if self.cov_in is None: 306 | self.cov_in = torch.zeros(self.N_datapoints,cov_MRA.shape[1],cov_MRA.shape[2],dtype=torch.complex64) 307 | self.cov_in[:self.base_L,:,:] = cov_MRA 308 | else: 309 | self.cov_in[l:l+self.base_L,:,:] = cov_MRA 310 | if self.source_number is None: 311 | self.source_number = torch.zeros(self.N_datapoints,dtype=torch.int16) 312 | self.source_number[:self.base_L] = num_sources[k] 313 | else: 314 | self.source_number[l:l+self.base_L] = num_sources[k] 315 | pbar.update(self.base_L) 316 | if save_dataset is True: 317 | if not os.path.isdir(dataset_folder): 318 | os.mkdir(dataset_folder) 319 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Saving the dataset to path {path}") 320 | with h5py.File(path,'w') as file: 321 | file.create_dataset(name="cov_in",data=self.cov_in.numpy(),compression='gzip') 322 | file.create_dataset(name="cov_out",data=self.cov_out.numpy(),compression='gzip') 323 | file.create_dataset(name="source_number",data=self.source_number.numpy(),compression='gzip') 324 | file.create_dataset(name="angles",data=self.angles.numpy(),compression='gzip') 325 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Dataset saved at {path}") 326 | 327 | def __len__(self): 328 | return self.N_datapoints 329 | 330 | def __getitem__(self, idx): 331 | if self.dynamic is True: 332 | s = np.random.choice(a=self.num_sources,size=1).item() 333 | source_number = s + 1 334 | self.angles = torch.zeros(max(self.num_sources)) 335 | source_angles = torch.from_numpy(random_source_angles(self.deg_range,self.min_sep[s],source_number)*self.pid180).unsqueeze(0) 336 | self.angles[:source_number] = source_angles 337 | if self.random_power is True: 338 | p = (self.power_range[1] - self.power_range[0]) * torch.rand(1,source_number) + self.power_range[0] 339 | p = p * p.size(1) / torch.sum(p,dim=1,keepdim=True) 340 | else: 341 | p = torch.ones(1,source_number) 342 | source_base = torch.randn(1,source_angles.shape[1],self.T_snapshots,dtype=torch.cfloat) # random source base 343 | noise_base = torch.randn(1,self.array_manifold.ULA_M_sensors,self.T_snapshots,dtype=torch.cfloat) # noise base 344 | if self.snr_uniform is True: 345 | SNR = torch.rand(1,1,1) * (self.snr_range[1]-self.snr_range[0]) + self.snr_range[0] 346 | else: 347 | SNR = torch.from_numpy(np.random.choice(a=self.snr_list,size=1,p=self.snr_prob).astype(np.float32)).unsqueeze(1).unsqueeze(2) 348 | cov_MRA, cov_ULA, _, _ = generate_batch_cov_MRA_ULA(self.array_manifold,self.rho,self.mix_rho,source_base,noise_base,self.T_snapshots,p,SNR,source_angles, 349 | self.use_variance,False,self.normalization,self.diag_src_cov,False,self.total_power_one) 350 | cov_out = cov_ULA[0,:,:] 351 | cov_in = cov_MRA[0,:,:] 352 | else: 353 | cov_out = self.cov_out[idx,:,:] 354 | cov_in = self.cov_in[idx,:,:] 355 | source_number = self.source_number[idx] 356 | angles = self.angles[idx,:] 357 | return cov_in, cov_out, source_number, angles 358 | 359 | class Cov2DoADataset(Dataset): 360 | def __init__(self, 361 | mode: str, 362 | d: float, 363 | lam: float, 364 | N_sensors: int, 365 | T_snapshots: int, 366 | num_sources: int, 367 | snr_range: List[float], 368 | seed: int, 369 | deg_range: List[float], 370 | min_sep: float, 371 | L: int, 372 | base_L: int, 373 | gain_bias: List[float], 374 | phase_bias_deg: List[float], 375 | position_bias: List[float], 376 | mc_mag_angle: List[float], 377 | rho: float, 378 | mix_rho: bool, 379 | provide_noise_var: bool = False, 380 | random_power: bool = False, 381 | power_range: List[float] = [0.1,1.0], 382 | total_power_one: bool = False, 383 | evenly_distributed: bool = False, 384 | return_snapshots: bool = False, 385 | device: str = 'cpu', 386 | save_dataset: bool = False 387 | ): 388 | np.random.seed(seed) 389 | torch.manual_seed(seed) 390 | self.d = d 391 | self.lam = lam 392 | self.N_sensors = N_sensors 393 | self.T_snapshots = T_snapshots 394 | self.num_sources = num_sources 395 | self.snr_range = snr_range 396 | self.deg_range = deg_range 397 | self.min_sep = min_sep 398 | self.L = L 399 | self.base_L = base_L 400 | self.random_power = random_power 401 | self.power_range = power_range 402 | self.evenly_distributed = evenly_distributed 403 | if evenly_distributed is True: 404 | if self.L != 1: 405 | raise ValueError("L must be 1 because evenly_distributed is True (angles are not random)") 406 | self.device = device 407 | self.N_datapoints = self.base_L * L 408 | self.rho = rho 409 | self.mix_rho = mix_rho 410 | self.provide_noise_var = provide_noise_var # only meaningful when random_power is False 411 | if provide_noise_var is True and random_power is True: 412 | raise ValueError("provide_noise_var can only be True when random_power is False") 413 | self.array_manifold = ArrayManifold(d=d,lam=lam,N_sensors=N_sensors,gain_bias=gain_bias,phase_bias_deg=phase_bias_deg, 414 | position_bias=position_bias,mc_mag_angle=mc_mag_angle,device=device) 415 | if return_snapshots is True: 416 | self.data_in = torch.zeros(self.N_datapoints,N_sensors,T_snapshots,dtype=torch.complex64) 417 | else: 418 | self.data_in = torch.zeros(self.N_datapoints,N_sensors,N_sensors,dtype=torch.complex64) 419 | if provide_noise_var is True: 420 | self.noise_var = torch.zeros(self.N_datapoints,dtype=torch.float64) 421 | 422 | dataset_folder = f'./cov2DoA_datasets_{mode}/' 423 | 424 | path = (f"{dataset_folder}{mode}_d={d}_lam={lam}_N={N_sensors}_T={T_snapshots}_nsrc={num_sources}_snr={str(snr_range).replace(' ','')}" 425 | f"_seed={seed}_degr={str(deg_range).replace(' ','')}_sep={min_sep}_rho={rho}_mix={int(mix_rho)}_L={L}_baseL={base_L}" 426 | f"pnv={int(provide_noise_var)}_rp={int(random_power)}_tpo={int(total_power_one)}_ed={int(evenly_distributed)}_rsnap={int(return_snapshots)}.hdf5") 427 | 428 | if os.path.exists(path): 429 | tqdm.write(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] A dataset already exists at {path} (remove the existing dataset if you want to create a new one). Start loading...") 430 | with h5py.File(path,'r') as file: 431 | self.DoA = torch.from_numpy(file["DoA"][:]) 432 | self.data_in = torch.from_numpy(file["data_in"][:]) 433 | if provide_noise_var is True: 434 | self.noise_var = torch.from_numpy(file["noise_var"][:]) 435 | tqdm.write(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Finished loading the dataset") 436 | else: 437 | tqdm.write(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] No dataset found at {path}, generate a new dataset for DoA estimation") 438 | self.DoA = torch.zeros(self.N_datapoints,num_sources,dtype=torch.float64) 439 | with torch.device(self.device): 440 | if evenly_distributed is True: 441 | source_angle = np.linspace(deg_range[0],deg_range[1],num_sources+2,dtype=np.float32)[1:-1] * np.pi/180 442 | else: 443 | source_angles = get_random_source_angles(deg_range=deg_range,min_sep=min_sep,num_sources=num_sources,num_datapoints=self.L,mode=mode,seed=seed) 444 | with tqdm(total=self.N_datapoints,leave=True) as pbar: 445 | for j in range(self.L): 446 | if evenly_distributed is True: 447 | src_angle = torch.from_numpy(source_angle) 448 | else: 449 | src_angle = torch.from_numpy(source_angles[j,:]) 450 | repeat_src_angles = src_angle.unsqueeze(0).repeat(self.base_L,1) 451 | if self.random_power is True: 452 | p = (self.power_range[1] - self.power_range[0]) * torch.rand(self.base_L,num_sources) + self.power_range[0] 453 | p = p * p.size(1) / torch.sum(p,dim=1,keepdim=True) 454 | else: 455 | p = torch.ones(self.base_L,num_sources) 456 | SNR = torch.rand(self.base_L,1,1) * (snr_range[1]-snr_range[0]) + snr_range[0] 457 | source_base, noise_base = get_source_and_noise_random_base(self.base_L,num_sources,T_snapshots,self.array_manifold.ULA_M_sensors,seed,'eval') 458 | data_in, _, _, _ = generate_batch_cov_MRA_ULA(self.array_manifold,self.rho,self.mix_rho,source_base,noise_base,T_snapshots,p,SNR,repeat_src_angles, 459 | False,True,'disabled',False,return_snapshots,total_power_one) 460 | l = j * self.base_L 461 | self.DoA[l:l+self.base_L,:] = torch.sort(src_angle)[0].unsqueeze(0).repeat(self.base_L,1) 462 | self.data_in[l:l+self.base_L,:,:] = data_in 463 | if provide_noise_var is True: 464 | self.noise_var[l:l+self.base_L] = 1/(10**(SNR.squeeze()/10)) 465 | pbar.update(self.base_L) 466 | if save_dataset is True: 467 | if not os.path.isdir(dataset_folder): 468 | os.mkdir(dataset_folder) 469 | tqdm.write(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Saving the dataset to path {path}") 470 | with h5py.File(path,'w') as file: 471 | file.create_dataset(name="data_in",data=self.data_in.numpy(),compression='gzip') 472 | file.create_dataset(name="DoA",data=self.DoA.numpy(),compression='gzip') 473 | if provide_noise_var is True: 474 | file.create_dataset(name="noise_var",data=self.noise_var.numpy(),compression='gzip') 475 | tqdm.write(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [data.py] Dataset saved at {path}") 476 | 477 | def __len__(self): 478 | return self.N_datapoints 479 | 480 | def __getitem__(self, idx): 481 | label = self.DoA[idx,:] 482 | data = self.data_in[idx,:,:] 483 | if self.provide_noise_var is True: 484 | noise_var = self.noise_var[idx] 485 | return data, noise_var, label 486 | else: 487 | return data, label 488 | 489 | if __name__ == '__main__': 490 | from torch.utils.data import DataLoader 491 | import time 492 | 493 | d = 0.01 494 | lam = 0.02 495 | N_sensors = 5 496 | T_snapshots = 50 497 | num_sources = 5 498 | snr_range = [10,20] 499 | seed = 0 500 | deg_range = [30,150] 501 | min_sep = 10 502 | L = 2 503 | diag_src_cov = True 504 | use_variance = True 505 | dynamic = False 506 | provide_noise_var = True 507 | random_power = False 508 | power_range = [0.1,1.0] 509 | return_snapshots = True 510 | normalization = 'disabled' 511 | mode = 'testdryrun' 512 | base_L = 100 513 | gain_bias = [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1] 514 | phase_bias_deg = [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1] 515 | position_bias = [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1] 516 | mc_mag_angle = [0.1,0.1] 517 | rho = 0 518 | mix_rho = False 519 | save_dataset = False 520 | 521 | DoA_dataset = Cov2DoADataset(mode,d,lam,N_sensors,T_snapshots,num_sources,snr_range,seed,deg_range,min_sep,L,base_L,gain_bias,phase_bias_deg,position_bias,mc_mag_angle,rho,mix_rho,provide_noise_var,random_power,power_range,return_snapshots,device='cpu',save_dataset=save_dataset) 522 | dataloader = DataLoader(DoA_dataset,batch_size=512,shuffle=True,num_workers=0,pin_memory=True,drop_last=False) 523 | 524 | print(dataloader) 525 | print(len(dataloader)) 526 | print(len(DoA_dataset)) 527 | 528 | tic = time.time() 529 | for idx, (data,noise_var,label) in enumerate(dataloader): 530 | print(idx) 531 | print(data.shape) 532 | print(data[0,:,:]) 533 | print(noise_var) 534 | print(noise_var.shape) 535 | print(label.shape) 536 | toc = time.time() 537 | print(toc-tic) -------------------------------------------------------------------------------- /eval/crlb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | Implementations are based on the following two papers 7 | 8 | Wang, Mianzhi, Zhen Zhang, and Arye Nehorai. "Further results on the Cramér–Rao bound for sparse linear arrays." IEEE Transactions on Signal Processing 67, no. 6 (2019): 1493-1507. 9 | Wang, Mianzhi, and Arye Nehorai. "Coarrays, MUSIC, and the Cramér–Rao bound." IEEE Transactions on Signal Processing 65, no. 4 (2016): 933-946. 10 | """ 11 | import numpy as np 12 | import torch 13 | from utils import MRA 14 | from scipy import linalg 15 | 16 | def get_steering_matrix(N_sensors, source_angles, d, lam): 17 | source_angles = np.expand_dims(source_angles,0) 18 | # MRA 19 | MRA_sensor_locations, _, _ = MRA(N_sensors,d) 20 | p = (2*np.pi/lam)*np.expand_dims(np.array(MRA_sensor_locations,dtype=np.longdouble),1) 21 | # steering matrix 22 | imag = p @ np.cos(source_angles,dtype=np.longdouble) 23 | V = np.exp(1j*imag,dtype=np.complex128) 24 | # dV / dtheta 25 | imag_p = -p @ np.sin(source_angles,dtype=np.longdouble) 26 | dV = 1j*imag_p*V 27 | return V, dV 28 | 29 | # Wang, Mianzhi, Zhen Zhang, and Arye Nehorai. "Further results on the Cramér–Rao bound for sparse linear arrays." IEEE Transactions on Signal Processing 67, no. 6 (2019): 1493-1507. 30 | def uncorrelated_CRLB(source_angles, N_sensors, d, lam, SNR, n_snapshots, total_power_one): 31 | num_sources = len(source_angles) 32 | if total_power_one is True: 33 | noise_var = (num_sources/(10**(SNR/10))) 34 | else: 35 | noise_var = (1/(10**(SNR/10))) 36 | k = num_sources 37 | m = N_sensors 38 | p = np.ones(num_sources,dtype=np.longdouble) 39 | 40 | V, dV = get_steering_matrix(N_sensors,source_angles,d,lam) 41 | 42 | A = V 43 | DA = dV 44 | 45 | A_H = A.conj().T 46 | DA_H = DA.conj().T 47 | P = np.diag(np.ones(num_sources,dtype=np.longdouble)) 48 | R = (A * p) @ A_H + noise_var * np.eye(m) 49 | R_inv = np.linalg.inv(R.astype(np.complex128)) 50 | R_inv = 0.5 * (R_inv + R_inv.conj().T) 51 | 52 | DRD = DA_H @ R_inv @ DA 53 | DRA = DA_H @ R_inv @ A 54 | ARD = A_H @ R_inv @ DA 55 | ARA = A_H @ R_inv @ A 56 | 57 | FIM_tt = 2.0 * (DRD.conj() * (P @ ARA @ P) + DRA.conj() * (P @ ARD @ P)).real 58 | FIM_pp = (ARA.conj() * ARA).real 59 | R_inv2 = R_inv @ R_inv 60 | FIM_ss = np.trace(R_inv2).real 61 | FIM_tp = 2.0 * (DRA.conj() * (p[:, np.newaxis] * ARA)).real 62 | FIM_ts = 2.0 * (p * np.sum(DA.conj() * (R_inv2 @ A), axis=0)).real[:, np.newaxis] 63 | FIM_ps = np.sum(A.conj() * (R_inv2 @ A), axis=0).real[:, np.newaxis] 64 | FIM = np.block([ 65 | [FIM_tt, FIM_tp, FIM_ts], 66 | [FIM_tp.conj().T, FIM_pp, FIM_ps], 67 | [FIM_ts.conj().T, FIM_ps.conj().T, FIM_ss] 68 | ]) 69 | CRB = np.linalg.inv(FIM.astype(np.float64))[:k, :k] / n_snapshots 70 | return 0.5 * (CRB + CRB.T) 71 | 72 | # Wang, Mianzhi, and Arye Nehorai. "Coarrays, MUSIC, and the Cramér–Rao bound." IEEE Transactions on Signal Processing 65, no. 4 (2016): 933-946. 73 | def uncorrelated_CRLB2(source_angles, N_sensors, d, lam, SNR, n_snapshots, total_power_one): 74 | num_sources = len(source_angles) 75 | if total_power_one is True: 76 | noise_var = (num_sources/(10**(SNR/10))) 77 | else: 78 | noise_var = (1/(10**(SNR/10))) 79 | k = num_sources 80 | m = N_sensors 81 | P = np.diag(np.ones(num_sources,dtype=np.longdouble)) 82 | I = np.eye(N_sensors,dtype=np.longdouble) 83 | 84 | V, dV = get_steering_matrix(N_sensors,source_angles,d,lam) 85 | 86 | A = V 87 | DA = dV 88 | 89 | R = A @ P @ A.conj().T + noise_var * I 90 | i = np.expand_dims(I.flatten('F'),axis=1) 91 | A_d = linalg.khatri_rao(A.conj(),A) 92 | A_d_dot = linalg.khatri_rao(DA.conj(),A) + linalg.khatri_rao(A.conj(),DA) 93 | 94 | RtR_sqrt = np.linalg.inv(linalg.sqrtm(np.kron(R.T,R)).astype(np.complex128)) 95 | M_theta = RtR_sqrt @ A_d_dot @ P 96 | 97 | M_s = RtR_sqrt @ np.concatenate((A_d,i),axis=1) 98 | 99 | M_sHM_s = M_s.conj().T @ M_s 100 | M_sHM_s_inv = np.linalg.inv(M_sHM_s.astype(np.complex128)) 101 | M_sHM_s_inv = 0.5 * (M_sHM_s_inv + M_sHM_s_inv.conj().T) 102 | 103 | M_s_r = M_s @ M_sHM_s_inv @ M_s.conj().T 104 | P_M_s = np.eye(M_s_r.shape[0]) - M_s_r 105 | 106 | FIM = M_theta.conj().T @ P_M_s @ M_theta 107 | 108 | CRB = np.linalg.inv(FIM.astype(np.complex128)) / n_snapshots 109 | return 0.5 * (CRB + CRB.T).real 110 | 111 | if __name__ == '__main__': 112 | from data import random_source_angles 113 | import matplotlib.pyplot as plt 114 | np.random.seed(0) 115 | torch.manual_seed(0) 116 | # source DoA 117 | d = 0.01 118 | lam = 0.02 119 | N_sensors = 5 120 | SNR = -10 121 | deg_range = [30,150] 122 | # (N_sensors,min_sep) (4,18) (5,12 or 13) (6,9) 123 | num_sources = 9 124 | min_sep = 13 125 | # the two CRLB implementations above will match if the minimum separation is sufficiently large 126 | total_power_one = False 127 | print_results = True 128 | n_snapshots = 50 129 | M = 10000 130 | j = 0 131 | tr_crb_vec = np.zeros(M) 132 | for i in range(M): 133 | source_angles = random_source_angles(deg_range,min_sep,num_sources) * np.pi / 180 134 | crb = uncorrelated_CRLB(source_angles=source_angles,N_sensors=N_sensors, d=d, lam=lam,SNR=SNR,n_snapshots=n_snapshots,total_power_one=total_power_one) 135 | crb2 = uncorrelated_CRLB2(source_angles=source_angles,N_sensors=N_sensors, d=d, lam=lam,SNR=SNR,n_snapshots=n_snapshots,total_power_one=total_power_one) 136 | tr_crb = np.mean(np.diag(crb)) 137 | tr_crb_vec[i] = tr_crb 138 | tr_crb2 = np.mean(np.diag(crb2)) 139 | crb_diag = np.diag(crb) 140 | crb2_diag = np.diag(crb2) 141 | diff_crb_max_ratio = np.max(np.abs(crb_diag-crb2_diag))/np.max(np.abs(crb_diag)) 142 | is_less_than_0 = np.prod(crb_diag >= 0) 143 | if is_less_than_0 == 0: 144 | print("less than 0") 145 | if diff_crb_max_ratio > 0.02: 146 | j += 1 147 | if print_results is True: 148 | print(diff_crb_max_ratio) 149 | print(source_angles) 150 | print(crb_diag) 151 | print(crb2_diag) 152 | print(tr_crb) 153 | print(tr_crb2) 154 | print('-'*10) 155 | print(f'N_sensors={N_sensors}') 156 | print(f'mean tr_crb={np.mean(tr_crb_vec)} rad^2 ( RMSE= {np.sqrt(np.mean(tr_crb_vec))*180/np.pi} deg ) | {np.log10(np.mean(tr_crb_vec))}') 157 | print(f' min tr_crb={np.min(tr_crb_vec)} rad^2 ( RMSE= {np.sqrt(np.min(tr_crb_vec))*180/np.pi} deg ) | {np.log10(np.min(tr_crb_vec))}') 158 | print(f' max tr_crb={np.max(tr_crb_vec)} rad^2 ( RMSE= {np.sqrt(np.max(tr_crb_vec))*180/np.pi} deg ) | {np.log10(np.max(tr_crb_vec))}') 159 | print(f'Number of disagreements: {j}/{M}') 160 | crb_hist,bin_edges = np.histogram(tr_crb_vec,bins=5000) 161 | plt.semilogx(bin_edges[:-1],crb_hist) 162 | plt.grid() 163 | plt.xlabel('MSE (rad^2)') 164 | plt.ylabel('Probability density') 165 | plt.tight_layout() 166 | plt.show() -------------------------------------------------------------------------------- /eval/perf_dnn_n4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRT=./checkpoint/N4_M7_toep_WRN_16_8_t=200_v=60_n=4_loss=ToepSquare_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGF=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=AffInvDist_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 6 | OURS=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 7 | 8 | MIXRHO_DCRGA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=AffInvDist_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=10_mix=1_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 9 | MIXRHO_OURS=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=10_mix=1_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 10 | 11 | # SNR vs. MSE 12 | python3 performance.py --results_folder dnn_results --N_sensors 4 --num_sources_list 1 2 3 4 5 6 --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --cov_models $DCRT $DCRGF $DCRGA $OURS 13 | 14 | # number of snapshots vs. MSE 15 | python3 performance.py --results_folder dnn_results --N_sensors 4 --num_sources_list 1 4 6 --min_sep 4 4 4 --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --cov_models $DCRT $DCRGF $DCRGA $OURS 16 | 17 | # array imperfection parameter rho vs. MSE 18 | for i in 0.0 0.1 0.2 0.5 1.0 19 | do 20 | python3 performance.py --results_folder dnn_results --N_sensors 4 --num_sources_list 1 4 6 --min_sep 4 4 4 --SNR_list 20 --rho $i --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --cov_models $MIXRHO_DCRGA $MIXRHO_OURS 21 | done -------------------------------------------------------------------------------- /eval/perf_dnn_n4_other_distances.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRT=./checkpoint/N4_M7_toep_WRN_16_8_t=200_v=60_n=4_loss=ToepSquare_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGF=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=AffInvDist_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 6 | 7 | CHORDAL_PA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalChordalDistPA_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 8 | CHORDAL_OB=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalChordalDistOB_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 9 | PROJEC_PA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalProjectionDistPA_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 10 | PROJEC_OB=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalProjectionDistOB_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 11 | FUBINISTUDY_PA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalFubiniStudyDistPA_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 12 | FUBINISTUDY_OB=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalFubiniStudyDistOB_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 13 | PROCRUSTES_PA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalProcrustesDistPA_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 14 | SPECTRAL_PA=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSpectralDistPA_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 15 | GEODESIC=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 16 | 17 | # SNR vs. MSE 18 | python3 performance.py --results_folder dnn_results --N_sensors 4 --num_sources_list 1 2 3 4 5 6 --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --cov_models $DCRT $DCRGF $DCRGA $CHORDAL_PA $PROJEC_PA $FUBINISTUDY_PA $PROCRUSTES_PA $SPECTRAL_PA $GEODESIC -------------------------------------------------------------------------------- /eval/perf_dnn_n4_random_power.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRT_RP=./checkpoint/N4_M7_toep_WRN_16_8_t=200_v=60_n=4_loss=ToepSquare_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=1_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGF_RP=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=1_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGA_RP=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=AffInvDist_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=1_pr=0110_tpo=0_nor=disabled_oc=1 6 | OURS_RP=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=1_pr=0110_tpo=0_nor=disabled_oc=1 7 | 8 | # SNR vs. MSE 9 | python3 performance.py --random_power 1 --provide_noise_var 0 --results_folder dnn_results --N_sensors 4 --num_sources_list 1 2 3 4 5 6 --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --cov_models $DCRT_RP $DCRGF_RP $DCRGA_RP $OURS_RP -------------------------------------------------------------------------------- /eval/perf_dnn_n4_w_and_wo_crs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | W_CRS=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | WO_CRS=./checkpoint/N4_M7_WRN_16_8_t=200_v=60_n=4_loss=SignalSubspaceDistNoCrsGroup_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456_T=50_rg=30150_sep=333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | 6 | # SNR vs. MSE 7 | python3 performance.py --results_folder dnn_results --N_sensors 4 --num_sources_list 1 2 3 4 5 6 --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --cov_models $WO_CRS $W_CRS -------------------------------------------------------------------------------- /eval/perf_dnn_n5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRT=./checkpoint/N5_M10_toep_WRN_16_8_t=200_v=60_n=5_loss=ToepSquare_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGF=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGA=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=AffInvDist_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 6 | OURS=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 7 | 8 | MIXRHO_DCRGA=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=AffInvDist_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=10_mix=1_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 9 | MIXRHO_OURS=./checkpoint/N5_M10_WRN_16_8_t=200_v=60_n=5_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=10_mix=1_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 10 | 11 | END2END=./checkpoint/Branch_N5_M10_WRN_16_8_t=200_v=60_n=5_loss=BranchAngleMSE_mu=02_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=123456789_T=50_rg=30150_sep=333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 12 | 13 | # SNR vs. MSE 14 | python3 performance.py --results_folder dnn_results --cov_models $DCRT $DCRGF $DCRGA $OURS 15 | 16 | # number of snapshots vs. MSE 17 | python3 performance.py --results_folder dnn_results --num_sources_list 1 6 9 --min_sep 4 4 4 --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --cov_models $DCRT $DCRGF $DCRGA $OURS 18 | 19 | # array imperfection parameter rho vs. MSE 20 | for i in 0.0 0.1 0.2 0.5 1.0 21 | do 22 | python3 performance.py --results_folder dnn_results --num_sources_list 1 6 9 --min_sep 4 4 4 --SNR_list 20 --rho $i --cov_models $MIXRHO_DCRGA $MIXRHO_OURS 23 | done 24 | 25 | # gridless end-to-end approach vs. ours 26 | python3 performance.py --results_folder dnn_results --cov_models $OURS $END2END -------------------------------------------------------------------------------- /eval/perf_dnn_n6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DCRT=./checkpoint/N6_M14_toep_WRN_16_8_t=200_v=60_n=6_loss=ToepSquare_mu=005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=12345678910111213_T=50_rg=30150_sep=3333333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 4 | DCRGF=./checkpoint/N6_M14_WRN_16_8_t=200_v=60_n=6_loss=FrobeniusNorm_mu=001_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=12345678910111213_T=50_rg=30150_sep=3333333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 5 | DCRGA=./checkpoint/N6_M14_WRN_16_8_t=200_v=60_n=6_loss=AffInvDist3_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=12345678910111213_T=50_rg=30150_sep=3333333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 6 | OURS=./checkpoint/N6_M14_WRN_16_8_t=200_v=60_n=6_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=12345678910111213_T=50_rg=30150_sep=3333333333333_rho=00_mix=0_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 7 | 8 | MIXRHO_DCRGA=./checkpoint/N6_M14_WRN_16_8_t=200_v=60_n=6_loss=AffInvDist3_mu=0005_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=12345678910111213_T=50_rg=30150_sep=3333333333333_rho=10_mix=1_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=0_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 9 | MIXRHO_OURS=./checkpoint/N6_M14_WRN_16_8_t=200_v=60_n=6_loss=SignalSubspaceDist_mu=01_mo=05_bs=4096_epoch=50_wd=00_seed=0_nsrc=12345678910111213_T=50_rg=30150_sep=3333333333333_rho=10_mix=1_snr=-1020_uni=0_spr=1_dg=1_uv=0_crs=1_dy=0_rp=0_pr=0110_tpo=0_nor=disabled_oc=1 10 | 11 | # SNR vs. MSE 12 | python3 performance.py --results_folder dnn_results --N_sensors 6 --num_sources_list 1 2 3 5 7 9 11 12 13 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 --cov_models $DCRT $DCRGF $DCRGA $OURS 13 | 14 | # number of snapshots vs. MSE 15 | python3 performance.py --results_folder dnn_results --N_sensors 6 --num_sources_list 1 7 13 --min_sep 4 4 4 --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 --cov_models $DCRT $DCRGF $DCRGA $OURS 16 | 17 | # array imperfection parameter rho vs. MSE 18 | for i in 0.0 0.1 0.2 0.5 1.0 19 | do 20 | python3 performance.py --results_folder dnn_results --N_sensors 6 --num_sources_list 1 7 13 --min_sep 4 4 4 --SNR_list 20 --rho $i --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 --cov_models $MIXRHO_DCRGA $MIXRHO_OURS 21 | done -------------------------------------------------------------------------------- /eval/perf_opt_n4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 2 3 4 5 6 --rho 0 --results_folder results --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 3 | 4 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 4 6 --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --rho 0 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 5 | 6 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 4 6 --SNR_list 20 --rho 0.0 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 7 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 4 6 --SNR_list 20 --rho 0.1 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 8 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 4 6 --SNR_list 20 --rho 0.2 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 9 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 4 6 --SNR_list 20 --rho 0.5 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 10 | python3 performance.py --N_sensors 4 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 4 6 --SNR_list 20 --rho 1.0 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /eval/perf_opt_n4_random_power.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 performance.py --random_power 1 --provide_noise_var 0 --N_sensors 4 --SPA 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 2 3 4 5 6 --rho 0 --results_folder results --min_sep 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /eval/perf_opt_n5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 2 3 4 5 6 7 8 9 --rho 0 --results_folder results --min_sep 4 4 4 4 4 4 4 4 4 3 | 4 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 6 9 --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --rho 0 --results_folder results --min_sep 4 4 4 5 | 6 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 6 9 --SNR_list 20 --rho 0.0 --results_folder results --min_sep 4 4 4 7 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 6 9 --SNR_list 20 --rho 0.1 --results_folder results --min_sep 4 4 4 8 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 6 9 --SNR_list 20 --rho 0.2 --results_folder results --min_sep 4 4 4 9 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 6 9 --SNR_list 20 --rho 0.5 --results_folder results --min_sep 4 4 4 10 | python3 performance.py --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 6 9 --SNR_list 20 --rho 1.0 --results_folder results --min_sep 4 4 4 -------------------------------------------------------------------------------- /eval/perf_opt_n6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 2 3 5 7 9 11 12 13 --rho 0 --results_folder results --min_sep 4 4 4 4 4 4 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 3 | 4 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 7 13 --SNR_list 20 --T_snapshots_list 10 20 30 40 50 60 70 80 90 100 --rho 0 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 5 | 6 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 7 13 --SNR_list 20 --rho 0.0 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 7 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 7 13 --SNR_list 20 --rho 0.1 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 8 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 7 13 --SNR_list 20 --rho 0.2 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 9 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 7 13 --SNR_list 20 --rho 0.5 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 10 | python3 performance.py --N_sensors 6 --DA 1 --SS 1 --Wasserstein 1 --SPA 1 --SPA_noisevar 1 --num_random_thetas 100 --trials_per_theta 100 --num_sources_list 1 7 13 --SNR_list 20 --rho 1.0 --results_folder results --min_sep 4 4 4 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | https://arxiv.org/abs/2408.16605 7 | """ 8 | import torch 9 | 10 | def AngleMSE(outputs, target_cov, source_numbers, angles): 11 | rank = source_numbers[0] 12 | error = torch.sort(outputs[:,:rank])[0] - angles[:,:rank] 13 | return torch.mean(error ** 2, dim=1) 14 | 15 | def OrderedAngleMSE(outputs, target_cov, source_numbers, angles): 16 | rank = source_numbers[0] 17 | error = outputs[:,:rank] - angles[:,:rank] 18 | return torch.mean(error ** 2, dim=1) 19 | 20 | # loss function of the gridless end-to-end approach 21 | def BranchAngleMSE(outputs, target_cov, source_numbers, angles): 22 | rank = source_numbers[0] 23 | error = torch.sort(outputs[rank-1])[0] - angles[:,:rank] 24 | return torch.mean(error ** 2, dim=1) 25 | 26 | def BranchOrderedAngleMSE(outputs, target_cov, source_numbers, angles): 27 | rank = source_numbers[0] 28 | error = outputs[rank-1] - angles[:,:rank] 29 | return torch.mean(error ** 2, dim=1) 30 | 31 | # loss function of DCR-T 32 | def ToepSquare(outputs, targets, source_numbers, angles): 33 | first_row_err = outputs[:,0,:] - targets[:,0,:] 34 | return 0.5 * torch.mean(torch.abs(first_row_err * first_row_err.conj()), dim=1) 35 | 36 | # loss function of DCR-G-Fro 37 | def FrobeniusNorm(outputs, targets, source_numbers, angles): 38 | A = outputs - targets 39 | return torch.linalg.matrix_norm(A,'fro') 40 | 41 | # subspace representation learning | Geodesic distance 42 | def NoiseSubspaceDist(outputs, targets, source_numbers, angles): 43 | rank = source_numbers[0] # assume consistent rank sampling is enabled 44 | m = targets.size(-1) - rank 45 | _, AQ = torch.linalg.eigh(outputs) 46 | _, BQ = torch.linalg.eigh(targets) 47 | A = AQ[:,:,-m:] 48 | B = BQ[:,:,:-rank] 49 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 50 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 51 | return torch.sqrt(torch.sum(theta[:,:m] ** 2, dim=1)) 52 | 53 | # the main loss function of the subspace representation learning approach | Geodesic distance | see Section IV in the paper 54 | def SignalSubspaceDist(outputs, targets, source_numbers, angles): 55 | rank = source_numbers[0] # assume consistent rank sampling is enabled 56 | _, AQ = torch.linalg.eigh(outputs) 57 | _, BQ = torch.linalg.eigh(targets) 58 | A = AQ[:,:,-rank:] 59 | B = BQ[:,:,-rank:] 60 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 61 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 62 | return torch.sqrt(torch.sum(theta[:,:rank] ** 2, dim=1)) 63 | 64 | # subspace representation learning | Geodesic distance | without consistent rank sampling | direct approach 65 | def SignalSubspaceDistNoCrsDirect(outputs, targets, source_numbers, angles): 66 | batch_size = outputs.size(0) 67 | l = [] 68 | _, AQ = torch.linalg.eigh(outputs) 69 | _, BQ = torch.linalg.eigh(targets) 70 | for i in range(batch_size): 71 | rank = source_numbers[i] 72 | A = AQ[:,:,-rank:] 73 | B = BQ[:,:,-rank:] 74 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 75 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 76 | l.append(torch.sqrt(torch.sum(theta[:,:rank] ** 2, dim=1))) 77 | return torch.cat(l,dim=0) 78 | 79 | # subspace representation learning | Geodesic distance | without consistent rank sampling | grouping approach 80 | def SignalSubspaceDistNoCrsGroup(outputs, targets, source_numbers, angles): 81 | max_n_src = max(source_numbers).item() 82 | l = [] 83 | _, AQ = torch.linalg.eigh(outputs) 84 | _, BQ = torch.linalg.eigh(targets) 85 | for i in range(1,max_n_src+1): 86 | x = source_numbers == i 87 | if not True in x: 88 | continue 89 | rank = source_numbers[x][0] 90 | A = AQ[x,:,-rank:] 91 | B = BQ[x,:,-rank:] 92 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 93 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 94 | l.extend(torch.sqrt(torch.sum(theta[:,:rank] ** 2, dim=1))) 95 | l = [j.reshape(1) for j in l] 96 | return torch.cat(l,dim=0) 97 | 98 | # subspace representation learning | Chordal distance (or projection Frobenius norm distance) using principal angles 99 | def SignalChordalDistPA(outputs, targets, source_numbers, angles): 100 | rank = source_numbers[0] # assume consistent rank sampling is enabled 101 | _, AQ = torch.linalg.eigh(outputs) 102 | _, BQ = torch.linalg.eigh(targets) 103 | A = AQ[:,:,-rank:] 104 | B = BQ[:,:,-rank:] 105 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 106 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 107 | return torch.sqrt(torch.sum(torch.sin(theta[:,:rank]) ** 2, dim=1)) 108 | 109 | # subspace representation learning | Chordal distance (or projection Frobenius norm distance) using orthonormal bases 110 | def SignalChordalDistOB(outputs, targets, source_numbers, angles): 111 | rank = source_numbers[0] # assume consistent rank sampling is enabled 112 | _, AQ = torch.linalg.eigh(outputs) 113 | _, BQ = torch.linalg.eigh(targets) 114 | A = AQ[:,:,-rank:] 115 | B = BQ[:,:,-rank:] 116 | C = A @ A.conj().transpose(-2,-1) - B @ B.conj().transpose(-2,-1) 117 | return torch.linalg.matrix_norm(C,'fro') / torch.sqrt(torch.tensor(2)) 118 | 119 | # subspace representation learning | Projection 2-norm using principal angles 120 | def SignalProjectionDistPA(outputs, targets, source_numbers, angles): 121 | rank = source_numbers[0] # assume consistent rank sampling is enabled 122 | _, AQ = torch.linalg.eigh(outputs) 123 | _, BQ = torch.linalg.eigh(targets) 124 | A = AQ[:,:,-rank:] 125 | B = BQ[:,:,-rank:] 126 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 127 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 128 | return torch.sin(theta[:,rank-1]) 129 | 130 | # subspace representation learning | Projection 2-norm using orthonormal bases 131 | def SignalProjectionDistOB(outputs, targets, source_numbers, angles): 132 | rank = source_numbers[0] # assume consistent rank sampling is enabled 133 | _, AQ = torch.linalg.eigh(outputs) 134 | _, BQ = torch.linalg.eigh(targets) 135 | A = AQ[:,:,-rank:] 136 | B = BQ[:,:,-rank:] 137 | C = A @ A.conj().transpose(-2,-1) - B @ B.conj().transpose(-2,-1) 138 | return torch.linalg.matrix_norm(C,2) 139 | 140 | # subspace representation learning | Fubini-Study distance using principal angles 141 | def SignalFubiniStudyDistPA(outputs, targets, source_numbers, angles): 142 | rank = source_numbers[0] # assume consistent rank sampling is enabled 143 | _, AQ = torch.linalg.eigh(outputs) 144 | _, BQ = torch.linalg.eigh(targets) 145 | A = AQ[:,:,-rank:] 146 | B = BQ[:,:,-rank:] 147 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 148 | C = torch.prod(S[:,:rank],dim=1) 149 | return torch.acos(-torch.nn.functional.threshold(-C,-1,-1)) 150 | 151 | # subspace representation learning | Fubini-Study distance using orthonormal bases 152 | def SignalFubiniStudyDistOB(outputs, targets, source_numbers, angles): 153 | rank = source_numbers[0] # assume consistent rank sampling is enabled 154 | _, AQ = torch.linalg.eigh(outputs) 155 | _, BQ = torch.linalg.eigh(targets) 156 | A = AQ[:,:,-rank:] 157 | B = BQ[:,:,-rank:] 158 | C = A.conj().transpose(-2,-1) @ B 159 | D = torch.abs(torch.linalg.det(C)) 160 | return torch.acos(-torch.nn.functional.threshold(-D,-1,-1)) 161 | 162 | # subspace representation learning | Procrustes distance (or chordal Frobenius norm distance) using principal angles 163 | def SignalProcrustesDistPA(outputs, targets, source_numbers, angles): 164 | rank = source_numbers[0] # assume consistent rank sampling is enabled 165 | _, AQ = torch.linalg.eigh(outputs) 166 | _, BQ = torch.linalg.eigh(targets) 167 | A = AQ[:,:,-rank:] 168 | B = BQ[:,:,-rank:] 169 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 170 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 171 | return 2 * torch.sqrt(torch.sum(torch.sin(theta[:,:rank] / 2) ** 2, dim=1)) 172 | 173 | # subspace representation learning | Procrustes distance (or chordal Frobenius norm distance) using orthonormal bases 174 | # def SignalProcrustesDistOB(outputs, targets, source_numbers, angles): 175 | # rank = source_numbers[0] # assume consistent rank sampling is enabled 176 | # _, AQ = torch.linalg.eigh(outputs) 177 | # _, BQ = torch.linalg.eigh(targets) 178 | # A = AQ[:,:,-rank:] 179 | # B = BQ[:,:,-rank:] 180 | # U, _, V = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 181 | # C = A @ U - B @ V 182 | # return torch.linalg.matrix_norm(C,'fro') 183 | 184 | # subspace representation learning | Spectral distance (or chordal 2-norm distance) using principal angles 185 | def SignalSpectralDistPA(outputs, targets, source_numbers, angles): 186 | rank = source_numbers[0] # assume consistent rank sampling is enabled 187 | _, AQ = torch.linalg.eigh(outputs) 188 | _, BQ = torch.linalg.eigh(targets) 189 | A = AQ[:,:,-rank:] 190 | B = BQ[:,:,-rank:] 191 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 192 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 193 | return 2 * torch.sin(theta[:,rank-1] / 2) 194 | 195 | # subspace representation learning | Spectral distance (or chordal 2-norm distance) using orthonormal bases 196 | # def SignalSpectralDistOB(outputs, targets, source_numbers, angles): 197 | # rank = source_numbers[0] # assume consistent rank sampling is enabled 198 | # _, AQ = torch.linalg.eigh(outputs) 199 | # _, BQ = torch.linalg.eigh(targets) 200 | # A = AQ[:,:,-rank:] 201 | # B = BQ[:,:,-rank:] 202 | # U, _, V = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 203 | # C = A @ U - B @ V 204 | # return torch.linalg.matrix_norm(C,2) 205 | 206 | # subspace representation learning 207 | def AvgSubspaceDist(outputs, targets, source_numbers, angles): 208 | rank = source_numbers[0] # assume consistent rank sampling is enabled 209 | m = targets.size(-1) - rank 210 | _, AQ = torch.linalg.eigh(outputs) 211 | _, BQ = torch.linalg.eigh(targets) 212 | A_s = AQ[:,:,-rank:] 213 | B_s = BQ[:,:,-rank:] 214 | A_n = AQ[:,:,:-rank] 215 | B_n = BQ[:,:,:-rank] 216 | _, S_s, _= torch.linalg.svd(A_s.conj().transpose(-2,-1) @ B_s) 217 | theta_s = torch.acos(-torch.nn.functional.threshold(-S_s,-1,-1)) 218 | _, S_n, _= torch.linalg.svd(A_n.conj().transpose(-2,-1) @ B_n) 219 | theta_n = torch.acos(-torch.nn.functional.threshold(-S_n,-1,-1)) 220 | return 0.5 * torch.sqrt(torch.sum(theta_s[:,:rank] ** 2, dim=1)) + 0.5 * torch.sqrt(torch.sum(theta_n[:,:m] ** 2, dim=1)) 221 | 222 | # subspace representation learning 223 | def L2SubspaceDist(outputs, targets, source_numbers, angles): 224 | rank = source_numbers[0] # assume consistent rank sampling is enabled 225 | m = targets.size(-1) - rank 226 | _, AQ = torch.linalg.eigh(outputs) 227 | _, BQ = torch.linalg.eigh(targets) 228 | A_s = AQ[:,:,-rank:] 229 | B_s = BQ[:,:,-rank:] 230 | A_n = AQ[:,:,:-rank] 231 | B_n = BQ[:,:,:-rank] 232 | _, S_s, _= torch.linalg.svd(A_s.conj().transpose(-2,-1) @ B_s) 233 | theta_s = torch.acos(-torch.nn.functional.threshold(-S_s,-1,-1)) 234 | _, S_n, _= torch.linalg.svd(A_n.conj().transpose(-2,-1) @ B_n) 235 | theta_n = torch.acos(-torch.nn.functional.threshold(-S_n,-1,-1)) 236 | return torch.sqrt(torch.sum(theta_s[:,:rank] ** 2, dim=1) + torch.sum(theta_n[:,:m] ** 2, dim=1)) 237 | 238 | def logm(A: torch.Tensor): 239 | lam, V = torch.linalg.eig(A) 240 | V_inv = torch.inverse(V) 241 | log_A_prime = torch.diag(lam.log()) 242 | return V @ log_A_prime @ V_inv 243 | 244 | def inv_sqrtmh(A): # modified from https://github.com/pytorch/pytorch/issues/25481 245 | """Compute sqrtm(inv(A)) where A is a symmetric or Hermitian PD matrix (or a batch of matrices)""" 246 | L, Q = torch.linalg.eigh(A) 247 | zero = torch.zeros((), device=L.device, dtype=L.dtype) 248 | threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps 249 | L = L.where(L > threshold.unsqueeze(-1), zero) # zero out small components 250 | return (Q * (1/L.sqrt().unsqueeze(-2))) @ Q.mH 251 | 252 | # loss function of DCR-G-Aff 253 | def AffInvDist(outputs, targets, source_numbers, angles): 254 | delta = 1e-4 255 | I = torch.eye(outputs.size(-1),device=outputs.device).unsqueeze(0) 256 | targets = targets + delta * I 257 | targets_inv_sqrt = inv_sqrtmh(targets) 258 | A = torch.vmap(logm)(targets_inv_sqrt @ outputs @ targets_inv_sqrt) 259 | return torch.linalg.matrix_norm(A,'fro') 260 | 261 | # loss function of DCR-G-Aff for the 6-element MRA 262 | def AffInvDist3(outputs, targets, source_numbers, angles): # delta is 1e-3 263 | delta = 1e-3 264 | I = torch.eye(outputs.size(-1),device=outputs.device).unsqueeze(0) 265 | targets = targets + delta * I 266 | targets_inv_sqrt = inv_sqrtmh(targets) 267 | A = torch.vmap(logm)(targets_inv_sqrt @ outputs @ targets_inv_sqrt) 268 | return torch.linalg.matrix_norm(A,'fro') 269 | 270 | # What if only phi_1 is minimized? 271 | def SignalPhi1(outputs, targets, source_numbers, angles): 272 | rank = source_numbers[0] # assume consistent rank sampling is enabled 273 | _, AQ = torch.linalg.eigh(outputs) 274 | _, BQ = torch.linalg.eigh(targets) 275 | A = AQ[:,:,-rank:] 276 | B = BQ[:,:,-rank:] 277 | _, S, _ = torch.linalg.svd(A.conj().transpose(-2,-1) @ B) 278 | theta = torch.acos(-torch.nn.functional.threshold(-S,-1,-1)) 279 | return theta[:,0] 280 | 281 | def scale_invariant_targets(outputs, targets): 282 | targets_ri = torch.cat((targets.real.unsqueeze(-1),targets.imag.unsqueeze(-1)),-1) 283 | outputs_ri = torch.cat((outputs.real.unsqueeze(-1),outputs.imag.unsqueeze(-1)),-1) 284 | alphas = torch.sum(targets_ri * outputs_ri, dim=[-3,-2,-1], keepdim=True) / torch.sum(targets_ri * targets_ri, dim=[-3,-2,-1], keepdim=True) # this is a real number 285 | return alphas.squeeze(-1) * targets 286 | 287 | def SignalSubspaceTargets(A, source_numbers): 288 | rank = source_numbers[0] 289 | _,Q = torch.linalg.eigh(A) 290 | return Q[:,:,-rank:] @ Q[:,:,-rank:].transpose(-2,-1).conj() 291 | 292 | # ICASSP SI-Cov 293 | def SISDRFrobeniusNorm(outputs, targets, source_numbers, angles): 294 | targets = scale_invariant_targets(outputs, targets) 295 | return - 10 * torch.log10(torch.linalg.matrix_norm(targets,'fro') / FrobeniusNorm(outputs, targets, source_numbers, None) ) 296 | 297 | # ICASSP SI-Sig 298 | def SignalSISDRFrobeniusNorm(outputs, targets, source_numbers, angles): 299 | targets = SignalSubspaceTargets(targets,source_numbers) 300 | targets = scale_invariant_targets(outputs, targets) 301 | return - 10 * torch.log10(torch.linalg.matrix_norm(targets,'fro') / FrobeniusNorm(outputs, targets, source_numbers, None) ) 302 | 303 | loss_dict = { 304 | 'AngleMSE': AngleMSE, 305 | 'OrderedAngleMSE': OrderedAngleMSE, 306 | 'BranchAngleMSE': BranchAngleMSE, 307 | 'BranchOrderedAngleMSE': BranchOrderedAngleMSE, 308 | 'ToepSquare': ToepSquare, 309 | 'FrobeniusNorm': FrobeniusNorm, 310 | 'NoiseSubspaceDist': NoiseSubspaceDist, 311 | 'SignalSubspaceDist': SignalSubspaceDist, 312 | 'AvgSubspaceDist': AvgSubspaceDist, 313 | 'L2SubspaceDist': L2SubspaceDist, 314 | 'AffInvDist': AffInvDist, 315 | 'AffInvDist3': AffInvDist3, 316 | 'SignalChordalDistPA': SignalChordalDistPA, 317 | 'SignalChordalDistOB': SignalChordalDistOB, 318 | 'SignalProjectionDistPA': SignalProjectionDistPA, 319 | 'SignalProjectionDistOB': SignalProjectionDistOB, 320 | 'SignalFubiniStudyDistPA': SignalFubiniStudyDistPA, 321 | 'SignalFubiniStudyDistOB': SignalFubiniStudyDistOB, 322 | 'SignalProcrustesDistPA': SignalProcrustesDistPA, 323 | 'SignalSpectralDistPA': SignalSpectralDistPA, 324 | 'SignalSubspaceDistNoCrsDirect': SignalSubspaceDistNoCrsDirect, 325 | 'SignalSubspaceDistNoCrsGroup': SignalSubspaceDistNoCrsGroup, 326 | 'SignalPhi1': SignalPhi1, 327 | 'SISDRFrobeniusNorm': SISDRFrobeniusNorm, 328 | 'SignalSISDRFrobeniusNorm': SignalSISDRFrobeniusNorm 329 | } 330 | 331 | def is_EnEnH(loss): 332 | if 'Noise' in loss: 333 | return True 334 | else: 335 | return False -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | Modified from https://github.com/kjason/DnnNormTimeFreq4DoA/tree/main/SpeechEnhancement 7 | """ 8 | import argparse 9 | from datetime import datetime 10 | from data import CovMapDataset 11 | from train import TrainParam,TrainRegressor 12 | from utils import dir_path, check_device 13 | from models import model_dict 14 | from loss import loss_dict, is_EnEnH 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser(description='Train a DNN model to estimate the covariance matrix of the corresponding ULA from a sample covariance of an MRA',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser_resume_group = parser.add_mutually_exclusive_group() 19 | parser_resume_group.add_argument('--resume', dest='resume', action='store_true', help='resume from the last checkpoint',default=True) 20 | parser_resume_group.add_argument('--no-resume', dest='noresume', action='store_true', help='start a new training or overwrite the last one',default=False) 21 | parser.add_argument('--checkpoint_folder',default='./checkpoint/', type=dir_path, help='path to the checkpoint folder') 22 | parser.add_argument('--device', default='cuda:0', type=check_device, help='specify a CUDA or CPU device, e.g., cuda:0, cuda:1, or cpu') 23 | parser.add_argument('--optimizer', default='SGD', choices=['SGD','AdamW'], help='optimizer') 24 | parser.add_argument('--mu', default=0.5, type=float, help='learning rate') 25 | parser.add_argument('--momentum', default=0.5, type=float, help='momentum') 26 | parser_nesterov_group = parser.add_mutually_exclusive_group() 27 | parser_nesterov_group.add_argument('--nesterov', dest='nesterov', action='store_true', help='enable Nesterov momentum',default=True) 28 | parser_nesterov_group.add_argument('--no-nesterov', dest='nonesterov', action='store_true', help='disable Nesterov momentum',default=False) 29 | parser.add_argument('--batch_size', default=4096, type=int, help='training batch size') 30 | parser.add_argument('--val_batch_size', default=4096, type=int, help='validation batch size') 31 | parser.add_argument('--weight_decay', default=0.0, type=float, help='weight decay') 32 | parser.add_argument('--mu_scale', default=[1.0], nargs='+', type=float, help='learning rate scaling') 33 | parser.add_argument('--mu_epoch', default=[50], nargs='+', type=int, help='epochs to scale the learning rate (the last element is the total number of epochs)') 34 | parser.add_argument('--milestone', default=[5,10,20,30,40,50,80,100,150], nargs='+', type=int, help='the model trained after these epochs will be saved') 35 | parser.add_argument('--print_every_n_batch', default=10000, type=int, help='print the training status every n batch') 36 | parser.add_argument('--seed_list', default=[0], nargs='+', type=int, help='train models with different random seeds') 37 | parser.add_argument('--model', default='N5_M10_WRN_16_8', choices=list(model_dict.keys()), help='the DNN model') 38 | parser.add_argument('--loss', default='SignalSubspaceDist', choices=list(loss_dict.keys()), help='loss function') 39 | parser.add_argument('--train_L', default=200, type=int, help='train_L*base_L training datapoints for every number of sources') 40 | parser.add_argument('--val_L', default=60, type=int, help='val_L*base_L validation datapoints for every number of sources') 41 | parser.add_argument('--base_L', default=10000, type=int, help='base number of datapoints') 42 | parser.add_argument('--num_workers', default=1, type=int, help='Number of workers of the dataloader') 43 | parser.add_argument('--snr_range', default=[-10,20], nargs='+', type=float, help='SNR range') 44 | parser.add_argument('--snr_uniform', default=0, type=int, help='1 or 0. (1): uniformly sample from the snr_range, (0): use the specified snr_list and snr_prob') 45 | parser.add_argument('--snr_list', default=[i for i in range(-11,23,2)], nargs='+', type=int, help='List of SNRs for training and validation') 46 | parser.add_argument('--snr_prob_ratio', default=1, type=float, help='the ratio given by snr_prob(last)/snr_prob(first) where snr_prob increases/descreases linearly') 47 | parser.add_argument('--N_sensors', default=5, type=int, help='N-element MRA') 48 | parser.add_argument('--deg_range', default=[30,150], nargs='+', type=float, help='DoA estimation range in degrees (0 to 180)') 49 | parser.add_argument('--min_sep', default=[3,3,3,3,3,3,3,3,3], nargs='+', type=float, help='List of minimum separations in degrees for the n_sources_train/val (must be a positive number)') 50 | parser.add_argument('--T_snapshots', default=50, type=int, help='T snapshots') 51 | parser.add_argument('--n_sources_train', default=[1,2,3,4,5,6,7,8,9], nargs='+', type=int, help='Number of sources for training') 52 | parser.add_argument('--n_sources_val', default=[1,2,3,4,5,6,7,8,9], nargs='+', type=int, help='Number of sources for validation') 53 | parser.add_argument('--diag_src_cov', default=1, type=int, help='1 or 0. target is (1): the diagonal sample covariance matrix, (0): sample covariance matrix') 54 | parser.add_argument('--use_variance', default=0, type=int, help='Use the covariance (1) or diagonal sample covariance (0) for the target (only effective if diag_src_cov=1)') 55 | parser.add_argument('--dynamic', default=0, type=int, help='1 or 0. (1): dynamically generate training data, (0): generate a fixed training dataset') 56 | parser.add_argument('--consistent_rank_sampling', default=0, type=int, help='1 or 0. (1): use ConsistentRankBatchSampler, (0): use the default random sampling') 57 | parser.add_argument('--fp16', default=0, type=int, help='1 or 0. (1): use mixed precision training float16 and float32, (0): use the default float32') 58 | parser.add_argument('--onecycle', default=1, type=int, help='1 or 0. (1): use OneCycleLR, (0): use LambdaLR') 59 | parser.add_argument('--normalization', default='disabled', choices=['disabled','max','sensors'], help='how to normalize the covariance matrix') 60 | parser.add_argument('--random_power', default=0, type=int, help='1 or 0. (1): random source power, (0): equal source power') 61 | parser.add_argument('--power_range', default=[0.1,1.0], nargs='+', type=float, help='range of the random power') 62 | parser.add_argument('--total_power_one', default=0, type=int, help='1 or 0. (1): normalize the power of sources such that the total source power is one, (0): no normalization') 63 | parser.add_argument('--d', default=0.01, type=float, help='sensor spacing') 64 | parser.add_argument('--lam', default=0.02, type=float, help='wavelength lambda') 65 | parser.add_argument('--gain_bias', default=[0.0,0.2,0.2,0.2,0.2,0.2,-0.2,-0.2,-0.2,-0.2], nargs='+', type=float, help='Gain bias') 66 | parser.add_argument('--phase_bias_deg', default=[0,-30,-30,-30,-30,-30,30,30,30,30], nargs='+', type=float, help='Phase bias in degrees') 67 | parser.add_argument('--position_bias', default=[0.0,-0.2,-0.2,-0.2,-0.2,-0.2,0.2,0.2,0.2,0.2], nargs='+', type=float, help='Position bias') 68 | parser.add_argument('--mc_mag_angle', default=[0.3,60], nargs='+', type=float, help='magnitude and phase (in degrees) of the mutual coupling coefficient') 69 | parser.add_argument('--rho', default=0.0, type=float, help='A number in [0,1] describing the degree of array imperfections') 70 | parser.add_argument('--mix_rho', default=0, type=int, help='1 or 0. (1): mix different rhos in [0,rho], (0): use the fixed given rho') 71 | parser.add_argument('--save_dataset', default=0, type=int, help='1 or 0. (1): save the datasets, (0): not saving') 72 | 73 | args = parser.parse_args() 74 | 75 | train_seed = 1000 76 | val_seed = 2000 # must be different from the train_seed 77 | 78 | save_dataset = bool(args.save_dataset) 79 | 80 | d = args.d 81 | lam = args.lam 82 | N_sensors = args.N_sensors 83 | T_snapshots = args.T_snapshots 84 | train_num_sources = args.n_sources_train 85 | validation_num_sources = args.n_sources_val 86 | snr_range = args.snr_range 87 | snr_uniform = bool(args.snr_uniform) 88 | snr_list = args.snr_list 89 | snr_prob_ratio = args.snr_prob_ratio 90 | snr_prob_inc = ((1+snr_prob_ratio) * len(snr_list))/2 91 | snr_prob = [1/snr_prob_inc+(i*(snr_prob_ratio-1)/(snr_prob_inc*(len(snr_list)-1))) for i in range(len(snr_list))] 92 | deg_range = args.deg_range 93 | min_sep = args.min_sep 94 | train_L = args.train_L 95 | base_L = args.base_L 96 | val_L = args.val_L 97 | diag_src_cov = bool(args.diag_src_cov) 98 | use_variance = bool(args.use_variance) 99 | dynamic = bool(args.dynamic) 100 | consistent_rank_sampling = bool(args.consistent_rank_sampling) 101 | fp16 = bool(args.fp16) 102 | onecycle = bool(args.onecycle) 103 | normalization = args.normalization 104 | random_power = bool(args.random_power) 105 | power_range = args.power_range 106 | total_power_one = bool(args.total_power_one) 107 | optimizer = args.optimizer 108 | gain_bias = args.gain_bias 109 | phase_bias_deg = args.phase_bias_deg 110 | position_bias = args.position_bias 111 | mc_mag_angle = args.mc_mag_angle 112 | rho = args.rho 113 | mix_rho = bool(args.mix_rho) 114 | 115 | if len(min_sep) != len(train_num_sources): 116 | raise ValueError(f"len(min_sep)={len(min_sep)} does not match len(num_sources_list)={len(train_num_sources)}") 117 | 118 | trainset = CovMapDataset(mode='train',L=train_L,d=d,lam=lam,N_sensors=N_sensors,T_snapshots=T_snapshots,num_sources=train_num_sources, 119 | snr_range=snr_range,snr_uniform=snr_uniform,snr_list=snr_list,snr_prob=snr_prob,seed=train_seed,deg_range=deg_range, 120 | min_sep=min_sep,diag_src_cov=diag_src_cov,use_variance=use_variance,gain_bias=gain_bias,phase_bias_deg=phase_bias_deg, 121 | position_bias=position_bias,mc_mag_angle=mc_mag_angle,rho=rho,mix_rho=mix_rho,base_L=base_L,dynamic=dynamic, 122 | random_power=random_power,power_range=power_range,total_power_one=total_power_one,normalization=normalization, 123 | device='cpu',save_dataset=save_dataset) 124 | 125 | validationset = CovMapDataset(mode='validation',L=val_L,d=d,lam=lam,N_sensors=N_sensors,T_snapshots=T_snapshots,num_sources=validation_num_sources, 126 | snr_range=snr_range,snr_uniform=snr_uniform,snr_list=snr_list,snr_prob=snr_prob,seed=val_seed,deg_range=deg_range, 127 | min_sep=min_sep,diag_src_cov=diag_src_cov,use_variance=use_variance,gain_bias=gain_bias,phase_bias_deg=phase_bias_deg, 128 | position_bias=position_bias,mc_mag_angle=mc_mag_angle,rho=rho,mix_rho=mix_rho,base_L=base_L,dynamic=False, 129 | random_power=random_power,power_range=power_range,total_power_one=total_power_one,normalization=normalization, 130 | device='cpu',save_dataset=save_dataset) 131 | 132 | criterion = loss_dict[args.loss] 133 | 134 | for seed in args.seed_list: 135 | 136 | name = (f"{args.model}_t={train_L}_v={val_L}_n={N_sensors}_loss={args.loss}_mu={args.mu}_mo={args.momentum}_bs={args.batch_size}_epoch={args.mu_epoch[-1]}" 137 | f"_wd={args.weight_decay}_seed={seed}_nsrc={str(train_num_sources)}_T={T_snapshots}_rg={str(deg_range)}_sep={str([int(s) for s in min_sep])}_rho={rho}_mix={args.mix_rho}" 138 | f"_snr={str(snr_range)}_uni={args.snr_uniform}_spr={round(snr_prob_ratio,1)}_dg={args.diag_src_cov}_uv={args.use_variance}" 139 | f"_crs={args.consistent_rank_sampling}_dy={args.dynamic}_rp={args.random_power}_pr={str(power_range)}_tpo={int(total_power_one)}_nor={normalization}" 140 | f"_oc={args.onecycle}").replace(' ','').replace('.','').replace(',','').replace('[','').replace(']','') 141 | 142 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [main.py] start the training task {name}") 143 | 144 | array_setting = {'N_sensors': N_sensors, 'd': d, 'lam': lam, 'normalization': normalization, 'model': args.model,'EnEnH': is_EnEnH(args.loss)} 145 | 146 | tp = TrainParam( 147 | mu = args.mu, 148 | mu_scale = args.mu_scale, 149 | mu_epoch = args.mu_epoch, 150 | weight_decay = args.weight_decay, 151 | momentum = args.momentum, 152 | batch_size = args.batch_size, 153 | val_batch_size = args.val_batch_size, 154 | nesterov = args.nesterov and not args.nonesterov, 155 | onecycle = onecycle, 156 | optimizer = optimizer 157 | ) 158 | 159 | r = TrainRegressor( 160 | name = name, 161 | net = model_dict[args.model], 162 | tp = tp, 163 | trainset = trainset, 164 | validationset = validationset, 165 | criterion = criterion, 166 | device = args.device, 167 | seed = seed, 168 | resume = args.resume and not args.noresume, 169 | checkpoint_folder = args.checkpoint_folder, 170 | num_workers = args.num_workers, 171 | consistent_rank_sampling = consistent_rank_sampling, 172 | milestone = args.milestone, 173 | print_every_n_batch = args.print_every_n_batch, 174 | fp16 = fp16, 175 | meta_data = array_setting 176 | ).train() 177 | 178 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [main.py] training task {name} is completed\n") -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | Modified from https://github.com/kjason/DnnNormTimeFreq4DoA/tree/main/SpeechEnhancement 7 | 8 | https://arxiv.org/abs/2408.16605 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from utils import ComplexMat2RealImagMat, RealImagMat2ComplexMat, RealImagMat2GramComplexMat, RealVec2HermitianMat, RealVec2HermitianToeplitzMat 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, in_planes, mid_planes, out_planes, stride=1, bias=False, bn=True): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=bias) 19 | self.conv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=bias) 20 | if bn is True: 21 | self.bn1 = nn.BatchNorm2d(in_planes) 22 | self.bn2 = nn.BatchNorm2d(mid_planes) 23 | if stride != 1 or in_planes != out_planes: 24 | self.projection = nn.Sequential( 25 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) 26 | ) 27 | 28 | def forward(self, x): 29 | y = self.bn1(x) if hasattr(self,'bn1') else x 30 | y = F.relu(y) 31 | shortcut = self.projection(y) if hasattr(self, 'projection') else x 32 | y = self.conv1(y) 33 | y = self.bn2(y) if hasattr(self,'bn2') else y 34 | v = F.relu(y) 35 | out = self.conv2(v) + shortcut if hasattr(self,'conv2') else 0 36 | return out,v 37 | 38 | # see Fig. 2 in the paper 39 | class ResNet(nn.Module): 40 | def __init__(self, block, num_blocks, num_out_channels, num_mid_channels, M_sensor_ULA, bias, bn, out_type): 41 | super(ResNet, self).__init__() 42 | assert(len(num_blocks)==len(num_out_channels)), "size does not match between num_blocks and num_out_channels" 43 | assert(len(num_blocks)==len(num_mid_channels)), "size does not match between num_blocks and num_mid_channels" 44 | self.bias = bias 45 | self.bn = bn 46 | self.in_planes = num_out_channels[0] 47 | self.num_blocks = num_blocks 48 | self.M_sensor_ULA = M_sensor_ULA 49 | self.out_type = out_type 50 | self.expansion = nn.Conv2d(2, num_out_channels[0], kernel_size=3, stride=1, padding=1, bias=bias) 51 | self.stage = nn.ModuleList() 52 | self.stage.append(self._creat_block_seq(block, num_mid_channels[0], num_out_channels[0], num_blocks[0], stride=1)) 53 | for j in range(1,len(num_blocks)): 54 | self.stage.append(self._creat_block_seq(block, num_mid_channels[j], num_out_channels[j], num_blocks[j], stride=2)) 55 | if bn is True: 56 | self.final_bn = nn.BatchNorm2d(num_out_channels[-1]) 57 | if out_type == 'direct': 58 | self.linear = nn.Linear(num_out_channels[-1],M_sensor_ULA - 1) 59 | elif out_type == 'branch': 60 | self.linear = nn.ModuleList() 61 | for j in range(1,M_sensor_ULA): 62 | self.linear.append(nn.Linear(num_out_channels[-1],j)) 63 | elif out_type == 'hermitian': 64 | self.linear = nn.Linear(num_out_channels[-1],M_sensor_ULA**2) 65 | elif out_type == 'toep': 66 | self.linear = nn.Linear(num_out_channels[-1],2*M_sensor_ULA-1) 67 | elif out_type == 'gram': 68 | self.linear = nn.Linear(num_out_channels[-1],2*M_sensor_ULA**2) 69 | else: 70 | self.linear = nn.Linear(num_out_channels[-1],2*M_sensor_ULA**2) 71 | 72 | def _creat_block_seq(self, block, mid_planes, out_planes, num_blocks, stride): 73 | stride_seq = [stride] + [1]*(num_blocks-1) 74 | block_seq = nn.ModuleList() 75 | for stride in stride_seq: 76 | block_seq.append(block(self.in_planes, mid_planes, out_planes, stride, self.bias, self.bn)) 77 | self.in_planes = out_planes 78 | return block_seq 79 | 80 | def forward(self, x): 81 | x = ComplexMat2RealImagMat(x) 82 | out = self.expansion(x) 83 | for j in range(len(self.num_blocks)): 84 | for i in range(self.num_blocks[j]): 85 | out,_ = self.stage[j][i](out) 86 | out = self.final_bn(out) if hasattr(self,'final_bn') else out 87 | out = F.relu(out) 88 | out = F.avg_pool2d(out, out.size(2)) 89 | out = out.view(out.size(0), -1) 90 | if self.out_type == 'direct': 91 | out = self.linear(out) 92 | return out 93 | elif self.out_type == 'branch': 94 | return [self.linear[j](out) for j in range(self.M_sensor_ULA-1)] 95 | elif self.out_type == 'hermitian': 96 | out = self.linear(out) 97 | return RealVec2HermitianMat(out) 98 | elif self.out_type == 'toep': 99 | out = self.linear(out) 100 | return RealVec2HermitianToeplitzMat(out) 101 | elif self.out_type == 'gram': 102 | out = self.linear(out) 103 | return RealImagMat2GramComplexMat(torch.reshape(out,(-1,2,self.M_sensor_ULA,self.M_sensor_ULA))) 104 | else: 105 | out = self.linear(out) 106 | return RealImagMat2ComplexMat(torch.reshape(out,(-1,2,self.M_sensor_ULA,self.M_sensor_ULA))) 107 | 108 | # models 109 | 110 | def N4_M7_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],7,True,False,out_type='gram') 111 | def N5_M10_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],10,True,False,out_type='gram') 112 | def N6_M14_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],14,True,False,out_type='gram') 113 | 114 | def N4_M7_toep_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],7,True,False,out_type='toep') 115 | def N5_M10_toep_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],10,True,False,out_type='toep') 116 | def N6_M14_toep_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],14,True,False,out_type='toep') 117 | 118 | def N4_M7_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],7,True,False,out_type='gram') 119 | def N5_M10_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],10,True,False,out_type='gram') 120 | def N6_M14_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],14,True,False,out_type='gram') 121 | 122 | def N4_M7_toep_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],7,True,False,out_type='toep') 123 | def N5_M10_toep_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],10,True,False,out_type='toep') 124 | def N6_M14_toep_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],14,True,False,out_type='toep') 125 | 126 | def N5_M10_WRN_40_4(): return ResNet(BasicBlock,[6,6,6],[64,128,256],[64,128,256],10,True,False,out_type='gram') 127 | def N5_M10_WRN_28_10(): return ResNet(BasicBlock,[4,4,4],[160,320,640],[160,320,640],10,True,False,out_type='gram') 128 | 129 | def N5_M10_hermitian_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],10,True,False,out_type='hermitian') 130 | 131 | def Direct_N5_M10_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],10,True,False,out_type='direct') 132 | def Direct_N5_M10_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],10,True,False,out_type='direct') 133 | 134 | def Branch_N5_M10_ResNet_20(): return ResNet(BasicBlock,[3,3,3],[16,32,64],[16,32,64],10,True,False,out_type='branch') 135 | def Branch_N5_M10_WRN_16_8(): return ResNet(BasicBlock,[2,2,2],[128,256,512],[128,256,512],10,True,False,out_type='branch') 136 | 137 | model_dict = { 138 | 'N4_M7_ResNet_20': N4_M7_ResNet_20, 139 | 'N5_M10_ResNet_20': N5_M10_ResNet_20, 140 | 'N6_M14_ResNet_20': N6_M14_ResNet_20, 141 | 'N4_M7_toep_ResNet_20': N4_M7_toep_ResNet_20, 142 | 'N5_M10_toep_ResNet_20': N5_M10_toep_ResNet_20, 143 | 'N6_M14_toep_ResNet_20': N6_M14_toep_ResNet_20, 144 | 'N4_M7_WRN_16_8': N4_M7_WRN_16_8, 145 | 'N5_M10_WRN_16_8': N5_M10_WRN_16_8, 146 | 'N6_M14_WRN_16_8': N6_M14_WRN_16_8, 147 | 'N4_M7_toep_WRN_16_8': N4_M7_toep_WRN_16_8, 148 | 'N5_M10_toep_WRN_16_8': N5_M10_toep_WRN_16_8, 149 | 'N6_M14_toep_WRN_16_8': N6_M14_toep_WRN_16_8, 150 | 'N5_M10_WRN_40_4': N5_M10_WRN_40_4, 151 | 'N5_M10_WRN_28_10': N5_M10_WRN_28_10, 152 | 'N5_M10_hermitian_WRN_16_8' : N5_M10_hermitian_WRN_16_8, 153 | 'Direct_N5_M10_ResNet_20': Direct_N5_M10_ResNet_20, 154 | 'Direct_N5_M10_WRN_16_8': Direct_N5_M10_WRN_16_8, 155 | 'Branch_N5_M10_ResNet_20': Branch_N5_M10_ResNet_20, 156 | 'Branch_N5_M10_WRN_16_8': Branch_N5_M10_WRN_16_8 157 | } -------------------------------------------------------------------------------- /performance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | """ 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import os 10 | import re 11 | import scipy.io 12 | from datetime import datetime 13 | from tqdm import tqdm 14 | from utils import dir_path, file_path, check_device 15 | from eval.crlb import uncorrelated_CRLB 16 | from data import Cov2DoADataset 17 | from DoA import CovMRA2ULA_DA, CovMRA2ULA_SS 18 | from SDP.SDP import SDPCovMRA2ULA_Wasserstein_SDPT3, SDPCovMRA2ULA_SPA_SDPT3, SDPSnapshotMRA2ULA_ProxCov_SDPT3, SDPCovMRA2ULA_StructCovMLE_SDPT3 19 | from predict import Predictor 20 | 21 | def get_name(s: str): 22 | patterns = [r'/([^;]*)_t=',r't=([^;]*)_v=',r'loss=([^;]*)_mu=',r'mu=([^;]*)_mo=', 23 | r'mo=([^;]*)_bs=',r'bs=([^;]*)_epoch=',r'epoch=([^;]*)_wd',r'wd=([^;]*)_seed=',r'spr=([^;]*)_dg=',r'sep=([^;]*)_rho=',r'nsrc=([^;]*)_T=',r'T=([^;]*)_rg=', 24 | r'rg=([^;]*)_sep=',r'rho=([^;]*)_mix=',r'mix=([^;]*)_snr=',r'snr=([^;]*)_uni=',r'uni=([^;]*)_spr=',r'dg=([^;]*)_uv=', 25 | r'uv=([^;]*)_crs=',r'crs=([^;]*)_dy=',r'rp=([^;]*)_pr=',r'_pr=([^;]*)_tpo=',r'tpo=([^;]*)_nor=',r'nor=([^;]*)_oc='] 26 | name = "" 27 | for i in patterns: 28 | x = re.findall(i,s) 29 | if len(x) != 0: 30 | name += x[0] 31 | name = name.replace('./','').replace('/','_').replace('.','_').replace(',','').replace('[','').replace(']','').replace('-','') 32 | return name 33 | 34 | def display_evaluation_status(N_sensors: int, num_sources: int, T_snapshots: int, SNR: float, crb: float, trials: int, mse: np.ndarray, bias: np.ndarray, success: np.ndarray, num_random_thetas: int, j: int, i: int, k: int, rho: float): 35 | t = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 36 | tqdm.write(f'{t} [performance.py] N: {N_sensors} | # of trials: '+f'{trials}'.rjust(6)+' | # of sources: '+f'{num_sources}'.rjust(3)+' | T_snapshots: '+f'{T_snapshots}'.rjust(4) 37 | +' | SNR (dB): '+f'{SNR}'.rjust(6)+' | rho: '+f'{rho}'.rjust(3)+' '*46+' | mean CRB: '+f'{crb*(180/np.pi)**2:.3f}'.rjust(8)+' deg^2 ('+f'{crb:.7f}'.rjust(9)+' rad^2)') 38 | for m in mse.keys(): 39 | if num_random_thetas == 1: 40 | bias_str = '| bias: '+f'{bias[m][j,i,k]:.7f}'.rjust(9) 41 | else: 42 | bias_str = '' 43 | tqdm.write(f'{t} [performance.py] '+f'{N_sensors}-MRA '.rjust(7)+m.rjust(135)+' | '.ljust(8)+'MSE: '+f'{mse[m][j,i,k]*(180/np.pi)**2:.3f}'.rjust(8)+' deg^2 ('+f'{mse[m][j,i,k]:.7f}'.rjust(9)+' rad^2) '+bias_str+' | success: '+f'{round(success[m][j,i,k])}/{trials}'.rjust(8)) 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser(description='Performance evaluation of DoA estimation methods',formatter_class=argparse.ArgumentDefaultsHelpFormatter) 47 | parser.add_argument('--results_folder', default='./results/', type=dir_path, help='path to the results folder') 48 | parser.add_argument('--resume_from', default=None, type=file_path, help='resume from the previous unfinished checkpoint file path') 49 | parser.add_argument('--device', default='cuda:0', type=check_device, help='specify a CUDA or CPU device, e.g., cuda:0, cuda:1, or cpu') 50 | parser.add_argument('--batch_size', default=256, type=int, help='training batch size') 51 | parser.add_argument('--num_workers', default=1, type=int, help='Number of workers of the dataloader') 52 | parser.add_argument('--seed', default=9, type=int, help='random seed') 53 | parser.add_argument('--num_random_thetas', default=100, type=int, help='Number of random angles for evaluation') 54 | parser.add_argument('--trials_per_theta', default=100, type=int, help='Number of trials per random theta (number of random source signals per theta)') 55 | parser.add_argument('--SNR_list', default=[i for i in range(-10,21,2)], nargs='+', type=float, help='List of SNRs for evaluation') 56 | parser.add_argument('--T_snapshots_list', default=[i for i in range(50,51)], nargs='+', type=int, help='List of snapshots for evaluation') 57 | parser.add_argument('--num_sources_list', default=[1,2,3,4,5,6,7,8,9], nargs='+', type=int, help='Number of sources for evaluation') 58 | parser.add_argument('--provide_noise_var', default=1, type=int, help='1 or 0. (1): provide the ground truth noise variance, (0): not provide the noise variance') 59 | parser.add_argument('--random_power', default=0, type=int, help='1 or 0. (1): random source power, (0): equal source power') 60 | parser.add_argument('--power_range', default=[0.1,1.0], nargs='+', type=float, help='range of the random power') 61 | parser.add_argument('--total_power_one', default=0, type=int, help='1 or 0. (1): normalize the power of sources such that the total source power is one, (0): no normalization') 62 | parser.add_argument('--evenly_distributed', default=0, type=int, help='1 or 0. (1): source angles are evenly distributed, (0): randomly distributed') 63 | parser.add_argument('--return_snapshots', default=0, type=int, help='1 or 0. (1): return snapshots as input, (0): return covariance matrices as input') 64 | parser.add_argument('--d', default=0.01, type=float, help='sensor spacing') 65 | parser.add_argument('--lam', default=0.02, type=float, help='wavelength lambda') 66 | parser.add_argument('--N_sensors', default=5, type=int, help='N-element MRA') 67 | parser.add_argument('--deg_range', default=[30,150], nargs='+', type=int, help='DoA estimation range in degrees (0 to 180)') 68 | parser.add_argument('--min_sep', default=[4,4,4,4,4,4,4,4,4], nargs='+', type=float, help='List of minimum separations in degrees for the n_sources_train/val (must be a positive number)') 69 | parser.add_argument('--save_dataset', default=1, type=int, help='1 or 0. (1): save all datasets that are going to be generated, (0): not save') 70 | parser.add_argument('--gain_bias', default=[0.0,0.2,0.2,0.2,0.2,0.2,-0.2,-0.2,-0.2,-0.2], nargs='+', type=float, help='Gain bias') 71 | parser.add_argument('--phase_bias_deg', default=[0,-30,-30,-30,-30,-30,30,30,30,30], nargs='+', type=float, help='Phase bias in degrees') 72 | parser.add_argument('--position_bias', default=[0.0,-0.2,-0.2,-0.2,-0.2,-0.2,0.2,0.2,0.2,0.2], nargs='+', type=float, help='Position bias') 73 | parser.add_argument('--mc_mag_angle', default=[0.3,60], nargs='+', type=float, help='magnitude and phase (in degrees) of the mutual coupling coefficient') 74 | parser.add_argument('--rho', default=0.0, type=float, help='A number in [0,1] describing the degree of array imperfections') 75 | parser.add_argument('--SPA', default=0, type=int, help='1 or 0. (1): evaluate the performance of SPA, (0): not evaluting SPA') 76 | parser.add_argument('--SPA_noisevar', default=0, type=int, help='1 or 0. (1): evaluate the performance of SPA using the noise variance, (0): not evaluting SPA using the noise variance') 77 | parser.add_argument('--Wasserstein', default=0, type=int, help='1 or 0. (1): evaluate the performance of Wasserstein, (0): not evaluting Wasserstein') 78 | parser.add_argument('--ProxCov', default=0, type=int, help='1 or 0. (1): evaluate the performance of ProxCov, (0): not evaluting ProxCov') 79 | parser.add_argument('--ProxCov_epsilon', default=1e-5, type=float, help='the epsilon parameter of ProxCov') 80 | parser.add_argument('--StructCovMLE', default=0, type=int, help='1 or 0. (1): evaluate the performance of StructCovMLE, (0): not evaluting StructCovMLE') 81 | parser.add_argument('--StructCovMLE_noisevar', default=0, type=int, help='1 or 0. (1): evaluate the performance of StructCovMLE using the noise variance, (0): not evaluting StructCovMLE using the noise variance') 82 | parser.add_argument('--StructCovMLE_epsilon', default=1e-3, type=float, help='the threshold of the relative change as the stopping criterion of StructCovMLE') 83 | parser.add_argument('--StructCovMLE_max_iter', default=100, type=int, help='the maximum number of iterations of StructCovMLE') 84 | parser.add_argument('--DA', default=0, type=int, help='1 or 0. (1): evaluate the performance of DA, (0): not evaluting DA') 85 | parser.add_argument('--SS', default=0, type=int, help='1 or 0. (1): evaluate the performance of SS, (0): not evaluting SS') 86 | parser.add_argument('--cov_models', nargs='+', type=str, help='Path to the DNN model checkpoint folder', default=None) 87 | 88 | args = parser.parse_args() 89 | save_dataset = bool(args.save_dataset) 90 | results_folder = args.results_folder 91 | seed = args.seed 92 | d = args.d 93 | lam = args.lam 94 | N_sensors = args.N_sensors 95 | deg_range = args.deg_range 96 | min_sep = args.min_sep 97 | # use 8,8,8,8,10,11,11,12,13 if meaningful CRBs are needed (the minimum separations need to be sufficiently large) 98 | provide_noise_var = bool(args.provide_noise_var) 99 | random_power = bool(args.random_power) 100 | power_range = args.power_range 101 | total_power_one = bool(args.total_power_one) 102 | evenly_distributed = bool(args.evenly_distributed) 103 | return_snapshots = bool(args.return_snapshots) 104 | num_sources_list = args.num_sources_list 105 | T_snapshots_list = args.T_snapshots_list 106 | SNR_list = args.SNR_list 107 | num_random_thetas = args.num_random_thetas 108 | trials_per_theta = args.trials_per_theta 109 | device=args.device 110 | batch_size = args.batch_size 111 | ProxCov_epsilon = args.ProxCov_epsilon 112 | StructCovMLE_epsilon = args.StructCovMLE_epsilon 113 | StructCovMLE_max_iter = args.StructCovMLE_max_iter 114 | 115 | if len(min_sep) != len(num_sources_list): 116 | raise ValueError(f"len(min_sep)={len(min_sep)} does not match len(num_sources_list)={len(num_sources_list)}") 117 | 118 | gain_bias = args.gain_bias 119 | phase_bias_deg = args.phase_bias_deg 120 | position_bias = args.position_bias 121 | mc_mag_angle = args.mc_mag_angle 122 | rho = args.rho 123 | 124 | if evenly_distributed is True: 125 | if num_random_thetas != 1: 126 | raise ValueError("num_random_thetas should be 1 because evenly_distributed is True") 127 | 128 | # DoA predictors 129 | methods = {} 130 | if args.DA: 131 | DA = CovMRA2ULA_DA() 132 | methods.update({'DA': DA}) 133 | if args.SS: 134 | SS = CovMRA2ULA_SS() 135 | methods.update({'SS': SS}) 136 | if args.Wasserstein: 137 | Wasserstein = SDPCovMRA2ULA_Wasserstein_SDPT3(N_sensors) 138 | methods.update({'Wasserstein': Wasserstein}) 139 | if args.SPA: 140 | SPA = SDPCovMRA2ULA_SPA_SDPT3(N_sensors,False) 141 | methods.update({'SPA': SPA}) 142 | if args.SPA_noisevar: 143 | SPA_noisevar = SDPCovMRA2ULA_SPA_SDPT3(N_sensors,True) 144 | methods.update({'SPA_noisevar': SPA_noisevar}) 145 | if args.ProxCov: 146 | ProxCov = SDPSnapshotMRA2ULA_ProxCov_SDPT3(N_sensors,ProxCov_epsilon) 147 | methods.update({'ProxCov': ProxCov}) 148 | if args.StructCovMLE: 149 | StructCovMLE = SDPCovMRA2ULA_StructCovMLE_SDPT3(N_sensors,StructCovMLE_epsilon,StructCovMLE_max_iter,False) 150 | methods.update({'StructCovMLE': StructCovMLE}) 151 | if args.StructCovMLE_noisevar: 152 | StructCovMLE_noisevar = SDPCovMRA2ULA_StructCovMLE_SDPT3(N_sensors,StructCovMLE_epsilon,StructCovMLE_max_iter,True) 153 | methods.update({'StructCovMLE_noisevar': StructCovMLE_noisevar}) 154 | if args.cov_models != None: 155 | for m in args.cov_models: 156 | name = get_name(m) 157 | cov_model = Predictor(m,device=device) 158 | if cov_model.isfunctional: 159 | methods.update({name: cov_model}) 160 | if len(methods.keys()) == 0: 161 | raise ValueError("No method. Stop performance evaluation.") 162 | else: 163 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [performance.py] methods to be evaluated: {list(methods.keys())}") 164 | 165 | # total number of trials 166 | trials_per_src_num_and_snr = num_random_thetas*trials_per_theta 167 | total_trials = trials_per_src_num_and_snr*len(num_sources_list)*len(T_snapshots_list)*len(SNR_list) 168 | 169 | # initialize result placeholders 170 | finished = False 171 | 172 | if args.resume_from != None: 173 | checkpoint = np.load(args.resume_from,allow_pickle=True).item() 174 | # check 175 | if ( 176 | checkpoint['d'] != d or checkpoint['lam'] != lam or checkpoint['N_sensors'] != N_sensors or checkpoint['deg_range'] != deg_range or checkpoint['min_sep'] != min_sep or checkpoint['random_power'] != random_power or 177 | checkpoint['total_power_one'] != total_power_one or checkpoint['evenly_distributed'] != evenly_distributed or checkpoint['trials_per_theta'] != trials_per_theta or checkpoint['num_random_thetas'] != num_random_thetas or 178 | checkpoint['num_sources_list'] != num_sources_list or checkpoint['SNR_list'] != SNR_list or checkpoint['T_snapshots_list'] != T_snapshots_list or checkpoint['rho'] != rho or checkpoint['gain_bias'] != gain_bias or 179 | checkpoint['phase_bias_deg'] != phase_bias_deg or checkpoint['position_bias'] != position_bias or checkpoint['mc_mag_angle'] != mc_mag_angle or checkpoint['ProxCov_epsilon'] != ProxCov_epsilon or 180 | checkpoint['StructCovMLE_epsilon'] != StructCovMLE_epsilon or checkpoint['StructCovMLE_max_iter'] != StructCovMLE_max_iter 181 | ): 182 | raise ValueError("Hyperparameters from resume_from do not match with the current setting") 183 | # load 184 | prev_result_path = args.resume_from 185 | crb_accumulated = checkpoint['crb_accumulated'] 186 | crb_total = checkpoint['crb_total'] 187 | crb_mean = checkpoint['crb_mean'] 188 | method_accumulated = {} 189 | method_total = {} 190 | method_success = {} 191 | method_mse = {} 192 | method_accu_doa = {} 193 | method_bias = {} 194 | for m in methods.keys(): 195 | method_mse.update({m:checkpoint[m]}) 196 | method_success.update({m:checkpoint[f'suc_{m}']}) 197 | method_accumulated.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 198 | method_total.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 199 | method_accu_doa.update({m:checkpoint[f'adoa_{m}']}) 200 | method_bias.update({m:checkpoint[f'bias_{m}']}) 201 | t0 = checkpoint['t0'] 202 | else: 203 | prev_result_path = None 204 | crb_accumulated = np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list))) 205 | crb_total = np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list))) 206 | crb_mean = np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list))) 207 | method_accumulated = {} 208 | method_total = {} 209 | method_success = {} 210 | method_mse = {} 211 | method_accu_doa = {} 212 | method_bias = {} 213 | for m in methods.keys(): 214 | method_accumulated.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 215 | method_total.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 216 | method_success.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 217 | method_mse.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 218 | method_accu_doa.update({m:np.zeros((max(num_sources_list),len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 219 | method_bias.update({m:np.zeros((len(T_snapshots_list),len(num_sources_list),len(SNR_list)))}) 220 | t0 = datetime.now().strftime('%Y-%m-%d_%H_%M_%S') 221 | 222 | # evaluation 223 | with tqdm(total=total_trials) as pbar: 224 | for j in range(len(T_snapshots_list)): 225 | for i in range(len(num_sources_list)): 226 | for k in range(len(SNR_list)): 227 | # skip if the results are available in the checkpoint 228 | if crb_total[j,i,k] != 0: 229 | pbar.update(n=trials_per_src_num_and_snr) 230 | continue 231 | # create or load a dataset 232 | eval_dataset = Cov2DoADataset(mode='eval',d=d,lam=lam,N_sensors=N_sensors,T_snapshots=T_snapshots_list[j],num_sources=num_sources_list[i],snr_range=[SNR_list[k],SNR_list[k]], 233 | seed=seed,deg_range=deg_range,min_sep=min_sep[i],L=num_random_thetas,base_L=trials_per_theta,gain_bias=gain_bias,phase_bias_deg=phase_bias_deg, 234 | position_bias=position_bias,mc_mag_angle=mc_mag_angle,rho=rho,mix_rho=False,provide_noise_var=provide_noise_var,random_power=random_power,power_range=power_range, 235 | total_power_one=total_power_one,evenly_distributed=evenly_distributed,return_snapshots=return_snapshots,device='cpu',save_dataset=save_dataset) 236 | dataloader = torch.utils.data.DataLoader(eval_dataset,batch_size=batch_size,shuffle=False,num_workers=0,pin_memory=True,drop_last=False) 237 | 238 | # evaluate each method on the given dataset 239 | with torch.no_grad(): 240 | for idx, x in enumerate(dataloader): 241 | if provide_noise_var is True: 242 | data_in, noise_var, DoA_gt = x[0], x[1], x[2] 243 | else: 244 | data_in, DoA_gt = x[0], x[1] 245 | for m in methods.keys(): 246 | if provide_noise_var is True: 247 | DoA_est, successes = methods[m].get_DoA(data_in,num_sources_list[i],return_snapshots,noise_var) 248 | else: 249 | DoA_est, successes = methods[m].get_DoA(data_in,num_sources_list[i],return_snapshots) 250 | method_accu_doa[m][:num_sources_list[i],j,i,k] += torch.sum(DoA_est[successes],dim=0).numpy() 251 | method_success[m][j,i,k] += sum(successes) 252 | method_total[m][j,i,k] += DoA_est.shape[1] * sum(successes) 253 | method_accumulated[m][j,i,k] += torch.sum((DoA_est[successes] - DoA_gt[successes]) ** 2).numpy() 254 | batch_size = DoA_gt.size(0) 255 | for b in range(batch_size): 256 | temp = np.diag(uncorrelated_CRLB(DoA_gt[b,:], N_sensors, d, lam, SNR_list[k], T_snapshots_list[j],total_power_one)) 257 | crb_total[j,i,k] += temp.size 258 | crb_accumulated[j,i,k] += np.sum(temp) 259 | pbar.update(n=batch_size) 260 | 261 | # compute the MSE and bias 262 | for m in methods.keys(): 263 | method_mse[m][j,i,k] = method_accumulated[m][j,i,k] / method_total[m][j,i,k] 264 | method_bias[m][j,i,k] = np.sum(abs(method_accu_doa[m][:num_sources_list[i],j,i,k] / method_success[m][j,i,k] - DoA_gt[0,:].numpy())) / num_sources_list[i] 265 | crb_mean[j,i,k] = crb_accumulated[j,i,k] / crb_total[j,i,k] 266 | 267 | # display current results 268 | display_evaluation_status(N_sensors,num_sources_list[i],T_snapshots_list[j],SNR_list[k],crb_mean[j,i,k],len(eval_dataset),method_mse,method_bias,method_success,num_random_thetas,j,i,k,rho) 269 | 270 | # save settings and results 271 | t = datetime.now().strftime('%m-%d_%H_%M_%S') 272 | if j + 1 == len(T_snapshots_list) and i + 1 == len(num_sources_list) and k + 1 == len(SNR_list): 273 | finished = True 274 | result_path = os.path.join(results_folder, 275 | (f"N={N_sensors}_nM={len(methods.keys())}_sep={str(min_sep)}_rg={str(deg_range)}_rp={int(random_power)}_tpo={int(total_power_one)}" 276 | f"_ed={int(evenly_distributed)}_tpt={trials_per_theta}_nt={num_random_thetas}_nSrc={str(num_sources_list)}_SNR={str(SNR_list)}" 277 | f"_T={str(T_snapshots_list)}_rho={str(rho)}_pnv={int(provide_noise_var)}_t0={t0}_t={t}_k={k+1}" 278 | f"_fin={int(finished)}").replace(' ','').replace(',','_').replace('[','').replace(']','')) 279 | result = { 280 | 'crb_accumulated': crb_accumulated, 281 | 'crb_total': crb_total, 282 | 'crb_mean': crb_mean, 283 | 'list_of_methods': list(method_mse.keys()), 284 | 'd': d, 285 | 'lam': lam, 286 | 'N_sensors': N_sensors, 287 | 'deg_range': deg_range, 288 | 'min_sep': min_sep, 289 | 'random_power': random_power, 290 | 'total_power_one': total_power_one, 291 | 'evenly_distributed': evenly_distributed, 292 | 'trials_per_theta': trials_per_theta, 293 | 'num_random_thetas': num_random_thetas, 294 | 'num_sources_list': num_sources_list, 295 | 'SNR_list': SNR_list, 296 | 'T_snapshots_list': T_snapshots_list, 297 | 'rho': rho, 298 | 'gain_bias': gain_bias, 299 | 'phase_bias_deg': phase_bias_deg, 300 | 'position_bias': position_bias, 301 | 'mc_mag_angle': mc_mag_angle, 302 | 'ProxCov_epsilon': ProxCov_epsilon, 303 | 'StructCovMLE_epsilon': StructCovMLE_epsilon, 304 | 'StructCovMLE_max_iter': StructCovMLE_max_iter, 305 | 't0': t0, 306 | 't': t 307 | } 308 | result.update(method_mse) 309 | result.update({f'suc_{k}': v for k,v in method_success.items()}) 310 | result.update({f'adoa_{k}': v for k,v in method_accu_doa.items()}) 311 | result.update({f'bias_{k}': v for k,v in method_bias.items()}) 312 | scipy.io.savemat(result_path+'.mat', result) 313 | np.save(result_path+'.npy', result) 314 | tqdm.write(f'{t} [performance.py] Results saved at {result_path}.npy and {result_path}.mat') 315 | if prev_result_path is not None and os.path.exists(prev_result_path+'.npy'): 316 | os.remove(prev_result_path+'.npy') 317 | os.remove(prev_result_path+'.mat') 318 | tqdm.write(f'{t} [performance.py] Intermediate results at {prev_result_path}.npy and {prev_result_path}.mat were removed') 319 | prev_result_path = result_path -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | """ 6 | import os 7 | import torch 8 | from models import model_dict 9 | from DoA import BasePredictor 10 | from utils import data_in_preprocess, cov_normalize 11 | 12 | class Predictor(BasePredictor): 13 | def __init__(self, model_folder: str, device: str ='cuda:0'): 14 | meta_path = os.path.join(model_folder,'meta_data.pt') 15 | model_path = os.path.join(model_folder,'best_model.pt') 16 | if not (os.path.isfile(meta_path) and os.path.isfile(model_path)): 17 | self.isfunctional = False 18 | return 19 | else: 20 | self.isfunctional = True 21 | self.model_folder = model_folder 22 | self.device = device 23 | array_data = torch.load(meta_path) 24 | self.name = array_data['model'] 25 | self.N_sensors = array_data['N_sensors'] 26 | self.normalization = array_data['normalization'] 27 | self.EnEnH = array_data['EnEnH'] if 'EnEnH' in array_data else False 28 | 29 | pretrained_model = torch.load(model_path,map_location=device) 30 | self.net = model_dict[self.name]() 31 | self.out_type = self.net.out_type 32 | self.net.load_state_dict(pretrained_model,strict=True) 33 | self.net = self.net.to(self.device) 34 | self.net.eval() 35 | 36 | def get_ULA_cov(self, cov: torch.Tensor, is_snapshot: bool = False): 37 | cov, _ = data_in_preprocess(cov) 38 | cov = cov_normalize(cov,self.normalization,self.N_sensors) 39 | with torch.no_grad(): 40 | outputs = self.net(cov.to(self.device)).cpu() 41 | return outputs 42 | 43 | def get_DoA(self, data_in: torch.Tensor, num_sources: int, is_snapshot: bool = False, noise_var: torch.Tensor = None): 44 | if self.out_type == 'direct': 45 | batch_size = data_in.shape[0] 46 | out = self.get_ULA_cov(data_in,False) 47 | DoA = torch.sort(out[:,:num_sources])[0] 48 | success = [True for _ in range(batch_size)] 49 | return DoA, success 50 | elif self.out_type == 'branch': 51 | batch_size = data_in.shape[0] 52 | cov, _ = data_in_preprocess(data_in) 53 | cov = cov_normalize(cov,self.normalization,self.N_sensors) 54 | with torch.no_grad(): 55 | out = self.net(cov.to(self.device)) 56 | DoA = torch.sort(out[num_sources-1].cpu())[0] 57 | success = [True for _ in range(batch_size)] 58 | return DoA, success 59 | else: 60 | return super().get_DoA(data_in,num_sources,is_snapshot,noise_var) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.10.0 2 | matlabengine==23.2.1 3 | matplotlib==3.8.2 4 | numpy==1.26.3 5 | scipy==1.12.0 6 | torch==2.1.1 7 | tqdm==4.66.1 -------------------------------------------------------------------------------- /run/cuda0_lr_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # search the best maximum learning rate 4 | 5 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 2.0 --save_dataset 1 6 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 1.0 7 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.5 8 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.2 9 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.1 10 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.05 11 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.02 12 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.01 13 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.005 14 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.002 15 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.001 16 | 17 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.5 18 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.2 19 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.1 20 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.05 21 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.02 22 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.01 23 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.005 24 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.002 25 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.001 26 | 27 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.2 28 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.1 29 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.05 30 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.02 31 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.01 32 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.005 33 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.002 34 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.001 35 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.0005 36 | 37 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 2.0 --consistent_rank_sampling 1 38 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 1.0 --consistent_rank_sampling 1 39 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.5 --consistent_rank_sampling 1 40 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.2 --consistent_rank_sampling 1 41 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 42 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.05 --consistent_rank_sampling 1 43 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.02 --consistent_rank_sampling 1 44 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.01 --consistent_rank_sampling 1 45 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.005 --consistent_rank_sampling 1 46 | 47 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 1.0 --consistent_rank_sampling 1 48 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.5 --consistent_rank_sampling 1 49 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.2 --consistent_rank_sampling 1 50 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.1 --consistent_rank_sampling 1 51 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.05 --consistent_rank_sampling 1 52 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.02 --consistent_rank_sampling 1 53 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.01 --consistent_rank_sampling 1 -------------------------------------------------------------------------------- /run/cuda0_n4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 4-element MRA 4 | 5 | python3 main.py --train_L 200 --loss ToepSquare --model N4_M7_toep_WRN_16_8 --mu 0.05 --save_dataset 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 6 | 7 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.01 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 8 | 9 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.005 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 10 | 11 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 12 | 13 | # imperfect arrays, 4-element MRA, top 2 14 | 15 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.005 --rho 1.0 --mix_rho 1 --save_dataset 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 16 | 17 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --rho 1.0 --mix_rho 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /run/cuda0_n4_other_distances.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 4-element MRA, other distances between subspaces 4 | 5 | python3 main.py --train_L 200 --loss SignalChordalDistPA --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 --save_dataset 1 6 | python3 main.py --train_L 200 --loss SignalChordalDistOB --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 7 | python3 main.py --train_L 200 --loss SignalProjectionDistPA --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 8 | python3 main.py --train_L 200 --loss SignalProjectionDistOB --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 9 | python3 main.py --train_L 200 --loss SignalFubiniStudyDistPA --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 10 | python3 main.py --train_L 200 --loss SignalFubiniStudyDistOB --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 11 | python3 main.py --train_L 200 --loss SignalProcrustesDistPA --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 12 | python3 main.py --train_L 200 --loss SignalSpectralDistPA --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /run/cuda0_n4_random_power.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 4-element MRA, random powers 4 | 5 | python3 main.py --random_power 1 --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 6 | 7 | python3 main.py --random_power 1 --train_L 200 --loss ToepSquare --model N4_M7_toep_WRN_16_8 --mu 0.05 --save_dataset 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 8 | 9 | python3 main.py --random_power 1 --train_L 200 --loss FrobeniusNorm --mu 0.01 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 10 | 11 | python3 main.py --random_power 1 --train_L 200 --loss AffInvDist --mu 0.005 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /run/cuda0_n4_w_and_wo_crs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 4-element MRA, w/ consistent rank sampling vs. w/o 4 | 5 | # batch size 4096: 172.30 vs. 356.67 seconds/epoch (w crs vs. w/o crs using grouping): Best validation loss: 2.1322e-01 vs. 2.1312e-01 (w crs vs. w/o crs using grouping) 6 | python3 main.py --train_L 200 --val_L 60 --batch_size 4096 --val_batch_size 4096 --print_every_n_batch 10000 --loss SignalSubspaceDistNoCrsGroup --mu 0.1 --consistent_rank_sampling 0 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 7 | python3 main.py --train_L 200 --val_L 60 --batch_size 4096 --val_batch_size 4096 --print_every_n_batch 10000 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --N_sensors 4 --n_sources_train 1 2 3 4 5 6 --n_sources_val 1 2 3 4 5 6 --min_sep 3 3 3 3 3 3 --model N4_M7_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /run/cuda0_n5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 5-element MRA 4 | 5 | python3 main.py --train_L 200 --loss ToepSquare --model N5_M10_toep_WRN_16_8 --mu 0.05 --save_dataset 1 6 | 7 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.01 8 | 9 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.005 10 | 11 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 12 | 13 | # imperfect arrays, 5-element MRA, top 2 14 | 15 | python3 main.py --train_L 200 --loss AffInvDist --mu 0.005 --rho 1.0 --mix_rho 1 --save_dataset 1 16 | 17 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --rho 1.0 --mix_rho 1 18 | 19 | # gridless end-to-end approach 20 | python3 main.py --train_L 200 --loss BranchAngleMSE --model Branch_N5_M10_WRN_16_8 --mu 0.2 --consistent_rank_sampling 1 -------------------------------------------------------------------------------- /run/cuda0_n6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # standard assumptions, 6-element MRA 4 | 5 | python3 main.py --train_L 200 --loss ToepSquare --model N6_M14_toep_WRN_16_8 --mu 0.05 --save_dataset 1 --N_sensors 6 --n_sources_train 1 2 3 4 5 6 7 8 9 10 11 12 13 --n_sources_val 1 2 3 4 5 6 7 8 9 10 11 12 13 --min_sep 3 3 3 3 3 3 3 3 3 3 3 3 3 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 6 | 7 | python3 main.py --train_L 200 --loss FrobeniusNorm --mu 0.01 --N_sensors 6 --n_sources_train 1 2 3 4 5 6 7 8 9 10 11 12 13 --n_sources_val 1 2 3 4 5 6 7 8 9 10 11 12 13 --min_sep 3 3 3 3 3 3 3 3 3 3 3 3 3 --model N6_M14_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 8 | 9 | python3 main.py --train_L 200 --loss AffInvDist3 --mu 0.005 --N_sensors 6 --n_sources_train 1 2 3 4 5 6 7 8 9 10 11 12 13 --n_sources_val 1 2 3 4 5 6 7 8 9 10 11 12 13 --min_sep 3 3 3 3 3 3 3 3 3 3 3 3 3 --model N6_M14_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 10 | 11 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --N_sensors 6 --n_sources_train 1 2 3 4 5 6 7 8 9 10 11 12 13 --n_sources_val 1 2 3 4 5 6 7 8 9 10 11 12 13 --min_sep 3 3 3 3 3 3 3 3 3 3 3 3 3 --model N6_M14_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 12 | 13 | # imperfect arrays, 6-element MRA, top 2 14 | 15 | python3 main.py --train_L 200 --loss AffInvDist3 --mu 0.005 --rho 1.0 --mix_rho 1 --save_dataset 1 --N_sensors 6 --n_sources_train 1 2 3 4 5 6 7 8 9 10 11 12 13 --n_sources_val 1 2 3 4 5 6 7 8 9 10 11 12 13 --min_sep 3 3 3 3 3 3 3 3 3 3 3 3 3 --model N6_M14_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 16 | 17 | python3 main.py --train_L 200 --loss SignalSubspaceDist --mu 0.1 --consistent_rank_sampling 1 --rho 1.0 --mix_rho 1 --N_sensors 6 --n_sources_train 1 2 3 4 5 6 7 8 9 10 11 12 13 --n_sources_val 1 2 3 4 5 6 7 8 9 10 11 12 13 --min_sep 3 3 3 3 3 3 3 3 3 3 3 3 3 --model N6_M14_WRN_16_8 --gain_bias 0.0 0.2 0.2 0.2 0.2 0.2 0.2 0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 --phase_bias_deg 0 -30 -30 -30 -30 -30 -30 -30 30 30 30 30 30 30 --position_bias 0.0 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2 0.2 0.2 0.2 0.2 0.2 0.2 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 14 2023 3 | 4 | @author: Kuan-Lin Chen 5 | 6 | Modified from https://github.com/kjason/DnnNormTimeFreq4DoA/tree/main/SpeechEnhancement 7 | """ 8 | import sys 9 | import os 10 | import time 11 | import torch 12 | import scipy.io 13 | import math 14 | from datetime import datetime 15 | from utils import get_device_name 16 | 17 | from batch_sampler import ConsistentRankBatchSampler 18 | 19 | class TrainParam: 20 | def __init__(self, 21 | mu, 22 | mu_scale, 23 | mu_epoch, 24 | weight_decay, 25 | momentum, 26 | batch_size, 27 | val_batch_size, 28 | nesterov, 29 | onecycle, 30 | optimizer 31 | ): 32 | assert len(mu_scale)==len(mu_epoch), "the length of mu_scale and mu_epoch should be the same" 33 | self.weight_decay = weight_decay 34 | self.momentum = momentum 35 | self.batch_size = batch_size 36 | self.val_batch_size = val_batch_size 37 | self.max_epoch = mu_epoch[-1] 38 | self.mu = mu 39 | self.mu_scale = mu_scale 40 | self.mu_epoch = mu_epoch 41 | self.nesterov = nesterov 42 | self.onecycle = onecycle 43 | self.optimizer = optimizer 44 | 45 | class TrainRegressor: 46 | pin_memory = True 47 | ckpt_filename = 'train.pt' 48 | def __init__(self, 49 | name, 50 | net, 51 | tp, 52 | trainset, 53 | validationset, 54 | criterion, 55 | device, 56 | seed, 57 | resume, 58 | checkpoint_folder, 59 | num_workers, 60 | consistent_rank_sampling, 61 | milestone = [], 62 | print_every_n_batch = 1, 63 | fp16 = False, 64 | meta_data = None 65 | ): 66 | torch.manual_seed(seed) 67 | self.criterion = criterion 68 | self.device = device 69 | self.net = net().to(device) 70 | self.checkpoint_folder = checkpoint_folder 71 | self.name = name 72 | self.seed = seed 73 | self.num_workers = num_workers 74 | self.milestone = milestone 75 | self.print_every_n_batch = print_every_n_batch 76 | self.consistent_rank_sampling = consistent_rank_sampling 77 | self.trainset = trainset 78 | self.validationset = validationset 79 | self.fp16 = fp16 80 | 81 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [train_regressor.py] {get_device_name(device)}") 82 | self.num_parameters = self.count_parameters() 83 | print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} [train_regressor.py] number of parameters in the model {name}: {self.num_parameters:,}") 84 | 85 | if self.consistent_rank_sampling is True: 86 | train_batch_sampler = ConsistentRankBatchSampler(N=trainset.N_datapoints_per_nsrc,K=len(trainset.num_sources),batch_size=tp.batch_size) 87 | val_batch_sampler = ConsistentRankBatchSampler(N=validationset.N_datapoints_per_nsrc,K=len(validationset.num_sources),batch_size=tp.batch_size) 88 | self.trainloader = torch.utils.data.DataLoader(trainset,batch_sampler=train_batch_sampler,num_workers=self.num_workers,pin_memory=self.pin_memory) 89 | self.validationloader = torch.utils.data.DataLoader(validationset,batch_sampler=val_batch_sampler,num_workers=self.num_workers,pin_memory=self.pin_memory) 90 | else: 91 | self.trainloader = torch.utils.data.DataLoader(trainset,batch_size=tp.batch_size,shuffle=True,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=False) 92 | self.validationloader = torch.utils.data.DataLoader(validationset,batch_size=tp.val_batch_size,shuffle=False,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=False) 93 | 94 | if tp.optimizer == "SGD": 95 | self.optimizer = torch.optim.SGD(self.net.parameters(),lr=tp.mu,momentum=tp.momentum,nesterov=tp.nesterov,weight_decay=tp.weight_decay) 96 | elif tp.optimizer == "AdamW": 97 | self.optimizer = torch.optim.AdamW(self.net.parameters(),lr=tp.mu,weight_decay=tp.weight_decay) 98 | else: 99 | raise ValueError(f"optimizer {self.tp.optimizer} not implemented") 100 | 101 | if tp.onecycle is True: 102 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer,max_lr=tp.mu,steps_per_epoch=len(self.trainloader),epochs=tp.max_epoch) 103 | else: 104 | self.mu_lambda = lambda i: next(tp.mu_scale[j] for j in range(len(tp.mu_epoch)) if min(tp.mu_epoch[j]//(i+1),1.0) >= 1.0) if i 12: 18 | raise ValueError("only support N up to 12") 19 | else: 20 | spacing = spacing_list[N-2] 21 | sensor_grid = [sum(spacing[:i]) for i in range(N)] # sensor locations on the grid of nonnegative integers 22 | N_a = sensor_grid[-1] 23 | sensor_locations = [(i-N_a/2)*d for i in sensor_grid] # sensor locations in the space 24 | return sensor_locations, sensor_grid, N_a 25 | 26 | def dir_path(path): 27 | if not os.path.isdir(path): 28 | os.mkdir(path) 29 | return path 30 | 31 | def file_path(path): 32 | if os.path.isfile(path) or path == None: 33 | return path 34 | else: 35 | raise ValueError('{} is not a valid file'.format(path)) 36 | 37 | def check_device(device): 38 | if device == 'cpu': 39 | return device 40 | elif torch.cuda.is_available(): 41 | count = torch.cuda.device_count() 42 | for i in range(count): 43 | if device == 'cuda:'+str(i): 44 | return device 45 | raise ValueError('{} not found in the available cuda or cpu list'.format(device)) 46 | else: 47 | raise ValueError('{} is not a valid cuda or cpu device'.format(device)) 48 | 49 | def get_device_name(device): 50 | if device[:4] == 'cuda': 51 | return torch.cuda.get_device_name(int(device[-1])) # print the GPU 52 | else: 53 | return device 54 | 55 | def data_in_preprocess(data_in: torch.Tensor): 56 | if len(data_in.shape) == 3: 57 | batch_size = data_in.size(0) 58 | elif len(data_in.shape) == 2: 59 | data_in = data_in.unsqueeze(0) 60 | batch_size = 1 61 | else: 62 | raise ValueError(f"len(data_in.shape)={len(data_in.shape)}, invalid data_in") 63 | return data_in, batch_size 64 | 65 | def cov_normalize(cov: torch.Tensor, mode: str, N: int): 66 | if mode == 'max': 67 | n_cov = cov / torch.amax(torch.abs(cov),dim=[-2,-1],keepdim=True) 68 | elif mode == 'sensors': 69 | n_cov = cov / N 70 | elif mode == 'disabled': 71 | n_cov = cov 72 | else: 73 | raise ValueError(f'normalization={mode} is invalid') 74 | return n_cov 75 | 76 | def ComplexMat2RealImagMat(cov: torch.Tensor): 77 | return torch.cat((cov.real.unsqueeze(1),cov.imag.unsqueeze(1)),1) 78 | 79 | def RealImagMat2ComplexMat(cov: torch.Tensor): 80 | return torch.complex(cov[:,0,:,:],cov[:,1,:,:]) 81 | 82 | def RealImagMat2GramComplexMat(cov: torch.Tensor): 83 | c = RealImagMat2ComplexMat(cov) 84 | return torch.matmul(c,c.conj().transpose(-2,-1)) 85 | 86 | def HermitianMat2RealVec(cov: torch.Tensor): 87 | N = cov.size(-1) 88 | tri = torch.triu(torch.ones(N, N)) == 1 89 | otri = (torch.triu(torch.ones(N,N)) == 1).fill_diagonal_(False) 90 | def OneHMat2RealVec(c: torch.Tensor): 91 | return torch.cat((c.real[tri == 1],c.imag[otri == 1]),0) 92 | return torch.vmap(OneHMat2RealVec)(cov) 93 | 94 | def RealVec2HermitianMat(vec: torch.Tensor): 95 | N = int(np.sqrt(vec.shape[1])) 96 | batch_size = vec.shape[0] 97 | H_real = torch.zeros(batch_size,N,N,dtype=vec.dtype,device=vec.device) 98 | H_imag = torch.zeros(batch_size,N,N,dtype=vec.dtype,device=vec.device) 99 | for i in range(N): 100 | j = int(i*N-i*(i-1)/2) 101 | H_real[:,i,i:] = vec[:,j:j+N-i] 102 | H_real[:,1+i:,i] = H_real[:,i,1+i:] 103 | k = (N+1)*N/2 104 | Nm = N - 1 105 | for i in range(Nm): 106 | j = int(k+i*Nm-i*(i-1)/2) 107 | H_imag[:,i,1+i:] = vec[:,j:j+Nm-i] 108 | H_imag[:,1+i:,i] = -H_imag[:,i,1+i:] 109 | H = torch.complex(H_real,H_imag) 110 | return H 111 | 112 | def HermitianToeplitzMat2RealVec(cov: torch.Tensor): 113 | return torch.cat([cov[:,0,:].real,cov[:,0,1:].imag],1) 114 | 115 | def RealVec2HermitianToeplitzMat(vec: torch.Tensor): 116 | N = int((vec.shape[-1]+1)/2) 117 | batch_size = vec.shape[0] 118 | T_real = torch.zeros(batch_size,N,N,device=vec.device) 119 | T_imag = torch.zeros(batch_size,N,N,device=vec.device) 120 | T_real[:,0,:] = vec[:,:N] 121 | T_real[:,1:,0] = vec[:,1:N] 122 | T_imag[:,0,1:] = vec[:,N:] 123 | T_imag[:,1:,0] = - vec[:,N:] 124 | for i in range(1,N): 125 | T_real[:,i,i:] = T_real[:,i-1,i-1:-1] 126 | T_real[:,i+1:,i] = T_real[:,i,i+1:] 127 | T_imag[:,i,i:] = T_imag[:,i-1,i-1:-1] 128 | T_imag[:,i+1:,i] = - T_imag[:,i,i+1:] 129 | T = torch.complex(T_real,T_imag) 130 | return T 131 | 132 | if __name__ == '__main__': 133 | import time 134 | cov = torch.zeros(3,3,dtype=torch.complex64,device='cuda:0') 135 | cov[0,0] = 2 136 | cov[1,1] = 2 137 | cov[2,2] = 2 138 | cov[1,0] = 3+1*1j 139 | cov[2,1] = 3+1*1j 140 | cov[2,0] = 4+5*1j 141 | cov[0,1] = 3-1*1j 142 | cov[1,2] = 3-1*1j 143 | cov[0,2] = 4-5*1j 144 | print(cov) 145 | cov = cov.unsqueeze(0) 146 | cov = cov.repeat(7,1,1) 147 | tic = time.time() 148 | vec = HermitianToeplitzMat2RealVec(cov) 149 | cov_r = RealVec2HermitianToeplitzMat(vec) 150 | print(torch.all(torch.isclose(cov_r,cov))) 151 | err_part = torch.isclose(cov_r,cov) == False 152 | print(cov[err_part]-cov_r[err_part]) 153 | toc = time.time() 154 | print(toc-tic) --------------------------------------------------------------------------------