├── .gitignore ├── Dataset ├── ImageDataLoader.py ├── JigsawImageLoader.py └── produce_small_data.py ├── JigsawNetwork.py ├── JigsawTrain.py ├── README.md ├── Utils ├── Layers.py ├── TrainingUtils.py ├── convert2h5.py └── logger.py ├── permutations_1000.npy ├── run_jigsaw_training.sh └── select_permutations.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints* 2 | *.pyc 3 | JpsTraininig 4 | run.sh 5 | 6 | -------------------------------------------------------------------------------- /Dataset/ImageDataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 18 11:58:07 2017 4 | 5 | @author: Biagio Brattoli 6 | """ 7 | import os, numpy as np 8 | from time import time 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | import torch 12 | 13 | from PIL import Image 14 | from random import shuffle 15 | 16 | def load_image(path,permutations,image_transformer,augment_tile): 17 | img = Image.open(path).convert('RGB') 18 | img = image_transformer(img) 19 | 20 | a = 75/2 21 | tiles = [None] * 9 22 | for n in range(9): 23 | i = n/3 24 | j = n%3 25 | c = [a*i*2+a,a*j*2+a] 26 | tile = img.crop((c[1]-a,c[0]-a,c[1]+a+1,c[0]+a+1)) 27 | tile = augment_tile(tile) 28 | # Normalize the patches indipendently to avoid low level features shortcut 29 | #m = tile.mean() 30 | #s = tile.std() 31 | #norm = transforms.Normalize(mean=[m, m, m], 32 | #std =[s, s, s]) 33 | #tile = norm(tile) 34 | tiles[n] = tile 35 | 36 | order = np.random.randint(len(permutations)) 37 | data = [tiles[permutations[order][t]] for t in range(9)] 38 | data = torch.stack(data,0) 39 | return data, int(order) 40 | 41 | class DataLoader(): 42 | def __init__(self,data_path,txt_list,batchsize=256,classes=1000): 43 | self.batchsize = batchsize 44 | 45 | self.data_path = data_path 46 | self.names, _ = self.__dataset_info(txt_list) 47 | self.N = len(self.names) 48 | #self.N = self.N-(self.N%batchsize) 49 | 50 | self.permutations = self.__retrive_permutations(classes) 51 | 52 | self.__image_transformer = transforms.Compose([ 53 | transforms.Resize(256,Image.BILINEAR), 54 | transforms.CenterCrop(225)]) 55 | self.__augment_tile = transforms.Compose([ 56 | transforms.RandomCrop(64), 57 | transforms.Resize((75,75)), 58 | transforms.Lambda(rgb_jittering), 59 | transforms.ToTensor(), 60 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 61 | std =[0.229, 0.224, 0.225])]) 62 | 63 | def __iter__(self): 64 | self.counter = 0 65 | shuffle(self.names) 66 | return self 67 | 68 | def next(self): 69 | try: 70 | names = [self.data_path+'/'+n for n in self.names[self.counter:self.counter+self.batchsize]] 71 | except IndexError: 72 | raise StopIteration 73 | self.counter += self.batchsize 74 | batch = [load_image(n,self.permutations,self.__image_transformer,self.__augment_tile) 75 | for n in names] 76 | 77 | data, labels = zip(*batch) 78 | labels = torch.LongTensor(labels) 79 | data = torch.stack(data, 0) 80 | return data, labels, 0 81 | 82 | def __dataset_info(self,txt_labels): 83 | with open(txt_labels,'r') as f: 84 | images_list = f.readlines() 85 | 86 | file_names = [] 87 | labels = [] 88 | for row in images_list: 89 | row = row.split(' ') 90 | file_names.append(row[0]) 91 | labels.append(int(row[1])) 92 | 93 | return file_names, labels 94 | 95 | def __retrive_permutations(self,classes): 96 | all_perm = np.load('permutations_%d.npy'%(classes)) 97 | # from range [1,9] to [0,8] 98 | if all_perm.min()==1: 99 | all_perm = all_perm-1 100 | 101 | return all_perm 102 | 103 | 104 | def rgb_jittering(im): 105 | im = np.array(im,np.float32)#convert to numpy array 106 | for ch in range(3): 107 | thisRand = np.random.uniform(0.8, 1.2) 108 | im[:,:,ch] *= thisRand 109 | shiftVal = np.random.randint(0,6) 110 | if np.random.randint(2) == 1: 111 | shiftVal = -shiftVal 112 | im += shiftVal; 113 | im = im.astype(np.uint8) 114 | im = im.astype(np.float32) 115 | return im -------------------------------------------------------------------------------- /Dataset/JigsawImageLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 18 11:58:07 2017 4 | 5 | @author: Biagio Brattoli 6 | """ 7 | import numpy as np 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | 14 | class DataLoader(data.Dataset): 15 | def __init__(self, data_path, txt_list, classes=1000): 16 | self.data_path = data_path 17 | self.names, _ = self.__dataset_info(txt_list) 18 | self.N = len(self.names) 19 | self.permutations = self.__retrive_permutations(classes) 20 | 21 | self.__image_transformer = transforms.Compose([ 22 | transforms.Resize(256, Image.BILINEAR), 23 | transforms.CenterCrop(255)]) 24 | self.__augment_tile = transforms.Compose([ 25 | transforms.RandomCrop(64), 26 | transforms.Resize((75, 75), Image.BILINEAR), 27 | transforms.Lambda(rgb_jittering), 28 | transforms.ToTensor(), 29 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], 30 | # std =[0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | def __getitem__(self, index): 34 | framename = self.data_path + '/' + self.names[index] 35 | 36 | img = Image.open(framename).convert('RGB') 37 | if np.random.rand() < 0.30: 38 | img = img.convert('LA').convert('RGB') 39 | 40 | if img.size[0] != 255: 41 | img = self.__image_transformer(img) 42 | 43 | s = float(img.size[0]) / 3 44 | a = s / 2 45 | tiles = [None] * 9 46 | for n in range(9): 47 | i = n / 3 48 | j = n % 3 49 | c = [a * i * 2 + a, a * j * 2 + a] 50 | c = np.array([c[1] - a, c[0] - a, c[1] + a + 1, c[0] + a + 1]).astype(int) 51 | tile = img.crop(c.tolist()) 52 | tile = self.__augment_tile(tile) 53 | # Normalize the patches indipendently to avoid low level features shortcut 54 | m, s = tile.view(3, -1).mean(dim=1).numpy(), tile.view(3, -1).std(dim=1).numpy() 55 | s[s == 0] = 1 56 | norm = transforms.Normalize(mean=m.tolist(), std=s.tolist()) 57 | tile = norm(tile) 58 | tiles[n] = tile 59 | 60 | order = np.random.randint(len(self.permutations)) 61 | data = [tiles[self.permutations[order][t]] for t in range(9)] 62 | data = torch.stack(data, 0) 63 | 64 | return data, int(order), tiles 65 | 66 | def __len__(self): 67 | return len(self.names) 68 | 69 | def __dataset_info(self, txt_labels): 70 | with open(txt_labels, 'r') as f: 71 | images_list = f.readlines() 72 | 73 | file_names = [] 74 | labels = [] 75 | for row in images_list: 76 | row = row.split(' ') 77 | file_names.append(row[0]) 78 | labels.append(int(row[1])) 79 | 80 | return file_names, labels 81 | 82 | def __retrive_permutations(self, classes): 83 | all_perm = np.load('permutations_%d.npy' % (classes)) 84 | # from range [1,9] to [0,8] 85 | if all_perm.min() == 1: 86 | all_perm = all_perm - 1 87 | 88 | return all_perm 89 | 90 | 91 | def rgb_jittering(im): 92 | im = np.array(im, 'int32') 93 | for ch in range(3): 94 | im[:, :, ch] += np.random.randint(-2, 2) 95 | im[im > 255] = 255 96 | im[im < 0] = 0 97 | return im.astype('uint8') 98 | -------------------------------------------------------------------------------- /Dataset/produce_small_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Sep 14 12:16:31 2017 4 | 5 | @author: Biagio Brattoli 6 | """ 7 | import os, sys, numpy as np 8 | import argparse 9 | from time import time 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import torch.utils.data as data 14 | import torchvision.transforms as transforms 15 | 16 | from PIL import Image 17 | 18 | datapath = 'path-to-imagenet' 19 | 20 | trainval = 'train' 21 | #trainval = 'val' 22 | 23 | def main(): 24 | #data = DataLoader(datapath+'/ILSVRC2012_img_train', datapath+'/ilsvrc12_train.txt') 25 | data = DataLoader(datapath+'/ILSVRC2012_img_'+trainval, datapath+'/ilsvrc12_'+trainval+'.txt') 26 | loader = torch.utils.data.DataLoader(dataset=data,batch_size=1, 27 | shuffle=False,num_workers=20) 28 | 29 | count = 0 30 | for i, filename in enumerate(tqdm(loader)): 31 | count += 1 32 | 33 | 34 | class DataLoader(data.Dataset): 35 | def __init__(self,data_path,txt_list): 36 | self.data_path = data_path if data_path[-1]!='/' else data_path[:-1] 37 | self.names, _ = self.__dataset_info(txt_list) 38 | self.__image_transformer = transforms.Compose([ 39 | transforms.Resize(256,Image.BILINEAR), 40 | transforms.CenterCrop(255)]) 41 | self.save_path = self.data_path+'_255x255/' 42 | if not os.path.exists(self.save_path): 43 | os.makedirs(self.save_path) 44 | for name in self.names: 45 | if '/' in name: 46 | fold = self.save_path+name[:name.rfind('/')] 47 | if not os.path.exists(fold): 48 | os.makedirs(fold) 49 | 50 | def __getitem__(self, index): 51 | name = self.names[index] 52 | if os.path.exists(self.save_path+name): 53 | return None, None 54 | 55 | filename = self.data_path+'/'+name 56 | img = Image.open(filename).convert('RGB') 57 | img = self.__image_transformer(img) 58 | img.save(self.save_path+name) 59 | return self.names[index] 60 | 61 | 62 | def __len__(self): 63 | return len(self.names) 64 | 65 | def __dataset_info(self,txt_labels): 66 | with open(txt_labels,'r') as f: 67 | images_list = f.readlines() 68 | 69 | file_names = [] 70 | labels = [] 71 | for row in images_list: 72 | row = row.split(' ') 73 | file_names.append(row[0]) 74 | labels.append(int(row[1])) 75 | #if len(file_names)>128*10: 76 | #break 77 | 78 | return file_names, labels 79 | 80 | if __name__ == "__main__": 81 | main() -------------------------------------------------------------------------------- /JigsawNetwork.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 13 15:57:01 2017 4 | 5 | @author: Biagio Brattoli 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from torch import cat 10 | import torch.nn.init as init 11 | 12 | import sys 13 | sys.path.append('Utils') 14 | from Layers import LRN 15 | 16 | class Network(nn.Module): 17 | 18 | def __init__(self, classes=1000): 19 | super(Network, self).__init__() 20 | 21 | self.conv = nn.Sequential() 22 | self.conv.add_module('conv1_s1',nn.Conv2d(3, 96, kernel_size=11, stride=2, padding=0)) 23 | self.conv.add_module('relu1_s1',nn.ReLU(inplace=True)) 24 | self.conv.add_module('pool1_s1',nn.MaxPool2d(kernel_size=3, stride=2)) 25 | self.conv.add_module('lrn1_s1',LRN(local_size=5, alpha=0.0001, beta=0.75)) 26 | 27 | self.conv.add_module('conv2_s1',nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2)) 28 | self.conv.add_module('relu2_s1',nn.ReLU(inplace=True)) 29 | self.conv.add_module('pool2_s1',nn.MaxPool2d(kernel_size=3, stride=2)) 30 | self.conv.add_module('lrn2_s1',LRN(local_size=5, alpha=0.0001, beta=0.75)) 31 | 32 | self.conv.add_module('conv3_s1',nn.Conv2d(256, 384, kernel_size=3, padding=1)) 33 | self.conv.add_module('relu3_s1',nn.ReLU(inplace=True)) 34 | 35 | self.conv.add_module('conv4_s1',nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2)) 36 | self.conv.add_module('relu4_s1',nn.ReLU(inplace=True)) 37 | 38 | self.conv.add_module('conv5_s1',nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2)) 39 | self.conv.add_module('relu5_s1',nn.ReLU(inplace=True)) 40 | self.conv.add_module('pool5_s1',nn.MaxPool2d(kernel_size=3, stride=2)) 41 | 42 | self.fc6 = nn.Sequential() 43 | self.fc6.add_module('fc6_s1',nn.Linear(256*3*3, 1024)) 44 | self.fc6.add_module('relu6_s1',nn.ReLU(inplace=True)) 45 | self.fc6.add_module('drop6_s1',nn.Dropout(p=0.5)) 46 | 47 | self.fc7 = nn.Sequential() 48 | self.fc7.add_module('fc7',nn.Linear(9*1024,4096)) 49 | self.fc7.add_module('relu7',nn.ReLU(inplace=True)) 50 | self.fc7.add_module('drop7',nn.Dropout(p=0.5)) 51 | 52 | self.classifier = nn.Sequential() 53 | self.classifier.add_module('fc8',nn.Linear(4096, classes)) 54 | 55 | #self.apply(weights_init) 56 | 57 | def load(self,checkpoint): 58 | model_dict = self.state_dict() 59 | pretrained_dict = torch.load(checkpoint) 60 | pretrained_dict = {k: v for k, v in list(pretrained_dict.items()) if k in model_dict and 'fc8' not in k} 61 | model_dict.update(pretrained_dict) 62 | self.load_state_dict(model_dict) 63 | print([k for k, v in list(pretrained_dict.items())]) 64 | 65 | def save(self,checkpoint): 66 | torch.save(self.state_dict(), checkpoint) 67 | 68 | def forward(self, x): 69 | B,T,C,H,W = x.size() 70 | x = x.transpose(0,1) 71 | 72 | x_list = [] 73 | for i in range(9): 74 | z = self.conv(x[i]) 75 | z = self.fc6(z.view(B,-1)) 76 | z = z.view([B,1,-1]) 77 | x_list.append(z) 78 | 79 | x = cat(x_list,1) 80 | x = self.fc7(x.view(B,-1)) 81 | x = self.classifier(x) 82 | 83 | return x 84 | 85 | 86 | def weights_init(model): 87 | if type(model) in [nn.Conv2d,nn.Linear]: 88 | nn.init.xavier_normal(model.weight.data) 89 | nn.init.constant(model.bias.data, 0.1) 90 | -------------------------------------------------------------------------------- /JigsawTrain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Sep 14 12:16:31 2017 4 | 5 | @author: Biagio Brattoli 6 | """ 7 | import os, sys, numpy as np 8 | import argparse 9 | from time import time 10 | from tqdm import tqdm 11 | 12 | import tensorflow # needs to call tensorflow before torch, otherwise crush 13 | sys.path.append('Utils') 14 | from logger import Logger 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.autograd import Variable 19 | 20 | sys.path.append('Dataset') 21 | from JigsawNetwork import Network 22 | 23 | from TrainingUtils import adjust_learning_rate, compute_accuracy 24 | 25 | 26 | parser = argparse.ArgumentParser(description='Train JigsawPuzzleSolver on Imagenet') 27 | parser.add_argument('data', type=str, help='Path to Imagenet folder') 28 | parser.add_argument('--model', default=None, type=str, help='Path to pretrained model') 29 | parser.add_argument('--classes', default=1000, type=int, help='Number of permutation to use') 30 | parser.add_argument('--gpu', default=0, type=int, help='gpu id') 31 | parser.add_argument('--epochs', default=70, type=int, help='number of total epochs for training') 32 | parser.add_argument('--iter_start', default=0, type=int, help='Starting iteration count') 33 | parser.add_argument('--batch', default=256, type=int, help='batch size') 34 | parser.add_argument('--checkpoint', default='checkpoints/', type=str, help='checkpoint folder') 35 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate for SGD optimizer') 36 | parser.add_argument('--cores', default=0, type=int, help='number of CPU core for loading') 37 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 38 | help='evaluate model on validation set, No training') 39 | args = parser.parse_args() 40 | 41 | #from ImageDataLoader import DataLoader 42 | from JigsawImageLoader import DataLoader 43 | 44 | 45 | def main(): 46 | if args.gpu is not None: 47 | print(('Using GPU %d'%args.gpu)) 48 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 49 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 50 | else: 51 | print('CPU mode') 52 | 53 | print('Process number: %d'%(os.getpid())) 54 | 55 | ## DataLoader initialize ILSVRC2012_train_processed 56 | trainpath = args.data+'/ILSVRC2012_img_train' 57 | if os.path.exists(trainpath+'_255x255'): 58 | trainpath += '_255x255' 59 | train_data = DataLoader(trainpath,args.data+'/ilsvrc12_train.txt', 60 | classes=args.classes) 61 | train_loader = torch.utils.data.DataLoader(dataset=train_data, 62 | batch_size=args.batch, 63 | shuffle=True, 64 | num_workers=args.cores) 65 | 66 | valpath = args.data+'/ILSVRC2012_img_val' 67 | if os.path.exists(valpath+'_255x255'): 68 | valpath += '_255x255' 69 | val_data = DataLoader(valpath, args.data+'/ilsvrc12_val.txt', 70 | classes=args.classes) 71 | val_loader = torch.utils.data.DataLoader(dataset=val_data, 72 | batch_size=args.batch, 73 | shuffle=True, 74 | num_workers=args.cores) 75 | N = train_data.N 76 | 77 | iter_per_epoch = train_data.N/args.batch 78 | print('Images: train %d, validation %d'%(train_data.N,val_data.N)) 79 | 80 | # Network initialize 81 | net = Network(args.classes) 82 | if args.gpu is not None: 83 | net.cuda() 84 | 85 | ############## Load from checkpoint if exists, otherwise from model ############### 86 | if os.path.exists(args.checkpoint): 87 | files = [f for f in os.listdir(args.checkpoint) if 'pth' in f] 88 | if len(files)>0: 89 | files.sort() 90 | #print files 91 | ckp = files[-1] 92 | net.load_state_dict(torch.load(args.checkpoint+'/'+ckp)) 93 | args.iter_start = int(ckp.split(".")[-3].split("_")[-1]) 94 | print('Starting from: ',ckp) 95 | else: 96 | if args.model is not None: 97 | net.load(args.model) 98 | else: 99 | if args.model is not None: 100 | net.load(args.model) 101 | 102 | criterion = nn.CrossEntropyLoss() 103 | optimizer = torch.optim.SGD(net.parameters(),lr=args.lr,momentum=0.9,weight_decay = 5e-4) 104 | 105 | logger = Logger(args.checkpoint+'/train') 106 | logger_test = Logger(args.checkpoint+'/test') 107 | 108 | ############## TESTING ############### 109 | if args.evaluate: 110 | test(net,criterion,None,val_loader,0) 111 | return 112 | 113 | ############## TRAINING ############### 114 | print(('Start training: lr %f, batch size %d, classes %d'%(args.lr,args.batch,args.classes))) 115 | print(('Checkpoint: '+args.checkpoint)) 116 | 117 | # Train the Model 118 | batch_time, net_time = [], [] 119 | steps = args.iter_start 120 | for epoch in range(int(args.iter_start/iter_per_epoch),args.epochs): 121 | if epoch%10==0 and epoch>0: 122 | test(net,criterion,logger_test,val_loader,steps) 123 | lr = adjust_learning_rate(optimizer, epoch, init_lr=args.lr, step=20, decay=0.1) 124 | 125 | end = time() 126 | for i, (images, labels, original) in enumerate(train_loader): 127 | batch_time.append(time()-end) 128 | if len(batch_time)>100: 129 | del batch_time[0] 130 | 131 | images = Variable(images) 132 | labels = Variable(labels) 133 | if args.gpu is not None: 134 | images = images.cuda() 135 | labels = labels.cuda() 136 | 137 | # Forward + Backward + Optimize 138 | optimizer.zero_grad() 139 | t = time() 140 | outputs = net(images) 141 | net_time.append(time()-t) 142 | if len(net_time)>100: 143 | del net_time[0] 144 | 145 | prec1, prec5 = compute_accuracy(outputs.cpu().data, labels.cpu().data, topk=(1, 5)) 146 | acc = prec1[0] 147 | 148 | loss = criterion(outputs, labels) 149 | loss.backward() 150 | optimizer.step() 151 | loss = float(loss.cpu().data.numpy()) 152 | 153 | if steps%20==0: 154 | print(('[%2d/%2d] %5d) [batch load % 2.3fsec, net %1.2fsec], LR %.5f, Loss: % 1.3f, Accuracy % 2.2f%%' %( 155 | epoch+1, args.epochs, steps, 156 | np.mean(batch_time), np.mean(net_time), 157 | lr, loss,acc))) 158 | 159 | if steps%20==0: 160 | logger.scalar_summary('accuracy', acc, steps) 161 | logger.scalar_summary('loss', loss, steps) 162 | 163 | original = [im[0] for im in original] 164 | imgs = np.zeros([9,75,75,3]) 165 | for ti, img in enumerate(original): 166 | img = img.numpy() 167 | imgs[ti] = np.stack([(im-im.min())/(im.max()-im.min()) 168 | for im in img],axis=2) 169 | 170 | logger.image_summary('input', imgs, steps) 171 | 172 | steps += 1 173 | 174 | if steps%1000==0: 175 | filename = '%s/jps_%03i_%06d.pth.tar'%(args.checkpoint,epoch,steps) 176 | net.save(filename) 177 | print('Saved: '+args.checkpoint) 178 | 179 | end = time() 180 | 181 | if os.path.exists(args.checkpoint+'/stop.txt'): 182 | # break without using CTRL+C 183 | break 184 | 185 | def test(net,criterion,logger,val_loader,steps): 186 | print('Evaluating network.......') 187 | accuracy = [] 188 | net.eval() 189 | for i, (images, labels, _) in enumerate(val_loader): 190 | images = Variable(images) 191 | if args.gpu is not None: 192 | images = images.cuda() 193 | 194 | # Forward + Backward + Optimize 195 | outputs = net(images) 196 | outputs = outputs.cpu().data 197 | 198 | prec1, prec5 = compute_accuracy(outputs, labels, topk=(1, 5)) 199 | accuracy.append(prec1[0]) 200 | 201 | if logger is not None: 202 | logger.scalar_summary('accuracy', np.mean(accuracy), steps) 203 | print('TESTING: %d), Accuracy %.2f%%' %(steps,np.mean(accuracy))) 204 | net.train() 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JigsawPuzzlePytorch 2 | Pytorch implementation of the paper ["Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles"](https://arxiv.org/abs/1603.09246) by Mehdi Noroozi [GitHub](https://github.com/MehdiNoroozi/JigsawPuzzleSolver) 3 | 4 | **Partially tested** 5 | **Performances Coming Soon** 6 | 7 | # Dependencies 8 | - Tested with Python 2.7 9 | - [Pytorch](http://pytorch.org/) v0.3 10 | - [Tensorflow](https://www.tensorflow.org/) is used for logging. 11 | Remove the Logger all scripts if tensorflow is missing 12 | 13 | # Train the JigsawPuzzleSolver 14 | ## Setup Loader 15 | Two DataLoader are provided: 16 | - ImageLoader: per each iteration it loads data in image format (jpg,png ,...) 17 | - *Dataset/JigsawImageLoader.py* uses PyTorch DataLoader and iterator 18 | - *Dataset/ImageDataLoader.py* custom implementation. 19 | 20 | The default loader is *JigsawImageLoader.py*. *ImageDataLoader.py* is slightly faster when using single core. 21 | 22 | The images can be preprocessed using *_produce_small_data.py_* which resize the image to 256, keeping the aspect ratio, and crops a patch of size 255x255 in the center. 23 | 24 | ## Run Training 25 | Fill the path information in *run_jigsaw_training.sh*. 26 | IMAGENET_FOLD needs to point to the folder containing *ILSVRC2012_img_train*. 27 | 28 | ``` 29 | ./run_jigsaw_training.sh [GPU_ID] 30 | ``` 31 | or call the python script 32 | ``` 33 | python JigsawTrain.py [*path_to_imagenet*] --checkpoint [*path_checkpoints_and_logs*] --gpu [*GPU_ID*] --batch [*batch_size*] 34 | ``` 35 | By default the network uses 1000 permutations with maximum hamming distance selected using *select_permutations.py*. 36 | 37 | To change the file name loaded for the permutations, open the file *JigsawLoader.py* and change the permutation file in the method *retrive_permutations* 38 | 39 | # Details: 40 | - The input of the network should be 64x64, but I need to resize to 75x75, 41 | otherwise the output of conv5 is 2x2 instead of 3x3 like the official architecture 42 | - Jigsaw trained using the approach of the paper: SGD, LRN layers, 70 epochs 43 | - Implemented *shortcuts*: spatial jittering, normalize each patch indipendently, color jittering, 30% black&white image 44 | - The LRN layer crushes with a PyTorch version older than 0.3 45 | 46 | # ToDo 47 | - TensorboardX 48 | - LMDB DataLoader 49 | -------------------------------------------------------------------------------- /Utils/Layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | https://github.com/jiecaoyu/pytorch_imagenet/blob/master/networks/model_list/alexnet.py 4 | """ 5 | import torch.nn as nn 6 | 7 | class LRN(nn.Module): 8 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): 9 | super(LRN, self).__init__() 10 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 11 | if ACROSS_CHANNELS: 12 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 13 | stride=1,padding=(int((local_size-1.0)/2), 0, 0)) 14 | else: 15 | self.average=nn.AvgPool2d(kernel_size=local_size, 16 | stride=1,padding=int((local_size-1.0)/2)) 17 | self.alpha = alpha 18 | self.beta = beta 19 | 20 | 21 | def forward(self, x): 22 | if self.ACROSS_CHANNELS: 23 | div = x.pow(2).unsqueeze(1) 24 | div = self.average(div).squeeze(1) 25 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 26 | else: 27 | div = x.pow(2) 28 | div = self.average(div) 29 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 30 | x = x.div(div) 31 | return x -------------------------------------------------------------------------------- /Utils/TrainingUtils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Sep 22 16:53:30 2017 4 | 5 | @author: bbrattol 6 | """ 7 | 8 | def adjust_learning_rate(optimizer, epoch, init_lr=0.1, step=30, decay=0.1): 9 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 10 | lr = init_lr * (decay ** (epoch // step)) 11 | print('Learning Rate %f'%lr) 12 | for param_group in optimizer.param_groups: 13 | param_group['lr'] = lr 14 | return lr 15 | 16 | def compute_accuracy(output, target, topk=(1,)): 17 | """Computes the precision@k for the specified values of k""" 18 | maxk = max(topk) 19 | batch_size = target.size(0) 20 | 21 | _, pred = output.topk(maxk, 1, True, True) 22 | pred = pred.t() 23 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 24 | 25 | res = [] 26 | for k in topk: 27 | correct_k = correct[:k].view(-1).float().sum(0) 28 | res.append(correct_k.mul_(100.0 / batch_size)) 29 | return res 30 | 31 | -------------------------------------------------------------------------------- /Utils/convert2h5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 27 14:41:50 2017 4 | 5 | @author: bbrattol 6 | """ 7 | import argparse, sys 8 | 9 | sys.path.append('JpsTraininig') 10 | from JigsawNetwork import Network 11 | 12 | parser = argparse.ArgumentParser(description='Train JigsawPuzzleSolver on Imagenet') 13 | parser.add_argument('model', type=str, help='Path to pretrained model') 14 | parser.add_argument('classes', type=int, help='Number of permutation to use') 15 | args = parser.parse_args() 16 | 17 | 18 | def save_net(fname, net): 19 | import h5py 20 | h5f = h5py.File(fname, mode='w') 21 | for k, v in list(net.state_dict().items()): 22 | h5f.create_dataset(k, data=v.cpu().numpy()) 23 | 24 | net = Network(args.classes,groups=2) 25 | net.load(args.model) 26 | 27 | save_net(args.model[:-8]+'.h5',net) 28 | -------------------------------------------------------------------------------- /Utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*-" 3 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 4 | import tensorflow as tf 5 | import numpy as np 6 | import scipy.misc 7 | try: 8 | from io import StringIO # Python 2.7 9 | except ImportError: 10 | from io import BytesIO # Python 3.x 11 | 12 | 13 | class Logger(object): 14 | 15 | def __init__(self, log_dir): 16 | """Create a summary writer logging to log_dir.""" 17 | self.writer = tf.summary.FileWriter(log_dir) 18 | 19 | def scalar_summary(self, tag, value, step): 20 | """Log a scalar variable.""" 21 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 22 | self.writer.add_summary(summary, step) 23 | 24 | def image_summary(self, tag, images, step): 25 | """Log a list of images.""" 26 | 27 | img_summaries = [] 28 | for i, img in enumerate(images): 29 | # Write the image to a string 30 | try: 31 | s = StringIO() 32 | except: 33 | s = BytesIO() 34 | scipy.misc.toimage(img).save(s, format="png") 35 | 36 | # Create an Image object 37 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 38 | height=img.shape[0], 39 | width=img.shape[1]) 40 | # Create a Summary value 41 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 42 | 43 | # Create and write Summary 44 | summary = tf.Summary(value=img_summaries) 45 | self.writer.add_summary(summary, step) 46 | 47 | def histo_summary(self, tag, values, step, bins=1000): 48 | """Log a histogram of the tensor of values.""" 49 | 50 | # Create a histogram using numpy 51 | counts, bin_edges = np.histogram(values, bins=bins) 52 | 53 | # Fill the fields of the histogram proto 54 | hist = tf.HistogramProto() 55 | hist.min = float(np.min(values)) 56 | hist.max = float(np.max(values)) 57 | hist.num = int(np.prod(values.shape)) 58 | hist.sum = float(np.sum(values)) 59 | hist.sum_squares = float(np.sum(values**2)) 60 | 61 | # Drop the start of the first bin 62 | bin_edges = bin_edges[1:] 63 | 64 | # Add bin edges and counts 65 | for edge in bin_edges: 66 | hist.bucket_limit.append(edge) 67 | for c in counts: 68 | hist.bucket.append(c) 69 | 70 | # Create and write Summary 71 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 72 | self.writer.add_summary(summary, step) 73 | self.writer.flush() -------------------------------------------------------------------------------- /permutations_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbrattoli/JigsawPuzzlePytorch/ec85994b9f244d08652a3975c1c7a55483cdfc05/permutations_1000.npy -------------------------------------------------------------------------------- /run_jigsaw_training.sh: -------------------------------------------------------------------------------- 1 | IMAGENET_FOLD=path_to_ILSVRC2012_img 2 | 3 | GPU=${1} # gpu used 4 | CHECKPOINTS_FOLD=${2} #path_to_output_folder 5 | 6 | #python JigsawTrain.py ${IMAGENET_FOLD} --checkpoint=${CHECKPOINTS_FOLD} \ 7 | # --classes=1000 --batch 128 --lr=0.001 --gpu=${GPU} --cores=10 8 | python JigsawTrain.py ${IMAGENET_FOLD} --classes=1000 --batch 128 --lr=0.001 --cores=10 9 | -------------------------------------------------------------------------------- /select_permutations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Sep 14 15:50:28 2017 4 | 5 | @author: bbrattol 6 | """ 7 | import argparse 8 | from tqdm import trange 9 | import numpy as np 10 | import itertools 11 | from scipy.spatial.distance import cdist 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Train network on Imagenet') 15 | parser.add_argument('--classes', default=1000, type=int, 16 | help='Number of permutations to select') 17 | parser.add_argument('--selection', default='max', type=str, 18 | help='Sample selected per iteration based on hamming distance: [max] highest; [mean] average') 19 | args = parser.parse_args() 20 | 21 | if __name__ == "__main__": 22 | outname = 'permutations/permutations_hamming_%s_%d'%(args.selection,args.classes) 23 | 24 | P_hat = np.array(list(itertools.permutations(list(range(9)), 9))) 25 | n = P_hat.shape[0] 26 | 27 | for i in trange(args.classes): 28 | if i==0: 29 | j = np.random.randint(n) 30 | P = np.array(P_hat[j]).reshape([1,-1]) 31 | else: 32 | P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0) 33 | 34 | P_hat = np.delete(P_hat,j,axis=0) 35 | D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten() 36 | 37 | if args.selection=='max': 38 | j = D.argmax() 39 | else: 40 | m = int(D.shape[0]/2) 41 | S = D.argsort() 42 | j = S[np.random.randint(m-10,m+10)] 43 | 44 | if i%100==0: 45 | np.save(outname,P) 46 | 47 | np.save(outname,P) 48 | print('file created --> '+outname) 49 | --------------------------------------------------------------------------------