├── .gitignore ├── README.md ├── img ├── n_pair_angular_loss_5000.png ├── n_pair_angular_loss_S.png ├── n_pair_loss_5000.png └── n_pair_loss_S.png ├── requirements.txt └── src ├── models └── CNN_3.py ├── modules ├── Dataset.py ├── Loss.py └── Sampler.py ├── n_pair_train.py ├── t_SNE.py └── utils └── mnist_to_img_pytorch.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | checkpoints 4 | datasets 5 | venv 6 | logs 7 | n_plus_1_train.py 8 | Sampling.py 9 | dousa.py 10 | Models.py 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Metric Learning ([npair loss](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf) & [angular loss](https://arxiv.org/pdf/1708.01682.pdf)) on mnist and Visualizing by t_SNE 2 | 3 | n_pair_loss|n_pair_angular_loss 4 | ---|--- 5 | ![](img/n_pair_loss_S.png)|![](img/n_pair_angular_loss_S.png) 6 | 7 | 8 | 9 | ## Usage 10 | run from top on terminal 11 | `pip install -r requirements.txt` 12 | `python src/utils/mnist_to_img.py` -- return mnist data separated by label 13 | `python src/n_pair_train.py` -- save model to `src/checkpoints/checkpoint.pth.tar`and log `logs/2019_00_CNN/*` 14 | `python src/t_SNE.py` -- show t_SNE 15 | -------------------------------------------------------------------------------- /img/n_pair_angular_loss_5000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomp11/metric_learning/2a9ce38522d779e98613189db97ef539fa96ddf0/img/n_pair_angular_loss_5000.png -------------------------------------------------------------------------------- /img/n_pair_angular_loss_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomp11/metric_learning/2a9ce38522d779e98613189db97ef539fa96ddf0/img/n_pair_angular_loss_S.png -------------------------------------------------------------------------------- /img/n_pair_loss_5000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomp11/metric_learning/2a9ce38522d779e98613189db97ef539fa96ddf0/img/n_pair_loss_5000.png -------------------------------------------------------------------------------- /img/n_pair_loss_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomp11/metric_learning/2a9ce38522d779e98613189db97ef539fa96ddf0/img/n_pair_loss_S.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | cycler==0.10.0 3 | filelock==3.0.12 4 | future==0.17.1 5 | grpcio==1.22.0 6 | joblib==0.13.2 7 | kiwisolver==1.1.0 8 | Markdown==3.1.1 9 | matplotlib==3.1.0 10 | numpy==1.16.4 11 | Pillow==6.0.0 12 | protobuf==3.7.1 13 | pyparsing==2.4.0 14 | python-dateutil==2.8.0 15 | scikit-learn==0.21.2 16 | scipy==1.3.0 17 | six==1.12.0 18 | sklearn==0.0 19 | tb-nightly==1.15.0a20190711 20 | torch==1.1.0 21 | torchvision==0.3.0 22 | tqdm==4.32.2 23 | typing==3.6.6 24 | typing-extensions==3.6.6 25 | Werkzeug==0.15.4 26 | -------------------------------------------------------------------------------- /src/models/CNN_3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | 5 | class CNN_3(nn.Module): 6 | def __init__(self): 7 | super(CNN_3, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 32, 3) # 28x28x32 -> 26x26x32 9 | self.conv2 = nn.Conv2d(32, 64, 3) # 26x26x64 -> 24x24x64 10 | self.pool = nn.MaxPool2d(2, 2) # 24x24x64 -> 12x12x64 11 | self.dropout1 = nn.Dropout2d() 12 | self.fc1 = nn.Linear(12 * 12 * 64, 128) 13 | 14 | 15 | def forward(self, x): 16 | x = F.relu(self.conv1(x)) 17 | x = self.pool(F.relu(self.conv2(x))) 18 | x = self.dropout1(x) 19 | x = x.view(-1, 12 * 12 * 64) 20 | x = F.relu(self.fc1(x)) 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /src/modules/Dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import os.path 4 | 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | 9 | 10 | def n_pair_dataset(data_path, transform): 11 | image_dataset = datasets.ImageFolder(data_path, transform) 12 | return image_dataset 13 | 14 | def default_image_loader(path): 15 | return Image.open(path) 16 | 17 | class N_Pair_ImageDataset(torch.utils.data.Dataset): 18 | def __init__(self, base_path, filenames_filename, n_pair_file_name, transform, 19 | loader=default_image_loader): 20 | self.base_path = base_path 21 | self.filenamelist = [] 22 | for line in open(filenames_filename): 23 | self.filenamelist.append(line.rstrip('\n')) 24 | paths = [] 25 | for line in open(n_pair_file_name): 26 | paths.append(([i for i in line.split(",")[0].split()], [i for i in line.split(",")[1].split()])) # ([anchors],[positives]) 27 | self.paths = paths 28 | self.transform = transform 29 | self.loader = loader 30 | def __getitem__(self, index): 31 | def path2img(path): 32 | img = self.loader(os.path.join(self.base_path,self.filenamelist[int(path)])) 33 | # print(img.getextrema()) 34 | 35 | return img 36 | # anchor_imgs = np.array([]) 37 | # positives_imgs = np.array([]) 38 | # for data in self.paths[index]: 39 | # print(data) 40 | # anchor_imgs.append(self.transform(path2img())) 41 | # positives_imgs.append(self.transform(path2img())) 42 | 43 | anchor_imgs = [self.transform(path2img(path)) for path in self.paths[index][0]] 44 | # print(anchor_imgs) 45 | positives_imgs = [self.transform(path2img(path)) for path in self.paths[index][1]] 46 | # print(anchor_imgs) 47 | anchor_imgs , positives_imgs = torch.stack(anchor_imgs), torch.stack(positives_imgs) 48 | return anchor_imgs, positives_imgs 49 | def __len__(self): 50 | return len(self.paths) 51 | 52 | 53 | class N_plus_1_ImageDataset(torch.utils.data.Dataset): 54 | def __init__(self, base_path, filenames_filename, n_plus_1_file_name, transform, 55 | loader=default_image_loader): 56 | self.base_path = base_path 57 | self.filenamelist = [] 58 | for line in open(filenames_filename): 59 | self.filenamelist.append(line.rstrip('\n')) 60 | paths = [] 61 | for line in open(n_plus_1_file_name): 62 | paths.append((line.split()[0], line.split()[1], line.split()[2:])) # (anchor,positive,[negatives]) 63 | self.paths = paths 64 | self.transform = transform 65 | self.loader = loader 66 | def __getitem__(self, index): 67 | 68 | def path2img(path): 69 | img = self.loader(os.path.join(self.base_path,self.filenamelist[int(path)])) 70 | return img 71 | 72 | anchor_img = self.transform(path2img(self.paths[index][0]))# [RGB, 224, 224] 73 | # anchor_img = torch.unsqueeze(anchor_img, 0)# [1, RGB, 224, 224] 74 | # print(anchor_img.size()) 75 | positives_img = self.transform(path2img(self.paths[index][1]))# [RGB, 224, 224] 76 | # positives_img = torch.unsqueeze(positives_img, 0)# [1, RGB, 224, 224] 77 | negatives_imgs = [self.transform(path2img(path)) for path in self.paths[index][2]] 78 | negatives_imgs = torch.stack(negatives_imgs)# [N, RGB, 224, 224] 79 | # print(torch.stack(negatives_imgs).size()) 80 | # negatives_imgs = torch.squeeze(torch.stack(negatives_imgs)) 81 | 82 | return anchor_img, positives_img, negatives_imgs 83 | def __len__(self): 84 | return len(self.paths) 85 | -------------------------------------------------------------------------------- /src/modules/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Angular_mc_loss(nn.Module): 8 | def __init__(self, alpha=45, in_degree=True): 9 | super(Angular_mc_loss, self).__init__() 10 | if in_degree: 11 | alpha = np.deg2rad(alpha) 12 | self.sq_tan_alpha = np.tan(alpha) ** 2 13 | 14 | def forward(self, embeddings, target, with_npair=True, lamb=2): 15 | n_pairs = self.get_n_pairs(target) 16 | n_pairs = n_pairs.cuda() 17 | f = embeddings[n_pairs[:, 0]] 18 | f_p = embeddings[n_pairs[:, 1]] 19 | # print(f, f_p) 20 | term1 = 4 * self.sq_tan_alpha * torch.matmul(f + f_p, torch.transpose(f_p, 0, 1)) 21 | term2 = 2 * (1 + self.sq_tan_alpha) * torch.sum(f * f_p, keepdim=True, dim=1) 22 | f_apn = term1 - term2 23 | mask = torch.ones_like(f_apn) - torch.eye(len(f)).cuda() 24 | f_apn = f_apn * mask 25 | loss = torch.mean(torch.logsumexp(f_apn, dim=1)) 26 | if with_npair: 27 | loss_npair = self.n_pair_mc_loss(f, f_p) 28 | # print(loss, loss_npair) 29 | loss = loss_npair + lamb*loss 30 | # Preventing overflow 31 | # with torch.no_grad(): 32 | # t = torch.max(x, dim=2)[0] # (batch_size, 1) 33 | # print(t.size()) 34 | # 35 | # x = torch.exp(x - t.unsqueeze(dim=1)) 36 | # x = torch.log(torch.exp(-t) + torch.sum(x, 2)) 37 | # loss = torch.mean(t + x) 38 | return loss 39 | 40 | @staticmethod 41 | def get_n_pairs(labels): 42 | """ 43 | Get index of n-pairs and n-negatives 44 | :param labels: label vector of mini-batch 45 | :return: A tuple of n_pairs (n, 2) 46 | """ 47 | labels = labels.cpu().data.numpy() 48 | n_pairs = [] 49 | for label in set(labels): 50 | label_mask = (labels == label) 51 | label_indices = np.where(label_mask)[0] 52 | if len(label_indices) < 2: 53 | continue 54 | anchor, positive = np.random.choice(label_indices, 2, replace=False) 55 | n_pairs.append([anchor, positive]) 56 | n_pairs = np.array(n_pairs) 57 | return torch.LongTensor(n_pairs) 58 | 59 | @staticmethod 60 | def n_pair_mc_loss(f, f_p): 61 | n_pairs = len(f) 62 | term1 = torch.matmul(f, torch.transpose(f_p, 0, 1)) 63 | term2 = torch.sum(f * f_p, keepdim=True, dim=1) 64 | f_apn = term1 - term2 65 | mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda() 66 | f_apn = f_apn * mask 67 | return torch.mean(torch.logsumexp(f_apn, dim=1)) 68 | 69 | class n_pair_mc_loss(nn.Module): 70 | def __init__(self): 71 | super(n_pair_mc_loss, self).__init__() 72 | 73 | def forward(self, f, f_p): 74 | n_pairs = len(f) 75 | term1 = torch.matmul(f, torch.transpose(f_p, 0, 1)) 76 | term2 = torch.sum(f * f_p, keepdim=True, dim=1) 77 | f_apn = term1 - term2 78 | mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda() 79 | f_apn = f_apn * mask 80 | return torch.mean(torch.logsumexp(f_apn, dim=1)) 81 | 82 | 83 | class N_plus_1_angularLoss(nn.Module): 84 | """ 85 | Angular loss 86 | Wang, Jian. "Deep Metric Learning with Angular Loss," CVPR, 2017 87 | https://arxiv.org/pdf/1708.01682.pdf 88 | """ 89 | 90 | def __init__(self, l2_reg=0.02, angle_bound=1., lambda_ang=2): 91 | super(my_AngularLoss, self).__init__() 92 | self.l2_reg = l2_reg 93 | self.angle_bound = angle_bound 94 | self.lambda_ang = lambda_ang 95 | self.softplus = nn.Softplus() 96 | 97 | def forward(self, anchors, positives, negatives): 98 | 99 | losses = self.angular_loss(anchors, positives, negatives, self.angle_bound) + self.l2_reg * self.l2_loss(anchors, positives) 100 | 101 | return losses 102 | 103 | @staticmethod 104 | def angular_loss(anchors, positives, negatives, angle_bound=1.): 105 | """ 106 | Calculates angular loss 107 | :param anchors: A torch.Tensor, (n, embedding_size) 108 | :param positives: A torch.Tensor, (n, embedding_size) 109 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 110 | :param angle_bound: tan^2 angle 111 | :return: A scalar 112 | """ 113 | 114 | anchors = torch.unsqueeze(anchors, dim=1) # (batch_size, 1, embedding_size) 115 | positives = torch.unsqueeze(positives, dim=1) # (batch_size, 1, embedding_size) 116 | batch_size = anchors.size()[0] 117 | negatives = [negatives[i*5:(i+1)*5] for i in range(batch_size)] 118 | negatives = torch.stack(negatives)# (batch_size, n-1, embedding_size) 119 | 120 | anchors, positives, negatives = anchors.cuda(), positives.cuda(), negatives.cuda() 121 | 122 | x = 4. * angle_bound * torch.matmul((anchors + positives), negatives.transpose(1, 2)) - 2. * (1. + angle_bound) * torch.matmul(anchors, positives.transpose(1, 2)) # (n, 1, n-1) 123 | 124 | print(x.size()) 125 | # Preventing overflow 126 | with torch.no_grad(): 127 | t = torch.max(x, dim=2)[0] # (batch_size, 1) 128 | print(t.size()) 129 | 130 | x = torch.exp(x - t.unsqueeze(dim=1)) 131 | x = torch.log(torch.exp(-t) + torch.sum(x, 2)) 132 | loss = torch.mean(t + x) 133 | 134 | return loss 135 | 136 | @staticmethod 137 | def l2_loss(anchors, positives): 138 | """ 139 | Calculates L2 norm regularization loss 140 | :param anchors: A torch.Tensor, (n, embedding_size) 141 | :param positives: A torch.Tensor, (n, embedding_size) 142 | :return: A scalar 143 | """ 144 | return torch.sum(anchors ** 2 + positives ** 2) / anchors.shape[0] 145 | 146 | 147 | 148 | class N_plus_1_Loss(nn.Module): 149 | """ 150 | N-Pair loss 151 | Sohn, Kihyuk. "Improved Deep Metric Learning with Multi-class N-pair Loss Objective," Advances in Neural Information 152 | Processing Systems. 2016. 153 | http://papers.nips.cc/paper/6199-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective 154 | """ 155 | 156 | def __init__(self, l2_reg=0.02): 157 | super(NPairLoss, self).__init__() 158 | self.l2_reg = l2_reg 159 | 160 | def forward(self, anchors, positives, negatives): 161 | """ 162 | anchors (batch_size, embedding_size) 163 | positives (batch_size, embedding_size) 164 | negatives (batch_size*(n-1), embedding_size) 165 | """ 166 | batch_size = anchors.size()[0] 167 | negatives = [negatives[i*5:(i+1)*5] for i in range(batch_size)] 168 | negatives = torch.stack(negatives)# (batch_size, n-1, embedding_size) 169 | 170 | # print(anchors) 171 | anchors, positives, negatives = anchors.cuda(), positives.cuda(), negatives.cuda() 172 | losses = self.n_pair_loss(anchors, positives, negatives) \ 173 | + self.l2_reg * self.l2_loss(anchors, positives) 174 | # print(self.n_pair_loss(anchors, positives, negatives), self.l2_reg * self.l2_loss(anchors, positives)) 175 | return losses 176 | 177 | 178 | @staticmethod 179 | def n_pair_loss(anchors, positives, negatives): 180 | """ 181 | Calculates N-Pair loss 182 | :param anchors: A torch.Tensor, (n, embedding_size) 183 | :param positives: A torch.Tensor, (n, embedding_size) 184 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 185 | :return: A scalar 186 | """ 187 | anchors = torch.unsqueeze(anchors, dim=1) # (n, 1, embedding_size) 188 | positives = torch.unsqueeze(positives, dim=1) # (n, 1, embedding_size) 189 | 190 | x = torch.matmul(anchors, (negatives - positives).transpose(1, 2)) # (n, 1, n-1) 191 | x = torch.sum(torch.exp(x), 2) # (n, 1) 192 | loss = torch.mean(torch.log(1+x)) 193 | return loss 194 | 195 | @staticmethod 196 | def l2_loss(anchors, positives): 197 | """ 198 | Calculates L2 norm regularization loss 199 | :param anchors: A torch.Tensor, (n, embedding_size) 200 | :param positives: A torch.Tensor, (n, embedding_size) 201 | :return: A scalar 202 | """ 203 | return torch.sum(anchors ** 2 + positives ** 2) / anchors.shape[0] 204 | -------------------------------------------------------------------------------- /src/modules/Sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original source: https://github.com/adambielski/siamese-triplet 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.sampler import BatchSampler 10 | 11 | 12 | class BalancedBatchSampler(BatchSampler): 13 | """ 14 | BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples. 15 | Returns batches of size n_classes * n_samples 16 | """ 17 | 18 | def __init__(self, dataset, n_classes, n_samples): 19 | loader = DataLoader(dataset) 20 | self.labels_list = [] 21 | for _, label in loader: 22 | self.labels_list.append(label) 23 | self.labels = torch.LongTensor(self.labels_list) 24 | self.labels_set = list(set(self.labels.numpy())) 25 | self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] 26 | for label in self.labels_set} 27 | for l in self.labels_set: 28 | np.random.shuffle(self.label_to_indices[l]) 29 | self.used_label_indices_count = {label: 0 for label in self.labels_set} 30 | self.count = 0 31 | self.n_classes = n_classes 32 | self.n_samples = n_samples 33 | self.dataset = dataset 34 | self.batch_size = self.n_samples * self.n_classes 35 | 36 | def __iter__(self): 37 | self.count = 0 38 | while self.count + self.batch_size < len(self.dataset): 39 | classes = np.random.choice(self.labels_set, self.n_classes, replace=False) 40 | indices = [] 41 | for class_ in classes: 42 | indices.extend(self.label_to_indices[class_][ 43 | self.used_label_indices_count[class_]:self.used_label_indices_count[ 44 | class_] + self.n_samples]) 45 | self.used_label_indices_count[class_] += self.n_samples 46 | if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): 47 | np.random.shuffle(self.label_to_indices[class_]) 48 | self.used_label_indices_count[class_] = 0 49 | yield indices 50 | self.count += self.n_classes * self.n_samples 51 | 52 | def __len__(self): 53 | return len(self.dataset) // self.batch_size 54 | -------------------------------------------------------------------------------- /src/n_pair_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import glob 4 | from PIL import Image 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | import torch.nn.functional as F 9 | from torchvision import datasets, transforms 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | import torchvision.models as models 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.utils.data.dataset import Subset 15 | 16 | from modules.Loss import Angular_mc_loss, Angular_mc_loss, N_plus_1_Loss, n_pair_mc_loss 17 | from modules.Sampler import BalancedBatchSampler 18 | from models.CNN_3 import CNN_3 19 | 20 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # "metric_test" 21 | 22 | traindata_path = os.path.join(base_path, "datasets", "mnist") # "metric_test\datasets\mnist" 23 | transform = transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.1307,), (0.3081,)) 26 | ]) 27 | 28 | # val_dataset = datasets.MNIST(root=testdata_path, train=False, download=True, transform=transforms) 29 | 30 | #ImageFolderのdefaultloaderだとmnistなのに3,28,28だったのでpillowのloader使う 31 | def image_loader(path): 32 | return Image.open(path) 33 | datasets = datasets.ImageFolder(traindata_path, transform, loader=image_loader) 34 | 35 | train_size = len(datasets)*9//10 36 | val_size = len(datasets) - train_size 37 | train_dataset, val_dataset = torch.utils.data.random_split(datasets, [train_size, val_size]) 38 | # subsetはスライスでとるのでimagefolderはラベル順に取り込んでいるからラベルが偏る 39 | # random_splitはランダム 40 | # train_dataset = Subset(datasets, list(range(train_size))) 41 | # val_dataset = Subset(datasets, list(range(train_size, len(datasets)))) 42 | train_batch_sampler = BalancedBatchSampler(train_dataset, n_classes=10, n_samples=8) 43 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler) 44 | val_batch_sampler = BalancedBatchSampler(val_dataset, n_classes=10, n_samples=8) 45 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_sampler=val_batch_sampler) 46 | 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | model = CNN_3().to(device) 49 | criterion = Angular_mc_loss() 50 | optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9) 51 | 52 | 53 | log_base_path = os.path.join(base_path, "logs") 54 | dt = datetime.datetime.now() 55 | model_id = len(glob.glob(os.path.join(log_base_path, "{}{}{}*".format(dt.year, dt.month, dt.day)))) 56 | log_dir_name = "{}{:02}{:02}_{:02}_{}".format(dt.year, dt.month, dt.day, model_id, model.__class__.__name__) 57 | log_path = os.path.join(log_base_path, log_dir_name) 58 | writer = SummaryWriter(log_dir=log_path) 59 | 60 | epochs = 20 61 | 62 | 63 | def adjust_learning_rate(optimizer, epoch): 64 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 65 | lr = args.lr * (0.1 ** (epoch // 30)) 66 | for param_group in optimizer.param_groups: 67 | param_group['lr'] = lr 68 | 69 | def train(epoch): 70 | for batch_idx, (data, target) in enumerate(train_loader): 71 | model.train() 72 | optimizer.zero_grad() 73 | data, target = data.cuda(), target.cuda() 74 | embedded = model(data) 75 | loss = criterion(embedded, target) 76 | optimizer.zero_grad() 77 | loss.backward() 78 | optimizer.step() 79 | writer.add_scalar("loss/train_loss", loss.item(), (len(train_loader)*(epoch-1)+batch_idx)) # 675*e+i 80 | 81 | if batch_idx % 20 == 0: 82 | #validation 83 | model.eval() 84 | with torch.no_grad(): 85 | val_losses = 0.0 86 | for idx, (data, target) in enumerate(val_loader): 87 | data, target = data.cuda(), target.cuda() 88 | embedded = model(data) 89 | val_loss = criterion(embedded, target) 90 | val_losses += val_loss 91 | mean_val_loss = val_losses/len(val_loader) 92 | writer.add_scalar("loss/val_loss", loss.item(), (len(train_loader)*(epoch-1)+batch_idx)) 93 | print('Train Epoch: {:>3} [{:>5}/{:>5} ({:>3.0f}%)]\ttrain_loss: {:>2.4f}\tval_loss: {:>2.4f}'.format( 94 | epoch, 95 | batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), 96 | loss.item(), val_loss)) 97 | 98 | 99 | 100 | 101 | def save(epoch): 102 | checkpoint_path = os.path.join(base_path, "checkpoints") 103 | save_file = "checkpoint.pth.tar" 104 | if not os.path.exists(checkpoint_path): 105 | os.makedirs(checkpoint_path) 106 | if not os.path.exists(os.path.join(checkpoint_path, log_dir_name)): 107 | os.makedirs(os.path.join(checkpoint_path, log_dir_name)) 108 | save_path = os.path.join(checkpoint_path, log_dir_name, save_file) 109 | torch.save(model.state_dict(), save_path) 110 | 111 | 112 | if __name__ == "__main__": 113 | for epoch in range(1, epochs+1): 114 | train(epoch) 115 | save(epoch) 116 | -------------------------------------------------------------------------------- /src/t_SNE.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import numpy as np 4 | from sklearn.manifold import TSNE 5 | import matplotlib.pyplot as plt 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torchvision import datasets, transforms 12 | from sklearn.decomposition import PCA 13 | 14 | from models.CNN_3 import CNN_3 15 | 16 | 17 | 18 | def test(t_SNE=True): 19 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 | testdata_path = os.path.join(base_path, "datasets", "mnist_test") 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | test_set = datasets.MNIST(root=testdata_path, train=False, download=True, transform=transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.1307,), (0.3081,)) 26 | ])) 27 | 28 | test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=5000, shuffle=False) 29 | # 1バッチ分だけ取り出し 30 | # testセット10000の内後半の5000取り出し 31 | data_iter = iter(test_loader) 32 | _, _ = data_iter.next() 33 | data, target = data_iter.next() # (5000,1,28,28), (5000) 34 | target = torch.LongTensor(target) 35 | print(data.size(), target.size()) 36 | # data = torch.stack([test_set[i][0] for i in range(5000)]) # (5000,1,28,28) 37 | # target = torch.Tensor([test_set[i][1] for i in range(5000)]) # (5000) 38 | 39 | net = CNN_3() 40 | model = net.to(device) 41 | model.load_state_dict(torch.load('./checkpoints/checkpoint.pth.tar')) 42 | classes = ["0","1","2","3","4","5","6","7","8","9"] 43 | 44 | 45 | model.eval() 46 | test_loss = 0 47 | correct = 0 48 | with torch.no_grad(): 49 | pred_categories = [] #予想ラベルたち 50 | 51 | target_array = target.numpy() 52 | master_features = [] #マスター画像の埋め込みたち 53 | for i in classes: 54 | indexes = np.where(target_array==int(i))[0] 55 | master_img = data[np.random.choice(indexes)].to(device) 56 | master_img = torch.unsqueeze(master_img, dim=0) 57 | master_img = master_img.to(device) 58 | embedded_master_img = model(master_img) 59 | master_features.append(embedded_master_img) 60 | master_features = torch.cat(master_features) # (10, 128) 61 | 62 | data = data.to(device) 63 | output = model(data) 64 | output_unbind = torch.unbind(output) 65 | for embedded_img in output_unbind: 66 | distances = torch.sum((master_features - embedded_img)**2, dim=1) #(10) 67 | pred_category = classes[distances.argmin()] 68 | pred_categories.append(int(pred_category))# 69 | pred_category = torch.LongTensor(pred_categories) 70 | # ラベルが数字だったのでtorch.Tensorにして条件文でやった。strならforぶん回す 71 | correct += (target == pred_category).sum() 72 | accuracy = float(correct)*100 / len(pred_categories) 73 | 74 | print('Accuracy: {}/{} ({}%)\n'.format(correct, len(pred_categories), accuracy)) 75 | 76 | if t_SNE: 77 | t_sne(output, target) 78 | 79 | 80 | def t_sne(latent_vecs, target): 81 | latent_vecs = latent_vecs.to("cpu") 82 | latent_vecs = latent_vecs.numpy() 83 | latent_vecs_reduced = TSNE(n_components=2, random_state=0).fit_transform(latent_vecs) 84 | # latent_vecs_reduced = PCA(n_components=2).fit_transform(latent_vecs) 85 | plt.scatter(latent_vecs_reduced[:, 0], latent_vecs_reduced[:, 1], 86 | c=target, cmap='jet') 87 | plt.colorbar() 88 | plt.show() 89 | 90 | if __name__ == '__main__': 91 | test() 92 | -------------------------------------------------------------------------------- /src/utils/mnist_to_img_pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torchvision import datasets 5 | import numpy as np 6 | 7 | base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | dataset_path = os.path.join(base_path, "datasets") 9 | 10 | mnist_train_path = os.path.join(dataset_path, "mnist_train") 11 | metric_mnist_train_path = os.path.join(dataset_path, "metric_mnist_train") 12 | metric_mnist_val_path = os.path.join(dataset_path, "metric_mnist_val") 13 | 14 | mnist_train = datasets.MNIST(root=mnist_train_path, train=True, download=True) 15 | 16 | if not os.path.isdir(dataset_path): 17 | os.mkdir(dataset_path) 18 | if not os.path.isdir(mnist_train_path): 19 | os.mkdir(mnist_train_path) 20 | if not os.path.isdir(metric_mnist_train_path): 21 | os.mkdir(metric_mnist_train_path) 22 | if not os.path.isdir(metric_mnist_val_path): 23 | os.mkdir(metric_mnist_val_path) 24 | for i in range(10): 25 | dirname = str(i) 26 | if not os.path.isdir(os.path.join(metric_mnist_train_path, dirname)): 27 | os.mkdir(os.path.join(metric_mnist_train_path, dirname)) 28 | if not os.path.isdir(os.path.join(metric_mnist_val_path, dirname)): 29 | os.mkdir(os.path.join(metric_mnist_val_path, dirname)) 30 | 31 | def save(data, target, savepath, index): 32 | filename = os.path.join(savepath, str(target), "train{0:04d}.png".format(index)) 33 | data.save(filename) 34 | print(os.path.join(str(target), "train{0:04d}.png".format(index))) 35 | 36 | for i in range(len(mnist_train)*9//10): 37 | data, target = mnist_train[i] 38 | save(data, target, metric_mnist_train_path, i) 39 | for i in range(len(mnist_train)*9//10, len(mnist_train)): 40 | data, target = mnist_train[i] 41 | save(data, target, metric_mnist_val_path, i) 42 | --------------------------------------------------------------------------------