├── requirements.txt ├── train.sh ├── modules ├── keypoint_detector.py ├── bg_motion_predictor.py ├── fg_motion_predictor.py ├── avd_network.py ├── inpainting_network.py ├── model.py ├── dense_motion.py └── util.py ├── dataset ├── Mixed_dataset_test.csv ├── copy_.py └── crop.py ├── README.md ├── reconstruction.py ├── run.py ├── predict.py ├── animate.py ├── train.py ├── config └── Mixed_data-10-8-wMaskWarp-aug.yaml ├── frames_dataset.py ├── logger.py ├── demo.py └── augmentation.py /requirements.txt: -------------------------------------------------------------------------------- 1 | cffi==1.14.6 2 | cycler==0.10.0 3 | decorator==5.1.0 4 | face-alignment==1.3.5 5 | imageio==2.9.0 6 | imageio-ffmpeg==0.4.5 7 | kiwisolver==1.3.2 8 | matplotlib==3.4.3 9 | networkx==2.6.3 10 | numpy==1.20.3 11 | pandas==1.3.3 12 | Pillow==8.3.2 13 | pycparser==2.20 14 | pyparsing==2.4.7 15 | python-dateutil==2.8.2 16 | pytz==2021.1 17 | PyWavelets==1.1.1 18 | PyYAML==5.4.1 19 | scikit-image==0.18.3 20 | scikit-learn==1.0 21 | scipy==1.7.1 22 | six==1.16.0 23 | torch==1.10.0+cu113 24 | torchvision==0.11.0+cu113 25 | tqdm==4.62.3 -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | ### ---------------------------training------------------------------- 2 | # mask warp 3 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py --config config/Mixed_data-10-8-wMaskWarp-aug.yaml 4 | 5 | ### ---------------------------generation------------------------------- 6 | ## mask warp 7 | # relative nobest 8 | CUDA_VISIBLE_DEVICES=0 python demo.py --config config/Mixed_data-10-8-wMaskWarp-aug.yaml \ 9 | --checkpoint 'path to the checkpoints' \ 10 | --result_video './ckpt/relative-nobest/wMaskWarp' \ 11 | --mode 'relative' 12 | 13 | -------------------------------------------------------------------------------- /modules/keypoint_detector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | 5 | class KPDetector(nn.Module): 6 | """ 7 | Predict K*N keypoints. 8 | """ 9 | 10 | def __init__(self, num_tps, **kwargs): 11 | super(KPDetector, self).__init__() 12 | self.num_tps = num_tps 13 | self.num_kps = kwargs['num_kps'] 14 | 15 | self.fg_encoder = models.resnet18(pretrained=False) 16 | num_features = self.fg_encoder.fc.in_features 17 | self.fg_encoder.fc = nn.Linear(num_features, num_tps * self.num_kps * 2) ## 直接坐标回归出若干个点 18 | 19 | 20 | def forward(self, image): 21 | 22 | fg_kp = self.fg_encoder(image) 23 | bs, _, = fg_kp.shape 24 | fg_kp = torch.sigmoid(fg_kp) 25 | fg_kp = fg_kp * 2 - 1 26 | out = {'fg_kp': fg_kp.view(bs, self.num_tps * self.num_kps, -1)} ## bs, self.num_tps*5, 2 27 | 28 | return out 29 | -------------------------------------------------------------------------------- /modules/bg_motion_predictor.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | 5 | class BGMotionPredictor(nn.Module): 6 | """ 7 | Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1] 8 | """ 9 | 10 | def __init__(self): 11 | super(BGMotionPredictor, self).__init__() 12 | self.bg_encoder = models.resnet18(pretrained=False) 13 | self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 14 | num_features = self.bg_encoder.fc.in_features 15 | self.bg_encoder.fc = nn.Linear(num_features, 6) 16 | self.bg_encoder.fc.weight.data.zero_() 17 | self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) 18 | 19 | def forward(self, source_image, driving_image): 20 | bs = source_image.shape[0] 21 | out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type()) ## bs,3,3 22 | prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1)) ## 两张图像放在一起求的仿射变换矩阵 23 | out[:, :2, :] = prediction.view(bs, 2, 3) ## homo transformation matrix bs 3 3 24 | return out 25 | -------------------------------------------------------------------------------- /modules/fg_motion_predictor.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision import models 4 | 5 | class FGMotionPredictor(nn.Module): 6 | """ 7 | Module for foreground estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1] 8 | """ 9 | 10 | def __init__(self): 11 | super(FGMotionPredictor, self).__init__() 12 | self.bg_encoder = models.resnet18(pretrained=False) 13 | self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 14 | num_features = self.bg_encoder.fc.in_features 15 | self.bg_encoder.fc = nn.Linear(num_features, 8) 16 | self.bg_encoder.fc.weight.data.zero_() 17 | self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0], dtype=torch.float)) 18 | 19 | 20 | def forward(self, source_image, driving_image): 21 | bs = source_image.shape[0] 22 | out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type()) ## bs,3,3 23 | out = out.view(bs, -1) 24 | prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1)) ## 两张图像放在一起求的仿射变换矩阵 25 | out[:,:8] = prediction 26 | out = out.view(bs, 3, 3) 27 | return out -------------------------------------------------------------------------------- /dataset/Mixed_dataset_test.csv: -------------------------------------------------------------------------------- 1 | distance,source,driving,frame 2 | 0,normalized_westernMale,Surprise_007_7_1,0 3 | 0,normalized_westernMale,Positive_022_3_3,0 4 | 0,normalized_westernMale,Negative_018_3_1,0 5 | 0,normalized_asianFemale,Surprise_007_7_1,0 6 | 0,normalized_asianFemale,Positive_022_3_3,0 7 | 0,normalized_asianFemale,Negative_018_3_1,0 8 | 0,normalized_westernMale,Surprise_EP01_13,0 9 | 0,normalized_westernMale,Positive_EP01_01f,0 10 | 0,normalized_westernMale,Negative_EP19_06f,0 11 | 0,normalized_asianFemale,Surprise_EP01_13,0 12 | 0,normalized_asianFemale,Positive_EP01_01f,0 13 | 0,normalized_asianFemale,Negative_EP19_06f,0 14 | 0,normalized_westernMale,Surprise_s20_sur_01,0 15 | 0,normalized_westernMale,Positive_s3_po_05,0 16 | 0,normalized_westernMale,Negative_s11_ne_02,0 17 | 0,normalized_asianFemale,Surprise_s20_sur_01,0 18 | 0,normalized_asianFemale,Positive_s3_po_05,0 19 | 0,normalized_asianFemale,Negative_s11_ne_02,0 20 | 0,normalized_westernFemale,Surprise_007_7_1,0 21 | 0,normalized_westernFemale,Positive_022_3_3,0 22 | 0,normalized_westernFemale,Negative_018_3_1,0 23 | 0,normalized_asianMale,Surprise_007_7_1,0 24 | 0,normalized_asianMale,Positive_022_3_3,0 25 | 0,normalized_asianMale,Negative_018_3_1,0 26 | 0,normalized_westernFemale,Surprise_EP01_13,0 27 | 0,normalized_westernFemale,Positive_EP01_01f,0 28 | 0,normalized_westernFemale,Negative_EP19_06f,0 29 | 0,normalized_asianMale,Surprise_EP01_13,0 30 | 0,normalized_asianMale,Positive_EP01_01f,0 31 | 0,normalized_asianMale,Negative_EP19_06f,0 32 | 0,normalized_westernFemale,Surprise_s20_sur_01,0 33 | 0,normalized_westernFemale,Positive_s3_po_05,0 34 | 0,normalized_westernFemale,Negative_s11_ne_02,0 35 | 0,normalized_asianMale,Surprise_s20_sur_01,0 36 | 0,normalized_asianMale,Positive_s3_po_05,0 37 | 0,normalized_asianMale,Negative_s11_ne_02,0 38 | -------------------------------------------------------------------------------- /modules/avd_network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class AVDNetwork(nn.Module): 7 | """ 8 | Animation via Disentanglement network 9 | """ 10 | 11 | def __init__(self, num_tps, num_kps, id_bottle_size=64, pose_bottle_size=64): 12 | super(AVDNetwork, self).__init__() 13 | input_size = num_kps * 2 * num_tps 14 | self.num_tps = num_tps 15 | self.num_kps = num_kps 16 | 17 | self.id_encoder = nn.Sequential( 18 | nn.Linear(input_size, 256), 19 | nn.BatchNorm1d(256), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(256, 512), 22 | nn.BatchNorm1d(512), 23 | nn.ReLU(inplace=True), 24 | nn.Linear(512, 1024), 25 | nn.BatchNorm1d(1024), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(1024, id_bottle_size) 28 | ) 29 | 30 | self.pose_encoder = nn.Sequential( 31 | nn.Linear(input_size, 256), 32 | nn.BatchNorm1d(256), 33 | nn.ReLU(inplace=True), 34 | nn.Linear(256, 512), 35 | nn.BatchNorm1d(512), 36 | nn.ReLU(inplace=True), 37 | nn.Linear(512, 1024), 38 | nn.BatchNorm1d(1024), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(1024, pose_bottle_size) 41 | ) 42 | 43 | self.decoder = nn.Sequential( 44 | nn.Linear(pose_bottle_size + id_bottle_size, 1024), 45 | nn.BatchNorm1d(1024), 46 | nn.ReLU(), 47 | nn.Linear(1024, 512), 48 | nn.BatchNorm1d(512), 49 | nn.ReLU(), 50 | nn.Linear(512, 256), 51 | nn.BatchNorm1d(256), 52 | nn.ReLU(), 53 | nn.Linear(256, input_size) 54 | ) 55 | 56 | def forward(self, kp_source, kp_random): 57 | 58 | bs = kp_source['fg_kp'].shape[0] 59 | 60 | # print(kp_random['fg_kp'].view(bs, -1).size()) 61 | pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1)) 62 | id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1)) 63 | 64 | rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1)) 65 | 66 | rec = {'fg_kp': rec.view(bs, self.num_tps*self.num_kps, -1)} 67 | return rec 68 | -------------------------------------------------------------------------------- /dataset/copy_.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | parent_dir = { 5 | 'CASMEII':"CASMEII/CASME2_RAW_selected_cropped", 6 | 'SAMM':"SAMM/SAMM_cropped", 7 | 'SMIC':"SMIC/SMIC_all_raw/HS_cropped" 8 | } 9 | 10 | if __name__=='__main__': 11 | os.makedirs("./Mixed_dataset/train", exist_ok=True) 12 | for dataset in ["SAMM","SMIC"]: 13 | for ID in os.listdir(os.path.join(parent_dir[dataset])): 14 | for item in os.listdir(os.path.join(parent_dir[dataset],ID)): 15 | src_folder = os.path.join(os.path.join(parent_dir[dataset],ID,item)) 16 | # target_folder = os.path.join("/data/home-ustc/xgc18/competition/MEGC2022/Code/Facial-Prior-Based-FOMM-main/data/Mixed_dataset/train",dataset+'_'+item) 17 | target_folder = os.path.join("./Mixed_dataset/train",dataset+'_'+item) 18 | shutil.copytree(src_folder,target_folder) 19 | 20 | for dataset in ["CASMEII"]: 21 | for ID in os.listdir(os.path.join(parent_dir[dataset])): 22 | for item in os.listdir(os.path.join(parent_dir[dataset],ID)): 23 | src_folder = os.path.join(os.path.join(parent_dir[dataset],ID,item)) 24 | # target_folder = os.path.join("/data/home-ustc/xgc18/competition/MEGC2022/Code/Facial-Prior-Based-FOMM-main/data/Mixed_dataset/train",dataset+'_'+ID+'_'+item) 25 | target_folder = os.path.join("./Mixed_dataset/train",dataset+'_'+ID+'_'+item) 26 | shutil.copytree(src_folder,target_folder) 27 | 28 | os.makedirs("./Mixed_dataset/test", exist_ok=True) 29 | for fold in os.listdir(os.path.join("megc2022-synthesis/source_samples_cropped")): 30 | for item in os.listdir(os.path.join("megc2022-synthesis/source_samples_cropped", fold)): 31 | src_fold = os.path.join("megc2022-synthesis/source_samples_cropped", fold, item) 32 | tgt_fold = os.path.join("./Mixed_dataset/test", item) 33 | shutil.copytree(src_fold, tgt_fold) 34 | 35 | name2dir={ 36 | "Template_Female_Asian.jpg": "normalized_asianFemale", 37 | "Template_Female_Europe.jpg": "normalized_westernFemale", 38 | "Template_Male_Asian.jpg": "normalized_asianMale", 39 | "Template_Male_Europe.JPG": "normalized_westernMale" 40 | } 41 | 42 | for name in os.listdir(os.path.join("megc2022-synthesis/target_template_face_cropped")): 43 | os.makedirs(os.path.join("./Mixed_dataset/test",name2dir[name]), exist_ok=True) 44 | src_img = os.path.join("megc2022-synthesis/target_template_face_cropped", name) 45 | tgt_img = os.path.join("./Mixed_dataset/test",name2dir[name], name) 46 | shutil.copy(src_img, tgt_img) 47 | 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Micro Expression Generation with Thin-plate Spline Motion Model and Face Parsing 2 | 3 | ### Installation 4 | 5 | We support ```python3```.(Recommended version is Python 3.9). 6 | To install the dependencies run: 7 | 8 | ```bash 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ### YAML configs 13 | 14 | In our method, all the configurations are contained in the file ```config/Mixed_data-10-8-wMaskWarp-aug.yaml```. 15 | 16 | ## Datasets 17 | 18 | 1. Download three datasets [CASME II](http://fu.psych.ac.cn/CASME/casme2-en.php), [SMIC](https://www.oulu.fi/cmvs/node/41319), [SAMM](http://www2.docm.mmu.ac.uk/STAFF/M.Yap/dataset.php) 19 | 20 | 2. Download the test dataset `megc2022-synthesis` 21 | 22 | 3. Download the `shape_predictor_68_face_landmarks.dat` and put it in the `dataset`folder 23 | 24 | 4. Put the three training set and one test set in the `dataset`folder. The file tree is shown as follows: 25 | 26 | ``` 27 | . 28 | ├── CASMEII 29 | │   ├── CASME2-coding-20190701.xlsx 30 | │   ├── CASME2_RAW_selected 31 | ├── copy_.py 32 | ├── crop.py 33 | ├── megc2022-synthesis 34 | │   ├── source_samples 35 | │   ├── target_template_face 36 | ├── SAMM 37 | │   ├── SAMM 38 | │   ├── SAMM_Micro_FACS_Codes_v2.xlsx 39 | ├── shape_predictor_68_face_landmarks.dat 40 | └── SMIC 41 | ├── SMIC_all_raw 42 | 43 | 44 | ``` 45 | 46 | 5. Run the following code 47 | 48 | ``` 49 | cd dataset 50 | python crop.py 51 | python copy_.py 52 | mv Mixed_dataset_test.csv ./Mixed_dataset 53 | cd .. 54 | ``` 55 | 56 | the root of the preprocessed dataset is `./dataset/Mixed_dataset` 57 | 58 | 6. Download the [train_mask.tar.gz](https://drive.google.com/file/d/1nv5auh3hYdQK9OiiUH_8ts7LnF7a0bLW/view?usp=sharing) and unzip it, then put it in the `./dataset/Mixed_dataset/train_mask` 59 | 60 | ## Training 61 | 62 | To train a model on specific dataset run: 63 | 64 | ``` 65 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \ 66 |         --config config/Mixed_data-10-8-wMaskWarp-aug.yaml \ 67 |         --device_ids 0,1,2,3 68 | ``` 69 | 70 | A log folder named after the timestamp will be created. Checkpoints, loss values, reconstruction results will be saved to this folder. 71 | 72 | ## Micro expression generation 73 | 74 | ``` 75 | CUDA_VISIBLE_DEVICES=0 python demo.py \ 76 |     --config config/Mixed_data-10-8-wMaskWarp-aug.yaml \ 77 | --checkpoint 'path to the checkpoint' \ 78 | --result_video './ckpt/relative' \ 79 | --mode 'relative' 80 | ``` 81 | Our provided model can be downloaded [here](https://drive.google.com/file/d/1zdN-mPwWANMUnPQCv1Ho41JlsRtl4iqv/view?usp=sharing) 82 | The final results are in the folder `./ckpt/relative` . 83 | 84 | # Acknowledgments 85 | 86 | The main code is based upon [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), [MRAA](https://github.com/snap-research/articulated-animation) and [TPS](https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model) 87 | 88 | Thanks for the excellent works! 89 | -------------------------------------------------------------------------------- /reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from logger import Logger, Visualizer 6 | import numpy as np 7 | import imageio 8 | 9 | 10 | def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset): 11 | png_dir = os.path.join(log_dir, 'reconstruction/png') 12 | log_dir = os.path.join(log_dir, 'reconstruction') 13 | 14 | if checkpoint is not None: 15 | Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector, 16 | bg_predictor=bg_predictor, dense_motion_network=dense_motion_network) 17 | else: 18 | raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") 19 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 20 | 21 | if not os.path.exists(log_dir): 22 | os.makedirs(log_dir) 23 | 24 | if not os.path.exists(png_dir): 25 | os.makedirs(png_dir) 26 | 27 | loss_list = [] 28 | 29 | inpainting_network.eval() 30 | kp_detector.eval() 31 | dense_motion_network.eval() 32 | if bg_predictor: 33 | bg_predictor.eval() 34 | 35 | for it, x in tqdm(enumerate(dataloader)): 36 | with torch.no_grad(): 37 | predictions = [] 38 | visualizations = [] 39 | if torch.cuda.is_available(): 40 | x['video'] = x['video'].cuda() 41 | kp_source = kp_detector(x['video'][:, :, 0]) 42 | for frame_idx in range(x['video'].shape[2]): 43 | source = x['video'][:, :, 0] 44 | driving = x['video'][:, :, frame_idx] 45 | kp_driving = kp_detector(driving) 46 | bg_params = None 47 | if bg_predictor: 48 | bg_params = bg_predictor(source, driving) 49 | 50 | dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving, 51 | kp_source=kp_source, bg_param = bg_params, 52 | dropout_flag = False) 53 | out = inpainting_network(source, dense_motion) 54 | out['kp_source'] = kp_source 55 | out['kp_driving'] = kp_driving 56 | 57 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 58 | 59 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source, 60 | driving=driving, out=out) 61 | visualizations.append(visualization) 62 | loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy() 63 | 64 | loss_list.append(loss) 65 | # print(np.mean(loss_list)) 66 | predictions = np.concatenate(predictions, axis=1) 67 | imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) 68 | 69 | print("Reconstruction loss: %s" % np.mean(loss_list)) 70 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import os, sys 5 | import yaml 6 | from argparse import ArgumentParser 7 | from time import gmtime, strftime 8 | from shutil import copy 9 | from frames_dataset import FramesDataset 10 | 11 | from modules.inpainting_network import InpaintingNetwork 12 | from modules.keypoint_detector import KPDetector 13 | from modules.bg_motion_predictor import BGMotionPredictor 14 | from modules.fg_motion_predictor import FGMotionPredictor 15 | from modules.dense_motion import DenseMotionNetwork 16 | from modules.avd_network import AVDNetwork 17 | import torch 18 | from train import train 19 | from reconstruction import reconstruction 20 | from animate import animate 21 | import os 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | if sys.version_info[0] < 3: 27 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9") 28 | 29 | parser = ArgumentParser() 30 | parser.add_argument("--config", default="config/vox-256.yaml", help="path to config") 31 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"]) 32 | parser.add_argument("--log_dir", default='log', help="path to log into") 33 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") 34 | parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))), 35 | help="Names of the devices comma separated.") 36 | 37 | opt = parser.parse_args() 38 | with open(opt.config) as f: 39 | config = yaml.load(f, Loader=yaml.FullLoader) 40 | 41 | if opt.checkpoint is not None: 42 | log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) 43 | else: 44 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) 45 | log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) 46 | 47 | inpainting = InpaintingNetwork(**config['model_params']['generator_params'], 48 | **config['model_params']['common_params']) 49 | 50 | if torch.cuda.is_available(): 51 | cuda_device = torch.device('cuda:'+str(opt.device_ids[0])) 52 | inpainting.to(cuda_device) 53 | 54 | kp_detector = KPDetector(**config['model_params']['common_params']) 55 | dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], 56 | **config['model_params']['dense_motion_params']) 57 | 58 | if torch.cuda.is_available(): 59 | kp_detector.to(opt.device_ids[0]) 60 | dense_motion_network.to(opt.device_ids[0]) 61 | 62 | bg_predictor = None 63 | if (config['model_params']['common_params']['bg']): 64 | bg_predictor = BGMotionPredictor() 65 | if torch.cuda.is_available(): 66 | bg_predictor.to(opt.device_ids[0]) 67 | 68 | fg_predictor = None 69 | if (config['model_params']['common_params']['fg']): 70 | fg_predictor = FGMotionPredictor() 71 | if torch.cuda.is_available(): 72 | fg_predictor.to(opt.device_ids[0]) 73 | 74 | avd_network = None 75 | if opt.mode == "train_avd": 76 | avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], 77 | num_kps=config['model_params']['common_params']['num_kps'], 78 | **config['model_params']['avd_network_params']) 79 | if torch.cuda.is_available(): 80 | avd_network.to(opt.device_ids[0]) 81 | 82 | dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params']) 83 | 84 | if not os.path.exists(log_dir): 85 | os.makedirs(log_dir) 86 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): 87 | copy(opt.config, log_dir) 88 | 89 | if opt.mode == 'train': 90 | print("Training...") 91 | # train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) 92 | train(config, inpainting, kp_detector, bg_predictor, fg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) 93 | elif opt.mode == 'reconstruction': 94 | print("Reconstruction...") 95 | reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) 96 | 97 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, "stylegan-encoder") 4 | import tempfile 5 | import warnings 6 | import imageio 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import matplotlib.animation as animation 10 | from skimage.transform import resize 11 | from skimage import img_as_ubyte 12 | import torch 13 | import torchvision.transforms as transforms 14 | import dlib 15 | from cog import BasePredictor, Path, Input 16 | 17 | from demo import load_checkpoints 18 | from demo import make_animation 19 | from ffhq_dataset.face_alignment import image_align 20 | from ffhq_dataset.landmarks_detector import LandmarksDetector 21 | 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 27 | LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat") 28 | 29 | 30 | class Predictor(BasePredictor): 31 | def setup(self): 32 | 33 | self.device = torch.device("cuda:0") 34 | datasets = ["vox", "taichi", "ted", "mgif"] 35 | ( 36 | self.inpainting, 37 | self.kp_detector, 38 | self.dense_motion_network, 39 | self.avd_network, 40 | ) = ({}, {}, {}, {}) 41 | for d in datasets: 42 | ( 43 | self.inpainting[d], 44 | self.kp_detector[d], 45 | self.dense_motion_network[d], 46 | self.avd_network[d], 47 | ) = load_checkpoints( 48 | config_path=f"config/{d}-384.yaml" 49 | if d == "ted" 50 | else f"config/{d}-256.yaml", 51 | checkpoint_path=f"checkpoints/{d}.pth.tar", 52 | device=self.device, 53 | ) 54 | 55 | def predict( 56 | self, 57 | source_image: Path = Input( 58 | description="Input source image.", 59 | ), 60 | driving_video: Path = Input( 61 | description="Choose a micromotion.", 62 | ), 63 | dataset_name: str = Input( 64 | choices=["vox", "taichi", "ted", "mgif"], 65 | default="vox", 66 | description="Choose a dataset.", 67 | ), 68 | ) -> Path: 69 | 70 | predict_mode = "relative" # ['standard', 'relative', 'avd'] 71 | # find_best_frame = False 72 | 73 | pixel = 384 if dataset_name == "ted" else 256 74 | 75 | if dataset_name == "vox": 76 | # first run face alignment 77 | align_image(str(source_image), 'aligned.png') 78 | source_image = imageio.imread('aligned.png') 79 | else: 80 | source_image = imageio.imread(str(source_image)) 81 | reader = imageio.get_reader(str(driving_video)) 82 | fps = reader.get_meta_data()["fps"] 83 | source_image = resize(source_image, (pixel, pixel))[..., :3] 84 | 85 | driving_video = [] 86 | try: 87 | for im in reader: 88 | driving_video.append(im) 89 | except RuntimeError: 90 | pass 91 | reader.close() 92 | 93 | driving_video = [ 94 | resize(frame, (pixel, pixel))[..., :3] for frame in driving_video 95 | ] 96 | 97 | inpainting, kp_detector, dense_motion_network, avd_network = ( 98 | self.inpainting[dataset_name], 99 | self.kp_detector[dataset_name], 100 | self.dense_motion_network[dataset_name], 101 | self.avd_network[dataset_name], 102 | ) 103 | 104 | predictions = make_animation( 105 | source_image, 106 | driving_video, 107 | inpainting, 108 | kp_detector, 109 | dense_motion_network, 110 | avd_network, 111 | device="cuda:0", 112 | mode=predict_mode, 113 | ) 114 | 115 | # save resulting video 116 | out_path = Path(tempfile.mkdtemp()) / "output.mp4" 117 | imageio.mimsave( 118 | str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps 119 | ) 120 | return out_path 121 | 122 | 123 | def align_image(raw_img_path, aligned_face_path): 124 | for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1): 125 | image_align(raw_img_path, aligned_face_path, face_landmarks) 126 | -------------------------------------------------------------------------------- /animate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from frames_dataset import PairedDataset 8 | from logger import Logger, Visualizer 9 | import imageio 10 | from scipy.spatial import ConvexHull 11 | import numpy as np 12 | 13 | # from sync_batchnorm import DataParallelWithCallback 14 | from skimage import img_as_ubyte 15 | 16 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, 17 | use_relative_movement=False, use_relative_jacobian=False): 18 | if adapt_movement_scale: 19 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume 20 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume 21 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) 22 | else: 23 | adapt_movement_scale = 1 24 | 25 | kp_new = {k: v for k, v in kp_driving.items()} 26 | 27 | if use_relative_movement: 28 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) 29 | kp_value_diff *= adapt_movement_scale 30 | kp_new['value'] = kp_value_diff + kp_source['value'] 31 | 32 | if use_relative_jacobian: 33 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) 34 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) 35 | 36 | return kp_new 37 | 38 | 39 | def animate(config, generator, kp_detector, checkpoint, log_dir, dataset): 40 | log_dir = os.path.join(log_dir, 'animation') 41 | png_dir = os.path.join(log_dir, 'png') 42 | animate_params = config['animate_params'] 43 | 44 | dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs']) 45 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 46 | 47 | if checkpoint is not None: 48 | Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) 49 | else: 50 | raise AttributeError("Checkpoint should be specified for mode='animate'.") 51 | 52 | if not os.path.exists(log_dir): 53 | os.makedirs(log_dir) 54 | 55 | if not os.path.exists(png_dir): 56 | os.makedirs(png_dir) 57 | 58 | if torch.cuda.is_available(): 59 | generator = torch.nn.DataParallel(generator) 60 | kp_detector = torch.nn.DataParallel(kp_detector) 61 | 62 | generator.eval() 63 | kp_detector.eval() 64 | 65 | for it, x in tqdm(enumerate(dataloader)): 66 | with torch.no_grad(): 67 | predictions = [] 68 | visualizations = [] 69 | 70 | driving_video = x['driving_video'] 71 | source_frame = x['source_video'][:, :, 0, :, :] 72 | 73 | kp_source = kp_detector(source_frame, x['source_keypoint']) 74 | 75 | 76 | kp_driving_initial = kp_detector(driving_video[:, :, 0], x['driving_keypoint']) 77 | 78 | for frame_idx in range(driving_video.shape[2]): 79 | driving_frame = driving_video[:, :, frame_idx] 80 | kp_driving = kp_detector(driving_frame, x['driving_keypoint']) 81 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, 82 | kp_driving_initial=kp_driving_initial, **animate_params['normalization_params']) 83 | out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm) 84 | 85 | out['kp_driving'] = kp_driving 86 | out['kp_source'] = kp_source 87 | out['kp_norm'] = kp_norm 88 | 89 | del out['sparse_deformed'] 90 | 91 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 92 | 93 | visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame, 94 | driving=driving_frame, out=out) 95 | visualization = visualization 96 | visualizations.append(visualization) 97 | 98 | predictions_ = np.concatenate(predictions, axis=1) 99 | result_name = "-".join([x['driving_name'][0], x['source_name'][0]]) 100 | imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions_).astype(np.uint8)) 101 | 102 | image_name = result_name + animate_params['format'] 103 | # imageio.mimsave(os.path.join(log_dir, image_name), visualizations) 104 | imageio.mimsave(os.path.join(log_dir, image_name), [img_as_ubyte(frame) for frame in predictions], fps = 100) 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm, trange 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from logger import Logger 5 | from modules.model import GeneratorFullModel 6 | from torch.optim.lr_scheduler import MultiStepLR 7 | from torch.nn.utils import clip_grad_norm_ 8 | from frames_dataset import DatasetRepeater 9 | import math 10 | 11 | def train(config, inpainting_network, kp_detector, bg_predictor, fg_predictor, dense_motion_network, checkpoint, log_dir, dataset): 12 | train_params = config['train_params'] 13 | optimizer = torch.optim.Adam( 14 | [{'params': list(inpainting_network.parameters()) + 15 | list(dense_motion_network.parameters()) + 16 | list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4) 17 | 18 | optimizer_bg_predictor = None 19 | param_bg_fg = [] 20 | if bg_predictor: 21 | print("bg_predictor created") 22 | param_bg_fg += list(bg_predictor.parameters()) 23 | 24 | if fg_predictor: 25 | print("fg_predictor created") 26 | param_bg_fg += list(fg_predictor.parameters()) 27 | 28 | optimizer_bg_predictor = torch.optim.Adam( 29 | [{'params':param_bg_fg,'initial_lr': train_params['lr_generator']}], 30 | lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4) 31 | 32 | if checkpoint is not None: 33 | start_epoch = Logger.load_cpk( 34 | checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network, 35 | kp_detector = kp_detector, bg_predictor = bg_predictor, fg_predictor = fg_predictor, 36 | optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor) 37 | print('load success:', start_epoch) 38 | start_epoch += 1 39 | else: 40 | start_epoch = 0 41 | 42 | scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1, 43 | last_epoch=start_epoch - 1) 44 | if bg_predictor or fg_predictor: 45 | scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'], 46 | gamma=0.1, last_epoch=start_epoch - 1) 47 | 48 | if 'num_repeats' in train_params or train_params['num_repeats'] != 1: 49 | dataset = DatasetRepeater(dataset, train_params['num_repeats']) 50 | print("length of the dataset is {}".format(len(dataset))) 51 | dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, 52 | num_workers=train_params['dataloader_workers'], drop_last=True) 53 | 54 | generator_full = GeneratorFullModel(kp_detector, bg_predictor, fg_predictor, dense_motion_network, inpainting_network, train_params) 55 | 56 | if torch.cuda.is_available(): 57 | generator_full = torch.nn.DataParallel(generator_full).cuda() 58 | 59 | bg_start = train_params['bg_start'] 60 | fg_start = train_params['fg_start'] 61 | 62 | with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], 63 | checkpoint_freq=train_params['checkpoint_freq']) as logger: 64 | for epoch in trange(start_epoch, train_params['num_epochs']): 65 | for _, x in tqdm(enumerate(dataloader)): 66 | if(torch.cuda.is_available()): 67 | x['driving'] = x['driving'].cuda() 68 | x['source'] = x['source'].cuda() 69 | x['source_mask'] = x['source_mask'].cuda() 70 | 71 | losses_generator, generated = generator_full(x, epoch) 72 | 73 | loss_values = [val.mean() for val in losses_generator.values()] 74 | loss = sum(loss_values) 75 | loss.backward() 76 | 77 | clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf) 78 | clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf) 79 | if bg_predictor and epoch>=bg_start: 80 | clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf) 81 | if fg_predictor and epoch>=fg_start: 82 | clip_grad_norm_(fg_predictor.parameters(), max_norm=10, norm_type = math.inf) 83 | 84 | optimizer.step() 85 | optimizer.zero_grad() 86 | if (bg_predictor and epoch>=bg_start) or (fg_predictor and epoch>=fg_start): 87 | optimizer_bg_predictor.step() 88 | optimizer_bg_predictor.zero_grad() 89 | 90 | losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} 91 | logger.log_iter(losses=losses) 92 | 93 | scheduler_optimizer.step() 94 | if bg_predictor or fg_predictor: 95 | scheduler_bg_predictor.step() 96 | 97 | model_save = { 98 | 'inpainting_network': inpainting_network, 99 | 'dense_motion_network': dense_motion_network, 100 | 'kp_detector': kp_detector, 101 | 'optimizer': optimizer, 102 | } 103 | if bg_predictor and epoch>=bg_start: 104 | model_save['bg_predictor'] = bg_predictor 105 | model_save['optimizer_bg_predictor'] = optimizer_bg_predictor 106 | 107 | if fg_predictor and epoch>=bg_start: 108 | model_save['fg_predictor'] = fg_predictor 109 | # print(model_save.keys(), x.size(), generated.size()) 110 | # print(generated.keys()) 111 | logger.log_epoch(epoch, model_save, inp=x, out=generated) 112 | 113 | -------------------------------------------------------------------------------- /config/Mixed_data-10-8-wMaskWarp-aug.yaml: -------------------------------------------------------------------------------- 1 | # Dataset parameters 2 | # Each dataset should contain 2 folders train and test 3 | # Each video can be represented as: 4 | # - an image of concatenated frames 5 | # - '.mp4' or '.gif' 6 | # - folder with all frames from a specific video 7 | # In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following 8 | # format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube 9 | # video id. 10 | dataset_params: 11 | # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. 12 | root_dir: ./dataset/Mixed_dataset 13 | # Image shape, needed for staked .png format. 14 | frame_shape: null 15 | # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person. 16 | # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False) 17 | # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335 18 | # id_sampling: True 19 | id_sampling: False 20 | # Augmentation parameters see augmentation.py for all posible augmentations 21 | augmentation_params: 22 | group1: 23 | flip_param: 24 | horizontal_flip: True 25 | time_flip: True 26 | 27 | perspective_param: 28 | distortion_scale: 0.1 29 | p: 0.5 30 | 31 | group2: 32 | jitter_param: 33 | brightness: 0.1 34 | contrast: 0.1 35 | saturation: 0.1 36 | hue: 0.1 37 | 38 | # Defines model architecture 39 | model_params: 40 | common_params: 41 | # Number of TPS transformation 42 | num_tps: 10 43 | # Number of key points 44 | num_kps: 8 45 | # Number of channels per image 46 | num_channels: 3 47 | # Whether to estimate affine background transformation 48 | bg: True 49 | # Whether to estimate perspective foreground transformation 50 | fg: False 51 | # Whether to estimate the multi-resolution occlusion masks 52 | multi_mask: True 53 | generator_params: 54 | # Number of features mutliplier 55 | block_expansion: 64 56 | # Maximum allowed number of features 57 | max_features: 512 58 | # Number of downsampling blocks and Upsampling blocks. 59 | num_down_blocks: 3 60 | dense_motion_params: 61 | # Number of features mutliplier 62 | block_expansion: 64 63 | # Maximum allowed number of features 64 | max_features: 1024 65 | # Number of block in Unet. 66 | num_blocks: 5 67 | # Optical flow is predicted on smaller images for better performance, 68 | # scale_factor=0.25 means that 256x256 image will be resized to 64x64 69 | scale_factor: 0.25 70 | avd_network_params: 71 | # Bottleneck for identity branch 72 | id_bottle_size: 128 73 | # Bottleneck for pose branch 74 | pose_bottle_size: 128 75 | 76 | # Parameters of training 77 | train_params: 78 | # Number of training epochs 79 | num_epochs: 100 80 | # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. 81 | # Thus effectivlly with num_repeats=100 each epoch is 100 times larger. 82 | num_repeats: 50 83 | # Drop learning rate by 10 times after this epochs 84 | epoch_milestones: [50, 90] 85 | # Initial learing rate for all modules 86 | lr_generator: 2.0e-4 87 | # batch_size: 28 88 | batch_size: 32 89 | # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, 90 | # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. 91 | scales: [1, 0.5, 0.25, 0.125] 92 | # Dataset preprocessing cpu workers 93 | dataloader_workers: 12 94 | # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs. 95 | checkpoint_freq: 5 96 | # Parameters of dropout 97 | # The first dropout_epoch training uses dropout operation 98 | dropout_epoch: 15 99 | # The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs 100 | dropout_maxp: 0.7 101 | dropout_startp: 0.0 102 | dropout_inc_epoch: 10 103 | # Estimate affine background transformation from the bg_start epoch. 104 | bg_start: 0 105 | # Estimate prespective background transformation from the fg_start epoch. 106 | fg_start: 0 107 | # Parameters of random TPS transformation for equivariance loss 108 | transform_params: 109 | # Sigma for affine part 110 | sigma_affine: 0.05 111 | # Sigma for deformation part 112 | sigma_tps: 0.005 113 | # Number of point in the deformation grid 114 | # points_tps: 5 115 | points_tps: 8 116 | loss_weights: 117 | # Weights for perceptual loss. 118 | perceptual: [10, 10, 10, 10, 10] 119 | # Weights for value equivariance. 120 | equivariance_value: 10 121 | # Weights for warp loss. 122 | warp_loss: 10 123 | # Weights for bg loss. 124 | bg: 10 125 | # Weights for warp loss. 126 | fg_warp_loss: 50 127 | # Weights for fg loss. 128 | fg: 0 129 | 130 | # Parameters of training (animation-via-disentanglement) 131 | train_avd_params: 132 | # Number of training epochs, visualization is produced after each epoch. 133 | num_epochs: 100 134 | # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. 135 | # Thus effectively with num_repeats=100 each epoch is 100 times larger. 136 | # num_repeats: 150 137 | num_repeats: 50 138 | # Batch size. 139 | batch_size: 256 140 | # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs. 141 | checkpoint_freq: 10 142 | # Dataset preprocessing cpu workers 143 | dataloader_workers: 24 144 | # Drop learning rate 10 times after this epochs 145 | epoch_milestones: [70, 90] 146 | # Initial learning rate 147 | lr: 1.0e-3 148 | # Weights for equivariance loss. 149 | lambda_shift: 1 150 | random_scale: 0.25 151 | 152 | visualizer_params: 153 | kp_size: 5 154 | draw_border: True 155 | colormap: 'gist_rainbow' 156 | -------------------------------------------------------------------------------- /modules/inpainting_network.py: -------------------------------------------------------------------------------- 1 | from numpy import source 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d 6 | # from modules.dense_motion import DenseMotionNetwork 7 | 8 | class InpaintingNetwork(nn.Module): 9 | """ 10 | Inpaint the missing regions and reconstruct the Driving image. 11 | """ 12 | def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs): 13 | super(InpaintingNetwork, self).__init__() 14 | 15 | self.num_down_blocks = num_down_blocks 16 | self.multi_mask = multi_mask 17 | self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 18 | 19 | down_blocks = [] 20 | for i in range(num_down_blocks): 21 | in_features = min(max_features, block_expansion * (2 ** i)) 22 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 23 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 24 | self.down_blocks = nn.ModuleList(down_blocks) 25 | 26 | up_blocks = [] 27 | in_features = [max_features, max_features, max_features//2] 28 | out_features = [max_features//2, max_features//4, max_features//8] 29 | for i in range(num_down_blocks): 30 | up_blocks.append(UpBlock2d(in_features[i], out_features[i], kernel_size=(3, 3), padding=(1, 1))) 31 | self.up_blocks = nn.ModuleList(up_blocks) 32 | 33 | resblock = [] 34 | for i in range(num_down_blocks): 35 | resblock.append(ResBlock2d(in_features[i], kernel_size=(3, 3), padding=(1, 1))) 36 | resblock.append(ResBlock2d(in_features[i], kernel_size=(3, 3), padding=(1, 1))) 37 | self.resblock = nn.ModuleList(resblock) 38 | 39 | self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) 40 | self.num_channels = num_channels 41 | 42 | def deform_input(self, inp, deformation): 43 | _, h_old, w_old, _ = deformation.shape 44 | _, _, h, w = inp.shape 45 | if h_old != h or w_old != w: 46 | deformation = deformation.permute(0, 3, 1, 2) 47 | deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True) 48 | deformation = deformation.permute(0, 2, 3, 1) 49 | return F.grid_sample(inp, deformation, align_corners=True) 50 | 51 | def occlude_input(self, inp, occlusion_map): 52 | if not self.multi_mask: 53 | if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]: 54 | occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True) 55 | out = inp * occlusion_map 56 | return out 57 | 58 | def forward(self, source_image, source_mask, dense_motion): 59 | 60 | # Shared Encoder source image features 61 | out = self.first(source_image) 62 | encoder_map = [out] 63 | for i in range(len(self.down_blocks)): 64 | out = self.down_blocks[i](out) 65 | encoder_map.append(out) 66 | 67 | # masked image 68 | masked_image = source_image 69 | 70 | # Shared Encoder masked image features 71 | out_masked = self.first(masked_image) 72 | encoder_map_masked = [out_masked] 73 | for i in range(len(self.down_blocks)): 74 | out_masked = self.down_blocks[i](out_masked) 75 | encoder_map_masked.append(out_masked) 76 | 77 | output_dict = {} 78 | output_dict['contribution_maps'] = dense_motion['contribution_maps'] 79 | output_dict['deformed_source'] = dense_motion['deformed_source'] 80 | 81 | # occlusion_bg 82 | occlusion_map = dense_motion['occlusion_map'] 83 | output_dict['occlusion_map'] = occlusion_map 84 | 85 | attention_map = dense_motion['attention_map'] 86 | output_dict['attention_map'] = attention_map 87 | 88 | deformation = dense_motion['deformation'] 89 | out_masked_ij = self.deform_input(out_masked.detach(), deformation) # 这一步是为了记录, deformation 无法参与此处的梯度 90 | out_masked = self.deform_input(out_masked, deformation) 91 | 92 | out_masked_ij = self.occlude_input(out_masked_ij, occlusion_map[0].detach()) 93 | out_masked = self.occlude_input(out_masked, occlusion_map[0]) 94 | 95 | warped_encoder_maps = [] 96 | warped_encoder_maps.append(out_masked_ij) 97 | 98 | for i in range(self.num_down_blocks): 99 | 100 | out_masked = self.resblock[2*i](out_masked) 101 | out_masked = self.resblock[2*i+1](out_masked) 102 | out_masked = self.up_blocks[i](out_masked) 103 | 104 | encode_masked_i = encoder_map_masked[-(i+2)] 105 | encode_masked_ij = self.deform_input(encode_masked_i.detach(), deformation) 106 | encode_masked_i = self.deform_input(encode_masked_i, deformation) 107 | 108 | occlusion_ind = 0 109 | if self.multi_mask: 110 | occlusion_ind = i+1 111 | encode_masked_ij = self.occlude_input(encode_masked_ij, occlusion_map[occlusion_ind].detach()) 112 | encode_masked_i = self.occlude_input(encode_masked_i, occlusion_map[occlusion_ind]) 113 | 114 | warped_encoder_maps.append(encode_masked_ij) 115 | 116 | if(i==self.num_down_blocks-1): 117 | break 118 | 119 | out_masked = torch.cat([out_masked, encode_masked_i], 1) 120 | 121 | deformed_source = self.deform_input(source_image, deformation) 122 | output_dict["deformed"] = deformed_source 123 | output_dict["warped_encoder_maps"] = warped_encoder_maps 124 | 125 | 126 | occlusion_last = occlusion_map[-1] 127 | if not self.multi_mask: 128 | occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear', align_corners=True) 129 | 130 | out_masked = out_masked * (1 - occlusion_last) + encode_masked_i 131 | out_masked = self.final(out_masked) 132 | out_masked = torch.sigmoid(out_masked) 133 | 134 | out_masked = out_masked * (1 - occlusion_last) + deformed_source * occlusion_last 135 | output_dict["prediction"] = out_masked 136 | 137 | return output_dict 138 | 139 | def get_encode(self, driver_image, occlusion_map): 140 | out = self.first(driver_image) 141 | encoder_map = [] 142 | encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach())) 143 | for i in range(len(self.down_blocks)): 144 | out = self.down_blocks[i](out.detach()) 145 | out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach()) 146 | encoder_map.append(out_mask.detach()) 147 | 148 | return encoder_map 149 | 150 | -------------------------------------------------------------------------------- /dataset/crop.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | from xml.dom import NotFoundErr 4 | import dlib 5 | import numpy as np 6 | import cv2 7 | import os 8 | 9 | def rect_to_bb(rect): # 获得人脸矩形的坐标信息 10 | x = rect.left() 11 | y = rect.top() 12 | w = rect.right() - x 13 | h = rect.bottom() - y 14 | return (x, y, w, h) 15 | 16 | def shape_to_np(shape, dtype="int"): # 将包含68个特征的的shape转换为numpy array格式 17 | coords = np.zeros((68, 2), dtype=dtype) 18 | for i in range(0, 68): 19 | coords[i] = (shape.part(i).x, shape.part(i).y) 20 | return coords 21 | 22 | 23 | def resize(image, width=1200): # 将待检测的image进行resize 24 | r = width * 1.0 / image.shape[1] 25 | dim = (width, int(image.shape[0] * r)) 26 | resized = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) 27 | return resized 28 | 29 | def feature(image_file): 30 | # image_file = "test.jpg" 31 | # print(image_file) 32 | detector = dlib.get_frontal_face_detector() 33 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 34 | image = cv2.imread(image_file) 35 | h,w,c = image.shape 36 | # print(h,w,c) 37 | # image = resize(image, width=1200) 38 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 39 | rects = detector(gray, 1) 40 | shapes = [] 41 | 42 | rect = dlib.rectangle(max(rects[0].left(), 0), max(0, rects[0].top()), min(w, rects[0].right()), min(h, rects[0].bottom())) 43 | # rect = dlib.rectangle(0, 0, w, h) 44 | rects = dlib.rectangles() 45 | rects.append(rect) 46 | 47 | for (i, rect) in enumerate(rects): 48 | shape = predictor(gray, rect) 49 | shape = shape_to_np(shape) 50 | # print(shape[19],shape[24],shape[27]) 51 | shapes.append(shape) 52 | 53 | left = max(0, (shape[0][0]+shape[17][0])//2) 54 | right = min(w-1, (shape[26][0]+shape[16][0])//2) 55 | up = int(max(0, min(shape[19][1], shape[24][1])-(shape[27][1] - max(shape[19][1], shape[24][1]))*0.7)) 56 | down = min(h-1, shape[8][1]) 57 | 58 | # print(left, right, up, down) 59 | 60 | return left, right, up, down 61 | 62 | 63 | if __name__ == "__main__": 64 | for fold_name in os.listdir("CASMEII/CASME2_RAW_selected"): 65 | for clip_name in os.listdir(os.path.join("CASMEII/CASME2_RAW_selected",fold_name)): 66 | clip_rel = os.path.join("CASMEII/CASME2_RAW_selected", fold_name, clip_name) 67 | print(clip_rel) 68 | names = os.listdir(clip_rel) 69 | f = lambda x:x.split('.')[-1].lower() in ['jpg','bmp','png'] 70 | names = list(filter(f, names)) 71 | names.sort() 72 | left, right, up, down = feature(os.path.join(clip_rel, names[0])) 73 | 74 | os.makedirs(os.path.join('CASMEII/CASME2_RAW_selected_cropped', fold_name, clip_name), exist_ok=True) 75 | for name in names: 76 | image = cv2.imread(os.path.join(clip_rel, name)) 77 | image = image[up:down+1,left:right+1,:] 78 | image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 79 | cv2.imwrite(os.path.join('CASMEII/CASME2_RAW_selected_cropped', fold_name, clip_name, name), image) 80 | 81 | for fold_name in os.listdir("SAMM/SAMM"): 82 | for clip_name in os.listdir(os.path.join("SAMM/SAMM",fold_name)): 83 | clip_rel = os.path.join("SAMM/SAMM", fold_name, clip_name) 84 | print(clip_rel) 85 | names = os.listdir(clip_rel) 86 | f = lambda x:x.split('.')[-1].lower() in ['jpg','bmp','png'] 87 | names = list(filter(f, names)) 88 | names.sort() 89 | left, right, up, down = feature(os.path.join(clip_rel, names[0])) 90 | 91 | os.makedirs(os.path.join('SAMM/SAMM_cropped', fold_name, clip_name), exist_ok=True) 92 | for name in names: 93 | image = cv2.imread(os.path.join(clip_rel, name)) 94 | image = image[up:down+1,left:right+1,:] 95 | image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 96 | cv2.imwrite(os.path.join('SAMM/SAMM_cropped', fold_name, clip_name, name), image) 97 | 98 | for fold_name in os.listdir("SMIC/SMIC_all_raw/HS"): 99 | for attr_name in os.listdir(os.path.join("SMIC/SMIC_all_raw/HS",fold_name, 'micro')): 100 | for clip_name in os.listdir(os.path.join("SMIC/SMIC_all_raw/HS", fold_name, 'micro', attr_name)): 101 | clip_rel = os.path.join("SMIC/SMIC_all_raw/HS", fold_name, 'micro', attr_name, clip_name) 102 | print(clip_rel) 103 | names = os.listdir(clip_rel) 104 | f = lambda x:x.split('.')[-1].lower() in ['jpg','bmp','png'] 105 | names = list(filter(f, names)) 106 | names.sort() 107 | left, right, up, down = feature(os.path.join(clip_rel, names[0])) 108 | 109 | os.makedirs(os.path.join('SMIC/SMIC_all_raw/HS_cropped', fold_name, clip_name), exist_ok=True) 110 | for name in names: 111 | image = cv2.imread(os.path.join(clip_rel, name)) 112 | image = image[up:down+1,left:right+1,:] 113 | # print(image.shape) 114 | image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 115 | cv2.imwrite(os.path.join('SMIC/SMIC_all_raw/HS_cropped', fold_name, clip_name, name), image) 116 | 117 | # for fold_name in os.listdir("megc2022-synthesis/source_samples"): 118 | for fold_name in ['SAMM_challenge','casme2_challenge',]: 119 | for clip_name in os.listdir(os.path.join("megc2022-synthesis/source_samples",fold_name)): 120 | clip_rel = os.path.join("megc2022-synthesis/source_samples", fold_name, clip_name) 121 | print(clip_rel) 122 | names = os.listdir(clip_rel) 123 | f = lambda x:x.split('.')[-1].lower() in ['jpg','bmp','png'] 124 | names = list(filter(f, names)) 125 | names.sort() 126 | left, right, up, down = feature(os.path.join(clip_rel, names[0])) 127 | 128 | os.makedirs(os.path.join('megc2022-synthesis/source_samples_cropped', fold_name, clip_name), exist_ok=True) 129 | for name in names: 130 | image = cv2.imread(os.path.join(clip_rel, name)) 131 | image = image[up:down+1,left:right+1,:] 132 | image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 133 | 134 | cv2.imwrite(os.path.join('megc2022-synthesis/source_samples_cropped', fold_name, clip_name, name), image) 135 | 136 | 137 | for fold_name in ['Smic_challenge']: 138 | for clip_name in os.listdir(os.path.join("megc2022-synthesis/source_samples",fold_name)): 139 | clip_rel = os.path.join("megc2022-synthesis/source_samples", fold_name, clip_name) 140 | print(clip_rel) 141 | names = os.listdir(clip_rel) 142 | f = lambda x:x.split('.')[-1].lower() in ['jpg','bmp','png'] 143 | names = list(filter(f, names)) 144 | names.sort() 145 | 146 | os.makedirs(os.path.join('megc2022-synthesis/source_samples_cropped', fold_name, clip_name), exist_ok=True) 147 | for name in names: 148 | image = cv2.imread(os.path.join(clip_rel, name)) 149 | image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 150 | cv2.imwrite(os.path.join('megc2022-synthesis/source_samples_cropped', fold_name, clip_name, name), image) 151 | 152 | 153 | clip_rel = "megc2022-synthesis/target_template_face" 154 | names = os.listdir(clip_rel) 155 | f = lambda x:x.split('.')[-1].lower() in ['jpg','bmp','png'] 156 | names = list(filter(f, names)) 157 | 158 | os.makedirs(os.path.join('megc2022-synthesis/target_template_face_cropped'), exist_ok=True) 159 | for name in names: 160 | left, right, up, down = feature(os.path.join(clip_rel, name)) 161 | image = cv2.imread(os.path.join(clip_rel, name)) 162 | image = image[up:down+1,left:right+1,:] 163 | image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) 164 | 165 | cv2.imwrite(os.path.join('megc2022-synthesis/target_template_face_cropped', name), image) 166 | 167 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from modules.util import AntiAliasInterpolation2d, TPS 5 | from torchvision import models 6 | import numpy as np 7 | 8 | 9 | class Vgg19(torch.nn.Module): 10 | """ 11 | Vgg19 network for perceptual loss. See Sec 3.3. 12 | """ 13 | def __init__(self, requires_grad=False): 14 | super(Vgg19, self).__init__() 15 | vgg_pretrained_features = models.vgg19(pretrained=True).features 16 | self.slice1 = torch.nn.Sequential() 17 | self.slice2 = torch.nn.Sequential() 18 | self.slice3 = torch.nn.Sequential() 19 | self.slice4 = torch.nn.Sequential() 20 | self.slice5 = torch.nn.Sequential() 21 | for x in range(2): 22 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(2, 7): 24 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(7, 12): 26 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(12, 21): 28 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 29 | for x in range(21, 30): 30 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 31 | 32 | self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), 33 | requires_grad=False) 34 | self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), 35 | requires_grad=False) 36 | 37 | if not requires_grad: 38 | for param in self.parameters(): 39 | param.requires_grad = False 40 | 41 | def forward(self, X): 42 | X = (X - self.mean) / self.std 43 | h_relu1 = self.slice1(X) 44 | h_relu2 = self.slice2(h_relu1) 45 | h_relu3 = self.slice3(h_relu2) 46 | h_relu4 = self.slice4(h_relu3) 47 | h_relu5 = self.slice5(h_relu4) 48 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 49 | return out 50 | 51 | 52 | class ImagePyramide(torch.nn.Module): 53 | """ 54 | Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 55 | """ 56 | def __init__(self, scales, num_channels): 57 | super(ImagePyramide, self).__init__() 58 | downs = {} 59 | for scale in scales: 60 | downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) 61 | self.downs = nn.ModuleDict(downs) 62 | 63 | def forward(self, x): 64 | out_dict = {} 65 | for scale, down_module in self.downs.items(): 66 | out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) 67 | return out_dict 68 | 69 | 70 | def detach_kp(kp): 71 | return {key: value.detach() for key, value in kp.items()} 72 | 73 | 74 | class GeneratorFullModel(torch.nn.Module): 75 | """ 76 | Merge all generator related updates into single model for better multi-gpu usage 77 | """ 78 | 79 | def __init__(self, kp_extractor, bg_predictor, fg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs): 80 | super(GeneratorFullModel, self).__init__() 81 | self.kp_extractor = kp_extractor 82 | self.inpainting_network = inpainting_network 83 | self.dense_motion_network = dense_motion_network 84 | 85 | self.bg_predictor = None 86 | if bg_predictor: 87 | self.bg_predictor = bg_predictor 88 | self.bg_start = train_params['bg_start'] 89 | 90 | self.fg_predictor = None 91 | if fg_predictor: 92 | self.fg_predictor = fg_predictor 93 | self.fg_start = train_params['fg_start'] 94 | 95 | self.train_params = train_params 96 | self.scales = train_params['scales'] 97 | 98 | self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels) 99 | if torch.cuda.is_available(): 100 | self.pyramid = self.pyramid.cuda() 101 | 102 | self.loss_weights = train_params['loss_weights'] 103 | self.dropout_epoch = train_params['dropout_epoch'] 104 | self.dropout_maxp = train_params['dropout_maxp'] 105 | self.dropout_inc_epoch = train_params['dropout_inc_epoch'] 106 | self.dropout_startp =train_params['dropout_startp'] 107 | 108 | if sum(self.loss_weights['perceptual']) != 0: 109 | self.vgg = Vgg19() 110 | if torch.cuda.is_available(): 111 | self.vgg = self.vgg.cuda() 112 | 113 | 114 | def forward(self, x, epoch): 115 | kp_source = self.kp_extractor(x['source']) # bs KN 2 116 | kp_driving = self.kp_extractor(x['driving']) # bs KN 2 117 | bg_param = None 118 | fg_param = None 119 | if self.bg_predictor: 120 | if(epoch>=self.bg_start): 121 | bg_param = self.bg_predictor(x['source'], x['driving']) # affine matrix bs 3 3 122 | 123 | if self.fg_predictor: 124 | if(epoch>=self.fg_start): 125 | fg_param = self.fg_predictor(x['source'], x['driving']) # perspective matrix bs 3 3 126 | 127 | if(epoch>=self.dropout_epoch): 128 | dropout_flag = False 129 | dropout_p = 0 130 | else: 131 | # dropout_p will linearly increase from dropout_startp to dropout_maxp 132 | dropout_flag = True 133 | dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp) 134 | 135 | dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving, 136 | kp_source=kp_source, bg_param = bg_param, fg_param = fg_param, 137 | dropout_flag = dropout_flag, dropout_p = dropout_p) 138 | 139 | generated = self.inpainting_network(x['source'], x['source_mask'], dense_motion) 140 | generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) 141 | 142 | loss_values = {} 143 | 144 | pyramide_real = self.pyramid(x['driving']) 145 | pyramide_generated = self.pyramid(generated['prediction']) 146 | 147 | # reconstruction loss 148 | if sum(self.loss_weights['perceptual']) != 0: 149 | value_total = 0 150 | for scale in self.scales: 151 | x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) 152 | y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) 153 | 154 | for i, weight in enumerate(self.loss_weights['perceptual']): 155 | value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() 156 | value_total += self.loss_weights['perceptual'][i] * value 157 | loss_values['perceptual'] = value_total 158 | 159 | # equivariance loss 160 | if self.loss_weights['equivariance_value'] != 0: 161 | transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params']) 162 | transform_grid = transform_random.transform_frame(x['driving']) 163 | transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True) 164 | transformed_kp = self.kp_extractor(transformed_frame) 165 | 166 | generated['transformed_frame'] = transformed_frame 167 | generated['transformed_kp'] = transformed_kp 168 | 169 | warped = transform_random.warp_coordinates(transformed_kp['fg_kp']) 170 | kp_d = kp_driving['fg_kp'] 171 | value = torch.abs(kp_d - warped).mean() 172 | loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value 173 | 174 | # warp loss 175 | if self.loss_weights['warp_loss'] != 0: 176 | occlusion_map = generated['occlusion_map'] 177 | encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map) 178 | decode_map = generated['warped_encoder_maps'] 179 | value = 0 180 | for i in range(len(encode_map)): 181 | value += torch.abs(encode_map[i]-decode_map[-i-1]).mean() 182 | 183 | loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value 184 | 185 | # bg loss 186 | if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0: 187 | bg_param_reverse = self.bg_predictor(x['driving'], x['source']) 188 | value = torch.matmul(bg_param, bg_param_reverse) 189 | eye = torch.eye(3).view(1, 1, 3, 3).type(value.type()) 190 | value = torch.abs(eye - value).mean() 191 | loss_values['bg'] = self.loss_weights['bg'] * value 192 | 193 | # fg warp loss 194 | if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['fg_warp_loss'] != 0: 195 | loss_values['fg_warp_loss'] = \ 196 | torch.abs(x['driving_mask']*generated['prediction']-x['driving_mask']*x['driving']).mean() 197 | 198 | # fg loss 199 | if self.fg_predictor and epoch >= self.fg_start and self.loss_weights['fg'] != 0: 200 | fg_param_reverse = self.fg_predictor(x['driving'], x['source']) 201 | value_fg = torch.matmul(fg_param, fg_param_reverse) 202 | eye = torch.eye(3).view(1, 1, 3, 3).type(value_fg.type()) 203 | value_fg = torch.abs(eye - value_fg).mean() 204 | loss_values['fg'] = self.loss_weights['fg'] * value_fg 205 | # print(value_fg) 206 | 207 | return loss_values, generated 208 | -------------------------------------------------------------------------------- /frames_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io, img_as_float32 3 | from skimage.color import gray2rgb 4 | from sklearn.model_selection import train_test_split 5 | from imageio import mimread 6 | from skimage.transform import resize 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from augmentation import AllAugmentationTransform 10 | import glob 11 | from functools import partial 12 | import pandas as pd 13 | 14 | 15 | def read_video(name, frame_shape): 16 | """ 17 | Read video which can be: 18 | - an image of concatenated frames 19 | - '.mp4' and'.gif' 20 | - folder with videos 21 | """ 22 | 23 | if os.path.isdir(name): 24 | frames = sorted(os.listdir(name)) 25 | num_frames = len(frames) 26 | video_array = np.array( 27 | [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]) 28 | elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): 29 | image = io.imread(name) 30 | 31 | if len(image.shape) == 2 or image.shape[2] == 1: 32 | image = gray2rgb(image) 33 | 34 | if image.shape[2] == 4: 35 | image = image[..., :3] 36 | 37 | image = img_as_float32(image) 38 | 39 | video_array = np.moveaxis(image, 1, 0) 40 | 41 | video_array = video_array.reshape((-1,) + frame_shape) 42 | video_array = np.moveaxis(video_array, 1, 2) 43 | elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): 44 | video = mimread(name) 45 | if len(video[0].shape) == 2: 46 | video = [gray2rgb(frame) for frame in video] 47 | if frame_shape is not None: 48 | video = np.array([resize(frame, frame_shape) for frame in video]) 49 | video = np.array(video) 50 | if video.shape[-1] == 4: 51 | video = video[..., :3] 52 | video_array = img_as_float32(video) 53 | else: 54 | raise Exception("Unknown file extensions %s" % name) 55 | 56 | return video_array 57 | 58 | 59 | class FramesDataset(Dataset): 60 | """ 61 | Dataset of videos, each video can be represented as: 62 | - an image of concatenated frames 63 | - '.mp4' or '.gif' 64 | - folder with all frames 65 | """ 66 | 67 | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, 68 | random_seed=0, pairs_list=None, augmentation_params=None): 69 | self.root = root_dir 70 | self.videos = os.listdir(root_dir) 71 | self.frame_shape = frame_shape 72 | print("Frame_shape: {}".format(self.frame_shape)) 73 | self.pairs_list = pairs_list 74 | self.id_sampling = id_sampling 75 | 76 | if os.path.exists(os.path.join(root_dir, 'train')): 77 | assert os.path.exists(os.path.join(root_dir, 'test')) 78 | print("Use predefined train-test split.") 79 | if id_sampling: 80 | train_videos = {os.path.basename(video).split('#')[0] for video in 81 | os.listdir(os.path.join(root_dir, 'train'))} 82 | train_videos = list(train_videos) 83 | else: 84 | train_videos = os.listdir(os.path.join(root_dir, 'train')) 85 | test_videos = os.listdir(os.path.join(root_dir, 'test')) 86 | mask_video = os.listdir(os.path.join(root_dir, 'train_mask')) 87 | self.root_dir = os.path.join(root_dir, 'train' if is_train else 'test') 88 | else: 89 | print("Use random train-test split.") 90 | train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) 91 | 92 | if is_train: 93 | self.videos = train_videos 94 | else: 95 | self.videos = test_videos 96 | 97 | self.mask_video = mask_video 98 | self.is_train = is_train 99 | 100 | if self.is_train: 101 | self.transform1 = AllAugmentationTransform(**augmentation_params["group1"]) ## flip & perspective 102 | self.transform2 = AllAugmentationTransform(**augmentation_params["group2"]) ## jitter 103 | else: 104 | self.transform1 = None 105 | self.transform2 = None 106 | 107 | def __len__(self): 108 | return len(self.videos) 109 | 110 | def __getitem__(self, idx): 111 | 112 | if self.is_train and self.id_sampling: 113 | name = self.videos[idx] 114 | path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) 115 | else: 116 | name = self.videos[idx] 117 | path = os.path.join(self.root_dir, name) 118 | 119 | video_name = os.path.basename(path) 120 | if self.is_train and os.path.isdir(path): 121 | 122 | frames = os.listdir(path) 123 | num_frames = len(frames) 124 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) 125 | 126 | if self.frame_shape is not None: 127 | resize_fn = partial(resize, output_shape=self.frame_shape) 128 | else: 129 | resize_fn = img_as_float32 130 | 131 | if type(frames[0]) is bytes: 132 | video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in frame_idx] 133 | 134 | if video_name in self.mask_video: 135 | mask_array = [resize_fn(io.imread(os.path.join(self.root, 'train_mask', video_name, frames[idx].decode('utf-8').split('.')[0]+'.png'))[:,:,None]) for idx in frame_idx] 136 | else: 137 | mask_array = [np.zeros_like(img)[:,:,:1] for img in video_array] 138 | 139 | else: 140 | video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] 141 | 142 | if video_name in self.mask_video: 143 | mask_array = [resize_fn(io.imread(os.path.join(self.root, 'train_mask', video_name, frames[idx].split('.')[0]+'.png'))[:,:,None]) for idx in frame_idx] 144 | else: 145 | mask_array = [np.zeros_like(img)[:,:,:1] for img in video_array] 146 | # print("##",video_array[0].shape, mask_array[0].shape) 147 | else: 148 | 149 | video_array = read_video(path, frame_shape=self.frame_shape) 150 | 151 | num_frames = len(video_array) 152 | frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range( 153 | num_frames) 154 | video_array = video_array[frame_idx] 155 | 156 | if self.transform1 is not None: 157 | video_mask_cat = [np.concatenate((video_array[i], mask_array[i]), axis=2) for i in range(2)] 158 | video_mask_cat = self.transform1(video_mask_cat) 159 | 160 | video_array = [vimg[:,:,:3] for vimg in video_mask_cat] 161 | mask_array = [vimg[:,:,3:] for vimg in video_mask_cat] 162 | 163 | if self.transform2: 164 | video_array = self.transform2(video_array) 165 | 166 | # print(mask_array[0].sum(), mask_array[1].sum()) 167 | 168 | out = {} 169 | if self.is_train: 170 | source = np.array(video_array[0], dtype='float32') 171 | driving = np.array(video_array[1], dtype='float32') 172 | source_mask = np.array(mask_array[0], dtype='float32') 173 | driving_mask = np.array(mask_array[1], dtype='float32') 174 | 175 | out['driving'] = driving.transpose((2, 0, 1)) 176 | out['source'] = source.transpose((2, 0, 1)) 177 | out['driving_mask'] = driving_mask.transpose((2, 0, 1)) 178 | out['source_mask'] = source_mask.transpose((2, 0, 1)) 179 | else: 180 | video = np.array(video_array, dtype='float32') 181 | out['video'] = video.transpose((3, 0, 1, 2)) 182 | 183 | out['name'] = video_name 184 | 185 | return out 186 | 187 | 188 | class DatasetRepeater(Dataset): 189 | """ 190 | Pass several times over the same dataset for better i/o performance 191 | """ 192 | 193 | def __init__(self, dataset, num_repeats=100): 194 | self.dataset = dataset 195 | self.num_repeats = num_repeats 196 | 197 | def __len__(self): 198 | return self.num_repeats * self.dataset.__len__() 199 | 200 | def __getitem__(self, idx): 201 | return self.dataset[idx % self.dataset.__len__()] 202 | 203 | class PairedDataset(Dataset): 204 | """ 205 | Dataset of pairs for animation. 206 | """ 207 | 208 | def __init__(self, initial_dataset, number_of_pairs, seed=0): 209 | self.initial_dataset = initial_dataset 210 | pairs_list = self.initial_dataset.pairs_list 211 | 212 | np.random.seed(seed) 213 | 214 | if pairs_list is None: 215 | max_idx = min(number_of_pairs, len(initial_dataset)) 216 | nx, ny = max_idx, max_idx 217 | xy = np.mgrid[:nx, :ny].reshape(2, -1).T 218 | number_of_pairs = min(xy.shape[0], number_of_pairs) 219 | self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0) 220 | else: 221 | videos = self.initial_dataset.videos 222 | name_to_index = {name: index for index, name in enumerate(videos)} 223 | pairs = pd.read_csv(pairs_list) 224 | pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))] 225 | 226 | number_of_pairs = min(pairs.shape[0], number_of_pairs) 227 | self.pairs = [] 228 | self.start_frames = [] 229 | for ind in range(number_of_pairs): 230 | self.pairs.append( 231 | (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]])) 232 | 233 | def __len__(self): 234 | return len(self.pairs) 235 | 236 | def __getitem__(self, idx): 237 | pair = self.pairs[idx] 238 | first = self.initial_dataset[pair[0]] 239 | 240 | second = self.initial_dataset[pair[1]] 241 | first = {'driving_' + key: value for key, value in first.items()} 242 | second = {'source_' + key: value for key, value in second.items()} 243 | 244 | return {**first, **second} 245 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import imageio 5 | 6 | import os 7 | from skimage.draw import circle_perimeter 8 | 9 | import matplotlib.pyplot as plt 10 | import collections 11 | 12 | 13 | class Logger: 14 | def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'): 15 | 16 | self.loss_list = [] 17 | self.cpk_dir = log_dir 18 | self.visualizations_dir = os.path.join(log_dir, 'train-vis') 19 | if not os.path.exists(self.visualizations_dir): 20 | os.makedirs(self.visualizations_dir) 21 | self.log_file = open(os.path.join(log_dir, log_file_name), 'a') 22 | self.zfill_num = zfill_num 23 | self.visualizer = Visualizer(**visualizer_params) 24 | self.checkpoint_freq = checkpoint_freq 25 | self.epoch = 0 26 | self.best_loss = float('inf') 27 | self.names = None 28 | 29 | def log_scores(self, loss_names): 30 | loss_mean = np.array(self.loss_list).mean(axis=0) 31 | 32 | loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) 33 | loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string 34 | 35 | print(loss_string, file=self.log_file) 36 | self.loss_list = [] 37 | self.log_file.flush() 38 | 39 | def visualize_rec(self, inp, out): 40 | image = self.visualizer.visualize(inp['driving'], inp['source'], out) 41 | imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image) 42 | 43 | def save_cpk(self, emergent=False): 44 | cpk = {k: v.state_dict() for k, v in self.models.items()} 45 | cpk['epoch'] = self.epoch 46 | cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num)) 47 | if not (os.path.exists(cpk_path) and emergent): 48 | torch.save(cpk, cpk_path) 49 | 50 | @staticmethod 51 | def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None, 52 | bg_predictor=None, fg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None, 53 | optimizer_avd=None): 54 | checkpoint = torch.load(checkpoint_path) 55 | if inpainting_network is not None: 56 | inpainting_network.load_state_dict(checkpoint['inpainting_network']) 57 | if kp_detector is not None: 58 | kp_detector.load_state_dict(checkpoint['kp_detector']) 59 | if bg_predictor is not None and 'bg_predictor' in checkpoint: 60 | bg_predictor.load_state_dict(checkpoint['bg_predictor']) 61 | if fg_predictor is not None and 'fg_predictor' in checkpoint: 62 | fg_predictor.load_state_dict(checkpoint['fg_predictor']) 63 | if dense_motion_network is not None: 64 | dense_motion_network.load_state_dict(checkpoint['dense_motion_network']) 65 | if avd_network is not None: 66 | if 'avd_network' in checkpoint: 67 | avd_network.load_state_dict(checkpoint['avd_network']) 68 | if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint: 69 | optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor']) 70 | if optimizer is not None and 'optimizer' in checkpoint: 71 | optimizer.load_state_dict(checkpoint['optimizer']) 72 | if optimizer_avd is not None: 73 | if 'optimizer_avd' in checkpoint: 74 | optimizer_avd.load_state_dict(checkpoint['optimizer_avd']) 75 | epoch = -1 76 | if 'epoch' in checkpoint: 77 | epoch = checkpoint['epoch'] 78 | return epoch 79 | 80 | def __enter__(self): 81 | return self 82 | 83 | def __exit__(self): 84 | if 'models' in self.__dict__: 85 | self.save_cpk() 86 | self.log_file.close() 87 | 88 | def log_iter(self, losses): 89 | losses = collections.OrderedDict(losses.items()) 90 | self.names = list(losses.keys()) 91 | self.loss_list.append(list(losses.values())) 92 | 93 | def log_epoch(self, epoch, models, inp, out): 94 | self.epoch = epoch 95 | print("Saving...", epoch) 96 | self.models = models 97 | if (self.epoch + 1) % self.checkpoint_freq == 0: 98 | self.save_cpk() 99 | self.log_scores(self.names) 100 | self.visualize_rec(inp, out) 101 | 102 | 103 | class Visualizer: 104 | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'): 105 | self.kp_size = kp_size 106 | self.draw_border = draw_border 107 | self.colormap = plt.get_cmap(colormap) 108 | 109 | def draw_image_with_kp(self, image, kp_array): 110 | image = np.copy(image) 111 | spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] 112 | kp_array = spatial_size * (kp_array + 1) / 2 113 | num_kp = kp_array.shape[0] 114 | for kp_ind, kp in enumerate(kp_array): 115 | # rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2]) 116 | rr, cc = circle_perimeter(int(kp[1]), int(kp[0]), self.kp_size, shape=image.shape[:2]) 117 | image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] 118 | return image 119 | 120 | def create_image_column_with_kp(self, images, kp): 121 | image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) 122 | return self.create_image_column(image_array) 123 | 124 | def create_image_column(self, images): 125 | if self.draw_border: 126 | images = np.copy(images) 127 | images[:, :, [0, -1]] = (1, 1, 1) 128 | images[:, :, [0, -1]] = (1, 1, 1) 129 | return np.concatenate(list(images), axis=0) 130 | 131 | def create_image_grid(self, *args): 132 | out = [] 133 | for arg in args: 134 | if type(arg) == tuple: 135 | out.append(self.create_image_column_with_kp(arg[0], arg[1])) 136 | else: 137 | out.append(self.create_image_column(arg)) 138 | return np.concatenate(out, axis=1) 139 | 140 | def visualize(self, driving, source, out): 141 | images = [] 142 | 143 | # Source image with keypoints 144 | source = source.data.cpu() 145 | kp_source = out['kp_source']['fg_kp'].data.cpu().numpy() 146 | source = np.transpose(source, [0, 2, 3, 1]) 147 | images.append((source, kp_source)) 148 | 149 | # Equivariance visualization 150 | if 'transformed_frame' in out: 151 | transformed = out['transformed_frame'].data.cpu().numpy() 152 | transformed = np.transpose(transformed, [0, 2, 3, 1]) 153 | transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy() 154 | images.append((transformed, transformed_kp)) 155 | 156 | # Driving image with keypoints 157 | kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy() 158 | driving = driving.data.cpu().numpy() 159 | driving = np.transpose(driving, [0, 2, 3, 1]) 160 | images.append((driving, kp_driving)) 161 | 162 | # Deformed image 163 | if 'deformed' in out: 164 | deformed = out['deformed'].data.cpu().numpy() 165 | deformed = np.transpose(deformed, [0, 2, 3, 1]) 166 | images.append(deformed) 167 | 168 | # Result with and without keypoints 169 | prediction = out['prediction'].data.cpu().numpy() 170 | prediction = np.transpose(prediction, [0, 2, 3, 1]) 171 | if 'kp_norm' in out: 172 | kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy() 173 | images.append((prediction, kp_norm)) 174 | images.append(prediction) 175 | 176 | 177 | ## Occlusion map 178 | if 'occlusion_map' in out: 179 | for i in range(len(out['occlusion_map'])): 180 | occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1) 181 | occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() 182 | occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) 183 | images.append(occlusion_map) 184 | 185 | ## source mask 186 | if 'source_mask' in out: 187 | source_mask = out['source_mask'].data.cpu().repeat(1, 3, 1, 1) 188 | source_mask = F.interpolate(source_mask, size=source.shape[1:3]).numpy() 189 | source_mask = np.transpose(source_mask, [0, 2, 3, 1]) 190 | images.append(source_mask) 191 | 192 | ## attention mask 193 | 194 | 195 | ## Occlusion map 196 | if 'occlusion_fg' in out: 197 | for i in range(len(out['occlusion_fg'])): 198 | occlusion_fg = out['occlusion_fg'][i].data.cpu().repeat(1, 3, 1, 1) 199 | occlusion_fg = F.interpolate(occlusion_fg, size=source.shape[1:3]).numpy() 200 | occlusion_fg = np.transpose(occlusion_fg, [0, 2, 3, 1]) 201 | images.append(occlusion_fg) 202 | 203 | # Deformed images according to each individual transform 204 | if 'deformed_source' in out: 205 | full_mask = [] 206 | for i in range(out['deformed_source'].shape[1]): 207 | image = out['deformed_source'][:, i].data.cpu() 208 | # import ipdb;ipdb.set_trace() 209 | image = F.interpolate(image, size=source.shape[1:3]) 210 | mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1) 211 | mask = F.interpolate(mask, size=source.shape[1:3]) 212 | image = np.transpose(image.numpy(), (0, 2, 3, 1)) 213 | mask = np.transpose(mask.numpy(), (0, 2, 3, 1)) 214 | 215 | if i != 0: 216 | color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3] 217 | else: 218 | color = np.array((0, 0, 0)) 219 | 220 | color = color.reshape((1, 1, 1, 3)) 221 | 222 | images.append(image) 223 | if i != 0: 224 | images.append(mask * color) 225 | else: 226 | images.append(mask) 227 | 228 | full_mask.append(mask * color) 229 | 230 | images.append(sum(full_mask)) 231 | 232 | image = self.create_image_grid(*images) 233 | image = (255 * image).astype(np.uint8) 234 | return image 235 | -------------------------------------------------------------------------------- /modules/dense_motion.py: -------------------------------------------------------------------------------- 1 | from unittest.main import main 2 | from xml.dom import NotFoundErr 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import torch 6 | from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian 7 | from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS 8 | import math 9 | 10 | class DenseMotionNetwork(nn.Module): 11 | """ 12 | Module that estimating an optical flow and multi-resolution occlusion masks 13 | from K TPS transformations and an affine transformation. 14 | """ 15 | 16 | def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_kps, num_channels, 17 | scale_factor=0.25, bg = False, fg = False, multi_mask = True, kp_variance=0.01): 18 | super(DenseMotionNetwork, self).__init__() 19 | 20 | if scale_factor != 1: 21 | self.down = AntiAliasInterpolation2d(num_channels, scale_factor) 22 | self.scale_factor = scale_factor 23 | self.multi_mask = multi_mask 24 | 25 | if bg and fg: 26 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps + 2) + num_tps * num_kps + 2), 27 | max_features=max_features, num_blocks=num_blocks) 28 | else: 29 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps + 1) + num_tps * num_kps + 1), 30 | max_features=max_features, num_blocks=num_blocks) 31 | # 包含了只有一个变换和两个变换都没有,补充了一个通道的情况, 32 | # 训练的时候都是带着变换推光流的,但是去掉补充一个新的之后能有作用么?对于不变的背景倒是有用。 33 | 34 | 35 | hourglass_output_size = self.hourglass.out_channels 36 | if bg and fg: 37 | self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 2, kernel_size=(7, 7), padding=(3, 3)) 38 | else: 39 | self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3)) 40 | 41 | if multi_mask: 42 | up = [] 43 | self.up_nums = int(math.log(1/scale_factor, 2)) # 2 44 | self.occlusion_num = 4 45 | 46 | channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)] 47 | for i in range(self.up_nums): 48 | up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1)) 49 | self.up = nn.ModuleList(up) # 2 levels 50 | 51 | channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]] 52 | for i in range(self.up_nums): 53 | channel.append(hourglass_output_size[-1]//(2**(i+1))) 54 | occlusion = [] 55 | 56 | for i in range(self.occlusion_num-1): 57 | occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3))) 58 | occlusion.append(nn.Conv2d(channel[self.occlusion_num-1], 2, kernel_size=(7, 7), padding=(3, 3))) 59 | # 此处有修改 60 | self.occlusion = nn.ModuleList(occlusion) 61 | else: 62 | occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))] 63 | self.occlusion = nn.ModuleList(occlusion) 64 | 65 | self.num_tps = num_tps 66 | self.num_kps = num_kps 67 | self.bg = bg 68 | self.fg = fg 69 | self.kp_variance = kp_variance 70 | 71 | 72 | def create_heatmap_representations(self, source_image, kp_driving, kp_source): 73 | 74 | spatial_size = source_image.shape[2:] 75 | gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance) ## bs, KN, w, h 76 | gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance) ## bs, KN, w, h 77 | heatmap = gaussian_driving - gaussian_source ## bs, KN, w, h 78 | 79 | if self.bg and self.fg: 80 | zeros = torch.zeros(heatmap.shape[0], 2, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device) 81 | else: 82 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device) 83 | 84 | heatmap = torch.cat([zeros, heatmap], dim=1) ## bs, KN+1(+2), w, h 85 | 86 | return heatmap 87 | 88 | def create_transformations(self, source_image, kp_driving, kp_source, bg_param, fg_param): 89 | # K TPS transformaions 90 | bs, _, h, w = source_image.shape 91 | kp_1 = kp_driving['fg_kp'] 92 | kp_2 = kp_source['fg_kp'] 93 | kp_1 = kp_1.view(bs, -1, self.num_kps, 2) 94 | kp_2 = kp_2.view(bs, -1, self.num_kps, 2) 95 | trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2) 96 | driving_to_source = trans.transform_frame(source_image) # bs K h w 2 97 | 98 | identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device) 99 | identity_grid = identity_grid.view(1, 1, h, w, 2) 100 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) # bs 1 h w 2 101 | 102 | # affine background transformation 103 | if not (bg_param is None): 104 | identity_grid_bg = to_homogeneous(identity_grid) 105 | identity_grid_bg = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid_bg.unsqueeze(-1)).squeeze(-1) 106 | identity_grid_bg = from_homogeneous(identity_grid_bg) # bs 1 h w 2 107 | 108 | # perspective foreground transformation 109 | if not (fg_param is None): 110 | identity_grid_fg = to_homogeneous(identity_grid) 111 | identity_grid_fg = torch.matmul(fg_param.view(bs, 1, 1, 1, 3, 3), identity_grid_fg.unsqueeze(-1)).squeeze(-1) 112 | identity_grid_fg = from_homogeneous(identity_grid_fg) # bs 1 h w 2 113 | 114 | # transformations = torch.cat([identity_grid_bg, identity_grid_fg, driving_to_source], dim=1) # bs K+2 h w 2 115 | transformations = driving_to_source # bs K h w 2 116 | if self.fg: 117 | if not (fg_param is None): 118 | transformations = torch.cat([identity_grid_fg, transformations], dim=1) 119 | else: 120 | transformations = torch.cat([identity_grid, transformations], dim=1) 121 | # 这里是在测试的时候满足模型的size要求 122 | 123 | if self.bg: 124 | if not (bg_param is None): 125 | transformations = torch.cat([identity_grid_bg, transformations], dim=1) 126 | else: 127 | transformations = torch.cat([identity_grid, transformations], dim=1) 128 | # 这里是在测试的时候满足模型的size要求 129 | 130 | return transformations 131 | 132 | def create_deformed_source_image(self, source_image, transformations): 133 | 134 | bs, _, h, w = source_image.shape 135 | K = transformations.size(1) 136 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, K, 1, 1, 1, 1) 137 | source_repeat = source_repeat.view(bs * K, -1, h, w) 138 | transformations = transformations.view((bs * K, h, w, -1)) 139 | deformed = F.grid_sample(source_repeat, transformations, align_corners=True) 140 | deformed = deformed.view((bs, K, -1, h, w)) 141 | return deformed # bs K+2 3 h w 142 | 143 | def dropout_softmax(self, X, P): 144 | ''' 145 | Dropout for TPS transformations. Eq(7) and Eq(8) in the paper. 146 | ''' 147 | drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device) 148 | drop[..., 0] = 1 149 | drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1) 150 | 151 | maxx = X.max(1).values.unsqueeze_(1) 152 | X = X - maxx 153 | X_exp = X.exp() 154 | X[:,1:,...] /= (1-P) 155 | mask_bool =(drop == 0) 156 | X_exp = X_exp.masked_fill(mask_bool, 0) 157 | partition = X_exp.sum(dim=1, keepdim=True) + 1e-6 158 | return X_exp / partition 159 | 160 | def forward(self, source_image, kp_driving, kp_source, bg_param = None, fg_param = None, dropout_flag=False, dropout_p = 0): 161 | if self.scale_factor != 1: 162 | source_image = self.down(source_image) ## /4 downsample 163 | 164 | bs, _, h, w = source_image.shape 165 | 166 | out_dict = dict() 167 | heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) # bs KN+2(+1) h w 168 | transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param, fg_param) # bs K+2(+1) h w 2 169 | 170 | deformed_source = self.create_deformed_source_image(source_image, transformations) # bs K+2(+1) 3 h w 输入图像使用每一种变形搞一下 171 | out_dict['deformed_source'] = deformed_source 172 | 173 | deformed_source = deformed_source.view(bs,-1,h,w) 174 | input = torch.cat([heatmap_representation, deformed_source], dim=1) # 形变之后的图像和关键点一起输入 175 | input = input.view(bs, -1, h, w) 176 | 177 | prediction = self.hourglass(input, mode = 1) 178 | 179 | contribution_maps = self.maps(prediction[-1]) ## bs k+2 h w 180 | if(dropout_flag): 181 | contribution_maps = self.dropout_softmax(contribution_maps, dropout_p) 182 | else: 183 | contribution_maps = F.softmax(contribution_maps, dim=1) 184 | out_dict['contribution_maps'] = contribution_maps 185 | 186 | # Combine the K+2 transformations 187 | # Eq(6) in the paper 188 | contribution_maps = contribution_maps.unsqueeze(2) ## bs k+2 1 h w 189 | transformations = transformations.permute(0, 1, 4, 2, 3) # bs K+2 2 h w 190 | deformation = (transformations * contribution_maps).sum(dim=1) # bs 2 h w 使用contribution map给transformation加了权 191 | deformation = deformation.permute(0, 2, 3, 1) # bs h w 2 192 | 193 | out_dict['deformation'] = deformation # Optical Flow 可以用这个算出光流,但是并不直接是光流 194 | 195 | occlusion_map = [] 196 | if self.multi_mask: 197 | for i in range(self.occlusion_num-self.up_nums): 198 | occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i]))) 199 | prediction = prediction[-1] 200 | for i in range(self.up_nums): 201 | prediction = self.up[i](prediction) 202 | occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction))) 203 | else: 204 | occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1]))) 205 | 206 | out_dict['attention_map'] = [occlusion_map[-1][:,1:]] 207 | occlusion_map[-1] = occlusion_map[-1][:,:1] 208 | out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks 209 | # 32x32x1 64x64x1 128x128x1 256x256x1 210 | return out_dict 211 | 212 | if __name__=='__main__': 213 | model = DenseMotionNetwork(64, 5, 1024, 10, 8, 3, scale_factor=0.25, bg = True, multi_mask = True, kp_variance=0.01) 214 | print(model) 215 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import sys 4 | import yaml 5 | import os 6 | from argparse import ArgumentParser 7 | from tqdm import tqdm 8 | from scipy.spatial import ConvexHull 9 | import numpy as np 10 | import imageio 11 | from skimage.transform import resize 12 | from skimage import img_as_ubyte 13 | import pandas as pd 14 | import torch 15 | from modules.inpainting_network import InpaintingNetwork 16 | from modules.keypoint_detector import KPDetector 17 | from modules.dense_motion import DenseMotionNetwork 18 | from modules.avd_network import AVDNetwork 19 | from modules.bg_motion_predictor import BGMotionPredictor 20 | from modules.fg_motion_predictor import FGMotionPredictor 21 | 22 | if sys.version_info[0] < 3: 23 | raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9") 24 | 25 | def relative_kp(kp_source, kp_driving, kp_driving_initial): 26 | 27 | source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume 28 | driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume 29 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) 30 | 31 | kp_new = {k: v for k, v in kp_driving.items()} 32 | 33 | kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp']) 34 | kp_value_diff *= adapt_movement_scale 35 | kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp'] 36 | 37 | return kp_new 38 | 39 | def load_checkpoints(config_path, checkpoint_path, device): 40 | with open(config_path) as f: 41 | config = yaml.load(f, Loader=yaml.FullLoader) 42 | 43 | inpainting = InpaintingNetwork(**config['model_params']['generator_params'], 44 | **config['model_params']['common_params']) 45 | kp_detector = KPDetector(**config['model_params']['common_params']) 46 | dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], 47 | **config['model_params']['dense_motion_params']) 48 | avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], 49 | num_kps=config['model_params']['common_params']['num_kps'], 50 | **config['model_params']['avd_network_params']) 51 | bg_predictor = None 52 | if (config['model_params']['common_params']['bg']): 53 | print("create BGMotionPredictor") 54 | bg_predictor = BGMotionPredictor() 55 | bg_predictor.to(device) 56 | bg_predictor.eval() 57 | 58 | fg_predictor = None 59 | if (config['model_params']['common_params']['fg']): 60 | print("create FGMotionPredictor") 61 | fg_predictor = FGMotionPredictor() 62 | fg_predictor.to(device) 63 | fg_predictor.eval() 64 | 65 | kp_detector.to(device) 66 | dense_motion_network.to(device) 67 | inpainting.to(device) 68 | avd_network.to(device) 69 | 70 | checkpoint = torch.load(checkpoint_path, map_location=device) 71 | 72 | inpainting.load_state_dict(checkpoint['inpainting_network']) 73 | kp_detector.load_state_dict(checkpoint['kp_detector']) 74 | dense_motion_network.load_state_dict(checkpoint['dense_motion_network']) 75 | if 'avd_network' in checkpoint: 76 | avd_network.load_state_dict(checkpoint['avd_network']) 77 | if 'bg_predictor' in checkpoint: 78 | bg_predictor.load_state_dict(checkpoint['bg_predictor']) 79 | if 'fg_predictor' in checkpoint: 80 | fg_predictor.load_state_dict(checkpoint['fg_predictor']) 81 | 82 | inpainting.eval() 83 | kp_detector.eval() 84 | dense_motion_network.eval() 85 | avd_network.eval() 86 | 87 | return inpainting, kp_detector, dense_motion_network, avd_network, bg_predictor, fg_predictor 88 | 89 | 90 | def make_animation(source_image, source_image_mask, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, bg_predictor, fg_predictor, device, mode = 'relative'): 91 | assert mode in ['standard', 'relative', 'avd'] 92 | with torch.no_grad(): 93 | predictions = [] 94 | source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) 95 | source = source.to(device) 96 | source_mask = torch.tensor(source_image_mask[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) 97 | source_mask = source_mask.to(device) 98 | 99 | driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device) 100 | 101 | kp_source = kp_detector(source) 102 | kp_driving_initial = kp_detector(driving[:, :, 0]) 103 | 104 | for frame_idx in tqdm(range(driving.shape[2])): 105 | driving_frame = driving[:, :, frame_idx] 106 | driving_frame = driving_frame.to(device) 107 | kp_driving = kp_detector(driving_frame) 108 | if mode == 'standard': 109 | kp_norm = kp_driving 110 | elif mode=='relative': 111 | kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving, 112 | kp_driving_initial=kp_driving_initial) 113 | elif mode == 'avd': 114 | kp_norm = avd_network(kp_source, kp_driving) 115 | 116 | bg_param = None 117 | if bg_predictor!=None: 118 | bg_param = bg_predictor(source, driving_frame) 119 | # print(bg_param) 120 | 121 | fg_param = None 122 | if fg_predictor!=None: 123 | fg_param = fg_predictor(source, driving_frame) 124 | # print(fg_param) 125 | 126 | dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm, 127 | kp_source=kp_source, bg_param = bg_param, fg_param = fg_param, 128 | dropout_flag = False) 129 | out = inpainting_network(source, source_mask, dense_motion) 130 | 131 | predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) 132 | return predictions 133 | 134 | 135 | def find_best_frame(source, driving, cpu): 136 | import face_alignment 137 | 138 | def normalize_kp(kp): 139 | kp = kp - kp.mean(axis=0, keepdims=True) 140 | area = ConvexHull(kp[:, :2]).volume 141 | area = np.sqrt(area) 142 | kp[:, :2] = kp[:, :2] / area 143 | return kp 144 | 145 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, 146 | device= 'cpu' if cpu else 'cuda') 147 | kp_source = fa.get_landmarks(255 * source)[0] 148 | kp_source = normalize_kp(kp_source) 149 | norm = float('inf') 150 | frame_num = 0 151 | for i, image in tqdm(enumerate(driving)): 152 | kp_driving = fa.get_landmarks(255 * image)[0] 153 | kp_driving = normalize_kp(kp_driving) 154 | new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() 155 | if new_norm < norm: 156 | norm = new_norm 157 | frame_num = i 158 | return frame_num 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = ArgumentParser() 163 | parser.add_argument("--config", required=True, help="path to config") 164 | parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore") 165 | 166 | parser.add_argument("--src_drv_csv", default='./dataset/Mixed_dataset/Mixed_dataset_test.csv', help="path to csv file") 167 | parser.add_argument("--result_video", default=None, help="path to output") 168 | 169 | parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))), 170 | help='Shape of image, that the model was trained on.') 171 | 172 | parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result") 173 | 174 | parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", 175 | help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") 176 | 177 | parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") 178 | 179 | opt = parser.parse_args() 180 | 181 | data = pd.read_csv(opt.src_drv_csv) 182 | os.makedirs(opt.result_video, exist_ok=True) 183 | 184 | for i in range(len(data['driving'])): 185 | src_dir = data['source'][i] 186 | 187 | src_img = os.listdir(os.path.join("./dataset/Mixed_dataset/test",src_dir))[0] 188 | source_image = imageio.imread(os.path.join("./dataset/Mixed_dataset/test",src_dir,src_img)) 189 | 190 | if os.path.exists(os.path.join("./dataset/Mixed_dataset/test_mask",src_dir)): 191 | src_img_mask = os.listdir(os.path.join("./dataset/Mixed_dataset/test_mask",src_dir))[0] 192 | source_image_mask = imageio.imread(os.path.join("./dataset/Mixed_dataset/test_mask",src_dir,src_img_mask)) 193 | source_image_mask = source_image_mask[:,:,None] 194 | else: 195 | source_image_mask = np.zeros((256,256,1)) 196 | 197 | drv_vid = data['driving'][i] 198 | driving_video = [] 199 | drv_vid_files = os.listdir(os.path.join("./dataset/Mixed_dataset/test",drv_vid)) 200 | 201 | f = lambda x:x.split('.')[-1] in ['png','jpg','bmp'] 202 | drv_vid_files = list(filter(f, drv_vid_files)) 203 | drv_vid_files.sort() 204 | 205 | for im in drv_vid_files: 206 | driving_video.append(imageio.imread(os.path.join("./dataset/Mixed_dataset/test",drv_vid,im))) 207 | 208 | if opt.cpu: 209 | device = torch.device('cpu') 210 | else: 211 | device = torch.device('cuda') 212 | 213 | source_image = resize(source_image, opt.img_shape)[..., :3] 214 | source_image_mask = resize(source_image_mask, opt.img_shape)[..., :1] 215 | driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video] 216 | inpainting, kp_detector, dense_motion_network, avd_network, bg_predictor, fg_predictor = \ 217 | load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device) 218 | 219 | if opt.find_best_frame: 220 | i = find_best_frame(source_image, driving_video, opt.cpu) 221 | print ("Best frame: " + str(i)) 222 | driving_forward = driving_video[i:] 223 | driving_backward = driving_video[:(i+1)][::-1] 224 | predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode) 225 | predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode) 226 | predictions = predictions_backward[::-1] + predictions_forward[1:] 227 | else: 228 | predictions = make_animation(source_image, source_image_mask, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, bg_predictor, fg_predictor, device = device, mode = opt.mode) 229 | 230 | print(drv_vid+'_'+src_dir+'.mp4') 231 | imageio.mimsave(os.path.join(opt.result_video, drv_vid+'_'+src_dir+'.mp4'), [img_as_ubyte(frame) for frame in predictions], fps=100) 232 | 233 | -------------------------------------------------------------------------------- /modules/util.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class TPS: 7 | ''' 8 | TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss. 9 | ''' 10 | def __init__(self, mode, bs, **kwargs): 11 | self.bs = bs 12 | self.mode = mode 13 | if mode == 'random': 14 | noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) 15 | self.theta = noise + torch.eye(2, 3).view(1, 2, 3) 16 | self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) 17 | self.control_points = self.control_points.unsqueeze(0) 18 | self.control_params = torch.normal(mean=0, 19 | std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) 20 | elif mode == 'kp': 21 | kp_1 = kwargs["kp_1"] 22 | kp_2 = kwargs["kp_2"] 23 | device = kp_1.device 24 | kp_type = kp_1.type() 25 | self.gs = kp_1.shape[1] 26 | n = kp_1.shape[2] 27 | K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2) # bs K N N 28 | K = K**2 29 | K = K * torch.log(K+1e-9) 30 | 31 | # 调整为齐次 32 | one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type) 33 | kp_1p = torch.cat([kp_1,one1], 3) # bs K N 3 34 | 35 | zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type) 36 | P = torch.cat([kp_1p, zero],2) # bs K N+3 3 37 | L = torch.cat([K,kp_1p.permute(0,1,3,2)],2) # bs K N+3 N 38 | L = torch.cat([L,P],3) # bs K N+3 N+3 39 | 40 | zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type) 41 | Y = torch.cat([kp_2, zero], 2) # bs K N+3 2 42 | one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01 # bs K N+3 N+3 43 | L = L + one # 数值稳定性 44 | 45 | param = torch.matmul(torch.inverse(L),Y) # bs K N+3 2 46 | self.theta = param[:,:,n:,:].permute(0,1,3,2) # bs K 2 3 47 | 48 | self.control_points = kp_1 49 | self.control_params = param[:,:,:n,:] 50 | else: 51 | raise Exception("Error TPS mode") 52 | 53 | def transform_frame(self, frame): 54 | grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device) # w h 2 55 | grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) 56 | shape = [self.bs, frame.shape[2], frame.shape[3], 2] 57 | if self.mode == 'kp': 58 | shape.insert(1, self.gs) 59 | grid = self.warp_coordinates(grid).view(*shape) 60 | return grid #实际上只用了frame的尺寸信息,bs gs w h 2 61 | 62 | def warp_coordinates(self, coordinates): 63 | theta = self.theta.type(coordinates.type()).to(coordinates.device) 64 | control_points = self.control_points.type(coordinates.type()).to(coordinates.device) 65 | control_params = self.control_params.type(coordinates.type()).to(coordinates.device) 66 | 67 | if self.mode == 'kp': 68 | transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:] 69 | 70 | distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2) 71 | 72 | distances = distances ** 2 73 | result = distances.sum(-1) 74 | result = result * torch.log(result + 1e-9) 75 | result = torch.matmul(result.permute(0, 1, 3, 2), control_params) 76 | transformed = transformed.permute(0, 1, 3, 2) + result 77 | 78 | elif self.mode == 'random': 79 | theta = theta.unsqueeze(1) 80 | transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] 81 | transformed = transformed.squeeze(-1) 82 | ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) 83 | distances = ances ** 2 84 | 85 | result = distances.sum(-1) 86 | result = result * torch.log(result + 1e-9) 87 | result = result * control_params 88 | result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) 89 | transformed = transformed + result 90 | else: 91 | raise Exception("Error TPS mode") 92 | 93 | return transformed # 原本的坐标在warp之后的坐标,bs w h 2 94 | 95 | 96 | def kp2gaussian(kp, spatial_size, kp_variance): 97 | """ 98 | Transform a keypoint into gaussian like representation 99 | """ 100 | 101 | coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device) # w h 2 102 | number_of_leading_dimensions = len(kp.shape) - 1 103 | shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape # 把原来的[x,y]更换为 w h 2的一组坐标 104 | coordinate_grid = coordinate_grid.view(*shape) 105 | repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1) 106 | coordinate_grid = coordinate_grid.repeat(*repeats) # 已经全部换完,现在每个wh2都对应之前的[x,y] 107 | 108 | # Preprocess kp shape 109 | shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2) 110 | kp = kp.view(*shape) 111 | 112 | mean_sub = (coordinate_grid - kp) # 坐标格子里的坐标和[x,y]相减 113 | 114 | out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) 115 | 116 | return out # bs KN w h 117 | 118 | 119 | def make_coordinate_grid(spatial_size, type): 120 | """ 121 | Create a meshgrid [-1,1] x [-1,1] of given spatial_size. 122 | """ 123 | h, w = spatial_size 124 | x = torch.arange(w).type(type) 125 | y = torch.arange(h).type(type) 126 | 127 | x = (2 * (x / (w - 1)) - 1) 128 | y = (2 * (y / (h - 1)) - 1) 129 | 130 | yy = y.view(-1, 1).repeat(1, w) 131 | xx = x.view(1, -1).repeat(h, 1) 132 | 133 | meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) 134 | 135 | return meshed ## h w 2 136 | 137 | 138 | class ResBlock2d(nn.Module): 139 | """ 140 | Res block, preserve spatial resolution. 141 | """ 142 | 143 | def __init__(self, in_features, kernel_size, padding): 144 | super(ResBlock2d, self).__init__() 145 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 146 | padding=padding) 147 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 148 | padding=padding) 149 | self.norm1 = nn.InstanceNorm2d(in_features, affine=True) 150 | self.norm2 = nn.InstanceNorm2d(in_features, affine=True) 151 | 152 | def forward(self, x): 153 | out = self.norm1(x) 154 | out = F.relu(out) 155 | out = self.conv1(out) 156 | out = self.norm2(out) 157 | out = F.relu(out) 158 | out = self.conv2(out) 159 | out += x 160 | return out 161 | 162 | 163 | class UpBlock2d(nn.Module): 164 | """ 165 | Upsampling block for use in decoder. 166 | """ 167 | 168 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 169 | super(UpBlock2d, self).__init__() 170 | 171 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 172 | padding=padding, groups=groups) 173 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 174 | 175 | def forward(self, x): 176 | out = F.interpolate(x, scale_factor=2) 177 | out = self.conv(out) 178 | out = self.norm(out) 179 | out = F.relu(out) 180 | return out 181 | 182 | 183 | class DownBlock2d(nn.Module): 184 | """ 185 | Downsampling block for use in encoder. 186 | """ 187 | 188 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 189 | super(DownBlock2d, self).__init__() 190 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 191 | padding=padding, groups=groups) 192 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 193 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 194 | 195 | def forward(self, x): 196 | out = self.conv(x) 197 | out = self.norm(out) 198 | out = F.relu(out) 199 | out = self.pool(out) 200 | return out 201 | 202 | 203 | class SameBlock2d(nn.Module): 204 | """ 205 | Simple block, preserve spatial resolution. 206 | """ 207 | 208 | def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): 209 | super(SameBlock2d, self).__init__() 210 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 211 | kernel_size=kernel_size, padding=padding, groups=groups) 212 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 213 | 214 | def forward(self, x): 215 | out = self.conv(x) 216 | out = self.norm(out) 217 | out = F.relu(out) 218 | return out 219 | 220 | 221 | class Encoder(nn.Module): 222 | """ 223 | Hourglass Encoder 224 | """ 225 | 226 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 227 | super(Encoder, self).__init__() 228 | 229 | down_blocks = [] 230 | for i in range(num_blocks): 231 | down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), 232 | min(max_features, block_expansion * (2 ** (i + 1))), 233 | kernel_size=3, padding=1)) 234 | self.down_blocks = nn.ModuleList(down_blocks) 235 | 236 | def forward(self, x): 237 | outs = [x] 238 | #print('encoder:' ,outs[-1].shape) 239 | for down_block in self.down_blocks: 240 | outs.append(down_block(outs[-1])) 241 | #print('encoder:' ,outs[-1].shape) 242 | return outs 243 | 244 | 245 | class Decoder(nn.Module): 246 | """ 247 | Hourglass Decoder 248 | """ 249 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 250 | super(Decoder, self).__init__() 251 | 252 | up_blocks = [] 253 | self.out_channels = [] 254 | for i in range(num_blocks)[::-1]: 255 | in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) 256 | self.out_channels.append(in_filters) 257 | out_filters = min(max_features, block_expansion * (2 ** i)) 258 | up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) 259 | 260 | self.up_blocks = nn.ModuleList(up_blocks) 261 | self.out_channels.append(block_expansion + in_features) 262 | # self.out_filters = block_expansion + in_features 263 | 264 | def forward(self, x, mode = 0): 265 | out = x.pop() 266 | outs = [] 267 | for up_block in self.up_blocks: 268 | out = up_block(out) 269 | skip = x.pop() 270 | out = torch.cat([out, skip], dim=1) 271 | outs.append(out) 272 | if(mode == 0): 273 | return out 274 | else: 275 | return outs 276 | 277 | 278 | class Hourglass(nn.Module): 279 | """ 280 | Hourglass architecture. 281 | """ 282 | 283 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 284 | super(Hourglass, self).__init__() 285 | self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) 286 | self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) 287 | self.out_channels = self.decoder.out_channels 288 | # self.out_filters = self.decoder.out_filters 289 | 290 | def forward(self, x, mode = 0): 291 | return self.decoder(self.encoder(x), mode) 292 | 293 | 294 | class AntiAliasInterpolation2d(nn.Module): 295 | """ 296 | Band-limited downsampling, for better preservation of the input signal. 297 | """ 298 | def __init__(self, channels, scale): 299 | super(AntiAliasInterpolation2d, self).__init__() 300 | sigma = (1 / scale - 1) / 2 301 | kernel_size = 2 * round(sigma * 4) + 1 302 | self.ka = kernel_size // 2 303 | self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka 304 | 305 | kernel_size = [kernel_size, kernel_size] 306 | sigma = [sigma, sigma] 307 | # The gaussian kernel is the product of the 308 | # gaussian function of each dimension. 309 | kernel = 1 310 | meshgrids = torch.meshgrid( 311 | [ 312 | torch.arange(size, dtype=torch.float32) 313 | for size in kernel_size 314 | ] 315 | ) 316 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 317 | mean = (size - 1) / 2 318 | kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) 319 | 320 | # Make sure sum of values in gaussian kernel equals 1. 321 | kernel = kernel / torch.sum(kernel) 322 | # Reshape to depthwise convolutional weight 323 | kernel = kernel.view(1, 1, *kernel.size()) 324 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 325 | 326 | self.register_buffer('weight', kernel) 327 | self.groups = channels 328 | self.scale = scale 329 | 330 | def forward(self, input): 331 | if self.scale == 1.0: 332 | return input 333 | 334 | out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) 335 | out = F.conv2d(out, weight=self.weight, groups=self.groups) 336 | out = F.interpolate(out, scale_factor=(self.scale, self.scale)) ## 337 | 338 | return out 339 | 340 | 341 | def to_homogeneous(coordinates): 342 | ones_shape = list(coordinates.shape) 343 | ones_shape[-1] = 1 344 | ones = torch.ones(ones_shape).type(coordinates.type()) 345 | 346 | return torch.cat([coordinates, ones], dim=-1) 347 | 348 | def from_homogeneous(coordinates): 349 | return coordinates[..., :2] / coordinates[..., 2:3] -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://github.com/hassony2/torch_videovision 3 | """ 4 | 5 | import numbers 6 | import cv2 7 | import random 8 | import numpy as np 9 | import PIL 10 | 11 | from skimage.transform import resize, rotate 12 | import torchvision 13 | 14 | import warnings 15 | 16 | from skimage import img_as_ubyte, img_as_float 17 | 18 | 19 | def crop_clip(clip, min_h, min_w, h, w): 20 | if isinstance(clip[0], np.ndarray): 21 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 22 | 23 | elif isinstance(clip[0], PIL.Image.Image): 24 | cropped = [ 25 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 26 | ] 27 | else: 28 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 29 | 'but got list of {0}'.format(type(clip[0]))) 30 | return cropped 31 | 32 | 33 | def pad_clip(clip, h, w): 34 | im_h, im_w = clip[0].shape[:2] 35 | pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) 36 | pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) 37 | 38 | return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') 39 | 40 | 41 | def resize_clip(clip, size, interpolation='bilinear'): 42 | if isinstance(clip[0], np.ndarray): 43 | if isinstance(size, numbers.Number): 44 | im_h, im_w, im_c = clip[0].shape 45 | # Min spatial dim already matches minimal size 46 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 47 | and im_h == size): 48 | return clip 49 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 50 | size = (new_w, new_h) 51 | else: 52 | size = size[1], size[0] 53 | 54 | scaled = [ 55 | resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, 56 | mode='constant', anti_aliasing=True) for img in clip 57 | ] 58 | elif isinstance(clip[0], PIL.Image.Image): 59 | if isinstance(size, numbers.Number): 60 | im_w, im_h = clip[0].size 61 | # Min spatial dim already matches minimal size 62 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 63 | and im_h == size): 64 | return clip 65 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 66 | size = (new_w, new_h) 67 | else: 68 | size = size[1], size[0] 69 | if interpolation == 'bilinear': 70 | pil_inter = PIL.Image.NEAREST 71 | else: 72 | pil_inter = PIL.Image.BILINEAR 73 | scaled = [img.resize(size, pil_inter) for img in clip] 74 | else: 75 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 76 | 'but got list of {0}'.format(type(clip[0]))) 77 | return scaled 78 | 79 | 80 | def get_resize_sizes(im_h, im_w, size): 81 | if im_w < im_h: 82 | ow = size 83 | oh = int(size * im_h / im_w) 84 | else: 85 | oh = size 86 | ow = int(size * im_w / im_h) 87 | return oh, ow 88 | 89 | 90 | class RandomFlip(object): 91 | def __init__(self, time_flip=False, horizontal_flip=False): 92 | self.time_flip = time_flip 93 | self.horizontal_flip = horizontal_flip 94 | 95 | def __call__(self, clip): 96 | if random.random() < 0.5 and self.time_flip: 97 | return clip[::-1] 98 | if random.random() < 0.5 and self.horizontal_flip: 99 | if random.random()>0.5: 100 | clip[0] = np.fliplr(clip[0]) 101 | else: 102 | clip[1] = np.fliplr(clip[1]) 103 | # return [np.fliplr(img) for img in clip] 104 | return clip 105 | 106 | return clip 107 | 108 | 109 | class RandomPerspective(object): 110 | def __init__(self, distortion_scale=0.1, p=0.5): 111 | self.distortion_scale = distortion_scale 112 | self.p = p 113 | 114 | def get_params(self, width, height, distortion_scale): 115 | 116 | half_height = height // 2 117 | half_width = width // 2 118 | topleft = [ 119 | random.randint(0, int(distortion_scale * half_width) + 1), 120 | random.randint(0, int(distortion_scale * half_height) + 1), 121 | ] 122 | topright = [ 123 | random.randint(width - int(distortion_scale * half_width) - 1, width), 124 | random.randint(0, int(distortion_scale * half_height) + 1), 125 | ] 126 | botright = [ 127 | random.randint(width - int(distortion_scale * half_width) - 1, width), 128 | random.randint(height - int(distortion_scale * half_height) - 1, height), 129 | ] 130 | botleft = [ 131 | random.randint(0, int(distortion_scale * half_width) + 1), 132 | random.randint(height - int(distortion_scale * half_height) - 1, height), 133 | ] 134 | startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] 135 | endpoints = [topleft, topright, botright, botleft] 136 | return startpoints, endpoints 137 | 138 | def get_perspective_mat(self, src_points, dst_points): 139 | src_points = np.float32(src_points) 140 | dst_points = np.float32(dst_points) 141 | 142 | M = cv2.getPerspectiveTransform(src_points, dst_points) 143 | 144 | return M 145 | 146 | def __call__(self, clip): 147 | if random.random() < self.p: 148 | H, W, _ = clip[1].shape 149 | startpoints, endpoints = self.get_params(W, H, self.distortion_scale) 150 | M = self.get_perspective_mat(startpoints, endpoints) 151 | # 只对drving变换 152 | clip[1] = cv2.warpPerspective(clip[1], M, dsize=clip[1].shape[:2]) 153 | return clip 154 | 155 | 156 | class RandomResize(object): 157 | """Resizes a list of (H x W x C) numpy.ndarray to the final size 158 | The larger the original image is, the more times it takes to 159 | interpolate 160 | Args: 161 | interpolation (str): Can be one of 'nearest', 'bilinear' 162 | defaults to nearest 163 | size (tuple): (widht, height) 164 | """ 165 | 166 | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): 167 | self.ratio = ratio 168 | self.interpolation = interpolation 169 | 170 | def __call__(self, clip): 171 | scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) 172 | 173 | if isinstance(clip[0], np.ndarray): 174 | im_h, im_w, im_c = clip[0].shape 175 | elif isinstance(clip[0], PIL.Image.Image): 176 | im_w, im_h = clip[0].size 177 | 178 | new_w = int(im_w * scaling_factor) 179 | new_h = int(im_h * scaling_factor) 180 | new_size = (new_w, new_h) 181 | resized = resize_clip( 182 | clip, new_size, interpolation=self.interpolation) 183 | 184 | return resized 185 | 186 | 187 | class RandomCrop(object): 188 | """Extract random crop at the same location for a list of videos 189 | Args: 190 | size (sequence or int): Desired output size for the 191 | crop in format (h, w) 192 | """ 193 | 194 | def __init__(self, size): 195 | if isinstance(size, numbers.Number): 196 | size = (size, size) 197 | 198 | self.size = size 199 | 200 | def __call__(self, clip): 201 | """ 202 | Args: 203 | img (PIL.Image or numpy.ndarray): List of videos to be cropped 204 | in format (h, w, c) in numpy.ndarray 205 | Returns: 206 | PIL.Image or numpy.ndarray: Cropped list of videos 207 | """ 208 | h, w = self.size 209 | if isinstance(clip[0], np.ndarray): 210 | im_h, im_w, im_c = clip[0].shape 211 | elif isinstance(clip[0], PIL.Image.Image): 212 | im_w, im_h = clip[0].size 213 | else: 214 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 215 | 'but got list of {0}'.format(type(clip[0]))) 216 | 217 | clip = pad_clip(clip, h, w) 218 | im_h, im_w = clip.shape[1:3] 219 | x1 = 0 if h == im_h else random.randint(0, im_w - w) 220 | y1 = 0 if w == im_w else random.randint(0, im_h - h) 221 | cropped = crop_clip(clip, y1, x1, h, w) 222 | 223 | return cropped 224 | 225 | 226 | class RandomRotation(object): 227 | """Rotate entire clip randomly by a random angle within 228 | given bounds 229 | Args: 230 | degrees (sequence or int): Range of degrees to select from 231 | If degrees is a number instead of sequence like (min, max), 232 | the range of degrees, will be (-degrees, +degrees). 233 | """ 234 | 235 | def __init__(self, degrees): 236 | if isinstance(degrees, numbers.Number): 237 | if degrees < 0: 238 | raise ValueError('If degrees is a single number,' 239 | 'must be positive') 240 | degrees = (-degrees, degrees) 241 | else: 242 | if len(degrees) != 2: 243 | raise ValueError('If degrees is a sequence,' 244 | 'it must be of len 2.') 245 | 246 | self.degrees = degrees 247 | 248 | def __call__(self, clip): 249 | """ 250 | Args: 251 | img (PIL.Image or numpy.ndarray): List of videos to be cropped 252 | in format (h, w, c) in numpy.ndarray 253 | Returns: 254 | PIL.Image or numpy.ndarray: Cropped list of videos 255 | """ 256 | angle = random.uniform(self.degrees[0], self.degrees[1]) 257 | if isinstance(clip[0], np.ndarray): 258 | rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] 259 | elif isinstance(clip[0], PIL.Image.Image): 260 | rotated = [img.rotate(angle) for img in clip] 261 | else: 262 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 263 | 'but got list of {0}'.format(type(clip[0]))) 264 | 265 | return rotated 266 | 267 | 268 | class ColorJitter(object): 269 | """Randomly change the brightness, contrast and saturation and hue of the clip 270 | Args: 271 | brightness (float): How much to jitter brightness. brightness_factor 272 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 273 | contrast (float): How much to jitter contrast. contrast_factor 274 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 275 | saturation (float): How much to jitter saturation. saturation_factor 276 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 277 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 278 | [-hue, hue]. Should be >=0 and <= 0.5. 279 | """ 280 | 281 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 282 | self.brightness = brightness 283 | self.contrast = contrast 284 | self.saturation = saturation 285 | self.hue = hue 286 | 287 | def get_params(self, brightness, contrast, saturation, hue): 288 | if brightness > 0: 289 | brightness_factor = random.uniform( 290 | max(0, 1 - brightness), 1 + brightness) 291 | else: 292 | brightness_factor = None 293 | 294 | if contrast > 0: 295 | contrast_factor = random.uniform( 296 | max(0, 1 - contrast), 1 + contrast) 297 | else: 298 | contrast_factor = None 299 | 300 | if saturation > 0: 301 | saturation_factor = random.uniform( 302 | max(0, 1 - saturation), 1 + saturation) 303 | else: 304 | saturation_factor = None 305 | 306 | if hue > 0: 307 | hue_factor = random.uniform(-hue, hue) 308 | else: 309 | hue_factor = None 310 | return brightness_factor, contrast_factor, saturation_factor, hue_factor 311 | 312 | def __call__(self, clip): 313 | """ 314 | Args: 315 | clip (list): list of PIL.Image 316 | Returns: 317 | list PIL.Image : list of transformed PIL.Image 318 | """ 319 | if isinstance(clip[0], np.ndarray): 320 | brightness, contrast, saturation, hue = self.get_params( 321 | self.brightness, self.contrast, self.saturation, self.hue) 322 | 323 | # Create img transform function sequence 324 | img_transforms = [] 325 | if brightness is not None: 326 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 327 | if saturation is not None: 328 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 329 | if hue is not None: 330 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 331 | if contrast is not None: 332 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 333 | random.shuffle(img_transforms) 334 | img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, 335 | img_as_float] 336 | 337 | with warnings.catch_warnings(): 338 | warnings.simplefilter("ignore") 339 | jittered_clip = [] 340 | for img in clip: 341 | jittered_img = img 342 | for func in img_transforms: 343 | jittered_img = func(jittered_img) 344 | jittered_clip.append(jittered_img.astype('float32')) 345 | elif isinstance(clip[0], PIL.Image.Image): 346 | brightness, contrast, saturation, hue = self.get_params( 347 | self.brightness, self.contrast, self.saturation, self.hue) 348 | 349 | # Create img transform function sequence 350 | img_transforms = [] 351 | if brightness is not None: 352 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 353 | if saturation is not None: 354 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 355 | if hue is not None: 356 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 357 | if contrast is not None: 358 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 359 | random.shuffle(img_transforms) 360 | 361 | # Apply to all videos 362 | jittered_clip = [] 363 | for img in clip: 364 | for func in img_transforms: 365 | jittered_img = func(img) 366 | jittered_clip.append(jittered_img) 367 | 368 | else: 369 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 370 | 'but got list of {0}'.format(type(clip[0]))) 371 | return jittered_clip 372 | 373 | 374 | class AllAugmentationTransform: 375 | def __init__(self, resize_param=None, rotation_param=None, flip_param=None, perspective_param=None, crop_param=None, jitter_param=None): 376 | self.transforms = [] 377 | 378 | if flip_param is not None: 379 | self.transforms.append(RandomFlip(**flip_param)) 380 | 381 | if perspective_param is not None: 382 | self.transforms.append(RandomPerspective(**perspective_param)) 383 | 384 | if rotation_param is not None: 385 | self.transforms.append(RandomRotation(**rotation_param)) 386 | 387 | if resize_param is not None: 388 | self.transforms.append(RandomResize(**resize_param)) 389 | 390 | if crop_param is not None: 391 | self.transforms.append(RandomCrop(**crop_param)) 392 | 393 | if jitter_param is not None: 394 | self.transforms.append(ColorJitter(**jitter_param)) 395 | 396 | def __call__(self, clip): 397 | for t in self.transforms: 398 | clip = t(clip) 399 | return clip 400 | --------------------------------------------------------------------------------