├── .gitignore
├── AlexNetModel.py
├── README.md
├── ResNetModel.py
├── SketchANetModel.py
├── Tools
├── GetImageMean_Std.py
├── ListAllImageName.py
├── SplitDataset.py
└── create_filelist.sh
├── Train.py
├── alexnet.py
├── filelist_data_loader.py
├── resnet.py
├── resume_train.sh
├── run_train.sh
└── sketchanet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | runs/
3 | model.py
--------------------------------------------------------------------------------
/AlexNetModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | from alexnet import alexnet
7 |
8 | class AlexNetModel(nn.Module):
9 | def __init__(self, num_classes=None):
10 | super(AlexNetModel, self).__init__()
11 | self.base = alexnet(pretrained=False)
12 |
13 | planes = 4096
14 | if num_classes is not None:
15 | self.fc = nn.Linear(planes, num_classes)
16 | init.normal(self.fc.weight, std=0.001)
17 | init.constant(self.fc.bias, 0.1)
18 |
19 | def forward(self, x):
20 | feat = self.base(x)
21 |
22 | if hasattr(self, 'fc'):
23 | logits = self.fc(feat)
24 | return feat, logits
25 |
26 | return feat
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sketch Classification
2 | A PyTorch Implementation for Sketch Classification Networks.
3 |
4 | ## Model Configuration
5 | - Optimizer
6 | - Adam
7 | ## DataSet
8 | TU-Berlin sketch dataset
9 |
10 | | Model | input_size |
11 | | ------------ | ----------- |
12 | | (raw size)* | 1111 * 1111 |
13 | | AlexNet | 224 * 224 |
14 | | SketchANet | 225 * 225 |
15 | | ResNet18 | 224 * 224 |
16 | | ResNet34 | 224 * 224 |
17 | | ResNet50 | 224 * 224 |
18 | | DenseNet121 | 224 * 224 |
19 | | Inception_v3 | 299 * 299 |
20 |
21 |
22 | ## Model Parameters
23 | | Model | lr | clip_grad_norm(max_norm) | learning rate decay | weight_decay |
24 | | ------------------------ | ---- | ------------------------ | ------------------- | --------------- |
25 | | AlexNet(pretrained) | 2e-4 | -- | 20 | 0.0005 |
26 | | AlexNet(scratch) | 2e-5 | 0.5 - 100.0 | 30 | 0.0005 |
27 | | SketchANet(DogsCats)* | 2e-5 | 0.5 - 1.0 | 30 | 0.0005 |
28 | | SketchANet(scratch) | 2e-5 | 0.5 - 100.0 | 800 | 0.0001 - 0.0003 |
29 | | ResNet18(pretrained) | 2e-4 | -- | 20 | 0.0005 |
30 | | ResNet34(pretrained) | 2e-4 | -- | 20 | 0.0001 |
31 | | ResNet50(pretrained) | 2e-4 | -- | 20 | 0.0005 |
32 | | DenseNet121(pretrained) | 2e-4 | -- | 20 | 0.0005 |
33 | | Inception_v3(pretrained) | 2e-4 | -- | 30 | 0.0005 |
34 | * *This is for test Model.
35 |
36 | ## Model Result
37 | ### Train Set
38 | | Model | Prec@1 | Prec@5 |
39 | | ------------------------ | ------- | ------ |
40 | | AlexNet(pretrained) | 93.4455 | 99.787 |
41 | | AlexNet(scratch) | 99.3024 | 99.988 |
42 | | SketchANet(scratch) | 86.3166 | 98.667 |
43 | | ResNet18(pretrained) | 96.9899 | 99.954 |
44 | | ResNet34(pretrained) | 97.1048 | 99.954 |
45 | | ResNet50(pretrained) | 98.3049 | 99.988 |
46 | | DenseNet121(pretrained) | 91.4301 | 99.596 |
47 | | Inception_v3(pretrained) | 91.8802 | 99.706 |
48 |
49 |
50 | ### Test Set
51 | | Model | Prec@1 | Prec@5 |
52 | | ------------------------ | ------ | ------ |
53 | | Human | 73.1 | -- |
54 | | AlexNeti | 68.6 | -- |
55 | | AlexNetii | 77.29 | -- |
56 | | GoogLeNetii | 80.85 | -- |
57 | | AlexNet(pretrained) | 70.850 | 90.050 |
58 | | AlexNet(scratch) | 53.850 | 78.000 |
59 | | SketchANet(scratch) | 68.700 | 88.900 |
60 | | ResNet18(pretrained) | 77.800 | 94.650 |
61 | | ResNet34(pretrained) | 79.100 | 95.050 |
62 | | ResNet50(pretrained) | 78.300 | 95.300 |
63 | | DenseNet121(pretrained) | 77.550 | 93.500 |
64 | | Inception_v3(pretrained) | 76.550 | 93.750 |
65 |
66 | * 1. *Sketch-a-Net that Beats Humans*
67 | 2. *The Sketchy Database: Learning to Retrieve Badly Drawn Bunnies*
68 |
69 | DPN, ShuffleNetG2, SENet18
70 | ## Tools
71 | - GetImageMean_Std
72 |
73 | Get image dataset mean and standard deviation.
74 |
75 | - SplitDataset
76 |
77 | Split image dataset according to the train and val record txt file.
78 |
79 | - ListAllImageName
80 |
81 | Get all image name in dataset.
--------------------------------------------------------------------------------
/ResNetModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | from resnet import resnet18, resnet34, resnet50
7 |
8 | class ResNetModel(nn.Module):
9 | def __init__(self, num_classes=None):
10 | super(ResNetModel, self).__init__()
11 | self.base = resnet34(pretrained=True)
12 |
13 | planes = 512
14 |
15 | if num_classes is not None:
16 | self.fc = nn.Linear(planes, num_classes)
17 | init.xavier_uniform(self.fc.weight)
18 | init.constant(self.fc.bias, 0.1)
19 |
20 | def forward(self, x):
21 | # shape [N, C, H, W]
22 | feat = self.base(x)
23 | global_feat = F.avg_pool2d(feat, feat.size()[2:])
24 | # shape [N, C]
25 | global_feat = global_feat.view(global_feat.size(0), -1)
26 |
27 | if hasattr(self, 'fc'):
28 | logits = self.fc(global_feat)
29 | return global_feat, logits
30 |
31 | # return global_feat, local_feat
32 | return global_feat
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/SketchANetModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | from sketchanet import sketchanet
7 |
8 | class SketchANetModel(nn.Module):
9 | def __init__(self, num_classes=None):
10 | super(SketchANetModel, self).__init__()
11 | self.base = sketchanet(pretrained=False)
12 |
13 | planes = 512
14 | if num_classes is not None:
15 | self.fc = nn.Linear(planes, num_classes)
16 | init.normal(self.fc.weight, std=0.001)
17 | init.constant(self.fc.bias, 0.1)
18 |
19 | def forward(self, x):
20 | feat = self.base(x)
21 |
22 | if hasattr(self, 'fc'):
23 | logits = self.fc(feat)
24 | return feat, logits
25 |
26 | return feat
27 |
--------------------------------------------------------------------------------
/Tools/GetImageMean_Std.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import os.path as osp
6 | from progressbar import Bar, ProgressBar
7 | from scipy.misc import imread, imresize
8 |
9 | def is_image(ext):
10 | ext = ext.lower()
11 | if ext == '.jpg':
12 | return True
13 | elif ext == '.png':
14 | return True
15 | elif ext == '.jpeg':
16 | return True
17 | elif ext == '.bmp':
18 | return True
19 | else:
20 | return False
21 |
22 | def get_all_image_names(rootdir, image_names_list=[]):
23 | for file in os.listdir(rootdir):
24 | filepath = osp.join(rootdir, file)
25 | if osp.isdir(filepath):
26 | get_all_image_names(filepath, image_names_list)
27 | elif osp.isfile(filepath):
28 | ext = osp.splitext(filepath)[1]
29 | if is_image(ext):
30 | image_names_list.append(filepath)
31 | image_names_list = sorted(image_names_list)
32 | return image_names_list
33 |
34 | def GetImageMean(rootdir, size=(256, 256)):
35 | R_channel = []
36 | G_channel = []
37 | B_channel = []
38 | image_names_list = get_all_image_names(rootdir)
39 | progress = ProgressBar(max_value= len(image_names_list))
40 | for i, name in enumerate(image_names_list):
41 | img = imread(name)
42 | img = imresize(img, size)
43 | if(img.shape[-1] == 3 or img.shape[-1] == 4):
44 | R_channel.append(img[:, :, 0])
45 | G_channel.append(img[:, :, 1])
46 | B_channel.append(img[:, :, 2])
47 | else:
48 | R_channel.append(img[:, :])
49 |
50 | progress.update(i)
51 | progress.finish()
52 |
53 | # num = len(image_names_list) * size[0] * size[1]
54 |
55 | if (img.shape[-1] == 3 or img.shape[-1] == 4):
56 | R_mean = np.mean(np.asarray(R_channel))
57 | G_mean = np.mean(np.asarray(G_channel))
58 | B_mean = np.mean(np.asarray(B_channel))
59 |
60 | R_std = np.std(np.asarray(R_channel))
61 | G_std = np.std(np.asarray(G_channel))
62 | B_std = np.std(np.asarray(B_channel))
63 | return R_mean, G_mean, B_mean, R_std, G_std, B_std
64 | else:
65 | R_mean = np.mean(np.asarray(R_channel))
66 | R_std = np.std(np.asarray(R_channel))
67 | return R_mean, R_std
68 |
69 |
70 | if __name__ == "__main__":
71 | rootdir = r"/home/bc/Work/Database/TU-Berlin sketch dataset/png"
72 | mean = GetImageMean(rootdir)
73 | print(mean)
--------------------------------------------------------------------------------
/Tools/ListAllImageName.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import os.path as osp
6 | from scipy.misc import imread, imresize
7 |
8 | def is_image(ext):
9 | ext = ext.lower()
10 | if ext == '.jpg':
11 | return True
12 | elif ext == '.png':
13 | return True
14 | elif ext == '.jpeg':
15 | return True
16 | elif ext == '.bmp':
17 | return True
18 | else:
19 | return False
20 |
21 | def get_all_image_names(rootdir, image_names_list=[]):
22 | for file in os.listdir(rootdir):
23 | filepath = osp.join(rootdir, file)
24 | if osp.isdir(filepath):
25 | get_all_image_names(filepath, image_names_list)
26 | elif osp.isfile(filepath):
27 | ext = osp.splitext(filepath)[1]
28 | if is_image(ext):
29 | image_names_list.append(osp.join(osp.split(rootdir)[1], file))
30 | image_names_list = sorted(image_names_list)
31 | return image_names_list
32 |
33 | def save_image_list(image_names_list, save_filename):
34 | f = open(save_filename, 'w')
35 | image_names_list = [line+'\n' for line in image_names_list]
36 | f.writelines(image_names_list)
37 | f.close()
38 |
39 | if __name__ == "__main__":
40 | data_root=r"/home/bc/Work/Database/TU-Berlin sketch dataset/png"
41 | save_filename=r"./train.txt"
42 | image_names_list = get_all_image_names(data_root)
43 | save_image_list(image_names_list, save_filename)
44 |
45 |
46 |
--------------------------------------------------------------------------------
/Tools/SplitDataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import sys
4 | import os
5 | import os.path as osp
6 | import shutil
7 |
8 | def copyfile(srcfile, dstfile):
9 | if not osp.isfile(srcfile):
10 | print("%s not exist!"%(srcfile))
11 | else:
12 | fpath, fname = osp.split(dstfile)
13 | if not osp.exists(fpath):
14 | os.makedirs(fpath)
15 | shutil.copyfile(srcfile, dstfile)
16 |
17 | def movetodir(record_file, Data_root, dataset_path):
18 | with open(record_file, 'r') as f:
19 | for line in f:
20 | src_img_path = osp.join(Data_root, line.rstrip())
21 | dst_img_path = osp.join(dataset_path, line.rstrip())
22 | copyfile(src_img_path, dst_img_path)
23 |
24 | def main():
25 | Data_root = "/home/bc/Work/Database/TU-Berlin sketch dataset/png"
26 | Train_record_file = "/home/bc/Work/Database/TU-Berlin sketch dataset/png/train_list.txt"
27 | Val_record_file = "/home/bc/Work/Database/TU-Berlin sketch dataset/png/val_list.txt"
28 |
29 |
30 | Dataset_train_path = osp.join(Data_root, "../train_val/train")
31 | if not osp.exists(Dataset_train_path):
32 | os.makedirs(Dataset_train_path)
33 |
34 | Dataset_val_path = osp.join(Data_root, "../train_val/val")
35 | if not osp.exists(Dataset_val_path):
36 | os.makedirs(Dataset_val_path)
37 |
38 | movetodir(Train_record_file, Data_root, Dataset_train_path)
39 | movetodir(Val_record_file, Data_root, Dataset_val_path)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/Tools/create_filelist.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATA=dataset/images
4 | echo "Create train.txt..."
5 | rm -rf $DATA/train.txt
6 | ls bike | sed "s:^:bike/:" | sed "s:$: 1:" >> train.txt
7 | ls cat | sed "s:^:cat/:" | sed "s:$: 2:" >> train.txt
--------------------------------------------------------------------------------
/Train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os
4 | import shutil
5 | import time
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.distributed as dist
12 | import torch.optim as optim
13 | import torch.utils.data
14 | import torch.utils.data.distributed
15 | from torchvision import datasets, transforms
16 | from torch.autograd import Variable
17 | import torch.backends.cudnn as cudnn
18 | from torch.nn.utils.clip_grad import clip_grad_norm
19 | from SketchANetModel import SketchANetModel
20 | from AlexNetModel import AlexNetModel
21 | from ResNetModel import ResNetModel
22 |
23 | parser = argparse.ArgumentParser(description='PyTorch Sketch Me That Shoe Example')
24 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
25 | help='input batch size for training (default: 64)')
26 | parser.add_argument('--test-batch-size', type=int, default=10, metavar='N',
27 | help='input batch size for testing (default: 10)')
28 | parser.add_argument('--epochs', type=int, default=2000, metavar='N',
29 | help='number of epochs to train (default: 10)')
30 | parser.add_argument('--weight_decay', type=float, default=0.0005,
31 | help='Adm weight decay')
32 | parser.add_argument('--lr', type=float, default=2e-4, metavar='LR',
33 | help='learning rate (default: 0.01)')
34 | parser.add_argument('--no-cuda', action='store_true', default=False,
35 | help='enables CUDA training')
36 | parser.add_argument('--seed', type=int, default=1, metavar='S',
37 | help='random seed (default: 1)')
38 | parser.add_argument('--log-interval', type=int, default=20, metavar='N',
39 | help='how many batches to wait before logging training status')
40 | parser.add_argument('--print-freq', '-p', default=15, type=int, metavar='N',
41 | help='print frequency (default: 10)')
42 | parser.add_argument('--classes', type=int, default=419,
43 | help='number of classes')
44 | parser.add_argument('--resume', default='', type=str,
45 | help='path to latest checkpoint (default: none)')
46 | parser.add_argument('--name', default='TripletNetModel', type=str,
47 | help='name of experiment')
48 | parser.add_argument('--normalize_feature', default=False, type=bool,
49 | help='normalize_feature')
50 |
51 | best_acc = 0
52 |
53 |
54 | def to_scalar(vt):
55 | """Transform a length-1 pytorch Variable or Tensor to scalar.
56 | Suppose tx is a torch Tensor with shape tx.size() = torch.Size([1]),
57 | then npx = tx.cpu().numpy() has shape (1,), not 1."""
58 | if isinstance(vt, Variable):
59 | return vt.data.cpu().numpy().flatten()[0]
60 | if torch.is_tensor(vt):
61 | return vt.cpu().numpy().flatten()[0]
62 | raise TypeError('Input should be a variable or tensor')
63 |
64 |
65 | def main():
66 | global args, best_acc
67 | args = parser.parse_args()
68 | args.cuda = not args.no_cuda and torch.cuda.is_available()
69 | torch.manual_seed(args.seed)
70 | if args.cuda:
71 | torch.cuda.manual_seed(args.seed)
72 |
73 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
74 |
75 | ###### DataSet ######
76 | sketch_dir = r"/home/bc/Work/Database/TU-Berlin sketch dataset/png"
77 | # sketch_dir = r"/home/bc/Work/Database/Dogs_Cats/catdog/train"
78 | train_dataset = datasets.ImageFolder(
79 | sketch_dir,
80 | transform=transforms.Compose([
81 | transforms.Resize([256, 256]),
82 | transforms.CenterCrop(224),
83 | transforms.RandomHorizontalFlip(),
84 | transforms.RandomVerticalFlip(),
85 | transforms.RandomRotation(45),
86 | transforms.ToTensor(),
87 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
88 | std=[0.229, 0.224, 0.225])
89 | ])
90 | )
91 |
92 | train_loader = torch.utils.data.DataLoader(
93 | train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs
94 | )
95 | test_dir = r"/home/bc/Work/Database/Dogs_Cats/catdog/val"
96 | test_dataset = datasets.ImageFolder(
97 | test_dir,
98 | transform=transforms.Compose([
99 | transforms.Resize([256, 256]),
100 | transforms.CenterCrop(224),
101 | transforms.ToTensor(),
102 | #transforms.Normalize(mean=[0.485, 0.456, 0.406],
103 | # std=[0.229, 0.224, 0.225])
104 | ])
105 | )
106 | test_loader = torch.utils.data.DataLoader(
107 | test_dataset, batch_size=args.test_batch_size, shuffle=True, **kwargs
108 | )
109 | ###### Model ######
110 |
111 | # snet = SketchANetModel(num_classes=250)
112 | # snet = AlexNetModel(num_classes=250)
113 | snet = ResNetModel(num_classes=250)
114 | print(snet)
115 | if args.cuda:
116 | snet.cuda()
117 |
118 | if args.resume:
119 | if os.path.isfile(args.resume):
120 | print("=> loading checkpoint '{}'".format(args.resume))
121 | checkpoint = torch.load(args.resume)
122 | args.start_epoch = checkpoint['epoch']
123 | best_acc = checkpoint['best_prec']
124 | snet.load_state_dict(checkpoint['state_dict'])
125 | print("=> loaded checkpoint '{}' (epoch {})"
126 | .format(args.resume, checkpoint['epoch']))
127 | else:
128 | print("=> no checkpoint found at '{}'".format(args.resume))
129 |
130 | cudnn.benchmark = True
131 |
132 | ###### Criteria ######
133 | id_criterion = nn.CrossEntropyLoss()
134 | optimizer = optim.Adam(snet.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
135 |
136 | n_parameters = sum([p.data.nelement() for p in snet.parameters()])
137 | print(' + Number of params: {}'.format(n_parameters))
138 |
139 | for epoch in range(1, args.epochs + 1):
140 | adjust_learning_rate(optimizer, epoch)
141 | # train for one epoch
142 | train(train_loader, snet, id_criterion, optimizer, epoch)
143 | # evaluate on validation set
144 | # prec1 = test(test_loader, snet, id_criterion, epoch)
145 |
146 | # remember best Accuracy and save checkpoint
147 | #is_best = prec1 > best_acc
148 | is_best = True
149 | #best_acc = max(prec1, best_acc)
150 | save_checkpoint({
151 | 'epoch': epoch + 1,
152 | 'state_dict': snet.state_dict(),
153 | 'best_prec': best_acc,
154 | }, is_best)
155 |
156 | def train(train_loader, snet, id_criterion, optimizer, epoch):
157 | batch_time = AverageMeter()
158 | data_time = AverageMeter()
159 | losses = AverageMeter()
160 | top1 = AverageMeter()
161 | top5 = AverageMeter()
162 |
163 | # switch to train mode
164 | snet.train()
165 |
166 | end = time.time()
167 | for batch_indx, (input, target) in enumerate(train_loader):
168 | # measure data loading time
169 | data_time.update(time.time() - end)
170 |
171 | if args.cuda:
172 | input, target = input.cuda(), target.cuda()
173 | input, target = Variable(input), Variable(target)
174 |
175 | # compute output
176 | _, output = snet(input)
177 |
178 | # print(output.data[0])
179 | loss = id_criterion(output, target)
180 |
181 | # measure accuracy and record loss
182 | prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5))
183 | losses.update(loss.data[0], input.size(0))
184 | top1.update(prec1[0], input.size(0))
185 | top5.update(prec5[0], input.size(0))
186 |
187 | # compute gradient and do SGD step
188 | optimizer.zero_grad()
189 | loss.backward()
190 | # clip_grad_norm(snet.parameters(), 100.0)
191 | optimizer.step()
192 |
193 | # measure elapsed time
194 | batch_time.update(time.time() - end)
195 | end = time.time()
196 |
197 | if batch_indx % args.print_freq == 0:
198 | print('Epoch: [{0}][{1}/{2}]\t'
199 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
200 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
201 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
202 | 'Prec@1 {top1.val:.4f} ({top1.avg:.4f})\t'
203 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
204 | epoch, batch_indx, len(train_loader), batch_time=batch_time,
205 | data_time=data_time, loss=losses, top1=top1, top5=top5))
206 |
207 | def test(test_loader, snet, criterion, epoch):
208 | batch_time = AverageMeter()
209 | losses = AverageMeter()
210 | top1 = AverageMeter()
211 | top5 = AverageMeter()
212 |
213 | # switch to evaluate mode
214 | snet.eval()
215 |
216 | end = time.time()
217 | for batch_indx, (input, target) in enumerate(test_loader):
218 | if args.cuda:
219 | input, target = input.cuda(), target.cuda()
220 | input, target = Variable(input), Variable(target)
221 |
222 | # compute output
223 | _, output = snet(input)
224 | output = snet(input)
225 | loss = criterion(output, target)
226 |
227 | # measure accuracy and record loss
228 | prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5))
229 | losses.update(loss.data[0], input.size(0))
230 | top1.update(prec1[0], input.size(0))
231 | top5.update(prec5[0], input.size(0))
232 |
233 | # measure elapsed time
234 | batch_time.update(time.time() - end)
235 | end = time.time()
236 |
237 | if batch_indx % args.print_freq == 0:
238 | print('Test: [{0}/{1}]\t'
239 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
240 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
241 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
242 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
243 | batch_indx, len(test_loader), batch_time=batch_time, loss=losses,
244 | top1=top1, top5=top5))
245 |
246 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
247 | .format(top1=top1, top5=top5))
248 |
249 | return top1.avg
250 |
251 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
252 | """Saves checkpoint to disk"""
253 | directory = "runs/%s/" % (args.name)
254 | if not os.path.exists(directory):
255 | os.makedirs(directory)
256 | filename = directory + filename
257 | torch.save(state, filename)
258 | if is_best:
259 | shutil.copyfile(filename, 'runs/%s/' % (args.name) + 'model_best.pth.tar')
260 |
261 | class AverageMeter(object):
262 | """Computes and stores the average and current value"""
263 |
264 | def __init__(self):
265 | self.reset()
266 |
267 | def reset(self):
268 | self.val = 0
269 | self.avg = 0
270 | self.sum = 0
271 | self.count = 0
272 |
273 | def update(self, val, n=1):
274 | self.val = val
275 | self.sum += val * n
276 | self.count += n
277 | self.avg = self.sum / self.count
278 |
279 | def adjust_learning_rate(optimizer, epoch):
280 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
281 | lr = args.lr * (0.1**(epoch // 10))
282 | for param_group in optimizer.state_dict()['param_groups']:
283 | param_group['lr'] = lr
284 |
285 | def accuracy(output, target, topk=(1,)):
286 | """Computes the precision@k for the specified values of k"""
287 | maxk = max(topk)
288 | batch_size = target.size(0)
289 | _, pred = output.topk(maxk, 1, True, True)
290 | pred = pred.t()
291 | correct = pred.eq(target.view(1, -1).expand_as(pred))
292 |
293 | res = []
294 | for k in topk:
295 | correct_k = correct[:k].view(-1).float().sum(0)
296 | res.append(correct_k.mul_(100.0 / batch_size))
297 | return res
298 |
299 | if __name__ == '__main__':
300 | main()
301 |
--------------------------------------------------------------------------------
/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 |
7 |
8 | __all__ = ['AlexNet', 'alexnet']
9 |
10 |
11 | model_urls = {
12 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
13 | }
14 |
15 |
16 | class AlexNet(nn.Module):
17 |
18 | def __init__(self, num_classes=1000):
19 | super(AlexNet, self).__init__()
20 | self.features = nn.Sequential(
21 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
25 | nn.ReLU(inplace=True),
26 | nn.MaxPool2d(kernel_size=3, stride=2),
27 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
32 | nn.ReLU(inplace=True),
33 | nn.MaxPool2d(kernel_size=3, stride=2),
34 | )
35 | self.classifier = nn.Sequential(
36 | nn.Dropout(),
37 | nn.Linear(256 * 6 * 6, 4096),
38 | nn.ReLU(inplace=True),
39 | nn.Dropout(),
40 | nn.Linear(4096, 4096),
41 | nn.ReLU(inplace=True),
42 | nn.Linear(4096, num_classes),
43 | )
44 |
45 | def forward(self, x):
46 | x = self.features(x)
47 | x = x.view(x.size(0), 256 * 6 * 6)
48 | x = self.classifier(x)
49 | return x
50 |
51 |
52 | def alexnet(pretrained=False, **kwargs):
53 | r"""AlexNet model architecture from the
54 | `"One weird trick..." `_ paper.
55 |
56 | Args:
57 | pretrained (bool): If True, returns a model pre-trained on ImageNet
58 | """
59 | model = AlexNet(**kwargs)
60 | if pretrained:
61 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
62 |
63 | new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
64 | model.classifier = new_classifier
65 | return model
--------------------------------------------------------------------------------
/filelist_data_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 |
4 | from __future__ import print_function
5 |
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 | import os
9 | import os.path as osp
10 | import sys
11 | import json
12 | import torch.utils.data
13 | import torchvision.transforms as transforms
14 |
15 |
16 | def default_image_loader(path):
17 | # return plt.imread(path)
18 | return Image.open(path).convert('RGB')
19 |
20 | class SketchImageLoader(torch.utils.data.Dataset):
21 | def __init__(self, base_path, filelist_filename, mode="train", transform=None, loader=default_image_loader):
22 | pass
23 |
24 | def __getitem__(self, index):
25 | pass
26 |
27 | def __len__(self):
28 | pass
29 |
30 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from torch.nn import functional as F
5 | from itertools import chain
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152']
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | "3x3 convolution with padding"
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super(BasicBlock, self).__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65 | padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * 4)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 | def __init__(self, block, layers):
98 | self.inplanes = 64
99 | super(ResNet, self).__init__()
100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
101 | bias=False)
102 | self.bn1 = nn.BatchNorm2d(64)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 | self.layer1 = self._make_layer(block, 64, layers[0])
106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
109 |
110 | for m in self.modules():
111 | if isinstance(m, nn.Conv2d):
112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
113 | m.weight.data.normal_(0, math.sqrt(2. / n))
114 | elif isinstance(m, nn.BatchNorm2d):
115 | m.weight.data.fill_(1)
116 | m.bias.data.zero_()
117 |
118 | def _make_layer(self, block, planes, blocks, stride=1):
119 | downsample = None
120 | if stride != 1 or self.inplanes != planes * block.expansion:
121 | downsample = nn.Sequential(
122 | nn.Conv2d(self.inplanes, planes * block.expansion,
123 | kernel_size=1, stride=stride, bias=False),
124 | nn.BatchNorm2d(planes * block.expansion),
125 | )
126 |
127 | layers = []
128 | layers.append(block(self.inplanes, planes, stride, downsample))
129 | self.inplanes = planes * block.expansion
130 | for i in range(1, blocks):
131 | layers.append(block(self.inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 |
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool(x)
141 |
142 | x = self.layer1(x)
143 | x = self.layer2(x)
144 | x = self.layer3(x)
145 | x = self.layer4(x)
146 |
147 | return x
148 |
149 |
150 | def remove_fc(state_dict):
151 | """Remove the fc layer parameters from state_dict."""
152 | # for key, value in state_dict.items(): python 2.7.12
153 | for key, value in list(state_dict.items()): #python 3.5.4
154 | if key.startswith('fc.'):
155 | del state_dict[key]
156 | return state_dict
157 |
158 |
159 | def resnet18(pretrained=False):
160 | """Constructs a ResNet-18 model.
161 |
162 | Args:
163 | pretrained (bool): If True, returns a model pre-trained on ImageNet
164 | """
165 | model = ResNet(BasicBlock, [2, 2, 2, 2])
166 | if pretrained:
167 | print("model_urls['resnet18']: {}".format(model_urls['resnet18']))
168 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18'])))
169 | return model
170 |
171 |
172 | def resnet34(pretrained=False):
173 | """Constructs a ResNet-34 model.
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | """
178 | model = ResNet(BasicBlock, [3, 4, 6, 3])
179 | if pretrained:
180 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34'])))
181 | return model
182 |
183 |
184 | def resnet50(pretrained=False):
185 | """Constructs a ResNet-50 model.
186 |
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | """
190 | model = ResNet(Bottleneck, [3, 4, 6, 3])
191 | if pretrained:
192 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50'])))
193 | return model
194 |
195 |
196 | def resnet101(pretrained=False):
197 | """Constructs a ResNet-101 model.
198 |
199 | Args:
200 | pretrained (bool): If True, returns a model pre-trained on ImageNet
201 | """
202 | model = ResNet(Bottleneck, [3, 4, 23, 3])
203 | if pretrained:
204 | model.load_state_dict(
205 | remove_fc(model_zoo.load_url(model_urls['resnet101'])))
206 | return model
207 |
208 |
209 | def resnet152(pretrained=False):
210 | """Constructs a ResNet-152 model.
211 |
212 | Args:
213 | pretrained (bool): If True, returns a model pre-trained on ImageNet
214 | """
215 | model = ResNet(Bottleneck, [3, 8, 36, 3])
216 | if pretrained:
217 | model.load_state_dict(
218 | remove_fc(model_zoo.load_url(model_urls['resnet152'])))
219 | return model
220 |
--------------------------------------------------------------------------------
/resume_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_VISIBLE_DEVICES=0 python Train.py \
4 | --batch-size 128 \
5 | --resume ./runs/NetModel/checkpoint.pth.tar
--------------------------------------------------------------------------------
/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_VISIBLE_DEVICES=0 python Train.py \
4 | --batch-size 128
--------------------------------------------------------------------------------
/sketchanet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | model_paths ={
7 | 'sketchanet': '',
8 | }
9 |
10 | class SketchANet(nn.Module):
11 | def __init__(self, num_classes=250):
12 | super(SketchANet, self).__init__()
13 | self.conv = nn.Sequential(
14 | nn.Conv2d(3, 64, kernel_size=15, stride=3, padding=0),
15 | nn.ReLU(inplace=True),
16 | nn.MaxPool2d(kernel_size=3, stride=2),
17 | nn.Conv2d(64, 128, kernel_size=5, padding=0),
18 | nn.ReLU(inplace=True),
19 | nn.MaxPool2d(kernel_size=3, stride=2),
20 | nn.Conv2d(128, 256, kernel_size=3, padding=1),
21 | nn.ReLU(inplace=True),
22 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.MaxPool2d(kernel_size=3, stride=2),
27 | )
28 | self.classifier = nn.Sequential(
29 | nn.Linear(256 * 6 * 6, 512),
30 | nn.ReLU(inplace=True),
31 | nn.Dropout(),
32 | nn.Linear(512, 512),
33 | nn.ReLU(inplace=True),
34 | nn.Dropout(),
35 | nn.Linear(512, num_classes),
36 | )
37 |
38 | def forward(self, x):
39 | x = self.conv(x)
40 | x = x.view(x.size(0), 256 * 6 * 6)
41 | x = self.classifier(x)
42 | return x
43 |
44 | def sketchanet(pretrained=False, **kwargs):
45 | model = SketchANet(**kwargs)
46 | if pretrained:
47 | model.load_state_dict(torch.load(model_paths['sketchanet']))
48 |
49 | new_classifer = nn.Sequential(*list(model.classifier.children())[:-1])
50 | model.classifier = new_classifer
51 | return model
--------------------------------------------------------------------------------