├── .gitignore ├── DeepEncoderClustering.py ├── README.md ├── main.py ├── train.py └── utils ├── datasetscls.py ├── extractfeatures.py ├── metrics.py ├── params.py ├── transfermodels.py └── utilityfn.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | *.txt 3 | *.db 4 | *.docx 5 | *.xlsx 6 | *.doc 7 | *.pdf 8 | *.sav 9 | *.R 10 | *.spv 11 | *.sps 12 | *.png 13 | *.pth 14 | *.npy 15 | -------------------------------------------------------------------------------- /DeepEncoderClustering.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from typing import Optional, List 5 | import math 6 | 7 | torch.cuda.empty_cache() 8 | torch.cuda.memory_summary(device=None, abbreviated=False) 9 | 10 | 11 | class autoencoder(nn.Module): 12 | def __init__( 13 | self, 14 | inputsize: int, 15 | dims: List[int]): 16 | """ 17 | 18 | """ 19 | 20 | super(autoencoder, self).__init__() 21 | 22 | self.inputsize = inputsize 23 | 24 | encmodules = [] 25 | encmodules.append(nn.Linear(inputsize, dims[0])) 26 | for index in range(len(dims)-1): 27 | encmodules.append(nn.ReLU(True)) 28 | encmodules.append(nn.Linear(dims[index], dims[index+1])) 29 | self.encoder = nn.Sequential(*encmodules) 30 | 31 | decmodules = [] 32 | for index in range(len(dims) - 1, 0, -1): 33 | decmodules.append(nn.Linear(dims[index], dims[index-1])) 34 | decmodules.append(nn.ReLU(True)) 35 | decmodules.append(nn.Linear(dims[0], inputsize)) 36 | self.decoder = nn.Sequential(*decmodules) 37 | 38 | self.init_weights() 39 | 40 | def forward( 41 | self, 42 | x 43 | ): 44 | """ 45 | """ 46 | x = self.encoder(x) 47 | x = self.decoder(x) 48 | return x 49 | 50 | def get_encoder( 51 | self 52 | ): 53 | """ 54 | """ 55 | return self.encoder 56 | 57 | def init_weights( 58 | self 59 | ): 60 | """ 61 | """ 62 | #glorot_uniform . Draws samples from a uniform distribution within [-limit, limit] , where limit = sqrt(6 / (fan_in + fan_out)) 63 | def func(m): 64 | if isinstance(m, nn.Linear): 65 | torch.manual_seed(4) 66 | limit = math.sqrt(6/(m.in_features + m.out_features)) 67 | torch.nn.init.uniform_(m.weight, -limit, limit) 68 | m.bias.data.fill_(0.00) 69 | 70 | self.encoder.apply(func) 71 | self.decoder.apply(func) 72 | 73 | 74 | 75 | class clustering(nn.Module): 76 | 77 | def __init__( 78 | self, 79 | n_clusters:int, 80 | input_shape:int, 81 | alpha: float = 1.0, 82 | cluster_centers: Optional[torch.Tensor] = None 83 | ) -> None: 84 | """ 85 | """ 86 | 87 | super(clustering, self).__init__() 88 | 89 | self.n_clusters = n_clusters 90 | self.alpha = alpha 91 | self.input_shape = input_shape 92 | 93 | if cluster_centers is None: 94 | initial_cluster_centers = torch.zeros(self.n_clusters, self.input_shape, dtype=torch.float32) 95 | nn.init.xavier_uniform_(initial_cluster_centers) 96 | else: 97 | initial_cluster_centers = cluster_centers 98 | self.clustcenters = nn.Parameter(initial_cluster_centers) 99 | 100 | 101 | 102 | def forward(self, inputs): 103 | """ student t-distribution, as same as used in t-SNE algorithm. 104 | q_ij = 1/(1+dist(x_i, u_j)^2), then normalize it. 105 | Arguments: 106 | inputs: the variable containing data, shape=(n_samples, n_features) 107 | Return: 108 | q: student's t-distribution, or soft labels for each sample. shape=(n_samples, n_clusters) 109 | """ 110 | q = 1.0 / (1.0 + (torch.sum(torch.square(torch.unsqueeze(inputs, axis=1) - self.clustcenters), axis=2) / self.alpha)) 111 | q **= (self.alpha + 1.0) / 2.0 112 | q = torch.transpose(torch.transpose(q, 0, 1) / torch.sum(q, axis=1), 0, 1) 113 | return q 114 | 115 | @staticmethod 116 | def target_distribution(q): 117 | weight = q ** 2 / q.sum(0) 118 | return (weight.T / weight.sum(1)).T 119 | 120 | 121 | class DEC(nn.Module): 122 | def __init__( 123 | self, 124 | dims: List[int], 125 | inputsize: int, 126 | n_clusters: int): 127 | """ 128 | """ 129 | super(DEC, self).__init__() 130 | self.AE = autoencoder(inputsize, dims) 131 | self.clustlayer = clustering(n_clusters, dims[-1]) 132 | 133 | self.model = nn.Sequential( 134 | self.AE.encoder, 135 | self.clustlayer) 136 | 137 | def forward( 138 | self, 139 | inputs): 140 | """ 141 | """ 142 | 143 | X = self.model(inputs) 144 | return X 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEC_Pytorch_tutorial 2 | 3 | This is a pyTorch implementation of a modified version of the Deep Embedded Clustering (DEC) algorithm described in "Unsupervised Deep Embedding for Clustering Analysis" of Junyuan Xie, Ross Girshick, Ali Farhadi (https://arxiv.org/abs/1511.06335). 4 | 5 | Here is the link to the Medium article explaining the code https://medium.com/p/bd2c9d51c80f/edit 6 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from sklearn.cluster import KMeans 7 | from utils import params 8 | import os 9 | 10 | from utils import datasetscls 11 | import DeepEncoderClustering 12 | from train import training 13 | from utils import metrics 14 | from utils.utilityfn import getinputsize 15 | from train import training, testing, pretraining 16 | 17 | torch.cuda.empty_cache() 18 | torch.cuda.memory_summary(device=None, abbreviated=False) 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | if __name__ == "__main__": 23 | # setting the hyper parameters 24 | import argparse 25 | 26 | parser = argparse.ArgumentParser(description='train', 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('--datadir', default='/data/stl/fc1/') 29 | parser.add_argument('--batch_size', default=256, type=int) 30 | parser.add_argument('--input_size', default=None) 31 | parser.add_argument('--n_clusters', default=10, type=int) 32 | parser.add_argument('--epochs', default=1000, type=int) 33 | parser.add_argument('--pretrain_epochs', default=10, type=int) 34 | parser.add_argument('--update_interval', default=30, type=int) 35 | parser.add_argument('--tol', default=0.001, type=float) 36 | parser.add_argument('--save_dir', default='./save/') 37 | parser.add_argument('--save_model', default=True, type=bool) 38 | parser.add_argument('--save_intermodel', default=False, type=bool) 39 | 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | #generator = datasetscls.customGenerator(args.datadir) 44 | generator = datasetscls.STLGenerator(args.datadir) 45 | try: 46 | assert len(generator)>=1 47 | except AssertionError: 48 | print('There should at least one input file.') 49 | raise 50 | 51 | if not args.input_size: 52 | args.input_size = getinputsize(generator) 53 | 54 | DEC = DeepEncoderClustering.DEC(inputsize=args.input_size, dims=params.dims, n_clusters=args.n_clusters) 55 | DEC.to(device) 56 | 57 | ae_weights = f'{args.save_dir}/models/stl/' 58 | if not os.path.exists(ae_weights): 59 | os.makedirs(ae_weights) 60 | 61 | if not os.path.exists(ae_weights+'ae_weights.pth'): 62 | pretraining(model=DEC, dbgenerator=generator, savepath=ae_weights, batch_size=args.batch_size, epochs=args.pretrain_epochs) 63 | else: 64 | DEC.AE.load_state_dict(torch.load(ae_weights+'ae_weights.pth')) 65 | 66 | DEC.train() # Set model to training mode 67 | 68 | with torch.no_grad(): 69 | print('Initializing cluster centers with k-means. number of clusters %s' % args.n_clusters) 70 | 71 | allfeatures = [] 72 | for _, filexy in enumerate(generator): 73 | if isinstance(filexy, tuple) and len(filexy) == 2: 74 | filex, filey = filexy 75 | filex = filex.to(device) 76 | filey = filey.to(device) 77 | else: 78 | filex = filexy 79 | filex = filex.to(device) 80 | allfeatures.append(DEC.AE.encoder(filex).clone().detach().cpu()) 81 | 82 | 83 | kmeans = KMeans(n_clusters=args.n_clusters, n_init=20) 84 | y_pred_last = kmeans.fit_predict(torch.cat(allfeatures)) 85 | seedfeatures, seedlabels=None, None 86 | 87 | clustcenters = torch.tensor(kmeans.cluster_centers_, dtype=torch.float, requires_grad=True) 88 | clustcenters = clustcenters.to(device) 89 | 90 | DEC.state_dict()["clustlayer.clustcenters"].copy_(clustcenters) 91 | 92 | criterion = nn.KLDivLoss(reduction='batchmean') 93 | optimizer = torch.optim.SGD(DEC.model.parameters(), lr=0.01, momentum=0.9) 94 | 95 | delta_label = None 96 | for epoch in range(args.epochs): 97 | loss = 0 98 | for _, filexy in enumerate(generator): 99 | if isinstance(filexy, tuple) and len(filexy) == 2: 100 | filex, filey = filexy 101 | filex = filex.to(device) 102 | filey = filey.to(device) 103 | 104 | elif not isinstance(filexy, tuple): 105 | filex = filexy 106 | filex = filex.to(device) 107 | filey = None 108 | 109 | train_loss = training(model=DEC, optimizer=optimizer, criterion=criterion, y_pred_last=y_pred_last, x=filex, y=filey, batch_size=args.batch_size, update_interval=args.update_interval, device = device) 110 | loss += train_loss 111 | 112 | if filey is not None: 113 | y_pred, acty = testing(model=DEC, dbgenerator=generator, device=device) 114 | acc = np.round(metrics.acc(acty.clone().detach().cpu().numpy(), y_pred.clone().detach().cpu().numpy()), 5) 115 | nmi = np.round(metrics.nmi(acty.clone().detach().cpu().numpy().squeeze(), y_pred.clone().detach().cpu().numpy()), 5) 116 | ari = np.round(metrics.ari(acty.clone().detach().cpu().numpy().squeeze(), y_pred.clone().detach().cpu().numpy()), 5) 117 | loss = np.round(loss/len(generator), 5) 118 | print('epoch %d: acc = %.5f, nmi = %.5f, ari = %.5f' % (epoch, acc, nmi, ari), ' ; loss=', loss) 119 | else: 120 | y_pred = testing(model=DEC, dbgenerator=generator, device=device, return_truth=False) 121 | nmi = np.round(metrics.nmi(y_pred_last, y_pred.clone().detach().cpu().numpy()), 5) 122 | ari = np.round(metrics.ari(y_pred_last, y_pred.clone().detach().cpu().numpy()), 5) 123 | loss = np.round(loss/len(generator), 5) 124 | print('epoch %d: nmi = %.5f, ari = %.5f' % (epoch, nmi, ari), ' ; loss=', loss) 125 | 126 | delta_label = np.sum(y_pred_last!= y_pred.clone().detach().cpu().numpy()) / y_pred.shape[0] 127 | if args.tol is not None and delta_label < args.tol: 128 | print('delta_label ', delta_label, '< tol ', args.tol) 129 | print('Reached tolerance threshold. Stopping training.') 130 | break 131 | 132 | y_pred_last = y_pred.detach().clone().cpu().numpy() 133 | 134 | if args.save_intermodel: 135 | torch.save(DEC.state_dict(), f'{args.save_dir}/models/dec_weights_%s_epoch%s.pth'%(args.n_clusters, epoch)) 136 | 137 | if args.save_model: 138 | torch.save(DEC.state_dict(), f'{args.save_dir}/models/dec_weights_%s.pth'%(args.n_clusters)) 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader, RandomSampler 4 | import utils.datasetscls as datasetscls 5 | import utils.metrics as metrics 6 | import os 7 | 8 | def pretraining( 9 | model:torch.nn.Module, 10 | dbgenerator:object, 11 | batch_size: int=256, 12 | epochs: int=10, 13 | savepath: str = './save/models/', 14 | device = 'cuda:0', 15 | savemodel: bool=True 16 | ): 17 | 18 | criterion = torch.nn.MSELoss() 19 | optimizer = torch.optim.Adam(model.AE.parameters(), lr=1e-3) 20 | 21 | for epoch in range(epochs): 22 | loss = 0 23 | count = 0 24 | model.AE.train() # Set model to training mode 25 | 26 | for _, filexy in enumerate(dbgenerator): 27 | if isinstance(filexy, tuple) and len(filexy) == 2: 28 | filex, _ = filexy 29 | else: 30 | filex = filexy 31 | 32 | 33 | dataset = datasetscls.customDataset(filex) 34 | sampler = RandomSampler(dataset) 35 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) 36 | 37 | for _, batch in enumerate(dataloader): 38 | if isinstance(batch, tuple) and len(batch) == 2: 39 | x, _ = batch 40 | x = x.to(device) 41 | else: 42 | x = batch 43 | x = x.to(device) 44 | 45 | 46 | optimizer.zero_grad() 47 | # track history if only in train 48 | with torch.set_grad_enabled(True): 49 | # forward 50 | outputs = model.AE(x) 51 | train_loss = criterion(outputs, x) 52 | 53 | # backward 54 | train_loss.backward() 55 | optimizer.step() 56 | 57 | 58 | # add the mini-batch training loss to epoch loss 59 | loss += train_loss.item() 60 | count+=1 61 | 62 | # compute the epoch training loss 63 | loss = loss / count 64 | print(f'epoch {epoch+1},loss = {loss:.8f}') 65 | 66 | if savemodel: 67 | if not os.path.exists(savepath): 68 | os.mkdir(savepath) 69 | torch.save(model.AE.state_dict(), os.path.join(savepath,'ae_weights.pth')) 70 | 71 | 72 | def training( 73 | model:torch.nn.Module, 74 | optimizer:torch.optim, 75 | criterion:torch.nn, 76 | y_pred_last:float, 77 | x:torch.tensor, 78 | y:torch.tensor=None, 79 | batch_size:int=256, 80 | update_interval:int=30, 81 | device:str='cuda:0', 82 | update_freq:bool=False 83 | ): 84 | """ 85 | """ 86 | 87 | 88 | index_array = np.arange(x.shape[0]) 89 | index = 0 90 | loss = 0 91 | count = 0 92 | for i in range(int(np.ceil(x.shape[0]/batch_size))): 93 | if i % update_interval == 0: 94 | with torch.no_grad(): 95 | q = model(x) 96 | p = model.clustlayer.target_distribution(q) # update the auxiliary target distribution p 97 | y_pred = q.argmax(1) 98 | 99 | if update_freq and i != 0 : 100 | if y is not None: 101 | acc = np.round(metrics.acc(y.clone().detach().cpu().numpy(), y_pred.clone().detach().cpu().numpy()), 5) 102 | nmi = np.round(metrics.nmi(y.clone().detach().cpu().numpy().squeeze(), y_pred.clone().detach().cpu().numpy()), 5) 103 | ari = np.round(metrics.ari(y.clone().detach().cpu().numpy().squeeze(), y_pred.clone().detach().cpu().numpy()), 5) 104 | loss = np.round(loss/count, 5) 105 | print('iter %d: acc = %.5f, nmi = %.5f, ari = %.5f' % (i, acc, nmi, ari), ' ; loss=', loss) 106 | else: 107 | nmi = np.round(metrics.nmi(y_pred_last, y_pred.clone().detach().cpu().numpy()), 5) 108 | ari = np.round(metrics.ari(y_pred_last, y_pred.clone().detach().cpu().numpy()), 5) 109 | loss = np.round(loss/count, 5) 110 | print('iter %d: nmi = %.5f, ari = %.5f' % (i, nmi, ari), ' ; loss=', loss) 111 | 112 | y_pred_last = y_pred.detach().clone().cpu().numpy() 113 | 114 | optimizer.zero_grad() 115 | with torch.set_grad_enabled(True): 116 | idx = index_array[index * batch_size: min((index + 1) * batch_size, x.shape[0])] 117 | 118 | trainx = x[idx] 119 | trainy = p[idx] 120 | 121 | trainx = trainx.to(device) 122 | trainy = trainy.to(device) 123 | 124 | outputs = model(trainx) 125 | index = index + 1 if (index + 1) * batch_size < x.shape[0] else 0 126 | 127 | train_loss = criterion(outputs.log(), trainy) 128 | 129 | train_loss.backward() 130 | optimizer.step() 131 | 132 | loss += train_loss.item() 133 | count +=1 134 | 135 | return loss/count 136 | 137 | 138 | def testing( 139 | model: torch.nn.Module, 140 | dbgenerator: object, 141 | batch_size: int = 1024, 142 | device: str= 'cuda:0', 143 | return_truth: bool = True, 144 | ): 145 | """ 146 | 147 | """ 148 | 149 | preds = [] 150 | gtruths = [] 151 | for _, filexy in enumerate(dbgenerator): 152 | if isinstance(filexy, tuple) and len(filexy) == 2: 153 | filex, filey = filexy 154 | filex = filex.to(device) 155 | filey = filey.to(device) 156 | 157 | elif not isinstance(filexy, tuple): 158 | filex = filexy 159 | filex = filex.to(device) 160 | filey = None 161 | 162 | dataset = datasetscls.customDataset(filex, filey) 163 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 164 | 165 | model.eval() 166 | for _, batch in enumerate(dataloader): 167 | x = batch 168 | if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2: 169 | x, y = batch 170 | if return_truth: 171 | gtruths.append(y) 172 | elif return_truth: 173 | raise ValueError( 174 | "Dataset has no ground truth to return" 175 | ) 176 | x = x.to(device) 177 | preds.append( 178 | model(x).detach().cpu() 179 | ) 180 | 181 | if return_truth: 182 | return torch.cat(preds).max(1)[1], torch.cat(gtruths).long() 183 | else: 184 | return torch.cat(preds).max(1)[1] -------------------------------------------------------------------------------- /utils/datasetscls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision import datasets 7 | 8 | 9 | class customDataset(Dataset): 10 | def __init__(self, X, Y=None): 11 | self.X = X 12 | self.Y = Y 13 | self.n_samples = self.X.shape[0] 14 | 15 | def __getitem__(self, idx): 16 | if self.Y is not None: 17 | return self.X[idx], self.Y[idx] 18 | else: 19 | return self.X[idx] 20 | 21 | def __len__(self): 22 | return self.n_samples 23 | 24 | 25 | 26 | class STLGenerator(object): 27 | def __init__(self, datapath): 28 | self.filenamesx, self.filenamesy = self.getXfiles(datapath) 29 | 30 | def __getitem__(self, idx): 31 | filenamex = self.filenamesx[idx] 32 | X = np.load(filenamex) 33 | X = torch.from_numpy(X.astype(np.float32)) 34 | if self.filenamesy: 35 | filenamey = self.filenamesy[idx] 36 | y = np.load(filenamey) 37 | y = torch.from_numpy(y.astype(np.float32)) 38 | else: 39 | y=None 40 | return X, y 41 | 42 | def __len__(self): 43 | return len(self.filenamesx) 44 | 45 | 46 | @staticmethod 47 | def getXfiles(datapath): 48 | allxpath = [] 49 | allypath = [] 50 | for root, _, files in os.walk((os.path.normpath(datapath)), topdown=False): 51 | dir = root.split('/')[-1] 52 | if dir != 'random': 53 | for name in files: 54 | if name.endswith('x.npy'): 55 | path = os.path.join(root, name) 56 | allxpath.append(path) 57 | elif name.endswith('y.npy'): 58 | path = os.path.join(root, name) 59 | allypath.append(path) 60 | return allxpath, allypath 61 | 62 | 63 | class customGenerator(object): 64 | def __init__(self, datapath): 65 | self.filenamesx = self.getXfiles(datapath) 66 | 67 | def __getitem__(self, idx): 68 | filenamex = self.filenamesx[idx] 69 | X = np.load(filenamex) 70 | X = torch.from_numpy(X.astype(np.float32)) 71 | return X 72 | 73 | def __len__(self): 74 | return len(self.filenamesx) 75 | 76 | 77 | @staticmethod 78 | def getXfiles(datapath): 79 | allxpath = [] 80 | for root, _, files in os.walk((os.path.normpath(datapath)), topdown=False): 81 | for name in files: 82 | if name.endswith('.npy'): 83 | path = os.path.join(root, name) 84 | allxpath.append(path) 85 | 86 | return allxpath 87 | 88 | 89 | class customImageFolder(datasets.ImageFolder): 90 | """Custom dataset that includes image file paths. 91 | """ 92 | 93 | # override the __getitem__ method. this is the method that dataloader calls 94 | def __getitem__(self, index): 95 | # this is what ImageFolder normally returns 96 | original_tuple = super(customImageFolder, self).__getitem__(index) 97 | # the image file path 98 | path = self.imgs[index][0] 99 | # make a new tuple that includes original and the path 100 | tuple_with_path = (original_tuple + (path,)) 101 | return tuple_with_path 102 | 103 | -------------------------------------------------------------------------------- /utils/extractfeatures.py: -------------------------------------------------------------------------------- 1 | 2 | import torchvision.transforms as transforms 3 | import argparse 4 | import os 5 | import numpy as np 6 | import torch 7 | from utils.transfermodels import PytorchVGG, VGG16normalization 8 | from utils.datasetscls import customImageFolder 9 | from utils.utilityfn import meanstd_torchVGG16 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | parser = argparse.ArgumentParser(description='Feature extraction using Pytorch pretrained models') 14 | 15 | parser.add_argument('--datadir', metavar='DDIR', type=str, 16 | default='/home/elahe/NortfaceProject/codes/DEC-keras/results/clusters/Telegram_final/2021_AprilMayJun/VGG16/block5pool/normalisedfiltered/seeds/Stephane/', 17 | help='path to dataset') 18 | parser.add_argument('--premodel', '-a', type=str, metavar='PMODEL', 19 | choices=['alexnet', 'vgg16'], default='vgg16', 20 | help='Pre-trained model (default: vgg16)') 21 | parser.add_argument('--workers', default=4, type=int, 22 | help='number of data loading workers (default: 4)') 23 | parser.add_argument('--batch', default=64, type=int, 24 | help='mini-batch size (default: 64)') 25 | parser.add_argument('--conv', default=5, help='') 26 | parser.add_argument('--fc', default=None, help='') 27 | parser.add_argument('--ave_pool', default=True, type=bool, 28 | help='') 29 | parser.add_argument('--datasetnorm', default=True, type=bool, 30 | help='whether normalize the input data mean and std, or based on VGG16 statistics') 31 | parser.add_argument('--split', default=False, type=bool, 32 | help='whether split the data into smaller files') 33 | parser.add_argument('--splitsize', default=5120, type=int, 34 | help='In case of splitting, the size of the split') 35 | parser.add_argument('--savedir', metavar='SDIR', type=str, 36 | default='./save', 37 | help='path to saving directory') 38 | 39 | def main(): 40 | global args 41 | args = parser.parse_args() 42 | 43 | if args.datasetnorm: 44 | #mean, std = meanstd_torchVGG16(args) 45 | mean = [0.553, 0.524, 0.506]#[0.549, 0.525, 0.515] 46 | std = [0.248, 0.254, 0.243]#[0.249, 0.243, 0.243] 47 | tra = VGG16normalization(mean, std) 48 | else: 49 | tra = VGG16normalization() 50 | 51 | # load the data 52 | dataset = customImageFolder(args.datadir, transform=transforms.Compose(tra)) 53 | dataloader = torch.utils.data.DataLoader(dataset, 54 | batch_size=args.batch, 55 | num_workers=args.workers, 56 | pin_memory=True) 57 | 58 | if not os.path.exists(args.savedir): 59 | os.makedirs(args.savedir) 60 | 61 | model = PytorchVGG(conv=args.conv, fc=args.fc, av_pool=args.ave_pool) 62 | model.eval() 63 | model.to(device) 64 | 65 | 66 | if not args.split: 67 | image_paths = [] 68 | # discard the label information in the dataloader 69 | for i, (input_tensor, _, path) in enumerate(dataloader): 70 | image_paths.extend(path) 71 | input_var = input_tensor.clone() 72 | input_var = input_var.to(device) 73 | 74 | aux = model(input_var) 75 | aux = aux.data.cpu().numpy() 76 | 77 | if i == 0: 78 | features = np.zeros((len(dataset), aux.shape[1]), dtype='float32') 79 | 80 | aux = aux.astype('float32') 81 | if i < len(dataloader) - 1: 82 | features[i * args.batch: (i + 1) * args.batch] = aux 83 | else: 84 | # special treatment for final batch 85 | features[i * args.batch:] = aux 86 | 87 | if (i % 100) == 0 and i != 0: 88 | print(f'{i} batch have been computed') 89 | 90 | np.save(args.savedir + '/featuresx', features) 91 | 92 | with open(args.savedir + '/paths.txt', 'w') as f: 93 | for item in image_paths: 94 | f.write("%s\n" % item) 95 | else: 96 | image_paths = [] 97 | j = 0 98 | count = 0 99 | # discard the label information in the dataloader 100 | for i, (input_tensor, _, path) in enumerate(dataloader): 101 | saved = False 102 | image_paths.extend(path) 103 | input_var = input_tensor.clone() 104 | input_var = input_var.to(device) 105 | 106 | input_var = model(input_var) 107 | input_var = input_var.data.cpu().numpy() 108 | input_var = input_var.astype('float32') 109 | if j == 0: 110 | if i + np.ceil(args.splitsize/args.batch) < len(dataloader) - 1: 111 | features = np.zeros((args.splitsize, input_var.shape[1])) 112 | else: 113 | features = np.zeros((len(dataset) - (i * args.batch), input_var.shape[1])) 114 | 115 | 116 | if i < len(dataloader) - 1: 117 | features[j * args.batch: (j + 1) * args.batch] = input_var 118 | else: 119 | # special treatment for final batch 120 | features[j * args.batch:] = input_var 121 | 122 | count+=args.batch 123 | j+=1 124 | 125 | if count % (args.splitsize) == 0: 126 | print(f'{count} samples are computed') 127 | np.save(args.savedir + '/featuresx_%s' % (int(np.ceil(count / (args.splitsize)))), features) 128 | with open(args.savedir + '/paths_%s.txt' % (int(np.ceil(count / (args.splitsize)))), 'w') as f: 129 | for item in image_paths: 130 | f.write("%s\n" % item) 131 | j = 0 132 | image_paths = [] 133 | saved = True 134 | 135 | # saving the final split 136 | if not saved: 137 | np.save(args.savedir + '/featuresx_%s' % (int(np.ceil(count / (args.splitsize)))), features) 138 | with open(args.savedir + '/paths_%s.txt' % (int(np.ceil(count / (args.splitsize)))), 'w') as f: 139 | for item in image_paths: 140 | f.write("%s\n" % item) 141 | 142 | return features, image_paths 143 | 144 | 145 | if __name__ == '__main__': 146 | main() -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score 3 | 4 | nmi = normalized_mutual_info_score 5 | ari = adjusted_rand_score 6 | 7 | 8 | def acc(y_true, y_pred): 9 | """ 10 | Calculate clustering accuracy. Require scikit-learn installed 11 | # Arguments 12 | y: true labels, numpy.array with shape `(n_samples,)` 13 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 14 | # Return 15 | accuracy, in [0,1] 16 | """ 17 | y_true = y_true.astype(np.int64) 18 | assert y_pred.size == y_true.size 19 | D = max(y_pred.max(), y_true.max()) + 1 20 | w = np.zeros((D, D), dtype=np.int64) 21 | for i in range(y_pred.size): 22 | w[y_pred[i], y_true[i]] += 1 23 | 24 | ind = linear_assignment(w.max() - w) 25 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 26 | 27 | try: 28 | np.array(5).astype(float, copy=False) 29 | except TypeError: 30 | # Compat where astype accepted no copy argument 31 | def astype(array, dtype, copy=True): 32 | if not copy and array.dtype == dtype: 33 | return array 34 | return array.astype(dtype) 35 | else: 36 | astype = np.ndarray.astype 37 | 38 | 39 | def linear_assignment(X): 40 | """Solve the linear assignment problem using the Hungarian algorithm. 41 | 42 | The problem is also known as maximum weight matching in bipartite graphs. 43 | The method is also known as the Munkres or Kuhn-Munkres algorithm. 44 | 45 | Parameters 46 | ---------- 47 | X : array 48 | The cost matrix of the bipartite graph 49 | 50 | Returns 51 | ------- 52 | indices : array, 53 | The pairs of (row, col) indices in the original array giving 54 | the original ordering. 55 | 56 | References 57 | ---------- 58 | 59 | 1. http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html 60 | 61 | 2. Harold W. Kuhn. The Hungarian Method for the assignment problem. 62 | *Naval Research Logistics Quarterly*, 2:83-97, 1955. 63 | 64 | 3. Harold W. Kuhn. Variants of the Hungarian method for assignment 65 | problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956. 66 | 67 | 4. Munkres, J. Algorithms for the Assignment and Transportation Problems. 68 | *Journal of the Society of Industrial and Applied Mathematics*, 69 | 5(1):32-38, March, 1957. 70 | 71 | 5. http://en.wikipedia.org/wiki/Hungarian_algorithm 72 | """ 73 | indices = _hungarian(X).tolist() 74 | indices.sort() 75 | # Re-force dtype to ints in case of empty list 76 | indices = np.array(indices, dtype=int) 77 | # Make sure the array is 2D with 2 columns. 78 | # This is needed when dealing with an empty list 79 | indices.shape = (-1, 2) 80 | return indices 81 | 82 | 83 | class _HungarianState(object): 84 | """State of one execution of the Hungarian algorithm. 85 | 86 | Parameters 87 | ---------- 88 | cost_matrix : 2D matrix 89 | The cost matrix. Does not have to be square. 90 | """ 91 | 92 | def __init__(self, cost_matrix): 93 | cost_matrix = np.atleast_2d(cost_matrix) 94 | 95 | # If there are more rows (n) than columns (m), then the algorithm 96 | # will not be able to work correctly. Therefore, we 97 | # transpose the cost function when needed. Just have to 98 | # remember to swap the result columns back later. 99 | transposed = (cost_matrix.shape[1] < cost_matrix.shape[0]) 100 | if transposed: 101 | self.C = (cost_matrix.T).copy() 102 | else: 103 | self.C = cost_matrix.copy() 104 | self.transposed = transposed 105 | 106 | # At this point, m >= n. 107 | n, m = self.C.shape 108 | self.row_uncovered = np.ones(n, dtype=np.bool) 109 | self.col_uncovered = np.ones(m, dtype=np.bool) 110 | self.Z0_r = 0 111 | self.Z0_c = 0 112 | self.path = np.zeros((n + m, 2), dtype=int) 113 | self.marked = np.zeros((n, m), dtype=int) 114 | 115 | def _find_prime_in_row(self, row): 116 | """ 117 | Find the first prime element in the specified row. Returns 118 | the column index, or -1 if no starred element was found. 119 | """ 120 | col = np.argmax(self.marked[row] == 2) 121 | if self.marked[row, col] != 2: 122 | col = -1 123 | return col 124 | 125 | def _clear_covers(self): 126 | """Clear all covered matrix cells""" 127 | self.row_uncovered[:] = True 128 | self.col_uncovered[:] = True 129 | 130 | 131 | def _hungarian(cost_matrix): 132 | """The Hungarian algorithm. 133 | 134 | Calculate the Munkres solution to the classical assignment problem and 135 | return the indices for the lowest-cost pairings. 136 | 137 | Parameters 138 | ---------- 139 | cost_matrix : 2D matrix 140 | The cost matrix. Does not have to be square. 141 | 142 | Returns 143 | ------- 144 | indices : 2D array of indices 145 | The pairs of (row, col) indices in the original array giving 146 | the original ordering. 147 | """ 148 | state = _HungarianState(cost_matrix) 149 | 150 | # No need to bother with assignments if one of the dimensions 151 | # of the cost matrix is zero-length. 152 | step = None if 0 in cost_matrix.shape else _step1 153 | 154 | while step is not None: 155 | step = step(state) 156 | 157 | # Look for the starred columns 158 | results = np.array(np.where(state.marked == 1)).T 159 | 160 | # We need to swap the columns because we originally 161 | # did a transpose on the input cost matrix. 162 | if state.transposed: 163 | results = results[:, ::-1] 164 | 165 | return results 166 | 167 | 168 | # Individual steps of the algorithm follow, as a state machine: they return 169 | # the next step to be taken (function to be called), if any. 170 | 171 | def _step1(state): 172 | """Steps 1 and 2 in the Wikipedia page.""" 173 | 174 | # Step1: For each row of the matrix, find the smallest element and 175 | # subtract it from every element in its row. 176 | state.C -= state.C.min(axis=1)[:, np.newaxis] 177 | # Step2: Find a zero (Z) in the resulting matrix. If there is no 178 | # starred zero in its row or column, star Z. Repeat for each element 179 | # in the matrix. 180 | for i, j in zip(*np.where(state.C == 0)): 181 | if state.col_uncovered[j] and state.row_uncovered[i]: 182 | state.marked[i, j] = 1 183 | state.col_uncovered[j] = False 184 | state.row_uncovered[i] = False 185 | 186 | state._clear_covers() 187 | return _step3 188 | 189 | 190 | def _step3(state): 191 | """ 192 | Cover each column containing a starred zero. If n columns are covered, 193 | the starred zeros describe a complete set of unique assignments. 194 | In this case, Go to DONE, otherwise, Go to Step 4. 195 | """ 196 | marked = (state.marked == 1) 197 | state.col_uncovered[np.any(marked, axis=0)] = False 198 | 199 | if marked.sum() < state.C.shape[0]: 200 | return _step4 201 | 202 | 203 | def _step4(state): 204 | """ 205 | Find a noncovered zero and prime it. If there is no starred zero 206 | in the row containing this primed zero, Go to Step 5. Otherwise, 207 | cover this row and uncover the column containing the starred 208 | zero. Continue in this manner until there are no uncovered zeros 209 | left. Save the smallest uncovered value and Go to Step 6. 210 | """ 211 | # We convert to int as numpy operations are faster on int 212 | C = (state.C == 0).astype(np.int) 213 | covered_C = C * state.row_uncovered[:, np.newaxis] 214 | covered_C *= astype(state.col_uncovered, dtype=np.int, copy=False) 215 | n = state.C.shape[0] 216 | m = state.C.shape[1] 217 | while True: 218 | # Find an uncovered zero 219 | row, col = np.unravel_index(np.argmax(covered_C), (n, m)) 220 | if covered_C[row, col] == 0: 221 | return _step6 222 | else: 223 | state.marked[row, col] = 2 224 | # Find the first starred element in the row 225 | star_col = np.argmax(state.marked[row] == 1) 226 | if not state.marked[row, star_col] == 1: 227 | # Could not find one 228 | state.Z0_r = row 229 | state.Z0_c = col 230 | return _step5 231 | else: 232 | col = star_col 233 | state.row_uncovered[row] = False 234 | state.col_uncovered[col] = True 235 | covered_C[:, col] = C[:, col] * ( 236 | astype(state.row_uncovered, dtype=np.int, copy=False)) 237 | covered_C[row] = 0 238 | 239 | 240 | def _step5(state): 241 | """ 242 | Construct a series of alternating primed and starred zeros as follows. 243 | Let Z0 represent the uncovered primed zero found in Step 4. 244 | Let Z1 denote the starred zero in the column of Z0 (if any). 245 | Let Z2 denote the primed zero in the row of Z1 (there will always be one). 246 | Continue until the series terminates at a primed zero that has no starred 247 | zero in its column. Unstar each starred zero of the series, star each 248 | primed zero of the series, erase all primes and uncover every line in the 249 | matrix. Return to Step 3 250 | """ 251 | count = 0 252 | path = state.path 253 | path[count, 0] = state.Z0_r 254 | path[count, 1] = state.Z0_c 255 | 256 | while True: 257 | # Find the first starred element in the col defined by 258 | # the path. 259 | row = np.argmax(state.marked[:, path[count, 1]] == 1) 260 | if not state.marked[row, path[count, 1]] == 1: 261 | # Could not find one 262 | break 263 | else: 264 | count += 1 265 | path[count, 0] = row 266 | path[count, 1] = path[count - 1, 1] 267 | 268 | # Find the first prime element in the row defined by the 269 | # first path step 270 | col = np.argmax(state.marked[path[count, 0]] == 2) 271 | if state.marked[row, col] != 2: 272 | col = -1 273 | count += 1 274 | path[count, 0] = path[count - 1, 0] 275 | path[count, 1] = col 276 | 277 | # Convert paths 278 | for i in range(count + 1): 279 | if state.marked[path[i, 0], path[i, 1]] == 1: 280 | state.marked[path[i, 0], path[i, 1]] = 0 281 | else: 282 | state.marked[path[i, 0], path[i, 1]] = 1 283 | 284 | state._clear_covers() 285 | # Erase all prime markings 286 | state.marked[state.marked == 2] = 0 287 | return _step3 288 | 289 | 290 | def _step6(state): 291 | """ 292 | Add the value found in Step 4 to every element of each covered row, 293 | and subtract it from every element of each uncovered column. 294 | Return to Step 4 without altering any stars, primes, or covered lines. 295 | """ 296 | # the smallest uncovered value in the matrix 297 | if np.any(state.row_uncovered) and np.any(state.col_uncovered): 298 | minval = np.min(state.C[state.row_uncovered], axis=0) 299 | minval = np.min(minval[state.col_uncovered]) 300 | state.C[np.logical_not(state.row_uncovered)] += minval 301 | state.C[:, state.col_uncovered] -= minval 302 | return _step4 -------------------------------------------------------------------------------- /utils/params.py: -------------------------------------------------------------------------------- 1 | dims = [500, 500, 2000, 10] -------------------------------------------------------------------------------- /utils/transfermodels.py: -------------------------------------------------------------------------------- 1 | 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | import copy 5 | import torchvision.transforms as transforms 6 | 7 | 8 | class PytorchVGG(nn.Module): 9 | def __init__( 10 | self, 11 | fc=1, 12 | conv=None, 13 | av_pool=False 14 | ): 15 | super(PytorchVGG, self).__init__() 16 | 17 | self.basemodel = models.vgg16(pretrained=True) 18 | self.initiate() 19 | 20 | self.conv = conv 21 | self.fc = fc 22 | self.av_pool = av_pool 23 | 24 | if self.conv: 25 | if self.conv==3: 26 | self.features = list(self.basemodel.features.children()) [:17] 27 | self.s = 200704 28 | if self.av_pool: 29 | self.features = self.features[:-1] 30 | self.features.append(nn.AvgPool2d(6, stride=6, padding=3)) 31 | self.s = 25600 32 | 33 | elif self.conv==4: 34 | self.features = list(self.basemodel.features.children()) [:24] 35 | self.s = 100352 36 | if self.av_pool: 37 | self.features = self.features[:-1] 38 | self.features.append(nn.AvgPool2d(4, stride=4, padding=0)) 39 | self.s = 25088 40 | 41 | elif self.conv==5: 42 | self.features = list(self.basemodel.features.children()) 43 | self.s = 25088 44 | if self.av_pool: 45 | self.features = self.features[:-1] 46 | self.features.append(nn.AvgPool2d(2, stride=2, padding=0)) 47 | self.s = 25088 48 | else: 49 | raise Exception(f"The Conv layer={self.conv} should be 3, 4, or 5") 50 | 51 | self.myVGG = nn.Sequential(*self.features) 52 | 53 | elif self.fc: 54 | self.myVGG = copy.deepcopy(self.basemodel) 55 | if self.fc == 1: 56 | self.myVGG.classifier = self.myVGG.classifier[:3] 57 | self.s = 4096 58 | elif self.fc == 2: 59 | self.myVGG.classifier = self.myVGG.classifier[:-1] 60 | self.s = 4096 61 | else: 62 | raise Exception(f"at least one of these arguments, fc or conv, should be specified.") 63 | 64 | 65 | def forward( 66 | self, 67 | x 68 | ): 69 | x = self.myVGG(x) 70 | if len(x.shape)>2: 71 | x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3)) 72 | return x 73 | 74 | def initiate( 75 | self 76 | ): 77 | self.basemodel.eval() 78 | for param in self.basemodel.parameters(): 79 | param.requires_grad = False 80 | 81 | 82 | def VGG16normalization( 83 | mean=None, 84 | std=None 85 | ): 86 | 87 | if not mean: 88 | mean = [0.485, 0.456, 0.406] 89 | std = [0.229, 0.224, 0.225] 90 | normalize = transforms.Normalize(mean, std) 91 | 92 | tra = [transforms.Resize(224), 93 | transforms.CenterCrop(224), 94 | transforms.ToTensor(), 95 | normalize 96 | ] 97 | return tra 98 | 99 | -------------------------------------------------------------------------------- /utils/utilityfn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import datasets 4 | import torchvision.transforms as transforms 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def meanstd_torchVGG16(args, verbose=False): 10 | tra = [transforms.Resize(224), 11 | transforms.CenterCrop(224), 12 | transforms.ToTensor() 13 | ] 14 | dataset = datasets.ImageFolder(args.datadir, transform=transforms.Compose(tra)) 15 | dataloader = torch.utils.data.DataLoader(dataset, 16 | batch_size=args.batch, 17 | num_workers=args.workers, 18 | pin_memory=True) 19 | 20 | mean = 0. 21 | std = 0. 22 | nb_samples = 0. 23 | for _, (input_tensor, _) in enumerate(dataloader): 24 | input_tensor_ = input_tensor.detach().clone().to(device) 25 | input_tensor_ = input_tensor_.view(args.batch, input_tensor_.size()[1], -1) 26 | mean += input_tensor_.mean(2).sum(0) 27 | std += input_tensor_.std(2).sum(0) 28 | 29 | nb_samples += args.batch 30 | 31 | mean /= nb_samples 32 | std /= nb_samples 33 | 34 | if verbose: 35 | print(f'mean and std for the this dataset is {mean}, and {std}') 36 | return torch.round(mean,decimals=3), torch.round(std,decimals=3) 37 | 38 | 39 | def getinputsize(generator): 40 | for _, filexy in enumerate(generator): 41 | filex = filexy 42 | if (isinstance(filexy, tuple) or isinstance(filexy, list)) and len(filexy) == 2: 43 | filex, _ = filexy 44 | 45 | try: 46 | assert len(filex.shape) ==2 47 | except AssertionError: 48 | print("") 49 | raise 50 | 51 | inputsize=filex.shape[1] 52 | 53 | return inputsize 54 | 55 | def readseeds(datapath): 56 | features = np.load(datapath + 'featuresx.npy') 57 | features = torch.from_numpy(features.astype(np.float32)) 58 | 59 | pathfile = open(datapath+'/featuresy.txt', "r") 60 | pathlist = pathfile.readlines() 61 | pathlist = [path[:-1] for path in pathlist] 62 | pathfile.close() 63 | dictlabel = dict(zip(list(set(pathlist)), range(len(set(pathlist))))) 64 | labels = np.array([dictlabel[label] for label in pathlist]) 65 | 66 | return features, labels 67 | 68 | 69 | 70 | 71 | --------------------------------------------------------------------------------