├── .gitignore ├── FCSN.py ├── FCSN_ENC.py ├── README.md ├── SD.py ├── SK.py ├── imgs ├── architecture_FCSN.PNG ├── architecture_VSLUD.PNG └── loss.PNG ├── sample_FCSN.py ├── test.py ├── train.py └── training_set_preparation ├── FeatureExtractor.py ├── extract_frame.sh ├── paper_training_set_preparation.py └── self_training_set_preparation.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | .vscode/ 3 | runs/ 4 | __pycache__/ 5 | saved_models/ 6 | nohup.out 7 | loss_record.tar 8 | -------------------------------------------------------------------------------- /FCSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FCSN(nn.Module): 5 | def __init__(self, n_class=2): 6 | super(FCSN, self).__init__() 7 | 8 | # conv1 input shape (batch_size, Channel, H, W) -> (1,1024,1,T) 9 | self.conv1_1 = nn.Conv2d(1024, 64, (1,3), padding=(0,100)) 10 | self.sn1_1 = nn.utils.spectral_norm(self.conv1_1) 11 | self.bn1_1 = nn.BatchNorm2d(64) 12 | self.relu1_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 13 | self.conv1_2 = nn.Conv2d(64, 64, (1,3), padding=(0,1)) 14 | self.sn1_2 = nn.utils.spectral_norm(self.conv1_2) 15 | self.bn1_2 = nn.BatchNorm2d(64) 16 | self.relu1_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 17 | self.pool1 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/2 18 | 19 | # conv2 20 | self.conv2_1 = nn.Conv2d(64, 128, (1,3), padding=(0,1)) 21 | self.sn2_1 = nn.utils.spectral_norm(self.conv2_1) 22 | self.bn2_1 = nn.BatchNorm2d(128) 23 | self.relu2_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 24 | self.conv2_2 = nn.Conv2d(128, 128, (1,3), padding=(0,1)) 25 | self.sn2_2 = nn.utils.spectral_norm(self.conv2_2) 26 | self.bn2_2 = nn.BatchNorm2d(128) 27 | self.relu2_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 28 | self.pool2 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/4 29 | 30 | # conv3 31 | self.conv3_1 = nn.Conv2d(128, 256, (1,3), padding=(0,1)) 32 | self.sn3_1 = nn.utils.spectral_norm(self.conv3_1) 33 | self.bn3_1 = nn.BatchNorm2d(256) 34 | self.relu3_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 35 | self.conv3_2 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 36 | self.sn3_2 = nn.utils.spectral_norm(self.conv3_2) 37 | self.bn3_2 = nn.BatchNorm2d(256) 38 | self.relu3_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 39 | self.conv3_3 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 40 | self.sn3_3 = nn.utils.spectral_norm(self.conv3_3) 41 | self.bn3_3 = nn.BatchNorm2d(256) 42 | self.relu3_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 43 | self.pool3 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/8 44 | 45 | # conv4 46 | self.conv4_1 = nn.Conv2d(256, 512, (1,3), padding=(0,1)) 47 | self.sn4_1 = nn.utils.spectral_norm(self.conv4_1) 48 | self.bn4_1 = nn.BatchNorm2d(512) 49 | self.relu4_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 50 | self.conv4_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 51 | self.sn4_2 = nn.utils.spectral_norm(self.conv4_2) 52 | self.bn4_2 = nn.BatchNorm2d(512) 53 | self.relu4_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 54 | self.conv4_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 55 | self.sn4_3 = nn.utils.spectral_norm(self.conv4_3) 56 | self.bn4_3 = nn.BatchNorm2d(512) 57 | self.relu4_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 58 | self.pool4 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/16 59 | 60 | # conv5 61 | self.conv5_1 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 62 | self.sn5_1 = nn.utils.spectral_norm(self.conv5_1) 63 | self.bn5_1 = nn.BatchNorm2d(512) 64 | self.relu5_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 65 | self.conv5_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 66 | self.sn5_2 = nn.utils.spectral_norm(self.conv5_2) 67 | self.bn5_2 = nn.BatchNorm2d(512) 68 | self.relu5_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 69 | self.conv5_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 70 | self.sn5_3 = nn.utils.spectral_norm(self.conv5_3) 71 | self.bn5_3 = nn.BatchNorm2d(512) 72 | self.relu5_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 73 | self.pool5 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/32 74 | 75 | # fc6 76 | self.fc6 = nn.Conv2d(512, 4096, (1,7)) 77 | self.sn6 = nn.utils.spectral_norm(self.fc6) 78 | self.in6 = nn.InstanceNorm2d(4096) 79 | self.relu6 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 80 | self.drop6 = nn.Dropout2d(p=0.5) 81 | 82 | # fc7 83 | self.fc7 = nn.Conv2d(4096, 4096, (1,1)) 84 | self.sn7 = nn.utils.spectral_norm(self.fc7) 85 | self.in7 = nn.InstanceNorm2d(4096) 86 | self.relu7 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 87 | self.drop7 = nn.Dropout2d(p=0.5) 88 | 89 | self.score_fr = nn.Conv2d(4096, n_class, (1,1)) 90 | self.sn_score_fr = nn.utils.spectral_norm(self.score_fr) 91 | self.bn_score_fr = nn.BatchNorm2d(n_class) 92 | self.in_score_fr = nn.InstanceNorm2d(n_class) 93 | self.relu_score_fr = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 94 | self.score_pool4 = nn.Conv2d(512, n_class, (1,1)) 95 | self.sn_score_pool4 = nn.utils.spectral_norm(self.score_pool4) 96 | self.bn_score_pool4 = nn.BatchNorm2d(n_class) 97 | self.relu_bn_score_pool4 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 98 | 99 | self.upscore2 = nn.ConvTranspose2d( 100 | n_class, n_class, (1,4), stride=(1,2)) 101 | self.sn_upscore2 = nn.utils.spectral_norm(self.upscore2) 102 | self.bn_upscore2 = nn.BatchNorm2d(n_class) 103 | self.relu_upscore2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 104 | 105 | self.upscore16 = nn.ConvTranspose2d( 106 | n_class, n_class, (1,32), stride=(1,16)) 107 | self.sn_upscore16 = nn.utils.spectral_norm(self.upscore16) 108 | self.bn_upscore16 = nn.BatchNorm2d(n_class) 109 | self.relu_upscore16 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 110 | self.sigmoid_upscore16 = nn.Sigmoid() 111 | self.tanh_upscore16 = nn.Tanh() 112 | 113 | self.relu_add = nn.ReLU()#nn.LeakyReLU(0.2) 114 | 115 | def forward(self, x): 116 | # input 117 | h = x 118 | # conv1 119 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) #;print(h.shape) 120 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) #;print(h.shape) 121 | h = self.pool1(h) #;print(h.shape) 122 | # conv2 123 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) #;print(h.shape) 124 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) #;print(h.shape) 125 | h = self.pool2(h) #;print(h.shape) 126 | # conv3 127 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) #;print(h.shape) 128 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) #;print(h.shape) 129 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) #;print(h.shape) 130 | h = self.pool3(h) #;print(h.shape) 131 | # conv4 132 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) #;print(h.shape) 133 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) #;print(h.shape) 134 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) #;print(h.shape) 135 | h = self.pool4(h) #;print(h.shape) 136 | pool4 = h 137 | # conv5 138 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) #;print(h.shape) 139 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) #;print(h.shape) 140 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) #;print(h.shape) 141 | h = self.pool5(h) #;print(h.shape) 142 | # conv6 143 | h = self.relu6(self.fc6(h)) #;print(h.shape) 144 | h = self.drop6(h) #;print(h.shape) 145 | # conv7 146 | h = self.relu7(self.fc7(h)) #;print(h.shape) 147 | h = self.drop7(h) #;print(h.shape) 148 | # conv8 149 | h = self.in_score_fr(self.score_fr(h)) # original should be bn_score_fr, in order to handle the one frame input i.e. [1,1024,1,1] input 150 | # deconv1 151 | h = self.upscore2(h) 152 | upscore2 = h 153 | # get score_pool4c to do skip connection 154 | h = self.bn_score_pool4(self.score_pool4(pool4)) 155 | h = h[:, :, :, 5:5+upscore2.size()[3]] 156 | score_pool4c = h 157 | # skip connection 158 | h = upscore2+score_pool4c 159 | # deconv2 160 | h = self.upscore16(h) 161 | h = h[:, :, :, 27:27+x.size()[3]] 162 | 163 | return h 164 | 165 | if __name__ == '__main__': 166 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 167 | 168 | model = FCSN(n_class=2) 169 | model.to(device) 170 | #model.eval() 171 | data = torch.randn(1, 1024, 1, 5, requires_grad=True).to(device) # [1,1024,1,input_frames] 172 | out = model(data) 173 | print(out.shape) #[1,2,1,input_frames] 174 | print(out) 175 | softmax = nn.Softmax(dim=1) 176 | softmax_out = softmax(out*100) 177 | print(softmax_out) 178 | print(softmax_out[:,1,:]) 179 | print(softmax_out[:,1,:]*out) 180 | -------------------------------------------------------------------------------- /FCSN_ENC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FCSN_ENC(nn.Module): 5 | 6 | def __init__(self): 7 | super(FCSN_ENC, self).__init__() 8 | # conv1 (input shape (batch_size X Channel X H X W)) 9 | self.conv1_1 = nn.Conv2d(1024, 64, (1,3), padding=(0,100)) 10 | self.sn1_1 = nn.utils.spectral_norm(self.conv1_1) 11 | self.bn1_1 = nn.BatchNorm2d(64) 12 | self.relu1_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 13 | self.conv1_2 = nn.Conv2d(64, 64, (1,3), padding=(0,1)) 14 | self.sn1_2 = nn.utils.spectral_norm(self.conv1_2) 15 | self.bn1_2 = nn.BatchNorm2d(64) 16 | self.relu1_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 17 | self.pool1 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/2 18 | 19 | # conv2 20 | self.conv2_1 = nn.Conv2d(64, 128, (1,3), padding=(0,1)) 21 | self.sn2_1 = nn.utils.spectral_norm(self.conv2_1) 22 | self.bn2_1 = nn.BatchNorm2d(128) 23 | self.relu2_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 24 | self.conv2_2 = nn.Conv2d(128, 128, (1,3), padding=(0,1)) 25 | self.sn2_2 = nn.utils.spectral_norm(self.conv2_2) 26 | self.bn2_2 = nn.BatchNorm2d(128) 27 | self.relu2_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 28 | self.pool2 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/4 29 | 30 | # conv3 31 | self.conv3_1 = nn.Conv2d(128, 256, (1,3), padding=(0,1)) 32 | self.sn3_1 = nn.utils.spectral_norm(self.conv3_1) 33 | self.bn3_1 = nn.BatchNorm2d(256) 34 | self.relu3_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 35 | self.conv3_2 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 36 | self.sn3_2 = nn.utils.spectral_norm(self.conv3_2) 37 | self.bn3_2 = nn.BatchNorm2d(256) 38 | self.relu3_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 39 | self.conv3_3 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 40 | self.sn3_3 = nn.utils.spectral_norm(self.conv3_3) 41 | self.bn3_3 = nn.BatchNorm2d(256) 42 | self.relu3_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 43 | self.pool3 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/8 44 | 45 | # conv4 46 | self.conv4_1 = nn.Conv2d(256, 512, (1,3), padding=(0,1)) 47 | self.sn4_1 = nn.utils.spectral_norm(self.conv4_1) 48 | self.bn4_1 = nn.BatchNorm2d(512) 49 | self.relu4_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 50 | self.conv4_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 51 | self.sn4_2 = nn.utils.spectral_norm(self.conv4_2) 52 | self.bn4_2 = nn.BatchNorm2d(512) 53 | self.relu4_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 54 | self.conv4_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 55 | self.sn4_3 = nn.utils.spectral_norm(self.conv4_3) 56 | self.bn4_3 = nn.BatchNorm2d(512) 57 | self.relu4_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 58 | self.pool4 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/16 59 | 60 | # conv5 61 | self.conv5_1 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 62 | self.sn5_1 = nn.utils.spectral_norm(self.conv5_1) 63 | self.bn5_1 = nn.BatchNorm2d(512) 64 | self.relu5_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 65 | self.conv5_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 66 | self.sn5_2 = nn.utils.spectral_norm(self.conv5_2) 67 | self.bn5_2 = nn.BatchNorm2d(512) 68 | self.relu5_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 69 | self.conv5_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 70 | self.sn5_3 = nn.utils.spectral_norm(self.conv5_3) 71 | self.bn5_3 = nn.BatchNorm2d(512) 72 | self.relu5_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 73 | self.pool5 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/32 74 | 75 | # fc6 76 | self.fc6 = nn.Conv2d(512, 4096, (1,7)) 77 | self.sn6 = nn.utils.spectral_norm(self.fc6) 78 | self.in6 = nn.InstanceNorm2d(4096) 79 | self.relu6 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 80 | self.drop6 = nn.Dropout2d(p=0.5) 81 | 82 | # fc7 83 | self.fc7 = nn.Conv2d(4096, 4096, (1,1)) 84 | self.sn7 = nn.utils.spectral_norm(self.fc7) 85 | self.in7 = nn.InstanceNorm2d(4096) 86 | self.relu7 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 87 | self.drop7 = nn.Dropout2d(p=0.5) 88 | 89 | 90 | def forward(self, x): 91 | # input 92 | h = x 93 | # conv1 94 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) #;print(h.shape) 95 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) #;print(h.shape) 96 | h = self.pool1(h) #;print(h.shape) 97 | # conv2 98 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) #;print(h.shape) 99 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) #;print(h.shape) 100 | h = self.pool2(h) #;print(h.shape) 101 | # conv3 102 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) #;print(h.shape) 103 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) #;print(h.shape) 104 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) #;print(h.shape) 105 | h = self.pool3(h) #;print(h.shape) 106 | # conv4 107 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) #;print(h.shape) 108 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) #;print(h.shape) 109 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) #;print(h.shape) 110 | h = self.pool4(h) #;print(h.shape) 111 | pool4 = h 112 | # conv5 113 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) #;print(h.shape) 114 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) #;print(h.shape) 115 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) #;print(h.shape) 116 | h = self.pool5(h) #;print(h.shape) 117 | # conv6 118 | h = self.relu6(self.fc6(h)) #;print(h.shape) 119 | h = self.drop6(h) #;print(h.shape) 120 | # conv7 121 | h = self.relu7(self.fc7(h)) #;print(h.shape) 122 | h = self.drop7(h) #;print(h.shape) 123 | 124 | return h 125 | 126 | 127 | if __name__ == '__main__': 128 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 129 | 130 | model = FCSN_ENC() 131 | model.to(device) 132 | #model.eval() 133 | 134 | inp = torch.randn(1, 1024, 1, 1).to(device) 135 | out = model(inp) 136 | print(out.shape) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-VSLUD 2 | This is the implementation of the paper [Video 3 | Summarization by Learning from Unpaired Data(CVPR2019)](http://openaccess.thecvf.com/content_CVPR_2019/papers/Rochan_Video_Summarization_by_Learning_From_Unpaired_Data_CVPR_2019_paper.pdf) 4 | 5 | ![](imgs/architecture_VSLUD.PNG) 6 | 7 | The FCSN architecture in above image is from [Video Summarization Using Fully Convolutional 8 | Sequence Networks(ECCV2018)](http://openaccess.thecvf.com/content_ECCV_2018/papers/Mrigank_Rochan_Video_Summarization_Using_ECCV_2018_paper.pdf) 9 | 10 | ![](imgs/architecture_FCSN.PNG) 11 | 12 | ## Environment 13 | - Ubuntu 18.04.1 LTS 14 | - python 3.6.7 15 | - numpy 1.15.4 16 | - pytorch 1.1.0 17 | - torchvision 0.3.0 18 | - tqdm 4.32.1 19 | - tensorboardX 1.6 20 | 21 | ## Get started 22 | ### 1. clone the project 23 | $ cd && git clone https://github.com/pcshih/pytorch-VSLUD.git && cd pytorch-VSLUD 24 | ### 2. create dir for saving models 25 | $ mkdir saved_models 26 | ### 3. download [datasets.zip](https://drive.google.com/open?id=19TPsAPi7z88I9Pi0TeCcoHJ5fcbF3Dzp)(this dataset is from [here](https://github.com/KaiyangZhou/pytorch-vsumm-reinforce/issues/23)) into the project folder and unzip it 27 | $ unzip datasets.zip 28 | ### 4. run training_set_preparation.py for creating summe training set 29 | $ python3 training_set_preparation.py 30 | ### 5. train 31 | $ python3 train.py 32 | ### 6. start tensorboardX to view the loss curves 33 | $ tensorboard --logdir runs --port 6006 34 | 35 | ## Problems 36 | ![](imgs/loss.PNG) 37 | Sorry for my poor coding, I am new to pytorch and deep learning. 38 | 39 | The loss curves above are not reasonable during GAN training. 40 | 41 | "The decoder of FCSN consists of several temporal deconvolution operations which produces a vector of prediction scores with the same length as the input video. Each score indicates the likelihood of the corresponding frame being a key frame or non-key frame. Based on these scores, we select k key frames to form the predicted summary video." -> found in the paper Video 42 | [Summarization by Learning from Unpaired Data(CVPR2019)](http://openaccess.thecvf.com/content_CVPR_2019/papers/Rochan_Video_Summarization_by_Learning_From_Unpaired_Data_CVPR_2019_paper.pdf) 43 | 44 | I implement "we select k key frames to form the predicted summary video" by [torch.index_select(input, dim, index, out=None)](https://pytorch.org/docs/stable/torch.html) 45 | 46 | Is the function [torch.index_select(input, dim, index, out=None)](https://pytorch.org/docs/stable/torch.html) differentiable during training?Is this the main problem to cause the training to death? 47 | 48 | Please feel free to contact me via email (pcshih.cs07g@nctu.edu.tw) or disscuss on issues if you have any suggestions. 49 | 50 | I am all gratitude. -------------------------------------------------------------------------------- /SD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from FCSN_ENC import FCSN_ENC 4 | 5 | 6 | 7 | class double_conv(nn.Module): 8 | '''(conv => BN => ReLU) * 2''' 9 | def __init__(self, in_ch, out_ch): 10 | super(double_conv, self).__init__() 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 13 | nn.BatchNorm2d(out_ch), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | return x 23 | 24 | class inconv(nn.Module): 25 | def __init__(self, in_ch, out_ch): 26 | super(inconv, self).__init__() 27 | self.conv = double_conv(in_ch, out_ch) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | class down(nn.Module): 34 | def __init__(self, in_ch, out_ch): 35 | super(down, self).__init__() 36 | self.mpconv = nn.Sequential( 37 | nn.MaxPool2d(2), 38 | double_conv(in_ch, out_ch) 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.mpconv(x) 43 | return x 44 | 45 | class up(nn.Module): 46 | def __init__(self, in_ch, out_ch, bilinear=True): 47 | super(up, self).__init__() 48 | 49 | # would be a nice idea if the upsampling could be learned too, 50 | # but my machine do not have enough memory to handle all those weights 51 | if bilinear: 52 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 53 | else: 54 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 55 | 56 | self.conv = double_conv(in_ch, out_ch) 57 | 58 | def forward(self, x1, x2): 59 | x1 = self.up(x1) 60 | 61 | # input is CHW 62 | diffY = x2.size()[2] - x1.size()[2] 63 | diffX = x2.size()[3] - x1.size()[3] 64 | 65 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 66 | diffY // 2, diffY - diffY//2)) 67 | 68 | # for padding issues, see 69 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 70 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 71 | 72 | x = torch.cat([x2, x1], dim=1) 73 | x = self.conv(x) 74 | return x 75 | 76 | class outconv(nn.Module): 77 | def __init__(self, in_ch, out_ch): 78 | super(outconv, self).__init__() 79 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 80 | 81 | def forward(self, x): 82 | x = self.conv(x) 83 | return x 84 | 85 | class SD_test(nn.Module): 86 | def __init__(self): 87 | super(SD_test, self).__init__() 88 | self.inc = inconv(1024, 64) 89 | self.down1 = down(64, 128) 90 | self.down2 = down(128, 256) 91 | self.down3 = down(256, 512) 92 | self.down4 = down(512, 512) 93 | 94 | self.linear = nn.Linear(512, 1) 95 | self.sigmoid = nn.Sigmoid() 96 | 97 | 98 | def forward(self, x): 99 | h = x 100 | x1 = self.inc(x); #print(x1.shape) 101 | x2 = self.down1(x1); #print(x2.shape) 102 | x3 = self.down2(x2); #print(x3.shape) 103 | x4 = self.down3(x3); #print(x4.shape) 104 | x5 = self.down4(x4); #print(x5.shape) 105 | 106 | h = nn.AvgPool2d((1,h.size()[3]), stride=(1,h.size()[3]), ceil_mode=True)(x5) 107 | 108 | h = h.view(1, -1) 109 | 110 | h = self.linear(h) 111 | 112 | h = self.sigmoid(h).view(-1) 113 | 114 | return h 115 | 116 | 117 | 118 | class SD(nn.Module): 119 | def __init__(self): 120 | super(SD, self).__init__() 121 | 122 | self.FCSN_ENC = FCSN_ENC() 123 | self.linear = nn.Linear(4096, 1) 124 | self.sigmoid = nn.Sigmoid() 125 | self.relu = nn.ReLU(inplace=True) 126 | 127 | def forward(self, x): 128 | h = x 129 | 130 | h = self.FCSN_ENC(h); print(h.shape) 131 | 132 | h = nn.AvgPool2d((1,h.size()[3]), stride=(1,h.size()[3]), ceil_mode=True)(h) 133 | 134 | h = h.view(1, -1) 135 | 136 | h = self.linear(h) 137 | 138 | h = self.sigmoid(h).view(-1) 139 | 140 | return h 141 | 142 | if __name__ == '__main__': 143 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 144 | 145 | model = SD_test() 146 | model.to(device) 147 | model.eval() 148 | 149 | inp = torch.randn(1, 1024, 1, 245, requires_grad=True).to(device) 150 | #mask = torch.randn(1, 1, 1, 2).to(device); print(mask) 151 | 152 | #inp_view = inp.view(1,3,2); print(inp_view) 153 | #mask_view = mask.view(1,1,2); print(mask_view) 154 | 155 | #print(inp_view*mask_view) 156 | 157 | #scalar = torch.randn(1); print(scalar) 158 | #print(torch.mean(scalar)) 159 | 160 | 161 | out = model(inp) 162 | #print(out.shape) -------------------------------------------------------------------------------- /SK.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | #from FCSN import FCSN 5 | 6 | import random 7 | 8 | 9 | class SK(nn.Module): 10 | def __init__(self, n_class=2): 11 | super(SK, self).__init__() 12 | # conv1 input shape (batch_size, Channel, H, W) -> (1,1024,1,T) 13 | self.conv1_1 = nn.Conv2d(1024, 64, (1,3), padding=(0,100)) 14 | self.sn1_1 = nn.utils.spectral_norm(self.conv1_1) 15 | self.bn1_1 = nn.BatchNorm2d(64) 16 | self.relu1_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 17 | self.conv1_2 = nn.Conv2d(64, 64, (1,3), padding=(0,1)) 18 | self.sn1_2 = nn.utils.spectral_norm(self.conv1_2) 19 | self.bn1_2 = nn.BatchNorm2d(64) 20 | self.relu1_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 21 | self.pool1 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/2 22 | 23 | # conv2 24 | self.conv2_1 = nn.Conv2d(64, 128, (1,3), padding=(0,1)) 25 | self.sn2_1 = nn.utils.spectral_norm(self.conv2_1) 26 | self.bn2_1 = nn.BatchNorm2d(128) 27 | self.relu2_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 28 | self.conv2_2 = nn.Conv2d(128, 128, (1,3), padding=(0,1)) 29 | self.sn2_2 = nn.utils.spectral_norm(self.conv2_2) 30 | self.bn2_2 = nn.BatchNorm2d(128) 31 | self.relu2_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 32 | self.pool2 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/4 33 | 34 | # conv3 35 | self.conv3_1 = nn.Conv2d(128, 256, (1,3), padding=(0,1)) 36 | self.sn3_1 = nn.utils.spectral_norm(self.conv3_1) 37 | self.bn3_1 = nn.BatchNorm2d(256) 38 | self.relu3_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 39 | self.conv3_2 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 40 | self.sn3_2 = nn.utils.spectral_norm(self.conv3_2) 41 | self.bn3_2 = nn.BatchNorm2d(256) 42 | self.relu3_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 43 | self.conv3_3 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 44 | self.sn3_3 = nn.utils.spectral_norm(self.conv3_3) 45 | self.bn3_3 = nn.BatchNorm2d(256) 46 | self.relu3_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 47 | self.pool3 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/8 48 | 49 | # conv4 50 | self.conv4_1 = nn.Conv2d(256, 512, (1,3), padding=(0,1)) 51 | self.sn4_1 = nn.utils.spectral_norm(self.conv4_1) 52 | self.bn4_1 = nn.BatchNorm2d(512) 53 | self.relu4_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 54 | self.conv4_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 55 | self.sn4_2 = nn.utils.spectral_norm(self.conv4_2) 56 | self.bn4_2 = nn.BatchNorm2d(512) 57 | self.relu4_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 58 | self.conv4_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 59 | self.sn4_3 = nn.utils.spectral_norm(self.conv4_3) 60 | self.bn4_3 = nn.BatchNorm2d(512) 61 | self.relu4_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 62 | self.pool4 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/16 63 | 64 | # conv5 65 | self.conv5_1 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 66 | self.sn5_1 = nn.utils.spectral_norm(self.conv5_1) 67 | self.bn5_1 = nn.BatchNorm2d(512) 68 | self.relu5_1 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 69 | self.conv5_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 70 | self.sn5_2 = nn.utils.spectral_norm(self.conv5_2) 71 | self.bn5_2 = nn.BatchNorm2d(512) 72 | self.relu5_2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 73 | self.conv5_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 74 | self.sn5_3 = nn.utils.spectral_norm(self.conv5_3) 75 | self.bn5_3 = nn.BatchNorm2d(512) 76 | self.relu5_3 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 77 | self.pool5 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/32 78 | 79 | # fc6 80 | self.fc6 = nn.Conv2d(512, 4096, (1,7)) 81 | self.sn6 = nn.utils.spectral_norm(self.fc6) 82 | self.in6 = nn.InstanceNorm2d(4096) 83 | self.relu6 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 84 | self.drop6 = nn.Dropout2d(p=0.5) 85 | 86 | # fc7 87 | self.fc7 = nn.Conv2d(4096, 4096, (1,1)) 88 | self.sn7 = nn.utils.spectral_norm(self.fc7) 89 | self.in7 = nn.InstanceNorm2d(4096) 90 | self.relu7 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 91 | self.drop7 = nn.Dropout2d(p=0.5) 92 | 93 | self.score_fr = nn.Conv2d(4096, n_class, (1,1)) 94 | self.sn_score_fr = nn.utils.spectral_norm(self.score_fr) 95 | self.bn_score_fr = nn.BatchNorm2d(n_class) 96 | self.in_score_fr = nn.InstanceNorm2d(n_class) 97 | self.relu_score_fr = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 98 | self.score_pool4 = nn.Conv2d(512, n_class, (1,1)) 99 | self.sn_score_pool4 = nn.utils.spectral_norm(self.score_pool4) 100 | self.bn_score_pool4 = nn.BatchNorm2d(n_class) 101 | self.relu_bn_score_pool4 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 102 | 103 | self.upscore2 = nn.ConvTranspose2d( 104 | n_class, n_class, (1,4), stride=(1,2)) 105 | self.sn_upscore2 = nn.utils.spectral_norm(self.upscore2) 106 | self.bn_upscore2 = nn.BatchNorm2d(n_class) 107 | self.relu_upscore2 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 108 | 109 | self.upscore16 = nn.ConvTranspose2d( 110 | n_class, n_class, (1,32), stride=(1,16)) 111 | self.sn_upscore16 = nn.utils.spectral_norm(self.upscore16) 112 | self.bn_upscore16 = nn.BatchNorm2d(n_class) 113 | self.relu_upscore16 = nn.ReLU(inplace=True)#nn.LeakyReLU(0.2) 114 | self.sigmoid_upscore16 = nn.Sigmoid() 115 | self.tanh_upscore16 = nn.Tanh() 116 | 117 | self.relu_add = nn.ReLU()#nn.LeakyReLU(0.2) 118 | 119 | self.softmax = nn.Softmax(dim=1) 120 | 121 | self.conv_reconstuct1 = nn.Conv2d(n_class, 1024, (1,1)) 122 | self.bn_reconstruct1 = nn.BatchNorm2d(1024) 123 | self.relu_reconstuct1 = nn.ReLU(inplace=True) 124 | 125 | self.conv_reconstuct2 = nn.Conv2d(1024, 1024, (1,1)) 126 | self.bn_reconstruct2 = nn.BatchNorm2d(1024) 127 | self.relu_reconstuct2 = nn.ReLU(inplace=True) 128 | 129 | 130 | def forward(self, x): 131 | # input 132 | h = x 133 | in_x = x 134 | # conv1 135 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) #;print(h.shape) 136 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) #;print(h.shape) 137 | h = self.pool1(h) #;print(h.shape) 138 | # conv2 139 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) #;print(h.shape) 140 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) #;print(h.shape) 141 | h = self.pool2(h) #;print(h.shape) 142 | # conv3 143 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) #;print(h.shape) 144 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) #;print(h.shape) 145 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) #;print(h.shape) 146 | h = self.pool3(h) #;print(h.shape) 147 | # conv4 148 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) #;print(h.shape) 149 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) #;print(h.shape) 150 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) #;print(h.shape) 151 | h = self.pool4(h) #;print(h.shape) 152 | pool4 = h 153 | # conv5 154 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) #;print(h.shape) 155 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) #;print(h.shape) 156 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) #;print(h.shape) 157 | h = self.pool5(h) #;print(h.shape) 158 | # conv6 159 | h = self.relu6(self.fc6(h)) #;print(h.shape) 160 | h = self.drop6(h) #;print(h.shape) 161 | # conv7 162 | h = self.relu7(self.fc7(h)) #;print(h.shape) 163 | h = self.drop7(h) #;print(h.shape) 164 | # conv8 165 | h = self.in_score_fr(self.score_fr(h)) # original should be bn_score_fr, in order to handle the one frame input i.e. [1,1024,1,1] input 166 | # deconv1 167 | h = self.upscore2(h) 168 | upscore2 = h 169 | # get score_pool4c to do skip connection 170 | h = self.bn_score_pool4(self.score_pool4(pool4)) 171 | h = h[:, :, :, 5:5+upscore2.size()[3]] 172 | score_pool4c = h 173 | # skip connection 174 | h = upscore2+score_pool4c 175 | # deconv2 176 | h = self.upscore16(h) 177 | h = h[:, :, :, 27:27+x.size()[3]]; #print("before softmax:", h) 178 | 179 | # h 180 | h_softmax = self.softmax(h); #print("after softmax:", h_softmax) 181 | 182 | # get simulated 0/1 vector 183 | mask = h_softmax[:,1,:].view(1,1,1,-1); #print("mask:", mask) # [1,1,1,T] use key frame score to be the mask 184 | 185 | h_mask = h*mask; #print("h_mask:", h_mask) 186 | 187 | h_reconstruct = self.relu_reconstuct1(self.bn_reconstruct1(self.conv_reconstuct1(h_mask))) # [1,1024,1,T] 188 | x_select = in_x*mask 189 | 190 | # merge with input features 191 | h_merge = h_reconstruct + x_select # [1,1024,1,T] 192 | h_merge_reconstruct = self.relu_reconstuct2(self.bn_reconstruct2(self.conv_reconstuct2(h_merge))) # [1,1024,1,T] 193 | 194 | 195 | return h_merge_reconstruct,mask,h # [1,1024,1,T],[1,1,1,T],[1,2,1,T] 196 | 197 | class SK_old(nn.Module): 198 | def __init__(self): 199 | super(SK_old, self).__init__() 200 | 201 | self.FCSN = FCSN(n_class=2) 202 | 203 | self.conv_1 = nn.Conv2d(2, 1024, (1,1)) 204 | self.batchnorm_1 = nn.BatchNorm2d(1024) 205 | self.relu_1 = nn.ReLU(inplace=True) 206 | 207 | self.tanh_h_select = nn.Tanh() 208 | self.relu_summary = nn.ReLU(inplace=True)#nn.RReLU() 209 | self.tanh_summary = nn.Tanh() 210 | self.sigmoid = nn.Sigmoid() 211 | self.softmax = nn.Softmax(dim=1) 212 | 213 | self.conv_reconstuct2 = nn.Conv2d(1024, 1024, (1,1)) 214 | self.bn_reconstruct2 = nn.BatchNorm2d(1024) 215 | self.relu_reconstuct2 = nn.ReLU(inplace=True) 216 | 217 | 218 | def forward(self, x): 219 | h = x # [1,1024,1,T] 220 | x_temp = x # [1,1024,1,T] 221 | 222 | h = self.FCSN(h); print(h) # [1,2,1,T] 223 | 224 | ###old### 225 | # values, indices = h.max(1, keepdim=True) 226 | # # 0/1 vector, we only want key(indices=1) frame 227 | # column_mask = (indices==1).view(-1).nonzero().view(-1).tolist() 228 | 229 | # # if S_K doesn't select more than one element, then random select two element(for the sake of diversity loss) 230 | # if len(column_mask)<2: 231 | # print("S_K does not select anything, give a random mask with 2 elements") 232 | # column_mask = random.sample(list(range(h.shape[3])), 2) 233 | 234 | # index = torch.tensor(column_mask, device=torch.device('cuda:0')) 235 | # h_select = torch.index_select(h, 3, index) 236 | # x_select = torch.index_select(x_temp, 3, index) 237 | ###old### 238 | 239 | ###new### 240 | #index_mask = self.sigmoid(h[:,1]-h[:,0]).view(1,1,1,-1) 241 | diverse_h = h*100 242 | h_softmax = self.softmax(diverse_h) # [1,2,1,T] 243 | index_mask = h_softmax[:,1,:].view(1,1,1,-1) 244 | #index_mask = self.sigmoid(h[:,1]-h[:,0]).view(1,1,1,-1) 245 | #index_mask = (indices==1).type(torch.float32) 246 | # if S_K doesn't select more than one element, then random select two element(for the sake of diversity loss) 247 | # if (len(index_mask.view(-1).nonzero().view(-1).tolist()) < 2): 248 | # print("S_K does not select anything, give a random mask with 2 elements") 249 | # index_mask = torch.zeros([1,1,1,h.shape[3]], dtype=torch.float32, device=torch.device('cuda:0')) 250 | # for idx in random.sample(list(range(h.shape[3])), 2): 251 | # index_mask[:,:,:,idx] = 1.0 252 | 253 | h_select = h*index_mask 254 | x_select = x_temp*index_mask 255 | ###new### 256 | 257 | 258 | h_select = self.relu_1(self.conv_1(h_select)) 259 | 260 | 261 | summary = x_select+h_select 262 | 263 | summary = self.relu_reconstuct2(self.bn_reconstruct2(self.conv_reconstuct2(summary))) # [5,1024,1,320] 264 | 265 | #summary = self.relu_summary(summary) 266 | 267 | #return summary,column_mask 268 | return summary,index_mask 269 | 270 | 271 | class double_conv(nn.Module): 272 | '''(conv => BN => ReLU) * 2''' 273 | def __init__(self, in_ch, out_ch): 274 | super(double_conv, self).__init__() 275 | self.conv = nn.Sequential( 276 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 277 | nn.BatchNorm2d(out_ch), 278 | nn.ReLU(inplace=True), 279 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 280 | nn.BatchNorm2d(out_ch), 281 | nn.ReLU(inplace=True) 282 | ) 283 | 284 | def forward(self, x): 285 | x = self.conv(x) 286 | return x 287 | 288 | class inconv(nn.Module): 289 | def __init__(self, in_ch, out_ch): 290 | super(inconv, self).__init__() 291 | self.conv = double_conv(in_ch, out_ch) 292 | 293 | def forward(self, x): 294 | x = self.conv(x) 295 | return x 296 | 297 | class down(nn.Module): 298 | def __init__(self, in_ch, out_ch): 299 | super(down, self).__init__() 300 | self.mpconv = nn.Sequential( 301 | nn.MaxPool2d(2), 302 | double_conv(in_ch, out_ch) 303 | ) 304 | 305 | def forward(self, x): 306 | x = self.mpconv(x) 307 | return x 308 | 309 | class up(nn.Module): 310 | def __init__(self, in_ch, out_ch, bilinear=True): 311 | super(up, self).__init__() 312 | 313 | # would be a nice idea if the upsampling could be learned too, 314 | # but my machine do not have enough memory to handle all those weights 315 | if bilinear: 316 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 317 | else: 318 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 319 | 320 | self.conv = double_conv(in_ch, out_ch) 321 | 322 | def forward(self, x1, x2): 323 | x1 = self.up(x1) 324 | 325 | # input is CHW 326 | diffY = x2.size()[2] - x1.size()[2] 327 | diffX = x2.size()[3] - x1.size()[3] 328 | 329 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 330 | diffY // 2, diffY - diffY//2)) 331 | 332 | # for padding issues, see 333 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 334 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 335 | 336 | x = torch.cat([x2, x1], dim=1) 337 | x = self.conv(x) 338 | return x 339 | 340 | class outconv(nn.Module): 341 | def __init__(self, in_ch, out_ch): 342 | super(outconv, self).__init__() 343 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 344 | 345 | def forward(self, x): 346 | x = self.conv(x) 347 | return x 348 | 349 | class SK_test(nn.Module): 350 | def __init__(self, n_channels=1024, n_classes=2): 351 | super(SK_test, self).__init__() 352 | self.inc = inconv(n_channels, 64) 353 | self.down1 = down(64, 128) 354 | self.down2 = down(128, 256) 355 | self.down3 = down(256, 512) 356 | self.down4 = down(512, 512) 357 | self.up1 = up(1024, 256) 358 | self.up2 = up(512, 128) 359 | self.up3 = up(256, 64) 360 | self.up4 = up(128, 64) 361 | self.outc = outconv(64, n_classes) 362 | self.softmax = nn.Softmax(dim=1) 363 | 364 | 365 | self.conv_reconstuct1 = nn.Conv2d(n_classes, 1024, (1,1)) 366 | self.bn_reconstruct1 = nn.BatchNorm2d(1024) 367 | self.relu_reconstuct1 = nn.ReLU(inplace=True) 368 | 369 | self.conv_reconstuct2 = nn.Conv2d(1024, 1024, (1,1)) 370 | self.bn_reconstruct2 = nn.BatchNorm2d(1024) 371 | self.relu_reconstuct2 = nn.ReLU(inplace=True) 372 | 373 | def forward(self, x): 374 | h = x 375 | 376 | x1 = self.inc(x); #print(x1.shape) 377 | x2 = self.down1(x1); #print(x2.shape) 378 | x3 = self.down2(x2); #print(x3.shape) 379 | x4 = self.down3(x3); #print(x4.shape) 380 | x5 = self.down4(x4); #print(x5.shape) 381 | x = self.up1(x5, x4); #print(x.shape) 382 | x = self.up2(x, x3); #print(x.shape) 383 | x = self.up3(x, x2); #print(x.shape) 384 | x = self.up4(x, x1); #print(x.shape) 385 | x = self.outc(x); #print(x.shape) 386 | 387 | h_softmax = self.softmax(x) 388 | 389 | mask = h_softmax[:,1,:].view(1,1,1,-1); #print("mask:", mask) # [1,1,1,T] use key frame score to be the mask 390 | 391 | h_mask = x*mask; #print("h_mask:", h_mask) 392 | 393 | h_reconstruct = self.relu_reconstuct1(self.bn_reconstruct1(self.conv_reconstuct1(h_mask))) # [1,1024,1,T] 394 | x_select = h*mask 395 | 396 | # # merge with input features 397 | h_merge = h_reconstruct + x_select # [1,1024,1,T] 398 | h_merge_reconstruct = self.relu_reconstuct2(self.bn_reconstruct2(self.conv_reconstuct2(h_merge))) # [1,1024,1,T] 399 | 400 | 401 | return h_merge_reconstruct,mask,x # [1,1024,1,T],[1,1,1,T],[1,2,1,T] 402 | 403 | 404 | 405 | if __name__ == '__main__': 406 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 407 | 408 | model = SK_test(1024, 2) 409 | #model.eval() 410 | model.to(device) 411 | inp = torch.randn(1, 1024, 1, 100, requires_grad=True).to(device) 412 | 413 | a,b,c = model(inp) 414 | print(a.shape) 415 | print(b.shape) 416 | print(c.shape) 417 | #print(out.shape) 418 | #print(out) 419 | #print(mask) 420 | 421 | -------------------------------------------------------------------------------- /imgs/architecture_FCSN.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcshih/pytorch-VSLUD/0c352a1b81b4d3a2663642d732af1d6ba67744b7/imgs/architecture_FCSN.PNG -------------------------------------------------------------------------------- /imgs/architecture_VSLUD.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcshih/pytorch-VSLUD/0c352a1b81b4d3a2663642d732af1d6ba67744b7/imgs/architecture_VSLUD.PNG -------------------------------------------------------------------------------- /imgs/loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcshih/pytorch-VSLUD/0c352a1b81b4d3a2663642d732af1d6ba67744b7/imgs/loss.PNG -------------------------------------------------------------------------------- /sample_FCSN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FCSN(nn.Module): 5 | 6 | def __init__(self, n_class=2): 7 | super(FCSN, self).__init__() 8 | # conv1 (input shape (batch_size X Channel X H X W)) 9 | self.conv1_1 = nn.Conv2d(1024, 64, (1,3), padding=(0,100)) 10 | self.relu1_1 = nn.ReLU(inplace=True) 11 | self.conv1_2 = nn.Conv2d(64, 64, (1,3), padding=(0,1)) 12 | self.relu1_2 = nn.ReLU(inplace=True) 13 | self.pool1 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/2 14 | 15 | # conv2 16 | self.conv2_1 = nn.Conv2d(64, 128, (1,3), padding=(0,1)) 17 | self.relu2_1 = nn.ReLU(inplace=True) 18 | self.conv2_2 = nn.Conv2d(128, 128, (1,3), padding=(0,1)) 19 | self.relu2_2 = nn.ReLU(inplace=True) 20 | self.pool2 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/4 21 | 22 | # conv3 23 | self.conv3_1 = nn.Conv2d(128, 256, (1,3), padding=(0,1)) 24 | self.relu3_1 = nn.ReLU(inplace=True) 25 | self.conv3_2 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 26 | self.relu3_2 = nn.ReLU(inplace=True) 27 | self.conv3_3 = nn.Conv2d(256, 256, (1,3), padding=(0,1)) 28 | self.relu3_3 = nn.ReLU(inplace=True) 29 | self.pool3 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/8 30 | 31 | # conv4 32 | self.conv4_1 = nn.Conv2d(256, 512, (1,3), padding=(0,1)) 33 | self.relu4_1 = nn.ReLU(inplace=True) 34 | self.conv4_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 35 | self.relu4_2 = nn.ReLU(inplace=True) 36 | self.conv4_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 37 | self.relu4_3 = nn.ReLU(inplace=True) 38 | self.pool4 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/16 39 | 40 | # conv5 41 | self.conv5_1 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 42 | self.relu5_1 = nn.ReLU(inplace=True) 43 | self.conv5_2 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 44 | self.relu5_2 = nn.ReLU(inplace=True) 45 | self.conv5_3 = nn.Conv2d(512, 512, (1,3), padding=(0,1)) 46 | self.relu5_3 = nn.ReLU(inplace=True) 47 | self.pool5 = nn.MaxPool2d((1,2), stride=(1,2), ceil_mode=True) # 1/32 48 | 49 | # fc6 50 | self.fc6 = nn.Conv2d(512, 4096, (1,7)) 51 | self.relu6 = nn.ReLU(inplace=True) 52 | self.drop6 = nn.Dropout2d() 53 | 54 | # fc7 55 | self.fc7 = nn.Conv2d(4096, 4096, (1,1)) 56 | self.relu7 = nn.ReLU(inplace=True) 57 | self.drop7 = nn.Dropout2d() 58 | 59 | self.score_fr = nn.Conv2d(4096, n_class, (1,1)) 60 | self.score_pool4 = nn.Conv2d(512, n_class, (1,1)) 61 | 62 | self.upscore2 = nn.ConvTranspose2d( 63 | n_class, n_class, (1,4), stride=(1,2), bias=False) 64 | self.upscore16 = nn.ConvTranspose2d( 65 | n_class, n_class, (1,32), stride=(1,16), bias=False) 66 | 67 | def forward(self, x): 68 | h = x 69 | h = self.relu1_1(self.conv1_1(h)) 70 | h = self.relu1_2(self.conv1_2(h)) 71 | h = self.pool1(h) 72 | 73 | h = self.relu2_1(self.conv2_1(h)) 74 | h = self.relu2_2(self.conv2_2(h)) 75 | h = self.pool2(h) 76 | 77 | h = self.relu3_1(self.conv3_1(h)) 78 | h = self.relu3_2(self.conv3_2(h)) 79 | h = self.relu3_3(self.conv3_3(h)) 80 | h = self.pool3(h) 81 | 82 | h = self.relu4_1(self.conv4_1(h)) 83 | h = self.relu4_2(self.conv4_2(h)) 84 | h = self.relu4_3(self.conv4_3(h)) 85 | h = self.pool4(h) 86 | pool4 = h # 1/16 87 | 88 | h = self.relu5_1(self.conv5_1(h)) 89 | h = self.relu5_2(self.conv5_2(h)) 90 | h = self.relu5_3(self.conv5_3(h)) 91 | h = self.pool5(h) 92 | 93 | h = self.relu6(self.fc6(h)) 94 | h = self.drop6(h) 95 | 96 | h = self.relu7(self.fc7(h)) 97 | h = self.drop7(h) 98 | 99 | h = self.score_fr(h) 100 | h = self.upscore2(h) 101 | upscore2 = h # 1/16 102 | 103 | h = self.score_pool4(pool4) 104 | #import ipdb; ipdb.set_trace() 105 | h = h[:, :, :, 5:5 + upscore2.size()[3]] 106 | score_pool4c = h # 1/16 107 | 108 | h = upscore2 + score_pool4c 109 | 110 | h = self.upscore16(h) 111 | h = h[:, :, :, 27:27 + x.size()[3]].contiguous() 112 | 113 | return h 114 | 115 | if __name__ == '__main__': 116 | model = FCSN(n_class=2) 117 | inp = torch.randn(1, 1024, 1, 519) # inp shape (1x1024x1xNframes) 118 | out = model(inp) 119 | print(out.shape) # should print (1x2x1XNfames) -- 2 score for each frame (key frame or non-key frame) 120 | 121 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from SK import * 3 | from SD import * 4 | import subprocess 5 | import tqdm 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | S_K = SK_test().to(device) 13 | S_D = SD_test().to(device) 14 | 15 | PATH_model_load = "saved_models/iter_0014999.tar" 16 | checkpoint_model = torch.load(PATH_model_load) 17 | S_K.load_state_dict(checkpoint_model['S_K_state_dict']) 18 | S_D.load_state_dict(checkpoint_model['S_D_state_dict']) 19 | 20 | S_K.eval() 21 | S_D.eval() 22 | #S_K.FCSN.eval() 23 | #S_D.FCSN_ENC.eval() 24 | 25 | video = torch.load("datasets/test_video_frame_pool5.tar") 26 | 27 | video_path="/media/data/PTec131b/VideoSum/testing_data/video" 28 | video_frame_path="/media/data/PTec131b/VideoSum/testing_data/video_frame" 29 | video_processed_path="/media/data/PTec131b/VideoSum/testing_data/video_processed_0029999" 30 | 31 | 32 | 33 | def merge_frame(name, video_processed_path, video_frame_path, video_mask, frames_count): 34 | """ 35 | name: 影片名 36 | video_processed_path: sample過後的影片、挑選完key frame的影片及concate前兩者的影片要放在哪 37 | video_mask: key frame mask 38 | frames_count: 這部影片總過有幾個frame 39 | """ 40 | width = 1280 41 | height = 720 42 | fps = 2 43 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 44 | black_frame = np.zeros((height,width,3), dtype=np.uint8) 45 | 46 | file_name_2fps = "{}/{}_2fps.mp4".format(video_processed_path,name) 47 | file_name_selected = "{}/{}_selected.mp4".format(video_processed_path,name) 48 | file_name_concate = "{}/{}_concate.mp4".format(video_processed_path,name) 49 | 50 | 51 | # out_2fps = cv2.VideoWriter(file_name_2fps, 52 | # fourcc, 53 | # fps, 54 | # (width, height)) 55 | # out_selected = cv2.VideoWriter(file_name_selected, 56 | # fourcc, 57 | # fps, 58 | # (width, height)) 59 | out_concate = cv2.VideoWriter(file_name_concate, 60 | fourcc, 61 | fps, 62 | (width*2, height)) 63 | 64 | 65 | for i in range(frames_count): 66 | # 2fps 67 | frame_path = "{}/{}/{}_{:0>4d}.jpg".format(video_frame_path, name, name, i+1) 68 | frame_2fps = cv2.imread(frame_path) 69 | #out_2fps.write(frame_2fps) 70 | 71 | # selected+concate 72 | if (i>=len(video_mask)): 73 | frame = black_frame 74 | else: 75 | frame_path = "{}/{}/{}_{:0>4d}.jpg".format(video_frame_path, name, name, video_mask[i]+1) 76 | frame = cv2.imread(frame_path) 77 | #out_selected.write(frame) 78 | 79 | frame_concate = np.concatenate((frame_2fps, frame), axis=1) 80 | out_concate.write(frame_concate) 81 | 82 | #out_2fps.release() 83 | #out_selected.release() 84 | out_concate.release() 85 | 86 | 87 | 88 | 89 | 90 | def test(): 91 | tqdm_range = tqdm.trange(len(video["feature"])) 92 | 93 | for i in tqdm_range: # video["feature"] -> [[1,1024,1,A], [1,1024,1,B]...] 94 | vd = video["feature"][i].to(device) 95 | name = video["name_list"][i] 96 | frames_count = video["frame_list"][i] 97 | 98 | _,video_mask,_ = S_K(vd) 99 | 100 | print(i, video_mask.view(-1)) 101 | 102 | #subprocess.call(["./merge_original_frame.sh", 103 | # video_frame_path, 104 | # video_processed_path, 105 | # name]) 106 | 107 | # merge frames 108 | #video_mask_list = video_mask.view(-1).nonzero().view(-1).tolist(); print(video_mask_list) 109 | #merge_frame(name, video_processed_path, video_frame_path, video_mask_list, frames_count) 110 | 111 | 112 | if __name__ == '__main__': 113 | test() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.optim as optim 5 | from tensorboardX import SummaryWriter 6 | 7 | import time 8 | import tqdm 9 | import random 10 | from SK import * 11 | from SD import * 12 | 13 | random.seed(time.time()) 14 | 15 | 16 | print("loading training data...") 17 | video = torch.load("datasets/video_frame_pool5.tar") 18 | summary = torch.load("datasets/summary_frame_pool5.tar") 19 | print("loading training data ended") 20 | 21 | PATH_record = "saved_models/loss_record_3.tar" 22 | PATH_model = "saved_models" 23 | 24 | EPOCH = 1000 25 | 26 | # reconstruction error coefficient 27 | reconstruction_error_coeff = 0.5 28 | # diversity error coefficient 29 | diversity_error_coeff = 0.0 30 | 31 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | 33 | # ref: https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5 34 | def weights_init(m): 35 | ''' 36 | Usage: 37 | model = Model() 38 | model.apply(weight_init) 39 | ''' 40 | if isinstance(m, nn.Conv1d): 41 | init.normal_(m.weight.data) 42 | if m.bias is not None: 43 | init.normal_(m.bias.data) 44 | elif isinstance(m, nn.Conv2d): 45 | #init.xavier_normal_(m.weight.data) 46 | init.kaiming_normal_(m.weight.data) 47 | if m.bias is not None: 48 | init.normal_(m.bias.data) 49 | elif isinstance(m, nn.Conv3d): 50 | init.xavier_normal_(m.weight.data) 51 | if m.bias is not None: 52 | init.normal_(m.bias.data) 53 | elif isinstance(m, nn.ConvTranspose1d): 54 | init.normal_(m.weight.data) 55 | if m.bias is not None: 56 | init.normal_(m.bias.data) 57 | elif isinstance(m, nn.ConvTranspose2d): 58 | #init.xavier_normal_(m.weight.data) 59 | init.kaiming_normal_(m.weight.data) 60 | if m.bias is not None: 61 | init.normal_(m.bias.data) 62 | elif isinstance(m, nn.ConvTranspose3d): 63 | init.xavier_normal_(m.weight.data) 64 | if m.bias is not None: 65 | init.normal_(m.bias.data) 66 | elif isinstance(m, nn.BatchNorm1d): 67 | init.normal_(m.weight.data, mean=1, std=0.02) 68 | init.constant_(m.bias.data, 0) 69 | elif isinstance(m, nn.BatchNorm2d): 70 | init.normal_(m.weight.data, mean=1, std=0.02) 71 | init.constant_(m.bias.data, 0) 72 | elif isinstance(m, nn.BatchNorm3d): 73 | init.normal_(m.weight.data, mean=1, std=0.02) 74 | init.constant_(m.bias.data, 0) 75 | elif isinstance(m, nn.Linear): 76 | init.xavier_normal_(m.weight.data) 77 | init.normal_(m.bias.data) 78 | elif isinstance(m, nn.LSTM): 79 | for param in m.parameters(): 80 | if len(param.shape) >= 2: 81 | init.orthogonal_(param.data) 82 | else: 83 | init.normal_(param.data) 84 | elif isinstance(m, nn.LSTMCell): 85 | for param in m.parameters(): 86 | if len(param.shape) >= 2: 87 | init.orthogonal_(param.data) 88 | else: 89 | init.normal_(param.data) 90 | elif isinstance(m, nn.GRU): 91 | for param in m.parameters(): 92 | if len(param.shape) >= 2: 93 | init.orthogonal_(param.data) 94 | else: 95 | init.normal_(param.data) 96 | elif isinstance(m, nn.GRUCell): 97 | for param in m.parameters(): 98 | if len(param.shape) >= 2: 99 | init.orthogonal_(param.data) 100 | else: 101 | init.normal_(param.data) 102 | 103 | 104 | S_K = SK_test().to(device) 105 | S_D = SD_test().to(device) 106 | 107 | 108 | optimizerS_K = optim.Adam(S_K.parameters(), lr=1e-5) 109 | optimizerS_D = optim.SGD(S_D.parameters(), lr=2e-5) 110 | 111 | # Assuming optimizer uses lr = 0.05 for all groups 112 | #example 113 | # lr = 0.05 if epoch < 30 114 | # lr = 0.005 if 30 <= epoch < 60 115 | # lr = 0.0005 if 60 <= epoch < 90 116 | scheduler_S_K = optim.lr_scheduler.StepLR(optimizerS_K, step_size=20, gamma=0.8) 117 | scheduler_S_D = optim.lr_scheduler.StepLR(optimizerS_D, step_size=20, gamma=0.8) 118 | 119 | 120 | # configure training record 121 | writer = SummaryWriter() 122 | 123 | 124 | # mode=0 -> first train 125 | # mode=1 -> continue train 126 | mode = 0 127 | 128 | if mode==0: 129 | print("first train") 130 | 131 | time_list = [] 132 | S_K_iter_loss_list = [] 133 | reconstruct_iter_loss_list = [] 134 | diversity_iter_loss_list = [] 135 | S_D_real_iter_loss_list = [] 136 | S_D_fake_iter_loss_list = [] 137 | S_D_total_iter_loss_list = [] 138 | 139 | S_K.apply(weights_init) 140 | S_D.apply(weights_init) 141 | S_K.train() 142 | S_D.train() 143 | elif mode==1: 144 | print("continue train") 145 | checkpoint_loss = torch.load(PATH_record) 146 | time_list = checkpoint_loss['time_list']; #print(time_list) 147 | 148 | iteration = len(time_list)-1 149 | PATH_model_load = "{}{}{:0>7d}{}".format(PATH_model, "/iter_", iteration, ".tar"); #print(PATH_model_load) 150 | checkpoint_model = torch.load(PATH_model_load) 151 | S_K.load_state_dict(checkpoint_model['S_K_state_dict']) 152 | S_D.load_state_dict(checkpoint_model['S_D_state_dict']) 153 | optimizerS_K.load_state_dict(checkpoint_model['optimizerS_K_state_dict']) 154 | optimizerS_D.load_state_dict(checkpoint_model['optimizerS_D_state_dict']) 155 | S_K.train() 156 | S_D.train() 157 | 158 | S_K_iter_loss_list = checkpoint_loss['S_K_iter_loss_list'] 159 | reconstruct_iter_loss_list = checkpoint_loss['reconstruct_iter_loss_list'] 160 | diversity_iter_loss_list = checkpoint_loss['diversity_iter_loss_list'] 161 | S_D_real_iter_loss_list = checkpoint_loss['S_D_real_iter_loss_list'] 162 | S_D_fake_iter_loss_list = checkpoint_loss['S_D_fake_iter_loss_list'] 163 | S_D_total_iter_loss_list = checkpoint_loss['S_D_total_iter_loss_list'] 164 | 165 | 166 | # draw previous loss 167 | for idx in range(len(time_list)): 168 | writer.add_scalar("loss/S_K", S_K_iter_loss_list[idx], idx, time_list[idx]) 169 | writer.add_scalar("loss/reconstruction", reconstruct_iter_loss_list[idx], idx, time_list[idx]) 170 | writer.add_scalar("loss/diversity", diversity_iter_loss_list[idx], idx, time_list[idx]) 171 | writer.add_scalar("loss/S_D_real", S_D_real_iter_loss_list[idx], idx, time_list[idx]) 172 | writer.add_scalar("loss/S_D_fake", S_D_fake_iter_loss_list[idx], idx, time_list[idx]) 173 | writer.add_scalar("loss/S_D_total", S_D_total_iter_loss_list[idx], idx, time_list[idx]) 174 | else: 175 | print("please select mode 0 or 1") 176 | 177 | 178 | 179 | criterion = nn.BCELoss() 180 | 181 | 182 | for epoch in range(EPOCH): 183 | # random feature index 184 | random.shuffle(video["feature"]) 185 | random.shuffle(summary["feature"]) 186 | 187 | # decay lr 188 | scheduler_S_K.step() 189 | scheduler_S_D.step() 190 | 191 | 192 | tqdm_range = tqdm.trange(len(video["feature"])) 193 | for i in tqdm_range: # video["feature"] -> [[1,1024,1,A], [1,1024,1,B]...] 194 | tqdm_range.set_description(" Epoch: {:0>5d}, Running current iter {:0>3d} ...".format(epoch+1, i+1)) 195 | 196 | vd = video["feature"][i] 197 | sd = summary["feature"][i] 198 | 199 | ############## 200 | # update S_K # 201 | ############## 202 | S_K.zero_grad() 203 | 204 | #S_K_summary,column_mask = S_K(vd) 205 | S_K_summary,index_mask,_ = S_K(vd) 206 | output = S_D(S_K_summary) 207 | label = torch.full((1,), 1, device=device) 208 | 209 | # adv. loss 210 | errS_K = criterion(output, label) 211 | 212 | ###old reconstruct### 213 | # index = torch.tensor(column_mask, device=device) 214 | # select_vd = torch.index_select(vd, 3, index) 215 | # reconstruct_loss = torch.norm(S_K_summary-select_vd, p=2)**2 216 | # reconstruct_loss /= len(column_mask) 217 | ###old reconstruct### 218 | 219 | ###new reconstruct### 220 | #reconstruct_loss = torch.sum((S_K_summary-vd)**2 * index_mask) / torch.sum(index_mask) # [1,1024,1,S]-[1,1024,1,T] 221 | ###new reconstruct### 222 | 223 | 224 | # diversity 225 | # S_K_summary = index_mask*S_K_summary 226 | # S_K_summary_reshape = S_K_summary.view(S_K_summary.shape[1], S_K_summary.shape[3]) 227 | # norm_div = torch.norm(S_K_summary_reshape, 2, 0, True) 228 | # S_K_summary_reshape = S_K_summary_reshape/norm_div 229 | # loss_matrix = S_K_summary_reshape.transpose(1, 0).mm(S_K_summary_reshape) 230 | # diversity_loss = loss_matrix.sum() - loss_matrix.trace() 231 | # #diversity_loss = diversity_loss/len(column_mask)/(len(column_mask)-1) 232 | # diversity_loss = diversity_loss/(torch.sum(index_mask))/(torch.sum(index_mask)-1) 233 | 234 | ######################## LOSS FROM FCSN ######################### 235 | # 2D 1D conversion 236 | outputs_reconstruct = S_K_summary.view(1,1024,-1) # [1,1024,1,T] -> [1,1024,T] 237 | mask = index_mask.view(1,1,-1) # [1,1,1,T] -> [1,1,T] 238 | feature = vd.view(1,1024,-1) # [1,1024,1,T] -> [1,1024,T] 239 | 240 | #print(i, mask.view(-1)) 241 | 242 | # reconst. loss改成分批再做平均 243 | feature_select = feature*mask # [1,1024,T] 244 | outputs_reconstruct_select = outputs_reconstruct*mask # [1,1024,T] 245 | feature_diff_1 = torch.sum((feature_select-outputs_reconstruct_select)**2, dim=1) # [1,T] 246 | feature_diff_1 = torch.sum(feature_diff_1, dim=1) # [1] 247 | 248 | mask_sum = torch.sum(mask, dim=2) # [1,1] 249 | mask_sum = torch.sum(mask_sum, dim=1) # [1] 250 | 251 | reconstruct_loss = torch.mean(feature_diff_1/mask_sum) # scalar 252 | 253 | 254 | # diversity loss 255 | batch_size, feat_size, frames = outputs_reconstruct.shape # [1,1024,T] 256 | 257 | outputs_reconstruct_norm = torch.norm(outputs_reconstruct, p=2, dim=1, keepdim=True) # [1,1,T] 258 | 259 | normalized_outputs_reconstruct = outputs_reconstruct/outputs_reconstruct_norm # [1,1024,T] 260 | 261 | normalized_outputs_reconstruct_reshape = normalized_outputs_reconstruct.permute(0, 2, 1) # [1,T,1024] 262 | 263 | similarity_matrix = torch.bmm(normalized_outputs_reconstruct_reshape, normalized_outputs_reconstruct) # [1,T,T] 264 | 265 | mask_trans = mask.permute(0,2,1) # [1,T,1] 266 | mask_matrix = torch.bmm(mask_trans, mask) # [1,T,T] 267 | # filter out non key 268 | similarity_matrix_filtered = similarity_matrix*mask_matrix # [1,T,T] 269 | 270 | diversity_loss = 0 271 | acc_batch_size = 0 272 | for j in range(batch_size): 273 | batch_similarity_matrix_filtered = similarity_matrix_filtered[j,:,:] # [T,T] 274 | batch_mask = mask[j,:,:] # [T,T] 275 | if batch_mask.sum() < 2: 276 | #print("select less than 2 frames", batch_mask.sum()) 277 | batch_diversity_loss = 0 278 | else: 279 | batch_diversity_loss = (batch_similarity_matrix_filtered.sum()-batch_similarity_matrix_filtered.trace())/(batch_mask.sum()*(batch_mask.sum()-1)) 280 | acc_batch_size += 1 281 | 282 | diversity_loss += batch_diversity_loss 283 | 284 | if acc_batch_size>0: 285 | diversity_loss /= acc_batch_size 286 | #print(acc_batch_size) 287 | else: 288 | diversity_loss = 0 289 | 290 | 291 | # sparsity loss 292 | # sigma = 0.3 293 | # mask_mean = torch.mean(mask, dim=2) # [1,1] 294 | # mask_mean = torch.sum(mask_mean, dim=1) # [1] 295 | # sigma_vector = torch.ones([batch_size], device=device)*sigma # [1] 296 | # sparsity_loss = torch.mean((sigma_vector-mask_mean)**2) 297 | 298 | S_K_total_loss = errS_K + reconstruction_error_coeff*reconstruct_loss + diversity_error_coeff*diversity_loss # for summe dataset beta=1 299 | #S_K_total_loss = errS_K+reconstruct_loss 300 | S_K_total_loss.backward() 301 | 302 | # print grad 303 | #print("weight grad:", S_K.inc.conv.conv[0].weight.grad) 304 | 305 | 306 | # update 307 | optimizerS_K.step() 308 | 309 | ############## 310 | # update S_D # 311 | ############## 312 | # update every 5 epoch 313 | #if((epoch+1)%5==0): 314 | S_D.zero_grad() 315 | 316 | # real summary # 317 | output = S_D(sd) 318 | label.fill_(1) 319 | err_S_D_real = criterion(output, label) 320 | err_S_D_real.backward() 321 | 322 | # fake summary # 323 | S_K_summary,_,_ = S_K(vd) 324 | output = S_D(S_K_summary.detach()); #print(S_K_summary) 325 | label.fill_(0) 326 | err_S_D_fake = criterion(output, label) 327 | err_S_D_fake.backward() 328 | 329 | S_D_total_loss = err_S_D_real+err_S_D_fake 330 | 331 | # print grad 332 | #print("weight grad:", S_D.inc.conv.conv[0].weight.grad) 333 | 334 | optimizerS_D.step() 335 | # else: 336 | # err_S_D_real = -1.0 337 | # err_S_D_fake = -1.0 338 | # S_D_total_loss = err_S_D_real+err_S_D_fake 339 | 340 | # record 341 | time_list.append(time.time()) 342 | S_K_iter_loss_list.append(errS_K) 343 | reconstruct_iter_loss_list.append(reconstruction_error_coeff*reconstruct_loss) 344 | diversity_iter_loss_list.append(diversity_error_coeff*diversity_loss) 345 | S_D_real_iter_loss_list.append(err_S_D_real) 346 | S_D_fake_iter_loss_list.append(err_S_D_fake) 347 | S_D_total_iter_loss_list.append(S_D_total_loss) 348 | 349 | iteration = len(time_list)-1 350 | 351 | if ((iteration+1)%(150*50)==0): # save every 50 epoch 352 | PATH_model_save = "{}{}{:0>7d}{}".format(PATH_model, "/iter_", iteration, "_3.tar") 353 | S_K_state_dict = S_K.state_dict() 354 | optimizerS_K_state_dict = optimizerS_K.state_dict() 355 | S_D_state_dict = S_D.state_dict() 356 | optimizerS_D_state_dict = optimizerS_D.state_dict() 357 | 358 | torch.save({ 359 | "S_K_state_dict": S_K_state_dict, 360 | "optimizerS_K_state_dict": optimizerS_K_state_dict, 361 | "S_D_state_dict": S_D_state_dict, 362 | "optimizerS_D_state_dict": optimizerS_D_state_dict 363 | }, PATH_model_save) 364 | 365 | print("model is saved in {}".format(PATH_model_save)) 366 | 367 | torch.save({ 368 | "S_K_iter_loss_list": S_K_iter_loss_list, 369 | "reconstruct_iter_loss_list": reconstruct_iter_loss_list, 370 | "diversity_iter_loss_list": diversity_iter_loss_list, 371 | "S_D_real_iter_loss_list": S_D_real_iter_loss_list, 372 | "S_D_fake_iter_loss_list": S_D_fake_iter_loss_list, 373 | "S_D_total_iter_loss_list": S_D_total_iter_loss_list, 374 | "time_list": time_list 375 | }, PATH_record) 376 | 377 | print("loss record is saved in {}".format(PATH_record)) 378 | 379 | print("key frame prob", mask.view(-1)) 380 | 381 | 382 | # # send to tensorboard 383 | writer.add_scalar("loss/S_K", S_K_iter_loss_list[iteration], iteration, time_list[iteration]) # tag, Y, X -> 當Y只有一個時 384 | writer.add_scalar("loss/reconstruction", reconstruct_iter_loss_list[iteration], iteration, time_list[iteration]) 385 | writer.add_scalar("loss/diversity", diversity_iter_loss_list[iteration], iteration, time_list[iteration]) 386 | writer.add_scalar("loss/S_D_real", S_D_real_iter_loss_list[iteration], iteration, time_list[iteration]) 387 | writer.add_scalar("loss/S_D_fake", S_D_fake_iter_loss_list[iteration], iteration, time_list[iteration]) 388 | writer.add_scalar("loss/S_D_total", S_D_total_iter_loss_list[iteration], iteration, time_list[iteration]) 389 | 390 | 391 | writer.close() -------------------------------------------------------------------------------- /training_set_preparation/FeatureExtractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | # torchvision0.2.2 5 | #from googlenet import googlenet 6 | import time 7 | 8 | class FeatureExtractor(nn.Module): 9 | def __init__(self): 10 | # supposed input format(N,C,L) C:#features L:#frames 11 | super(FeatureExtractor, self).__init__() 12 | # GPU 13 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | # torchvision0.3.0 16 | self.googlenet = torchvision.models.googlenet(pretrained=True) 17 | # use eval mode to do feature extraction 18 | self.googlenet.eval() 19 | 20 | # we only want features no grads 21 | for param in self.googlenet.parameters(): 22 | param.requires_grad = False 23 | 24 | # feature extractor 25 | self.model = nn.Sequential(*list(self.googlenet.children())[:-2]) 26 | 27 | self.model = nn.DataParallel(self.model) 28 | self.model.to(self.device) 29 | 30 | def forward(self, x): 31 | # put data in to device 32 | x = x.to(self.device) 33 | 34 | h = self.model(x) 35 | 36 | h = h.view(h.size()[0],1024) 37 | h = h.transpose(1,0) 38 | 39 | return h 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 45 | 46 | net = FeatureExtractor() 47 | 48 | 49 | #net = nn.DataParallel(net) 50 | 51 | #net.to(device) 52 | 53 | #data = torch.randn((20, 3, 299, 299)) # (N,C,299,299) inceptionv3 input otherwise (N,C,224,224) 54 | data = torch.randn((616, 3, 224, 224)) 55 | #tic = time.time() 56 | #data = data.to(device) 57 | result = net(data) 58 | print(result.requires_grad) 59 | #toc = time.time() 60 | 61 | #print(toc-tic) GPU:0.1sec CPU:15.4sec 62 | 63 | #print(net) 64 | #print(net(data).size()) 65 | #print(net(data)) -------------------------------------------------------------------------------- /training_set_preparation/extract_frame.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | # no space in "=" left or right 4 | # seems cannot use "~" to get user path 5 | #video_path="/media/data/PTec131b/VideoSum/training_data/video" 6 | #video_frame_path="/media/data/PTec131b/VideoSum/training_data/video_frame" 7 | 8 | 9 | 10 | #video_path="/media/data/PTec131b/VideoSum/training_data/video_wang" 11 | #video_frame_path="/media/data/PTec131b/VideoSum/training_data/video_frame_wang" 12 | 13 | video_path=${1} 14 | video_frame_path=${2} 15 | 16 | skip_start_time=10 17 | skip_end_time=20 18 | 19 | for file in $(ls ${video_path}) # 加冒號最後會多一個冒號!!! 20 | do 21 | #echo ${file} 22 | #echo "Extracting frames of ${file} ..." 23 | dir_name="${video_frame_path}/${file%????}" #% ?有幾個代表去除最後幾個字母 24 | 25 | # 如果有處理過的不再處理 26 | if [ -d ${dir_name} ] 27 | then 28 | continue 29 | else 30 | mkdir ${dir_name} 31 | fi 32 | 33 | file_path="${video_path}/${file}" 34 | # cut -c 只要13~20個字 35 | # awk -F: 以":"作為分隔,一般是以" "作為分割 36 | duration=$(ffmpeg -i ${file_path} 2>&1 | grep "Duration" | cut -c 13-20 | awk -F: '{ print $1*3600+$2*60+$3 }') 37 | interval=$((duration-skip_start_time-skip_end_time)) # residue of duration 38 | 39 | 40 | frame_path="${dir_name}/${file%????}_%04d.jpg" 41 | # skip front and last 42 | #ffmpeg -i ${file_path} -ss ${skip_start_time} -t ${interval} -vf fps=2 ${frame_path} 2>&1 | grep "Input" # sample one frame every 0.5 sec(2fps), -ss 前幾秒丟掉 43 | # no skip 44 | ffmpeg -i ${file_path} -vf fps=2 ${frame_path} 2>&1 | grep "Input" 45 | done -------------------------------------------------------------------------------- /training_set_preparation/paper_training_set_preparation.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import collections 3 | import random 4 | import time 5 | 6 | import torch 7 | 8 | random.seed(time.time()) 9 | 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | # only train for summe first 13 | create_evaluation_dataset = "summe" 14 | 15 | dir_base = "datasets" 16 | file_training_video = "{}/reorganized_training_dataset_{}_video.tar".format(dir_base, create_evaluation_dataset) 17 | file_training_summary = "{}/reorganized_training_dataset_{}_summary.tar".format(dir_base, create_evaluation_dataset) 18 | 19 | dataset_name_list = ["summe", "tvsum", "youtube", "ovp"] 20 | 21 | dataset_reorganized = collections.OrderedDict() 22 | 23 | for dataset_name in dataset_name_list: 24 | dataset = h5py.File("{}/eccv16_dataset_{}_google_pool5.h5".format(dir_base,dataset_name), 'r') 25 | keys = list(dataset.keys()) 26 | 27 | # about 20% to be testing set 28 | if dataset_name=="summe": 29 | # mimic random selection 30 | random.shuffle(keys) 31 | keys = keys[6:] 32 | 33 | for key in keys: 34 | attributes = collections.OrderedDict() 35 | 36 | new_key = "{}_{}".format(dataset_name, key) 37 | 38 | feature_video_cuda = torch.from_numpy(dataset[key]["features"][...]).to(device) 39 | feature_video_cuda = feature_video_cuda.transpose(1,0).view(1,1024,1,feature_video_cuda.shape[0]) 40 | attributes["video_features"] = feature_video_cuda; #print(torch.isnan(feature_video_cuda).nonzero().view(-1)) 41 | 42 | gt_summary = torch.from_numpy(dataset[key]["gtsummary"][...]).to(device) 43 | column_index = gt_summary.nonzero().view(-1) 44 | feature_summary_cuda = torch.from_numpy(dataset[key]["features"][...]).to(device) 45 | feature_summary_cuda = feature_summary_cuda.transpose(1,0).view(1,1024,1,feature_summary_cuda.shape[0]) 46 | feature_summary_cuda = torch.index_select(feature_summary_cuda, 3, column_index) 47 | attributes["summary_features"] = feature_summary_cuda; #print(torch.isnan(feature_summary_cuda).nonzero().view(-1)) 48 | 49 | dataset_reorganized[new_key] = attributes; #print(new_key, dataset_reorganized[new_key]["video_features"].shape, dataset_reorganized[new_key]["summary_features"].shape) 50 | 51 | 52 | 53 | dataset_reorganized_keys_list = list(dataset_reorganized.keys()); #print(len(dataset_reorganized_keys_list)) 54 | 55 | # randomized to mimic random selection 56 | random.shuffle(dataset_reorganized_keys_list) 57 | # 50% video, 50% summary 58 | half_index = len(dataset_reorganized_keys_list)//2 59 | 60 | 61 | video_feature_data_list = [] 62 | summary_feature_data_list = [] 63 | for idx,video_name in enumerate(dataset_reorganized_keys_list): 64 | if(idxCHW [0,255]->[0.0,1.0] 43 | transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5)) # [0.0,1.0] -> [-1.0,1.0] official 那組是從imageNet出來的 44 | ])) 45 | 46 | feature_data_list = [] # 裡面存每個video sub frames的feature i.e. [[1,1024,1,v1], [1,1024,1,v2]......] 47 | frame_list = [] # 每個影片裡面有多少個frame i.e. [v1, v2......] 48 | name_list = dataset.classes # 存每部影片的名字 i.e. [Gordon, James......] 49 | 50 | 51 | # get the frame_count of each video -> frame_list 52 | tqdm_range = tqdm.trange(len(dataset.classes)) 53 | for video_idx in tqdm_range: 54 | tqdm_range.set_description(" Extracting Features from {}".format(name_list[video_idx])) 55 | temp_list = [i for i in dataset.imgs if i[1] == video_idx] # 用來數這部影片有多少個frame 56 | frame_list.append(len(temp_list)) # 得到該video的frame數量並儲存, frame_list=[1143, 2242, ......] 57 | 58 | video_images = torch.randn((len(temp_list),3,self.resize,self.resize)) # 宣告用來存每部video frame的空間 59 | 60 | # 先把一部影片的所有frame寫入video_images,之後要用googlenet抽特徵 61 | for frame in range(frame_list[video_idx]): 62 | if video_idx==0: # 如果是第一步影片時,idx從0~frame_list[video_idx]-1 63 | video_images[frame][:][:][:] = dataset[frame][0].view(1,3,self.resize,self.resize) 64 | else: 65 | video_images[frame][:][:][:] = dataset[frame+frame_list[video_idx-1]][0].view(1,3,self.resize,self.resize) 66 | 67 | # 因為GPU沒辦法處理一次整個影片的frames做feature extraction,故要切 68 | video_images_subs = torch.split(video_images, 1500, dim=0) # 多少張切成一個區塊,第一個區塊會存1500張照片 69 | video_images_subs = list(video_images_subs) # 轉成list[[1500,3,224,224], [rest,3,224,224]] 70 | 71 | # 處理每個區塊的feature 72 | for idx,sub in enumerate(video_images_subs): 73 | sub_gpu = sub.to(self.device); #print(sub_gpu.shape) 74 | sub_feature_data = self.net(sub_gpu) # [1024,T] 75 | 76 | if(idx == 0): 77 | cat_sub_feature = sub_feature_data 78 | #print(cat_sub_feature.shape) 79 | else: 80 | cat_sub_feature = torch.cat((cat_sub_feature,sub_feature_data),1) 81 | #print(cat_sub_feature.shape) 82 | 83 | # release gpu memory 84 | sub_gpu = sub_gpu.cpu() 85 | sub_feature_data = sub_feature_data.cpu() 86 | torch.cuda.empty_cache() 87 | 88 | 89 | cat_sub_feature = cat_sub_feature.view(1,1024,1,cat_sub_feature.size()[1]); 90 | 91 | print(cat_sub_feature) 92 | feature_data_list.append(cat_sub_feature); #print(cat_sub_feature.requires_grad) 93 | 94 | # release gpu memory 95 | cat_sub_feature = cat_sub_feature.cpu() 96 | torch.cuda.empty_cache() 97 | 98 | 99 | torch.save({"feature":feature_data_list, "name_list":name_list, "frame_list":frame_list}, self.save_path[index]) 100 | 101 | # print save result 102 | for i,feature in enumerate(feature_data_list): 103 | print(name_list[i], frame_list[i], feature.shape) 104 | 105 | 106 | def test(self): 107 | for index in range(len(self.video_path)): 108 | arg_1 = self.video_path[index] 109 | arg_2 = self.root_path[index] 110 | subprocess.call(["./extract_frame.sh", arg_1, arg_2]) # extract frame first 111 | 112 | 113 | if __name__ == '__main__': 114 | #print(len(dataset.classes), dataset.classes) 115 | #print(dataset.class_to_idx) 116 | #print(dataset.__len__()) 117 | 118 | #checkpoint = torch.load(PATH) 119 | #print(len(checkpoint["training_data"])) 120 | 121 | 122 | # 要抽取feature的影片位置 123 | #"/media/data/PTec131b/VideoSum/training_data/video", 124 | #"/media/data/PTec131b/VideoSum/training_data/summary", 125 | #"/media/data/PTec131b/VideoSum/testing_data/video" 126 | video_path_list = [ 127 | "/media/data/PTec131b/VideoSum/training_data/video", 128 | "/media/data/PTec131b/VideoSum/training_data/summary" 129 | #"/media/data/PTec131b/VideoSum/testing_data/video" 130 | ] 131 | # 暫存的影片frame先放在哪裡 132 | #"/media/data/PTec131b/VideoSum/training_data/video_frame", 133 | #"/media/data/PTec131b/VideoSum/training_data/summary_frame", 134 | #"/media/data/PTec131b/VideoSum/testing_data/video_frame" 135 | root_path_list = [ 136 | "/media/data/PTec131b/VideoSum/training_data/video_frame_wang", 137 | "/media/data/PTec131b/VideoSum/training_data/summary_frame_wang" 138 | #"/media/data/PTec131b/VideoSum/testing_data/video_frame" 139 | ] 140 | # 影片抽取完feature要放在哪 141 | #"/media/data/PTec131b/VideoSum/training_data/video_frame/video_frame.tar", 142 | #"/media/data/PTec131b/VideoSum/training_data/summary_frame/summary_frame.tar", 143 | #"/media/data/PTec131b/VideoSum/testing_data/video_frame/video_frame.tar" 144 | save_path_list = [ 145 | "/media/data/PTec131b/VideoSum/training_data/video_frame_wang/video_frame_pool5.tar", 146 | "/media/data/PTec131b/VideoSum/training_data/summary_frame_wang/summary_frame_pool5.tar" 147 | #"/media/data/PTec131b/VideoSum/testing_data/video_frame/video_frame.tar" 148 | ] 149 | process = PreProcess(video_path_list, root_path_list, save_path_list) 150 | #process.test() 151 | process.pre_process() 152 | #a = "{}{}{:0>5d}{}".format("./saved_models", "/iter_", 5, ".tar") 153 | #print(a) --------------------------------------------------------------------------------