├── requirements.txt ├── .travis.yml ├── test.py ├── LICENSE ├── README.md └── lshutils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | bokeh==0.12.6 2 | numpy==1.14.0 3 | tensorflow==1.4.1 4 | scipy==1.0.0 5 | sklearn 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | os: 2 | - linux 3 | 4 | language: python 5 | 6 | python: 7 | - "3.6" 8 | 9 | script: 10 | - python test.py 11 | 12 | #install: 13 | # - pip install -r requirements.txt 14 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import lshutils 2 | import numpy as np 3 | 4 | if __name__=="__main__": 5 | data=lshutils.Dataset('random') 6 | max_index=10_000 7 | nnn=max_index//50 8 | inputs_=data.data[:max_index,:] 9 | 10 | lshmodel=lshutils.LSH(inputs_,hash_length=16) 11 | flymodel=lshutils.flylsh(inputs_,hash_length=16,sampling_ratio=0.1,embedding_size=20*16) 12 | 13 | lshmap=lshmodel.findmAP(nnn=nnn,n_points=100) 14 | flymap=flymodel.findmAP(nnn=nnn,n_points=100) 15 | 16 | print('LSH model mAP={:.3f}'.format(lshmap)) 17 | print('Fly model mAP={:.3f}'.format(flymap)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jaiyam Sharma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build status](https://travis-ci.org/dataplayer12/Fly-LSH.svg?master)](https://travis-ci.org/dataplayer12) 2 | # Paper 3 | Code accompanying our [paper](https://arxiv.org/abs/1812.01844) **Improving Similarity Search with High-dimensional Locality sensitive hashing** 4 | 5 | # Summary 6 | We make three important contributions: 7 | 1. We present a new data independent approximate nearest neighbor (ANN) search algorithm inspired by the fruit fly olfactory circuit introduced by [Dasgupta et. al.](http://science.sciencemag.org/content/358/6364/793/tab-article-info). Named *DenseFly*, the proposed algorithm performs significantly better than several existing data independent algorithms on six benchmark datasets. (figures 2 and 3) 8 | 2. We prove several theoretical results about the original *FlyHash* as well as the proposed *DenseFly* algorithms. In particular, we show that *FlyHash* preserves rank similarity under any *Lp* norm and that *DenseFly* approximates a *SimHash* in very high dimensions at a much lower computational cost. (Lemmas 1 and 2) 9 | 3. We develop a multi-probe binning scheme for *FlyHash* and *DenseFly* algorithms, which are indispensable for practical applications of ANN algorithms. Remarkably, the proposed multi-probe binning scheme does not require additional computation over and above those used to create the high dimensional *Fly* or *DenseFly* hashes. Thus, the multi-probe versions of *FlyHash* and *DenseFly* result in a significant increase in mAP scores for a given query time. (figure 4) 10 | 11 | # Code 12 | The code for all the new algorithms described are present in one large file. Helper scripts to compare different algorithms will be added soon. 13 | -------------------------------------------------------------------------------- /lshutils.py: -------------------------------------------------------------------------------- 1 | from tensorflow.examples.tutorials.mnist.input_data import read_data_sets 2 | import tensorflow as tf 3 | import numpy as np 4 | from scipy.io import loadmat 5 | import pickle, time 6 | import os 7 | from collections import OrderedDict as odict 8 | from functools import reduce 9 | from sklearn.cluster import KMeans 10 | from bokeh.plotting import figure,output_file,output_notebook,show 11 | import bokeh 12 | 13 | class Dataset(object): 14 | def __init__(self,name,path='./datasets/'): 15 | self.path=path 16 | self.name=name.upper() 17 | if self.name=='MNIST' or self.name=='FMNIST': 18 | self.indim=784 19 | try: 20 | self.data=read_data_sets(self.path+self.name) 21 | except OSError as err: 22 | print(str(err)) 23 | raise ValueError('Try again') 24 | 25 | elif self.name=='CIFAR10': 26 | self.indim=(32,32,3) 27 | if self.name not in os.listdir(self.path): 28 | print('Data not in path') 29 | raise ValueError() 30 | elif self.name=='GLOVE': 31 | self.indim=300 32 | self.data=pickle.load(open(self.path+'glove30k.p','rb')) 33 | 34 | elif self.name=='SIFT': 35 | self.indim=128 36 | self.data=loadmat(self.path+self.name+'/siftvecs.mat')['vecs'] 37 | 38 | elif self.name=='GIST': 39 | self.indim=960 40 | self.data=loadmat(self.path+self.name+'/gistvecs.mat')['vecs'] 41 | 42 | elif self.name=='LMGIST': 43 | self.indim=512 44 | self.data=loadmat(self.path+self.name+'/LabelMe_gist.mat')['gist'] 45 | 46 | elif self.name=='RANDOM': 47 | self.indim=128 48 | self.data=np.random.random(size=(100_000,self.indim)) #np.random.randn(100_000,self.indim) 49 | 50 | def train_batches(self,batch_size=64,sub_mean=False,maxsize=-1): 51 | if self.name in ['MNIST','FMNIST']: 52 | max_=self.data.train.images.shape[0]-batch_size if maxsize==-1 else maxsize-batch_size 53 | for idx in range(0,max_,batch_size): 54 | batch_x=self.data.train.images[idx:idx+batch_size,:] 55 | batch_y=self.data.train.labels[idx:idx+batch_size] 56 | batch_y=np.eye(10)[batch_y] 57 | if sub_mean: 58 | batch_x=batch_x-batch_x.mean(axis=1)[:,None] 59 | 60 | yield batch_x,batch_y 61 | 62 | elif self.name=='CIFAR10': 63 | for batch_num in [1,2,3,4,5]: 64 | filename=self.name+'/train_batch_'+str(batch_num)+'.p' 65 | with open(filename,mode='rb') as f: 66 | features,labels=pickle.load(f) 67 | for begin in range(0,len(features),batch_size): 68 | end=min(begin+batch_size,len(features)) 69 | yield features[begin:end],labels[begin:end] 70 | 71 | elif self.name in ['GLOVE','SIFT','LMGIST','RANDOM']: 72 | max_=self.data.shape[0]-batch_size if maxsize==-1 else maxsize-batch_size 73 | for idx in range(0,max_,batch_size): 74 | batch_x=self.data[idx:idx+batch_size,:] 75 | if sub_mean: 76 | batch_x=batch_x-batch_x.mean(axis=1)[:,None] 77 | yield batch_x,None 78 | 79 | def test_set(self,maxsize=-1,sub_mean=False): 80 | #maxsize determines how many elements of test set to return 81 | if self.name in ['MNIST','FMNIST']: 82 | test_x=self.data.test.images[:maxsize] 83 | test_y=np.eye(10)[self.data.test.labels[:maxsize]] 84 | if sub_mean: 85 | test_x=test_x-test_x.mean(axis=1)[:,None] 86 | return (test_x,test_y) 87 | 88 | elif self.name=='CIFAR10': 89 | with open(self.path+self.name+'/test_batch.p',mode='rb') as f: 90 | features,labels=pickle.load(f) 91 | test_x,test_y=features[:maxsize],labels[:maxsize] 92 | if sub_mean: 93 | test_x=test_x-test_x.mean(axis=1)[:,None] 94 | return test_x,test_y 95 | 96 | elif self.name in ['GLOVE','SIFT','LMGIST','RANDOM']: 97 | test_x=self.data[:maxsize] 98 | #test_y=np.eye(10)[self.data.test.labels[:maxsize]] 99 | if sub_mean: 100 | test_x=test_x-test_x.mean(axis=1)[:,None] 101 | return (test_x,None) 102 | 103 | class LSH(object): 104 | def __init__(self,data,hash_length): 105 | """ 106 | data: Nxd matrix 107 | hash_length: scalar 108 | sampling_ratio: fraction of input dims to sample from when producing a hash 109 | (ratio of PNs that each KC samples from) 110 | embedding_size: dimensionality of projection space, m 111 | """ 112 | self.hash_length=hash_length 113 | self.data=data-np.mean(data,axis=1)[:,None] 114 | self.weights=np.random.random((data.shape[1],hash_length)) 115 | self.hashes=(self.data@self.weights)>0 116 | self.maxl1distance=2*self.hash_length 117 | 118 | def query(self,qidx,nnn,not_olap=False): 119 | L1_distances=np.sum(np.abs(self.hashes[qidx,:]^self.hashes),axis=1) 120 | #np.sum(np.bitwise_xor(self.hashes[qidx,:],self.hashes),axis=1) 121 | nnn=min(self.hashes.shape[0],nnn) 122 | if not_olap: 123 | no_overlaps=np.sum(L1_distances==self.maxl1distance) 124 | return no_overlaps 125 | 126 | NNs=L1_distances.argsort() 127 | NNs=NNs[(NNs != qidx)][:nnn] 128 | #print(L1_distances[NNs]) #an interesting property of this hash is that the L1 distances are always even 129 | return NNs 130 | 131 | def true_nns(self,qidx,nnn): 132 | sample=self.data[qidx,:] 133 | tnns=np.sum((self.data-sample)**2,axis=1).argsort()[:nnn+1] 134 | tnns=tnns[(tnns!=qidx)] 135 | if nnn0 352 | self.maxl1distance=2*self.hash_length 353 | 354 | class LSHpar_ensemble(object): 355 | def __init__(self,data,hash_length,K): 356 | self.n_models=K 357 | def _create_model(): 358 | mymodel=LSH(data,hash_length) 359 | mymodel.create_bins() 360 | 361 | class LSHensemble(object): 362 | def __init__(self,data,hash_length,K): 363 | self.models=[LSH(data,hash_length) for _ in range(K)] 364 | self.numsamples=data.shape[0] 365 | self.firstmodel=self.models[0] 366 | self.firstmodel.create_bins() 367 | for m in self.models[1:]: 368 | m.create_bins() 369 | del m.data #remove data 370 | self.timetoindex=sum([m.timetoindex for m in self.models]) 371 | 372 | def compute_recall(self,n_points,nnn,sr): 373 | sample_indices=np.random.choice(self.numsamples,n_points) 374 | recalls=[] 375 | elapsed=[] 376 | numpredicted=[] 377 | for qidx in sample_indices: 378 | start=time.time() 379 | #preds=np.array([m.query_bins(qidx,sr) for m in self.models]) 380 | predicted=self.firstmodel.query_bins(qidx,sr)#reduce(np.union1d,preds) 381 | if len(predicted) 0 454 | 455 | def create_highd_bins(self,d,rounds=1): 456 | """ 457 | This function implements a relaxed binning for FlyLSH 458 | This is only one of the many possible implementations for such a scheme 459 | d: the number of bits to match between hashes for putting them in the same bin 460 | """ 461 | self.highd_bins=self.hashes[0:1,:] #initialize hashes to first point 462 | self.highd_binstopoints,self.highd_pointstobins={},{i:[] for i in range(self.hashes.shape[0])} 463 | for round in range(rounds): 464 | for hash_idx,this_hash in enumerate(self.hashes): 465 | overlap=(self.maxl1distance-((this_hash[None,:]^self.highd_bins).sum(axis=1)))>=2*d 466 | #print(overlap.shape) 467 | if overlap.any(): 468 | indices=np.flatnonzero(overlap) 469 | #indices=indices.tolist() 470 | #print(indices) 471 | self.highd_pointstobins[hash_idx].extend(indices) 472 | for idx in indices: 473 | if idx not in self.highd_binstopoints: 474 | #print(indices,idx) 475 | self.highd_binstopoints[idx]=[] 476 | self.highd_binstopoints[idx].append(hash_idx) 477 | else: 478 | self.highd_bins=np.append(self.highd_bins,this_hash[None,:],axis=0) 479 | bin_idx=self.highd_bins.shape[0]-1 480 | self.highd_pointstobins[hash_idx].append(bin_idx) 481 | self.highd_binstopoints[bin_idx]=[hash_idx] 482 | 483 | def create_lowd_bins(self): 484 | start=time.time() 485 | self.lowd_bins=np.unique(self.lowd_hashes,axis=0) 486 | #self.num_bins=self.bins.shape[0] 487 | 488 | assignment=np.zeros(self.lowd_hashes.shape[0]) 489 | for idx,_bin in enumerate(self.lowd_bins): 490 | assignment[(self.lowd_hashes==_bin).all(axis=1)]=idx 491 | self.lowd_binstopoints={bin_idx:np.flatnonzero(assignment==bin_idx) for bin_idx in range(self.lowd_bins.shape[0])} 492 | self.lowd_pointstobins={point:int(_bin) for point,_bin in enumerate(assignment)} 493 | self.timetoindex=time.time()-start 494 | 495 | def query_lowd_bins(self,qidx,search_radius=1,order=False): 496 | if not hasattr(self,'lowd_bins'): 497 | raise ValueError('low dimensional bins for model not created') 498 | query_bin=self.lowd_bins[self.lowd_pointstobins[qidx]] 499 | valid_bins=np.flatnonzero((query_bin[None,:]^self.lowd_bins).sum(axis=1)<=2*search_radius) 500 | all_points=reduce(np.union1d,np.array([self.lowd_binstopoints[idx] for idx in valid_bins])) 501 | if order: 502 | l1distances=(self.hashes[qidx,:]^self.hashes[all_points,:]).sum(axis=1) 503 | all_points=all_points[l1distances.argsort()] 504 | return all_points 505 | 506 | def query_highd_bins(self,qidx,order=False): 507 | if not hasattr(self,'highd_bins'): 508 | raise ValueError('high dimensional bins for model not created') 509 | valid_bins=self.highd_pointstobins[qidx] 510 | all_points=reduce(np.union1d,np.array([self.highd_binstopoints[idx] for idx in valid_bins])) 511 | if order: 512 | l1distances=(self.hashes[qidx,:]^self.hashes[all_points,:]).sum(axis=1) 513 | all_points=all_points[l1distances.argsort()] 514 | return all_points 515 | 516 | def compute_query_mAP(self,n_points,search_radius=1,order=False,qtype='lowd',nnn=None): 517 | sample_indices=np.random.choice(self.hashes.shape[0],n_points) 518 | average_precisions=[] 519 | elapsed=[] 520 | numpredicted=[] 521 | ms = lambda l:(np.mean(l),np.std(l)) 522 | for qidx in sample_indices: 523 | start=time.time() 524 | if qtype=='lowd': 525 | predicted=self.query_lowd_bins(qidx,search_radius,order) 526 | elif qtype=='highd': 527 | predicted=self.query_highd_bins(qidx,order) 528 | assert len(predicted)1-sampling_ratio) #sparse projection vectors 594 | all_activations=(self.data@self.weights) 595 | threshold=0 596 | self.hashes=(all_activations>=threshold) #choose topk activations 597 | #self.dense_activations=all_activations 598 | #self.sparse_activations=self.hashes.astype(np.float32)*all_activations #elementwise product 599 | self.maxl1distance=2*self.hash_length 600 | self.lowd_hashes=all_activations.reshape((-1,hash_length,K)).sum(axis=-1) > 0 601 | 602 | class lowdflylsh(LSH): 603 | def __init__(self,data,hash_length,sampling_ratio,embedding_size): 604 | """ 605 | data: Nxd matrix 606 | hash_length: scalar 607 | sampling_ratio: fraction of input dims to sample from when producing a hash 608 | embedding_size: dimensionality of projection space, m 609 | Note that in Flylsh, the hash length and embedding_size are NOT the same 610 | whereas in usual LSH they are 611 | """ 612 | #f_bits=0.5 613 | self.hash_length=hash_length 614 | self.embedding_size=embedding_size 615 | K=embedding_size//hash_length 616 | self.data=(data-np.mean(data,axis=1)[:,None]) 617 | weights=np.random.random((data.shape[1],embedding_size)) 618 | self.weights=(weights>1-sampling_ratio) 619 | all_activations=(self.data@self.weights) 620 | self.activations=all_activations.reshape((-1,hash_length,K)).sum(axis=-1) 621 | #threshold=np.sort(self.activations,axis=1)[:,-int(f_bits*hash_length)][:,None] 622 | threshold=0 623 | self.hashes=(self.activations>=threshold) #choose topk activations 624 | self.maxl1distance=2*self.hash_length 625 | 626 | class WTAHash(flylsh): 627 | #implements Google's WTA hash 628 | def __init__(self,data,code_length,K=4): 629 | """ 630 | hash_length: code length m in the paper 631 | """ 632 | self.hash_length=code_length 633 | #K=1/wta_ratio, assuming a WTA ratio of 5% as in Fly LSH paper to make a fair comparison 634 | self.embedding_size=K*code_length 635 | self.data=data-np.mean(data,axis=1)[:,None] #this is not needed for WTAHash 636 | self.thetas=[np.random.choice(data.shape[1],K) for _ in range(code_length)] 637 | xindices=np.arange(data.shape[0],dtype=np.int32) 638 | yindices=self.data[:,self.thetas[0]].argmax(axis=1) 639 | #this line permutes the vectors with theta[0], takes the first K elements and computes 640 | #the index corresponding to max element 641 | 642 | this_hash=np.zeros((data.shape[0],K),dtype=np.bool) # a K dim binary vector for each data point 643 | this_hash[xindices,yindices]=True #set the positions corresponding to argmax to True 644 | self.hashes=this_hash[:] 645 | 646 | for t in self.thetas[1:]: 647 | this_hash=np.zeros((data.shape[0],K),dtype=np.bool) 648 | yindices=self.data[:,t].argmax(axis=1) #same as line 162 above 649 | this_hash[xindices,yindices]=True 650 | self.hashes=np.concatenate((self.hashes,this_hash),axis=1) 651 | #concatenate all m, K dimensional binary hashes, this is a 652 | #one hot encoded version of step 2 (C_X) in Algorithm 1 of the paper. 653 | #This can also be implemented exactly as shown in the paper. I chose this way 654 | #as it allows us to use existing functions of LSH object to find mAP 655 | #self.tokens=np.sort(self.hashes.argsort(axis=1)[:,-self.hash_length:],axis=1) 656 | self.maxl1distance=2*self.hash_length 657 | 658 | class FlyWTA(LSH): 659 | def __init__(self,data,hash_length,sampling_ratio,K): 660 | """ 661 | data: Nxd matrix 662 | hash_length: scalar 663 | sampling_ratio: fraction of input dims to sample from when producing a hash 664 | embedding_size: dimensionality of projection space, m 665 | Note that in Flylsh, the hash length and embedding_size are NOT the same 666 | whereas in usual LSH they are 667 | """ 668 | self.hash_length=hash_length 669 | self.embedding_size=K*hash_length#embedding_size 670 | #K=embedding_size//hash_length 671 | self.data=(data-np.mean(data,axis=1)[:,None]) 672 | 673 | num_projections=int(sampling_ratio*data.shape[1]) 674 | weights=np.random.random((data.shape[1],self.embedding_size)) 675 | yindices=np.arange(weights.shape[1])[None,:] 676 | xindices=weights.argsort(axis=0)[-num_projections:,:] 677 | self.weights=np.zeros_like(weights,dtype=np.bool) 678 | self.weights[xindices,yindices]= True#sparse projection vectors 679 | 680 | all_activations=(self.data@self.weights) 681 | 682 | self.thetas=[np.random.choice(all_activations.shape[1],K) for _ in range(self.hash_length)] 683 | xindices=np.arange(all_activations.shape[0],dtype=np.int32) 684 | yindices=all_activations[:,self.thetas[0]].argmax(axis=1) 685 | 686 | this_hash=np.zeros((all_activations.shape[0],K),dtype=np.bool) # a K dim binary vector for each data point 687 | this_hash[xindices,yindices]=True #set the positions corresponding to argmax to True 688 | self.hashes=this_hash[:] 689 | 690 | for t in self.thetas[1:]: 691 | this_hash=np.zeros((all_activations.shape[0],K),dtype=np.bool) 692 | yindices=all_activations[:,t].argmax(axis=1) #same as line 162 above 693 | this_hash[xindices,yindices]=True 694 | self.hashes=np.concatenate((self.hashes,this_hash),axis=1) 695 | 696 | self.maxl1distance=2*self.hash_length 697 | self.lowd_hashes=all_activations.reshape((-1,hash_length,K)).sum(axis=-1) > 0 698 | 699 | 700 | class WTAHash2(LSH): 701 | #implements Google's WTA hash 702 | def __init__(self,data,code_length,K=4): 703 | """ 704 | hash_length: code length m in the paper 705 | """ 706 | self.hash_length=code_length 707 | self.embedding_size=K*code_length 708 | self.data=data-np.mean(data,axis=1)[:,None] #this is not needed for WTAHash 709 | n_cycles= self.embedding_size//self.data.shape[1] +(self.embedding_size%self.data.shape[1]>0) 710 | 711 | self.perms=[np.random.permutation(data.shape[1]) for _ in range(n_cycles)] 712 | 713 | self.thetas=[p[idx:idx+K] for p in self.perms for idx in range(0,len(p),K)][:self.hash_length] 714 | #print(len(self.thetas)) 715 | xindices=np.arange(data.shape[0],dtype=np.int32) 716 | yindices=self.data[:,self.thetas[0]].argmax(axis=1) 717 | 718 | this_hash=np.zeros((data.shape[0],K),dtype=np.bool) # a K dim binary vector for each data point 719 | this_hash[xindices,yindices]=True #set the positions corresponding to argmax to True 720 | self.hashes=this_hash[:] 721 | 722 | for t in self.thetas[1:]: 723 | this_hash=np.zeros((data.shape[0],K),dtype=np.bool) 724 | yindices=self.data[:,t].argmax(axis=1) #same as line 162 above 725 | this_hash[xindices,yindices]=True 726 | self.hashes=np.concatenate((self.hashes,this_hash),axis=1) 727 | self.maxl1distance=2*self.hash_length 728 | 729 | class AEflylsh(LSH): 730 | #implements Fly LSH where weights are pre-specified 731 | #The weights passed to init should be learnt from an autoencoder 732 | def __init__(self,data,hash_length,sampling_ratio,weights,local=False): 733 | """ 734 | data: Nxd matrix 735 | hash_length: scalar 736 | sampling_ratio: fraction of input dims to sample from when producing a hash 737 | embedding_size: dimensionality of projection space, m 738 | weights: weights learnt from an autoencoder 739 | weights should have the same dimensionality as projection space (m) 740 | """ 741 | #assert weights.shape[1]==embedding_size, f'Expects a {embedding_size} dim embedding from {weights.shhape[1]} dim weights' 742 | self.hash_length=hash_length 743 | self.embedding_size=weights.shape[1] 744 | self.data=(data-np.mean(data,axis=1)[:,None]) 745 | if local: 746 | self.weights=(weights>=np.sort(weights,axis=0)[-int(weights.shape[0]*sampling_ratio),:][None,:]) #sparse projection vectors 747 | else: 748 | n_weights=int(np.prod(weights.shape)*sampling_ratio) 749 | self.weights=(weights>=np.sort(weights,axis=None)[-n_weights]) #sparse projection vectors 750 | all_activations=(self.data@self.weights) 751 | threshold=np.sort(all_activations,axis=1)[:,-hash_length][:,None] 752 | self.hashes=(all_activations>=threshold) #choose top k activations 753 | self.maxl1distance=2*self.hash_length 754 | 755 | class AutoEncoder(object): 756 | def __init__(self,nodes,is_sparse=False,rho=0.5,beta=2,dropconnect=False): 757 | """ 758 | nodes: a list [in_dim,n_hidden] 759 | is_sparse: bool 760 | rho: if dropout is False, sparsity factor (fraction of weights turned on) 761 | otherwise, see below 762 | beta: weight of kl_divergence loss 763 | total_loss=reconstruction_loss+beta*kl_divergence 764 | dropout: if true, rho fraction of hidden units are dropped out 765 | """ 766 | dropout=False 767 | self.in_dim=nodes[0] 768 | self.n_hidden=nodes[1] 769 | self.epochs=5 770 | self.learn_rate=[1e-3/(2**(e//3)) for e in range(self.epochs)] 771 | self.batch_size=32 772 | self.inputs_,self.targets,self.lr=self.get_placeholders() 773 | if is_sparse: 774 | self.rho=rho 775 | self.encode_weights=tf.Variable(tf.random_uniform([self.in_dim,self.n_hidden],minval=0,maxval=10*self.rho)) 776 | self.decode_weights=tf.Variable(tf.truncated_normal([self.in_dim,self.n_hidden],stddev=0.05)) 777 | else: 778 | self.encode_weights=tf.Variable(tf.truncated_normal([self.in_dim,self.n_hidden],stddev=0.05)) 779 | self.decode_weights=self.encode_weights 780 | 781 | #biases=tf.Variable(tf.zeros([self.n_hidden])) #we don't want to use biases 782 | if dropconnect: 783 | self.encode_weights=tf.nn.dropout(self.encode_weights,keep_prob=tf.constant(rho))*rho 784 | 785 | hlayer=tf.matmul(self.inputs_,self.encode_weights) 786 | self.hlayer=tf.nn.relu(hlayer) #hlayer: relu for MNIST, sigmoid for GloVE 787 | if dropout: 788 | self.hlayer=tf.nn.dropout(self.hlayer,keep_prob=tf.constant(rho)) 789 | 790 | output=tf.matmul(self.hlayer,tf.transpose(self.decode_weights)) 791 | self.output=tf.nn.sigmoid(output) 792 | self.recon_loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.targets,logits=self.output)) #LAST CHANGE HERE 793 | normed= lambda w: (w-tf.reduce_min(w))/(tf.reduce_max(w)-tf.reduce_min(w)) 794 | #normalize things to be between 0 and 1 795 | 796 | if is_sparse: 797 | rho_hat=tf.reduce_mean(normed(self.encode_weights)) #axis=0 798 | self.kl_loss=self.find_KL_div(self.rho,rho_hat) 799 | #self.kl_loss=tf.nn.l2_loss(self.weights1) 800 | self.cost=self.recon_loss+beta*self.kl_loss 801 | else: 802 | self.cost=self.recon_loss 803 | 804 | self.opt=tf.train.AdamOptimizer(self.lr).minimize(self.cost) 805 | 806 | def find_KL_div(self,rho,rho_hat): 807 | return rho*tf.log(rho)-rho*tf.log(rho_hat)+(1-rho)*tf.log(1-rho)-(1-rho)*tf.log(1-rho_hat) 808 | 809 | def get_placeholders(self): 810 | inputs_=tf.placeholder(tf.float32,[None,self.in_dim]) 811 | targets=tf.placeholder(tf.float32,[None,self.in_dim]) 812 | lr=tf.placeholder(tf.float32) 813 | return inputs_,targets,lr 814 | 815 | def train(self,data,maxsize=-1,show_recon=False): 816 | """data: a Dataset object""" 817 | with tf.Session() as sess: 818 | sess.run(tf.global_variables_initializer()) 819 | count=0 820 | for e in range(self.epochs): 821 | for batch_x,_ in data.train_batches(self.batch_size,sub_mean=True, maxsize=maxsize): 822 | count+=1 823 | feed={self.inputs_:batch_x,self.targets:batch_x,self.lr:self.learn_rate[e]} 824 | _=sess.run([self.opt],feed_dict=feed) 825 | 826 | #print(f'Epoch {e+1}/{self.epochs}, recon_loss={rl}') 827 | 828 | all_weights=self.encode_weights.eval() 829 | 830 | #all_inputs=data.data.train.images[:maxsize] if data.name in ['MNIST','FMNIST'] else data.data[:maxsize] 831 | #all_inputs=all_inputs-all_inputs.mean(axis=1)[:,None] 832 | 833 | #feed={self.inputs_:all_inputs} 834 | #average_activations=sess.run(tf.reduce_mean(self.hlayer,axis=0),feed_dict=feed) 835 | 836 | #average_activations=average_activations[None,:]/average_activations.max() 837 | #print(average_activations) 838 | #all_weights=-np.abs(np.repeat(average_activations,data.indim,axis=0)-np.maximum(0.,all_inputs).sum(axis=0)[:,None].astype(np.float32)) 839 | if show_recon: 840 | test_x,_=data.test_set(maxsize=10,sub_mean=True) 841 | feed={self.inputs_:test_x} 842 | recons=sess.run(self.output,feed_dict=feed) 843 | return (all_weights,(test_x,recons)) 844 | 845 | return all_weights 846 | 847 | class WTAAutoEncoder(AutoEncoder): 848 | def __init__(self,nodes,rho=0.1): 849 | """ 850 | nodes: a list [in_dim,n_hidden] 851 | rho: sparsity factor (fraction of top activations kept during forward pass) 852 | """ 853 | self.in_dim=nodes[0] 854 | self.n_hidden=nodes[1] 855 | self.epochs=5 856 | self.learn_rate=[1e-3/(2**(e//3)) for e in range(self.epochs)] 857 | self.batch_size=32 858 | self.inputs_,self.targets,self.lr=self.get_placeholders() 859 | self.rho=rho 860 | self.topk=int(self.n_hidden*self.rho) 861 | self.encode_weights=tf.Variable(tf.truncated_normal([self.in_dim,self.n_hidden],stddev=0.05)) 862 | self.decode_weights=tf.Variable(tf.truncated_normal([self.in_dim,self.n_hidden],stddev=0.05)) 863 | 864 | #biases=tf.Variable(tf.zeros([self.n_hidden])) #we don't want to use biases 865 | hlayer=tf.matmul(self.inputs_,self.encode_weights) 866 | hlayer=tf.nn.relu(hlayer) #hlayer: relu for MNIST, sigmoid for GloVE 867 | 868 | thresholds,_=tf.nn.top_k(hlayer,k=self.topk,sorted=True) 869 | thresholds=thresholds[:,-1] 870 | mask=(hlayer-tf.expand_dims(thresholds,1))>=0 871 | self.hlayer=hlayer*tf.cast(mask,dtype=tf.float32) 872 | 873 | self.output=tf.matmul(self.hlayer,tf.transpose(self.decode_weights)) 874 | 875 | self.cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.targets,logits=self.output)) #LAST CHANGE HERE 876 | 877 | self.opt=tf.train.AdamOptimizer(self.lr).minimize(self.cost) 878 | 879 | 880 | colors={'LSH':'red','AELSH':'black','Fly':'green','SparseAEFly':'blue',\ 881 | 'AEFly':'orange','AE_local':'red','AE_global':'green','DenseFly':'blue',\ 882 | 'WTA':'teal','ExWTA':'black','very_sparse_RP':'black','lowfly':'blue','PQ':'pink','FlyWTA':'pink'} 883 | 884 | def plot_results(all_results,hash_lengths=None,keys=None,name='data',location='./',metric='mAP'): 885 | if hash_lengths is None: 886 | hash_lengths=sorted(all_results.keys()) 887 | 888 | if keys is None: 889 | keys=list(all_results[hash_lengths[0]].keys()) 890 | 891 | Lk=len(keys) 892 | fmt= lambda mk:mk.join([k for k in keys]) 893 | 894 | global colors 895 | 896 | if metric=='mAP': 897 | curve_ylabel='mean Average Precision (mAP)' 898 | min_y=0 899 | mean= lambda x,n:np.mean(all_results[x][n]) 900 | stdev=lambda x,n:np.std(all_results[x][n]) 901 | elif metric=='auprc': 902 | curve_ylabel='Area under precision recall curve' 903 | min_y=0 904 | n_trials=len(all_results[hash_lengths[0]][keys[0]]) 905 | all_precisions={hl:{k:[all_results[hl][k][i][0] for i in range(n_trials)] for k in keys} for hl in hash_lengths} 906 | all_recalls={hl:{k:[all_results[hl][k][i][1]/np.max(all_results[hl][k][i][1]) for i in range(n_trials)] for k in keys} for hl in hash_lengths} 907 | auprc= lambda hl,k,i: np.sum(np.gradient(all_recalls[hl][k][i])*all_precisions[hl][k][i]) 908 | mean= lambda hl,k:np.mean([auprc(hl,k,i) for i in range(n_trials)]) 909 | stdev=lambda hl,k:np.std([auprc(hl,k,i) for i in range(n_trials)]) #np.std(np.array(all_MAPs[x][n]),axis=0) 910 | elif metric=='auroc': 911 | curve_ylabel='Area under Receiver Operating Characteristic (ROC) curve' 912 | min_y=0.5 913 | n_trials=len(all_results[hash_lengths[0]][keys[0]]) 914 | all_tprs={hl:{k:[all_results[hl][k][i][1] for i in range(n_trials)] for k in keys} for hl in hash_lengths} 915 | all_fprs={hl:{k:[all_results[hl][k][i][0]/np.max(all_results[hl][k][i][0]) for i in range(n_trials)] for k in keys} for hl in hash_lengths} 916 | 917 | auroc= lambda hl,k,i: np.sum(np.gradient(all_fprs[hl][k][i])*all_tprs[hl][k][i]) 918 | mean= lambda hl,k:np.mean([auroc(hl,k,i) for i in range(n_trials)]) 919 | stdev=lambda hl,k:np.std([auroc(hl,k,i) for i in range(n_trials)]) #np.std(np.array(all_MAPs[x][n]),axis=0) 920 | 921 | p=figure(x_range=[str(h) for h in hash_lengths],title=f'{fmt(",")} on {name}') 922 | delta=0.5/(Lk+1) 923 | deltas=[delta*i for i in range(-Lk,Lk)][1::2] 924 | assert len(deltas)==Lk, 'Bad luck' 925 | 926 | x_axes=np.sort(np.array([[x+d for d in deltas] for x in range(1,1+len(hash_lengths))]),axis=None) 927 | means=[mean(hashl,name) for name,hashl in zip(keys*len(hash_lengths),sorted(hash_lengths*Lk))] 928 | stds=[stdev(hashl,name) for name,hashl in zip(keys*len(hash_lengths),sorted(hash_lengths*Lk))] 929 | 930 | for i in range(len(hash_lengths)): 931 | for j in range(Lk): 932 | p.vbar(x=x_axes[Lk*i+j], width=delta, bottom=0, top=means[Lk*i+j] , color=colors[keys[j]],legend=keys[j]) 933 | 934 | err_xs=[[i,i] for i in x_axes] 935 | err_ys= [[m-s,m+s] for m,s in zip(means,stds)] 936 | p.y_range.bounds=(min_y,np.floor(10*max(means))/10 + 0.1) 937 | p.multi_line(err_xs, err_ys,line_width=2, color='black',legend='stdev') 938 | p.legend.location='top_left' 939 | p.legend.click_policy='hide' 940 | p.xaxis.axis_label='Hash length (k)/Code length (bits)' 941 | p.yaxis.axis_label= curve_ylabel 942 | output_file(f'{location+fmt("_")}_{name}.html') 943 | show(p) 944 | 945 | def plothlcurve(all_results,hl,name='data',location='./',metric='prc'): 946 | global colors 947 | 948 | assert hl in all_results.keys(), 'Provide a valid hash length' 949 | keys=list(all_results[hl].keys()) 950 | n_trials=len(all_results[hl][keys[0]]) 951 | 952 | if metric=='prc': 953 | all_ys={k:np.mean([all_results[hl][k][i][0] for i in range(n_trials)],axis=0) for k in keys} 954 | all_xs={k:np.mean([all_results[hl][k][i][1] for i in range(n_trials)],axis=0) for k in keys} 955 | all_xs={k:all_xs[k]/np.max(all_xs[k]) for k in keys} 956 | title=f'Precision recall curves for {name}, hash length={hl}' 957 | xlabel='Recall' 958 | ylabel='Precision' 959 | legend_location='top_right' 960 | elif metric=='roc': 961 | all_xs={k:np.mean([all_results[hl][k][i][0] for i in range(n_trials)],axis=0) for k in keys} 962 | all_ys={k:np.mean([all_results[hl][k][i][1] for i in range(n_trials)],axis=0) for k in keys} 963 | all_xs={k:all_xs[k]/np.max(all_xs[k]) for k in keys} 964 | 965 | title=f'ROC curves for {name}, hash length={hl}' 966 | xlabel='False Positive rate' 967 | ylabel='True Positive rate' 968 | legend_location='bottom_right' 969 | auc= lambda k: np.sum(np.gradient(all_xs[k])*all_ys[k]) 970 | aucs={k:auc(k) for k in keys} 971 | 972 | p=figure(title=title) 973 | for k in keys: 974 | leg='{}({:.2f})'.format(k,0.01*np.floor(100*np.mean(aucs[k]))) 975 | p.line(all_xs[k],all_ys[k],line_width=2,color=colors[k],legend=leg) 976 | 977 | if metric=='roc': 978 | p.line(np.arange(100)/100.0,np.arange(100)/100.0,line_width=1,line_dash='dashed',legend='random (0.5)') 979 | #show random classifier line for ROC metrics 980 | 981 | p.legend.location=legend_location 982 | p.legend.click_policy='hide' 983 | p.xaxis.axis_label=xlabel 984 | p.yaxis.axis_label=ylabel 985 | 986 | output_file(f'{location}{metric}_{name}_{hl}.html') 987 | show(p) 988 | 989 | def parse_computed(foldername): 990 | allfiles=os.listdir(foldername) 991 | mnames=['LSH','Fly','WTA'] 992 | fmlname={'LSH':'LSH','Fly':'Fly','WTA':'WTA'} 993 | #mnames=['lsh','fly','WTA'] 994 | #fmlname={'lsh':'LSH','fly':'Fly','WTA':'WTA'} 995 | hash_lengths=[4,8,16,24,32,48,64,96,128,192,256] 996 | allmaps={hl:{} for hl in hash_lengths} 997 | for hl in hash_lengths: 998 | for mnm in mnames: 999 | allmaps[hl][fmlname[mnm]]=[] 1000 | possible=[f for f in allfiles if mnm+str(hl)+'_' in f] 1001 | for fnm in possible: 1002 | f=open(foldername+fnm,'r') 1003 | allmaps[hl][fmlname[mnm]].append(float(f.read())) 1004 | return allmaps 1005 | 1006 | 1007 | if __name__=='__main__': 1008 | 1009 | data=Dataset('mnist') 1010 | input_dim=784 #d 1011 | max_index=10000 1012 | sampling_ratio=0.10 1013 | nnn=200 #number of nearest neighbours to compare, 2% of max_index as in paper 1014 | hash_lengths=[2,4,8,12,16,20,24,28,32] 1015 | inputs_=data.data.train.images[:max_index] 1016 | all_MAPs={} 1017 | for hash_length in hash_lengths: #k 1018 | embedding_size= int(20*hash_length) #int(10*input_dim) #20k or 10d 1019 | all_MAPs[hash_length]={} 1020 | all_MAPs[hash_length]['Fly']=[] 1021 | all_MAPs[hash_length]['LSH']=[] 1022 | for _ in range(10): 1023 | fly_model=flylsh(inputs_,hash_length,sampling_ratio,embedding_size) 1024 | fly_mAP=fly_model.findmAP(nnn,1000) 1025 | msg='mean average precision is equal to {:.2f}'.format(fly_mAP) 1026 | #_=os.system('say "'+msg+'"') #works only on mac 1027 | all_MAPs[hash_length]['Fly'].append(fly_mAP) 1028 | 1029 | dense_model=LSH(inputs_,hash_length) 1030 | dense_mAP=dense_model.findmAP(nnn,1000) 1031 | all_MAPs[hash_length]['LSH'].append(dense_mAP) 1032 | msg='mean average precision is equal to {:.2f}'.format(dense_mAP) 1033 | #_=os.system('say "'+msg+'"') #works only on mac 1034 | print('Both models ran successfully') 1035 | print(f'{hash_length} done') 1036 | 1037 | print(all_MAPs) --------------------------------------------------------------------------------