├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── README.md ├── config.py ├── data └── make_train_test.py ├── dataloader.py ├── eval.py ├── model.py ├── scripts ├── eval.sh └── train.sh ├── tools.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/UCF 2 | data/train 3 | data/test 4 | data/train.csv 5 | data/test.csv 6 | 7 | checkpoints 8 | __pycache__ 9 | 10 | *.swp 11 | *.csv 12 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/zhouzhilong/anaconda3/envs/lstm-cnn/bin/python" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Video-Classification 2 | Make video classification on UCF101 using CNN and RNN with Pytorch framework. 3 | 4 | # Environments 5 | ```bash 6 | # 1. torch >= 1.0 7 | conda create -n crnn 8 | source activate crnn # or `conda activate crnn` 9 | # GPU version 10 | conda install pytorch torchvision cudatoolkit=9.0 -c pytorch 11 | # CPU version 12 | conda install pytorch-cpu torchvision-cpu -c pytorch 13 | 14 | # 2. pip dependencies 15 | pip install pandas scikit-learn tqdm opencv-python 16 | 17 | # 3. prepare datasets 18 | cd ./Pytorch-Video-Classification # go to the root dir of the code 19 | cp -r path/to/your/UCF ./data # copy UCF dataset to data dir 20 | cd ./data && python make_train_test.py # preprocess the dataset 21 | 22 | # 4. train your network on UCF-101 23 | python train.py 24 | 25 | # (optional)5. restore from checkpoints 26 | python train.py -r path/to/checkpoints/file 27 | ``` 28 | 29 | To know more about the usage of scripts, run the following commands: 30 | ```bash 31 | python train.py -h 32 | python make_train_test.py -h 33 | ``` 34 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | img_w = 224 2 | img_h = 224 3 | 4 | dataset_params = { 5 | 'batch_size': 128, 6 | 'shuffle': True, 7 | 'num_workers': 4, 8 | 'pin_memory': True 9 | } 10 | 11 | cnn_encoder_params = { 12 | 'cnn_out_dim': 256, 13 | 'drop_prob': 0.3, 14 | 'bn_momentum': 0.01 15 | } 16 | 17 | rnn_decoder_params = { 18 | 'use_gru': True, 19 | 'cnn_out_dim': 256, 20 | 'rnn_hidden_layers': 3, 21 | 'rnn_hidden_nodes': 256, 22 | 'num_classes': 10, 23 | 'drop_prob': 0.3 24 | } 25 | 26 | learning_rate = 1e-5 27 | epoches = 10 28 | log_interval = 2 # 打印间隔,默认每2个batch_size打印一次 29 | save_interval = 1 # 模型保存间隔,默认每个epoch保存一次 -------------------------------------------------------------------------------- /data/make_train_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | # 数据集的默认位置 8 | default_output_dir = os.path.dirname(os.path.abspath(__file__)) 9 | default_src_dir = os.path.join(default_output_dir, 'UCF') 10 | default_test_size = 0.2 11 | 12 | def split(src_dir=default_src_dir, output_dir=default_src_dir, size=default_test_size): 13 | # 设置默认参数 14 | src_dir = default_src_dir if src_dir is None else src_dir 15 | output_dir = default_output_dir if output_dir is None else output_dir 16 | size = default_test_size if size is None else size 17 | 18 | # 生成测试集和训练集目录 19 | for folder in ['train', 'test']: 20 | folder_path = os.path.join(output_dir, folder) 21 | if not os.path.exists(folder_path): 22 | os.mkdir(folder_path) 23 | print('Folder {} is created'.format(folder_path)) 24 | 25 | # 划分测试集和训练集 26 | train_set = [] 27 | test_set = [] 28 | classes = os.listdir(src_dir) 29 | num_classes = len(classes) 30 | for class_index, classname in enumerate(classes): 31 | # 读取所有视频路径 32 | videos = os.listdir(os.path.join(src_dir, classname)) 33 | # 打乱视频名称 34 | np.random.shuffle(videos) 35 | # 确定测试集划分点 36 | split_size = int(len(videos) * size) 37 | 38 | # 生成训练集和测试集的文件夹 39 | for i in range(2): 40 | part = ['train', 'test'][i] 41 | class_dir = os.path.join(output_dir, part, classname) 42 | if not os.path.exists(class_dir): 43 | os.mkdir(class_dir) 44 | 45 | # 遍历每个视频,将每个视频的图像帧提取出来 46 | for i in tqdm(range(len(videos)), desc='[%d/%d]%s' % (class_index + 1, num_classes, classname)): 47 | video_path = os.path.join(src_dir, classname, videos[i]) 48 | video_fd = cv2.VideoCapture(video_path) 49 | 50 | if not video_fd.isOpened(): 51 | print('Skpped: {}'.format(video_path)) 52 | continue 53 | 54 | video_type = 'test' if i <= split_size else 'train' 55 | 56 | frame_index = 0 57 | success, frame = video_fd.read() 58 | video_name = videos[i].rsplit('.')[0] 59 | while success: 60 | img_path = os.path.join(output_dir, video_type, classname, '%s_%d.jpg' % (video_name, frame_index)) 61 | cv2.imwrite(img_path, frame) 62 | info = [classname, video_name, img_path] 63 | # 将视频帧信息保存起来 64 | if video_type == 'test': 65 | test_set.append(info) 66 | else: 67 | train_set.append(info) 68 | frame_index += 1 69 | success, frame = video_fd.read() 70 | 71 | video_fd.release() 72 | 73 | # 将训练集和测试集数据保存到文件中,方便写dataloader 74 | datas = [train_set, test_set] 75 | names = ['train', 'test'] 76 | for i in range(2): 77 | with open(output_dir + '/' + names[i] + '.csv', 'w') as f: 78 | f.write('\n'.join([','.join(line) for line in datas[i]])) 79 | 80 | def parse_args(): 81 | parser = argparse.ArgumentParser(usage='python3 make_train_test.py -i path/to/UCF -o path/to/output -s 0.3') 82 | parser.add_argument('-i', '--src_dir', help='path to UCF datasets', default=default_src_dir) 83 | parser.add_argument('-o', '--output_dir', help='path to output', default=default_output_dir) 84 | parser.add_argument('-s', '--size', help='ratio of test sets', default=default_test_size) 85 | args = parser.parse_args() 86 | return args 87 | 88 | if __name__ == '__main__': 89 | args = parse_args() 90 | split(**vars(args)) 91 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | from tqdm import tqdm 5 | from PIL import Image 6 | import config 7 | 8 | class Dataset(data.Dataset): 9 | def __init__(self, data_list=[], skip_frame=1, time_step=30): 10 | ''' 11 | 定义一个数据集,从UCF101中读取数据 12 | ''' 13 | # 用来将类别转换为one-hot数据 14 | self.labels = [] 15 | # 用来缓存图片数据,直接加载到内存中 16 | self.images = [] 17 | # 是否直接加载至内存中,可以加快训练速 18 | self.use_mem = False 19 | 20 | self.skip_frame = skip_frame 21 | self.time_step = time_step 22 | self.data_list = self._build_data_list(data_list) 23 | 24 | def __len__(self): 25 | return len(self.data_list) // self.time_step 26 | 27 | def __getitem__(self, index): 28 | # 每次读取time_step帧图片 29 | index = index * self.time_step 30 | imgs = self.data_list[index:index + self.time_step] 31 | 32 | # 图片读取来源,如果设置了内存加速,则从内存中读取 33 | if self.use_mem: 34 | X = [self.images[x[3]] for x in imgs] 35 | else: 36 | X = [self._read_img_and_transform(x[2]) for x in imgs] 37 | 38 | # 转换成tensor 39 | X = torch.stack(X, dim=0) 40 | 41 | # 为这些图片指定类别标签 42 | y = torch.tensor(self._label_category(imgs[0][0])) 43 | return X, y 44 | 45 | def transform(self, img): 46 | return transforms.Compose([ 47 | transforms.Resize((config.img_w, config.img_h)), 48 | transforms.ToTensor(), 49 | transforms.Normalize( 50 | mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225] 52 | ) 53 | ])(img) 54 | 55 | def _read_img_and_transform(self, img:str): 56 | return self.transform(Image.open(img).convert('RGB')) 57 | 58 | def _build_data_list(self, data_list=[]): 59 | ''' 60 | 构建数据集 61 | ''' 62 | if len(data_list) == 0: 63 | return [] 64 | 65 | data_group = {} 66 | for x in tqdm(data_list, desc='Building dataset'): 67 | # 将视频分别按照classname和videoname分组 68 | [classname, videoname] = x[0:2] 69 | if classname not in data_group: 70 | data_group[classname] = {} 71 | if videoname not in data_group[classname]: 72 | data_group[classname][videoname] = [] 73 | 74 | # 将图片数据加载到内存 75 | if self.use_mem: 76 | self.images.append(self._read_img_and_transform(x[2])) 77 | 78 | data_group[classname][videoname].append(list(x) + [len(self.images) - 1]) 79 | 80 | # 处理类别变量 81 | self.labels = list(data_group.keys()) 82 | 83 | ret_list = [] 84 | n = 0 85 | 86 | # 填充数据 87 | for classname in data_group: 88 | video_group = data_group[classname] 89 | for videoname in video_group: 90 | # 如果某个视频的帧总数没法被time_step整除,那么需要按照最后一帧进行填充 91 | video_pad_count = len(video_group[videoname]) % self.time_step 92 | video_group[videoname] += [video_group[videoname][-1]] * (self.time_step - video_pad_count) 93 | ret_list += video_group[videoname] 94 | n += len(video_group[videoname]) 95 | 96 | return ret_list 97 | 98 | def _label_one_hot(self, label): 99 | ''' 100 | 将标签转换成one-hot形式 101 | ''' 102 | if label not in self.labels: 103 | raise RuntimeError('不存在的label!') 104 | one_hot = [0] * len(self.labels) 105 | one_hot[self.labels.index(label)] = 1 106 | return one_hot 107 | 108 | def _label_category(self, label): 109 | ''' 110 | 将标签转换成整型 111 | ''' 112 | if label not in self.labels: 113 | raise RuntimeError('不存在的label!') 114 | c_label = self.labels.index(label) 115 | return c_label 116 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from sklearn.metrics import accuracy_score 4 | from tqdm import tqdm 5 | from PIL import Image 6 | import pandas 7 | import os 8 | import argparse 9 | import cv2 10 | 11 | from dataloader import Dataset 12 | from model import CNNEncoder, RNNDecoder 13 | import config 14 | 15 | def load_imgs_from_video(path: str)->list: 16 | """Extract images from video. 17 | 18 | Args: 19 | path(str): The path of video. 20 | 21 | Returns: 22 | A list of PIL Image. 23 | """ 24 | video_fd = cv2.VideoCapture(path) 25 | video_fd.set(16, True) 26 | # flag 16: 'CV_CAP_PROP_CONVERT_RGB' 27 | # indicating the images should be converted to RGB. 28 | 29 | if not video_fd.isOpened(): 30 | raise ValueError('Invalid path! which is: {}'.format(path)) 31 | 32 | images = [] # type: list[Image] 33 | 34 | success, frame = video_fd.read() 35 | while success: 36 | images.append(Image.fromarray(frame)) 37 | success, frame = video_fd.read() 38 | 39 | return images 40 | 41 | def _eval(checkpoint: str, video_path: str, labels=[])->list: 42 | """Inference the model and return the labels. 43 | 44 | Args: 45 | checkpoint(str): The checkpoint where the model restore from. 46 | path(str): The path of videos. 47 | labels(list): Labels of videos. 48 | 49 | Returns: 50 | A list of labels of the videos. 51 | """ 52 | if not os.path.exists(video_path): 53 | raise ValueError('Invalid path! which is: {}'.format(video_path)) 54 | 55 | print('Loading model from {}'.format(checkpoint)) 56 | use_cuda = torch.cuda.is_available() 57 | device = torch.device('cuda' if use_cuda else 'cpu') 58 | 59 | # Build model 60 | model = nn.Sequential( 61 | CNNEncoder(**config.cnn_encoder_params), 62 | RNNDecoder(**config.rnn_decoder_params) 63 | ) 64 | model.to(device) 65 | model.eval() 66 | 67 | # Load model 68 | ckpt = torch.load(checkpoint) 69 | model.load_state_dict(ckpt['model_state_dict']) 70 | print('Model has been loaded from {}'.format(checkpoint)) 71 | 72 | label_map = [-1] * config.rnn_decoder_params['num_classes'] 73 | # load label map 74 | if 'label_map' in ckpt: 75 | label_map = ckpt['label_map'] 76 | 77 | # Do inference 78 | pred_labels = [] 79 | video_names = os.listdir(video_path) 80 | with torch.no_grad(): 81 | for video in tqdm(video_names, desc='Inferencing'): 82 | # read images from video 83 | images = load_imgs_from_video(os.path.join(video_path, video)) 84 | # apply transform 85 | images = [Dataset.transform(None, img) for img in images] 86 | # stack to tensor, batch size = 1 87 | images = torch.stack(images, dim=0).unsqueeze(0) 88 | # do inference 89 | images = images.to(device) 90 | pred_y = model(images) # type: torch.Tensor 91 | pred_y = pred_y.argmax(dim=1).cpu().numpy().tolist() 92 | pred_labels.append([video, pred_y[0], label_map[pred_y[0]]]) 93 | print(pred_labels[-1]) 94 | 95 | if len(labels) > 0: 96 | acc = accuracy_score(pred_labels, labels) 97 | print('Accuracy: %0.2f' % acc) 98 | 99 | # Save results 100 | pandas.DataFrame(pred_labels).to_csv('result.csv', index=False) 101 | print('Results has been saved to {}'.format('result.csv')) 102 | 103 | return pred_labels 104 | 105 | def parse_args(): 106 | parser = argparse.ArgumentParser(usage='python3 eval.py -i path/to/videos -r path/to/checkpoint') 107 | parser.add_argument('-i', '--video_path', help='path to videos') 108 | parser.add_argument('-r', '--checkpoint', help='path to the checkpoint') 109 | args = parser.parse_args() 110 | return args 111 | 112 | if __name__ == "__main__": 113 | args = parse_args() 114 | _eval(args.checkpoint, args.video_path) 115 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torchvision import models 5 | 6 | class CNNEncoder(nn.Module): 7 | def __init__(self, cnn_out_dim=256, drop_prob=0.3, bn_momentum=0.01): 8 | ''' 9 | 使用pytorch提供的预训练模型作为encoder 10 | ''' 11 | super(CNNEncoder, self).__init__() 12 | 13 | self.cnn_out_dim = cnn_out_dim 14 | self.drop_prob = drop_prob 15 | self.bn_momentum = bn_momentum 16 | 17 | # 使用resnet预训练模型来提取特征,去掉最后一层分类器 18 | pretrained_cnn = models.resnet152(pretrained=True) 19 | cnn_layers = list(pretrained_cnn.children())[:-1] 20 | 21 | # 把resnet的最后一层fc层去掉,用来提取特征 22 | self.cnn = nn.Sequential(*cnn_layers) 23 | # 将特征embed成cnn_out_dim维向量 24 | self.fc = nn.Sequential( 25 | *[ 26 | self._build_fc(pretrained_cnn.fc.in_features, 512, True), 27 | nn.ReLU(), 28 | self._build_fc(512, 512, True), 29 | nn.ReLU(), 30 | nn.Dropout(p=self.drop_prob), 31 | self._build_fc(512, self.cnn_out_dim, False) 32 | ] 33 | ) 34 | 35 | def _build_fc(self, in_features, out_features, with_bn=True): 36 | return nn.Sequential( 37 | nn.Linear(in_features, out_features), 38 | nn.BatchNorm1d(out_features, momentum=self.bn_momentum) 39 | ) if with_bn else nn.Linear(in_features, out_features) 40 | 41 | def forward(self, x_3d): 42 | ''' 43 | 输入的是T帧图像,shape = (batch_size, t, h, w, 3) 44 | ''' 45 | cnn_embedding_out = [] 46 | for t in range(x_3d.size(1)): 47 | # 使用cnn提取特征 48 | # 为什么要用到no_grad()? 49 | # -- 因为我们使用的预训练模型,防止后续的层训练时反向传播而影响前面的层 50 | with torch.no_grad(): 51 | x = self.cnn(x_3d[:, t, :, :, :]) 52 | x = torch.flatten(x, start_dim=1) 53 | 54 | # 处理fc层 55 | x = self.fc(x) 56 | 57 | cnn_embedding_out.append(x) 58 | 59 | cnn_embedding_out = torch.stack(cnn_embedding_out, dim=0).transpose(0, 1) 60 | 61 | return cnn_embedding_out 62 | 63 | class RNNDecoder(nn.Module): 64 | def __init__(self, use_gru=True, cnn_out_dim=256, rnn_hidden_layers=3, rnn_hidden_nodes=256, 65 | num_classes=10, drop_prob=0.3): 66 | super(RNNDecoder, self).__init__() 67 | 68 | self.rnn_input_features = cnn_out_dim 69 | self.rnn_hidden_layers = rnn_hidden_layers 70 | self.rnn_hidden_nodes = rnn_hidden_nodes 71 | 72 | self.drop_prob = drop_prob 73 | self.num_classes = num_classes # 这里调整分类数目 74 | 75 | # rnn配置参数 76 | rnn_params = { 77 | 'input_size': self.rnn_input_features, 78 | 'hidden_size': self.rnn_hidden_nodes, 79 | 'num_layers': self.rnn_hidden_layers, 80 | 'batch_first': True 81 | } 82 | 83 | # 使用lstm或者gru作为rnn层 84 | self.rnn = (nn.GRU if use_gru else nn.LSTM)(**rnn_params) 85 | 86 | # rnn层输出到线性分类器 87 | self.fc = nn.Sequential( 88 | nn.Linear(self.rnn_hidden_nodes, 128), 89 | nn.ReLU(), 90 | nn.Dropout(self.drop_prob), 91 | nn.Linear(128, self.num_classes) 92 | ) 93 | 94 | def forward(self, x_rnn): 95 | self.rnn.flatten_parameters() 96 | rnn_out, _ = self.rnn(x_rnn, None) 97 | # 注意,前面定义rnn模块时,batch_first=True保证了以下结构: 98 | # rnn_out shape: (batch, timestep, output_size) 99 | # h_n and h_c shape: (n_layers, batch, hidden_size) 100 | 101 | x = self.fc(rnn_out[:, -1, :]) # 只抽取最后一层做输出 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python eval.py -i ./data/UCF/Diving -r ./checkpoints/ep-0_patched.pth 2 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CUDA_VISIBLE_DEVICES=0 python train.py 4 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas 3 | import argparse 4 | 5 | from dataloader import Dataset 6 | import config 7 | 8 | def merge_labels_to_ckpt(ck_path:str, train_file:str): 9 | '''Merge labels to a checkpoint file. 10 | 11 | Args: 12 | ck_path(str): path to checkpoint file 13 | train_file(str): path to train set index file, eg. train.csv 14 | 15 | Return: 16 | This function will create a {ck_path}_patched.pth file. 17 | ''' 18 | # load model 19 | print('Loading checkpoint') 20 | ckpt = torch.load(ck_path) 21 | 22 | # load train files 23 | print('Loading dataset') 24 | raw_data = pandas.read_csv(train_file) 25 | train_set = Dataset(raw_data.to_numpy()) 26 | 27 | # patch file name 28 | print('Patching') 29 | patch_path = ck_path.replace('.pth', '') + '_patched.pth' 30 | 31 | ck_dict = { 'label_map': train_set.labels } 32 | names = ['epoch', 'model_state_dict', 'optimizer_state_dict'] 33 | for name in names: 34 | ck_dict[name] = ckpt[name] 35 | 36 | torch.save(ck_dict, patch_path) 37 | print('Patched checkpoint has been saved to {}'.format(patch_path)) 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser(usage='python3 tools.py -i path/to/train.csv -r path/to/checkpoint') 41 | parser.add_argument('-i', '--data_path', help='path to your dataset index file') 42 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint', default=None) 43 | args = parser.parse_args() 44 | return args 45 | 46 | if __name__ == '__main__': 47 | args = parse_args() 48 | merge_labels_to_ckpt(args.restore_from, args.data_path) 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.utils.data import DataLoader 5 | from torchvision import models 6 | from sklearn.metrics import accuracy_score 7 | from tqdm import tqdm 8 | import pandas 9 | import json 10 | import os 11 | import argparse 12 | 13 | from model import CNNEncoder, RNNDecoder 14 | from dataloader import Dataset 15 | import config 16 | 17 | def train_on_epochs(train_loader:DataLoader, test_loader:DataLoader, restore_from:str=None): 18 | # 配置训练时环境 19 | use_cuda = torch.cuda.is_available() 20 | device = torch.device('cuda' if use_cuda else 'cpu') 21 | 22 | # 实例化计算图模型 23 | model = nn.Sequential( 24 | CNNEncoder(**config.cnn_encoder_params), 25 | RNNDecoder(**config.rnn_decoder_params) 26 | ) 27 | model.to(device) 28 | 29 | # 多GPU训练 30 | device_count = torch.cuda.device_count() 31 | if device_count > 1: 32 | print('使用{}个GPU训练'.format(device_count)) 33 | model = nn.DataParallel(model) 34 | 35 | ckpt = {} 36 | # 从断点继续训练 37 | if restore_from is not None: 38 | ckpt = torch.load(restore_from) 39 | model.load_state_dict(ckpt['model_state_dict']) 40 | print('Model is loaded from %s' % (restore_from)) 41 | 42 | # 提取网络参数,准备进行训练 43 | model_params = model.parameters() 44 | 45 | # 设定优化器 46 | optimizer = torch.optim.Adam(model_params, lr=config.learning_rate) 47 | 48 | if restore_from is not None: 49 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 50 | 51 | # 训练时数据 52 | info = { 53 | 'train_losses': [], 54 | 'train_scores': [], 55 | 'test_losses': [], 56 | 'test_scores': [] 57 | } 58 | 59 | start_ep = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0 60 | 61 | save_path = './checkpoints' 62 | if not os.path.exists(save_path): 63 | os.mkdir(save_path) 64 | 65 | # 开始训练 66 | for ep in range(start_ep, config.epoches): 67 | train_losses, train_scores = train(model, train_loader, optimizer, ep, device) 68 | test_loss, test_score = validation(model, test_loader, optimizer, ep, device) 69 | 70 | # 保存信息 71 | info['train_losses'].append(train_losses) 72 | info['train_scores'].append(train_scores) 73 | info['test_losses'].append(test_loss) 74 | info['test_scores'].append(test_score) 75 | 76 | # 保存模型 77 | ckpt_path = os.path.join(save_path, 'ep-%d.pth' % ep) 78 | if (ep + 1) % config.save_interval == 0: 79 | torch.save({ 80 | 'epoch': ep, 81 | 'model_state_dict': model.state_dict(), 82 | 'optimizer_state_dict': optimizer.state_dict(), 83 | 'label_map': train_loader.dataset.labels 84 | }, ckpt_path) 85 | print('Model of Epoch %3d has been saved to: %s' % (ep, ckpt_path)) 86 | 87 | with open('./train_info.json', 'w') as f: 88 | json.dump(info, f) 89 | 90 | print('训练结束') 91 | 92 | def load_data_list(file_path): 93 | return pandas.read_csv(file_path).to_numpy() 94 | 95 | def train(model:nn.Sequential, dataloader:torch.utils.data.DataLoader, optimizer:torch.optim.Optimizer, epoch, device): 96 | model.train() 97 | 98 | train_losses = [] 99 | train_scores = [] 100 | 101 | print('Size of Training Set: ', len(dataloader.dataset)) 102 | 103 | for i, (X, y) in enumerate(dataloader): 104 | X = X.to(device) 105 | y = y.to(device) 106 | 107 | # 初始化优化器参数 108 | optimizer.zero_grad() 109 | # 执行前向传播 110 | y_ = model(X) 111 | 112 | # 计算loss 113 | loss = F.cross_entropy(y_, y) 114 | # 反向传播梯度 115 | loss.backward() 116 | optimizer.step() 117 | 118 | y_ = y_.argmax(dim=1) 119 | acc = accuracy_score(y_.cpu().numpy(), y.cpu().numpy()) 120 | 121 | # 保存loss等信息 122 | train_losses.append(loss.item()) 123 | train_scores.append(acc) 124 | 125 | if (i + 1) % config.log_interval == 0: 126 | print('[Epoch %3d]Training %3d of %3d: acc = %.2f, loss = %.2f' % (epoch, i + 1, len(dataloader), acc, loss.item())) 127 | 128 | return train_losses, train_scores 129 | 130 | def validation(model:nn.Sequential, test_loader:torch.utils.data.DataLoader, optimizer:torch.optim.Optimizer, epoch:int, device:int): 131 | model.eval() 132 | 133 | print('Size of Test Set: ', len(test_loader.dataset)) 134 | 135 | # 准备在测试集上验证模型性能 136 | test_loss = 0 137 | y_gd = [] 138 | y_pred = [] 139 | 140 | # 不需要反向传播,关闭求导 141 | with torch.no_grad(): 142 | for X, y in tqdm(test_loader, desc='Validating'): 143 | # 对测试集中的数据进行预测 144 | X, y = X.to(device), y.to(device) 145 | y_ = model(X) 146 | 147 | # 计算loss 148 | loss = F.cross_entropy(y_, y, reduction='sum') 149 | test_loss += loss.item() 150 | 151 | # 收集prediction和ground truth 152 | y_ = y_.argmax(dim=1) 153 | y_gd += y.cpu().numpy().tolist() 154 | y_pred += y_.cpu().numpy().tolist() 155 | 156 | # 计算loss 157 | test_loss /= len(test_loader) 158 | # 计算正确率 159 | test_acc = accuracy_score(y_gd, y_pred) 160 | 161 | print('[Epoch %3d]Test avg loss: %0.4f, acc: %0.2f\n' % (epoch, test_loss, test_acc)) 162 | 163 | return test_loss, test_acc 164 | 165 | def parse_args(): 166 | parser = argparse.ArgumentParser(usage='python3 train.py -i path/to/data -r path/to/checkpoint') 167 | parser.add_argument('-i', '--data_path', help='path to your datasets', default='./data') 168 | parser.add_argument('-r', '--restore_from', help='path to the checkpoint', default=None) 169 | args = parser.parse_args() 170 | return args 171 | 172 | if __name__ == "__main__": 173 | args = parse_args() 174 | data_path = args.data_path 175 | 176 | # 准备数据加载器 177 | dataloaders = {} 178 | for name in ['train', 'test']: 179 | raw_data = pandas.read_csv(os.path.join(data_path, '%s.csv' % name)) 180 | dataloaders[name] = DataLoader(Dataset(raw_data.to_numpy()), **config.dataset_params) 181 | train_on_epochs(dataloaders['train'], dataloaders['test'], args.restore_from) 182 | --------------------------------------------------------------------------------