├── .gitignore ├── README.md ├── classifier_test.py ├── classifier_train.py ├── data ├── __init__.py └── dataset.py ├── network ├── __init__.py └── resnet.py ├── pic ├── BasicBlock_Bottleneck.png ├── ResNet34_ResNet101.jpg ├── accuracy.png ├── loss.png └── shortcut.png ├── testimg ├── img_0000.png ├── img_0014.png ├── img_0046.png ├── img_0072.png ├── img_0080.png └── img_0100.png └── utils ├── Tester.py ├── Trainer.py ├── __init__.py ├── log.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | # own ignore 106 | *.pyc 107 | .idea 108 | /data/images 109 | /model 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### 实验目的 2 | 3 | 对手势数字数据集进行分类。数据采用`./data/images/`中的数据。其中,训练集4324张,测试集484张,手势数字类别:0-5,图像大小均为64*64。 4 | 5 | ### Update 6 | 7 | - 180521:增加**多GPU**支持 8 | - 设置`classifier_train.py`及`classifier_test.py`文件中的`params.gpus `变量设定使用的GPU序号(与`nvidia-smi`命令显示的GPU需要对应)。例如:`params.gpus = [2,3]`。 9 | - **CPU模式**:设定`params.gpus = []` 10 | - 由于目前测试部分计算量小,因此当前代码执行测试步骤(`classifier_test.py`)时,仅使用指定的第一个GPU(`params.gpus[0] `) 11 | 12 | ### 步骤 13 | 14 | 使用Pytorch为工具,以ResNet34或者ResNet101为基础,实现手势识别。 15 | 16 | - 数据准备: 17 | - 训练:将image文件夹放在`./data/`路径下。[image文件下载](https://cloud.tsinghua.edu.cn/f/787490e187714336aae2/?dl=1) 18 | - 测试:将训练好的模型放在`./models/`路径下。 [模型下载](https://cloud.tsinghua.edu.cn/d/dbf0243babd443c49e21/) 19 | - 训练步骤: 20 | - 首先使用`nohup python -m visdom.server &`打开`Visdom`服务器 21 | - 然后运行`classifier_train.py`即可。 22 | - 训练好的模型将以`.pth`文件的形式保存在`./models/`文件夹下。 23 | - 注意:需根据GPU情况调整代码中的`batch_size`参数,确保显存不溢出。 24 | - ResNet34,1GPU,`batch_size=120`,显存占用<7G 25 | - ResNet101,1GPU,`batch_size=60`,显存占用<10G 26 | - 测试步骤: 27 | - 修改`classifier_test.py`文件相关参数,其中`ckpt`表示模型加载位置,`testdata_dir`表示待测试图片文件夹。注意`ckpt`需要与`model`选择相对应。 28 | - 然后运行`classifier_test.py`即可。在控制台输出每张图片的预测结果。 29 | 30 | 31 | ### 方法 32 | 33 | - 使用的库:PIL、torch、torchvision、numpy、visdom 34 | 35 | - ResNet: 36 | 37 | 对ResNet34及ResNet101两种网络进行实验。为了节省较深网络中的参数,ResNet34及ResNet101分别具有两种不同的基本“shortcut connection”结构。ResNet34使用BasicBlock,ResNet101使用 Bottleneck作为“shortcut connection”。 38 | 39 | ![BasicBlock and Bottleneck](./pic/BasicBlock_Bottleneck.png) 40 | 41 | ![ResNet34 and ResNet101](./pic/ResNet34_ResNet101.jpg) 42 | 43 | ### 训练代码流程 44 | 45 | 1. Hyper-params: 设置数据加载路径、模型保存路径、初始学习率等参数。 46 | 2. Training parameters: 用于定义模型训练中的相关参数,例如最大迭代次数、优化器、损失函数、是否使用GPU等、模型保存频率等 47 | 3. load data: 定义了用于读取数据的Hand类,在其中实现了数据、标签读取及预处理过程。预处理过程在`__getitem__`中。 48 | 4. models: 从定义的ResNet类,实例化ResNet34及ResNet101网络模型。 49 | 5. optimizer、criterion、lr_scheduler: 定义优化器为SGD优化器,损失函数为CrossEntropyLoss,学习率调整策略采用ReduceLROnPlateau。 50 | 6. trainer: 定义了用于模型训练和验证的类Trainer,trainer为Trainer的实例化。在Trainer的构造函数中根据步骤二中的参数设定,对训练过程中的参数进行设置,包括训练数据、测试数据、模型、是否使用GPU等。 51 | Trainer中定义了训练和测试函数,分别为`train()`和`_val_one_epoch()`。`train()`函数中,根据设定的最大循环次数进行训练,每次循环调用`_train_one_epoch()`函数进行单步训练; 52 | 训练过程中的loss保存在loss_meter中,confusion_matrix中保存具体预测结果; 53 | `_val_one_epoch()`函数对测试集在当前训练模型上的表现进行测试,具体预测结果保存在val_cm中,预测精度保存在val_accuracy中; 54 | 最后,通过`Visdom`工具对结果进行输出,包括loss和accuracy以及训练日志。可以在浏览器地址 `http://localhost:8097` 中查看结果。 55 | 56 | ### 测试代码流程 57 | 58 | 1. Test parameters: 用于定义模型测试中的相关参数 59 | 2. models: 从定义的ResNet类,实例化ResNet34及ResNet101网络模型。 60 | 3. tester: 对测试类Tester实例化,Tester中主要进行模型加载函数与预测函数。 61 | `_load_ckpt()`函数加载模型; 62 | `test()`函数进行预测,其中定义了对单张图片进行预处理的过程,并输出预测结果。 63 | 64 | ### Result 65 | 66 | - Loss 67 | 68 | ![](./pic/loss.png) 69 | 70 | - accuracy 71 | 72 | ![](./pic/accuracy.png) 73 | 74 | - 预测结果: 75 | 76 | ``` 77 | Processing image: img_0046.png 78 | Prediction number: 0 79 | Processing image: img_0000.png 80 | Prediction number: 1 81 | Processing image: img_0072.png 82 | Prediction number: 2 83 | Processing image: img_0080.png 84 | Prediction number: 4 85 | Processing image: img_0100.png 86 | Prediction number: 5 87 | Processing image: img_0014.png 88 | Prediction number: 3 89 | ``` 90 | 91 | ### Reference 92 | 93 | - [pytorch](https://github.com/pytorch/pytorch) 94 | - [pytorch-book](https://github.com/chenyuntc/pytorch-book) 95 | 96 | 97 | -------------------------------------------------------------------------------- /classifier_test.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from utils import Tester 3 | from network import resnet34, resnet101 4 | 5 | # Set Test parameters 6 | params = Tester.TestParams() 7 | params.gpus = [0] # set 'params.gpus=[]' to use CPU model. if len(params.gpus)>1, default to use params.gpus[0] to test 8 | params.ckpt = './models/ckpt_epoch_800_res101.pth' #'./models/ckpt_epoch_400_res34.pth' 9 | params.testdata_dir = './testimg/' 10 | 11 | # models 12 | # model = resnet34(pretrained=False, num_classes=1000) # batch_size=120, 1GPU Memory < 7000M 13 | # model.fc = nn.Linear(512, 6) 14 | model = resnet101(pretrained=False,num_classes=1000) # batch_size=60, 1GPU Memory > 9000M 15 | model.fc = nn.Linear(512*4, 6) 16 | 17 | # Test 18 | tester = Tester(model, params) 19 | tester.test() 20 | -------------------------------------------------------------------------------- /classifier_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau 5 | 6 | from data import Hand 7 | from utils import Trainer 8 | from network import resnet34, resnet101 9 | 10 | # Hyper-params 11 | data_root = './data/' 12 | model_path = './models/' 13 | batch_size = 60 # batch_size per GPU, if use GPU mode; resnet34: batch_size=120 14 | num_workers = 2 15 | 16 | init_lr = 0.01 17 | lr_decay = 0.8 18 | momentum = 0.9 19 | weight_decay = 0.000 20 | nesterov = True 21 | 22 | # Set Training parameters 23 | params = Trainer.TrainParams() 24 | params.max_epoch = 1000 25 | params.criterion = nn.CrossEntropyLoss() 26 | params.gpus = [0] # set 'params.gpus=[]' to use CPU mode 27 | params.save_dir = model_path 28 | params.ckpt = None 29 | params.save_freq_epoch = 100 30 | 31 | # load data 32 | print("Loading dataset...") 33 | train_data = Hand(data_root,train=True) 34 | val_data = Hand(data_root,train=False) 35 | 36 | batch_size = batch_size if len(params.gpus) == 0 else batch_size*len(params.gpus) 37 | 38 | train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers) 39 | print('train dataset len: {}'.format(len(train_dataloader.dataset))) 40 | 41 | val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers) 42 | print('val dataset len: {}'.format(len(val_dataloader.dataset))) 43 | 44 | # models 45 | # model = resnet34(pretrained=False, modelpath=model_path, num_classes=1000) # batch_size=120, 1GPU Memory < 7000M 46 | # model.fc = nn.Linear(512, 6) 47 | model = resnet101(pretrained=False, modelpath=model_path, num_classes=1000) # batch_size=60, 1GPU Memory > 9000M 48 | model.fc = nn.Linear(512*4, 6) 49 | 50 | # optimizer 51 | trainable_vars = [param for param in model.parameters() if param.requires_grad] 52 | print("Training with sgd") 53 | params.optimizer = torch.optim.SGD(trainable_vars, lr=init_lr, 54 | momentum=momentum, 55 | weight_decay=weight_decay, 56 | nesterov=nesterov) 57 | 58 | # Train 59 | params.lr_scheduler = ReduceLROnPlateau(params.optimizer, 'min', factor=lr_decay, patience=10, cooldown=10, verbose=True) 60 | trainer = Trainer(model, params, train_dataloader, val_dataloader) 61 | trainer.train() 62 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Hand 2 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | from PIL import Image 4 | from torch.utils import data 5 | from torchvision import transforms as T 6 | 7 | 8 | class Hand(data.Dataset): 9 | 10 | def __init__(self,root,transforms=None,train=True): 11 | ''' 12 | Get images, divide into train/val set 13 | ''' 14 | 15 | self.train = train 16 | self.images_root = root 17 | 18 | self._read_txt_file() 19 | 20 | if transforms is None: 21 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 22 | std=[0.229, 0.224, 0.225]) 23 | 24 | if not train: 25 | self.transforms = T.Compose([ 26 | T.Scale(224), 27 | T.CenterCrop(224), 28 | T.ToTensor(), 29 | normalize 30 | ]) 31 | else: 32 | self.transforms = T.Compose([ 33 | T.Scale(256), 34 | T.RandomSizedCrop(224), 35 | T.RandomHorizontalFlip(), 36 | T.ToTensor(), 37 | normalize 38 | ]) 39 | 40 | def _read_txt_file(self): 41 | self.images_path = [] 42 | self.images_labels = [] 43 | 44 | if self.train: 45 | txt_file = self.images_root + "./images/train.txt" 46 | else: 47 | txt_file = self.images_root + "./images/test.txt" 48 | 49 | with open(txt_file, 'r') as f: 50 | lines = f.readlines() 51 | for line in lines: 52 | item = line.strip().split(' ') 53 | self.images_path.append(item[0]) 54 | self.images_labels.append(item[1]) 55 | 56 | def __getitem__(self, index): 57 | ''' 58 | return the data of one image 59 | ''' 60 | img_path = self.images_root+self.images_path[index] 61 | label = self.images_labels[index] 62 | data = Image.open(img_path) 63 | data = self.transforms(data) 64 | return data, int(label) 65 | 66 | def __len__(self): 67 | return len(self.images_path) 68 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet34', 'resnet101'] 6 | 7 | model_urls = { 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 10 | } 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | 92 | def __init__(self, block, layers, num_classes=1000): 93 | self.inplanes = 64 94 | super(ResNet, self).__init__() 95 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 96 | bias=False) 97 | self.bn1 = nn.BatchNorm2d(64) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 100 | self.layer1 = self._make_layer(block, 64, layers[0]) 101 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 102 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 103 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 104 | self.avgpool = nn.AvgPool2d(7, stride=1) 105 | self.fc = nn.Linear(512 * block.expansion, num_classes) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | 147 | return x 148 | 149 | 150 | def resnet34(pretrained=False, modelpath='./models',**kwargs): 151 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 152 | if pretrained: 153 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], model_dir=modelpath)) 154 | return model 155 | 156 | 157 | def resnet101(pretrained=False, modelpath='./models', **kwargs): 158 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 159 | if pretrained: 160 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir=modelpath)) 161 | return model -------------------------------------------------------------------------------- /pic/BasicBlock_Bottleneck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/pic/BasicBlock_Bottleneck.png -------------------------------------------------------------------------------- /pic/ResNet34_ResNet101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/pic/ResNet34_ResNet101.jpg -------------------------------------------------------------------------------- /pic/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/pic/accuracy.png -------------------------------------------------------------------------------- /pic/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/pic/loss.png -------------------------------------------------------------------------------- /pic/shortcut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/pic/shortcut.png -------------------------------------------------------------------------------- /testimg/img_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/testimg/img_0000.png -------------------------------------------------------------------------------- /testimg/img_0014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/testimg/img_0014.png -------------------------------------------------------------------------------- /testimg/img_0046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/testimg/img_0046.png -------------------------------------------------------------------------------- /testimg/img_0072.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/testimg/img_0072.png -------------------------------------------------------------------------------- /testimg/img_0080.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/testimg/img_0080.png -------------------------------------------------------------------------------- /testimg/img_0100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiMeng95/pytorch_hand_classifier/afd448b1e0aaded1f8cafd66fa2d688560ac71de/testimg/img_0100.png -------------------------------------------------------------------------------- /utils/Tester.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | from PIL import Image 5 | from .log import logger 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | import torchvision.transforms.functional as tv_F 11 | 12 | 13 | class TestParams(object): 14 | # params based on your local env 15 | gpus = [] # default to use CPU mode 16 | 17 | # loading existing checkpoint 18 | ckpt = './models/ckpt_epoch_800_res101.pth' # path to the ckpt file 19 | 20 | testdata_dir = './testimg/' 21 | 22 | class Tester(object): 23 | 24 | TestParams = TestParams 25 | 26 | def __init__(self, model, test_params): 27 | assert isinstance(test_params, TestParams) 28 | self.params = test_params 29 | 30 | # load model 31 | self.model = model 32 | ckpt = self.params.ckpt 33 | if ckpt is not None: 34 | self._load_ckpt(ckpt) 35 | logger.info('Load ckpt from {}'.format(ckpt)) 36 | 37 | # set CUDA_VISIBLE_DEVICES, 1 GPU is enough 38 | if len(self.params.gpus) > 0: 39 | gpu_test = str(self.params.gpus[0]) 40 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_test 41 | logger.info('Set CUDA_VISIBLE_DEVICES to {}...'.format(gpu_test)) 42 | self.model = self.model.cuda() 43 | 44 | self.model.eval() 45 | 46 | def test(self): 47 | 48 | img_list = os.listdir(self.params.testdata_dir) 49 | 50 | for img_name in img_list: 51 | print('Processing image: ' + img_name) 52 | 53 | img = Image.open(os.path.join(self.params.testdata_dir, img_name)) 54 | img = tv_F.to_tensor(tv_F.resize(img, (224, 224))) 55 | img = tv_F.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 56 | img_input = Variable(torch.unsqueeze(img, 0)) 57 | if len(self.params.gpus) > 0: 58 | img_input = img_input.cuda() 59 | 60 | output = self.model(img_input) 61 | score = F.softmax(output, dim=1) 62 | _, prediction = torch.max(score.data, dim=1) 63 | 64 | print('Prediction number: ' + str(prediction[0])) 65 | 66 | def _load_ckpt(self, ckpt): 67 | self.model.load_state_dict(torch.load(ckpt)) 68 | -------------------------------------------------------------------------------- /utils/Trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import numpy as np 5 | 6 | import torch as t 7 | import torch.nn as nn 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | from torch.autograd import Variable 10 | from torchnet import meter 11 | 12 | from .log import logger 13 | from .visualize import Visualizer 14 | 15 | 16 | def get_learning_rates(optimizer): 17 | lrs = [pg['lr'] for pg in optimizer.param_groups] 18 | lrs = np.asarray(lrs, dtype=np.float) 19 | return lrs 20 | 21 | 22 | class TrainParams(object): 23 | # required params 24 | max_epoch = 30 25 | 26 | # optimizer and criterion and learning rate scheduler 27 | optimizer = None 28 | criterion = None 29 | lr_scheduler = None # should be an instance of ReduceLROnPlateau or _LRScheduler 30 | 31 | # params based on your local env 32 | gpus = [] # default to use CPU mode 33 | save_dir = './models/' # default `save_dir` 34 | 35 | # loading existing checkpoint 36 | ckpt = None # path to the ckpt file 37 | 38 | # saving checkpoints 39 | save_freq_epoch = 1 # save one ckpt per `save_freq_epoch` epochs 40 | 41 | 42 | class Trainer(object): 43 | 44 | TrainParams = TrainParams 45 | 46 | def __init__(self, model, train_params, train_data, val_data=None): 47 | assert isinstance(train_params, TrainParams) 48 | self.params = train_params 49 | 50 | # Data loaders 51 | self.train_data = train_data 52 | self.val_data = val_data 53 | 54 | # criterion and Optimizer and learning rate 55 | self.last_epoch = 0 56 | self.criterion = self.params.criterion 57 | self.optimizer = self.params.optimizer 58 | self.lr_scheduler = self.params.lr_scheduler 59 | logger.info('Set criterion to {}'.format(type(self.criterion))) 60 | logger.info('Set optimizer to {}'.format(type(self.optimizer))) 61 | logger.info('Set lr_scheduler to {}'.format(type(self.lr_scheduler))) 62 | 63 | # load model 64 | self.model = model 65 | logger.info('Set output dir to {}'.format(self.params.save_dir)) 66 | if os.path.isdir(self.params.save_dir): 67 | pass 68 | else: 69 | os.makedirs(self.params.save_dir) 70 | 71 | ckpt = self.params.ckpt 72 | if ckpt is not None: 73 | self._load_ckpt(ckpt) 74 | logger.info('Load ckpt from {}'.format(ckpt)) 75 | 76 | # meters 77 | self.loss_meter = meter.AverageValueMeter() 78 | self.confusion_matrix = meter.ConfusionMeter(6) 79 | 80 | # set CUDA_VISIBLE_DEVICES 81 | if len(self.params.gpus) > 0: 82 | gpus = ','.join([str(x) for x in self.params.gpus]) 83 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 84 | self.params.gpus = tuple(range(len(self.params.gpus))) 85 | logger.info('Set CUDA_VISIBLE_DEVICES to {}...'.format(gpus)) 86 | self.model = nn.DataParallel(self.model, device_ids=self.params.gpus) 87 | self.model = self.model.cuda() 88 | 89 | self.model.train() 90 | 91 | def train(self): 92 | vis = Visualizer() 93 | best_loss = np.inf 94 | for epoch in range(self.last_epoch, self.params.max_epoch): 95 | 96 | self.loss_meter.reset() 97 | self.confusion_matrix.reset() 98 | 99 | self.last_epoch += 1 100 | logger.info('Start training epoch {}'.format(self.last_epoch)) 101 | 102 | self._train_one_epoch() 103 | 104 | # save model 105 | if (self.last_epoch % self.params.save_freq_epoch == 0) or (self.last_epoch == self.params.max_epoch - 1): 106 | save_name = self.params.save_dir + 'ckpt_epoch_{}.pth'.format(self.last_epoch) 107 | t.save(self.model.state_dict(), save_name) 108 | 109 | val_cm, val_accuracy = self._val_one_epoch() 110 | 111 | if self.loss_meter.value()[0] < best_loss: 112 | logger.info('Found a better ckpt ({:.3f} -> {:.3f}), '.format(best_loss, self.loss_meter.value()[0])) 113 | best_loss = self.loss_meter.value()[0] 114 | 115 | # visualize 116 | vis.plot('loss', self.loss_meter.value()[0]) 117 | vis.plot('val_accuracy', val_accuracy) 118 | vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format( 119 | epoch=epoch, loss=self.loss_meter.value()[0], val_cm=str(val_cm.value()), 120 | train_cm=str(self.confusion_matrix.value()), lr=get_learning_rates(self.optimizer))) 121 | 122 | # adjust the lr 123 | if isinstance(self.lr_scheduler, ReduceLROnPlateau): 124 | self.lr_scheduler.step(self.loss_meter.value()[0], self.last_epoch) 125 | 126 | def _load_ckpt(self, ckpt): 127 | self.model.load_state_dict(t.load(ckpt)) 128 | 129 | def _train_one_epoch(self): 130 | for step, (data, label) in enumerate(self.train_data): 131 | # train model 132 | inputs = Variable(data) 133 | target = Variable(label) 134 | if len(self.params.gpus) > 0: 135 | inputs = inputs.cuda() 136 | target = target.cuda() 137 | 138 | # forward 139 | score = self.model(inputs) 140 | loss = self.criterion(score, target) 141 | 142 | # backward 143 | self.optimizer.zero_grad() 144 | loss.backward() 145 | self.optimizer.step(None) 146 | 147 | # meters update 148 | self.loss_meter.add(loss.data[0]) 149 | self.confusion_matrix.add(score.data, target.data) 150 | 151 | def _val_one_epoch(self): 152 | self.model.eval() 153 | confusion_matrix = meter.ConfusionMeter(6) 154 | logger.info('Val on validation set...') 155 | 156 | for step, (data, label) in enumerate(self.val_data): 157 | 158 | # val model 159 | inputs = Variable(data, volatile=True) 160 | target = Variable(label.type(t.LongTensor), volatile=True) 161 | if len(self.params.gpus) > 0: 162 | inputs = inputs.cuda() 163 | target = target.cuda() 164 | 165 | score = self.model(inputs) 166 | confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor)) 167 | 168 | self.model.train() 169 | cm_value = confusion_matrix.value() 170 | accuracy = 100. * (cm_value[0][0] + cm_value[1][1] 171 | + cm_value[2][2] + cm_value[3][3] 172 | + cm_value[4][4] + cm_value[5][5]) / (cm_value.sum()) 173 | return confusion_matrix, accuracy 174 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .Trainer import Trainer 2 | from .Tester import Tester 3 | from .log import logger -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(name='root'): 5 | formatter = logging.Formatter( 6 | # fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s') 7 | fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 8 | 9 | handler = logging.StreamHandler() 10 | handler.setFormatter(formatter) 11 | 12 | logger = logging.getLogger(name) 13 | logger.setLevel(logging.DEBUG) 14 | logger.addHandler(handler) 15 | return logger 16 | 17 | 18 | logger = get_logger('root') 19 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import visdom 3 | import time 4 | import numpy as np 5 | 6 | class Visualizer(object): 7 | ''' 8 | 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function` 9 | 调用原生的visdom接口 10 | ''' 11 | 12 | def __init__(self, env='default', **kwargs): 13 | self.vis = visdom.Visdom(env=env, **kwargs) 14 | 15 | # 画的第几个数,相当于横座标 16 | # 保存(’loss',23) 即loss的第23个点 17 | self.index = {} 18 | self.log_text = '' 19 | def reinit(self,env='default',**kwargs): 20 | ''' 21 | 修改visdom的配置 22 | ''' 23 | self.vis = visdom.Visdom(env=env,**kwargs) 24 | return self 25 | 26 | def plot_many(self, d): 27 | ''' 28 | 一次plot多个 29 | @params d: dict (name,value) i.e. ('loss',0.11) 30 | ''' 31 | for k, v in d.items(): 32 | self.plot(k, v) 33 | 34 | def img_many(self, d): 35 | for k, v in d.items(): 36 | self.img(k, v) 37 | 38 | def plot(self, name, y,**kwargs): 39 | ''' 40 | self.plot('loss',1.00) 41 | ''' 42 | x = self.index.get(name, 0) 43 | self.vis.line(Y=np.array([y]), X=np.array([x]), 44 | win=name, 45 | opts=dict(title=name), 46 | update=None if x == 0 else 'append', 47 | **kwargs 48 | ) 49 | self.index[name] = x + 1 50 | 51 | def img(self, name, img_,**kwargs): 52 | ''' 53 | self.img('input_img',t.Tensor(64,64)) 54 | self.img('input_imgs',t.Tensor(3,64,64)) 55 | self.img('input_imgs',t.Tensor(100,1,64,64)) 56 | self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10) 57 | 58 | !!!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!!! 59 | ''' 60 | self.vis.images(img_.cpu().numpy(), 61 | win=name, 62 | opts=dict(title=name), 63 | **kwargs 64 | ) 65 | 66 | 67 | def log(self,info,win='log_text'): 68 | ''' 69 | self.log({'loss':1,'lr':0.0001}) 70 | ''' 71 | 72 | self.log_text += ('[{time}] {info}
'.format( 73 | time=time.strftime('%m%d_%H%M%S'),\ 74 | info=info)) 75 | self.vis.text(self.log_text,win) 76 | 77 | def __getattr__(self, name): 78 | return getattr(self.vis, name) 79 | 80 | --------------------------------------------------------------------------------