├── network.png ├── README.md ├── LICENSE ├── train_STVEN ├── train_rPPGNet ├── models ├── STVEN.py └── rPPGNet.py └── train_STVEN_rPPGNet /network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZitongYu/STVEN_rPPGNet/HEAD/network.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STVEN_rPPGNet 2 | Main code of **ICCV2019 paper "Remote Heart Rate Measurement from Highly Compressed Facial Videos: an End-to-end Deep Learning Solution with Video Enhancement"** [[.pdf]](https://arxiv.org/pdf/1907.11921.pdf) 3 | 4 | Note that the specific **dataloader, data preprocessing and postprocessing** should be done by users depending on particular datasets. 5 | 6 | ![image](https://github.com/ZitongYu/STVEN_rPPGNet/blob/master/network.png) 7 | 8 | It is just for **research purpose**, and commercial use is not allowed. 9 | 10 | Citation 11 | ------- 12 | If you use the STVEN or rPPGNet please cite: 13 | 14 | >@inproceedings{yu2019remote, 15 | >    title={Remote Heart Rate Measurement from Highly Compressed Facial Videos: an End-to-end Deep Learning Solution with Video Enhancement}, 16 | >    author={Yu*, Zitong and Peng*, Wei and Li, Xiaobai and Hong, Xiaopeng and Zhao, Guoying}, 17 | >    booktitle= {International Conference on Computer Vision (ICCV)}, 18 | >    year = {2019} 19 | >} 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Fisher Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /train_STVEN: -------------------------------------------------------------------------------- 1 | ### it is just for research purpose, and commercial use is not allowed ### 2 | 3 | import torch 4 | from models.STVEN import * 5 | 6 | ''' ############################################################### 7 | # 8 | # Step 1: loss functions for STVEN 9 | # 1.1 torch.mean(torch.abs(x_reconst - video_y_GT)) for L1 reconstruction loss 10 | # 1.2 psnr() for L2 reconstruction loss 11 | # 1.3 torch.mean(torch.abs(x_reconst -video_x_GT)) Cycle reconstruction loss 12 | # 13 | ''' ############################################################### 14 | 15 | def psnr(self, img, img_g): 16 | 17 | criterionMSE = nn.MSELoss() #.to(device) 18 | mse = criterionMSE(img, img_g) 19 | psnr = 10 * torch.log10(1./ (mse+10e-8)) #20 * 20 | 21 | return psnr 22 | 23 | 24 | 25 | ''' ############################################################### 26 | # 27 | # Step 2: Forward model and calculate the losses 28 | # # input 1: facial frames --> [3, 64, 128, 128] 29 | # input 2: target label mask --> 5D vector 30 | # # video_GroudTruth: the original video (before highly compressed) 31 | # 32 | # 2.1 Forward the model, Generate video from original video to the target video 33 | # 2.2 Calculate the reconstruction loss 34 | # 2.3 Calculate the PSNR loss 35 | # 2.4 Calculate the cycle loss 36 | # 37 | ''' ############################################################### 38 | 39 | 40 | model = STVEN_Generator() 41 | 42 | x_reconst = model(video_1, traget_label1) 43 | 44 | L1_loss = torch.mean(torch.abs(x_reconst - video_GroudTruth)) 45 | Loss_PSNR = psnr(x_reconst, video_GroudTruth) 46 | 47 | x_fake = model(x_reconst, original_label1) 48 | 49 | L1_loss_cycle = torch.mean(torch.abs(x_fake - video_1)) 50 | Loss_PSNR_cycle = psnr(x_fake, video_1) 51 | 52 | 53 | ''' ############################################################### 54 | # 55 | # Step 3: loss fusion and BP 56 | # 57 | ''' ############################################################### 58 | 59 | loss = 100*L1_loss + Loss_PSNR+ 100*L1_loss_cycle + Loss_PSNR_cycle 60 | 61 | loss.backward() 62 | 63 | -------------------------------------------------------------------------------- /train_rPPGNet: -------------------------------------------------------------------------------- 1 | ### it is just for research purpose, and commercial use is not allowed ### 2 | 3 | import torch 4 | from models.rPPGNet import * 5 | 6 | ''' ############################################################### 7 | # 8 | # Step 1: two loss function 9 | # 1.1 nn.BCELoss() for skin segmentation 10 | # 1.2 Neg_Pearson() for rPPG signal regression 11 | # 12 | ''' ############################################################### 13 | 14 | class Neg_Pearson(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss 15 | def __init__(self): 16 | super(Neg_Pearson,self).__init__() 17 | return 18 | def forward(self, preds, labels): # all variable operation 19 | loss = 0 20 | for i in range(preds.shape[0]): 21 | sum_x = torch.sum(preds[i]) # x 22 | sum_y = torch.sum(labels[i]) # y 23 | sum_xy = torch.sum(preds[i]*labels[i]) # xy 24 | sum_x2 = torch.sum(torch.pow(preds[i],2)) # x^2 25 | sum_y2 = torch.sum(torch.pow(labels[i],2)) # y^2 26 | N = preds.shape[1] 27 | pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2)))) 28 | 29 | #if (pearson>=0).data.cpu().numpy(): # torch.cuda.ByteTensor --> numpy 30 | # loss += 1 - pearson 31 | #else: 32 | # loss += 1 - torch.abs(pearson) 33 | 34 | loss += 1 - pearson 35 | 36 | 37 | loss = loss/preds.shape[0] 38 | return loss 39 | 40 | 41 | 42 | criterion_Binary = nn.BCELoss() # binary segmentation 43 | criterion_Pearson = Neg_Pearson() # rPPG singal 44 | 45 | 46 | ''' ############################################################### 47 | # 48 | # Step 2: Forward model and calculate the losses 49 | # # inputs: facial frames --> [3, 64, 128, 128] 50 | # skin_seg_label: binary skin labels --> [64, 64, 64] 51 | # ecg: groundtruth smoothed ecg signals --> [64] 52 | # 53 | # 2.1 Forward the model, get the predicted skin maps and rPPG signals 54 | # 2.2 Calculate the loss between predicted skin maps and binary skin labels (loss_binary) 55 | # 2.3 Calculate the loss between predicted rPPG signals and groundtruth smoothed ecg signals (loss_ecg, loss_ecg1, loss_ecg2, loss_ecg3,## loss_ecg4, loss_ecg_aux) 56 | # 57 | ''' ############################################################### 58 | 59 | 60 | model = rPPGNet() 61 | 62 | skin_map, rPPG_aux, rPPG, rPPG_SA1, rPPG_SA2, rPPG_SA3, rPPG_SA4, x_visual6464, x_visual3232 = model(inputs) 63 | 64 | 65 | loss_binary = criterion_Binary(skin_map, skin_seg_label) 66 | 67 | rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG) # normalize2 68 | rPPG_SA1 = (rPPG_SA1-torch.mean(rPPG_SA1)) /torch.std(rPPG_SA1) # normalize2 69 | rPPG_SA2 = (rPPG_SA2-torch.mean(rPPG_SA2)) /torch.std(rPPG_SA2) # normalize2 70 | rPPG_SA3 = (rPPG_SA3-torch.mean(rPPG_SA3)) /torch.std(rPPG_SA3) # normalize2 71 | rPPG_SA4 = (rPPG_SA4-torch.mean(rPPG_SA4)) /torch.std(rPPG_SA4) # normalize2 72 | rPPG_aux = (rPPG_aux-torch.mean(rPPG_aux)) /torch.std(rPPG_aux) # normalize2 73 | 74 | loss_ecg = criterion_Pearson(rPPG, ecg) 75 | loss_ecg1 = criterion_Pearson(rPPG_SA1, ecg) 76 | loss_ecg2 = criterion_Pearson(rPPG_SA2, ecg) 77 | loss_ecg3 = criterion_Pearson(rPPG_SA3, ecg) 78 | loss_ecg4 = criterion_Pearson(rPPG_SA4, ecg) 79 | loss_ecg_aux = criterion_Pearson(rPPG_aux, ecg) 80 | 81 | 82 | 83 | ''' ############################################################### 84 | # 85 | # Step 3: loss fusion and BP 86 | # 87 | ''' ############################################################### 88 | 89 | loss = 0.1*loss_binary + 0.5*(loss_ecg1 + loss_ecg2 + loss_ecg3 + loss_ecg4 + loss_ecg_aux) + loss_ecg 90 | 91 | loss.backward() 92 | 93 | -------------------------------------------------------------------------------- /models/STVEN.py: -------------------------------------------------------------------------------- 1 | ### it is just for research purpose, and commercial use is not allowed ### 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | import math 9 | from torch.nn.modules.utils import _triple 10 | 11 | 12 | 13 | class SpatioTemporalConv(nn.Module): 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 16 | super(SpatioTemporalConv, self).__init__() 17 | 18 | # if ints are entered, convert them to iterables, 1 -> [1, 1, 1] 19 | kernel_size = _triple(kernel_size) 20 | stride = _triple(stride) 21 | padding = _triple(padding) 22 | 23 | # decomposing the parameters into spatial and temporal components by 24 | # masking out the values with the defaults on the axis that 25 | # won't be convolved over. This is necessary to avoid unintentional 26 | # behavior such as padding being added twice 27 | spatial_kernel_size = [1, kernel_size[1], kernel_size[2]] 28 | spatial_stride = [1, stride[1], stride[2]] 29 | spatial_padding = [0, padding[1], padding[2]] 30 | 31 | temporal_kernel_size = [kernel_size[0], 1, 1] 32 | temporal_stride = [stride[0], 1, 1] 33 | temporal_padding = [padding[0], 0, 0] 34 | 35 | # compute the number of intermediary channels (M) using formula 36 | # from the paper section 3.5 37 | intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/(kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels))) 38 | 39 | 40 | # the spatial conv is effectively a 2D conv due to the 41 | # spatial_kernel_size, followed by batch_norm and ReLU 42 | self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size, 43 | stride=spatial_stride, padding=spatial_padding, bias=bias) 44 | self.bn = nn.BatchNorm3d(intermed_channels) 45 | self.relu = nn.ReLU() ## nn.Tanh() or nn.ReLU(inplace=True) 46 | 47 | 48 | # the temporal conv is effectively a 1D conv, but has batch norm 49 | # and ReLU added inside the model constructor, not here. This is an 50 | # intentional design choice, to allow this module to externally act 51 | # identical to a standard Conv3D, so it can be reused easily in any 52 | # other codebase 53 | self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, 54 | stride=temporal_stride, padding=temporal_padding, bias=bias) 55 | 56 | def forward(self, x): 57 | x = self.relu(self.bn(self.spatial_conv(x))) 58 | x = self.temporal_conv(x) 59 | return x 60 | 61 | 62 | 63 | # 64 | class STVEN_Generator(nn.Module): 65 | """Generator network.""" 66 | def __init__(self, conv_dim=64, c_dim=5, repeat_num=4): 67 | super(STVEN_Generator, self).__init__() 68 | 69 | 70 | layers = [] 71 | n_input_channel = 3+c_dim # image + label channels 72 | layers.append(nn.Conv3d(n_input_channel, conv_dim, kernel_size=(3,7,7), stride=(1,1,1), padding=(1,3,3), bias=False)) 73 | layers.append(nn.InstanceNorm3d(conv_dim, affine=True, track_running_stats=True)) 74 | layers.append(nn.ReLU(inplace=True)) 75 | 76 | # Down-sampling layers. 77 | curr_dim = conv_dim 78 | 79 | 80 | layers.append(nn.Conv3d(curr_dim, curr_dim*2, kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1), bias=False)) 81 | layers.append(nn.InstanceNorm3d(curr_dim*2, affine=True, track_running_stats=True)) 82 | layers.append(nn.ReLU(inplace=True)) 83 | curr_dim = curr_dim * 2 84 | 85 | layers.append(nn.Conv3d(curr_dim, curr_dim*4, kernel_size=(4,4,4), stride=(2,2,2), padding=(1,1,1), bias=False)) 86 | layers.append(nn.InstanceNorm3d(curr_dim*4, affine=True, track_running_stats=True)) 87 | layers.append(nn.ReLU(inplace=True)) 88 | curr_dim = curr_dim * 4 89 | 90 | # Bottleneck layers. 91 | for i in range(repeat_num): 92 | layers.append(SpatioTemporalConv(curr_dim, curr_dim, [3, 3, 3], stride=(1,1,1), padding=[1,1,1])) 93 | 94 | # Up-sampling layers. 95 | layers2 = [] 96 | layers2.append(nn.ConvTranspose3d(curr_dim, curr_dim//4, kernel_size=(4,4,4), stride=(2,2,2), padding=(1,1,1), bias=False)) 97 | layers2.append(nn.InstanceNorm3d(curr_dim//4, affine=True, track_running_stats=True)) 98 | layers2.append(nn.ReLU(inplace=True)) 99 | curr_dim = curr_dim // 4 100 | 101 | layers3 = [] 102 | layers3.append(nn.ConvTranspose3d(curr_dim, curr_dim//2, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1), bias=False)) 103 | layers3.append(nn.InstanceNorm3d(curr_dim//2, affine=True, track_running_stats=True)) 104 | layers3.append(nn.ReLU(inplace=True)) 105 | curr_dim = curr_dim //2 106 | 107 | layers4 = [] 108 | layers4.append(nn.Conv3d(curr_dim, 3, kernel_size=(1,7,7), stride=(1,1,1), padding=(0,3,3), bias=False)) 109 | layers4.append(nn.Tanh()) 110 | 111 | self.down3Dmain = nn.Sequential(*layers) 112 | 113 | self.layers2 = nn.Sequential(*layers2) 114 | self.layers3 = nn.Sequential(*layers3) 115 | self.layers4 = nn.Sequential(*layers4) 116 | 117 | 118 | def forward(self, x, c): 119 | # Replicate spatially and concatenate domain information. 120 | c = c.view(c.size(0), c.size(1), 1, 1, 1) 121 | c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3), x.size(4)) 122 | 123 | x0 = torch.cat([x, c], dim=1) 124 | x0 = self.down3Dmain(x0) 125 | 126 | x1 = self.layers2(x0) 127 | x2 = self.layers3(x1) 128 | x3 = self.layers4(x2) 129 | 130 | out = x3 +x #Res Connection 131 | 132 | return out 133 | 134 | -------------------------------------------------------------------------------- /train_STVEN_rPPGNet: -------------------------------------------------------------------------------- 1 | ### it is just for research purpose, and commercial use is not allowed ### 2 | 3 | import torch 4 | from models.rPPGNet import * 5 | from models.STVEN import * 6 | 7 | ''' ############################################################### 8 | # 9 | # Step 1: two loss function for STVEN 10 | # 1.1 torch.mean(torch.abs(x_reconst - video_y_GT)) for L1 reconstruction loss 11 | # 1.2 psnr() for L2 reconstruction loss 12 | # 13 | # two loss function for rPPGNet 14 | # 1.3 nn.BCELoss() for skin segmentation 15 | # 1.4 Neg_Pearson() for rPPG signal regression 16 | # 17 | ''' ############################################################### 18 | 19 | def psnr(self, img, img_g): 20 | 21 | criterionMSE = nn.MSELoss() #.to(device) 22 | mse = criterionMSE(img, img_g) 23 | psnr = 10 * torch.log10(1./ (mse+10e-8)) #20 * 24 | 25 | return psnr 26 | 27 | 28 | class Neg_Pearson(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss 29 | def __init__(self): 30 | super(Neg_Pearson,self).__init__() 31 | return 32 | def forward(self, preds, labels): # all variable operation 33 | loss = 0 34 | for i in range(preds.shape[0]): 35 | sum_x = torch.sum(preds[i]) # x 36 | sum_y = torch.sum(labels[i]) # y 37 | sum_xy = torch.sum(preds[i]*labels[i]) # xy 38 | sum_x2 = torch.sum(torch.pow(preds[i],2)) # x^2 39 | sum_y2 = torch.sum(torch.pow(labels[i],2)) # y^2 40 | N = preds.shape[1] 41 | pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2)))) 42 | 43 | #if (pearson>=0).data.cpu().numpy(): # torch.cuda.ByteTensor --> numpy 44 | # loss += 1 - pearson 45 | #else: 46 | # loss += 1 - torch.abs(pearson) 47 | 48 | loss += 1 - pearson 49 | 50 | 51 | loss = loss/preds.shape[0] 52 | return loss 53 | 54 | 55 | 56 | criterion_Binary = nn.BCELoss() # binary segmentation 57 | criterion_Pearson = Neg_Pearson() # rPPG singal 58 | 59 | 60 | ''' ############################################################### 61 | # 62 | # Step 2: Forward model and calculate the losses 63 | # # input 1 : facial frames --> [3, 64, 128, 128] 64 | # input 2: target label mask --> 5D vector 65 | # 66 | # skin_seg_label: binary skin labels --> [64, 64, 64] 67 | # ecg: groundtruth smoothed ecg signals --> [64] 68 | # video_GroudTruth: the original video (before highly compressed) 69 | # 70 | # 2.1 Forward the model, generate the enhanced video; Calculate the reconstruction loss; Calculate the PSNR loss 71 | # 2.2 Get the predicted skin maps and rPPG signals from the enhanced video; Calculate the skin loss and rPPG loss 72 | # 2.3 Get the predicted skin maps and rPPG signals from the original video; Calculate the skin loss and rPPG loss and Perceptual loss 73 | # 74 | ''' ############################################################### 75 | 76 | 77 | model_rPPGNet = rPPGNet() # load the pretrained model after "train_rPPGNet"; fix weights not updated 78 | model_fixOri_rPPGNet = rPPGNet() # load the pretrained model after "train_rPPGNet"; fix weights not updated 79 | model_STVEN = STVEN_Generator() # load the pretrained model after "train_STVEN"; updated the weights via BP 80 | 81 | 82 | ######## Loss_STVEN ########### 83 | x_reconst = model_STVEN(inputs, traget_label1) 84 | 85 | L1_loss = torch.mean(torch.abs(x_reconst - video_GroudTruth)) 86 | Loss_PSNR = psnr(x_reconst, video_GroudTruth) 87 | Loss_STVEN = L1_loss + Loss_PSNR 88 | 89 | 90 | ######## Loss_rPPGNet ########### 91 | skin_map, rPPG_aux, rPPG, rPPG_SA1, rPPG_SA2, rPPG_SA3, rPPG_SA4, x_visual6464, x_visual3232 = model_rPPGNet(x_reconst) 92 | 93 | loss_binary = criterion_Binary(skin_map, skin_seg_label) 94 | 95 | rPPG = (rPPG-torch.mean(rPPG)) /torch.std(rPPG) # normalize2 96 | rPPG_SA1 = (rPPG_SA1-torch.mean(rPPG_SA1)) /torch.std(rPPG_SA1) # normalize2 97 | rPPG_SA2 = (rPPG_SA2-torch.mean(rPPG_SA2)) /torch.std(rPPG_SA2) # normalize2 98 | rPPG_SA3 = (rPPG_SA3-torch.mean(rPPG_SA3)) /torch.std(rPPG_SA3) # normalize2 99 | rPPG_SA4 = (rPPG_SA4-torch.mean(rPPG_SA4)) /torch.std(rPPG_SA4) # normalize2 100 | rPPG_aux = (rPPG_aux-torch.mean(rPPG_aux)) /torch.std(rPPG_aux) # normalize2 101 | 102 | loss_ecg = criterion_Pearson(rPPG, ecg) 103 | loss_ecg1 = criterion_Pearson(rPPG_SA1, ecg) 104 | loss_ecg2 = criterion_Pearson(rPPG_SA2, ecg) 105 | loss_ecg3 = criterion_Pearson(rPPG_SA3, ecg) 106 | loss_ecg4 = criterion_Pearson(rPPG_SA4, ecg) 107 | loss_ecg_aux = criterion_Pearson(rPPG_aux, ecg) 108 | 109 | 110 | ######## Loss_Perceptual ########### 111 | with torch.no_grad(): 112 | skin_map_GT, rPPG_aux_GT, rPPG_GT, rPPG_SA1_GT, rPPG_SA2_GT, rPPG_SA3_GT, rPPG_SA4_GT, x_visual6464_GT, x_visual3232_GT = model_fixOri_rPPGNet(video_GroudTruth) 113 | rPPG_GT = (rPPG_GT-torch.mean(rPPG_GT)) /torch.std(rPPG_GT) # normalize2 114 | rPPG_SA1_GT = (rPPG_SA1_GT-torch.mean(rPPG_SA1_GT)) /torch.std(rPPG_SA1_GT) 115 | rPPG_SA2_GT = (rPPG_SA2_GT-torch.mean(rPPG_SA2_GT)) /torch.std(rPPG_SA2_GT) 116 | rPPG_SA3_GT = (rPPG_SA3_GT-torch.mean(rPPG_SA3_GT)) /torch.std(rPPG_SA3_GT) 117 | rPPG_SA4_GT = (rPPG_SA4_GT-torch.mean(rPPG_SA4_GT)) /torch.std(rPPG_SA4_GT) 118 | 119 | loss_visual6464_Per = criterion_reg(x_visual6464, x_visual6464_GT) 120 | loss_visual3232_Per = criterion_reg(x_visual3232, x_visual3232_GT) 121 | loss_ecg_Per = criterion_Pearson(rPPG, rPPG_GT) 122 | loss_ecg1_Per = criterion_Pearson(rPPG_SA1, rPPG_SA1_GT) 123 | loss_ecg2_Per = criterion_Pearson(rPPG_SA2, rPPG_SA2_GT) 124 | loss_ecg3_Per = criterion_Pearson(rPPG_SA3, rPPG_SA3_GT) 125 | loss_ecg4_Per = criterion_Pearson(rPPG_SA4, rPPG_SA4_GT) 126 | loss_binary_Per = criterion_reg(skin_map, skin_map_GT) 127 | 128 | 129 | 130 | ''' ############################################################### 131 | # 132 | # Step 3: loss fusion and BP 133 | # # only update STVEN, fix rPPGNet 134 | # 135 | ''' ############################################################### 136 | 137 | loss_PhysNet = 0.1*loss_binary + 0.5*(loss_ecg1 + loss_ecg2 + loss_ecg3 + loss_ecg4 + loss_ecg_aux) + loss_ecg 138 | loss_Perceptual = 0.1*loss_binary_Per + 0.5*(loss_ecg1_Per + loss_ecg2_Per + loss_ecg3_Per + loss_ecg4_Per )+ loss_ecg_Per + loss_visual6464_Per + loss_visual3232_Per 139 | loss = loss_PhysNet + loss_Perceptual + 0.0001*Loss_STVEN 140 | 141 | loss.backward() 142 | 143 | # optimizer_model_STVEN.step() # only update STVEN, fix rPPGNet 144 | -------------------------------------------------------------------------------- /models/rPPGNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _triple 4 | import pdb 5 | import torch 6 | 7 | 8 | class SpatioTemporalConv(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 10 | super(SpatioTemporalConv, self).__init__() 11 | 12 | 13 | # if ints are entered, convert them to iterables, 1 -> [1, 1, 1] 14 | kernel_size = _triple(kernel_size) 15 | stride = _triple(stride) 16 | padding = _triple(padding) 17 | 18 | # decomposing the parameters into spatial and temporal components by 19 | # masking out the values with the defaults on the axis that 20 | # won't be convolved over. This is necessary to avoid unintentional 21 | # behavior such as padding being added twice 22 | spatial_kernel_size = [1, kernel_size[1], kernel_size[2]] 23 | spatial_stride = [1, stride[1], stride[2]] 24 | spatial_padding = [0, padding[1], padding[2]] 25 | 26 | temporal_kernel_size = [kernel_size[0], 1, 1] 27 | temporal_stride = [stride[0], 1, 1] 28 | temporal_padding = [padding[0], 0, 0] 29 | 30 | # compute the number of intermediary channels (M) using formula 31 | # from the paper section 3.5 32 | intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/(kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels))) 33 | 34 | # self-definition 35 | #intermed_channels = int((in_channels+intermed_channels)/2) 36 | 37 | # the spatial conv is effectively a 2D conv due to the 38 | # spatial_kernel_size, followed by batch_norm and ReLU 39 | self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size, 40 | stride=spatial_stride, padding=spatial_padding, bias=bias) 41 | self.bn = nn.BatchNorm3d(intermed_channels) 42 | self.relu = nn.ReLU() ## nn.Tanh() or nn.ReLU(inplace=True) 43 | 44 | 45 | # the temporal conv is effectively a 1D conv, but has batch norm 46 | # and ReLU added inside the model constructor, not here. This is an 47 | # intentional design choice, to allow this module to externally act 48 | # identical to a standard Conv3D, so it can be reused easily in any 49 | # other codebase 50 | self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, 51 | stride=temporal_stride, padding=temporal_padding, bias=bias) 52 | 53 | def forward(self, x): 54 | x = self.relu(self.bn(self.spatial_conv(x))) 55 | x = self.temporal_conv(x) 56 | return x 57 | 58 | 59 | class MixA_Module(nn.Module): 60 | """ Spatial-Skin attention module""" 61 | def __init__(self): 62 | super(MixA_Module,self).__init__() 63 | self.softmax = nn.Softmax(dim=-1) 64 | self.AVGpool = nn.AdaptiveAvgPool1d(1) 65 | self.MAXpool = nn.AdaptiveMaxPool1d(1) 66 | def forward(self,x , skin): 67 | """ 68 | inputs : 69 | x : input feature maps( B X C X T x W X H) 70 | skin : skin confidence maps( B X T x W X H) 71 | returns : 72 | out : attention value 73 | spatial attention: W x H 74 | """ 75 | m_batchsize, C, T ,W, H = x.size() 76 | B_C_TWH = x.view(m_batchsize,C,-1) 77 | B_TWH_C = x.view(m_batchsize,C,-1).permute(0,2,1) 78 | B_TWH_C_AVG = torch.sigmoid(self.AVGpool(B_TWH_C)).view(m_batchsize,T,W,H) 79 | B_TWH_C_MAX = torch.sigmoid(self.MAXpool(B_TWH_C)).view(m_batchsize,T,W,H) 80 | B_TWH_C_Fusion = B_TWH_C_AVG + B_TWH_C_MAX + skin 81 | Attention_weight = self.softmax(B_TWH_C_Fusion.view(m_batchsize,T,-1)) 82 | Attention_weight = Attention_weight.view(m_batchsize,T,W,H) 83 | # mask1 mul 84 | output = x.clone() 85 | for i in range(C): 86 | output[:,i,:,:,:] = output[:,i,:,:,:].clone()*Attention_weight 87 | 88 | return output, Attention_weight 89 | 90 | 91 | # for open-source 92 | # skin segmentation + PhysNet + MixA3232 + MixA1616part4 93 | class rPPGNet(nn.Module): 94 | def __init__(self, frames=64): 95 | super(rPPGNet, self).__init__() 96 | 97 | self.ConvSpa1 = nn.Sequential( 98 | nn.Conv3d(3, 16, [1,5,5],stride=1, padding=[0,2,2]), 99 | nn.BatchNorm3d(16), 100 | nn.ReLU(inplace=True), 101 | ) 102 | 103 | self.ConvSpa3 = nn.Sequential( 104 | SpatioTemporalConv(16, 32, [3, 3, 3], stride=1, padding=1), 105 | nn.BatchNorm3d(32), 106 | nn.ReLU(inplace=True), 107 | ) 108 | self.ConvSpa4 = nn.Sequential( 109 | SpatioTemporalConv(32, 32, [3, 3, 3], stride=1, padding=1), 110 | nn.BatchNorm3d(32), 111 | nn.ReLU(inplace=True), 112 | ) 113 | 114 | self.ConvSpa5 = nn.Sequential( 115 | SpatioTemporalConv(32, 64, [3, 3, 3], stride=1, padding=1), 116 | nn.BatchNorm3d(64), 117 | nn.ReLU(inplace=True), 118 | ) 119 | self.ConvSpa6 = nn.Sequential( 120 | SpatioTemporalConv(64, 64, [3, 3, 3], stride=1, padding=1), 121 | nn.BatchNorm3d(64), 122 | nn.ReLU(inplace=True), 123 | ) 124 | self.ConvSpa7 = nn.Sequential( 125 | SpatioTemporalConv(64, 64, [3, 3, 3], stride=1, padding=1), 126 | nn.BatchNorm3d(64), 127 | nn.ReLU(inplace=True), 128 | ) 129 | self.ConvSpa8 = nn.Sequential( 130 | SpatioTemporalConv(64, 64, [3, 3, 3], stride=1, padding=1), 131 | nn.BatchNorm3d(64), 132 | nn.ReLU(inplace=True), 133 | ) 134 | self.ConvSpa9 = nn.Sequential( 135 | SpatioTemporalConv(64, 64, [3, 3, 3], stride=1, padding=1), 136 | nn.BatchNorm3d(64), 137 | nn.ReLU(inplace=True), 138 | ) 139 | 140 | self.ConvSpa10 = nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0) 141 | self.ConvSpa11 = nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0) 142 | self.ConvPart1 = nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0) 143 | self.ConvPart2 = nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0) 144 | self.ConvPart3 = nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0) 145 | self.ConvPart4 = nn.Conv3d(64, 1, [1,1,1],stride=1, padding=0) 146 | 147 | 148 | self.AvgpoolSpa = nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2)) 149 | self.AvgpoolSkin_down = nn.AvgPool2d((2,2), stride=2) 150 | self.AvgpoolSpaTem = nn.AvgPool3d((2, 2, 2), stride=2) 151 | 152 | self.ConvSpa = nn.Conv3d(3, 16, [1,3,3],stride=1, padding=[0,1,1]) 153 | 154 | 155 | # global average pooling of the output 156 | self.pool = nn.AdaptiveAvgPool3d(1) 157 | self.poolspa = nn.AdaptiveAvgPool3d((frames,1,1)) # attention to this value 158 | 159 | 160 | # skin_branch 161 | self.skin_main = nn.Sequential( 162 | nn.Conv3d(32, 16, [1,3,3], stride=1, padding=[0,1,1]), 163 | nn.BatchNorm3d(16), 164 | nn.ReLU(inplace=True), 165 | nn.Conv3d(16, 8, [1,3,3], stride=1, padding=[0,1,1]), 166 | nn.BatchNorm3d(8), 167 | nn.ReLU(inplace=True), 168 | ) 169 | 170 | self.skin_residual = nn.Sequential( 171 | nn.Conv3d(32, 8, [1,1,1], stride=1, padding=0), 172 | nn.BatchNorm3d(8), 173 | nn.ReLU(inplace=True), 174 | ) 175 | 176 | self.skin_output = nn.Sequential( 177 | nn.Conv3d(8, 1, [1,3,3], stride=1, padding=[0,1,1]), 178 | nn.Sigmoid(), ## binary 179 | ) 180 | 181 | self.MixA_Module = MixA_Module() 182 | 183 | def forward(self, x): # x [3, 64, 128,128] 184 | x_visual = x 185 | 186 | x = self.ConvSpa1(x) # x [3, 64, 128,128] 187 | x = self.AvgpoolSpa(x) # x [16, 64, 64,64] 188 | 189 | x = self.ConvSpa3(x) # x [32, 64, 64,64] 190 | x_visual6464 = self.ConvSpa4(x) # x [32, 64, 64,64] 191 | x = self.AvgpoolSpa(x_visual6464) # x [32, 64, 32,32] 192 | 193 | 194 | ## branch 1: skin segmentation 195 | x_skin_main = self.skin_main(x_visual6464) # x [8, 64, 64,64] 196 | x_skin_residual = self.skin_residual(x_visual6464) # x [8, 64, 64,64] 197 | x_skin = self.skin_output(x_skin_main+x_skin_residual) # x [1, 64, 64,64] 198 | x_skin = x_skin[:,0,:,:,:] # x [74, 64,64] 199 | 200 | 201 | ## branch 2: rPPG 202 | x = self.ConvSpa5(x) # x [64, 64, 32,32] 203 | x_visual3232 = self.ConvSpa6(x) # x [64, 64, 32,32] 204 | x = self.AvgpoolSpa(x_visual3232) # x [64, 64, 16,16] 205 | 206 | x = self.ConvSpa7(x) # x [64, 64, 16,16] 207 | x = self.ConvSpa8(x) # x [64, 64, 16,16] 208 | x_visual1616 = self.ConvSpa9(x) # x [64, 64, 16,16] 209 | 210 | 211 | ## SkinA1_loss 212 | x_skin3232 = self.AvgpoolSkin_down(x_skin) # x [64, 32,32] 213 | x_visual3232_SA1, Attention3232 = self.MixA_Module(x_visual3232, x_skin3232) 214 | x_visual3232_SA1 = self.poolspa(x_visual3232_SA1) # x [64, 64, 1,1] 215 | ecg_SA1 = self.ConvSpa10(x_visual3232_SA1).squeeze(1).squeeze(-1).squeeze(-1) 216 | 217 | 218 | ## SkinA2_loss 219 | x_skin1616 = self.AvgpoolSkin_down(x_skin3232) # x [64, 16,16] 220 | x_visual1616_SA2, Attention1616 = self.MixA_Module(x_visual1616, x_skin1616) 221 | ## Global 222 | global_F = self.poolspa(x_visual1616_SA2) # x [64, 64, 1,1] 223 | ecg_global = self.ConvSpa11(global_F).squeeze(1).squeeze(-1).squeeze(-1) 224 | 225 | ## Local 226 | Part1 = x_visual1616_SA2[:,:,:,:8,:8] 227 | Part1 = self.poolspa(Part1) # x [64, 64, 1,1] 228 | ecg_part1 = self.ConvSpa11(Part1).squeeze(1).squeeze(-1).squeeze(-1) 229 | 230 | Part2 = x_visual1616_SA2[:,:,:,8:16,:8] 231 | Part2 = self.poolspa(Part2) # x [64, 64, 1,1] 232 | ecg_part2 = self.ConvPart2(Part2).squeeze(1).squeeze(-1).squeeze(-1) 233 | 234 | Part3 = x_visual1616_SA2[:,:,:,:8,8:16] 235 | Part3 = self.poolspa(Part3) # x [64, 64, 1,1] 236 | ecg_part3 = self.ConvPart3(Part3).squeeze(1).squeeze(-1).squeeze(-1) 237 | 238 | Part4 = x_visual1616_SA2[:,:,:,8:16,8:16] 239 | Part4 = self.poolspa(Part4) # x [64, 64, 1,1] 240 | ecg_part4 = self.ConvPart4(Part4).squeeze(1).squeeze(-1).squeeze(-1) 241 | 242 | 243 | 244 | return x_skin, ecg_SA1, ecg_global, ecg_part1, ecg_part2, ecg_part3, ecg_part4, x_visual6464, x_visual3232 245 | 246 | 247 | 248 | --------------------------------------------------------------------------------