├── test.jpeg ├── output ├── E135_0.6466.pth └── E370_acc_0.6504.pth ├── requirements.txt ├── README.md ├── frame.py ├── video.py ├── utils ├── dataset.py ├── Model.py └── DataAugment.py ├── eval.py └── train.py /test.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thgpddl/mini_Xception/HEAD/test.jpeg -------------------------------------------------------------------------------- /output/E135_0.6466.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thgpddl/mini_Xception/HEAD/output/E135_0.6466.pth -------------------------------------------------------------------------------- /output/E370_acc_0.6504.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thgpddl/mini_Xception/HEAD/output/E370_acc_0.6504.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.3 2 | visualdl==2.2.1 3 | torch 4 | torchvision 5 | scikit-learn==0.24.2 6 | opencv-python==4.5.1.48 7 | pandas==1.1.5 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mini_Xception 2 | 用于表情识别的轻量级卷积神经网络 3 | 4 | 来自论文[《Real-time Convolutional Neural Networks for Emotion and Gender Classification》](https://arxiv.org/pdf/1710.07557v1.pdf) 5 | 6 | 当然有官方的项目:[oarriaga/face_classification](https://github.com/oarriaga/face_classification) 7 | 8 | 关于论文详解,可以移步博客:[Real-time Convolutional Neural Networks for Emotion and Gender Classification--O Arriaga](https://blog.csdn.net/qq_40243750/article/details/124208527),需要重点关注的是: 9 | > 论文中基于kera实现的,在fer2013数据集上达到了66%的精度。本文是基于Pytorch实现的,最高只能达到65%的精度。对于这1%~2%精度差异未找到原因,只能归结于框架的不同(不同框架之间的效果会有差异)。 10 | 11 | 我复现论文的总结,可以移步:[Pytorch实现表情识别卷积神经网络网络:mini_Xception 12 | ](https://blog.csdn.net/qq_40243750/article/details/124226066?spm=1001.2014.3001.5501) 13 | 14 | # 1、安装轮子 15 | 使用命令: 16 | > pip install -r requirements.txt 17 | 18 | 如果太慢,可以加个清华源: 19 | > pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 20 | 21 | 22 | # 2、下载数据集 23 | 下将链接中的train.csv和test.csv文件下载下来:[数据集](https://www.aliyundrive.com/s/fQz68x23mtk) 24 | 25 | 然后在mini_Xception根目录创建dataset文件夹,将train.csv和test.csv文件放在dataset文件夹中即可 26 | 27 | 28 | # 3、训练 29 | 运行train.py脚本: 30 | - num_epochs = 200 31 | - log_step = 100 # 打印info的间隔步数 32 | - num_workers = 16 # 线程数 33 | 34 | # 4、eval 35 | 运行eval.py脚本,会计算出测试集的精度和loss,并且显示出混淆矩阵,并保存为图片。 36 | ![ConfusionMatrix](https://user-images.githubusercontent.com/48787805/163796143-8d134aa7-9e51-433b-9da8-61c651f4bb5d.png) 37 | 38 | 39 | 40 | # 5、测试 41 | 测试单幅图像,运行frame.py脚本 42 | 摄像头实时预测,运行video.py脚本 43 | 44 | # 6、DeBug 45 | 1. 出现“BrokenPipeError: [Errno 32] Broken pipe”,把线程数num_workers=0即可。 46 | -------------------------------------------------------------------------------- /frame.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from utils.Model import mini_XCEPTION 5 | 6 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 7 | print("devicea:",device) 8 | 9 | def preprocess_input(x): 10 | x = x.astype('float32') 11 | x = x / 255.0 12 | x = x - 0.5 13 | x = x * 2.0 14 | return torch.tensor(x) 15 | 16 | 17 | src = cv2.imread("test.jpeg") 18 | img=cv2.cvtColor(src,cv2.COLOR_BGR2GRAY) 19 | 20 | detection_model_path = 'utils/haarcascade_frontalface_default.xml' 21 | emotion_model_path = 'output/E135_0.6466.pth' 22 | emotion_labels = {0: 'angry', 1: 'disgust', 2: 'fear', 3: 'happy', 4: 'sad', 5: 'surprise', 6: 'neutral'} 23 | 24 | face_detection = cv2.CascadeClassifier(detection_model_path) 25 | model = mini_XCEPTION(num_classes=7).to(device) 26 | model.load_state_dict(torch.load(emotion_model_path,map_location=device)) 27 | 28 | input_size = (48, 48) 29 | 30 | faces = face_detection.detectMultiScale(img, scaleFactor=1.1, minNeighbors=8) 31 | 32 | with torch.no_grad(): 33 | for face_coordinates in faces: 34 | x, y, w, h = face_coordinates 35 | gray_face = img[y:y + h, x:x + w] 36 | try: 37 | gray_face = cv2.resize(gray_face, input_size) 38 | except: 39 | continue 40 | gray_face = preprocess_input(gray_face) 41 | inp = torch.unsqueeze(gray_face, 0) 42 | inp = torch.unsqueeze(inp, 0) 43 | inp=inp.to(device) 44 | emotion_label_arg = np.argmax(model(inp)).item() 45 | emotion_text = emotion_labels[emotion_label_arg] 46 | 47 | print("predict:", emotion_text) 48 | cv2.rectangle(src, (x, y), (x + w, y + h), (0, 0, 255), 1) 49 | cv2.imshow("", src) 50 | cv2.waitKey(0) 51 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from utils.Model import mini_XCEPTION 5 | 6 | # 使用的opencv人脸检测器,似乎不太好用 7 | 8 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 9 | print("devicea:",device) 10 | 11 | def preprocess_input(x): 12 | x = x.astype('float32') 13 | x = x / 255.0 14 | x = x - 0.5 15 | x = x * 2.0 16 | return torch.tensor(x) 17 | 18 | 19 | img = cv2.imread("test.jpeg", 0) 20 | 21 | detection_model_path = 'utils/haarcascade_frontalface_default.xml' 22 | emotion_model_path = 'output/E135_0.6466.pth' 23 | emotion_labels = {0: 'angry', 1: 'disgust', 2: 'fear', 3: 'happy', 4: 'sad', 5: 'surprise', 6: 'neutral'} 24 | 25 | face_detection = cv2.CascadeClassifier(detection_model_path) 26 | model = mini_XCEPTION(num_classes=7).to(device) 27 | model.load_state_dict(torch.load(emotion_model_path, map_location=device)) 28 | 29 | input_size = (48, 48) 30 | 31 | cap = cv2.VideoCapture(0) 32 | 33 | with torch.no_grad(): 34 | while True: 35 | ret, src = cap.read() 36 | frame = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY) 37 | faces = face_detection.detectMultiScale(frame, scaleFactor=1.3, minNeighbors=5) 38 | for face_coordinates in faces: 39 | x, y, w, h = face_coordinates 40 | gray_face = img[y:y + h, x:x + w] 41 | try: 42 | gray_face = cv2.resize(gray_face, input_size) 43 | except: 44 | continue 45 | gray_face = preprocess_input(gray_face) 46 | inp = torch.unsqueeze(gray_face, 0) 47 | inp = torch.unsqueeze(inp, 0) 48 | inp = inp.to(device) 49 | emotion_label_arg = np.argmax(model(inp)).item() 50 | emotion_text = emotion_labels[emotion_label_arg] 51 | 52 | print("predict:", emotion_text) 53 | cv2.rectangle(src, (x, y), (x + w, y + h), (0, 0, 255), 1) 54 | cv2.imshow("", src) 55 | cv2.waitKey(1) 56 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | """ 3 | @File : dataset.py 4 | @Contact : thgpddl@163.com 5 | 6 | @Modify Time @Author @Version @Desciption 7 | ------------ ------- -------- ----------- 8 | 2021/12/7 10:16 thgpddl 1.0 None 9 | """ 10 | import os 11 | import cv2 12 | import pandas as pd 13 | from torch.utils.data import Dataset 14 | from torchvision.transforms import * 15 | 16 | from .DataAugment import * 17 | 18 | hsr = Height_Shift_Range(0.1) 19 | wsr = Width_Shift_Range(0.1) 20 | 21 | 22 | class FER2013(Dataset): 23 | def __init__(self, mode, input_size): 24 | super(FER2013, self).__init__() 25 | self.data = np.array(pd.read_csv(os.path.join("dataset", mode + ".csv"))) 26 | self.input_size = input_size 27 | if mode == "train": 28 | self.aug = Augment([Salt_Pepper_Noise(0.05), 29 | Width_Shift_Range(0.1), 30 | Height_Shift_Range(0.1)]) 31 | 32 | self.transform = transforms.Compose([ToTensor(), 33 | ColorJitter(brightness=0.2), 34 | RandomRotation(10), 35 | RandomHorizontalFlip(0.5)]) 36 | else: 37 | self.aug = Augment() 38 | self.transform = transforms.Compose([ToTensor()]) 39 | 40 | def __getitem__(self, item): 41 | label, img, _ = self.data[item] 42 | data = np.array([int(pix) for pix in img.split()], dtype=np.uint8) 43 | img = np.reshape(data, (48, 48)) 44 | # np的resize是用0填充,所以64*64的下降的原因呢可能是这个 45 | img = cv2.resize(img, self.input_size, interpolation=cv2.INTER_LINEAR) 46 | 47 | img = self.aug(img) # 自定义增强 48 | img = self.transform(img) # torch增强 49 | # img=(img-0.5)*2 50 | return label, img 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | -------------------------------------------------------------------------------- /utils/Model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SeparableConv2d(nn.Module): 5 | def __init__(self, in_ch, out_ch, kernel_size, padding, stride=1): 6 | super(SeparableConv2d, self).__init__() 7 | # 也相当于分组为1的分组卷积 8 | self.depth_conv = nn.Conv2d(in_channels=in_ch, 9 | out_channels=in_ch, 10 | kernel_size=kernel_size, 11 | stride=stride, 12 | padding=padding, 13 | groups=in_ch) 14 | self.point_conv = nn.Conv2d(in_channels=in_ch, 15 | out_channels=out_ch, 16 | kernel_size=1, 17 | stride=1, 18 | padding=0, 19 | groups=1) 20 | 21 | def forward(self, input): 22 | out = self.depth_conv(input) 23 | out = self.point_conv(out) 24 | return out 25 | 26 | 27 | # residual depth-wise separable convolutions 28 | class RDWSC(nn.Module): 29 | def __init__(self, input_channels, output_channels): 30 | super(RDWSC, self).__init__() 31 | 32 | self.left = nn.Sequential(SeparableConv2d(input_channels, output_channels, kernel_size=(3, 3), padding=1), 33 | nn.BatchNorm2d(output_channels), 34 | nn.ReLU(), 35 | SeparableConv2d(output_channels, output_channels, kernel_size=(3, 3), padding=1), 36 | nn.BatchNorm2d(output_channels), 37 | nn.MaxPool2d((3, 3), stride=(2, 2), padding=1)) 38 | 39 | self.right = nn.Sequential(nn.Conv2d(input_channels, output_channels, (1, 1), stride=(2, 2)), 40 | nn.BatchNorm2d(output_channels)) 41 | 42 | def forward(self, x): 43 | right = self.right(x) 44 | left = self.left(x) 45 | output = right + left 46 | return output 47 | 48 | 49 | class mini_XCEPTION(nn.Module): 50 | def __init__(self, num_classes=7): 51 | super(mini_XCEPTION, self).__init__() 52 | 53 | self.base = nn.Sequential(nn.Conv2d(1, 8, (3, 3), (1, 1)), 54 | nn.BatchNorm2d(8), 55 | nn.ReLU(), 56 | nn.Conv2d(8, 8, (3, 3), stride=(1, 1)), 57 | nn.BatchNorm2d(8), 58 | nn.ReLU()) 59 | self.module1 = RDWSC(input_channels=8, output_channels=16) 60 | self.module2 = RDWSC(input_channels=16, output_channels=32) 61 | self.module3 = RDWSC(input_channels=32, output_channels=64) 62 | self.module4 = RDWSC(input_channels=64, output_channels=128) 63 | 64 | # output 65 | self.conv=nn.Conv2d(128, num_classes, kernel_size=(3, 3),padding=1) 66 | 67 | def forward(self, x): 68 | x = self.base(x) 69 | x = self.module1(x) 70 | x = self.module2(x) 71 | x = self.module3(x) 72 | x = self.module4(x) 73 | x=self.conv(x) 74 | x=x.mean(axis=[-1,-2]) # avgpool 75 | return x 76 | 77 | -------------------------------------------------------------------------------- /utils/DataAugment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from torch import nn 5 | from torch import Tensor 6 | import random 7 | 8 | 9 | class Augment: 10 | def __init__(self, augments: list = []): 11 | self.augments = augments 12 | 13 | def __call__(self, img: np.ndarray) -> np.ndarray: 14 | for ag in self.augments: 15 | img = ag(img) 16 | return img 17 | 18 | 19 | class Salt_Pepper_Noise: 20 | def __init__(self, prob): 21 | self.prob = prob 22 | 23 | def __call__(self, image: np.ndarray): 24 | """ 25 | 添加椒盐噪声 26 | :param image: 输入图像 27 | :param prob: 噪声比 28 | :return: 带有椒盐噪声的图像 29 | """ 30 | thres = 1 - self.prob 31 | for i in range(image.shape[0]): 32 | for j in range(image.shape[1]): 33 | rdn = np.random.rand() 34 | if rdn < self.prob: 35 | image[i, j] = 0 36 | elif rdn > thres: 37 | image[i, j] = 255 38 | return image 39 | 40 | 41 | class Width_Shift_Range: 42 | def __init__(self, rate): 43 | super(Width_Shift_Range, self).__init__() 44 | self.rate = rate 45 | 46 | def __call__(self, img): 47 | rate = np.random.uniform(0, self.rate) 48 | if len(img.shape) == 2: 49 | h, w = img.shape 50 | else: 51 | h, w, c = img.shape 52 | x = int(w * rate) # 计算平移像素 53 | if np.random.rand() < 0.5: # 随机左右平移 54 | x = -x 55 | M = np.float32([[1, 0, x], [0, 1, 0]]) 56 | shifted = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) 57 | return shifted 58 | 59 | 60 | class Gaussian_Noise: 61 | def __init__(self, means, sigma, percetage): 62 | self.means = means 63 | self.sigma = sigma 64 | self.percetage = percetage 65 | 66 | def __call__(self, src): 67 | NoiseImg = src 68 | NoiseNum = int(self.percetage * src.shape[0] * src.shape[1]) 69 | for i in range(NoiseNum): 70 | # 每次取一个随机点 71 | # 把一张图片的像素用行和列表示的话,randX 代表随机生成的行,randY代表随机生成的列 72 | # random.randint生成随机整数 73 | # 高斯噪声图片边缘不处理,故-1 74 | randX = random.randint(0, src.shape[0] - 1) 75 | randY = random.randint(0, src.shape[1] - 1) 76 | # 此处在原有像素灰度值上加上随机数 77 | NoiseImg[randX, randY] = NoiseImg[randX, randY] + random.gauss(self.means, self.sigma) 78 | # 若灰度值小于0则强制为0,若灰度值大于255则强制为255 79 | if NoiseImg[randX, randY] < 0: 80 | NoiseImg[randX, randY] = 0 81 | elif NoiseImg[randX, randY] > 255: 82 | NoiseImg[randX, randY] = 255 83 | return NoiseImg 84 | 85 | 86 | class Height_Shift_Range: 87 | def __init__(self, rate): 88 | self.rate = rate 89 | 90 | def __call__(self, img): 91 | rate = np.random.uniform(0, self.rate) 92 | if len(img.shape) == 2: 93 | h, w = img.shape 94 | else: 95 | h, w, c = img.shape 96 | y = int(h * rate) # 计算平移像素 97 | if np.random.rand() < 0.5: # 随机左右平移 98 | y = -y 99 | M = np.float32([[1, 0, 0], [0, 1, y]]) 100 | shifted = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) 101 | return shifted 102 | 103 | # 104 | # aug=Augment() 105 | # img=cv2.imread("../1.jpeg",0) 106 | # res=aug(img) 107 | # cv2.imshow("",res) 108 | # cv2.waitKey(0) 109 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # 8 | from utils.Model import mini_XCEPTION 9 | from utils.dataset import FER2013 10 | 11 | 12 | class DrawConfusionMatrix: 13 | def __init__(self, labels_name): 14 | """ 15 | 16 | :param num_classes: 分类数目 17 | """ 18 | self.labels_name = labels_name 19 | self.num_classes = len(labels_name) 20 | self.matrix = np.zeros((self.num_classes, self.num_classes), dtype="float32") 21 | 22 | def update(self, predicts, labels): 23 | """ 24 | 25 | :param predicts: 一维预测向量,eg:array([0,5,1,6,3,...],dtype=int64) 26 | :param labels: 一维标签向量:eg:array([0,5,0,6,2,...],dtype=int64) 27 | :return: 28 | """ 29 | for predict, label in zip(predicts, labels): 30 | self.matrix[predict, label] += 1 31 | 32 | def draw(self): 33 | per_sum = self.matrix.sum(axis=1) # 计算每行的和,用于百分比计算 34 | for i in range(self.num_classes): 35 | self.matrix[i] = (self.matrix[i] / per_sum[i]) # 百分比 36 | 37 | plt.imshow(self.matrix, cmap=plt.cm.Blues) # 仅画出颜色格子,没有值 38 | plt.title("Normalized confusion matrix") # title 39 | plt.xlabel("Predict label") 40 | plt.ylabel("Truth label") 41 | plt.yticks(range(7), self.labels_name) # y轴标签 42 | plt.xticks(range(7), self.labels_name, rotation=45) # x轴标签 43 | 44 | for x in range(7): 45 | for y in range(7): 46 | value = float(format('%.2f' % self.matrix[y, x])) # 数值处理 47 | plt.text(x, y, value, verticalalignment='center', horizontalalignment='center') # 写值 48 | 49 | plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域 50 | 51 | plt.colorbar() # 色条 52 | plt.savefig('./ConfusionMatrix.png', bbox_inches='tight') # bbox_inches='tight'可确保标签信息显示全 53 | plt.show() 54 | 55 | 56 | def eval(): 57 | drawconfusionmatrix = DrawConfusionMatrix(labels_name=['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 58 | 'neutral']) 59 | total_test_loss = 0 60 | total_test_acc = 0 61 | count = 0 62 | model.eval() 63 | for index, (labels, imgs) in enumerate(test_loader): 64 | labels_pd = model(imgs.to(device)) 65 | predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1) 66 | labels_np = labels.numpy() 67 | drawconfusionmatrix.update(predict_np, labels_np) 68 | acc = sum(predict_np == labels_np) 69 | loss = loss_fn(labels_pd, labels.to(device)) 70 | total_test_loss += loss.item() 71 | total_test_acc += acc 72 | count +=len(labels) 73 | 74 | mean_test_loss = total_test_loss / count 75 | mean_test_acc = total_test_acc / count 76 | print("evla\tloss:{:.4f}\tacc:{:.4f}".format(mean_test_loss, mean_test_acc)) 77 | drawconfusionmatrix.draw() 78 | 79 | 80 | if __name__ == "__main__": 81 | num_workers = 0 # 线程数 82 | 83 | # output文件夹,会根据当前时间命名文件夹。 84 | 85 | batch_size = 32 86 | input_size = (48, 48) 87 | num_classes = 7 88 | 89 | # 定义模型 90 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 91 | model = mini_XCEPTION(num_classes=7) 92 | model.load_state_dict(torch.load("output/E370_acc_0.6504.pth", map_location=device)) 93 | model.to(device) 94 | 95 | # 数据加载 96 | test_dataset = FER2013("test", input_size=input_size) 97 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 98 | 99 | # 优化器 100 | loss_fn = torch.nn.CrossEntropyLoss() 101 | 102 | # 开始评估 103 | eval() 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | import numpy as np 5 | from visualdl import LogWriter 6 | from torch.utils.data import DataLoader 7 | from sklearn.metrics import accuracy_score 8 | 9 | from utils.Model import mini_XCEPTION 10 | from utils.dataset import FER2013 11 | 12 | num_epochs = 200 13 | log_step = 100 # 打印info的间隔步数 14 | num_workers = 10 # 线程数 15 | 16 | # output文件夹,会根据当前时间命名文件夹。 17 | base_path = 'output/{}/'.format(datetime.datetime.now().strftime("%Y-%m-%d-%H.%M.%S")) 18 | writter = LogWriter(logdir=base_path) 19 | 20 | batch_size = 32 21 | input_size = (48, 48) 22 | num_classes = 7 23 | patience = 50 24 | 25 | if not os.path.exists(base_path): 26 | os.makedirs(base_path) 27 | 28 | # 定义模型 29 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 30 | model = mini_XCEPTION(num_classes=7) 31 | model.to(device) 32 | 33 | # 数据加载 34 | train_dataset = FER2013("train", input_size=input_size) 35 | test_dataset = FER2013("test", input_size=input_size) 36 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 37 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 38 | 39 | # 优化器 40 | optimizer = torch.optim.Adam(lr=0.001, params=model.parameters()) 41 | loss_fn = torch.nn.CrossEntropyLoss() 42 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 43 | mode='max', 44 | factor=0.1, 45 | patience=int(patience / 4), 46 | verbose=True) 47 | 48 | 49 | def train_f(): 50 | # 训练 51 | best_acc = 0 52 | step = 0 53 | for Epoch in range(0, num_epochs): 54 | total_train_loss, total_test_loss = 0, 0 55 | total_train_acc, total_test_acc = 0, 0 56 | count = 0 57 | end_index = len(train_loader) - 1 58 | model.train() 59 | for index, (labels, imgs) in enumerate(train_loader): 60 | imgs = imgs.to(device) 61 | labels_pd = model(imgs) 62 | # 记录acc和loss 63 | acc = accuracy_score(np.argmax(labels_pd.cpu().detach().numpy(), axis=-1), labels) 64 | total_train_acc += acc 65 | loss = loss_fn(labels_pd, labels.to(device)) 66 | total_train_loss += loss.item() 67 | count += 1 68 | # 更新梯度 69 | loss.backward() 70 | optimizer.step() 71 | optimizer.zero_grad() 72 | 73 | epoch_mean_acc = total_train_acc / count 74 | epoch_mean_loss = total_train_loss / count 75 | 76 | step += 1 77 | writter.add_scalar(tag="train_acc", step=step, value=epoch_mean_acc) 78 | writter.add_scalar(tag="train_loss", step=step, value=epoch_mean_loss) 79 | 80 | if index % log_step == 0 or index == end_index: 81 | print("e:{}\titer:{}/{}\tloss:{:.4f}\tacc:{:.4f}".format(Epoch, index, end_index, 82 | epoch_mean_loss, 83 | epoch_mean_acc)) 84 | count = 0 85 | model.eval() 86 | for index, (labels, imgs) in enumerate(test_loader): 87 | labels_pd = model(imgs.to(device)) 88 | acc = accuracy_score(np.argmax(labels_pd.cpu().detach().numpy(), axis=-1), labels) 89 | loss = loss_fn(labels_pd, labels.to(device)) 90 | total_test_loss += loss.item() 91 | total_test_acc += acc 92 | count += 1 93 | 94 | mean_test_loss = total_test_loss / count 95 | mean_test_acc = total_test_acc / count 96 | 97 | scheduler.step(mean_test_acc) 98 | print("evla\tloss:{:.4f}\tacc:{:.4f}".format(mean_test_loss, mean_test_acc)) 99 | 100 | writter.add_scalar(tag="test_acc", step=Epoch, value=mean_test_acc) 101 | writter.add_scalar(tag="test_loss", step=Epoch, value=mean_test_loss) 102 | 103 | if (total_test_acc / count) > best_acc: 104 | torch.save(model.state_dict(), "{}/E{}_acc_{:.4f}.pth".format(base_path, Epoch, total_test_acc / count)) 105 | best_acc = total_test_acc / count 106 | print("saved best model") 107 | 108 | 109 | if __name__ == "__main__": 110 | train_f() 111 | --------------------------------------------------------------------------------