├── .gitignore ├── AlexNetModel.py ├── README.md ├── ResNetModel.py ├── SketchANetModel.py ├── Tools ├── GetImageMean_Std.py ├── ListAllImageName.py ├── SplitDataset.py └── create_filelist.sh ├── Train.py ├── alexnet.py ├── filelist_data_loader.py ├── resnet.py ├── resume_train.sh ├── run_train.sh └── sketchanet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | runs/ 3 | model.py -------------------------------------------------------------------------------- /AlexNetModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | from alexnet import alexnet 7 | 8 | class AlexNetModel(nn.Module): 9 | def __init__(self, num_classes=None): 10 | super(AlexNetModel, self).__init__() 11 | self.base = alexnet(pretrained=False) 12 | 13 | planes = 4096 14 | if num_classes is not None: 15 | self.fc = nn.Linear(planes, num_classes) 16 | init.normal(self.fc.weight, std=0.001) 17 | init.constant(self.fc.bias, 0.1) 18 | 19 | def forward(self, x): 20 | feat = self.base(x) 21 | 22 | if hasattr(self, 'fc'): 23 | logits = self.fc(feat) 24 | return feat, logits 25 | 26 | return feat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sketch Classification 2 | A PyTorch Implementation for Sketch Classification Networks. 3 | 4 | ## Model Configuration 5 | - Optimizer 6 | - Adam 7 | ## DataSet 8 | TU-Berlin sketch dataset 9 | 10 | | Model | input_size | 11 | | ------------ | ----------- | 12 | | (raw size)* | 1111 * 1111 | 13 | | AlexNet | 224 * 224 | 14 | | SketchANet | 225 * 225 | 15 | | ResNet18 | 224 * 224 | 16 | | ResNet34 | 224 * 224 | 17 | | ResNet50 | 224 * 224 | 18 | | DenseNet121 | 224 * 224 | 19 | | Inception_v3 | 299 * 299 | 20 | 21 | 22 | ## Model Parameters 23 | | Model | lr | clip_grad_norm(max_norm) | learning rate decay | weight_decay | 24 | | ------------------------ | ---- | ------------------------ | ------------------- | --------------- | 25 | | AlexNet(pretrained) | 2e-4 | -- | 20 | 0.0005 | 26 | | AlexNet(scratch) | 2e-5 | 0.5 - 100.0 | 30 | 0.0005 | 27 | | SketchANet(DogsCats)* | 2e-5 | 0.5 - 1.0 | 30 | 0.0005 | 28 | | SketchANet(scratch) | 2e-5 | 0.5 - 100.0 | 800 | 0.0001 - 0.0003 | 29 | | ResNet18(pretrained) | 2e-4 | -- | 20 | 0.0005 | 30 | | ResNet34(pretrained) | 2e-4 | -- | 20 | 0.0001 | 31 | | ResNet50(pretrained) | 2e-4 | -- | 20 | 0.0005 | 32 | | DenseNet121(pretrained) | 2e-4 | -- | 20 | 0.0005 | 33 | | Inception_v3(pretrained) | 2e-4 | -- | 30 | 0.0005 | 34 | * *This is for test Model. 35 | 36 | ## Model Result 37 | ### Train Set 38 | | Model | Prec@1 | Prec@5 | 39 | | ------------------------ | ------- | ------ | 40 | | AlexNet(pretrained) | 93.4455 | 99.787 | 41 | | AlexNet(scratch) | 99.3024 | 99.988 | 42 | | SketchANet(scratch) | 86.3166 | 98.667 | 43 | | ResNet18(pretrained) | 96.9899 | 99.954 | 44 | | ResNet34(pretrained) | 97.1048 | 99.954 | 45 | | ResNet50(pretrained) | 98.3049 | 99.988 | 46 | | DenseNet121(pretrained) | 91.4301 | 99.596 | 47 | | Inception_v3(pretrained) | 91.8802 | 99.706 | 48 | 49 | 50 | ### Test Set 51 | | Model | Prec@1 | Prec@5 | 52 | | ------------------------ | ------ | ------ | 53 | | Human | 73.1 | -- | 54 | | AlexNeti | 68.6 | -- | 55 | | AlexNetii | 77.29 | -- | 56 | | GoogLeNetii | 80.85 | -- | 57 | | AlexNet(pretrained) | 70.850 | 90.050 | 58 | | AlexNet(scratch) | 53.850 | 78.000 | 59 | | SketchANet(scratch) | 68.700 | 88.900 | 60 | | ResNet18(pretrained) | 77.800 | 94.650 | 61 | | ResNet34(pretrained) | 79.100 | 95.050 | 62 | | ResNet50(pretrained) | 78.300 | 95.300 | 63 | | DenseNet121(pretrained) | 77.550 | 93.500 | 64 | | Inception_v3(pretrained) | 76.550 | 93.750 | 65 | 66 | * 1. *Sketch-a-Net that Beats Humans* 67 | 2. *The Sketchy Database: Learning to Retrieve Badly Drawn Bunnies* 68 | 69 | DPN, ShuffleNetG2, SENet18 70 | ## Tools 71 | - GetImageMean_Std 72 | 73 | Get image dataset mean and standard deviation. 74 | 75 | - SplitDataset 76 | 77 | Split image dataset according to the train and val record txt file. 78 | 79 | - ListAllImageName 80 | 81 | Get all image name in dataset. -------------------------------------------------------------------------------- /ResNetModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | from resnet import resnet18, resnet34, resnet50 7 | 8 | class ResNetModel(nn.Module): 9 | def __init__(self, num_classes=None): 10 | super(ResNetModel, self).__init__() 11 | self.base = resnet34(pretrained=True) 12 | 13 | planes = 512 14 | 15 | if num_classes is not None: 16 | self.fc = nn.Linear(planes, num_classes) 17 | init.xavier_uniform(self.fc.weight) 18 | init.constant(self.fc.bias, 0.1) 19 | 20 | def forward(self, x): 21 | # shape [N, C, H, W] 22 | feat = self.base(x) 23 | global_feat = F.avg_pool2d(feat, feat.size()[2:]) 24 | # shape [N, C] 25 | global_feat = global_feat.view(global_feat.size(0), -1) 26 | 27 | if hasattr(self, 'fc'): 28 | logits = self.fc(global_feat) 29 | return global_feat, logits 30 | 31 | # return global_feat, local_feat 32 | return global_feat 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /SketchANetModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | from sketchanet import sketchanet 7 | 8 | class SketchANetModel(nn.Module): 9 | def __init__(self, num_classes=None): 10 | super(SketchANetModel, self).__init__() 11 | self.base = sketchanet(pretrained=False) 12 | 13 | planes = 512 14 | if num_classes is not None: 15 | self.fc = nn.Linear(planes, num_classes) 16 | init.normal(self.fc.weight, std=0.001) 17 | init.constant(self.fc.bias, 0.1) 18 | 19 | def forward(self, x): 20 | feat = self.base(x) 21 | 22 | if hasattr(self, 'fc'): 23 | logits = self.fc(feat) 24 | return feat, logits 25 | 26 | return feat 27 | -------------------------------------------------------------------------------- /Tools/GetImageMean_Std.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os.path as osp 6 | from progressbar import Bar, ProgressBar 7 | from scipy.misc import imread, imresize 8 | 9 | def is_image(ext): 10 | ext = ext.lower() 11 | if ext == '.jpg': 12 | return True 13 | elif ext == '.png': 14 | return True 15 | elif ext == '.jpeg': 16 | return True 17 | elif ext == '.bmp': 18 | return True 19 | else: 20 | return False 21 | 22 | def get_all_image_names(rootdir, image_names_list=[]): 23 | for file in os.listdir(rootdir): 24 | filepath = osp.join(rootdir, file) 25 | if osp.isdir(filepath): 26 | get_all_image_names(filepath, image_names_list) 27 | elif osp.isfile(filepath): 28 | ext = osp.splitext(filepath)[1] 29 | if is_image(ext): 30 | image_names_list.append(filepath) 31 | image_names_list = sorted(image_names_list) 32 | return image_names_list 33 | 34 | def GetImageMean(rootdir, size=(256, 256)): 35 | R_channel = [] 36 | G_channel = [] 37 | B_channel = [] 38 | image_names_list = get_all_image_names(rootdir) 39 | progress = ProgressBar(max_value= len(image_names_list)) 40 | for i, name in enumerate(image_names_list): 41 | img = imread(name) 42 | img = imresize(img, size) 43 | if(img.shape[-1] == 3 or img.shape[-1] == 4): 44 | R_channel.append(img[:, :, 0]) 45 | G_channel.append(img[:, :, 1]) 46 | B_channel.append(img[:, :, 2]) 47 | else: 48 | R_channel.append(img[:, :]) 49 | 50 | progress.update(i) 51 | progress.finish() 52 | 53 | # num = len(image_names_list) * size[0] * size[1] 54 | 55 | if (img.shape[-1] == 3 or img.shape[-1] == 4): 56 | R_mean = np.mean(np.asarray(R_channel)) 57 | G_mean = np.mean(np.asarray(G_channel)) 58 | B_mean = np.mean(np.asarray(B_channel)) 59 | 60 | R_std = np.std(np.asarray(R_channel)) 61 | G_std = np.std(np.asarray(G_channel)) 62 | B_std = np.std(np.asarray(B_channel)) 63 | return R_mean, G_mean, B_mean, R_std, G_std, B_std 64 | else: 65 | R_mean = np.mean(np.asarray(R_channel)) 66 | R_std = np.std(np.asarray(R_channel)) 67 | return R_mean, R_std 68 | 69 | 70 | if __name__ == "__main__": 71 | rootdir = r"/home/bc/Work/Database/TU-Berlin sketch dataset/png" 72 | mean = GetImageMean(rootdir) 73 | print(mean) -------------------------------------------------------------------------------- /Tools/ListAllImageName.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os.path as osp 6 | from scipy.misc import imread, imresize 7 | 8 | def is_image(ext): 9 | ext = ext.lower() 10 | if ext == '.jpg': 11 | return True 12 | elif ext == '.png': 13 | return True 14 | elif ext == '.jpeg': 15 | return True 16 | elif ext == '.bmp': 17 | return True 18 | else: 19 | return False 20 | 21 | def get_all_image_names(rootdir, image_names_list=[]): 22 | for file in os.listdir(rootdir): 23 | filepath = osp.join(rootdir, file) 24 | if osp.isdir(filepath): 25 | get_all_image_names(filepath, image_names_list) 26 | elif osp.isfile(filepath): 27 | ext = osp.splitext(filepath)[1] 28 | if is_image(ext): 29 | image_names_list.append(osp.join(osp.split(rootdir)[1], file)) 30 | image_names_list = sorted(image_names_list) 31 | return image_names_list 32 | 33 | def save_image_list(image_names_list, save_filename): 34 | f = open(save_filename, 'w') 35 | image_names_list = [line+'\n' for line in image_names_list] 36 | f.writelines(image_names_list) 37 | f.close() 38 | 39 | if __name__ == "__main__": 40 | data_root=r"/home/bc/Work/Database/TU-Berlin sketch dataset/png" 41 | save_filename=r"./train.txt" 42 | image_names_list = get_all_image_names(data_root) 43 | save_image_list(image_names_list, save_filename) 44 | 45 | 46 | -------------------------------------------------------------------------------- /Tools/SplitDataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | import shutil 7 | 8 | def copyfile(srcfile, dstfile): 9 | if not osp.isfile(srcfile): 10 | print("%s not exist!"%(srcfile)) 11 | else: 12 | fpath, fname = osp.split(dstfile) 13 | if not osp.exists(fpath): 14 | os.makedirs(fpath) 15 | shutil.copyfile(srcfile, dstfile) 16 | 17 | def movetodir(record_file, Data_root, dataset_path): 18 | with open(record_file, 'r') as f: 19 | for line in f: 20 | src_img_path = osp.join(Data_root, line.rstrip()) 21 | dst_img_path = osp.join(dataset_path, line.rstrip()) 22 | copyfile(src_img_path, dst_img_path) 23 | 24 | def main(): 25 | Data_root = "/home/bc/Work/Database/TU-Berlin sketch dataset/png" 26 | Train_record_file = "/home/bc/Work/Database/TU-Berlin sketch dataset/png/train_list.txt" 27 | Val_record_file = "/home/bc/Work/Database/TU-Berlin sketch dataset/png/val_list.txt" 28 | 29 | 30 | Dataset_train_path = osp.join(Data_root, "../train_val/train") 31 | if not osp.exists(Dataset_train_path): 32 | os.makedirs(Dataset_train_path) 33 | 34 | Dataset_val_path = osp.join(Data_root, "../train_val/val") 35 | if not osp.exists(Dataset_val_path): 36 | os.makedirs(Dataset_val_path) 37 | 38 | movetodir(Train_record_file, Data_root, Dataset_train_path) 39 | movetodir(Val_record_file, Data_root, Dataset_val_path) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /Tools/create_filelist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=dataset/images 4 | echo "Create train.txt..." 5 | rm -rf $DATA/train.txt 6 | ls bike | sed "s:^:bike/:" | sed "s:$: 1:" >> train.txt 7 | ls cat | sed "s:^:cat/:" | sed "s:$: 2:" >> train.txt -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os 4 | import shutil 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim as optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | from torchvision import datasets, transforms 16 | from torch.autograd import Variable 17 | import torch.backends.cudnn as cudnn 18 | from torch.nn.utils.clip_grad import clip_grad_norm 19 | from SketchANetModel import SketchANetModel 20 | from AlexNetModel import AlexNetModel 21 | from ResNetModel import ResNetModel 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch Sketch Me That Shoe Example') 24 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 25 | help='input batch size for training (default: 64)') 26 | parser.add_argument('--test-batch-size', type=int, default=10, metavar='N', 27 | help='input batch size for testing (default: 10)') 28 | parser.add_argument('--epochs', type=int, default=2000, metavar='N', 29 | help='number of epochs to train (default: 10)') 30 | parser.add_argument('--weight_decay', type=float, default=0.0005, 31 | help='Adm weight decay') 32 | parser.add_argument('--lr', type=float, default=2e-4, metavar='LR', 33 | help='learning rate (default: 0.01)') 34 | parser.add_argument('--no-cuda', action='store_true', default=False, 35 | help='enables CUDA training') 36 | parser.add_argument('--seed', type=int, default=1, metavar='S', 37 | help='random seed (default: 1)') 38 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 39 | help='how many batches to wait before logging training status') 40 | parser.add_argument('--print-freq', '-p', default=15, type=int, metavar='N', 41 | help='print frequency (default: 10)') 42 | parser.add_argument('--classes', type=int, default=419, 43 | help='number of classes') 44 | parser.add_argument('--resume', default='', type=str, 45 | help='path to latest checkpoint (default: none)') 46 | parser.add_argument('--name', default='TripletNetModel', type=str, 47 | help='name of experiment') 48 | parser.add_argument('--normalize_feature', default=False, type=bool, 49 | help='normalize_feature') 50 | 51 | best_acc = 0 52 | 53 | 54 | def to_scalar(vt): 55 | """Transform a length-1 pytorch Variable or Tensor to scalar. 56 | Suppose tx is a torch Tensor with shape tx.size() = torch.Size([1]), 57 | then npx = tx.cpu().numpy() has shape (1,), not 1.""" 58 | if isinstance(vt, Variable): 59 | return vt.data.cpu().numpy().flatten()[0] 60 | if torch.is_tensor(vt): 61 | return vt.cpu().numpy().flatten()[0] 62 | raise TypeError('Input should be a variable or tensor') 63 | 64 | 65 | def main(): 66 | global args, best_acc 67 | args = parser.parse_args() 68 | args.cuda = not args.no_cuda and torch.cuda.is_available() 69 | torch.manual_seed(args.seed) 70 | if args.cuda: 71 | torch.cuda.manual_seed(args.seed) 72 | 73 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} 74 | 75 | ###### DataSet ###### 76 | sketch_dir = r"/home/bc/Work/Database/TU-Berlin sketch dataset/png" 77 | # sketch_dir = r"/home/bc/Work/Database/Dogs_Cats/catdog/train" 78 | train_dataset = datasets.ImageFolder( 79 | sketch_dir, 80 | transform=transforms.Compose([ 81 | transforms.Resize([256, 256]), 82 | transforms.CenterCrop(224), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.RandomVerticalFlip(), 85 | transforms.RandomRotation(45), 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 88 | std=[0.229, 0.224, 0.225]) 89 | ]) 90 | ) 91 | 92 | train_loader = torch.utils.data.DataLoader( 93 | train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs 94 | ) 95 | test_dir = r"/home/bc/Work/Database/Dogs_Cats/catdog/val" 96 | test_dataset = datasets.ImageFolder( 97 | test_dir, 98 | transform=transforms.Compose([ 99 | transforms.Resize([256, 256]), 100 | transforms.CenterCrop(224), 101 | transforms.ToTensor(), 102 | #transforms.Normalize(mean=[0.485, 0.456, 0.406], 103 | # std=[0.229, 0.224, 0.225]) 104 | ]) 105 | ) 106 | test_loader = torch.utils.data.DataLoader( 107 | test_dataset, batch_size=args.test_batch_size, shuffle=True, **kwargs 108 | ) 109 | ###### Model ###### 110 | 111 | # snet = SketchANetModel(num_classes=250) 112 | # snet = AlexNetModel(num_classes=250) 113 | snet = ResNetModel(num_classes=250) 114 | print(snet) 115 | if args.cuda: 116 | snet.cuda() 117 | 118 | if args.resume: 119 | if os.path.isfile(args.resume): 120 | print("=> loading checkpoint '{}'".format(args.resume)) 121 | checkpoint = torch.load(args.resume) 122 | args.start_epoch = checkpoint['epoch'] 123 | best_acc = checkpoint['best_prec'] 124 | snet.load_state_dict(checkpoint['state_dict']) 125 | print("=> loaded checkpoint '{}' (epoch {})" 126 | .format(args.resume, checkpoint['epoch'])) 127 | else: 128 | print("=> no checkpoint found at '{}'".format(args.resume)) 129 | 130 | cudnn.benchmark = True 131 | 132 | ###### Criteria ###### 133 | id_criterion = nn.CrossEntropyLoss() 134 | optimizer = optim.Adam(snet.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) 135 | 136 | n_parameters = sum([p.data.nelement() for p in snet.parameters()]) 137 | print(' + Number of params: {}'.format(n_parameters)) 138 | 139 | for epoch in range(1, args.epochs + 1): 140 | adjust_learning_rate(optimizer, epoch) 141 | # train for one epoch 142 | train(train_loader, snet, id_criterion, optimizer, epoch) 143 | # evaluate on validation set 144 | # prec1 = test(test_loader, snet, id_criterion, epoch) 145 | 146 | # remember best Accuracy and save checkpoint 147 | #is_best = prec1 > best_acc 148 | is_best = True 149 | #best_acc = max(prec1, best_acc) 150 | save_checkpoint({ 151 | 'epoch': epoch + 1, 152 | 'state_dict': snet.state_dict(), 153 | 'best_prec': best_acc, 154 | }, is_best) 155 | 156 | def train(train_loader, snet, id_criterion, optimizer, epoch): 157 | batch_time = AverageMeter() 158 | data_time = AverageMeter() 159 | losses = AverageMeter() 160 | top1 = AverageMeter() 161 | top5 = AverageMeter() 162 | 163 | # switch to train mode 164 | snet.train() 165 | 166 | end = time.time() 167 | for batch_indx, (input, target) in enumerate(train_loader): 168 | # measure data loading time 169 | data_time.update(time.time() - end) 170 | 171 | if args.cuda: 172 | input, target = input.cuda(), target.cuda() 173 | input, target = Variable(input), Variable(target) 174 | 175 | # compute output 176 | _, output = snet(input) 177 | 178 | # print(output.data[0]) 179 | loss = id_criterion(output, target) 180 | 181 | # measure accuracy and record loss 182 | prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5)) 183 | losses.update(loss.data[0], input.size(0)) 184 | top1.update(prec1[0], input.size(0)) 185 | top5.update(prec5[0], input.size(0)) 186 | 187 | # compute gradient and do SGD step 188 | optimizer.zero_grad() 189 | loss.backward() 190 | # clip_grad_norm(snet.parameters(), 100.0) 191 | optimizer.step() 192 | 193 | # measure elapsed time 194 | batch_time.update(time.time() - end) 195 | end = time.time() 196 | 197 | if batch_indx % args.print_freq == 0: 198 | print('Epoch: [{0}][{1}/{2}]\t' 199 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 200 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 201 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 202 | 'Prec@1 {top1.val:.4f} ({top1.avg:.4f})\t' 203 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 204 | epoch, batch_indx, len(train_loader), batch_time=batch_time, 205 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 206 | 207 | def test(test_loader, snet, criterion, epoch): 208 | batch_time = AverageMeter() 209 | losses = AverageMeter() 210 | top1 = AverageMeter() 211 | top5 = AverageMeter() 212 | 213 | # switch to evaluate mode 214 | snet.eval() 215 | 216 | end = time.time() 217 | for batch_indx, (input, target) in enumerate(test_loader): 218 | if args.cuda: 219 | input, target = input.cuda(), target.cuda() 220 | input, target = Variable(input), Variable(target) 221 | 222 | # compute output 223 | _, output = snet(input) 224 | output = snet(input) 225 | loss = criterion(output, target) 226 | 227 | # measure accuracy and record loss 228 | prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5)) 229 | losses.update(loss.data[0], input.size(0)) 230 | top1.update(prec1[0], input.size(0)) 231 | top5.update(prec5[0], input.size(0)) 232 | 233 | # measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | if batch_indx % args.print_freq == 0: 238 | print('Test: [{0}/{1}]\t' 239 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 240 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 241 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 242 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 243 | batch_indx, len(test_loader), batch_time=batch_time, loss=losses, 244 | top1=top1, top5=top5)) 245 | 246 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 247 | .format(top1=top1, top5=top5)) 248 | 249 | return top1.avg 250 | 251 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 252 | """Saves checkpoint to disk""" 253 | directory = "runs/%s/" % (args.name) 254 | if not os.path.exists(directory): 255 | os.makedirs(directory) 256 | filename = directory + filename 257 | torch.save(state, filename) 258 | if is_best: 259 | shutil.copyfile(filename, 'runs/%s/' % (args.name) + 'model_best.pth.tar') 260 | 261 | class AverageMeter(object): 262 | """Computes and stores the average and current value""" 263 | 264 | def __init__(self): 265 | self.reset() 266 | 267 | def reset(self): 268 | self.val = 0 269 | self.avg = 0 270 | self.sum = 0 271 | self.count = 0 272 | 273 | def update(self, val, n=1): 274 | self.val = val 275 | self.sum += val * n 276 | self.count += n 277 | self.avg = self.sum / self.count 278 | 279 | def adjust_learning_rate(optimizer, epoch): 280 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 281 | lr = args.lr * (0.1**(epoch // 10)) 282 | for param_group in optimizer.state_dict()['param_groups']: 283 | param_group['lr'] = lr 284 | 285 | def accuracy(output, target, topk=(1,)): 286 | """Computes the precision@k for the specified values of k""" 287 | maxk = max(topk) 288 | batch_size = target.size(0) 289 | _, pred = output.topk(maxk, 1, True, True) 290 | pred = pred.t() 291 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 292 | 293 | res = [] 294 | for k in topk: 295 | correct_k = correct[:k].view(-1).float().sum(0) 296 | res.append(correct_k.mul_(100.0 / batch_size)) 297 | return res 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | __all__ = ['AlexNet', 'alexnet'] 9 | 10 | 11 | model_urls = { 12 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 13 | } 14 | 15 | 16 | class AlexNet(nn.Module): 17 | 18 | def __init__(self, num_classes=1000): 19 | super(AlexNet, self).__init__() 20 | self.features = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(kernel_size=3, stride=2), 27 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.MaxPool2d(kernel_size=3, stride=2), 34 | ) 35 | self.classifier = nn.Sequential( 36 | nn.Dropout(), 37 | nn.Linear(256 * 6 * 6, 4096), 38 | nn.ReLU(inplace=True), 39 | nn.Dropout(), 40 | nn.Linear(4096, 4096), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(4096, num_classes), 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = x.view(x.size(0), 256 * 6 * 6) 48 | x = self.classifier(x) 49 | return x 50 | 51 | 52 | def alexnet(pretrained=False, **kwargs): 53 | r"""AlexNet model architecture from the 54 | `"One weird trick..." `_ paper. 55 | 56 | Args: 57 | pretrained (bool): If True, returns a model pre-trained on ImageNet 58 | """ 59 | model = AlexNet(**kwargs) 60 | if pretrained: 61 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 62 | 63 | new_classifier = nn.Sequential(*list(model.classifier.children())[:-1]) 64 | model.classifier = new_classifier 65 | return model -------------------------------------------------------------------------------- /filelist_data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | from __future__ import print_function 5 | 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | import os 9 | import os.path as osp 10 | import sys 11 | import json 12 | import torch.utils.data 13 | import torchvision.transforms as transforms 14 | 15 | 16 | def default_image_loader(path): 17 | # return plt.imread(path) 18 | return Image.open(path).convert('RGB') 19 | 20 | class SketchImageLoader(torch.utils.data.Dataset): 21 | def __init__(self, base_path, filelist_filename, mode="train", transform=None, loader=default_image_loader): 22 | pass 23 | 24 | def __getitem__(self, index): 25 | pass 26 | 27 | def __len__(self): 28 | pass 29 | 30 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from torch.nn import functional as F 5 | from itertools import chain 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, layers): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion), 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): 131 | layers.append(block(self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): python 2.7.12 153 | for key, value in list(state_dict.items()): #python 3.5.4 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False): 160 | """Constructs a ResNet-18 model. 161 | 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | """ 165 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 166 | if pretrained: 167 | print("model_urls['resnet18']: {}".format(model_urls['resnet18'])) 168 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 169 | return model 170 | 171 | 172 | def resnet34(pretrained=False): 173 | """Constructs a ResNet-34 model. 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 179 | if pretrained: 180 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 181 | return model 182 | 183 | 184 | def resnet50(pretrained=False): 185 | """Constructs a ResNet-50 model. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 191 | if pretrained: 192 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 193 | return model 194 | 195 | 196 | def resnet101(pretrained=False): 197 | """Constructs a ResNet-101 model. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 203 | if pretrained: 204 | model.load_state_dict( 205 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 206 | return model 207 | 208 | 209 | def resnet152(pretrained=False): 210 | """Constructs a ResNet-152 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 216 | if pretrained: 217 | model.load_state_dict( 218 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 219 | return model 220 | -------------------------------------------------------------------------------- /resume_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python Train.py \ 4 | --batch-size 128 \ 5 | --resume ./runs/NetModel/checkpoint.pth.tar -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python Train.py \ 4 | --batch-size 128 -------------------------------------------------------------------------------- /sketchanet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | model_paths ={ 7 | 'sketchanet': '', 8 | } 9 | 10 | class SketchANet(nn.Module): 11 | def __init__(self, num_classes=250): 12 | super(SketchANet, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(3, 64, kernel_size=15, stride=3, padding=0), 15 | nn.ReLU(inplace=True), 16 | nn.MaxPool2d(kernel_size=3, stride=2), 17 | nn.Conv2d(64, 128, kernel_size=5, padding=0), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(kernel_size=3, stride=2), 20 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(kernel_size=3, stride=2), 27 | ) 28 | self.classifier = nn.Sequential( 29 | nn.Linear(256 * 6 * 6, 512), 30 | nn.ReLU(inplace=True), 31 | nn.Dropout(), 32 | nn.Linear(512, 512), 33 | nn.ReLU(inplace=True), 34 | nn.Dropout(), 35 | nn.Linear(512, num_classes), 36 | ) 37 | 38 | def forward(self, x): 39 | x = self.conv(x) 40 | x = x.view(x.size(0), 256 * 6 * 6) 41 | x = self.classifier(x) 42 | return x 43 | 44 | def sketchanet(pretrained=False, **kwargs): 45 | model = SketchANet(**kwargs) 46 | if pretrained: 47 | model.load_state_dict(torch.load(model_paths['sketchanet'])) 48 | 49 | new_classifer = nn.Sequential(*list(model.classifier.children())[:-1]) 50 | model.classifier = new_classifer 51 | return model --------------------------------------------------------------------------------