├── README.md ├── data ├── img_37.jpg └── testImages │ ├── images │ ├── 000001.jpg │ ├── 000002.jpg │ ├── 000003.jpg │ ├── 000004.jpg │ ├── 000005.jpg │ ├── 000006.jpg │ ├── 000007.jpg │ ├── 000008.jpg │ ├── 000009.jpg │ └── 000010.jpg │ └── test_images.txt ├── imgLoader ├── cifar10.py ├── loader.py ├── myImageFloder.py └── transform.py ├── nets ├── __init__.py ├── lenet.py ├── lenetseq.py ├── test.py └── utils.py ├── train └── main.py └── transferLearning ├── resnet.py └── vgg16.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-tutorials 2 | 3 | ## Pre-Processing 4 | - imageLoader/transforms.py 5 | 6 | ## Load the images 7 | - imageLoader/cifar10.py 8 | - imageLoader/myImageLoader.py loader.py 9 | 10 | ## ConvNet 11 | - nets/* 12 | 13 | ## Train 14 | - train/* 15 | 16 | ## TransferLearning 17 | - transferLearning/* 18 | -------------------------------------------------------------------------------- /data/img_37.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/img_37.jpg -------------------------------------------------------------------------------- /data/testImages/images/000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000001.jpg -------------------------------------------------------------------------------- /data/testImages/images/000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000002.jpg -------------------------------------------------------------------------------- /data/testImages/images/000003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000003.jpg -------------------------------------------------------------------------------- /data/testImages/images/000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000004.jpg -------------------------------------------------------------------------------- /data/testImages/images/000005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000005.jpg -------------------------------------------------------------------------------- /data/testImages/images/000006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000006.jpg -------------------------------------------------------------------------------- /data/testImages/images/000007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000007.jpg -------------------------------------------------------------------------------- /data/testImages/images/000008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000008.jpg -------------------------------------------------------------------------------- /data/testImages/images/000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000009.jpg -------------------------------------------------------------------------------- /data/testImages/images/000010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/data/testImages/images/000010.jpg -------------------------------------------------------------------------------- /data/testImages/test_images.txt: -------------------------------------------------------------------------------- 1 | lefteye_x lefteye_y righteye_x righteye_y nose_x nose_y leftmouth_x leftmouth_y rightmouth_x rightmouth_y 2 | 000001.jpg 69 109 106 113 77 142 73 152 108 154 3 | 000002.jpg 69 110 107 112 81 135 70 151 108 153 4 | 000003.jpg 76 112 104 106 108 128 74 156 98 158 5 | 000004.jpg 72 113 108 108 101 138 71 155 101 151 6 | 000005.jpg 66 114 112 112 86 119 71 147 104 150 7 | 000006.jpg 71 111 106 110 94 131 74 154 102 153 8 | 000007.jpg 70 112 108 111 85 135 72 152 104 152 9 | 000008.jpg 71 110 106 111 84 137 73 155 104 153 10 | 000009.jpg 68 113 110 111 97 139 66 152 109 150 11 | 000010.jpg 68 111 108 112 89 136 70 151 107 151 -------------------------------------------------------------------------------- /imgLoader/cifar10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sat Apr 8 09:47:27 2017 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import numpy as np 12 | import cv2 13 | 14 | print("\n\t\t torchvision.datasets.CIFAR10 test \n \t\t\t\t\t\t ----by tfygg") 15 | 16 | # torchvision.datasets.CIFAR10 17 | cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True) 18 | print(cifarSet[0]) 19 | img, label = cifarSet[0] 20 | print (img) 21 | print (label) 22 | print (img.format, img.size, img.mode) 23 | img.show() 24 | 25 | -------------------------------------------------------------------------------- /imgLoader/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.11 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import numpy as np 12 | import cv2 13 | 14 | import myImageFloder as myFloder 15 | 16 | print("\n\t\t torch.utils.data.DataLoader test \n \t\t\t\t\t\t ----by tfygg") 17 | 18 | mytransform = transforms.Compose([ 19 | transforms.ToTensor() 20 | ] 21 | ) 22 | 23 | # torch.utils.data.DataLoader 24 | cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform ) 25 | cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2) 26 | print len(cifarSet) 27 | print len(cifarLoader) 28 | 29 | for i, data in enumerate(cifarLoader, 0): 30 | print(data[i][0]) 31 | # PIL 32 | img = transforms.ToPILImage()(data[i][0]) 33 | img.show() 34 | break 35 | 36 | # torch.utils.data.DataLoader 37 | imgLoader = torch.utils.data.DataLoader( 38 | myFloder.myImageFloder(root = "../data/testImages/images", label = "../data/testImages/test_images.txt", transform = mytransform ), 39 | batch_size= 2, shuffle= False, num_workers= 2) 40 | 41 | for i, data in enumerate(imgLoader, 0): 42 | print(data[i][0]) 43 | # opencv 44 | img2 = data[i][0].numpy()*255 45 | img2 = img2.astype('uint8') 46 | img2 = np.transpose(img2, (1,2,0)) 47 | img2=img2[:,:,::-1]#RGB->BGR 48 | cv2.imshow('img2', img2) 49 | cv2.waitKey() 50 | break 51 | -------------------------------------------------------------------------------- /imgLoader/myImageFloder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.11 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import os 8 | import torch 9 | import torch.utils.data as data 10 | from PIL import Image 11 | 12 | def default_loader(path): 13 | return Image.open(path).convert('RGB') 14 | 15 | class myImageFloder(data.Dataset): 16 | def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader): 17 | fh = open(label) 18 | c=0 19 | imgs=[] 20 | class_names=[] 21 | for line in fh.readlines(): 22 | if c==0: 23 | class_names=[n.strip() for n in line.rstrip().split(' ')] 24 | else: 25 | cls = line.split() 26 | fn = cls.pop(0) 27 | if os.path.isfile(os.path.join(root, fn)): 28 | imgs.append((fn, tuple([float(v) for v in cls]))) 29 | c=c+1 30 | self.root = root 31 | self.imgs = imgs 32 | self.classes = class_names 33 | self.transform = transform 34 | self.target_transform = target_transform 35 | self.loader = loader 36 | 37 | def __getitem__(self, index): 38 | fn, label = self.imgs[index] 39 | img = self.loader(os.path.join(self.root, fn)) 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | return img, torch.Tensor(label) 43 | 44 | def __len__(self): 45 | return len(self.imgs) 46 | 47 | def getName(self): 48 | return self.classes 49 | 50 | 51 | def testmyImageFloder(): 52 | dataloader = myImageFloder('../data/testImages/images', 53 | '../data/testImages/test_images.txt') 54 | print ('dataloader.getName', dataloader.getName()) 55 | 56 | for index , (img, label) in enumerate(dataloader): 57 | img.show() 58 | print ('label',label) 59 | 60 | 61 | if __name__ == "__main__": 62 | 63 | testmyImageFloder() 64 | -------------------------------------------------------------------------------- /imgLoader/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.11 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torchvision.utils as utils 11 | from PIL import Image 12 | import numpy as np 13 | import cv2 14 | 15 | print("\n\t\t torch.utils.data.DataLoader test \n \t\t\t\t\t\t ----by tfygg") 16 | 17 | img_path = "../data/img_37.jpg" 18 | 19 | # transforms.ToTensor() 20 | transform1 = transforms.Compose([ 21 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 22 | ] 23 | ) 24 | 25 | ##opencv 26 | img = cv2.imread(img_path) 27 | print("img = ", img) 28 | cv2.imshow("img", img) 29 | cv2.waitKey() 30 | 31 | img1 = transform1(img) 32 | print("img1 = ",img1) 33 | img_1 = img1.numpy()*255 34 | img_1 = img_1.astype('uint8') 35 | img_1 = np.transpose(img_1, (1,2,0)) 36 | cv2.imshow('img_1', img_1) 37 | cv2.waitKey() 38 | 39 | ##PIL 40 | img = Image.open(img_path).convert('RGB') 41 | img2 = transform1(img) 42 | utils.save_image(img2, 'test.jpg') 43 | print("img2 = ",img2) 44 | img_2 = transforms.ToPILImage()(img2).convert('RGB') 45 | print("img_2 = ",img_2) 46 | img_2.show() 47 | 48 | # transforms.Normalize() 49 | transform2 = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))#range [0, 255] ->[-1.0,1.0] 52 | ] 53 | ) 54 | transform3 = transforms.Compose([ 55 | transforms.Normalize(mean = (-1, -1, -1), std = (2, 2, 2))#range [-1.0,1.0] -> [0.0,1.0] 56 | ] 57 | ) 58 | 59 | img = Image.open(img_path).convert('RGB') 60 | img3 = transform2(img) 61 | print("img3 = ",img3) 62 | img_3 = transforms.ToPILImage()(transform3(img3)).convert('RGB') 63 | print("img_3 = ",img_3) 64 | img_3.show() 65 | 66 | 67 | # transforms.CenterCrop() 68 | transform3 = transforms.Compose([ 69 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 70 | transforms.ToPILImage(), 71 | transforms.CenterCrop((300, 300)), 72 | ] 73 | ) 74 | 75 | img = Image.open(img_path).convert('RGB') 76 | img3 = transform3(img) 77 | img3.show() 78 | 79 | # transforms.RandomCrop() 80 | transform4 = transforms.Compose([ 81 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 82 | transforms.ToPILImage(), 83 | transforms.RandomCrop((100,100)), 84 | ] 85 | ) 86 | 87 | img = Image.open(img_path).convert('RGB') 88 | img3 = transform4(img) 89 | img3.show() 90 | 91 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfygg/pytorch-tutorials/b98262678898578802c07f4291fd2be6af1dcace/nets/__init__.py -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.11 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class LeNet(nn.Module): 11 | def __init__(self): 12 | super(LeNet, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 6, 5) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.fc1 = nn.Linear(16*5*5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, 10) 18 | 19 | def forward(self, x): 20 | x = F.relu(self.conv1(x)) 21 | x = F.max_pool2d(x, 2) 22 | x = F.relu(self.conv2(x)) 23 | x = F.max_pool2d(x, 2) 24 | x = x.view(x.size(0), -1) 25 | x = F.relu(self.fc1(x)) 26 | x = F.relu(self.fc2(x)) 27 | out = self.fc3(x) 28 | return out 29 | -------------------------------------------------------------------------------- /nets/lenetseq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.11 3 | 4 | @author: tfygg 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | class LeNetSeq(nn.Module): 10 | def __init__(self): 11 | super(LeNetSeq, self).__init__() 12 | self.conv = nn.Sequential( 13 | nn.Conv2d(3, 6, 5), 14 | nn.ReLU(), 15 | nn.MaxPool2d(2), 16 | nn.Conv2d(6, 16, 5), 17 | nn.ReLU(), 18 | nn.MaxPool2d(2), 19 | ) 20 | 21 | self.fc = nn.Sequential( 22 | nn.Linear(16*5*5, 120), 23 | nn.ReLU(), 24 | nn.Linear(120, 84), 25 | nn.ReLU(), 26 | nn.Linear(84, 10) 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = out.view(x.size(0), -1) 32 | out = self.fc(x) 33 | return out -------------------------------------------------------------------------------- /nets/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.11 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import lenet 8 | import lenetseq 9 | import utils 10 | 11 | # Net 12 | net = lenet.LeNet() 13 | print(net) 14 | 15 | for index, param in enumerate(net.parameters()): 16 | print(list(param.data)) 17 | print(type(param.data), param.size()) 18 | print index, "-->", param 19 | 20 | 21 | print(net.state_dict()) 22 | print(net.state_dict().keys()) 23 | 24 | for key in net.state_dict(): 25 | print key, 'corresponds to', list(net.state_dict()[key]) 26 | 27 | 28 | #NetSeq 29 | netSeq = lenetseq.LeNetSeq() 30 | print(netSeq) 31 | 32 | utils.initNetParams(netSeq) 33 | 34 | for key in netSeq.state_dict(): 35 | print key, 'corresponds to', list(netSeq.state_dict()[key]) 36 | 37 | -------------------------------------------------------------------------------- /nets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | def initNetParams(net): 6 | '''Init net parameters.''' 7 | for m in net.modules(): 8 | if isinstance(m, nn.Conv2d): 9 | init.xavier_uniform(m.weight) 10 | if m.bias: 11 | init.constant(m.bias, 0) 12 | elif isinstance(m, nn.BatchNorm2d): 13 | init.constant(m.weight, 1) 14 | init.constant(m.bias, 0) 15 | elif isinstance(m, nn.Linear): 16 | init.normal(m.weight, std=1e-3) 17 | if m.bias: 18 | init.constant(m.bias, 0) 19 | 20 | """ 21 | w = torch.Tensor(3, 5) 22 | print(w) 23 | print("uniform :", init.uniform(w)) 24 | print("normal :", init.normal(w)) 25 | print("xavier_uniform :", init.xavier_uniform(w)) 26 | print("xavier_normal :", init.xavier_normal(w)) 27 | print("kaiming_uniform :", init.kaiming_uniform(w)) 28 | print("kaiming_normal :", init.kaiming_normal(w)) 29 | """ 30 | 31 | -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.17 3 | 4 | @author: tfygg 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import os 14 | import sys 15 | sys.path.append('../nets') 16 | import argparse 17 | 18 | import lenet 19 | from torch.autograd import Variable 20 | 21 | # Training settings 22 | parser = argparse.ArgumentParser(description='PyTorch Example') 23 | parser.add_argument('--batch-size', type=int, default=8, metavar='N', 24 | help='input batch size for training (default: 64)') 25 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N', 26 | help='input batch size for testing (default: 1000)') 27 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 28 | help='number of epochs to train (default: 10)') 29 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 30 | help='learning rate (default: 0.01)') 31 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 32 | help='SGD momentum (default: 0.5)') 33 | parser.add_argument('--no-cuda', action='store_true', default=True, 34 | help='disables CUDA training') 35 | parser.add_argument('--seed', type=int, default=1, metavar='S', 36 | help='random seed (default: 1)') 37 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 38 | help='how many batches to wait before logging training status') 39 | args = parser.parse_args() 40 | args.cuda = not args.no_cuda and torch.cuda.is_available() 41 | 42 | print('==> Preparing data..') 43 | transform_train = transforms.Compose([ 44 | transforms.RandomCrop(32, padding=4), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 53 | ]) 54 | 55 | trainset = torchvision.datasets.CIFAR10(root='../data/cifar', train=True, download=True, transform=transform_train) 56 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) 57 | 58 | testset = torchvision.datasets.CIFAR10(root='../data/cifar', train=False, download=True, transform=transform_test) 59 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2) 60 | 61 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 62 | 63 | # Net 64 | net = lenet.LeNet() 65 | print(net) 66 | 67 | criterion = nn.CrossEntropyLoss() 68 | 69 | if args.cuda: 70 | net.cuda() 71 | criterion.cuda() 72 | 73 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4) 74 | 75 | def train(epoch): 76 | net.train() 77 | for batch_idx, (data, target) in enumerate(trainloader): 78 | if args.cuda: 79 | data, target = data.cuda(), target.cuda() 80 | data, target = Variable(data), Variable(target) 81 | 82 | optimizer.zero_grad() 83 | output = net(data) 84 | loss = criterion(output, target) 85 | loss.backward() 86 | optimizer.step() 87 | 88 | if batch_idx % args.log_interval == 0: 89 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 90 | epoch, batch_idx * len(data), len(trainloader.dataset), 91 | 100. * batch_idx / len(trainloader), loss.data[0])) 92 | 93 | def test(epoch): 94 | net.eval() 95 | test_loss = 0 96 | correct = 0 97 | for data, target in testloader: 98 | if args.cuda: 99 | data, target = data.cuda(), target.cuda() 100 | data, target = Variable(data, volatile=True), Variable(target) 101 | output = net(data) 102 | test_loss = criterion(output, target) 103 | 104 | pred = output.data.max(1)[1] # get the index of the max log-probability 105 | correct += pred.eq(target.data).cpu().sum() 106 | 107 | test_loss = test_loss 108 | test_loss /= len(testloader) 109 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 110 | test_loss.data[0], correct, len(testloader.dataset), 111 | 100. * correct / len(testloader.dataset))) 112 | 113 | for epoch in range(1, args.epochs + 1): 114 | train(epoch) 115 | test(epoch) 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /transferLearning/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.17 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import matplotlib.pyplot as plt 15 | import time 16 | import copy 17 | import os 18 | import argparse 19 | 20 | # Training settings 21 | parser = argparse.ArgumentParser(description='PyTorch Example') 22 | parser.add_argument('--batch-size', type=int, default=8, metavar='N', 23 | help='input batch size for training (default: 64)') 24 | parser.add_argument('--test-batch-size', type=int, default=8, metavar='N', 25 | help='input batch size for testing (default: 1000)') 26 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 27 | help='number of epochs to train (default: 10)') 28 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 29 | help='learning rate (default: 0.01)') 30 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 31 | help='SGD momentum (default: 0.5)') 32 | parser.add_argument('--no-cuda', action='store_true', default=False, 33 | help='disables CUDA training') 34 | parser.add_argument('--seed', type=int, default=1, metavar='S', 35 | help='random seed (default: 1)') 36 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 37 | help='how many batches to wait before logging training status') 38 | args = parser.parse_args() 39 | args.cuda = not args.no_cuda and torch.cuda.is_available() 40 | 41 | print('==> Preparing data..') 42 | transform_train = transforms.Compose([ 43 | #transforms.RandomCrop(224, padding=4), 44 | transforms.Scale(224), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.Scale(224), 52 | transforms.ToTensor(), 53 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 54 | ]) 55 | 56 | trainset = torchvision.datasets.CIFAR10(root='../data/cifar', train=True, download=True, transform=transform_train) 57 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) 58 | 59 | testset = torchvision.datasets.CIFAR10(root='../data/cifar', train=False, download=True, transform=transform_test) 60 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2) 61 | 62 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 63 | 64 | # ConvNet 65 | model_ft = models.resnet18(pretrained=True) 66 | print(model_ft) 67 | 68 | for i, param in enumerate(model_ft.parameters()): 69 | param.requires_grad = False 70 | 71 | num_ftrs = model_ft.fc.in_features 72 | model_ft.fc = nn.Linear(num_ftrs, 10) 73 | print(model_ft) 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | 77 | if args.cuda: 78 | model_ft.cuda() 79 | criterion.cuda() 80 | 81 | optimizer = optim.SGD(model_ft.fc.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4) 82 | 83 | def train(epoch): 84 | model_ft.train() 85 | for batch_idx, (data, target) in enumerate(trainloader): 86 | if args.cuda: 87 | data, target = data.cuda(), target.cuda() 88 | data, target = Variable(data), Variable(target) 89 | 90 | optimizer.zero_grad() 91 | output = model_ft(data) 92 | loss = criterion(output, target) 93 | loss.backward() 94 | optimizer.step() 95 | 96 | if batch_idx % args.log_interval == 0: 97 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 98 | epoch, batch_idx * len(data), len(trainloader.dataset), 99 | 100. * batch_idx / len(trainloader), loss.data[0])) 100 | 101 | def test(epoch): 102 | model_ft.eval() 103 | test_loss = 0 104 | correct = 0 105 | for data, target in testloader: 106 | if args.cuda: 107 | data, target = data.cuda(), target.cuda() 108 | data, target = Variable(data, volatile=True), Variable(target) 109 | output = model_ft(data) 110 | test_loss = criterion(output, target) 111 | 112 | pred = output.data.max(1)[1] # get the index of the max log-probability 113 | correct += pred.eq(target.data).cpu().sum() 114 | 115 | test_loss = test_loss 116 | test_loss /= len(testloader) 117 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 118 | test_loss.data[0], correct, len(testloader.dataset), 119 | 100. * correct / len(testloader.dataset))) 120 | 121 | for epoch in range(1, args.epochs + 1): 122 | train(epoch) 123 | test(epoch) 124 | -------------------------------------------------------------------------------- /transferLearning/vgg16.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2017.6.17 3 | 4 | @author: tfygg 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import matplotlib.pyplot as plt 15 | import time 16 | import copy 17 | import os 18 | import argparse 19 | 20 | import copy 21 | 22 | # Training settings 23 | parser = argparse.ArgumentParser(description='PyTorch Example') 24 | parser.add_argument('--batch-size', type=int, default=8, metavar='N', 25 | help='input batch size for training (default: 64)') 26 | parser.add_argument('--test-batch-size', type=int, default=8, metavar='N', 27 | help='input batch size for testing (default: 1000)') 28 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 29 | help='number of epochs to train (default: 10)') 30 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 31 | help='learning rate (default: 0.01)') 32 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 33 | help='SGD momentum (default: 0.5)') 34 | parser.add_argument('--no-cuda', action='store_true', default=True, 35 | help='disables CUDA training') 36 | parser.add_argument('--seed', type=int, default=1, metavar='S', 37 | help='random seed (default: 1)') 38 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 39 | help='how many batches to wait before logging training status') 40 | args = parser.parse_args() 41 | args.cuda = not args.no_cuda and torch.cuda.is_available() 42 | 43 | print('==> Preparing data..') 44 | transform_train = transforms.Compose([ 45 | #transforms.RandomCrop(224, padding=4), 46 | transforms.Scale(224), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 50 | ]) 51 | 52 | transform_test = transforms.Compose([ 53 | transforms.Scale(224), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 56 | ]) 57 | 58 | trainset = torchvision.datasets.CIFAR10(root='../data/cifar', train=True, download=True, transform=transform_train) 59 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) 60 | 61 | testset = torchvision.datasets.CIFAR10(root='../data/cifar', train=False, download=True, transform=transform_test) 62 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2) 63 | 64 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 65 | 66 | # ConvNet 67 | model_ft = models.vgg16(pretrained=True) 68 | print(model_ft) 69 | cls = nn.Sequential( 70 | nn.Linear(512 * 7 * 7, 4096), 71 | nn.ReLU(True), 72 | nn.Dropout(), 73 | nn.Linear(4096, 4096), 74 | nn.ReLU(True), 75 | nn.Dropout(), 76 | nn.Linear(4096, 10), 77 | ) 78 | model_ft.classifier = cls 79 | print(model_ft) 80 | 81 | criterion = nn.CrossEntropyLoss() 82 | 83 | if args.cuda: 84 | model_ft.cuda() 85 | criterion.cuda() 86 | 87 | optimizer = optim.SGD(model_ft.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5e-4) 88 | 89 | def train(epoch): 90 | model_ft.train() 91 | for batch_idx, (data, target) in enumerate(trainloader): 92 | if args.cuda: 93 | data, target = data.cuda(), target.cuda() 94 | data, target = Variable(data), Variable(target) 95 | 96 | optimizer.zero_grad() 97 | output = model_ft(data) 98 | loss = criterion(output, target) 99 | loss.backward() 100 | optimizer.step() 101 | 102 | if batch_idx % args.log_interval == 0: 103 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 104 | epoch, batch_idx * len(data), len(trainloader.dataset), 105 | 100. * batch_idx / len(trainloader), loss.data[0])) 106 | 107 | def test(epoch): 108 | model_ft.eval() 109 | test_loss = 0 110 | correct = 0 111 | for data, target in testloader: 112 | if args.cuda: 113 | data, target = data.cuda(), target.cuda() 114 | data, target = Variable(data, volatile=True), Variable(target) 115 | output = model_ft(data) 116 | test_loss = criterion(output, target) 117 | 118 | pred = output.data.max(1)[1] # get the index of the max log-probability 119 | correct += pred.eq(target.data).cpu().sum() 120 | 121 | test_loss = test_loss 122 | test_loss /= len(testloader) 123 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 124 | test_loss.data[0], correct, len(testloader.dataset), 125 | 100. * correct / len(testloader.dataset))) 126 | 127 | for epoch in range(1, args.epochs + 1): 128 | train(epoch) 129 | test(epoch) 130 | --------------------------------------------------------------------------------