├── Framework.jpg ├── README.md ├── test.py ├── tripletloss.py ├── train.py ├── functions.py └── model.py /Framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Marsrocky/GaitFi/HEAD/Framework.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GaitFi: Robust Device-Free Human Identification via WiFi and Vision Multimodal Learning [[link]](https://doi.org/10.48550/arXiv.2208.14326) 2 | 3 | ## Introduction 4 | As an important biomarker for human identification, human gait can be collected at a distance by passive sensors without subject cooperation, which plays an essential role in crime prevention, security detection and other human identification applications. At present, most research works are based on cameras and computer vision techniques to perform gait recognition. However, vision-based methods are not reliable when confronting poor illuminations, leading to degrading performances. In this paper, we propose a novel multimodal gait recognition method, namely GaitFi, which leverages WiFi signals and videos for human identification. In GaitFi, Channel State Information (CSI) that reflects the multi-path propagation of WiFi is collected to capture human gaits, while videos are captured by cameras. To learn robust gait information, we propose a Lightweight Residual Convolution Network (LRCN) as the backbone network, and further propose the two-stream GaitFi by integrating WiFi and vision features for the gait retrieval task. The GaitFi is trained by the triplet loss and classification loss on different levels of features. Extensive experiments are conducted in the real world, which demonstrates that the GaitFi outperforms state-of-the-art gait recognition methods based on single WiFi or camera, achieving 94.2\% for human identification tasks of 12 subjects. 5 | 6 | ![framework](https://github.com/Marsrocky/GaitFi/blob/main/Framework.jpg) 7 | 8 | ## Requirements 9 | 10 | ``` 11 | scipy - 1.5.4 12 | numpy - 1.21.5 13 | torchvision - 0.11.2 14 | pytorch - 1.7.0 15 | ``` 16 | 17 | 18 | 19 | ## Training 20 | Train using vision modality only: `python train.py --input_type image` 21 | 22 | Train using WiFi modality only: `python train.py --input_type mat` 23 | 24 | Train using both vision and WiFi modality: `python train.py --input_type both` 25 | 26 | ## Testing 27 | Copy the model saved to save_models in training to best_models 28 | 29 | Test using vision modality only: `python test.py --input_type image` 30 | 31 | Test using WiFi modality only: `python test.py --input_type mat` 32 | 33 | Test using WiFi modality only: `python test.py --input_type both` 34 | 35 | 36 | 37 | ## Model 38 | 39 | The GaitFi has the following components: 40 | 41 | - ***class*** **CNN** : LRCN block 42 | - ***class*** **RNN** : LSTM block 43 | - ***class*** **CRNN**: Fusion mechanism of WiFi and vision modalities 44 | 45 | ## Reference 46 | 47 | ``` 48 | @ARTICLE{9887951, 49 | author={Deng, Lang and Yang, Jianfei and Yuan, Shenghai and Zou, Han and Lu, Chris Xiaoxuan and Xie, Lihua}, 50 | journal={IEEE Internet of Things Journal}, 51 | title={GaitFi: Robust Device-Free Human Identification via WiFi and Vision Multimodal Learning}, 52 | year={2022}, 53 | publisher={IEEE}, 54 | doi={10.1109/JIOT.2022.3203559}} 55 | ``` 56 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.utils.data as data 6 | from functions import labels2cat, Dataset_CRNN, validation, acc_calculate 7 | from sklearn.preprocessing import OneHotEncoder, LabelEncoder 8 | import random 9 | import argparse 10 | import warnings 11 | import json 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | if __name__ == '__main__': 16 | # set parameters 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--train_image_path', type=str, default='./datasets/Train/image') 19 | parser.add_argument('--train_mat_path', type=str, default='./datasets/Train/mat') 20 | parser.add_argument('--test_image_path', type=str, default='./datasets/Test/image') 21 | parser.add_argument('--test_mat_path', type=str, default='./datasets/Test/mat') 22 | parser.add_argument('--img_x', type=int, default=64) 23 | parser.add_argument('--img_y', type=int, default=64) 24 | parser.add_argument('--k', type=int, default=12) 25 | parser.add_argument('--batch_size', type=int, default=32) 26 | parser.add_argument('--n_frames', type=int, default=32) 27 | parser.add_argument('--num_workers', type=int, default=0) 28 | parser.add_argument('--input_type', type=str, default='both', choices=['image', 'mat', 'both']) 29 | parser.add_argument('--seed', type=int, default=233) 30 | args = parser.parse_args() 31 | 32 | args.load_model_path = f'./best_models/crnn_best_{args.input_type}.pt' 33 | 34 | random.seed(args.seed) 35 | np.random.seed(args.seed) 36 | torch.manual_seed(args.seed) 37 | torch.cuda.manual_seed(args.seed) 38 | torch.backends.cudnn.deterministic = True 39 | torch.backends.cudnn.benchmark = False 40 | 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use CPU or GPU 42 | 43 | # convert labels -> category 44 | action_names = os.listdir(args.train_image_path) 45 | 46 | le = LabelEncoder() 47 | le.fit(action_names) 48 | 49 | # show how many classes 50 | print('labels:{}'.format(list(le.classes_))) 51 | 52 | # convert category -> 1-hot 53 | action_category = le.transform(action_names).reshape(-1, 1) 54 | enc = OneHotEncoder() 55 | enc.fit(action_category) 56 | 57 | train_actions = [] 58 | train_all_names = [] 59 | test_actions = [] 60 | test_all_names = [] 61 | for action in action_names: 62 | for f_name in os.listdir(f'{args.train_image_path}/{action}'): 63 | train_actions.append(action) 64 | train_all_names.append(f'{action}/{f_name}') 65 | 66 | for f_name in os.listdir(f'{args.test_image_path}/{action}'): 67 | test_actions.append(action) 68 | test_all_names.append(f'{action}/{f_name}') 69 | 70 | train_list = train_all_names 71 | train_label = labels2cat(le, train_actions) 72 | test_list = test_all_names # all video file names 73 | test_label = labels2cat(le, test_actions) # all video labels 74 | 75 | transform = transforms.Compose([transforms.Resize([args.img_x, args.img_y]), transforms.ToTensor(), 76 | transforms.Normalize(mean=[0.5], std=[0.5])]) # 串联多个图片变换 77 | 78 | train_set = Dataset_CRNN(args.train_image_path, args.train_mat_path, 79 | train_list, train_label, args.n_frames, transform=transform, input_type=args.input_type) 80 | test_set = Dataset_CRNN(args.test_image_path, args.test_mat_path, 81 | test_list, test_label, args.n_frames, transform=transform, input_type=args.input_type) 82 | 83 | train_loader = data.DataLoader(train_set, batch_size=args.batch_size, 84 | shuffle=False, num_workers=args.num_workers) 85 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, 86 | shuffle=False, num_workers=args.num_workers) 87 | 88 | # Create model 89 | model = torch.load(args.load_model_path) 90 | 91 | gallery_feat, gallery_label, prob_feat, prob_label = validation(model, device, train_loader, test_loader) 92 | 93 | gallery_feat = torch.cat(gallery_feat) 94 | gallery_label = torch.cat(gallery_label) 95 | prob_feat = torch.cat(prob_feat) 96 | prob_label = torch.cat(prob_label) 97 | 98 | test_correct, test_total = acc_calculate(gallery_feat, gallery_label, prob_feat, prob_label) 99 | 100 | with open(f'./outputs/saved_outputs_{args.input_type}.json', 'w', encoding='utf-8') as f: 101 | f.write(json.dumps({ 102 | 'gallery_feat': gallery_feat.detach().cpu().numpy().tolist(), 103 | 'gallery_label': gallery_label.detach().cpu().numpy().tolist(), 104 | 'prob_feat': prob_feat.detach().cpu().numpy().tolist(), 105 | 'prob_label': prob_label.detach().cpu().numpy().tolist() 106 | })) 107 | 108 | test_acc = test_correct / test_total * 100 109 | print('test_acc:{:.3f}%'.format(test_acc)) 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /tripletloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist.addmm_(1, -2, x, y.t()) 29 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 30 | return dist 31 | 32 | 33 | def hard_example_mining(dist_mat, labels, mask=None, return_inds=False): 34 | """For each anchor, find the hardest positive and negative sample. 35 | Args: 36 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 37 | labels: pytorch LongTensor, with shape [N] 38 | mask: pytorch Tensor, with shape [N, N] 39 | return_inds: whether to return the indices. Save time if `False`(?) 40 | Returns: 41 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 42 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 43 | p_inds: pytorch LongTensor, with shape [N]; 44 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 45 | n_inds: pytorch LongTensor, with shape [N]; 46 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 47 | NOTE: Only consider the case in which all labels have same num of samples, 48 | thus we can cope with all anchors in parallel. 49 | """ 50 | 51 | assert len(dist_mat.size()) == 2 52 | assert dist_mat.size(0) == dist_mat.size(1) 53 | N = dist_mat.size(0) 54 | 55 | # shape [N, N] 56 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()).float() 57 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()).float() 58 | 59 | # `dist_ap` means distance(anchor, positive) 60 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 61 | if mask is None: 62 | mask = torch.ones_like(dist_mat) 63 | 64 | aux_mat = torch.zeros_like(dist_mat) 65 | aux_mat[mask==0] -= 10 66 | dist_mat = dist_mat + aux_mat 67 | 68 | dist_ap, relative_p_inds = torch.max( 69 | (dist_mat * is_pos).contiguous().view(N, -1), 1, keepdim=True) 70 | 71 | 72 | 73 | # `dist_an` means distance(anchor, negative) 74 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 75 | # dist_mat[dist_mat == 0] += 10000 # 处理非法值。归一化后的最大距离为2 76 | aux_mat = torch.zeros_like(dist_mat) 77 | aux_mat[mask==0] += 10000 78 | dist_mat = dist_mat + aux_mat 79 | dist_an, relative_n_inds = torch.min( 80 | (dist_mat * is_neg).contiguous().view(N, -1), 1, keepdim=True) 81 | # shape [N] 82 | 83 | 84 | 85 | dist_ap = dist_ap.squeeze(1) 86 | dist_an = dist_an.squeeze(1) 87 | 88 | if return_inds: 89 | # shape [N, N] 90 | ind = (labels.new().resize_as_(labels) 91 | .copy_(torch.arange(0, N).long()) 92 | .unsqueeze(0).expand(N, N)) 93 | # shape [N, 1] 94 | p_inds = torch.gather( 95 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 96 | n_inds = torch.gather( 97 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 98 | # shape [N] 99 | p_inds = p_inds.squeeze(1) 100 | n_inds = n_inds.squeeze(1) 101 | return dist_ap, dist_an, p_inds, n_inds 102 | 103 | return dist_ap, dist_an 104 | 105 | 106 | class TripletLoss(object): 107 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 108 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 109 | Loss for Person Re-Identification'.""" 110 | 111 | def __init__(self, margin=None): 112 | self.margin = margin 113 | if margin is not None: 114 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 115 | else: 116 | self.ranking_loss = nn.SoftMarginLoss() 117 | 118 | def __call__(self, global_feat, labels, mask=None, normalize_feature=False): 119 | """ 120 | 121 | :param global_feat: 122 | :param labels: 123 | :param mask: [N, N] 可见性mask。不可见的mask将不会被选择。若全部不可见,则对结果*0 124 | :param normalize_feature: 125 | :return: 126 | """ 127 | if normalize_feature: 128 | global_feat = normalize(global_feat, axis=-1) 129 | dist_mat = euclidean_dist(global_feat, global_feat) 130 | dist_ap, dist_an = hard_example_mining( 131 | dist_mat, labels, mask=mask) 132 | y = dist_an.new().resize_as_(dist_an).fill_(1) 133 | if self.margin is not None: 134 | loss = self.ranking_loss(dist_an, dist_ap, y) 135 | else: 136 | loss = self.ranking_loss(dist_an - dist_ap, y) 137 | return loss#, dist_ap, dist_an -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.utils.data as data 6 | from functions import labels2cat, Dataset_CRNN, train 7 | from model import CRNN 8 | from sklearn.preprocessing import OneHotEncoder, LabelEncoder 9 | from tripletloss import TripletLoss 10 | import random 11 | import argparse 12 | import warnings 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | if __name__ == '__main__': 17 | # set parameters 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--train_image_path', type=str, default='./datasets/Train/image') 20 | parser.add_argument('--train_mat_path', type=str, default='./datasets/Train/mat') 21 | parser.add_argument('--test_image_path', type=str, default='./datasets/Test/image') 22 | parser.add_argument('--test_mat_path', type=str, default='./datasets/Test/mat') 23 | parser.add_argument('--save_model_path', type=str, default='./save_models/') 24 | parser.add_argument('--CNN_fc_hidden1', type=int, default=64) 25 | parser.add_argument('--CNN_fc_hidden2', type=int, default=64) 26 | parser.add_argument('--CNN_embed_dim', type=int, default=64) 27 | parser.add_argument('--img_x', type=int, default=64) 28 | parser.add_argument('--img_y', type=int, default=64) 29 | parser.add_argument('--dropout_p', type=float, default=0.4) 30 | parser.add_argument('--RNN_hidden_layers', type=int, default=1) 31 | parser.add_argument('--RNN_hidden_nodes', type=int, default=64) 32 | parser.add_argument('--RNN_FC_dim', type=int, default=64) 33 | parser.add_argument('--k', type=int, default=12) 34 | parser.add_argument('--epochs', type=int, default=30) 35 | parser.add_argument('--batch_size', type=int, default=32) 36 | parser.add_argument('--learning_rate', type=float, default=0.001) 37 | parser.add_argument('--alpha', type=float, default=0.001) 38 | parser.add_argument('--n_frames', type=int, default=32) 39 | parser.add_argument('--num_workers', type=int, default=0) 40 | parser.add_argument('--input_type', type=str, default='both', choices=['image', 'mat', 'both']) 41 | parser.add_argument('--seed', type=int, default=233) 42 | args = parser.parse_args() 43 | 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | torch.cuda.manual_seed(args.seed) 48 | torch.backends.cudnn.deterministic = True 49 | torch.backends.cudnn.benchmark = False 50 | 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use CPU or GPU 52 | 53 | # convert labels -> category 54 | action_names = os.listdir(args.train_image_path) 55 | 56 | le = LabelEncoder() 57 | le.fit(action_names) 58 | 59 | # show how many classes 60 | print('labels:{}'.format(list(le.classes_))) 61 | 62 | # convert category -> 1-hot 63 | action_category = le.transform(action_names).reshape(-1, 1) 64 | enc = OneHotEncoder() 65 | enc.fit(action_category) 66 | 67 | train_actions = [] 68 | train_all_names = [] 69 | test_actions = [] 70 | test_all_names = [] 71 | for action in action_names: 72 | for f_name in os.listdir(f'{args.train_image_path}/{action}'): 73 | train_actions.append(action) 74 | train_all_names.append(f'{action}/{f_name}') 75 | 76 | for f_name in os.listdir(f'{args.test_image_path}/{action}'): 77 | test_actions.append(action) 78 | test_all_names.append(f'{action}/{f_name}') 79 | 80 | train_list = train_all_names 81 | train_label = labels2cat(le, train_actions) 82 | test_list = test_all_names # all video file names 83 | test_label = labels2cat(le, test_actions) # all video labels 84 | 85 | transform = transforms.Compose([transforms.Resize([args.img_x, args.img_y]), transforms.ToTensor(), 86 | transforms.Normalize(mean=[0.5], std=[0.5])]) # 串联多个图片变换 87 | 88 | train_set = Dataset_CRNN(args.train_image_path, args.train_mat_path, 89 | train_list, train_label, args.n_frames, transform=transform, input_type=args.input_type) 90 | test_set = Dataset_CRNN(args.test_image_path, args.test_mat_path, 91 | test_list, test_label, args.n_frames, transform=transform, input_type=args.input_type) 92 | 93 | train_loader = data.DataLoader(train_set, batch_size=args.batch_size, 94 | shuffle=True, num_workers=args.num_workers) 95 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, 96 | shuffle=False, num_workers=args.num_workers) 97 | 98 | # import mat shape 99 | mat_x, mat_y = 114, 500 100 | 101 | # Create model 102 | model = CRNN(args.img_x, args.img_y, mat_x, mat_y, args.CNN_fc_hidden1, 103 | args.CNN_fc_hidden2, args.CNN_embed_dim, args.RNN_hidden_layers, 104 | args.RNN_hidden_nodes, args.RNN_FC_dim, args.dropout_p, args.k, args.input_type).to(device) 105 | 106 | metric_loss = TripletLoss(margin=0.3) 107 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 108 | 109 | # start training 110 | best_valid_acc = 0.0 111 | best_test_acc = 0.0 112 | for epoch in range(args.epochs): 113 | # train, test model 114 | # model.train() 115 | train_loss = train(model, device, train_loader, optimizer, metric_loss, args.alpha) 116 | print('Epoch:{} train_loss:{:.6f}'.format(epoch + 1, train_loss)) 117 | 118 | # save Pytorch models of best record 119 | torch.save(model, os.path.join(args.save_model_path, 120 | 'crnn_best_{}.pt'.format(args.input_type))) # save best model 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils import data 6 | from sklearn.metrics import accuracy_score, f1_score 7 | import torch 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | import scipy.io as scio 11 | 12 | 13 | def labels2cat(label_encoder, list): 14 | return label_encoder.transform(list) 15 | 16 | 17 | def labels2onehot(OneHotEncoder, label_encoder, list): 18 | return OneHotEncoder.transform(label_encoder.transform(list).reshape(-1, 1)).toarray() 19 | 20 | 21 | def onehot2labels(label_encoder, y_onehot): 22 | return label_encoder.inverse_transform(np.where(y_onehot == 1)[1]).tolist() 23 | 24 | 25 | def cat2labels(label_encoder, y_cat): 26 | return label_encoder.inverse_transform(y_cat).tolist() 27 | 28 | 29 | class Dataset_CRNN(data.Dataset): 30 | "Characterizes a dataset for PyTorch" 31 | def __init__(self, image_path, mat_path, folders, labels, n_frames, transform=None, input_type='image'): 32 | "Initialization" 33 | self.image_path = image_path 34 | self.mat_path = mat_path 35 | self.labels = labels 36 | self.folders = folders 37 | self.transform = transform 38 | self.n_frames = n_frames 39 | self.input_type = input_type 40 | 41 | def __len__(self): 42 | "Denotes the total number of samples" 43 | return len(self.folders) 44 | 45 | def read_images(self, image_path, selected_folder, use_transform): 46 | names = os.listdir(f'{image_path}/{selected_folder}') 47 | assert len(names) > 0, f'please remove the dir {image_path}/{selected_folder} where exists {len(names)} images.' 48 | 49 | if len(names) > self.n_frames: 50 | names = random.sample(names, self.n_frames) 51 | else: 52 | names += [names[-1]] * (self.n_frames - len(names)) 53 | names = sorted(names, key=lambda info: (int(info[0:-4]), info[-4:])) 54 | 55 | images = [] 56 | for name in names: 57 | image = Image.open(f'{image_path}/{selected_folder}/{name}') 58 | if use_transform is not None: 59 | image = use_transform(image) 60 | images.append(image) 61 | 62 | images = torch.stack(images, dim=0) 63 | return images 64 | 65 | def read_mat(self, mat_path, selected_folder): 66 | mat = scio.loadmat(f'{mat_path}/{selected_folder}.mat')['CSIamp'] 67 | 68 | # normalize 69 | mat = (mat - 42.3199) / 4.9802 70 | 71 | # sampling: 2000 -> 500 72 | mat = mat[:, 100::3] 73 | mat = mat.reshape(3, 114, 500) 74 | 75 | # x = np.expand_dims(x, axis=0) 76 | mat = torch.FloatTensor(mat) 77 | mat = torch.tensor(mat, dtype=torch.float32) 78 | return mat 79 | 80 | def __getitem__(self, index): 81 | # Select sample 82 | folder = self.folders[index] 83 | 84 | # Load data 85 | if self.input_type == 'image': 86 | image = self.read_images(self.image_path, folder, self.transform) # (input) spatial images 87 | mat = torch.tensor(1) 88 | elif self.input_type == 'mat': 89 | image = torch.tensor(1) 90 | mat = self.read_mat(self.mat_path, folder) # (input) spatial mat 91 | elif self.input_type == 'both': 92 | image = self.read_images(self.image_path, folder, self.transform) 93 | mat = self.read_mat(self.mat_path, folder) 94 | 95 | label = torch.LongTensor([self.labels[index]]) 96 | return image, mat, label 97 | 98 | 99 | def train(model, device, train_loader, optimizer, metric_loss, alpha): 100 | model.train() 101 | 102 | # set model as training mode 103 | N_count = 0 # counting total trained sample in one epoch 104 | epoch_loss = 0.0 105 | gallery_feat, gallery_label = [], [] 106 | for batch_idx, (image, mat, label) in enumerate(train_loader): 107 | # distribute data to device 108 | image, mat, label = image.to(device), mat.to(device), label.to(device).view(-1, ) 109 | 110 | N_count += image.size(0) 111 | 112 | optimizer.zero_grad() 113 | hidden, output = model(image, mat) # output has dim = (batch, number of classes) 114 | 115 | loss = F.cross_entropy(output, label) + metric_loss(hidden, label) * alpha 116 | # loss = F.cross_entropy(output, label) 117 | # loss = metric_loss(hidden, label) 118 | epoch_loss += loss.item() 119 | 120 | loss.backward() 121 | optimizer.step() 122 | 123 | ave_loss = epoch_loss / N_count 124 | return ave_loss 125 | 126 | 127 | def validation(model, device, train_loader, test_loader): 128 | # set model as testing mode 129 | model.eval() 130 | 131 | gallery_feat, gallery_label = [], [] 132 | prob_feat, prob_label = [], [] 133 | for image, mat, label in train_loader: 134 | # distribute data to device 135 | image, mat = image.to(device), mat.to(device) 136 | hidden, output = model(image, mat) 137 | gallery_feat.append(hidden) 138 | gallery_label.append(label) 139 | 140 | for image, mat, label in test_loader: 141 | # distribute data to device 142 | image, mat = image.to(device), mat.to(device) 143 | hidden, output = model(image, mat) 144 | prob_feat.append(hidden) 145 | prob_label.append(label) 146 | 147 | # return correct, total 148 | return gallery_feat, gallery_label, prob_feat, prob_label 149 | 150 | def acc_calculate(gallery_feat, gallery_label, prob_feat, prob_label): 151 | gallery_feat = gallery_feat 152 | gallery_label = gallery_label.detach().cpu().numpy() 153 | prob_feat = prob_feat 154 | prob_label = prob_label.detach().cpu().numpy() 155 | m, n = prob_feat.shape[0], gallery_feat.shape[0] 156 | dist = torch.pow(prob_feat, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 157 | torch.pow(gallery_feat, 2).sum(dim=1, keepdim=True).expand(n, m).t() 158 | dist.addmm_(1, -2, prob_feat, gallery_feat.t()) 159 | dist = dist.cpu().detach().numpy() 160 | index = dist.argmin(axis=1) 161 | pred = np.array([gallery_label[i] for i in index]) 162 | assert pred.shape == prob_label.shape 163 | total = len(pred) 164 | correct = np.sum((pred == prob_label).astype(np.float)) 165 | return correct, total 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # size 7 | def conv2D_output_size(img_size, padding, kernel_size, stride): 8 | # compute output shape of conv2D 9 | outshape = (np.floor((img_size[0] + 2 * padding[0] - (kernel_size[0] - 1) - 1) / stride[0] + 1).astype(int), 10 | np.floor((img_size[1] + 2 * padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1).astype(int)) 11 | return outshape 12 | 13 | 14 | class CRNN(nn.Module): 15 | def __init__(self, img_x, img_y, mat_x, mat_y, fc_hidden1, fc_hidden2, CNN_embed_dim, 16 | h_RNN_layers, h_RNN, h_FC_dim, drop_p, num_classes, input_type): 17 | super().__init__() 18 | assert input_type in ['mat', 'image', 'both'], 'please choose the right type as: mat, image, both' 19 | self.input_type = input_type 20 | if input_type == 'image' or input_type == 'both': 21 | self.image_CNN = CNN(img_x, img_y, 1, fc_hidden1, fc_hidden2, drop_p, CNN_embed_dim, input_type='image', fc_in_dim=256) 22 | self.image_RNN = RNN(CNN_embed_dim, h_RNN_layers, h_RNN, h_FC_dim, drop_p, num_classes) #h_RNN, h_FC_dim 23 | 24 | if input_type == 'mat' or input_type == 'both': 25 | self.mat_CNN = CNN(mat_x, mat_y, 3, fc_hidden1, fc_hidden2, drop_p, CNN_embed_dim, input_type='mat', fc_in_dim=512) 26 | 27 | self.fc = nn.Linear(2 * h_FC_dim if input_type == 'both' else h_FC_dim, num_classes) 28 | 29 | def forward(self, image, mat): 30 | if self.input_type == 'image' or self.input_type == 'both': 31 | cnn_emb = self.image_CNN(image) 32 | rnn_emb = self.image_RNN(cnn_emb) 33 | if self.input_type == 'mat' or self.input_type == 'both': 34 | mat_emb = self.mat_CNN(mat) 35 | 36 | if self.input_type == 'both': 37 | # concatenate rnn_emb with mat 38 | hidden = torch.cat((rnn_emb, mat_emb), dim=1) 39 | # hidden = rnn_emb + mat_emb 40 | elif self.input_type == 'image': 41 | hidden = rnn_emb 42 | elif self.input_type == 'mat': 43 | hidden = mat_emb 44 | output = self.fc(hidden) 45 | return hidden, output 46 | 47 | #一个残差模块 48 | class Block(nn.Module): 49 | def __init__(self, in_channel, out_channel, strides=1, same_shape=True): 50 | super(Block, self).__init__() 51 | self.same_shape = same_shape 52 | if not same_shape: 53 | strides = 2 54 | self.strides = strides 55 | self.block = nn.Sequential( 56 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1, bias=False), 57 | nn.BatchNorm2d(out_channel), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=False), 60 | nn.BatchNorm2d(out_channel) 61 | ) 62 | if not same_shape: 63 | self.conv3 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, bias=False) 64 | self.bn3 = nn.BatchNorm2d(out_channel) 65 | 66 | def forward(self, x): 67 | out = self.block(x) 68 | if not self.same_shape: 69 | x = self.bn3(self.conv3(x)) 70 | return F.relu(out + x) 71 | 72 | 73 | class CNN(nn.Module): 74 | def __init__(self, high, wide, in_channel, fc_hidden1, fc_hidden2, drop_p, CNN_embed_dim, input_type, fc_in_dim): #CNN_embed_dim参数设置中为64 75 | super().__init__() 76 | self.high = high 77 | self.wide = wide 78 | self.input_type = input_type 79 | self.CNN_embed_dim = CNN_embed_dim 80 | 81 | # CNN architechtures 82 | self.ch1, self.ch2, self.ch3 = 8, 16, 32 83 | self.k1, self.k2, self.k3 = (3, 3), (3, 3), (3, 3) # 2d kernal size 84 | self.s1, self.s2, self.s3 = (2, 2), (2, 2), (2, 2) # 2d strides 85 | self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0) # 2d padding 86 | 87 | # conv2D output shapes 88 | self.conv1_outshape = conv2D_output_size((self.high, self.wide), self.pd1, self.k1, 89 | self.s1) # Conv1 output shape 90 | self.conv2_outshape = conv2D_output_size(self.conv1_outshape, self.pd2, self.k2, self.s2) 91 | self.conv3_outshape = conv2D_output_size(self.conv2_outshape, self.pd3, self.k3, self.s3) 92 | 93 | # fully connected layer hidden nodes 94 | self.fc_hidden1, self.fc_hidden2 = fc_hidden1, fc_hidden2 95 | self.drop_p = drop_p 96 | 97 | self.conv1 = nn.Sequential( 98 | nn.Conv2d(in_channels=in_channel, out_channels=self.ch1, kernel_size=self.k1, stride=self.s1, 99 | padding=self.pd1), 100 | nn.BatchNorm2d(self.ch1, momentum=0.01), 101 | nn.ReLU(inplace=True), 102 | # 103 | ) 104 | self.layer1 = self._make_layer(self.ch1, self.ch1, 2, stride=2) # res 105 | 106 | self.conv2 = nn.Sequential( 107 | nn.Conv2d(in_channels=self.ch1, out_channels=self.ch2, kernel_size=self.k2, stride=self.s2, 108 | padding=self.pd2), 109 | nn.BatchNorm2d(self.ch2, momentum=0.01), 110 | nn.ReLU(inplace=True), 111 | 112 | ) 113 | self.layer2 = self._make_layer(self.ch2, self.ch2, 2, stride=2) # res 114 | 115 | self.conv3 = nn.Sequential( 116 | nn.Conv2d(in_channels=self.ch2, out_channels=self.ch3, kernel_size=self.k3, stride=self.s3, 117 | padding=self.pd3), 118 | nn.BatchNorm2d(self.ch3, momentum=0.01), 119 | nn.ReLU(inplace=True), 120 | ) 121 | self.layer3 = self._make_layer(self.ch3, self.ch3, 2, stride=2) # res 122 | 123 | self.fc1 = nn.Linear(fc_in_dim, self.CNN_embed_dim) 124 | 125 | def _make_layer(self, in_channel, out_channel, block_num, stride=1): 126 | layers = [] 127 | if stride != 1: 128 | layers.append(Block(in_channel, out_channel, stride, same_shape=False)) 129 | else: 130 | layers.append(Block(in_channel, out_channel, stride)) 131 | 132 | for i in range(1, block_num): 133 | layers.append(Block(out_channel, out_channel)) 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x_2d): 137 | if self.input_type == 'image': 138 | cnn_embed_seq = [] 139 | for t in range(x_2d.size(1)): 140 | # CNNs 141 | x = x_2d[:, t, :, :, :] 142 | x = self.conv1(x) 143 | x = self.layer1(x) 144 | x = self.conv2(x) 145 | x = self.layer2(x) 146 | x = x.view(x.size(0), -1) # flatten the output of conv 147 | x = F.dropout(x, p=self.drop_p, training=self.training) 148 | 149 | # FC layers 150 | x = F.relu(self.fc1(x)) 151 | cnn_embed_seq.append(x) 152 | 153 | # swap time and sample dim such that (sample dim, time dim, CNN latent dim) 154 | output = torch.stack(cnn_embed_seq, dim=0).transpose_(0, 1) 155 | else: 156 | x = self.conv1(x_2d) 157 | x = self.layer1(x) 158 | x = self.conv2(x) 159 | x = self.layer2(x) 160 | x = self.conv3(x) 161 | x = self.layer3(x) 162 | x = x.view(x.size(0), -1) # flatten the output of conv 163 | x = F.dropout(x, p=self.drop_p, training=self.training) 164 | 165 | # FC layers 166 | output = F.relu(self.fc1(x)) 167 | return output 168 | 169 | 170 | class RNN(nn.Module): 171 | def __init__(self, CNN_embed_dim, h_RNN_layers, h_RNN, h_FC_dim, drop_p, num_classes): 172 | super().__init__() 173 | 174 | self.RNN_input_size = CNN_embed_dim 175 | self.h_RNN_layers = h_RNN_layers # RNN hidden layers 176 | self.h_RNN = h_RNN # RNN hidden nodes 177 | self.h_FC_dim = h_FC_dim 178 | self.drop_p = drop_p 179 | self.num_classes = num_classes 180 | 181 | self.LSTM = nn.LSTM( 182 | input_size=self.RNN_input_size, 183 | hidden_size=self.h_RNN, 184 | num_layers=h_RNN_layers, 185 | batch_first=True 186 | ) 187 | 188 | self.fc1 = nn.Linear(self.h_RNN, self.h_FC_dim) 189 | 190 | def forward(self, x_RNN): 191 | rnn_out, (_, _) = self.LSTM(x_RNN, None) 192 | 193 | # FC layers 194 | x = self.fc1(rnn_out[:, -1, :]) # choose RNN_out at the last time step 195 | x = F.relu(x) 196 | x = F.dropout(x, p=self.drop_p, training=self.training) 197 | return x 198 | 199 | 200 | 201 | --------------------------------------------------------------------------------