├── LICENSE ├── README.md ├── imgs ├── Net.png ├── acc_rafdb.png ├── arrangement.png └── perception.png └── src ├── Networks.py ├── __pycache__ ├── Networks.cpython-38.pyc └── image_utils.cpython-38.pyc ├── dataset.py ├── distributed.py ├── image_utils.py ├── plot_confusion_matrix.py ├── test_raf-db.py └── train_raf-db.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jiawei Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Amend Facial Expression Representation via De-albino and Affinity 2 | Jiawei Shi and Songhao Zhu 3 | Nanjing University of Posts and Telecommunications 4 | Nanjing, China 5 | {1319055608, zhush}@njupt.edu.cn 6 | 7 | 8 | ## Abstract 9 | Facial Expression Recognition (FER) is a classification task that points to face variants. Hence, there are certain affinity features between facial expressions, receiving little attention in the FER literature. Convolution padding, despite helping capture the edge information, causes erosion of the feature map simultaneously. After multi-layer filling convolution, the output feature map named albino feature definitely weakens the representation of the expression. To tackle these challenges, we propose a novel architecture named Amending Representation Module (ARM). ARM is a substitute for the pooling layer. Theoretically, it can be embedded in the back end of any network to deal with the Padding Erosion. ARM efficiently enhances facial expression representation from two different directions: 1) reducing the weight of eroded features to offset the side effect of padding, and 2) decomposing facial features to simplify representation learning. In terms of data imbalance, we designed a minimal random resampling (MRR) scheme to suppress network overfitting. Experiments on public benchmarks prove that our ARM boosts the performance of FER remarkably. The validation accuracies are respectively **90.42%** on RAF-DB, **65.2%** on Affect-Net, and **58.71%** on SFEW, exceeding current state-of-theart methods. The paper has been submitted in [arXiv.org](https://arxiv.org/abs/2103.10189). 10 | 11 | ## Amend-Representation-Module 12 | 13 | ![image](https://github.com/sunmusik/Amend-Representation-Module/blob/master/imgs/Net.png) 14 | 15 | Overview of Amend Representation Module (ARM). The ARM composed of three blocks replaces the pooling layer 16 | of CNN. The solid arrows indicate the processing flow of one feature map, and the dotted arrows refer to the auxiliary flow of 17 | a batch. It should be noted that the relationship between the two channels requires the de-albino kernel to be single-channel 18 | and unique. 19 | 20 | ## Train 21 | - Requirements 22 | 23 | Torch 1.7.1, APEX 0.1, and torchvision 0.8.2. 24 | 25 | For APEX 0.1 (Linux): 26 | 27 | git clone https://github.com/NVIDIA/apex 28 | cd apex 29 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 30 | 31 | - Data Preparation 32 | 33 | Download [RAF-DB](http://www.whdeng.cn/RAF/model1.html#dataset) dataset, and make sure it have a structure like following: 34 | 35 | ``` 36 | - datasets/raf-basic/ 37 | EmoLabel/ 38 | list_patition_label.txt 39 | Image/aligned/ 40 | train_00001_aligned.jpg 41 | test_0001_aligned.jpg 42 | ... 43 | ``` 44 | - Training 45 | ``` 46 | python src/train_raf-db.py 47 | ``` 48 | 49 | 50 | - Testing 51 | 52 | ``` 53 | python src/test_raf-db.py --checkpoint *.pth 54 | ``` 55 | 56 | - Testing and Confusion Matrix 57 | 58 | ``` 59 | python src/test_raf-db.py --checkpoint *.pth --plot_cm 60 | ``` 61 | 62 | 63 | 64 | ## Result 65 | - Confusion Matrix on RAF-DB 66 | 67 |
68 | 69 | 70 | 71 | # Citation 72 | If you use the sample code or part of it in your research, please cite the following: 73 | 74 | ``` 75 | @ARTICLE{2021arXiv210310189S, 76 | author = {{Shi}, Jiawei and {Zhu}, Songhao}, 77 | title = "{Learning to Amend Facial Expression Representation via De-albino and Affinity}", 78 | journal = {arXiv e-prints}, 79 | keywords = {Computer Science - Computer Vision and Pattern Recognition}, 80 | year = 2021, 81 | month = mar, 82 | eid = {arXiv:2103.10189}, 83 | pages = {arXiv:2103.10189}, 84 | archivePrefix = {arXiv}, 85 | eprint = {2103.10189}, 86 | primaryClass = {cs.CV}, 87 | adsurl = {https://ui.adsabs.harvard.edu/abs/2021arXiv210310189S}, 88 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 89 | } 90 | ``` 91 | 92 | ## License 93 | ARM is available under the MIT license. See the LICENSE file for more info. 94 | -------------------------------------------------------------------------------- /imgs/Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimKarras/Amend-Representation-Module/fb8a1552325bfe31090ae1c51a515869bcd01a85/imgs/Net.png -------------------------------------------------------------------------------- /imgs/acc_rafdb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimKarras/Amend-Representation-Module/fb8a1552325bfe31090ae1c51a515869bcd01a85/imgs/acc_rafdb.png -------------------------------------------------------------------------------- /imgs/arrangement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimKarras/Amend-Representation-Module/fb8a1552325bfe31090ae1c51a515869bcd01a85/imgs/arrangement.png -------------------------------------------------------------------------------- /imgs/perception.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimKarras/Amend-Representation-Module/fb8a1552325bfe31090ae1c51a515869bcd01a85/imgs/perception.png -------------------------------------------------------------------------------- /src/Networks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | import torch 4 | from torchvision import models 5 | 6 | 7 | class ResNet18(nn.Module): 8 | def __init__(self, pretrained=False, num_classes=7, drop_rate=0): 9 | super(ResNet18, self).__init__() 10 | self.drop_rate = drop_rate 11 | resnet = models.resnet18(pretrained) 12 | self.features = nn.Sequential(*list(resnet.children())[:-1]) 13 | self.fc = nn.Linear(512, num_classes) 14 | 15 | 16 | def forward(self, x): 17 | x = self.features(x) 18 | if self.drop_rate > 0: 19 | x = nn.Dropout(self.drop_rate)(x) 20 | x = x.view(x.size(0), -1) 21 | out = self.fc(x) 22 | 23 | return out, out 24 | 25 | 26 | 27 | class ResNet18_ARM___RAF(nn.Module): 28 | def __init__(self, pretrained=True, num_classes=7, drop_rate=0): 29 | super(ResNet18_ARM___RAF, self).__init__() 30 | self.drop_rate = drop_rate 31 | resnet = models.resnet18(pretrained) 32 | self.features = nn.Sequential(*list(resnet.children())[:-2]) # before avgpool 512x1 33 | self.arrangement = nn.PixelShuffle(16) 34 | self.arm = Amend_raf() 35 | self.fc = nn.Linear(121, num_classes) 36 | 37 | 38 | def forward(self, x): 39 | x = self.features(x) 40 | 41 | x = self.arrangement(x) 42 | 43 | x, alpha = self.arm(x) 44 | 45 | if self.drop_rate > 0: 46 | x = nn.Dropout(self.drop_rate)(x) 47 | 48 | x = x.view(x.size(0), -1) 49 | out = self.fc(x) 50 | 51 | return out, alpha 52 | 53 | class Amend_raf(nn.Module): # moren 54 | def __init__(self, inplace=2): 55 | super(Amend_raf, self).__init__() 56 | self.de_albino = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=32, stride=8, padding=0, bias=False) 57 | self.bn = nn.BatchNorm2d(inplace) 58 | self.alpha = nn.Parameter(torch.tensor([1.0])) 59 | 60 | def forward(self, x): 61 | mask = torch.tensor([]).cuda() 62 | createVar = locals() 63 | for i in range(x.size(1)): 64 | createVar['x' + str(i)] = torch.unsqueeze(x[:, i], 1) 65 | createVar['x' + str(i)] = self.de_albino(createVar['x' + str(i)]) 66 | mask = torch.cat((mask, createVar['x' + str(i)]), 1) 67 | x = self.bn(mask) 68 | xmax, _ = torch.max(x, 1, keepdim=True) 69 | global_mean = x.mean(dim=[0, 1]) 70 | xmean = torch.mean(x, 1, keepdim=True) 71 | xmin, _ = torch.min(x, 1, keepdim=True) 72 | x = xmean + self.alpha * global_mean 73 | 74 | return x, self.alpha 75 | 76 | 77 | 78 | 79 | 80 | if __name__=='__main__': 81 | model = ResNet18_ARM___RAF() 82 | input = torch.randn(64, 3, 224, 224) 83 | out, alpha = model(input) 84 | print(out.size()) 85 | -------------------------------------------------------------------------------- /src/__pycache__/Networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimKarras/Amend-Representation-Module/fb8a1552325bfe31090ae1c51a515869bcd01a85/src/__pycache__/Networks.cpython-38.pyc -------------------------------------------------------------------------------- /src/__pycache__/image_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimKarras/Amend-Representation-Module/fb8a1552325bfe31090ae1c51a515869bcd01a85/src/__pycache__/image_utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import cv2 3 | import pandas as pd 4 | import os 5 | import image_utils 6 | import random 7 | 8 | 9 | 10 | class RafDataSet(data.Dataset): 11 | def __init__(self, raf_path, phase, transform=None, basic_aug=False): 12 | self.phase = phase 13 | self.transform = transform 14 | self.raf_path = raf_path 15 | 16 | NAME_COLUMN = 0 17 | LABEL_COLUMN = 1 18 | df = pd.read_csv(os.path.join(self.raf_path, 'EmoLabel/list_patition_label.txt'), sep=' ', header=None) 19 | if phase == 'train': 20 | dataset = df[df[NAME_COLUMN].str.startswith('train')] 21 | else: 22 | dataset = df[df[NAME_COLUMN].str.startswith('test')] 23 | file_names = dataset.iloc[:, NAME_COLUMN].values 24 | self.label = dataset.iloc[:, 25 | LABEL_COLUMN].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral 26 | 27 | self.file_paths = [] 28 | # use raf-db aligned images for training/testing 29 | for f in file_names: 30 | f = f.split(".")[0] 31 | f = f + "_aligned.jpg" 32 | path = os.path.join(self.raf_path, 'Image/aligned', f) 33 | self.file_paths.append(path) 34 | 35 | self.basic_aug = basic_aug 36 | self.aug_func = [image_utils.flip_image, image_utils.add_gaussian_noise] 37 | 38 | def __len__(self): 39 | return len(self.file_paths) 40 | 41 | def __getitem__(self, idx): 42 | path = self.file_paths[idx] 43 | image = cv2.imread(path) 44 | image = image[:, :, ::-1] # BGR to RGB 45 | label = self.label[idx] 46 | if self.phase == 'train': 47 | if self.basic_aug and random.uniform(0, 1) > 0.5: 48 | index = random.randint(0, 1) 49 | image = self.aug_func[index](image) 50 | 51 | if self.transform is not None: 52 | image = self.transform(image) 53 | 54 | return image, label, idx 55 | -------------------------------------------------------------------------------- /src/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /src/image_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def add_gaussian_noise(image_array, mean=0.0, var=30): 5 | std = var**0.5 6 | noisy_img = image_array + np.random.normal(mean, std, image_array.shape) 7 | noisy_img_clipped = np.clip(noisy_img, 0, 255).astype(np.uint8) 8 | return noisy_img_clipped 9 | 10 | def flip_image(image_array): 11 | return cv2.flip(image_array, 1) 12 | 13 | def color2gray(image_array): 14 | gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) 15 | gray_img_3d = image_array.copy() 16 | gray_img_3d[:, :, 0] = gray 17 | gray_img_3d[:, :, 1] = gray 18 | gray_img_3d[:, :, 2] = gray 19 | return gray_img_3d 20 | -------------------------------------------------------------------------------- /src/plot_confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import matplotlib.pyplot as plt # 绘图库 3 | import numpy as np 4 | import os 5 | 6 | def plot_confusion_matrix(cm, labels_name, title, acc): 7 | cm = cm / cm.sum(axis=1)[:, np.newaxis] # 归一化 8 | thresh = cm.max() / 2 9 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 10 | plt.text(j, i, "{:0.2f}".format(cm[i, j]), 11 | horizontalalignment="center", 12 | color="white" if cm[i, j] > thresh else "black") 13 | plt.imshow(cm, interpolation='nearest') # 在特定的窗口上显示图像 14 | plt.title(title) # 图像标题 15 | plt.colorbar() 16 | num_class = np.array(range(len(labels_name))) # 获取标签的间隔数 17 | plt.xticks(num_class, labels_name, rotation=90) # 将标签印在x轴坐标上 18 | plt.yticks(num_class, labels_name) # 将标签印在y轴坐标上 19 | plt.ylabel('Target') 20 | plt.xlabel('Prediction') 21 | plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap('Blues')) 22 | plt.tight_layout() 23 | plt.savefig(os.path.join('./Confusion_matrix/raf-db', "acc" + str(acc) + ".png"), format='png') 24 | plt.show() 25 | 26 | -------------------------------------------------------------------------------- /src/test_raf-db.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import numpy as np 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | import torch 7 | import argparse 8 | import Networks 9 | from sklearn.metrics import confusion_matrix 10 | from plot_confusion_matrix import plot_confusion_matrix 11 | from dataset import RafDataSet 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--raf_path', type=str, default='./datasets/raf-basic/', help='Raf-DB dataset path.') 17 | parser.add_argument('-c', '--checkpoint', type=str, default=None, help='Pytorch checkpoint file path') 18 | parser.add_argument('-b', '--batch_size', type=int, default=64, help='Batch size.') 19 | parser.add_argument('--workers', default=4, type=int, help='Number of data loading workers (default: 4)') 20 | parser.add_argument('-p', '--plot_cm', action="store_true", help="Ploting confusion matrix.") 21 | return parser.parse_args() 22 | 23 | 24 | def test(): 25 | args = parse_args() 26 | model = Networks.ResNet18_ARM___RAF() 27 | 28 | print("Loading pretrained weights...", args.checkpoint) 29 | checkpoint = torch.load(args.checkpoint) 30 | model.load_state_dict(checkpoint["model_state_dict"], strict=False) 31 | 32 | data_transforms_test = transforms.Compose([ 33 | transforms.ToPILImage(), 34 | transforms.Resize((224, 224)), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 37 | test_dataset = RafDataSet(args.raf_path, phase='test', transform=data_transforms_test) 38 | test_size = test_dataset.__len__() 39 | print('Test set size:', test_size) 40 | 41 | test_loader = torch.utils.data.DataLoader(test_dataset, 42 | batch_size=args.batch_size, 43 | num_workers=args.workers, 44 | shuffle=False, 45 | pin_memory=True) 46 | 47 | model = model.cuda() 48 | 49 | pre_labels = [] 50 | gt_labels = [] 51 | with torch.no_grad(): 52 | bingo_cnt = 0 53 | model.eval() 54 | for batch_i, (imgs, targets, _) in enumerate(test_loader): 55 | outputs, _ = model(imgs.cuda()) 56 | targets = targets.cuda() 57 | _, predicts = torch.max(outputs, 1) 58 | correct_or_not = torch.eq(predicts, targets) 59 | pre_labels += predicts.cpu().tolist() 60 | gt_labels += targets.cpu().tolist() 61 | bingo_cnt += correct_or_not.sum().cpu() 62 | 63 | acc = bingo_cnt.float() / float(test_size) 64 | acc = np.around(acc.numpy(), 4) 65 | print(f"Test accuracy: {acc:.4f}.") 66 | 67 | if args.plot_cm: 68 | cm = confusion_matrix(gt_labels, pre_labels) 69 | cm = np.array(cm) 70 | labels_name = ['SU', 'FE', 'DI', 'HA', 'SA', 'AN', "NE"] # 横纵坐标标签 71 | plot_confusion_matrix(cm, labels_name, 'RAF-DB', acc) 72 | 73 | 74 | if __name__ == "__main__": 75 | test() 76 | -------------------------------------------------------------------------------- /src/train_raf-db.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | from apex import amp 4 | import numpy as np 5 | import torch.utils.data as data 6 | from torchvision import transforms 7 | import os, torch 8 | import argparse 9 | import Networks 10 | from dataset import RafDataSet 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--raf_path', type=str, default='./datasets/raf-basic/', help='Raf-DB dataset path.') 16 | parser.add_argument('-c', '--checkpoint', type=str, default=None, help='Pytorch checkpoint file path') 17 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size.') 18 | parser.add_argument('--val_batch_size', type=int, default=64, help='Batch size for validation.') 19 | parser.add_argument('--optimizer', type=str, default="adam", help='Optimizer, adam or sgd.') 20 | parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate for sgd.') 21 | parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for sgd') 22 | parser.add_argument('--workers', default=4, type=int, help='Number of data loading workers (default: 4)') 23 | parser.add_argument('--epochs', type=int, default=70, help='Total training epochs.') 24 | parser.add_argument('--wandb', action='store_true') 25 | return parser.parse_args() 26 | 27 | 28 | def run_training(): 29 | args = parse_args() 30 | if args.wandb: 31 | import wandb 32 | wandb.init(project='raf-db') 33 | 34 | model = Networks.ResNet18_ARM___RAF() 35 | # print(model) 36 | print("batch_size:", args.batch_size) 37 | 38 | if args.checkpoint: 39 | print("Loading pretrained weights...", args.checkpoint) 40 | checkpoint = torch.load(args.checkpoint) 41 | model.load_state_dict(checkpoint["model_state_dict"], strict=False) 42 | 43 | data_transforms = transforms.Compose([ 44 | transforms.ToPILImage(), 45 | transforms.Resize((224, 224)), 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 48 | transforms.RandomErasing(scale=(0.02, 0.1))]) 49 | 50 | train_dataset = RafDataSet(args.raf_path, phase='train', transform=data_transforms, basic_aug=True) 51 | 52 | print('Train set size:', train_dataset.__len__()) 53 | train_loader = torch.utils.data.DataLoader(train_dataset, 54 | batch_size=args.batch_size, 55 | num_workers=args.workers, 56 | shuffle=True, 57 | pin_memory=True) 58 | 59 | data_transforms_val = transforms.Compose([ 60 | transforms.ToPILImage(), 61 | transforms.Resize((224, 224)), 62 | transforms.ToTensor(), 63 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 64 | val_dataset = RafDataSet(args.raf_path, phase='test', transform=data_transforms_val) 65 | val_num = val_dataset.__len__() 66 | print('Validation set size:', val_num) 67 | 68 | val_loader = torch.utils.data.DataLoader(val_dataset, 69 | batch_size=args.val_batch_size, 70 | num_workers=args.workers, 71 | shuffle=False, 72 | pin_memory=True) 73 | 74 | params = model.parameters() 75 | if args.optimizer == 'adam': 76 | optimizer = torch.optim.Adam(params, weight_decay=1e-4) 77 | elif args.optimizer == 'sgd': 78 | optimizer = torch.optim.SGD(params, args.lr, momentum=args.momentum, weight_decay=1e-4) 79 | if args.wandb: 80 | config = wandb.config 81 | config.learning_rate = args.lr 82 | else: 83 | raise ValueError("Optimizer not supported.") 84 | print(optimizer) 85 | 86 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 87 | model = model.cuda() 88 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) 89 | CE_criterion = torch.nn.CrossEntropyLoss() 90 | 91 | 92 | best_acc = 0 93 | for i in range(1, args.epochs + 1): 94 | train_loss = 0.0 95 | correct_sum = 0 96 | iter_cnt = 0 97 | model.train() 98 | for batch_i, (imgs, targets, indexes) in enumerate(train_loader): 99 | 100 | iter_cnt += 1 101 | optimizer.zero_grad() 102 | imgs = imgs.cuda() 103 | outputs, alpha = model(imgs) 104 | targets = targets.cuda() 105 | 106 | CE_loss = CE_criterion(outputs, targets) 107 | loss = CE_loss 108 | with amp.scale_loss(loss, optimizer) as scaled_loss: 109 | scaled_loss.backward() 110 | optimizer.step() 111 | 112 | train_loss += loss 113 | _, predicts = torch.max(outputs, 1) 114 | correct_num = torch.eq(predicts, targets).sum() 115 | correct_sum += correct_num 116 | 117 | 118 | train_acc = correct_sum.float() / float(train_dataset.__len__()) 119 | train_loss = train_loss/iter_cnt 120 | print('[Epoch %d] Training accuracy: %.4f. Loss: %.3f LR: %.6f' % 121 | (i, train_acc, train_loss, optimizer.param_groups[0]["lr"])) 122 | scheduler.step() 123 | 124 | with torch.no_grad(): 125 | val_loss = 0.0 126 | iter_cnt = 0 127 | bingo_cnt = 0 128 | model.eval() 129 | for batch_i, (imgs, targets, _) in enumerate(val_loader): 130 | outputs, _ = model(imgs.cuda()) 131 | targets = targets.cuda() 132 | 133 | CE_loss = CE_criterion(outputs, targets) 134 | loss = CE_loss 135 | 136 | val_loss += loss 137 | iter_cnt += 1 138 | _, predicts = torch.max(outputs, 1) 139 | correct_or_not = torch.eq(predicts, targets) 140 | bingo_cnt += correct_or_not.sum().cpu() 141 | 142 | val_loss = val_loss/iter_cnt 143 | val_acc = bingo_cnt.float()/float(val_num) 144 | val_acc = np.around(val_acc.numpy(), 4) 145 | print("[Epoch %d] Validation accuracy:%.4f. Loss:%.3f" % (i, val_acc, val_loss)) 146 | 147 | if args.wandb: 148 | wandb.log( 149 | { 150 | "train_loss": train_loss, 151 | "train_acc": train_acc, 152 | "val_loss": val_loss, 153 | "val_acc": val_acc, 154 | } 155 | ) 156 | 157 | if val_acc > 0.92 and val_acc > best_acc: 158 | torch.save({'iter': i, 159 | 'model_state_dict': model.state_dict(), 160 | 'optimizer_state_dict': optimizer.state_dict(), }, 161 | os.path.join('models/RAF-DB', "epoch" + str(i) + "_acc" + str(val_acc) + ".pth")) 162 | print('Model saved.') 163 | if val_acc > best_acc: 164 | best_acc = val_acc 165 | print("best_acc:" + str(best_acc)) 166 | 167 | 168 | if __name__ == "__main__": 169 | run_training() 170 | --------------------------------------------------------------------------------