├── .gitignore ├── README.md ├── cifarclassify ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── bearing_loader.py │ └── caltech101_loader.py ├── modelloader │ ├── __init__.py │ ├── cifar │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── googlenet.py │ │ ├── resnet.py │ │ └── wide_resnet.py │ └── imagenet │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── densenet.py │ │ ├── googlenet.py │ │ ├── mobilenet.py │ │ ├── mobilenet_v2.py │ │ ├── peleenet.py │ │ ├── resnet.py │ │ ├── resnet_ibn_a.py │ │ ├── resnet_ibn_b.py │ │ ├── seresnet.py │ │ ├── shufflenet.py │ │ ├── squeezenet.py │ │ ├── wide_resnet.py │ │ └── xception.py └── utils │ ├── __init__.py │ ├── imagenet_utils.py │ └── numpy_utils.py ├── data ├── EnglishCockerSpaniel_simon.jpg ├── cat.jpg └── synset_words.txt ├── doc ├── develop_releated.md ├── mobilenet_implement.md └── pytorch_net_visual.md ├── requirements.txt ├── test ├── __init__.py ├── context.py └── test_tf_resnet.py ├── tfcifarclassify ├── __init__.py ├── dataloader │ ├── __init__.py │ └── cifar_input.py └── modelloader │ ├── __init__.py │ ├── cifar │ └── __init__.py │ └── imagenet │ ├── __init__.py │ └── resnet.py ├── train_cifar.py └── train_imagenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | experiments/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cifarclassify 2 | 3 | 本仓库的开发计划见[项目下一步开发计划](https://github.com/guanfuchen/cifarclassify/issues/1) 4 | 5 | --- 6 | ## 图像分类算法 7 | 这个仓库主要实现常用的网络并在cifar10数据集上进行试验,比较分类精度。主要参考如下所示: 8 | - [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) 9 | - [pytorch-classification](https://github.com/bearpaw/pytorch-classification) 10 | - [wide-resnet.pytorch](https://github.com/meliketoy/wide-resnet.pytorch) 使用resnet训练的cifar10和cifar100模型。 11 | - [pytorch-playground](https://github.com/aaron-xichen/pytorch-playground) 实现常用数据集和模型。 12 | - [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) 该仓库实现了大量的常用DL模型。 13 | 14 | --- 15 | ## 学习率算法 16 | 增加学习率算法比较。 17 | 18 | [Learning-Rate](https://github.com/nathanhubens/Learning-Rate) 19 | 20 | [lr_scheduler.py](https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py) 21 | 22 | [pytorch-lr-scheduler](https://github.com/Jiaming-Liu/pytorch-lr-scheduler) 23 | 24 | --- 25 | ### 网络实现 26 | - alexnet 27 | - MobileNet [mobilenet实现](doc/mobilenet_implement.md) 28 | - resnet 29 | - densenet 30 | - squeezenet 31 | - ... 32 | 33 | --- 34 | ### 数据集实现 35 | - cifar10 36 | - ... 37 | 38 | --- 39 | ### 依赖 40 | - pytorch 41 | 42 | --- 43 | ### 用法 44 | 45 | **可视化** 46 | 47 | [visdom](https://github.com/facebookresearch/visdom) 48 | [网络结构可视化](doc/pytorch_net_visual.md) 49 | 50 | 51 | ```bash 52 | # 在tmux或者另一个终端中开启可视化服务器visdom 53 | python -m visdom.server 54 | # 然后在浏览器中查看127.0.0.1:9097 55 | ``` 56 | 57 | **训练** 58 | ```bash 59 | # 训练模型 60 | python train.py 61 | ``` 62 | 63 | **校验** 64 | ```bash 65 | # 校验模型 66 | python validate.py 67 | ``` 68 | 69 | **测试** 70 | ```bash 71 | # 测试模型 72 | python test.py 73 | ``` 74 | -------------------------------------------------------------------------------- /cifarclassify/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /cifarclassify/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /cifarclassify/dataloader/bearing_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import os 4 | import collections 5 | import random 6 | 7 | import cv2 8 | import numpy as np 9 | import scipy.misc as m 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | from torch.utils import data 13 | from torchvision import transforms 14 | import glob 15 | 16 | 17 | class BearingLoader(data.Dataset): 18 | def __init__(self, root, split="train", is_transform=False, is_augment=False): 19 | self.root = root 20 | self.split = split 21 | self.img_size = (224, 224) # (h, w) 22 | self.is_transform = is_transform 23 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 24 | self.n_classes = 4 25 | self.files = collections.defaultdict(list) 26 | self.joint_augment_transform = None 27 | self.is_augment = is_augment 28 | 29 | file_list = glob.glob(root + '/dataset/*.jpg') 30 | file_list.sort() 31 | file_list_len = len(file_list) 32 | split_index = int(file_list_len*0.7) 33 | if self.split == 'train': 34 | self.files[split] = file_list[:split_index] 35 | elif self.split == 'val': 36 | self.files[split] = file_list[split_index:] 37 | 38 | def __len__(self): 39 | return len(self.files[self.split]) 40 | 41 | def __getitem__(self, index): 42 | img_name = self.files[self.split][index] 43 | img_file_name = img_name[img_name.rfind('/') + 1:img_name.rfind('.')] 44 | # img_file_name = img_name[:img_name.rfind('.')] 45 | # print(img_file_name) 46 | 47 | img = Image.open(img_name) 48 | img = img.resize((self.img_size[1], self.img_size[0])) 49 | lbl = img_file_name[:img_file_name.index('_')] 50 | lbl = int(lbl) 51 | lbl -= 1 # 1-4 to 0-3 52 | # print(lbl) 53 | 54 | if self.is_augment: 55 | if self.joint_augment_transform is not None: 56 | img, lbl = self.joint_augment_transform(img, lbl) 57 | 58 | img = np.array(img, dtype=np.uint8) 59 | lbl = np.array(lbl, dtype=np.int32) 60 | 61 | if self.is_transform: 62 | img, lbl = self.transform(img, lbl) 63 | 64 | return img, lbl 65 | 66 | # 转换HWC为CHW 67 | def transform(self, img, lbl): 68 | img = img[:, :, ::-1] 69 | img = img.astype(np.float64) 70 | img -= self.mean 71 | img = img.astype(float) / 255.0 72 | # HWC -> CHW 73 | img = img.transpose(2, 0, 1) 74 | 75 | img = torch.from_numpy(img).float() 76 | lbl = torch.from_numpy(lbl).long() 77 | return img, lbl 78 | 79 | 80 | def main(): 81 | home_path = os.path.expanduser('~') 82 | local_path = os.path.join(home_path, 'Data/Bearing') 83 | batch_size = 4 84 | dst = BearingLoader(local_path, is_transform=True, is_augment=False) 85 | trainloader = data.DataLoader(dst, batch_size=batch_size, shuffle=True) 86 | for i, (imgs, labels) in enumerate(trainloader): 87 | print(i) 88 | print(imgs.shape) 89 | print(labels.shape) 90 | # if i == 0: 91 | image_list_len = imgs.shape[0] 92 | for image_list in range(image_list_len): 93 | img = imgs[image_list, :, :, :] 94 | img = img.numpy() 95 | img = np.transpose(img, (1, 2, 0)) 96 | plt.subplot(image_list_len, 2, 2 * image_list + 1) 97 | plt.imshow(img) 98 | plt.show() 99 | if i == 0: 100 | break 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | 106 | -------------------------------------------------------------------------------- /cifarclassify/dataloader/caltech101_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # [Caltech 101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) 3 | import torch 4 | import os 5 | import collections 6 | import random 7 | 8 | import cv2 9 | import numpy as np 10 | import scipy.misc as m 11 | import matplotlib.pyplot as plt 12 | from PIL import Image 13 | from torch.utils import data 14 | from torchvision import transforms 15 | import glob 16 | 17 | 18 | class Caltech101Loader(data.Dataset): 19 | def __init__(self, root, split="train", is_transform=False, is_augment=False): 20 | self.root = root 21 | self.split = split 22 | self.img_size = (224, 224) # (h, w) 23 | self.is_transform = is_transform 24 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 25 | self.files = [] 26 | self.joint_augment_transform = None 27 | self.is_augment = is_augment 28 | self.klassnames = ['BACKGROUND_Google', 'Faces', 'Faces_easy', 'Leopards', 'Motorbikes', 'accordion', 'airplanes', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'] 29 | self.n_classes = len(self.klassnames) 30 | 31 | klassname_list = glob.glob(os.path.join(root + '/images/*')) 32 | klassname_list.sort() 33 | for klassname_path in klassname_list: 34 | klassname = klassname_path[klassname_path.rfind('/')+1:] 35 | # self.klassnames.append(klassname) 36 | # print('klassname:', klassname) 37 | file_list = glob.glob(os.path.join(root + '/images/{}/*.jpg'.format(klassname))) 38 | file_list.sort() 39 | file_list_len = len(file_list) 40 | split_index = int(file_list_len * 0.7) 41 | if self.split == 'train': 42 | self.files += file_list[:split_index] 43 | elif self.split == 'val': 44 | self.files += file_list[split_index:] 45 | self.files.sort() 46 | # self.klassnames.sort() 47 | # print('klassnames:', self.klassnames) 48 | # print('files:', self.files) 49 | 50 | def __len__(self): 51 | return len(self.files) 52 | 53 | def __getitem__(self, index): 54 | img_name = self.files[index] 55 | # img_file_name = img_name[img_name.rfind('/') + 1:img_name.rfind('.')] 56 | # img_file_name = img_name[:img_name.rfind('.')] 57 | # print(img_file_name) 58 | 59 | img = cv2.imread(img_name) # BGR 60 | # print('img.shape:', img.shape) 61 | img = cv2.resize(img, (self.img_size[1], self.img_size[0])) 62 | # print('img.shape:', img.shape) 63 | # img = Image.open(img_name) 64 | # img = img.resize((self.img_size[1], self.img_size[0])) 65 | # lbl = img_file_name[:img_file_name.index('_')] 66 | klassname = img_name[:img_name.rfind('/')] 67 | klassname = klassname[klassname.rfind('/')+1:] 68 | # print('klassname:', klassname) 69 | lbl = self.klassnames.index(klassname) 70 | lbl = int(lbl) 71 | # print(lbl) 72 | 73 | if self.is_augment: 74 | if self.joint_augment_transform is not None: 75 | img, lbl = self.joint_augment_transform(img, lbl) 76 | 77 | img = np.array(img, dtype=np.uint8) 78 | lbl = np.array(lbl, dtype=np.int32) 79 | 80 | if self.is_transform: 81 | img, lbl = self.transform(img, lbl) 82 | 83 | return img, lbl 84 | 85 | # 转换HWC为CHW 86 | def transform(self, img, lbl): 87 | # print('img.shape:', img.shape) 88 | # img = img[:, :, ::-1] 89 | img = img.astype(np.float64) 90 | img -= self.mean 91 | img = img.astype(float) / 255.0 92 | # HWC -> CHW 93 | img = img.transpose(2, 0, 1) 94 | 95 | img = torch.from_numpy(img).float() 96 | lbl = torch.from_numpy(lbl).long() 97 | return img, lbl 98 | 99 | 100 | def main(): 101 | home_path = os.path.expanduser('~') 102 | local_path = os.path.join(home_path, 'Data/101_ObjectCategories') 103 | batch_size = 4 104 | dst = Caltech101Loader(local_path, is_transform=True, is_augment=False) 105 | trainloader = data.DataLoader(dst, batch_size=batch_size, shuffle=True) 106 | for i, (imgs, labels) in enumerate(trainloader): 107 | print(i) 108 | print(imgs.shape) 109 | print(labels.shape) 110 | print(labels.shape) 111 | # if i == 0: 112 | image_list_len = imgs.shape[0] 113 | # for image_list in range(image_list_len): 114 | # img = imgs[image_list, :, :, :] 115 | # img = img.numpy() 116 | # img = np.transpose(img, (1, 2, 0)) 117 | # plt.subplot(image_list_len, 2, 2 * image_list + 1) 118 | # plt.imshow(img) 119 | # plt.show() 120 | # if i == 0: 121 | # break 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | 127 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import time 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torchvision import models 8 | import numpy as np 9 | import os 10 | from scipy import misc 11 | import numpy as np 12 | import scipy 13 | import matplotlib.pyplot as plt 14 | 15 | from cifarclassify.utils import imagenet_utils 16 | 17 | class AlexNet(nn.Module): 18 | """ 19 | :param 20 | """ 21 | def __init__(self, n_classes=10): 22 | super(AlexNet, self).__init__() 23 | # features和classifier的结构和vgg16等类似 24 | self.features = nn.Sequential( 25 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 26 | nn.ReLU(inplace=True), 27 | nn.MaxPool2d(kernel_size=2, stride=2), 28 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=2, stride=2), 31 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | ) 39 | self.classifier = nn.Linear(256, n_classes) 40 | 41 | def forward(self, x): 42 | """ 43 | :param x: 44 | :return: 45 | """ 46 | x = self.features(x) 47 | x = x.view(x.size(0), -1) 48 | x = self.classifier(x) 49 | return x 50 | 51 | if __name__ == '__main__': 52 | n_classes = 10 53 | model = AlexNet(n_classes=n_classes) 54 | 55 | x = Variable(torch.randn(1, 3, 32, 32)) 56 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 57 | start = time.time() 58 | pred = model(x) 59 | # print('pred.shape', pred.shape) 60 | end = time.time() 61 | print("AlexNet forward time:", end-start) 62 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/cifar/googlenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | # googlenet中的inception结构 14 | class Inception(nn.Module): 15 | """ 16 | :param 17 | """ 18 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 19 | super(Inception, self).__init__() 20 | # 1x1 conv branch 21 | # 1x1卷积分支 22 | self.b1 = nn.Sequential( 23 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 24 | nn.BatchNorm2d(n1x1), 25 | nn.ReLU(True), 26 | ) 27 | 28 | # 1x1 conv -> 3x3 conv branch 29 | # 1x1卷积加上3x3卷积 30 | self.b2 = nn.Sequential( 31 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 32 | nn.BatchNorm2d(n3x3red), 33 | nn.ReLU(True), 34 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n3x3), 36 | nn.ReLU(True), 37 | ) 38 | 39 | # 1x1 conv -> 5x5 conv branch 40 | # 1x1卷积加上5x5卷积,这里所有的卷积都保证了最后输出的feature map与输入相同,这里5x5用两个3x3卷积代替 41 | self.b3 = nn.Sequential( 42 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 43 | nn.BatchNorm2d(n5x5red), 44 | nn.ReLU(True), 45 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 46 | nn.BatchNorm2d(n5x5), 47 | nn.ReLU(True), 48 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 49 | nn.BatchNorm2d(n5x5), 50 | nn.ReLU(True), 51 | ) 52 | 53 | # 3x3 pool -> 1x1 conv branch 54 | self.b4 = nn.Sequential( 55 | nn.MaxPool2d(3, stride=1, padding=1), 56 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 57 | nn.BatchNorm2d(pool_planes), 58 | nn.ReLU(True), 59 | ) 60 | 61 | def forward(self, x): 62 | """ 63 | :param x: 64 | :return: 65 | """ 66 | y1 = self.b1(x) 67 | y2 = self.b2(x) 68 | y3 = self.b3(x) 69 | y4 = self.b4(x) 70 | out = torch.cat([y1, y2, y3, y4], 1) 71 | # print('y1.data.shape:', y1.data.shape) 72 | # print('y2.data.shape:', y2.data.shape) 73 | # print('y3.data.shape:', y3.data.shape) 74 | # print('y4.data.shape:', y4.data.shape) 75 | # print('out.data.shape:', out.data.shape) 76 | return out 77 | 78 | 79 | # cifar10上的GoogLeNet 80 | class GoogLeNet(nn.Module): 81 | """ 82 | :param 83 | """ 84 | def __init__(self, n_classes=10): 85 | super(GoogLeNet, self).__init__() 86 | self.pre_layers = nn.Sequential( 87 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 88 | nn.BatchNorm2d(192), 89 | nn.ReLU(True), 90 | ) 91 | 92 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 93 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 94 | 95 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 96 | 97 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 98 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 99 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 100 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 101 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 102 | 103 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 104 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 105 | 106 | self.avgpool = nn.AvgPool2d(8, stride=1) 107 | self.linear = nn.Linear(1024, n_classes) 108 | 109 | def forward(self, x): 110 | """ 111 | :param x: 112 | :return: 113 | """ 114 | out = self.pre_layers(x) 115 | out = self.a3(out) 116 | out = self.b3(out) 117 | out = self.maxpool(out) 118 | out = self.a4(out) 119 | out = self.b4(out) 120 | out = self.c4(out) 121 | out = self.d4(out) 122 | out = self.e4(out) 123 | out = self.maxpool(out) 124 | out = self.a5(out) 125 | out = self.b5(out) 126 | out = self.avgpool(out) 127 | out = out.view(out.size(0), -1) 128 | out = self.linear(out) 129 | return out 130 | 131 | 132 | if __name__ == '__main__': 133 | n_classes = 10 134 | model = GoogLeNet(n_classes=n_classes) 135 | model.eval() 136 | x = Variable(torch.randn(1, 3, 32, 32)) 137 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 138 | # print(x.shape) 139 | start = time.time() 140 | pred = model(x) 141 | end = time.time() 142 | print("GoogLeNet forward time:", end-start) 143 | # print(pred.shape) 144 | # criterion = nn.CrossEntropyLoss() 145 | # loss = criterion(pred, y) 146 | # print(loss) 147 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | 17 | self.shortcut = nn.Sequential() 18 | if stride != 1 or in_planes != self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 21 | nn.BatchNorm2d(self.expansion*planes) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(self.conv1(x))) 26 | out = self.bn2(self.conv2(out)) 27 | out += self.shortcut(x) 28 | out = F.relu(out) 29 | return out 30 | 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, in_planes, planes, stride=1): 36 | super(Bottleneck, self).__init__() 37 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(planes) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 42 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 43 | 44 | self.shortcut = nn.Sequential() 45 | if stride != 1 or in_planes != self.expansion*planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 48 | nn.BatchNorm2d(self.expansion*planes) 49 | ) 50 | 51 | def forward(self, x): 52 | out = F.relu(self.bn1(self.conv1(x))) 53 | out = F.relu(self.bn2(self.conv2(out))) 54 | out = self.bn3(self.conv3(out)) 55 | out += self.shortcut(x) 56 | out = F.relu(out) 57 | return out 58 | 59 | 60 | class ResNet(nn.Module): 61 | def __init__(self, block, num_blocks, num_classes=10): 62 | super(ResNet, self).__init__() 63 | self.in_planes = 64 64 | 65 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(64) 67 | 68 | # self.bn1.eval() 69 | # for bn1_parameter in self.bn1.parameters(): 70 | # bn1_parameter.requires_grad = False 71 | 72 | # self.bn1.weight.requires_grad = False 73 | # self.bn1.bias.requires_grad = False 74 | # self.bn1.running_var.requires_grad = False 75 | # self.bn1.running_mean.requires_grad = False 76 | 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 81 | self.linear = nn.Linear(512*block.expansion, num_classes) 82 | 83 | def _make_layer(self, block, planes, num_blocks, stride): 84 | strides = [stride] + [1]*(num_blocks-1) 85 | layers = [] 86 | for stride in strides: 87 | layers.append(block(self.in_planes, planes, stride)) 88 | self.in_planes = planes * block.expansion 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | out = F.relu(self.bn1(self.conv1(x))) 93 | 94 | # print('self.bn1.weight.shape', self.bn1.weight.shape) 95 | # print('self.bn1.bias.shape', self.bn1.bias.shape) 96 | # print('self.bn1.weight:', sum(self.bn1.weight)) 97 | # print('self.bn1.bias:', sum(self.bn1.bias)) 98 | 99 | # print('self.bn1.running_var.shape', self.bn1.running_var.shape) 100 | # print('self.bn1.running_mean.shape', self.bn1.running_mean.shape) 101 | # print('self.bn1.running_var:', sum(self.bn1.running_var)) 102 | # print('self.bn1.running_mean:', sum(self.bn1.running_mean)) 103 | 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = F.avg_pool2d(out, 4) 109 | out = out.view(out.size(0), -1) 110 | out = self.linear(out) 111 | return out 112 | 113 | 114 | def ResNet18(): 115 | return ResNet(BasicBlock, [2,2,2,2]) 116 | 117 | def ResNet34(): 118 | return ResNet(BasicBlock, [3,4,6,3]) 119 | 120 | def ResNet50(): 121 | return ResNet(Bottleneck, [3,4,6,3]) 122 | 123 | def ResNet101(): 124 | return ResNet(Bottleneck, [3,4,23,3]) 125 | 126 | def ResNet152(): 127 | return ResNet(Bottleneck, [3,8,36,3]) 128 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/cifar/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 参考代码[resnet.py](https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py) 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import torch 7 | from torch.nn.init import kaiming_normal_ 8 | import torch.nn.functional as F 9 | from torch.nn.parallel._functions import Broadcast 10 | from torch.nn.parallel import scatter, parallel_apply, gather 11 | from functools import partial 12 | from nested_dict import nested_dict 13 | import os 14 | import time 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch.autograd import Variable 20 | from collections import OrderedDict 21 | from torch.nn import init 22 | import numpy as np 23 | from scipy import misc 24 | import matplotlib.pyplot as plt 25 | 26 | from cifarclassify.utils import imagenet_utils 27 | 28 | 29 | def cast(params, dtype='float'): 30 | if isinstance(params, dict): 31 | return {k: cast(v, dtype) for k,v in params.items()} 32 | else: 33 | return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)() 34 | 35 | 36 | # conv params no*ni*k*k 37 | def conv_params(ni, no, k=1): 38 | return kaiming_normal_(torch.Tensor(no, ni, k, k)) 39 | 40 | 41 | def linear_params(ni, no): 42 | return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)} 43 | 44 | 45 | def bnparams(n): 46 | return {'weight': torch.rand(n), 47 | 'bias': torch.zeros(n), 48 | 'running_mean': torch.zeros(n), 49 | 'running_var': torch.ones(n)} 50 | 51 | 52 | def data_parallel(f, input, params, mode, device_ids, output_device=None): 53 | assert isinstance(device_ids, list) 54 | if output_device is None: 55 | output_device = device_ids[0] 56 | 57 | if len(device_ids) == 1: 58 | return f(input, params, mode) 59 | 60 | params_all = Broadcast.apply(device_ids, *params.values()) 61 | params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())} 62 | for j in range(len(device_ids))] 63 | 64 | replicas = [partial(f, params=p, mode=mode) 65 | for p in params_replicas] 66 | inputs = scatter([input], device_ids) 67 | outputs = parallel_apply(replicas, inputs) 68 | return gather(outputs, output_device) 69 | 70 | 71 | def flatten(params): 72 | return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None} 73 | 74 | 75 | def batch_norm(x, params, base, mode): 76 | return F.batch_norm(x, weight=params[base + '.weight'], 77 | bias=params[base + '.bias'], 78 | running_mean=params[base + '.running_mean'], 79 | running_var=params[base + '.running_var'], 80 | training=mode) 81 | 82 | 83 | def print_tensor_dict(params): 84 | kmax = max(len(key) for key in params.keys()) 85 | for i, (key, v) in enumerate(params.items()): 86 | print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad) 87 | 88 | 89 | def set_requires_grad_except_bn_(params): 90 | for k, v in params.items(): 91 | if not k.endswith('running_mean') and not k.endswith('running_var'): 92 | # if not running_mean or running_var requires grad 93 | v.requires_grad = True 94 | 95 | # def wide_resnet(depth, width, num_classes): 96 | class wide_resnet(nn.Module): 97 | def __init__(self, depth, width, n_classes): 98 | super(wide_resnet, self).__init__() 99 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' # 4 is for conv0 100 | self.n = (depth - 4) // 6 101 | widths = [int(v * width) for v in (16, 32, 64)] # for normal resnet 16 32 64 102 | 103 | def gen_block_params(ni, no): 104 | return { 105 | 'conv0': conv_params(ni, no, 3), 106 | 'conv1': conv_params(no, no, 3), 107 | 'bn0': bnparams(ni), 108 | 'bn1': bnparams(no), 109 | 'convdim': conv_params(ni, no, 1) if ni != no else None, 110 | } 111 | 112 | def gen_group_params(ni, no, count): 113 | return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) for i in range(count)} 114 | 115 | # conv0+group0+goup1+group2 116 | self.flat_params = cast(flatten({ 117 | 'conv0': conv_params(3, 16, 3), # input 2 output 16 118 | 'group0': gen_group_params(16, widths[0], self.n), # input 16 output widths[0] 119 | 'group1': gen_group_params(widths[0], widths[1], self.n), 120 | 'group2': gen_group_params(widths[1], widths[2], self.n), 121 | 'bn': bnparams(widths[2]), 122 | 'fc': linear_params(widths[2], n_classes), 123 | })) 124 | 125 | # except bn requires grad 126 | set_requires_grad_except_bn_(self.flat_params) 127 | 128 | def block(self, x, params, base, mode, stride): 129 | o1 = F.relu(batch_norm(x, params, base + '.bn0', mode), inplace=True) 130 | y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1) 131 | o2 = F.relu(batch_norm(y, params, base + '.bn1', mode), inplace=True) 132 | z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1) 133 | if base + '.convdim' in params: 134 | return z + F.conv2d(o1, params[base + '.convdim'], stride=stride) 135 | else: 136 | return z + x 137 | 138 | def group(self, o, params, base, mode, stride): 139 | for i in range(self.n): 140 | o = self.block(o, params, '%s.block%d' % (base, i), mode, stride if i == 0 else 1) 141 | return o 142 | 143 | 144 | def forward(self, input): 145 | x = F.conv2d(input, self.flat_params['conv0'], padding=1) 146 | g0 = self.group(x, self.flat_params, 'group0', self.training, 1) 147 | g1 = self.group(g0, self.flat_params, 'group1', self.training, 2) 148 | g2 = self.group(g1, self.flat_params, 'group2', self.training, 2) 149 | o = F.relu(batch_norm(g2, self.flat_params, 'bn', self.training)) 150 | # print('o.shape:', o.shape) 151 | o = F.avg_pool2d(o, 8, 1, 0) 152 | # print('o.shape:', o.shape) 153 | o = o.view(o.size(0), -1) 154 | # print('o.shape:', o.shape) 155 | o = F.linear(o, self.flat_params['fc.weight'], self.flat_params['fc.bias']) 156 | return o 157 | # return f, flat_params 158 | # return f 159 | 160 | def wide_resnet_28_10(n_classes): 161 | depth = 28 162 | width = 10 163 | n_classes = n_classes 164 | return wide_resnet(depth, width, n_classes) 165 | 166 | def wide_resnet_16_8(n_classes): 167 | depth = 16 168 | width = 8 169 | n_classes = n_classes 170 | return wide_resnet(depth, width, n_classes) 171 | 172 | if __name__ == '__main__': 173 | n_classes = 10 174 | # model = wide_resnet_28_10(n_classes=n_classes) 175 | model = wide_resnet_16_8(n_classes) 176 | 177 | x = Variable(torch.randn(1, 3, 32, 32)) 178 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 179 | # print(x.shape) 180 | start = time.time() 181 | pred = model(x) 182 | # print('pred.shape', pred.shape) 183 | end = time.time() 184 | print("AlexNet forward time:", end-start) 185 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import time 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torchvision import models 8 | import numpy as np 9 | import os 10 | from scipy import misc 11 | import numpy as np 12 | import scipy 13 | import matplotlib.pyplot as plt 14 | from torch.utils import model_zoo 15 | 16 | from cifarclassify.utils import imagenet_utils 17 | 18 | 19 | class AlexNet(nn.Module): 20 | """ 21 | :param 22 | """ 23 | def __init__(self, n_classes=1000, pretrained=False): 24 | super(AlexNet, self).__init__() 25 | self.n_classes = n_classes 26 | self.pretrained = pretrained 27 | 28 | # features和classifier的结构和vgg16等类似 29 | self.features = nn.Sequential( 30 | # (224-11+2*2)/4+1=55 31 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 32 | nn.ReLU(inplace=True), 33 | # (55-3)/2+1=27 34 | nn.MaxPool2d(kernel_size=3, stride=2), 35 | # (27-5+2*2)/1+1=27 36 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 37 | nn.ReLU(inplace=True), 38 | # (27-3)/2+1=13 39 | nn.MaxPool2d(kernel_size=3, stride=2), 40 | # (13-3+2*1)/1+1=13 41 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 42 | nn.ReLU(inplace=True), 43 | # (13-3+2*1)/1+1=13 44 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 45 | nn.ReLU(inplace=True), 46 | # (13-3+2*1)/1+1=13 47 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 48 | nn.ReLU(inplace=True), 49 | # (13-3)/2+1=6 50 | nn.MaxPool2d(kernel_size=3, stride=2), 51 | ) 52 | # 分类器使用Linear全连接层,特征层使用Conv2d卷积层 53 | self.classifier = nn.Sequential( 54 | 55 | nn.Dropout(p=0.5), 56 | # 特征层的输出为256*6*6,转换为4096的输出 57 | nn.Linear(256 * 6 * 6, 4096), 58 | nn.ReLU(inplace=True), 59 | 60 | nn.Dropout(p=0.5), 61 | nn.Linear(4096, 4096), 62 | nn.ReLU(inplace=True), 63 | 64 | nn.Linear(4096, n_classes), 65 | ) 66 | 67 | if self.pretrained: 68 | self.load_weights() 69 | 70 | def forward(self, x): 71 | """ 72 | :param x: 73 | :return: 74 | """ 75 | x = self.features(x) 76 | x = x.view(x.size(0), 256 * 6 * 6) 77 | x = self.classifier(x) 78 | return x 79 | 80 | def load_weights(self): 81 | pretrained_dict = model_zoo.load_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth') 82 | model_dict = self.state_dict() 83 | # print('pretrained_dict.keys():', pretrained_dict.keys()) 84 | # print('model_dict.keys():', model_dict.keys()) 85 | if self.n_classes!=1000: 86 | new_dict = {k: v for k, v in pretrained_dict.items() if k not in {'classifier.6.weight', 'classifier.6.bias'}} 87 | else: 88 | new_dict = pretrained_dict 89 | model_dict.update(new_dict) 90 | self.load_state_dict(model_dict) 91 | 92 | 93 | if __name__ == '__main__': 94 | n_classes = 1000 95 | model = AlexNet(n_classes=n_classes, pretrained=True) 96 | model.eval() 97 | # model_pretrain_filename = os.path.expanduser('~/.torch/models/alexnet-owt-4df8aa71.pth') 98 | # if os.path.exists(model_pretrain_filename): 99 | # model.load_state_dict(torch.load(model_pretrain_filename)) 100 | 101 | input_data = misc.imread('../../../data/cat.jpg') 102 | # 按照imagenet的图像格式预处理 103 | input_data = imagenet_utils.imagenet_preprocess(input_data) 104 | 105 | # x = Variable(torch.randn(1, 3, 224, 224)) 106 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 107 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 108 | # print(x.shape) 109 | start = time.time() 110 | pred = model(x) 111 | end = time.time() 112 | print("AlexNet forward time:", end-start) 113 | 114 | imagenet_utils.get_imagenet_label(pred) 115 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/densenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 下面代码来自于[densenet.pytorch](https://github.com/bamos/densenet.pytorch) 3 | 4 | import torch 5 | import time 6 | from torch import nn 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import torch.nn.functional as F 10 | import math 11 | 12 | 13 | class Bottleneck(nn.Module): 14 | """ 15 | DenseNet中的Bottleneck,由BN+ReLU+Conv(1x1)+BN+ReLU+Conv(3x3)组成,其中1x1的feature map通道数为4*growthRate 16 | """ 17 | def __init__(self, nChannels, growthRate): 18 | """ 19 | :param nChannels: 输入该Bottleneck的通道数,growthRate是每一个网络层的输出通道数k 20 | :param growthRate: 21 | """ 22 | super(Bottleneck, self).__init__() 23 | interChannels = 4*growthRate 24 | self.bn1 = nn.BatchNorm2d(nChannels) 25 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(interChannels) 27 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) 28 | 29 | def forward(self, x): 30 | """ 31 | Bottleneck前向传播 32 | :param x: 33 | :return: 34 | """ 35 | out = self.conv1(F.relu(self.bn1(x))) 36 | out = self.conv2(F.relu(self.bn2(out))) 37 | out = torch.cat((x, out), 1) 38 | return out 39 | 40 | 41 | class SingleLayer(nn.Module): 42 | """ 43 | SingleLayer是DenseNet中每一个Block中最小网络单位,由BN+ReLU+Conv(3x3)组成 44 | """ 45 | def __init__(self, nChannels, growthRate): 46 | super(SingleLayer, self).__init__() 47 | self.bn1 = nn.BatchNorm2d(nChannels) 48 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) 49 | 50 | def forward(self, x): 51 | """ 52 | SingleLayer前向传播 53 | :param x: 这里的输入是经过先前所有网络层concat而成的feature map 54 | :return: 55 | """ 56 | out = self.conv1(F.relu(self.bn1(x))) 57 | out = torch.cat((x, out), 1) 58 | return out 59 | 60 | class Transition(nn.Module): 61 | """ 62 | DenseNet网络中Block间的网络层组成,主要作用是改变feature map的大小,由BN+ReLU+Conv+MaxPooling组成 63 | """ 64 | def __init__(self, nChannels, nOutChannels): 65 | """ 66 | :param nChannels: 输入到Transition的通道数 67 | :param nOutChannels: Transition输出的通道数,可以一致k,也可以变化 68 | """ 69 | super(Transition, self).__init__() 70 | self.bn1 = nn.BatchNorm2d(nChannels) 71 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) 72 | 73 | def forward(self, x): 74 | """ 75 | Transition前向传播 76 | :param x: DenseNet先前block的feature map 77 | :return: 78 | """ 79 | out = self.conv1(F.relu(self.bn1(x))) 80 | out = F.avg_pool2d(out, 2) 81 | return out 82 | 83 | 84 | class DenseNet(nn.Module): 85 | """ 86 | DenseNet网络构造 87 | """ 88 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 89 | """ 90 | :param growthRate: 其中Block中每一个网络层输出的通道数,也就是k 91 | :param depth: depth是DenseBlocks的深度参考 92 | :param reduction: reduction是每一个transition是否降低输出通道数,也就是growthRate*reduction 93 | :param nClasses: 94 | :param bottleneck: 95 | """ 96 | super(DenseNet, self).__init__() 97 | 98 | nDenseBlocks = (depth-4) // 3 99 | if bottleneck: 100 | nDenseBlocks //= 2 101 | 102 | nChannels = 2*growthRate 103 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) 104 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 105 | nChannels += nDenseBlocks*growthRate 106 | nOutChannels = int(math.floor(nChannels*reduction)) 107 | self.trans1 = Transition(nChannels, nOutChannels) 108 | 109 | nChannels = nOutChannels 110 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 111 | nChannels += nDenseBlocks*growthRate 112 | nOutChannels = int(math.floor(nChannels*reduction)) 113 | self.trans2 = Transition(nChannels, nOutChannels) 114 | 115 | nChannels = nOutChannels 116 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 117 | nChannels += nDenseBlocks*growthRate 118 | 119 | self.bn1 = nn.BatchNorm2d(nChannels) 120 | self.fc = nn.Linear(nChannels, nClasses) 121 | 122 | # 网络中参数初始化 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | elif isinstance(m, nn.Linear): 131 | m.bias.data.zero_() 132 | 133 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 134 | """ 135 | make dense是DenseNet中的一个Dense Block 136 | :param nChannels: Dense Block的输入通道 137 | :param growthRate: Dense Block中的growthRate 138 | :param nDenseBlocks: Dense Block中的layer数量 139 | :param bottleneck: 是否使用bottleneck,即1x1+3x3,不使用则为3x3 140 | :return: 141 | """ 142 | layers = [] 143 | for i in range(int(nDenseBlocks)): 144 | if bottleneck: 145 | # 如果使用bottleneck那么网络层中使用bottleneck,块中每一个输入通道数目为nChannels+l*growthRate 146 | layers.append(Bottleneck(nChannels, growthRate)) 147 | else: 148 | layers.append(SingleLayer(nChannels, growthRate)) 149 | nChannels += growthRate 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | """ 154 | DenseNet前向传播 155 | :param x: 156 | :return: 157 | """ 158 | out = self.conv1(x) 159 | out = self.trans1(self.dense1(out)) 160 | out = self.trans2(self.dense2(out)) 161 | out = self.dense3(out) 162 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 163 | out = self.fc(out) 164 | # out = F.log_softmax(out) 165 | return out 166 | 167 | 168 | if __name__ == '__main__': 169 | n_classes = 10 170 | model = DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=n_classes) 171 | model.eval() 172 | # model.init_vgg16() 173 | x = Variable(torch.randn(1, 3, 32, 32)) 174 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 175 | # print(x.shape) 176 | start = time.time() 177 | pred = model(x) 178 | end = time.time() 179 | # print(model) 180 | print("DenseNet forward time:", end - start) 181 | # start = time.time() 182 | # vgg_16 = models.vgg16(pretrained=False) 183 | # pred = vgg_16(x) 184 | # end = time.time() 185 | # print("vgg16 forward time:", end-start) -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/googlenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | # googlenet中的inception结构 14 | class Inception(nn.Module): 15 | """ 16 | :param 17 | """ 18 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 19 | super(Inception, self).__init__() 20 | # 1x1 conv branch 21 | # 1x1卷积分支 22 | self.b1 = nn.Sequential( 23 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 24 | nn.BatchNorm2d(n1x1), 25 | nn.ReLU(True), 26 | ) 27 | 28 | # 1x1 conv -> 3x3 conv branch 29 | # 1x1卷积加上3x3卷积 30 | self.b2 = nn.Sequential( 31 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 32 | nn.BatchNorm2d(n3x3red), 33 | nn.ReLU(True), 34 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n3x3), 36 | nn.ReLU(True), 37 | ) 38 | 39 | # 1x1 conv -> 5x5 conv branch 40 | # 1x1卷积加上5x5卷积,这里所有的卷积都保证了最后输出的feature map与输入相同,这里5x5用两个3x3卷积代替 41 | self.b3 = nn.Sequential( 42 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 43 | nn.BatchNorm2d(n5x5red), 44 | nn.ReLU(True), 45 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 46 | nn.BatchNorm2d(n5x5), 47 | nn.ReLU(True), 48 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 49 | nn.BatchNorm2d(n5x5), 50 | nn.ReLU(True), 51 | ) 52 | 53 | # 3x3 pool -> 1x1 conv branch 54 | self.b4 = nn.Sequential( 55 | nn.MaxPool2d(3, stride=1, padding=1), 56 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 57 | nn.BatchNorm2d(pool_planes), 58 | nn.ReLU(True), 59 | ) 60 | 61 | def forward(self, x): 62 | """ 63 | :param x: 64 | :return: 65 | """ 66 | y1 = self.b1(x) 67 | y2 = self.b2(x) 68 | y3 = self.b3(x) 69 | y4 = self.b4(x) 70 | out = torch.cat([y1, y2, y3, y4], 1) 71 | # print('y1.data.shape:', y1.data.shape) 72 | # print('y2.data.shape:', y2.data.shape) 73 | # print('y3.data.shape:', y3.data.shape) 74 | # print('y4.data.shape:', y4.data.shape) 75 | # print('out.data.shape:', out.data.shape) 76 | return out 77 | 78 | 79 | # imagenet上的GoogLeNet 80 | class GoogLeNet(nn.Module): 81 | """ 82 | :param 83 | """ 84 | def __init__(self, n_classes=1000): 85 | super(GoogLeNet, self).__init__() 86 | self.pre_layers = nn.Sequential( 87 | nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2), 88 | nn.BatchNorm2d(64), 89 | nn.ReLU(True), 90 | nn.MaxPool2d(3, stride=2, padding=1), 91 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 92 | nn.BatchNorm2d(192), 93 | nn.ReLU(True), 94 | nn.MaxPool2d(3, stride=2, padding=1), 95 | ) 96 | 97 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 98 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 99 | 100 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 101 | 102 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 103 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 104 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 105 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 106 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 107 | 108 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 109 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 110 | 111 | self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1) 112 | self.linear = nn.Linear(1024, n_classes) 113 | 114 | def forward(self, x): 115 | """ 116 | :param x: 117 | :return: 118 | """ 119 | out = self.pre_layers(x) 120 | # print('out.shape:', out.shape) 121 | out = self.a3(out) 122 | # print('out_a3.shape:', out.shape) 123 | out = self.b3(out) 124 | out = self.maxpool(out) 125 | out = self.a4(out) 126 | out = self.b4(out) 127 | out = self.c4(out) 128 | out = self.d4(out) 129 | out = self.e4(out) 130 | out = self.maxpool(out) 131 | out = self.a5(out) 132 | out = self.b5(out) 133 | # print('out_b5.shape:', out.shape) 134 | out = self.avgpool(out) 135 | # print('out_avgpool.shape:', out.shape) 136 | out = out.view(out.size(0), -1) 137 | out = self.linear(out) 138 | return out 139 | 140 | 141 | def main(): 142 | n_classes = 1000 143 | model = GoogLeNet(n_classes=n_classes) 144 | model.eval() 145 | x = Variable(torch.randn(1, 3, 224, 224)) 146 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 147 | # print(x.shape) 148 | start = time.time() 149 | pred = model(x) 150 | end = time.time() 151 | print("GoogLeNet forward time:", end-start) 152 | # print(pred.shape) 153 | # criterion = nn.CrossEntropyLoss() 154 | # loss = criterion(pred, y) 155 | # print(loss) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | 161 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import time 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torchvision import models 8 | import numpy as np 9 | import os 10 | from scipy import misc 11 | 12 | from cifarclassify.utils import imagenet_utils 13 | 14 | class mobilenet_conv_bn_relu(nn.Module): 15 | """ 16 | :param 17 | """ 18 | def __init__(self, in_channels, out_channels, stride): 19 | super(mobilenet_conv_bn_relu, self).__init__() 20 | self.cbr_seq = nn.Sequential( 21 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, 22 | padding=1, bias=False), 23 | nn.BatchNorm2d(num_features=out_channels), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | def forward(self, x): 28 | """ 29 | :param x: 30 | :return: 31 | """ 32 | x = self.cbr_seq(x) 33 | return x 34 | 35 | 36 | class mobilenet_conv_dw_relu(nn.Module): 37 | """ 38 | :param 39 | """ 40 | def __init__(self, in_channels, out_channels, stride): 41 | super(mobilenet_conv_dw_relu, self).__init__() 42 | self.cbr_seq = nn.Sequential( 43 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=stride, 44 | padding=1, groups=in_channels, bias=False), 45 | nn.BatchNorm2d(num_features=in_channels), 46 | nn.ReLU(inplace=True), 47 | 48 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, 49 | padding=0, bias=False), 50 | nn.BatchNorm2d(num_features=out_channels), 51 | nn.ReLU(inplace=True) 52 | ) 53 | 54 | def forward(self, x): 55 | """ 56 | :param x: 57 | :return: 58 | """ 59 | x = self.cbr_seq(x) 60 | return x 61 | 62 | 63 | class MobileNet(nn.Module): 64 | """ 65 | :param 66 | """ 67 | def __init__(self, n_classes=1000): 68 | super(MobileNet, self).__init__() 69 | self.conv1_bn = mobilenet_conv_bn_relu(3, 32, 2) 70 | self.conv2_dw = mobilenet_conv_dw_relu(32, 64, 1) 71 | self.conv3_dw = mobilenet_conv_dw_relu(64, 128, 2) 72 | self.conv4_dw = mobilenet_conv_dw_relu(128, 128, 1) 73 | self.conv5_dw = mobilenet_conv_dw_relu(128, 256, 2) 74 | self.conv6_dw = mobilenet_conv_dw_relu(256, 256, 1) 75 | self.conv7_dw = mobilenet_conv_dw_relu(256, 512, 2) 76 | self.conv8_dw = mobilenet_conv_dw_relu(512, 512, 1) 77 | self.conv9_dw = mobilenet_conv_dw_relu(512, 512, 1) 78 | self.conv10_dw = mobilenet_conv_dw_relu(512, 512, 1) 79 | self.conv11_dw = mobilenet_conv_dw_relu(512, 512, 1) 80 | self.conv12_dw = mobilenet_conv_dw_relu(512, 512, 1) 81 | self.conv13_dw = mobilenet_conv_dw_relu(512, 1024, 2) 82 | self.conv14_dw = mobilenet_conv_dw_relu(1024, 1024, 1) 83 | self.avg_pool = nn.AvgPool2d(7) 84 | self.fc = nn.Linear(1024, n_classes) 85 | self.init_weights(pretrained=True) 86 | 87 | def init_weights(self, pretrained=False): 88 | model_checkpoint_path = os.path.expanduser('~/.torch/models/mobilenet_sgd_rmsprop_69.526.tar') 89 | if os.path.exists(model_checkpoint_path): 90 | model_checkpoint = torch.load(model_checkpoint_path, map_location='cpu') 91 | pretrained_dict = model_checkpoint['state_dict'] 92 | 93 | model_dict = self.state_dict() 94 | 95 | # print(model_dict.keys()) 96 | # print(pretrained_dict.keys()) 97 | model_dict_keys = model_dict.keys() 98 | 99 | new_dict = {} 100 | for dict_index, (k, v) in enumerate(pretrained_dict.items()): 101 | # print(dict_index) 102 | # print(k) 103 | new_k = model_dict_keys[dict_index] 104 | new_v = v 105 | new_dict[new_k] = new_v 106 | model_dict.update(new_dict) 107 | self.load_state_dict(model_dict) 108 | 109 | def forward(self, x): 110 | """ 111 | :param x: 112 | :return: 113 | """ 114 | x = self.conv1_bn(x) 115 | x = self.conv2_dw(x) 116 | x = self.conv3_dw(x) 117 | x = self.conv4_dw(x) 118 | x = self.conv5_dw(x) 119 | x = self.conv6_dw(x) 120 | x = self.conv7_dw(x) 121 | x = self.conv8_dw(x) 122 | x = self.conv9_dw(x) 123 | x = self.conv10_dw(x) 124 | x = self.conv11_dw(x) 125 | x = self.conv12_dw(x) 126 | x = self.conv13_dw(x) 127 | x = self.conv14_dw(x) 128 | x = self.avg_pool(x) 129 | x = x.view(-1, 1024) 130 | x = self.fc(x) 131 | return x 132 | 133 | 134 | if __name__ == '__main__': 135 | n_classes = 1000 136 | model = MobileNet(n_classes=n_classes) 137 | model.eval() 138 | input_data = misc.imread('../../../data/cat.jpg') 139 | # 按照imagenet的图像格式预处理 140 | input_data = imagenet_utils.imagenet_preprocess(input_data) 141 | 142 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 143 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 144 | 145 | start = time.time() 146 | pred = model(x) 147 | end = time.time() 148 | print("MobileNet forward time:", end - start) 149 | 150 | imagenet_utils.get_imagenet_label(pred) 151 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import torch.nn as nn 5 | import time 6 | import torch 7 | from torch import nn 8 | from torch.autograd import Variable 9 | from torchvision import models 10 | import numpy as np 11 | import math 12 | import os 13 | from scipy import misc 14 | import numpy as np 15 | import scipy 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | from cifarclassify.utils import numpy_utils 20 | from cifarclassify.utils import imagenet_utils 21 | 22 | 23 | # conv batchnorm 24 | def conv_bn(inp, oup, stride): 25 | """ 26 | :param inp: 27 | :param oup: 28 | :param stride: 29 | :return: 30 | """ 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 33 | nn.BatchNorm2d(oup), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | 38 | def conv_1x1_bn(inp, oup): 39 | """ 40 | :param inp: 41 | :param oup: 42 | :return: 43 | """ 44 | return nn.Sequential( 45 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 46 | nn.BatchNorm2d(oup), 47 | nn.ReLU(inplace=True) 48 | ) 49 | 50 | 51 | # 反向残差模块 52 | class InvertedResidual(nn.Module): 53 | """ 54 | :param 55 | """ 56 | def __init__(self, inp, oup, stride, expand_ratio): 57 | super(InvertedResidual, self).__init__() 58 | self.stride = stride 59 | assert stride in [1, 2] 60 | 61 | # 仅仅当stride==1和inp==oup时使用残差连接 62 | self.use_res_connect = self.stride == 1 and inp == oup 63 | 64 | self.conv = nn.Sequential( 65 | # pw 66 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 67 | nn.BatchNorm2d(inp * expand_ratio), 68 | nn.ReLU6(inplace=True), 69 | # dw 70 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 71 | nn.BatchNorm2d(inp * expand_ratio), 72 | nn.ReLU6(inplace=True), 73 | # pw-linear 74 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 75 | nn.BatchNorm2d(oup), 76 | ) 77 | 78 | def forward(self, x): 79 | if self.use_res_connect: 80 | return x + self.conv(x) 81 | else: 82 | return self.conv(x) 83 | 84 | 85 | class MobileNetV2(nn.Module): 86 | """ 87 | :param 88 | """ 89 | def __init__(self, n_classes=1000, input_size=224, width_mult=1.): 90 | super(MobileNetV2, self).__init__() 91 | # setting of inverted residual blocks 92 | # 反向残差块设置 93 | self.interverted_residual_setting = [ 94 | # t, c, n, s 95 | [1, 16, 1, 1], 96 | [6, 24, 2, 2], 97 | [6, 32, 3, 2], 98 | [6, 64, 4, 2], 99 | [6, 96, 3, 1], 100 | [6, 160, 3, 2], 101 | [6, 320, 1, 1], 102 | ] 103 | 104 | # building first layer 105 | assert input_size % 32 == 0 106 | input_channel = int(32 * width_mult) 107 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 108 | self.features = [conv_bn(3, input_channel, 2)] 109 | # building inverted residual blocks 110 | for t, c, n, s in self.interverted_residual_setting: 111 | output_channel = int(c * width_mult) 112 | for i in range(n): 113 | if i == 0: 114 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 115 | else: 116 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 117 | input_channel = output_channel 118 | # building last several layers 119 | # 构建最后几层 120 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 121 | self.features.append(nn.AvgPool2d(input_size/32)) 122 | # make it nn.Sequential 123 | self.features = nn.Sequential(*self.features) 124 | 125 | # building classifier 126 | self.classifier = nn.Sequential( 127 | nn.Dropout(), 128 | nn.Linear(self.last_channel, n_classes), 129 | ) 130 | 131 | self._initialize_weights() 132 | 133 | def forward(self, x): 134 | """ 135 | :param x: 136 | :return: 137 | """ 138 | x = self.features(x) 139 | x = x.view(-1, self.last_channel) 140 | x = self.classifier(x) 141 | return x 142 | 143 | def _initialize_weights(self): 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 147 | m.weight.data.normal_(0, math.sqrt(2. / n)) 148 | if m.bias is not None: 149 | m.bias.data.zero_() 150 | elif isinstance(m, nn.BatchNorm2d): 151 | m.weight.data.fill_(1) 152 | m.bias.data.zero_() 153 | elif isinstance(m, nn.Linear): 154 | n = m.weight.size(1) 155 | m.weight.data.normal_(0, 0.01) 156 | m.bias.data.zero_() 157 | 158 | if __name__ == '__main__': 159 | 160 | image_height, image_width, image_channel = (224, 224, 3) 161 | input = misc.imread('../../../data/cat.jpg') 162 | # 按照imagenet的图像格式预处理 163 | input = imagenet_utils.imagenet_preprocess(input) 164 | 165 | n_classes = 1000 166 | model = MobileNetV2(n_classes=n_classes) 167 | model.eval() 168 | # 训练模型为gpu模型 169 | # model.load_state_dict(torch.load(os.path.expanduser('~/Data/mobilenetv2.pth.tar'), map_location=lambda storage, loc: storage)) 170 | # x = Variable(torch.randn(1, image_channel, image_height, image_width)) 171 | x = Variable(torch.FloatTensor(torch.from_numpy(input))) 172 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 173 | # print(x.shape) 174 | start = time.time() 175 | pred = model(x) 176 | end = time.time() 177 | print("MobileNetV2 forward time:", end-start) 178 | 179 | imagenet_utils.get_imagenet_label(pred) 180 | 181 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/peleenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | import time 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Variable 10 | from torchvision import models 11 | import numpy as np 12 | import os 13 | from scipy import misc 14 | import numpy as np 15 | import scipy 16 | import matplotlib.pyplot as plt 17 | from torch.utils import model_zoo 18 | 19 | from cifarclassify.utils import imagenet_utils 20 | 21 | 22 | class Conv_BN_Relu(nn.Module): 23 | def __init__(self, inp, oup, kernel_size=3, stride=1, pad=1, use_relu=True): 24 | super(Conv_BN_Relu, self).__init__() 25 | self.use_relu = use_relu 26 | if self.use_relu: 27 | self.convs = nn.Sequential( 28 | nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False), 29 | nn.BatchNorm2d(oup), 30 | nn.ReLU(inplace=True), 31 | ) 32 | else: 33 | self.convs = nn.Sequential( 34 | nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False), 35 | nn.BatchNorm2d(oup), 36 | ) 37 | 38 | def forward(self, x): 39 | out = self.convs(x) 40 | return out 41 | 42 | 43 | class StemBlock(nn.Module): 44 | def __init__(self, inp=3, num_init_features=32): 45 | super(StemBlock, self).__init__() 46 | 47 | self.stem_1 = Conv_BN_Relu(inp, num_init_features, 3, 2, 1) 48 | self.stem_2a = Conv_BN_Relu(num_init_features, int(num_init_features / 2), 1, 1, 0) 49 | self.stem_2b = Conv_BN_Relu(int(num_init_features / 2), num_init_features, 3, 2, 1) 50 | self.stem_2c = nn.MaxPool2d(kernel_size=2, stride=2) 51 | self.stem_3 = Conv_BN_Relu(num_init_features * 2, num_init_features, 1, 1, 0) 52 | 53 | def forward(self, x): 54 | 55 | # --------------stem_1-------------- 56 | stem_1_out = self.stem_1(x) 57 | # --------------stem_1-------------- 58 | 59 | # --------------stem_2-------------- 60 | stem_2a_out = self.stem_2a(stem_1_out) 61 | stem_2b_out = self.stem_2b(stem_2a_out) 62 | 63 | stem_2c_out = self.stem_2c(stem_1_out) 64 | # --------------stem_2-------------- 65 | 66 | # --------------stem_3-------------- 67 | out = self.stem_3(torch.cat((stem_2b_out, stem_2c_out), 1)) 68 | # --------------stem_3-------------- 69 | 70 | return out 71 | 72 | 73 | class DenseBlock(nn.Module): 74 | def __init__(self, inp, inter_channel, growth_rate): 75 | super(DenseBlock, self).__init__() 76 | # print('inter_channel:', inter_channel) 77 | # print('growth_rate:', growth_rate) 78 | 79 | self.cb1_a = Conv_BN_Relu(inp, inter_channel, 1, 1, 0) 80 | self.cb1_b = Conv_BN_Relu(inter_channel, growth_rate, 3, 1, 1) 81 | 82 | self.cb2_a = Conv_BN_Relu(inp, inter_channel, 1, 1, 0) 83 | self.cb2_b = Conv_BN_Relu(inter_channel, growth_rate, 3, 1, 1) 84 | self.cb2_c = Conv_BN_Relu(growth_rate, growth_rate, 3, 1, 1) 85 | 86 | def forward(self, x): 87 | cb1_a_out = self.cb1_a(x) 88 | cb1_b_out = self.cb1_b(cb1_a_out) 89 | 90 | cb2_a_out = self.cb2_a(x) 91 | cb2_b_out = self.cb2_b(cb2_a_out) 92 | cb2_c_out = self.cb2_c(cb2_b_out) 93 | 94 | out = torch.cat((x, cb1_b_out, cb2_c_out), 1) # dense 95 | 96 | return out 97 | 98 | 99 | class TransitionBlock(nn.Module): 100 | def __init__(self, inp, oup, with_pooling=True): 101 | super(TransitionBlock, self).__init__() 102 | if with_pooling: 103 | self.tb = nn.Sequential(Conv_BN_Relu(inp, oup, 1, 1, 0), 104 | nn.AvgPool2d(kernel_size=2, stride=2)) 105 | else: 106 | self.tb = Conv_BN_Relu(inp, oup, 1, 1, 0) 107 | 108 | def forward(self, x): 109 | out = self.tb(x) 110 | return out 111 | 112 | 113 | class PeleeNet(nn.Module): 114 | def __init__(self, n_classes=1000, num_init_features=32, growthRate=32, nDenseBlocks=[3, 4, 8, 6], bottleneck_width=[1, 2, 4, 4], pretrained=False): 115 | super(PeleeNet, self).__init__() 116 | 117 | self.stage = nn.Sequential() 118 | self.n_classes = n_classes 119 | self.num_init_features = num_init_features 120 | 121 | inter_channel = list() 122 | total_filter = list() 123 | dense_inp = list() 124 | 125 | self.half_growth_rate = int(growthRate / 2) 126 | 127 | # building stemblock 128 | self.stage.add_module('stage_0', StemBlock(3, num_init_features)) 129 | 130 | # 131 | for i, b_w in enumerate(bottleneck_width): 132 | 133 | inter_channel.append(int(self.half_growth_rate * b_w / 4) * 4) # different stage different inter channel 134 | 135 | if i == 0: 136 | total_filter.append(num_init_features + growthRate * nDenseBlocks[i]) 137 | dense_inp.append(self.num_init_features) 138 | else: 139 | total_filter.append(total_filter[i - 1] + growthRate * nDenseBlocks[i]) 140 | dense_inp.append(total_filter[i - 1]) 141 | 142 | if i == len(nDenseBlocks) - 1: 143 | # 最后一层不加池化层 144 | with_pooling = False 145 | else: 146 | with_pooling = True 147 | 148 | # building middle stageblock 149 | self.stage.add_module('stage_{}'.format(i + 1), self._make_dense_transition(dense_inp[i], total_filter[i], inter_channel[i], nDenseBlocks[i], with_pooling=with_pooling)) 150 | 151 | # building classifier 152 | self.classifier = nn.Sequential( 153 | nn.Dropout(), 154 | nn.Linear(total_filter[len(nDenseBlocks) - 1], self.n_classes) 155 | ) 156 | 157 | self._initialize_weights() 158 | 159 | def _make_dense_transition(self, dense_inp, total_filter, inter_channel, ndenseblocks, with_pooling=True): 160 | layers = [] 161 | 162 | for i in range(ndenseblocks): 163 | layers.append(DenseBlock(dense_inp, inter_channel, self.half_growth_rate)) 164 | dense_inp += self.half_growth_rate * 2 165 | 166 | # Transition Layer without Compression 167 | layers.append(TransitionBlock(dense_inp, total_filter, with_pooling)) 168 | 169 | return nn.Sequential(*layers) 170 | 171 | def forward(self, x): 172 | 173 | x = self.stage(x) 174 | 175 | # global average pooling layer 176 | x = F.avg_pool2d(x, kernel_size=7) 177 | x = x.view(x.size(0), -1) 178 | out = self.classifier(x) 179 | # out = F.log_softmax(x, dim=1) 180 | 181 | return out 182 | 183 | def _initialize_weights(self): 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 187 | m.weight.data.normal_(0, math.sqrt(2. / n)) 188 | if m.bias is not None: 189 | m.bias.data.zero_() 190 | elif isinstance(m, nn.BatchNorm2d): 191 | m.weight.data.fill_(1) 192 | m.bias.data.zero_() 193 | elif isinstance(m, nn.Linear): 194 | n = m.weight.size(1) 195 | m.weight.data.normal_(0, 0.01) 196 | m.bias.data.zero_() 197 | 198 | 199 | def main(): 200 | n_classes = 1000 201 | model = PeleeNet(n_classes=n_classes, pretrained=False) 202 | model.eval() 203 | 204 | input_data = misc.imread('../../../data/cat.jpg') 205 | # 按照imagenet的图像格式预处理 206 | input_data = imagenet_utils.imagenet_preprocess(input_data) 207 | 208 | # x = Variable(torch.randn(1, 3, 224, 224)) 209 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 210 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 211 | # print(x.shape) 212 | start = time.time() 213 | pred = model(x) 214 | end = time.time() 215 | print("PeleeNet forward time:", end-start) 216 | imagenet_utils.get_imagenet_label(pred) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import torch 4 | from torch import nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.autograd import Variable 7 | import os 8 | from scipy import misc 9 | import numpy as np 10 | import scipy 11 | import matplotlib.pyplot as plt 12 | import torchvision 13 | 14 | from cifarclassify.utils import imagenet_utils 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1): 26 | """ 3x3卷积(padding) 27 | :param in_planes: 28 | :param out_planes: 29 | :param stride: 30 | :return: 31 | """ 32 | return nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | """ 37 | BasicBlock 38 | """ 39 | expansion = 1 # 最后一层是前一层的expansion倍 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3(in_planes=inplanes, out_planes=planes, stride=stride) 44 | self.bn1 = nn.BatchNorm2d(num_features=planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(in_planes=planes, out_planes=planes) 47 | self.bn2 = nn.BatchNorm2d(num_features=planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | """ 53 | :param x: 54 | :return: 55 | """ 56 | residual = x 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | residual = self.downsample(x) 66 | 67 | out += residual 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | """ 75 | Bottleneck 76 | """ 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None): 80 | super(Bottleneck, self).__init__() 81 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(planes) 83 | 84 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 85 | padding=1, bias=False) 86 | 87 | self.bn2 = nn.BatchNorm2d(planes) 88 | 89 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 90 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 91 | 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | """ 98 | :param x: 99 | :return: 100 | """ 101 | residual = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | residual = self.downsample(x) 116 | 117 | out += residual 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | """ Constructs a ResNet template 125 | """ 126 | def __init__(self, block, layers, n_classes=1000): 127 | """ 128 | :param block: BasicBlock or Bottleneck 129 | :param layers: 130 | :param num_classes: 131 | """ 132 | super(ResNet, self).__init__() 133 | self.n_classes = n_classes 134 | self.inplanes = 64 135 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False) # padding=(kernel_size-1)/2 bias=False 136 | self.bn1 = nn.BatchNorm2d(num_features=64) 137 | self.relu = nn.ReLU(inplace=True) 138 | 139 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # padding=(kernel_size-1)/2 140 | self.layer1 = self._make_layer(block=block, planes=64, blocks=layers[0]) 141 | self.layer2 = self._make_layer(block=block, planes=128, blocks=layers[1], stride=2) 142 | self.layer3 = self._make_layer(block=block, planes=256, blocks=layers[2], stride=2) 143 | self.layer4 = self._make_layer(block=block, planes=512, blocks=layers[3], stride=2) 144 | self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1) 145 | self.fc = nn.Linear(in_features=512*block.expansion, out_features=self.n_classes) 146 | 147 | 148 | # 初始化卷积层和BN层 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 152 | elif isinstance(m, nn.BatchNorm2d): 153 | nn.init.constant_(m.weight, 1) 154 | nn.init.constant_(m.bias, 0) 155 | 156 | def _make_layer(self, block, planes, blocks, stride=1): 157 | downsample = None 158 | # stride = 1表示第一层,不需要下采样(使用maxpool下采样了),stride = 2表示第二,三,四层,需要下采样 159 | if stride != 1 or self.inplanes != planes * block.expansion: 160 | downsample = nn.Sequential( 161 | nn.Conv2d(in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=1, stride=stride, bias=False), 162 | nn.BatchNorm2d(num_features=planes * block.expansion) 163 | ) 164 | 165 | layers = [] 166 | # blocks中的第一层决定是否有下采样,其中第一个block的第一层没有下采样,其他block的第一层有下采样 167 | layers.append(block(self.inplanes, planes, stride, downsample)) 168 | self.inplanes = planes * block.expansion 169 | 170 | for i in range(1, blocks): 171 | layers.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | 176 | def forward(self, x): 177 | """ 178 | :param x: 179 | """ 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | 184 | x = self.maxpool(x) 185 | # print('x.size():{}'.format(x.size())) 186 | 187 | x = self.layer1(x) 188 | # print('x.size():{}'.format(x.size())) 189 | x = self.layer2(x) 190 | # print('x.size():{}'.format(x.size())) 191 | x = self.layer3(x) 192 | # print('x.size():{}'.format(x.size())) 193 | x = self.layer4(x) 194 | # print('x.size():{}'.format(x.size())) 195 | 196 | x = self.avgpool(x) 197 | x = x.view(x.size(0), -1) 198 | x = self.fc(x) 199 | 200 | return x 201 | 202 | def load_weights(self, url): 203 | pretrained_dict = model_zoo.load_url(model_urls[url]) 204 | model_dict = self.state_dict() 205 | # print('pretrained_dict.keys():', pretrained_dict.keys()) 206 | # print('model_dict.keys():', model_dict.keys()) 207 | if self.n_classes!=1000: 208 | new_dict = {k: v for k, v in pretrained_dict.items() if k not in {'fc.weight', 'fc.bias'}} 209 | else: 210 | new_dict = pretrained_dict 211 | model_dict.update(new_dict) 212 | self.load_state_dict(model_dict) 213 | 214 | 215 | def resnet18(pretrained=False, **kwargs): 216 | """Constructs a ResNet-18 model. 217 | 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 222 | if pretrained: 223 | model.load_weights('resnet18') 224 | return model 225 | 226 | 227 | def resnet34(pretrained=False, **kwargs): 228 | """Constructs a ResNet-34 model 229 | 230 | :param pretrained: If True, returns a model pre-trained on ImageNet 231 | :param kwargs: 232 | """ 233 | model = ResNet(BasicBlock, layers=[3, 4, 6, 3], **kwargs) 234 | if pretrained: 235 | model.load_weights('resnet50') 236 | return model 237 | 238 | 239 | def resnet50(pretrained=False, **kwargs): 240 | """Constructs a ResNet-50 model. 241 | Args: 242 | pretrained (bool): If True, returns a model pre-trained on ImageNet 243 | """ 244 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 245 | if pretrained: 246 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 247 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in {'fc.bias', 'fc.weight'}} 248 | pretrained_dict.update(model.state_dict()) 249 | # print(pretrained_dict.keys()) 250 | model.load_state_dict(pretrained_dict) 251 | return model 252 | 253 | 254 | def resnet101(pretrained=False, **kwargs): 255 | """Constructs a ResNet-101 model. 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 260 | if pretrained: 261 | model.load_weights('resnet101') 262 | return model 263 | 264 | 265 | def resnet152(pretrained=False, **kwargs): 266 | """Constructs a ResNet-152 model. 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | """ 270 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 271 | if pretrained: 272 | model.load_weights('resnet152') 273 | return model 274 | 275 | 276 | if __name__ == '__main__': 277 | model = torchvision.models.resnet152(pretrained=True) 278 | # model = torchvision.models.resnet34(pretrained=True) 279 | # model = resnet34(pretrained=True) 280 | model.eval() 281 | 282 | input_data = misc.imread('../../../data/cat.jpg') 283 | # 按照imagenet的图像格式预处理 284 | input_data = imagenet_utils.imagenet_preprocess(input_data) 285 | 286 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 287 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 288 | # print(x.shape) 289 | start = time.time() 290 | pred = model(x) 291 | end = time.time() 292 | print("resnet152 forward time:", end-start) 293 | 294 | imagenet_utils.get_imagenet_label(pred) 295 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | import os 8 | import time 9 | from torch import nn 10 | from torch.autograd import Variable 11 | import os 12 | from scipy import misc 13 | import numpy as np 14 | import scipy 15 | import matplotlib.pyplot as plt 16 | import torchvision 17 | 18 | from cifarclassify.utils import imagenet_utils 19 | 20 | __all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a', 21 | 'resnet152_ibn_a'] 22 | 23 | model_urls = { 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1): 31 | "3x3 convolution with padding" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=1, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class IBN(nn.Module): 69 | def __init__(self, planes): 70 | super(IBN, self).__init__() 71 | half1 = int(planes / 2) 72 | self.half = half1 73 | half2 = planes - half1 74 | self.IN = nn.InstanceNorm2d(half1, affine=True) 75 | self.BN = nn.BatchNorm2d(half2) 76 | 77 | def forward(self, x): 78 | split = torch.split(x, self.half, 1) 79 | out1 = self.IN(split[0].contiguous()) 80 | out2 = self.BN(split[1].contiguous()) 81 | out = torch.cat((out1, out2), 1) 82 | return out 83 | 84 | 85 | class Bottleneck(nn.Module): 86 | expansion = 4 87 | 88 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 89 | super(Bottleneck, self).__init__() 90 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 91 | if ibn: 92 | self.bn1 = IBN(planes) 93 | else: 94 | self.bn1 = nn.BatchNorm2d(planes) 95 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 96 | padding=1, bias=False) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | residual = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | residual = self.downsample(x) 120 | 121 | out += residual 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000): 130 | scale = 64 131 | self.inplanes = scale 132 | super(ResNet, self).__init__() 133 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = nn.BatchNorm2d(scale) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, scale, layers[0]) 139 | self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2) 140 | self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) 141 | self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2) 142 | self.avgpool = nn.AvgPool2d(7) 143 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 148 | m.weight.data.normal_(0, math.sqrt(2. / n)) 149 | elif isinstance(m, nn.BatchNorm2d): 150 | m.weight.data.fill_(1) 151 | m.bias.data.zero_() 152 | elif isinstance(m, nn.InstanceNorm2d): 153 | m.weight.data.fill_(1) 154 | m.bias.data.zero_() 155 | 156 | def _make_layer(self, block, planes, blocks, stride=1): 157 | downsample = None 158 | if stride != 1 or self.inplanes != planes * block.expansion: 159 | downsample = nn.Sequential( 160 | nn.Conv2d(self.inplanes, planes * block.expansion, 161 | kernel_size=1, stride=stride, bias=False), 162 | nn.BatchNorm2d(planes * block.expansion), 163 | ) 164 | 165 | layers = [] 166 | ibn = True 167 | if planes == 512: 168 | ibn = False 169 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 170 | self.inplanes = planes * block.expansion 171 | for i in range(1, blocks): 172 | layers.append(block(self.inplanes, planes, ibn)) 173 | 174 | return nn.Sequential(*layers) 175 | 176 | def forward(self, x): 177 | x = self.conv1(x) 178 | x = self.bn1(x) 179 | x = self.relu(x) 180 | x = self.maxpool(x) 181 | 182 | x = self.layer1(x) 183 | x = self.layer2(x) 184 | x = self.layer3(x) 185 | x = self.layer4(x) 186 | 187 | x = self.avgpool(x) 188 | x = x.view(x.size(0), -1) 189 | x = self.fc(x) 190 | 191 | return x 192 | 193 | 194 | def resnet50_ibn_a(pretrained=False, **kwargs): 195 | """Constructs a ResNet-50 model. 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 202 | return model 203 | 204 | 205 | def resnet101_ibn_a(pretrained=False, **kwargs): 206 | """Constructs a ResNet-101 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 213 | return model 214 | 215 | 216 | def resnet152_ibn_a(pretrained=False, **kwargs): 217 | """Constructs a ResNet-152 model. 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 222 | if pretrained: 223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 224 | return model 225 | 226 | if __name__ == '__main__': 227 | # import torch._utils 228 | # 229 | # try: 230 | # torch._utils._rebuild_tensor_v2 231 | # except AttributeError: 232 | # def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 233 | # tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 234 | # tensor.requires_grad = requires_grad 235 | # tensor._backward_hooks = backward_hooks 236 | # return tensor 237 | # 238 | # 239 | # torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 240 | 241 | model = resnet50_ibn_a() 242 | model.eval() 243 | model_checkpoint_path = os.path.expanduser('~/.torch/models/resnet50_ibn_a.pth.tar') 244 | if os.path.exists(model_checkpoint_path): 245 | model_checkpoint = torch.load(model_checkpoint_path, map_location='cpu') 246 | model_dict = model.state_dict() 247 | pretrained_dict = model_checkpoint['state_dict'] 248 | # new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 249 | new_dict = {} 250 | for k, v in pretrained_dict.items(): 251 | new_k = k[k.find('.')+1:] 252 | new_v = v 253 | new_dict[new_k] = new_v 254 | 255 | # print(model_dict.keys()) 256 | # print(pretrained_dict.keys()) 257 | # print(new_dict) 258 | model_dict.update(new_dict) 259 | model.load_state_dict(model_dict) 260 | 261 | input_data = misc.imread('../../../data/cat.jpg') 262 | # 按照imagenet的图像格式预处理 263 | input_data = imagenet_utils.imagenet_preprocess(input_data) 264 | 265 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 266 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 267 | # print(x.shape) 268 | start = time.time() 269 | pred = model(x) 270 | end = time.time() 271 | print("resnet50_ibn_a forward time:", end-start) 272 | 273 | imagenet_utils.get_imagenet_label(pred) 274 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/resnet_ibn_b.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | import os 7 | 8 | __all__ = ['ResNet', 'resnet50_ibn_b', 'resnet101_ibn_b', 9 | 'resnet152_ibn_b'] 10 | 11 | model_urls = { 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, IN=False): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.IN = None 69 | if IN: 70 | self.IN = nn.InstanceNorm2d(planes * 4, affine=True) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | if self.IN is not None: 94 | out = self.IN(out) 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000): 103 | scale = 64 104 | self.inplanes = scale 105 | super(ResNet, self).__init__() 106 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | self.bn1 = nn.InstanceNorm2d(scale, affine=True) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | self.layer1 = self._make_layer(block, scale, layers[0], stride=1, IN=True) 112 | self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, IN=True) 113 | self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7) 116 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | elif isinstance(m, nn.InstanceNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1, IN=False): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=stride, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks - 1): 142 | layers.append(block(self.inplanes, planes)) 143 | layers.append(block(self.inplanes, planes, IN=IN)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def resnet50_ibn_b(pretrained=False, **kwargs): 166 | """Constructs a ResNet-50 model. 167 | Args: 168 | pretrained (bool): If True, returns a model pre-trained on ImageNet 169 | """ 170 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 171 | if pretrained: 172 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 173 | return model 174 | 175 | 176 | def resnet101_ibn_b(pretrained=False, **kwargs): 177 | """Constructs a ResNet-101 model. 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 184 | return model 185 | 186 | 187 | def resnet152_ibn_b(pretrained=False, **kwargs): 188 | """Constructs a ResNet-152 model. 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 193 | if pretrained: 194 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 195 | return model 196 | 197 | if __name__ == '__main__': 198 | model = resnet50_ibn_b() 199 | model_weight_path = os.path.expanduser('~/.torch/models/resnet50_ibn_b.pth.tar') 200 | if os.path.exists(model_weight_path): 201 | model_weight = torch.load(model_weight_path, map_location='cpu') 202 | model.load_state_dict(model_weight['state_dict']) 203 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/seresnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import torch 4 | from torch import nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch.autograd import Variable 7 | import os 8 | from scipy import misc 9 | import numpy as np 10 | import scipy 11 | import matplotlib.pyplot as plt 12 | import torchvision 13 | import torch.nn as nn 14 | import math 15 | 16 | from cifarclassify.utils import imagenet_utils 17 | 18 | # encoding:utf-8 19 | # Modify from torchvision 20 | # ResNeXt: Copy from https://github.com/last-one/tools/blob/master/pytorch/SE-ResNeXt/SeResNeXt.py 21 | # [PyTorch-SE-ResNet](https://github.com/StickCui/PyTorch-SE-ResNet/blob/master/model/model.py) 22 | 23 | 24 | class Bottleneck(nn.Module): 25 | expansion = 4 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(Bottleneck, self).__init__() 29 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 35 | self.bn3 = nn.BatchNorm2d(planes * 4) 36 | self.relu = nn.ReLU(inplace=True) 37 | # SE 38 | self.global_pool = nn.AdaptiveAvgPool2d(1) 39 | self.conv_down = nn.Conv2d( 40 | planes * 4, planes // 4, kernel_size=1, bias=False) 41 | self.conv_up = nn.Conv2d( 42 | planes // 4, planes * 4, kernel_size=1, bias=False) 43 | self.sig = nn.Sigmoid() 44 | # Downsample 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv3(out) 60 | out = self.bn3(out) 61 | 62 | out1 = self.global_pool(out) 63 | out1 = self.conv_down(out1) 64 | out1 = self.relu(out1) 65 | out1 = self.conv_up(out1) 66 | out1 = self.sig(out1) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | res = out1 * out + residual 72 | res = self.relu(res) 73 | 74 | return res 75 | 76 | 77 | class SEResNet(nn.Module): 78 | 79 | def __init__(self, block, layers, num_classes=1000): 80 | self.inplanes = 64 81 | super(SEResNet, self).__init__() 82 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 83 | bias=False) 84 | self.bn1 = nn.BatchNorm2d(64) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 87 | self.layer1 = self._make_layer(block, 64, layers[0]) 88 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 89 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 90 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 91 | self.avgpool = nn.AvgPool2d(7) 92 | self.fc = nn.Linear(512 * block.expansion, num_classes) 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | 102 | def _make_layer(self, block, planes, blocks, stride=1): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = nn.Sequential( 106 | nn.Conv2d(self.inplanes, planes * block.expansion, 107 | kernel_size=1, stride=stride, bias=False), 108 | nn.BatchNorm2d(planes * block.expansion), 109 | ) 110 | 111 | layers = [] 112 | layers.append(block(self.inplanes, planes, stride, downsample)) 113 | self.inplanes = planes * block.expansion 114 | for i in range(1, blocks): 115 | layers.append(block(self.inplanes, planes)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.conv1(x) 121 | x = self.bn1(x) 122 | x = self.relu(x) 123 | x = self.maxpool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.avgpool(x) 131 | x = x.view(x.size(0), -1) 132 | x = self.fc(x) 133 | 134 | return x 135 | 136 | 137 | class Selayer(nn.Module): 138 | 139 | def __init__(self, inplanes): 140 | super(Selayer, self).__init__() 141 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 142 | self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1) 143 | self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.sigmoid = nn.Sigmoid() 146 | 147 | def forward(self, x): 148 | 149 | out = self.global_avgpool(x) 150 | 151 | out = self.conv1(out) 152 | out = self.relu(out) 153 | 154 | out = self.conv2(out) 155 | out = self.sigmoid(out) 156 | 157 | return x * out 158 | 159 | 160 | class BottleneckX(nn.Module): 161 | expansion = 4 162 | 163 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None): 164 | super(BottleneckX, self).__init__() 165 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 166 | self.bn1 = nn.BatchNorm2d(planes * 2) 167 | 168 | self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride, 169 | padding=1, groups=cardinality, bias=False) 170 | self.bn2 = nn.BatchNorm2d(planes * 2) 171 | 172 | self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False) 173 | self.bn3 = nn.BatchNorm2d(planes * 4) 174 | 175 | self.selayer = Selayer(planes * 4) 176 | 177 | self.relu = nn.ReLU(inplace=True) 178 | self.downsample = downsample 179 | self.stride = stride 180 | 181 | def forward(self, x): 182 | residual = x 183 | 184 | out = self.conv1(x) 185 | out = self.bn1(out) 186 | out = self.relu(out) 187 | 188 | out = self.conv2(out) 189 | out = self.bn2(out) 190 | out = self.relu(out) 191 | 192 | out = self.conv3(out) 193 | out = self.bn3(out) 194 | 195 | out = self.selayer(out) 196 | 197 | if self.downsample is not None: 198 | residual = self.downsample(x) 199 | 200 | out += residual 201 | out = self.relu(out) 202 | 203 | return out 204 | 205 | 206 | class SEResNeXt(nn.Module): 207 | 208 | def __init__(self, block, layers, cardinality=32, num_classes=1000): 209 | super(SEResNeXt, self).__init__() 210 | self.cardinality = cardinality 211 | self.inplanes = 64 212 | 213 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 214 | bias=False) 215 | self.bn1 = nn.BatchNorm2d(64) 216 | self.relu = nn.ReLU(inplace=True) 217 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 218 | 219 | self.layer1 = self._make_layer(block, 64, layers[0]) 220 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 221 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 222 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 223 | 224 | self.avgpool = nn.AdaptiveAvgPool2d(1) 225 | self.fc = nn.Linear(512 * block.expansion, num_classes) 226 | 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 230 | m.weight.data.normal_(0, math.sqrt(2. / n)) 231 | if m.bias is not None: 232 | m.bias.data.zero_() 233 | elif isinstance(m, nn.BatchNorm2d): 234 | m.weight.data.fill_(1) 235 | m.bias.data.zero_() 236 | 237 | def _make_layer(self, block, planes, blocks, stride=1): 238 | downsample = None 239 | if stride != 1 or self.inplanes != planes * block.expansion: 240 | downsample = nn.Sequential( 241 | nn.Conv2d(self.inplanes, planes * block.expansion, 242 | kernel_size=1, stride=stride, bias=False), 243 | nn.BatchNorm2d(planes * block.expansion), 244 | ) 245 | 246 | layers = [] 247 | layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample)) 248 | self.inplanes = planes * block.expansion 249 | for i in range(1, blocks): 250 | layers.append(block(self.inplanes, planes, self.cardinality)) 251 | 252 | return nn.Sequential(*layers) 253 | 254 | def forward(self, x): 255 | x = self.conv1(x) 256 | x = self.bn1(x) 257 | x = self.relu(x) 258 | x = self.maxpool(x) 259 | print('maxpool x:', x.shape) 260 | 261 | x = self.layer1(x) 262 | print('layer1 x:', x.shape) 263 | x = self.layer2(x) 264 | print('layer2 x:', x.shape) 265 | x = self.layer3(x) 266 | print('layer3 x:', x.shape) 267 | x = self.layer4(x) 268 | print('layer4 x:', x.shape) 269 | 270 | x = self.avgpool(x) 271 | print('avgpool x:', x.shape) 272 | x = x.view(x.size(0), -1) 273 | 274 | x = self.fc(x) 275 | 276 | return x 277 | 278 | 279 | def se_resnet50(**kwargs): 280 | """Constructs a SE-ResNet-50 model. 281 | Args: 282 | num_classes = 1000 (default) 283 | """ 284 | model = SEResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 285 | return model 286 | 287 | 288 | def se_resnet101(**kwargs): 289 | """Constructs a SE-ResNet-101 model. 290 | Args: 291 | num_classes = 1000 (default) 292 | """ 293 | model = SEResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 294 | return model 295 | 296 | 297 | def se_resnet152(**kwargs): 298 | """Constructs a SE-ResNet-152 model. 299 | Args: 300 | num_classes = 1000 (default) 301 | """ 302 | model = SEResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 303 | return model 304 | 305 | 306 | def se_resnext50(**kwargs): 307 | """Constructs a SE-ResNeXt-50 model. 308 | Args: 309 | num_classes = 1000 (default) 310 | """ 311 | model = SEResNeXt(BottleneckX, [3, 4, 6, 3], **kwargs) 312 | return model 313 | 314 | 315 | def se_resnext101(**kwargs): 316 | """Constructs a SE-ResNeXt-101 model. 317 | Args: 318 | num_classes = 1000 (default) 319 | """ 320 | model = SEResNeXt(BottleneckX, [3, 4, 23, 3], **kwargs) 321 | return model 322 | 323 | 324 | def se_resnext152(**kwargs): 325 | """Constructs a SE-ResNeXt-152 model. 326 | Args: 327 | num_classes = 1000 (default) 328 | """ 329 | model = SEResNeXt(BottleneckX, [3, 8, 36, 3], **kwargs) 330 | return model 331 | 332 | 333 | if __name__ == '__main__': 334 | # model = torchvision.models.resnet152(pretrained=True) 335 | # model = torchvision.models.resnet34(pretrained=True) 336 | model = se_resnext50() 337 | model.eval() 338 | 339 | # input_data = misc.imread('../../../data/cat.jpg') 340 | # 按照imagenet的图像格式预处理 341 | # input_data = imagenet_utils.imagenet_preprocess(input_data) 342 | 343 | # x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 344 | # x = Variable(torch.randn(1, 3, 224, 224)) 345 | x = Variable(torch.randn(1, 3, 96, 96)) 346 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 347 | # print(x.shape) 348 | start = time.time() 349 | pred = model(x) 350 | end = time.time() 351 | print("forward time:", end-start) 352 | 353 | imagenet_utils.get_imagenet_label(pred) 354 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/shufflenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # !!!code is from [ShuffleNet-1g8-Pytorch](https://github.com/ericsun99/ShuffleNet-1g8-Pytorch)!!! 3 | import os 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from collections import OrderedDict 11 | from torch.nn import init 12 | import numpy as np 13 | from scipy import misc 14 | import matplotlib.pyplot as plt 15 | 16 | from cifarclassify.utils import imagenet_utils 17 | 18 | 19 | def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True, groups=1): 20 | """3x3 convolution with padding 21 | """ 22 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups) 23 | 24 | 25 | def conv1x1(in_channels, out_channels, groups=1): 26 | """1x1 convolution with padding 27 | - Normal pointwise convolution When groups == 1 28 | - Grouped pointwise convolution when groups > 1 29 | 1x1卷积,groups==1那么正常卷积,否则grouped卷积 30 | """ 31 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, stride=1) 32 | 33 | 34 | def channel_shuffle(x, groups): 35 | batchsize, num_channels, height, width = x.data.size() 36 | 37 | channels_per_group = num_channels // groups # 每一group的通道数 38 | 39 | # reshape 40 | x = x.view(batchsize, groups, channels_per_group, height, width) # reshape为不同的groups 41 | 42 | # transpose 43 | # - contiguous() required if transpose() is used before view(). 44 | # See https://github.com/pytorch/pytorch/issues/764 45 | x = torch.transpose(x, 1, 2).contiguous() 46 | 47 | # flatten 48 | x = x.view(batchsize, -1, height, width) 49 | 50 | return x 51 | 52 | 53 | class ShuffleUnit(nn.Module): 54 | def __init__(self, in_channels, out_channels, groups=3, grouped_conv=True, combine='add'): 55 | """ 56 | :param in_channels: ShuffleUnit输入通道数 57 | :param out_channels: ShuffleUnit输出通道数 58 | :param groups: ShuffleUnit分组groups 59 | :param grouped_conv: 是否在第一个1x1卷积使用groued卷积 60 | :param combine: combine使用element wise add or concat 61 | """ 62 | 63 | super(ShuffleUnit, self).__init__() 64 | self.in_channels = in_channels 65 | self.out_channels = out_channels 66 | self.grouped_conv = grouped_conv 67 | self.combine = combine 68 | self.groups = groups 69 | self.bottleneck_channels = self.out_channels // 4 # 输出通道数的1/4为ShuffleUnit的bottleneck通道数 70 | 71 | # define the type of ShuffleUnit 72 | # 不同stirde的ShuffleUnit使用不同的combine with path connection,stride=2的使用concat,stride=1的使用element wise add 73 | if self.combine == 'add': 74 | # ShuffleUnit Figure 2b 75 | self.depthwise_stride = 1 76 | self._combine_func = self._add 77 | elif self.combine == 'concat': 78 | # ShuffleUnit Figure 2c 79 | self.depthwise_stride = 2 80 | self._combine_func = self._concat 81 | # ensure output of concat has the same channels as 82 | # original output channels. 83 | self.out_channels -= self.in_channels # 确保concat输出和原始输出有相同的通道数,因此将out_channels-in_channels作为残差快的输出即可 84 | else: 85 | raise ValueError("Cannot combine tensors with \"{}\"" \ 86 | "Only \"add\" and \"concat\" are" \ 87 | "supported".format(self.combine)) 88 | 89 | # Use a 1x1 grouped or non-grouped convolution to reduce input channels 90 | # to bottleneck channels, as in a ResNet bottleneck module. 91 | # NOTE: Do not use group convolution for the first conv1x1 in Stage 2. 92 | self.first_1x1_groups = self.groups if grouped_conv else 1 # 在stage2的第一个conv1x1不使用group卷积 93 | 94 | self.g_conv_1x1_compress = self._make_grouped_conv1x1(self.in_channels, self.bottleneck_channels, self.first_1x1_groups, batch_norm=True, relu=True) 95 | 96 | # 3x3 depthwise convolution followed by batch normalization 97 | # 3x3 deptheise convolution with BN 98 | self.depthwise_conv3x3 = conv3x3(self.bottleneck_channels, self.bottleneck_channels, stride=self.depthwise_stride, groups=self.bottleneck_channels) 99 | self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels) 100 | 101 | # Use 1x1 grouped convolution to expand from 102 | # bottleneck_channels to out_channels 103 | self.g_conv_1x1_expand = self._make_grouped_conv1x1(self.bottleneck_channels, self.out_channels, self.groups, batch_norm=True, relu=False) 104 | 105 | @staticmethod 106 | def _add(x, out): 107 | # residual connection 108 | # 残差add连接,用于stride=1的ShuffleUnit 109 | return x + out 110 | 111 | @staticmethod 112 | def _concat(x, out): 113 | # concatenate along channel axis 114 | # concat连接,用于stride=2的ShuffleUnit 115 | return torch.cat((x, out), 1) 116 | 117 | def _make_grouped_conv1x1(self, in_channels, out_channels, groups, batch_norm=True, relu=False): 118 | 119 | modules = OrderedDict() 120 | 121 | conv = conv1x1(in_channels, out_channels, groups=groups) 122 | modules['conv1x1'] = conv 123 | 124 | # 是否在1x1卷积中增加BN 125 | if batch_norm: 126 | modules['batch_norm'] = nn.BatchNorm2d(out_channels) 127 | # 是否在1x1卷积中增加ReLU 128 | if relu: 129 | modules['relu'] = nn.ReLU() 130 | 131 | if len(modules) > 1: 132 | return nn.Sequential(modules) 133 | else: 134 | return conv 135 | 136 | def forward(self, x): 137 | # save for combining later with output 138 | residual = x 139 | 140 | # 如果是concat path connection那么使用平均池化操作,否则直接是x作为残差 141 | if self.combine == 'concat': 142 | residual = F.avg_pool2d(residual, kernel_size=3, stride=2, padding=1) 143 | 144 | # 1x1 compress to out_channels//4的卷积 145 | out = self.g_conv_1x1_compress(x) 146 | out = channel_shuffle(out, self.groups) # channel shuffle操作,将out按照groups shuffle 147 | 148 | out = self.depthwise_conv3x3(out) 149 | out = self.bn_after_depthwise(out) 150 | 151 | # 1x1 expand to out_channels的卷积 152 | out = self.g_conv_1x1_expand(out) 153 | 154 | # 最后和残差组合 155 | out = self._combine_func(residual, out) 156 | return F.relu(out) 157 | 158 | 159 | class ShuffleNet(nn.Module): 160 | """ShuffleNet implementation. 161 | """ 162 | 163 | def __init__(self, groups=3, in_channels=3, n_classes=1000): 164 | """ShuffleNet constructor. 165 | Arguments: 166 | groups (int, optional): number of groups to be used in grouped 167 | 1x1 convolutions in each ShuffleUnit. Default is 3 for best 168 | performance according to original paper. 169 | 在ShuffeleUnit中使用的1x1卷积groups数量 170 | in_channels (int, optional): number of channels in the input tensor. 171 | Default is 3 for RGB image inputs. 172 | n_classes (int, optional): number of classes to predict. Default 173 | is 1000 for ImageNet. 174 | """ 175 | super(ShuffleNet, self).__init__() 176 | 177 | self.groups = groups 178 | self.stage_repeats = [3, 7, 3] # stage 重复次数 179 | self.in_channels = in_channels 180 | self.n_classes = n_classes 181 | 182 | # index 0 is invalid and should never be called. 183 | # only used for indexing convenience. 184 | # 总攻有3个stage,包括第一个Conv1和MaxPool 185 | if groups == 1: 186 | self.stage_out_channels = [-1, 24, 144, 288, 567] 187 | elif groups == 2: 188 | self.stage_out_channels = [-1, 24, 200, 400, 800] 189 | elif groups == 3: 190 | self.stage_out_channels = [-1, 24, 240, 480, 960] 191 | elif groups == 4: 192 | self.stage_out_channels = [-1, 24, 272, 544, 1088] 193 | elif groups == 8: 194 | self.stage_out_channels = [-1, 24, 384, 768, 1536] 195 | else: 196 | raise ValueError( 197 | """{} groups is not supported for 198 | 1x1 Grouped Convolutions""".format(groups)) 199 | 200 | # Stage 1 always has 24 output channels 201 | # conv1+maxpool组成了stage1,2,3的第一组操作 202 | self.conv1 = conv3x3(self.in_channels, self.stage_out_channels[1], stride=2) 203 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 204 | 205 | # Stage 2 206 | # stage2构建 207 | self.stage2 = self._make_stage(2) 208 | # Stage 3 209 | # stage3构建 210 | self.stage3 = self._make_stage(3) 211 | # Stage 4 212 | # stage4构建 213 | self.stage4 = self._make_stage(4) 214 | 215 | # Global pooling: 216 | # Undefined as PyTorch's functional API can be used for on-the-fly 217 | # shape inference if input size is not ImageNet's 224x224 218 | 219 | # Fully-connected classification layer 220 | num_inputs = self.stage_out_channels[-1] 221 | self.fc = nn.Linear(num_inputs, self.n_classes) 222 | 223 | def _make_stage(self, stage): 224 | modules = OrderedDict() 225 | stage_name = "ShuffleUnit_Stage{}".format(stage) # 增加module name for convience 226 | 227 | # First ShuffleUnit in the stage 228 | # 1. non-grouped 1x1 convolution (i.e. pointwise convolution) 229 | # is used in Stage 2. Group convolutions used everywhere else. 230 | grouped_conv = stage > 2 # 第二个stage不使用grouped conv 231 | 232 | # 2. concatenation unit is always used. 233 | # 总是使用convcat单元在每一个stage的第一个 234 | first_module = ShuffleUnit( 235 | self.stage_out_channels[stage - 1], 236 | self.stage_out_channels[stage], 237 | groups=self.groups, 238 | grouped_conv=grouped_conv, 239 | combine='concat' 240 | ) 241 | modules[stage_name + "_0"] = first_module 242 | 243 | # add more ShuffleUnits depending on pre-defined number of repeats 244 | for i in range(self.stage_repeats[stage - 2]): 245 | name = stage_name + "_{}".format(i + 1) 246 | module = ShuffleUnit( 247 | self.stage_out_channels[stage], 248 | self.stage_out_channels[stage], 249 | groups=self.groups, 250 | grouped_conv=True, 251 | combine='add' 252 | ) 253 | modules[name] = module 254 | 255 | return nn.Sequential(modules) 256 | 257 | def forward(self, x): 258 | x = self.conv1(x) 259 | x = self.maxpool(x) 260 | 261 | x = self.stage2(x) 262 | x = self.stage3(x) 263 | x = self.stage4(x) 264 | 265 | # global average pooling layer 266 | x = F.avg_pool2d(x, x.data.size()[-2:]) 267 | 268 | # flatten for input to fully-connected layer 269 | x = x.view(x.size(0), -1) 270 | x = self.fc(x) 271 | 272 | # return F.log_softmax(x, dim=1) 273 | return x 274 | 275 | if __name__ == '__main__': 276 | 277 | image_height, image_width, image_channel = (224, 224, 3) 278 | input = misc.imread('../../../data/cat.jpg') 279 | # 按照imagenet的图像格式预处理 280 | input = imagenet_utils.imagenet_preprocess(input) 281 | 282 | n_classes = 1000 283 | model = ShuffleNet(n_classes=n_classes, groups=8) 284 | model.eval() 285 | # 训练模型为gpu模型 286 | model_checkpoint_path = os.path.expanduser('~/.torch/models/ShuffleNet_1g8_Top1_67.408_Top5_87.258.pth.tar') 287 | if os.path.exists(model_checkpoint_path): 288 | pretrained_dict = torch.load(model_checkpoint_path, map_location='cpu') 289 | model_dict = model.state_dict() 290 | # new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 291 | new_dict = {} 292 | for k, v in pretrained_dict.items(): 293 | new_k = k[k.find('.')+1:] 294 | new_v = v 295 | new_dict[new_k] = new_v 296 | 297 | # print(model_dict.keys()[:5]) 298 | # print(pretrained_dict.keys()[:5]) 299 | # print(new_dict.keys()[:5]) 300 | model_dict.update(new_dict) 301 | model.load_state_dict(model_dict) 302 | 303 | x = Variable(torch.randn(1, image_channel, image_height, image_width)) 304 | x = Variable(torch.FloatTensor(torch.from_numpy(input))) 305 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 306 | # print(x.shape) 307 | start = time.time() 308 | pred = model(x) 309 | end = time.time() 310 | print("ShuffleNet forward time:", end-start) 311 | 312 | imagenet_utils.get_imagenet_label(pred) 313 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/squeezenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | # code from https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py 8 | 9 | model_urls = { 10 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 11 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 12 | } 13 | 14 | 15 | class Fire(nn.Module): 16 | 17 | # 输入是inplanes,输出是expand1x1_planes+expand3x3_planes 18 | # 参数是inplanes*squeeze_planes+squeeze_planes*expand1x1_planes+squeeze_planes*expand3x3_planes*9 19 | def __init__(self, inplanes, squeeze_planes, 20 | expand1x1_planes, expand3x3_planes): 21 | super(Fire, self).__init__() 22 | self.inplanes = inplanes 23 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 24 | self.squeeze_activation = nn.ReLU(inplace=True) 25 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 26 | kernel_size=1) 27 | self.expand1x1_activation = nn.ReLU(inplace=True) 28 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 29 | kernel_size=3, padding=1) 30 | self.expand3x3_activation = nn.ReLU(inplace=True) 31 | 32 | def forward(self, x): 33 | x = self.squeeze_activation(self.squeeze(x)) 34 | # 通道1上表示channel上concat 35 | return torch.cat([ 36 | self.expand1x1_activation(self.expand1x1(x)), 37 | self.expand3x3_activation(self.expand3x3(x)) 38 | ], 1) 39 | 40 | 41 | class SqueezeNet(nn.Module): 42 | 43 | def __init__(self, version=1.0, num_classes=1000): 44 | super(SqueezeNet, self).__init__() 45 | if version not in [1.0, 1.1]: 46 | raise ValueError("Unsupported SqueezeNet version {version}:" 47 | "1.0 or 1.1 expected".format(version=version)) 48 | self.num_classes = num_classes 49 | if version == 1.0: 50 | self.features = nn.Sequential( 51 | nn.Conv2d(3, 96, kernel_size=7, stride=2), 52 | nn.ReLU(inplace=True), 53 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 54 | Fire(96, 16, 64, 64), 55 | Fire(128, 16, 64, 64), 56 | Fire(128, 32, 128, 128), 57 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 58 | Fire(256, 32, 128, 128), 59 | Fire(256, 48, 192, 192), 60 | Fire(384, 48, 192, 192), 61 | Fire(384, 64, 256, 256), 62 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 63 | Fire(512, 64, 256, 256), 64 | ) 65 | else: 66 | self.features = nn.Sequential( 67 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 68 | nn.ReLU(inplace=True), 69 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 70 | Fire(64, 16, 64, 64), 71 | Fire(128, 16, 64, 64), 72 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 73 | Fire(128, 32, 128, 128), 74 | Fire(256, 32, 128, 128), 75 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 76 | Fire(256, 48, 192, 192), 77 | Fire(384, 48, 192, 192), 78 | Fire(384, 64, 256, 256), 79 | Fire(512, 64, 256, 256), 80 | ) 81 | # Final convolution is initialized differently form the rest 82 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 83 | self.classifier = nn.Sequential( 84 | nn.Dropout(p=0.5), 85 | final_conv, 86 | nn.ReLU(inplace=True), 87 | nn.AdaptiveAvgPool2d((1, 1)) 88 | ) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | if m is final_conv: 93 | init.normal_(m.weight, mean=0.0, std=0.01) 94 | else: 95 | init.kaiming_uniform_(m.weight) 96 | if m.bias is not None: 97 | init.constant_(m.bias, 0) 98 | 99 | def forward(self, x): 100 | x = self.features(x) 101 | x = self.classifier(x) 102 | return x.view(x.size(0), self.num_classes) 103 | 104 | 105 | def squeezenet1_0(pretrained=False, **kwargs): 106 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 107 | accuracy with 50x fewer parameters and <0.5MB model size" 108 | `_ paper. 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | """ 112 | model = SqueezeNet(version=1.0, **kwargs) 113 | if pretrained: 114 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) 115 | return model 116 | 117 | 118 | def squeezenet1_1(pretrained=False, **kwargs): 119 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 120 | `_. 121 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 122 | than SqueezeNet 1.0, without sacrificing accuracy. 123 | Args: 124 | pretrained (bool): If True, returns a model pre-trained on ImageNet 125 | """ 126 | model = SqueezeNet(version=1.1, **kwargs) 127 | if pretrained: 128 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) 129 | return model 130 | 131 | if __name__ == '__main__': 132 | n_classes = 1000 133 | model = squeezenet1_1(pretrained=True) 134 | model.eval() 135 | 136 | from scipy import misc 137 | from torch.autograd import Variable 138 | import time 139 | from cifarclassify.utils import imagenet_utils 140 | import numpy as np 141 | input_data = misc.imread('../../../data/EnglishCockerSpaniel_simon.jpg') 142 | # 按照imagenet的图像格式预处理 143 | input_data = imagenet_utils.imagenet_preprocess(input_data) 144 | 145 | # x = Variable(torch.randn(1, 3, 224, 224)) 146 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 147 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 148 | # print(x.shape) 149 | start = time.time() 150 | pred = model(x) 151 | end = time.time() 152 | print("squeezenet1_1 forward time:", end-start) 153 | 154 | imagenet_utils.get_imagenet_label(pred) 155 | -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 参考代码[resnet.py](https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py) 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import torch 7 | from torch.nn.init import kaiming_normal_ 8 | import torch.nn.functional as F 9 | from torch.nn.parallel._functions import Broadcast 10 | from torch.nn.parallel import scatter, parallel_apply, gather 11 | from functools import partial 12 | from nested_dict import nested_dict 13 | import os 14 | import time 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch.autograd import Variable 20 | from collections import OrderedDict 21 | from torch.nn import init 22 | import numpy as np 23 | from scipy import misc 24 | import matplotlib.pyplot as plt 25 | 26 | from cifarclassify.utils import imagenet_utils 27 | 28 | 29 | def cast(params, dtype='float'): 30 | if isinstance(params, dict): 31 | return {k: cast(v, dtype) for k,v in params.items()} 32 | else: 33 | return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)() 34 | 35 | 36 | # conv params no*ni*k*k 37 | def conv_params(ni, no, k=1): 38 | return kaiming_normal_(torch.Tensor(no, ni, k, k)) 39 | 40 | 41 | def linear_params(ni, no): 42 | return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)} 43 | 44 | 45 | def bnparams(n): 46 | return {'weight': torch.rand(n), 47 | 'bias': torch.zeros(n), 48 | 'running_mean': torch.zeros(n), 49 | 'running_var': torch.ones(n)} 50 | 51 | 52 | def data_parallel(f, input, params, mode, device_ids, output_device=None): 53 | assert isinstance(device_ids, list) 54 | if output_device is None: 55 | output_device = device_ids[0] 56 | 57 | if len(device_ids) == 1: 58 | return f(input, params, mode) 59 | 60 | params_all = Broadcast.apply(device_ids, *params.values()) 61 | params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())} 62 | for j in range(len(device_ids))] 63 | 64 | replicas = [partial(f, params=p, mode=mode) 65 | for p in params_replicas] 66 | inputs = scatter([input], device_ids) 67 | outputs = parallel_apply(replicas, inputs) 68 | return gather(outputs, output_device) 69 | 70 | 71 | def flatten(params): 72 | return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None} 73 | 74 | 75 | def batch_norm(x, params, base, mode): 76 | return F.batch_norm(x, weight=params[base + '.weight'], 77 | bias=params[base + '.bias'], 78 | running_mean=params[base + '.running_mean'], 79 | running_var=params[base + '.running_var'], 80 | training=mode) 81 | 82 | 83 | def print_tensor_dict(params): 84 | kmax = max(len(key) for key in params.keys()) 85 | for i, (key, v) in enumerate(params.items()): 86 | print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad) 87 | 88 | 89 | def set_requires_grad_except_bn_(params): 90 | for k, v in params.items(): 91 | if not k.endswith('running_mean') and not k.endswith('running_var'): 92 | # if not running_mean or running_var requires grad 93 | v.requires_grad = True 94 | 95 | # def wide_resnet(depth, width, num_classes): 96 | class wide_resnet(nn.Module): 97 | def __init__(self, depth, width, n_classes): 98 | super(wide_resnet, self).__init__() 99 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' # 4 is for conv0 100 | self.n = (depth - 4) // 6 101 | widths = [int(v * width) for v in (64, 128, 256, 512)] # for normal resnet 16 32 64 102 | 103 | def gen_block_params(ni, no): 104 | return { 105 | 'conv0': conv_params(ni, no, 3), 106 | 'conv1': conv_params(no, no, 3), 107 | 'bn0': bnparams(ni), 108 | 'bn1': bnparams(no), 109 | 'convdim': conv_params(ni, no, 1) if ni != no else None, 110 | } 111 | 112 | def gen_group_params(ni, no, count): 113 | return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) for i in range(count)} 114 | 115 | # conv0+group0+goup1+group2 116 | self.flat_params = cast(flatten({ 117 | 'conv0': conv_params(3, 16, 3), # input 2 output 16 118 | 'group0': gen_group_params(16, widths[0], self.n), # input 16 output widths[0] 119 | 'group1': gen_group_params(widths[0], widths[1], self.n), 120 | 'group2': gen_group_params(widths[1], widths[2], self.n), 121 | 'group3': gen_group_params(widths[2], widths[3], self.n), 122 | 'bn': bnparams(widths[3]), 123 | 'fc': linear_params(widths[3], n_classes), 124 | })) 125 | 126 | # except bn requires grad 127 | set_requires_grad_except_bn_(self.flat_params) 128 | 129 | def block(self, x, params, base, mode, stride): 130 | o1 = F.relu(batch_norm(x, params, base + '.bn0', mode), inplace=True) 131 | y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1) 132 | o2 = F.relu(batch_norm(y, params, base + '.bn1', mode), inplace=True) 133 | z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1) 134 | if base + '.convdim' in params: 135 | return z + F.conv2d(o1, params[base + '.convdim'], stride=stride) 136 | else: 137 | return z + x 138 | 139 | def group(self, o, params, base, mode, stride): 140 | for i in range(self.n): 141 | o = self.block(o, params, '%s.block%d' % (base, i), mode, stride if i == 0 else 1) 142 | return o 143 | 144 | 145 | def forward(self, input): 146 | x = F.conv2d(input, self.flat_params['conv0'], padding=1, stride=2) 147 | # print('x.shape:', x.shape) 148 | g0 = self.group(x, self.flat_params, 'group0', self.training, 2) 149 | # print('g0.shape:', g0.shape) 150 | g1 = self.group(g0, self.flat_params, 'group1', self.training, 2) 151 | # print('g1.shape:', g1.shape) 152 | g2 = self.group(g1, self.flat_params, 'group2', self.training, 2) 153 | # print('g2.shape:', g2.shape) 154 | g3 = self.group(g2, self.flat_params, 'group3', self.training, 2) 155 | o = F.relu(batch_norm(g3, self.flat_params, 'bn', self.training)) 156 | # print('o.shape:', o.shape) 157 | o = F.avg_pool2d(o, 7, 1, 0) 158 | # print('o.shape:', o.shape) 159 | o = o.view(o.size(0), -1) 160 | # print('o.shape:', o.shape) 161 | o = F.linear(o, self.flat_params['fc.weight'], self.flat_params['fc.bias']) 162 | return o 163 | # return f, flat_params 164 | # return f 165 | 166 | def wide_resnet_28_10(n_classes): 167 | depth = 28 168 | width = 10 169 | n_classes = n_classes 170 | return wide_resnet(depth, width, n_classes) 171 | 172 | def wide_resnet_16_4(n_classes): 173 | depth = 16 174 | width = 4 175 | n_classes = n_classes 176 | return wide_resnet(depth, width, n_classes) 177 | 178 | if __name__ == '__main__': 179 | 180 | image_height, image_width, image_channel = (224, 224, 3) 181 | input = misc.imread('../../../data/cat.jpg') 182 | # 按照imagenet的图像格式预处理 183 | input = imagenet_utils.imagenet_preprocess(input) 184 | 185 | n_classes = 1000 186 | # model = wide_resnet_28_10(n_classes) 187 | model = wide_resnet_16_4(n_classes) 188 | # model.eval() 189 | 190 | x = Variable(torch.randn(1, image_channel, image_height, image_width)) 191 | # x = Variable(torch.FloatTensor(torch.from_numpy(input))) 192 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 193 | # print(x.shape) 194 | start = time.time() 195 | pred = model(x) 196 | end = time.time() 197 | print("Wide ResNet forward time:", end-start) 198 | 199 | imagenet_utils.get_imagenet_label(pred) -------------------------------------------------------------------------------- /cifarclassify/modelloader/imagenet/xception.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # code is from https://github.com/tstandley/Xception-PyTorch/blob/master/xception.py 3 | 4 | """ 5 | Creates an Xception Model as defined in: 6 | Francois Chollet 7 | Xception: Deep Learning with Depthwise Separable Convolutions 8 | https://arxiv.org/pdf/1610.02357.pdf 9 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 11 | REMEMBER to set your image size to 3x299x299 for both test and validation 12 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 13 | std=[0.5, 0.5, 0.5]) 14 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 15 | """ 16 | import math 17 | import os 18 | import time 19 | 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.utils.model_zoo as model_zoo 23 | from torch.autograd import Variable 24 | from torch.nn import init 25 | import torch 26 | from scipy import misc 27 | import numpy as np 28 | 29 | from cifarclassify.utils import imagenet_utils 30 | 31 | __all__ = ['xception'] 32 | 33 | # model_urls = { 34 | # 'xception': 'https://www.dropbox.com/s/1hplpzet9d7dv29/xception-c0a72b38.pth.tar?dl=1' 35 | # } 36 | 37 | 38 | class SeparableConv2d(nn.Module): 39 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 40 | super(SeparableConv2d, self).__init__() 41 | 42 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, 43 | bias=bias) 44 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 45 | 46 | def forward(self, x): 47 | x = self.conv1(x) 48 | x = self.pointwise(x) 49 | return x 50 | 51 | 52 | class Block(nn.Module): 53 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): 54 | super(Block, self).__init__() 55 | 56 | if out_filters != in_filters or strides != 1: 57 | self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) 58 | self.skipbn = nn.BatchNorm2d(out_filters) 59 | else: 60 | self.skip = None 61 | 62 | self.relu = nn.ReLU(inplace=True) 63 | rep = [] 64 | 65 | filters = in_filters 66 | if grow_first: 67 | rep.append(self.relu) 68 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 69 | rep.append(nn.BatchNorm2d(out_filters)) 70 | filters = out_filters 71 | 72 | for i in range(reps - 1): 73 | rep.append(self.relu) 74 | rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) 75 | rep.append(nn.BatchNorm2d(filters)) 76 | 77 | if not grow_first: 78 | rep.append(self.relu) 79 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 80 | rep.append(nn.BatchNorm2d(out_filters)) 81 | 82 | if not start_with_relu: 83 | rep = rep[1:] 84 | else: 85 | rep[0] = nn.ReLU(inplace=False) 86 | 87 | if strides != 1: 88 | rep.append(nn.MaxPool2d(3, strides, 1)) 89 | self.rep = nn.Sequential(*rep) 90 | 91 | def forward(self, inp): 92 | x = self.rep(inp) 93 | 94 | if self.skip is not None: 95 | skip = self.skip(inp) 96 | skip = self.skipbn(skip) 97 | else: 98 | skip = inp 99 | 100 | x += skip 101 | return x 102 | 103 | 104 | class Xception(nn.Module): 105 | """ 106 | Xception optimized for the ImageNet dataset, as specified in 107 | https://arxiv.org/pdf/1610.02357.pdf 108 | """ 109 | 110 | def __init__(self, n_classes=1000): 111 | """ Constructor 112 | Args: 113 | num_classes: number of classes 114 | """ 115 | super(Xception, self).__init__() 116 | 117 | self.n_classes = n_classes 118 | 119 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) 120 | self.bn1 = nn.BatchNorm2d(32) 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False) 124 | self.bn2 = nn.BatchNorm2d(64) 125 | # do relu here 126 | 127 | self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) 128 | self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) 129 | self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) 130 | 131 | self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 132 | self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 134 | self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 135 | 136 | self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 137 | self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 138 | self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 139 | self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 140 | 141 | self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) 142 | 143 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) 144 | self.bn3 = nn.BatchNorm2d(1536) 145 | 146 | # do relu here 147 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) 148 | self.bn4 = nn.BatchNorm2d(2048) 149 | 150 | self.fc = nn.Linear(2048, n_classes) 151 | 152 | # ------- init weights -------- 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d): 155 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 156 | m.weight.data.normal_(0, math.sqrt(2. / n)) 157 | elif isinstance(m, nn.BatchNorm2d): 158 | m.weight.data.fill_(1) 159 | m.bias.data.zero_() 160 | # ----------------------------- 161 | 162 | # self.init_weights(pretrained=True) 163 | 164 | def forward(self, x): 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | x = self.relu(x) 168 | 169 | x = self.conv2(x) 170 | x = self.bn2(x) 171 | x = self.relu(x) 172 | 173 | x = self.block1(x) 174 | x = self.block2(x) 175 | x = self.block3(x) 176 | x = self.block4(x) 177 | x = self.block5(x) 178 | x = self.block6(x) 179 | x = self.block7(x) 180 | x = self.block8(x) 181 | x = self.block9(x) 182 | x = self.block10(x) 183 | x = self.block11(x) 184 | x = self.block12(x) 185 | 186 | x = self.conv3(x) 187 | x = self.bn3(x) 188 | x = self.relu(x) 189 | 190 | x = self.conv4(x) 191 | x = self.bn4(x) 192 | x = self.relu(x) 193 | 194 | x = F.adaptive_avg_pool2d(x, (1, 1)) 195 | x = x.view(x.size(0), -1) 196 | x = self.fc(x) 197 | 198 | return x 199 | 200 | def init_weights(self, pretrained=False): 201 | model_checkpoint_path = os.path.expanduser('~/.torch/models/xception-c0a72b38.pth.tar') 202 | if os.path.exists(model_checkpoint_path): 203 | pretrained_dict = torch.load(model_checkpoint_path, map_location='cpu') 204 | # pretrained_dict = model_checkpoint['state_dict'] 205 | 206 | model_dict = self.state_dict() 207 | 208 | # print(model_dict.keys()) 209 | # print(pretrained_dict.keys()) 210 | model_dict_keys = model_dict.keys() 211 | 212 | new_dict = {} 213 | for dict_index, (k, v) in enumerate(pretrained_dict.items()): 214 | # print(dict_index) 215 | # print(k) 216 | # new_k = model_dict_keys[dict_index] 217 | new_k = k 218 | if new_k in model_dict_keys: 219 | new_v = v 220 | new_dict[new_k] = new_v 221 | model_dict.update(new_dict) 222 | self.load_state_dict(model_dict) 223 | 224 | 225 | def xception(pretrained=False, **kwargs): 226 | """ 227 | Construct Xception. 228 | """ 229 | 230 | model = Xception(**kwargs) 231 | if pretrained: 232 | # model.load_state_dict(model_zoo.load_url(model_urls['xception'])) 233 | model.init_weights(pretrained=True) 234 | return model 235 | 236 | if __name__ == '__main__': 237 | n_classes = 1000 238 | model = xception(n_classes=n_classes, pretrained=True) 239 | model.eval() 240 | input_data = misc.imread('../../../data/cat.jpg') 241 | # 按照imagenet的图像格式预处理 242 | input_data = imagenet_utils.imagenet_preprocess(input_data, height=299, width=299) 243 | 244 | x = Variable(torch.FloatTensor(torch.from_numpy(input_data))) 245 | y = Variable(torch.LongTensor(np.ones(1, dtype=np.int))) 246 | 247 | start = time.time() 248 | pred = model(x) 249 | end = time.time() 250 | print("MobileNet forward time:", end - start) 251 | 252 | imagenet_utils.get_imagenet_label(pred) -------------------------------------------------------------------------------- /cifarclassify/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /cifarclassify/utils/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import numpy as np 4 | import pickle 5 | import os 6 | 7 | from cifarclassify.utils import numpy_utils 8 | 9 | def get_imagenet_label(pred): 10 | """ 11 | :param pred: torch.Variable,(1, 1000) 12 | :return: 在imagenet中的标签名 13 | """ 14 | pred_np = pred.data.numpy() 15 | print('pred_np.shape:', pred_np.shape) 16 | pred_np = np.squeeze(pred_np, axis=0) 17 | pred_np_prob = numpy_utils.softmax(pred_np) 18 | # argsort是从小到大 19 | pred_np_argmax = np.argsort(pred_np)[::-1] 20 | pred_np_argmax_top5 = pred_np_argmax[:5] 21 | pred_np_prob_top5 = pred_np_prob[pred_np_argmax_top5] 22 | # print('pred_np_argmax.shape:', pred_np_argmax.shape) 23 | print('pred_np_argmax_top5:', pred_np_argmax_top5) 24 | print('pred_np_prob_top5:', pred_np_prob_top5) 25 | 26 | # 获取可读性的标签 27 | imagenet_label_file_path = os.path.expanduser('~/Data/imagenet1000_clsid_to_human.pkl') 28 | if os.path.exists(imagenet_label_file_path): 29 | label_name = pickle.load(open(imagenet_label_file_path, 'r')) 30 | pred_np_label_name_top5 = [] 31 | for pred_np_argmax_top5_index in pred_np_argmax_top5: 32 | label = label_name[pred_np_argmax_top5_index] 33 | pred_np_label_name_top5.append(label) 34 | print('pred_np_label_name_top5:', pred_np_label_name_top5) 35 | 36 | 37 | def imagenet_preprocess(input_data, height=224, width=224): 38 | """ 39 | :param input_data: input numpy shape: (height, width, channel) 40 | :return: output numpy shape: (batch, channel, height, width) 41 | """ 42 | image_height, image_width, image_channel = (height, width, 3) 43 | # crop中心 44 | input_data = numpy_utils.image_crop_resize(input_data, image_height, image_width) 45 | # 直接resize 46 | # input = misc.imresize(input, (image_height, image_width)) 47 | input_data = input_data[:, :, ::-1] 48 | # BGR 49 | input_data = input_data - [103.939, 116.779, 123.68] 50 | input_data = input_data * 0.017 51 | input_data = np.expand_dims(input_data, axis=0) 52 | input_data = input_data.astype(np.float32) 53 | input_data = input_data.transpose((0, 3, 1, 2)) 54 | # print(input_data.shape) 55 | return input_data -------------------------------------------------------------------------------- /cifarclassify/utils/numpy_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import numpy as np 5 | import skimage 6 | from scipy import misc 7 | 8 | 9 | def softmax(x): 10 | """ get the softmax value of x 11 | :param x: 12 | :return: 13 | """ 14 | exp_x = np.exp(x) 15 | softmax_x = exp_x / np.sum(exp_x) 16 | return softmax_x 17 | 18 | 19 | def crop_center(img, cropx, cropy): 20 | """ crop the image in the center 21 | :param img: 22 | :param cropx: 23 | :param cropy: 24 | :return: 25 | """ 26 | y, x, c = img.shape 27 | startx = x // 2 - (cropx // 2) 28 | starty = y // 2 - (cropy // 2) 29 | return img[starty:starty + cropy, startx:startx + cropx] 30 | 31 | 32 | def image_crop_resize(img, input_height, input_width): 33 | """ crop the image and resize 34 | :param img: 35 | :param input_height: 36 | :param input_width: 37 | :return: 38 | """ 39 | img_scaled = None 40 | aspect = img.shape[1] / float(img.shape[0]) 41 | # print("Orginal aspect ratio: " + str(aspect)) 42 | if aspect > 1: 43 | # landscape orientation - wide image 44 | res = int(aspect * input_height) 45 | img_scaled = misc.imresize(img, (input_width, res)) 46 | if aspect < 1: 47 | # portrait orientation - tall image 48 | res = int(input_width / aspect) 49 | img_scaled = misc.imresize(img, (res, input_height)) 50 | if aspect == 1: 51 | img_scaled = misc.imresize(img, (input_width, input_height)) 52 | 53 | img_center = crop_center(img_scaled, 224, 224) 54 | return img_center 55 | -------------------------------------------------------------------------------- /data/EnglishCockerSpaniel_simon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanfuchen/cifarclassify/650881e33e45f09a75a1815e83cd9e0f38f88f57/data/EnglishCockerSpaniel_simon.jpg -------------------------------------------------------------------------------- /data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanfuchen/cifarclassify/650881e33e45f09a75a1815e83cd9e0f38f88f57/data/cat.jpg -------------------------------------------------------------------------------- /data/synset_words.txt: -------------------------------------------------------------------------------- 1 | n01440764 tench, Tinca tinca 2 | n01443537 goldfish, Carassius auratus 3 | n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | n01491361 tiger shark, Galeocerdo cuvieri 5 | n01494475 hammerhead, hammerhead shark 6 | n01496331 electric ray, crampfish, numbfish, torpedo 7 | n01498041 stingray 8 | n01514668 cock 9 | n01514859 hen 10 | n01518878 ostrich, Struthio camelus 11 | n01530575 brambling, Fringilla montifringilla 12 | n01531178 goldfinch, Carduelis carduelis 13 | n01532829 house finch, linnet, Carpodacus mexicanus 14 | n01534433 junco, snowbird 15 | n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | n01558993 robin, American robin, Turdus migratorius 17 | n01560419 bulbul 18 | n01580077 jay 19 | n01582220 magpie 20 | n01592084 chickadee 21 | n01601694 water ouzel, dipper 22 | n01608432 kite 23 | n01614925 bald eagle, American eagle, Haliaeetus leucocephalus 24 | n01616318 vulture 25 | n01622779 great grey owl, great gray owl, Strix nebulosa 26 | n01629819 European fire salamander, Salamandra salamandra 27 | n01630670 common newt, Triturus vulgaris 28 | n01631663 eft 29 | n01632458 spotted salamander, Ambystoma maculatum 30 | n01632777 axolotl, mud puppy, Ambystoma mexicanum 31 | n01641577 bullfrog, Rana catesbeiana 32 | n01644373 tree frog, tree-frog 33 | n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | n01664065 loggerhead, loggerhead turtle, Caretta caretta 35 | n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | n01667114 mud turtle 37 | n01667778 terrapin 38 | n01669191 box turtle, box tortoise 39 | n01675722 banded gecko 40 | n01677366 common iguana, iguana, Iguana iguana 41 | n01682714 American chameleon, anole, Anolis carolinensis 42 | n01685808 whiptail, whiptail lizard 43 | n01687978 agama 44 | n01688243 frilled lizard, Chlamydosaurus kingi 45 | n01689811 alligator lizard 46 | n01692333 Gila monster, Heloderma suspectum 47 | n01693334 green lizard, Lacerta viridis 48 | n01694178 African chameleon, Chamaeleo chamaeleon 49 | n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | n01697457 African crocodile, Nile crocodile, Crocodylus niloticus 51 | n01698640 American alligator, Alligator mississipiensis 52 | n01704323 triceratops 53 | n01728572 thunder snake, worm snake, Carphophis amoenus 54 | n01728920 ringneck snake, ring-necked snake, ring snake 55 | n01729322 hognose snake, puff adder, sand viper 56 | n01729977 green snake, grass snake 57 | n01734418 king snake, kingsnake 58 | n01735189 garter snake, grass snake 59 | n01737021 water snake 60 | n01739381 vine snake 61 | n01740131 night snake, Hypsiglena torquata 62 | n01742172 boa constrictor, Constrictor constrictor 63 | n01744401 rock python, rock snake, Python sebae 64 | n01748264 Indian cobra, Naja naja 65 | n01749939 green mamba 66 | n01751748 sea snake 67 | n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 70 | n01768244 trilobite 71 | n01770081 harvestman, daddy longlegs, Phalangium opilio 72 | n01770393 scorpion 73 | n01773157 black and gold garden spider, Argiope aurantia 74 | n01773549 barn spider, Araneus cavaticus 75 | n01773797 garden spider, Aranea diademata 76 | n01774384 black widow, Latrodectus mactans 77 | n01774750 tarantula 78 | n01775062 wolf spider, hunting spider 79 | n01776313 tick 80 | n01784675 centipede 81 | n01795545 black grouse 82 | n01796340 ptarmigan 83 | n01797886 ruffed grouse, partridge, Bonasa umbellus 84 | n01798484 prairie chicken, prairie grouse, prairie fowl 85 | n01806143 peacock 86 | n01806567 quail 87 | n01807496 partridge 88 | n01817953 African grey, African gray, Psittacus erithacus 89 | n01818515 macaw 90 | n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | n01820546 lorikeet 92 | n01824575 coucal 93 | n01828970 bee eater 94 | n01829413 hornbill 95 | n01833805 hummingbird 96 | n01843065 jacamar 97 | n01843383 toucan 98 | n01847000 drake 99 | n01855032 red-breasted merganser, Mergus serrator 100 | n01855672 goose 101 | n01860187 black swan, Cygnus atratus 102 | n01871265 tusker 103 | n01872401 echidna, spiny anteater, anteater 104 | n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | n01877812 wallaby, brush kangaroo 106 | n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | n01883070 wombat 108 | n01910747 jellyfish 109 | n01914609 sea anemone, anemone 110 | n01917289 brain coral 111 | n01924916 flatworm, platyhelminth 112 | n01930112 nematode, nematode worm, roundworm 113 | n01943899 conch 114 | n01944390 snail 115 | n01945685 slug 116 | n01950731 sea slug, nudibranch 117 | n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | n01968897 chambered nautilus, pearly nautilus, nautilus 119 | n01978287 Dungeness crab, Cancer magister 120 | n01978455 rock crab, Cancer irroratus 121 | n01980166 fiddler crab 122 | n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | n01985128 crayfish, crawfish, crawdad, crawdaddy 126 | n01986214 hermit crab 127 | n01990800 isopod 128 | n02002556 white stork, Ciconia ciconia 129 | n02002724 black stork, Ciconia nigra 130 | n02006656 spoonbill 131 | n02007558 flamingo 132 | n02009229 little blue heron, Egretta caerulea 133 | n02009912 American egret, great white heron, Egretta albus 134 | n02011460 bittern 135 | n02012849 crane 136 | n02013706 limpkin, Aramus pictus 137 | n02017213 European gallinule, Porphyrio porphyrio 138 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 139 | n02018795 bustard 140 | n02025239 ruddy turnstone, Arenaria interpres 141 | n02027492 red-backed sandpiper, dunlin, Erolia alpina 142 | n02028035 redshank, Tringa totanus 143 | n02033041 dowitcher 144 | n02037110 oystercatcher, oyster catcher 145 | n02051845 pelican 146 | n02056570 king penguin, Aptenodytes patagonica 147 | n02058221 albatross, mollymawk 148 | n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | n02074367 dugong, Dugong dugon 151 | n02077923 sea lion 152 | n02085620 Chihuahua 153 | n02085782 Japanese spaniel 154 | n02085936 Maltese dog, Maltese terrier, Maltese 155 | n02086079 Pekinese, Pekingese, Peke 156 | n02086240 Shih-Tzu 157 | n02086646 Blenheim spaniel 158 | n02086910 papillon 159 | n02087046 toy terrier 160 | n02087394 Rhodesian ridgeback 161 | n02088094 Afghan hound, Afghan 162 | n02088238 basset, basset hound 163 | n02088364 beagle 164 | n02088466 bloodhound, sleuthhound 165 | n02088632 bluetick 166 | n02089078 black-and-tan coonhound 167 | n02089867 Walker hound, Walker foxhound 168 | n02089973 English foxhound 169 | n02090379 redbone 170 | n02090622 borzoi, Russian wolfhound 171 | n02090721 Irish wolfhound 172 | n02091032 Italian greyhound 173 | n02091134 whippet 174 | n02091244 Ibizan hound, Ibizan Podenco 175 | n02091467 Norwegian elkhound, elkhound 176 | n02091635 otterhound, otter hound 177 | n02091831 Saluki, gazelle hound 178 | n02092002 Scottish deerhound, deerhound 179 | n02092339 Weimaraner 180 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 181 | n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | n02093647 Bedlington terrier 183 | n02093754 Border terrier 184 | n02093859 Kerry blue terrier 185 | n02093991 Irish terrier 186 | n02094114 Norfolk terrier 187 | n02094258 Norwich terrier 188 | n02094433 Yorkshire terrier 189 | n02095314 wire-haired fox terrier 190 | n02095570 Lakeland terrier 191 | n02095889 Sealyham terrier, Sealyham 192 | n02096051 Airedale, Airedale terrier 193 | n02096177 cairn, cairn terrier 194 | n02096294 Australian terrier 195 | n02096437 Dandie Dinmont, Dandie Dinmont terrier 196 | n02096585 Boston bull, Boston terrier 197 | n02097047 miniature schnauzer 198 | n02097130 giant schnauzer 199 | n02097209 standard schnauzer 200 | n02097298 Scotch terrier, Scottish terrier, Scottie 201 | n02097474 Tibetan terrier, chrysanthemum dog 202 | n02097658 silky terrier, Sydney silky 203 | n02098105 soft-coated wheaten terrier 204 | n02098286 West Highland white terrier 205 | n02098413 Lhasa, Lhasa apso 206 | n02099267 flat-coated retriever 207 | n02099429 curly-coated retriever 208 | n02099601 golden retriever 209 | n02099712 Labrador retriever 210 | n02099849 Chesapeake Bay retriever 211 | n02100236 German short-haired pointer 212 | n02100583 vizsla, Hungarian pointer 213 | n02100735 English setter 214 | n02100877 Irish setter, red setter 215 | n02101006 Gordon setter 216 | n02101388 Brittany spaniel 217 | n02101556 clumber, clumber spaniel 218 | n02102040 English springer, English springer spaniel 219 | n02102177 Welsh springer spaniel 220 | n02102318 cocker spaniel, English cocker spaniel, cocker 221 | n02102480 Sussex spaniel 222 | n02102973 Irish water spaniel 223 | n02104029 kuvasz 224 | n02104365 schipperke 225 | n02105056 groenendael 226 | n02105162 malinois 227 | n02105251 briard 228 | n02105412 kelpie 229 | n02105505 komondor 230 | n02105641 Old English sheepdog, bobtail 231 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 232 | n02106030 collie 233 | n02106166 Border collie 234 | n02106382 Bouvier des Flandres, Bouviers des Flandres 235 | n02106550 Rottweiler 236 | n02106662 German shepherd, German shepherd dog, German police dog, alsatian 237 | n02107142 Doberman, Doberman pinscher 238 | n02107312 miniature pinscher 239 | n02107574 Greater Swiss Mountain dog 240 | n02107683 Bernese mountain dog 241 | n02107908 Appenzeller 242 | n02108000 EntleBucher 243 | n02108089 boxer 244 | n02108422 bull mastiff 245 | n02108551 Tibetan mastiff 246 | n02108915 French bulldog 247 | n02109047 Great Dane 248 | n02109525 Saint Bernard, St Bernard 249 | n02109961 Eskimo dog, husky 250 | n02110063 malamute, malemute, Alaskan malamute 251 | n02110185 Siberian husky 252 | n02110341 dalmatian, coach dog, carriage dog 253 | n02110627 affenpinscher, monkey pinscher, monkey dog 254 | n02110806 basenji 255 | n02110958 pug, pug-dog 256 | n02111129 Leonberg 257 | n02111277 Newfoundland, Newfoundland dog 258 | n02111500 Great Pyrenees 259 | n02111889 Samoyed, Samoyede 260 | n02112018 Pomeranian 261 | n02112137 chow, chow chow 262 | n02112350 keeshond 263 | n02112706 Brabancon griffon 264 | n02113023 Pembroke, Pembroke Welsh corgi 265 | n02113186 Cardigan, Cardigan Welsh corgi 266 | n02113624 toy poodle 267 | n02113712 miniature poodle 268 | n02113799 standard poodle 269 | n02113978 Mexican hairless 270 | n02114367 timber wolf, grey wolf, gray wolf, Canis lupus 271 | n02114548 white wolf, Arctic wolf, Canis lupus tundrarum 272 | n02114712 red wolf, maned wolf, Canis rufus, Canis niger 273 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 274 | n02115641 dingo, warrigal, warragal, Canis dingo 275 | n02115913 dhole, Cuon alpinus 276 | n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | n02117135 hyena, hyaena 278 | n02119022 red fox, Vulpes vulpes 279 | n02119789 kit fox, Vulpes macrotis 280 | n02120079 Arctic fox, white fox, Alopex lagopus 281 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 282 | n02123045 tabby, tabby cat 283 | n02123159 tiger cat 284 | n02123394 Persian cat 285 | n02123597 Siamese cat, Siamese 286 | n02124075 Egyptian cat 287 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | n02127052 lynx, catamount 289 | n02128385 leopard, Panthera pardus 290 | n02128757 snow leopard, ounce, Panthera uncia 291 | n02128925 jaguar, panther, Panthera onca, Felis onca 292 | n02129165 lion, king of beasts, Panthera leo 293 | n02129604 tiger, Panthera tigris 294 | n02130308 cheetah, chetah, Acinonyx jubatus 295 | n02132136 brown bear, bruin, Ursus arctos 296 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 297 | n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | n02134418 sloth bear, Melursus ursinus, Ursus ursinus 299 | n02137549 mongoose 300 | n02138441 meerkat, mierkat 301 | n02165105 tiger beetle 302 | n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | n02167151 ground beetle, carabid beetle 304 | n02168699 long-horned beetle, longicorn, longicorn beetle 305 | n02169497 leaf beetle, chrysomelid 306 | n02172182 dung beetle 307 | n02174001 rhinoceros beetle 308 | n02177972 weevil 309 | n02190166 fly 310 | n02206856 bee 311 | n02219486 ant, emmet, pismire 312 | n02226429 grasshopper, hopper 313 | n02229544 cricket 314 | n02231487 walking stick, walkingstick, stick insect 315 | n02233338 cockroach, roach 316 | n02236044 mantis, mantid 317 | n02256656 cicada, cicala 318 | n02259212 leafhopper 319 | n02264363 lacewing, lacewing fly 320 | n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | n02268853 damselfly 322 | n02276258 admiral 323 | n02277742 ringlet, ringlet butterfly 324 | n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | n02280649 cabbage butterfly 326 | n02281406 sulphur butterfly, sulfur butterfly 327 | n02281787 lycaenid, lycaenid butterfly 328 | n02317335 starfish, sea star 329 | n02319095 sea urchin 330 | n02321529 sea cucumber, holothurian 331 | n02325366 wood rabbit, cottontail, cottontail rabbit 332 | n02326432 hare 333 | n02328150 Angora, Angora rabbit 334 | n02342885 hamster 335 | n02346627 porcupine, hedgehog 336 | n02356798 fox squirrel, eastern fox squirrel, Sciurus niger 337 | n02361337 marmot 338 | n02363005 beaver 339 | n02364673 guinea pig, Cavia cobaya 340 | n02389026 sorrel 341 | n02391049 zebra 342 | n02395406 hog, pig, grunter, squealer, Sus scrofa 343 | n02396427 wild boar, boar, Sus scrofa 344 | n02397096 warthog 345 | n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | n02403003 ox 347 | n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | n02410509 bison 349 | n02412080 ram, tup 350 | n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | n02417914 ibex, Capra ibex 352 | n02422106 hartebeest 353 | n02422699 impala, Aepyceros melampus 354 | n02423022 gazelle 355 | n02437312 Arabian camel, dromedary, Camelus dromedarius 356 | n02437616 llama 357 | n02441942 weasel 358 | n02442845 mink 359 | n02443114 polecat, fitch, foulmart, foumart, Mustela putorius 360 | n02443484 black-footed ferret, ferret, Mustela nigripes 361 | n02444819 otter 362 | n02445715 skunk, polecat, wood pussy 363 | n02447366 badger 364 | n02454379 armadillo 365 | n02457408 three-toed sloth, ai, Bradypus tridactylus 366 | n02480495 orangutan, orang, orangutang, Pongo pygmaeus 367 | n02480855 gorilla, Gorilla gorilla 368 | n02481823 chimpanzee, chimp, Pan troglodytes 369 | n02483362 gibbon, Hylobates lar 370 | n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | n02484975 guenon, guenon monkey 372 | n02486261 patas, hussar monkey, Erythrocebus patas 373 | n02486410 baboon 374 | n02487347 macaque 375 | n02488291 langur 376 | n02488702 colobus, colobus monkey 377 | n02489166 proboscis monkey, Nasalis larvatus 378 | n02490219 marmoset 379 | n02492035 capuchin, ringtail, Cebus capucinus 380 | n02492660 howler monkey, howler 381 | n02493509 titi, titi monkey 382 | n02493793 spider monkey, Ateles geoffroyi 383 | n02494079 squirrel monkey, Saimiri sciureus 384 | n02497673 Madagascar cat, ring-tailed lemur, Lemur catta 385 | n02500267 indri, indris, Indri indri, Indri brevicaudatus 386 | n02504013 Indian elephant, Elephas maximus 387 | n02504458 African elephant, Loxodonta africana 388 | n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | n02514041 barracouta, snoek 391 | n02526121 eel 392 | n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | n02606052 rock beauty, Holocanthus tricolor 394 | n02607072 anemone fish 395 | n02640242 sturgeon 396 | n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus 397 | n02643566 lionfish 398 | n02655020 puffer, pufferfish, blowfish, globefish 399 | n02666196 abacus 400 | n02667093 abaya 401 | n02669723 academic gown, academic robe, judge's robe 402 | n02672831 accordion, piano accordion, squeeze box 403 | n02676566 acoustic guitar 404 | n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier 405 | n02690373 airliner 406 | n02692877 airship, dirigible 407 | n02699494 altar 408 | n02701002 ambulance 409 | n02704792 amphibian, amphibious vehicle 410 | n02708093 analog clock 411 | n02727426 apiary, bee house 412 | n02730930 apron 413 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | n02749479 assault rifle, assault gun 415 | n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack 416 | n02776631 bakery, bakeshop, bakehouse 417 | n02777292 balance beam, beam 418 | n02782093 balloon 419 | n02783161 ballpoint, ballpoint pen, ballpen, Biro 420 | n02786058 Band Aid 421 | n02787622 banjo 422 | n02788148 bannister, banister, balustrade, balusters, handrail 423 | n02790996 barbell 424 | n02791124 barber chair 425 | n02791270 barbershop 426 | n02793495 barn 427 | n02794156 barometer 428 | n02795169 barrel, cask 429 | n02797295 barrow, garden cart, lawn cart, wheelbarrow 430 | n02799071 baseball 431 | n02802426 basketball 432 | n02804414 bassinet 433 | n02804610 bassoon 434 | n02807133 bathing cap, swimming cap 435 | n02808304 bath towel 436 | n02808440 bathtub, bathing tub, bath, tub 437 | n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | n02814860 beacon, lighthouse, beacon light, pharos 439 | n02815834 beaker 440 | n02817516 bearskin, busby, shako 441 | n02823428 beer bottle 442 | n02823750 beer glass 443 | n02825657 bell cote, bell cot 444 | n02834397 bib 445 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 446 | n02837789 bikini, two-piece 447 | n02840245 binder, ring-binder 448 | n02841315 binoculars, field glasses, opera glasses 449 | n02843684 birdhouse 450 | n02859443 boathouse 451 | n02860847 bobsled, bobsleigh, bob 452 | n02865351 bolo tie, bolo, bola tie, bola 453 | n02869837 bonnet, poke bonnet 454 | n02870880 bookcase 455 | n02871525 bookshop, bookstore, bookstall 456 | n02877765 bottlecap 457 | n02879718 bow 458 | n02883205 bow tie, bow-tie, bowtie 459 | n02892201 brass, memorial tablet, plaque 460 | n02892767 brassiere, bra, bandeau 461 | n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | n02895154 breastplate, aegis, egis 463 | n02906734 broom 464 | n02909870 bucket, pail 465 | n02910353 buckle 466 | n02916936 bulletproof vest 467 | n02917067 bullet train, bullet 468 | n02927161 butcher shop, meat market 469 | n02930766 cab, hack, taxi, taxicab 470 | n02939185 caldron, cauldron 471 | n02948072 candle, taper, wax light 472 | n02950826 cannon 473 | n02951358 canoe 474 | n02951585 can opener, tin opener 475 | n02963159 cardigan 476 | n02965783 car mirror 477 | n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig 478 | n02966687 carpenter's kit, tool kit 479 | n02971356 carton 480 | n02974003 car wheel 481 | n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | n02978881 cassette 483 | n02979186 cassette player 484 | n02980441 castle 485 | n02981792 catamaran 486 | n02988304 CD player 487 | n02992211 cello, violoncello 488 | n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | n02999410 chain 490 | n03000134 chainlink fence 491 | n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | n03000684 chain saw, chainsaw 493 | n03014705 chest 494 | n03016953 chiffonier, commode 495 | n03017168 chime, bell, gong 496 | n03018349 china cabinet, china closet 497 | n03026506 Christmas stocking 498 | n03028079 church, church building 499 | n03032252 cinema, movie theater, movie theatre, movie house, picture palace 500 | n03041632 cleaver, meat cleaver, chopper 501 | n03042490 cliff dwelling 502 | n03045698 cloak 503 | n03047690 clog, geta, patten, sabot 504 | n03062245 cocktail shaker 505 | n03063599 coffee mug 506 | n03063689 coffeepot 507 | n03065424 coil, spiral, volute, whorl, helix 508 | n03075370 combination lock 509 | n03085013 computer keyboard, keypad 510 | n03089624 confectionery, confectionary, candy store 511 | n03095699 container ship, containership, container vessel 512 | n03100240 convertible 513 | n03109150 corkscrew, bottle screw 514 | n03110669 cornet, horn, trumpet, trump 515 | n03124043 cowboy boot 516 | n03124170 cowboy hat, ten-gallon hat 517 | n03125729 cradle 518 | n03126707 crane 519 | n03127747 crash helmet 520 | n03127925 crate 521 | n03131574 crib, cot 522 | n03133878 Crock Pot 523 | n03134739 croquet ball 524 | n03141823 crutch 525 | n03146219 cuirass 526 | n03160309 dam, dike, dyke 527 | n03179701 desk 528 | n03180011 desktop computer 529 | n03187595 dial telephone, dial phone 530 | n03188531 diaper, nappy, napkin 531 | n03196217 digital clock 532 | n03197337 digital watch 533 | n03201208 dining table, board 534 | n03207743 dishrag, dishcloth 535 | n03207941 dishwasher, dish washer, dishwashing machine 536 | n03208938 disk brake, disc brake 537 | n03216828 dock, dockage, docking facility 538 | n03218198 dogsled, dog sled, dog sleigh 539 | n03220513 dome 540 | n03223299 doormat, welcome mat 541 | n03240683 drilling platform, offshore rig 542 | n03249569 drum, membranophone, tympan 543 | n03250847 drumstick 544 | n03255030 dumbbell 545 | n03259280 Dutch oven 546 | n03271574 electric fan, blower 547 | n03272010 electric guitar 548 | n03272562 electric locomotive 549 | n03290653 entertainment center 550 | n03291819 envelope 551 | n03297495 espresso maker 552 | n03314780 face powder 553 | n03325584 feather boa, boa 554 | n03337140 file, file cabinet, filing cabinet 555 | n03344393 fireboat 556 | n03345487 fire engine, fire truck 557 | n03347037 fire screen, fireguard 558 | n03355925 flagpole, flagstaff 559 | n03372029 flute, transverse flute 560 | n03376595 folding chair 561 | n03379051 football helmet 562 | n03384352 forklift 563 | n03388043 fountain 564 | n03388183 fountain pen 565 | n03388549 four-poster 566 | n03393912 freight car 567 | n03394916 French horn, horn 568 | n03400231 frying pan, frypan, skillet 569 | n03404251 fur coat 570 | n03417042 garbage truck, dustcart 571 | n03424325 gasmask, respirator, gas helmet 572 | n03425413 gas pump, gasoline pump, petrol pump, island dispenser 573 | n03443371 goblet 574 | n03444034 go-kart 575 | n03445777 golf ball 576 | n03445924 golfcart, golf cart 577 | n03447447 gondola 578 | n03447721 gong, tam-tam 579 | n03450230 gown 580 | n03452741 grand piano, grand 581 | n03457902 greenhouse, nursery, glasshouse 582 | n03459775 grille, radiator grille 583 | n03461385 grocery store, grocery, food market, market 584 | n03467068 guillotine 585 | n03476684 hair slide 586 | n03476991 hair spray 587 | n03478589 half track 588 | n03481172 hammer 589 | n03482405 hamper 590 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | n03485407 hand-held computer, hand-held microcomputer 592 | n03485794 handkerchief, hankie, hanky, hankey 593 | n03492542 hard disc, hard disk, fixed disk 594 | n03494278 harmonica, mouth organ, harp, mouth harp 595 | n03495258 harp 596 | n03496892 harvester, reaper 597 | n03498962 hatchet 598 | n03527444 holster 599 | n03529860 home theater, home theatre 600 | n03530642 honeycomb 601 | n03532672 hook, claw 602 | n03534580 hoopskirt, crinoline 603 | n03535780 horizontal bar, high bar 604 | n03538406 horse cart, horse-cart 605 | n03544143 hourglass 606 | n03584254 iPod 607 | n03584829 iron, smoothing iron 608 | n03590841 jack-o'-lantern 609 | n03594734 jean, blue jean, denim 610 | n03594945 jeep, landrover 611 | n03595614 jersey, T-shirt, tee shirt 612 | n03598930 jigsaw puzzle 613 | n03599486 jinrikisha, ricksha, rickshaw 614 | n03602883 joystick 615 | n03617480 kimono 616 | n03623198 knee pad 617 | n03627232 knot 618 | n03630383 lab coat, laboratory coat 619 | n03633091 ladle 620 | n03637318 lampshade, lamp shade 621 | n03642806 laptop, laptop computer 622 | n03649909 lawn mower, mower 623 | n03657121 lens cap, lens cover 624 | n03658185 letter opener, paper knife, paperknife 625 | n03661043 library 626 | n03662601 lifeboat 627 | n03666591 lighter, light, igniter, ignitor 628 | n03670208 limousine, limo 629 | n03673027 liner, ocean liner 630 | n03676483 lipstick, lip rouge 631 | n03680355 Loafer 632 | n03690938 lotion 633 | n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | n03692522 loupe, jeweler's loupe 635 | n03697007 lumbermill, sawmill 636 | n03706229 magnetic compass 637 | n03709823 mailbag, postbag 638 | n03710193 mailbox, letter box 639 | n03710637 maillot 640 | n03710721 maillot, tank suit 641 | n03717622 manhole cover 642 | n03720891 maraca 643 | n03721384 marimba, xylophone 644 | n03724870 mask 645 | n03729826 matchstick 646 | n03733131 maypole 647 | n03733281 maze, labyrinth 648 | n03733805 measuring cup 649 | n03742115 medicine chest, medicine cabinet 650 | n03743016 megalith, megalithic structure 651 | n03759954 microphone, mike 652 | n03761084 microwave, microwave oven 653 | n03763968 military uniform 654 | n03764736 milk can 655 | n03769881 minibus 656 | n03770439 miniskirt, mini 657 | n03770679 minivan 658 | n03773504 missile 659 | n03775071 mitten 660 | n03775546 mixing bowl 661 | n03776460 mobile home, manufactured home 662 | n03777568 Model T 663 | n03777754 modem 664 | n03781244 monastery 665 | n03782006 monitor 666 | n03785016 moped 667 | n03786901 mortar 668 | n03787032 mortarboard 669 | n03788195 mosque 670 | n03788365 mosquito net 671 | n03791053 motor scooter, scooter 672 | n03792782 mountain bike, all-terrain bike, off-roader 673 | n03792972 mountain tent 674 | n03793489 mouse, computer mouse 675 | n03794056 mousetrap 676 | n03796401 moving van 677 | n03803284 muzzle 678 | n03804744 nail 679 | n03814639 neck brace 680 | n03814906 necklace 681 | n03825788 nipple 682 | n03832673 notebook, notebook computer 683 | n03837869 obelisk 684 | n03838899 oboe, hautboy, hautbois 685 | n03840681 ocarina, sweet potato 686 | n03841143 odometer, hodometer, mileometer, milometer 687 | n03843555 oil filter 688 | n03854065 organ, pipe organ 689 | n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | n03866082 overskirt 691 | n03868242 oxcart 692 | n03868863 oxygen mask 693 | n03871628 packet 694 | n03873416 paddle, boat paddle 695 | n03874293 paddlewheel, paddle wheel 696 | n03874599 padlock 697 | n03876231 paintbrush 698 | n03877472 pajama, pyjama, pj's, jammies 699 | n03877845 palace 700 | n03884397 panpipe, pandean pipe, syrinx 701 | n03887697 paper towel 702 | n03888257 parachute, chute 703 | n03888605 parallel bars, bars 704 | n03891251 park bench 705 | n03891332 parking meter 706 | n03895866 passenger car, coach, carriage 707 | n03899768 patio, terrace 708 | n03902125 pay-phone, pay-station 709 | n03903868 pedestal, plinth, footstall 710 | n03908618 pencil box, pencil case 711 | n03908714 pencil sharpener 712 | n03916031 perfume, essence 713 | n03920288 Petri dish 714 | n03924679 photocopier 715 | n03929660 pick, plectrum, plectron 716 | n03929855 pickelhaube 717 | n03930313 picket fence, paling 718 | n03930630 pickup, pickup truck 719 | n03933933 pier 720 | n03935335 piggy bank, penny bank 721 | n03937543 pill bottle 722 | n03938244 pillow 723 | n03942813 ping-pong ball 724 | n03944341 pinwheel 725 | n03947888 pirate, pirate ship 726 | n03950228 pitcher, ewer 727 | n03954731 plane, carpenter's plane, woodworking plane 728 | n03956157 planetarium 729 | n03958227 plastic bag 730 | n03961711 plate rack 731 | n03967562 plow, plough 732 | n03970156 plunger, plumber's helper 733 | n03976467 Polaroid camera, Polaroid Land camera 734 | n03976657 pole 735 | n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | n03980874 poncho 737 | n03982430 pool table, billiard table, snooker table 738 | n03983396 pop bottle, soda bottle 739 | n03991062 pot, flowerpot 740 | n03992509 potter's wheel 741 | n03995372 power drill 742 | n03998194 prayer rug, prayer mat 743 | n04004767 printer 744 | n04005630 prison, prison house 745 | n04008634 projectile, missile 746 | n04009552 projector 747 | n04019541 puck, hockey puck 748 | n04023962 punching bag, punch bag, punching ball, punchball 749 | n04026417 purse 750 | n04033901 quill, quill pen 751 | n04033995 quilt, comforter, comfort, puff 752 | n04037443 racer, race car, racing car 753 | n04039381 racket, racquet 754 | n04040759 radiator 755 | n04041544 radio, wireless 756 | n04044716 radio telescope, radio reflector 757 | n04049303 rain barrel 758 | n04065272 recreational vehicle, RV, R.V. 759 | n04067472 reel 760 | n04069434 reflex camera 761 | n04070727 refrigerator, icebox 762 | n04074963 remote control, remote 763 | n04081281 restaurant, eating house, eating place, eatery 764 | n04086273 revolver, six-gun, six-shooter 765 | n04090263 rifle 766 | n04099969 rocking chair, rocker 767 | n04111531 rotisserie 768 | n04116512 rubber eraser, rubber, pencil eraser 769 | n04118538 rugby ball 770 | n04118776 rule, ruler 771 | n04120489 running shoe 772 | n04125021 safe 773 | n04127249 safety pin 774 | n04131690 saltshaker, salt shaker 775 | n04133789 sandal 776 | n04136333 sarong 777 | n04141076 sax, saxophone 778 | n04141327 scabbard 779 | n04141975 scale, weighing machine 780 | n04146614 school bus 781 | n04147183 schooner 782 | n04149813 scoreboard 783 | n04152593 screen, CRT screen 784 | n04153751 screw 785 | n04154565 screwdriver 786 | n04162706 seat belt, seatbelt 787 | n04179913 sewing machine 788 | n04192698 shield, buckler 789 | n04200800 shoe shop, shoe-shop, shoe store 790 | n04201297 shoji 791 | n04204238 shopping basket 792 | n04204347 shopping cart 793 | n04208210 shovel 794 | n04209133 shower cap 795 | n04209239 shower curtain 796 | n04228054 ski 797 | n04229816 ski mask 798 | n04235860 sleeping bag 799 | n04238763 slide rule, slipstick 800 | n04239074 sliding door 801 | n04243546 slot, one-armed bandit 802 | n04251144 snorkel 803 | n04252077 snowmobile 804 | n04252225 snowplow, snowplough 805 | n04254120 soap dispenser 806 | n04254680 soccer ball 807 | n04254777 sock 808 | n04258138 solar dish, solar collector, solar furnace 809 | n04259630 sombrero 810 | n04263257 soup bowl 811 | n04264628 space bar 812 | n04265275 space heater 813 | n04266014 space shuttle 814 | n04270147 spatula 815 | n04273569 speedboat 816 | n04275548 spider web, spider's web 817 | n04277352 spindle 818 | n04285008 sports car, sport car 819 | n04286575 spotlight, spot 820 | n04296562 stage 821 | n04310018 steam locomotive 822 | n04311004 steel arch bridge 823 | n04311174 steel drum 824 | n04317175 stethoscope 825 | n04325704 stole 826 | n04326547 stone wall 827 | n04328186 stopwatch, stop watch 828 | n04330267 stove 829 | n04332243 strainer 830 | n04335435 streetcar, tram, tramcar, trolley, trolley car 831 | n04336792 stretcher 832 | n04344873 studio couch, day bed 833 | n04346328 stupa, tope 834 | n04347754 submarine, pigboat, sub, U-boat 835 | n04350905 suit, suit of clothes 836 | n04355338 sundial 837 | n04355933 sunglass 838 | n04356056 sunglasses, dark glasses, shades 839 | n04357314 sunscreen, sunblock, sun blocker 840 | n04366367 suspension bridge 841 | n04367480 swab, swob, mop 842 | n04370456 sweatshirt 843 | n04371430 swimming trunks, bathing trunks 844 | n04371774 swing 845 | n04372370 switch, electric switch, electrical switch 846 | n04376876 syringe 847 | n04380533 table lamp 848 | n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle 849 | n04392985 tape player 850 | n04398044 teapot 851 | n04399382 teddy, teddy bear 852 | n04404412 television, television system 853 | n04409515 tennis ball 854 | n04417672 thatch, thatched roof 855 | n04418357 theater curtain, theatre curtain 856 | n04423845 thimble 857 | n04428191 thresher, thrasher, threshing machine 858 | n04429376 throne 859 | n04435653 tile roof 860 | n04442312 toaster 861 | n04443257 tobacco shop, tobacconist shop, tobacconist 862 | n04447861 toilet seat 863 | n04456115 torch 864 | n04458633 totem pole 865 | n04461696 tow truck, tow car, wrecker 866 | n04462240 toyshop 867 | n04465501 tractor 868 | n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | n04476259 tray 870 | n04479046 trench coat 871 | n04482393 tricycle, trike, velocipede 872 | n04483307 trimaran 873 | n04485082 tripod 874 | n04486054 triumphal arch 875 | n04487081 trolleybus, trolley coach, trackless trolley 876 | n04487394 trombone 877 | n04493381 tub, vat 878 | n04501370 turnstile 879 | n04505470 typewriter keyboard 880 | n04507155 umbrella 881 | n04509417 unicycle, monocycle 882 | n04515003 upright, upright piano 883 | n04517823 vacuum, vacuum cleaner 884 | n04522168 vase 885 | n04523525 vault 886 | n04525038 velvet 887 | n04525305 vending machine 888 | n04532106 vestment 889 | n04532670 viaduct 890 | n04536866 violin, fiddle 891 | n04540053 volleyball 892 | n04542943 waffle iron 893 | n04548280 wall clock 894 | n04548362 wallet, billfold, notecase, pocketbook 895 | n04550184 wardrobe, closet, press 896 | n04552348 warplane, military plane 897 | n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | n04554684 washer, automatic washer, washing machine 899 | n04557648 water bottle 900 | n04560804 water jug 901 | n04562935 water tower 902 | n04579145 whiskey jug 903 | n04579432 whistle 904 | n04584207 wig 905 | n04589890 window screen 906 | n04590129 window shade 907 | n04591157 Windsor tie 908 | n04591713 wine bottle 909 | n04592741 wing 910 | n04596742 wok 911 | n04597913 wooden spoon 912 | n04599235 wool, woolen, woollen 913 | n04604644 worm fence, snake fence, snake-rail fence, Virginia fence 914 | n04606251 wreck 915 | n04612504 yawl 916 | n04613696 yurt 917 | n06359193 web site, website, internet site, site 918 | n06596364 comic book 919 | n06785654 crossword puzzle, crossword 920 | n06794110 street sign 921 | n06874185 traffic light, traffic signal, stoplight 922 | n07248320 book jacket, dust cover, dust jacket, dust wrapper 923 | n07565083 menu 924 | n07579787 plate 925 | n07583066 guacamole 926 | n07584110 consomme 927 | n07590611 hot pot, hotpot 928 | n07613480 trifle 929 | n07614500 ice cream, icecream 930 | n07615774 ice lolly, lolly, lollipop, popsicle 931 | n07684084 French loaf 932 | n07693725 bagel, beigel 933 | n07695742 pretzel 934 | n07697313 cheeseburger 935 | n07697537 hotdog, hot dog, red hot 936 | n07711569 mashed potato 937 | n07714571 head cabbage 938 | n07714990 broccoli 939 | n07715103 cauliflower 940 | n07716358 zucchini, courgette 941 | n07716906 spaghetti squash 942 | n07717410 acorn squash 943 | n07717556 butternut squash 944 | n07718472 cucumber, cuke 945 | n07718747 artichoke, globe artichoke 946 | n07720875 bell pepper 947 | n07730033 cardoon 948 | n07734744 mushroom 949 | n07742313 Granny Smith 950 | n07745940 strawberry 951 | n07747607 orange 952 | n07749582 lemon 953 | n07753113 fig 954 | n07753275 pineapple, ananas 955 | n07753592 banana 956 | n07754684 jackfruit, jak, jack 957 | n07760859 custard apple 958 | n07768694 pomegranate 959 | n07802026 hay 960 | n07831146 carbonara 961 | n07836838 chocolate sauce, chocolate syrup 962 | n07860988 dough 963 | n07871810 meat loaf, meatloaf 964 | n07873807 pizza, pizza pie 965 | n07875152 potpie 966 | n07880968 burrito 967 | n07892512 red wine 968 | n07920052 espresso 969 | n07930864 cup 970 | n07932039 eggnog 971 | n09193705 alp 972 | n09229709 bubble 973 | n09246464 cliff, drop, drop-off 974 | n09256479 coral reef 975 | n09288635 geyser 976 | n09332890 lakeside, lakeshore 977 | n09399592 promontory, headland, head, foreland 978 | n09421951 sandbar, sand bar 979 | n09428293 seashore, coast, seacoast, sea-coast 980 | n09468604 valley, vale 981 | n09472597 volcano 982 | n09835506 ballplayer, baseball player 983 | n10148035 groom, bridegroom 984 | n10565667 scuba diver 985 | n11879895 rapeseed 986 | n11939491 daisy 987 | n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | n12144580 corn 989 | n12267677 acorn 990 | n12620546 hip, rose hip, rosehip 991 | n12768682 buckeye, horse chestnut, conker 992 | n12985857 coral fungus 993 | n12998815 agaric 994 | n13037406 gyromitra 995 | n13040303 stinkhorn, carrion fungus 996 | n13044778 earthstar 997 | n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | n13054560 bolete 999 | n13133613 ear, spike, capitulum 1000 | n15075141 toilet tissue, toilet paper, bathroom tissue 1001 | -------------------------------------------------------------------------------- /doc/develop_releated.md: -------------------------------------------------------------------------------- 1 | # 分类网络相关收集 2 | 3 | --- 4 | ## 开发相关 5 | 6 | ### 微调 7 | 8 | 增加常用网络的微调,直接将pytorch vision中的模型转换到新的数据集模型中。 9 | 10 | [Almost any Image Classification Problem using PyTorch](https://medium.com/@14prakash/almost-any-image-classification-problem-using-pytorch-i-am-in-love-with-pytorch-26c7aa979ec4) 11 | 12 | [pytorch_classifiers](https://github.com/Prakashvanapalli/pytorch_classifiers) 13 | 14 | [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) 15 | 16 | 17 | --- 18 | ## 论文资料 19 | - CondenseNet: An Efficient DenseNet using Learned Group Convolutions [论文](https://arxiv.org/abs/1711.09224) [代码](https://github.com/ShichenLiu/CondenseNet) 20 | 一种高效地分类网络,提升前向推断速度。 21 | - Squeeze-and-Excitation Networks [论文](https://arxiv.org/abs/1709.01507) 22 | ImageNet2017的冠军模型 [代码](https://github.com/moskomule/senet.pytorch) 23 | - imagenet类别,[imagenet1000_clsid_to_human.txt](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) 24 | - ... -------------------------------------------------------------------------------- /doc/mobilenet_implement.md: -------------------------------------------------------------------------------- 1 | # MobileNet网络实现 2 | 3 | 参考如下: 4 | - [MobileNet-Caffe](https://github.com/shicai/MobileNet-Caffe) 5 | - [pytorch-mobilenet](https://github.com/marvis/pytorch-mobilenet) 6 | - [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) 7 | - [MobileNetV2阅读笔记](https://zhuanlan.zhihu.com/p/33052910) 8 | - ... -------------------------------------------------------------------------------- /doc/pytorch_net_visual.md: -------------------------------------------------------------------------------- 1 | # 网络结构可视化 2 | 3 | 将pytorch模型转换为caffe,然后利用netscope可视化,参考如下: 4 | - [nn_tools](https://github.com/hahnyuan/nn_tools) 5 | - [netscope](http://ethereon.github.io/netscope/#/editor) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | tensorflow 3 | matplotlib==1.3.1 4 | numpy==1.14.3 5 | scikit_image==0.13.1 6 | torch==0.4.0 7 | torchvision==0.2.1 8 | scipy==1.1.0 9 | skimage==0.0 10 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /test/context.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | # 增加上下文package 5 | import os 6 | import sys 7 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | import cifarclassify, tfcifarclassify 10 | -------------------------------------------------------------------------------- /test/test_tf_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # [Unit tests](https://www.tensorflow.org/api_guides/python/test) 3 | 4 | import tensorflow as tf 5 | import os 6 | import time 7 | 8 | from context import tfcifarclassify 9 | from tfcifarclassify.modelloader.imagenet import resnet 10 | from tfcifarclassify.dataloader import cifar_input 11 | 12 | FLAGS = tf.app.flags.FLAGS 13 | tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.') 14 | tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.') 15 | tf.app.flags.DEFINE_string('train_data_path', os.path.expanduser('~/Data/cifar-10-batches-bin/data_batch_1.bin'), 'Filepattern for training data.') 16 | tf.app.flags.DEFINE_string('eval_data_path', os.path.expanduser('~/Data/cifar-10-batches-bin/data_batch_1.bin'), 'Filepattern for eval data') 17 | tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.') 18 | tf.app.flags.DEFINE_string('train_dir', '', 'Directory to keep training outputs.') 19 | tf.app.flags.DEFINE_string('eval_dir', '', 'Directory to keep eval outputs.') 20 | tf.app.flags.DEFINE_integer('eval_batch_count', 50, 'Number of batches to eval.') 21 | tf.app.flags.DEFINE_bool('eval_once', False, 'Whether evaluate the model only once.') 22 | tf.app.flags.DEFINE_string('log_root', '/tmp/{}'.format(time.time()), 'Directory to keep the checkpoints. Should be a parent directory of FLAGS.train_dir/eval_dir.') 23 | tf.app.flags.DEFINE_integer('num_gpus', 0, 'Number of gpus used for training. (0 or 1)') 24 | 25 | 26 | def main(_): 27 | dev = '/cpu:0' 28 | batch_size = 1 29 | hps = resnet.HParams(batch_size=batch_size, 30 | num_classes=10, 31 | min_lrn_rate=0.0001, 32 | lrn_rate=0.1, 33 | num_residual_units=5, 34 | use_bottleneck=False, 35 | weight_decay_rate=0.0002, 36 | relu_leakiness=0.1, 37 | optimizer='mom') 38 | with tf.device(dev): 39 | images, labels = cifar_input.build_input('cifar10', FLAGS.train_data_path, hps.batch_size, FLAGS.mode) 40 | # print('images:', images) 41 | # print('labels:', labels) 42 | model = resnet.ResNet(hps, images, labels, FLAGS.mode) 43 | model.build_graph() 44 | 45 | truth = tf.argmax(model.labels, axis=1) 46 | predictions = tf.argmax(model.predictions, axis=1) 47 | precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth))) 48 | 49 | logging_hook = tf.train.LoggingTensorHook( 50 | tensors={'step': model.global_step, 'loss': model.cost, 'precision': precision}, every_n_iter=100) 51 | 52 | with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.log_root, hooks=[logging_hook], save_summaries_steps=0, config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess: 53 | while not mon_sess.should_stop(): 54 | mon_sess.run(model.train_op) 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.logging.set_verbosity(tf.logging.INFO) 59 | tf.app.run() 60 | -------------------------------------------------------------------------------- /tfcifarclassify/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /tfcifarclassify/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | -------------------------------------------------------------------------------- /tfcifarclassify/dataloader/cifar_input.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | def build_input(dataset, data_path, batch_size, mode): 6 | """Build CIFAR image and labels. 7 | Args: 8 | dataset: Either 'cifar10' or 'cifar100'. 9 | data_path: Filename for data. 10 | batch_size: Input batch size. 11 | mode: Either 'train' or 'eval'. 12 | Returns: 13 | images: Batches of images. [batch_size, image_size, image_size, 3] 14 | labels: Batches of labels. [batch_size, num_classes] 15 | Raises: 16 | ValueError: when the specified dataset is not supported. 17 | """ 18 | image_size = 32 19 | if dataset == 'cifar10': 20 | label_bytes = 1 21 | label_offset = 0 22 | num_classes = 10 23 | elif dataset == 'cifar100': 24 | label_bytes = 1 25 | label_offset = 1 26 | num_classes = 100 27 | else: 28 | raise ValueError('Not supported dataset %s', dataset) 29 | 30 | depth = 3 31 | image_bytes = image_size * image_size * depth 32 | record_bytes = label_bytes + label_offset + image_bytes 33 | 34 | data_files = tf.gfile.Glob(data_path) 35 | file_queue = tf.train.string_input_producer(data_files, shuffle=True) 36 | # Read examples from files in the filename queue. 37 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 38 | _, value = reader.read(file_queue) 39 | 40 | # Convert these examples to dense labels and processed images. 41 | record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes]) 42 | label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32) 43 | # Convert from string to [depth * height * width] to [depth, height, width]. 44 | depth_major = tf.reshape(tf.slice(record, [label_offset + label_bytes], [image_bytes]), 45 | [depth, image_size, image_size]) 46 | # Convert from [depth, height, width] to [height, width, depth]. 47 | image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) 48 | 49 | if mode == 'train': 50 | image = tf.image.resize_image_with_crop_or_pad(image, image_size+4, image_size+4) 51 | image = tf.random_crop(image, [image_size, image_size, 3]) 52 | image = tf.image.random_flip_left_right(image) 53 | # Brightness/saturation/constrast provides small gains .2%~.5% on cifar. 54 | # image = tf.image.random_brightness(image, max_delta=63. / 255.) 55 | # image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 56 | # image = tf.image.random_contrast(image, lower=0.2, upper=1.8) 57 | image = tf.image.per_image_standardization(image) 58 | 59 | example_queue = tf.RandomShuffleQueue(capacity=16 * batch_size, min_after_dequeue=8 * batch_size, dtypes=[tf.float32, tf.int32], shapes=[[image_size, image_size, depth], [1]]) 60 | num_threads = 16 61 | else: 62 | image = tf.image.resize_image_with_crop_or_pad(image, image_size, image_size) 63 | image = tf.image.per_image_standardization(image) 64 | 65 | example_queue = tf.FIFOQueue(3 * batch_size, dtypes=[tf.float32, tf.int32], shapes=[[image_size, image_size, depth], [1]]) 66 | num_threads = 1 67 | 68 | example_enqueue_op = example_queue.enqueue([image, label]) 69 | tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner(example_queue, [example_enqueue_op] * num_threads)) 70 | 71 | # Read 'batch' labels + images from the example queue. 72 | images, labels = example_queue.dequeue_many(batch_size) 73 | labels = tf.reshape(labels, [batch_size, 1]) 74 | indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1]) 75 | labels = tf.sparse_to_dense(tf.concat(values=[indices, labels], axis=1), [batch_size, num_classes], 1.0, 0.0) 76 | 77 | assert len(images.get_shape()) == 4 78 | assert images.get_shape()[0] == batch_size 79 | assert images.get_shape()[-1] == 3 80 | assert len(labels.get_shape()) == 2 81 | assert labels.get_shape()[0] == batch_size 82 | assert labels.get_shape()[1] == num_classes 83 | 84 | # Display the training images in the visualizer. 85 | tf.summary.image('images', images) 86 | return images, labels 87 | -------------------------------------------------------------------------------- /tfcifarclassify/modelloader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /tfcifarclassify/modelloader/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /tfcifarclassify/modelloader/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /tfcifarclassify/modelloader/imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # code is from [resnet_model.py](https://github.com/tensorflow/models/blob/master/research/resnet/resnet_model.py) 3 | 4 | from collections import namedtuple 5 | import numpy as np 6 | import tensorflow as tf 7 | import six 8 | from tensorflow.python.training import moving_averages 9 | 10 | 11 | # Hyper Parameters 12 | HParams = namedtuple('HParams', 'batch_size, num_classes, min_lrn_rate, lrn_rate, num_residual_units, use_bottleneck, weight_decay_rate, relu_leakiness, optimizer') 13 | 14 | class ResNet(object): 15 | def __init__(self, hps, images, labels, mode): 16 | """ResNet constructor. 17 | Args: 18 | hps: Hyperparameters. 19 | images: Batches of images. [batch_size, image_size, image_size, 3] 20 | labels: Batches of labels. [batch_size, num_classes] 21 | mode: One of 'train' and 'eval'. 22 | """ 23 | self.hps = hps 24 | self._images = images 25 | self.labels = labels 26 | self.mode = mode 27 | 28 | self._extra_train_ops = [] 29 | 30 | def build_graph(self): 31 | """ 32 | Build a whole graph for the model. 33 | """ 34 | self.global_step = tf.train.get_or_create_global_step() 35 | self._build_model() 36 | if self.mode == 'train': 37 | self._build_train_op() 38 | self.summaries = tf.summary.merge_all() 39 | 40 | def _stride_arr(self, stride): 41 | """ 42 | Map a stride scalar to the stride array for tf.nn.conv2d. 43 | """ 44 | return [1, stride, stride, 1] 45 | 46 | def _build_model(self): 47 | """ 48 | Build the core model within the graph. 49 | """ 50 | with tf.variable_scope('init'): 51 | x = self._images 52 | x = self._conv('init_conv', x, filter_size=3, in_filters=3, out_filters=16, strides=self._stride_arr(1)) 53 | 54 | strides = [1, 2, 2] 55 | activate_before_residual = [True, False, False] # relu(F+x) or x+relu(F) 56 | 57 | # check whether use bottleneck 58 | if self.hps.use_bottleneck: 59 | res_func = self._bottleneck_residual 60 | filters = [16, 64, 128, 256] 61 | else: 62 | res_func = self._residual 63 | filters = [16, 16, 32, 64] 64 | 65 | with tf.variable_scope('unit_1_0'): 66 | x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), activate_before_residual[0]) 67 | for i in six.moves.range(1, self.hps.num_residual_units): 68 | with tf.variable_scope('unit_1_%d' % i): 69 | x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) 70 | 71 | with tf.variable_scope('unit_2_0'): 72 | x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), activate_before_residual[1]) 73 | for i in six.moves.range(1, self.hps.num_residual_units): 74 | with tf.variable_scope('unit_2_%d' % i): 75 | x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) 76 | 77 | with tf.variable_scope('unit_3_0'): 78 | x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), activate_before_residual[2]) 79 | for i in six.moves.range(1, self.hps.num_residual_units): 80 | with tf.variable_scope('unit_3_%d' % i): 81 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) 82 | 83 | with tf.variable_scope('unit_last'): 84 | x = self._batch_norm('final_bn', x) 85 | x = self._relu(x, self.hps.relu_leakiness) 86 | x = self._global_avg_pool(x) 87 | 88 | with tf.variable_scope('logit'): 89 | logits = self._fully_connected(x, self.hps.num_classes) 90 | self.predictions = tf.nn.softmax(logits) 91 | 92 | with tf.variable_scope('costs'): 93 | xent = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.labels) 94 | self.cost = tf.reduce_mean(xent, name='xent') 95 | self.cost += self._decay() 96 | 97 | tf.summary.scalar('cost', self.cost) 98 | 99 | def _build_train_op(self): 100 | """ 101 | Build training specific ops for the graph. 102 | """ 103 | self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32) 104 | tf.summary.scalar('learning_rate', self.lrn_rate) 105 | 106 | trainable_variables = tf.trainable_variables() 107 | grads = tf.gradients(self.cost, trainable_variables) 108 | 109 | if self.hps.optimizer == 'sgd': 110 | optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate) 111 | elif self.hps.optimizer == 'mom': 112 | optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9) 113 | 114 | apply_op = optimizer.apply_gradients(zip(grads, trainable_variables), global_step=self.global_step, name='train_step') 115 | 116 | train_ops = [apply_op] + self._extra_train_ops 117 | self.train_op = tf.group(*train_ops) 118 | 119 | # TODO(xpan): Consider batch_norm in contrib/layers/python/layers/layers.py 120 | def _batch_norm(self, name, x): 121 | """Batch normalization.""" 122 | with tf.variable_scope(name): 123 | params_shape = [x.get_shape()[-1]] 124 | 125 | beta = tf.get_variable( 126 | 'beta', params_shape, tf.float32, 127 | initializer=tf.constant_initializer(0.0, tf.float32)) 128 | gamma = tf.get_variable( 129 | 'gamma', params_shape, tf.float32, 130 | initializer=tf.constant_initializer(1.0, tf.float32)) 131 | 132 | if self.mode == 'train': 133 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 134 | 135 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) 136 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32), trainable=False) 137 | 138 | self._extra_train_ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9)) 139 | self._extra_train_ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9)) 140 | else: 141 | mean = tf.get_variable('moving_mean', params_shape, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) 142 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32), trainable=False) 143 | tf.summary.histogram(mean.op.name, mean) 144 | tf.summary.histogram(variance.op.name, variance) 145 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. 146 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 0.001) 147 | y.set_shape(x.get_shape()) 148 | return y 149 | 150 | def _residual(self, x, in_filter, out_filter, stride, activate_before_residual=False): 151 | """ 152 | Residual unit with 2 sub layers. 153 | """ 154 | if activate_before_residual: 155 | with tf.variable_scope('shared_activation'): 156 | x = self._batch_norm('init_bn', x) 157 | x = self._relu(x, self.hps.relu_leakiness) 158 | orig_x = x 159 | else: 160 | with tf.variable_scope('residual_only_activation'): 161 | orig_x = x 162 | x = self._batch_norm('init_bn', x) 163 | x = self._relu(x, self.hps.relu_leakiness) 164 | 165 | with tf.variable_scope('sub1'): 166 | x = self._conv('conv1', x, 3, in_filter, out_filter, stride) 167 | 168 | with tf.variable_scope('sub2'): 169 | x = self._batch_norm('bn2', x) 170 | x = self._relu(x, self.hps.relu_leakiness) 171 | x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) 172 | 173 | with tf.variable_scope('sub_add'): 174 | # when in_filter is not equal to out_filter, use average pool 175 | if in_filter != out_filter: 176 | orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') 177 | orig_x = tf.pad(orig_x, [[0, 0], [0, 0], [0, 0], [(out_filter-in_filter)//2, (out_filter-in_filter)//2]]) 178 | x += orig_x 179 | 180 | tf.logging.debug('image after unit %s', x.get_shape()) 181 | return x 182 | 183 | def _bottleneck_residual(self, x, in_filter, out_filter, stride, activate_before_residual=False): 184 | """ 185 | Bottleneck residual unit with 3 sub layers. 186 | """ 187 | if activate_before_residual: 188 | with tf.variable_scope('common_bn_relu'): 189 | x = self._batch_norm('init_bn', x) 190 | x = self._relu(x, self.hps.relu_leakiness) 191 | orig_x = x 192 | else: 193 | with tf.variable_scope('residual_bn_relu'): 194 | orig_x = x 195 | x = self._batch_norm('init_bn', x) 196 | x = self._relu(x, self.hps.relu_leakiness) 197 | 198 | with tf.variable_scope('sub1'): 199 | x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride) 200 | 201 | with tf.variable_scope('sub2'): 202 | x = self._batch_norm('bn2', x) 203 | x = self._relu(x, self.hps.relu_leakiness) 204 | x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1]) 205 | 206 | with tf.variable_scope('sub3'): 207 | x = self._batch_norm('bn3', x) 208 | x = self._relu(x, self.hps.relu_leakiness) 209 | x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1]) 210 | 211 | with tf.variable_scope('sub_add'): 212 | if in_filter != out_filter: 213 | orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride) 214 | x += orig_x 215 | 216 | tf.logging.info('image after unit %s', x.get_shape()) 217 | return x 218 | 219 | def _decay(self): 220 | """ 221 | L2 weight decay loss. 222 | """ 223 | costs = [] 224 | for var in tf.trainable_variables(): 225 | if var.op.name.find(r'DW') > 0: 226 | # DW means decay weight 227 | costs.append(tf.nn.l2_loss(var)) 228 | # tf.summary.histogram(var.op.name, var) 229 | 230 | return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs)) 231 | 232 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides): 233 | """ 234 | Convolution. 235 | """ 236 | with tf.variable_scope(name): 237 | n = filter_size * filter_size * out_filters 238 | kernel = tf.get_variable('DW', [filter_size, filter_size, in_filters, out_filters], tf.float32, initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/n))) 239 | return tf.nn.conv2d(x, kernel, strides, padding='SAME') 240 | 241 | def _relu(self, x, leakiness=0.0): 242 | """ 243 | Relu, with optional leaky support. 244 | """ 245 | # if x<0 then leakiness * x otherwise x 246 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') 247 | 248 | def _fully_connected(self, x, out_dim): 249 | """ 250 | FullyConnected layer for final output. 251 | """ 252 | x = tf.reshape(x, [self.hps.batch_size, -1]) 253 | w = tf.get_variable('DW', [x.get_shape()[1], out_dim], initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 254 | b = tf.get_variable('biases', [out_dim], initializer=tf.constant_initializer()) 255 | return tf.nn.xw_plus_b(x, w, b) 256 | 257 | def _global_avg_pool(self, x): 258 | assert x.get_shape().ndims == 4 259 | return tf.reduce_mean(x, [1, 2]) 260 | -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import os 5 | 6 | import torch 7 | import torchvision 8 | from torch import nn 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torchvision import transforms 12 | import argparse 13 | import visdom 14 | import numpy as np 15 | 16 | from cifarclassify.modelloader.cifar.resnet import ResNet18 17 | from cifarclassify.modelloader.cifar.wide_resnet import wide_resnet_16_8 18 | from cifarclassify.modelloader.cifar.alexnet import AlexNet 19 | 20 | 21 | def train(args): 22 | if args.vis: 23 | vis = visdom.Visdom() 24 | vis.close() 25 | transform_train = transforms.Compose([ 26 | transforms.RandomCrop(32, padding=4), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 30 | ]) 31 | 32 | transform_test = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 35 | ]) 36 | 37 | trainset = torchvision.datasets.CIFAR10(root=os.path.expanduser('~/Data'), train=True, download=True, transform=transform_train) 38 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) 39 | 40 | valset = torchvision.datasets.CIFAR10(root=os.path.expanduser('~/Data'), train=False, download=True, transform=transform_test) 41 | valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False) 42 | 43 | start_epoch = 0 44 | 45 | if args.structure == 'AlexNet': 46 | model = AlexNet(n_classes=10) 47 | # if args.structure == 'wide_resnet_16_8': 48 | # model = wide_resnet_16_8(n_classes=32) 49 | elif args.structure == 'ResNet18': 50 | model = ResNet18() 51 | else: 52 | print('not valid model name') 53 | exit(0) 54 | 55 | if args.resume_model_state_dict != '': 56 | start_epoch_id1 = args.resume_model_state_dict.rfind('_') 57 | start_epoch_id2 = args.resume_model_state_dict.rfind('.') 58 | start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) 59 | model.load_state_dict(torch.load(args.resume_model_state_dict)) 60 | 61 | if args.cuda: 62 | model.cuda() 63 | 64 | criterion = nn.CrossEntropyLoss() 65 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) 66 | 67 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 68 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250, 350], gamma=0.1) 69 | 70 | for epoch in range(start_epoch, 20000, 1): 71 | print('epoch:', epoch) 72 | scheduler.step() 73 | model.train() 74 | # loss_epoch = 0 75 | # loss_avg_epoch = 0 76 | # data_count = 0 77 | 78 | if args.vis: 79 | win = 'lr step' 80 | lr = scheduler.get_lr() 81 | lr = np.array(lr) 82 | # print('lr:', lr) 83 | win_res = vis.line(X=np.ones(1) * epoch, Y=lr, win=win, update='append', name=win) 84 | if win_res != win: 85 | vis.line(X=np.ones(1) * epoch, Y=lr, win=win, name=win, opts=dict(title=win, xlabel='iteration', ylabel='lr')) 86 | 87 | for i, (imgs, labels) in enumerate(trainloader): 88 | # data_count = i 89 | # print(i) 90 | imgs, labels = Variable(imgs), Variable(labels) 91 | 92 | if args.cuda: 93 | imgs = imgs.cuda() 94 | labels = labels.cuda() 95 | 96 | # 训练优化参数 97 | optimizer.zero_grad() 98 | 99 | outputs = model(imgs) 100 | loss = criterion(outputs, labels) 101 | # print('loss:', loss) 102 | loss_numpy = loss.cpu().data.numpy() 103 | loss_numpy = loss_numpy[np.newaxis] 104 | # print('loss_numpy.shape:', loss_numpy.shape) 105 | # print('loss_numpy:', loss_numpy) 106 | # loss_epoch += loss_numpy 107 | if args.vis: 108 | win = 'loss iterations' 109 | # print(trainset.__len__()) 110 | # print(epoch * trainset.__len__() / (args.batch_size * 1.0) + i) 111 | win_res = vis.line(X=np.ones(1) * (epoch*trainset.__len__()/(args.batch_size*1.0) + i), Y=loss_numpy, win=win, update='append', name=win) 112 | if win_res != win: 113 | vis.line(X=np.ones(1) * (epoch*trainset.__len__()/(args.batch_size*1.0) + i), Y=loss_numpy, win=win, name=win, opts=dict(title=win, xlabel='iteration', ylabel='loss')) 114 | loss.backward() 115 | 116 | optimizer.step() 117 | # if i == 10: 118 | # break 119 | # break 120 | 121 | # 输出一个周期后的loss 122 | # loss_avg_epoch = loss_epoch / (data_count * args.batch_size * 1.0) 123 | # print('loss_avg_epoch:', loss_avg_epoch) 124 | 125 | # val result on val dataset and pick best to save 126 | if args.val_interval > 0 and epoch % args.val_interval == 0: 127 | # print('----starting val----') 128 | model.eval() 129 | val_correct = 0 130 | val_data_count = valset.__len__() * 1.0 131 | for val_i, (val_imgs, val_labels) in enumerate(valloader): 132 | val_imgs, val_labels = Variable(val_imgs), Variable(val_labels) 133 | 134 | if args.cuda: 135 | val_imgs = val_imgs.cuda() 136 | val_labels = val_labels.cuda() 137 | 138 | # print('val_imgs.shape:', val_imgs.shape) 139 | # print('val_labels.shape:', val_labels.shape) 140 | val_outputs = model(val_imgs) 141 | # print('val_outputs.shape:', val_outputs.shape) 142 | val_pred = val_outputs.cpu().data.max(1)[1].numpy() 143 | # print('val_pred:', val_pred) 144 | # print('val_pred.shape:', val_pred.shape) 145 | val_labels_np = val_labels.cpu().data.numpy() 146 | # print('val_labels_np:', val_labels_np) 147 | val_correct += sum(val_labels_np==val_pred) 148 | # print('val_correct:', val_correct) 149 | # break 150 | val_acc = val_correct * 1.0 / val_data_count 151 | # print('val_acc:', val_acc) 152 | if args.vis: 153 | win = 'acc epoch' 154 | val_acc_expand = np.expand_dims(val_acc, axis=0) 155 | win_res = vis.line(X=np.ones(1) * epoch * args.val_interval, Y=val_acc_expand, win=win, update='append', name=win) 156 | if win_res != win: 157 | vis.line(X=np.ones(1) * epoch * args.val_interval, Y=val_acc_expand, win=win, name=win, opts=dict(title=win, xlabel='epoch', ylabel='acc')) 158 | # print('----ending val----') 159 | # 存储模型 160 | # if args.save_model and epoch%args.save_epoch==0 and epoch != 0: 161 | # torch.save(model.state_dict(), '{}_cifar10_{}.pt'.format(args.structure, epoch)) 162 | 163 | 164 | if __name__ == '__main__': 165 | parser = argparse.ArgumentParser(description='training parameter setting') 166 | parser.add_argument('--structure', type=str, default='ResNet18', help='use the net structure to segment [ AlexNet ]') 167 | parser.add_argument('--resume_model', type=str, default='', help='resume model path [ AlexNet_cifar10_0.pkl ]') 168 | parser.add_argument('--resume_model_state_dict', type=str, default='', help='resume model state dict path [ AlexNet_cifar10_0.pt ]') 169 | parser.add_argument('--save_model', type=bool, default=False, help='save model [ False ]') 170 | parser.add_argument('--save_epoch', type=int, default=1, help='save model after epoch [ 1 ]') 171 | parser.add_argument('--init_vgg16', type=bool, default=False, help='init model using vgg16 weights [ False ]') 172 | parser.add_argument('--dataset_path', type=str, default='', help='train dataset path [ /home/cgf/Data/CamVid ]') 173 | parser.add_argument('--data_augment', type=bool, default=False, help='enlarge the training data [ False ]') 174 | parser.add_argument('--batch_size', type=int, default=128, help='train dataset batch size [ 128 ]') 175 | parser.add_argument('--val_interval', type=int, default=1, help='val dataset interval unit epoch [ 3 ]') 176 | parser.add_argument('--lr', type=float, default=1e-1, help='train learning rate [ 0.01 ]') 177 | parser.add_argument('--vis', type=bool, default=False, help='visualize the training results [ False ]') 178 | parser.add_argument('--cuda', type=bool, default=False, help='use cuda [ False ]') 179 | args = parser.parse_args() 180 | print(args) 181 | train(args) 182 | -------------------------------------------------------------------------------- /train_imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | import os 5 | 6 | import torch 7 | import torchvision 8 | from torch import nn 9 | from torch import optim 10 | from torch.autograd import Variable 11 | from torchvision import transforms 12 | import argparse 13 | import visdom 14 | import numpy as np 15 | 16 | from cifarclassify.dataloader.bearing_loader import BearingLoader 17 | from cifarclassify.dataloader.caltech101_loader import Caltech101Loader 18 | from cifarclassify.modelloader.imagenet.alexnet import AlexNet 19 | from cifarclassify.modelloader.imagenet.googlenet import GoogLeNet 20 | from cifarclassify.modelloader.imagenet.peleenet import PeleeNet 21 | from cifarclassify.modelloader.imagenet.resnet import resnet18, resnet50 22 | 23 | 24 | def train(args): 25 | if args.vis: 26 | vis = visdom.Visdom() 27 | vis.close() 28 | 29 | local_path = os.path.expanduser(args.dataset_path) 30 | trainset = None 31 | valset = None 32 | if args.dataset == 'Bearing': 33 | trainset = BearingLoader(local_path, is_transform=True, is_augment=False, split='train') 34 | valset = BearingLoader(local_path, is_transform=True, is_augment=False, split='val') 35 | elif args.dataset == 'Caltech101': 36 | trainset = Caltech101Loader(local_path, is_transform=True, is_augment=False, split='train') 37 | valset = Caltech101Loader(local_path, is_transform=True, is_augment=False, split='val') 38 | else: 39 | print('{} dataset does not implement'.format(args.dataset)) 40 | exit(0) 41 | 42 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) 43 | valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False) 44 | 45 | start_epoch = 0 46 | 47 | if args.structure == 'AlexNet': 48 | model = AlexNet(n_classes=trainset.n_classes, pretrained=args.init_vgg16) 49 | elif args.structure == 'resnet18': 50 | model = resnet18(n_classes=trainset.n_classes, pretrained=args.init_vgg16) 51 | elif args.structure == 'resnet50': 52 | model = resnet50(n_classes=trainset.n_classes, pretrained=args.init_vgg16) 53 | elif args.structure == 'GoogLeNet': 54 | model = GoogLeNet(n_classes=trainset.n_classes) 55 | elif args.structure == 'PeleeNet': 56 | model = PeleeNet(n_classes=trainset.n_classes) 57 | else: 58 | print('not valid model name') 59 | exit(0) 60 | 61 | if args.resume_model_state_dict != '': 62 | start_epoch_id1 = args.resume_model_state_dict.rfind('_') 63 | start_epoch_id2 = args.resume_model_state_dict.rfind('.') 64 | start_epoch = int(args.resume_model_state_dict[start_epoch_id1 + 1:start_epoch_id2]) 65 | model.load_state_dict(torch.load(args.resume_model_state_dict)) 66 | 67 | if args.cuda: 68 | model.cuda() 69 | 70 | criterion = nn.CrossEntropyLoss() 71 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) 72 | 73 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 74 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250, 350], gamma=0.1) 75 | 76 | for epoch in range(start_epoch, 20000, 1): 77 | print('epoch:', epoch) 78 | scheduler.step() 79 | model.train() 80 | # loss_epoch = 0 81 | # loss_avg_epoch = 0 82 | # data_count = 0 83 | 84 | if args.vis: 85 | win = 'lr step' 86 | lr = scheduler.get_lr() 87 | lr = np.array(lr) 88 | # print('lr:', lr) 89 | win_res = vis.line(X=np.ones(1) * epoch, Y=lr, win=win, update='append', name=win) 90 | if win_res != win: 91 | vis.line(X=np.ones(1) * epoch, Y=lr, win=win, name=win, opts=dict(title=win, xlabel='iteration', ylabel='lr')) 92 | 93 | for i, (imgs, labels) in enumerate(trainloader): 94 | # data_count = i 95 | # print(i) 96 | imgs, labels = Variable(imgs), Variable(labels) 97 | 98 | if args.cuda: 99 | imgs = imgs.cuda() 100 | labels = labels.cuda() 101 | 102 | # 训练优化参数 103 | optimizer.zero_grad() 104 | 105 | outputs = model(imgs) 106 | loss = criterion(outputs, labels) 107 | # print('loss:', loss) 108 | loss_numpy = loss.cpu().data.numpy() 109 | loss_numpy = loss_numpy[np.newaxis] 110 | # print('loss_numpy.shape:', loss_numpy.shape) 111 | # print('loss_numpy:', loss_numpy) 112 | # loss_epoch += loss_numpy 113 | if args.vis: 114 | win = 'loss iterations' 115 | # print(trainset.__len__()) 116 | # print(epoch * trainset.__len__() / (args.batch_size * 1.0) + i) 117 | win_res = vis.line(X=np.ones(1) * (epoch*trainset.__len__()/(args.batch_size*1.0) + i), Y=loss_numpy, win=win, update='append', name=win) 118 | if win_res != win: 119 | vis.line(X=np.ones(1) * (epoch*trainset.__len__()/(args.batch_size*1.0) + i), Y=loss_numpy, win=win, name=win, opts=dict(title=win, xlabel='iteration', ylabel='loss')) 120 | loss.backward() 121 | 122 | optimizer.step() 123 | # if i == 10: 124 | # break 125 | # break 126 | 127 | # 输出一个周期后的loss 128 | # loss_avg_epoch = loss_epoch / (data_count * args.batch_size * 1.0) 129 | # print('loss_avg_epoch:', loss_avg_epoch) 130 | 131 | # val result on val dataset and pick best to save 132 | if args.val_interval > 0 and epoch % args.val_interval == 0: 133 | # print('----starting val----') 134 | model.eval() 135 | val_correct = 0 136 | val_data_count = valset.__len__() * 1.0 137 | for val_i, (val_imgs, val_labels) in enumerate(valloader): 138 | val_imgs, val_labels = Variable(val_imgs), Variable(val_labels) 139 | 140 | if args.cuda: 141 | val_imgs = val_imgs.cuda() 142 | val_labels = val_labels.cuda() 143 | 144 | # print('val_imgs.shape:', val_imgs.shape) 145 | # print('val_labels.shape:', val_labels.shape) 146 | val_outputs = model(val_imgs) 147 | # print('val_outputs.shape:', val_outputs.shape) 148 | val_pred = val_outputs.cpu().data.max(1)[1].numpy() 149 | # print('val_pred:', val_pred) 150 | # print('val_pred.shape:', val_pred.shape) 151 | val_labels_np = val_labels.cpu().data.numpy() 152 | # print('val_labels_np:', val_labels_np) 153 | val_correct += sum(val_labels_np==val_pred) 154 | # print('val_correct:', val_correct) 155 | # break 156 | val_acc = val_correct * 1.0 / val_data_count 157 | # print('val_acc:', val_acc) 158 | if args.vis: 159 | win = 'acc epoch' 160 | val_acc_expand = np.expand_dims(val_acc, axis=0) 161 | win_res = vis.line(X=np.ones(1) * epoch * args.val_interval, Y=val_acc_expand, win=win, update='append', name=win) 162 | if win_res != win: 163 | vis.line(X=np.ones(1) * epoch * args.val_interval, Y=val_acc_expand, win=win, name=win, opts=dict(title=win, xlabel='epoch', ylabel='acc')) 164 | # print('----ending val----') 165 | # 存储模型 166 | # if args.save_model and epoch%args.save_epoch==0 and epoch != 0: 167 | # torch.save(model.state_dict(), '{}_cifar10_{}.pt'.format(args.structure, epoch)) 168 | 169 | 170 | if __name__ == '__main__': 171 | parser = argparse.ArgumentParser(description='training parameter setting') 172 | parser.add_argument('--structure', type=str, default='resnet18', help='use the net structure to segment [ AlexNet ]') 173 | parser.add_argument('--resume_model', type=str, default='', help='resume model path [ AlexNet_cifar10_0.pkl ]') 174 | parser.add_argument('--resume_model_state_dict', type=str, default='', help='resume model state dict path [ AlexNet_cifar10_0.pt ]') 175 | parser.add_argument('--save_model', type=bool, default=False, help='save model [ False ]') 176 | parser.add_argument('--save_epoch', type=int, default=1, help='save model after epoch [ 1 ]') 177 | parser.add_argument('--init_vgg16', type=bool, default=False, help='init model using vgg16 weights [ False ]') 178 | parser.add_argument('--dataset', type=str, default='Caltech101', help='train dataset path [ Caltech101 Bearing ]') 179 | parser.add_argument('--dataset_path', type=str, default='~/Data/101_ObjectCategories', help='train dataset path [ ~/Data/101_ObjectCategories ~/Data/Bearing ]') 180 | parser.add_argument('--data_augment', type=bool, default=False, help='enlarge the training data [ False ]') 181 | parser.add_argument('--batch_size', type=int, default=128, help='train dataset batch size [ 128 ]') 182 | parser.add_argument('--val_interval', type=int, default=1, help='val dataset interval unit epoch [ 3 ]') 183 | parser.add_argument('--lr', type=float, default=1e-1, help='train learning rate [ 0.01 ]') 184 | parser.add_argument('--vis', type=bool, default=False, help='visualize the training results [ False ]') 185 | parser.add_argument('--cuda', type=bool, default=False, help='use cuda [ False ]') 186 | args = parser.parse_args() 187 | print(args) 188 | train(args) 189 | --------------------------------------------------------------------------------