├── README.md ├── test.py ├── data_loader_3.py ├── data_loader_4.py ├── networks.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # RGB-D-Face-Recognition 2 | 3 | Simple RGB-D face recognition implementation. 4 | - Input: Combine RGB and Depth images into 4 channels input data 5 | - Network: ResNet 6 | 7 | Dataset we used is BU-3DFE (Binghamton University 3D Facial Expression) Database. Due to the license, we cannot provide any related data. If you need it, please request the dataset by yourself [[link]](http://www.cs.binghamton.edu/~lijun/Research/3DFE/3DFE_Analysis.html). Then, reproject the 3D face model into RGB and D images. 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.utils.data as Data 4 | import torchvision.transforms as Transforms 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | 10 | import os 11 | import sys 12 | import math 13 | import argparse 14 | import numpy as np 15 | import matplotlib 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | from tqdm import tqdm 19 | 20 | from networks import * 21 | from data_loader_4 import CreateDataloader 22 | from data_loader_3 import CreateDataloader_3 23 | 24 | ### Parameters ### 25 | split_ratio = 0 26 | batch_size = 1 27 | cuda = True 28 | 29 | def main(args): 30 | 31 | Net = None 32 | test_loader = None 33 | class_num = 0 34 | 35 | if args.channel == 4: 36 | class_num, test_loader, _ = CreateDataloader(args.rgb_dir, args.d_dir, batch_size, split_ratio) 37 | #Net = ResNet18_(4, class_num) 38 | print('------------------------------------') 39 | print('Input Channel Size: ', args.channel) 40 | print('RGB Data Directory: ', args.rgb_dir) 41 | print('D Data Directory: ', args.d_dir) 42 | else: 43 | class_num, test_loader, _ = CreateDataloader_3(args.d_dir, batch_size, split_ratio) 44 | #Net = ResNet18_(3, class_num) 45 | print('------------------------------------') 46 | print('Input Channel Size: ', args.channel) 47 | print('Data Directory: ', args.d_dir) 48 | print('RGB Data Directory: ', args.rgb_dir) 49 | print('D Data Directory: ', args.d_dir) 50 | 51 | print('Load checkpoint: ', args.checkpoint) 52 | Net = torch.load(args.checkpoint) 53 | if cuda: 54 | Net.cuda() 55 | else: 56 | Net.cpu() 57 | loss_function = nn.CrossEntropyLoss() 58 | print('------------------------------------') 59 | 60 | ### Testing ### 61 | correct = 0 62 | total = 0 63 | for (images, labels) in tqdm(test_loader): 64 | if cuda: 65 | images, labels = images.cuda(), labels.cuda() 66 | else: 67 | images, labels = images.cpu(), labels.cpu() 68 | 69 | outputs = Net(images) 70 | _, predicted = torch.max(outputs, 1) 71 | total += labels.size(0) 72 | correct += (predicted == labels).sum() 73 | loss = loss_function(outputs, labels) 74 | 75 | val_accu = 100 * float(correct) / float(total) 76 | print('Testing Accuracy: %d %%' % (val_accu)) 77 | 78 | 79 | def parse_arguments(argv): 80 | parser = argparse.ArgumentParser() 81 | 82 | parser.add_argument('--channel', type=int, 83 | help='Input layer channel size', default=4) 84 | parser.add_argument('--rgb_dir', type=str, 85 | help='RGB dataset.', default='D:/Datasets/BU_1225/dataset/test/RGB') 86 | parser.add_argument('--d_dir', type=str, 87 | help='RGB dataset.', default='D:/Datasets/BU_1225/dataset/test/D') 88 | parser.add_argument('--checkpoint', type=str, 89 | help='Folder to save checkpoints.') 90 | 91 | return parser.parse_args(argv) 92 | 93 | if __name__ == '__main__': 94 | main(parse_arguments(sys.argv[1:])) 95 | -------------------------------------------------------------------------------- /data_loader_3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.utils.data as Data 4 | import torchvision.transforms as Transforms 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | 10 | import numpy as np 11 | import os 12 | import math 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | from PIL import Image 16 | 17 | 18 | classes_n = 0 19 | 20 | class DataType(): 21 | "Stores the paths to images for a given class" 22 | def __init__(self, name, img_paths): 23 | self.class_name = name 24 | self.img_paths = img_paths 25 | 26 | class ImagePath(): 27 | def __init__(self, rgb_path, d_path): 28 | self.rgb_path = rgb_path 29 | self.d_path = d_path 30 | 31 | def get_image_paths(facedir): 32 | image_paths = [] 33 | if os.path.isdir(facedir): 34 | images = os.listdir(facedir) 35 | image_paths = [os.path.join(facedir,img) for img in images] 36 | 37 | return image_paths 38 | 39 | def get_labels(data): 40 | images = [] 41 | labels = [] 42 | for i in range(len(data)): 43 | #images += [ImagePath(data[i].rgb_paths[j], data[i].d_paths[j]) for j in range(len(data[i].rgb_paths))] 44 | images += [data[i].img_paths[j] for j in range(len(data[i].img_paths))] 45 | labels += [i] * len(data[i].img_paths) 46 | 47 | return np.array(images), np.array(labels) 48 | 49 | 50 | def load_data(data_path): 51 | dataset = [] 52 | classes = [path for path in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, path))] 53 | 54 | classes_n = len(classes) 55 | 56 | for i in range(classes_n): 57 | face_dir = os.path.join(data_path, classes[i]) 58 | 59 | # Get image pathes of this class 60 | image_paths = get_image_paths(face_dir) 61 | 62 | dataset.append(DataType(classes[i], image_paths)) 63 | 64 | 65 | train_x, train_y = get_labels(dataset) 66 | 67 | return classes_n, train_x, train_y 68 | 69 | 70 | def my_loader(path, Type): 71 | #print(path) 72 | with open(path, 'rb') as f: 73 | with Image.open(f) as img: 74 | if Type == 3: 75 | img = img.convert('RGB') 76 | elif Type == 1: 77 | img = img.convert('L') 78 | return img 79 | 80 | class MyDataset(Data.Dataset): 81 | def __init__(self, img_paths, labels, transform, loader=my_loader): 82 | self.img_paths = img_paths 83 | self.labels = labels 84 | self.transform = transform 85 | self.loader = loader 86 | 87 | def __getitem__(self, index): #return data type is tensor 88 | #rgb_path, d_path = self.img_paths[index].rgb_path, self.img_paths[index].d_path 89 | img_path = self.img_paths[index] 90 | label = self.labels[index] 91 | img = my_loader(img_path, 3) 92 | ''' 93 | rgb_img = np.array( my_loader(rgb_path, 3) ) 94 | d_img = np.array( my_loader(d_path, 1) ) 95 | d_img = np.expand_dims(d_img, axis=2) 96 | 97 | img = np.append(rgb_img, d_img, axis=2) 98 | ''' 99 | img = self.transform(img) 100 | label = torch.from_numpy(np.array(label)).type(torch.LongTensor) 101 | 102 | return img, label 103 | 104 | def __len__(self): # return the total size of the dataset 105 | return len(self.labels) 106 | 107 | 108 | 109 | ### Split dataset and creat train & valid dataloader ### 110 | def split_dataset(dataset_t, batch, split_ratio): 111 | num_train = len(dataset_t) 112 | indices = list(range(num_train)) 113 | split = int(np.floor(split_ratio * num_train)) 114 | 115 | #np.random.seed(random_seed) 116 | np.random.shuffle(indices) 117 | 118 | train_idx, valid_idx = indices[split:], indices[:split] 119 | train_sampler = SubsetRandomSampler(train_idx) 120 | valid_sampler = SubsetRandomSampler(valid_idx) 121 | 122 | train_loader = torch.utils.data.DataLoader(dataset_t, batch_size=batch, sampler=train_sampler) 123 | valid_loader = torch.utils.data.DataLoader(dataset_t, batch_size=batch, sampler=valid_sampler) 124 | 125 | return train_loader, valid_loader 126 | 127 | 128 | def CreateDataloader_3(data_path, batch, split_ratio): 129 | classes_n, train_x, train_y = load_data(data_path) 130 | 131 | transform = Transforms.Compose([ 132 | Transforms.Resize(224), 133 | Transforms.ToTensor(), 134 | ]) 135 | 136 | dataset = MyDataset(train_x, train_y, transform=transform) 137 | 138 | train_loader, valid_loader = split_dataset(dataset, batch, split_ratio) 139 | 140 | print('Number of classes: %d' % classes_n) 141 | print('Total images: %d' % len(train_x)) 142 | #print('Total images: %d (split ratio: %.1f)' % (len(train_x), split_ratio) ) 143 | #print('Training images:', len(train_loader)) 144 | #print('Validation images: ', len(valid_loader)) 145 | 146 | return classes_n, train_loader, valid_loader -------------------------------------------------------------------------------- /data_loader_4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.utils.data as Data 4 | import torchvision.transforms as Transforms 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | 10 | import numpy as np 11 | import os 12 | import math 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | from PIL import Image 16 | 17 | 18 | #data_path = 'D:/Datasets/BU_3DFE/RGB' 19 | #data_path_d = 'D:/Datasets/BU_3DFE/D' 20 | classes_n = 0 21 | 22 | class DataType(): 23 | "Stores the paths to images for a given class" 24 | def __init__(self, name, rgb_paths, d_paths): 25 | self.class_name = name 26 | self.rgb_paths = rgb_paths 27 | self.d_paths = d_paths 28 | 29 | class ImagePath(): 30 | def __init__(self, rgb_path, d_path): 31 | self.rgb_path = rgb_path 32 | self.d_path = d_path 33 | 34 | def get_image_paths(facedir): 35 | image_paths = [] 36 | if os.path.isdir(facedir): 37 | images = os.listdir(facedir) 38 | image_paths = [os.path.join(facedir,img) for img in images] 39 | 40 | return image_paths 41 | 42 | def get_labels(data): 43 | images = [] 44 | labels = [] 45 | for i in range(len(data)): 46 | images += [ImagePath(data[i].rgb_paths[j], data[i].d_paths[j]) for j in range(len(data[i].rgb_paths))] 47 | labels += [i] * len(data[i].rgb_paths) 48 | 49 | return np.array(images), np.array(labels) 50 | 51 | 52 | def load_data(data_path, data_path_d): 53 | dataset = [] 54 | classes = [path for path in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, path))] 55 | 56 | classes_n = len(classes) 57 | 58 | for i in range(classes_n): 59 | face_dir = os.path.join(data_path, classes[i]) 60 | face_dir_d = os.path.join(data_path_d, classes[i]) 61 | 62 | # Get image pathes of this class 63 | image_paths = get_image_paths(face_dir) 64 | image_paths_d = get_image_paths(face_dir_d) 65 | 66 | dataset.append(DataType(classes[i], image_paths, image_paths_d)) 67 | 68 | 69 | train_x, train_y = get_labels(dataset) 70 | 71 | return classes_n, train_x, train_y 72 | 73 | 74 | def my_loader(path, Type): 75 | #print(path) 76 | with open(path, 'rb') as f: 77 | with Image.open(f) as img: 78 | if Type == 3: 79 | img = img.convert('RGB') 80 | elif Type == 1: 81 | img = img.convert('L') 82 | return img 83 | 84 | class MyDataset(Data.Dataset): 85 | def __init__(self, img_paths, labels, transform, loader=my_loader): 86 | self.img_paths = img_paths 87 | self.labels = labels 88 | self.transform = transform 89 | self.loader = loader 90 | 91 | def __getitem__(self, index): #return data type is tensor 92 | rgb_path, d_path = self.img_paths[index].rgb_path, self.img_paths[index].d_path 93 | label = self.labels[index] 94 | 95 | rgb_img = np.array( my_loader(rgb_path, 3) ) 96 | d_img = np.array( my_loader(d_path, 1) ) 97 | d_img = np.expand_dims(d_img, axis=2) 98 | 99 | img = np.append(rgb_img, d_img, axis=2) 100 | img = self.transform(Image.fromarray(img)) 101 | label = torch.from_numpy(np.array(label)).type(torch.LongTensor) 102 | 103 | 104 | return img, label 105 | 106 | def __len__(self): # return the total size of the dataset 107 | return len(self.labels) 108 | 109 | 110 | 111 | ### Split dataset and creat train & valid dataloader ### 112 | def split_dataset(dataset_t, batch, split_ratio): 113 | num_train = len(dataset_t) 114 | indices = list(range(num_train)) 115 | split = int(np.floor(split_ratio * num_train)) 116 | 117 | #np.random.seed(random_seed) 118 | np.random.shuffle(indices) 119 | 120 | train_idx, valid_idx = indices[split:], indices[:split] 121 | train_sampler = SubsetRandomSampler(train_idx) 122 | valid_sampler = SubsetRandomSampler(valid_idx) 123 | 124 | train_loader = torch.utils.data.DataLoader(dataset_t, batch_size=batch, sampler=train_sampler) 125 | valid_loader = torch.utils.data.DataLoader(dataset_t, batch_size=batch, sampler=valid_sampler) 126 | 127 | return train_loader, valid_loader 128 | 129 | 130 | def CreateDataloader(data_path, data_path_d, batch, split_ratio): 131 | classes_n, train_x, train_y = load_data(data_path, data_path_d) 132 | 133 | transform = Transforms.Compose([ 134 | Transforms.Resize(224), 135 | Transforms.ToTensor(), 136 | ]) 137 | 138 | dataset = MyDataset(train_x, train_y, transform=transform) 139 | 140 | train_loader, valid_loader = split_dataset(dataset, batch, split_ratio) 141 | 142 | print('Number of classes: %d' % classes_n) 143 | print('Total images: %d' % len(train_x)) 144 | #print('Total images: %d (split ratio: %.1f)' % (len(train_x), split_ratio) ) 145 | #print('Training images:', len(train_loader)) 146 | #print('Validation images: ', len(valid_loader)) 147 | 148 | return classes_n, train_loader, valid_loader -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.utils.data as Data 4 | import torchvision.transforms as Transforms 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | import numpy as np 12 | import os 13 | import math 14 | import matplotlib 15 | 16 | ### Model ### 17 | 18 | # 3x3 convolution 19 | def conv3x3(in_channels, out_channels, stride=1): 20 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 21 | stride=stride, padding=1, bias=False) 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | class ResNet(nn.Module): 93 | def __init__(self, block, layers, input_ch, num_classes=1000): 94 | self.inplanes = 64 95 | super(ResNet, self).__init__() 96 | self.conv1 = nn.Conv2d(input_ch, 64, kernel_size=7, stride=2, padding=input_ch, bias=False) 97 | self.bn1 = nn.BatchNorm2d(64) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 100 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 101 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 102 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 103 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 104 | self.avgpool = nn.AvgPool2d(7, stride=1) 105 | #import pdb 106 | #pdb.set_trace() 107 | if input_ch == 3: 108 | self.fc = nn.Linear(512 * block.expansion, num_classes) 109 | #self.fc = nn.Linear(512 * block.expansion, 256) 110 | else: 111 | self.fc = nn.Linear(512 * block.expansion * 4, num_classes) 112 | #self.fc = nn.Linear(512 * block.expansion * 4, 256) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | 156 | def ResNet18(input_channel, class_num): 157 | return ResNet(BasicBlock, [2, 2, 2, 2], input_channel, class_num) 158 | 159 | def ResNet50(input_channel, class_num): 160 | return ResNet(Bottleneck, [3, 4, 6, 3], input_channel, class_num) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | import os 8 | import sys 9 | import math 10 | import argparse 11 | import numpy as np 12 | import matplotlib 13 | import matplotlib.pyplot as plt 14 | from PIL import Image 15 | 16 | from tensorboardX import SummaryWriter 17 | import networks 18 | from data_loader_4 import CreateDataloader 19 | from data_loader_3 import CreateDataloader_3 20 | 21 | 22 | 23 | 24 | def main(args): 25 | ### Parameters ### 26 | split_ratio = args.split 27 | epochs = 500 28 | batch_size = 1 29 | lr = 0.01 30 | momentum = 0.5 31 | 32 | cuda = True 33 | log_step_percentage = 10 34 | checkpoint_path = './checkpoints' 35 | 36 | train_loader = None 37 | valid_loader = None 38 | class_num = 0 39 | 40 | if args.channel == 4: 41 | class_num, train_loader, valid_loader = CreateDataloader(args.rgb_dir, args.d_dir, batch_size, split_ratio) 42 | if split_ratio == 0: 43 | _, valid_loader, _ = CreateDataloader(args.rgb_dir_test, args.d_dir_test, batch_size, split_ratio) 44 | Net = networks.ResNet18(4, class_num) 45 | print('------------------------------------') 46 | print('Input Channel Size: ', args.channel) 47 | print('RGB Data Directory: ', args.rgb_dir) 48 | print('D Data Directory: ', args.d_dir) 49 | else: 50 | class_num, train_loader, valid_loader = CreateDataloader_3(args.d_dir, batch_size, split_ratio) 51 | if split_ratio == 0: 52 | _, valid_loader, _ = CreateDataloader_3(args.d_dir_test, batch_size, split_ratio) 53 | Net = nwtworks.ResNet18(3, class_num) 54 | print('------------------------------------') 55 | print('Input Channel Size: ', args.channel) 56 | print('Data Directory: ', args.d_dir) 57 | print('RGB Data Directory: ', args.rgb_dir) 58 | print('D Data Directory: ', args.d_dir) 59 | 60 | # checkpints path 61 | #if not os.path.exists(f'{checkpoint_path}/{args.checkpoint_name}'): 62 | #os.makedirs(f'{checkpoint_path}/{args.checkpoint_name}') 63 | try: 64 | os.stat('%s/%s' % (checkpoint_path, args.checkpoint_name)) 65 | except: 66 | os.mkdir('%s/%s' % (checkpoint_path, args.checkpoint_name)) 67 | 68 | if cuda: 69 | Net.cuda() 70 | 71 | model = Net 72 | 73 | loss_function = nn.CrossEntropyLoss() 74 | #loss_function = nn.LogSoftmax() 75 | optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-5) 76 | log_interval = len(train_loader) / log_step_percentage 77 | 78 | 79 | # Add tensorboard writer 80 | writer = SummaryWriter() 81 | 82 | ### Training loop ### 83 | 84 | print('Start!') 85 | for epoch in range(1, epochs + 1): 86 | running_loss = 0.0 87 | correct = 0 88 | total = 0 89 | loss = 0 90 | for batch_idx, (inputs, labels) in enumerate(train_loader, 0): 91 | # get the inputs 92 | if cuda: 93 | inputs, labels = inputs.cuda(), labels.cuda() 94 | 95 | # zero the parameter gradients 96 | optimizer.zero_grad() 97 | 98 | # forward + backward + optimize 99 | outputs = Net(inputs) 100 | #print(outputs) 101 | #import pdb 102 | #pdb.set_trace() 103 | _, predicted = torch.max(outputs, 1) 104 | 105 | #print(outputs.size(), predicted.size()) 106 | 107 | total += labels.size(0) 108 | correct += (predicted == labels).sum() 109 | 110 | loss = loss_function(outputs, labels) 111 | loss.backward() 112 | optimizer.step() 113 | 114 | 115 | # print statistics 116 | if batch_idx % 10000 == 0 and batch_idx != 0: 117 | n_iter = epoch*len(train_loader) + batch_idx 118 | 119 | accu = 100 * float(correct) / float(total) 120 | writer.add_scalar('data/training_accuracy', accu, n_iter) 121 | writer.add_scalar('data/loss', loss.item(), n_iter) 122 | print(f'Train Epoch: {epoch} [{batch_idx*len(inputs)}/{len(train_loader.dataset)}] \t Loss: {loss.item()} \t Accuracy: {accu:.2f}%') 123 | 124 | if epoch % 1 == 0: 125 | correct = 0 126 | total = 0 127 | #for data in val_loader: 128 | for (images, labels) in valid_loader: 129 | #images, labels = data 130 | #labels = labels.type(torch.LongTensor) 131 | if cuda: 132 | images, labels = images.cuda(), labels.cuda() 133 | 134 | outputs = Net(images) 135 | _, predicted = torch.max(outputs, 1) 136 | total += labels.size(0) 137 | correct += (predicted == labels).sum() 138 | 139 | loss = loss_function(outputs, labels) 140 | 141 | val_accu = 100 * float(correct) / float(total) 142 | print('Validation Accuracy: %d %%' % (val_accu)) 143 | writer.add_scalar('data/val_accuracy', val_accu, epoch) 144 | writer.add_scalar('data/loss', loss, epoch) 145 | 146 | if epoch == 50: 147 | lr = lr/2 148 | if epoch % 1 == 0: 149 | torch.save(Net, f'{checkpoint_path}/{args.checkpoint_name}/Net_{epoch}.pkl') 150 | print(f'Save network: {checkpoint_path}/{args.checkpoint_name}/Net_{epoch}.pkl') 151 | 152 | print('Finished Training') 153 | 154 | torch.save(Net, 'Net_final.pkl') 155 | print('Save network successfully!') 156 | 157 | def parse_arguments(argv): 158 | parser = argparse.ArgumentParser() 159 | 160 | parser.add_argument('--rgb_dir', type=str, 161 | help='RGB dataset.', default='D:/Datasets/BU_1225/dataset/train/RGB') 162 | parser.add_argument('--d_dir', type=str, 163 | help='RGB dataset.', default='D:/Datasets/BU_1225/dataset/train/D') 164 | parser.add_argument('--rgb_dir_test', type=str, 165 | help='RGB dataset.', default='D:/Datasets/BU_1225/dataset/test/RGB') 166 | parser.add_argument('--d_dir_test', type=str, 167 | help='RGB dataset.', default='D:/Datasets/BU_1225/dataset/test/D') 168 | parser.add_argument('--channel', type=int, 169 | help='Input layer channel size', default=4) 170 | parser.add_argument('--checkpoint_name', type=str, 171 | help='Folder to save checkpoints.', default='tmp') 172 | parser.add_argument('--split', type=float, 173 | help='Split ratio', default=0) 174 | 175 | return parser.parse_args(argv) 176 | 177 | if __name__ == '__main__': 178 | main(parse_arguments(sys.argv[1:])) 179 | --------------------------------------------------------------------------------