├── .gitignore ├── README.md ├── cgmm.py ├── run-offline-cgmm-mvdr.py └── run-online-cgmm-mvdr.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | .pytest_cache/ 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule.* 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | env.bak/ 88 | venv.bak/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | 104 | # End of https://www.gitignore.io/api/python 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # online-offline-CGMM-for-MVDR 2 | 3 | Online/offline CGMM implementation of the paper: 4 | 5 | Higuchi, Takuya, et al. "Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR." IEEE/ACM Transactions on Audio, Speech, and Language Processing 25.4 (2017): 780-793. 6 | 7 | Besides the commonly used offline CGMM, we also implemented online CGMM with spatial prior mentioned in paper. 8 | 9 | --- 10 | ## Files 11 | **cgmm.py**: 12 | 1. `class CGMM` which is used as offline mode. Using EM to do Maximal Likelihood (ML) estimation 13 | 2. `class PriorCGMM(CGMM)` which is used as online mode. Using EM to do Maximal A Posterior (MAP) estimation. 14 | 15 | **run-offline-cgmm-mvdr.py**: 16 | Psuedo-codes of using offline CGMM (`CGMM`) to do MVDR 17 | 18 | **run-online-cgmm-mvdr.py**: 19 | Psuedo-codes of using online CGMM (`PriorCGMM`) to do MVDR 20 | 21 | --- 22 | ## Usage 23 | 24 | Offline manner: see `run-offline-cgmm-mvdr.py` 25 | 26 | Online manner: see `run-online-cgmm-mvdr.py` 27 | -------------------------------------------------------------------------------- /cgmm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | 6 | ''' 7 | This is the implementation with CGMM based on paper "Online MVDR Beamformer Based on Complex Gaussian Mixture Model with Spatial Prior for Noise Robust ASR" 8 | We use Circularly Symmetric Gaussian Mixture Model (Complex domain with mean=0 and pseudocovariance=0) 9 | # M: channel number 10 | # K: CGMM cluster number (usually is 2) 11 | # T: frame number 12 | ''' 13 | 14 | class CGMM: 15 | def __init__(self,Y,K=2,openAssert=False): 16 | self._openAssert = openAssert 17 | self._K = K # number of clusters (number of sound sources + 1 background noise) 18 | self._Y = Y # Y: (M = mic_number or feat_dim, T = frame_num) 19 | self._M, self._T = Y.shape 20 | M, T = self._M, self._T 21 | # declares the parameters shape and type 22 | self._Phi = np.zeros([K,T],dtype=complex) # (K,T): variance of signals w.r.t. all time frames for each clusters 23 | self._R = np.zeros([K,M,M],dtype=complex) # (K,M,M): covariances for each clusters 24 | self._invR = np.zeros([K,M,M],dtype=complex) # (K,M,M): inverse covariances for each clusters 25 | self._alpha = np.zeros([K,]) # (K,): mixture weights 26 | self._posterior = np.zeros([K,T]) # posterior prob. 27 | self._steerVec = np.zeros([K,]) # steering vector 28 | 29 | self._initParam() 30 | 31 | def _initParam(self): 32 | K = self._K 33 | Y = self._Y # Y: (M, T) 34 | M, T = self._M, self._T 35 | 36 | if K==2: 37 | self._R[0,...] = 1e-6*np.eye(M).astype(complex) # indicates noise cluster 38 | self._R[1,...] = np.matmul(Y,Y.conj().T)/T # (M,M), indicates speech cluster 39 | else: 40 | # WARNING: Bad performance. 41 | rand_scale = 1e-3*np.random.rand(K).astype(complex) # (K,) 42 | rand_eye = np.einsum('k,ij->kij',rand_scale,np.eye(M).astype(complex)) 43 | for k in range(1,K): 44 | self._R[k,...] = np.matmul(Y,Y.conj().T)/T 45 | self._R += rand_eye 46 | 47 | self._invR = np.linalg.inv(self._R) 48 | tmpMat = np.einsum('mt,kmn->knt',Y.conj(),self._invR) 49 | self._Phi = np.einsum('knt,nt->kt',tmpMat,Y)/M # (K,T) 50 | self._alpha = np.ones([K,])/K 51 | 52 | def getR(self): 53 | return np.copy(self._R) 54 | def getPost(self): 55 | return np.copy(self._posterior) 56 | def getPhi(self): 57 | return np.copy(self._Phi) 58 | def getMixWeights(self): 59 | return np.copy(self._alpha) 60 | 61 | def _calLogGaussianProb(self,Y,Phi,R,invR): 62 | """ 63 | Arguments: (for nfft: num_fft, M: num_mics, T: num_frames) 64 | Y: (M, T), T observations with M-dim 65 | Phi: (T,), Representing the signal variance for each time t. Precise definition can be found in paper. 66 | R, invR: (M,M), the spatial (inverse) covariance matrix 67 | Return: 68 | logProb: (T,), the log-probabilities 69 | """ 70 | M, T = Y.shape 71 | R = (R + np.transpose(np.conj(R))) / 2 72 | if self._openAssert: 73 | tmpMat = np.einsum('mt,mn->nt',Y.conj(),invR) 74 | tmpMat = np.einsum('nt,nt->t',tmpMat,Y) 75 | assert(np.allclose(np.real(tmpMat),np.real(Phi)*M)) 76 | assert(np.allclose(np.imag(tmpMat),np.imag(Phi)*M)) 77 | 78 | det = np.linalg.det(R).real 79 | logProb = -M*np.log(Phi*np.pi) - np.log(det) - M 80 | if self._openAssert: 81 | assert(np.allclose(np.imag(logProb),0)) 82 | 83 | return logProb 84 | 85 | def run(self,itr_num=10): 86 | """ 87 | Maximal Likelihood (ML) with EM algorithm 88 | itr_num: iteration number 89 | Return: 90 | post: (K,T), posterior probabilities for T observations (K-dim) 91 | """ 92 | K, M, T, Y = self._K, self._M, self._T, self._Y 93 | R, invR, Phi, alpha, post = self._R, self._invR, self._Phi, self._alpha, self._posterior 94 | log_post = np.zeros(post.shape) 95 | 96 | for itr in range(itr_num): 97 | 98 | # ===== E Step 99 | # log_post, post: (K,T) 100 | log_alpha = np.log(alpha) # (K,) 101 | for k in range(K): 102 | log_post[k,:] = log_alpha[k] + self._calLogGaussianProb(Y,Phi[k,:],R[k,...],invR[k,...]) 103 | post = np.exp(log_post) 104 | post = post/np.sum(post,axis=0) 105 | if self._openAssert: 106 | assert(np.allclose(np.sum(post,axis=0),1)) 107 | post_sum = np.sum(post,axis=1) # (K,) 108 | 109 | # ===== M Step 110 | # Update Phi 111 | tmpMat = np.einsum('mt,kmn->knt',Y.conj(),invR) 112 | Phi = np.einsum('knt,nt->kt',tmpMat,Y)/M # (K,T) 113 | # Update R 114 | tmpMat = np.einsum('kt,mt->kmt',(post/Phi),Y) # (K,M,T) 115 | R = np.einsum('kmt,tn->kmn',tmpMat,Y.T.conj()) # (K,M,M) 116 | R = np.einsum('kmn,k->kmn',R,1/post_sum) 117 | invR = np.linalg.inv(R) 118 | # Update alpha. It is not updated in paper. Can comment below line. 119 | alpha = post_sum/T 120 | 121 | # Compute post after all iterations 122 | log_alpha = np.log(alpha) # (K,) 123 | for k in range(K): 124 | log_post[k,:] = log_alpha[k] + self._calLogGaussianProb(Y,Phi[k,:],R[k,...],invR[k,...]) 125 | post = np.exp(log_post) 126 | post = post/np.sum(post,axis=0) 127 | self._R, self._invR, self._Phi, self._alpha, self._posterior = R, invR, Phi, alpha, post 128 | return post 129 | 130 | 131 | ''' 132 | This is the implementation with spatial prior CGMM based on paper "Online MVDR Beamformer Based on Complex Gaussian Mixture Model with Spatial Prior for Noise Robust ASR" 133 | # M: channel number 134 | # K: CGMM cluster number (usually is 2) 135 | # T: frame number 136 | ''' 137 | class PriorCGMM(CGMM): 138 | def __init__(self,Y,K=2,openAssert=False): 139 | CGMM.__init__(self,Y,K,openAssert) 140 | CGMM.run(self,itr_num=3) 141 | # Init Super-parameters 142 | # See https://en.wikipedia.org/wiki/Conjugate_prior for conjugate prior 143 | self._Eta = self._T # Control the ratio of previous v.s. new data 144 | # We use lambda (see definition in paper) instead of usual super-parameters in inverse-wishart 145 | # self._posterior is of shape (K,T) 146 | self._Lambda = np.sum(self._posterior,axis=1) # (K,) 147 | assert(len(self._Lambda)==K) 148 | 149 | def run(self,Y,itr_num=3): 150 | """ 151 | Maximal A Posterior (MAP) with EM algorithm 152 | itr_num: iteration number 153 | Return: 154 | post: (K,T), posterior probabilities for T observations (K-dim) 155 | """ 156 | self._Y = Y # Y: (M, T), set the new data as current Y 157 | M, T = Y.shape 158 | assert(M==self._M) 159 | self._T = T # set the new frame number as current T 160 | K = self._K 161 | R, invR, Phi, alpha, post = self._R, self._invR, self._Phi, self._alpha, self._posterior 162 | log_post = np.zeros(post.shape) 163 | 164 | for itr in range(itr_num): 165 | 166 | # ===== E Step 167 | # log_post, post: (K,T) 168 | log_alpha = np.log(alpha) # (K,) 169 | for k in range(K): 170 | log_post[k,:] = log_alpha[k] + self._calLogGaussianProb(Y,Phi[k,:],R[k,...],invR[k,...]) 171 | post = np.exp(log_post) 172 | post = post/np.sum(post,axis=0) 173 | if self._openAssert: 174 | assert(np.allclose(np.sum(post,axis=0),1)) 175 | post_sum = np.sum(post,axis=1) # (K,) 176 | 177 | # ===== M Step 178 | # Update Phi 179 | tmpMat = np.einsum('mt,kmn->knt',Y.conj(),invR) 180 | Phi = np.einsum('knt,nt->kt',tmpMat,Y)/M # (K,T) 181 | # # Update alpha 182 | # alpha = post_sum/T 183 | # Update R, MAP udpate 184 | lambda_next = self._Lambda + post_sum # (K,) 185 | tmpConst = (self._Eta + M + 1)/2 186 | numerator = self._Lambda + tmpConst # (K,) 187 | demonimator = lambda_next + tmpConst # (K,) 188 | tmpMat = np.einsum('kt,mt->kmt',(post/Phi),Y) # (K,M,T) 189 | tmpMat = np.einsum('kmt,tn->kmn',tmpMat,Y.T.conj()) # (K,M,M) 190 | # R: (K,M,M) 191 | priorInfo = np.einsum('k,kmn->kmn',numerator/demonimator,R) # (K,M,M) 192 | newInfo = np.einsum('k,kmn->kmn',1/demonimator,tmpMat) # (K,M,M) 193 | R = priorInfo + newInfo 194 | invR = np.linalg.inv(R) 195 | 196 | # Compute post after all iterations 197 | log_alpha = np.log(alpha) # (K,) 198 | for k in range(K): 199 | log_post[k,:] = log_alpha[k] + self._calLogGaussianProb(Y,Phi[k,:],R[k,...],invR[k,...]) 200 | post = np.exp(log_post) 201 | post = post/np.sum(post,axis=0) 202 | post_sum = np.sum(post,axis=1) # (K,) 203 | self._R, self._invR, self._Phi, self._alpha, self._posterior = R, invR, Phi, alpha, post 204 | 205 | # update super-parameters 206 | self._Eta = self._Eta + T 207 | self._Lambda = self._Lambda + post_sum # (K,) 208 | 209 | return post 210 | 211 | -------------------------------------------------------------------------------- /run-offline-cgmm-mvdr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Psuedo-codes of using offline CGMM with MVDR 3 | ''' 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from cgmm import CGMM 7 | # import ... 8 | 9 | # ===== Read file, do stft 10 | ... 11 | stft_mat = ... 12 | # M: channel number 13 | # K: CGMM cluster number (usually is 2) 14 | # T: frame number 15 | M, valid_n_fft, T = stft_mat.shape 16 | 17 | # ===== Offline CGMM 18 | cgmmEngine = [CGMM(stft_mat[:,i,:]) for i in range(valid_n_fft)] 19 | for i in range(valid_n_fft): 20 | cgmmEngine[i].run() 21 | # Get the spatial covariance matrix 22 | R = np.array([cgmmEngine[i].getR() for i in range(valid_n_fft)]) # (valid_n_fft, K, M, M) 23 | # Get the posterior results 24 | mask_results = np.array([cgmmEngine[i].getPost() for i in range(valid_n_fft)]) # (valid_n_fft, K, T) 25 | 26 | # ===== MVDR 27 | Rv, Rx = R[:,0,:,:], R[:,1,:,:] 28 | # Do MVDR by using Rv and Rx 29 | stft_out = ... # (valid_n_fft, T) 30 | 31 | # OLA back to wav form 32 | wav_out=... 33 | 34 | # ========== Plotting Area 35 | # mask_results: (valid_n_fft, K=2, T) 36 | plt.imshow(mask_results[:,1,:]) # plot the cluster index 1, as it represents speech cluster 37 | plt.title('speech mask') 38 | plt.show() -------------------------------------------------------------------------------- /run-online-cgmm-mvdr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Psuedo-codes of using online CGMM with MVDR 3 | ''' 4 | import numpy as np 5 | from cgmm import PriorCGMM 6 | import matplotlib.pyplot as plt 7 | # import ... 8 | 9 | # ===== Read file, do stft 10 | ... 11 | stft_mat = ... 12 | stft_mat_for_init = ... 13 | # M: channel number 14 | # K: CGMM cluster number (usually is 2) 15 | # T: frame number 16 | M, valid_n_fft, T = stft_mat.shape 17 | 18 | # ===== Online CGMM MVDR 19 | # Use stft_mat_for_init to initialize PriorCGMM 20 | cgmmEngine = [PriorCGMM(stft_mat_for_init[:,i,:]) for i in range(valid_n_fft)] 21 | 22 | # For each chunk, do MAP estimation to simulate online update 23 | chunk_num = int(T/chunk_size) 24 | for c in range(chunk_num): 25 | # ==== Online CGMM 26 | offset = chunk_size*c 27 | for i in range(valid_n_fft): 28 | cgmmEngine[i].run(stft_mat[:,i,offset:offset+chunk_size]) 29 | # Get the spatial covariance matrix 30 | R = np.array([cgmmEngine[i].getR() for i in range(valid_n_fft)]) # (valid_n_fft, K, M, M) 31 | # Get the posterior results 32 | postArray = np.array([cgmmEngine[i].getPost() for i in range(valid_n_fft)]) # (valid_n_fft, K, T) 33 | if(c==1): 34 | mask_results = postArray 35 | else: 36 | mask_results = np.concatenate([mask_results,postArray],axis=2) 37 | # === MVDR 38 | Rv, Rx = R[:,0,:,:], R[:,1,:,:] 39 | # Do MVDR by using Rv and Rx 40 | stft_out_online = ... # (valid_n_fft, T) 41 | if(c==1): 42 | stft_out = stft_out_online 43 | else: 44 | stft_out = np.concatenate([stft_out,stft_out_online],axis=1) 45 | 46 | # OLA back to wav form 47 | wav_out=... 48 | 49 | # ========== Plotting Area 50 | plt.imshow(mask_results[:,1,:]) # plot the cluster index 1, as it represents speech cluster 51 | plt.title('speech mask') 52 | plt.show() 53 | --------------------------------------------------------------------------------