├── README.md ├── TFA-Net_inference ├── 1110_zzq_80_1_Raw_0_chan1.mat ├── ECGData.mat ├── GWarblet.m ├── GWarblet_complex.m ├── MSST_Y_new2.m ├── PerformDFT.m ├── RS.m ├── SET_Y2.m ├── SST2.m ├── TFA_exe │ ├── TFNet_model.py │ ├── TFNet_model_stft.py │ ├── TFNet_server.py │ ├── TFNet_server.spec │ ├── TFNet_util.py │ ├── __pycache__ │ │ ├── TFNet_model.cpython-36.pyc │ │ ├── TFNet_model_stft.cpython-36.pyc │ │ ├── TFNet_util.cpython-36.pyc │ │ ├── complexFunctions.cpython-36.pyc │ │ ├── complexLayers.cpython-36.pyc │ │ └── complexModule.cpython-36.pyc │ ├── complexFunctions.py │ ├── complexLayers.py │ ├── complexModule.py │ ├── config │ │ └── config.ini │ └── model │ │ ├── RED_Net.pth │ │ └── TFA-Net.pth ├── batdata2.mat ├── computation_complexity.m ├── data1_resfreq.mat ├── dianji_data.mat ├── exp1_simu.m ├── exp1_simu2.m ├── exp2_bat.m ├── exp2_heart.m ├── exp3_breheart.m ├── exp3_dianji.m ├── exp3_voice.m ├── get_fscoeff.m ├── ref_methods.m ├── ringtoneFrasers.mp3 ├── stft_rearrage.m ├── tight_subplot.m └── winderi.m └── TFA-Net_train ├── checkpoint ├── RED-Net_model │ └── fr │ │ └── RED_Net.pth └── TFA-Net_model │ └── fr │ └── TFA-Net.pth ├── complexFunctions.py ├── complexLayers.py ├── complexModules.py ├── complexTrain.py └── data ├── data.py ├── dataset.py ├── fr.py ├── loss.py └── noise.py /README.md: -------------------------------------------------------------------------------- 1 | # TFA-Net 2 | 3 | Codes for ``TFA-Net: A Deep Learning-Based Time-Frequency Analysis Tool'' 4 | 5 | 1. Folder TFA-Net_train contains the codes for TFA-Net training 6 | --complexTrain.py is the main function for TFA-Net 7 | --stft_backbone.py is the main function for a RED-Net modified as our demands. That is, TFA is performed on the basis of the STFT results. 8 | --generate_dataset.py is used to generate the test set, and model_explanation.py is used for interference of the test set. 9 | --Env. Requirements: pytorch 1.7.1 10 | 11 | 2. Folder TFA-Net_inference contains the codes for experiments in paper 12 | --The subFolder TFA_exe is a flask interferece implementation for the model interferece, which is required to run first, Env. Requirements: pytorch 1.7.1. 13 | * TFA_exe/config includes a configuration file for locating the experiment codes and service port used for calling a model interference 14 | --The experiments included in the paper: 15 | * exp1_simu.m is the first simulation in Section IV A 16 | * exp1_simu.m is the second simulation in Section IV A 17 | * exp2_bat.m is the TFA of bat signal in Section IV B (1) 18 | * exp2_heart.m is the TFA of ECG in Section IV B (2) 19 | * exp3_dianji.m is the TFA of micro-Doppler of reflectors in Section IV B (3) 20 | * exp3_breheart.m is the TFA of micro-Doppler of Vital Signs in Section IV B (4) 21 | * exp3_voice.m is the TFA of voice of a mammal in Section IV B (5). 22 | 23 | 24 | -------------------------------------------------------------------------------- /TFA-Net_inference/1110_zzq_80_1_Raw_0_chan1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/1110_zzq_80_1_Raw_0_chan1.mat -------------------------------------------------------------------------------- /TFA-Net_inference/ECGData.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/ECGData.mat -------------------------------------------------------------------------------- /TFA-Net_inference/GWarblet.m: -------------------------------------------------------------------------------- 1 | function [Spec,f] = GWarblet(Sig,SampFreq,Ratio,fm,N,WinLen); 2 | % Generalized Warblet Transform 3 | % 4 | % Sig : the signal to be analyzed 5 | % SampFreq : sampling frequency 6 | % Ratio : Coefficients of Fourier series 2xm [coeff. of sine term; coeff. of cosine term] 7 | % fm : non-zero harmonic frequencies of Fourier series 1xm 8 | % N : the number of frequency bins (default : length(Sig)). 9 | % WinLen : the length of window used to locate the signal in time. 10 | 11 | % 12 | % Written by Yang Yang, March 2011. 13 | % email: z.peng@sjtu.edu.cn 14 | % Copyright (c) belongs to the authors of the papers: 15 | %) Peng Z.K , Meng G., Lang Z.Q.,Chu F.L, Zhang W.M., Yang Y., Polynomial Chirplet Transform with Application to Instantaneous Frequency Estimation, 16 | % IEEE Transactions on Measurement and Instrumentation 60(2011) 3222-3229 17 | % Yang Y., Peng Z.K., Dong X.J., Zhang W.M., Meng G.,General parameterized time-frequency transform, 18 | % IEEE Transactions on Signal Processing, 62(2014) 2751-2764 19 | % Yang Y., Peng Z.K., Zhang W.M., Meng G., Spline-kernelled chirplet transform for the analysis of signals with time-varying frequency and its application, 20 | % IEEE Transactions on Industrial Electronics, 59(2012) 1612-1621 21 | % Yang Y., Peng Z.K, Meng G., Zhang W.M., Characterize highly oscillating frequency modulation using generalized Warblet transform, 22 | % Mechanical Systems and Signal Processing, 26 (2012) 128-140 23 | % Yang Y., Peng Z.K., Zhang W.M., Meng G., Frequency-varying group delay estimation using frequency domain polynomial chirplet transform, 24 | % Mechanical Systems and Signal Processing, 46(2014) 146-162 25 | % The citation about the papers must be included in all publications or 26 | % thesises as long as this program is used by anyone. 27 | 28 | 29 | 30 | if(nargin < 4), 31 | error('At least 4 inputs are required!'); 32 | end 33 | 34 | SigLen = length(Sig); 35 | 36 | if (nargin < 6), 37 | WinLen = SigLen / 4; 38 | end 39 | 40 | if (nargin < 5), 41 | N = SigLen; 42 | end 43 | 44 | if(N > 512), 45 | N = 512; 46 | end 47 | 48 | if size(Ratio,1)==1 49 | Ratio(2,:) = zeros(size(Ratio)); 50 | % disp('The coefficients of cosinea are zero') 51 | end 52 | 53 | if fm ==0 54 | disp('The harmonic frequencies of Fourier series cannot be zero') 55 | return 56 | end 57 | 58 | sig = Sig; 59 | 60 | RatioNum = size(Ratio,2); 61 | 62 | dt = (0:(SigLen-1))'; 63 | dt = dt/ SampFreq; 64 | 65 | sf = zeros(size(dt));%shift 66 | for k = 1:RatioNum, 67 | sf = sf - Ratio(1,k) * sin(2*pi*fm(k)*dt) + Ratio(2,k)*cos(2*pi*fm(k)*dt); 68 | end 69 | 70 | kernel = zeros(size(dt)); %rotate 71 | for k = 1:RatioNum, 72 | kernel = kernel + Ratio(1,k)/fm(k) * cos(2*pi*fm(k)*dt)+ Ratio(2,k)/fm(k) * sin(2*pi*fm(k)*dt); 73 | end 74 | rSig = hilbert(real(Sig)); %Z(t) 75 | Sig = rSig .* exp(-j*kernel); 76 | 77 | WinLen = ceil(WinLen / 2) * 2; 78 | t = linspace(-1,1,WinLen)'; 79 | WinFun = exp(log(0.005) * t.^2 ); 80 | WinFun = WinFun / norm(WinFun); 81 | Lh = (WinLen - 1)/2; 82 | 83 | Spec = zeros(N,SigLen) ; % matrix 84 | 85 | % wait = waitbar(0,'Please wait...'); 86 | for iLoop = 1:SigLen, 87 | 88 | % waitbar(iLoop/SigLen,wait); 89 | 90 | tau = -min([round(N/2)-1,Lh,iLoop-1]):min([round(N/2)-1,Lh,SigLen-iLoop]); % signal span 91 | temp = floor(iLoop + tau); 92 | 93 | rSig = Sig .* exp(j*2*pi*sf(iLoop)*dt); % shift: IF 94 | rSig = rSig(temp); 95 | 96 | temp1 = floor(Lh+1+tau); % window span 97 | rSig = rSig .* conj(WinFun(temp1)); % Z(t)* complex conjugate of window? 98 | Spec(1:length(rSig),iLoop) = rSig; % windowed analytic signal 99 | end; 100 | 101 | Spec = fft(Spec); 102 | Spec = abs(Spec); 103 | 104 | % close(wait); 105 | 106 | Spec = Spec(1:round(end/2),:); 107 | [nLevel, SigLen] = size(Spec); 108 | 109 | f = [0:nLevel-1]/nLevel * SampFreq/2; % frequency in TF plane? 110 | % t = (0: SigLen-1)/SampFreq; % time in TF plane 111 | % 112 | % [fmax fmin] = FreqRange(sig); 113 | % fmax = fmax * SampFreq; 114 | % fmin = fmin * SampFreq; 115 | % 116 | % clf 117 | % %=====================================================% 118 | % % Plot the result % 119 | % %=====================================================% 120 | % set(gcf,'Position',[20 100 350 300]); 121 | % set(gcf,'Color','w'); 122 | % 123 | % imagesc(t,f,Spec); 124 | % axis([min(t) max(t) fmin fmax]); 125 | % ylabel('Freq / Hz'); 126 | % xlabel('Time / Sec') 127 | % Info = 'C = '; 128 | % 129 | % for i = 1:RatioNum, 130 | % Info = [Info,num2str(Ratio(i),4), ' ']; 131 | % end 132 | % 133 | % if RatioNum == 0, 134 | % Info = 'C = 0'; 135 | % end 136 | 137 | %title(['Nonlinear Chirplet[',Info,']']); 138 | -------------------------------------------------------------------------------- /TFA-Net_inference/GWarblet_complex.m: -------------------------------------------------------------------------------- 1 | function [Spec,f] = GWarblet(Sig,SampFreq,Ratio,fm,N,WinLen); 2 | % Generalized Warblet Transform 3 | % 4 | % Sig : the signal to be analyzed 5 | % SampFreq : sampling frequency 6 | % Ratio : Coefficients of Fourier series 2xm [coeff. of sine term; coeff. of cosine term] 7 | % fm : non-zero harmonic frequencies of Fourier series 1xm 8 | % N : the number of frequency bins (default : length(Sig)). 9 | % WinLen : the length of window used to locate the signal in time. 10 | 11 | % 12 | % Written by Yang Yang, March 2011. 13 | % email: z.peng@sjtu.edu.cn 14 | % Copyright (c) belongs to the authors of the papers: 15 | %) Peng Z.K , Meng G., Lang Z.Q.,Chu F.L, Zhang W.M., Yang Y., Polynomial Chirplet Transform with Application to Instantaneous Frequency Estimation, 16 | % IEEE Transactions on Measurement and Instrumentation 60(2011) 3222-3229 17 | % Yang Y., Peng Z.K., Dong X.J., Zhang W.M., Meng G.,General parameterized time-frequency transform, 18 | % IEEE Transactions on Signal Processing, 62(2014) 2751-2764 19 | % Yang Y., Peng Z.K., Zhang W.M., Meng G., Spline-kernelled chirplet transform for the analysis of signals with time-varying frequency and its application, 20 | % IEEE Transactions on Industrial Electronics, 59(2012) 1612-1621 21 | % Yang Y., Peng Z.K, Meng G., Zhang W.M., Characterize highly oscillating frequency modulation using generalized Warblet transform, 22 | % Mechanical Systems and Signal Processing, 26 (2012) 128-140 23 | % Yang Y., Peng Z.K., Zhang W.M., Meng G., Frequency-varying group delay estimation using frequency domain polynomial chirplet transform, 24 | % Mechanical Systems and Signal Processing, 46(2014) 146-162 25 | % The citation about the papers must be included in all publications or 26 | % thesises as long as this program is used by anyone. 27 | 28 | 29 | 30 | if(nargin < 4), 31 | error('At least 4 inputs are required!'); 32 | end 33 | 34 | SigLen = length(Sig); 35 | 36 | if (nargin < 6), 37 | WinLen = SigLen / 4; 38 | end 39 | 40 | if (nargin < 5), 41 | N = SigLen; 42 | end 43 | 44 | if(N > 512), 45 | N = 512; 46 | end 47 | 48 | if size(Ratio,1)==1 49 | Ratio(2,:) = zeros(size(Ratio)); 50 | % disp('The coefficients of cosinea are zero') 51 | end 52 | 53 | if fm ==0 54 | disp('The harmonic frequencies of Fourier series cannot be zero') 55 | return 56 | end 57 | 58 | sig = Sig; 59 | 60 | RatioNum = size(Ratio,2); 61 | 62 | dt = (0:(SigLen-1))'; 63 | dt = dt/ SampFreq; 64 | 65 | sf = zeros(size(dt));%shift 66 | for k = 1:RatioNum, 67 | sf = sf - Ratio(1,k) * sin(2*pi*fm(k)*dt) + Ratio(2,k)*cos(2*pi*fm(k)*dt); 68 | end 69 | 70 | kernel = zeros(size(dt)); %rotate 71 | for k = 1:RatioNum, 72 | kernel = kernel + Ratio(1,k)/fm(k) * cos(2*pi*fm(k)*dt)+ Ratio(2,k)/fm(k) * sin(2*pi*fm(k)*dt); 73 | end 74 | rSig = Sig; %Z(t) 75 | Sig = rSig .* exp(-j*kernel); 76 | 77 | WinLen = ceil(WinLen / 2) * 2; 78 | t = linspace(-1,1,WinLen)'; 79 | WinFun = exp(log(0.005) * t.^2 ); 80 | WinFun = WinFun / norm(WinFun); 81 | Lh = (WinLen - 1)/2; 82 | 83 | Spec = zeros(N,SigLen) ; % matrix 84 | 85 | wait = waitbar(0,'Please wait...'); 86 | for iLoop = 1:SigLen, 87 | 88 | waitbar(iLoop/SigLen,wait); 89 | 90 | tau = -min([round(N/2)-1,Lh,iLoop-1]):min([round(N/2)-1,Lh,SigLen-iLoop]); % signal span 91 | temp = floor(iLoop + tau); 92 | 93 | rSig = Sig .* exp(j*2*pi*sf(iLoop)*dt); % shift: IF 94 | rSig = rSig(temp); 95 | 96 | temp1 = floor(Lh+1+tau); % window span 97 | rSig = rSig .* conj(WinFun(temp1)); % Z(t)* complex conjugate of window? 98 | Spec(1:length(rSig),iLoop) = rSig; % windowed analytic signal 99 | end; 100 | 101 | Spec = fft(Spec); 102 | Spec = abs(Spec); 103 | 104 | close(wait); 105 | 106 | Spec = Spec(1:round(end/2),:); 107 | [nLevel, SigLen] = size(Spec); 108 | 109 | f = [0:nLevel-1]/nLevel * SampFreq/2; % frequency in TF plane? 110 | % t = (0: SigLen-1)/SampFreq; % time in TF plane 111 | % 112 | % [fmax fmin] = FreqRange(sig); 113 | % fmax = fmax * SampFreq; 114 | % fmin = fmin * SampFreq; 115 | % 116 | % clf 117 | % %=====================================================% 118 | % % Plot the result % 119 | % %=====================================================% 120 | % set(gcf,'Position',[20 100 350 300]); 121 | % set(gcf,'Color','w'); 122 | % 123 | % imagesc(t,f,Spec); 124 | % axis([min(t) max(t) fmin fmax]); 125 | % ylabel('Freq / Hz'); 126 | % xlabel('Time / Sec') 127 | % Info = 'C = '; 128 | % 129 | % for i = 1:RatioNum, 130 | % Info = [Info,num2str(Ratio(i),4), ' ']; 131 | % end 132 | % 133 | % if RatioNum == 0, 134 | % Info = 'C = 0'; 135 | % end 136 | 137 | %title(['Nonlinear Chirplet[',Info,']']); 138 | -------------------------------------------------------------------------------- /TFA-Net_inference/MSST_Y_new2.m: -------------------------------------------------------------------------------- 1 | function [Ts,tfr,omega2] = MSST_Y_new2(x,hlength,num); 2 | % Computes the MSST (Ts) of the signal x. 3 | % Expression (31)-based Algorithm. 4 | % INPUT 5 | % x : Signal needed to be column vector. 6 | % hlength: The hlength of window function. 7 | % num : iteration number. 8 | % OUTPUT 9 | % Ts : The SST 10 | % tfr : The STFT 11 | nfft=256; 12 | [xrow,xcol] = size(x); 13 | 14 | if (xcol~=1), 15 | error('X must be column vector'); 16 | end; 17 | 18 | if (nargin < 3), 19 | error('At least 3 parameter is required'); 20 | end 21 | % 22 | % if (nargin < 2), 23 | % hlength=round(xrow/5); 24 | % num=1; 25 | % else if (nargin < 3), 26 | % num=1; 27 | % end; 28 | 29 | hlength=hlength+1-rem(hlength,2); 30 | ht = linspace(-0.5,0.5,hlength);ht=ht'; 31 | 32 | % Gaussian window 33 | % h = exp(-pi/0.32^2*ht.^2); 34 | h = kaiser(min(hlength,xrow),10); 35 | % derivative of window 36 | % dh = -2*pi/0.32^2*ht .* h; % g' 37 | dh=winderi(h,1); 38 | [hrow,hcol]=size(h); Lh=(hrow-1)/2; 39 | 40 | N=xrow; 41 | t=1:xrow; 42 | 43 | [trow,tcol] = size(t); 44 | 45 | 46 | tfr1= zeros (N,tcol) ; 47 | tfr2= zeros (N,tcol) ; 48 | 49 | 50 | Ts= zeros (nfft,tcol) ; 51 | % Pad the signal vector x 52 | if rem(hlength,2)==1 53 | xp = [zeros((hlength-1)/2,1) ; x ; zeros((hlength-1)/2,1)]; 54 | else 55 | xp = [zeros((hlength)/2,1) ; x ; zeros((hlength-2)/2,1)]; 56 | end 57 | xin = stft_rearrage(xp,length(xp),hlength,hlength-1,1); 58 | 59 | % Compute the STFT 60 | [tfr1,fout] = PerformDFT(bsxfun(@times,h,xin),nfft,1); 61 | tfr2 = PerformDFT(bsxfun(@times,dh,xin),nfft,1); 62 | 63 | m = floor(hlength/2); 64 | inds = 0:nfft-1; 65 | ez = exp(-1i*2*pi*m*inds/nfft)'; 66 | sstout = bsxfun(@times,tfr1,ez); 67 | 68 | ft = 1:nfft; 69 | bt = 1:N; 70 | 71 | %%operator omega 72 | nb = length(bt); 73 | neta = length(ft); 74 | 75 | omega = zeros (nfft,tcol); 76 | 77 | fout=0:1/nfft:1-1/nfft; 78 | for b=1:nb 79 | fcorr=-imag(tfr2(ft,b)./tfr1(ft,b)); 80 | fcorr(~isfinite(fcorr)) = 0; 81 | omega(:,b) =1+ mod(round((fout'+fcorr)*(nfft-1)),nfft); 82 | % omega(:,b)=ft'-imag(tfr2(ft,b)./tfr1(ft,b)); 83 | end 84 | 85 | [neta,nb]=size(tfr1); 86 | 87 | if num>1 88 | for kk=1:num-1 89 | for b=1:nb 90 | for eta=1:neta 91 | k = omega(eta,b); 92 | if k>=1 && k<=neta 93 | omega2(eta,b)=omega(k,b); 94 | end 95 | end 96 | end 97 | omega=omega2; 98 | end 99 | else 100 | omega2=omega; 101 | end 102 | 103 | for b=1:nb%time 104 | % Reassignment step 105 | for eta=1:neta%frequency 106 | if abs(sstout(eta,b))>0.0001%you can set much lower value than this. 107 | k = omega2(eta,b); 108 | if k>=1 && k<=neta 109 | Ts(k,b) = Ts(k,b) + sstout(eta,b); 110 | end 111 | end 112 | end 113 | end 114 | %tfr=tfr/(sum(h)/2); 115 | tfr=sstout/(xrow/2); 116 | Ts=Ts/(xrow/2); 117 | end 118 | % 119 | % function [Ts_f]=SST(tfr_f,omega_f); 120 | % [tfrm,tfrn]=size(tfr_f); 121 | % Ts_f= zeros (tfrm,tfrn) ; 122 | % %mx=max(max(tfr_f)); 123 | % for b=1:tfrn%time 124 | % % Reassignment step 125 | % for eta=1:tfrm%frequency 126 | % %if abs(tfr_f(eta,b))>0.001*mx%you can set much lower value than this. 127 | % k = omega_f(eta,b); 128 | % if k>=1 && k<=tfrm 129 | % Ts_f(k,b) = Ts_f(k,b) + tfr_f(eta,b); 130 | % end 131 | % %end 132 | % end 133 | % end 134 | % end -------------------------------------------------------------------------------- /TFA-Net_inference/PerformDFT.m: -------------------------------------------------------------------------------- 1 | function [Xx,f] = PerformDFT(xin,nfft,varargin) 2 | %#codegen 3 | %COMPUTEDFT Computes DFT using FFT or Goertzel 4 | % This function is used to calculate the DFT of a signal using the FFT 5 | % or the Goertzel algorithm. 6 | % 7 | % [XX,F] = COMPUTEDFT(XIN,NFFT) where NFFT is a scalar and computes the 8 | % DFT XX using FFT. F is the frequency points at which the XX is 9 | % computed and is of length NFFT. 10 | % 11 | % [XX,F] = COMPUTEDFT(XIN,F) where F is a vector with at least two 12 | % elements computes the DFT XX using the Goertzel algorithm. 13 | % 14 | % [XX,F] = COMPUTEDFT(...,Fs) returns the frequency vector F (in hz) 15 | % where Fs is the sampling frequency 16 | % 17 | % Inputs: 18 | % XIN is the input signal 19 | % NFFT if a scalar corresponds to the number of FFT points used to 20 | % calculate the DFT using FFT. 21 | % NFFT if a vector corresponds to the frequency points at which the DFT 22 | % is calculated using goertzel. 23 | % FS is the sampling frequency 24 | 25 | % Copyright 2006-2019 The MathWorks, Inc. 26 | 27 | % [1] Oppenheim, A.V., and R.W. Schafer, Discrete-Time Signal Processing, 28 | % Prentice-Hall, Englewood Cliffs, NJ, 1989, pp. 713-718. 29 | % [2] Mitra, S. K., Digital Signal Processing. A Computer-Based Approach. 30 | % 2nd Ed. McGraw-Hill, N.Y., 2001. 31 | 32 | 33 | narginchk(2,3); 34 | if nargin > 2 35 | Fs = varargin{1}; 36 | else 37 | Fs = 2*pi; 38 | end 39 | 40 | nx = size(xin,1); 41 | 42 | isfreqScalar = isscalar(nfft); 43 | 44 | if isfreqScalar 45 | [Xx,f] = computeDFTviaFFT(xin,nx,nfft(1),Fs); 46 | else 47 | f = nfft(:); % if nfft is a vector then it contains a list of freqs 48 | 49 | % see if we can get a uniform spacing of the freq vector 50 | [fstart, fstop, m, maxerr] = getUniformApprox(f); 51 | 52 | % see if the ratio of the maximum absolute deviation relative to the 53 | % largest absolute in the frequency vector is less than a few eps 54 | isuniform = maxerr < 3*eps(class(f)); 55 | 56 | % check if the number of steps in Goertzel ~ 1 k1 N*M is greater 57 | % than the expected number of steps in CZT ~ 20 k2 N*log2(N+M-1) 58 | % where k2/k1 is empirically found to be ~80. 59 | n = size(xin,1); 60 | islarge = m > 80*log2(nextpow2(m+n-1)); 61 | 62 | if isuniform && islarge 63 | % use CZT if uniformly spaced with a significant number of bins 64 | Xx = computeDFTviaCZT(xin,fstart,fstop,m,Fs); 65 | else 66 | % use Goertzel if small number of bins or not uniformly spaced 67 | Xx = computeDFTviaGoertzel(xin,f,Fs); 68 | end 69 | end 70 | 71 | end 72 | 73 | %------------------------------------------------------------------------- 74 | function [Xx,f] = computeDFTviaFFT(xin,nx,nfft,Fs) 75 | % Use FFT to compute raw STFT and return the F vector. 76 | 77 | % Handle the case where NFFT is less than the segment length, i.e., "wrap" 78 | % the data as appropriate. 79 | xin_ncol = size(xin,2); 80 | xin_nchan = size(xin,3); 81 | xw = zeros(nfft,xin_ncol,xin_nchan,'like',xin); 82 | if nx > nfft 83 | for j = 1:xin_ncol*xin_nchan 84 | wrappedData = datawrap(xin(:,j),nfft); 85 | xw(:,j) = wrappedData(:); %coder size inference:- infer it as Column Vector to match LHS 86 | end 87 | else 88 | xw = xin; 89 | end 90 | 91 | Xx = fft(xw,nfft); 92 | f = psdfreqvec('npts',nfft,'Fs',Fs); 93 | 94 | end 95 | 96 | %-------------------------------------------------------------------------- 97 | function Xx = computeDFTviaGoertzel(xin,f,Fs) 98 | % Use Goertzel to compute raw DFT 99 | 100 | f = mod(f,Fs); % 0 <= f < = Fs 101 | xm = size(xin,1); % NFFT 102 | 103 | isInMATLAB = coder.target('MATLAB'); 104 | isInputComplex = ~isreal(xin); 105 | isInputSingle = isa(xin,'single'); 106 | 107 | % wavenumber in cycles/period used by the Goertzel function 108 | % (see equation 11.1 pg. 755 of [2]) 109 | if isInputSingle 110 | k = single(f/Fs*xm); 111 | else 112 | k = f/Fs*xm; 113 | end 114 | 115 | Xx = signal.internal.goertzel.callGoertzel(xin,k,isInputComplex,isInputSingle,isInMATLAB); 116 | 117 | end 118 | 119 | %-------------------------------------------------------------------------- 120 | function Xx = computeDFTviaCZT(xin,fstart,fstop,npts,Fs) 121 | % Use CZT to compute raw DFT 122 | 123 | % start with initial complex weight 124 | Winit = exp(2i*pi*fstart/Fs); 125 | 126 | % compute the relative complex weight 127 | Wdelta = exp(2i*pi*(fstart-fstop)/((npts-1)*Fs)); 128 | 129 | % feed complex weights into chirp-z transform 130 | Xx = czt(xin, npts, Wdelta, Winit); 131 | 132 | end 133 | 134 | % LocalWords: DFT XIN NFFT Fs hz FS Oppenheim Schafer Englewood Mitra nd Graw 135 | % LocalWords: nfft STFT npts wavenumber 136 | -------------------------------------------------------------------------------- /TFA-Net_inference/RS.m: -------------------------------------------------------------------------------- 1 | function [tfr] = RS(x,hlength) 2 | % Reassignment transform 3 | % x : Signal. 4 | % hlength : Window length. 5 | 6 | % tfr : Time-Frequency Representation. 7 | 8 | % This program is distributed in the hope that it will be useful, 9 | % but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 11 | 12 | [xrow,xcol] = size(x); 13 | 14 | N=xrow; 15 | 16 | if (xcol~=1) 17 | error('X must be column vector'); 18 | end 19 | 20 | if (nargin < 2) 21 | hlength=round(xrow/8); 22 | end; 23 | 24 | t=1:N; 25 | nfft=N; 26 | ft = 1:round(nfft); 27 | 28 | [trow,tcol] = size(t); 29 | 30 | hlength=hlength+1-rem(hlength,2); 31 | ht = linspace(-0.5,0.5,hlength);ht=ht'; 32 | 33 | % Gaussian window 34 | h = exp(-pi/0.32^2*ht.^2); 35 | % derivative of window 36 | dh = -2*pi/0.32^2*ht .* h; % g' 37 | % 38 | th=h.*ht; 39 | 40 | [hrow,hcol]=size(h); Lh=(hrow-1)/2; 41 | 42 | tfr1= zeros (nfft,tcol); 43 | tfr2= zeros (nfft,tcol); 44 | tfr3= zeros (nfft,tcol); 45 | 46 | va=N/hlength; 47 | 48 | for icol=1:tcol 49 | ti= t(icol); tau=-min([round(N/2)-1,Lh,ti-1]):min([round(N/2)-1,Lh,xrow-ti]); 50 | indices= rem(nfft+tau,nfft)+1; 51 | rSig = x(ti+tau,1); 52 | tfr1(indices,icol)=rSig.*conj(h(Lh+1+tau)); 53 | tfr2(indices,icol)=rSig.*conj(dh(Lh+1+tau)); 54 | tfr3(indices,icol)=rSig.*conj(th(Lh+1+tau)); 55 | end; 56 | 57 | tfr1=fft(tfr1); 58 | tfr2=fft(tfr2); 59 | tfr3=fft(tfr3); 60 | 61 | tfr1=tfr1(1:round(nfft),:); 62 | tfr2=tfr2(1:round(nfft),:); 63 | tfr3=tfr3(1:round(nfft),:); 64 | 65 | omega1 = zeros(round(nfft),tcol); 66 | omega2= zeros(round(nfft),tcol); 67 | 68 | for b=1:N 69 | omega1(:,b) = (ft-1)'+real(va*1i*tfr2(ft,b)/2/pi./tfr1(ft,b)); 70 | end 71 | 72 | for a=1:round(nfft) 73 | omega2(a,:) = t+(hlength-1)*real(tfr3(a,t)./tfr1(a,t)); 74 | end 75 | 76 | omega1=round(omega1); 77 | omega2=round(omega2); 78 | 79 | Ts = zeros(round(nfft),tcol); 80 | 81 | % Reassignment step 82 | for b=1:N%time 83 | for eta=1:round(nfft)%frequency 84 | if abs(tfr1(eta,b))>0.0001 85 | k1 = omega1(eta,b); 86 | k2 = omega2(eta,b); 87 | if k1>=1 && k1<=round(nfft) && k2>=1 && k2<=N 88 | Ts(k1,k2) = Ts(k1,k2) + abs(tfr1(eta,b)); 89 | end 90 | end 91 | end 92 | end 93 | 94 | tfr=Ts/(xrow/2); 95 | -------------------------------------------------------------------------------- /TFA-Net_inference/SET_Y2.m: -------------------------------------------------------------------------------- 1 | function [Te,IF] = SET_Y2(x,hlength); 2 | % Synchroextracting Transform 3 | % x : Signal. 4 | % hlength : Window length. 5 | 6 | % IF : Synchroextracting operator representation. 7 | % Te : SET result. 8 | % tfr : STFT result 9 | % This program is distributed in the hope that it will be useful, 10 | % but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 12 | % 13 | % Written by YuGang in Shandong University at 2016.5.13. 14 | nfft=256; 15 | [xrow,xcol] = size(x); 16 | 17 | N=xrow; 18 | 19 | if (xcol~=1), 20 | error('X must be column vector'); 21 | end; 22 | 23 | if (nargin < 2), 24 | hlength=round(xrow/8); 25 | end; 26 | 27 | t=1:N; 28 | [trow,tcol] = size(t); 29 | 30 | hlength=hlength+1-rem(hlength,2); 31 | ht = linspace(-0.5,0.5,hlength);ht=ht'; 32 | 33 | % Gaussian window 34 | % h = exp(-pi/0.32^2*ht.^2); 35 | h = kaiser(min(hlength,N),10); 36 | % derivative of window 37 | % dh = -2*pi/0.32^2*ht .* h; % g' 38 | dh=winderi(h,1); 39 | [hrow,hcol]=size(h); Lh=(hrow-1)/2; 40 | 41 | N=xrow; 42 | t=1:xrow; 43 | 44 | [trow,tcol] = size(t); 45 | 46 | 47 | % Pad the signal vector x 48 | if rem(hlength,2)==1 49 | xp = [zeros((hlength-1)/2,1) ; x ; zeros((hlength-1)/2,1)]; 50 | else 51 | xp = [zeros((hlength)/2,1) ; x ; zeros((hlength-2)/2,1)]; 52 | end 53 | xin = stft_rearrage(xp,length(xp),hlength,hlength-1,1); 54 | 55 | % Compute the STFT 56 | [tfr1,fout] = PerformDFT(bsxfun(@times,h,xin),nfft,1); 57 | tfr2 = PerformDFT(bsxfun(@times,dh,xin),nfft,1); 58 | 59 | m = floor(hlength/2); 60 | inds = 0:nfft-1; 61 | ez = exp(-1i*2*pi*m*inds/nfft)'; 62 | sstout = bsxfun(@times,tfr1,ez); 63 | 64 | omega = zeros (nfft,tcol); 65 | fout=0:1/nfft:1-1/nfft; 66 | ft=1:nfft; 67 | for b=1:tcol 68 | fcorr=-imag(tfr2(ft,b)./tfr1(ft,b)); 69 | fcorr(~isfinite(fcorr)) = 0; 70 | omega(:,b) =1+ mod(round((fout'+fcorr)*(nfft-1)),nfft); 71 | % omega(:,b)=ft'-imag(tfr2(ft,b)./tfr1(ft,b)); 72 | end 73 | E=mean(abs(x)); 74 | IF=zeros(nfft,N); 75 | for i=1:nfft%frequency 76 | for j=1:N%time 77 | % if abs(tfr1(i,j))>0.8*E%if you are interested in weak signals, you can delete this line. 78 | %if abs(1-real(va*1i*tfr2(i,j)/2/pi./tfr1(i,j)))<0.5 79 | if abs(omega(i,j)-i)<0.5 80 | IF(i,j)=1; 81 | end 82 | % end 83 | end 84 | end 85 | 86 | Te=tfr1.*IF; 87 | 88 | % %The following code is an alternative way to estimate IF. 89 | % %In theroy, they are same. 90 | % omega = zeros(round(N/2),tcol); 91 | % for b=1:N 92 | % omega(:,b) = (ft-1)'+real(va*1i*tfr2(ft,b)/2/pi./tfr1(ft,b)); 93 | % end 94 | % for i=1:round(N/2)%frequency 95 | % for j=1:N%time 96 | % if abs(tfr1(i,j))>0.8*E%default frequency resolution is 1Hz. 97 | % if abs(omega(i,j)-i)<0.5%default frequency resolution is 1Hz. 98 | % IF(i,j)=1; 99 | % end 100 | % end 101 | % end 102 | % end 103 | -------------------------------------------------------------------------------- /TFA-Net_inference/SST2.m: -------------------------------------------------------------------------------- 1 | function [Ts] = SST2(x,hlength); 2 | % Computes the SST (Ts) of the signal x. 3 | % INPUT 4 | % x : Signal needed to be column vector. 5 | % hlength: The hlength of window function. 6 | % OUTPUT 7 | % Ts : The SST 8 | 9 | [xrow,xcol] = size(x); 10 | nfft=256; 11 | if (xcol~=1), 12 | error('X must be column vector'); 13 | end; 14 | 15 | if (nargin < 1), 16 | error('At least 1 parameter is required'); 17 | end; 18 | 19 | if (nargin < 2), 20 | hlength=round(xrow/5); 21 | end; 22 | 23 | Siglength=xrow; 24 | hlength=hlength+1-rem(hlength,2); 25 | ht = linspace(-0.5,0.5,hlength);ht=ht'; 26 | 27 | % Gaussian window 28 | % h = exp(-pi/0.32^2*ht.^2); 29 | h = kaiser(min(hlength,Siglength),10); 30 | % derivative of window 31 | % dh = -2*pi/0.32^2*ht .* h; % g' 32 | dh=winderi(h,1); 33 | [hrow,hcol]=size(h); Lh=(hrow-1)/2; 34 | 35 | N=xrow; 36 | t=1:xrow; 37 | 38 | [trow,tcol] = size(t); 39 | 40 | 41 | tfr1= zeros (N,tcol) ; 42 | tfr2= zeros (N,tcol) ; 43 | 44 | 45 | Ts= zeros (nfft,tcol) ; 46 | % Pad the signal vector x 47 | if rem(hlength,2)==1 48 | xp = [zeros((hlength-1)/2,1) ; x ; zeros((hlength-1)/2,1)]; 49 | else 50 | xp = [zeros((hlength)/2,1) ; x ; zeros((hlength-2)/2,1)]; 51 | end 52 | xin = stft_rearrage(xp,length(xp),hlength,hlength-1,1); 53 | 54 | % Compute the STFT 55 | [tfr1,fout] = PerformDFT(bsxfun(@times,h,xin),nfft,1); 56 | tfr2 = PerformDFT(bsxfun(@times,dh,xin),nfft,1); 57 | 58 | m = floor(hlength/2); 59 | inds = 0:nfft-1; 60 | ez = exp(-1i*2*pi*m*inds/nfft)'; 61 | sstout = bsxfun(@times,tfr1,ez); 62 | 63 | ft = 1:nfft; 64 | bt = 1:N; 65 | 66 | %%operator omega 67 | nb = length(bt); 68 | neta = length(ft); 69 | 70 | omega = zeros (nfft,tcol); 71 | 72 | fout=0:1/nfft:1-1/nfft; 73 | for b=1:nb 74 | fcorr=-imag(tfr2(ft,b)./tfr1(ft,b)); 75 | fcorr(~isfinite(fcorr)) = 0; 76 | omega(:,b) =1+ mod(round((fout'+fcorr)*(nfft-1)),nfft); 77 | % omega(:,b)=ft'-imag(tfr2(ft,b)./tfr1(ft,b)); 78 | end 79 | % omega=round(omega); 80 | 81 | for b=1:nb%time 82 | % Reassignment step 83 | for eta=1:neta%frequency 84 | % if sstout(eta,b)>0.000001 85 | k = omega(eta,b); 86 | if k>=1 && k<=neta 87 | Ts(k,b) = Ts(k,b) + sstout(eta,b); 88 | end 89 | % end 90 | end 91 | end 92 | Ts=Ts/(xrow/2); 93 | end -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/TFNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import TFNet_util as util 3 | import numpy as np 4 | import h5py 5 | import os 6 | import scipy.io as sio 7 | 8 | class Trainer: 9 | 10 | def __init__(self,data_path): 11 | skip_path = os.path.join('model', 'TFA-Net.pth') 12 | device = torch.device('cpu' if torch.cuda.is_available() else 'cpu') 13 | self.skip_module, _, _, _, _ = util.load(skip_path, 'layer1',device) 14 | self.skip_module.cpu() 15 | self.skip_module.eval() 16 | self.data_path=data_path 17 | 18 | 19 | def inference(self): 20 | path = os.path.join(self.data_path, 'matlab_real2.h5') 21 | f = h5py.File(path, 'r') 22 | real_data2 = f['matlab_real2'][:] 23 | f.close() 24 | path = os.path.join(self.data_path, 'matlab_imag2.h5') 25 | f = h5py.File(path, 'r') 26 | imag_data2 = f['matlab_imag2'][:] 27 | f.close() 28 | path = os.path.join(self.data_path, 'bz.h5') 29 | f = h5py.File(path, 'r') 30 | bz = f['bz'][:] 31 | f.close() 32 | 33 | N=256 34 | signal_50dB2 = np.zeros([int(bz), 2, N]).astype(np.float32) 35 | signal_50dB2[:, 0,:] = (real_data2.astype(np.float32)).T 36 | signal_50dB2[:, 1,:] = (imag_data2.astype(np.float32)).T 37 | 38 | 39 | with torch.no_grad(): 40 | fr_50dB2 = self.skip_module(torch.tensor(signal_50dB2)) 41 | fr_50dB2 = fr_50dB2.cpu().data.numpy() 42 | path = os.path.join(self.data_path, 'data1_resfreq.mat') 43 | sio.savemat(path, {'data1_resfreq':fr_50dB2}) 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/TFNet_model_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import TFNet_util as util 3 | import numpy as np 4 | import h5py 5 | import os 6 | import scipy.io as sio 7 | 8 | class Trainer: 9 | 10 | def __init__(self,data_path): 11 | skip_path = os.path.join('model', 'RED-Net.pth') 12 | device = torch.device('cpu' if torch.cuda.is_available() else 'cpu') 13 | self.skip_module, _, _, _, _ = util.load(skip_path, 'layer1',device) 14 | self.skip_module.cpu() 15 | self.skip_module.eval() 16 | self.data_path=data_path 17 | 18 | 19 | def inference(self): 20 | path = os.path.join(self.data_path, 'matlab_real2.h5') 21 | f = h5py.File(path, 'r') 22 | real_data2 = f['matlab_real2'][:] 23 | f.close() 24 | path = os.path.join(self.data_path, 'matlab_imag2.h5') 25 | f = h5py.File(path, 'r') 26 | imag_data2 = f['matlab_imag2'][:] 27 | f.close() 28 | path = os.path.join(self.data_path, 'bz.h5') 29 | f = h5py.File(path, 'r') 30 | bz = f['bz'][:] 31 | f.close() 32 | print(real_data2.shape) 33 | 34 | N=256 35 | signal_50dB2 = np.zeros([int(bz), 8, N//2,N]).astype(np.float32) 36 | signal_50dB2[:, 0:4,:,:] = (real_data2.astype(np.float32)).transpose((3,2,1,0)) 37 | signal_50dB2[:, 4:8,:,:] = (imag_data2.astype(np.float32)).transpose((3,2,1,0)) 38 | 39 | with torch.no_grad(): 40 | fr_50dB2 = self.skip_module(torch.tensor(signal_50dB2)) 41 | fr_50dB2 = fr_50dB2.cpu().data.numpy() 42 | path = os.path.join(self.data_path, 'data1_resfreq.mat') 43 | sio.savemat(path, {'data1_resfreq':fr_50dB2}) 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/TFNet_server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from TFNet_model_stft import Trainer 3 | from TFNet_model import Trainer 4 | 5 | import os 6 | app = Flask(__name__) 7 | 8 | @app.route('/') 9 | def getLabel(): 10 | pm.inference() 11 | return "ok" 12 | 13 | if __name__ == '__main__': 14 | para_dict={} 15 | f = open("./config/config.ini", encoding="utf-8") 16 | contents=f.read().splitlines() 17 | for line in contents: 18 | tmp=line.split('=') 19 | para_dict[tmp[0]]=tmp[1] 20 | f.close() 21 | if not os.path.exists(para_dict['data_path']): # 如果路径不存在 22 | os.makedirs(para_dict['data_path']) 23 | pm = Trainer(para_dict['data_path']) 24 | app.run(host='127.0.0.1', port=para_dict['server_port'], debug=False) 25 | 26 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/TFNet_server.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | 3 | block_cipher = None 4 | SETUP_DIR = '.\\' 5 | 6 | a = Analysis(['TFNet_server.py', 7 | 'complexModule.py', 8 | 'complexLayers.py', 9 | 'complexFunctions.py', 10 | 'TFNet_model.py', 11 | 'TFNet_util.py'], 12 | pathex=['.\\'], 13 | binaries=[], 14 | datas=[(SETUP_DIR+'model','model'),(SETUP_DIR+'config','config')], 15 | hiddenimports=[], 16 | hookspath=[], 17 | runtime_hooks=[], 18 | excludes=[], 19 | win_no_prefer_redirects=False, 20 | win_private_assemblies=False, 21 | cipher=block_cipher, 22 | noarchive=False) 23 | pyz = PYZ(a.pure, a.zipped_data, 24 | cipher=block_cipher) 25 | exe = EXE(pyz, 26 | a.scripts, 27 | [], 28 | exclude_binaries=True, 29 | name='TFNet_server', 30 | debug=False, 31 | bootloader_ignore_signals=False, 32 | strip=False, 33 | upx=True, 34 | console=True ) 35 | coll = COLLECT(exe, 36 | a.binaries, 37 | a.zipfiles, 38 | a.datas, 39 | strip=False, 40 | upx=True, 41 | upx_exclude=[], 42 | name='TFNet') 43 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/TFNet_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import errno 4 | import complexModule 5 | 6 | def model_parameters(model): 7 | num_params = 0 8 | for param in model.parameters(): 9 | num_params += param.numel() 10 | return num_params 11 | 12 | 13 | def symlink_force(target, link_name): 14 | try: 15 | os.symlink(target, link_name) 16 | except OSError as e: 17 | if e.errno == errno.EEXIST: 18 | os.remove(link_name) 19 | os.symlink(target, link_name) 20 | else: 21 | raise e 22 | 23 | 24 | def save(model, optimizer, scheduler, args, epoch, module_type): 25 | checkpoint = { 26 | 'epoch': epoch, 27 | 'model': model.state_dict(), 28 | 'optimizer': optimizer.state_dict(), 29 | 'scheduler': scheduler.state_dict(), 30 | 'args': args, 31 | } 32 | if scheduler is not None: 33 | checkpoint["scheduler"] = scheduler.state_dict() 34 | if not os.path.exists(os.path.join(args.output_dir, module_type)): 35 | os.makedirs(os.path.join(args.output_dir, module_type)) 36 | cp = os.path.join(args.output_dir, module_type, 'last.pth') 37 | fn = os.path.join(args.output_dir, module_type, 'epoch_'+str(epoch)+'.pth') 38 | torch.save(checkpoint, fn) 39 | symlink_force(fn, cp) 40 | 41 | 42 | def load(fn, module_type, device = torch.device('cuda')): 43 | checkpoint = torch.load(fn, map_location=device) 44 | args = checkpoint['args'] 45 | if device == torch.device('cpu'): 46 | args.use_cuda = False 47 | if module_type == 'skip': 48 | model = complexModule.set_skip_module(args) 49 | elif module_type == 'layer1': 50 | model = complexModule.set_layer1_module(args) 51 | else: 52 | raise NotImplementedError('Module type not recognized') 53 | model.load_state_dict(checkpoint['model']) 54 | optimizer, scheduler = set_optim(args, model, module_type) 55 | if checkpoint["scheduler"] is not None: 56 | scheduler.load_state_dict(checkpoint["scheduler"]) 57 | optimizer.load_state_dict(checkpoint["optimizer"]) 58 | return model, optimizer, scheduler, args, checkpoint['epoch'] 59 | 60 | 61 | def set_optim(args, module, module_type): 62 | if module_type == 'fr': 63 | # params = list(module.parameters()) + list(adaptive.parameters()) 64 | optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fr) 65 | elif module_type == 'fc': 66 | optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fc) 67 | elif module_type == 'skip': 68 | optimizer = torch.optim.RMSprop(module.parameters(), lr=args.lr_fr, alpha=0.9) 69 | # optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fr) 70 | elif module_type == 'rdn': 71 | optimizer = torch.optim.RMSprop(module.parameters(), lr=args.lr_fr, alpha=0.9) 72 | # optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fr) 73 | elif module_type == 'freq': 74 | optimizer = torch.optim.RMSprop(module.parameters(), lr=args.lr_fr, alpha=0.9) 75 | # optimizer = torch.optim.Adam(module.parameters(), lr=args.lr_fr) 76 | elif module_type == 'deepfreqRes': 77 | optimizer = torch.optim.RMSprop(module.parameters(), lr=args.lr_fr, alpha=0.9) 78 | elif module_type == 'layer1': 79 | optimizer = torch.optim.RMSprop(module.parameters(), lr=args.lr_fr, alpha=0.9) 80 | 81 | else: 82 | raise(ValueError('Expected module_type to be fr or fc but got {}'.format(module_type))) 83 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=7, factor=0.5, verbose=True) 84 | return optimizer, scheduler 85 | 86 | 87 | def print_args(logger, args): 88 | message = '' 89 | for k, v in sorted(vars(args).items()): 90 | message += '\n{:>30}: {:<30}'.format(str(k), str(v)) 91 | logger.info(message) 92 | 93 | args_path = os.path.join(args.output_dir, 'run.args') 94 | with open(args_path, 'wt') as args_file: 95 | args_file.write(message) 96 | args_file.write('\n') 97 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/__pycache__/TFNet_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/__pycache__/TFNet_model.cpython-36.pyc -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/__pycache__/TFNet_model_stft.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/__pycache__/TFNet_model_stft.cpython-36.pyc -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/__pycache__/TFNet_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/__pycache__/TFNet_util.cpython-36.pyc -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/__pycache__/complexFunctions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/__pycache__/complexFunctions.cpython-36.pyc -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/__pycache__/complexLayers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/__pycache__/complexLayers.cpython-36.pyc -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/__pycache__/complexModule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/__pycache__/complexModule.cpython-36.pyc -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/complexFunctions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @author: spopoff 6 | """ 7 | 8 | from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d 9 | import torch 10 | 11 | 12 | def complex_matmul(A, B): 13 | ''' 14 | Performs the matrix product between two complex matrices 15 | ''' 16 | 17 | outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) 18 | outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) 19 | 20 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) 21 | 22 | 23 | def complex_avg_pool2d(input, *args, **kwargs): 24 | ''' 25 | Perform complex average pooling. 26 | ''' 27 | absolute_value_real = avg_pool2d(input.real, *args, **kwargs) 28 | absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs) 29 | 30 | return absolute_value_real.type(torch.complex64) + 1j * absolute_value_imag.type(torch.complex64) 31 | 32 | 33 | def complex_relu(input): 34 | return relu(input.real).type(torch.complex64) + 1j * relu(input.imag).type(torch.complex64) 35 | 36 | 37 | def _retrieve_elements_from_indices(tensor, indices): 38 | flattened_tensor = tensor.flatten(start_dim=-2) 39 | output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices) 40 | return output 41 | 42 | 43 | def complex_max_pool2d(input, kernel_size, stride=None, padding=0, 44 | dilation=1, ceil_mode=False, return_indices=False): 45 | ''' 46 | Perform complex max pooling by selecting on the absolute value on the complex values. 47 | ''' 48 | absolute_value, indices = max_pool2d( 49 | input.abs(), 50 | kernel_size=kernel_size, 51 | stride=stride, 52 | padding=padding, 53 | dilation=dilation, 54 | ceil_mode=ceil_mode, 55 | return_indices=True 56 | ) 57 | # performs the selection on the absolute values 58 | absolute_value = absolute_value.type(torch.complex64) 59 | # retrieve the corresonding phase value using the indices 60 | # unfortunately, the derivative for 'angle' is not implemented 61 | angle = torch.atan2(input.imag, input.real) 62 | # get only the phase values selected by max pool 63 | angle = _retrieve_elements_from_indices(angle, indices) 64 | return absolute_value \ 65 | * (torch.cos(angle).type(torch.complex64) + 1j * torch.sin(angle).type(torch.complex64)) 66 | 67 | 68 | def complex_dropout(input, p=0.5, training=True): 69 | # need to have the same dropout mask for real and imaginary part, 70 | # this not a clean solution! 71 | mask = torch.ones_like(input).type(torch.float32) 72 | mask = dropout(mask, p, training) * 1 / (1 - p) 73 | return mask * input 74 | 75 | 76 | def complex_dropout2d(input, p=0.5, training=True): 77 | # need to have the same dropout mask for real and imaginary part, 78 | # this not a clean solution! 79 | mask = torch.ones_like(input).type(torch.float32) 80 | mask = dropout2d(mask, p, training) * 1 / (1 - p) 81 | return mask * input -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/complexModule.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from complexLayers import ComplexConv1d 4 | 5 | 6 | 7 | def set_layer1_module(args): 8 | """ 9 | Create a frequency-representation module 10 | """ 11 | net = None 12 | if args.fr_module_type == 'fr': 13 | assert args.fr_size == args.fr_inner_dim * args.fr_upsampling, \ 14 | 'The desired size of the frequency representation (fr_size) must be equal to inner_dim*upsampling' 15 | net = FrequencyRepresentationModule_TFA_Net(signal_dim=args.signal_dim, n_filters=args.fr_n_filters, 16 | inner_dim=args.fr_inner_dim, n_layers=args.fr_n_layers, 17 | upsampling=args.fr_upsampling, kernel_size=args.fr_kernel_size, 18 | kernel_out=args.fr_kernel_out) 19 | # net = FrequencyRepresentationModule_RED_Net(signal_dim=args.signal_dim, n_filters=args.fr_n_filters, 20 | # inner_dim=args.fr_inner_dim, n_layers=args.fr_n_layers, 21 | # upsampling=args.fr_upsampling, kernel_size=args.fr_kernel_size, 22 | # kernel_out=args.fr_kernel_out) 23 | 24 | else: 25 | raise NotImplementedError('Frequency representation module type not implemented') 26 | if args.use_cuda: 27 | net.cuda() 28 | return net 29 | 30 | import math 31 | class REDNet30_stft(nn.Module): 32 | def __init__(self, num_layers=15, num_features=8): 33 | super(REDNet30_stft, self).__init__() 34 | self.num_layers = num_layers 35 | 36 | conv_layers = [] 37 | deconv_layers = [] 38 | 39 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1), 40 | nn.ReLU(inplace=True))) 41 | for i in range(num_layers - 1): 42 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features*2, num_features*2, kernel_size=3, padding=1), 43 | nn.ReLU(inplace=True))) 44 | 45 | for i in range(num_layers - 1): 46 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features*2, num_features*2, kernel_size=3, padding=1), 47 | nn.ReLU(inplace=True))) 48 | deconv_layers.append(nn.ConvTranspose2d(num_features*2, num_features, kernel_size=3, stride=2, padding=1, output_padding=1)) 49 | 50 | self.conv_layers = nn.Sequential(*conv_layers) 51 | self.deconv_layers = nn.Sequential(*deconv_layers) 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | conv_feats = [] 58 | for i in range(self.num_layers): 59 | x = self.conv_layers[i](x) 60 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1: 61 | conv_feats.append(x) 62 | 63 | conv_feats_idx = 0 64 | for i in range(self.num_layers): 65 | x = self.deconv_layers[i](x) 66 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats): 67 | conv_feat = conv_feats[-(conv_feats_idx + 1)] 68 | conv_feats_idx += 1 69 | x = x + conv_feat 70 | x = self.relu(x) 71 | 72 | x += residual 73 | x = self.relu(x) 74 | 75 | return x 76 | class REDNet30(nn.Module): 77 | def __init__(self, num_layers=15, num_features=8): 78 | super(REDNet30, self).__init__() 79 | self.num_layers = num_layers 80 | 81 | conv_layers = [] 82 | deconv_layers = [] 83 | 84 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1), 85 | nn.ReLU(inplace=True))) 86 | for i in range(num_layers - 1): 87 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features*2, num_features*2, kernel_size=3, padding=1), 88 | nn.ReLU(inplace=True))) 89 | 90 | for i in range(num_layers - 1): 91 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features*2, num_features*2, kernel_size=3, padding=1), 92 | nn.ReLU(inplace=True))) 93 | deconv_layers.append(nn.ConvTranspose2d(num_features*2, num_features, kernel_size=3, stride=2, padding=1, output_padding=1)) 94 | 95 | self.conv_layers = nn.Sequential(*conv_layers) 96 | self.deconv_layers = nn.Sequential(*deconv_layers) 97 | self.relu = nn.ReLU(inplace=True) 98 | 99 | def forward(self, x): 100 | residual = x 101 | 102 | conv_feats = [] 103 | for i in range(self.num_layers): 104 | x = self.conv_layers[i](x) 105 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1: 106 | conv_feats.append(x) 107 | 108 | conv_feats_idx = 0 109 | for i in range(self.num_layers): 110 | x = self.deconv_layers[i](x) 111 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats): 112 | conv_feat = conv_feats[-(conv_feats_idx + 1)] 113 | conv_feats_idx += 1 114 | x = x + conv_feat 115 | x = self.relu(x) 116 | 117 | x += residual 118 | x = self.relu(x) 119 | 120 | return x 121 | 122 | class FrequencyRepresentationModule_TFA_Net(nn.Module): 123 | def __init__(self, signal_dim=50, n_filters=8, n_layers=3, inner_dim=125, 124 | kernel_size=3, upsampling=2, kernel_out=3): 125 | super().__init__() 126 | 127 | self.n_filters = n_filters 128 | self.inner = inner_dim 129 | self.n_layers = n_layers 130 | # self.dnoise1=ComplexConv1d2(1,n_filters,kernel_size=(1,5),padding=(0,5//2)) 131 | # self.dnoise2 = ComplexConv1d(1, n_filters, kernel_size=(1, 7), padding=(0, 7// 2)) 132 | # self.concat_layer=ComplexConv1d(2*n_filters,n_filters,kernel_size=(1,1)) 133 | self.in_layer1 = ComplexConv1d(1, inner_dim * n_filters, kernel_size=(1, 31), padding=(0, 31 // 2), 134 | bias=False) 135 | # self.win = ComplexConv2d(n_filters, n_filters, kernel_size=(127, 1), padding=(127 // 2, 0), 136 | # bias=False,padding_mode='circular') 137 | self.rednet=REDNet30(self.n_layers,num_features=n_filters) 138 | self.out_layer = nn.ConvTranspose2d(n_filters, 1, (3, 1), stride=(upsampling, 1), 139 | padding=(1, 0), output_padding=(1, 0), bias=False) 140 | # self.out_layer2=nn.ReLU() 141 | # self.out_layer3=nn.Conv2d(n_filters//2, 1, (3, 3),padding=(1, 1)) 142 | 143 | def forward(self, x): 144 | bsz = x.size(0) 145 | inp_real = x[:, 0, :].view(bsz, 1, 1, -1) 146 | inp_imag = x[:, 1, :].view(bsz, 1, 1, -1) 147 | inp = torch.cat((inp_real, inp_imag), 1) 148 | # x1=self.dnoise1(inp) 149 | # x1real,x1imag=torch.chunk(x1,2,1) 150 | # x2 = self.dnoise2(inp) 151 | # x2real, x2imag = torch.chunk(x2, 2, 1) 152 | # x=torch.cat((x1real,x2real,x1imag,x2imag),1) 153 | # x=self.concat_layer(x) 154 | x1 = self.in_layer1(inp) 155 | xreal,ximag=torch.chunk(x1,2,1) 156 | xreal = xreal.view(bsz, self.n_filters, self.inner, -1) 157 | ximag = ximag.view(bsz, self.n_filters, self.inner, -1) 158 | # x = torch.cat((xreal, ximag), 1) 159 | # x=self.win(x) 160 | # xreal,ximag=torch.chunk(x,2,1) 161 | x=torch.sqrt(torch.pow(xreal, 2) + torch.pow(ximag, 2)) 162 | x=self.rednet(x) 163 | 164 | x = self.out_layer(x).squeeze(-3).transpose(1, 2) 165 | # x=self.out_layer2(x) 166 | # x=self.out_layer3(x) 167 | return x 168 | 169 | class FrequencyRepresentationModule_RED_Net(nn.Module): 170 | def __init__(self, signal_dim=50, n_filters=8, n_layers=3, inner_dim=125, 171 | kernel_size=3, upsampling=2, kernel_out=3): 172 | super().__init__() 173 | 174 | self.n_filters = n_filters 175 | self.inner = inner_dim 176 | self.n_layers = n_layers 177 | self.in_layer=nn.Conv2d(4, n_filters, kernel_size=1) 178 | self.rednet=REDNet30_stft(self.n_layers,num_features=n_filters) 179 | self.out_layer = nn.ConvTranspose2d(n_filters, 1, (3, 1), stride=(upsampling, 1), 180 | padding=(1, 0), output_padding=(1, 0), bias=False) 181 | 182 | # self.out_layer2=nn.ReLU() 183 | # self.out_layer3=nn.Conv2d(n_filters//2, 1, (3, 3),padding=(1, 1)) 184 | 185 | def forward(self, x): 186 | bsz = x.size(0) 187 | xreal = x[:, 0, :].view(bsz, 1, 1, -1) 188 | ximag = x[:, 1, :].view(bsz, 1, 1, -1) 189 | 190 | xreal2 = x[:, 2, :].view(bsz, 1, 1, -1) 191 | ximag2 = x[:, 3, :].view(bsz, 1, 1, -1) 192 | 193 | xreal3 = x[:, 4, :].view(bsz, 1, 1, -1) 194 | ximag3 = x[:, 5, :].view(bsz, 1, 1, -1) 195 | 196 | xreal4 = x[:, 6, :].view(bsz, 1, 1, -1) 197 | ximag4 = x[:, 7, :].view(bsz, 1, 1, -1) 198 | 199 | xreal = xreal.view(bsz, 1, self.inner, -1) 200 | ximag = ximag.view(bsz, 1, self.inner, -1) 201 | 202 | xreal2 = xreal2.view(bsz, 1, self.inner, -1) 203 | ximag2 = ximag2.view(bsz, 1, self.inner, -1) 204 | 205 | xreal3 = xreal3.view(bsz, 1, self.inner, -1) 206 | ximag3 = ximag3.view(bsz, 1, self.inner, -1) 207 | 208 | xreal4 = xreal4.view(bsz, 1, self.inner, -1) 209 | ximag4 = ximag4.view(bsz, 1, self.inner, -1) 210 | 211 | x=torch.sqrt(torch.pow(xreal, 2) + torch.pow(ximag, 2)) 212 | x2 = torch.sqrt(torch.pow(xreal2, 2) + torch.pow(ximag2, 2)) 213 | x3 = torch.sqrt(torch.pow(xreal3, 2) + torch.pow(ximag3, 2)) 214 | x4 = torch.sqrt(torch.pow(xreal4, 2) + torch.pow(ximag4, 2)) 215 | x=torch.cat((x,x2,x3,x4),1) 216 | x=self.in_layer(x) 217 | x=self.rednet(x) 218 | x = self.out_layer(x).squeeze(-3).transpose(1, 2) 219 | # x=self.out_layer2(x) 220 | # x=self.out_layer3(x) 221 | return x 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/config/config.ini: -------------------------------------------------------------------------------- 1 | data_path=../ 2 | server_port=5012 3 | -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/model/RED_Net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/model/RED_Net.pth -------------------------------------------------------------------------------- /TFA-Net_inference/TFA_exe/model/TFA-Net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/TFA_exe/model/TFA-Net.pth -------------------------------------------------------------------------------- /TFA-Net_inference/batdata2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/batdata2.mat -------------------------------------------------------------------------------- /TFA-Net_inference/computation_complexity.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | N=256; 5 | nfft=256; 6 | fsz=16; 7 | winlen=64; 8 | divide_times=1/2; 9 | times_times=2; 10 | x11=0.3; x22=2.7; 11 | y11=-1; y22=20; 12 | 13 | x1=0.45+5.555;x2=4.4+5.555; 14 | y1=-48.5;y2=-4; 15 | %% SIMU1 16 | fs=128; 17 | FS=fs; 18 | ts=1/fs; 19 | t = 0:ts:2-ts; 20 | Sig1 = exp(1i*2*pi*(8* t + 6 *sin(t))); % get the A(t) 21 | Sig2 = exp(1i*2*pi*(10 * t + 6 *sin(1.5*t) )); % get the A(t) 22 | Sig=Sig1+Sig2; 23 | data_reshape=Sig.'; 24 | 25 | ITER=1000; 26 | time_STFT=zeros(1,ITER); 27 | time_SST=zeros(1,ITER); 28 | time_SET=zeros(1,ITER); 29 | time_MSST=zeros(1,ITER); 30 | time_TFA=zeros(1,ITER); 31 | time_RS=zeros(1,ITER); 32 | time_WVD=zeros(1,ITER); 33 | time_G=zeros(1,ITER); 34 | for i=1:ITER 35 | i 36 | %% STFT 37 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 38 | tic 39 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 40 | aa=toc; 41 | time_STFT(i)=aa; 42 | %% SST 43 | tic 44 | spc_SST = SST2(data_reshape,winlen); 45 | aa=toc; 46 | time_SST(i)=aa; 47 | %% SET 48 | tic 49 | [spc_SET,~] = SET_Y2(data_reshape,winlen); 50 | aa=toc; 51 | time_SET(i)=aa; 52 | %% MSST 53 | tic 54 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 55 | aa=toc; 56 | time_MSST(i)=aa; 57 | %% ResFreq 58 | data_rsh=data_reshape.'; 59 | mv=max(abs(data_rsh)); 60 | ret=data_rsh/mv; 61 | bz=size(ret,1); 62 | 63 | if ~exist('matlab_real2.h5','file')==0 64 | delete('matlab_real2.h5') 65 | end 66 | if ~exist('matlab_imag2.h5','file')==0 67 | delete('matlab_imag2.h5') 68 | end 69 | if ~exist('bz.h5','file')==0 70 | delete('bz.h5') 71 | end 72 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 73 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 74 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 75 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 76 | h5create('bz.h5','/bz',size(bz)); 77 | h5write('bz.h5','/bz',bz); 78 | tic 79 | net_flag=system('curl -s 127.0.0.1:5012/'); 80 | aa=toc; 81 | time_TFA(i)=aa; 82 | %% GWarblet 83 | tic 84 | [Spec,f] = GWarblet(data_reshape,FS,0,1,nfft,winlen); 85 | [v l] = max(Spec,[],1); 86 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 87 | WinLen =winlen*4; 88 | [Spec,f] = GWarblet(data_reshape,FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 89 | aa=toc; 90 | time_G(i)=aa; 91 | %% WVD 92 | tic 93 | Spec=wvd(data_reshape,'smoothedPseudo','NumFrequencyPoints',nfft); 94 | aa=toc; 95 | time_WVD(i)=aa; 96 | %% RS 97 | tic 98 | spc_RS = RS(data_reshape,winlen); 99 | aa=toc; 100 | time_RS(i)=aa; 101 | end 102 | 103 | time=[mean(time_STFT),mean(time_WVD),mean(time_G),mean(time_RS),mean(time_SST),mean(time_SET),mean(time_MSST),mean(time_TFA)]; 104 | save time_cost.mat time 105 | -------------------------------------------------------------------------------- /TFA-Net_inference/data1_resfreq.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/data1_resfreq.mat -------------------------------------------------------------------------------- /TFA-Net_inference/dianji_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/dianji_data.mat -------------------------------------------------------------------------------- /TFA-Net_inference/exp1_simu.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | N=256; 5 | nfft=256; 6 | fsz=23; 7 | winlen=64; 8 | divide_times=1/2; 9 | times_times=2; 10 | x11=0.3; x22=2.89; 11 | y11=-1; y22=20; 12 | 13 | ylim_max=30; 14 | x1=5.9;x2=4.4+5.5; 15 | y1=-49.5;y2=-10; 16 | %% SIMU1 17 | fs=100; 18 | ts=1/fs; 19 | t = 0:ts:10-ts; 20 | Sig1 = exp(1i*2*pi*(8* t + 6 *sin(t))); % get the A(t) 21 | Sig2 = exp(1i*2*pi*(10 * t + 6 *sin(1.5*t) )); % get the A(t) 22 | Sig=Sig1+Sig2; 23 | data_reshape=Sig.'; 24 | ydelta=fs/nfft; 25 | yaxis=(0:ydelta:fs-ydelta)-fs/2; 26 | if1=(8 + 6*1 *cos(1*t)) ; 27 | if2=(10 + 6*1.5 *cos(1.5*t)) ; 28 | 29 | figure; 30 | plot(t,if1) 31 | hold on 32 | plot(t,if2) 33 | ylim([min(yaxis) ylim_max]) 34 | set(gca,'ydir','reverse') 35 | xlabel({'Time / sec';'(a)'}) 36 | ylabel('Doppler / Hz') 37 | title('Ground Truth') 38 | set(gca,'FontSize',fsz); 39 | set(get(gca,'XLabel'),'FontSize',fsz); 40 | set(get(gca,'YLabel'),'FontSize',fsz); 41 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_Ground_Truth'); 42 | saveas(gcf, fname); 43 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_Ground_Truth','.pdf'); 44 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 45 | 46 | %% STFT 47 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 48 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 49 | data_reshape2=[zeros(winlen*divide_times/2,1);data_reshape;zeros(winlen*divide_times/2-1,1)]; 50 | spc_STFT2=abs(stft(data_reshape2,'Window',hamming(winlen*divide_times).','OverlapLength',winlen*divide_times-1,'FFTLength',nfft)); 51 | data_reshape3=[zeros(winlen*times_times/2,1);data_reshape;zeros(winlen*times_times/2-1,1)]; 52 | spc_STFT3=abs(stft(data_reshape3,'Window',hamming(winlen*times_times).','OverlapLength',winlen*times_times-1,'FFTLength',nfft)); 53 | tt=t; 54 | figure; 55 | imagesc(tt,yaxis,((abs(spc_STFT2)))); 56 | ylim([min(yaxis) ylim_max]) 57 | title(strcat('STFT, winlen=',num2str(winlen*divide_times))) 58 | xlabel({'Time / sec';'(c)'}) 59 | ylabel('Freq. / Hz') 60 | set(gca,'FontSize',fsz); 61 | set(get(gca,'XLabel'),'FontSize',fsz); 62 | set(get(gca,'YLabel'),'FontSize',fsz); 63 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 64 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 65 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 66 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 67 | axes('Position',[0.61,0.6,0.284,0.271]); 68 | imagesc(tt,yaxis,(abs(spc_STFT2))) 69 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 70 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_STFT_win',num2str(winlen*divide_times)); 71 | saveas(gcf, fname); 72 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_STFT_win',num2str(winlen*divide_times),'.pdf'); 73 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 74 | 75 | figure; 76 | imagesc(tt,yaxis,((abs(spc_STFT3)))); 77 | ylim([min(yaxis) ylim_max]) 78 | title(strcat('STFT, winlen=',num2str(winlen*times_times))) 79 | xlabel({'Time / sec';'(d)'}) 80 | ylabel('Freq. / Hz') 81 | set(gca,'FontSize',fsz); 82 | set(get(gca,'XLabel'),'FontSize',fsz); 83 | set(get(gca,'YLabel'),'FontSize',fsz); 84 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 85 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 86 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 87 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 88 | axes('Position',[0.61,0.6,0.284,0.271]); 89 | imagesc(tt,yaxis,(abs(spc_STFT3))) 90 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 91 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_STFT_win',num2str(winlen*times_times)); 92 | saveas(gcf, fname); 93 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_STFT_win',num2str(winlen*times_times),'.pdf'); 94 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 95 | 96 | figure; 97 | imagesc(tt,yaxis,((abs(spc_STFT)))); 98 | ylim([min(yaxis) ylim_max]) 99 | title(strcat('STFT, winlen=',num2str(winlen))) 100 | xlabel({'Time / sec';'(e)'}) 101 | ylabel('Freq. / Hz') 102 | set(gca,'FontSize',fsz); 103 | set(get(gca,'XLabel'),'FontSize',fsz); 104 | set(get(gca,'YLabel'),'FontSize',fsz); 105 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 106 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 107 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 108 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 109 | axes('Position',[0.61,0.6,0.284,0.271]); 110 | imagesc(tt,yaxis,(abs(spc_STFT))) 111 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 112 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_STFT_win',num2str(winlen)); 113 | saveas(gcf, fname); 114 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_STFT_win',num2str(winlen),'.pdf'); 115 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 116 | %% SST 117 | spc_SST = SST2(data_reshape,winlen); 118 | spc_SST2 = SST2(data_reshape,winlen*divide_times); 119 | spc_SST3 = SST2(data_reshape,winlen*times_times); 120 | spc_SST=abs(spc_SST); 121 | 122 | tt=t; 123 | figure; 124 | imagesc(tt,yaxis,(fftshift(abs(spc_SST2),1))); 125 | ylim([min(yaxis) ylim_max]) 126 | xlabel({'Time / sec';'(f)'}) 127 | ylabel('Freq. / Hz') 128 | title(strcat('SST, winlen=',num2str(winlen*divide_times))) 129 | set(gca,'FontSize',fsz); 130 | set(get(gca,'XLabel'),'FontSize',fsz); 131 | set(get(gca,'YLabel'),'FontSize',fsz); 132 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 133 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 134 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 135 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 136 | axes('Position',[0.61,0.6,0.284,0.271]); 137 | imagesc(tt,yaxis,(fftshift(abs(spc_SST2),1))); 138 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 139 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SST_win',num2str(winlen*divide_times)); 140 | saveas(gcf, fname); 141 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SST_win',num2str(winlen*divide_times),'.pdf'); 142 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 143 | 144 | figure; 145 | imagesc(tt,yaxis,(fftshift(abs(spc_SST3),1))); 146 | ylim([min(yaxis) ylim_max]) 147 | xlabel({'Time / sec';'(g)'}) 148 | ylabel('Freq. / Hz') 149 | title(strcat('SST, winlen=',num2str(winlen*times_times))) 150 | set(gca,'FontSize',fsz); 151 | set(get(gca,'XLabel'),'FontSize',fsz); 152 | set(get(gca,'YLabel'),'FontSize',fsz); 153 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 154 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 155 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 156 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 157 | axes('Position',[0.61,0.6,0.284,0.271]); 158 | imagesc(tt,yaxis,(fftshift(abs(spc_SST3),1))); 159 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 160 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SST_win',num2str(winlen*times_times)); 161 | saveas(gcf, fname); 162 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SST_win',num2str(winlen*times_times),'.pdf'); 163 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 164 | 165 | figure; 166 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 167 | ylim([min(yaxis) ylim_max]) 168 | xlabel({'Time / sec';'(h)'}) 169 | ylabel('Freq. / Hz') 170 | title(strcat('SST, winlen=',num2str(winlen))) 171 | set(gca,'FontSize',fsz); 172 | set(get(gca,'XLabel'),'FontSize',fsz); 173 | set(get(gca,'YLabel'),'FontSize',fsz); 174 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 175 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 176 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 177 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 178 | axes('Position',[0.61,0.6,0.284,0.271]); 179 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 180 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 181 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SST_win',num2str(winlen)); 182 | saveas(gcf, fname); 183 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SST_win',num2str(winlen),'.pdf'); 184 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 185 | 186 | %% SET 187 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 188 | [spc_SET2,tfr] = SET_Y2(data_reshape,winlen*divide_times); 189 | [spc_SET3,tfr] = SET_Y2(data_reshape,winlen*times_times); 190 | spc_SET=abs(spc_SET); 191 | tt=t; 192 | 193 | figure; 194 | imagesc(tt,yaxis,(fftshift(abs(spc_SET2),1))); 195 | ylim([min(yaxis) ylim_max]) 196 | title(strcat('SET, winlen=',num2str(winlen*divide_times))) 197 | xlabel({'Time / sec';'(i)'}) 198 | ylabel('Freq. / Hz') 199 | set(gca,'FontSize',fsz); 200 | set(get(gca,'XLabel'),'FontSize',fsz); 201 | set(get(gca,'YLabel'),'FontSize',fsz); 202 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 203 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 204 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 205 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 206 | axes('Position',[0.61,0.6,0.284,0.271]); 207 | imagesc(tt,yaxis,(fftshift(abs(spc_SET2),1))); 208 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 209 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SET_win',num2str(winlen*divide_times)); 210 | saveas(gcf, fname); 211 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SET_win',num2str(winlen*divide_times),'.pdf'); 212 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 213 | 214 | figure; 215 | imagesc(tt,yaxis,(fftshift(abs(spc_SET3),1))); 216 | ylim([min(yaxis) ylim_max]) 217 | title(strcat('SET, winlen=',num2str(winlen*times_times))) 218 | xlabel({'Time / sec';'(j)'}) 219 | ylabel('Freq. / Hz') 220 | set(gca,'FontSize',fsz); 221 | set(get(gca,'XLabel'),'FontSize',fsz); 222 | set(get(gca,'YLabel'),'FontSize',fsz); 223 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 224 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 225 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 226 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 227 | axes('Position',[0.61,0.6,0.284,0.271]); 228 | imagesc(tt,yaxis,(fftshift(abs(spc_SET3),1))); 229 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 230 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SET_win',num2str(winlen*times_times)); 231 | saveas(gcf, fname); 232 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SET_win',num2str(winlen*times_times),'.pdf'); 233 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 234 | 235 | figure; 236 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 237 | ylim([min(yaxis) ylim_max]) 238 | title(strcat('SET, winlen=',num2str(winlen))) 239 | xlabel({'Time / sec';'(k)'}) 240 | ylabel('Freq. / Hz') 241 | set(gca,'FontSize',fsz); 242 | set(get(gca,'XLabel'),'FontSize',fsz); 243 | set(get(gca,'YLabel'),'FontSize',fsz); 244 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 245 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 246 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 247 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 248 | axes('Position',[0.61,0.6,0.284,0.271]); 249 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 250 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 251 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SET_win',num2str(winlen)); 252 | saveas(gcf, fname); 253 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_SET_win',num2str(winlen),'.pdf'); 254 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 255 | 256 | %% MSST 257 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 258 | [spc_MSST2,tfr,omega2] = MSST_Y_new2(data_reshape,winlen*divide_times,3); 259 | [spc_MSST3,tfr,omega2] = MSST_Y_new2(data_reshape,winlen*times_times,3); 260 | spc_MSST=spc_MSST1; 261 | 262 | tt=t; 263 | figure; 264 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST2),1))); 265 | ylim([min(yaxis) ylim_max]) 266 | title(strcat('MSST, winlen=',num2str(winlen*divide_times))) 267 | xlabel({'Time / sec';'(l)'}) 268 | ylabel('Freq. / Hz') 269 | set(gca,'FontSize',fsz); 270 | set(get(gca,'XLabel'),'FontSize',fsz); 271 | set(get(gca,'YLabel'),'FontSize',fsz); 272 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 273 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 274 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 275 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 276 | axes('Position',[0.61,0.6,0.284,0.271]); 277 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST2),1))); 278 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 279 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_MSST_win',num2str(winlen*divide_times)); 280 | saveas(gcf, fname); 281 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_MSST_win',num2str(winlen*divide_times),'.pdf'); 282 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 283 | figure; 284 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST3),1))); 285 | ylim([min(yaxis) ylim_max]) 286 | title(strcat('MSST, winlen=',num2str(winlen*times_times))) 287 | xlabel({'Time / sec';'(m)'}) 288 | ylabel('Freq. / Hz') 289 | set(gca,'FontSize',fsz); 290 | set(get(gca,'XLabel'),'FontSize',fsz); 291 | set(get(gca,'YLabel'),'FontSize',fsz); 292 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 293 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 294 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 295 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 296 | axes('Position',[0.61,0.6,0.284,0.271]); 297 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST3),1))); 298 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 299 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_MSST_win',num2str(winlen*times_times)); 300 | saveas(gcf, fname); 301 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_MSST_win',num2str(winlen*times_times),'.pdf'); 302 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 303 | 304 | figure; 305 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 306 | ylim([min(yaxis) ylim_max]) 307 | title(strcat('MSST, winlen=',num2str(winlen))) 308 | xlabel({'Time / sec';'(n)'}) 309 | ylabel('Freq. / Hz') 310 | set(gca,'FontSize',fsz); 311 | set(get(gca,'XLabel'),'FontSize',fsz); 312 | set(get(gca,'YLabel'),'FontSize',fsz); 313 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 314 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 315 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 316 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 317 | axes('Position',[0.61,0.6,0.284,0.271]); 318 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 319 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 320 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_MSST_win',num2str(winlen)); 321 | saveas(gcf, fname); 322 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_MSST_win',num2str(winlen),'.pdf'); 323 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 324 | 325 | %% ResFreq 326 | data_rsh=data_reshape.'; 327 | mv=max(abs(data_rsh)); 328 | noisedSig=data_rsh/mv; 329 | 330 | siglen=256; 331 | ret=[]; 332 | sp=1; 333 | ep=sp+siglen-1; 334 | overlay=16; 335 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 342 | flag=1; 343 | left_len=length(noisedSig)-sp+1; 344 | ret=[ret;noisedSig(end-siglen+1:end)]; 345 | end 346 | 347 | bz=size(ret,1); 348 | 349 | if ~exist('matlab_real2.h5','file')==0 350 | delete('matlab_real2.h5') 351 | end 352 | if ~exist('matlab_imag2.h5','file')==0 353 | delete('matlab_imag2.h5') 354 | end 355 | if ~exist('bz.h5','file')==0 356 | delete('bz.h5') 357 | end 358 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 359 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 360 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 361 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 362 | h5create('bz.h5','/bz',size(bz)); 363 | h5write('bz.h5','/bz',bz); 364 | 365 | net_flag=system('curl -s 127.0.0.1:5012/'); 366 | load data1_resfreq.mat 367 | ret=[]; 368 | 369 | if flag==1 370 | iter=bz-1; 371 | else 372 | iter=bz; 373 | end 374 | for i=1:iter 375 | if i==1 376 | tmp=squeeze(data1_resfreq(i,:,:)); 377 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 378 | else 379 | tmp=squeeze(data1_resfreq(i,:,:)); 380 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 381 | end 382 | end 383 | if flag==1 384 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 385 | tmp=tmp(end-left_len+1:end,:); 386 | ret=[ret;tmp(overlay/2+1:end,:)]; 387 | end 388 | figure; 389 | ret=(10.^(ret/20)-1)/10; 390 | imagesc(t,yaxis,fftshift(abs(ret.'),1)) 391 | ylim([min(yaxis) ylim_max]) 392 | xlabel({'Time / sec';'(b)'}) 393 | ylabel('Freq. / Hz') 394 | title('TFA-Net') 395 | set(gca,'FontSize',fsz); 396 | set(get(gca,'XLabel'),'FontSize',fsz); 397 | set(get(gca,'YLabel'),'FontSize',fsz); 398 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 399 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 400 | line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 401 | line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 402 | axes('Position',[0.61,0.6,0.284,0.271]); 403 | imagesc(tt,yaxis,fftshift(abs(ret.'),1)); 404 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 405 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_TFA-Net'); 406 | saveas(gcf, fname); 407 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu_TFA-Net','.pdf'); 408 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 409 | 410 | -------------------------------------------------------------------------------- /TFA-Net_inference/exp1_simu2.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | N=256; 5 | nfft=256; 6 | fsz=23; 7 | winlen=32; 8 | divide_times=1/2; 9 | times_times=2; 10 | x11=1.7; x22=2.5; 11 | y11=-15; y22=15; 12 | 13 | x1=0.1;x2=1.29; 14 | y1=25.5;y2=62.7; 15 | 16 | ylim_low=-20; 17 | %% SIMU1 18 | fs=128; 19 | ts=1/fs; 20 | t = 0 : ts : 3-ts; 21 | 22 | c1 = 2 * pi * 10; % initial frequency of the chirp excitation 23 | c2 = 2 * pi * 2/2; % set the speed of frequency change be 1 Hz/second 24 | c3 = 2 * pi * 1/10; 25 | c4 = 2 * pi * -2/2; 26 | 27 | d1 = -2 * pi * 10; % initial frequency of the chirp excitation 28 | d2 = 2 * pi * 2/2; % set the speed of frequency change be 1 Hz/second 29 | d3 = 2 * pi * 1/10; 30 | d4 = 2 * pi * 2/2; 31 | 32 | e1 = 2 * pi * 0; % initial frequency of the chirp excitation 33 | e2 = 2 * pi * 2; % set the speed of frequency change be 1 Hz/second 34 | e3 = 2 * pi * (-0.07); 35 | e4 = 2 * pi * (-2.5); 36 | 37 | Sig1 = exp(1i*(c1 * t + c4 * t.^4 /4)); % get the A(t) 38 | Sig2 = exp(1i*(d1 * t + d4 * t.^4 /4)); % get the A(t) 39 | Sig3 = 5*exp(1i*(e1 * t + e4 *sin(2.5*pi*t) )); % get the A(t) 40 | Sig=Sig1+Sig2+Sig3; 41 | data_reshape=Sig.'; 42 | data_reshape=awgn(data_reshape,10); 43 | ydelta=fs/nfft; 44 | yaxis=(0:ydelta:fs-ydelta)-fs/2; 45 | if1=(c1+c4 * t.^3)/2/pi; 46 | if2=(d1+ d4 * t.^3)/2/pi; 47 | if3=(e1+ e4 *cos(2.5*pi*t)*2.5*pi)/2/pi ; 48 | figure 49 | plot(t,if1) 50 | hold on 51 | plot(t,if2) 52 | hold on 53 | plot(t,if3) 54 | ylim([ylim_low max(yaxis)]) 55 | xlabel({'Time / sec';'(a)'}) 56 | ylabel('Doppler / Hz') 57 | title('Ground Truth') 58 | set(gca,'FontSize',fsz); 59 | set(get(gca,'XLabel'),'FontSize',fsz); 60 | set(get(gca,'YLabel'),'FontSize',fsz); 61 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_Ground_Truth'); 62 | saveas(gcf, fname); 63 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_Ground_Truth','.pdf'); 64 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 65 | 66 | %% STFT 67 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 68 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 69 | data_reshape2=[zeros(winlen*divide_times/2,1);data_reshape;zeros(winlen*divide_times/2-1,1)]; 70 | spc_STFT2=abs(stft(data_reshape2,'Window',hamming(winlen*divide_times).','OverlapLength',winlen*divide_times-1,'FFTLength',nfft)); 71 | data_reshape3=[zeros(winlen*times_times/2,1);data_reshape;zeros(winlen*times_times/2-1,1)]; 72 | spc_STFT3=abs(stft(data_reshape3,'Window',hamming(winlen*times_times).','OverlapLength',winlen*times_times-1,'FFTLength',nfft)); 73 | tt=t; 74 | figure; 75 | imagesc(tt,yaxis,((abs(spc_STFT2)))); 76 | ylim([ylim_low max(yaxis)]) 77 | set(gca,'ydir','normal') 78 | title(strcat('STFT, winlen=',num2str(winlen*divide_times))) 79 | xlabel({'Time / sec';'(c)'}) 80 | ylabel('Freq. / Hz') 81 | set(gca,'FontSize',fsz); 82 | set(get(gca,'XLabel'),'FontSize',fsz); 83 | set(get(gca,'YLabel'),'FontSize',fsz); 84 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 85 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 86 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 87 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 88 | axes('Position',[0.2,0.62,0.285,0.25]); 89 | imagesc(tt,yaxis,(abs(spc_STFT2))) 90 | set(gca,'ydir','normal') 91 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 92 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_STFT_win',num2str(winlen*divide_times)); 93 | saveas(gcf, fname); 94 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_STFT_win',num2str(winlen*divide_times),'.pdf'); 95 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 96 | 97 | figure; 98 | imagesc(tt,yaxis,((abs(spc_STFT3)))); 99 | ylim([ylim_low max(yaxis)]) 100 | set(gca,'ydir','normal') 101 | title(strcat('STFT, winlen=',num2str(winlen*times_times))) 102 | xlabel({'Time / sec';'(d)'}) 103 | ylabel('Freq. / Hz') 104 | set(gca,'FontSize',fsz); 105 | set(get(gca,'XLabel'),'FontSize',fsz); 106 | set(get(gca,'YLabel'),'FontSize',fsz); 107 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 108 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 109 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 110 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 111 | axes('Position',[0.2,0.62,0.285,0.25]); 112 | imagesc(tt,yaxis,(abs(spc_STFT3))) 113 | set(gca,'ydir','normal') 114 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 115 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_STFT_win',num2str(winlen*times_times)); 116 | saveas(gcf, fname); 117 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_STFT_win',num2str(winlen*times_times),'.pdf'); 118 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 119 | 120 | figure; 121 | imagesc(tt,yaxis,((abs(spc_STFT)))); 122 | ylim([ylim_low max(yaxis)]) 123 | set(gca,'ydir','normal') 124 | title(strcat('STFT, winlen=',num2str(winlen))) 125 | xlabel({'Time / sec';'(e)'}) 126 | ylabel('Freq. / Hz') 127 | set(gca,'FontSize',fsz); 128 | set(get(gca,'XLabel'),'FontSize',fsz); 129 | set(get(gca,'YLabel'),'FontSize',fsz); 130 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 131 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 132 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 133 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 134 | axes('Position',[0.2,0.62,0.285,0.25]); 135 | imagesc(tt,yaxis,(abs(spc_STFT))) 136 | set(gca,'ydir','normal') 137 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 138 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_STFT_win',num2str(winlen)); 139 | saveas(gcf, fname); 140 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_STFT_win',num2str(winlen),'.pdf'); 141 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 142 | %% SST 143 | spc_SST = SST2(data_reshape,winlen); 144 | spc_SST2 = SST2(data_reshape,winlen*divide_times); 145 | spc_SST3 = SST2(data_reshape,winlen*times_times); 146 | spc_SST=abs(spc_SST); 147 | 148 | tt=t; 149 | figure; 150 | imagesc(tt,yaxis,(fftshift(abs(spc_SST2),1))); 151 | ylim([ylim_low max(yaxis)]) 152 | set(gca,'ydir','normal') 153 | xlabel({'Time / sec';'(f)'}) 154 | ylabel('Freq. / Hz') 155 | title(strcat('SST, winlen=',num2str(winlen*divide_times))) 156 | set(gca,'FontSize',fsz); 157 | set(get(gca,'XLabel'),'FontSize',fsz); 158 | set(get(gca,'YLabel'),'FontSize',fsz); 159 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 160 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 161 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 162 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 163 | axes('Position',[0.2,0.62,0.285,0.25]); 164 | imagesc(tt,yaxis,(fftshift(abs(spc_SST2),1))); 165 | set(gca,'ydir','normal') 166 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 167 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SST_win',num2str(winlen*divide_times)); 168 | saveas(gcf, fname); 169 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SST_win',num2str(winlen*divide_times),'.pdf'); 170 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 171 | 172 | figure; 173 | imagesc(tt,yaxis,(fftshift(abs(spc_SST3),1))); 174 | ylim([ylim_low max(yaxis)]) 175 | set(gca,'ydir','normal') 176 | xlabel({'Time / sec';'(g)'}) 177 | ylabel('Freq. / Hz') 178 | title(strcat('SST, winlen=',num2str(winlen*times_times))) 179 | set(gca,'FontSize',fsz); 180 | set(get(gca,'XLabel'),'FontSize',fsz); 181 | set(get(gca,'YLabel'),'FontSize',fsz); 182 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 183 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 184 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 185 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 186 | axes('Position',[0.2,0.62,0.285,0.25]); 187 | imagesc(tt,yaxis,(fftshift(abs(spc_SST3),1))); 188 | set(gca,'ydir','normal') 189 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 190 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SST_win',num2str(winlen*times_times)); 191 | saveas(gcf, fname); 192 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SST_win',num2str(winlen*times_times),'.pdf'); 193 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 194 | 195 | figure; 196 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 197 | ylim([ylim_low max(yaxis)]) 198 | set(gca,'ydir','normal') 199 | xlabel({'Time / sec';'(h)'}) 200 | ylabel('Freq. / Hz') 201 | title(strcat('SST, winlen=',num2str(winlen))) 202 | set(gca,'FontSize',fsz); 203 | set(get(gca,'XLabel'),'FontSize',fsz); 204 | set(get(gca,'YLabel'),'FontSize',fsz); 205 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 206 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 207 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 208 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 209 | axes('Position',[0.2,0.62,0.285,0.25]); 210 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 211 | set(gca,'ydir','normal') 212 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 213 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SST_win',num2str(winlen)); 214 | saveas(gcf, fname); 215 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SST_win',num2str(winlen),'.pdf'); 216 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 217 | 218 | %% SET 219 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 220 | [spc_SET2,tfr] = SET_Y2(data_reshape,winlen*divide_times); 221 | [spc_SET3,tfr] = SET_Y2(data_reshape,winlen*times_times); 222 | spc_SET=abs(spc_SET); 223 | 224 | tt=t; 225 | 226 | figure; 227 | imagesc(tt,yaxis,(fftshift(abs(spc_SET2),1))); 228 | ylim([ylim_low max(yaxis)]) 229 | set(gca,'ydir','normal') 230 | title(strcat('SET, winlen=',num2str(winlen*divide_times))) 231 | xlabel({'Time / sec';'(i)'}) 232 | ylabel('Freq. / Hz') 233 | set(gca,'FontSize',fsz); 234 | set(get(gca,'XLabel'),'FontSize',fsz); 235 | set(get(gca,'YLabel'),'FontSize',fsz); 236 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 237 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 238 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 239 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 240 | axes('Position',[0.2,0.62,0.285,0.25]); 241 | imagesc(tt,yaxis,(fftshift(abs(spc_SET2),1))); 242 | set(gca,'ydir','normal') 243 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 244 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SET_win',num2str(winlen*divide_times)); 245 | saveas(gcf, fname); 246 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SET_win',num2str(winlen*divide_times),'.pdf'); 247 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 248 | 249 | figure; 250 | imagesc(tt,yaxis,(fftshift(abs(spc_SET3),1))); 251 | ylim([ylim_low max(yaxis)]) 252 | title(strcat('SET, winlen=',num2str(winlen*times_times))) 253 | set(gca,'ydir','normal') 254 | xlabel({'Time / sec';'(j)'}) 255 | ylabel('Freq. / Hz') 256 | set(gca,'FontSize',fsz); 257 | set(get(gca,'XLabel'),'FontSize',fsz); 258 | set(get(gca,'YLabel'),'FontSize',fsz); 259 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 260 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 261 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 262 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 263 | axes('Position',[0.2,0.62,0.285,0.25]); 264 | imagesc(tt,yaxis,(fftshift(abs(spc_SET3),1))); 265 | set(gca,'ydir','normal') 266 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 267 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SET_win',num2str(winlen*times_times)); 268 | saveas(gcf, fname); 269 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SET_win',num2str(winlen*times_times),'.pdf'); 270 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 271 | 272 | figure; 273 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 274 | ylim([ylim_low max(yaxis)]) 275 | title(strcat('SET, winlen=',num2str(winlen))) 276 | set(gca,'ydir','normal') 277 | xlabel({'Time / sec';'(k)'}) 278 | ylabel('Freq. / Hz') 279 | set(gca,'FontSize',fsz); 280 | set(get(gca,'XLabel'),'FontSize',fsz); 281 | set(get(gca,'YLabel'),'FontSize',fsz); 282 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 283 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 284 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 285 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 286 | axes('Position',[0.2,0.62,0.285,0.25]); 287 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 288 | set(gca,'ydir','normal') 289 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 290 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SET_win',num2str(winlen)); 291 | saveas(gcf, fname); 292 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_SET_win',num2str(winlen),'.pdf'); 293 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 294 | 295 | %% MSST 296 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 297 | [spc_MSST2,tfr,omega2] = MSST_Y_new2(data_reshape,winlen*divide_times,3); 298 | [spc_MSST3,tfr,omega2] = MSST_Y_new2(data_reshape,winlen*times_times,3); 299 | spc_MSST=spc_MSST1; 300 | 301 | tt=t; 302 | figure; 303 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST2),1))); 304 | ylim([ylim_low max(yaxis)]) 305 | set(gca,'ydir','normal') 306 | title(strcat('MSST, winlen=',num2str(winlen*divide_times))) 307 | xlabel({'Time / sec';'(l)'}) 308 | ylabel('Freq. / Hz') 309 | set(gca,'FontSize',fsz); 310 | set(get(gca,'XLabel'),'FontSize',fsz); 311 | set(get(gca,'YLabel'),'FontSize',fsz); 312 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 313 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 314 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 315 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 316 | axes('Position',[0.2,0.62,0.285,0.25]); 317 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST2),1))); 318 | set(gca,'ydir','normal') 319 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 320 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_MSST_win',num2str(winlen*divide_times)); 321 | saveas(gcf, fname); 322 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_MSST_win',num2str(winlen*divide_times),'.pdf'); 323 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 324 | figure; 325 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST3),1))); 326 | ylim([ylim_low max(yaxis)]) 327 | set(gca,'ydir','normal') 328 | title(strcat('MSST, winlen=',num2str(winlen*times_times))) 329 | xlabel({'Time / sec';'(m)'}) 330 | ylabel('Freq. / Hz') 331 | set(gca,'FontSize',fsz); 332 | set(get(gca,'XLabel'),'FontSize',fsz); 333 | set(get(gca,'YLabel'),'FontSize',fsz); 334 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 335 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 336 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 337 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 338 | axes('Position',[0.2,0.62,0.285,0.25]); 339 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST3),1))); 340 | set(gca,'ydir','normal') 341 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 342 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_MSST_win',num2str(winlen*times_times)); 343 | saveas(gcf, fname); 344 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_MSST_win',num2str(winlen*times_times),'.pdf'); 345 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 346 | 347 | figure; 348 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 349 | ylim([ylim_low max(yaxis)]) 350 | set(gca,'ydir','normal') 351 | title(strcat('MSST, winlen=',num2str(winlen))) 352 | xlabel({'Time / sec';'(n)'}) 353 | ylabel('Freq. / Hz') 354 | set(gca,'FontSize',fsz); 355 | set(get(gca,'XLabel'),'FontSize',fsz); 356 | set(get(gca,'YLabel'),'FontSize',fsz); 357 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 358 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 359 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 360 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 361 | axes('Position',[0.2,0.62,0.285,0.25]); 362 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 363 | set(gca,'ydir','normal') 364 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 365 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_MSST_win',num2str(winlen)); 366 | saveas(gcf, fname); 367 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_MSST_win',num2str(winlen),'.pdf'); 368 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 369 | 370 | %% ResFreq 371 | data_rsh=data_reshape.'; 372 | mv=max(abs(data_rsh)); 373 | noisedSig=data_rsh/mv; 374 | 375 | siglen=256; 376 | ret=[]; 377 | sp=1; 378 | ep=sp+siglen-1; 379 | overlay=30; 380 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 387 | flag=1; 388 | left_len=length(noisedSig)-sp+1; 389 | ret=[ret;noisedSig(end-siglen+1:end)]; 390 | end 391 | 392 | bz=size(ret,1); 393 | 394 | if ~exist('matlab_real2.h5','file')==0 395 | delete('matlab_real2.h5') 396 | end 397 | if ~exist('matlab_imag2.h5','file')==0 398 | delete('matlab_imag2.h5') 399 | end 400 | if ~exist('bz.h5','file')==0 401 | delete('bz.h5') 402 | end 403 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 404 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 405 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 406 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 407 | h5create('bz.h5','/bz',size(bz)); 408 | h5write('bz.h5','/bz',bz); 409 | 410 | net_flag=system('curl -s 127.0.0.1:5012/'); 411 | load data1_resfreq.mat 412 | ret=[]; 413 | 414 | if flag==1 415 | iter=bz-1; 416 | else 417 | iter=bz; 418 | end 419 | for i=1:iter 420 | if i==1 421 | tmp=squeeze(data1_resfreq(i,:,:)); 422 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 423 | else 424 | tmp=squeeze(data1_resfreq(i,:,:)); 425 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 426 | end 427 | end 428 | if flag==1 429 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 430 | tmp=tmp(end-left_len+1:end,:); 431 | ret=[ret;tmp(overlay/2+1:end,:)]; 432 | end 433 | figure; 434 | ret=(10.^(ret/20)-1)/10; 435 | imagesc(t,yaxis,fftshift(abs(ret.'),1)) 436 | ylim([ylim_low max(yaxis)]) 437 | set(gca,'ydir','normal') 438 | xlabel({'Time / sec';'(b)'}) 439 | ylabel('Freq. / Hz') 440 | title('TFA-Net') 441 | set(gca,'FontSize',fsz); 442 | set(get(gca,'XLabel'),'FontSize',fsz); 443 | set(get(gca,'YLabel'),'FontSize',fsz); 444 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 445 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 446 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 447 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 448 | axes('Position',[0.2,0.62,0.285,0.25]); 449 | imagesc(tt,yaxis,fftshift(abs(ret.'),1)); 450 | set(gca,'ydir','normal') 451 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 452 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_TFA-Net'); 453 | saveas(gcf, fname); 454 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp1_simu2_TFA-Net','.pdf'); 455 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 456 | 457 | -------------------------------------------------------------------------------- /TFA-Net_inference/exp2_bat.m: -------------------------------------------------------------------------------- 1 | 2 | clc 3 | clear 4 | close all 5 | nfft=256; 6 | fsz=16; 7 | clims_lim=30; 8 | 9 | 10 | load('batdata2.mat'); 11 | SampFreq = 1000000/7; 12 | FS=SampFreq; 13 | n=length(data); 14 | data_reshape=data; 15 | time=(1:n)/SampFreq; 16 | t=time*1000; 17 | ydelta=FS/nfft; 18 | yaxis=((0:ydelta:FS-ydelta)-FS/2)/1000; 19 | winlen=48; 20 | ylim_low=0; 21 | 22 | x11=14.7; x22=16.9; 23 | y11=0; y22=20; 24 | 25 | x1=19.1;x2=23.4; 26 | y1=23.6;y2=43; 27 | 28 | %% RS 29 | spc_RS = RS(data_reshape,winlen); 30 | spc_RS=abs(spc_RS); 31 | tt=t; 32 | figure; 33 | clims = [max(max(20*log10((abs(spc_RS)))))-clims_lim,max(max(20*log10((abs(spc_RS)))))]; 34 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 35 | % set(gca,'ydir','normal') 36 | ylim([ylim_low max(yaxis)]) 37 | title(strcat('RS, winlen=',num2str(winlen))) 38 | xlabel({'Time / sec';'(d)'}) 39 | ylabel('Freq. / kHz') 40 | set(gca,'FontSize',fsz); 41 | set(get(gca,'XLabel'),'FontSize',fsz); 42 | set(get(gca,'YLabel'),'FontSize',fsz); 43 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_RS_win',num2str(winlen)); 44 | saveas(gcf, fname); 45 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 46 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 47 | 48 | 49 | %% GWarblet 50 | 51 | [Spec,f] = GWarblet_complex(data_reshape,FS,0,1,nfft,winlen); 52 | [v l] = max(Spec,[],1); 53 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 54 | WinLen =winlen*4; 55 | [Spec,f] = GWarblet_complex(data_reshape,FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 56 | 57 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,0,1,nfft,winlen); 58 | [v l] = max(Spec2,[],1); 59 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 60 | WinLen =winlen*4; 61 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 62 | Spec=[flipud(Spec2);Spec]; 63 | 64 | figure; 65 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 66 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 67 | % set(gca,'ydir','normal') 68 | ylim([ylim_low max(yaxis)]) 69 | title(strcat('GWarblet')) 70 | xlabel({'Time / sec';'(c)'}) 71 | ylabel('Freq. / kHz') 72 | set(gca,'FontSize',fsz); 73 | set(get(gca,'XLabel'),'FontSize',fsz); 74 | set(get(gca,'YLabel'),'FontSize',fsz); 75 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_GWarblet'); 76 | saveas(gcf, fname); 77 | 78 | %% WVD 79 | yaxis1=(0:ydelta:FS/2-ydelta)/1000; 80 | Spec=wvd(data_reshape,'smoothedPseudo','NumFrequencyPoints',nfft); 81 | figure; 82 | clims = [max(max(20*log10((abs(Spec)))))-40,max(max(20*log10((abs(Spec)))))]; 83 | imagesc(t,yaxis1,20*log10((abs(Spec))),clims); 84 | % set(gca,'ydir','normal') 85 | % ylim([ylim_low max(yaxis)]) 86 | title(strcat('SPWVD')) 87 | xlabel({'Time / sec';'(b)'}) 88 | ylabel('Freq. / kHz') 89 | set(gca,'FontSize',fsz); 90 | set(get(gca,'XLabel'),'FontSize',fsz); 91 | set(get(gca,'YLabel'),'FontSize',fsz); 92 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_SPWVD'); 93 | saveas(gcf, fname); 94 | 95 | %% cwt 96 | % Spec=cwt((data_reshape)); 97 | % Spec=flipud((Spec)); 98 | % figure; 99 | % clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 100 | % imagesc(t,yaxis,20*log10((abs(Spec))),clims); 101 | % % set(gca,'ydir','normal') 102 | % ylim([ylim_low max(yaxis)]) 103 | % title(strcat('CWT')) 104 | % xlabel({'Time / sec';'(c)'}) 105 | % ylabel('Freq. / kHz') 106 | % set(gca,'FontSize',fsz); 107 | % set(get(gca,'XLabel'),'FontSize',fsz); 108 | % set(get(gca,'YLabel'),'FontSize',fsz); 109 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_CWT'); 110 | % saveas(gcf, fname); 111 | %% STFT 112 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 113 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 114 | tt=t; 115 | 116 | figure; 117 | clims = [max(max(20*log10((abs(spc_STFT)))))-clims_lim,max(max(20*log10((abs(spc_STFT)))))]; 118 | imagesc(tt,yaxis,20*log10((abs(spc_STFT))),clims); 119 | % set(gca,'ydir','normal') 120 | ylim([ylim_low max(yaxis)]) 121 | title(strcat('STFT, winlen=',num2str(winlen))) 122 | xlabel({'Time / sec';'(a)'}) 123 | ylabel('Freq. / kHz') 124 | set(gca,'FontSize',fsz); 125 | set(get(gca,'XLabel'),'FontSize',fsz); 126 | set(get(gca,'YLabel'),'FontSize',fsz); 127 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_STFT_win',num2str(winlen)); 128 | saveas(gcf, fname); 129 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_STFT_win',num2str(winlen),'.pdf'); 130 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 131 | %% SST 132 | spc_SST = SST2(data_reshape,winlen); 133 | spc_SST=abs(spc_SST); 134 | tt=t; 135 | figure; 136 | clims = [max(max(20*log10((abs(spc_SST)))))-clims_lim,max(max(20*log10((abs(spc_SST)))))]; 137 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 138 | % set(gca,'ydir','normal') 139 | ylim([ylim_low max(yaxis)]) 140 | xlabel({'Time / sec';'(e)'}) 141 | ylabel('Freq. / kHz') 142 | title(strcat('SST, winlen=',num2str(winlen))) 143 | set(gca,'FontSize',fsz); 144 | set(get(gca,'XLabel'),'FontSize',fsz); 145 | set(get(gca,'YLabel'),'FontSize',fsz); 146 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_SST_win',num2str(winlen)); 147 | saveas(gcf, fname); 148 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SST_win',num2str(winlen),'.pdf'); 149 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 150 | 151 | %% SET 152 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 153 | spc_SET=abs(spc_SET); 154 | tt=t; 155 | figure; 156 | clims = [max(max(20*log10((abs(spc_SET)))))-clims_lim,max(max(20*log10((abs(spc_SET)))))]; 157 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 158 | % set(gca,'ydir','normal') 159 | ylim([ylim_low max(yaxis)]) 160 | title(strcat('SET, winlen=',num2str(winlen))) 161 | xlabel({'Time / sec';'(f)'}) 162 | ylabel('Freq. / kHz') 163 | set(gca,'FontSize',fsz); 164 | set(get(gca,'XLabel'),'FontSize',fsz); 165 | set(get(gca,'YLabel'),'FontSize',fsz); 166 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_SET_win',num2str(winlen)); 167 | saveas(gcf, fname); 168 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SET_win',num2str(winlen),'.pdf'); 169 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 170 | 171 | %% MSST 172 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 173 | spc_MSST=spc_MSST1; 174 | tt=t; 175 | figure; 176 | clims = [max(max(20*log10((abs(spc_MSST)))))-clims_lim,max(max(20*log10((abs(spc_MSST)))))]; 177 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 178 | % set(gca,'ydir','normal') 179 | ylim([ylim_low max(yaxis)]) 180 | title(strcat('MSST, winlen=',num2str(winlen))) 181 | xlabel({'Time / sec';'(g)'}) 182 | ylabel('Freq. / kHz') 183 | set(gca,'FontSize',fsz); 184 | set(get(gca,'XLabel'),'FontSize',fsz); 185 | set(get(gca,'YLabel'),'FontSize',fsz); 186 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_MSST_win',num2str(winlen)); 187 | saveas(gcf, fname); 188 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 189 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 190 | 191 | 192 | 193 | %% ResFreq 194 | data_rsh=data_reshape.'; 195 | mv=max(abs(data_rsh)); 196 | noisedSig=data_rsh/mv; 197 | 198 | siglen=256; 199 | ret=[]; 200 | sp=1; 201 | ep=sp+siglen-1; 202 | overlay=30; 203 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 210 | flag=1; 211 | left_len=length(noisedSig)-sp+1; 212 | ret=[ret;noisedSig(end-siglen+1:end)]; 213 | end 214 | 215 | bz=size(ret,1); 216 | 217 | if ~exist('matlab_real2.h5','file')==0 218 | delete('matlab_real2.h5') 219 | end 220 | if ~exist('matlab_imag2.h5','file')==0 221 | delete('matlab_imag2.h5') 222 | end 223 | if ~exist('bz.h5','file')==0 224 | delete('bz.h5') 225 | end 226 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 227 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 228 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 229 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 230 | h5create('bz.h5','/bz',size(bz)); 231 | h5write('bz.h5','/bz',bz); 232 | 233 | net_flag=system('curl -s 127.0.0.1:5012/'); 234 | load data1_resfreq.mat 235 | ret=[]; 236 | 237 | if flag==1 238 | iter=bz-1; 239 | else 240 | iter=bz; 241 | end 242 | for i=1:iter 243 | if i==1 244 | tmp=squeeze(data1_resfreq(i,:,:)); 245 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 246 | else 247 | tmp=squeeze(data1_resfreq(i,:,:)); 248 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 249 | end 250 | end 251 | if flag==1 252 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 253 | tmp=tmp(end-left_len+1:end,:); 254 | ret=[ret;tmp(overlay/2+1:end,:)]; 255 | end 256 | figure; 257 | ret=(10.^(ret/20)-1)/10; 258 | 259 | clims = [max(max(20*log10((abs(ret)))))-clims_lim,max(max(20*log10((abs(ret)))))]; 260 | imagesc(t,yaxis,(20*log10(fftshift(abs(ret.'),1))),clims) 261 | % % set(gca,'ydir','normal') 262 | ylim([ylim_low max(yaxis)]) 263 | xlabel({'Time / sec';'(i)'}) 264 | ylabel('Freq. / kHz') 265 | title('TFA-Net') 266 | set(gca,'FontSize',fsz); 267 | set(get(gca,'XLabel'),'FontSize',fsz); 268 | set(get(gca,'YLabel'),'FontSize',fsz); 269 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_bat_TFA-Net'); 270 | saveas(gcf, fname); 271 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_TFA-Net','.pdf'); 272 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); -------------------------------------------------------------------------------- /TFA-Net_inference/exp2_heart.m: -------------------------------------------------------------------------------- 1 | 2 | clc 3 | clear 4 | close all 5 | nfft=256; 6 | fsz=16; 7 | clims_lim=45; 8 | 9 | 10 | load ECGData.mat 11 | TS=1/128; 12 | sp=1600; 13 | ep=3000; 14 | data_reshape=Y(sp:ep).'; 15 | 16 | FS=1/TS; 17 | N=length(data_reshape); 18 | time=N*TS; 19 | t=sp*TS:TS:(ep+1)*TS-TS; 20 | ydelta=FS/nfft; 21 | yaxis=(0:ydelta:FS-ydelta)-FS/2; 22 | winlen=32; 23 | x11=14.7; x22=16.9; 24 | y11=0; y22=20; 25 | 26 | x1=19.1;x2=23.4; 27 | y1=23.6;y2=43; 28 | 29 | %% GT 30 | figure;plot(t,(data_reshape)); 31 | xlim([min(t) max(t)]) 32 | ylim([min(data_reshape) 2.5]) 33 | title('ECG time signal') 34 | xlabel({'Time / sec';'(a)'}) 35 | ylabel('Amp.') 36 | set(gca,'FontSize',fsz); 37 | set(get(gca,'XLabel'),'FontSize',fsz); 38 | set(get(gca,'YLabel'),'FontSize',fsz); 39 | rectangle('Position',[x11 -0.77 x22-x11 1.5],'EdgeColor','red','Linewidth',3); 40 | % rectangle('Position',[x1 0.81 x2-x1 1.8-0.81],'EdgeColor','red','Linewidth',3); 41 | % line([x11,x1],[y11,y1],'color','r','Linewidth',2,'LineStyle','-.') 42 | % line([x22,x1],[y11,y2],'color','r','Linewidth',2,'LineStyle','-.') 43 | axes('Position',[0.6,0.6,0.3,0.3]); 44 | 45 | plot(t,(data_reshape)); 46 | text(15.5,0.9,'R','Fontsize',13,'Color','r') 47 | text(15.31,-0.35,'Q','Fontsize',13,'Color','r') 48 | text(15.55,-0.705,'S','Fontsize',13,'Color','r') 49 | text(15.2,-0.1,'P','Fontsize',13,'Color','r') 50 | text(15.8,-0.2,'T','Fontsize',13,'Color','r') 51 | set(gca,'xlim',[x11 x22],'ylim',[-0.77 1]); 52 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_Ground_Truth'); 53 | saveas(gcf, fname); 54 | 55 | %% GWarblet 56 | [Spec,f] = GWarblet(data_reshape,FS,0,1,nfft,winlen); 57 | [v l] = max(Spec,[],1); 58 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 59 | WinLen =winlen*4; 60 | [Spec,f] = GWarblet(data_reshape,FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 61 | Spec=fftshift(Spec,1); 62 | figure; 63 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 64 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 65 | set(gca,'ydir','normal') 66 | ylim([0 max(yaxis)-20]) 67 | title(strcat('GWarblet')) 68 | xlabel({'Time / sec';'(d)'}) 69 | ylabel('Freq. / Hz') 70 | set(gca,'FontSize',fsz); 71 | set(get(gca,'XLabel'),'FontSize',fsz); 72 | set(get(gca,'YLabel'),'FontSize',fsz); 73 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 74 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 75 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 76 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 77 | axes('Position',[0.6,0.6,0.3,0.3]); 78 | 79 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 80 | set(gca,'ydir','normal') 81 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 82 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_GWarblet'); 83 | saveas(gcf, fname); 84 | 85 | %% WVD 86 | Spec=wvd(data_reshape,'smoothedPseudo','NumFrequencyPoints',nfft); 87 | Spec=fftshift(Spec,1); 88 | figure; 89 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 90 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 91 | set(gca,'ydir','normal') 92 | ylim([0 max(yaxis)-20]) 93 | title(strcat('SPWVD')) 94 | xlabel({'Time / sec';'(c)'}) 95 | ylabel('Freq. / Hz') 96 | set(gca,'FontSize',fsz); 97 | set(get(gca,'XLabel'),'FontSize',fsz); 98 | set(get(gca,'YLabel'),'FontSize',fsz); 99 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 100 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 101 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 102 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 103 | axes('Position',[0.6,0.6,0.3,0.3]); 104 | 105 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 106 | set(gca,'ydir','normal') 107 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 108 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SPWVD'); 109 | saveas(gcf, fname); 110 | 111 | %% cwt 112 | % Spec=cwt(data_reshape); 113 | % Spec=flipud(Spec); 114 | % figure; 115 | % clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 116 | % imagesc(t,yaxis,20*log10((abs(Spec))),clims); 117 | % set(gca,'ydir','normal') 118 | % ylim([0 max(yaxis)-20]) 119 | % title(strcat('CWT')) 120 | % xlabel({'Time / sec';'(d)'}) 121 | % ylabel('Freq. / Hz') 122 | % set(gca,'FontSize',fsz); 123 | % set(get(gca,'XLabel'),'FontSize',fsz); 124 | % set(get(gca,'YLabel'),'FontSize',fsz); 125 | % rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 126 | % rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 127 | % line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 128 | % line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 129 | % axes('Position',[0.6,0.6,0.3,0.3]); 130 | % 131 | % imagesc(t,yaxis,20*log10((abs(Spec))),clims); 132 | % set(gca,'ydir','normal') 133 | % set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 134 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_CWT'); 135 | % saveas(gcf, fname); 136 | %% STFT 137 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 138 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 139 | tt=t; 140 | 141 | figure; 142 | clims = [max(max(20*log10((abs(spc_STFT)))))-clims_lim,max(max(20*log10((abs(spc_STFT)))))]; 143 | imagesc(tt,yaxis,20*log10((abs(spc_STFT))),clims); 144 | set(gca,'ydir','normal') 145 | ylim([0 max(yaxis)-20]) 146 | title(strcat('STFT, winlen=',num2str(winlen))) 147 | xlabel({'Time / sec';'(b)'}) 148 | ylabel('Freq. / Hz') 149 | set(gca,'FontSize',fsz); 150 | set(get(gca,'XLabel'),'FontSize',fsz); 151 | set(get(gca,'YLabel'),'FontSize',fsz); 152 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 153 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 154 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 155 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 156 | axes('Position',[0.6,0.6,0.3,0.3]); 157 | 158 | imagesc(tt,yaxis,20*log10(abs(spc_STFT)),clims) 159 | set(gca,'ydir','normal') 160 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 161 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_STFT_win',num2str(winlen)); 162 | saveas(gcf, fname); 163 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_STFT_win',num2str(winlen),'.pdf'); 164 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 165 | %% SST 166 | spc_SST = SST2(data_reshape,winlen); 167 | spc_SST=abs(spc_SST); 168 | 169 | tt=t; 170 | figure; 171 | clims = [max(max(20*log10((abs(spc_SST)))))-70,max(max(20*log10((abs(spc_SST)))))]; 172 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 173 | set(gca,'ydir','normal') 174 | ylim([0 max(yaxis)-20]) 175 | xlabel({'Time / sec';'(e)'}) 176 | ylabel('Freq. / Hz') 177 | title(strcat('SST, winlen=',num2str(winlen))) 178 | set(gca,'FontSize',fsz); 179 | set(get(gca,'XLabel'),'FontSize',fsz); 180 | set(get(gca,'YLabel'),'FontSize',fsz); 181 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 182 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 183 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 184 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 185 | axes('Position',[0.6,0.6,0.3,0.3]); 186 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 187 | set(gca,'ydir','normal') 188 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 189 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SST_win',num2str(winlen)); 190 | saveas(gcf, fname); 191 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SST_win',num2str(winlen),'.pdf'); 192 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 193 | 194 | %% SET 195 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 196 | spc_SET=abs(spc_SET); 197 | tt=t; 198 | figure; 199 | clims = [max(max(20*log10((abs(spc_SET)))))-90,max(max(20*log10((abs(spc_SET)))))]; 200 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 201 | set(gca,'ydir','normal') 202 | ylim([0 max(yaxis)-20]) 203 | title(strcat('SET, winlen=',num2str(winlen))) 204 | xlabel({'Time / sec';'(f)'}) 205 | ylabel('Freq. / Hz') 206 | set(gca,'FontSize',fsz); 207 | set(get(gca,'XLabel'),'FontSize',fsz); 208 | set(get(gca,'YLabel'),'FontSize',fsz); 209 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 210 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 211 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 212 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 213 | axes('Position',[0.6,0.6,0.3,0.3]); 214 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 215 | set(gca,'ydir','normal') 216 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 217 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SET_win',num2str(winlen)); 218 | saveas(gcf, fname); 219 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SET_win',num2str(winlen),'.pdf'); 220 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 221 | 222 | %% MSST 223 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 224 | spc_MSST=spc_MSST1; 225 | tt=t; 226 | figure; 227 | clims = [max(max(20*log10((abs(spc_MSST)))))-90,max(max(20*log10((abs(spc_MSST)))))]; 228 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 229 | set(gca,'ydir','normal') 230 | ylim([0 max(yaxis)-20]) 231 | title(strcat('MSST, winlen=',num2str(winlen))) 232 | xlabel({'Time / sec';'(g)'}) 233 | ylabel('Freq. / Hz') 234 | set(gca,'FontSize',fsz); 235 | set(get(gca,'XLabel'),'FontSize',fsz); 236 | set(get(gca,'YLabel'),'FontSize',fsz); 237 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 238 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 239 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 240 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 241 | axes('Position',[0.6,0.6,0.3,0.3]); 242 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 243 | set(gca,'ydir','normal') 244 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 245 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen)); 246 | saveas(gcf, fname); 247 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 248 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 249 | 250 | %% ResFreq 251 | data_rsh=data_reshape.'; 252 | mv=max(abs(data_rsh)); 253 | noisedSig=data_rsh/mv; 254 | 255 | siglen=256; 256 | ret=[]; 257 | sp=1; 258 | ep=sp+siglen-1; 259 | overlay=16; 260 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 267 | flag=1; 268 | left_len=length(noisedSig)-sp+1; 269 | ret=[ret;noisedSig(end-siglen+1:end)]; 270 | end 271 | 272 | bz=size(ret,1); 273 | 274 | if ~exist('matlab_real2.h5','file')==0 275 | delete('matlab_real2.h5') 276 | end 277 | if ~exist('matlab_imag2.h5','file')==0 278 | delete('matlab_imag2.h5') 279 | end 280 | if ~exist('bz.h5','file')==0 281 | delete('bz.h5') 282 | end 283 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 284 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 285 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 286 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 287 | h5create('bz.h5','/bz',size(bz)); 288 | h5write('bz.h5','/bz',bz); 289 | 290 | net_flag=system('curl -s 127.0.0.1:5012/'); 291 | load data1_resfreq.mat 292 | ret=[]; 293 | 294 | if flag==1 295 | iter=bz-1; 296 | else 297 | iter=bz; 298 | end 299 | for i=1:iter 300 | if i==1 301 | tmp=squeeze(data1_resfreq(i,:,:)); 302 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 303 | else 304 | tmp=squeeze(data1_resfreq(i,:,:)); 305 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 306 | end 307 | end 308 | if flag==1 309 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 310 | tmp=tmp(end-left_len+1:end,:); 311 | ret=[ret;tmp(overlay/2+1:end,:)]; 312 | end 313 | figure; 314 | ret=(10.^(ret/20)-1)/10; 315 | clims = [max(max(20*log10((abs(ret)))))-clims_lim,max(max(20*log10((abs(ret)))))]; 316 | imagesc(t,yaxis,20*log10(fftshift(abs(ret.'),1)),clims) 317 | set(gca,'ydir','normal') 318 | ylim([0 max(yaxis)-20]) 319 | xlabel({'Time / sec';'(h)'}) 320 | ylabel('Freq. / Hz') 321 | title('TFA-Net') 322 | set(gca,'FontSize',fsz); 323 | set(get(gca,'XLabel'),'FontSize',fsz); 324 | set(get(gca,'YLabel'),'FontSize',fsz); 325 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 326 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 327 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 328 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 329 | axes('Position',[0.6,0.6,0.3,0.3]); 330 | imagesc(tt,yaxis,20*log10(fftshift(abs(ret.'),1)),clims); 331 | set(gca,'ydir','normal') 332 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 333 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_TFA-Net'); 334 | saveas(gcf, fname); 335 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_TFA-Net','.pdf'); 336 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 337 | %% 338 | 339 | -------------------------------------------------------------------------------- /TFA-Net_inference/exp3_breheart.m: -------------------------------------------------------------------------------- 1 | 2 | clc 3 | clear 4 | close all 5 | nfft=256; 6 | fsz=16; 7 | clims_lim=50; 8 | 9 | 10 | load 1110_zzq_80_1_Raw_0_chan1.mat 11 | data_reshape=sig.'; 12 | sz=size(data_reshape); 13 | time=16.25; 14 | dt=time/sz(1); 15 | t=0:dt:time-dt; 16 | fs=1/dt; 17 | FS=fs; 18 | TS=dt; 19 | winlen=32; 20 | ydelta=fs/nfft; 21 | yaxis=(0:ydelta:fs-ydelta)-fs/2; 22 | 23 | 24 | x11=9; x22=11.7; 25 | y11=-18; y22=18; 26 | 27 | x1=0.55;x2=7; 28 | y1=11.8;y2=37.4; 29 | ylim_low=-19; 30 | 31 | %% RS 32 | spc_RS = RS(data_reshape,winlen); 33 | spc_RS=abs(spc_RS); 34 | spc_RS=flipud(spc_RS); 35 | tt=t; 36 | figure; 37 | clims = [max(max(20*log10((abs(spc_RS)))))-70,max(max(20*log10((abs(spc_RS)))))]; 38 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 39 | set(gca,'ydir','normal') 40 | ylim([ylim_low max(yaxis)]) 41 | title(strcat('RS, winlen=',num2str(winlen))) 42 | xlabel({'Time / sec';'(d)'}) 43 | ylabel('Freq. / Hz') 44 | set(gca,'FontSize',fsz); 45 | set(get(gca,'XLabel'),'FontSize',fsz); 46 | set(get(gca,'YLabel'),'FontSize',fsz); 47 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 48 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 49 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 50 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 51 | axes('Position',[0.16,0.6,0.3,0.3]); 52 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 53 | set(gca,'ydir','normal') 54 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 55 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_RS_win',num2str(winlen)); 56 | saveas(gcf, fname); 57 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 58 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 59 | 60 | 61 | %% GWarblet 62 | [Spec,f] = GWarblet_complex(data_reshape,FS,0,1,nfft,winlen); 63 | [v l] = max(Spec,[],1); 64 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 65 | WinLen =winlen*4; 66 | [Spec,f] = GWarblet_complex(data_reshape,FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 67 | 68 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,0,1,nfft,winlen); 69 | [v l] = max(Spec2,[],1); 70 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 71 | WinLen =winlen*4; 72 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 73 | Spec=[flipud(Spec2);Spec]; 74 | Spec=flipud(Spec); 75 | figure; 76 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 77 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 78 | set(gca,'ydir','normal') 79 | ylim([ylim_low max(yaxis)]) 80 | title(strcat('GWarblet')) 81 | xlabel({'Time / sec';'(c)'}) 82 | ylabel('Freq. / Hz') 83 | set(gca,'FontSize',fsz); 84 | set(get(gca,'XLabel'),'FontSize',fsz); 85 | set(get(gca,'YLabel'),'FontSize',fsz); 86 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 87 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 88 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 89 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 90 | axes('Position',[0.16,0.6,0.3,0.3]); 91 | 92 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 93 | set(gca,'ydir','normal') 94 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 95 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_GWarblet'); 96 | saveas(gcf, fname); 97 | 98 | %% WVD 99 | Spec=wvd(data_reshape,'smoothedPseudo'); 100 | Spec=fftshift(Spec,1); 101 | Spec=flipud(Spec); 102 | figure; 103 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 104 | imagesc(t,yaxis/2,20*log10((abs(Spec))),clims); 105 | set(gca,'ydir','normal') 106 | ylim([ylim_low max(yaxis)]) 107 | title(strcat('SPWVD')) 108 | xlabel({'Time / sec';'(b)'}) 109 | ylabel('Freq. / Hz') 110 | set(gca,'FontSize',fsz); 111 | set(get(gca,'XLabel'),'FontSize',fsz); 112 | set(get(gca,'YLabel'),'FontSize',fsz); 113 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 114 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 115 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 116 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 117 | axes('Position',[0.16,0.6,0.3,0.3]); 118 | 119 | imagesc(t,yaxis/2,20*log10((abs(Spec))),clims); 120 | set(gca,'ydir','normal') 121 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 122 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_SPWVD'); 123 | saveas(gcf, fname); 124 | 125 | %% cwt 126 | % Spec=cwt((data_reshape)); 127 | % Spec=([(fftshift(Spec(:,:,2),1));fftshift(flipud(Spec(:,:,1)),1)]); 128 | % Spec=flipud(Spec); 129 | % figure; 130 | % clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 131 | % imagesc(t,yaxis,20*log10((abs(Spec))),clims); 132 | % set(gca,'ydir','normal') 133 | % ylim([ylim_low max(yaxis)]) 134 | % title(strcat('CWT')) 135 | % xlabel({'Time / sec';'(c)'}) 136 | % ylabel('Freq. / Hz') 137 | % set(gca,'FontSize',fsz); 138 | % set(get(gca,'XLabel'),'FontSize',fsz); 139 | % set(get(gca,'YLabel'),'FontSize',fsz); 140 | % rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 141 | % rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 142 | % line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 143 | % line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 144 | % axes('Position',[0.16,0.6,0.3,0.3]); 145 | % 146 | % imagesc(t,yaxis,20*log10((abs(Spec))),clims); 147 | % set(gca,'ydir','normal') 148 | % set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 149 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_CWT'); 150 | % saveas(gcf, fname); 151 | %% STFT 152 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 153 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 154 | spc_STFT=flipud(spc_STFT); 155 | tt=t; 156 | 157 | figure; 158 | clims = [max(max(20*log10((abs(spc_STFT)))))-clims_lim,max(max(20*log10((abs(spc_STFT)))))]; 159 | imagesc(tt,yaxis,20*log10((abs(spc_STFT))),clims); 160 | set(gca,'ydir','normal') 161 | ylim([ylim_low max(yaxis)]) 162 | title(strcat('STFT, winlen=',num2str(winlen))) 163 | xlabel({'Time / sec';'(a)'}) 164 | ylabel('Freq. / Hz') 165 | set(gca,'FontSize',fsz); 166 | set(get(gca,'XLabel'),'FontSize',fsz); 167 | set(get(gca,'YLabel'),'FontSize',fsz); 168 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 169 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 170 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 171 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 172 | axes('Position',[0.16,0.6,0.3,0.3]); 173 | 174 | imagesc(tt,yaxis,20*log10(abs(spc_STFT)),clims) 175 | set(gca,'ydir','normal') 176 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 177 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_STFT_win',num2str(winlen)); 178 | saveas(gcf, fname); 179 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_STFT_win',num2str(winlen),'.pdf'); 180 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 181 | %% SST 182 | spc_SST = SST2(data_reshape,winlen); 183 | spc_SST=abs(spc_SST); 184 | spc_SST=flipud(spc_SST); 185 | tt=t; 186 | figure; 187 | clims = [max(max(20*log10((abs(spc_SST)))))-70,max(max(20*log10((abs(spc_SST)))))]; 188 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 189 | set(gca,'ydir','normal') 190 | ylim([ylim_low max(yaxis)]) 191 | xlabel({'Time / sec';'(e)'}) 192 | ylabel('Freq. / Hz') 193 | title(strcat('SST, winlen=',num2str(winlen))) 194 | set(gca,'FontSize',fsz); 195 | set(get(gca,'XLabel'),'FontSize',fsz); 196 | set(get(gca,'YLabel'),'FontSize',fsz); 197 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 198 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 199 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 200 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 201 | axes('Position',[0.16,0.6,0.3,0.3]); 202 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 203 | set(gca,'ydir','normal') 204 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 205 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_SST_win',num2str(winlen)); 206 | saveas(gcf, fname); 207 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SST_win',num2str(winlen),'.pdf'); 208 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 209 | 210 | %% SET 211 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 212 | spc_SET=abs(spc_SET); 213 | spc_SET=flipud(spc_SET); 214 | tt=t; 215 | figure; 216 | clims = [max(max(20*log10((abs(spc_SET)))))-clims_lim,max(max(20*log10((abs(spc_SET)))))]; 217 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 218 | set(gca,'ydir','normal') 219 | ylim([ylim_low max(yaxis)]) 220 | title(strcat('SET, winlen=',num2str(winlen))) 221 | xlabel({'Time / sec';'(f)'}) 222 | ylabel('Freq. / Hz') 223 | set(gca,'FontSize',fsz); 224 | set(get(gca,'XLabel'),'FontSize',fsz); 225 | set(get(gca,'YLabel'),'FontSize',fsz); 226 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 227 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 228 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 229 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 230 | axes('Position',[0.16,0.6,0.3,0.3]); 231 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 232 | set(gca,'ydir','normal') 233 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 234 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_SET_win',num2str(winlen)); 235 | saveas(gcf, fname); 236 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SET_win',num2str(winlen),'.pdf'); 237 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 238 | 239 | %% MSST 240 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 241 | spc_MSST=spc_MSST1; 242 | spc_MSST=flipud(spc_MSST); 243 | tt=t; 244 | figure; 245 | clims = [max(max(20*log10((abs(spc_MSST)))))-clims_lim,max(max(20*log10((abs(spc_MSST)))))]; 246 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 247 | set(gca,'ydir','normal') 248 | ylim([ylim_low max(yaxis)]) 249 | title(strcat('MSST, winlen=',num2str(winlen))) 250 | xlabel({'Time / sec';'(g)'}) 251 | ylabel('Freq. / Hz') 252 | set(gca,'FontSize',fsz); 253 | set(get(gca,'XLabel'),'FontSize',fsz); 254 | set(get(gca,'YLabel'),'FontSize',fsz); 255 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 256 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 257 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 258 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 259 | axes('Position',[0.16,0.6,0.3,0.3]); 260 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 261 | set(gca,'ydir','normal') 262 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 263 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_MSST_win',num2str(winlen)); 264 | saveas(gcf, fname); 265 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 266 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 267 | 268 | 269 | 270 | %% ResFreq 271 | data_rsh=data_reshape.'; 272 | mv=max(abs(data_rsh)); 273 | noisedSig=data_rsh/mv; 274 | 275 | siglen=256; 276 | ret=[]; 277 | sp=1; 278 | ep=sp+siglen-1; 279 | overlay=30; 280 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 287 | flag=1; 288 | left_len=length(noisedSig)-sp+1; 289 | ret=[ret;noisedSig(end-siglen+1:end)]; 290 | end 291 | 292 | bz=size(ret,1); 293 | 294 | if ~exist('matlab_real2.h5','file')==0 295 | delete('matlab_real2.h5') 296 | end 297 | if ~exist('matlab_imag2.h5','file')==0 298 | delete('matlab_imag2.h5') 299 | end 300 | if ~exist('bz.h5','file')==0 301 | delete('bz.h5') 302 | end 303 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 304 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 305 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 306 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 307 | h5create('bz.h5','/bz',size(bz)); 308 | h5write('bz.h5','/bz',bz); 309 | 310 | net_flag=system('curl -s 127.0.0.1:5012/'); 311 | load data1_resfreq.mat 312 | ret=[]; 313 | 314 | if flag==1 315 | iter=bz-1; 316 | else 317 | iter=bz; 318 | end 319 | for i=1:iter 320 | if i==1 321 | tmp=squeeze(data1_resfreq(i,:,:)); 322 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 323 | else 324 | tmp=squeeze(data1_resfreq(i,:,:)); 325 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 326 | end 327 | end 328 | if flag==1 329 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 330 | tmp=tmp(end-left_len+1:end,:); 331 | ret=[ret;tmp(overlay/2+1:end,:)]; 332 | end 333 | figure; 334 | ret=(10.^(ret/20)-1)/10; 335 | 336 | clims = [max(max(20*log10((abs(ret)))))-clims_lim,max(max(20*log10((abs(ret)))))]; 337 | imagesc(t,yaxis,flipud(20*log10(fftshift(abs(ret.'),1))),clims) 338 | set(gca,'ydir','normal') 339 | ylim([ylim_low max(yaxis)]) 340 | xlabel({'Time / sec';'(i)'}) 341 | ylabel('Freq. / Hz') 342 | title('TFA-Net') 343 | set(gca,'FontSize',fsz); 344 | set(get(gca,'XLabel'),'FontSize',fsz); 345 | set(get(gca,'YLabel'),'FontSize',fsz); 346 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 347 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 348 | line([x11,x2],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 349 | line([x22,x2],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 350 | axes('Position',[0.16,0.6,0.3,0.3]); 351 | imagesc(tt,yaxis,flipud(20*log10(fftshift(abs(ret.'),1))),clims); 352 | set(gca,'ydir','normal') 353 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 354 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_breheart_TFA-Net'); 355 | saveas(gcf, fname); 356 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_TFA-Net','.pdf'); 357 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); -------------------------------------------------------------------------------- /TFA-Net_inference/exp3_dianji.m: -------------------------------------------------------------------------------- 1 | 2 | clc 3 | clear 4 | close all 5 | nfft=256; 6 | fsz=16; 7 | clims_lim=40; 8 | 9 | 10 | load dianji_data.mat 11 | data_reshape=(dianji_data(1:2:8100).').'; 12 | sz=size(data_reshape); 13 | fs=4e6; 14 | dt=1/fs; 15 | FS=fs; 16 | t=(dt:dt:(sz)*dt)*1000; 17 | winlen=32; 18 | nfft=256; 19 | corr_fs=1000; 20 | ydelta=corr_fs/nfft; 21 | yaxis=(0:ydelta:corr_fs-ydelta)-corr_fs/2; 22 | 23 | 24 | x11=0.28; x22=0.43; 25 | y11=-120; y22=153; 26 | 27 | x1=0.601;x2=1.0125; 28 | y1=179;y2=485; 29 | ylim_low=-200; 30 | %% RS 31 | spc_RS = RS(data_reshape,winlen); 32 | spc_RS=abs(spc_RS); 33 | 34 | tt=t; 35 | figure; 36 | clims = [max(max(20*log10((abs(spc_RS)))))-40,max(max(20*log10((abs(spc_RS)))))]; 37 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 38 | set(gca,'ydir','normal') 39 | ylim([ylim_low max(yaxis)]) 40 | title(strcat('RS, winlen=',num2str(winlen))) 41 | xlabel({'Time / ms';'(d)'}) 42 | ylabel('Freq. / Hz') 43 | set(gca,'FontSize',fsz); 44 | set(get(gca,'XLabel'),'FontSize',fsz); 45 | set(get(gca,'YLabel'),'FontSize',fsz); 46 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 47 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 48 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 49 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 50 | axes('Position',[0.6,0.6,0.3,0.3]); 51 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 52 | set(gca,'ydir','normal') 53 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 54 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_RS_win',num2str(winlen)); 55 | saveas(gcf, fname); 56 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 57 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 58 | 59 | 60 | %% GWarblet 61 | [Spec,f] = GWarblet_complex(data_reshape,FS,0,1,nfft,winlen); 62 | [v l] = max(Spec,[],1); 63 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 64 | WinLen =winlen*4; 65 | [Spec,f] = GWarblet_complex(data_reshape,FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 66 | 67 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,0,1,nfft,winlen); 68 | [v l] = max(Spec2,[],1); 69 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 70 | WinLen =winlen*4; 71 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 72 | Spec=[flipud(Spec2);Spec]; 73 | 74 | figure; 75 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 76 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 77 | set(gca,'ydir','normal') 78 | ylim([ylim_low max(yaxis)]) 79 | title(strcat('GWarblet')) 80 | xlabel({'Time / ms';'(c)'}) 81 | ylabel('Freq. / Hz') 82 | set(gca,'FontSize',fsz); 83 | set(get(gca,'XLabel'),'FontSize',fsz); 84 | set(get(gca,'YLabel'),'FontSize',fsz); 85 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 86 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 87 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 88 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 89 | axes('Position',[0.6,0.6,0.3,0.3]); 90 | 91 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 92 | set(gca,'ydir','normal') 93 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 94 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_GWarblet'); 95 | saveas(gcf, fname); 96 | 97 | %% WVD 98 | Spec=wvd(data_reshape,'smoothedPseudo'); 99 | Spec=fftshift(Spec,1); 100 | figure; 101 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 102 | imagesc(t,yaxis/2,20*log10((abs(Spec))),clims); 103 | set(gca,'ydir','normal') 104 | ylim([ylim_low max(yaxis)]) 105 | title(strcat('SPWVD')) 106 | xlabel({'Time / ms';'(b)'}) 107 | ylabel('Freq. / Hz') 108 | set(gca,'FontSize',fsz); 109 | set(get(gca,'XLabel'),'FontSize',fsz); 110 | set(get(gca,'YLabel'),'FontSize',fsz); 111 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 112 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 113 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 114 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 115 | axes('Position',[0.6,0.6,0.3,0.3]); 116 | 117 | imagesc(t,yaxis/2,20*log10((abs(Spec))),clims); 118 | set(gca,'ydir','normal') 119 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 120 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_SPWVD'); 121 | saveas(gcf, fname); 122 | 123 | %% cwt 124 | % Spec=cwt((data_reshape)); 125 | % Spec=([(fftshift(Spec(:,:,2),1));fftshift(flipud(Spec(:,:,1)),1)]); 126 | % figure; 127 | % clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 128 | % imagesc(t,yaxis,20*log10((abs(Spec)))); 129 | % set(gca,'ydir','normal') 130 | % ylim([ylim_low max(yaxis)]) 131 | % title(strcat('CWT')) 132 | % xlabel({'Time / ms';'(c)'}) 133 | % ylabel('Freq. / Hz') 134 | % set(gca,'FontSize',fsz); 135 | % set(get(gca,'XLabel'),'FontSize',fsz); 136 | % set(get(gca,'YLabel'),'FontSize',fsz); 137 | % rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 138 | % rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 139 | % line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 140 | % line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 141 | % axes('Position',[0.6,0.6,0.3,0.3]); 142 | % 143 | % imagesc(t,yaxis,20*log10((abs(Spec)))); 144 | % set(gca,'ydir','normal') 145 | % set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 146 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_CWT'); 147 | % saveas(gcf, fname); 148 | %% STFT 149 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 150 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 151 | tt=t; 152 | 153 | figure; 154 | clims = [max(max(20*log10((abs(spc_STFT)))))-clims_lim,max(max(20*log10((abs(spc_STFT)))))]; 155 | imagesc(tt,yaxis,20*log10((abs(spc_STFT))),clims); 156 | set(gca,'ydir','normal') 157 | ylim([ylim_low max(yaxis)]) 158 | title(strcat('STFT, winlen=',num2str(winlen))) 159 | xlabel({'Time / ms';'(a)'}) 160 | ylabel('Freq. / Hz') 161 | set(gca,'FontSize',fsz); 162 | set(get(gca,'XLabel'),'FontSize',fsz); 163 | set(get(gca,'YLabel'),'FontSize',fsz); 164 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 165 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 166 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 167 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 168 | axes('Position',[0.6,0.6,0.3,0.3]); 169 | 170 | imagesc(tt,yaxis,20*log10(abs(spc_STFT)),clims) 171 | set(gca,'ydir','normal') 172 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 173 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_STFT_win',num2str(winlen)); 174 | saveas(gcf, fname); 175 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_STFT_win',num2str(winlen),'.pdf'); 176 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 177 | %% SST 178 | spc_SST = SST2(data_reshape,winlen); 179 | spc_SST=abs(spc_SST); 180 | 181 | tt=t; 182 | figure; 183 | clims = [max(max(20*log10((abs(spc_SST)))))-25,max(max(20*log10((abs(spc_SST)))))]; 184 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 185 | set(gca,'ydir','normal') 186 | ylim([ylim_low max(yaxis)]) 187 | xlabel({'Time / ms';'(e)'}) 188 | ylabel('Freq. / Hz') 189 | title(strcat('SST, winlen=',num2str(winlen))) 190 | set(gca,'FontSize',fsz); 191 | set(get(gca,'XLabel'),'FontSize',fsz); 192 | set(get(gca,'YLabel'),'FontSize',fsz); 193 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 194 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 195 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 196 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 197 | axes('Position',[0.6,0.6,0.3,0.3]); 198 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 199 | set(gca,'ydir','normal') 200 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 201 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_SST_win',num2str(winlen)); 202 | saveas(gcf, fname); 203 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SST_win',num2str(winlen),'.pdf'); 204 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 205 | 206 | %% SET 207 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 208 | spc_SET=abs(spc_SET); 209 | tt=t; 210 | figure; 211 | clims = [max(max(20*log10((abs(spc_SET)))))-15,max(max(20*log10((abs(spc_SET)))))]; 212 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 213 | set(gca,'ydir','normal') 214 | ylim([ylim_low max(yaxis)]) 215 | title(strcat('SET, winlen=',num2str(winlen))) 216 | xlabel({'Time / ms';'(f)'}) 217 | ylabel('Freq. / Hz') 218 | set(gca,'FontSize',fsz); 219 | set(get(gca,'XLabel'),'FontSize',fsz); 220 | set(get(gca,'YLabel'),'FontSize',fsz); 221 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 222 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 223 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 224 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 225 | axes('Position',[0.6,0.6,0.3,0.3]); 226 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 227 | set(gca,'ydir','normal') 228 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 229 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_SET_win',num2str(winlen)); 230 | saveas(gcf, fname); 231 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_SET_win',num2str(winlen),'.pdf'); 232 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 233 | 234 | %% MSST 235 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 236 | spc_MSST=spc_MSST1; 237 | tt=t; 238 | figure; 239 | clims = [max(max(20*log10((abs(spc_MSST)))))-20,max(max(20*log10((abs(spc_MSST)))))]; 240 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 241 | set(gca,'ydir','normal') 242 | ylim([ylim_low max(yaxis)]) 243 | title(strcat('MSST, winlen=',num2str(winlen))) 244 | xlabel({'Time / ms';'(g)'}) 245 | ylabel('Freq. / Hz') 246 | set(gca,'FontSize',fsz); 247 | set(get(gca,'XLabel'),'FontSize',fsz); 248 | set(get(gca,'YLabel'),'FontSize',fsz); 249 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 250 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 251 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 252 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 253 | axes('Position',[0.6,0.6,0.3,0.3]); 254 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 255 | set(gca,'ydir','normal') 256 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 257 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_MSST_win',num2str(winlen)); 258 | saveas(gcf, fname); 259 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_MSST_win',num2str(winlen),'.pdf'); 260 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 261 | 262 | 263 | 264 | %% ResFreq 265 | data_rsh=data_reshape.'; 266 | mv=max(abs(data_rsh)); 267 | noisedSig=data_rsh; 268 | 269 | siglen=256; 270 | ret=[]; 271 | sp=1; 272 | ep=sp+siglen-1; 273 | overlay=30; 274 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 281 | flag=1; 282 | left_len=length(noisedSig)-sp+1; 283 | ret=[ret;noisedSig(end-siglen+1:end)]; 284 | end 285 | 286 | bz=size(ret,1); 287 | 288 | if ~exist('matlab_real2.h5','file')==0 289 | delete('matlab_real2.h5') 290 | end 291 | if ~exist('matlab_imag2.h5','file')==0 292 | delete('matlab_imag2.h5') 293 | end 294 | if ~exist('bz.h5','file')==0 295 | delete('bz.h5') 296 | end 297 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 298 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 299 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 300 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 301 | h5create('bz.h5','/bz',size(bz)); 302 | h5write('bz.h5','/bz',bz); 303 | 304 | net_flag=system('curl -s 127.0.0.1:5012/'); 305 | load data1_resfreq.mat 306 | ret=[]; 307 | 308 | if flag==1 309 | iter=bz-1; 310 | else 311 | iter=bz; 312 | end 313 | for i=1:iter 314 | if i==1 315 | tmp=squeeze(data1_resfreq(i,:,:)); 316 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 317 | else 318 | tmp=squeeze(data1_resfreq(i,:,:)); 319 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 320 | end 321 | end 322 | if flag==1 323 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 324 | tmp=tmp(end-left_len+1:end,:); 325 | ret=[ret;tmp(overlay/2+1:end,:)]; 326 | end 327 | figure; 328 | ret=(10.^(ret/20)-1)/10; 329 | clims = [max(max(20*log10((abs(ret)))))-clims_lim,max(max(20*log10((abs(ret)))))]; 330 | imagesc(t,yaxis,20*log10(fftshift(abs(ret.'),1)),clims) 331 | set(gca,'ydir','normal') 332 | ylim([ylim_low max(yaxis)]) 333 | xlabel({'Time / ms';'(i)'}) 334 | ylabel('Freq. / Hz') 335 | title('TFA-Net') 336 | set(gca,'FontSize',fsz); 337 | set(get(gca,'XLabel'),'FontSize',fsz); 338 | set(get(gca,'YLabel'),'FontSize',fsz); 339 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 340 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','red','Linewidth',3); 341 | line([x11,x1],[y22,y2],'color','r','Linewidth',2,'LineStyle','-.') 342 | line([x22,x1],[y22,y1],'color','r','Linewidth',2,'LineStyle','-.') 343 | axes('Position',[0.6,0.6,0.3,0.3]); 344 | imagesc(tt,yaxis,20*log10(fftshift(abs(ret.'),1)),clims); 345 | set(gca,'ydir','normal') 346 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 347 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp3_dianji_TFA-Net'); 348 | saveas(gcf, fname); 349 | % fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_heart_TFA-Net','.pdf'); 350 | % export_fig(gcf , '-eps' , '-r300' , '-painters' , fname); 351 | %% 352 | 353 | -------------------------------------------------------------------------------- /TFA-Net_inference/exp3_voice.m: -------------------------------------------------------------------------------- 1 | 2 | clc 3 | clear 4 | close all 5 | nfft=256; 6 | fsz=16; 7 | clims_lim=30; 8 | 9 | 10 | [Y, FS]=audioread('ringtoneFrasers.mp3'); 11 | TS=1/FS; 12 | N=length(Y(:,1)); 13 | data_reshape=Y(floor(N*1.4/10):1:floor(2.4*N/6),1); 14 | 15 | data_reshape=hilbert(data_reshape); 16 | M=length(data_reshape); 17 | t=0:TS:M*TS-TS; 18 | winlen=64; 19 | ydelta=FS/nfft; 20 | yaxis=((0:ydelta:FS-ydelta)-FS/2)/1000; 21 | x11=0.2; x22=0.55; 22 | y11=5; y22=19; 23 | 24 | x1=0.7;x2=1; 25 | y1=5;y2=15; 26 | ylim_low=0; 27 | 28 | 29 | 30 | %% STFT 31 | data_reshape1=[zeros(winlen/2,1);data_reshape;zeros(winlen/2-1,1)]; 32 | spc_STFT=abs(stft(data_reshape1,'Window',hamming(winlen).','OverlapLength',winlen-1,'FFTLength',nfft)); 33 | tt=t; 34 | 35 | h=figure(); 36 | set(h,'position',[100 100 1200 400]); 37 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 38 | axes(ha(1)); 39 | clims = [max(max(20*log10((abs(spc_STFT)))))-clims_lim,max(max(20*log10((abs(spc_STFT)))))]; 40 | imagesc(tt,yaxis,20*log10((abs(spc_STFT))),clims); 41 | set(gca,'ydir','normal') 42 | ylim([ylim_low max(yaxis)]) 43 | title(strcat('STFT, winlen=',num2str(winlen))) 44 | xlabel({'Time / sec';'(a)'}) 45 | ylabel('Freq. / kHz') 46 | set(gca,'FontSize',fsz); 47 | set(get(gca,'XLabel'),'FontSize',fsz); 48 | set(get(gca,'YLabel'),'FontSize',fsz); 49 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 50 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 51 | 52 | axes(ha(2)) 53 | imagesc(tt,yaxis,20*log10(abs(spc_STFT)),clims) 54 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 55 | set(gca,'ydir','normal') 56 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 57 | title(strcat('local STFT, winlen=',num2str(winlen))) 58 | xlabel({'Time / sec';'(b)'}) 59 | ylabel('Freq. / kHz') 60 | set(gca,'FontSize',fsz); 61 | set(get(gca,'XLabel'),'FontSize',fsz); 62 | set(get(gca,'YLabel'),'FontSize',fsz); 63 | 64 | axes(ha(3)) 65 | imagesc(tt,yaxis,20*log10(abs(spc_STFT)),clims) 66 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 67 | set(gca,'ydir','normal') 68 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 69 | title(strcat('local STFT, winlen=',num2str(winlen))) 70 | xlabel({'Time / sec';'(c)'}) 71 | ylabel('Freq. / kHz') 72 | set(gca,'FontSize',fsz); 73 | set(get(gca,'XLabel'),'FontSize',fsz); 74 | set(get(gca,'YLabel'),'FontSize',fsz); 75 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_STFT_win',num2str(winlen)); 76 | saveas(gcf, fname); 77 | 78 | 79 | %% GWarblet 80 | [Spec,f] = GWarblet_complex(data_reshape,FS,0,1,nfft,winlen); 81 | [v l] = max(Spec,[],1); 82 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 83 | WinLen =winlen*4; 84 | [Spec,f] = GWarblet_complex(data_reshape,FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 85 | 86 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,0,1,nfft,winlen); 87 | [v l] = max(Spec2,[],1); 88 | [IF, a_n,b_n,fm] = get_fscoeff(f(l),length(t),t,FS); 89 | WinLen =winlen*4; 90 | [Spec2,f] = GWarblet_complex(conj(data_reshape),FS,[-a_n;b_n],fm(2:end),nfft,WinLen); 91 | Spec=[flipud(Spec2);Spec]; 92 | % save gwarblet.mat Spec 93 | % load gwarblet.mat 94 | h=figure(); 95 | set(h,'position',[100 100 1200 400]); 96 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 97 | axes(ha(1)) 98 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 99 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 100 | set(gca,'ydir','normal') 101 | ylim([ylim_low max(yaxis)]) 102 | title(strcat('GWarblet')) 103 | xlabel({'Time / sec';'(g)'}) 104 | ylabel('Freq. / kHz') 105 | set(gca,'FontSize',fsz); 106 | set(get(gca,'XLabel'),'FontSize',fsz); 107 | set(get(gca,'YLabel'),'FontSize',fsz); 108 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 109 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 110 | 111 | axes(ha(2)) 112 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 113 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 114 | set(gca,'ydir','normal') 115 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 116 | title(strcat('local GWarblet')) 117 | xlabel({'Time / sec';'(h)'}) 118 | ylabel('Freq. / kHz') 119 | set(gca,'FontSize',fsz); 120 | set(get(gca,'XLabel'),'FontSize',fsz); 121 | set(get(gca,'YLabel'),'FontSize',fsz); 122 | 123 | axes(ha(3)) 124 | imagesc(t,yaxis,20*log10((abs(Spec))),clims); 125 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 126 | set(gca,'ydir','normal') 127 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 128 | title(strcat('local GWarblet')) 129 | xlabel({'Time / sec';'(i)'}) 130 | ylabel('Freq. / kHz') 131 | set(gca,'FontSize',fsz); 132 | set(get(gca,'XLabel'),'FontSize',fsz); 133 | set(get(gca,'YLabel'),'FontSize',fsz); 134 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_GWarblet_win',num2str(winlen)); 135 | saveas(gcf, fname); 136 | 137 | %% WVD 138 | yaxis1=((0:ydelta:FS/2-ydelta))/1000; 139 | Spec=wvd(data_reshape,'smoothedPseudo',hamming(winlen-1),'NumFrequencyPoints',nfft); 140 | 141 | h=figure(); 142 | set(h,'position',[100 100 1200 400]); 143 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 144 | axes(ha(1)) 145 | clims = [max(max(20*log10((abs(Spec)))))-clims_lim,max(max(20*log10((abs(Spec)))))]; 146 | imagesc(t,yaxis1,20*log10((abs(Spec))),clims); 147 | set(gca,'ydir','normal') 148 | ylim([ylim_low max(yaxis1)]) 149 | title(strcat('SPWVD')) 150 | xlabel({'Time / sec';'(d)'}) 151 | ylabel('Freq. / kHz') 152 | set(gca,'FontSize',fsz); 153 | set(get(gca,'XLabel'),'FontSize',fsz); 154 | set(get(gca,'YLabel'),'FontSize',fsz); 155 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 156 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 157 | 158 | axes(ha(2)) 159 | imagesc(t,yaxis1,20*log10((abs(Spec))),clims); 160 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 161 | set(gca,'ydir','normal') 162 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 163 | title(strcat('local SPWVD')) 164 | xlabel({'Time / sec';'(e)'}) 165 | ylabel('Freq. / kHz') 166 | set(gca,'FontSize',fsz); 167 | set(get(gca,'XLabel'),'FontSize',fsz); 168 | set(get(gca,'YLabel'),'FontSize',fsz); 169 | 170 | axes(ha(3)) 171 | imagesc(t,yaxis1,20*log10((abs(Spec))),clims); 172 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 173 | set(gca,'ydir','normal') 174 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 175 | title(strcat('local SPWVD')) 176 | xlabel({'Time / sec';'(f)'}) 177 | ylabel('Freq. / kHz') 178 | set(gca,'FontSize',fsz); 179 | set(get(gca,'XLabel'),'FontSize',fsz); 180 | set(get(gca,'YLabel'),'FontSize',fsz); 181 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_PSWVD_win',num2str(winlen)); 182 | saveas(gcf, fname); 183 | 184 | %% SST 185 | spc_SST = SST2(data_reshape,winlen); 186 | spc_SST=abs(spc_SST); 187 | 188 | tt=t; 189 | h=figure(); 190 | set(h,'position',[100 100 1200 400]); 191 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 192 | axes(ha(1)); 193 | clims = [max(max(20*log10((abs(spc_SST)))))-clims_lim,max(max(20*log10((abs(spc_SST)))))]; 194 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 195 | set(gca,'ydir','normal') 196 | ylim([ylim_low max(yaxis)]) 197 | xlabel({'Time / sec';'(j)'}) 198 | ylabel('Freq. / kHz') 199 | title(strcat('SST, winlen=',num2str(winlen))) 200 | set(gca,'FontSize',fsz); 201 | set(get(gca,'XLabel'),'FontSize',fsz); 202 | set(get(gca,'YLabel'),'FontSize',fsz); 203 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 204 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 205 | 206 | axes(ha(2)); 207 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 208 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 209 | set(gca,'ydir','normal') 210 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 211 | title(strcat('local SST, winlen=',num2str(winlen))) 212 | xlabel({'Time / sec';'(k)'}) 213 | ylabel('Freq. / kHz') 214 | set(gca,'FontSize',fsz); 215 | set(get(gca,'XLabel'),'FontSize',fsz); 216 | set(get(gca,'YLabel'),'FontSize',fsz); 217 | 218 | axes(ha(3)) 219 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SST),1)),clims); 220 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 221 | set(gca,'ydir','normal') 222 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 223 | title(strcat('local SST, winlen=',num2str(winlen))) 224 | xlabel({'Time / sec';'(m)'}) 225 | ylabel('Freq. / kHz') 226 | set(gca,'FontSize',fsz); 227 | set(get(gca,'XLabel'),'FontSize',fsz); 228 | set(get(gca,'YLabel'),'FontSize',fsz); 229 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_SST_win',num2str(winlen)); 230 | saveas(gcf, fname); 231 | 232 | %% SET 233 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 234 | spc_SET=abs(spc_SET); 235 | tt=t; 236 | h=figure(); 237 | set(h,'position',[100 100 1200 400]); 238 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 239 | axes(ha(1)); 240 | clims = [max(max(20*log10((abs(spc_SET)))))-clims_lim,max(max(20*log10((abs(spc_SET)))))]; 241 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 242 | set(gca,'ydir','normal') 243 | ylim([ylim_low max(yaxis)]) 244 | title(strcat('SET, winlen=',num2str(winlen))) 245 | xlabel({'Time / sec';'(n)'}) 246 | ylabel('Freq. / kHz') 247 | set(gca,'FontSize',fsz); 248 | set(get(gca,'XLabel'),'FontSize',fsz); 249 | set(get(gca,'YLabel'),'FontSize',fsz); 250 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 251 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 252 | 253 | axes(ha(2)) 254 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 255 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 256 | set(gca,'ydir','normal') 257 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 258 | title(strcat('local SET, winlen=',num2str(winlen))) 259 | xlabel({'Time / sec';'(o)'}) 260 | ylabel('Freq. / kHz') 261 | set(gca,'FontSize',fsz); 262 | set(get(gca,'XLabel'),'FontSize',fsz); 263 | set(get(gca,'YLabel'),'FontSize',fsz); 264 | 265 | axes(ha(3)) 266 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_SET),1)),clims); 267 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 268 | set(gca,'ydir','normal') 269 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 270 | title(strcat('local SET, winlen=',num2str(winlen))) 271 | xlabel({'Time / sec';'(p)'}) 272 | ylabel('Freq. / kHz') 273 | set(gca,'FontSize',fsz); 274 | set(get(gca,'XLabel'),'FontSize',fsz); 275 | set(get(gca,'YLabel'),'FontSize',fsz); 276 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_SET_win',num2str(winlen)); 277 | saveas(gcf, fname); 278 | 279 | %% MSST 280 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 281 | spc_MSST=spc_MSST1; 282 | tt=t; 283 | h=figure(); 284 | set(h,'position',[100 100 1200 400]); 285 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 286 | axes(ha(1)); 287 | clims = [max(max(20*log10((abs(spc_MSST)))))-clims_lim,max(max(20*log10((abs(spc_MSST)))))]; 288 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 289 | set(gca,'ydir','normal') 290 | ylim([ylim_low max(yaxis)]) 291 | title(strcat('MSST, winlen=',num2str(winlen))) 292 | xlabel({'Time / sec';'(q)'}) 293 | ylabel('Freq. / kHz') 294 | set(gca,'FontSize',fsz); 295 | set(get(gca,'XLabel'),'FontSize',fsz); 296 | set(get(gca,'YLabel'),'FontSize',fsz); 297 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 298 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 299 | 300 | axes(ha(2)) 301 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 302 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 303 | set(gca,'ydir','normal') 304 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 305 | title(strcat('local MSST, winlen=',num2str(winlen))) 306 | xlabel({'Time / sec';'(r)'}) 307 | ylabel('Freq. / kHz') 308 | set(gca,'FontSize',fsz); 309 | set(get(gca,'XLabel'),'FontSize',fsz); 310 | set(get(gca,'YLabel'),'FontSize',fsz); 311 | 312 | axes(ha(3)) 313 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_MSST),1)),clims); 314 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 315 | set(gca,'ydir','normal') 316 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 317 | title(strcat('local MSST, winlen=',num2str(winlen))) 318 | xlabel({'Time / sec';'(s)'}) 319 | ylabel('Freq. / kHz') 320 | set(gca,'FontSize',fsz); 321 | set(get(gca,'XLabel'),'FontSize',fsz); 322 | set(get(gca,'YLabel'),'FontSize',fsz); 323 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_MSST_win',num2str(winlen)); 324 | saveas(gcf, fname); 325 | 326 | %% RS 327 | data_rsh=data_reshape.'; 328 | noisedSig=data_rsh; 329 | siglen=256; 330 | ret=[]; 331 | sp=1; 332 | ep=sp+siglen-1; 333 | overlay=30; 334 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 341 | flag=1; 342 | left_len=length(noisedSig)-sp+1; 343 | ret=[ret;noisedSig(end-siglen+1:end)]; 344 | end 345 | 346 | bz=size(ret,1); 347 | spc_RS=[]; 348 | for i=1:bz 349 | spc_RS(i,:,:) = RS(ret(i,:).',winlen); 350 | end 351 | 352 | if flag==1 353 | iter=bz-1; 354 | else 355 | iter=bz; 356 | end 357 | ret=[]; 358 | for i=1:iter 359 | tmp=squeeze(spc_RS(i,:,:)).'; 360 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 361 | end 362 | if flag==1 363 | tmp=squeeze(spc_RS(iter+1,:,:)).'; 364 | tmp=tmp(end-left_len+1:end,:); 365 | ret=[ret;tmp(overlay/2+1:end,:)]; 366 | end 367 | 368 | spc_RS=ret.'; 369 | spc_RS=abs(spc_RS); 370 | 371 | tt=t; 372 | h=figure(); 373 | set(h,'position',[100 100 1200 400]); 374 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 375 | axes(ha(1)) 376 | clims = [max(max(20*log10((abs(spc_RS)))))-clims_lim,max(max(20*log10((abs(spc_RS)))))]; 377 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 378 | set(gca,'ydir','normal') 379 | ylim([ylim_low max(yaxis)]) 380 | title(strcat('RS, winlen=',num2str(winlen))) 381 | xlabel({'Time / sec';'(t)'}) 382 | ylabel('Freq. / kHz') 383 | set(gca,'FontSize',fsz); 384 | set(get(gca,'XLabel'),'FontSize',fsz); 385 | set(get(gca,'YLabel'),'FontSize',fsz); 386 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 387 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 388 | 389 | axes(ha(2)) 390 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 391 | set(gca,'ydir','normal') 392 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 393 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 394 | title(strcat('RS, winlen=',num2str(winlen))) 395 | xlabel({'Time / sec';'(u)'}) 396 | ylabel('Freq. / kHz') 397 | set(gca,'FontSize',fsz); 398 | set(get(gca,'XLabel'),'FontSize',fsz); 399 | set(get(gca,'YLabel'),'FontSize',fsz); 400 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_RS_win',num2str(winlen)); 401 | saveas(gcf, fname); 402 | 403 | axes(ha(3)) 404 | imagesc(tt,yaxis,20*log10(fftshift(abs(spc_RS),1)),clims); 405 | set(gca,'ydir','normal') 406 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 407 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 408 | title(strcat('RS, winlen=',num2str(winlen))) 409 | xlabel({'Time / sec';'(v)'}) 410 | ylabel('Freq. / kHz') 411 | set(gca,'FontSize',fsz); 412 | set(get(gca,'XLabel'),'FontSize',fsz); 413 | set(get(gca,'YLabel'),'FontSize',fsz); 414 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_RS_win',num2str(winlen)); 415 | saveas(gcf, fname); 416 | 417 | %% ResFreq 418 | data_rsh=data_reshape.'; 419 | mv=max(abs(data_rsh)); 420 | noisedSig=data_rsh/mv; 421 | 422 | siglen=256; 423 | ret=[]; 424 | sp=1; 425 | ep=sp+siglen-1; 426 | overlay=30; 427 | while eplength(noisedSig) && length(noisedSig)-sp+1>overlay/2 434 | flag=1; 435 | left_len=length(noisedSig)-sp+1; 436 | ret=[ret;noisedSig(end-siglen+1:end)]; 437 | end 438 | 439 | bz=size(ret,1); 440 | 441 | if ~exist('matlab_real2.h5','file')==0 442 | delete('matlab_real2.h5') 443 | end 444 | if ~exist('matlab_imag2.h5','file')==0 445 | delete('matlab_imag2.h5') 446 | end 447 | if ~exist('bz.h5','file')==0 448 | delete('bz.h5') 449 | end 450 | h5create('matlab_real2.h5','/matlab_real2',size(ret)); 451 | h5write('matlab_real2.h5','/matlab_real2',real(ret)); 452 | h5create('matlab_imag2.h5','/matlab_imag2',size(ret)); 453 | h5write('matlab_imag2.h5','/matlab_imag2',imag(ret)); 454 | h5create('bz.h5','/bz',size(bz)); 455 | h5write('bz.h5','/bz',bz); 456 | 457 | net_flag=system('curl -s 127.0.0.1:5012/'); 458 | load data1_resfreq.mat 459 | ret=[]; 460 | 461 | if flag==1 462 | iter=bz-1; 463 | else 464 | iter=bz; 465 | end 466 | for i=1:iter 467 | if i==1 468 | tmp=squeeze(data1_resfreq(i,:,:)); 469 | ret=[ret;tmp(1:siglen-overlay/2,:)]; 470 | else 471 | tmp=squeeze(data1_resfreq(i,:,:)); 472 | ret=[ret;tmp(overlay/2+1:siglen-overlay/2,:)]; 473 | end 474 | end 475 | if flag==1 476 | tmp=squeeze(data1_resfreq(iter+1,:,:)); 477 | tmp=tmp(end-left_len+1:end,:); 478 | ret=[ret;tmp(overlay/2+1:end,:)]; 479 | end 480 | h=figure(); 481 | set(h,'position',[100 100 1200 400]); 482 | ha=tight_subplot(1,3,[0.08 0.06],[.25 .08],[.06 .02]); 483 | axes(ha(1)); 484 | ret=(10.^(ret/20)-1)/10; 485 | clims = [max(max(20*log10((abs(ret)))))-clims_lim,max(max(20*log10((abs(ret)))))]; 486 | imagesc(t,yaxis,20*log10(fftshift(abs(ret.'),1)),clims) 487 | set(gca,'ydir','normal') 488 | ylim([ylim_low max(yaxis)]) 489 | xlabel({'Time / sec';'(w)'}) 490 | ylabel('Freq. / kHz') 491 | title('TFA-Net') 492 | set(gca,'FontSize',fsz); 493 | set(get(gca,'XLabel'),'FontSize',fsz); 494 | set(get(gca,'YLabel'),'FontSize',fsz); 495 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 496 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',3); 497 | 498 | axes(ha(2)) 499 | imagesc(t,yaxis,20*log10(fftshift(abs(ret.'),1)),clims) 500 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 501 | set(gca,'ydir','normal') 502 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 503 | title(strcat('local TFA-Net')) 504 | xlabel({'Time / sec';'(x)'}) 505 | ylabel('Freq. / kHz') 506 | set(gca,'FontSize',fsz); 507 | set(get(gca,'XLabel'),'FontSize',fsz); 508 | set(get(gca,'YLabel'),'FontSize',fsz); 509 | 510 | axes(ha(3)) 511 | imagesc(t,yaxis,20*log10(fftshift(abs(ret.'),1)),clims) 512 | rectangle('Position',[x1 y1 x2-x1 y2-y1],'EdgeColor','#D2691E','Linewidth',7); 513 | set(gca,'ydir','normal') 514 | set(gca,'xlim',[x1 x2],'ylim',[y1 y2]); 515 | title(strcat('local TFA-Net')) 516 | xlabel({'Time / sec';'(y)'}) 517 | ylabel('Freq. / kHz') 518 | set(gca,'FontSize',fsz); 519 | set(get(gca,'XLabel'),'FontSize',fsz); 520 | set(get(gca,'YLabel'),'FontSize',fsz); 521 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','exp2_voice_TFA'); 522 | saveas(gcf, fname); 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | -------------------------------------------------------------------------------- /TFA-Net_inference/get_fscoeff.m: -------------------------------------------------------------------------------- 1 | function [f, a_n,b_n,Freq] = get_fscoeff(y,nfft,t,SampFreq) 2 | % To obtain parameters for GWarblet 3 | % y : the extracted IF 4 | % nfft: the number of fft 5 | 6 | % f : reconstructed fourier series 7 | % a_n : coeff of sin; 8 | % b_n : coeff of cos; 9 | yfft = fft(y,nfft)/nfft; 10 | Nq = ceil(nfft/2); 11 | Freq = SampFreq/nfft*[0:Nq-1]; 12 | 13 | % Build the fourier series 14 | f = zeros(size(y)); 15 | for j = 2:round(Nq/3) 16 | a_n(j-1) = -2*imag(yfft(j)); 17 | b_n(j-1) = 2*real(yfft(j)); 18 | f = f + b_n(j-1)*cos(2*pi*t*Freq(j)) + a_n(j-1)*sin(2*pi*t*Freq(j)); 19 | end 20 | f = abs(yfft(1))+f; -------------------------------------------------------------------------------- /TFA-Net_inference/ref_methods.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all 4 | N=256; 5 | nfft=256; 6 | 7 | % h=figure(); 8 | % set(h,'position',[100 100 1200 900]); 9 | % ha=tight_subplot(4,2,[0.09 0.05],[.1 .05],[.05 .02]); 10 | fsz=23; 11 | %% SIMU1 12 | fs=100; 13 | ts=1/fs; 14 | t = 0 : ts : 10-ts; 15 | Sig1 = exp(1i*2*pi*(8* t + 6 *sin(t) )); % get the A(t) 16 | Sig2 = exp(1i*2*pi*(10 * t + 6 *sin(1.5*t) )); % get the A(t) 17 | Sig=Sig1+Sig2; 18 | data_reshape=Sig.'; 19 | t1=t; 20 | 21 | ydelta=fs/nfft; 22 | yaxis=(0:ydelta:fs-ydelta)-fs/2; 23 | 24 | if1=(8 + 6*1 *cos(1*t)) ; 25 | if2=(10 + 6*1.5 *cos(1.5*t)) ; 26 | 27 | h=figure(); 28 | set(h,'position',[100 100 1200 400]); 29 | ha=tight_subplot(1,2,[0.08 0.07],[.31 .1],[.07 .02]); 30 | axes(ha(1)) 31 | plot(t,if1) 32 | hold on 33 | plot(t,if2) 34 | 35 | set(gca,'ydir','reverse') 36 | ylim([min(yaxis) max(yaxis)]) 37 | xlabel({'Time / sec';'(a)'}) 38 | ylabel('Freq. / Hz') 39 | title('Ground Truth') 40 | set(gca,'FontSize',fsz); 41 | set(get(gca,'XLabel'),'FontSize',fsz); 42 | set(get(gca,'YLabel'),'FontSize',fsz); 43 | winlen=64; 44 | x11=4.4; x22=8.5; 45 | y11=6; y22=15; 46 | 47 | xx11=0.5; xx22=3; 48 | yy11=0.5; yy22=17; 49 | ylow=0; 50 | yhigh=20; 51 | ylim([ylow,yhigh]) 52 | 53 | %% STFT 54 | window_v=ones(1,64); 55 | spc_STFT=abs(spectrogram(data_reshape,winlen,winlen-1,nfft)); 56 | siglen2=size(spc_STFT,2); 57 | siglen=size(data_reshape,1); 58 | start_pos=fix((siglen-siglen2)/2); 59 | if start_pos==0 60 | start_pos=1; 61 | end 62 | tt=t(start_pos:start_pos+siglen2-1); 63 | 64 | axes(ha(2)) 65 | imagesc(tt,yaxis,(fftshift(abs(spc_STFT),1))); 66 | title(strcat('STFT, winlen=',num2str(winlen))) 67 | xlabel({'Time / sec';'(b)'}) 68 | ylabel('Freq. / Hz') 69 | set(gca,'FontSize',fsz); 70 | set(get(gca,'XLabel'),'FontSize',fsz); 71 | set(get(gca,'YLabel'),'FontSize',fsz); 72 | ylim([ylow,yhigh]) 73 | 74 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','ref_STFT_win',num2str(winlen)); 75 | saveas(gcf, fname); 76 | 77 | %% GT 78 | xx=(fftshift(abs(spc_STFT),1)); 79 | x1=(xx(:,371)); 80 | x2=(xx(:,211)); 81 | h=figure(); 82 | set(h,'position',[100 100 1200 400]); 83 | ha=tight_subplot(1,2,[0.08 0.07],[.31 .1],[.06 .02]); 84 | axes(ha(1)) 85 | axes(ha(1)) 86 | plot(yaxis,x2,'k','Linewidth',2); 87 | hold on 88 | plot([if1(241),if1(241)],[0,55],'r','Linewidth',2) 89 | plot([if2(241),if2(241)],[0,55],'r','Linewidth',2) 90 | xlabel({'Freq. / Hz';'(a)'}) 91 | ylabel('Amp.') 92 | set(gca,'FontSize',fsz); 93 | set(get(gca,'XLabel'),'FontSize',fsz); 94 | set(get(gca,'YLabel'),'FontSize',fsz); 95 | xlim([-5,15]) 96 | ylim([ylow,55]) 97 | grid on 98 | axes(ha(2)) 99 | plot(yaxis,x1,'k','Linewidth',2); 100 | hold on 101 | plot([if1(401),if1(401)],[0,50],'r','Linewidth',2) 102 | plot([if2(401),if2(401)],[0,50],'r','Linewidth',2) 103 | xlabel({'Freq. / Hz';'(b)'}) 104 | ylabel('Amp.') 105 | set(gca,'FontSize',fsz); 106 | set(get(gca,'XLabel'),'FontSize',fsz); 107 | set(get(gca,'YLabel'),'FontSize',fsz); 108 | xlim([0,25]) 109 | ylim([ylow,50]) 110 | grid on 111 | 112 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','ref_GT_win',num2str(winlen)); 113 | saveas(gcf, fname); 114 | %% SST 115 | spc_SST = SST2(data_reshape,winlen); 116 | spc_SST=abs(spc_SST); 117 | siglen2=size(spc_SST,2); 118 | start_pos=fix((siglen-siglen2)/2); 119 | if start_pos==0 120 | start_pos=1; 121 | end 122 | tt=t(start_pos:start_pos+siglen2-1); 123 | 124 | h=figure(); 125 | set(h,'position',[100 100 1600 400]); 126 | ha=tight_subplot(1,3,[0.05 0.05],[.31 .1],[.05 .02]); 127 | axes(ha(1)) 128 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 129 | xlabel({'Time / sec';'(c)'}) 130 | ylabel('Freq. / Hz') 131 | title(strcat('SST, winlen=',num2str(winlen))) 132 | set(gca,'FontSize',fsz); 133 | set(get(gca,'XLabel'),'FontSize',fsz); 134 | set(get(gca,'YLabel'),'FontSize',fsz); 135 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 136 | rectangle('Position',[xx11 yy11 xx22-xx11 yy22-yy11],'EdgeColor','#D2691E','Linewidth',4); 137 | ylim([ylow,yhigh]) 138 | 139 | axes(ha(3)) 140 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 141 | xlabel({'Time / sec';'(e)'}) 142 | ylabel('Freq. / Hz') 143 | title(strcat('local SST, winlen=',num2str(winlen))) 144 | set(gca,'FontSize',fsz); 145 | set(get(gca,'XLabel'),'FontSize',fsz); 146 | set(get(gca,'YLabel'),'FontSize',fsz); 147 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 148 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',7); 149 | 150 | axes(ha(2)) 151 | imagesc(tt,yaxis,(fftshift(abs(spc_SST),1))); 152 | xlabel({'Time / sec';'(d)'}) 153 | ylabel('Freq. / Hz') 154 | title(strcat('local SST, winlen=',num2str(winlen))) 155 | set(gca,'FontSize',fsz); 156 | set(get(gca,'XLabel'),'FontSize',fsz); 157 | set(get(gca,'YLabel'),'FontSize',fsz); 158 | set(gca,'xlim',[xx11 xx22],'ylim',[yy11 yy22]); 159 | rectangle('Position',[xx11 yy11 xx22-xx11 yy22-yy11],'EdgeColor','#D2691E','Linewidth',6); 160 | 161 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','ref_SST_win',num2str(winlen)); 162 | saveas(gcf, fname); 163 | %% SET 164 | [spc_SET,tfr] = SET_Y2(data_reshape,winlen); 165 | spc_SET=abs(spc_SET); 166 | siglen2=size(spc_SET,2); 167 | start_pos=fix((siglen-siglen2)/2); 168 | if start_pos==0 169 | start_pos=1; 170 | end 171 | tt=t(start_pos:start_pos+siglen2-1); 172 | h=figure(); 173 | set(h,'position',[100 100 1600 400]); 174 | ha=tight_subplot(1,3,[0.05 0.05],[.31 .1],[.05 .02]); 175 | axes(ha(1)) 176 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 177 | title(strcat('SET, winlen=',num2str(winlen))) 178 | xlabel({'Time / sec';'(f)'}) 179 | ylabel('Freq. / Hz') 180 | set(gca,'FontSize',fsz); 181 | set(get(gca,'XLabel'),'FontSize',fsz); 182 | set(get(gca,'YLabel'),'FontSize',fsz); 183 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 184 | rectangle('Position',[xx11 yy11 xx22-xx11 yy22-yy11],'EdgeColor','#D2691E','Linewidth',4); 185 | ylim([ylow,yhigh]) 186 | 187 | axes(ha(3)) 188 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 189 | title(strcat('local SET, winlen=',num2str(winlen))) 190 | xlabel({'Time / sec';'(h)'}) 191 | ylabel('Freq. / Hz') 192 | set(gca,'FontSize',fsz); 193 | set(get(gca,'XLabel'),'FontSize',fsz); 194 | set(get(gca,'YLabel'),'FontSize',fsz); 195 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 196 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',6); 197 | 198 | axes(ha(2)) 199 | imagesc(tt,yaxis,(fftshift(abs(spc_SET),1))); 200 | title(strcat('local SET, winlen=',num2str(winlen))) 201 | xlabel({'Time / sec';'(g)'}) 202 | ylabel('Freq. / Hz') 203 | set(gca,'FontSize',fsz); 204 | set(get(gca,'XLabel'),'FontSize',fsz); 205 | set(get(gca,'YLabel'),'FontSize',fsz); 206 | set(gca,'xlim',[xx11 xx22],'ylim',[yy11 yy22]); 207 | rectangle('Position',[xx11 yy11 xx22-xx11 yy22-yy11],'EdgeColor','#D2691E','Linewidth',7); 208 | 209 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','ref_SET_win',num2str(winlen)); 210 | saveas(gcf, fname); 211 | %% MSST 212 | [spc_MSST1,tfr,omega2] = MSST_Y_new2(data_reshape,winlen,3); 213 | % spc_MSST1=abs(spc_MSST1); 214 | % [spc_MSST2,tfr,omega2] = MSST(conj(data_reshape),winlen,3); 215 | % spc_MSST2=abs(spc_MSST2); 216 | spc_MSST=spc_MSST1; 217 | siglen2=size(spc_MSST,2); 218 | start_pos=fix((siglen-siglen2)/2); 219 | if start_pos==0 220 | start_pos=1; 221 | end 222 | tt=t(start_pos:start_pos+siglen2-1); 223 | 224 | h=figure(); 225 | set(h,'position',[100 100 1600 400]); 226 | ha=tight_subplot(1,3,[0.05 0.05],[.31 .1],[.05 .02]); 227 | axes(ha(1)) 228 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 229 | title(strcat('MSST, winlen=',num2str(winlen))) 230 | xlabel({'Time / sec';'(i)'}) 231 | ylabel('Freq. / Hz') 232 | set(gca,'FontSize',fsz); 233 | set(get(gca,'XLabel'),'FontSize',fsz); 234 | set(get(gca,'YLabel'),'FontSize',fsz); 235 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',3); 236 | rectangle('Position',[xx11 yy11 xx22-xx11 yy22-yy11],'EdgeColor','#D2691E','Linewidth',4); 237 | ylim([ylow,yhigh]) 238 | 239 | axes(ha(3)) 240 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 241 | title(strcat('local MSST, winlen=',num2str(winlen))) 242 | xlabel({'Time / sec';'(k)'}) 243 | ylabel('Freq. / Hz') 244 | set(gca,'FontSize',fsz); 245 | set(get(gca,'XLabel'),'FontSize',fsz); 246 | set(get(gca,'YLabel'),'FontSize',fsz); 247 | set(gca,'xlim',[x11 x22],'ylim',[y11 y22]); 248 | rectangle('Position',[x11 y11 x22-x11 y22-y11],'EdgeColor','red','Linewidth',6); 249 | 250 | axes(ha(2)) 251 | imagesc(tt,yaxis,(fftshift(abs(spc_MSST),1))); 252 | title(strcat('local MSST, winlen=',num2str(winlen))) 253 | xlabel({'Time / sec';'(j)'}) 254 | ylabel('Freq. / Hz') 255 | set(gca,'FontSize',fsz); 256 | set(get(gca,'XLabel'),'FontSize',fsz); 257 | set(get(gca,'YLabel'),'FontSize',fsz); 258 | set(gca,'xlim',[xx11 xx22],'ylim',[yy11 yy22]); 259 | rectangle('Position',[xx11 yy11 xx22-xx11 yy22-yy11],'EdgeColor','#D2691E','Linewidth',7); 260 | 261 | fname=strcat('F:\pycharm_proj\cResTF\TFA_Net\figures_TFA_submit_20220120\figs\','ref_MSST_win',num2str(winlen)); 262 | saveas(gcf, fname); 263 | 264 | -------------------------------------------------------------------------------- /TFA-Net_inference/ringtoneFrasers.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_inference/ringtoneFrasers.mp3 -------------------------------------------------------------------------------- /TFA-Net_inference/stft_rearrage.m: -------------------------------------------------------------------------------- 1 | function [xin,t] = stft_rearrage(x,nx,nwin,noverlap,Fs) 2 | %getSTFTColumns re-orders input signal into matrix with overlap 3 | % This function is for internal use only. It may be removed in the future. 4 | % 5 | % Copyright 2016-2019 The MathWorks, Inc. 6 | %#codegen 7 | 8 | % Determine the number of columns of the STFT output (i.e., the S output) 9 | classCast = class(x); 10 | numChannels = size(x,2); 11 | numSample = size(x,1); 12 | ncol = fix((nx-noverlap)/(nwin-noverlap)); 13 | if ~isreal(x) 14 | xin = complex(zeros(nwin,ncol,numChannels,classCast)); 15 | else 16 | xin = zeros(nwin,ncol,numChannels,classCast); 17 | end 18 | 19 | % Determine the number of columns of the STFT output (i.e., the S output) 20 | coloffsets = (0:(ncol-1))*(nwin-noverlap); 21 | rowindices = (1:nwin)'; 22 | 23 | % Segment x into individual columns with the proper offsets for each input 24 | % channel 25 | winPerCh = bsxfun(@plus,rowindices,coloffsets); 26 | for iCh = 1:numChannels 27 | xin(:,:,iCh) = x(winPerCh+(iCh-1)*numSample); 28 | end 29 | 30 | % Return time vector whose elements are centered in each segment 31 | t = (coloffsets+(nwin/2)')/Fs; -------------------------------------------------------------------------------- /TFA-Net_inference/tight_subplot.m: -------------------------------------------------------------------------------- 1 | function ha = tight_subplot(Nh, Nw, gap, marg_h, marg_w) 2 | 3 | % tight_subplot creates "subplot" axes with adjustable gaps and margins 4 | % 5 | % ha = tight_subplot(Nh, Nw, gap, marg_h, marg_w) 6 | % 7 | % in: Nh number of axes in hight (vertical direction) 8 | % Nw number of axes in width (horizontaldirection) 9 | % gap gaps between the axes in normalized units (0...1) 10 | % or [gap_h gap_w] for different gaps in height and width 11 | % marg_h margins in height in normalized units (0...1) 12 | % or [lower upper] for different lower and upper margins 13 | % marg_w margins in width in normalized units (0...1) 14 | % or [left right] for different left and right margins 15 | % 16 | % out: ha array of handles of the axes objects 17 | % starting from upper left corner, going row-wise as in 18 | % going row-wise as in 19 | % 20 | % Example: ha = tight_subplot(3,2,[.01 .03],[.1 .01],[.01 .01]) 21 | % for ii = 1:6; axes(ha(ii)); plot(randn(10,ii)); end 22 | % set(ha(1:4),'XTickLabel',''); set(ha,'YTickLabel','') 23 | 24 | % Pekka Kumpulainen 20.6.2010 @tut.fi 25 | % Tampere University of Technology / Automation Science and Engineering 26 | 27 | 28 | if nargin<3; gap = .02; end 29 | if nargin<4 || isempty(marg_h); marg_h = .05; end 30 | if nargin<5; marg_w = .05; end 31 | 32 | if numel(gap)==1; 33 | gap = [gap gap]; 34 | end 35 | if numel(marg_w)==1; 36 | marg_w = [marg_w marg_w]; 37 | end 38 | if numel(marg_h)==1; 39 | marg_h = [marg_h marg_h]; 40 | end 41 | 42 | axh = (1-sum(marg_h)-(Nh-1)*gap(1))/Nh; 43 | axw = (1-sum(marg_w)-(Nw-1)*gap(2))/Nw; 44 | 45 | py = 1-marg_h(2)-axh; 46 | 47 | ha = zeros(Nh*Nw,1); 48 | ii = 0; 49 | for ih = 1:Nh 50 | px = marg_w(1); 51 | 52 | for ix = 1:Nw 53 | ii = ii+1; 54 | ha(ii) = axes('Units','normalized', ... 55 | 'Position',[px py axw axh], ... 56 | 'XTickLabel','', ... 57 | 'YTickLabel',''); 58 | px = px+axw+gap(2); 59 | end 60 | py = py-axh-gap(1); 61 | end -------------------------------------------------------------------------------- /TFA-Net_inference/winderi.m: -------------------------------------------------------------------------------- 1 | function Wdt = winderi(w,Fs) 2 | %DTWIN differentiate window in time domain via cubic spline interpolation 3 | % This function is for internal use only. It may be removed in the future. 4 | % 5 | % See also DFWIN. 6 | % 7 | % Copyright 2016-2018 The MathWorks, Inc. 8 | 9 | %#codegen 10 | 11 | % compute the piecewise polynomial representation of the window 12 | % and fetch the coefficients 13 | n = numel(w); 14 | pp = spline(1:n,w); 15 | 16 | % take the derivative of each polynomial and evaluate it over the same 17 | % samples as the original window 18 | 19 | if coder.target('MATLAB') 20 | [breaks,coefs,npieces,order,dim] = unmkpp(pp); 21 | ppd = mkpp(breaks,repmat(order-1:-1:1,dim*npieces,1).*coefs(:,1:order-1),dim); 22 | else 23 | ppd = coder.internal.ppder(pp); 24 | end 25 | 26 | Wdt = ppval(ppd,(1:n)').*(Fs/(2*pi)); 27 | 28 | % LocalWords: DFWIN 29 | -------------------------------------------------------------------------------- /TFA-Net_train/checkpoint/RED-Net_model/fr/RED_Net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_train/checkpoint/RED-Net_model/fr/RED_Net.pth -------------------------------------------------------------------------------- /TFA-Net_train/checkpoint/TFA-Net_model/fr/TFA-Net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panpp-git/TFA-Net/5f3ff33ba1cd3367eaf1603c79c25077d99035b1/TFA-Net_train/checkpoint/TFA-Net_model/fr/TFA-Net.pth -------------------------------------------------------------------------------- /TFA-Net_train/complexFunctions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @author: spopoff 6 | """ 7 | 8 | from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d 9 | import torch 10 | 11 | 12 | def complex_matmul(A, B): 13 | ''' 14 | Performs the matrix product between two complex matrices 15 | ''' 16 | 17 | outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) 18 | outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) 19 | 20 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) 21 | 22 | 23 | def complex_avg_pool2d(input, *args, **kwargs): 24 | ''' 25 | Perform complex average pooling. 26 | ''' 27 | absolute_value_real = avg_pool2d(input.real, *args, **kwargs) 28 | absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs) 29 | 30 | return absolute_value_real.type(torch.complex64) + 1j * absolute_value_imag.type(torch.complex64) 31 | 32 | 33 | def complex_relu2(input): 34 | input_r,input_i=torch.chunk(input,2,1) 35 | ret_r=relu(input_r) 36 | ret_i=relu(input_i) 37 | return torch.cat((ret_r,ret_i),1) 38 | # return relu(input.real).type(torch.complex64) + 1j * relu(input.imag).type(torch.complex64) 39 | 40 | 41 | def complex_relu(input): 42 | input_r,input_i=torch.chunk(input,2,1) 43 | 44 | ret_r=relu(input_r) 45 | ret_i=relu(input_i) 46 | return torch.cat((ret_r,ret_i),1) 47 | 48 | 49 | def _retrieve_elements_from_indices(tensor, indices): 50 | flattened_tensor = tensor.flatten(start_dim=-2) 51 | output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices) 52 | return output 53 | 54 | 55 | def complex_max_pool2d(input, kernel_size, stride=None, padding=0, 56 | dilation=1, ceil_mode=False, return_indices=False): 57 | ''' 58 | Perform complex max pooling by selecting on the absolute value on the complex values. 59 | ''' 60 | absolute_value, indices = max_pool2d( 61 | input.abs(), 62 | kernel_size=kernel_size, 63 | stride=stride, 64 | padding=padding, 65 | dilation=dilation, 66 | ceil_mode=ceil_mode, 67 | return_indices=True 68 | ) 69 | # performs the selection on the absolute values 70 | absolute_value = absolute_value.type(torch.complex64) 71 | # retrieve the corresonding phase value using the indices 72 | # unfortunately, the derivative for 'angle' is not implemented 73 | angle = torch.atan2(input.imag, input.real) 74 | # get only the phase values selected by max pool 75 | angle = _retrieve_elements_from_indices(angle, indices) 76 | return absolute_value \ 77 | * (torch.cos(angle).type(torch.complex64) + 1j * torch.sin(angle).type(torch.complex64)) 78 | 79 | 80 | def complex_dropout(input, p=0.5, training=True): 81 | # need to have the same dropout mask for real and imaginary part, 82 | # this not a clean solution! 83 | mask = torch.ones_like(input).type(torch.float32) 84 | mask = dropout(mask, p, training) * 1 / (1 - p) 85 | return mask * input 86 | 87 | 88 | def complex_dropout2d(input, p=0.5, training=True): 89 | # need to have the same dropout mask for real and imaginary part, 90 | # this not a clean solution! 91 | mask = torch.ones_like(input).type(torch.float32) 92 | mask = dropout2d(mask, p, training) * 1 / (1 - p) 93 | return mask * input -------------------------------------------------------------------------------- /TFA-Net_train/complexModules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from complexLayers import ComplexConv1d 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def set_layer1_module(args): 8 | """ 9 | Create a frequency-representation module 10 | """ 11 | net = None 12 | if args.fr_module_type == 'fr': 13 | assert args.fr_size == args.fr_inner_dim * args.fr_upsampling, \ 14 | 'The desired size of the frequency representation (fr_size) must be equal to inner_dim*upsampling' 15 | net = FrequencyRepresentationModule_TFA_Net(signal_dim=args.signal_dim, n_filters=args.fr_n_filters, 16 | inner_dim=args.fr_inner_dim, n_layers=args.fr_n_layers, 17 | upsampling=args.fr_upsampling, kernel_size=args.fr_kernel_size, 18 | kernel_out=args.fr_kernel_out) 19 | 20 | #########uncomment for RED-Net training############################################################### 21 | # net = FrequencyRepresentationModule_RED_Net(signal_dim=args.signal_dim, n_filters=args.fr_n_filters, 22 | # inner_dim=args.fr_inner_dim, n_layers=args.fr_n_layers, 23 | # upsampling=args.fr_upsampling, kernel_size=args.fr_kernel_size, 24 | # kernel_out=args.fr_kernel_out) 25 | 26 | 27 | else: 28 | raise NotImplementedError('Frequency representation module type not implemented') 29 | if args.use_cuda: 30 | net.cuda() 31 | return net 32 | 33 | 34 | 35 | import math 36 | class REDNet30(nn.Module): 37 | def __init__(self, num_layers=15, num_features=8): 38 | super(REDNet30, self).__init__() 39 | self.num_layers = num_layers 40 | 41 | conv_layers = [] 42 | deconv_layers = [] 43 | 44 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1), 45 | nn.ReLU(inplace=True))) 46 | for i in range(num_layers - 1): 47 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features*2, num_features*2, kernel_size=3, padding=1), 48 | nn.ReLU(inplace=True))) 49 | 50 | for i in range(num_layers - 1): 51 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features*2, num_features*2, kernel_size=3, padding=1), 52 | nn.ReLU(inplace=True))) 53 | deconv_layers.append(nn.ConvTranspose2d(num_features*2, num_features, kernel_size=3, stride=2, padding=1, output_padding=1)) 54 | 55 | self.conv_layers = nn.Sequential(*conv_layers) 56 | self.deconv_layers = nn.Sequential(*deconv_layers) 57 | self.relu = nn.ReLU(inplace=True) 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | conv_feats = [] 63 | for i in range(self.num_layers): 64 | x = self.conv_layers[i](x) 65 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1: 66 | conv_feats.append(x) 67 | 68 | conv_feats_idx = 0 69 | for i in range(self.num_layers): 70 | x = self.deconv_layers[i](x) 71 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats): 72 | conv_feat = conv_feats[-(conv_feats_idx + 1)] 73 | conv_feats_idx += 1 74 | x = x + conv_feat 75 | x = self.relu(x) 76 | 77 | x += residual 78 | x = self.relu(x) 79 | 80 | return x 81 | class REDNet30_stft(nn.Module): 82 | def __init__(self, num_layers=15, num_features=8): 83 | super(REDNet30_stft, self).__init__() 84 | self.num_layers = num_layers 85 | 86 | conv_layers = [] 87 | deconv_layers = [] 88 | 89 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1), 90 | nn.ReLU(inplace=True))) 91 | for i in range(num_layers - 1): 92 | conv_layers.append(nn.Sequential(nn.Conv2d(num_features*2, num_features*2, kernel_size=3, padding=1), 93 | nn.ReLU(inplace=True))) 94 | 95 | for i in range(num_layers - 1): 96 | deconv_layers.append(nn.Sequential(nn.ConvTranspose2d(num_features*2, num_features*2, kernel_size=3, padding=1), 97 | nn.ReLU(inplace=True))) 98 | deconv_layers.append(nn.ConvTranspose2d(num_features*2, num_features, kernel_size=3, stride=2, padding=1, output_padding=1)) 99 | 100 | self.conv_layers = nn.Sequential(*conv_layers) 101 | self.deconv_layers = nn.Sequential(*deconv_layers) 102 | self.relu = nn.ReLU(inplace=True) 103 | 104 | def forward(self, x): 105 | residual = x 106 | 107 | conv_feats = [] 108 | for i in range(self.num_layers): 109 | x = self.conv_layers[i](x) 110 | if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(self.num_layers / 2) - 1: 111 | conv_feats.append(x) 112 | 113 | conv_feats_idx = 0 114 | for i in range(self.num_layers): 115 | x = self.deconv_layers[i](x) 116 | if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats): 117 | conv_feat = conv_feats[-(conv_feats_idx + 1)] 118 | conv_feats_idx += 1 119 | x = x + conv_feat 120 | x = self.relu(x) 121 | 122 | x += residual 123 | x = self.relu(x) 124 | 125 | return x 126 | 127 | class FrequencyRepresentationModule_TFA_Net(nn.Module): 128 | def __init__(self, signal_dim=50, n_filters=8, n_layers=3, inner_dim=125, 129 | kernel_size=3, upsampling=2, kernel_out=3): 130 | super().__init__() 131 | 132 | self.n_filters = n_filters 133 | self.inner = inner_dim 134 | self.n_layers = n_layers 135 | 136 | self.in_layer1 = ComplexConv1d(1, inner_dim * n_filters, kernel_size=(1, 31), padding=(0, 31 // 2), 137 | bias=False) 138 | 139 | self.rednet=REDNet30(self.n_layers,num_features=n_filters) 140 | self.out_layer = nn.ConvTranspose2d(n_filters, 1, (3, 1), stride=(upsampling, 1), 141 | padding=(1, 0), output_padding=(1, 0), bias=False) 142 | 143 | 144 | def forward(self, x): 145 | bsz = x.size(0) 146 | inp_real = x[:, 0, :].view(bsz, 1, 1, -1) 147 | inp_imag = x[:, 1, :].view(bsz, 1, 1, -1) 148 | inp = torch.cat((inp_real, inp_imag), 1) 149 | 150 | x1 = self.in_layer1(inp) 151 | xreal,ximag=torch.chunk(x1,2,1) 152 | xreal = xreal.view(bsz, self.n_filters, self.inner, -1) 153 | ximag = ximag.view(bsz, self.n_filters, self.inner, -1) 154 | 155 | x=torch.sqrt(torch.pow(xreal, 2) + torch.pow(ximag, 2)) 156 | 157 | # for i in range(0,16,1): 158 | # plt.figure() 159 | # plt.ion() 160 | # plt.imshow(x[0,i,:,:].abs()/torch.max(x[0,i,:,:].abs())) 161 | # plt.xticks([]) 162 | # plt.yticks([]) 163 | 164 | x=self.rednet(x) 165 | x = self.out_layer(x).squeeze(-3).transpose(1, 2) 166 | return x 167 | class FrequencyRepresentationModule_RED_Net(nn.Module): 168 | def __init__(self, signal_dim=50, n_filters=8, n_layers=3, inner_dim=125, 169 | kernel_size=3, upsampling=2, kernel_out=3): 170 | super().__init__() 171 | 172 | self.n_filters = n_filters 173 | self.inner = inner_dim 174 | self.n_layers = n_layers 175 | self.in_layer=nn.Conv2d(4, n_filters, kernel_size=1) 176 | self.rednet=REDNet30_stft(self.n_layers,num_features=n_filters) 177 | self.out_layer = nn.ConvTranspose2d(n_filters, 1, (3, 1), stride=(upsampling, 1), 178 | padding=(1, 0), output_padding=(1, 0), bias=False) 179 | 180 | 181 | def forward(self, x): 182 | bsz = x.size(0) 183 | xreal = x[:, 0, :].view(bsz, 1, 1, -1) 184 | ximag = x[:, 1, :].view(bsz, 1, 1, -1) 185 | 186 | xreal2 = x[:, 2, :].view(bsz, 1, 1, -1) 187 | ximag2 = x[:, 3, :].view(bsz, 1, 1, -1) 188 | 189 | xreal3 = x[:, 4, :].view(bsz, 1, 1, -1) 190 | ximag3 = x[:, 5, :].view(bsz, 1, 1, -1) 191 | 192 | xreal4 = x[:, 6, :].view(bsz, 1, 1, -1) 193 | ximag4 = x[:, 7, :].view(bsz, 1, 1, -1) 194 | 195 | xreal = xreal.view(bsz, 1, self.inner, -1) 196 | ximag = ximag.view(bsz, 1, self.inner, -1) 197 | 198 | xreal2 = xreal2.view(bsz, 1, self.inner, -1) 199 | ximag2 = ximag2.view(bsz, 1, self.inner, -1) 200 | 201 | xreal3 = xreal3.view(bsz, 1, self.inner, -1) 202 | ximag3 = ximag3.view(bsz, 1, self.inner, -1) 203 | 204 | xreal4 = xreal4.view(bsz, 1, self.inner, -1) 205 | ximag4 = ximag4.view(bsz, 1, self.inner, -1) 206 | 207 | x=torch.sqrt(torch.pow(xreal, 2) + torch.pow(ximag, 2)) 208 | x2 = torch.sqrt(torch.pow(xreal2, 2) + torch.pow(ximag2, 2)) 209 | x3 = torch.sqrt(torch.pow(xreal3, 2) + torch.pow(ximag3, 2)) 210 | x4 = torch.sqrt(torch.pow(xreal4, 2) + torch.pow(ximag4, 2)) 211 | x=torch.cat((x,x2,x3,x4),1) 212 | x=self.in_layer(x) 213 | x=self.rednet(x) 214 | x = self.out_layer(x).squeeze(-3).transpose(1, 2) 215 | return x 216 | 217 | -------------------------------------------------------------------------------- /TFA-Net_train/complexTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import logging 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | from data import dataset 9 | import complexModules 10 | import util 11 | from data.noise import noise_torch 12 | import torch 13 | logger = logging.getLogger(__name__) 14 | torch.autograd.set_detect_anomaly(True) 15 | 16 | def train_frequency_representation(args, fr_module, fr_optimizer, fr_criterion, fr_scheduler, train_loader, val_loader, 17 | xgrid, epoch, tb_writer): 18 | """ 19 | Train the frequency-representation module for one epoch 20 | """ 21 | epoch_start_time = time.time() 22 | fr_module.train() 23 | loss_train_fr,fnr_train = 0,0 24 | 25 | for batch_idx, (clean_signal, target_fr,wgt) in enumerate(train_loader): 26 | if args.use_cuda: 27 | clean_signal, target_fr,wgt = clean_signal.cuda(), target_fr.cuda(),wgt.cuda() 28 | noisy_signal = noise_torch(clean_signal, args.snr, args.noise) 29 | abs_max = torch.max(torch.sqrt((pow(noisy_signal[:, 0, :], 2) + pow(noisy_signal[:, 1, :], 2)))) 30 | for i in range(noisy_signal.size()[0]): 31 | noisy_signal[i][0]=noisy_signal[i][0]/abs_max 32 | noisy_signal[i][1]=noisy_signal[i][1]/abs_max 33 | 34 | output_fr = fr_module(noisy_signal) 35 | 36 | 37 | loss_l2=torch.pow((output_fr - target_fr), 2)*wgt 38 | loss_fr=torch.sum(loss_l2).to(torch.float32) 39 | 40 | fr_optimizer.zero_grad() 41 | with torch.autograd.detect_anomaly(): 42 | loss_fr.backward() 43 | 44 | fr_optimizer.step() 45 | loss_train_fr += loss_fr.data.item() 46 | 47 | 48 | fr_module.eval() 49 | loss_val_fr, fnr_val = 0, 0 50 | for batch_idx, (noisy_signal, _, target_fr,wgt) in enumerate(val_loader): 51 | if args.use_cuda: 52 | noisy_signal, target_fr,wgt = noisy_signal.cuda(), target_fr.cuda(),wgt.cuda() 53 | 54 | abs_max = torch.max(torch.sqrt((pow(noisy_signal[:, 0, :], 2) + pow(noisy_signal[:, 1, :], 2)))) 55 | for i in range(noisy_signal.size()[0]): 56 | noisy_signal[i][0]=noisy_signal[i][0]/abs_max 57 | noisy_signal[i][1]=noisy_signal[i][1]/abs_max 58 | 59 | with torch.no_grad(): 60 | output_fr = fr_module(noisy_signal) 61 | 62 | 63 | loss_l2=torch.pow((output_fr - target_fr), 2)*wgt 64 | loss_fr=torch.sum(loss_l2).to(torch.float32) 65 | loss_val_fr += loss_fr.data.item() 66 | 67 | 68 | loss_train_fr /= args.n_training 69 | loss_val_fr /= args.n_validation 70 | 71 | 72 | tb_writer.add_scalar('fr_l2_training', loss_train_fr, epoch) 73 | tb_writer.add_scalar('fr_l2_validation', loss_val_fr, epoch) 74 | 75 | 76 | fr_scheduler.step(loss_val_fr) 77 | logger.info("Epochs: %d / %d, Time: %.1f, FR training L2 loss %.2f, FR validation L2 loss %.2f", 78 | epoch, args.n_epochs_fr, time.time() - epoch_start_time, loss_train_fr, loss_val_fr) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | 84 | # basic parameters 85 | parser.add_argument('--output_dir', type=str, default='./checkpoint/TFA-Net_model', help='output directory') 86 | parser.add_argument('--no_cuda', action='store_true', help="avoid using CUDA when available") 87 | # dataset parameters 88 | parser.add_argument('--batch_size', type=int, default=64, help='batch size used during training') 89 | parser.add_argument('--signal_dim', type=int, default=256, help='dimensionof the input signal') 90 | parser.add_argument('--fr_size', type=int, default=256, help='size of the frequency representation') 91 | parser.add_argument('--max_n_freq', type=int, default=10, 92 | help='for each signal the number of frequencies is uniformly drawn between 1 and max_n_freq') 93 | parser.add_argument('--min_sep', type=float, default=0.5, 94 | help='minimum separation between spikes, normalized by signal_dim') 95 | parser.add_argument('--distance', type=str, default='normal', help='distance distribution between spikes') 96 | parser.add_argument('--amplitude', type=str, default='normal_floor', help='spike amplitude distribution') 97 | parser.add_argument('--floor_amplitude', type=float, default=0.1, help='minimum amplitude of spikes') 98 | parser.add_argument('--noise', type=str, default='gaussian_blind', help='kind of noise to use') 99 | parser.add_argument('--snr', type=float, default=-5, help='snr parameter') 100 | # frequency-representation (fr) module parameters 101 | parser.add_argument('--fr_module_type', type=str, default='fr', help='type of the fr module: [fr | psnet]') 102 | parser.add_argument('--fr_n_layers', type=int, default=20, help='number of convolutional layers in the fr module') 103 | parser.add_argument('--fr_n_filters', type=int, default=16, help='number of filters per layer in the fr module') 104 | parser.add_argument('--fr_kernel_size', type=int, default=3, 105 | help='filter size in the convolutional blocks of the fr module') 106 | parser.add_argument('--fr_kernel_out', type=int, default=3, help='size of the conv transpose kernel') 107 | parser.add_argument('--fr_inner_dim', type=int, default=128, help='dimension after first linear transformation') 108 | parser.add_argument('--fr_upsampling', type=int, default=2, 109 | help='stride of the transposed convolution, upsampling * inner_dim = fr_size') 110 | 111 | # kernel parameters used to generate the ideal frequency representation 112 | parser.add_argument('--kernel_type', type=str, default='gaussian', 113 | help='type of kernel used to create the ideal frequency representation [gaussian, triangle or closest]') 114 | parser.add_argument('--triangle_slope', type=float, default=1000, 115 | help='slope of the triangle kernel normalized by signal_dim') 116 | parser.add_argument('--gaussian_std', type=float, default=0.12, 117 | help='std of the gaussian kernel normalized by signal_dim') 118 | # training parameters 119 | parser.add_argument('--n_training', type=int, default=30000, help='# of training data') 120 | parser.add_argument('--n_validation', type=int, default=1000, help='# of validation data') 121 | parser.add_argument('--lr_fr', type=float, default=0.001, 122 | help='initial learning rate for adam optimizer used for the frequency-representation module') 123 | parser.add_argument('--n_epochs_fr', type=int, default=410, help= 'number of epochs used to train the fr module') 124 | parser.add_argument('--save_epoch_freq', type=int, default=10, 125 | help='frequency of saving checkpoints at the end of epochs') 126 | parser.add_argument('--numpy_seed', type=int, default=100) 127 | parser.add_argument('--torch_seed', type=int, default=76) 128 | 129 | args = parser.parse_args() 130 | 131 | if torch.cuda.is_available() and not args.no_cuda: 132 | args.use_cuda = True 133 | else: 134 | args.use_cuda = False 135 | 136 | if not os.path.exists(args.output_dir): 137 | os.makedirs(args.output_dir) 138 | 139 | file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, 'run.log')) 140 | stdout_handler = logging.StreamHandler(sys.stdout) 141 | handlers = [file_handler, stdout_handler] 142 | logging.basicConfig( 143 | datefmt='%m/%d/%Y %H:%M:%S', 144 | level=logging.INFO, 145 | format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', 146 | handlers=handlers 147 | ) 148 | 149 | tb_writer = SummaryWriter(args.output_dir) 150 | util.print_args(logger, args) 151 | 152 | np.random.seed(args.numpy_seed) 153 | torch.manual_seed(args.torch_seed) 154 | 155 | train_loader = dataset.make_train_data(args) 156 | val_loader = dataset.make_eval_data(args) 157 | 158 | 159 | fr_module = complexModules.set_layer1_module(args) 160 | fr_optimizer, fr_scheduler = util.set_optim(args, fr_module, 'layer1') 161 | 162 | fr_criterion = torch.nn.MSELoss(reduction='sum') 163 | start_epoch = 1 164 | 165 | logger.info('[Network] Number of parameters in the frequency-representation module : %.3f M' % ( 166 | util.model_parameters(fr_module) / 1e6)) 167 | 168 | 169 | xgrid = np.linspace(-0.5, 0.5, args.fr_size, endpoint=False) 170 | for epoch in range(start_epoch, args.n_epochs_fr + 1): 171 | 172 | if epoch < args.n_epochs_fr: 173 | train_frequency_representation(args=args, fr_module=fr_module, fr_optimizer=fr_optimizer, fr_criterion=fr_criterion, 174 | fr_scheduler=fr_scheduler, train_loader=train_loader, val_loader=val_loader, 175 | xgrid=xgrid, epoch=epoch, tb_writer=tb_writer) 176 | 177 | 178 | if epoch % args.save_epoch_freq == 0 or epoch == args.n_epochs_fr: 179 | util.save(fr_module, fr_optimizer, fr_scheduler, args, epoch, args.fr_module_type) 180 | 181 | -------------------------------------------------------------------------------- /TFA-Net_train/data/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | def amplitude_generation(dim, amplitude, floor_amplitude=0.1): 6 | """ 7 | Generate the amplitude associated with each frequency. 8 | """ 9 | if amplitude == 'uniform': 10 | return np.random.rand(*dim) * (15 - floor_amplitude) + floor_amplitude 11 | elif amplitude == 'normal': 12 | return np.abs(np.random.randn(*dim)) 13 | elif amplitude == 'normal_floor': 14 | return 5*np.abs(np.random.randn(*dim)) + floor_amplitude 15 | elif amplitude == 'alternating': 16 | return np.random.rand(*dim) * 0.5 + 20 * np.random.rand(*dim) * np.random.randint(0, 2, size=dim) 17 | 18 | def amplitude_freq_generation(dim,floor_amplitude=0.1): 19 | 20 | dotm=np.random.rand(*dim) * (6 - 0.01) + 0.01 21 | f1=(np.random.rand(*dim)-0.5) * 12 22 | f2=(np.random.rand(*dim)-0.5) * 8 23 | amp=dotm/f1 24 | f3 = (np.random.rand(*dim) - 0.5) * 300 25 | return f1,f2,amp,f3 26 | 27 | 28 | def gen_signal(num_samples, signal_dim, num_freq, min_sep, distance='normal', amplitude='normal_floor', 29 | floor_amplitude=0.1, variable_num_freq=False): 30 | s = np.zeros((num_samples, 2, signal_dim)) 31 | 32 | xgrid=np.linspace(0,1,signal_dim,endpoint=False)[:, None] 33 | f1,f2,r,f3 = amplitude_freq_generation((num_samples, num_freq), floor_amplitude) 34 | 35 | amp = amplitude_generation((num_samples, num_freq), amplitude, floor_amplitude) 36 | 37 | theta = np.random.rand(num_samples, num_freq) * 2 * np.pi 38 | start_pos = np.floor(np.random.rand(num_samples, num_freq) * (signal_dim-256)).astype(int) 39 | sig_len_max = signal_dim - start_pos 40 | sig_len = np.random.randint(256, sig_len_max+1) 41 | 42 | 43 | if variable_num_freq: 44 | nfreq = np.random.randint(1, num_freq + 1, num_samples) 45 | else: 46 | nfreq = np.ones(num_samples, dtype='int') * num_freq 47 | for n in range(num_samples): 48 | if n%100==0: 49 | print(n) 50 | # if n>=100 and n<120: 51 | # nfreq[n]=3 52 | # r[n,0:3]=1 53 | for i in range(nfreq[n]): 54 | 55 | xgrid_t = np.zeros((signal_dim, 1)).astype('complex128') 56 | xgrid_t[start_pos[n, i]:start_pos[n, i] + sig_len[n, i], 0] = xgrid[start_pos[n, i]:start_pos[n, i] + sig_len[n, i], 0] 57 | sin=amp[n,i]*np.exp(2j*np.pi*r[n,i]*np.cos(2*np.pi*(f1[n,i]*xgrid_t.T+f2[n,i]*np.power(xgrid_t.T,2))+theta[n,i])+2j*np.pi*f3[n,i]*xgrid_t.T) 58 | s[n, 0] = s[n, 0] + sin.real 59 | s[n, 1] = s[n, 1] + sin.imag 60 | 61 | s[n] = s[n] / (np.sqrt(np.mean(np.power(s[n], 2)))+1e-13) 62 | 63 | return s.astype('float32'), f1.astype('float32'), f2.astype('float32'),nfreq,r,theta.astype('float32'),amp.astype('float32'),start_pos.astype('int'),sig_len.astype('int'),f3.astype('float32') 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /TFA-Net_train/data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.utils.data as data_utils 4 | from data import fr, data 5 | from .noise import noise_torch 6 | 7 | 8 | def load_dataloader(num_samples, signal_dim, max_n_freq, min_sep, distance, amplitude, floor_amplitude, 9 | kernel_type, kernel_param, batch_size, xgrid): 10 | clean_signals, f1,f2, nfreq ,r,theta,amp,sigpos,siglen,f3= data.gen_signal(num_samples, signal_dim, max_n_freq, min_sep, distance=distance, 11 | amplitude=amplitude, floor_amplitude=floor_amplitude, 12 | variable_num_freq=True) 13 | frequency_representation,factor = fr.freq2fr(f1,f2, xgrid, kernel_type, kernel_param,r,nfreq,signal_dim,theta,amp,sigpos,siglen,f3) 14 | clean_signals = torch.from_numpy(clean_signals).float() 15 | frequency_representation = torch.from_numpy(frequency_representation).float() 16 | factor = torch.from_numpy(factor).float() 17 | dataset = data_utils.TensorDataset(clean_signals, frequency_representation,factor) 18 | return data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True) 19 | 20 | 21 | def load_dataloader_fixed_noise(num_samples, signal_dim, max_n_freq, min_sep, distance, amplitude, floor_amplitude, 22 | kernel_type, kernel_param, batch_size, xgrid, snr, noise): 23 | clean_signals, f1,f2, nfreq,r,theta,amp,sigpos,siglen,f3 = data.gen_signal(num_samples, signal_dim, max_n_freq, min_sep, distance=distance, 24 | amplitude=amplitude, floor_amplitude=floor_amplitude, 25 | variable_num_freq=True) 26 | frequency_representation,factor = fr.freq2fr(f1,f2, xgrid, kernel_type, kernel_param,r,nfreq,signal_dim,theta,amp,sigpos,siglen,f3) 27 | 28 | clean_signals = torch.from_numpy(clean_signals).float() 29 | frequency_representation = torch.from_numpy(frequency_representation).float() 30 | factor = torch.from_numpy(factor).float() 31 | noisy_signals = noise_torch(clean_signals, snr, noise) 32 | dataset = data_utils.TensorDataset(noisy_signals, clean_signals, frequency_representation,factor) 33 | return data_utils.DataLoader(dataset, batch_size=batch_size) 34 | 35 | 36 | def make_train_data(args): 37 | xgrid = np.linspace(-0.5, 0.5, args.fr_size, endpoint=False) 38 | if args.kernel_type == 'triangle': 39 | kernel_param = args.triangle_slope / args.signal_dim 40 | else: 41 | kernel_param = args.gaussian_std / args.signal_dim 42 | return load_dataloader(args.n_training, signal_dim=args.signal_dim, max_n_freq=args.max_n_freq, 43 | min_sep=args.min_sep, distance=args.distance, amplitude=args.amplitude, 44 | floor_amplitude=args.floor_amplitude, kernel_type=args.kernel_type, 45 | kernel_param=kernel_param, batch_size=args.batch_size, xgrid=xgrid) 46 | 47 | 48 | def make_eval_data(args): 49 | xgrid = np.linspace(-0.5, 0.5, args.fr_size, endpoint=False) 50 | if args.kernel_type == 'triangle': 51 | kernel_param = args.triangle_slope / args.signal_dim 52 | else: 53 | kernel_param = args.gaussian_std / args.signal_dim 54 | return load_dataloader_fixed_noise(args.n_validation, signal_dim=args.signal_dim, max_n_freq=args.max_n_freq, 55 | min_sep=args.min_sep, distance=args.distance, amplitude=args.amplitude, 56 | floor_amplitude=args.floor_amplitude, kernel_type=args.kernel_type, 57 | kernel_param=kernel_param, batch_size=args.batch_size, xgrid=xgrid, 58 | snr=args.snr, noise=args.noise) 59 | -------------------------------------------------------------------------------- /TFA-Net_train/data/fr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sympy import * 3 | 4 | def freq2fr(f1,f2, xgrid, kernel_type='gaussian', param=None, r=None,nfreq=None,sig_dim=None,theta=None,amp=None,sigpos=None,siglen=None,f3=None): 5 | if kernel_type == 'gaussian': 6 | return gaussian_kernel_simplify(f1,f2, xgrid, param, r,nfreq,sig_dim,theta,amp,sigpos,siglen,f3) 7 | 8 | 9 | 10 | def gaussian_kernel_simplify(f1,f2, xgrid, sigma, r,nfreq,sig_dim,theta,amp,sigpos,siglen,f3): 11 | t=Symbol('t') 12 | t1=np.linspace(0,1,sig_dim,endpoint=False) 13 | fr = np.zeros((f1.shape[0], sig_dim,xgrid.shape[0])) 14 | wgt_factor = np.ones((f1.shape[0], sig_dim, xgrid.shape[0])) 15 | 16 | abs_max=-1 17 | for n in range(fr.shape[0]): 18 | tmp_max=np.max(amp[n,:nfreq[n]]) 19 | if tmp_max>abs_max: 20 | abs_max=tmp_max 21 | abs_max=20 * np.log10(10*abs_max + 1) 22 | for n in range(fr.shape[0]): 23 | 24 | amp[n, :nfreq[n]] = 20 * np.log10(10*amp[n, :nfreq[n]] + 1) 25 | for i in range(nfreq[n]): 26 | x11 = r[n,i] * cos(2 * np.pi * (f1[n,i] * t + f2[n,i] * t**2) + theta[n,i])+f3[n,i]*t 27 | ph = diff(x11,t) 28 | func=lambdify(t,ph,'numpy') 29 | ph1=func(t1) 30 | ph1=ph1-np.floor(ph1/sig_dim)*sig_dim 31 | idx1=np.floor(ph1).astype('int') 32 | 33 | 34 | fr[n, range(sig_dim)[sigpos[n, i]:sigpos[n, i] + siglen[n, i]], np.mod(idx1[sigpos[n, i]:sigpos[n, i] + siglen[ 35 | n, i]]+1,256)] = np.maximum(amp[n, i]/6,fr[n, range(sig_dim)[sigpos[n, i]:sigpos[n, i] + siglen[n, i]], np.mod(idx1[sigpos[n, i]:sigpos[n, i] + siglen[ 36 | n, i]]+1,256)]) 37 | fr[n, range(sig_dim)[sigpos[n, i]:sigpos[n, i] + siglen[n, i]], np.mod(idx1[sigpos[n, i]:sigpos[n, i] + siglen[ 38 | n, i]]-1,256)] = np.maximum(amp[n, i]/6,fr[n, range(sig_dim)[sigpos[n, i]:sigpos[n, i] + siglen[n, i]], np.mod(idx1[sigpos[n, i]:sigpos[n, i] + siglen[ 39 | n, i]]-1,256)]) 40 | 41 | 42 | fr[n,range(sig_dim)[sigpos[n,i]:sigpos[n,i]+siglen[n,i]],idx1[sigpos[n,i]:sigpos[n,i]+siglen[n,i]]] =np.maximum(fr[n,range(sig_dim)[sigpos[n,i]:sigpos[n,i]+siglen[n,i]],idx1[sigpos[n,i]:sigpos[n,i]+siglen[n,i]]],amp[n,i]) 43 | wgt_factor[n,range(sig_dim)[sigpos[n,i]:sigpos[n,i]+siglen[n,i]],idx1[sigpos[n,i]:sigpos[n,i]+siglen[n,i]]]=np.maximum(wgt_factor[n,range(sig_dim)[sigpos[n,i]:sigpos[n,i]+siglen[n,i]],idx1[sigpos[n,i]:sigpos[n,i]+siglen[n,i]]],np.power(abs_max/amp[n,i],2)) 44 | 45 | 46 | return fr,wgt_factor 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /TFA-Net_train/data/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fnr(f1, f2, signal_dim): 5 | threshold = 1/(4*signal_dim) 6 | false_negative = np.zeros(f1.shape[0]) 7 | nfreq = np.sum(f2 > -0.5, axis=1) 8 | for i in range(f2.shape[1]): 9 | dist_i_direct = np.min(np.abs(f1 - f2[:, i][:, None]), axis=1) 10 | dist_i_rshift = np.min(np.abs((f1 + 1) - f2[:, i][:, None]), axis=1) 11 | dist_i_lshift = np.min(np.abs((f1 - 1) - f2[:, i][:, None]), axis=1) 12 | dist_i = np.min((dist_i_direct, dist_i_rshift, dist_i_lshift), axis=0) 13 | valid_freq = (f2[:, i] != -10) 14 | false_negative += (dist_i > threshold)*valid_freq 15 | return np.sum(false_negative/nfreq) 16 | 17 | 18 | def chamfer(f_estimate, f_target): 19 | f_estimate[f_estimate == -1] = -10. 20 | f_target[f_target == -1] = -10. 21 | dist_f_target = np.zeros(f_target.shape[0]) 22 | for i in range(f_target.shape[1]): 23 | dist_i_direct = np.min(np.abs(f_estimate - f_target[:, i][:, None]), axis=1) 24 | dist_i_rshift = np.min(np.abs((f_estimate + 1) - f_target[:, i][:, None]), axis=1) 25 | dist_i_lshift = np.min(np.abs((f_estimate - 1) - f_target[:, i][:, None]), axis=1) 26 | dist_i = np.min((dist_i_direct, dist_i_rshift, dist_i_lshift), axis=0) 27 | b = (f_target[:, i] != -10.) 28 | dist_f_target += dist_i*b 29 | dist_f_estimate = np.zeros(f_estimate.shape[0]) 30 | for i in range(f_estimate.shape[1]): 31 | dist_i_direct = np.min(np.abs(f_target - f_estimate[:, i][:, None]), axis=1) 32 | dist_i_rshift = np.min(np.abs((f_target + 1) - f_estimate[:, i][:, None]), axis=1) 33 | dist_i_lshift = np.min(np.abs((f_target - 1) - f_estimate[:, i][:, None]), axis=1) 34 | dist_i = np.min((dist_i_direct, dist_i_rshift, dist_i_lshift), axis=0) 35 | b = (f_estimate[:, i] != -10.) 36 | dist_f_estimate += dist_i*b 37 | dist = (dist_f_estimate + dist_f_target) 38 | return np.sum(dist) 39 | -------------------------------------------------------------------------------- /TFA-Net_train/data/noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def noise_torch(t, snr=1., kind='gaussian_b', n_corr=None): 6 | if kind == 'gaussian': 7 | return gaussian_noise(t, snr) 8 | elif kind == 'gaussian_blind': 9 | return gaussian_blind_noise(t, snr) 10 | elif kind == 'sparse': 11 | return sparse_noise(t, n_corr) 12 | elif kind == 'variable_sparse': 13 | return variable_sparse_noise(t, n_corr) 14 | 15 | 16 | def gaussian_blind_noise(s, snr): 17 | """ 18 | Add Gaussian noise to the input signal. The std of the gaussian noise is uniformly chosen between 0 and 1/sqrt(snr). 19 | """ 20 | bsz, _, signal_dim = s.size() 21 | s = s.view(bsz, -1) 22 | low=snr 23 | high=40 24 | 25 | scpu=s.cpu().numpy() 26 | snr_array = (low + (high - low) * torch.rand(bsz))[:, None] 27 | snr_array=snr_array.cpu().numpy() 28 | noise = torch.randn(s.size(), device=s.device, dtype=s.dtype) 29 | s = torch.from_numpy(scpu * (10 ** (snr_array / 20))).to(s.device) 30 | return (s + noise).view(bsz, -1, signal_dim) 31 | 32 | def gaussian_noise(s, snr): 33 | """ 34 | Add Gaussian noise to the input signal. 35 | """ 36 | bsz, _, signal_dim = s.size() 37 | s = s.view(s.size(0), -1) 38 | # sigma = np.sqrt(1. / snr) 39 | noise = torch.randn(s.size(), device=s.device, dtype=s.dtype) 40 | scpu = s.cpu().numpy() 41 | snr_array = (snr * torch.ones(bsz))[:, None] 42 | snr_array = snr_array.cpu().numpy() 43 | s = torch.from_numpy(scpu * (10 ** (snr_array / 20))).to(s.device) 44 | return (s + noise).view(bsz, -1, signal_dim) 45 | 46 | 47 | def sparse_noise(s, n_corr): 48 | """ 49 | Add sparse noise to the input signal. The number of corrupted elements is equal to n_corr. 50 | """ 51 | noisy_signal = s.clone() 52 | corruption = 0.5 * torch.randn((s.size(0), s.size(1), n_corr), device=s.device, dtype=s.dtype) 53 | for i in range(s.size(0)): 54 | idx = torch.multinomial(torch.ones(s.size(-1)), n_corr, replacement=False) 55 | noisy_signal[i, :, idx] += corruption[i, :] 56 | return noisy_signal 57 | 58 | 59 | def variable_sparse_noise(s, max_corr): 60 | """ 61 | Add sparse noise to the input signal. The number of corrupted elements is drawn uniformaly between 1 and 62 | max_corruption. 63 | """ 64 | noisy_signal = s.clone() 65 | corruption = 0.5 * torch.randn((s.size(0), s.size(1), max_corr), device=s.device, dtype=s.dtype) 66 | n_corr = np.random.randint(1, max_corr + 1, (s.size(0))) 67 | for i in range(s.size(0)): 68 | idx = torch.multinomial(torch.ones(s.size(-1)), int(n_corr[i]), replacement=False) 69 | noisy_signal[i, :, idx] += corruption[i, :, :n_corr[i]] 70 | return noisy_signal 71 | --------------------------------------------------------------------------------