├── .gitignore ├── LICENSE ├── README.md ├── examples ├── calibration_demo.py ├── demo.py └── test_mix_psda.py ├── psda ├── __init__.py ├── besseli.py ├── mix_psda.py ├── psdamodel.py ├── vmf.py └── vmf_sampler.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Niko Brummer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSDA 2 | ## Probabilistic _Spherical_ Discriminant Analysis 3 | 4 | **Update**: We have a new, more powerful generalization, **T-PSDA**. The new code repo is [here](https://github.com/bsxfan/Toroidal-PSDA), and a paper describing the new model is available [here](https://arxiv.org/abs/2210.15441). 5 | 6 | This is a Python implementation of the algorithms described in our Interspeech 2022 paper: 7 | > [Probabilistic Spherical Discriminant Analysis: An Alternative to PLDA for length-normalized embeddings](https://arxiv.org/abs/2203.14893) 8 | 9 | - Please cite this paper if you find our code useful. 10 | 11 | Probabilistic _Linear_ Discrimnant Analysys (PLDA) is a trainable scoring backend that can be used for things like speaker/face recognition or clustering, or speaker diarization. PLDA uses the self-conjugacy of multivariate Gaussians to obtain closed-form scoring and closed-form EM updates for learning. Some of the Gaussian assumptions of the PLDA model are violated when embeddings are length-normalized. 12 | 13 | With PSDA, we use [Von Mises-Fisher](https://en.wikipedia.org/wiki/Von_Mises%E2%80%93Fisher_distribution) (VMF) instead of Gaussians, because they may give a better model for this kind of data. The VMF is also self-conjugate, so we enjoy the same benefits of closed-form scoring and EM-learning. 14 | 15 | ## Installation 16 | For now everything is implemented in numpy and scipy. (The EM algorithm has closed-form updates, so we don't need automatic derivatives for now). The demo code uses our [PYLLR](https://github.com/bsxfan/PYLLR) toolkit for evaluation of the accuracy and calibration. 17 | 18 | We will neaten the installation procedure later. For now, install PYLLR and then just put the directory of this toolkit in your python path. Then run demo.py to see that it works and look at the demo code to figure out how to use the toolkit for training and scoring. 19 | -------------------------------------------------------------------------------- /examples/calibration_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import randn, randint 3 | import matplotlib.pyplot as plt 4 | 5 | from psda.psdamodel import VMF, PSDA, decompose, atleast2 6 | from pyllr import quick_eval 7 | from psda.vmf_sampler import sample_uniform 8 | 9 | dim = 100 10 | w0, uniform = 100, VMF.uniform(dim) # within, between 11 | # mu = sample_uniform(dim) 12 | # w0, uniform = 100, VMF(mu,1.0) # within, between 13 | 14 | model0 = PSDA(w0, uniform) 15 | 16 | 17 | # generate some test data 18 | ns, nt = 500, 10000 19 | print(f'sampling {ns*2} test speakers') 20 | Z1 = model0.sample_speakers(ns) 21 | Z2 = model0.sample_speakers(ns) 22 | labels1 = randint(ns,size=(nt,)) 23 | labels2 = randint(ns,size=(nt,)) 24 | 25 | print(f'sampling {nt*3} test data') 26 | Enroll = model0.sample(Z1, labels1) # enrollment embeddings 27 | Test1 = model0.sample(Z1, labels1) # target test embeddings 28 | Test2 = model0.sample(Z2, labels2) # nnotar test embeddings 29 | 30 | 31 | nw = 200 32 | cllr = np.empty(nw) 33 | mincllr = np.empty(nw) 34 | print(f'scroring {nw} models') 35 | ww = np.exp(np.linspace(np.log(w0/2),np.log(w0*2),nw)) 36 | for i, w in enumerate(ww): 37 | 38 | model = PSDA(w,uniform) 39 | 40 | # compute PSDA scores 41 | E = model.prep(Enroll) 42 | T1 = model.prep(Test1) 43 | T2 = model.prep(Test2) 44 | 45 | tar = E.llr_vector(T1) 46 | non = E.llr_vector(T2) 47 | 48 | 49 | eer, cllr[i], mincllr[i] = quick_eval.tarnon_2_eer_cllr_mincllr(tar, non) 50 | print(f"{i}: w = {w:.2f}, Cllr = {cllr[i]:.2f}") 51 | 52 | plt.figure() 53 | plt.semilogx(ww,cllr,label='Cllr') 54 | plt.semilogx(ww,mincllr,label='minCllr') 55 | plt.title(f'samples from PSDA(w={w0}, between = uniform)') 56 | plt.xlabel('w for scoring model') 57 | plt.ylabel('Cllr') 58 | plt.grid() 59 | plt.legend() 60 | plt.show() 61 | 62 | -------------------------------------------------------------------------------- /examples/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import randn, randint 3 | import matplotlib.pyplot as plt 4 | 5 | from psda.psdamodel import VMF, PSDA, decompose, atleast2 6 | from pyllr import quick_eval 7 | 8 | dim = 100 9 | b, w = 50, 100 # within, between concentrations 10 | 11 | ns = 500 # number of training speakers 12 | n = 10000 # numer of training examples 13 | 14 | 15 | # set up model to sample from 16 | norm, mu = decompose(randn(dim)) 17 | model0 = PSDA(w, VMF(mu, b)) 18 | 19 | print(f'sampling {ns} training speakers') 20 | Z = model0.sample_speakers(ns) 21 | labels = randint(ns,size=(n,)) 22 | uu, labels, counts = np.unique(labels, return_inverse=True, return_counts=True) 23 | 24 | # sample training data 25 | print(f'sampling {n} training data') 26 | Xtrain = model0.sample(Z, labels) 27 | 28 | if dim == 2: 29 | plt.figure() 30 | plt.scatter(Xtrain[:,0],Xtrain[:,1]) 31 | plt.axis('square') 32 | plt.xlim(-1.2,1.2) 33 | plt.ylim(-1.2,1.2) 34 | plt.grid() 35 | plt.title('Embeddings') 36 | plt.show() 37 | 38 | print('training') 39 | 40 | 41 | # one hot label matrix 42 | L = np.full((n,len(counts)),False) # (n, ns) 43 | L[np.arange(n),labels] = True 44 | 45 | # these are the 1st-order stats required by the em traning 46 | means = (L.T @ Xtrain) / counts.reshape(-1,1) 47 | 48 | # filter out singleton speakers 49 | means, counts = atleast2(means, counts) 50 | 51 | # train the model! 52 | model, obj = PSDA.em(means, counts, niters=10) 53 | 54 | plt.figure() 55 | plt.plot(obj,'-*') 56 | plt.grid() 57 | plt.title('PSDA EM algorithm') 58 | plt.xlabel('iteration') 59 | plt.ylabel('marginal likelihood') 60 | plt.show() 61 | 62 | # generate some test data 63 | nt = 10000 64 | print(f'sampling {ns*2} test speakers') 65 | Z1 = model0.sample_speakers(ns) 66 | Z2 = model0.sample_speakers(ns) 67 | labels1 = randint(ns,size=(nt,)) 68 | labels2 = randint(ns,size=(nt,)) 69 | 70 | print(f'sampling {nt*3} test data') 71 | Enroll = model0.sample(Z1, labels1) # enrollment embeddings 72 | Test1 = model0.sample(Z1, labels1) # target test embeddings 73 | Test2 = model0.sample(Z2, labels2) # nnotar test embeddings 74 | 75 | print('scoring single enroll') 76 | # compute PSDA scores 77 | E = model.prep(Enroll) 78 | T1 = model.prep(Test1) 79 | T2 = model.prep(Test2) 80 | 81 | tar = E.llr_vector(T1) 82 | non = E.llr_vector(T2) 83 | 84 | # compute cosine scores 85 | tarc = (Enroll*Test1).sum(axis=-1) 86 | nonc = (Enroll*Test2).sum(axis=-1) 87 | 88 | 89 | plt.figure() 90 | plt.plot(non,nonc,'.',label='non') 91 | plt.plot(tar,tarc,'.',label='tar') 92 | plt.grid() 93 | plt.xlabel('PSDA score') 94 | plt.ylabel('cosine score') 95 | plt.legend() 96 | plt.show() 97 | 98 | 99 | 100 | # compute double-enroll PSDA scores 101 | print(f'sampling {nt} 2nd enrollments') 102 | Enroll2 = model0.sample(Z1, labels1) # 2nd enrollment embeddings 103 | 104 | print('scoring double enroll') 105 | E2 = model.prep(Enroll + Enroll2) 106 | tar2 = E2.llr_vector(T1) 107 | non2 = E2.llr_vector(T2) 108 | 109 | # compute double-enroll cosine scores 110 | E2c = decompose(Enroll + Enroll2)[1] 111 | tar2c = (E2c*Test1).sum(axis=-1) 112 | non2c = (E2c*Test2).sum(axis=-1) 113 | 114 | 115 | tar12 = np.hstack([tar,tar2]) 116 | non12 = np.hstack([non,non2]) 117 | 118 | tar12c = np.hstack([tarc,tar2c]) 119 | non12c = np.hstack([nonc,non2c]) 120 | 121 | 122 | print('evaluating') 123 | 124 | eer_p, cllr_p, mincllr_p = quick_eval.tarnon_2_eer_cllr_mincllr(tar, non) 125 | eer_p2, cllr_p2, mincllr_p2 = quick_eval.tarnon_2_eer_cllr_mincllr(tar2, non2) 126 | 127 | eer_c, cllr_c, mincllr_c = quick_eval.tarnon_2_eer_cllr_mincllr(tarc, nonc) 128 | eer_c2, cllr_c2, mincllr_c2 = quick_eval.tarnon_2_eer_cllr_mincllr(tar2c, non2c) 129 | 130 | eer_p12, cllr_p12, mincllr_p12 = quick_eval.tarnon_2_eer_cllr_mincllr(tar12, non12) 131 | eer_c12, cllr_c12, mincllr_c12 = quick_eval.tarnon_2_eer_cllr_mincllr(tar12c, non12c) 132 | 133 | 134 | print("\n\nCosine scoring, single enroll:") 135 | print(f" EER: {eer_c*100:.1f}%") 136 | print(f" Cllr: {cllr_c:.3f}") 137 | print(f" minCllr: {mincllr_c:.3f}") 138 | 139 | print("\nPSDA scoring, single enroll:") 140 | print(f" EER: {eer_p*100:.1f}%") 141 | print(f" Cllr: {cllr_p:.3f}") 142 | print(f" minCllr: {mincllr_p:.3f}") 143 | 144 | print("\nCosine scoring, double enroll:") 145 | print(f" EER: {eer_c2*100:.1f}%") 146 | print(f" Cllr: {cllr_c2:.3f}") 147 | print(f" minCllr: {mincllr_c2:.3f}") 148 | 149 | print("\nPSDA scoring, double enroll:") 150 | print(f" EER: {eer_p2*100:.1f}%") 151 | print(f" Cllr: {cllr_p2:.3f}") 152 | print(f" minCllr: {mincllr_p2:.3f}") 153 | 154 | print("\nCosine scoring, mixed enroll:") 155 | print(f" EER: {eer_c12*100:.1f}%") 156 | print(f" Cllr: {cllr_c12:.3f}") 157 | print(f" minCllr: {mincllr_c12:.3f}") 158 | 159 | print("\nPSDA scoring, mixed enroll:") 160 | print(f" EER: {eer_p12*100:.1f}%") 161 | print(f" Cllr: {cllr_p12:.3f}") 162 | print(f" minCllr: {mincllr_p12:.3f}") 163 | 164 | 165 | plt.figure() 166 | plt.hist(non,100,label='non') 167 | plt.hist(tar,100,label='tar',alpha=0.5) 168 | plt.title(f'PSDA LLR, EER={eer_p*100:.1f}%') 169 | plt.grid() 170 | plt.legend() 171 | plt.show() 172 | 173 | plt.figure() 174 | plt.hist(nonc,100,label='non') 175 | plt.hist(tarc,100,label='tar',alpha=0.5) 176 | plt.title(f'cosine, EER={eer_c*100:.1f}%') 177 | plt.grid() 178 | plt.legend() 179 | plt.show() 180 | 181 | -------------------------------------------------------------------------------- /examples/test_mix_psda.py: -------------------------------------------------------------------------------- 1 | 2 | import os,sys 3 | 4 | import numpy as np 5 | from numpy.random import randn, randint 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from pyllr import quick_eval 10 | from psda.mix_psda import VMF, MixPSDA, decompose, atleast2 11 | 12 | rng = np.random.default_rng() 13 | 14 | 15 | 16 | def generate_psda(dim,ns,ntrain,ntest): 17 | from psda.psda import PSDA 18 | 19 | b, w = 50, 100 # within, between concentrations 20 | 21 | # set up model to sample from 22 | norm, mu = decompose(randn(dim)) 23 | model0 = PSDA(w, VMF(mu, b)) 24 | print(f"true : 0","B =",b, "W =",w,"mu =",mu.ravel()[:6]) 25 | 26 | print(f'sampling {ns} training speakers') 27 | Z = model0.sample_speakers(ns) 28 | labels = randint(ns,size=(ntrain,)) 29 | uu, labels, counts = np.unique(labels, return_inverse=True, return_counts=True) 30 | 31 | # sample training data 32 | print(f'sampling {ntrain} training data') 33 | Xtrain = model0.sample(Z, labels) 34 | 35 | # generate some test data 36 | #nt = 10000 37 | print(f'sampling {ns*2} test speakers') 38 | Z1 = model0.sample_speakers(ns) 39 | Z2 = model0.sample_speakers(ns) 40 | labels1 = randint(ns,size=(ntest,)) 41 | labels2 = randint(ns,size=(ntest,)) 42 | 43 | print(f'sampling {ntest*3} test data') 44 | Enroll = model0.sample(Z1, labels1) # enrollment embeddings 45 | Test1 = model0.sample(Z1, labels1) # target test embeddings 46 | Test2 = model0.sample(Z2, labels2) # nnotar test embeddings 47 | 48 | return Xtrain, labels, Enroll, Test1, Test2 49 | 50 | 51 | def generate_mix_psda(m, dim,ns,ntrain,ntest): 52 | 53 | p_i = np.ones(m)/m 54 | norm, mu = decompose(randn(m,dim)) 55 | # if m>1: 56 | # mu[1] = -mu[0] 57 | 58 | if m == 1: 59 | b = np.array([50]) 60 | w = np.array([100]) 61 | else: 62 | b = rng.uniform(20,100,size=m) 63 | w = rng.uniform(50,300,size=m) 64 | 65 | print(f"true : 0","B =",b, "W =",w,"mu =",mu.ravel()[:6]) 66 | print() 67 | 68 | model0 = MixPSDA(p_i, w, VMF(mu, b)) 69 | 70 | labels = randint(ns,size=(ntrain,)) 71 | uu, labels, counts = np.unique(labels, return_inverse=True, return_counts=True) 72 | onehot = labels[:,None] == uu 73 | # filter out speakers with < 2 utterances 74 | onehot = onehot[:,onehot.sum(axis=0) > 1] 75 | labels = labels[onehot.sum(axis=1) > 0] 76 | uu, labels, counts = np.unique(labels, return_inverse=True, return_counts=True) 77 | onehot = labels[:,None] == uu 78 | 79 | spk2comp = np.random.choice(range(m),size=ns,replace=True) 80 | # component_labels = np.asarray([spk2comp[spk] for spk in labels]) 81 | 82 | print(f'sampling {ntrain} training data') 83 | Z = model0.sample_speakers(spk2comp) 84 | Xtrain = model0.sample(Z, spk2comp, labels) 85 | 86 | 87 | # # generate some test data 88 | print(f'sampling {ns*2} test speakers') 89 | spk2comp1 = np.random.choice(range(m),size=ns,replace=True) 90 | Z1 = model0.sample_speakers(spk2comp1) 91 | spk2comp2 = np.random.choice(range(m),size=ns,replace=True) 92 | Z2 = model0.sample_speakers(spk2comp2) 93 | 94 | labels1 = randint(ns,size=(ntest,)) 95 | labels2 = randint(ns,size=(ntest,)) 96 | Enroll = model0.sample(Z1, spk2comp1, labels1) 97 | Test1 = model0.sample(Z1, spk2comp1, labels1) 98 | Test2 = model0.sample(Z2, spk2comp2, labels2) 99 | 100 | return Xtrain, labels, Enroll, Test1, Test2 101 | 102 | 103 | 104 | 105 | if __name__ == "__main__": 106 | 107 | dim = 2 108 | ns = 500 # number of training speakers 109 | ntrain = 10000 # numer of training examples 110 | ntest = 10000 111 | m = 3 # number of VMF components 112 | 113 | # Xtrain, labels, Enroll, Test1, Test2 = generate_psda(dim,ns,ntrain,ntest) 114 | Xtrain, labels, Enroll, Test1, Test2 = generate_mix_psda(m,dim,ns,ntrain,ntest) 115 | uu, labels, counts = np.unique(labels, return_inverse=True, return_counts=True) 116 | 117 | onehot = labels[:,None] == uu 118 | onehot = onehot[:,onehot.sum(axis=0) > 1] 119 | ii = onehot.sum(axis=1) > 0 120 | onehot = onehot[ii] 121 | Xtrain = Xtrain[ii] 122 | onesoft = onehot/onehot.sum(axis=0) 123 | 124 | labels = onehot.argmax(axis=1) 125 | counts = onehot.sum(axis=0) 126 | uu = np.unique(labels) 127 | 128 | # ================================================================= 129 | 130 | # these are the 1st-order stats required by the em traning 131 | means = onesoft.T @ Xtrain 132 | 133 | p_i = np.ones(m)/m 134 | mm = np.random.randn(m,dim) 135 | bb = np.random.rand(m)*0.1 # np.ones_like(b)*1/0.11 136 | ww = np.random.rand(m)*100+300 137 | model1 = MixPSDA(p_i, ww, VMF(mm, bb)) 138 | #model1 = model0 139 | model, obj = MixPSDA.em(means, counts, niters=20, w0=None, psda_init=model1) 140 | 141 | # ================================================================= 142 | # ================================================================= 143 | 144 | E = model.prep(Enroll) 145 | T1 = model.prep(Test1) 146 | T2 = model.prep(Test2) 147 | 148 | tar = E.llr_vector(T1) 149 | non = E.llr_vector(T2) 150 | 151 | tarc = np.sum(Enroll*Test1,axis=-1) 152 | nonc = np.sum(Enroll*Test2,axis=-1) 153 | 154 | fig,axes = plt.subplots(2,1,figsize=(12,12)) 155 | axes[0].hist(non,bins=100,label='non',alpha=0.85,density=True) 156 | axes[0].hist(tar,bins=100,label='tar',alpha=0.85,density=True) 157 | axes[0].legend() 158 | 159 | eer,cllr,mincllr = quick_eval.tarnon_2_eer_cllr_mincllr(tar,non) 160 | axes[0].set_title(f"EER={eer:1.3%}, CLLR={cllr:1.3f}, minCLLR={mincllr:1.3f}") 161 | 162 | axes[1].hist(nonc,bins=100,label='nonc',alpha=0.85,density=True) 163 | axes[1].hist(tarc,bins=100,label='tarc',alpha=0.85,density=True) 164 | axes[1].legend() 165 | #axes[1].sharex(axes[0]) 166 | eer,cllr,mincllr = quick_eval.tarnon_2_eer_cllr_mincllr(tarc,nonc) 167 | axes[1].set_title(f"EER={eer:1.3%}, CLLR={cllr:1.3f}, minCLLR={mincllr:1.3f}") 168 | 169 | 170 | plt.figure() 171 | plt.plot(non,nonc,'.',label='non') 172 | plt.plot(tar,tarc,'.',label='tar') 173 | plt.xlabel("PSDA") 174 | plt.ylabel("Cosine") 175 | 176 | eer,cllr,mincllr = quick_eval.tarnon_2_eer_cllr_mincllr(tar,non) 177 | plt.title(f"EER={eer:1.3%}, CLLR={cllr:1.3f}, minCLLR={mincllr:1.3f}") 178 | 179 | 180 | if 'plot' in sys.argv: 181 | cmap = plt.get_cmap('Spectral') 182 | cc = [cmap(s/ns) for s in labels] 183 | 184 | plt.figure() 185 | 186 | if dim == 2: 187 | x,y = Xtrain.T 188 | plt.scatter(x,y,color=cc,marker='o') 189 | # for spk in np.unique(labels): 190 | # ii = labels==spk 191 | # for k in np.unique(component_labels[ii]): 192 | # jj = np.logical_and(ii,component_labels==k) 193 | # plt.scatter(Xtrain[jj][:,0],Xtrain[jj][:,1],color=cmap(spk/ns), marker="x^+o."[k%4]) 194 | 195 | for m in model.between.mu: 196 | plt.arrow(0,0,*m,color='r') 197 | plt.axis('square') 198 | plt.xlim(-1.2,1.2) 199 | plt.ylim(-1.2,1.2) 200 | plt.grid() 201 | plt.show() 202 | 203 | elif dim==3: 204 | 205 | cc = [cmap(s/ns) for s in labels] 206 | x,y,z = Xtrain.T 207 | 208 | fig = plt.figure() 209 | ax = fig.add_subplot(111, projection='3d') 210 | ax.scatter(x,y,z,color=cc,marker='o') 211 | 212 | for mi in model.between.mu: 213 | ax.quiver(0,0,0,*mi,color='r') 214 | 215 | 216 | 217 | plt.show() 218 | -------------------------------------------------------------------------------- /psda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsxfan/PSDA/572d2993be59bd4da33bcaec06d9a1eb440ed548/psda/__init__.py -------------------------------------------------------------------------------- /psda/besseli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bessel-I is numerically tricky it can underflow and overflow and is much 3 | slower than typical float functions like log and exp. 4 | 5 | Bessel-I is available in a few forms in scipy.special: 6 | iv(nu,x): I_nu(x) # underflows and overflows for large nu 7 | ive(nu,x): I_nu(x) exp(-x) # underflows, better against overflow, but it 8 | still happens if x is too large 9 | ivp: derivatives for iv 10 | 11 | Bessel-I and even its lagrithm is available in Tensorflow. 12 | 13 | Bessel-I is not available in Pytorch (except for nu = 0 and 1). 14 | 15 | 16 | In this module are some tools to compute log I_nu in regions where iv or ive 17 | would underflow, or overflow. 18 | 19 | We also have some tools to create a fast, bespoke approximation to the 20 | log of the exponentially scaled normalization constant of the Von Mises-Fisher 21 | distribution, with nu = dim/2-1 and concentration kappa: 22 | 23 | log Cvmf_e(nu,kappa) = nu*log(kappa) - log(ive). 24 | 25 | The approximation form is: 26 | 27 | affine --> softplus --> affine 28 | 29 | The affine parameters are tuned for every nu. Tuning invokes scipy.special.ive, 30 | but once tuned, the approximation can run anywhere without scipy. So it can be 31 | used for example in Pytorch (with backprop) on the GPU. No Pytorch tools are 32 | included here, but sending the tuned approximation for use on any platform that 33 | has a softplus available is trivial. 34 | 35 | We found the approximation to be an order of magnitude faster than the (patched) 36 | scipy.special.ive function. 37 | 38 | 39 | """ 40 | 41 | import numpy as np 42 | 43 | from scipy.special import gammaln, logsumexp, ive 44 | from scipy.optimize import minimize_scalar 45 | 46 | logfactorial = lambda x: gammaln(x+1) 47 | log2 = np.log(2) 48 | log2pi = np.log(2*np.pi) 49 | 50 | 51 | 52 | def k_and_logk(k = None, logk = None, compute_k = True, compute_logk = True): 53 | """ 54 | Convenience method used by all functions that have inputs in the style: 55 | 56 | k and/or logk 57 | 58 | """ 59 | assert k is not None or logk is not None, "at least one of k or logk is required" 60 | if compute_k and k is None: 61 | k = np.exp(logk) 62 | if compute_logk and logk is None: 63 | with np.errstate(divide='ignore'): 64 | logk = np.log(k) 65 | return k, logk 66 | 67 | 68 | 69 | 70 | def log_ive_raw(nu, k = None, logk = None): 71 | """ 72 | This wrapper returns: 73 | 74 | np.log(ive(nu,k)) 75 | 76 | scipy.special.ive underflows for k too small relative to nu. This cannot 77 | be fixed in ive, without changing to a logarithmic function return value. 78 | If ive underflows (returns 0), then the log throws a warning and this 79 | raw wrapper function returns -inf. 80 | 81 | If both nu and k are too large, NaN is returned quietly, 82 | e.g. ive(255,np.exp(21)). I believe this a bug. The ive function values 83 | for even larger inputs do still have floating point representations. 84 | 85 | This underflow and NaN behaviour is 'patched up' in the class LogBesselI 86 | and its methods, which provide logrithmic input and output interfaces where 87 | needed. 88 | 89 | 90 | inputs: k and/or logk. 91 | Only k is used, but it will be computed from logk if not given. 92 | 93 | 94 | """ 95 | k, logk = k_and_logk(k, logk, True, False) 96 | 97 | return np.log(ive(nu,k)) 98 | 99 | 100 | 101 | class LogBesselI: 102 | """ 103 | Callable to implement log I_nu(k), for nu >=0 and k >= 0. 104 | 105 | Unlike scipy.special.ive, this callable will not underflow if k is small 106 | relative to nu and it catches NaNs and recomputes the value using 2 terms 107 | of a series expansion for large arguments. 108 | 109 | The degree, nu is stored in the callable, while x is supplied to the call, 110 | e.g.: 111 | 112 | logI = LogBesselI(nu) 113 | y1 = logI(k1) 114 | y2 = logI(k2) 115 | 116 | I_nu(k) >= 0, so log is -inf or real 117 | I_nu(k) (and its log) is monotonic rising 118 | 119 | log I_0(0) = 0, 120 | log I_nu(0) = -inf (no warning), for nu > 0 121 | 122 | For large k, I_nu(k) --> exp(k) / sqrt(2 pi k) 123 | 124 | 125 | """ 126 | 127 | 128 | def __init__(self, nu, n=5): 129 | assert nu >= 0, 'Bessel-I is defined for nu < 0, but this code excludes that.' 130 | self.nu = nu 131 | self.n = n 132 | m = np.arange(n) 133 | self.exponent = (2*m+nu).reshape(-1,1) 134 | self.den = (logfactorial(m) + gammaln(m+1+nu)).reshape(-1,1) 135 | self.at0 = 0.0 if nu==0 else -np.inf 136 | 137 | 138 | 139 | def small_log_iv(self, k=None, logk=None): 140 | """ 141 | Short series expansion for: 142 | 143 | log iv(nu,k)) = log Inu(k) 144 | 145 | for smallish k > 0. At a fixed number of terms, accuracy depends on nu. 146 | We use this series expansion only if ive underflows, effecting an 147 | automatic decision when to invoke this expansion. We found log(ive) 148 | to be accurate up to the point (going to smaller x) where underflow 149 | still does not happen. 150 | 151 | inputs: k and/or logk. 152 | Only logk is used, but it will be computed from k if not given. 153 | 154 | 155 | 156 | """ 157 | k, logk = k_and_logk(k, logk, False, True) 158 | num = self.exponent * (logk-log2) 159 | return logsumexp(num-self.den,axis=0) 160 | 161 | 162 | def large_log_ive(self, k = None, logk = None, asymptote = True): 163 | """ 164 | Evaluates linear asymptote for log ive(nu,k) for large k. 165 | 166 | log ive(nu,k)) = log Inu(k) - k --> (log2pi - logk) / 2 167 | 168 | If input flag asymptote = False, the results is refined using also the 169 | next term of a series expansion for large arguments. 170 | 171 | Example: 172 | 173 | nu = 255 174 | logI = LogBesselI(nu) 175 | for logk in (20,21): 176 | raw = log_ive_raw(nu, np.exp(logk)) 177 | s1 = logI.large_log_ive(logk,asymptote=True) 178 | s2 = logI.large_log_ive(logk,asymptote=False) 179 | print(f"logk={logk}: {raw:.5f}, {s2:.5f}, {s1:.5f}") 180 | 181 | > logk=20: -10.91901, -10.91901, -10.91894 182 | > logk=21: nan, -11.41896, -11.41894 183 | 184 | 185 | We use this call to patch up log(ive) in cases where ive returns NaN. 186 | (We assume this happens only for large k. If this is not the case, 187 | the log1p below can also NaN if k is too small relative to nu.) 188 | 189 | 190 | inputs: k and/or logk. 191 | For asymptote=True, only logk is used. 192 | For asymptote=False, both are used. 193 | 194 | 195 | 196 | """ 197 | nu = self.nu 198 | k, logk = k_and_logk(k, logk, not asymptote, True) 199 | lin_asymptote = - (log2pi + logk)/2 200 | if asymptote: 201 | return lin_asymptote 202 | return np.log1p(-(4*nu**2-1)/(8*k)) + lin_asymptote 203 | 204 | 205 | 206 | 207 | def __call__(self, k=None, logk=None, exp_scale = False): 208 | """ 209 | Evaluates log I(nu, k), so that it also works for small and large 210 | values of k. 211 | 212 | - k = 0 is valid 213 | - scipy.special.ive is used, unless it underflows or returns NaN 214 | - ive underflow happens when k is small relative to nu, this is 215 | fixed by a logsumexp low-order series epansion 216 | - ive NaN is probably a bug. It happens for large nu and k. It is 217 | fixed with a different (small) series expansion. 218 | 219 | 220 | inputs: 221 | 222 | k and/ or logk: scalars or vectors 223 | 224 | Both k and logk are used. 225 | 226 | The k-values should be non-negative: 227 | if k==0 and nu > 0 -inf is returned quietly 228 | if k==0 and nu = 0, 0 is returned 229 | 230 | 231 | exp_scale: flag default=False: log Bessel-I is returned. 232 | If true: log ive is returned instead. 233 | 234 | 235 | returns: scalar or vector: 236 | 237 | log I(nu,k), if not exp_scale 238 | log I(nu,k) - k, if exp_scale 239 | 240 | """ 241 | k, logk = k_and_logk(k, logk, True, False) 242 | if np.isscalar(k): 243 | k = np.array([k]) 244 | if logk is not None: 245 | logk = np.array([logk]) 246 | return self.__call__(k,logk,exp_scale)[0] 247 | 248 | assert np.all(k >= 0) 249 | 250 | # try ive for all k 251 | y = ive(self.nu,k) 252 | 253 | # apply logs when k=0, or y > 0 and not NaN 254 | ok = np.logical_or(k==0, y > 0) # ive gives correct answer (0 or 1) for x==0 255 | with np.errstate(divide='ignore'): # y may be 0 if x ==0 256 | y[ok] = np.log(y[ok]) 257 | if not exp_scale: y[ok] += k[ok] # undo scaling done by ive 258 | 259 | # patch overflow 260 | nan = np.isnan(y) # we assume this signals overflow 261 | knan = k[nan] 262 | log_knan = np.log(knan) if logk is None else logk[nan] 263 | y[nan] = self.large_log_ive(knan, log_knan, asymptote=False) 264 | if not exp_scale: y[nan] += knan # undo scaling done by ive 265 | 266 | # patch underflow 267 | not_ok = np.logical_not(ok) 268 | not_nan = np.logical_not(nan) 269 | uf = np.logical_and(not_ok, not_nan) 270 | logk_uf = np.log(k[uf]) if logk is None else logk[uf] 271 | y[uf] = self.small_log_iv(logk=logk_uf) 272 | if exp_scale: y[uf] -= k[uf] 273 | 274 | return y 275 | 276 | 277 | 278 | def log_iv(self, k=None, logk = None): 279 | """ 280 | Returns log I(nu, k). This is the same as __call__ with the default 281 | flag. 282 | 283 | See __call__ for more details. 284 | 285 | inputs: 286 | 287 | k and/ or logk: scalars or vectors 288 | 289 | Both k and lox are used. 290 | 291 | 292 | returns: scalar or vector log I(nu,k) 293 | 294 | """ 295 | return self(k, logk, exp_scale = False) 296 | 297 | 298 | 299 | def log_ive(self, k=None, logk=None): 300 | """ 301 | Returns exponentially scaled log Bessel-I: 302 | 303 | log [I(nu, k) exp(-k)] = log I(nu,k) - k. 304 | 305 | This inokes __call__ with exp_scaling=True. See __call__ for more 306 | details. 307 | 308 | inputs: 309 | 310 | k and/ or logk: scalars or vectors 311 | 312 | Both k and logk are used. Supply both if you have them. 313 | 314 | 315 | returns: scalar or vector log I(nu,k) - k 316 | 317 | """ 318 | 319 | return self(k, logk, exp_scale = True) 320 | 321 | 322 | def logCvmf(self, k = None, logk = None, exp_scale = False): 323 | """ 324 | log normalization constant (numerator) for Von-Mises-Fisher 325 | distribution, with nu = dim/2-1 326 | 327 | 328 | log Cvmf(k) = log [ k^nu / I_nu(k) ] 329 | 330 | 331 | VMF(x | mu, k) \propto Cvmf(k) exp[kappa*mu'x] 332 | 333 | 334 | 335 | input: k and/or logk, where k >= 0 is the concentration 336 | 337 | Both k and logk are used. Supply both if you have both. 338 | 339 | returns: function value(s) 340 | 341 | The output has the same shape as the input. 342 | 343 | 344 | Notes: 345 | 346 | Cvmf omits a factor that is dependent only on the dimension. 347 | 348 | The limit at k=0 (or logk=-np.inf) is handled in this call, 349 | but only works for a scalar input. 350 | 351 | If you need the derivative, see LogBesselIPair.logCvmf(). 352 | 353 | 354 | """ 355 | nu = self.nu 356 | k, logk = k_and_logk(k, logk) 357 | if np.isscalar(k) and k == 0: 358 | return nu*log2 + gammaln(nu+1) # irrespective of exp_scaling 359 | logI = self(k, logk, exp_scale) 360 | y = nu*logk - logI 361 | return y 362 | 363 | 364 | 365 | def logCvmf_e(self, k=None, logk=None): 366 | """ 367 | log normalization constant (numerator) for Von Mises-Fisher 368 | distribution, with nu = dim/2-1 369 | 370 | 371 | log Cvmf_e(k) = log [ k^nu / (I_nu(k) exp(-k)) ] 372 | = nu*log(k) + k - log I_nu(k) 373 | 374 | VMF(x | mu, k) \propto Cvmf_e(k) exp[k*(mu'x-1)] 375 | 376 | 377 | 378 | input: k, and/or logk, where k >= 0 is the concentration 379 | 380 | Both k and logk are used. Supply both if you have both. 381 | 382 | returns: function value(s) 383 | 384 | The output has the same shape as the input. 385 | """ 386 | 387 | return self.logCvmf(k, logk, exp_scale=True) 388 | 389 | 390 | 391 | 392 | 393 | class LogBesselIPair: 394 | """ 395 | This is a callable that computes log I_nu and log I_{nu+1} and their 396 | derivatives in a single call. 397 | 398 | The degree nu is fixed within the object. 399 | 400 | To compute rho = I_{n+1} / I_nu, you need both I's and then the derivatives 401 | are almost free: 402 | 403 | d/dk I(nu,k) = I(nu-1, k) - (nu/x)I(nu,k) 404 | = (nu/k)I(nu,k) + I(nu+1,k) 405 | = (I(nu-1,k) + I(nu+1,k)) / 2 406 | 407 | so 408 | 409 | d/dk I(nu+1,k) = I(nu, z) - ((nu+1)/k) I(nu+1,k) 410 | 411 | 412 | What is rho? For a Von Mises-Fisher distribution, with concentration k, 413 | 0 <= rho(k) < 1 gives the norm of the expected value. 414 | 415 | 416 | """ 417 | def __init__(self, nu): 418 | self.nu = nu 419 | self.logI = LogBesselI(nu) 420 | self.logI1 = LogBesselI(nu+1) 421 | 422 | def __call__(self, k=None, logk=None): 423 | """ 424 | input: k and/or logk. Both are used, so supply both if you have them. 425 | 426 | returns: an IPair, from which function values and derivatives 427 | can be obtained as properties. 428 | 429 | """ 430 | nu = self.nu 431 | k, logk = k_and_logk(k, logk, True, True) 432 | y = self.logI(k,logk) 433 | y1 = self.logI1(k,logk) 434 | return IPair(nu, k, logk, y, y1) 435 | 436 | 437 | 438 | class IPair: 439 | def __init__(self, nu, k, logk, y, y1): 440 | self.nu = nu 441 | self.k = k 442 | self.logk = logk 443 | self.logI = y 444 | self.logI1 = y1 445 | 446 | 447 | @property 448 | def dlogI_dlogk(self): 449 | nu = self.nu 450 | logk, y, y1 = self.logk, self.logI, self.logI1 451 | return nu + np.exp(logk + y1 - y) 452 | 453 | @property 454 | def dlogI1_dlogk(self): 455 | nu = self.nu 456 | logk, y, y1 = self.logk, self.logI, self.logI1 457 | return np.exp(logk + y - y1) - (nu+1) 458 | 459 | @property 460 | def dlogI_dk(self): 461 | return self.dI_dlogk / self.k 462 | 463 | @property 464 | def dlogI1_dk(self): 465 | return self.dI1_dlogk / self.k 466 | 467 | 468 | 469 | @property 470 | def logCvmf(self): 471 | """ 472 | see documentation for BesselI.logCvmf 473 | """ 474 | nu = self.nu 475 | logk, logI = self.logk, self.logI 476 | return nu*logk - logI 477 | 478 | @property 479 | def dlogCvmf_dlogk(self): 480 | return self.nu - self.dlogI_dlogk 481 | 482 | @property 483 | def dlogCvmf_dk(self): 484 | return self.dlogCvmf_dlogk / self.k 485 | 486 | 487 | 488 | @property 489 | def log_rho(self): 490 | """ 491 | rho(k) = I_nu+1(k) / I_nu(k) 492 | """ 493 | k = self.k 494 | if np.isscalar(k) and k==0: 495 | return -np.inf 496 | return self.logI1 - self.logI 497 | 498 | @property 499 | def dlog_rho_dlogk(self): 500 | """ 501 | rho(x) = I_nu+1(x) / I_nu(x) 502 | """ 503 | return self.dlogI1_dlogk - self.dlogI_dlogk 504 | 505 | @property 506 | def dlog_rho_dk(self): 507 | return self.dlog_rho_dlogk / self.k 508 | 509 | 510 | @property 511 | def rho(self): 512 | """ 513 | rho(x) = I_nu+1(x) / I_nu(x) 514 | """ 515 | return np.exp(self.log_rho) 516 | 517 | 518 | 519 | @property 520 | def drho_dlogk(self): 521 | """ 522 | rho(x) = I_nu+1(x) / I_nu(x) 523 | """ 524 | return self.rho * self.dlog_rho_dlogk 525 | 526 | @property 527 | def drho_dk(self): 528 | """ 529 | rho(k) = I_nu+1(k) / I_nu(k) 530 | """ 531 | return self.rho * self.dlog_rho_dk 532 | 533 | 534 | 535 | 536 | def softplus(x): 537 | """ 538 | This is just one way to define a numericallly stable softplus. 539 | It implements log(1+exp(x)), but will not overflow for large x. 540 | """ 541 | return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0) 542 | 543 | 544 | 545 | def fastLogCvmf_e(logI: LogBesselI, 546 | d=np.pi, tune = True, quiet = True, err = None): 547 | 548 | """ 549 | This works: 550 | very well for large nu, 551 | ok for smaller nu 552 | and worst for nu=0 (VMF dim = 2). 553 | 554 | It tunes a pre-and-post-scaled-and-shifted softplus approximation to the 555 | function: 556 | 557 | log Cvmf_e(log k). 558 | 559 | The approximation is constrained to always obey the left limit and the 560 | right linear asymptote of logCvmf_e. 561 | 562 | 563 | The true functions for small nu are less smooth than our approximations, 564 | espically for nu=0, which has an extra bulge below the softplus elbow. 565 | For large nu, the trur function is very smooth and well-approximated by 566 | the softplus approximation. 567 | 568 | 569 | inputs: 570 | 571 | logI: LogBesselI: this contains nu = dim/2-1 572 | 573 | d>0: (default = pi) the tunable softplus input scaling factor 574 | 575 | tune: boolean (default True) 576 | 577 | quiet: boolean (default True), print tuning progress if False 578 | 579 | err: an optional tuning objective 580 | 581 | 582 | returns: 583 | 584 | f: a function handle for the fast approximation, with: 585 | 586 | f.nu 587 | f.params = (a,b,c,d), the scaling and shifting constants 588 | f.slow, function handle for the slower reference implementation 589 | 590 | """ 591 | 592 | nu = logI.nu 593 | slow = logI.logCvmf_e 594 | 595 | left = nu*log2 + gammaln(nu+1) # left flat asymptote 596 | right_offs = log2pi/2 # offset for right linear asymptote 597 | right_slope = nu + 0.5 # slope for right linear asymptote 598 | 599 | a = left 600 | def bc(d): 601 | b = right_slope / d 602 | c = (right_offs - a) / b 603 | return b, c 604 | b, c = bc(d) 605 | #print(f'nu={nu}: b={b}, c = {c}, d={d}') 606 | 607 | 608 | approx = lambda logk: a + b*softplus(c + d*logk) 609 | def f(k=None, logk=None): 610 | k, logk = k_and_logk(k, logk, False, True) 611 | return approx(logk) 612 | 613 | f.nu = nu 614 | f.slow = slow 615 | if not tune: 616 | f.params = (a,b,c,d) 617 | return f 618 | 619 | target = lambda logk: slow(logk=logk) 620 | 621 | if err is None: 622 | err = lambda logk,y: (y-target(logk))**2 623 | neg_err = lambda logk: -err(logk,approx(logk)) 624 | logk = minimize_scalar(neg_err,(2.0,4.0)).x 625 | 626 | if not quiet: 627 | print(f'\ntuning softplus for nu={nu}') 628 | print(f' max abs error of {np.sqrt(-neg_err(logk))} at {logk}') 629 | 630 | def obj(d): 631 | b, c = bc(d) 632 | flogk = a + b*softplus(c + d*logk) 633 | return err(logk,flogk) 634 | d = minimize_scalar(obj,(d*0.9,d*1.1)).x 635 | if not quiet: print(f' new d={d}, local error = {obj(d)}') 636 | b, c = bc(d) 637 | 638 | # this happens anyway, f already sees the new b,c,d 639 | # f = lambda x: a + b*softplus(c + d*x) 640 | 641 | logk = minimize_scalar(neg_err,(2.0,4.0)).x 642 | if not quiet: print(f' new max abs error of {np.sqrt(-neg_err(logk))} at {logk}') 643 | 644 | f.params = (a,b,c,d) 645 | 646 | return f 647 | 648 | 649 | 650 | def fast_logrho(logI: LogBesselI, fastLogCe = None, quiet = True): 651 | """ 652 | This works ok for nu=0 and well for nu>0 653 | 654 | It tunes two softplus log Cvmf_e approximations and uses their difference 655 | to approximate log rho. 656 | 657 | 658 | inputs: 659 | 660 | logI: LogBesselI: this contains nu = dim/2-1 661 | 662 | quiet: boolean (default True), print tuning progress if False 663 | 664 | 665 | 666 | returns: 667 | 668 | f: a function handle for the fast approximation, which maps: 669 | 670 | k and/or logk to log rho(k). 671 | 672 | The approximation uses only logk. 673 | 674 | Extra info is returned in attached fields: 675 | f.nu 676 | f.C = function handle for fast logCvmf_e(nu) 677 | f.C1, function handle for fast logCvmf_e(nu+1) 678 | f.slow function handle to reference log rho 679 | 680 | 681 | 682 | Note: An altenative is to tune 683 | 684 | exp(affine->softplus->affine) 685 | 686 | to fit rho. This approximation gives a sigmoid. The true 687 | rho(log_kappa) is close to a sigmoid, but is somewhat less smooth, 688 | especially for small nu. 689 | 690 | 691 | 692 | """ 693 | nu = logI.nu 694 | logI1 = LogBesselI(nu+1) 695 | Cslow = logI.logCvmf_e 696 | Cslow1 = logI1.logCvmf_e 697 | 698 | C = fastLogCe or fastLogCvmf_e(logI, tune=nu>0, quiet=quiet) 699 | 700 | def teacher(logk): 701 | return logk + Cslow(logk=logk) - Cslow1(logk=logk) 702 | 703 | 704 | def student1(logk,c1): 705 | return logk + C(logk=logk) - c1 706 | 707 | 708 | def err1(logk,c1): 709 | return (np.exp(teacher(logk))-np.exp(student1(logk,c1)))**2 710 | 711 | 712 | C1 = fastLogCvmf_e(logI1, quiet=quiet, err=err1) 713 | 714 | 715 | # fast log rho 716 | def fast(k=None, logk=None): 717 | k, logk = k_and_logk(k, logk, False, True) 718 | return logk + C(logk=logk) - C1(logk=logk) 719 | 720 | # slow log rho 721 | def slow(k=None, logk=None): 722 | k, logk = k_and_logk(k, logk, False, True) 723 | return teacher(logk) 724 | 725 | fast.slow = slow 726 | fast.nu = nu 727 | fast.C = C 728 | fast.C1 = C1 729 | 730 | return fast 731 | 732 | 733 | 734 | 735 | if __name__ == "__main__": 736 | 737 | import matplotlib.pyplot as plt 738 | 739 | dim = 256 740 | nu = dim/2-1 741 | 742 | logk = np.linspace(-5,5,200) 743 | k = np.exp(logk) 744 | 745 | logBesselI = LogBesselI(nu,5) 746 | 747 | small = logBesselI.small_log_iv(logk=logk) 748 | 749 | 750 | with np.errstate(divide='ignore'): 751 | ref = log_ive_raw(nu, k) 752 | 753 | 754 | plt.figure() 755 | plt.semilogx(k,small,'g',label='small') 756 | plt.semilogx(k,ref,'r--',label='ref') 757 | plt.semilogx(k,ref-small,label='err') 758 | plt.legend() 759 | plt.grid() 760 | plt.show() 761 | 762 | 763 | 764 | 765 | logk = np.linspace(-6,14,200) 766 | nu = 127 767 | pair = LogBesselIPair(nu)(logk=logk) 768 | rho, drho_dlogk = pair.rho, pair.drho_dlogk 769 | plt.figure() 770 | plt.plot(logk,rho,'r',label='rho') 771 | plt.plot(logk,drho_dlogk,label='dy/dlogk') 772 | plt.grid() 773 | plt.xlabel('log k') 774 | plt.ylabel('rho') 775 | plt.title(f'nu = {nu}') 776 | 777 | fastlogrho = fast_logrho(LogBesselI(nu)) 778 | y = np.exp(fastlogrho(logk=logk)) 779 | plt.plot(logk,y,'g--',label='rho approx') 780 | 781 | 782 | plt.legend() 783 | plt.show() 784 | 785 | 786 | 787 | logk = np.linspace(-6,20,200) 788 | plt.figure() 789 | for dim in [128, 256, 512]: 790 | nu = dim/2-1 791 | logI = LogBesselI(nu) 792 | y = logI.logCvmf_e(logk=logk) 793 | plt.plot(logk,y,label=f'dim={dim}') 794 | y = (nu+0.5)*logk + log2pi/2 795 | plt.plot(logk,y,'--') 796 | plt.grid() 797 | plt.xlabel('log k') 798 | plt.ylabel('log C_nu(k) + k') 799 | plt.title('asymptotes') 800 | plt.legend() 801 | plt.show() 802 | 803 | 804 | 805 | 806 | 807 | logk = np.linspace(-5,21,200) 808 | x = np.exp(logk) 809 | plt.figure() 810 | #for dim in [100, 110, 120]: 811 | for dim in [128, 256, 512]: 812 | #for dim in [2, 3, 4]: 813 | nu = dim/2-1 814 | fast = fastLogCvmf_e(LogBesselI(nu), tune=nu>0, quiet=False) 815 | target = fast.slow 816 | y = target(x,logk) 817 | plt.plot(logk,y,label=f'dim={dim}') 818 | plt.plot(logk,fast(x,logk),'--') 819 | 820 | 821 | plt.grid() 822 | plt.xlabel('log k') 823 | plt.ylabel('log C_nu(k) + k') 824 | plt.legend() 825 | plt.show() 826 | 827 | 828 | print("\n\n") 829 | nu = 255 830 | logI = LogBesselI(nu) 831 | for logk in (20,21): 832 | k = np.exp(logk) 833 | raw = log_ive_raw(nu, k) 834 | s1 = logI.large_log_ive(k, logk,asymptote=True) 835 | s2 = logI.large_log_ive(k, logk,asymptote=False) 836 | print(f"logk={logk}: {raw:.5f}, {s2:.5f}, {s1:.5f}") 837 | -------------------------------------------------------------------------------- /psda/mix_psda.py: -------------------------------------------------------------------------------- 1 | """ 2 | MixPSDA: Probabilistic Spherical Discriminant Analysis 3 | 4 | """ 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | from psda.vmf import VMF, compose, decompose, LogNormConst 9 | 10 | 11 | rng = np.random.default_rng() 12 | 13 | 14 | 15 | class MixPSDA: 16 | """ 17 | Probabilistic Spherical Discriminant Analysis Model 18 | 19 | """ 20 | 21 | def __init__(self, 22 | p_i:float, 23 | within_concentration:float, 24 | between_distr:VMF): 25 | """ 26 | model = MixPSDA(w, VMF(mu, b)) 27 | 28 | w,b > 0 29 | mu (dim, ) is a lengh-normalized speaker mean 30 | 31 | 32 | or 33 | 34 | model = MixPSDA(w, VMF(mean)) 35 | 36 | w,b > 0, 37 | mean (dim,) is speaker mean inside (not on) unit hypersphere 38 | 39 | or 40 | 41 | model = MixPSDA.em(means,counts) # see the documention for em() 42 | 43 | 44 | 45 | """ 46 | self.p_i = p_i 47 | self.w = w = within_concentration 48 | self.between = between = between_distr 49 | self.b = b = between.k 50 | self.dim = between.dim 51 | self.mu = between.mu 52 | self.bmu = between.kmu 53 | self.logC = logC = between.logC 54 | self.logCb = logC(b) 55 | self.logCw = logC(w) 56 | 57 | self.ncomp = len(p_i) 58 | assert self.mu.shape[0] == self.ncomp, "Expecting {self.ncomp} means." 59 | self.tied_w = self.w.size < self.ncomp 60 | self.tied_b = self.b.size < self.ncomp 61 | 62 | def save(self,fname): 63 | import h5py 64 | with h5py.File(fname,'w') as h5: 65 | h5["pi"] = self.p_i 66 | h5["w"] = self.w 67 | self.between.save_to_h5(h5,"between") 68 | 69 | @classmethod 70 | def load(cls,fname): 71 | import h5py 72 | with h5py.File(fname,'r') as h5: 73 | w = np.asarray(h5["w"]) 74 | p_i = np.asarray(h5["pi"]) 75 | w = np.atleast_1d(w) 76 | p_i = np.atleast_1d(p_i) 77 | between = VMF.load_from_h5(h5,"between") 78 | return cls(p_i,w,between) 79 | 80 | # def marg_llh(self, data_sum, count): 81 | # """ 82 | # Computes the marginal log-likelihood log P(X | same speaker), where 83 | # z is integrated out. We use: log P(X | z) P(z) / P(z | X) 84 | 85 | # Returns a vector of independent calculations if data_sum is a matrix 86 | # and count is a vector. 87 | # """ 88 | # post = self.zposterior(data_sum) 89 | # return self.logCw*count + self.logCb - post.logCk 90 | 91 | 92 | def sample_speakers(self, n): 93 | return self.between.sample(n) 94 | 95 | 96 | def sample(self, speakers, component_labels, labels): 97 | ww = self.w if np.isscalar(self.w) else self.w[component_labels] 98 | within = VMF(speakers, ww, self.logC) 99 | return within.sample(labels) 100 | 101 | 102 | def prep(self, X: ndarray, counts: ndarray=None): 103 | """ 104 | Does some precomputation for fast computation of a matrix of LLR 105 | scores. 106 | 107 | X: vector or matrix of observations (in rows) 108 | Each row can represent a single observation, or multiple 109 | observations. For single observation, the row must be on the 110 | unit hypersphere. For multiple observations, the row must be the 111 | sum of observations, which can be anywhere in R^d. 112 | 113 | To be clear, when doing multi-enroll trials, the enrollment 114 | embeddings must be summed, not averaged. Do not length-norm 115 | again after summing. 116 | 117 | 118 | returns: a Side that contains precomputed stuff for fast scoring. 119 | The Side provides a method for scoring against another 120 | Side. 121 | """ 122 | if counts is None: 123 | counts = np.ones(X.shape[0]) 124 | return Side(self,X.astype(np.float64),counts) 125 | 126 | 127 | 128 | def llr_matrix(self, enroll:ndarray, test:ndarray) -> ndarray: 129 | """ 130 | Convenience method. See MixPSDA.prep() for details. 131 | """ 132 | return self.prep(enroll).llr_matrix(self.prep(test)) 133 | 134 | def llr_vector(self, enroll:ndarray, test:ndarray) -> ndarray: 135 | """ 136 | Convenience method. See MixPSDA.prep() for details. 137 | """ 138 | return self.prep(enroll).llr_vector(self.prep(test)) 139 | 140 | 141 | @classmethod 142 | def em(cls, means: ndarray, counts: ndarray, 143 | niters = 10, w0 = 1.0, quiet = False, 144 | psda_init = None, fname_tmp=None): 145 | """ 146 | Trains a MixPSDA model from data. 147 | 148 | means: (n, dim) the means of each of the n classes available for training 149 | 150 | counts: (n,) the number of examples of each class 151 | 152 | niters: the number of EM iterations to run 153 | 154 | w0>0: (optional) initial guess for within-class concentration 155 | 156 | returns: the trained model as a MixPSDA object 157 | 158 | 159 | """ 160 | assert counts.min() > 1, "all speakers need at least 2 observations" 161 | means = means.astype(np.float64) 162 | ns, dim = means.shape 163 | assert len(counts) == ns 164 | total = counts.sum() 165 | psda = psda_init or cls.em_init(means,w0) 166 | if not quiet: 167 | print(f"em init : 0 B =",psda.b, "W =",psda.w,"mu =",psda.between.mu.ravel()[:6]) 168 | obj = [] 169 | llh0 = 0 170 | for i in range(niters): 171 | psda, llh = psda.em_iter(means,counts) 172 | if fname_tmp is not None: 173 | psda.save(fname_tmp.format(iter=i)) 174 | 175 | impr = llh - llh0; llh0 = llh 176 | if not quiet: 177 | print(f"em iter {i}: {impr}","B =",psda.b, "W =",psda.w,"mu =",psda.between.mu.ravel()[:6]) 178 | obj.append(llh) 179 | return psda, obj 180 | 181 | @classmethod 182 | def em_init(cls, means, w0=None, b0=None,ncomp=None): 183 | """ 184 | Invoked by em 185 | """ 186 | if ncomp is None: 187 | assert not (w0 is None and b0 is None) 188 | ncomp = np.atleast_1d(w0).size if b0 is None else np.atleast_1d(b0).size 189 | 190 | pi0 = np.ones(ncomp)/ncomp 191 | w0 = rng.uniform(100,1000,size=ncomp).astype(float) 192 | 193 | norms, means = decompose(means) 194 | assert all(norms < 1), "Invalid means" 195 | ii = np.random.choice(np.arange(means.shape[0]),size=len(w0),replace=False) 196 | # between1 = VMF.max_likelihood(means) 197 | b0 = np.ones_like(w0) 198 | 199 | between = VMF(means[ii], b0) 200 | return MixPSDA(pi0, w0, between) 201 | 202 | # def zposterior(self, data_sum:ndarray): 203 | def zposterior(self, means:ndarray, counts:ndarray): 204 | """ 205 | Computes the hidden variable posterior, given the sufficient statistics 206 | (sum of the observations). 207 | 208 | If mutiple sums (each sum is a vector) are supplied, multiple posteriors 209 | are computed. 210 | 211 | The posterior(s) one or more are returned in a single VMF object. 212 | """ 213 | 214 | data_sum = compose(counts, means) # s x dim 215 | 216 | w = self.w if np.isscalar(self.w) else self.w[:,None,None] 217 | # (k x 1 x 1)*(k x s x d) + (k x 1 x d) 218 | theta = w*data_sum[None,...] + self.bmu[:,None,:] 219 | post = VMF(theta) 220 | 221 | m,n = counts.size, counts.sum() 222 | # y_exp = post.mean() # m x s x dim 223 | r_tilde = np.log(self.p_i) + self.logCb # m x 1 224 | r_tilde = r_tilde[:,None] + np.atleast_1d(self.logCw)[:,None]*counts # m x s 225 | r_tilde -= post.logCk # m x s 226 | 227 | # normalize over components 228 | log_r = r_tilde - np.logaddexp.reduce(r_tilde,axis=0,keepdims=True) 229 | gamma = np.exp(log_r) # m x s 230 | assert np.all(np.isfinite(gamma)), "Check gammas!" 231 | return post, log_r 232 | 233 | def llh(self, means, counts): 234 | post, log_r = self.zposterior(means,counts) # m x s x dim 235 | m,n = counts.size, counts.sum() 236 | 237 | llh = self.logCw*n + self.logCb*m - post.logCk.sum(axis=-1) 238 | llh += m*np.log(self.p_i) - log_r.sum(axis=1) 239 | assert np.allclose(llh,llh[0]), "Candidate's formula fail!" 240 | return llh[0] 241 | 242 | def em_iter(self, means, counts): 243 | """ 244 | Invoked by em 245 | 246 | m components 247 | s speakers 248 | d dimensional embeddings 249 | 250 | returns: 251 | a new updated MixPSDA 252 | marginal log-likelihood (em objective) 253 | """ 254 | 255 | p_i, w, b, mu = self.p_i, self.w, self.b, self.mu 256 | 257 | y_post, log_r = self.zposterior(means, counts) 258 | y_exp = y_post.mean() # m x s x dim 259 | gamma = np.exp(log_r) # m x s 260 | 261 | y_bar = np.sum(gamma[:,:,None]*y_exp, axis=1) # m x dim (spk s summed out) 262 | y_bar /= gamma.sum(axis=1,keepdims=True) # m x dim 263 | 264 | pi_new = gamma.sum(axis=1)/gamma.sum() # m 265 | 266 | if self.tied_b: 267 | b_new, mu_new = decompose(y_bar) 268 | b_new = np.atleast_1d(b_new) 269 | b_new = self.logC.rhoinv(b_new@pi_new) 270 | else: 271 | # between_new = VMF.max_likelihood(y_bar, logC=self.logC) 272 | b_new, mu_new = decompose(y_bar) 273 | b_new = np.atleast_1d(b_new) 274 | b_new = self.logC.rhoinv(b_new) 275 | between_new = VMF(mu_new, b_new,self.logC) 276 | 277 | if self.tied_w: 278 | warg = np.sum(gamma[:,:,None]*y_exp, axis=0) 279 | warg = np.sum(compose(counts,means)*warg)/counts.sum() 280 | else: 281 | # warg = ((gamma*counts)*(y_exp*means).sum(axis=-1)).sum(axis=-1) 282 | warg = (gamma*(y_exp*means).sum(axis=-1))@counts 283 | warg /= (gamma@counts) 284 | 285 | assert np.all(0 < warg) and np.all(warg < 1) 286 | w_new = self.logC.rhoinv(warg) 287 | 288 | newmod = MixPSDA(pi_new, w_new, between_new) 289 | llh = newmod.llh(means,counts) 290 | return newmod, llh 291 | 292 | 293 | def __repr__(self): 294 | return f"MixPSDA(dim={self.dim}, b={self.b}, w={self.w})" 295 | 296 | def atleast2(means, counts): 297 | ok = counts > 1 298 | return means[ok,:], counts[ok] 299 | 300 | 301 | 302 | class Side: 303 | """ 304 | Represents a trial side, for one or more observations. When two trial sides 305 | are scored against each other, one containing m and the other n observations 306 | an (m,n) llr score matrix is produced. 307 | 308 | """ 309 | 310 | def __init__(self, psda:MixPSDA, X: ndarray, counts: ndarray): 311 | """ 312 | This constructor is invoked by psda.prep(X), see the docs of MixPSDA. 313 | """ 314 | self.psda = psda 315 | self.X = X 316 | self.counts = counts 317 | 318 | self.yi1norm2 = np.sum(X**2,axis=1,keepdims=True)*psda.w**2 + X@(2*psda.bmu.T*psda.w) + psda.b**2 319 | logr1 = counts[:,None]*psda.logCw + psda.logCb + np.log(psda.p_i) 320 | logr1 -= psda.logC(self.yi1norm2) 321 | self.logr1 = np.logaddexp.reduce(logr1, axis=1) 322 | 323 | 324 | def llr_matrix(self,rhs): 325 | """ 326 | Scores the one or more (m) trial sides contained in self against 327 | all (n) of the trial side(s) in rhs. Returns an (m,n) matrix of 328 | LLR scores. 329 | 330 | """ 331 | 332 | yi3norm2 = self.yi1norm2[:,None,:] + rhs.yi1norm2[None,:,:] - self.psda.b[None,None,:]**2 333 | yi3norm2 += 2*self.psda.w[None,None,:]*np.sum(self.X[:,None,:]*rhs.X[None,:,:],axis=-1,keepdims=True) 334 | 335 | logr3 = (self.counts[:,None] + rhs.counts)[:,:,None]*self.psda.logCw[None,None,:] 336 | logr3 += (self.psda.logCb + np.log(self.psda.p_i))[None,None,:] 337 | logr3 -= self.psda.logC(yi3norm2) 338 | logr3 = np.logaddexp.reduce(logr3, axis=-1) 339 | 340 | return logr3 - self.logr1[:,None] - rhs.logr1 341 | 342 | 343 | def llr_vector(self, rhs): 344 | """ 345 | Scores the n trial sides contained in self against the respective n 346 | sides in the rhs. Returns an (n,) vector of LLR scores. If one of the 347 | sides has a single trial and the other multiple trials, broadcasting 348 | will be done in the usual way. 349 | """ 350 | 351 | yi3norm2 = self.yi1norm2 + rhs.yi1norm2 - self.psda.b**2 352 | yi3norm2 += 2*np.sum(self.X*rhs.X,axis=-1,keepdims=True)*self.psda.w 353 | 354 | logr3 = (self.counts + rhs.counts)[:,None]*self.psda.logCw 355 | logr3 += (self.psda.logCb + np.log(self.psda.p_i)) 356 | logr3 -= self.psda.logC(yi3norm2) 357 | logr3 = np.logaddexp.reduce(logr3, axis=-1) 358 | 359 | return logr3 - self.logr1 - rhs.logr1 360 | -------------------------------------------------------------------------------- /psda/psdamodel.py: -------------------------------------------------------------------------------- 1 | """ 2 | PSDA: Probabilistic Spherical Discriminant Analysis 3 | 4 | """ 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | from psda.vmf import VMF, compose, decompose 9 | 10 | 11 | 12 | class PSDA: 13 | """ 14 | Probabilistic Spherical Discriminant Analysis Model 15 | 16 | """ 17 | 18 | def __init__(self, within_concentration:float, between_distr:VMF): 19 | """ 20 | model = PSDA(w, VMF(mu, b)) 21 | 22 | w,b > 0 23 | mu (dim, ) is a lengh-normalized speaker mean 24 | 25 | or 26 | 27 | model = PSDA(w, VMF(mean)) 28 | 29 | w,b > 0, 30 | mean (dim,) is speaker mean inside (not on) unit hypersphere 31 | 32 | or 33 | 34 | model = PSDA.em(means,counts) # see the documention for em() 35 | 36 | 37 | 38 | """ 39 | self.w = w = within_concentration 40 | self.between = between = between_distr 41 | self.b = b = between.k 42 | self.dim = between.dim 43 | self.mu = between.mu 44 | self.bmu = between.kmu 45 | self.logC = logC = between.logC 46 | self.logCb = logC(b) 47 | self.logCw = logC(w) 48 | 49 | def save(self,fname): 50 | import h5py 51 | with h5py.File(fname,'w') as h5: 52 | h5["w"] = self.w 53 | self.between.save_to_h5(h5,"between") 54 | 55 | @classmethod 56 | def load(cls,fname): 57 | import h5py 58 | with h5py.File(fname,'r') as h5: 59 | w = np.asarray(h5["w"]) 60 | w = np.atleast_1d(w) 61 | between = VMF.load_from_h5(h5,"between") 62 | return cls(w,between) 63 | 64 | def zposterior(self, data_sum:ndarray): 65 | """ 66 | Computes the hidden variable posterior, given the sufficient statistics 67 | (sum of the observations). 68 | 69 | If mutiple sums (each sum is a vector) are supplied, multiple posteriors 70 | are computed. 71 | 72 | The posterior(s) one or more are returned in a single VMF object. 73 | """ 74 | 75 | w, bmu = self.w, self.bmu 76 | return VMF(w*data_sum + bmu) 77 | 78 | def marg_llh(self, data_sum, count): 79 | """ 80 | Computes the marginal log-likelihood log P(X | same speaker), where 81 | z is integrated out. We use: log P(X | z) P(z) / P(z | X) 82 | 83 | Returns a vector of independent calculations if data_sum is a matrix 84 | and count is a vector. 85 | """ 86 | post = self.zposterior(data_sum) 87 | return self.logCw*count + self.logCb - post.logCk 88 | 89 | 90 | 91 | 92 | def sample_speakers(self, n): 93 | return self.between.sample(n) 94 | 95 | 96 | def sample(self, speakers, labels): 97 | within = VMF(speakers, self.w, self.logC) 98 | return within.sample(labels) 99 | 100 | 101 | 102 | 103 | def prep(self, X: ndarray): 104 | """ 105 | Does some precomputation for fast computation of a matrix of LLR 106 | scores. 107 | 108 | X: vector or matrix of observations (in rows) 109 | Each row can represent a single observation, or multiple 110 | observations. For single observation, the row must be on the 111 | unit hypersphere. For multiple observations, the row must be the 112 | sum of observations, which can be anywhere in R^d. 113 | 114 | To be clear, when doing multi-enroll trials, the enrollment 115 | embeddings must be summed, not averaged. Do not length-norm 116 | again after summing. 117 | 118 | 119 | returns: a Side that contains precomputed stuff for fast scoring. 120 | The Side provides a method for scoring against another 121 | Side. 122 | """ 123 | return Side(self,X.astype(np.float64)) 124 | 125 | 126 | 127 | def llr_matrix(self, enroll:ndarray, test:ndarray) -> ndarray: 128 | """ 129 | Convenience method. See PSDA.prep() for details. 130 | """ 131 | return self.prep(enroll).llr_matrix(self.prep(test)) 132 | 133 | def llr_vector(self, enroll:ndarray, test:ndarray) -> ndarray: 134 | """ 135 | Convenience method. See PSDA.prep() for details. 136 | """ 137 | return self.prep(enroll).llr_vector(self.prep(test)) 138 | 139 | @classmethod 140 | def em(cls, means: ndarray, counts:ndarray, niters = 10, w0 = 1.0, quiet = False): 141 | """ 142 | Trains a PSDA model from data. 143 | 144 | means: (n, dim) the means of each of the n classes available for training 145 | 146 | counts: (n,) the number of examples of each class 147 | 148 | niters: the number of EM iterations to run 149 | 150 | w0>0: (optional) initial guess for within-class concentration 151 | 152 | returns: the trained model as a PSDA object 153 | 154 | 155 | """ 156 | assert counts.min() > 1, "all speakers need at least 2 observations" 157 | means = means.astype(np.float64) 158 | ns, dim = means.shape 159 | assert len(counts) == ns 160 | total = counts.sum() 161 | psda = cls.em_init(means,w0) 162 | obj = [] 163 | for i in range(niters): 164 | psda, llh = psda.em_iter(means,counts,total) 165 | if not quiet: print(f"em iter {i}: {llh}") 166 | obj.append(llh) 167 | return psda, obj 168 | 169 | 170 | 171 | def em_iter(self, means, counts, total): 172 | """ 173 | Invoked by em 174 | 175 | returns: 176 | a new updated PSDA 177 | marginal log-likelihood (em objective) 178 | """ 179 | zpost = self.zposterior(compose(counts,means)) 180 | nspk = len(counts) 181 | llh = self.logCw*total + self.logCb*nspk - zpost.logCk.sum() 182 | 183 | z_exp = zpost.mean() 184 | zbar = z_exp.mean(axis=0) 185 | between = VMF.max_likelihood(zbar, self.logC) 186 | r = ((z_exp*means).sum(axis=-1)@counts) / total 187 | assert 0 < r < 1 188 | w = self.logC.rhoinv(r) 189 | return PSDA(w,between), llh 190 | 191 | @classmethod 192 | def em_init(cls,means, w0): 193 | """ 194 | Invoked by em 195 | """ 196 | norms, means = decompose(means) 197 | assert all(norms < 1), "Invalid means" 198 | between = VMF.max_likelihood(means.mean(axis=0)) 199 | return PSDA(w0, between) 200 | 201 | 202 | def __repr__(self): 203 | return f"PSDA(dim={self.dim}, b={self.b}, w={self.w})" 204 | 205 | def atleast2(means, counts): 206 | ok = counts > 1 207 | return means[ok,:], counts[ok] 208 | 209 | 210 | 211 | class Side: 212 | """ 213 | Represents a trial side, for one or more observations. When two trial sides 214 | are scored against each other, one containing m and the other n observations 215 | an (m,n) llr score matrix is produced. 216 | 217 | """ 218 | 219 | def __init__(self, psda:PSDA, X: ndarray): 220 | """ 221 | This constructor is invoked by psda.prep(X), see the docs of PSDA. 222 | """ 223 | self.logC = psda.logC 224 | self.logCb = psda.logCb 225 | self.wX = wX = psda.w*X 226 | self.wX_norm2 = (wX**2).sum(axis=-1) 227 | self.pstats = pstats = wX + psda.bmu 228 | self.pstats_norm2 = pnorm2 = (pstats**2).sum(axis=-1) 229 | pnorm = np.sqrt(pnorm2) 230 | self.num = self.logC(pnorm) 231 | 232 | 233 | def llr_matrix(self,rhs): 234 | """ 235 | Scores the one or more (m) trial sides contained in self against 236 | all (n) of the trial side(s) in rhs. Returns an (m,n) matrix of 237 | LLR scores. 238 | """ 239 | norm2 = self.pstats_norm2.reshape(-1,1) + rhs.wX_norm2 + \ 240 | 2*self.pstats @ rhs.wX.T 241 | denom = self.logC(np.sqrt(norm2)) 242 | return self.num.reshape(-1,1) + rhs.num - denom - self.logCb 243 | 244 | 245 | 246 | def llr_vector(self, rhs): 247 | """ 248 | Scores the n trial sides contained in self against the respective n 249 | sides in the rhs. Returns an (n,) vector of LLR scores. If one of the 250 | sides has a single trial and the other multiple trials, broadcasting 251 | will be done in the usual way. 252 | """ 253 | norm2 = self.pstats_norm2 + rhs.wX_norm2 + \ 254 | 2*(self.pstats * rhs.wX).sum(axis=-1) 255 | denom = self.logC(np.sqrt(norm2)) 256 | return self.num + rhs.num - denom - self.logCb 257 | -------------------------------------------------------------------------------- /psda/vmf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #from scipy.special import gammaln 4 | from scipy.optimize import toms748 5 | 6 | 7 | from psda.vmf_sampler import rotate_to_mu, sample_vmf_canonical_mu, sample_uniform 8 | from psda.besseli import LogBesselI, fast_logrho, fastLogCvmf_e, k_and_logk 9 | 10 | 11 | # def logfactorial(x): 12 | # """ 13 | # Natural log of factorial. Invokes scipy.special.gammaln. 14 | # log x! = log gamma(x+1) 15 | # """ 16 | # return gammaln(x+1) 17 | 18 | 19 | 20 | 21 | 22 | class LogNormConst: 23 | """ 24 | This is a callable for the log normalization constant for Von Mises-Fisher 25 | distribution, computed as a function of the dimension and the concentration 26 | parameter. A term that dependens only on the dimension is omitted. 27 | 28 | The dimensionality, via nu = dim/2-1, is supplied at construction, 29 | while the concentration is passed later to the constructed callable. 30 | 31 | 32 | As a callable it returns: 33 | 34 | log C_nu(k) = log k**nu - log Inu(k) 35 | 36 | 37 | An additional method, rho(k) computes the norm of the expected 38 | value of a VMF with this normalization constant, Cnu(k). 39 | 40 | Also supplied are rhoinv(rho) that uses root finding to invert rho(k), 41 | and rhoinv_fast(rho), that does a fast approximation. 42 | 43 | """ 44 | def __init__(self,dim): 45 | self.dim = dim 46 | self.nu = nu = dim/2-1 47 | self.logI = logI = LogBesselI(nu) 48 | self.logCvmf_e = logI.logCvmf_e 49 | self.fastlogCvmf_e = fastLogCe = fastLogCvmf_e(logI) 50 | self.fastlogrho = fastlogrho = fast_logrho(logI, fastLogCe) 51 | self.logrho = fastlogrho.slow 52 | 53 | 54 | def __call__(self, k = None, logk = None, fast = False, exp_scale = False): 55 | """ 56 | Returns the log normalization constant, omitting a term dependent 57 | only on the dimensionality, nu. 58 | 59 | kappa > 0: the VMF concentration parameter 60 | 61 | 62 | """ 63 | k, logk = k_and_logk(k, logk) 64 | logCe = self.fastlogCvmf_e(k, logk) if fast else \ 65 | self.logCvmf_e(k, logk) 66 | if exp_scale: return logCe 67 | return logCe - k 68 | 69 | 70 | def rho(self, k = None, logk = None, fast = False): 71 | """ 72 | The norm of the expected value for VMF(nu,k). The expected value is 73 | monotonic rising, from rho(0) = 0 to rho(inf) = 1. The limit at 0 74 | is handled explicitly, but the one at infinity is not implemented. 75 | """ 76 | log_rho = self.fastlogrho(k, logk) if fast else \ 77 | self.logrho(k, logk) 78 | return np.exp(log_rho) 79 | 80 | 81 | def rhoinv_fast(self,rho): 82 | """ 83 | Fast, approximate inversion of rho given by Banerjee'05 84 | """ 85 | if np.isscalar(rho): 86 | return self.rhoinv_fast(np.array([rho]))[0] 87 | dim = self.dim 88 | nz = rho>0 89 | k = np.zeros_like(rho) 90 | if np.any(nz): 91 | rhonz = rho[nz] 92 | rho2 = rhonz**2 93 | k[nz] = rhonz*(dim-rho2) / (1-rho2) 94 | return k 95 | 96 | 97 | def rhoinv(self, rho, fast=False): 98 | """ 99 | Slower, more accurate inversion of rho using a root finder. 100 | Except, if fast = True, it just calls rhoinv_fast. 101 | """ 102 | 103 | # probably more accurate than iverting the fast rho approximation 104 | if fast: return self.rhoinv_fast(rho) 105 | 106 | if not np.isscalar(rho): 107 | return np.array([self.rhoinv(ri) for ri in rho]) 108 | if rho == 0: return 0.0 109 | k0 = self.rhoinv_fast(rho) 110 | f = lambda logk: self.rho(logk = logk) - rho 111 | left = np.log(k0) 112 | fleft = f(left) 113 | if fleft == 0: return k0 114 | if fleft < 0: 115 | right = left 116 | fright = fleft 117 | while fright <=0: 118 | right = 2.3 + right 119 | fright = f(right) 120 | else: # fleft > 0 121 | right = left 122 | fright = fleft 123 | while fleft >= 0: 124 | left = left - 2.3 125 | fleft = f(left) 126 | return np.exp(toms748(f,left,right)) 127 | 128 | def decompose(x): 129 | """ 130 | If x is a vector, return: norm, x/norm 131 | 132 | If x is a matrix, do the same for every row. The norms are returned 133 | as a 1d array (not as a column). 134 | 135 | """ 136 | if x.ndim == 1: 137 | norm = np.sqrt((x**2).sum(axis=-1)) 138 | if norm == 0: return 0.0, x 139 | return norm, x/norm 140 | norm = np.sqrt((x**2).sum(axis=-1,keepdims=True)) 141 | zeros = norm == 0 142 | if np.any(zeros): 143 | norm[zeros] = 1 144 | return norm.squeeze(), x/norm 145 | 146 | def compose(norm,mu): 147 | """ 148 | Does the inverse of decompose. Returns norm.reshape(-1,1)*mu 149 | 150 | norm: scalar or vector 151 | mu: vector or matrix 152 | 153 | """ 154 | if not np.isscalar(norm): norm = norm.reshape(*(*mu.shape[:-1],1)) 155 | return norm*mu 156 | 157 | class VMF: 158 | """ 159 | Von Mises-Fisher distribution. The parameters are supplied at construction. 160 | """ 161 | def __init__(self, mu=None, k = None, logC = None): 162 | """ 163 | mu: (dim, ): mean direction, must be lengh-normalized. 164 | 165 | (n,dim): if mu is a matrix, each row gives a different distribution 166 | 167 | If k is not given, then mu is the natural parameter, which is 168 | not length-normed. Then mu will be decomposed so that its norm 169 | becomes the concentration. 170 | 171 | k>=0: scalar, concentration parameter 172 | 173 | (n, ): if k is a vector mu must be a matrix and they must agree in 174 | shape[0] 175 | 176 | If k is omitted, mu is assumed to be the unnormalized natural 177 | parameter and k is recovered from the norm of mu. 178 | 179 | logC: LogNormConst, optional. If already available, it can be 180 | supplied to save memory and compute. 181 | 182 | """ 183 | if np.isscalar(mu): # dim <- mu, k <- 0 184 | k = 0.0 185 | mu = sample_uniform(mu) 186 | if k is None: 187 | kmu = mu 188 | k, mu = decompose(mu) 189 | else: 190 | kmu = compose(k, mu) 191 | self.mu = mu 192 | self.k = k 193 | self.kmu = kmu 194 | self.dim = dim = mu.shape[-1] 195 | if logC is None: 196 | logC = LogNormConst(dim) 197 | else: 198 | assert logC.dim == dim 199 | self.logC = logC 200 | self.logCk = logC(k) 201 | self.rho = logC.rho # function to compute k -> norm of mean 202 | 203 | def save_to_h5(self,h5,path): 204 | h5[f"{path}/mu"] = self.mu 205 | h5[f"{path}/k"] = self.k 206 | 207 | @classmethod 208 | def load_from_h5(cls,h5,path): 209 | mu = np.asarray(h5[f"{path}/mu"]) 210 | k = np.asarray(h5[f"{path}/k"]) 211 | mu = np.atleast_2d(mu) 212 | k = np.atleast_1d(k) 213 | return cls(mu,k) 214 | 215 | def mean(self): 216 | """ 217 | Returns the expected value in R^d, which is inside the sphere, 218 | not on it. 219 | """ 220 | r = self.rho(self.k) 221 | return compose(r,self.mu) 222 | 223 | def kmu(self): 224 | """ 225 | returns the natural parameter, which is in R^d 226 | """ 227 | return self.kmu 228 | 229 | @classmethod 230 | def uniform(cls, dim): 231 | return cls(dim) 232 | 233 | 234 | @classmethod 235 | def max_likelihood(cls, mean, logC = None): 236 | """ 237 | The returns the maximum-likelihood estimate(s) for one or more VMFs, given 238 | the sufficient stats. 239 | 240 | mean: (dim,) the empirical mean (average) of the observations. 241 | The observations are on the unit hypersphere and the mean must be 242 | inside it (with norm stricly < 1) 243 | 244 | (n,dim): do n independent ML estimates 245 | 246 | returns: a VMF object, containing one or more distributions, all of the 247 | same dimensionality 248 | 249 | 250 | """ 251 | norm, mu = decompose(mean) 252 | assert norm < 1, "The mean norm must be strictly < 1" 253 | dim = len(mean) 254 | if logC is None: 255 | logC = LogNormConst(dim) 256 | else: 257 | assert logC.dim == mean.shape[-1] 258 | k = logC.rhoinv(norm) 259 | return cls(mu,k,logC) 260 | 261 | 262 | 263 | def sample_quick_and_dirty(self, n_or_labels): 264 | """ 265 | Quick and dirty (statistically incorrect) samples, meant only for 266 | preliminary tyre-kicking. 267 | 268 | If self contains a single distribution, supply n, the number of 269 | required samples. 270 | 271 | If self contains multiple distribution, supply labels (n, ) to select 272 | for each sample the distribution to be sampled from 273 | 274 | """ 275 | 276 | 277 | if np.isscalar(n_or_labels): 278 | n = n_or_labels 279 | labels = None 280 | assert self.mu.ndim == 1 281 | else: 282 | labels = n_or_labels 283 | n = len(labels) 284 | assert self.mu.ndim == 2 285 | 286 | 287 | dim, k = self.dim, self.k 288 | mean = self.mean() 289 | if labels is not None: 290 | mean = mean[labels,:] 291 | X = np.random.randn(n,dim)/np.sqrt(k) + mean 292 | return decompose(X)[1] 293 | 294 | 295 | def sample(self, n_or_labels): 296 | """ 297 | Generate samples from the von Mises-Fisher distribution. 298 | If self contains a single distribution, supply n, the number of 299 | required samples. 300 | If self contains multiple distributions, supply labels (n, ) to select 301 | for each sample the distribution to be sampled from. 302 | Reference: 303 | o Stochastic Sampling of the Hyperspherical von Mises–Fisher Distribution 304 | Without Rejection Methods - Kurz & Hanebeck, 2015 305 | o Simulation of the von Mises-Fisher distribution - Wood, 1994 306 | """ 307 | 308 | dim, mu = self.dim, self.mu 309 | 310 | if np.isscalar(n_or_labels): # n iid samples from a single distribution 311 | n = n_or_labels 312 | assert mu.ndim == 1 313 | assert np.isscalar(self.k) 314 | if self.k==0: 315 | return sample_uniform(dim,n) 316 | X = np.vstack([sample_vmf_canonical_mu(dim,self.k) for i in range(n)]) 317 | X = rotate_to_mu(X,mu) 318 | 319 | else: # index distribution by labels 320 | labels = n_or_labels 321 | assert mu.ndim == 2 322 | if np.isscalar(self.k): # broadcast k 323 | kk = np.full((len(labels),),self.k) 324 | else: 325 | kk = self.k[labels] 326 | 327 | X = np.vstack([sample_vmf_canonical_mu(dim,k) for k in kk]) 328 | 329 | for lab in np.unique(labels): 330 | ii = labels==lab 331 | X[ii] = rotate_to_mu(X[ii],mu[lab]) 332 | 333 | return X 334 | 335 | 336 | def logpdf(self, X): 337 | """ 338 | If X is a vector, return scalar or vector, depending if self contains 339 | one or more distributions. 340 | 341 | If X is a matrix, returns an (m,) vector or an (m,n) matrix, where m 342 | is the number of rows of X and n is the number of distributions in self. 343 | """ 344 | llh = X @ self.kmu.T 345 | return llh + self.logCk 346 | 347 | 348 | def entropy(self): 349 | return -self.logpdf(self.mean()) 350 | 351 | def kl(self, other): 352 | mean = self.mean() 353 | return self.logpdf(mean) - other.logpdf(mean) 354 | 355 | 356 | def __repr__(self): 357 | if np.isscalar(self.k): 358 | return f"VMF(mu:{self.mu.shape}, k={self.k})" 359 | return f"VMF(mean:{self.mu.shape}, k:{self.k.shape})" 360 | 361 | 362 | if __name__ == "__main__": 363 | import matplotlib.pyplot as plt 364 | 365 | # k = np.exp(np.linspace(-5,4,1000)) 366 | # dim, n = 256, 1 367 | # C1 = LogNormConst(dim,n) 368 | # y = C1(k) 369 | 370 | # thr = C1.logI.thr 371 | # y_thr = C1(thr) 372 | 373 | 374 | # plt.figure() 375 | # plt.semilogx(k,y,label='spliced compromise') 376 | # plt.semilogx(thr,y_thr,'*',label='splice location') 377 | 378 | 379 | 380 | # plt.grid() 381 | # plt.xlabel('concentration parameter') 382 | # plt.ylabel('Von Mises-Fisher log norm. const.') 383 | # plt.legend() 384 | # plt.title(f'approximating terms: {n}') 385 | # plt.show() 386 | 387 | 388 | 389 | # dim, n = 256, 5 390 | # C5 = LogNormConst(dim,n) 391 | # y = C5(k) 392 | 393 | # thr = C5.logI.thr 394 | # y_thr = C5(thr) 395 | 396 | 397 | # plt.figure() 398 | # plt.semilogx(k,y,label='spliced compromise') 399 | # plt.semilogx(thr,y_thr,'*',label='splice location') 400 | 401 | 402 | 403 | # plt.grid() 404 | # plt.xlabel('concentration parameter') 405 | # plt.ylabel('Von Mises-Fisher log norm. const.') 406 | # plt.legend() 407 | # plt.title(f'approximating terms: {n}') 408 | # plt.show() 409 | 410 | 411 | # k = np.exp(np.linspace(-5,20,20)) 412 | # dim = 256 413 | # logC = LogNormConst(dim) 414 | # rho = logC.rho(k) 415 | # plt.semilogx(k,rho) 416 | # kk = logC.rhoinv_fast(rho) 417 | # plt.semilogx(kk,rho,'--') 418 | 419 | x0 = np.array([0.9,0.9])/1.5 420 | # vmf = VMF(x0) 421 | vmf = VMF.max_likelihood(x0) 422 | 423 | X = vmf.sample(10) 424 | plt.scatter(X[:,0],X[:,1]) 425 | plt.axis('square') 426 | plt.xlim(-1.2,1.2) 427 | plt.ylim(-1.2,1.2) 428 | plt.grid() 429 | -------------------------------------------------------------------------------- /psda/vmf_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | drng = np.random.default_rng() 4 | 5 | def rotate_to_mu(X,mu): 6 | 7 | # Rotate [1,0,...,0] to mu 8 | dim = mu.size 9 | M = np.zeros((dim,dim)) 10 | M[:,0] = mu/np.linalg.norm(mu) 11 | Q,R = np.linalg.qr(M) 12 | if R[0,0] < 0: 13 | Q = -Q 14 | Q *= np.linalg.norm(mu) 15 | return X@Q.T 16 | 17 | 18 | def sample_vmf_canonical_mu(dim, k, rng=drng): 19 | """ 20 | Generate samples from the von Mises-Fisher distribution 21 | with canonical mean, mu = [1,0,...,0] 22 | 23 | Reference: 24 | Simulation of the von Mises-Fisher distribution - Wood, 1994 25 | 26 | 27 | Inputs: 28 | dim: dimensionality of containing Euclidean space 29 | 30 | k: concentration parameter 31 | 32 | returns: one sample (dim, ) 33 | 34 | """ 35 | 36 | # VM*, step 0: 37 | b = (-2*k + np.sqrt(4*k**2 + (dim-1)**2))/(dim-1) # (eqn 4) 38 | x0 = (1 - b)/(1 + b) 39 | c = k*x0 + (dim - 1)*np.log(1 - x0**2) 40 | 41 | done = False 42 | while not done: 43 | # VM*, step 1: 44 | Z = rng.beta((dim-1)/2, (dim-1)/2) 45 | W = (1.0 - (1.0 + b)*Z)/(1.0 - (1.0 - b)*Z) 46 | 47 | # VM*, step 2: 48 | logU = np.log(rng.uniform()) 49 | done = k*W + (dim-1)*np.log(1-x0*W) - c >= logU 50 | 51 | # VM*, step 3: 52 | V = rng.normal(size=dim-1) 53 | V /= np.linalg.norm(V) 54 | 55 | X = np.append(W, V*np.sqrt(1 - W**2)) 56 | return X 57 | 58 | 59 | def sample_uniform(dim, n=None, rng=drng): 60 | randn = lambda *args: rng.normal(size=args) 61 | if n is None: 62 | r = randn(dim) 63 | return r/np.linalg.norm(r) 64 | R = randn(n, dim) 65 | return R/np.linalg.norm(R,axis=-1,keepdims=True) 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2022 5 | # Author: Niko Brummer 6 | # All Rights Reserved 7 | 8 | from distutils.core import setup 9 | from setuptools import find_packages 10 | 11 | 12 | setup( 13 | name='PSDA', 14 | version='1.0', 15 | packages=find_packages(), 16 | url='https://github.com/bsxfan/PSDA', 17 | install_requires=[ 18 | 'numpy', 19 | 'scipy', 20 | 'matplotlib', 21 | ], 22 | license='MIT', 23 | author='Niko Brummer', 24 | author_email='niko.brummer@gmail.com', 25 | description='Python implementation of Probabilistic Spherical Discriminant Analysis.' 26 | ) 27 | --------------------------------------------------------------------------------