├── README.md ├── losses.py ├── hyperparameters.py ├── dynanmics.py ├── train_AE_BN3_kth_1_2.py ├── utils.py ├── vision_BN3_1.py ├── ops.py └── train_all_AE_BN3_kth_1_2.py /README.md: -------------------------------------------------------------------------------- 1 | # Accurate Grid Keypoint Learning for Efficient Video Prediction 2 | 0. Dataset processing follows https://github.com/edenton/svg 3 | create the "data" folder 4 | 1. run train_AE_BN3_kth_1_2.py to detect keypoints 5 | 2. run train_all_AE_BN3_kth_1_2.py to learn keypoint dynamics 6 | 7 | 8 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Losses for the video representation model.""" 18 | 19 | import torch 20 | import torch.nn as nn 21 | import numpy as np 22 | import ops 23 | 24 | def temporal_separation_loss(cfg, coords): 25 | """Encourages keypoint to have different temporal trajectories. 26 | 27 | If two keypoints move along trajectories that are identical up to a time- 28 | invariant translation (offset), this suggest that they both represent the same 29 | object and are redundant, which we want to avoid. 30 | 31 | To measure this similarity of trajectories, we first center each trajectory by 32 | subtracting its mean. Then, we compute the pairwise distance between all 33 | trajectories at each timepoint. These distances are higher for trajectories 34 | that are less similar. To compute the loss, the distances are transformed by 35 | a Gaussian and averaged across time and across trajectories. 36 | 37 | Args: 38 | cfg: ConfigDict. 39 | coords: [time, batch, num_landmarks, 3] coordinate tensor. 40 | 41 | Returns: 42 | Separation loss. 43 | """ 44 | x = coords[Ellipsis, 0] 45 | y = coords[Ellipsis, 1] 46 | 47 | # Center trajectories: 48 | x = x - torch.mean(x, dim=0, keepdim=True) 49 | y = y - torch.mean(y, dim=0, keepdim=True) 50 | 51 | # Compute pairwise distance matrix: 52 | d = ((x[:, :, :, np.newaxis] - x[:, :, np.newaxis, :]) ** 2.0 + 53 | (y[:, :, :, np.newaxis] - y[:, :, np.newaxis, :]) ** 2.0) 54 | 55 | # Temporal mean: 56 | d = torch.mean(d, dim=0) 57 | 58 | # Apply Gaussian function such that loss falls off with distance: 59 | loss_matrix = torch.exp(-d / (2.0 * cfg.separation_loss_sigma ** 2.0)) 60 | loss_matrix = torch.mean(loss_matrix, dim=0) # Mean across batch. 61 | loss = torch.sum(loss_matrix) # Sum matrix elements. 62 | 63 | # Subtract sum of values on diagonal, which are always 1: 64 | loss = loss - cfg.num_keypoints 65 | 66 | # Normalize by maximal possible value. The loss is now scaled between 0 (all 67 | # keypoints are infinitely far apart) and 1 (all keypoints are at the same 68 | # location): 69 | loss = loss / (cfg.num_keypoints * (cfg.num_keypoints - 1)) 70 | 71 | 72 | return cfg.separation_loss_scale * loss 73 | 74 | def sparse_loss(weight_matrix, cfg): 75 | """L1-loss on mean heatmap activations, to encourage sparsity.""" 76 | weight_shape = weight_matrix.shape 77 | assert len(weight_shape) == 5, weight_shape 78 | 79 | heatmap_mean = torch.mean(weight_matrix, dim=(3, 4)) 80 | penalty = torch.mean(torch.abs(heatmap_mean)) 81 | 82 | return penalty * cfg.heatmap_regularization 83 | 84 | def detect_loss(weight_matrix, gama=0.01): 85 | weight_shape = weight_matrix.shape 86 | assert len(weight_shape) == 5, weight_shape 87 | weight_matrix1 = weight_matrix.view(weight_matrix.size(0), 88 | weight_matrix.size(1), weight_matrix.size(2), -1) 89 | heatmap_max = torch.max(weight_matrix1, dim=-1)[0] 90 | heatmap_mean = torch.mean(weight_matrix1, dim=-1) 91 | max_min = torch.min(heatmap_max-heatmap_mean, dim=-1)[0] 92 | incite = -torch.sum(max_min) 93 | return gama*incite 94 | -------------------------------------------------------------------------------- /hyperparameters.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Hyperparameters of the structured video prediction models.""" 18 | 19 | 20 | class ConfigDict(dict): 21 | """A dictionary whose keys can be accessed as attributes.""" 22 | 23 | def __getattr__(self, name): 24 | try: 25 | return self[name] 26 | except KeyError: 27 | raise AttributeError(name) 28 | 29 | def __setattr__(self, name, value): 30 | self[name] = value 31 | 32 | def get(self, key, default=None): 33 | """Allows to specify defaults when accessing the config.""" 34 | if key not in self: 35 | return default 36 | return self[key] 37 | 38 | 39 | def get_config(): 40 | """Default values for all hyperparameters.""" 41 | 42 | cfg = ConfigDict() 43 | 44 | cfg.seed = 1 45 | 46 | # Directories: 47 | cfg.dataset = 'Human3.6m' 48 | cfg.data_root = '../Dataset' 49 | cfg.train_dir = 'train_list.txt' 50 | cfg.test_dir = 'test_list.txt' 51 | 52 | # Architecture: 53 | cfg.layers_per_scale = 2 54 | cfg.conv_layer_kwargs = _conv_layer_kwargs() 55 | cfg.dense_layer_kwargs = _dense_layer_kwargs() 56 | 57 | # Optimization: 58 | cfg.batch_size = 32 59 | cfg.test_batch_size = 8 60 | cfg.test_N = 224 61 | cfg.steps_per_epoch = 600//cfg.batch_size 62 | cfg.num_epochs = 1000 63 | cfg.learning_rate = 0.001 64 | cfg.clipnorm = 1 65 | 66 | # Image sequence parameters: 67 | cfg.observed_steps = 10 68 | cfg.predicted_steps = 10 69 | cfg.seq_len = cfg.observed_steps + cfg.predicted_steps 70 | cfg.n_eval = 50 71 | cfg.img_w = 64 72 | cfg.img_h = 64 73 | 74 | # Keypoint encoding settings: 75 | cfg.num_keypoints = 48 76 | cfg.heatmap_width = 16 77 | cfg.heatmap_regularization = 1e-2 78 | cfg.keypoint_width = 1.5 79 | cfg.num_encoder_filters = 32 80 | cfg.separation_loss_scale = 2e-2 81 | cfg.separation_loss_sigma = 2e-3 82 | cfg.reg_lambda = 1e-4 83 | 84 | # Agent settings: 85 | cfg.hidden_size = 128 86 | cfg.n_layers = 2 87 | 88 | # Dynamics: 89 | cfg.num_rnn_units = 512 90 | cfg.convgru_units = 32 91 | cfg.convprior_net_dim = 32 92 | cfg.convposterior_net_dim = 32 93 | cfg.prior_net_dim = 4 94 | cfg.posterior_net_dim = 128 95 | cfg.latent_code_size = 16 96 | cfg.kl_loss_scale = 1e-2 97 | cfg.kl_annealing_steps = 1000 98 | cfg.use_deterministic_belief = False 99 | cfg.scheduled_sampling_ramp_steps = (cfg.steps_per_epoch * int(cfg.num_epochs * 0.8)) 100 | cfg.scheduled_sampling_p_true_start_obs = 1.0 101 | cfg.scheduled_sampling_p_true_end_obs = 0.1 102 | cfg.scheduled_sampling_p_true_start_pred = 1.0 103 | cfg.scheduled_sampling_p_true_end_pred = 0.5 104 | cfg.num_samples_for_bom = 10 105 | cfg.nsample = 100 106 | 107 | return cfg 108 | 109 | 110 | def _conv_layer_kwargs(): 111 | """Returns a configDict with default conv layer hyperparameters.""" 112 | 113 | cfg = ConfigDict() 114 | 115 | cfg.kernel_size = 3 116 | cfg.padding = 1 117 | #cfg.activation = tf.nn.leaky_relu 118 | #cfg.kernel_regularizer = tf.keras.regularizers.l2(1e-4) 119 | 120 | # He-uniform initialization is suggested by this paper: 121 | # https://arxiv.org/abs/1803.01719 122 | # The paper only considers ReLU units and it might be different for leaky 123 | # ReLU, but it is a better guess than Glorot. 124 | #cfg.kernel_initializer = 'he_uniform' 125 | 126 | return cfg 127 | 128 | 129 | def _dense_layer_kwargs(): 130 | """Returns a configDict with default dense layer hyperparameters.""" 131 | 132 | cfg = ConfigDict() 133 | #cfg.activation = tf.nn.relu 134 | #cfg.kernel_initializer = 'he_uniform' 135 | 136 | return cfg 137 | -------------------------------------------------------------------------------- /dynanmics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import ops 5 | import torch.nn.functional as F 6 | 7 | %================================================================================ 8 | class convLSTMCell(nn.Module): 9 | 10 | def __init__(self, input_dim, hidden_dim, kernel_size): 11 | """ 12 | Initialize ConvLSTM cell. 13 | Parameters 14 | ---------- 15 | input_dim: int 16 | Number of channels of input tensor. 17 | hidden_dim: int 18 | Number of channels of hidden state. 19 | kernel_size: int 20 | Size of the convolutional kernel. 21 | bias: bool 22 | Whether or not to add the bias. 23 | """ 24 | 25 | super(convLSTMCell, self).__init__() 26 | 27 | self.input_dim = input_dim 28 | self.hidden_dim = hidden_dim 29 | 30 | self.kernel_size = kernel_size 31 | self.padding = kernel_size // 2 32 | 33 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 34 | out_channels=4 * self.hidden_dim, 35 | kernel_size=self.kernel_size, 36 | padding=self.padding, bias=False) 37 | 38 | def forward(self, input_tensor, cur_state): 39 | h_cur, c_cur = cur_state 40 | 41 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 42 | 43 | combined_conv = self.conv(combined) 44 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 45 | i = torch.sigmoid(cc_i) 46 | f = torch.sigmoid(cc_f) 47 | o = torch.sigmoid(cc_o) 48 | g = torch.tanh(cc_g) 49 | 50 | c_next = f * c_cur + i * g 51 | h_next = o * torch.tanh(c_next) 52 | 53 | return h_next, c_next 54 | 55 | 56 | class convlstm_rnn_p(nn.Module): 57 | def __init__(self, cfg, map_width, add_dim=0, scale_factor=1, input_dim = None): 58 | super(convlstm_rnn_p, self).__init__() 59 | if input_dim is None: 60 | self.input_dim = cfg.num_keypoints 61 | else: 62 | self.input_dim = input_dim 63 | self.hidden_size = 128//scale_factor 64 | self.batch_size = cfg.batch_size 65 | self.keypoint_width = cfg.keypoint_width 66 | self.map_width = map_width//4 67 | self.n_layers = 1 68 | self.convlayer1 = nn.Sequential(nn.Conv2d(self.input_dim, 69 | 32//scale_factor, kernel_size=3, 70 | stride=1, padding=1), 71 | nn.LeakyReLU(0.2, inplace=True)) 72 | self.convlayer2 = nn.Sequential(nn.Conv2d(32//scale_factor, 73 | 64//scale_factor, kernel_size=3, 74 | stride=1, padding=1), 75 | nn.LeakyReLU(0.2, inplace=True)) 76 | self.convlayer3 = nn.Sequential(nn.Conv2d(64//scale_factor, 77 | self.hidden_size, kernel_size=3, 78 | stride=1, padding=1), 79 | nn.LeakyReLU(0.2, inplace=True)) 80 | self.convlstm = nn.ModuleList([convLSTMCell(input_dim=self.hidden_size+add_dim, 81 | hidden_dim=self.hidden_size, 82 | kernel_size=3) for i in range(self.n_layers)]) 83 | self.hidden = self.init_hidden() 84 | self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 85 | 86 | def init_hidden(self): 87 | hidden = [] 88 | for i in range(self.n_layers): 89 | hidden.append((torch.zeros(self.batch_size, self.hidden_size, 90 | self.map_width, self.map_width).cuda(), 91 | torch.zeros(self.batch_size, self.hidden_size, 92 | self.map_width, self.map_width).cuda())) 93 | return hidden 94 | 95 | def forward(self, gaussian_maps, latent_code=None, update_hidden=True): 96 | h_in = self.convlayer1(gaussian_maps) 97 | h_in = self.convlayer2(self.mp(h_in)) 98 | h_in = self.convlayer3(self.mp(h_in)) 99 | if not (latent_code is None): 100 | h_in = torch.cat((h_in, latent_code), 1) 101 | for i in range(self.n_layers): 102 | if update_hidden: 103 | self.hidden[i] = self.convlstm[i](h_in, self.hidden[i]) 104 | h_in = self.hidden[i][0] 105 | else: 106 | hidden = self.convlstm[i](h_in, self.hidden[i]) 107 | h_in = hidden[0] 108 | return h_in 109 | 110 | 111 | class convlstm_decoder_p(nn.Module): 112 | def __init__(self, cfg, add_dim=0, scale_factor=1): 113 | super(convlstm_decoder_p, self).__init__() 114 | self.convlayer1 = nn.Sequential(nn.Conv2d(128//scale_factor+add_dim, 115 | 128//scale_factor, kernel_size=3, 116 | stride=1, padding=1), 117 | nn.LeakyReLU(0.2, inplace=True)) 118 | self.convlayer2 = nn.Sequential(nn.Conv2d(128//scale_factor, 119 | 64//scale_factor, kernel_size=3, 120 | stride=1, padding=1), 121 | nn.LeakyReLU(0.2, inplace=True)) 122 | self.convlayer3 = nn.Sequential(nn.Conv2d(64//scale_factor, 123 | 32//scale_factor, kernel_size=3, 124 | stride=1, padding=1), 125 | nn.LeakyReLU(0.2, inplace=True)) 126 | self.adjust_channels_of_output = nn.Sequential(nn.Conv2d(32//scale_factor, cfg.num_keypoints, 127 | kernel_size=1)) 128 | self.LogSoftmax = nn.LogSoftmax(dim=2) 129 | self.up = nn.UpsamplingNearest2d(scale_factor=2) 130 | 131 | def forward(self, rnn_state, latent_code=None): 132 | if not (latent_code is None): 133 | rnn_state = torch.cat((rnn_state, latent_code), 1) 134 | h_out = self.convlayer1(rnn_state) 135 | h_out = self.convlayer2(self.up(h_out)) 136 | h_out = self.convlayer3(self.up(h_out)) 137 | gaussian_maps = self.adjust_channels_of_output(h_out) 138 | gaussian_maps_flat = gaussian_maps.view(gaussian_maps.size(0), 139 | gaussian_maps.size(1), 140 | -1) 141 | gaussian_maps_flat = self.LogSoftmax(gaussian_maps_flat) 142 | return gaussian_maps_flat 143 | 144 | 145 | class prior_net_cnn(nn.Module): 146 | def __init__(self, cfg, scale_factor=1): 147 | super(prior_net_cnn, self).__init__() 148 | self.embed = nn.Sequential(nn.Conv2d(128//scale_factor, 128//scale_factor, kernel_size=3, 149 | stride=1, padding=1), 150 | nn.LeakyReLU(0.2, inplace=True)) 151 | self.embed1 = nn.Sequential(nn.Conv2d(128//scale_factor, 1, kernel_size=3, 152 | stride=1, padding=1)) 153 | self.embed2 = nn.Sequential(nn.Conv2d(128//scale_factor, 1, kernel_size=3, 154 | stride=1, padding=1), 155 | nn.Softplus()) 156 | 157 | def forward(self, rnn_state): 158 | hidden = self.embed(rnn_state) 159 | means = self.embed1(hidden) 160 | stds = self.embed2(hidden) + 1e-4 161 | return means, stds 162 | 163 | 164 | class posterior_net_cnn(nn.Module): 165 | def __init__(self, cfg, scale_factor=1, input_dim=None): 166 | super(posterior_net_cnn, self).__init__() 167 | self.num_keypoints = cfg.num_keypoints 168 | if input_dim is None: 169 | self.input_dim = cfg.num_keypoints 170 | else: 171 | self.input_dim = input_dim 172 | self.convlayer1 = nn.Sequential(nn.Conv2d(self.input_dim, 173 | 16//scale_factor, kernel_size=3, 174 | stride=1, padding=1), 175 | nn.LeakyReLU(0.2, inplace=True)) 176 | self.convlayer2 = nn.Sequential(nn.Conv2d(16//scale_factor, 177 | 32//scale_factor, kernel_size=3, 178 | stride=1, padding=1), 179 | nn.LeakyReLU(0.2, inplace=True)) 180 | self.convlayer3 = nn.Sequential(nn.Conv2d(32//scale_factor, 181 | 64//scale_factor, kernel_size=3, 182 | stride=1, padding=1), 183 | nn.LeakyReLU(0.2, inplace=True)) 184 | self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 185 | self.embed = nn.Sequential(nn.Conv2d(128//scale_factor+64//scale_factor, 128//scale_factor, kernel_size=3, 186 | stride=1, padding=1), 187 | nn.LeakyReLU(0.2, inplace=True)) 188 | self.embed1 = nn.Sequential(nn.Conv2d(128//scale_factor, 1, kernel_size=3, 189 | stride=1, padding=1)) 190 | self.embed2 = nn.Sequential(nn.Conv2d(128//scale_factor, 1, kernel_size=3, 191 | stride=1, padding=1), 192 | nn.Softplus()) 193 | 194 | def forward(self, rnn_state, gaussian_maps): 195 | gaussian_maps = self.convlayer1(gaussian_maps) 196 | gaussian_maps = self.mp(self.convlayer2(gaussian_maps)) 197 | gaussian_maps = self.mp(self.convlayer3(gaussian_maps)) 198 | hidden = self.embed(torch.cat((rnn_state, gaussian_maps), dim=1)) 199 | means = self.embed1(hidden) 200 | stds = self.embed2(hidden) + 1e-4 201 | return means, stds 202 | 203 | -------------------------------------------------------------------------------- /train_AE_BN3_kth_1_2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | r"""Minimal example for training a video_structure model. 18 | 19 | See README.md for installation instructions. To run on GPU device 0: 20 | 21 | CUDA_VISIBLE_DEVICES=0 python -m video_structure.train 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | 27 | from __future__ import print_function 28 | 29 | import torch 30 | import torch.optim as optim 31 | import torch.nn as nn 32 | import os, subprocess 33 | import random 34 | from torch.utils.data import DataLoader 35 | import numpy as np 36 | 37 | import utils,hyperparameters,losses,vision_BN3_1,ops 38 | 39 | ''' 40 | Builds the complete model with image encoder plus dynamics model. 41 | 42 | This architecture is meant for testing/illustration only. 43 | 44 | Model architecture: 45 | 46 | image --> keypoints --> reconstructed_image 47 | 48 | The model takes a [batch_size, timesteps, H, W, C] image sequence as input. It 49 | "observes" all frames, detects keypoints, and reconstructs the images. The 50 | dynamics model learns to predict future keypoints based on the detected 51 | keypoints. 52 | ''' 53 | os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in subprocess.Popen( 54 | "nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, 55 | stdout=subprocess.PIPE).stdout.readlines()])) 56 | 57 | cfg = hyperparameters.get_config() 58 | name = 'model=AE_BN3_kth_1_2' 59 | 60 | # do not use miu 61 | # increase detect intense 62 | # normalize max map to 1 63 | # simple skip connections 64 | # Adam optimizer 65 | # no training clamp(-0.5, 0.5) 66 | # incite loss 67 | # discrete to 64x64 grid 68 | cfg.dataset = 'kth' 69 | cfg.num_keypoints = 12 70 | 71 | 72 | load_model = False 73 | log_dir = 'logs/struc' 74 | log_dir = '%s-%s' % (log_dir, name) 75 | 76 | os.makedirs('%s/gen/' % log_dir, exist_ok=True) 77 | os.makedirs('%s/plots/' % log_dir, exist_ok=True) 78 | 79 | print("Random Seed: ", cfg.seed) 80 | random.seed(cfg.seed) 81 | np.random.seed(cfg.seed) 82 | torch.manual_seed(cfg.seed) 83 | torch.cuda.manual_seed_all(cfg.seed) 84 | dtype = torch.cuda.FloatTensor 85 | 86 | # --------- loss functions ------------------------------------ 87 | mse_criterion = nn.MSELoss(reduction='sum') 88 | 89 | def CalDelta_xy(keypoints_np): 90 | keypoints_np1 = np.zeros_like(keypoints_np) 91 | keypoints_np1[:, :, 0] = np.round((1 + keypoints_np[:, :, 0]) * (cfg.img_w-1) / 2) 92 | keypoints_np1[:, :, 1] = np.round((1 - keypoints_np[:, :, 1]) * (cfg.img_w-1) / 2) 93 | keypoints_np1[:, :, 0] = keypoints_np1[:, :, 0] / (cfg.img_w-1) * 2 - 1 94 | keypoints_np1[:, :, 1] = 1 - keypoints_np1[:, :, 1] / (cfg.img_w-1) * 2 95 | 96 | delta_xy = keypoints_np1 - keypoints_np 97 | return delta_xy 98 | 99 | if load_model: 100 | saved_model = torch.load('%s/model0.pth' % log_dir) 101 | cfg.learning_rate /= 2 102 | build_images_to_keypoints_net = saved_model['build_images_to_keypoints_net'] 103 | keypoints_to_images_net = saved_model['keypoints_to_images_net'] 104 | else: 105 | build_images_to_keypoints_net = vision_BN3_1.build_images_to_keypoints_net( 106 | cfg, [1, cfg.img_h, cfg.img_w]) 107 | keypoints_to_images_net = vision_BN3_1.build_keypoints_to_images_net( 108 | cfg, [1, cfg.img_h, cfg.img_w]) 109 | build_images_to_keypoints_net.apply(utils.init_weights) 110 | keypoints_to_images_net.apply(utils.init_weights) 111 | 112 | build_images_to_keypoints_net_optimizer = optim.Adam(build_images_to_keypoints_net.parameters(), 113 | lr=cfg.learning_rate, weight_decay=cfg.reg_lambda) 114 | keypoints_to_images_net_optimizer = optim.Adam(keypoints_to_images_net.parameters(), 115 | lr=cfg.learning_rate, weight_decay=cfg.reg_lambda) 116 | 117 | # --------- transfer to gpu ------------------------------------ 118 | build_images_to_keypoints_net.cuda() 119 | keypoints_to_images_net.cuda() 120 | 121 | 122 | # --------- load a dataset ------------------------------------ 123 | train_data, test_data = utils.load_dataset(cfg) 124 | 125 | train_loader = DataLoader(train_data, 126 | num_workers=4, 127 | batch_size=cfg.batch_size, 128 | shuffle=True, 129 | drop_last=True, 130 | pin_memory=True) 131 | test_loader = DataLoader(test_data, 132 | num_workers=1, 133 | batch_size=cfg.test_batch_size, 134 | shuffle=True, 135 | drop_last=True, 136 | pin_memory=True) 137 | 138 | def get_training_batch(): 139 | while True: 140 | for sequence in train_loader: 141 | batch = utils.normalize_data(cfg, dtype, sequence) 142 | yield batch 143 | 144 | training_batch_generator = get_training_batch() 145 | 146 | def get_testing_batch(): 147 | while True: 148 | for sequence in test_loader: 149 | batch = utils.normalize_data(cfg, dtype, sequence) 150 | yield batch 151 | 152 | testing_batch_generator = get_testing_batch() 153 | 154 | # --------- plotting funtions ------------------------------------ 155 | def plot(x, epoch): 156 | gen_seq = [] 157 | gen_keypointsxy = [] 158 | gt_seq = [x[i] for i in range(len(x))] 159 | 160 | observed_keypoints = [] 161 | for i in range(cfg.n_eval): 162 | keypoints, _ = build_images_to_keypoints_net(x[i]) 163 | observed_keypoints.append(keypoints.detach()) 164 | 165 | for i in range(cfg.n_eval): 166 | reconstructed_image = keypoints_to_images_net(observed_keypoints[i], x[0], observed_keypoints[0]) 167 | gen_keypointsxy.append(ops.change2xy(observed_keypoints[i])) 168 | gen_seq.append(reconstructed_image.detach()) 169 | 170 | to_plot = [] 171 | gifs = [[] for t in range(cfg.n_eval)] 172 | 173 | nrow = min(cfg.test_batch_size, 10) 174 | for i in range(nrow): 175 | # ground truth sequence 176 | row = [] 177 | for t in range(cfg.n_eval): 178 | row.append(gt_seq[t][i]) 179 | to_plot.append(row) 180 | 181 | row = [] 182 | for t in range(cfg.n_eval): 183 | row.append(ops.add_keypoints(gen_seq[t][i].clone(), gen_keypointsxy[t][i])) 184 | to_plot.append(row) 185 | 186 | for t in range(cfg.n_eval): 187 | row = [] 188 | row.append(gt_seq[t][i]) 189 | row.append(gen_seq[t][i]) 190 | gifs[t].append(row) 191 | 192 | fname = '%s/gen/sample_%d.png' % (log_dir, epoch) 193 | utils.save_tensors_image(fname, to_plot) 194 | 195 | fname = '%s/gen/sample_%d.gif' % (log_dir, epoch) 196 | utils.save_gif(fname, gifs) 197 | 198 | # --------- training funtions ------------------------------------ 199 | def train(x): 200 | build_images_to_keypoints_net.zero_grad() 201 | keypoints_to_images_net.zero_grad() 202 | 203 | mse = 0 204 | observed_keypoints = [] 205 | observed_heatmaps = [] 206 | for i in range(cfg.observed_steps + cfg.predicted_steps): 207 | keypoints, heatmaps = build_images_to_keypoints_net(x[i]) 208 | keypoints_np = keypoints.data.cpu().numpy() 209 | delta_xy = CalDelta_xy(keypoints_np) 210 | delta_xypt = torch.FloatTensor(delta_xy).cuda() 211 | keypoints = keypoints + delta_xypt 212 | observed_keypoints.append(keypoints) 213 | observed_heatmaps.append(heatmaps) 214 | 215 | for i in range(cfg.observed_steps + cfg.predicted_steps): 216 | reconstructed_image = keypoints_to_images_net(observed_keypoints[i], x[0], observed_keypoints[0]) 217 | mse += 0.5*mse_criterion(reconstructed_image, x[i]) 218 | 219 | mse /= (cfg.observed_steps + cfg.predicted_steps)*cfg.batch_size 220 | separation_loss = losses.temporal_separation_loss( 221 | cfg, torch.stack(observed_keypoints[:cfg.observed_steps])) 222 | sparse_loss = losses.sparse_loss(torch.stack(observed_heatmaps), cfg) 223 | incite = losses.detect_loss(torch.stack(observed_heatmaps)) 224 | 225 | loss = mse + incite #+ separation_loss + sparse_loss 226 | loss.backward() 227 | 228 | torch.nn.utils.clip_grad_norm_(build_images_to_keypoints_net.parameters(), cfg.clipnorm) 229 | torch.nn.utils.clip_grad_norm_(keypoints_to_images_net.parameters(), cfg.clipnorm) 230 | 231 | build_images_to_keypoints_net_optimizer.step() 232 | keypoints_to_images_net_optimizer.step() 233 | 234 | return mse.data.cpu().numpy(), separation_loss.data.cpu().numpy(), sparse_loss.data.cpu().numpy(), \ 235 | incite.data.cpu().numpy() 236 | 237 | # --------- training loop ------------------------------------ 238 | result = open(log_dir+'/'+"result.txt","w") 239 | for epoch in range(cfg.num_epochs): 240 | build_images_to_keypoints_net.train() 241 | keypoints_to_images_net.train() 242 | epoch_mse = 0 243 | epoch_sep = 0 244 | epoch_spa = 0 245 | epoch_inc = 0 246 | 247 | for i in range(cfg.steps_per_epoch): 248 | x = next(training_batch_generator) # sequence and the sequence class number 249 | 250 | # train frame_predictor 251 | mse, separation_loss, sparse_loss, incite = train(x) 252 | epoch_mse += mse 253 | epoch_sep += separation_loss 254 | epoch_spa += sparse_loss 255 | epoch_inc += incite 256 | 257 | print('[%02d] mse loss: %.5f | separation loss: %.5f | sparse loss: %.5f | incite loss: %.5f (%s)' % ( 258 | epoch, epoch_mse / cfg.steps_per_epoch, epoch_sep / cfg.steps_per_epoch, 259 | epoch_spa / cfg.steps_per_epoch, epoch_inc / cfg.steps_per_epoch,name)) 260 | result.write('[%02d] mse loss: %.5f | separation loss: %.5f | sparse loss: %.5f | incite loss: %.5f (%s)\n' % ( 261 | epoch, epoch_mse / cfg.steps_per_epoch, epoch_sep / cfg.steps_per_epoch, 262 | epoch_spa / cfg.steps_per_epoch, epoch_inc / cfg.steps_per_epoch,name)) 263 | 264 | # plot some stuff 265 | build_images_to_keypoints_net.eval() 266 | keypoints_to_images_net.eval() 267 | 268 | if (epoch+1)%100==0 or epoch==0: 269 | x = next(testing_batch_generator) 270 | plot(x, epoch) 271 | if (epoch + 1) % 500 == 0: 272 | cfg.learning_rate /= 4 273 | utils.set_learning_rate(build_images_to_keypoints_net_optimizer, cfg.learning_rate) 274 | utils.set_learning_rate(keypoints_to_images_net_optimizer, cfg.learning_rate) 275 | 276 | # save the model 277 | torch.save({ 278 | 'build_images_to_keypoints_net': build_images_to_keypoints_net, 279 | 'keypoints_to_images_net': keypoints_to_images_net}, 280 | '%s/model.pth' % log_dir) 281 | result.close() 282 | 283 | 284 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import socket 4 | import argparse 5 | import os 6 | import numpy as np 7 | from skimage.metrics import peak_signal_noise_ratio as psnr_metric 8 | from skimage.metrics import structural_similarity as ssim_metric 9 | from PIL import Image, ImageDraw 10 | 11 | from torchvision import datasets, transforms 12 | from torch.autograd import Variable 13 | from torch import nn 14 | import imageio 15 | import random 16 | 17 | def get_minibatches_idx(n, 18 | minibatch_size, 19 | shuffle=False, 20 | min_frame=None, 21 | trainfiles=None, 22 | del_list=None): 23 | """ 24 | Used to shuffle the dataset at each iteration. 25 | """ 26 | idx_list = np.arange(n, dtype="int32") 27 | 28 | if min_frame != None: 29 | if del_list == None: 30 | del_list = list() 31 | for i in idx_list: 32 | vid_path = trainfiles[i].split()[0] 33 | length = len([f for f in listdir(vid_path) if f.endswith('.png')]) 34 | if length < min_frame: 35 | del_list.append(i) 36 | print('[!] Discarded %d samples from training set!' % len(del_list)) 37 | idx_list = np.delete(idx_list, del_list) 38 | 39 | if shuffle: 40 | random.shuffle(idx_list) 41 | 42 | minibatches = [] 43 | minibatch_start = 0 44 | for i in range(n // minibatch_size): 45 | minibatches.append( 46 | idx_list[minibatch_start:minibatch_start + minibatch_size]) 47 | minibatch_start += minibatch_size 48 | 49 | if (minibatch_start != n): 50 | # Make a minibatch out of what is left 51 | minibatches.append(idx_list[minibatch_start:]) 52 | 53 | return zip(range(len(minibatches)), minibatches), del_list 54 | 55 | def load_dataset(opt): 56 | if opt.dataset == 'smmnist': 57 | from data.moving_mnist import MovingMNIST 58 | train_data = MovingMNIST( 59 | train=True, 60 | data_root=opt.data_root, 61 | seq_len=opt.observed_steps + opt.predicted_steps, 62 | image_size=opt.img_w, 63 | deterministic=False, 64 | num_digits=opt.num_digits) 65 | test_data = MovingMNIST( 66 | train=False, 67 | data_root=opt.data_root, 68 | seq_len=opt.n_eval, 69 | image_size=opt.img_w, 70 | deterministic=False, 71 | num_digits=opt.num_digits) 72 | elif opt.dataset == 'bair': 73 | from data.bair import RobotPush 74 | train_data = RobotPush( 75 | data_root=opt.data_root, 76 | train=True, 77 | seq_len=opt.observed_steps + opt.predicted_steps, 78 | image_size=opt.img_w) 79 | test_data = RobotPush( 80 | data_root=opt.data_root, 81 | train=False, 82 | seq_len=opt.n_eval, 83 | image_size=opt.img_w) 84 | elif opt.dataset == 'kth': 85 | from data.kth import KTH 86 | train_data = KTH( 87 | train=True, 88 | data_root=opt.data_root, 89 | seq_len=opt.observed_steps + opt.predicted_steps, 90 | image_size=opt.img_w) 91 | test_data = KTH( 92 | train=False, 93 | data_root=opt.data_root, 94 | seq_len=opt.n_eval, 95 | image_size=opt.img_w) 96 | elif opt.dataset == 'JIGSAWS-Suturing': 97 | from data.Suturning import Suturning 98 | train_data = Suturning( 99 | train=True, 100 | data_root=opt.data_root, 101 | seq_len=opt.observed_steps + opt.predicted_steps, 102 | image_size=opt.img_w) 103 | test_data = Suturning( 104 | train=False, 105 | data_root=opt.data_root, 106 | seq_len=opt.n_eval, 107 | image_size=opt.img_w) 108 | elif opt.dataset == 'Human3.6m': 109 | from data.human import Human 110 | train_data = Human( 111 | train=True, 112 | data_root=opt.data_root, 113 | seq_len=opt.observed_steps + opt.predicted_steps, 114 | image_size=opt.img_w) 115 | test_data = Human( 116 | train=False, 117 | data_root=opt.data_root, 118 | seq_len=opt.n_eval, 119 | image_size=opt.img_w) 120 | elif opt.dataset == 'Human3.6m_hr': 121 | from data.human_hr import Human 122 | train_data = Human( 123 | train=True, 124 | data_root=opt.data_root, 125 | seq_len=opt.observed_steps + opt.predicted_steps, 126 | image_size=opt.img_w) 127 | test_data = Human( 128 | train=False, 129 | data_root=opt.data_root, 130 | seq_len=opt.n_eval, 131 | image_size=opt.img_w) 132 | elif opt.dataset == 'kitti': 133 | from data.kitti import Kitti 134 | train_data = Kitti( 135 | train=True, 136 | data_root=opt.data_root, 137 | seq_len=opt.observed_steps + opt.predicted_steps, 138 | image_size=opt.img_w) 139 | test_data = Kitti( 140 | train=False, 141 | data_root=opt.data_root, 142 | seq_len=opt.n_eval, 143 | image_size=opt.img_w) 144 | elif opt.dataset == 'penn': 145 | from data.penn import Penn 146 | train_data = Penn( 147 | train=True, 148 | data_root=opt.data_root, 149 | seq_len=opt.observed_steps + opt.predicted_steps, 150 | image_size=opt.img_w) 151 | test_data = Penn( 152 | train=False, 153 | data_root=opt.data_root, 154 | seq_len=opt.n_eval, 155 | image_size=opt.img_w) 156 | 157 | return train_data, test_data 158 | 159 | def sequence_input(seq, dtype): 160 | return [Variable(x.type(dtype)) for x in seq] 161 | 162 | def change_to_video(sequence): 163 | sequence.transpose_(0, 1).transpose_(1, 2) 164 | return sequence 165 | 166 | def normalize_data(opt, dtype, sequence): 167 | 168 | sequence.transpose_(0, 1) 169 | sequence.transpose_(3, 4).transpose_(2, 3) 170 | 171 | #sequence.transpose_(0, 1) 172 | 173 | return sequence_input(sequence, dtype) 174 | 175 | 176 | def init_weights(m): 177 | classname = m.__class__.__name__ 178 | if classname.find('Linear') != -1: 179 | nn.init.xavier_normal_(m.weight) 180 | if not (m.bias is None): 181 | m.bias.data.fill_(0) 182 | 183 | elif classname.find('Conv') != -1: 184 | nn.init.kaiming_normal_(m.weight) 185 | if not (m.bias is None): 186 | m.bias.data.fill_(0) 187 | 188 | def set_learning_rate(optimizer, lr): 189 | """Sets the learning rate to the given value""" 190 | for param_group in optimizer.param_groups: 191 | param_group['lr'] = lr 192 | 193 | 194 | def is_sequence(arg): 195 | return (not hasattr(arg, "strip") and 196 | not type(arg) is np.ndarray and 197 | not hasattr(arg, "dot") and 198 | (hasattr(arg, "__getitem__") or 199 | hasattr(arg, "__iter__"))) 200 | 201 | def image_tensor(inputs, padding=1): 202 | # assert is_sequence(inputs) 203 | assert len(inputs) > 0 204 | # print(inputs) 205 | 206 | # if this is a list of lists, unpack them all and grid them up 207 | if is_sequence(inputs[0]) or (hasattr(inputs, "dim") and inputs.dim() > 4): 208 | images = [image_tensor(x) for x in inputs] 209 | if images[0].dim() == 3: 210 | c_dim = images[0].size(0) 211 | x_dim = images[0].size(1) 212 | y_dim = images[0].size(2) 213 | else: 214 | c_dim = 1 215 | x_dim = images[0].size(0) 216 | y_dim = images[0].size(1) 217 | 218 | result = torch.ones(c_dim, 219 | x_dim * len(images) + padding * (len(images)-1), 220 | y_dim) 221 | for i, image in enumerate(images): 222 | result[:, i * x_dim + i * padding : 223 | (i+1) * x_dim + i * padding, :].copy_(image) 224 | 225 | return result 226 | 227 | # if this is just a list, make a stacked image 228 | else: 229 | images = [x.data if isinstance(x, torch.autograd.Variable) else x 230 | for x in inputs] 231 | # print(images) 232 | if images[0].dim() == 3: 233 | c_dim = images[0].size(0) 234 | x_dim = images[0].size(1) 235 | y_dim = images[0].size(2) 236 | else: 237 | c_dim = 1 238 | x_dim = images[0].size(0) 239 | y_dim = images[0].size(1) 240 | 241 | result = torch.ones(c_dim, 242 | x_dim, 243 | y_dim * len(images) + padding * (len(images)-1)) 244 | for i, image in enumerate(images): 245 | result[:, :, i * y_dim + i * padding : 246 | (i+1) * y_dim + i * padding].copy_(image) 247 | return result 248 | 249 | def make_image(tensor, colored_kp=False, key_points=None): 250 | tensor = tensor.cpu().clamp(-0.5, 0.5) + 0.5 251 | if tensor.size(0) == 1: 252 | tensor = tensor.expand(3, tensor.size(1), tensor.size(2)) 253 | tensor = tensor.transpose(0, 1).transpose(1, 2).numpy() * 255 254 | 255 | if colored_kp: 256 | radius = 1 257 | color_bar = np.array([[248,248,24],[246,221,41],[254,191,60], 258 | [213,190,39],[148,202,73],[83,204,125], 259 | [41,195,170],[1,183,202],[33,164,227], 260 | [45,142,242],[58,115,255],[71,86,247]]) 261 | for i in range(len(key_points)): 262 | x = [np.clip(key_points[i][0] - radius, 0, 63), np.clip(key_points[i][0] + radius, 0, 63)] 263 | y = [np.clip(key_points[i][1] - radius, 0, 63), np.clip(key_points[i][1] + radius, 0, 63)] 264 | tensor[y[0]:y[1] + 1, key_points[i][0]] = color_bar[i] 265 | tensor[key_points[i][1], x[0]:x[1] + 1] = color_bar[i] 266 | # pdb.set_trace() 267 | return Image.fromarray(np.uint8(tensor)) 268 | 269 | def save_image(filename, tensor, colored_kp=False, key_points=None): 270 | img = make_image(tensor, colored_kp, key_points) 271 | img.save(filename) 272 | 273 | def save_tensors_image(filename, inputs, padding=1): 274 | images = image_tensor(inputs, padding) 275 | return save_image(filename, images) 276 | 277 | def save_gif(filename, inputs, duration=0.15, colored_kp=False, key_points=None): 278 | images = [] 279 | m = 0 280 | for tensor in inputs: 281 | img = image_tensor(tensor, padding=0) 282 | img = img.cpu() 283 | img = img.transpose(0,1).transpose(1,2).clamp(-0.5, 0.5) 284 | img = (img.numpy()+0.5) * 255 285 | if colored_kp: 286 | img += 255 287 | radius = 1 288 | kp = key_points[m] 289 | color_bar = np.array([[248, 248, 24], [246, 221, 41], [254, 191, 60], 290 | [213, 190, 39], [148, 202, 73], [83, 204, 125], 291 | [41, 195, 170], [1, 183, 202], [33, 164, 227], 292 | [45, 142, 242], [58, 115, 255], [71, 86, 247]]) 293 | for i in range(len(kp)): 294 | x = [np.clip(kp[i][0] - radius, 0, 63), np.clip(kp[i][0] + radius, 0, 63)] 295 | y = [np.clip(kp[i][1] - radius, 0, 63), np.clip(kp[i][1] + radius, 0, 63)] 296 | img[y[0]:y[1] + 1, kp[i][0]] = color_bar[i] 297 | img[kp[i][1], x[0]:x[1] + 1] = color_bar[i] 298 | m+=1 299 | images.append(img.astype('uint8')) 300 | imageio.mimsave(filename, images, duration=duration) 301 | 302 | def draw_text_tensor(tensor, text): 303 | np_x = tensor.transpose(0, 1).transpose(1, 2).data.cpu().numpy() 304 | pil = Image.fromarray(np.uint8(np_x*255)) 305 | draw = ImageDraw.Draw(pil) 306 | draw.text((4, 64), text, (0,0,0)) 307 | img = np.asarray(pil) 308 | return Variable(torch.Tensor(img / 255.)).transpose(1, 2).transpose(0, 1) 309 | 310 | def save_gif_with_text(filename, inputs, text, duration=0.25): 311 | images = [] 312 | for tensor, text in zip(inputs, text): 313 | img = image_tensor([draw_text_tensor(ti, texti) for ti, texti in zip(tensor, text)], padding=0) 314 | img = img.cpu() 315 | img = img.transpose(0,1).transpose(1,2).clamp(0,1).numpy()*255 316 | images.append(img.astype('uint8')) 317 | imageio.mimsave(filename, images, duration=duration) 318 | 319 | def merge(images, size): 320 | h, w = images.shape[1], images.shape[2] 321 | img = np.zeros((h * size[0], w * size[1])) 322 | 323 | for idx, image in enumerate(images): 324 | i = idx % size[1] 325 | j = idx // size[1] 326 | img[j * h:j * h + h, i * w:i * w + w] = image 327 | 328 | return img 329 | 330 | 331 | def transform(input_): 332 | return 2 * input_ - 1. 333 | 334 | 335 | def inverse_transform(input_): 336 | return (input_ + 1.) / 2. 337 | 338 | 339 | def imsave(images, size, path): 340 | return imageio.imwrite(path, merge(images, size)) 341 | 342 | def gauss2D_mask(center, shape, sigma=0.5): 343 | m, n = [ss - 1 for ss in shape] 344 | y, x = np.ogrid[0:m + 1, 0:n + 1] 345 | y = y - center[0] 346 | x = x - center[1] 347 | z = x * x + y * y 348 | h = np.exp(-z / (2. * sigma * sigma/(shape[0]**2))) 349 | sumh = h.sum() 350 | if sumh != 0: 351 | h = h / sumh 352 | return h 353 | 354 | def visualize_lm(posex, posey, image_size, num_keypoints): 355 | posey = inverse_transform(posey) * image_size 356 | posex = inverse_transform(posex) * image_size 357 | cpose = np.zeros((image_size, image_size, num_keypoints)) 358 | for j in range(num_keypoints): 359 | gmask = gauss2D_mask( 360 | (posey[j], posex[j]), (image_size, image_size), sigma=8.) 361 | cpose[:, :, j] = gmask / gmask.max() 362 | 363 | return np.amax(cpose, axis=2) 364 | 365 | def mse_metric(x1, x2): 366 | err = np.sum((x1 - x2) ** 2) 367 | err /= float(x1.shape[0] * x1.shape[1] * x1.shape[2]) 368 | return err 369 | 370 | def l1_metric(x1, x2): 371 | err = np.sum(abs(x1 - x2)) 372 | err /= float(x1.shape[0] * x1.shape[1] * x1.shape[2]) 373 | return err 374 | 375 | def eval_seq(gt, pred): 376 | T = len(gt) 377 | bs = gt[0].shape[0] 378 | ssim = np.zeros((bs, T)) 379 | psnr = np.zeros((bs, T)) 380 | l1 = np.zeros((bs, T)) 381 | for i in range(bs): 382 | for t in range(T): 383 | for c in range(gt[t][i].shape[0]): # calculate for each channel respectively 384 | ssim[i, t] += ssim_metric(gt[t][i][c], pred[t][i][c], data_range=1.0) 385 | psnr[i, t] += psnr_metric(gt[t][i][c], pred[t][i][c], data_range=1.0) 386 | ssim[i, t] /= gt[t][i].shape[0] 387 | psnr[i, t] /= gt[t][i].shape[0] 388 | l1[i, t] = l1_metric(gt[t][i], pred[t][i]) 389 | 390 | return l1, ssim, psnr 391 | 392 | def eval_seq1(gt, pred, LPIPSmodel): 393 | T = len(gt) 394 | bs = gt[0].shape[0] 395 | ssim = np.zeros((bs, T)) 396 | psnr = np.zeros((bs, T)) 397 | l1 = np.zeros((bs, T)) 398 | lpips = np.zeros((bs, T)) 399 | for i in range(bs): 400 | for t in range(T): 401 | for c in range(gt[t][i].shape[0]): # calculate for each channel respectively 402 | ssim[i, t] += ssim_metric(gt[t][i][c], pred[t][i][c], data_range=1.0) 403 | psnr[i, t] += psnr_metric(gt[t][i][c], pred[t][i][c], data_range=1.0) 404 | ssim[i, t] /= gt[t][i].shape[0] 405 | psnr[i, t] /= gt[t][i].shape[0] 406 | l1[i, t] = l1_metric(gt[t][i], pred[t][i]) 407 | for t in range(T): 408 | gt_im = 2*gt[t]-1 # normalize to [-1, 1] 409 | pred_im = 2*pred[t]-1 410 | if gt_im.shape[1]==1: 411 | gt_im = np.tile(gt_im, (1,3,1,1)) 412 | pred_im = np.tile(pred_im, (1,3,1,1)) 413 | gt_im = torch.FloatTensor(gt_im).cuda() 414 | pred_im = torch.FloatTensor(pred_im).cuda() 415 | lpips[:, t] = LPIPSmodel.forward(gt_im, pred_im).squeeze().data.cpu().numpy() 416 | return l1, ssim, psnr, lpips 417 | 418 | def changed_keypoints(keypoints, weight): 419 | keypoints = keypoints* weight 420 | return keypoints 421 | 422 | def changed_keypoints1(keypoints, weight): 423 | keypoints = torch.clamp(0.95*keypoints + 0.05*weight, min=0, max=1.0) 424 | return keypoints -------------------------------------------------------------------------------- /vision_BN3_1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Vision-related components of the structured video representation model. 18 | 19 | These components perform the pixels <--> keypoints transformation. 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | import torch.nn as nn 25 | #from video_structure 26 | import ops 27 | import torch.nn.functional as F 28 | 29 | 30 | 31 | 32 | class build_image_encoder(nn.Module): 33 | """Extracts feature maps from images. 34 | 35 | The encoder iteratively halves the resolution and doubles the number of 36 | filters until the size of the feature maps is output_map_width by 37 | output_map_width. 38 | 39 | Args: 40 | input_shape: Shape of the input image (without batch dimension). 41 | initial_num_filters: Number of filters to apply at the input resolution. 42 | output_map_width: Width of the output feature maps. 43 | layers_per_scale: How many additional size-preserving conv layers to apply 44 | at each map scale. 45 | **conv_layer_kwargs: Passed to layers.Conv2D. 46 | 47 | Raises: 48 | ValueError: If the width of the input image is not compatible with 49 | output_map_width, i.e. if input_width/output_map_width is not a perfect 50 | square. 51 | """ 52 | def __init__(self, input_shape, initial_num_filters=32, output_map_width=16, 53 | layers_per_scale=1, **conv_layer_kwargs): 54 | super(build_image_encoder, self).__init__() 55 | if np.log2(input_shape[1] / output_map_width) % 1: 56 | raise ValueError( 57 | 'The ratio of input width and output_map_width must be a perfect ' 58 | 'square, but got {} and {} with ratio {}'.format( 59 | input_shape[1], output_map_width, input_shape[1] / output_map_width)) 60 | total_modules = [] 61 | modules = [nn.Conv2d(input_shape[0], initial_num_filters, 62 | **conv_layer_kwargs), 63 | nn.BatchNorm2d(initial_num_filters), 64 | nn.LeakyReLU(0.2, inplace=True)] 65 | 66 | # Expand image to initial_num_filters maps: 67 | for _ in range(layers_per_scale): 68 | modules.extend([nn.Conv2d(initial_num_filters, initial_num_filters, 69 | **conv_layer_kwargs), 70 | nn.BatchNorm2d(initial_num_filters), 71 | nn.LeakyReLU(0.2, inplace=True)]) 72 | total_modules.append(nn.Sequential(*modules)) 73 | modules = [] 74 | 75 | # Apply downsampling blocks until feature map width is output_map_width: 76 | width = input_shape[2] 77 | num_filters = initial_num_filters 78 | while width > output_map_width: 79 | # Reduce resolution: 80 | modules.extend([nn.Conv2d(num_filters, num_filters*2, 81 | **conv_layer_kwargs), 82 | nn.BatchNorm2d(num_filters*2), 83 | nn.LeakyReLU(0.2, inplace=True)]) 84 | 85 | # Apply additional layers: 86 | for _ in range(layers_per_scale): 87 | modules.extend([nn.Conv2d(num_filters*2, num_filters*2, **conv_layer_kwargs), 88 | nn.BatchNorm2d(num_filters*2), 89 | nn.LeakyReLU(0.2, inplace=True)]) 90 | num_filters *= 2 91 | width //= 2 92 | total_modules.append(nn.Sequential(*modules)) 93 | modules = [] 94 | self.conv = nn.ModuleList(total_modules) 95 | self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 96 | 97 | def forward(self, input): 98 | h1 = self.conv[0](input) 99 | h2 = self.conv[1](self.mp(h1)) 100 | h3 = self.conv[2](self.mp(h2)) 101 | return h3, [h2, h1] 102 | 103 | 104 | class build_image_decoder(nn.Module): 105 | """Decodes images from feature maps. 106 | 107 | The encoder iteratively doubles the resolution and halves the number of 108 | filters until the size of the feature maps is output_width. 109 | 110 | Args: 111 | input_shape: Shape of the input image (without batch dimension). 112 | output_width: Width of the output image. 113 | layers_per_scale: How many additional size-preserving conv layers to apply 114 | at each map scale. 115 | **conv_layer_kwargs: Passed to layers.Conv2D. 116 | 117 | Raises: 118 | ValueError: If the width of the input feature maps is not compatible with 119 | output_width, i.e. if output_width/input_map_width is not a perfect 120 | square. 121 | """ 122 | def __init__(self, input_shape, output_width, layers_per_scale=1, **conv_layer_kwargs): 123 | super(build_image_decoder, self).__init__() 124 | self.num_levels = np.log2(output_width / input_shape[2]) 125 | if self.num_levels % 1: 126 | raise ValueError( 127 | 'The ratio of output_width and input width must be a perfect ' 128 | 'square, but got {} and {} with ratio {}'.format( 129 | output_width, input_shape[2], output_width / input_shape[2])) 130 | 131 | # Expand until we have filters_out channels: 132 | self.num_filters = input_shape[0] 133 | num_filters = input_shape[0] 134 | self.up = nn.UpsamplingNearest2d(scale_factor=2) 135 | total_modules = [] 136 | modules = [nn.Conv2d(num_filters, num_filters, **conv_layer_kwargs), 137 | nn.LeakyReLU(0.2, inplace=True)] 138 | # Expand image to initial_num_filters maps: 139 | for i in range(layers_per_scale): 140 | if i < layers_per_scale - 1: 141 | modules.extend([nn.Conv2d(num_filters, num_filters, 142 | **conv_layer_kwargs), 143 | nn.LeakyReLU(0.2, inplace=True)]) 144 | else: 145 | modules.extend([nn.Conv2d(num_filters, num_filters // 2, 146 | **conv_layer_kwargs), 147 | nn.LeakyReLU(0.2, inplace=True)]) 148 | total_modules.append(nn.Sequential(*modules)) 149 | modules = [] 150 | 151 | for i in range(int(self.num_levels)): 152 | modules.extend([nn.Conv2d(num_filters, num_filters // 2, **conv_layer_kwargs), 153 | nn.LeakyReLU(0.2, inplace=True)]) 154 | # Apply additional layers: 155 | for j in range(layers_per_scale): 156 | if j < layers_per_scale - 1: 157 | modules.extend([nn.Conv2d(num_filters // 2, num_filters // 2, **conv_layer_kwargs), 158 | nn.LeakyReLU(0.2, inplace=True)]) 159 | else: 160 | if i < layers_per_scale - 1: 161 | modules.extend([nn.Conv2d(num_filters // 2, num_filters // 4, **conv_layer_kwargs), 162 | nn.LeakyReLU(0.2, inplace=True)]) 163 | else: 164 | modules.extend([nn.Conv2d(num_filters // 2, num_filters // 2, **conv_layer_kwargs), 165 | nn.LeakyReLU(0.2, inplace=True)]) 166 | num_filters //= 2 167 | total_modules.append(nn.Sequential(*modules)) 168 | modules = [] 169 | self.out_filters = num_filters 170 | self.conv = nn.ModuleList(total_modules) 171 | 172 | def forward(self, x, skip): 173 | d1 = self.conv[0](x) 174 | up1 = self.up(d1) 175 | d2 = self.conv[1](torch.cat([up1, skip[0]], 1)) 176 | up2 = self.up(d2) 177 | d3 = self.conv[2](torch.cat([up2, skip[1]], 1)) 178 | return d3 179 | 180 | class build_images_to_keypoints_net(nn.Module): 181 | """Builds a model that encodes an image into a keypoint. 182 | 183 | The feature maps are then reduced to num_keypoints heatmaps, and 184 | the heatmaps to (x, y, scale)-keypoints. 185 | 186 | Args: 187 | cfg: ConfigDict with model hyperparamters. 188 | image_shape: Image shape tuple: (C, H, W). 189 | 190 | Returns: 191 | A tf.keras.Model object. 192 | """ 193 | def __init__(self, cfg, image_shape): 194 | super(build_images_to_keypoints_net, self).__init__() 195 | # Adjust channel number to account for add_coord_channels: 196 | encoder_input_shape = image_shape 197 | #encoder_input_shape[0] += 2 198 | # Build feature extractor: 199 | self.image_encoder = build_image_encoder( 200 | input_shape=encoder_input_shape, 201 | initial_num_filters=cfg.num_encoder_filters, 202 | output_map_width=cfg.heatmap_width, 203 | layers_per_scale=cfg.layers_per_scale, 204 | **cfg.conv_layer_kwargs) 205 | 206 | # Build final layer that maps to the desired number of heatmaps: 207 | self.features_to_keypoint_heatmaps = nn.Sequential( 208 | nn.Conv2d(cfg.img_w//cfg.heatmap_width*cfg.num_encoder_filters, 209 | cfg.num_keypoints, kernel_size=1), 210 | nn.Sigmoid()) 211 | 212 | def forward(self, image, pre_keypoints=None): 213 | #image = ops.add_coord_channels(image) 214 | encoded, _ = self.image_encoder(image) 215 | heatmaps = self.features_to_keypoint_heatmaps(encoded) 216 | if not pre_keypoints is None: 217 | pre_gaussian_maps = ops.keypoints_to_maps1(pre_keypoints, sigma=3.0) 218 | else: 219 | pre_gaussian_maps = torch.ones_like(heatmaps).cuda() 220 | #pre_gaussian_maps_np = pre_gaussian_maps.data.cpu().numpy() 221 | #heatmaps_np = heatmaps.data.cpu().numpy() 222 | #keymap_np = (heatmaps * pre_gaussian_maps).data.cpu().numpy() 223 | keypoints = ops.maps_to_keypoints1(heatmaps*pre_gaussian_maps) 224 | return keypoints, heatmaps 225 | 226 | class build_keypoints_to_images_net(nn.Module): 227 | """Builds a model to reconstructs an image from keypoints. 228 | 229 | Model architecture: 230 | 231 | (keypoints[t], image[0], keypoints[0]) --> reconstructed_image 232 | 233 | For all frames image[t] we also we also concatenate the Gaussian maps for 234 | the keypoints obtained from the initial frame image[0]. This helps the 235 | decoder "inpaint" the image regions that are occluded by objects in the first 236 | frame. 237 | 238 | Args: 239 | cfg: ConfigDict with model hyperparameters. 240 | image_shape: Image shape tuple: (C, H, W). 241 | 242 | Returns: 243 | A tf.keras.Model object. 244 | """ 245 | def __init__(self, cfg, image_shape): 246 | super(build_keypoints_to_images_net, self).__init__() 247 | # Build encoder net to extract appearance features from the first frame: 248 | self.keypoint_width = cfg.keypoint_width 249 | self.heatmap_width = cfg.heatmap_width 250 | self.appearance_feature_extractor = build_image_encoder( 251 | input_shape=image_shape, 252 | initial_num_filters=cfg.num_encoder_filters, 253 | layers_per_scale=cfg.layers_per_scale, 254 | **cfg.conv_layer_kwargs) 255 | 256 | # Build image decoder that goes from Gaussian maps to reconstructed images: 257 | num_encoder_output_channels = ( 258 | cfg.num_encoder_filters * image_shape[1] // cfg.heatmap_width) 259 | input_shape = [num_encoder_output_channels, cfg.heatmap_width, 260 | cfg.heatmap_width] 261 | self.image_decoder = build_image_decoder( 262 | input_shape=input_shape, 263 | output_width=image_shape[1], 264 | layers_per_scale=cfg.layers_per_scale, 265 | **cfg.conv_layer_kwargs) 266 | 267 | # Build layers to adjust channel numbers for decoder input and output image: 268 | kwargs = dict(cfg.conv_layer_kwargs) 269 | kwargs['kernel_size'] = 1 270 | kwargs['padding'] = 0 271 | self.adjust_channels_of_decoder_input = nn.Sequential( 272 | nn.Conv2d(cfg.num_keypoints 273 | + cfg.num_encoder_filters*cfg.img_w//cfg.heatmap_width 274 | , num_encoder_output_channels, **kwargs), 275 | nn.LeakyReLU(0.2, inplace=True)) 276 | 277 | kwargs = dict(cfg.conv_layer_kwargs) 278 | kwargs['kernel_size'] = 1 279 | kwargs['padding'] = 0 280 | self.adjust_channels_of_output_image = nn.Sequential( 281 | nn.Conv2d(self.image_decoder.out_filters, image_shape[0], **kwargs)) 282 | 283 | def forward(self, keypoints, first_frame, first_frame_keypoints, predicted_gaussian_maps=None): 284 | # Get features and maps for first frame: 285 | # Note that we cannot use the Gaussian maps above because the 286 | # first_frame_keypoints may be different than the keypoints (i.e. obs vs 287 | # pred). 288 | first_frame_features, skip = self.appearance_feature_extractor(first_frame) 289 | first_frame_gaussian_maps = ops.keypoints_to_maps2(first_frame_keypoints, 290 | sigma=self.keypoint_width, 291 | heatmap_width=self.heatmap_width) 292 | 293 | # Convert keypoints to pixel maps: 294 | if predicted_gaussian_maps is None: 295 | gaussian_maps = ops.keypoints_to_maps2(keypoints, 296 | sigma=self.keypoint_width, 297 | heatmap_width=self.heatmap_width) 298 | else: 299 | gaussian_maps = predicted_gaussian_maps 300 | 301 | # Reconstruct image: 302 | gaussian_maps = gaussian_maps - first_frame_gaussian_maps 303 | combined_representation = torch.cat((gaussian_maps, first_frame_features), 1) 304 | #combined_representation = ops.add_coord_channels(combined_representation) 305 | combined_representation = self.adjust_channels_of_decoder_input( 306 | combined_representation) 307 | decoded_representation = self.image_decoder(combined_representation, skip) 308 | image = self.adjust_channels_of_output_image(decoded_representation) 309 | 310 | # Add in the first frame of the sequence such that the model only needs to 311 | # predict the change from the first frame: 312 | image = image + first_frame 313 | 314 | return image 315 | 316 | class build_image_decoder_2image(nn.Module): 317 | def __init__(self, input_shape, output_width, layers_per_scale=1, **conv_layer_kwargs): 318 | super(build_image_decoder_2image, self).__init__() 319 | self.num_levels = np.log2(output_width / input_shape[2]) 320 | if self.num_levels % 1: 321 | raise ValueError( 322 | 'The ratio of output_width and input width must be a perfect ' 323 | 'square, but got {} and {} with ratio {}'.format( 324 | output_width, input_shape[2], output_width / input_shape[2])) 325 | 326 | # Expand until we have filters_out channels: 327 | self.num_filters = input_shape[0] 328 | num_filters = input_shape[0] 329 | self.up = nn.UpsamplingNearest2d(scale_factor=2) 330 | total_modules = [] 331 | modules = [nn.Conv2d(num_filters, num_filters, **conv_layer_kwargs), 332 | nn.LeakyReLU(0.2, inplace=True)] 333 | # Expand image to initial_num_filters maps: 334 | for i in range(layers_per_scale): 335 | if i < layers_per_scale - 1: 336 | modules.extend([nn.Conv2d(num_filters, num_filters, 337 | **conv_layer_kwargs), 338 | nn.BatchNorm2d(num_filters), 339 | nn.LeakyReLU(0.2, inplace=True)]) 340 | else: 341 | modules.extend([nn.Conv2d(num_filters, num_filters // 2, 342 | **conv_layer_kwargs), 343 | nn.BatchNorm2d(num_filters//2), 344 | nn.LeakyReLU(0.2, inplace=True)]) 345 | total_modules.append(nn.Sequential(*modules)) 346 | modules = [] 347 | 348 | for i in range(int(self.num_levels)): 349 | modules.extend([nn.Conv2d(num_filters, num_filters // 2, **conv_layer_kwargs), 350 | nn.BatchNorm2d(num_filters//2), 351 | nn.LeakyReLU(0.2, inplace=True)]) 352 | # Apply additional layers: 353 | for j in range(layers_per_scale): 354 | if j < layers_per_scale - 1: 355 | modules.extend([nn.Conv2d(num_filters // 2, num_filters // 2, **conv_layer_kwargs), 356 | nn.BatchNorm2d(num_filters//2), 357 | nn.LeakyReLU(0.2, inplace=True)]) 358 | else: 359 | if i < layers_per_scale - 1: 360 | modules.extend([nn.Conv2d(num_filters // 2, num_filters // 4, **conv_layer_kwargs), 361 | nn.BatchNorm2d(num_filters//4), 362 | nn.LeakyReLU(0.2, inplace=True)]) 363 | else: 364 | modules.extend([nn.Conv2d(num_filters // 2, num_filters // 2, **conv_layer_kwargs), 365 | nn.BatchNorm2d(num_filters//2), 366 | nn.LeakyReLU(0.2, inplace=True)]) 367 | num_filters //= 2 368 | total_modules.append(nn.Sequential(*modules)) 369 | modules = [] 370 | self.out_filters = num_filters 371 | self.conv = nn.ModuleList(total_modules) 372 | 373 | def forward(self, x, skip): 374 | d1 = self.conv[0](x) 375 | up1 = self.up(d1) 376 | d2 = self.conv[1](torch.cat([up1, skip[0]], 1)) 377 | up2 = self.up(d2) 378 | d3 = self.conv[2](torch.cat([up2, skip[1]], 1)) 379 | return d3 380 | 381 | class build_keypoints_to_images_net_2image(nn.Module): 382 | def __init__(self, cfg, image_shape): 383 | super(build_keypoints_to_images_net_2image, self).__init__() 384 | # Build encoder net to extract appearance features from the first frame: 385 | self.keypoint_width = cfg.keypoint_width 386 | self.heatmap_width = cfg.heatmap_width 387 | self.appearance_feature_extractor = build_image_encoder( 388 | input_shape=image_shape, 389 | initial_num_filters=cfg.num_encoder_filters, 390 | layers_per_scale=cfg.layers_per_scale, 391 | **cfg.conv_layer_kwargs) 392 | 393 | # Build image decoder that goes from Gaussian maps to reconstructed images: 394 | num_encoder_output_channels = ( 395 | cfg.num_encoder_filters * image_shape[1] // cfg.heatmap_width) 396 | input_shape = [num_encoder_output_channels, cfg.heatmap_width, 397 | cfg.heatmap_width] 398 | self.image_decoder = build_image_decoder_2image( 399 | input_shape=input_shape, 400 | output_width=image_shape[1], 401 | layers_per_scale=cfg.layers_per_scale, 402 | **cfg.conv_layer_kwargs) 403 | 404 | # Build layers to adjust channel numbers for decoder input and output image: 405 | kwargs = dict(cfg.conv_layer_kwargs) 406 | kwargs['kernel_size'] = 1 407 | kwargs['padding'] = 0 408 | self.adjust_channels_of_decoder_input = nn.Sequential( 409 | nn.Conv2d(cfg.num_keypoints*3 410 | + cfg.num_encoder_filters*cfg.img_w//cfg.heatmap_width*2 411 | , num_encoder_output_channels, **kwargs), 412 | nn.BatchNorm2d(num_encoder_output_channels), 413 | nn.LeakyReLU(0.2, inplace=True)) 414 | 415 | kwargs = dict(cfg.conv_layer_kwargs) 416 | kwargs['kernel_size'] = 1 417 | kwargs['padding'] = 0 418 | self.adjust_channels_of_output_image = nn.Sequential( 419 | nn.Conv2d(self.image_decoder.out_filters, image_shape[0], **kwargs)) 420 | 421 | def forward(self, keypoints, first_frame, first_frame_keypoints, sec_frame, sec_frame_keypoints): 422 | # Get features and maps for first frame: 423 | # Note that we cannot use the Gaussian maps above because the 424 | # first_frame_keypoints may be different than the keypoints (i.e. obs vs 425 | # pred). 426 | first_frame_features, skip = self.appearance_feature_extractor(first_frame) 427 | first_frame_gaussian_maps = ops.keypoints_to_maps2(first_frame_keypoints, 428 | sigma=self.keypoint_width, 429 | heatmap_width=self.heatmap_width) 430 | 431 | sec_frame_features, _ = self.appearance_feature_extractor(sec_frame) 432 | sec_frame_gaussian_maps = ops.keypoints_to_maps2(sec_frame_keypoints, 433 | sigma=self.keypoint_width, 434 | heatmap_width=self.heatmap_width) 435 | 436 | # Convert keypoints to pixel maps: 437 | gaussian_maps = ops.keypoints_to_maps2(keypoints, 438 | sigma=self.keypoint_width, 439 | heatmap_width=self.heatmap_width) 440 | 441 | # Reconstruct image: 442 | combined_representation = torch.cat((gaussian_maps, first_frame_gaussian_maps, 443 | sec_frame_gaussian_maps, 444 | first_frame_features, 445 | sec_frame_features), 1) 446 | #combined_representation = ops.add_coord_channels(combined_representation) 447 | combined_representation = self.adjust_channels_of_decoder_input( 448 | combined_representation) 449 | decoded_representation = self.image_decoder(combined_representation, skip) 450 | image = self.adjust_channels_of_output_image(decoded_representation) 451 | 452 | # Add in the first frame of the sequence such that the model only needs to 453 | # predict the change from the first frame: 454 | image = torch.tanh(image) 455 | 456 | return image 457 | 458 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """PyTorch ops for the structured video representation model.""" 18 | 19 | import enum 20 | import torch 21 | import numpy as np 22 | import matplotlib.pyplot as plt 23 | import torch.nn.functional as F 24 | 25 | EPSILON = 1e-20 # Constant for numerical stability. 26 | 27 | 28 | class Axis(enum.Enum): 29 | """Maps axes to image indices, assuming that 0th dimension is the batch.""" 30 | y = 2 31 | x = 3 32 | 33 | 34 | def maps_to_keypoints(heatmaps): 35 | """Turns feature-detector heatmaps into (x, y, scale) keypoints. 36 | 37 | This function takes a tensor of feature maps as input. Each map is normalized 38 | to a probability distribution and the location of the mean of the distribution 39 | (in image coordinates) is computed. This location is used as a low-dimensional 40 | representation of the heatmap (i.e. a keypoint). 41 | 42 | To model keypoint presence/absence, the mean intensity of each feature map is 43 | also computed, so that each keypoint is represented by an (x, y, scale) 44 | triplet. 45 | 46 | Args: 47 | heatmaps: [batch_size, num_keypoints, H, W] tensors. 48 | Returns: 49 | A [batch_size, num_keypoints, 3] tensor with (x, y, scale)-triplets for each 50 | keypoint. Coordinate range is [-1, 1] for x and y, and [0, 1] for scale. 51 | """ 52 | 53 | # Check that maps are non-negative: 54 | map_min = torch.min(heatmaps) 55 | if map_min < 0.0: 56 | print("map_min: ", map_min.detach().cpu().numpy()) 57 | assert map_min >= 0.0 58 | 59 | x_coordinates = _maps_to_coordinates(heatmaps, Axis.x) 60 | y_coordinates = _maps_to_coordinates(heatmaps, Axis.y) 61 | map_scales = torch.mean(heatmaps, dim=(2, 3)) 62 | ''' 63 | map_scales_np = map_scales.data.cpu().numpy() 64 | heatmaps_np = heatmaps.data.cpu().numpy() 65 | map_scales_signp = torch.mean(torch.sigmoid(heatmaps), dim=(2, 3)).data.cpu().numpy() 66 | heatmaps_signp = torch.sigmoid(heatmaps).data.cpu().numpy() 67 | map_scales_max = torch.max(torch.sigmoid(heatmaps), dim = -1, keepdim=True)[0] 68 | map_scales_max = torch.max(map_scales_max, dim=-2, keepdim=True)[0] 69 | heatmaps1 = torch.sigmoid(heatmaps) - 0.9 * map_scales_max 70 | heatmaps1 = F.relu(heatmaps1)*10 71 | heatmaps1_np = heatmaps1.data.cpu().numpy() 72 | xy = change2xy(torch.stack((x_coordinates, y_coordinates), dim=-1), width=16) 73 | 74 | #''' 75 | 76 | # Normalize map scales to [0.0, 1.0] across keypoints. This removes a 77 | # degeneracy between the encoder and decoder heatmap scales and ensures that 78 | # the scales are in a reasonable range for the RNN: 79 | map_scales /= (EPSILON + torch.max(map_scales, dim=-1, keepdim=True)[0]) 80 | 81 | return torch.stack((x_coordinates, y_coordinates, map_scales), dim=-1) 82 | 83 | def maps_to_keypoints1(heatmaps): 84 | """ 85 | do not use miu 86 | """ 87 | 88 | # Check that maps are non-negative: 89 | #map_min = torch.min(heatmaps) 90 | #if map_min < 0.0: 91 | #print("map_min: ", map_min.detach().cpu().numpy()) 92 | #assert map_min >= 0.0 93 | 94 | #heatmaps_np = heatmaps.data.cpu().numpy() 95 | 96 | x_coordinates = _maps_to_coordinates(heatmaps, Axis.x) 97 | y_coordinates = _maps_to_coordinates(heatmaps, Axis.y) 98 | 99 | # Normalize map scales to [0.0, 1.0] across keypoints. This removes a 100 | # degeneracy between the encoder and decoder heatmap scales and ensures that 101 | # the scales are in a reasonable range for the RNN: 102 | 103 | 104 | return torch.stack((x_coordinates, y_coordinates), dim=-1) 105 | 106 | def maps_to_keypoints1_1(heatmaps): 107 | """ 108 | use max value as miu 109 | """ 110 | 111 | # Check that maps are non-negative: 112 | map_min = torch.min(heatmaps) 113 | if map_min < 0.0: 114 | print("map_min: ", map_min.detach().cpu().numpy()) 115 | assert map_min >= 0.0 116 | 117 | #heatmaps_np = heatmaps.data.cpu().numpy() 118 | 119 | x_coordinates = _maps_to_coordinates(heatmaps, Axis.x) 120 | y_coordinates = _maps_to_coordinates(heatmaps, Axis.y) 121 | 122 | map_scales = torch.max(heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1), dim=-1)[0] 123 | 124 | # Normalize map scales to [0.0, 1.0] across keypoints. This removes a 125 | # degeneracy between the encoder and decoder heatmap scales and ensures that 126 | # the scales are in a reasonable range for the RNN: 127 | 128 | 129 | return torch.stack((x_coordinates, y_coordinates, map_scales), dim=-1) 130 | 131 | def maps_to_keypoints2(heatmaps): 132 | """ 133 | do not use miu 134 | use max point as coordinate 135 | """ 136 | 137 | # Check that maps are non-negative: 138 | map_min = torch.min(heatmaps) 139 | if map_min < 0.0: 140 | print("map_min: ", map_min.detach().cpu().numpy()) 141 | assert map_min >= 0.0 142 | #heatmaps_np0 = heatmaps.data.cpu().numpy() 143 | heatmaps1 = heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1) 144 | topk = torch.topk(heatmaps1, k=1 + 1, dim=-1)[0] 145 | heatmaps = heatmaps - topk[:, :, -1, np.newaxis, np.newaxis] 146 | heatmaps = F.relu(heatmaps) 147 | #heatmaps_np = heatmaps.data.cpu().numpy() 148 | 149 | x_coordinates = _maps_to_coordinates(heatmaps, Axis.x) 150 | y_coordinates = _maps_to_coordinates(heatmaps, Axis.y) 151 | 152 | # Normalize map scales to [0.0, 1.0] across keypoints. This removes a 153 | # degeneracy between the encoder and decoder heatmap scales and ensures that 154 | # the scales are in a reasonable range for the RNN: 155 | return torch.stack((x_coordinates, y_coordinates), dim=-1) 156 | 157 | def maps_to_keypoints3(heatmaps): 158 | """ 159 | use miu 160 | use max point as coordinate 161 | """ 162 | 163 | # Check that maps are non-negative: 164 | map_min = torch.min(heatmaps) 165 | if map_min < 0.0: 166 | print("map_min: ", map_min.detach().cpu().numpy()) 167 | assert map_min >= 0.0 168 | heatmaps1 = heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1) 169 | topk = torch.topk(heatmaps1, k=1 + 1, dim=-1)[0] 170 | heatmaps = heatmaps - topk[:, :, -1, np.newaxis, np.newaxis] 171 | heatmaps = F.relu(heatmaps) 172 | x_coordinates = _maps_to_coordinates(heatmaps, Axis.x) 173 | y_coordinates = _maps_to_coordinates(heatmaps, Axis.y) 174 | map_scales = topk[:,:,0] 175 | 176 | # Normalize map scales to [0.0, 1.0] across keypoints. This removes a 177 | # degeneracy between the encoder and decoder heatmap scales and ensures that 178 | # the scales are in a reasonable range for the RNN: 179 | map_scales /= (EPSILON + torch.max(map_scales, dim=-1, keepdim=True)[0]) 180 | #map_scales_np = map_scales.data.cpu().numpy() 181 | 182 | return torch.stack((x_coordinates, y_coordinates, map_scales), dim=-1) 183 | 184 | def maps_to_keypoints4(heatmaps): 185 | """ 186 | consider variance 187 | """ 188 | x_coordinates = _maps_to_coordinates(heatmaps, Axis.x) 189 | y_coordinates = _maps_to_coordinates(heatmaps, Axis.y) 190 | #x_variance = _maps_to_variance(heatmaps, Axis.x, x_coordinates) 191 | #y_variance = _maps_to_variance(heatmaps, Axis.y, y_coordinates) 192 | return torch.stack((x_coordinates, y_coordinates), dim=-1) 193 | 194 | #------------------------------key areas------------------------------- 195 | def maps_to_keyareas(heatmaps): 196 | 197 | # Check that maps are non-negative: 198 | map_min = torch.min(heatmaps) 199 | if map_min < 0.0: 200 | print("map_min: ", map_min.detach().cpu().numpy()) 201 | assert map_min >= 0.0 202 | map_scales_max = torch.max(heatmaps, dim=-1, keepdim = True)[0] 203 | map_scales_max = torch.max(map_scales_max, dim=-2, keepdim=True)[0] 204 | heatmaps = heatmaps/map_scales_max 205 | #heatmaps_np = heatmaps.data.cpu().numpy() 206 | 207 | return heatmaps 208 | 209 | def maps_to_keyareas1_1(heatmaps): 210 | 211 | # Check that maps are non-negative: 212 | map_min = torch.min(heatmaps) 213 | if map_min < 0.0: 214 | print("map_min: ", map_min.detach().cpu().numpy()) 215 | assert map_min >= 0.0 216 | heatmaps = heatmaps - 0.8 217 | heatmaps = F.relu(heatmaps)*5 218 | heatmaps_np = heatmaps.data.cpu().numpy() 219 | 220 | return heatmaps 221 | 222 | def maps_to_keyareas1(heatmaps): 223 | heatmaps1 = heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1) 224 | topk = torch.topk(heatmaps1, k=48+1, dim=-1)[0] 225 | heatmaps = heatmaps - topk[:, :, -1, np.newaxis, np.newaxis] 226 | heatmaps = torch.tanh(heatmaps * 10000) 227 | heatmaps = F.relu(heatmaps) 228 | #heatmaps_np = heatmaps.data.cpu().numpy() 229 | 230 | return heatmaps 231 | 232 | def maps_to_keyareas2(heatmaps): 233 | heatmaps1 = heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1) 234 | topk = torch.topk(heatmaps1, k=1+1, dim=-1)[0] 235 | heatmaps = heatmaps - topk[:, :, -1, np.newaxis, np.newaxis] 236 | heatmaps = F.relu(heatmaps) 237 | heatmaps = torch.tanh(heatmaps * 10000) 238 | heatmaps = torch.sum(heatmaps, dim=1, keepdim=True) 239 | #heatmaps_np = heatmaps.data.cpu().numpy() 240 | 241 | return heatmaps 242 | 243 | 244 | def _maps_to_coordinates(maps, axis): 245 | """Reduces heatmaps to coordinates along one axis (x or y). 246 | 247 | Args: 248 | maps: [batch_size, num_keypoints, H, W] tensors. 249 | axis: Axis Enum. 250 | 251 | Returns: 252 | A [batch_size, num_keypoints, 2] tensor with (x, y)-coordinates. 253 | """ 254 | 255 | width = maps.shape[axis.value] 256 | grid = _get_pixel_grid(axis, width) 257 | shape = [1, 1, 1, 1] 258 | shape[axis.value] = -1 259 | grid = grid.view(shape) 260 | 261 | if axis == Axis.x: 262 | marginalize_dim = 2 263 | elif axis == Axis.y: 264 | marginalize_dim = 3 265 | 266 | # Normalize the heatmaps to a probability distribution (i.e. sum to 1): 267 | weights = torch.sum(maps, dim=marginalize_dim, keepdim=True) 268 | 269 | weights /= torch.sum(weights, dim=axis.value, keepdim=True) + EPSILON 270 | #weights_np = weights.data.cpu().numpy() 271 | 272 | # Compute the center of mass of the marginalized maps to obtain scalar 273 | # coordinates: 274 | coordinates = torch.sum(weights * grid, dim=axis.value, keepdim=True) 275 | coordinates = torch.squeeze(coordinates, -1) 276 | coordinates = torch.squeeze(coordinates, -1) 277 | 278 | return coordinates 279 | 280 | def _maps_to_variance(maps, axis, miu): 281 | """Reduces heatmaps to variances along one axis (x or y). 282 | """ 283 | 284 | width = maps.shape[axis.value] 285 | grid = _get_pixel_grid(axis, width) 286 | shape = [1, 1, 1, 1] 287 | shape[axis.value] = -1 288 | grid = grid.view(shape) 289 | 290 | if axis == Axis.x: 291 | marginalize_dim = 2 292 | elif axis == Axis.y: 293 | marginalize_dim = 3 294 | 295 | # Normalize the heatmaps to a probability distribution (i.e. sum to 1): 296 | weights = torch.sum(maps + EPSILON, dim=marginalize_dim, keepdim=True) 297 | 298 | weights /= torch.sum(weights, dim=axis.value, keepdim=True) 299 | miu = miu[:, :, np.newaxis, np.newaxis] 300 | var = (grid-miu)**2 301 | variance = torch.sum(weights * var, dim=axis.value, keepdim=True) 302 | variance = torch.squeeze(variance, -1) 303 | variance = torch.squeeze(variance, -1) 304 | return variance 305 | 306 | def keypoints_to_maps(keypoints, sigma=1.0, heatmap_width=16): 307 | """Turns (x, y, scale)-tuples into pixel maps with a Gaussian blob at (x, y). 308 | 309 | Args: 310 | keypoints: [batch_size, num_keypoints, 3] tensor of keypoints where the last 311 | dimension contains (x, y, scale) triplets. 312 | sigma: Std. dev. of the Gaussian blob, in units of heatmap pixels. 313 | heatmap_width: Width of output heatmaps in pixels. 314 | 315 | Returns: 316 | A [batch_size, num_keypoints, heatmap_width, heatmap_width] tensor. 317 | """ 318 | 319 | coordinates, map_scales = torch.split(keypoints, 2, dim=-1) 320 | 321 | def get_grid(axis): 322 | grid = _get_pixel_grid(axis, heatmap_width) 323 | shape = [1, 1, 1, 1] 324 | shape[axis.value] = -1 325 | return grid.view(shape) 326 | 327 | # Expand to [batch_size, num_keypoints, 1, 1] for broadcasting later: 328 | x_coordinates = coordinates[:, :, np.newaxis, np.newaxis, 0] 329 | y_coordinates = coordinates[:, :, np.newaxis, np.newaxis, 1] 330 | 331 | # Create two 1-D Gaussian vectors (marginals) and multiply to get a 2-d map: 332 | #sigma = torch.FloatTensor(sigma) 333 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 334 | 335 | x_vec = torch.exp(-(get_grid(Axis.x) - x_coordinates)**2/keypoint_width) 336 | y_vec = torch.exp(-(get_grid(Axis.y) - y_coordinates)**2/keypoint_width) 337 | maps = torch.mul(x_vec, y_vec) 338 | 339 | #npmaps0 = maps.data.cpu().numpy() 340 | maps = maps * map_scales[:, :, np.newaxis, np.newaxis, 0] 341 | #npmaps = maps.data.cpu().numpy() 342 | #maps = torch.sum(maps, dim=1, keepdim=True) 343 | #npmaps1 = maps.detach().cpu().numpy() 344 | 345 | return maps 346 | 347 | def keypoints_to_maps1(keypoints, sigma=1.0, heatmap_width=16): 348 | """ 349 | do not use miu 350 | """ 351 | def get_grid(axis): 352 | grid = _get_pixel_grid(axis, heatmap_width) 353 | shape = [1, 1, 1, 1] 354 | shape[axis.value] = -1 355 | return grid.view(shape) 356 | 357 | # Expand to [batch_size, num_keypoints, 1, 1] for broadcasting later: 358 | x_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 0] 359 | y_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 1] 360 | 361 | # Create two 1-D Gaussian vectors (marginals) and multiply to get a 2-d map: 362 | #sigma = torch.FloatTensor(sigma) 363 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 364 | 365 | x_vec = torch.exp(-(get_grid(Axis.x) - x_coordinates)**2/keypoint_width) 366 | y_vec = torch.exp(-(get_grid(Axis.y) - y_coordinates)**2/keypoint_width) 367 | maps = torch.mul(x_vec, y_vec) 368 | #maps_np = maps.data.cpu().numpy() 369 | 370 | return maps 371 | 372 | def keypoints_to_maps2(keypoints, sigma=1.0, heatmap_width=16): 373 | """ 374 | do not use miu 375 | re nomalize max value 376 | """ 377 | def get_grid(axis): 378 | grid = _get_pixel_grid(axis, heatmap_width) 379 | shape = [1, 1, 1, 1] 380 | shape[axis.value] = -1 381 | return grid.view(shape) 382 | 383 | # Expand to [batch_size, num_keypoints, 1, 1] for broadcasting later: 384 | x_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 0] 385 | y_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 1] 386 | 387 | # Create two 1-D Gaussian vectors (marginals) and multiply to get a 2-d map: 388 | #sigma = torch.FloatTensor(sigma) 389 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 390 | 391 | x_vec = torch.exp(-(get_grid(Axis.x) - x_coordinates)**2/keypoint_width) 392 | y_vec = torch.exp(-(get_grid(Axis.y) - y_coordinates)**2/keypoint_width) 393 | maps = torch.mul(x_vec, y_vec) 394 | maps_max = torch.max(maps.view(maps.size(0), maps.size(1), -1), dim=-1)[0] 395 | maps = maps / maps_max[:, :, np.newaxis, np.newaxis] 396 | #maps_np = maps.data.cpu().numpy() 397 | 398 | return maps 399 | 400 | def keypoints_to_edgemaps(keypoints, neighbor_link, sigma=1.0, heatmap_width=16): 401 | """ 402 | do not use miu 403 | """ 404 | def get_grid(axis): 405 | grid = _get_pixel_grid(axis, heatmap_width) 406 | shape = [1, 1, 1, 1] 407 | shape[axis.value] = -1 408 | return grid.view(shape) 409 | 410 | # Expand to [batch_size, num_keypoints, 1, 1] for broadcasting later: 411 | x_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 0] 412 | y_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 1] 413 | 414 | # Create two 1-D Gaussian vectors (marginals) and multiply to get a 2-d map: 415 | #sigma = torch.FloatTensor(sigma) 416 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 417 | 418 | x_vec = torch.exp(-(get_grid(Axis.x) - x_coordinates)**2/keypoint_width) 419 | y_vec = torch.exp(-(get_grid(Axis.y) - y_coordinates)**2/keypoint_width) 420 | maps = torch.mul(x_vec, y_vec) 421 | edgemaps = [] 422 | for edge in neighbor_link: 423 | edgemaps.append(maps[:, edge[0]]+maps[:, edge[1]]) 424 | edgemaps = torch.stack(edgemaps, dim = 1) 425 | #maps_np = maps.data.cpu().numpy() 426 | #edgemaps_np = edgemaps.data.cpu().numpy() 427 | 428 | return edgemaps 429 | 430 | def keypoints_to_edgemaps1(keypoints, neighbor_link, sigma=1.0, heatmap_width=16): 431 | """ 432 | do not use miu 433 | """ 434 | def get_grid(axis): 435 | grid = _get_pixel_grid(axis, heatmap_width) 436 | shape = [1, 1, 1, 1] 437 | shape[axis.value] = -1 438 | return grid.view(shape) 439 | 440 | # Expand to [batch_size, num_keypoints, 1, 1] for broadcasting later: 441 | x_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 0] 442 | y_coordinates = keypoints[:, :, np.newaxis, np.newaxis, 1] 443 | 444 | # Create two 1-D Gaussian vectors (marginals) and multiply to get a 2-d map: 445 | #sigma = torch.FloatTensor(sigma) 446 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 447 | 448 | x_vec = torch.exp(-(get_grid(Axis.x) - x_coordinates)**2/keypoint_width) 449 | y_vec = torch.exp(-(get_grid(Axis.y) - y_coordinates)**2/keypoint_width) 450 | maps = torch.mul(x_vec, y_vec) 451 | #maps_np = maps.data.cpu().numpy() 452 | edgemaps = [] 453 | d = 1.5/(heatmap_width-1)*2 454 | for edge in neighbor_link: 455 | edgemaps.append(maps[:, edge[0]] + maps[:, edge[1]]) 456 | dist = ((keypoints[:, edge[0], 0] - keypoints[:, edge[1], 0]) ** 2+ 457 | (keypoints[:, edge[0], 1] - keypoints[:, edge[1], 1]) ** 2)**0.5 458 | for j in range(keypoints.size(0)): 459 | if dist[j]>d*2: 460 | x_coordinates1 = 0.5 * (keypoints[j, edge[0], 0] + 461 | keypoints[j, edge[1], 0]) 462 | y_coordinates1 = 0.5 * (keypoints[j, edge[0], 1] + 463 | keypoints[j, edge[1], 1]) 464 | x_vec1 = torch.exp(-(get_grid(Axis.x) - x_coordinates1) ** 2 / keypoint_width) 465 | y_vec1 = torch.exp(-(get_grid(Axis.y) - y_coordinates1) ** 2 / keypoint_width) 466 | maps1 = torch.mul(x_vec1, y_vec1) 467 | edgemaps[-1][j] += maps1[0,0] 468 | ''' 469 | x_coordinates1 = 0.5*(keypoints[:, edge[0], np.newaxis, np.newaxis, np.newaxis, 0]+ 470 | keypoints[:, edge[1], np.newaxis, np.newaxis, np.newaxis, 0]) 471 | y_coordinates1 = 0.5*(keypoints[:, edge[0], np.newaxis, np.newaxis, np.newaxis, 1]+ 472 | keypoints[:, edge[1], np.newaxis, np.newaxis, np.newaxis, 1]) 473 | x_vec1 = torch.exp(-(get_grid(Axis.x) - x_coordinates1) ** 2 / keypoint_width) 474 | y_vec1 = torch.exp(-(get_grid(Axis.y) - y_coordinates1) ** 2 / keypoint_width) 475 | maps1 = torch.mul(x_vec1, y_vec1) 476 | ''' 477 | #maps1_np = maps1.data.cpu().numpy() 478 | edgemaps = torch.stack(edgemaps, dim = 1) 479 | #edgemaps_np = edgemaps.data.cpu().numpy() 480 | 481 | return edgemaps 482 | 483 | def keypoints_to_edgemaps2(keypoints, neighbor_link, sigma=1.0, heatmap_width=16): 484 | """ 485 | do not use miu 486 | rotation Gaussian 487 | """ 488 | def get_grid(axis): 489 | grid = _get_pixel_grid(axis, heatmap_width) 490 | shape = [1, 1, 1, 1] 491 | shape[axis.value] = -1 492 | return grid.view(shape) 493 | 494 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 495 | 496 | edgemaps = [] 497 | S = torch.FloatTensor([[1,0],[0,1/4]]).cuda() 498 | for edge in neighbor_link: 499 | x_coordinates1 = 0.5*(keypoints[:, edge[0], np.newaxis, np.newaxis, np.newaxis, 0]+ 500 | keypoints[:, edge[1], np.newaxis, np.newaxis, np.newaxis, 0]) 501 | y_coordinates1 = 0.5*(keypoints[:, edge[0], np.newaxis, np.newaxis, np.newaxis, 1]+ 502 | keypoints[:, edge[1], np.newaxis, np.newaxis, np.newaxis, 1]) 503 | theta = torch.atan((keypoints[:, edge[0], 1] - keypoints[:, edge[1], 1])/\ 504 | (keypoints[:, edge[0], 0] - keypoints[:, edge[1], 0] + EPSILON)) 505 | cos_theta = torch.cos(-theta) 506 | sin_theta = torch.sin(-theta) 507 | R1 = torch.stack([cos_theta, -sin_theta], dim=-1) 508 | R2 = torch.stack([sin_theta, cos_theta], dim=-1) 509 | R = torch.stack([R1, R2], dim=1) 510 | RT = torch.einsum('...ij->...ji', R) 511 | Sigma = torch.einsum('bij, jk, bkm->bim', R, S, RT) 512 | x_vec1 = -Sigma[:, np.newaxis, np.newaxis, np.newaxis, 0, 0]*(get_grid(Axis.x) - x_coordinates1) ** 2 / keypoint_width 513 | y_vec1 = -Sigma[:, np.newaxis, np.newaxis, np.newaxis, 1, 1]*(get_grid(Axis.y) - y_coordinates1) ** 2 / keypoint_width 514 | xy_vec1 = -2*Sigma[:, np.newaxis, np.newaxis, np.newaxis, 0, 1]*(get_grid(Axis.x) - x_coordinates1)*(get_grid(Axis.y) - y_coordinates1) / keypoint_width 515 | maps1 = torch.exp(x_vec1+y_vec1+xy_vec1) 516 | #maps1_np = maps1.data.cpu().numpy() 517 | edgemaps.append(maps1) 518 | edgemaps = torch.stack(edgemaps, dim = 1).squeeze() 519 | #edgemaps = torch.sum(edgemaps, dim=1, keepdim=True) 520 | edgemaps_max = torch.max(edgemaps.view(edgemaps.size(0),edgemaps.size(1), -1), dim=-1)[0] 521 | edgemaps = edgemaps/edgemaps_max[:, :, np.newaxis, np.newaxis] 522 | #edgemaps_np = edgemaps.data.cpu().numpy() 523 | 524 | return edgemaps 525 | 526 | def keypoints_to_edgemaps3(keypoints, neighbor_link, sigma=1.0, heatmap_width=16): 527 | """ 528 | do not use miu 529 | rotation Gaussian 530 | changable variance 531 | """ 532 | def get_grid(axis): 533 | grid = _get_pixel_grid(axis, heatmap_width) 534 | shape = [1, 1, 1, 1] 535 | shape[axis.value] = -1 536 | return grid.view(shape) 537 | 538 | keypoint_width = 2.0 * (sigma / heatmap_width) ** 2.0 539 | 540 | edgemaps = [] 541 | for edge in neighbor_link: 542 | dist = ((keypoints[:, edge[0], 0] - keypoints[:, edge[1], 0]) ** 2 + 543 | (keypoints[:, edge[0], 1] - keypoints[:, edge[1], 1]) ** 2) ** 0.5 544 | x_coordinates1 = 0.5 * (keypoints[:, edge[0], np.newaxis, np.newaxis, np.newaxis, 0] + 545 | keypoints[:, edge[1], np.newaxis, np.newaxis, np.newaxis, 0]) 546 | y_coordinates1 = 0.5 * (keypoints[:, edge[0], np.newaxis, np.newaxis, np.newaxis, 1] + 547 | keypoints[:, edge[1], np.newaxis, np.newaxis, np.newaxis, 1]) 548 | theta = torch.atan((keypoints[:, edge[0], 1] - keypoints[:, edge[1], 1])/\ 549 | (keypoints[:, edge[0], 0] - keypoints[:, edge[1], 0] + EPSILON)) 550 | cos_theta = torch.cos(-theta) 551 | sin_theta = torch.sin(-theta) 552 | R1 = torch.stack([cos_theta, -sin_theta], dim=-1) 553 | R2 = torch.stack([sin_theta, cos_theta], dim=-1) 554 | R = torch.stack([R1, R2], dim=1) 555 | RT = torch.einsum('...ij->...ji', R) 556 | S = [torch.FloatTensor([[1,0],[0,1/4/(1+dist[i]*10)]]).cuda() for i in range(keypoints.size(0))] 557 | S = torch.stack(S) 558 | Sigma = torch.einsum('bij, bjk, bkm->bim', R, S, RT) 559 | x_vec1 = -Sigma[:, np.newaxis, np.newaxis, np.newaxis, 0, 0]*(get_grid(Axis.x) - x_coordinates1) ** 2 / keypoint_width 560 | y_vec1 = -Sigma[:, np.newaxis, np.newaxis, np.newaxis, 1, 1]*(get_grid(Axis.y) - y_coordinates1) ** 2 / keypoint_width 561 | xy_vec1 = -2*Sigma[:, np.newaxis, np.newaxis, np.newaxis, 0, 1]*(get_grid(Axis.x) - x_coordinates1)*(get_grid(Axis.y) - y_coordinates1) / keypoint_width 562 | maps1 = torch.exp(x_vec1+y_vec1+xy_vec1) 563 | #maps1_np = maps1.data.cpu().numpy() 564 | edgemaps.append(maps1) 565 | edgemaps = torch.stack(edgemaps, dim = 1).squeeze() 566 | edgemaps = torch.sum(edgemaps, dim=1, keepdim=True) 567 | edgemaps_max = torch.max(edgemaps.view(edgemaps.size(0), -1), dim=-1)[0] 568 | edgemaps = edgemaps / edgemaps_max[:, np.newaxis, np.newaxis, np.newaxis] 569 | #edgemaps_np = edgemaps.data.cpu().numpy() 570 | 571 | return edgemaps 572 | 573 | def _get_pixel_grid(axis, width): 574 | """Returns an array of length `width` containing pixel coordinates.""" 575 | if axis == Axis.x: 576 | return torch.linspace(-1.0, 1.0, width).cuda() # Left is negative, right is positive. 577 | elif axis == Axis.y: 578 | return torch.linspace(1.0, -1.0, width).cuda() # Top is positive, bottom is negative. 579 | 580 | 581 | def change2xy(cfg, keypoints): 582 | xy = keypoints[:,:,:2].detach().cpu().numpy() 583 | xy[:,:,0] = np.round((1+xy[:,:,0])*(cfg.img_w-1)/2) 584 | xy[:,:,1] = np.round((1-xy[:,:,1])*(cfg.img_w-1)/2) 585 | return np.uint8(xy) 586 | 587 | def add_keypoints(image, key_points, radius = 1, miu=None): # image in [-0.5, 0.5] 588 | im_key = torch.ones_like(image).cuda() 589 | if not (miu is None): 590 | im_key1 = torch.zeros_like(image).cuda() 591 | 592 | for i in range(len(key_points)): 593 | x = [np.clip(key_points[i][0]-radius, 0, 63), np.clip(key_points[i][0]+radius, 0, 63)] 594 | y = [np.clip(key_points[i][1] - radius, 0, 63), np.clip(key_points[i][1] + radius, 0, 63)] 595 | im_key[:, y[0]:y[1] + 1, key_points[i][0]] = 0 596 | im_key[:, key_points[i][1], x[0]:x[1] + 1] = 0 597 | if not (miu is None): 598 | if miu[i] > 0: 599 | im_key1[:, y[0]:y[1] + 1, key_points[i][0]] = miu[i] 600 | im_key1[:, key_points[i][1], x[0]:x[1] + 1] = miu[i] 601 | 602 | if not (miu is None): 603 | image *= im_key 604 | im_key = 1 - im_key 605 | image += -0.5 * im_key 606 | image += im_key1 607 | else: 608 | #im_key_np = im_key.data.cpu().numpy() 609 | image *= im_key 610 | im_key = 1 - im_key 611 | image += 0.5 * im_key 612 | return image 613 | 614 | def show_keypoints(image, key_points, img_w, radius = 1): 615 | im_key = torch.zeros_like(image).cuda() 616 | 617 | for i in range(len(key_points)): 618 | x = [np.clip(key_points[i][0]-radius, 0, img_w-1), np.clip(key_points[i][0]+radius, 0, img_w-1)] 619 | y = [np.clip(key_points[i][1] - radius, 0, img_w-1), np.clip(key_points[i][1] + radius, 0, img_w-1)] 620 | im_key[:, y[0]:y[1] + 1, key_points[i][0]] = 1 621 | im_key[:, key_points[i][1], x[0]:x[1] + 1] = 1 622 | return im_key-0.5 623 | 624 | 625 | -------------------------------------------------------------------------------- /train_all_AE_BN3_kth_1_2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | r"""Minimal example for training a video_structure model. 18 | 19 | See README.md for installation instructions. To run on GPU device 0: 20 | 21 | CUDA_VISIBLE_DEVICES=0 python -m video_structure.train 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | 27 | from __future__ import print_function 28 | 29 | import torch 30 | import torch.optim as optim 31 | import torch.nn as nn 32 | import os, subprocess 33 | import random 34 | from torch.utils.data import DataLoader 35 | from torch.autograd import Variable 36 | import numpy as np 37 | import utils,hyperparameters,losses,ops,dynanmics 38 | import time 39 | import imageio, cv2 40 | ''' 41 | Builds the complete model with image encoder plus dynamics model. 42 | 43 | This architecture is meant for testing/illustration only. 44 | 45 | Model architecture: 46 | 47 | image --> keypoints --> reconstructed_image 48 | 49 | The model takes a [batch_size, timesteps, H, W, C] image sequence as input. It 50 | "observes" all frames, detects keypoints, and reconstructs the images. The 51 | dynamics model learns to predict future keypoints based on the detected 52 | keypoints. 53 | ''' 54 | os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in subprocess.Popen( 55 | "nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, 56 | stdout=subprocess.PIPE).stdout.readlines()])) 57 | 58 | cfg = hyperparameters.get_config() 59 | name = 'model=AE_BN3_kth_1_2' 60 | 61 | # do not use miu 62 | # convlstm 63 | # pi loss 64 | # maxpooling map 65 | # auto-regressive model 66 | 67 | cfg.dataset = 'kth' 68 | cfg.observed_steps = 10 69 | cfg.predicted_steps = 10 70 | cfg.num_keypoints = 12 71 | cfg.batch_size = 32 72 | cfg.num_epochs = 1500 73 | cfg.learning_rate = 1e-3 74 | cfg.kl_loss_scale = 0.05 #---------------------------- 75 | cfg.test_N = 256 76 | cfg.test_batch_size = 32 77 | cfg.nsample = 4 78 | cfg.reso = cfg.img_w 79 | 80 | load_model = False 81 | log_dir = 'logs/struc' 82 | log_dir = '%s-%s' % (log_dir, name) 83 | 84 | os.makedirs('%s/gen1/' % log_dir, exist_ok=True) 85 | 86 | print("Random Seed: ", cfg.seed) 87 | np.random.seed(cfg.seed) 88 | random.seed(cfg.seed) 89 | torch.manual_seed(cfg.seed) 90 | torch.cuda.manual_seed_all(cfg.seed) 91 | dtype = torch.cuda.FloatTensor 92 | 93 | # --------- loss functions ------------------------------------ 94 | def CalDelta_xy(keypoints_np, width=64): 95 | keypoints_np1 = np.zeros_like(keypoints_np) 96 | keypoints_np1[:, :, 0] = np.round((1 + keypoints_np[:, :, 0]) * (width-1) / 2) 97 | keypoints_np1[:, :, 1] = np.round((1 - keypoints_np[:, :, 1]) * (width-1) / 2) 98 | keypoints_np1[:, :, 0] = keypoints_np1[:, :, 0] / (width-1) * 2 - 1 99 | keypoints_np1[:, :, 1] = 1 - keypoints_np1[:, :, 1] / (width-1) * 2 100 | 101 | delta_xy = keypoints_np1 - keypoints_np 102 | return delta_xy 103 | 104 | mse_criterion = nn.MSELoss(reduction='sum') 105 | mse_criterion_none = nn.MSELoss(reduction='none') 106 | def kl_criterion(mu1, stds1, mu2, stds2): 107 | # KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2)) = 108 | # log( sqrt( 109 | # 110 | kld = torch.log(stds2/stds1) + (stds1**2 + (mu1 - mu2)**2)/(2*stds2**2) - 1/2 111 | return kld.sum() 112 | 113 | 114 | repeat=0 115 | saved_model = torch.load('%s/model.pth' % log_dir) 116 | build_images_to_keypoints_net = saved_model['build_images_to_keypoints_net'] 117 | keypoints_to_images_net = saved_model['keypoints_to_images_net'] 118 | rnn_cell = dynanmics.convlstm_rnn_p(cfg, map_width=cfg.img_w, add_dim=1) 119 | keypoint_decoder = dynanmics.convlstm_decoder_p(cfg, add_dim=1) 120 | prior_net = dynanmics.prior_net_cnn(cfg) 121 | posterior_net = dynanmics.posterior_net_cnn(cfg) 122 | 123 | rnn_cell.apply(utils.init_weights) 124 | keypoint_decoder.apply(utils.init_weights) 125 | prior_net.apply(utils.init_weights) 126 | posterior_net.apply(utils.init_weights) 127 | 128 | 129 | rnn_cell_optimizer = optim.Adam(rnn_cell.parameters(), 130 | lr=cfg.learning_rate, weight_decay=cfg.reg_lambda) 131 | keypoint_decoder_optimizer = optim.Adam(keypoint_decoder.parameters(), 132 | lr=cfg.learning_rate, weight_decay=cfg.reg_lambda) 133 | prior_net_optimizer = optim.Adam(prior_net.parameters(), 134 | lr=cfg.learning_rate, weight_decay=cfg.reg_lambda) 135 | posterior_net_optimizer = optim.Adam(posterior_net.parameters(), 136 | lr=cfg.learning_rate, weight_decay=cfg.reg_lambda) 137 | 138 | # --------- transfer to gpu ------------------------------------ 139 | build_images_to_keypoints_net.cuda() 140 | keypoints_to_images_net.cuda() 141 | rnn_cell.cuda() 142 | keypoint_decoder.cuda() 143 | prior_net.cuda() 144 | posterior_net.cuda() 145 | 146 | # --------- load a dataset ------------------------------------ 147 | train_data, test_data = utils.load_dataset(cfg) 148 | 149 | train_loader = DataLoader(train_data, 150 | num_workers=4, 151 | batch_size=cfg.batch_size, 152 | shuffle=True, 153 | drop_last=True, 154 | pin_memory=True) 155 | test_loader = DataLoader(test_data, 156 | num_workers=4, 157 | batch_size=cfg.test_batch_size, 158 | shuffle=True, 159 | drop_last=True, 160 | pin_memory=True) 161 | 162 | def get_training_batch(): 163 | while True: 164 | for sequence in train_loader: 165 | batch = utils.normalize_data(cfg, dtype, sequence) 166 | yield batch 167 | 168 | training_batch_generator = get_training_batch() 169 | 170 | def get_testing_batch(): 171 | while True: 172 | for sequence in test_loader: 173 | batch = utils.normalize_data(cfg, dtype, sequence) 174 | yield batch 175 | 176 | testing_batch_generator = get_testing_batch() 177 | 178 | # --------- plotting funtions ------------------------------------ 179 | def plot(x, epoch): 180 | gen_keypoints = [] 181 | gen_seq = [] 182 | gt_seq = [x[i] for i in range(len(x))] 183 | 184 | observed_keypoints = [] 185 | for i in range(cfg.observed_steps): 186 | keypoints, _ = build_images_to_keypoints_net(x[i]) 187 | keypoints_np = keypoints.data.cpu().numpy() 188 | delta_xy = CalDelta_xy(keypoints_np, width=cfg.img_w) 189 | delta_xypt = torch.FloatTensor(delta_xy).cuda() 190 | keypoints = keypoints + delta_xypt 191 | observed_keypoints.append(keypoints.detach()) 192 | 193 | rnn_cell.batch_size = cfg.test_batch_size 194 | rnn_cell.hidden = rnn_cell.init_hidden() 195 | rnn_cell.batch_size = cfg.batch_size 196 | rnn_state = rnn_cell.hidden[0][0] 197 | 198 | for i in range(cfg.n_eval): 199 | if i < cfg.observed_steps: 200 | observed_keypoints_np = observed_keypoints[i].data.cpu().numpy() 201 | observed_keypoints_np[:, :, 0] = np.round((1 + observed_keypoints_np[:, :, 0]) 202 | * (cfg.img_w - 1) / 2) 203 | observed_keypoints_np[:, :, 1] = np.round((1 - observed_keypoints_np[:, :, 1]) 204 | * (cfg.img_w - 1) / 2) 205 | observed_keypoints_np = np.clip(observed_keypoints_np, 0, cfg.img_w - 1) 206 | observed_keypoints_id = observed_keypoints_np[:, :, 0] + \ 207 | observed_keypoints_np[:, :, 1] * cfg.img_w 208 | best_keypoints_id_flat = observed_keypoints_id.flatten().astype(int) 209 | observed_keypoints_map_np = np.zeros((cfg.test_batch_size, cfg.num_keypoints, 210 | cfg.img_w, cfg.img_w)) 211 | observed_keypoints_map_np_flat = observed_keypoints_map_np.reshape(cfg.test_batch_size * 212 | cfg.num_keypoints, -1) 213 | observed_keypoints_map_np_flat[range(cfg.test_batch_size * cfg.num_keypoints), 214 | best_keypoints_id_flat[range(cfg.test_batch_size 215 | * cfg.num_keypoints)]] = 1 216 | observed_keypoints_map_np = observed_keypoints_map_np_flat.reshape(cfg.test_batch_size, 217 | cfg.num_keypoints, 218 | cfg.img_w, -1) 219 | observed_keypoints_map_batch = torch.FloatTensor(observed_keypoints_map_np).cuda() 220 | observed_keypoints_batch = observed_keypoints[i] 221 | mean, std = posterior_net(rnn_state, observed_keypoints_map_batch) 222 | if i == 0: 223 | observed_keypoints_batch0 = observed_keypoints_batch 224 | else: 225 | mean_prior, std_prior = prior_net(rnn_state) 226 | mean = mean_prior.detach() 227 | std = std_prior.detach() 228 | eps = Variable(std.data.new(std.size()).normal_()) 229 | eps = eps * std + mean 230 | if i < cfg.observed_steps: 231 | keypoints = observed_keypoints_batch 232 | else: 233 | sampled_keypoints_flat = keypoint_decoder(rnn_state, eps).detach() 234 | sampled_keypoints_flat = torch.exp(sampled_keypoints_flat).data.cpu().numpy() 235 | sampled_keypoints_id = np.argmax(sampled_keypoints_flat, axis=-1) 236 | keypoints = np.zeros((cfg.test_batch_size, cfg.num_keypoints, 2)) 237 | keypoints[:, :, 0] = sampled_keypoints_id % cfg.img_w / (cfg.img_w - 1) * 2 + (-1) 238 | keypoints[:, :, 1] = 1 - sampled_keypoints_id // cfg.img_w / (cfg.img_w - 1) * 2 239 | keypoints = torch.FloatTensor(keypoints).cuda() 240 | observed_keypoints_map_np = np.zeros((cfg.test_batch_size, cfg.num_keypoints, 241 | cfg.img_w, cfg.img_w)) 242 | observed_keypoints_map_np_flat = observed_keypoints_map_np.reshape(cfg.test_batch_size * 243 | cfg.num_keypoints, -1) 244 | sampled_keypoints_id_flat = sampled_keypoints_id.flatten().astype(int) 245 | observed_keypoints_map_np_flat[range(cfg.test_batch_size * cfg.num_keypoints), 246 | sampled_keypoints_id_flat[range(cfg.test_batch_size 247 | * cfg.num_keypoints)]] = 1 248 | observed_keypoints_map_np = observed_keypoints_map_np_flat.reshape(cfg.test_batch_size, 249 | cfg.num_keypoints, 250 | cfg.img_w, -1) 251 | observed_keypoints_map_batch = torch.FloatTensor(observed_keypoints_map_np).cuda() 252 | reconstructed_image = keypoints_to_images_net(keypoints, x[0], observed_keypoints_batch0).detach() 253 | rnn_state = rnn_cell(observed_keypoints_map_batch, eps) 254 | gen_keypoints.append(ops.change2xy(keypoints)) 255 | gen_seq.append(reconstructed_image) 256 | 257 | to_plot = [] 258 | gifs = [[] for t in range(cfg.n_eval)] 259 | 260 | nrow = min(cfg.test_batch_size, 10) 261 | for i in range(nrow): 262 | # ground truth sequence 263 | row = [] 264 | for t in range(cfg.n_eval): 265 | row.append(gt_seq[t][i]) 266 | to_plot.append(row) 267 | 268 | row = [] 269 | for t in range(cfg.n_eval): 270 | row.append(ops.add_keypoints(gen_seq[t][i].clone(), gen_keypoints[t][i])) 271 | to_plot.append(row) 272 | 273 | for t in range(cfg.n_eval): 274 | row = [] 275 | row.append(gt_seq[t][i]) 276 | row.append(gen_seq[t][i]) 277 | gifs[t].append(row) 278 | 279 | fname = '%s/gen1/sample_%d.png' % (log_dir, epoch) 280 | utils.save_tensors_image(fname, to_plot) 281 | 282 | fname = '%s/gen1/sample_%d.gif' % (log_dir, epoch) 283 | utils.save_gif(fname, gifs) 284 | 285 | def val(x, epoch): 286 | observed_keypoints = [] 287 | for i in range(cfg.observed_steps): 288 | keypoints, _ = build_images_to_keypoints_net(x[i]) 289 | keypoints_np = keypoints.data.cpu().numpy() 290 | delta_xy = CalDelta_xy(keypoints_np, width=cfg.img_w) 291 | delta_xypt = torch.FloatTensor(delta_xy).cuda() 292 | keypoints = keypoints + delta_xypt 293 | observed_keypoints.append(keypoints.detach()) 294 | 295 | ssim = np.zeros((cfg.test_batch_size, cfg.nsample, cfg.n_eval)) 296 | psnr = np.zeros((cfg.test_batch_size, cfg.nsample, cfg.n_eval)) 297 | all_gen = [] 298 | all_gen_keypoints = [] 299 | for s in range(cfg.nsample): 300 | gen_seq = [] 301 | gt_seq = [] 302 | all_gen.append([]) 303 | all_gen_keypoints.append([]) 304 | rnn_cell.batch_size = cfg.test_batch_size 305 | rnn_cell.hidden = rnn_cell.init_hidden() 306 | rnn_cell.batch_size = cfg.batch_size 307 | rnn_state = rnn_cell.hidden[0][0] 308 | for i in range(cfg.n_eval): 309 | if i < cfg.observed_steps: 310 | observed_keypoints_np = observed_keypoints[i].data.cpu().numpy() 311 | observed_keypoints_np[:, :, 0] = np.round((1 + observed_keypoints_np[:, :, 0]) 312 | * (cfg.reso - 1) / 2) 313 | observed_keypoints_np[:, :, 1] = np.round((1 - observed_keypoints_np[:, :, 1]) 314 | * (cfg.reso - 1) / 2) 315 | observed_keypoints_np = np.clip(observed_keypoints_np, 0, cfg.reso - 1) 316 | observed_keypoints_id = observed_keypoints_np[:, :, 0] + \ 317 | observed_keypoints_np[:, :, 1] * cfg.reso 318 | best_keypoints_id_flat = observed_keypoints_id.flatten().astype(int) 319 | observed_keypoints_map_np = np.zeros((cfg.test_batch_size, cfg.num_keypoints, 320 | cfg.reso, cfg.reso)) 321 | observed_keypoints_map_np_flat = observed_keypoints_map_np.reshape(cfg.test_batch_size * 322 | cfg.num_keypoints, -1) 323 | observed_keypoints_map_np_flat[range(cfg.test_batch_size * cfg.num_keypoints), 324 | best_keypoints_id_flat[range(cfg.test_batch_size 325 | * cfg.num_keypoints)]] = 1 326 | observed_keypoints_map_np = observed_keypoints_map_np_flat.reshape(cfg.test_batch_size, 327 | cfg.num_keypoints, 328 | cfg.reso, -1) 329 | observed_keypoints_map_batch = torch.FloatTensor(observed_keypoints_map_np).cuda() 330 | observed_keypoints_batch = observed_keypoints[i] 331 | mean, std = posterior_net(rnn_state, observed_keypoints_map_batch) 332 | else: 333 | mean_prior, std_prior = prior_net(rnn_state) 334 | mean = mean_prior.detach() 335 | std = std_prior.detach() 336 | eps = Variable(std.data.new(std.size()).normal_()) 337 | eps = eps * std + mean 338 | if i < cfg.observed_steps: 339 | keypoints = observed_keypoints_batch 340 | else: 341 | sampled_keypoints_flat = keypoint_decoder(rnn_state, eps).detach() 342 | sampled_keypoints_flat = torch.exp(sampled_keypoints_flat).data.cpu().numpy() 343 | sampled_keypoints_id = np.argmax(sampled_keypoints_flat, axis=-1) 344 | keypoints = np.zeros((cfg.test_batch_size, cfg.num_keypoints, 2)) 345 | keypoints[:, :, 0] = sampled_keypoints_id % cfg.reso / (cfg.reso - 1) * 2 + (-1) 346 | keypoints[:, :, 1] = 1 - sampled_keypoints_id // cfg.reso / (cfg.reso - 1) * 2 347 | keypoints = torch.FloatTensor(keypoints).cuda() 348 | observed_keypoints_map_np = np.zeros((cfg.test_batch_size, cfg.num_keypoints, 349 | cfg.reso, cfg.reso)) 350 | observed_keypoints_map_np_flat = observed_keypoints_map_np.reshape(cfg.test_batch_size * 351 | cfg.num_keypoints, -1) 352 | sampled_keypoints_id_flat = sampled_keypoints_id.flatten().astype(int) 353 | observed_keypoints_map_np_flat[range(cfg.test_batch_size * cfg.num_keypoints), 354 | sampled_keypoints_id_flat[range(cfg.test_batch_size 355 | * cfg.num_keypoints)]] = 1 356 | observed_keypoints_map_np = observed_keypoints_map_np_flat.reshape(cfg.test_batch_size, 357 | cfg.num_keypoints, 358 | cfg.reso, -1) 359 | observed_keypoints_map_batch = torch.FloatTensor(observed_keypoints_map_np).cuda() 360 | rnn_state = rnn_cell(observed_keypoints_map_batch, eps) 361 | reconstructed_image = keypoints_to_images_net(keypoints, x[cfg.observed_steps - 1], 362 | observed_keypoints[cfg.observed_steps - 1]).detach() 363 | reconstructed_image = torch.clamp(reconstructed_image, -0.5, 0.5) 364 | 365 | all_gen_keypoints[s].append(ops.change2xy(keypoints)) 366 | all_gen[s].append(reconstructed_image) 367 | gen_seq.append(reconstructed_image.data.cpu().numpy() + 0.5) 368 | gt_seq.append(x[i].data.cpu().numpy() + 0.5) 369 | _, ssim[:, s, :], psnr[:, s, :] = utils.eval_seq(gt_seq, gen_seq) 370 | 371 | return ssim, psnr 372 | 373 | # --------- training funtions ------------------------------------ 374 | def train(x, train_step): 375 | rnn_cell.zero_grad() 376 | keypoint_decoder.zero_grad() 377 | prior_net.zero_grad() 378 | posterior_net.zero_grad() 379 | 380 | kl_divergence = torch.FloatTensor([0]).cuda() 381 | mse = torch.FloatTensor([0]).cuda() 382 | keypoints_pi = torch.FloatTensor([0]).cuda() 383 | observed_keypoints = [] 384 | for i in range(cfg.observed_steps + cfg.predicted_steps): 385 | keypoints, _ = build_images_to_keypoints_net(x[i]) 386 | keypoints_np = keypoints.data.cpu().numpy() 387 | delta_xy = CalDelta_xy(keypoints_np, width=cfg.img_w) 388 | delta_xypt = torch.FloatTensor(delta_xy).cuda() 389 | keypoints = keypoints + delta_xypt 390 | observed_keypoints.append(keypoints.detach()) 391 | #''' 392 | rnn_cell.hidden = rnn_cell.init_hidden() 393 | rnn_state = rnn_cell.hidden[0][0] 394 | 395 | for i in range(cfg.observed_steps + cfg.predicted_steps): 396 | observed_keypoints_map_np = np.zeros((cfg.batch_size, cfg.num_keypoints, 397 | cfg.img_w, cfg.img_w)) 398 | observed_keypoints_np = observed_keypoints[i].data.cpu().numpy() 399 | observed_keypoints_np[:, :, 0] = np.round((1 + observed_keypoints_np[:, :, 0]) 400 | * (cfg.img_w - 1) / 2) 401 | observed_keypoints_np[:, :, 1] = np.round((1 - observed_keypoints_np[:, :, 1]) 402 | * (cfg.img_w - 1) / 2) 403 | observed_keypoints_np = np.clip(observed_keypoints_np, 0, cfg.img_w - 1) 404 | eye = np.eye(cfg.img_w ** 2) 405 | observed_keypoints_id = observed_keypoints_np[:, :, 0] + \ 406 | observed_keypoints_np[:, :, 1] * cfg.img_w 407 | best_keypoints_id_flat = observed_keypoints_id.flatten().astype(int) 408 | probs = eye[best_keypoints_id_flat] 409 | probs = probs.reshape(cfg.batch_size, cfg.num_keypoints, -1) 410 | probs = torch.FloatTensor(probs).cuda() 411 | observed_keypoints_map_np_flat = observed_keypoints_map_np.reshape(cfg.batch_size*cfg.num_keypoints, -1) 412 | observed_keypoints_map_np_flat[range(cfg.batch_size*cfg.num_keypoints), 413 | best_keypoints_id_flat[range(cfg.batch_size*cfg.num_keypoints)] 414 | ] = 1 415 | observed_keypoints_map_np = observed_keypoints_map_np_flat.reshape(cfg.batch_size, cfg.num_keypoints, 416 | cfg.img_w, -1) 417 | observed_keypoints_map_batch = torch.FloatTensor(observed_keypoints_map_np).cuda() 418 | mean_prior, std_prior = prior_net(rnn_state) 419 | mean, std = posterior_net(rnn_state, observed_keypoints_map_batch) 420 | if i>0: 421 | kl_divergence += kl_criterion(mean_prior, std_prior, mean, std) 422 | 423 | # Conduct BestOfMany 424 | sampled_latent_list = [] 425 | sample_losses = [] 426 | for j in range(cfg.num_samples_for_bom): 427 | eps = Variable(std.data.new(std.size()).normal_()) 428 | eps = eps * std + mean 429 | sampled_latent_list.append(eps) 430 | sampled_keypoints_flat = keypoint_decoder(rnn_state, eps).detach() 431 | sample_losses.append(torch.sum(-probs * sampled_keypoints_flat, dim=(1,2)).detach()) 432 | _, best_sample_ind = torch.min(torch.stack(sample_losses), dim=0) 433 | best_sample_ind = best_sample_ind.detach().cpu().numpy() 434 | best_latent = torch.stack([sampled_latent_list[best_sample_ind[j]][j] 435 | for j in range(cfg.batch_size)]) 436 | best_keypoints_flat = keypoint_decoder(rnn_state, best_latent) 437 | keypoints_pi += torch.sum(-probs * best_keypoints_flat) 438 | rnn_state = rnn_cell(observed_keypoints_map_batch, best_latent) 439 | 440 | #''' 441 | keypoints_pi /= (cfg.observed_steps+cfg.predicted_steps)*cfg.num_keypoints*cfg.batch_size 442 | kl_divergence /= (cfg.observed_steps+cfg.predicted_steps-1)*cfg.batch_size 443 | 444 | loss = keypoints_pi + cfg.kl_loss_scale*kl_divergence 445 | loss.backward() 446 | 447 | torch.nn.utils.clip_grad_norm_(rnn_cell.parameters(), cfg.clipnorm) 448 | torch.nn.utils.clip_grad_norm_(keypoint_decoder.parameters(), cfg.clipnorm) 449 | torch.nn.utils.clip_grad_norm_(prior_net.parameters(), cfg.clipnorm) 450 | torch.nn.utils.clip_grad_norm_(posterior_net.parameters(), cfg.clipnorm) 451 | 452 | rnn_cell_optimizer.step() 453 | keypoint_decoder_optimizer.step() 454 | prior_net_optimizer.step() 455 | posterior_net_optimizer.step() 456 | 457 | return kl_divergence.data.cpu().numpy(), mse.data.cpu().numpy(), \ 458 | keypoints_pi.data.cpu().numpy() 459 | 460 | 461 | # --------- training loop ------------------------------------ 462 | train_step = 0 + repeat*cfg.num_epochs*cfg.steps_per_epoch 463 | for epoch in range(cfg.num_epochs): 464 | build_images_to_keypoints_net.eval() 465 | keypoints_to_images_net.eval() 466 | rnn_cell.train() 467 | keypoint_decoder.train() 468 | prior_net.train() 469 | posterior_net.train() 470 | 471 | epoch_kl = 0 472 | epoch_mse = 0 473 | epoch_keypoints_pi = 0 474 | 475 | #cfg.steps_per_epoch = 1 476 | for i in range(cfg.steps_per_epoch): 477 | x = next(training_batch_generator) # sequence and the sequence class number 478 | 479 | # train frame_predictor 480 | kl, mse, keypoints_pi = train(x, train_step) 481 | train_step += 1 482 | epoch_kl += kl 483 | epoch_mse += mse 484 | epoch_keypoints_pi += keypoints_pi 485 | 486 | print('[%02d] kl: %.5f | mse: %.5f | future_key_pi: %.5f (%s_all)' % ( 487 | epoch, epoch_kl / cfg.steps_per_epoch, 488 | epoch_mse / cfg.steps_per_epoch, 489 | epoch_keypoints_pi / cfg.steps_per_epoch, name)) 490 | with open(log_dir+'/'+"result_all_best.txt","a") as result: 491 | if epoch==0: 492 | result.write('\n') 493 | result.write('[%02d] kl: %.5f | mse: %.5f | future_key_pi: %.5f (%s_all)\n' % ( 494 | epoch, epoch_kl / cfg.steps_per_epoch, 495 | epoch_mse / cfg.steps_per_epoch, 496 | epoch_keypoints_pi / cfg.steps_per_epoch, name)) 497 | 498 | # plot some stuff 499 | build_images_to_keypoints_net.eval() 500 | keypoints_to_images_net.eval() 501 | rnn_cell.eval() 502 | keypoint_decoder.eval() 503 | prior_net.eval() 504 | posterior_net.eval() 505 | 506 | if (epoch+1)%100==0 or epoch==0: 507 | x = next(testing_batch_generator) 508 | plot(x, epoch) 509 | psnr_total = np.zeros((cfg.test_N, cfg.n_eval)) 510 | ssim_total = np.zeros((cfg.test_N, cfg.n_eval)) 511 | for i in range(0, cfg.test_N, cfg.test_batch_size): 512 | # plot test 513 | test_x = next(testing_batch_generator) 514 | ssim, psnr = val(test_x, epoch) 515 | for j in range(0, cfg.test_batch_size): 516 | psnr_total[i + j, :] = psnr[j, np.argmax(np.mean(psnr[j], axis=1)), :] 517 | ssim_total[i + j, :] = ssim[j, np.argmax(np.mean(ssim[j], axis=1)), :] 518 | 519 | ssim_val = int("{:3.0f}".format(np.mean(ssim_total[:, cfg.observed_steps:]) * 1000)) 520 | psnr_val = int("{:4.0f}".format(np.mean(psnr_total[:, cfg.observed_steps:]) * 100)) 521 | # save the model 522 | torch.save({ 523 | 'build_images_to_keypoints_net': build_images_to_keypoints_net, 524 | 'keypoints_to_images_net': keypoints_to_images_net, 525 | 'rnn_cell': rnn_cell, 526 | 'keypoint_decoder': keypoint_decoder, 527 | 'prior_net': prior_net, 528 | 'posterior_net': posterior_net}, 529 | '%s/model_all_epoch%d_kl%.2f_ssim%d_psnr_%d.pth' % (log_dir, epoch, cfg.kl_loss_scale, ssim_val, psnr_val)) 530 | 531 | 532 | if (epoch + 1) % 500 == 0: 533 | cfg.learning_rate /= 4 534 | utils.set_learning_rate(rnn_cell_optimizer, cfg.learning_rate) 535 | utils.set_learning_rate(keypoint_decoder_optimizer, cfg.learning_rate) 536 | utils.set_learning_rate(prior_net_optimizer, cfg.learning_rate) 537 | utils.set_learning_rate(posterior_net_optimizer, cfg.learning_rate) 538 | 539 | 540 | 541 | 542 | --------------------------------------------------------------------------------