├── trainers ├── __init__.py └── face_trainer.py ├── third_part └── PerceptualSimilarity │ ├── util │ ├── __init__.py │ ├── html.py │ ├── visualizer.py │ └── util.py │ ├── models │ ├── __init__.py │ ├── models.py │ ├── base_model.py │ ├── pretrained_networks.py │ ├── networks_basic.py │ └── dist_model.py │ └── weights │ ├── v0.0 │ ├── vgg.pth │ ├── alex.pth │ └── squeeze.pth │ └── v0.1 │ ├── vgg.pth │ ├── alex.pth │ └── squeeze.pth ├── demo_images ├── expression.mat ├── id10010#Fi21gDronE4#001686#002204_00079.jpg ├── id10068#vqs1ZvoJ4XY#010464#010913_00138.jpg ├── id10071#A2bWI0qrkd8#001409#001652_00256.jpg ├── id10132#n-lR9XWz8g4#002588#002752_00104.jpg ├── id10137#LRQHZEBlcKo#001256#001586_00380.jpg ├── id10145#RlXVj2MlnnE#007232#007344_00052.jpg ├── id10318#nXC2Ne90r2M#002586#002690_00107.jpg ├── id10593#cY3-UkuZhJM#007714#007874_00186.jpg ├── id11180#i33msR57ahw#003483#003609_00096.jpg ├── id11244#8KU6g7NB-L0#011007#011129_00104.jpg └── id11248#QdBQTHX55yI#021933#022196_00274.jpg ├── .gitmodules ├── scripts ├── download_weights.sh ├── download_demo_dataset.sh ├── inference_options.py ├── extract_kp_videos.py ├── coeff_detector.py ├── face_recon_images.py ├── face_recon_videos.py └── prepare_vox_lmdb.py ├── util ├── cudnn.py ├── logging.py ├── flow_util.py ├── init_weight.py ├── distributed.py ├── io.py ├── lpips.py ├── trainer.py ├── meters.py └── misc.py ├── DatasetHelper.md ├── requirements.txt ├── config ├── face.yaml └── face_demo.yaml ├── data ├── __init__.py ├── image_dataset.py ├── vox_video_dataset.py └── vox_dataset.py ├── train.py ├── generators ├── face_model.py └── base_function.py ├── inference.py ├── intuitive_control.py ├── README.md └── config.py /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo_images/expression.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/expression.mat -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Deep3DFaceRecon_pytorch"] 2 | path = Deep3DFaceRecon_pytorch 3 | url = https://github.com/sicxu/Deep3DFaceRecon_pytorch 4 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/third_part/PerceptualSimilarity/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/third_part/PerceptualSimilarity/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /demo_images/id10010#Fi21gDronE4#001686#002204_00079.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10010#Fi21gDronE4#001686#002204_00079.jpg -------------------------------------------------------------------------------- /demo_images/id10068#vqs1ZvoJ4XY#010464#010913_00138.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10068#vqs1ZvoJ4XY#010464#010913_00138.jpg -------------------------------------------------------------------------------- /demo_images/id10071#A2bWI0qrkd8#001409#001652_00256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10071#A2bWI0qrkd8#001409#001652_00256.jpg -------------------------------------------------------------------------------- /demo_images/id10132#n-lR9XWz8g4#002588#002752_00104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10132#n-lR9XWz8g4#002588#002752_00104.jpg -------------------------------------------------------------------------------- /demo_images/id10137#LRQHZEBlcKo#001256#001586_00380.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10137#LRQHZEBlcKo#001256#001586_00380.jpg -------------------------------------------------------------------------------- /demo_images/id10145#RlXVj2MlnnE#007232#007344_00052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10145#RlXVj2MlnnE#007232#007344_00052.jpg -------------------------------------------------------------------------------- /demo_images/id10318#nXC2Ne90r2M#002586#002690_00107.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10318#nXC2Ne90r2M#002586#002690_00107.jpg -------------------------------------------------------------------------------- /demo_images/id10593#cY3-UkuZhJM#007714#007874_00186.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id10593#cY3-UkuZhJM#007714#007874_00186.jpg -------------------------------------------------------------------------------- /demo_images/id11180#i33msR57ahw#003483#003609_00096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id11180#i33msR57ahw#003483#003609_00096.jpg -------------------------------------------------------------------------------- /demo_images/id11244#8KU6g7NB-L0#011007#011129_00104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id11244#8KU6g7NB-L0#011007#011129_00104.jpg -------------------------------------------------------------------------------- /demo_images/id11248#QdBQTHX55yI#021933#022196_00274.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/demo_images/id11248#QdBQTHX55yI#021933#022196_00274.jpg -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/third_part/PerceptualSimilarity/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/third_part/PerceptualSimilarity/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/third_part/PerceptualSimilarity/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenYurui/PIRender/HEAD/third_part/PerceptualSimilarity/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /scripts/download_weights.sh: -------------------------------------------------------------------------------- 1 | gdown https://drive.google.com/uc?id=1-0xOf6g58OmtKtEWJlU3VlnfRqPN9Uq7 2 | unzip -x ./face.zip 3 | mkdir ./result 4 | mv face ./result 5 | rm face.zip 6 | -------------------------------------------------------------------------------- /scripts/download_demo_dataset.sh: -------------------------------------------------------------------------------- 1 | gdown https://drive.google.com/uc?id=1ruuLw5-0fpm6EREexPn3I_UQPmkrBoq9 2 | unzip -x ./vox_lmdb_demo.zip 3 | mkdir ./dataset 4 | mv vox_lmdb_demo ./dataset 5 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/models/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | def create_model(opt): 4 | model = None 5 | print(opt.model) 6 | from .siam_model import * 7 | model = DistModel() 8 | model.initialize(opt, opt.batchSize, ) 9 | print("model [%s] was created" % (model.name())) 10 | return model 11 | 12 | -------------------------------------------------------------------------------- /util/cudnn.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | 3 | from util.distributed import master_only_print as print 4 | 5 | 6 | def init_cudnn(deterministic, benchmark): 7 | r"""Initialize the cudnn module. The two things to consider is whether to 8 | use cudnn benchmark and whether to use cudnn deterministic. If cudnn 9 | benchmark is set, then the cudnn deterministic is automatically false. 10 | 11 | Args: 12 | deterministic (bool): Whether to use cudnn deterministic. 13 | benchmark (bool): Whether to use cudnn benchmark. 14 | """ 15 | cudnn.deterministic = deterministic 16 | cudnn.benchmark = benchmark 17 | print('cudnn benchmark: {}'.format(benchmark)) 18 | print('cudnn deterministic: {}'.format(deterministic)) 19 | -------------------------------------------------------------------------------- /scripts/inference_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class InferenceOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 13 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 14 | 15 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 16 | parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') 17 | parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') 18 | parser.add_argument('--save_split_files', action='store_true', help='save split files or not') 19 | parser.add_argument('--inference_batch_size', type=int, default=8) 20 | 21 | # Dropout and Batchnorm has different behavior during training and test. 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /DatasetHelper.md: -------------------------------------------------------------------------------- 1 | ### Extract 3DMM Coefficients for Videos 2 | 3 | We provide scripts for extracting 3dmm coefficients for videos by using [DeepFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch/tree/73d491102af6731bded9ae6b3cc7466c3b2e9e48). 4 | 5 | 1. Follow the instructions of their repo to build the environment of DeepFaceRecon. 6 | 7 | 2. Copy the provided scrips to the folder `Deep3DFaceRecon_pytorch`. 8 | 9 | ```bash 10 | cp scripts/face_recon_videos.py ./Deep3DFaceRecon_pytorch 11 | cp scripts/extract_kp_videos.py ./Deep3DFaceRecon_pytorch 12 | cp scripts/coeff_detector.py ./Deep3DFaceRecon_pytorch 13 | cp scripts/inference_options.py ./Deep3DFaceRecon_pytorch/options 14 | 15 | cd Deep3DFaceRecon_pytorch 16 | ``` 17 | 18 | 3. Extract facial landmarks from videos. 19 | 20 | ```bash 21 | python extract_kp_videos.py \ 22 | --input_dir path_to_viodes \ 23 | --output_dir path_to_keypoint \ 24 | --device_ids 0,1,2,3 \ 25 | --workers 12 26 | ``` 27 | 28 | 4. Extract coefficients for videos 29 | 30 | ```bash 31 | python face_recon_videos.py \ 32 | --input_dir path_to_videos \ 33 | --keypoint_dir path_to_keypoint \ 34 | --output_dir output_dir \ 35 | --inference_batch_size 100 \ 36 | --name=model_name \ 37 | --epoch=20 \ 38 | --model facerecon 39 | ``` 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | backcall==0.2.0 3 | cachetools==4.2.2 4 | certifi==2021.5.30 5 | charset-normalizer==2.0.6 6 | cycler==0.10.0 7 | dataclasses==0.8 8 | decorator==4.4.2 9 | filelock==3.0.12 10 | gdown==3.13.1 11 | google-auth==1.35.0 12 | google-auth-oauthlib==0.4.6 13 | grpcio==1.40.0 14 | idna==3.2 15 | imageio==2.9.0 16 | importlib-metadata==4.8.1 17 | ipython==7.16.1 18 | ipython-genutils==0.2.0 19 | jedi==0.18.0 20 | kiwisolver==1.3.1 21 | lmdb==1.2.1 22 | Markdown==3.3.4 23 | matplotlib==3.3.4 24 | mkl-fft==1.3.0 25 | mkl-random==1.1.1 26 | mkl-service==2.3.0 27 | networkx==2.5.1 28 | numpy==1.19.2 29 | oauthlib==3.1.1 30 | olefile==0.46 31 | opencv-python==4.5.3.56 32 | parso==0.8.2 33 | pexpect==4.8.0 34 | pickleshare==0.7.5 35 | Pillow==8.3.1 36 | pip==21.2.2 37 | prompt-toolkit==3.0.20 38 | protobuf==3.18.0 39 | ptyprocess==0.7.0 40 | pyasn1==0.4.8 41 | pyasn1-modules==0.2.8 42 | Pygments==2.10.0 43 | pyparsing==2.4.7 44 | PySocks==1.7.1 45 | python-dateutil==2.8.2 46 | PyWavelets==1.1.1 47 | PyYAML==5.4.1 48 | requests==2.26.0 49 | requests-oauthlib==1.3.0 50 | rsa==4.7.2 51 | scikit-image==0.17.2 52 | scipy==1.5.4 53 | setuptools==58.0.4 54 | six==1.16.0 55 | tensorboard==2.6.0 56 | tensorboard-data-server==0.6.1 57 | tensorboard-plugin-wit==1.8.0 58 | tifffile==2020.9.3 59 | torch==1.7.1 60 | torchvision==0.8.2 61 | tqdm==4.62.2 62 | traitlets==4.3.3 63 | typing-extensions==3.10.0.2 64 | urllib3==1.26.6 65 | wcwidth==0.2.5 66 | Werkzeug==2.0.1 67 | wheel==0.37.0 68 | zipp==3.5.0 69 | -------------------------------------------------------------------------------- /util/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | 4 | from util.meters import set_summary_writer 5 | from util.distributed import master_only_print as print 6 | from util.distributed import master_only 7 | 8 | def get_date_uid(): 9 | """Generate a unique id based on date. 10 | Returns: 11 | str: Return uid string, e.g. '20171122171307111552'. 12 | """ 13 | return str(datetime.datetime.now().strftime("%Y_%m%d_%H%M_%S")) 14 | 15 | 16 | def init_logging(opt): 17 | date_uid = get_date_uid() 18 | if opt.name is not None: 19 | logdir = os.path.join(opt.checkpoints_dir, opt.name) 20 | else: 21 | logdir = os.path.join(opt.checkpoints_dir, date_uid) 22 | opt.logdir = logdir 23 | return date_uid, logdir 24 | 25 | @master_only 26 | def make_logging_dir(logdir, date_uid): 27 | r"""Create the logging directory 28 | 29 | Args: 30 | logdir (str): Log directory name 31 | """ 32 | 33 | 34 | print('Make folder {}'.format(logdir)) 35 | os.makedirs(logdir, exist_ok=True) 36 | tensorboard_dir = os.path.join(logdir, 'tensorboard') 37 | image_dir = os.path.join(logdir, 'image') 38 | eval_dir = os.path.join(logdir, 'evaluation') 39 | os.makedirs(tensorboard_dir, exist_ok=True) 40 | os.makedirs(image_dir, exist_ok=True) 41 | os.makedirs(eval_dir, exist_ok=True) 42 | 43 | set_summary_writer(tensorboard_dir) 44 | loss_log_name = os.path.join(logdir, 'loss_log.txt') 45 | with open(loss_log_name, "a") as log_file: 46 | log_file.write('================ Training Loss (%s) ================\n' % date_uid) 47 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True): 15 | self.use_gpu = use_gpu 16 | self.Tensor = torch.cuda.FloatTensor if self.use_gpu else torch.Tensor 17 | # self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | # embed() 46 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 47 | save_path = os.path.join(self.save_dir, save_filename) 48 | print('Loading network from %s'%save_path) 49 | network.load_state_dict(torch.load(save_path)) 50 | 51 | def update_learning_rate(): 52 | pass 53 | 54 | def get_image_paths(self): 55 | return self.image_paths 56 | 57 | def save_done(self, flag=False): 58 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 59 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 60 | 61 | -------------------------------------------------------------------------------- /util/flow_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def convert_flow_to_deformation(flow): 4 | r"""convert flow fields to deformations. 5 | 6 | Args: 7 | flow (tensor): Flow field obtained by the model 8 | Returns: 9 | deformation (tensor): The deformation used for warpping 10 | """ 11 | b,c,h,w = flow.shape 12 | flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) 13 | grid = make_coordinate_grid(flow) 14 | deformation = grid + flow_norm.permute(0,2,3,1) 15 | return deformation 16 | 17 | def make_coordinate_grid(flow): 18 | r"""obtain coordinate grid with the same size as the flow filed. 19 | 20 | Args: 21 | flow (tensor): Flow field obtained by the model 22 | Returns: 23 | grid (tensor): The grid with the same size as the input flow 24 | """ 25 | b,c,h,w = flow.shape 26 | 27 | x = torch.arange(w).to(flow) 28 | y = torch.arange(h).to(flow) 29 | 30 | x = (2 * (x / (w - 1)) - 1) 31 | y = (2 * (y / (h - 1)) - 1) 32 | 33 | yy = y.view(-1, 1).repeat(1, w) 34 | xx = x.view(1, -1).repeat(h, 1) 35 | 36 | meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) 37 | meshed = meshed.expand(b, -1, -1, -1) 38 | return meshed 39 | 40 | 41 | def warp_image(source_image, deformation): 42 | r"""warp the input image according to the deformation 43 | 44 | Args: 45 | source_image (tensor): source images to be warpped 46 | deformation (tensor): deformations used to warp the images; value in range (-1, 1) 47 | Returns: 48 | output (tensor): the warpped images 49 | """ 50 | _, h_old, w_old, _ = deformation.shape 51 | _, _, h, w = source_image.shape 52 | if h_old != h or w_old != w: 53 | deformation = deformation.permute(0, 3, 1, 2) 54 | deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') 55 | deformation = deformation.permute(0, 2, 3, 1) 56 | return torch.nn.functional.grid_sample(source_image, deformation) -------------------------------------------------------------------------------- /config/face.yaml: -------------------------------------------------------------------------------- 1 | # How often do you want to log the training stats. 2 | # network_list: 3 | # gen: gen_optimizer 4 | # dis: dis_optimizer 5 | 6 | distributed: True 7 | image_to_tensorboard: True 8 | snapshot_save_iter: 40000 9 | snapshot_save_epoch: 20 10 | snapshot_save_start_iter: 20000 11 | snapshot_save_start_epoch: 10 12 | image_save_iter: 1000 13 | max_epoch: 200 14 | logging_iter: 100 15 | results_dir: ./eval_results 16 | 17 | gen_optimizer: 18 | type: adam 19 | lr: 0.0001 20 | adam_beta1: 0.5 21 | adam_beta2: 0.999 22 | lr_policy: 23 | iteration_mode: True 24 | type: step 25 | step_size: 300000 26 | gamma: 0.2 27 | 28 | trainer: 29 | type: trainers.face_trainer::FaceTrainer 30 | pretrain_warp_iteration: 200000 31 | loss_weight: 32 | weight_perceptual_warp: 2.5 33 | weight_perceptual_final: 4 34 | vgg_param_warp: 35 | network: vgg19 36 | layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] 37 | use_style_loss: False 38 | num_scales: 4 39 | vgg_param_final: 40 | network: vgg19 41 | layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] 42 | use_style_loss: True 43 | num_scales: 4 44 | style_to_perceptual: 250 45 | init: 46 | type: 'normal' 47 | gain: 0.02 48 | gen: 49 | type: generators.face_model::FaceGenerator 50 | param: 51 | mapping_net: 52 | coeff_nc: 73 53 | descriptor_nc: 256 54 | layer: 3 55 | warpping_net: 56 | encoder_layer: 5 57 | decoder_layer: 3 58 | base_nc: 32 59 | editing_net: 60 | layer: 3 61 | num_res_blocks: 2 62 | base_nc: 64 63 | common: 64 | image_nc: 3 65 | descriptor_nc: 256 66 | max_nc: 256 67 | use_spect: False 68 | 69 | 70 | # Data options. 71 | data: 72 | type: data.vox_dataset::VoxDataset 73 | path: ./dataset/vox_lmdb 74 | resolution: 256 75 | semantic_radius: 13 76 | train: 77 | batch_size: 5 78 | distributed: True 79 | val: 80 | batch_size: 8 81 | distributed: True 82 | 83 | 84 | -------------------------------------------------------------------------------- /config/face_demo.yaml: -------------------------------------------------------------------------------- 1 | # How often do you want to log the training stats. 2 | # network_list: 3 | # gen: gen_optimizer 4 | # dis: dis_optimizer 5 | 6 | distributed: True 7 | image_to_tensorboard: True 8 | snapshot_save_iter: 40000 9 | snapshot_save_epoch: 20 10 | snapshot_save_start_iter: 20000 11 | snapshot_save_start_epoch: 10 12 | image_save_iter: 1000 13 | max_epoch: 200 14 | logging_iter: 100 15 | results_dir: ./eval_results 16 | 17 | gen_optimizer: 18 | type: adam 19 | lr: 0.0001 20 | adam_beta1: 0.5 21 | adam_beta2: 0.999 22 | lr_policy: 23 | iteration_mode: True 24 | type: step 25 | step_size: 300000 26 | gamma: 0.2 27 | 28 | trainer: 29 | type: trainers.face_trainer::FaceTrainer 30 | pretrain_warp_iteration: 200000 31 | loss_weight: 32 | weight_perceptual_warp: 2.5 33 | weight_perceptual_final: 4 34 | vgg_param_warp: 35 | network: vgg19 36 | layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] 37 | use_style_loss: False 38 | num_scales: 4 39 | vgg_param_final: 40 | network: vgg19 41 | layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] 42 | use_style_loss: True 43 | num_scales: 4 44 | style_to_perceptual: 250 45 | init: 46 | type: 'normal' 47 | gain: 0.02 48 | gen: 49 | type: generators.face_model::FaceGenerator 50 | param: 51 | mapping_net: 52 | coeff_nc: 73 53 | descriptor_nc: 256 54 | layer: 3 55 | warpping_net: 56 | encoder_layer: 5 57 | decoder_layer: 3 58 | base_nc: 32 59 | editing_net: 60 | layer: 3 61 | num_res_blocks: 2 62 | base_nc: 64 63 | common: 64 | image_nc: 3 65 | descriptor_nc: 256 66 | max_nc: 256 67 | use_spect: False 68 | 69 | 70 | # Data options. 71 | data: 72 | type: data.vox_dataset::VoxDataset 73 | path: ./dataset/vox_lmdb_demo 74 | resolution: 256 75 | semantic_radius: 13 76 | train: 77 | batch_size: 5 78 | distributed: True 79 | val: 80 | batch_size: 8 81 | distributed: True 82 | 83 | 84 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | # self.img_dir = os.path.join(self.web_dir, ) 11 | self.img_subdir = image_subdir 12 | self.img_dir = os.path.join(self.web_dir, image_subdir) 13 | if not os.path.exists(self.web_dir): 14 | os.makedirs(self.web_dir) 15 | if not os.path.exists(self.img_dir): 16 | os.makedirs(self.img_dir) 17 | # print(self.img_dir) 18 | 19 | self.doc = dominate.document(title=title) 20 | if reflesh > 0: 21 | with self.doc.head: 22 | meta(http_equiv="reflesh", content=str(reflesh)) 23 | 24 | def get_image_dir(self): 25 | return self.img_dir 26 | 27 | def add_header(self, str): 28 | with self.doc: 29 | h3(str) 30 | 31 | def add_table(self, border=1): 32 | self.t = table(border=border, style="table-layout: fixed;") 33 | self.doc.add(self.t) 34 | 35 | def add_images(self, ims, txts, links, width=400): 36 | self.add_table() 37 | with self.t: 38 | with tr(): 39 | for im, txt, link in zip(ims, txts, links): 40 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 41 | with p(): 42 | with a(href=os.path.join(link)): 43 | img(style="width:%dpx" % width, src=os.path.join(im)) 44 | br() 45 | p(txt) 46 | 47 | def save(self,file='index'): 48 | html_file = '%s/%s.html' % (self.web_dir,file) 49 | f = open(html_file, 'wt') 50 | f.write(self.doc.render()) 51 | f.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | html = HTML('web/', 'test_html') 56 | html.add_header('hello world') 57 | 58 | ims = [] 59 | txts = [] 60 | links = [] 61 | for n in range(4): 62 | ims.append('image_%d.png' % n) 63 | txts.append('text_%d' % n) 64 | links.append('image_%d.png' % n) 65 | html.add_images(ims, txts, links) 66 | html.save() 67 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch.utils.data 4 | from util.distributed import master_only_print as print 5 | 6 | def find_dataset_using_name(dataset_name): 7 | dataset_filename = dataset_name 8 | module, target = dataset_name.split('::') 9 | datasetlib = importlib.import_module(module) 10 | dataset = None 11 | for name, cls in datasetlib.__dict__.items(): 12 | if name == target: 13 | dataset = cls 14 | 15 | if dataset is None: 16 | raise ValueError("In %s.py, there should be a class " 17 | "with class name that matches %s in lowercase." % 18 | (dataset_filename, target)) 19 | 20 | return dataset 21 | 22 | 23 | def get_option_setter(dataset_name): 24 | dataset_class = find_dataset_using_name(dataset_name) 25 | return dataset_class.modify_commandline_options 26 | 27 | 28 | def create_dataloader(opt, is_inference): 29 | dataset = find_dataset_using_name(opt.type) 30 | instance = dataset(opt, is_inference) 31 | phase = 'val' if is_inference else 'training' 32 | batch_size = opt.val.batch_size if is_inference else opt.train.batch_size 33 | print("%s dataset [%s] of size %d was created" % 34 | (phase, opt.type, len(instance))) 35 | dataloader = torch.utils.data.DataLoader( 36 | instance, 37 | batch_size=batch_size, 38 | sampler=data_sampler(instance, shuffle=not is_inference, distributed=opt.train.distributed), 39 | drop_last=not is_inference, 40 | num_workers=getattr(opt, 'num_workers', 0), 41 | ) 42 | 43 | return dataloader 44 | 45 | 46 | def data_sampler(dataset, shuffle, distributed): 47 | if distributed: 48 | return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) 49 | if shuffle: 50 | return torch.utils.data.RandomSampler(dataset) 51 | else: 52 | return torch.utils.data.SequentialSampler(dataset) 53 | 54 | 55 | def get_dataloader(opt, is_inference=False): 56 | dataset = create_dataloader(opt, is_inference=is_inference) 57 | return dataset 58 | 59 | 60 | def get_train_val_dataloader(opt): 61 | val_dataset = create_dataloader(opt, is_inference=True) 62 | train_dataset = create_dataloader(opt, is_inference=False) 63 | return val_dataset, train_dataset 64 | -------------------------------------------------------------------------------- /util/init_weight.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | 3 | 4 | def weights_init(init_type='normal', gain=0.02, bias=None): 5 | r"""Initialize weights in the network. 6 | 7 | Args: 8 | init_type (str): The name of the initialization scheme. 9 | gain (float): The parameter that is required for the initialization 10 | scheme. 11 | bias (object): If not ``None``, specifies the initialization parameter 12 | for bias. 13 | 14 | Returns: 15 | (obj): init function to be applied. 16 | """ 17 | 18 | def init_func(m): 19 | r"""Init function 20 | 21 | Args: 22 | m: module to be weight initialized. 23 | """ 24 | class_name = m.__class__.__name__ 25 | if hasattr(m, 'weight') and ( 26 | class_name.find('Conv') != -1 or 27 | class_name.find('Linear') != -1 or 28 | class_name.find('Embedding') != -1): 29 | if init_type == 'normal': 30 | init.normal_(m.weight.data, 0.0, gain) 31 | elif init_type == 'xavier': 32 | init.xavier_normal_(m.weight.data, gain=gain) 33 | elif init_type == 'xavier_uniform': 34 | init.xavier_uniform_(m.weight.data, gain=1.0) 35 | elif init_type == 'kaiming': 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | elif init_type == 'orthogonal': 38 | init.orthogonal_(m.weight.data, gain=gain) 39 | elif init_type == 'none': 40 | m.reset_parameters() 41 | else: 42 | raise NotImplementedError( 43 | 'initialization method [%s] is ' 44 | 'not implemented' % init_type) 45 | if hasattr(m, 'bias') and m.bias is not None: 46 | if bias is not None: 47 | bias_type = getattr(bias, 'type', 'normal') 48 | if bias_type == 'normal': 49 | bias_gain = getattr(bias, 'gain', 0.5) 50 | init.normal_(m.bias.data, 0.0, bias_gain) 51 | else: 52 | raise NotImplementedError( 53 | 'initialization method [%s] is ' 54 | 'not implemented' % bias_type) 55 | else: 56 | init.constant_(m.bias.data, 0.0) 57 | return init_func 58 | -------------------------------------------------------------------------------- /util/distributed.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | def init_dist(local_rank, backend='nccl', **kwargs): 7 | r"""Initialize distributed training""" 8 | if dist.is_available(): 9 | if dist.is_initialized(): 10 | return torch.cuda.current_device() 11 | torch.cuda.set_device(local_rank) 12 | dist.init_process_group(backend=backend, init_method='env://', **kwargs) 13 | 14 | 15 | def get_rank(): 16 | r"""Get rank of the thread.""" 17 | rank = 0 18 | if dist.is_available(): 19 | if dist.is_initialized(): 20 | rank = dist.get_rank() 21 | return rank 22 | 23 | 24 | def get_world_size(): 25 | r"""Get world size. How many GPUs are available in this job.""" 26 | world_size = 1 27 | if dist.is_available(): 28 | if dist.is_initialized(): 29 | world_size = dist.get_world_size() 30 | return world_size 31 | 32 | 33 | def master_only(func): 34 | r"""Apply this function only to the master GPU.""" 35 | @functools.wraps(func) 36 | def wrapper(*args, **kwargs): 37 | r"""Simple function wrapper for the master function""" 38 | if get_rank() == 0: 39 | return func(*args, **kwargs) 40 | else: 41 | return None 42 | return wrapper 43 | 44 | 45 | def is_master(): 46 | r"""check if current process is the master""" 47 | return get_rank() == 0 48 | 49 | 50 | @master_only 51 | def master_only_print(*args): 52 | r"""master-only print""" 53 | print(*args) 54 | 55 | 56 | def dist_reduce_tensor(tensor): 57 | r""" Reduce to rank 0 """ 58 | world_size = get_world_size() 59 | if world_size < 2: 60 | return tensor 61 | with torch.no_grad(): 62 | dist.reduce(tensor, dst=0) 63 | if get_rank() == 0: 64 | tensor /= world_size 65 | return tensor 66 | 67 | 68 | def dist_all_reduce_tensor(tensor): 69 | r""" Reduce to all ranks """ 70 | world_size = get_world_size() 71 | if world_size < 2: 72 | return tensor 73 | with torch.no_grad(): 74 | dist.all_reduce(tensor) 75 | tensor.div_(world_size) 76 | return tensor 77 | 78 | 79 | def dist_all_gather_tensor(tensor): 80 | r""" gather to all ranks """ 81 | world_size = get_world_size() 82 | if world_size < 2: 83 | return [tensor] 84 | tensor_list = [ 85 | torch.ones_like(tensor) for _ in range(dist.get_world_size())] 86 | with torch.no_grad(): 87 | dist.all_gather(tensor_list, tensor) 88 | return tensor_list 89 | -------------------------------------------------------------------------------- /data/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | import torchvision.transforms.functional as F 9 | 10 | 11 | 12 | class ImageDataset(): 13 | def __init__(self, opt, input_name): 14 | self.opt = opt 15 | self.IMAGEEXT = ['png', 'jpg'] 16 | self.input_image_list, self.coeff_list = self.obtain_inputs(input_name) 17 | self.index = -1 18 | # load image dataset opt 19 | self.resolution = opt.resolution 20 | self.semantic_radius = opt.semantic_radius 21 | 22 | def next_image(self): 23 | self.index += 1 24 | image_name = self.input_image_list[self.index] 25 | coeff_name = self.coeff_list[self.index] 26 | img = Image.open(image_name) 27 | input_image = self.trans_image(img) 28 | 29 | coeff_3dmm = np.loadtxt(coeff_name).astype(np.float32) 30 | coeff_3dmm = self.transform_semantic(coeff_3dmm) 31 | 32 | return { 33 | 'source_image': input_image[None], 34 | 'target_semantics': coeff_3dmm[None], 35 | 'name': os.path.splitext(os.path.basename(image_name))[0] 36 | } 37 | 38 | def obtain_inputs(self, root): 39 | filenames = list() 40 | 41 | IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'} 42 | IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE}) 43 | extensions = IMAGE_EXTENSIONS 44 | 45 | for ext in extensions: 46 | filenames += glob.glob(f'{root}/*.{ext}', recursive=True) 47 | filenames = sorted(filenames) 48 | coeffnames = sorted(glob.glob(f'{root}/*_3dmm_coeff.txt')) 49 | 50 | return filenames, coeffnames 51 | 52 | def transform_semantic(self, semantic): 53 | semantic = semantic[None].repeat(self.semantic_radius*2+1, 0) 54 | ex_coeff = semantic[:,80:144] #expression 55 | angles = semantic[:,224:227] #euler angles for pose 56 | translation = semantic[:,254:257] #translation 57 | crop = semantic[:,259:262] #crop param 58 | 59 | coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1) 60 | return torch.Tensor(coeff_3dmm).permute(1,0) 61 | 62 | def trans_image(self, image): 63 | image = F.resize( 64 | image, size=self.resolution, interpolation=Image.BICUBIC) 65 | image = F.to_tensor(image) 66 | image = F.normalize(image, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 67 | return image 68 | 69 | def __len__(self): 70 | return len(self.input_image_list) 71 | 72 | 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import data as Dataset 4 | from config import Config 5 | from util.logging import init_logging, make_logging_dir 6 | from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer 7 | from util.distributed import init_dist 8 | from util.distributed import master_only_print as print 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='Training') 13 | parser.add_argument('--config', default='./config/face.yaml') 14 | parser.add_argument('--name', default=None) 15 | parser.add_argument('--checkpoints_dir', default='result', 16 | help='Dir for saving logs and models.') 17 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 18 | parser.add_argument('--which_iter', type=int, default=None) 19 | parser.add_argument('--no_resume', action='store_true') 20 | parser.add_argument('--local_rank', type=int, default=0) 21 | parser.add_argument('--single_gpu', action='store_true') 22 | parser.add_argument('--debug', action='store_true') 23 | 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | if __name__ == '__main__': 29 | # get training options 30 | args = parse_args() 31 | set_random_seed(args.seed) 32 | opt = Config(args.config, args, is_train=True) 33 | 34 | if not args.single_gpu: 35 | opt.local_rank = args.local_rank 36 | init_dist(opt.local_rank) 37 | opt.device = opt.local_rank 38 | 39 | # create a visualizer 40 | date_uid, logdir = init_logging(opt) 41 | opt.logdir = logdir 42 | make_logging_dir(logdir, date_uid) 43 | # create a dataset 44 | val_dataset, train_dataset = Dataset.get_train_val_dataloader(opt.data) 45 | 46 | # create a model 47 | net_G, net_G_ema, opt_G, sch_G \ 48 | = get_model_optimizer_and_scheduler(opt) 49 | 50 | trainer = get_trainer(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset) 51 | 52 | current_epoch, current_iteration = trainer.load_checkpoint(opt, args.which_iter) 53 | # training flag 54 | max_epoch = opt.max_epoch 55 | 56 | if args.debug: 57 | trainer.test_everything(train_dataset, val_dataset, current_epoch, current_iteration) 58 | exit() 59 | # Start training. 60 | for epoch in range(current_epoch, opt.max_epoch): 61 | print('Epoch {} ...'.format(epoch)) 62 | if not args.single_gpu: 63 | train_dataset.sampler.set_epoch(current_epoch) 64 | trainer.start_of_epoch(current_epoch) 65 | for it, data in enumerate(train_dataset): 66 | data = trainer.start_of_iteration(data, current_iteration) 67 | trainer.optimize_parameters(data) 68 | current_iteration += 1 69 | trainer.end_of_iteration(data, current_epoch, current_iteration) 70 | 71 | if current_iteration >= opt.max_iter: 72 | print('Done with training!!!') 73 | break 74 | current_epoch += 1 75 | trainer.end_of_epoch(data, val_dataset, current_epoch, current_iteration) 76 | -------------------------------------------------------------------------------- /scripts/extract_kp_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import glob 5 | import argparse 6 | import face_alignment 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from itertools import cycle 11 | 12 | from torch.multiprocessing import Pool, Process, set_start_method 13 | 14 | class KeypointExtractor(): 15 | def __init__(self): 16 | self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D) 17 | 18 | def extract_keypoint(self, images, name=None): 19 | if isinstance(images, list): 20 | keypoints = [] 21 | for image in images: 22 | current_kp = self.extract_keypoint(image) 23 | if np.mean(current_kp) == -1 and keypoints: 24 | keypoints.append(keypoints[-1]) 25 | else: 26 | keypoints.append(current_kp[None]) 27 | 28 | keypoints = np.concatenate(keypoints, 0) 29 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 30 | return keypoints 31 | else: 32 | while True: 33 | try: 34 | keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] 35 | break 36 | except RuntimeError as e: 37 | if str(e).startswith('CUDA'): 38 | print("Warning: out of memory, sleep for 1s") 39 | time.sleep(1) 40 | else: 41 | print(e) 42 | break 43 | except TypeError: 44 | print('No face detected in this image') 45 | shape = [68, 2] 46 | keypoints = -1. * np.ones(shape) 47 | break 48 | if name is not None: 49 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 50 | return keypoints 51 | 52 | def read_video(filename): 53 | frames = [] 54 | cap = cv2.VideoCapture(filename) 55 | while cap.isOpened(): 56 | ret, frame = cap.read() 57 | if ret: 58 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 59 | frame = Image.fromarray(frame) 60 | frames.append(frame) 61 | else: 62 | break 63 | cap.release() 64 | return frames 65 | 66 | def run(data): 67 | filename, opt, device = data 68 | os.environ['CUDA_VISIBLE_DEVICES'] = device 69 | kp_extractor = KeypointExtractor() 70 | images = read_video(filename) 71 | name = filename.split('/')[-2:] 72 | os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) 73 | kp_extractor.extract_keypoint( 74 | images, 75 | name=os.path.join(opt.output_dir, name[-2], name[-1]) 76 | ) 77 | 78 | if __name__ == '__main__': 79 | set_start_method('spawn') 80 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 81 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 82 | parser.add_argument('--output_dir', type=str, help='the folder of the output files') 83 | parser.add_argument('--device_ids', type=str, default='0,1') 84 | parser.add_argument('--workers', type=int, default=4) 85 | 86 | opt = parser.parse_args() 87 | filenames = list() 88 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} 89 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) 90 | extensions = VIDEO_EXTENSIONS 91 | for ext in extensions: 92 | filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}')) 93 | print('Total number of videos:', len(filenames)) 94 | pool = Pool(opt.workers) 95 | args_list = cycle([opt]) 96 | device_ids = opt.device_ids.split(",") 97 | device_ids = cycle(device_ids) 98 | for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): 99 | None 100 | -------------------------------------------------------------------------------- /util/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | import torch.distributed as dist 5 | import torchvision.utils 6 | 7 | from util.distributed import is_master 8 | 9 | 10 | def save_pilimage_in_jpeg(fullname, output_img): 11 | r"""Save PIL Image to JPEG. 12 | 13 | Args: 14 | fullname (str): Full save path. 15 | output_img (PIL Image): Image to be saved. 16 | """ 17 | dirname = os.path.dirname(fullname) 18 | os.makedirs(dirname, exist_ok=True) 19 | output_img.save(fullname, 'JPEG', quality=99) 20 | 21 | 22 | def save_intermediate_training_results( 23 | visualization_images, logdir, current_epoch, current_iteration): 24 | r"""Save intermediate training results for debugging purpose. 25 | 26 | Args: 27 | visualization_images (tensor): Image where pixel values are in [-1, 1]. 28 | logdir (str): Where to save the image. 29 | current_epoch (int): Current training epoch. 30 | current_iteration (int): Current training iteration. 31 | """ 32 | visualization_images = (visualization_images + 1) / 2 33 | output_filename = os.path.join( 34 | logdir, 'images', 35 | 'epoch_{:05}iteration{:09}.jpg'.format( 36 | current_epoch, current_iteration)) 37 | print('Save output images to {}'.format(output_filename)) 38 | os.makedirs(os.path.dirname(output_filename), exist_ok=True) 39 | image_grid = torchvision.utils.make_grid( 40 | visualization_images.data, nrow=1, padding=0, normalize=False) 41 | torchvision.utils.save_image(image_grid, output_filename, nrow=1) 42 | 43 | 44 | def download_file_from_google_drive(file_id, destination): 45 | r"""Download a file from the google drive by using the file ID. 46 | 47 | Args: 48 | file_id: Google drive file ID 49 | destination: Path to save the file. 50 | 51 | Returns: 52 | 53 | """ 54 | URL = "https://docs.google.com/uc?export=download" 55 | session = requests.Session() 56 | response = session.get(URL, params={'id': file_id}, stream=True) 57 | token = get_confirm_token(response) 58 | if token: 59 | params = {'id': file_id, 'confirm': token} 60 | response = session.get(URL, params=params, stream=True) 61 | save_response_content(response, destination) 62 | 63 | 64 | def get_confirm_token(response): 65 | r"""Get confirm token 66 | 67 | Args: 68 | response: Check if the file exists. 69 | 70 | Returns: 71 | 72 | """ 73 | for key, value in response.cookies.items(): 74 | if key.startswith('download_warning'): 75 | return value 76 | return None 77 | 78 | 79 | def save_response_content(response, destination): 80 | r"""Save response content 81 | 82 | Args: 83 | response: 84 | destination: Path to save the file. 85 | 86 | Returns: 87 | 88 | """ 89 | chunk_size = 32768 90 | with open(destination, "wb") as f: 91 | for chunk in response.iter_content(chunk_size): 92 | if chunk: 93 | f.write(chunk) 94 | 95 | 96 | def get_checkpoint(checkpoint_path, url=''): 97 | r"""Get the checkpoint path. If it does not exist yet, download it from 98 | the url. 99 | 100 | Args: 101 | checkpoint_path (str): Checkpoint path. 102 | url (str): URL to download checkpoint. 103 | Returns: 104 | (str): Full checkpoint path. 105 | """ 106 | if 'TORCH_HOME' not in os.environ: 107 | os.environ['TORCH_HOME'] = os.getcwd() 108 | save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints') 109 | os.makedirs(save_dir, exist_ok=True) 110 | full_checkpoint_path = os.path.join(save_dir, checkpoint_path) 111 | if not os.path.exists(full_checkpoint_path): 112 | os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True) 113 | if is_master(): 114 | print('Download {}'.format(url)) 115 | download_file_from_google_drive(url, full_checkpoint_path) 116 | if dist.is_available() and dist.is_initialized(): 117 | dist.barrier() 118 | return full_checkpoint_path 119 | -------------------------------------------------------------------------------- /scripts/coeff_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from os import makedirs, name 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from options.inference_options import InferenceOptions 12 | from models import create_model 13 | from util.preprocess import align_img 14 | from util.load_mats import load_lm3d 15 | from extract_kp_videos import KeypointExtractor 16 | 17 | 18 | class CoeffDetector(nn.Module): 19 | def __init__(self, opt): 20 | super().__init__() 21 | 22 | self.model = create_model(opt) 23 | self.model.setup(opt) 24 | self.model.device = 'cuda' 25 | self.model.parallelize() 26 | self.model.eval() 27 | 28 | self.lm3d_std = load_lm3d(opt.bfm_folder) 29 | 30 | def forward(self, img, lm): 31 | 32 | img, trans_params = self.image_transform(img, lm) 33 | 34 | data_input = { 35 | 'imgs': img[None], 36 | } 37 | self.model.set_input(data_input) 38 | self.model.test() 39 | pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict} 40 | pred_coeff = np.concatenate([ 41 | pred_coeff['id'], 42 | pred_coeff['exp'], 43 | pred_coeff['tex'], 44 | pred_coeff['angle'], 45 | pred_coeff['gamma'], 46 | pred_coeff['trans'], 47 | trans_params[None], 48 | ], 1) 49 | 50 | return {'coeff_3dmm':pred_coeff, 51 | 'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))} 52 | 53 | def image_transform(self, images, lm): 54 | """ 55 | param: 56 | images: -- PIL image 57 | lm: -- numpy array 58 | """ 59 | W,H = images.size 60 | if np.mean(lm) == -1: 61 | lm = (self.lm3d_std[:, :2]+1)/2. 62 | lm = np.concatenate( 63 | [lm[:, :1]*W, lm[:, 1:2]*H], 1 64 | ) 65 | else: 66 | lm[:, -1] = H - 1 - lm[:, -1] 67 | 68 | trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) 69 | img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) 70 | trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) 71 | trans_params = torch.tensor(trans_params.astype(np.float32)) 72 | return img, trans_params 73 | 74 | def get_data_path(root, keypoint_root): 75 | filenames = list() 76 | keypoint_filenames = list() 77 | 78 | IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'} 79 | IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE}) 80 | extensions = IMAGE_EXTENSIONS 81 | 82 | for ext in extensions: 83 | filenames += glob.glob(f'{root}/*.{ext}', recursive=True) 84 | filenames = sorted(filenames) 85 | for filename in filenames: 86 | name = os.path.splitext(os.path.basename(filename))[0] 87 | keypoint_filenames.append( 88 | os.path.join(keypoint_root, name + '.txt') 89 | ) 90 | return filenames, keypoint_filenames 91 | 92 | 93 | if __name__ == "__main__": 94 | opt = InferenceOptions().parse() 95 | coeff_detector = CoeffDetector(opt) 96 | kp_extractor = KeypointExtractor() 97 | image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir) 98 | makedirs(opt.keypoint_dir, exist_ok=True) 99 | makedirs(opt.output_dir, exist_ok=True) 100 | 101 | for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)): 102 | image = Image.open(image_name) 103 | if not os.path.isfile(keypoint_name): 104 | lm = kp_extractor.extract_keypoint(image, keypoint_name) 105 | else: 106 | lm = np.loadtxt(keypoint_name).astype(np.float32) 107 | lm = lm.reshape([-1, 2]) 108 | predicted = coeff_detector(image, lm) 109 | name = os.path.splitext(os.path.basename(image_name))[0] 110 | np.savetxt( 111 | "{}/{}_3dmm_coeff.txt".format(opt.output_dir, name), 112 | predicted['coeff_3dmm'].reshape(-1)) 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /util/lpips.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from imageio import imread 5 | 6 | import torch 7 | 8 | from third_part.PerceptualSimilarity.models import dist_model as dm 9 | 10 | def get_image_list(flist): 11 | if isinstance(flist, list): 12 | return flist 13 | 14 | # flist: image file path, image directory path, text file flist path 15 | if isinstance(flist, str): 16 | if os.path.isdir(flist): 17 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 18 | flist.sort() 19 | return flist 20 | 21 | if os.path.isfile(flist): 22 | try: 23 | return np.genfromtxt(flist, dtype=np.str) 24 | except: 25 | return [flist] 26 | print('can not read files from %s return empty list'%flist) 27 | return [] 28 | 29 | def preprocess_path_for_deform_task(gt_path, distorted_path): 30 | distorted_image_list = sorted(get_image_list(distorted_path)) 31 | gt_list=[] 32 | distorated_list=[] 33 | 34 | for distorted_image in distorted_image_list: 35 | image = os.path.basename(distorted_image) 36 | image = image.split('_2_')[-1] 37 | image = image.split('_vis')[0] +'.jpg' 38 | gt_image = os.path.join(gt_path, image) 39 | if not os.path.isfile(gt_image): 40 | gt_image = gt_image.replace('.jpg', '.png') 41 | gt_list.append(gt_image) 42 | distorated_list.append(distorted_image) 43 | return gt_list, distorated_list 44 | 45 | class LPIPS(): 46 | def __init__(self, use_gpu=True): 47 | self.model = dm.DistModel() 48 | self.model.initialize(model='net-lin', net='alex', use_gpu=use_gpu) 49 | self.use_gpu=use_gpu 50 | 51 | def __call__(self, image_1, image_2): 52 | """ 53 | image_1: images with size (n, 3, w, h) with value [-1, 1] 54 | image_2: images with size (n, 3, w, h) with value [-1, 1] 55 | """ 56 | result = self.model.forward(image_1, image_2) 57 | return result 58 | 59 | def calculate_from_disk(self, gt_path, distorted_path, batch_size=64, verbose=False, for_deformation=True): 60 | # if sort: 61 | if for_deformation: 62 | files_1, files_2 = preprocess_path_for_deform_task(gt_path, distorted_path) 63 | else: 64 | files_1 = sorted(get_image_list(gt_path)) 65 | files_2 = sorted(get_image_list(distorted_path)) 66 | 67 | new_files_1, new_files_2 = [], [] 68 | for item1,item2 in zip(files_1, files_2): 69 | if os.path.isfile(item1) and os.path.isfile(item2): 70 | new_files_1.append(item1) 71 | new_files_2.append(item2) 72 | else: 73 | print(item2) 74 | imgs_1 = np.array([imread(str(fn)).astype(np.float32)/127.5-1 for fn in new_files_1]) 75 | imgs_2 = np.array([imread(str(fn)).astype(np.float32)/127.5-1 for fn in new_files_2]) 76 | 77 | # Bring images to shape (B, 3, H, W) 78 | imgs_1 = imgs_1.transpose((0, 3, 1, 2)) 79 | imgs_2 = imgs_2.transpose((0, 3, 1, 2)) 80 | 81 | result=[] 82 | 83 | 84 | d0 = imgs_1.shape[0] 85 | if batch_size > d0: 86 | print(('Warning: batch size is bigger than the data size. ' 87 | 'Setting batch size to data size')) 88 | batch_size = d0 89 | 90 | n_batches = d0 // batch_size 91 | n_used_imgs = n_batches * batch_size 92 | 93 | # imgs_1_arr = np.empty((n_used_imgs, self.dims)) 94 | # imgs_2_arr = np.empty((n_used_imgs, self.dims)) 95 | for i in range(n_batches): 96 | if verbose: 97 | print('\rPropagating batch %d/%d' % (i + 1, n_batches)) 98 | # end='', flush=True) 99 | start = i * batch_size 100 | end = start + batch_size 101 | 102 | img_1_batch = torch.from_numpy(imgs_1[start:end]).type(torch.FloatTensor) 103 | img_2_batch = torch.from_numpy(imgs_2[start:end]).type(torch.FloatTensor) 104 | 105 | if self.use_gpu: 106 | img_1_batch = img_1_batch.cuda() 107 | img_2_batch = img_2_batch.cuda() 108 | 109 | 110 | result.append(self.model.forward(img_1_batch, img_2_batch)) 111 | 112 | 113 | distance = np.average(result) 114 | print('lpips: %.3f'%distance) 115 | return distance -------------------------------------------------------------------------------- /util/trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import importlib 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import Adam, lr_scheduler 8 | 9 | from util.distributed import master_only_print as print 10 | from util.init_weight import weights_init 11 | 12 | def accumulate(model1, model2, decay=0.999): 13 | par1 = dict(model1.named_parameters()) 14 | par2 = dict(model2.named_parameters()) 15 | 16 | for k in par1.keys(): 17 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 18 | 19 | def set_random_seed(seed): 20 | r"""Set random seeds for everything. 21 | 22 | Args: 23 | seed (int): Random seed. 24 | by_rank (bool): 25 | """ 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | 32 | 33 | 34 | def get_trainer(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset): 35 | module, trainer_name = opt.trainer.type.split('::') 36 | 37 | trainer_lib = importlib.import_module(module) 38 | trainer_class = getattr(trainer_lib, trainer_name) 39 | trainer = trainer_class(opt, net_G, net_G_ema, opt_G, sch_G, train_dataset) 40 | return trainer 41 | 42 | def get_model_optimizer_and_scheduler(opt): 43 | gen_module, gen_network_name = opt.gen.type.split('::') 44 | lib = importlib.import_module(gen_module) 45 | network = getattr(lib, gen_network_name) 46 | net_G = network(**opt.gen.param).to(opt.device) 47 | init_bias = getattr(opt.trainer.init, 'bias', None) 48 | net_G.apply(weights_init( 49 | opt.trainer.init.type, opt.trainer.init.gain, init_bias)) 50 | 51 | net_G_ema = network(**opt.gen.param).to(opt.device) 52 | net_G_ema.eval() 53 | accumulate(net_G_ema, net_G, 0) 54 | print('net [{}] parameter count: {:,}'.format( 55 | 'net_G', _calculate_model_size(net_G))) 56 | print('Initialize net_G weights using ' 57 | 'type: {} gain: {}'.format(opt.trainer.init.type, 58 | opt.trainer.init.gain)) 59 | 60 | 61 | opt_G = get_optimizer(opt.gen_optimizer, net_G) 62 | 63 | if opt.distributed: 64 | net_G = nn.parallel.DistributedDataParallel( 65 | net_G, 66 | device_ids=[opt.local_rank], 67 | output_device=opt.local_rank, 68 | broadcast_buffers=False, 69 | find_unused_parameters=True, 70 | ) 71 | 72 | # Scheduler 73 | sch_G = get_scheduler(opt.gen_optimizer, opt_G) 74 | return net_G, net_G_ema, opt_G, sch_G 75 | 76 | 77 | def _calculate_model_size(model): 78 | r"""Calculate number of parameters in a PyTorch network. 79 | 80 | Args: 81 | model (obj): PyTorch network. 82 | 83 | Returns: 84 | (int): Number of parameters. 85 | """ 86 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 87 | 88 | 89 | def get_scheduler(opt_opt, opt): 90 | """Return the scheduler object. 91 | 92 | Args: 93 | opt_opt (obj): Config for the specific optimization module (gen/dis). 94 | opt (obj): PyTorch optimizer object. 95 | 96 | Returns: 97 | (obj): Scheduler 98 | """ 99 | if opt_opt.lr_policy.type == 'step': 100 | scheduler = lr_scheduler.StepLR( 101 | opt, 102 | step_size=opt_opt.lr_policy.step_size, 103 | gamma=opt_opt.lr_policy.gamma) 104 | elif opt_opt.lr_policy.type == 'constant': 105 | scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1) 106 | else: 107 | return NotImplementedError('Learning rate policy {} not implemented.'. 108 | format(opt_opt.lr_policy.type)) 109 | return scheduler 110 | 111 | 112 | def get_optimizer(opt_opt, net): 113 | return get_optimizer_for_params(opt_opt, net.parameters()) 114 | 115 | 116 | def get_optimizer_for_params(opt_opt, params): 117 | r"""Return the scheduler object. 118 | 119 | Args: 120 | opt_opt (obj): Config for the specific optimization module (gen/dis). 121 | params (obj): Parameters to be trained by the parameters. 122 | 123 | Returns: 124 | (obj): Optimizer 125 | """ 126 | # We will use fuse optimizers by default. 127 | if opt_opt.type == 'adam': 128 | opt = Adam(params, 129 | lr=opt_opt.lr, 130 | betas=(opt_opt.adam_beta1, opt_opt.adam_beta2)) 131 | else: 132 | raise NotImplementedError( 133 | 'Optimizer {} is not yet implemented.'.format(opt_opt.type)) 134 | return opt 135 | 136 | 137 | -------------------------------------------------------------------------------- /generators/face_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from util import flow_util 9 | from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder 10 | 11 | class FaceGenerator(nn.Module): 12 | def __init__( 13 | self, 14 | mapping_net, 15 | warpping_net, 16 | editing_net, 17 | common 18 | ): 19 | super(FaceGenerator, self).__init__() 20 | self.mapping_net = MappingNet(**mapping_net) 21 | self.warpping_net = WarpingNet(**warpping_net, **common) 22 | self.editing_net = EditingNet(**editing_net, **common) 23 | 24 | def forward( 25 | self, 26 | input_image, 27 | driving_source, 28 | stage=None 29 | ): 30 | if stage == 'warp': 31 | descriptor = self.mapping_net(driving_source) 32 | output = self.warpping_net(input_image, descriptor) 33 | else: 34 | descriptor = self.mapping_net(driving_source) 35 | output = self.warpping_net(input_image, descriptor) 36 | output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) 37 | return output 38 | 39 | class MappingNet(nn.Module): 40 | def __init__(self, coeff_nc, descriptor_nc, layer): 41 | super( MappingNet, self).__init__() 42 | 43 | self.layer = layer 44 | nonlinearity = nn.LeakyReLU(0.1) 45 | 46 | self.first = nn.Sequential( 47 | torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) 48 | 49 | for i in range(layer): 50 | net = nn.Sequential(nonlinearity, 51 | torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) 52 | setattr(self, 'encoder' + str(i), net) 53 | 54 | self.pooling = nn.AdaptiveAvgPool1d(1) 55 | self.output_nc = descriptor_nc 56 | 57 | def forward(self, input_3dmm): 58 | out = self.first(input_3dmm) 59 | for i in range(self.layer): 60 | model = getattr(self, 'encoder' + str(i)) 61 | out = model(out) + out[:,:,3:-3] 62 | out = self.pooling(out) 63 | return out 64 | 65 | class WarpingNet(nn.Module): 66 | def __init__( 67 | self, 68 | image_nc, 69 | descriptor_nc, 70 | base_nc, 71 | max_nc, 72 | encoder_layer, 73 | decoder_layer, 74 | use_spect 75 | ): 76 | super( WarpingNet, self).__init__() 77 | 78 | nonlinearity = nn.LeakyReLU(0.1) 79 | norm_layer = functools.partial(LayerNorm2d, affine=True) 80 | kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} 81 | 82 | self.descriptor_nc = descriptor_nc 83 | self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, 84 | max_nc, encoder_layer, decoder_layer, **kwargs) 85 | 86 | self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), 87 | nonlinearity, 88 | nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) 89 | 90 | self.pool = nn.AdaptiveAvgPool2d(1) 91 | 92 | def forward(self, input_image, descriptor): 93 | final_output={} 94 | output = self.hourglass(input_image, descriptor) 95 | final_output['flow_field'] = self.flow_out(output) 96 | 97 | deformation = flow_util.convert_flow_to_deformation(final_output['flow_field']) 98 | final_output['warp_image'] = flow_util.warp_image(input_image, deformation) 99 | return final_output 100 | 101 | 102 | class EditingNet(nn.Module): 103 | def __init__( 104 | self, 105 | image_nc, 106 | descriptor_nc, 107 | layer, 108 | base_nc, 109 | max_nc, 110 | num_res_blocks, 111 | use_spect): 112 | super(EditingNet, self).__init__() 113 | 114 | nonlinearity = nn.LeakyReLU(0.1) 115 | norm_layer = functools.partial(LayerNorm2d, affine=True) 116 | kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} 117 | self.descriptor_nc = descriptor_nc 118 | 119 | # encoder part 120 | self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) 121 | self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) 122 | 123 | def forward(self, input_image, warp_image, descriptor): 124 | x = torch.cat([input_image, warp_image], 1) 125 | x = self.encoder(x) 126 | gen_image = self.decoder(x, descriptor) 127 | return gen_image 128 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import lmdb 4 | import math 5 | import argparse 6 | import numpy as np 7 | from io import BytesIO 8 | from PIL import Image 9 | 10 | import torch 11 | import torchvision.transforms.functional as F 12 | import torchvision.transforms as transforms 13 | 14 | from util.logging import init_logging, make_logging_dir 15 | from util.distributed import init_dist 16 | from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer 17 | from util.distributed import master_only_print as print 18 | from data.vox_video_dataset import VoxVideoDataset 19 | from config import Config 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Training') 24 | parser.add_argument('--config', default='./config/face.yaml') 25 | parser.add_argument('--name', default=None) 26 | parser.add_argument('--checkpoints_dir', default='result', 27 | help='Dir for saving logs and models.') 28 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 29 | parser.add_argument('--cross_id', action='store_true') 30 | parser.add_argument('--which_iter', type=int, default=None) 31 | parser.add_argument('--no_resume', action='store_true') 32 | parser.add_argument('--local_rank', type=int, default=0) 33 | parser.add_argument('--single_gpu', action='store_true') 34 | parser.add_argument('--output_dir', type=str) 35 | 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | def write2video(results_dir, *video_list): 41 | cat_video=None 42 | 43 | for video in video_list: 44 | video_numpy = video[:,:3,:,:].cpu().float().detach().numpy() 45 | video_numpy = (np.transpose(video_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 46 | video_numpy = video_numpy.astype(np.uint8) 47 | cat_video = np.concatenate([cat_video, video_numpy], 2) if cat_video is not None else video_numpy 48 | 49 | image_array=[] 50 | for i in range(cat_video.shape[0]): 51 | image_array.append(cat_video[i]) 52 | 53 | out_name = results_dir+'.mp4' 54 | _, height, width, layers = cat_video.shape 55 | size = (width,height) 56 | out = cv2.VideoWriter(out_name, cv2.VideoWriter_fourcc(*'mp4v'), 15, size) 57 | 58 | for i in range(len(image_array)): 59 | out.write(image_array[i][:,:,::-1]) 60 | out.release() 61 | 62 | if __name__ == '__main__': 63 | args = parse_args() 64 | set_random_seed(args.seed) 65 | opt = Config(args.config, args, is_train=False) 66 | 67 | if not args.single_gpu: 68 | opt.local_rank = args.local_rank 69 | init_dist(opt.local_rank) 70 | opt.device = torch.cuda.current_device() 71 | # create a visualizer 72 | date_uid, logdir = init_logging(opt) 73 | opt.logdir = logdir 74 | make_logging_dir(logdir, date_uid) 75 | 76 | # create a model 77 | net_G, net_G_ema, opt_G, sch_G \ 78 | = get_model_optimizer_and_scheduler(opt) 79 | 80 | trainer = get_trainer(opt, net_G, net_G_ema, \ 81 | opt_G, sch_G, None) 82 | 83 | current_epoch, current_iteration = trainer.load_checkpoint( 84 | opt, args.which_iter) 85 | net_G = trainer.net_G_ema.eval() 86 | 87 | output_dir = os.path.join( 88 | args.output_dir, 89 | 'epoch_{:05}_iteration_{:09}'.format(current_epoch, current_iteration) 90 | ) 91 | os.makedirs(output_dir, exist_ok=True) 92 | opt.data.cross_id = args.cross_id 93 | dataset = VoxVideoDataset(opt.data, is_inference=True) 94 | with torch.no_grad(): 95 | for video_index in range(dataset.__len__()): 96 | data = dataset.load_next_video() 97 | input_source = data['source_image'][None].cuda() 98 | name = data['video_name'] 99 | 100 | output_images, gt_images, warp_images = [],[],[] 101 | for frame_index in range(len(data['target_semantics'])): 102 | target_semantic = data['target_semantics'][frame_index][None].cuda() 103 | output_dict = net_G(input_source, target_semantic) 104 | output_images.append( 105 | output_dict['fake_image'].cpu().clamp_(-1, 1) 106 | ) 107 | warp_images.append( 108 | output_dict['warp_image'].cpu().clamp_(-1, 1) 109 | ) 110 | gt_images.append( 111 | data['target_image'][frame_index][None] 112 | ) 113 | 114 | gen_images = torch.cat(output_images, 0) 115 | gt_images = torch.cat(gt_images, 0) 116 | warp_images = torch.cat(warp_images, 0) 117 | 118 | write2video("{}/{}".format(output_dir, name), gt_images, warp_images, gen_images) 119 | print("write results to video {}/{}".format(output_dir, name)) 120 | 121 | -------------------------------------------------------------------------------- /data/vox_video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import random 4 | import collections 5 | import numpy as np 6 | from PIL import Image 7 | from io import BytesIO 8 | 9 | import torch 10 | 11 | from data.vox_dataset import VoxDataset 12 | from data.vox_dataset import format_for_lmdb 13 | 14 | class VoxVideoDataset(VoxDataset): 15 | def __init__(self, opt, is_inference): 16 | super(VoxVideoDataset, self).__init__(opt, is_inference) 17 | self.video_index = -1 18 | self.cross_id = opt.cross_id 19 | # whether normalize the crop parameters when performing cross_id reenactments 20 | # set it as "True" always brings better performance 21 | self.norm_crop_param = True 22 | 23 | def __len__(self): 24 | return len(self.video_items) 25 | 26 | def load_next_video(self): 27 | data={} 28 | self.video_index += 1 29 | video_item = self.video_items[self.video_index] 30 | source_video_item = self.random_video(video_item) if self.cross_id else video_item 31 | 32 | with self.env.begin(write=False) as txn: 33 | key = format_for_lmdb(source_video_item['video_name'], 0) 34 | img_bytes_1 = txn.get(key) 35 | img1 = Image.open(BytesIO(img_bytes_1)) 36 | data['source_image'] = self.transform(img1) 37 | 38 | semantics_key = format_for_lmdb(video_item['video_name'], 'coeff_3dmm') 39 | semantics_numpy = np.frombuffer(txn.get(semantics_key), dtype=np.float32) 40 | semantics_numpy = semantics_numpy.reshape((video_item['num_frame'],-1)) 41 | if self.cross_id and self.norm_crop_param: 42 | semantics_source_key = format_for_lmdb(source_video_item['video_name'], 'coeff_3dmm') 43 | semantics_source_numpy = np.frombuffer(txn.get(semantics_source_key), dtype=np.float32) 44 | semantic_source_numpy = semantics_source_numpy.reshape((source_video_item['num_frame'],-1))[0:1] 45 | crop_norm_ratio = self.find_crop_norm_ratio(semantic_source_numpy, semantics_numpy) 46 | else: 47 | crop_norm_ratio = None 48 | 49 | data['target_image'], data['target_semantics'] = [], [] 50 | for frame_index in range(video_item['num_frame']): 51 | key = format_for_lmdb(video_item['video_name'], frame_index) 52 | img_bytes_1 = txn.get(key) 53 | img1 = Image.open(BytesIO(img_bytes_1)) 54 | data['target_image'].append(self.transform(img1)) 55 | data['target_semantics'].append( 56 | self.transform_semantic(semantics_numpy, frame_index, crop_norm_ratio) 57 | ) 58 | data['video_name'] = self.obtain_name(video_item['video_name'], source_video_item['video_name']) 59 | return data 60 | 61 | def random_video(self, target_video_item): 62 | target_person_id = target_video_item['person_id'] 63 | assert len(self.person_ids) > 1 64 | source_person_id = np.random.choice(self.person_ids) 65 | if source_person_id == target_person_id: 66 | source_person_id = np.random.choice(self.person_ids) 67 | source_video_index = np.random.choice(self.idx_by_person_id[source_person_id]) 68 | source_video_item = self.video_items[source_video_index] 69 | return source_video_item 70 | 71 | def find_crop_norm_ratio(self, source_coeff, target_coeffs): 72 | alpha = 0.3 73 | exp_diff = np.mean(np.abs(target_coeffs[:,80:144] - source_coeff[:,80:144]), 1) 74 | angle_diff = np.mean(np.abs(target_coeffs[:,224:227] - source_coeff[:,224:227]), 1) 75 | index = np.argmin(alpha*exp_diff + (1-alpha)*angle_diff) 76 | crop_norm_ratio = source_coeff[:,-3] / target_coeffs[index:index+1, -3] 77 | return crop_norm_ratio 78 | 79 | def transform_semantic(self, semantic, frame_index, crop_norm_ratio): 80 | index = self.obtain_seq_index(frame_index, semantic.shape[0]) 81 | coeff_3dmm = semantic[index,...] 82 | # id_coeff = coeff_3dmm[:,:80] #identity 83 | ex_coeff = coeff_3dmm[:,80:144] #expression 84 | # tex_coeff = coeff_3dmm[:,144:224] #texture 85 | angles = coeff_3dmm[:,224:227] #euler angles for pose 86 | # gamma = coeff_3dmm[:,227:254] #lighting 87 | translation = coeff_3dmm[:,254:257] #translation 88 | crop = coeff_3dmm[:,257:300] #crop param 89 | 90 | if self.cross_id and self.norm_crop_param: 91 | crop[:, -3] = crop[:, -3] * crop_norm_ratio 92 | 93 | coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1) 94 | return torch.Tensor(coeff_3dmm).permute(1,0) 95 | 96 | def obtain_name(self, target_name, source_name): 97 | if not self.cross_id: 98 | return target_name 99 | else: 100 | source_name = os.path.splitext(os.path.basename(source_name))[0] 101 | return source_name+'_to_'+target_name -------------------------------------------------------------------------------- /util/meters.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | from torch.utils.tensorboard.summary import hparams 6 | 7 | 8 | from util.distributed import master_only 9 | from util.distributed import master_only_print as print 10 | 11 | LOG_WRITER = None 12 | LOG_DIR = None 13 | 14 | 15 | @torch.no_grad() 16 | def sn_reshape_weight_to_matrix(weight): 17 | r"""Reshape weight to obtain the matrix form. 18 | 19 | Args: 20 | weight (Parameters): pytorch layer parameter tensor. 21 | """ 22 | weight_mat = weight 23 | height = weight_mat.size(0) 24 | return weight_mat.reshape(height, -1) 25 | 26 | 27 | @torch.no_grad() 28 | def get_weight_stats(mod, cfg, loss_id): 29 | r"""Get weight state 30 | 31 | Args: 32 | mod: Pytorch module 33 | cfg: Configuration object 34 | loss_id: Needed when using AMP. 35 | """ 36 | loss_scale = 1.0 37 | if cfg.trainer.amp == 'O1' or cfg.trainer.amp == 'O2': 38 | # AMP rescales the gradient so we have to undo it. 39 | loss_scale = amp._amp_state.loss_scalers[loss_id].loss_scale() 40 | if mod.weight_orig.grad is not None: 41 | grad_norm = mod.weight_orig.grad.data.norm().item() / float(loss_scale) 42 | else: 43 | grad_norm = 0. 44 | weight_norm = mod.weight_orig.data.norm().item() 45 | weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig) 46 | sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v)) 47 | return grad_norm, weight_norm, sigma 48 | 49 | 50 | @master_only 51 | def set_summary_writer(log_dir): 52 | r"""Set summary writer 53 | 54 | Args: 55 | log_dir (str): Log directory. 56 | """ 57 | global LOG_DIR, LOG_WRITER 58 | LOG_DIR = log_dir 59 | LOG_WRITER = SummaryWriter(log_dir=log_dir) 60 | 61 | 62 | @master_only 63 | def write_summary(name, summary, step, hist=False): 64 | """Utility function for write summary to log_writer. 65 | """ 66 | global LOG_WRITER 67 | lw = LOG_WRITER 68 | if lw is None: 69 | raise Exception("Log writer not set.") 70 | if hist: 71 | lw.add_histogram(name, summary, step) 72 | else: 73 | lw.add_scalar(name, summary, step) 74 | 75 | 76 | @master_only 77 | def add_hparams(hparam_dict=None, metric_dict=None): 78 | r"""Add a set of hyperparameters to be compared in tensorboard. 79 | 80 | Args: 81 | hparam_dict (dictionary): Each key-value pair in the dictionary is the 82 | name of the hyper parameter and it's corresponding value. 83 | The type of the value can be one of `bool`, `string`, `float`, 84 | `int`, or `None`. 85 | metric_dict (dictionary): Each key-value pair in the dictionary is the 86 | name of the metric and it's corresponding value. Note that the key 87 | used here should be unique in the tensorboard record. Otherwise the 88 | value you added by `add_scalar` will be displayed in hparam plugin. 89 | In most cases, this is unwanted. 90 | """ 91 | if type(hparam_dict) is not dict or type(metric_dict) is not dict: 92 | raise TypeError('hparam_dict and metric_dict should be dictionary.') 93 | global LOG_WRITER 94 | lw = LOG_WRITER 95 | 96 | exp, ssi, sei = hparams(hparam_dict, metric_dict) 97 | 98 | lw.file_writer.add_summary(exp) 99 | lw.file_writer.add_summary(ssi) 100 | lw.file_writer.add_summary(sei) 101 | 102 | 103 | class Meter(object): 104 | """Meter is to keep track of statistics along steps. 105 | Meters write values for purpose like printing average values. 106 | Meters can be flushed to log files (i.e. TensorBoard for now) 107 | regularly. 108 | 109 | Args: 110 | name (str): the name of meter 111 | """ 112 | 113 | @master_only 114 | def __init__(self, name): 115 | self.name = name 116 | self.values = [] 117 | 118 | @master_only 119 | def reset(self): 120 | r"""Reset the meter values""" 121 | self.values = [] 122 | 123 | @master_only 124 | def write(self, value): 125 | r"""Record the value""" 126 | self.values.append(value) 127 | 128 | @master_only 129 | def flush(self, step): 130 | r"""Write the value in the tensorboard. 131 | 132 | Args: 133 | step (int): Epoch or iteration number. 134 | """ 135 | if not all(math.isfinite(x) for x in self.values): 136 | print("meter {} contained a nan or inf.".format(self.name)) 137 | filtered_values = list(filter(lambda x: math.isfinite(x), self.values)) 138 | if float(len(filtered_values)) != 0: 139 | value = float(sum(filtered_values)) / float(len(filtered_values)) 140 | write_summary(self.name, value, step) 141 | self.reset() 142 | 143 | @master_only 144 | def write_image(self, img_grid, step): 145 | r"""Write the value in the tensorboard. 146 | 147 | Args: 148 | img_grid: 149 | step (int): Epoch or iteration number. 150 | """ 151 | global LOG_WRITER 152 | lw = LOG_WRITER 153 | if lw is None: 154 | raise Exception("Log writer not set.") 155 | lw.add_image("Visualizations", img_grid, step) 156 | -------------------------------------------------------------------------------- /data/vox_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import random 4 | import collections 5 | import numpy as np 6 | from PIL import Image 7 | from io import BytesIO 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | 13 | def format_for_lmdb(*args): 14 | key_parts = [] 15 | for arg in args: 16 | if isinstance(arg, int): 17 | arg = str(arg).zfill(7) 18 | key_parts.append(arg) 19 | return '-'.join(key_parts).encode('utf-8') 20 | 21 | class VoxDataset(Dataset): 22 | def __init__(self, opt, is_inference): 23 | path = opt.path 24 | self.env = lmdb.open( 25 | os.path.join(path, str(opt.resolution)), 26 | max_readers=32, 27 | readonly=True, 28 | lock=False, 29 | readahead=False, 30 | meminit=False, 31 | ) 32 | 33 | if not self.env: 34 | raise IOError('Cannot open lmdb dataset', path) 35 | list_file = "test_list.txt" if is_inference else "train_list.txt" 36 | list_file = os.path.join(path, list_file) 37 | with open(list_file, 'r') as f: 38 | lines = f.readlines() 39 | videos = [line.replace('\n', '') for line in lines] 40 | 41 | self.resolution = opt.resolution 42 | self.semantic_radius = opt.semantic_radius 43 | self.video_items, self.person_ids = self.get_video_index(videos) 44 | self.idx_by_person_id = self.group_by_key(self.video_items, key='person_id') 45 | self.person_ids = self.person_ids * 100 46 | 47 | self.transform = transforms.Compose( 48 | [ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 51 | ]) 52 | 53 | def get_video_index(self, videos): 54 | video_items = [] 55 | for video in videos: 56 | video_items.append(self.Video_Item(video)) 57 | 58 | person_ids = sorted(list({video.split('#')[0] for video in videos})) 59 | 60 | return video_items, person_ids 61 | 62 | def group_by_key(self, video_list, key): 63 | return_dict = collections.defaultdict(list) 64 | for index, video_item in enumerate(video_list): 65 | return_dict[video_item[key]].append(index) 66 | return return_dict 67 | 68 | def Video_Item(self, video_name): 69 | video_item = {} 70 | video_item['video_name'] = video_name 71 | video_item['person_id'] = video_name.split('#')[0] 72 | with self.env.begin(write=False) as txn: 73 | key = format_for_lmdb(video_item['video_name'], 'length') 74 | length = int(txn.get(key).decode('utf-8')) 75 | video_item['num_frame'] = length 76 | 77 | return video_item 78 | 79 | def __len__(self): 80 | return len(self.person_ids) 81 | 82 | def __getitem__(self, index): 83 | data={} 84 | person_id = self.person_ids[index] 85 | video_item = self.video_items[random.choices(self.idx_by_person_id[person_id], k=1)[0]] 86 | frame_source, frame_target = self.random_select_frames(video_item) 87 | 88 | with self.env.begin(write=False) as txn: 89 | key = format_for_lmdb(video_item['video_name'], frame_source) 90 | img_bytes_1 = txn.get(key) 91 | key = format_for_lmdb(video_item['video_name'], frame_target) 92 | img_bytes_2 = txn.get(key) 93 | semantics_key = format_for_lmdb(video_item['video_name'], 'coeff_3dmm') 94 | semantics_numpy = np.frombuffer(txn.get(semantics_key), dtype=np.float32) 95 | semantics_numpy = semantics_numpy.reshape((video_item['num_frame'],-1)) 96 | 97 | img1 = Image.open(BytesIO(img_bytes_1)) 98 | data['source_image'] = self.transform(img1) 99 | 100 | img2 = Image.open(BytesIO(img_bytes_2)) 101 | data['target_image'] = self.transform(img2) 102 | 103 | data['target_semantics'] = self.transform_semantic(semantics_numpy, frame_target) 104 | data['source_semantics'] = self.transform_semantic(semantics_numpy, frame_source) 105 | 106 | return data 107 | 108 | def random_select_frames(self, video_item): 109 | num_frame = video_item['num_frame'] 110 | frame_idx = random.choices(list(range(num_frame)), k=2) 111 | return frame_idx[0], frame_idx[1] 112 | 113 | def transform_semantic(self, semantic, frame_index): 114 | index = self.obtain_seq_index(frame_index, semantic.shape[0]) 115 | coeff_3dmm = semantic[index,...] 116 | # id_coeff = coeff_3dmm[:,:80] #identity 117 | ex_coeff = coeff_3dmm[:,80:144] #expression 118 | # tex_coeff = coeff_3dmm[:,144:224] #texture 119 | angles = coeff_3dmm[:,224:227] #euler angles for pose 120 | # gamma = coeff_3dmm[:,227:254] #lighting 121 | translation = coeff_3dmm[:,254:257] #translation 122 | crop = coeff_3dmm[:,257:260] #crop param 123 | 124 | coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1) 125 | return torch.Tensor(coeff_3dmm).permute(1,0) 126 | 127 | def obtain_seq_index(self, index, num_frames): 128 | seq = list(range(index-self.semantic_radius, index+self.semantic_radius+1)) 129 | seq = [ min(max(item, 0), num_frames-1) for item in seq ] 130 | return seq 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /scripts/face_recon_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from scipy.io import savemat 7 | 8 | import torch 9 | 10 | from models import create_model 11 | from options.inference_options import InferenceOptions 12 | from util.preprocess import align_img 13 | from util.load_mats import load_lm3d 14 | from util.util import tensor2im, save_image 15 | 16 | 17 | def get_data_path(root, keypoint_root): 18 | filenames = list() 19 | keypoint_filenames = list() 20 | 21 | IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'} 22 | IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE}) 23 | extensions = IMAGE_EXTENSIONS 24 | 25 | for ext in extensions: 26 | filenames += glob.glob(f'{root}/*.{ext}', recursive=True) 27 | filenames = sorted(filenames) 28 | for filename in filenames: 29 | name = os.path.splitext(os.path.basename(filename))[0] 30 | keypoint_filenames.append( 31 | os.path.join(keypoint_root, name + '.txt') 32 | ) 33 | return filenames, keypoint_filenames 34 | 35 | 36 | class ImagePathDataset(torch.utils.data.Dataset): 37 | def __init__(self, filenames, txt_filenames, bfm_folder): 38 | self.filenames = filenames 39 | self.txt_filenames = txt_filenames 40 | self.lm3d_std = load_lm3d(bfm_folder) 41 | 42 | def __len__(self): 43 | return len(self.filenames) 44 | 45 | def __getitem__(self, i): 46 | filename = self.filenames[i] 47 | txt_filename = self.txt_filenames[i] 48 | imgs, _, trans_params = self.read_data(filename, txt_filename) 49 | return { 50 | 'imgs':imgs, 51 | 'trans_param':trans_params, 52 | 'filename': filename 53 | } 54 | 55 | def image_transform(self, images, lm): 56 | W,H = images.size 57 | if np.mean(lm) == -1: 58 | lm = (self.lm3d_std[:, :2]+1)/2. 59 | lm = np.concatenate( 60 | [lm[:, :1]*W, lm[:, 1:2]*H], 1 61 | ) 62 | else: 63 | lm[:, -1] = H - 1 - lm[:, -1] 64 | 65 | trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) 66 | img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) 67 | lm = torch.tensor(lm) 68 | trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) 69 | trans_params = torch.tensor(trans_params.astype(np.float32)) 70 | return img, lm, trans_params 71 | 72 | def read_data(self, filename, txt_filename): 73 | images = Image.open(filename).convert('RGB') 74 | lm = np.loadtxt(txt_filename).astype(np.float32) 75 | lm = lm.reshape([-1, 2]) 76 | imgs, lms, trans_params = self.image_transform(images, lm) 77 | return imgs, lms, trans_params 78 | 79 | 80 | def main(opt, model): 81 | import torch.multiprocessing 82 | torch.multiprocessing.set_sharing_strategy('file_system') 83 | filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir) 84 | 85 | dataset = ImagePathDataset(filenames, keypoint_filenames, opt.bfm_folder) 86 | dataloader = torch.utils.data.DataLoader( 87 | dataset, 88 | batch_size=opt.inference_batch_size, 89 | shuffle=False, 90 | drop_last=False, 91 | num_workers=8, 92 | ) 93 | pred_coeffs, pred_trans_params = [], [] 94 | print('nums of images:', dataset.__len__()) 95 | for iteration, data in tqdm(enumerate(dataloader)): 96 | data_input = { 97 | 'imgs': data['imgs'], 98 | } 99 | 100 | model.set_input(data_input) 101 | model.test() 102 | pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict} 103 | pred_coeff = np.concatenate([ 104 | pred_coeff['id'], 105 | pred_coeff['exp'], 106 | pred_coeff['tex'], 107 | pred_coeff['angle'], 108 | pred_coeff['gamma'], 109 | pred_coeff['trans']], 1) 110 | pred_coeffs.append(pred_coeff) 111 | trans_param = data['trans_param'].cpu().numpy() 112 | pred_trans_params.append(trans_param) 113 | if opt.save_split_files: 114 | for index, filename in enumerate(data['filename']): 115 | basename = os.path.splitext(os.path.basename(filename))[0] 116 | output_path = os.path.join(opt.output_dir, basename+'.mat') 117 | savemat( 118 | output_path, 119 | {'coeff':pred_coeff[index], 120 | 'transform_params':trans_param[index]} 121 | ) 122 | # visuals = model.get_current_visuals() # get image results 123 | # for name in visuals: 124 | # images = visuals[name] 125 | # for i in range(images.shape[0]): 126 | # image_numpy = tensor2im(images[i]) 127 | # save_image(image_numpy, os.path.basename(data['filename'][i])+'.png') 128 | 129 | pred_coeffs = np.concatenate(pred_coeffs, 0) 130 | pred_trans_params = np.concatenate(pred_trans_params, 0) 131 | savemat(os.path.join(opt.output_dir, 'ffhq.mat'), {'coeff':pred_coeffs, 'transform_params':pred_trans_params}) 132 | 133 | 134 | if __name__ == '__main__': 135 | opt = InferenceOptions().parse() # get test options 136 | model = create_model(opt) 137 | model.setup(opt) 138 | model.device = 'cuda:0' 139 | model.parallelize() 140 | model.eval() 141 | lm3d_std = load_lm3d(opt.bfm_folder) 142 | main(opt, model) 143 | 144 | 145 | -------------------------------------------------------------------------------- /intuitive_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import numpy as np 5 | from scipy.io import savemat,loadmat 6 | 7 | import torch 8 | import torchvision.transforms.functional as F 9 | import torchvision.transforms as transforms 10 | 11 | from config import Config 12 | from util.logging import init_logging, make_logging_dir 13 | from util.distributed import init_dist 14 | from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer 15 | from util.distributed import master_only_print as print 16 | from data.image_dataset import ImageDataset 17 | from inference import write2video 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Training') 22 | parser.add_argument('--config', default='./config/face.yaml') 23 | parser.add_argument('--name', default=None) 24 | parser.add_argument('--checkpoints_dir', default='result', 25 | help='Dir for saving logs and models.') 26 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 27 | parser.add_argument('--which_iter', type=int, default=None) 28 | parser.add_argument('--no_resume', action='store_true') 29 | parser.add_argument('--input_name', type=str) 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | parser.add_argument('--single_gpu', action='store_true') 32 | parser.add_argument('--output_dir', type=str) 33 | 34 | args = parser.parse_args() 35 | return args 36 | 37 | def get_control(input_name): 38 | control_dict = {} 39 | control_dict['rotation_center'] = torch.tensor([0,0,0,0,0,0.45]) 40 | control_dict['rotation_left_x'] = torch.tensor([0,0,math.pi/10,0,0,0.45]) 41 | control_dict['rotation_right_x'] = torch.tensor([0,0,-math.pi/10,0,0,0.45]) 42 | 43 | control_dict['rotation_left_y'] = torch.tensor([math.pi/10,0,0,0,0,0.45]) 44 | control_dict['rotation_right_y'] = torch.tensor([-math.pi/10,0,0,0,0,0.45]) 45 | 46 | control_dict['rotation_left_z'] = torch.tensor([0,math.pi/8,0,0,0,0.45]) 47 | control_dict['rotation_right_z'] = torch.tensor([0,-math.pi/8,0,0,0,0.45]) 48 | 49 | expession = loadmat('{}/expression.mat'.format(input_name)) 50 | 51 | for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']: 52 | control_dict[item] = torch.tensor(expession[item])[0] 53 | 54 | sort_rot_control = [ 55 | 'rotation_left_x', 'rotation_center', 56 | 'rotation_right_x', 'rotation_center', 57 | 'rotation_left_y', 'rotation_center', 58 | 'rotation_right_y', 'rotation_center', 59 | 'rotation_left_z', 'rotation_center', 60 | 'rotation_right_z', 'rotation_center' 61 | ] 62 | 63 | sort_exp_control = [ 64 | 'expression_center', 'expression_mouth', 65 | 'expression_center', 'expression_eyebrow', 66 | 'expression_center', 'expression_eyes', 67 | ] 68 | return control_dict, sort_rot_control, sort_exp_control 69 | 70 | if __name__ == '__main__': 71 | args = parse_args() 72 | set_random_seed(args.seed) 73 | opt = Config(args.config, args, is_train=False) 74 | 75 | if not args.single_gpu: 76 | opt.local_rank = args.local_rank 77 | init_dist(opt.local_rank) 78 | opt.device = torch.cuda.current_device() 79 | 80 | # create a visualizer 81 | date_uid, logdir = init_logging(opt) 82 | opt.logdir = logdir 83 | make_logging_dir(logdir, date_uid) 84 | 85 | # create a model 86 | net_G, net_G_ema, opt_G, sch_G \ 87 | = get_model_optimizer_and_scheduler(opt) 88 | 89 | trainer = get_trainer(opt, net_G, net_G_ema, \ 90 | opt_G, sch_G, None) 91 | 92 | current_epoch, current_iteration = trainer.load_checkpoint( 93 | opt, args.which_iter) 94 | net_G = trainer.net_G_ema.eval() 95 | 96 | output_dir = os.path.join( 97 | args.output_dir, 98 | 'epoch_{:05}_iteration_{:09}'.format(current_epoch, current_iteration) 99 | ) 100 | 101 | os.makedirs(output_dir, exist_ok=True) 102 | image_dataset = ImageDataset(opt.data, args.input_name) 103 | 104 | control_dict, sort_rot_control, sort_exp_control = get_control(args.input_name) 105 | for _ in range(image_dataset.__len__()): 106 | with torch.no_grad(): 107 | data = image_dataset.next_image() 108 | num = 10 109 | output_images = [] 110 | # rotation control 111 | current = control_dict['rotation_center'] 112 | for control in sort_rot_control: 113 | for i in range(num): 114 | rotation = (control_dict[control]-current)*i/(num-1)+current 115 | data['target_semantics'][:, 64:70, :] = rotation[None, :, None] 116 | output_dict = net_G(data['source_image'].cuda(), data['target_semantics'].cuda()) 117 | output_images.append( 118 | output_dict['fake_image'].cpu().clamp_(-1, 1) 119 | ) 120 | current = rotation 121 | 122 | # expression control 123 | current = data['target_semantics'][0, :64, 0] 124 | for control in sort_exp_control: 125 | for i in range(num): 126 | expression = (control_dict[control]-current)*i/(num-1)+current 127 | data['target_semantics'][:, :64, :] = expression[None, :, None] 128 | output_dict = net_G(data['source_image'].cuda(), data['target_semantics'].cuda()) 129 | output_images.append( 130 | output_dict['fake_image'].cpu().clamp_(-1, 1) 131 | ) 132 | current = expression 133 | output_images = torch.cat(output_images, 0) 134 | print('write results to file {}/{}'.format(output_dir, data['name'])) 135 | write2video('{}/{}'.format(output_dir, data['name']), output_images) 136 | 137 | -------------------------------------------------------------------------------- /scripts/face_recon_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from scipy.io import savemat 8 | 9 | import torch 10 | 11 | from models import create_model 12 | from options.inference_options import InferenceOptions 13 | from util.preprocess import align_img 14 | from util.load_mats import load_lm3d 15 | from util.util import mkdirs, tensor2im, save_image 16 | 17 | 18 | def get_data_path(root, keypoint_root): 19 | filenames = list() 20 | keypoint_filenames = list() 21 | 22 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} 23 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) 24 | extensions = VIDEO_EXTENSIONS 25 | 26 | for ext in extensions: 27 | filenames += glob.glob(f'{root}/**/*.{ext}', recursive=True) 28 | filenames = sorted(filenames) 29 | keypoint_filenames = sorted(glob.glob(f'{keypoint_root}/**/*.txt', recursive=True)) 30 | assert len(filenames) == len(keypoint_filenames) 31 | 32 | return filenames, keypoint_filenames 33 | 34 | class VideoPathDataset(torch.utils.data.Dataset): 35 | def __init__(self, filenames, txt_filenames, bfm_folder): 36 | self.filenames = filenames 37 | self.txt_filenames = txt_filenames 38 | self.lm3d_std = load_lm3d(bfm_folder) 39 | 40 | def __len__(self): 41 | return len(self.filenames) 42 | 43 | def __getitem__(self, index): 44 | filename = self.filenames[index] 45 | txt_filename = self.txt_filenames[index] 46 | frames = self.read_video(filename) 47 | lm = np.loadtxt(txt_filename).astype(np.float32) 48 | lm = lm.reshape([len(frames), -1, 2]) 49 | out_images, out_trans_params = list(), list() 50 | for i in range(len(frames)): 51 | out_img, _, out_trans_param \ 52 | = self.image_transform(frames[i], lm[i]) 53 | out_images.append(out_img[None]) 54 | out_trans_params.append(out_trans_param[None]) 55 | return { 56 | 'imgs': torch.cat(out_images, 0), 57 | 'trans_param':torch.cat(out_trans_params, 0), 58 | 'filename': filename 59 | } 60 | 61 | def read_video(self, filename): 62 | frames = list() 63 | cap = cv2.VideoCapture(filename) 64 | while cap.isOpened(): 65 | ret, frame = cap.read() 66 | if ret: 67 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 68 | frame = Image.fromarray(frame) 69 | frames.append(frame) 70 | else: 71 | break 72 | cap.release() 73 | return frames 74 | 75 | def image_transform(self, images, lm): 76 | W,H = images.size 77 | if np.mean(lm) == -1: 78 | lm = (self.lm3d_std[:, :2]+1)/2. 79 | lm = np.concatenate( 80 | [lm[:, :1]*W, lm[:, 1:2]*H], 1 81 | ) 82 | else: 83 | lm[:, -1] = H - 1 - lm[:, -1] 84 | 85 | trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) 86 | img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) 87 | lm = torch.tensor(lm) 88 | trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) 89 | trans_params = torch.tensor(trans_params.astype(np.float32)) 90 | return img, lm, trans_params 91 | 92 | def main(opt, model): 93 | import torch.multiprocessing 94 | torch.multiprocessing.set_sharing_strategy('file_system') 95 | filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir) 96 | dataset = VideoPathDataset(filenames, keypoint_filenames, opt.bfm_folder) 97 | dataloader = torch.utils.data.DataLoader( 98 | dataset, 99 | batch_size=1, # can noly set to one here! 100 | shuffle=False, 101 | drop_last=False, 102 | num_workers=8, 103 | ) 104 | batch_size = opt.inference_batch_size 105 | for data in tqdm(dataloader): 106 | num_batch = data['imgs'][0].shape[0] // batch_size + 1 107 | pred_coeffs = list() 108 | for index in range(num_batch): 109 | data_input = { 110 | 'imgs': data['imgs'][0,index*batch_size:(index+1)*batch_size], 111 | } 112 | model.set_input(data_input) 113 | model.test() 114 | pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict} 115 | pred_coeff = np.concatenate([ 116 | pred_coeff['id'], 117 | pred_coeff['exp'], 118 | pred_coeff['tex'], 119 | pred_coeff['angle'], 120 | pred_coeff['gamma'], 121 | pred_coeff['trans']], 1) 122 | pred_coeffs.append(pred_coeff) 123 | visuals = model.get_current_visuals() # get image results 124 | if False: # debug 125 | for name in visuals: 126 | images = visuals[name] 127 | for i in range(images.shape[0]): 128 | image_numpy = tensor2im(images[i]) 129 | save_image( 130 | image_numpy, 131 | os.path.join( 132 | opt.output_dir, 133 | os.path.basename(data['filename'][0])+str(i).zfill(5)+'.jpg') 134 | ) 135 | exit() 136 | 137 | pred_coeffs = np.concatenate(pred_coeffs, 0) 138 | pred_trans_params = data['trans_param'][0].cpu().numpy() 139 | name = data['filename'][0].split('/')[-2:] 140 | name[-1] = os.path.splitext(name[-1])[0] + '.mat' 141 | os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) 142 | savemat( 143 | os.path.join(opt.output_dir, name[-2], name[-1]), 144 | {'coeff':pred_coeffs, 'transform_params':pred_trans_params} 145 | ) 146 | 147 | if __name__ == '__main__': 148 | opt = InferenceOptions().parse() # get test options 149 | model = create_model(opt) 150 | model.setup(opt) 151 | model.device = 'cuda:0' 152 | model.parallelize() 153 | model.eval() 154 | 155 | main(opt, model) 156 | 157 | 158 | -------------------------------------------------------------------------------- /trainers/face_trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from trainers.base import BaseTrainer 6 | from util.trainer import accumulate, get_optimizer 7 | from loss.perceptual import PerceptualLoss 8 | 9 | class FaceTrainer(BaseTrainer): 10 | r"""Initialize lambda model trainer. 11 | 12 | Args: 13 | cfg (obj): Global configuration. 14 | net_G (obj): Generator network. 15 | opt_G (obj): Optimizer for the generator network. 16 | sch_G (obj): Scheduler for the generator optimizer. 17 | train_data_loader (obj): Train data loader. 18 | val_data_loader (obj): Validation data loader. 19 | """ 20 | 21 | def __init__(self, opt, net_G, opt_G, sch_G, 22 | train_data_loader, val_data_loader=None): 23 | super(FaceTrainer, self).__init__(opt, net_G, opt_G, sch_G, train_data_loader, val_data_loader) 24 | self.accum = 0.5 ** (32 / (10 * 1000)) 25 | self.log_size = int(math.log(opt.data.resolution, 2)) 26 | 27 | def _init_loss(self, opt): 28 | self._assign_criteria( 29 | 'perceptual_warp', 30 | PerceptualLoss( 31 | network=opt.trainer.vgg_param_warp.network, 32 | layers=opt.trainer.vgg_param_warp.layers, 33 | num_scales=getattr(opt.trainer.vgg_param_warp, 'num_scales', 1), 34 | use_style_loss=getattr(opt.trainer.vgg_param_warp, 'use_style_loss', False), 35 | weight_style_to_perceptual=getattr(opt.trainer.vgg_param_warp, 'style_to_perceptual', 0) 36 | ).to('cuda'), 37 | opt.trainer.loss_weight.weight_perceptual_warp) 38 | 39 | self._assign_criteria( 40 | 'perceptual_final', 41 | PerceptualLoss( 42 | network=opt.trainer.vgg_param_final.network, 43 | layers=opt.trainer.vgg_param_final.layers, 44 | num_scales=getattr(opt.trainer.vgg_param_final, 'num_scales', 1), 45 | use_style_loss=getattr(opt.trainer.vgg_param_final, 'use_style_loss', False), 46 | weight_style_to_perceptual=getattr(opt.trainer.vgg_param_final, 'style_to_perceptual', 0) 47 | ).to('cuda'), 48 | opt.trainer.loss_weight.weight_perceptual_final) 49 | 50 | def _assign_criteria(self, name, criterion, weight): 51 | self.criteria[name] = criterion 52 | self.weights[name] = weight 53 | 54 | def optimize_parameters(self, data): 55 | self.gen_losses = {} 56 | source_image, target_image = data['source_image'], data['target_image'] 57 | source_semantic, target_semantic = data['source_semantics'], data['target_semantics'] 58 | 59 | input_image = torch.cat((source_image, target_image), 0) 60 | input_semantic = torch.cat((target_semantic, source_semantic), 0) 61 | gt_image = torch.cat((target_image, source_image), 0) 62 | 63 | output_dict = self.net_G(input_image, input_semantic, self.training_stage) 64 | 65 | if self.training_stage == 'gen': 66 | fake_img = output_dict['fake_image'] 67 | warp_img = output_dict['warp_image'] 68 | self.gen_losses["perceptual_final"] = self.criteria['perceptual_final'](fake_img, gt_image) 69 | self.gen_losses["perceptual_warp"] = self.criteria['perceptual_warp'](warp_img, gt_image) 70 | else: 71 | warp_img = output_dict['warp_image'] 72 | self.gen_losses["perceptual_warp"] = self.criteria['perceptual_warp'](warp_img, gt_image) 73 | 74 | total_loss = 0 75 | for key in self.gen_losses: 76 | self.gen_losses[key] = self.gen_losses[key] * self.weights[key] 77 | total_loss += self.gen_losses[key] 78 | 79 | self.gen_losses['total_loss'] = total_loss 80 | 81 | self.net_G.zero_grad() 82 | total_loss.backward() 83 | self.opt_G.step() 84 | 85 | accumulate(self.net_G_ema, self.net_G_module, self.accum) 86 | 87 | def _start_of_iteration(self, data, current_iteration): 88 | self.training_stage = 'gen' if current_iteration >= self.opt.trainer.pretrain_warp_iteration else 'warp' 89 | if current_iteration == self.opt.trainer.pretrain_warp_iteration: 90 | self.reset_trainer() 91 | return data 92 | 93 | def reset_trainer(self): 94 | self.opt_G = get_optimizer(self.opt.gen_optimizer, self.net_G.module) 95 | 96 | def _get_visualizations(self, data): 97 | source_image, target_image = data['source_image'], data['target_image'] 98 | source_semantic, target_semantic = data['source_semantics'], data['target_semantics'] 99 | 100 | input_image = torch.cat((source_image, target_image), 0) 101 | input_semantic = torch.cat((target_semantic, source_semantic), 0) 102 | with torch.no_grad(): 103 | self.net_G_ema.eval() 104 | output_dict = self.net_G_ema( 105 | input_image, input_semantic, self.training_stage 106 | ) 107 | if self.training_stage == 'gen': 108 | fake_img = torch.cat([output_dict['warp_image'], output_dict['fake_image']], 3) 109 | else: 110 | fake_img = output_dict['warp_image'] 111 | 112 | fake_source, fake_target = torch.chunk(fake_img, 2, dim=0) 113 | sample_source = torch.cat([source_image, fake_source, target_image], 3) 114 | sample_target = torch.cat([target_image, fake_target, source_image], 3) 115 | sample = torch.cat([sample_source, sample_target], 2) 116 | sample = torch.cat(torch.chunk(sample, sample.size(0), 0)[:3], 2) 117 | return sample 118 | 119 | def test(self, data_loader, output_dir, current_iteration=-1): 120 | pass 121 | 122 | def _compute_metrics(self, data, current_iteration): 123 | if self.training_stage == 'gen': 124 | source_image, target_image = data['source_image'], data['target_image'] 125 | source_semantic, target_semantic = data['source_semantics'], data['target_semantics'] 126 | 127 | input_image = torch.cat((source_image, target_image), 0) 128 | input_semantic = torch.cat((target_semantic, source_semantic), 0) 129 | gt_image = torch.cat((target_image, source_image), 0) 130 | metrics = {} 131 | with torch.no_grad(): 132 | self.net_G_ema.eval() 133 | output_dict = self.net_G_ema( 134 | input_image, input_semantic, self.training_stage 135 | ) 136 | fake_image = output_dict['fake_image'] 137 | metrics['lpips'] = self.lpips(fake_image, gt_image).mean() 138 | return metrics -------------------------------------------------------------------------------- /scripts/prepare_vox_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import lmdb 4 | import argparse 5 | import multiprocessing 6 | import numpy as np 7 | 8 | from glob import glob 9 | from io import BytesIO 10 | from tqdm import tqdm 11 | from PIL import Image 12 | from scipy.io import loadmat 13 | from torchvision.transforms import functional as trans_fn 14 | 15 | def format_for_lmdb(*args): 16 | key_parts = [] 17 | for arg in args: 18 | if isinstance(arg, int): 19 | arg = str(arg).zfill(7) 20 | key_parts.append(arg) 21 | return '-'.join(key_parts).encode('utf-8') 22 | 23 | class Resizer: 24 | def __init__(self, size, kp_root, coeff_3dmm_root, img_format): 25 | self.size = size 26 | self.kp_root = kp_root 27 | self.coeff_3dmm_root = coeff_3dmm_root 28 | self.img_format = img_format 29 | 30 | def get_resized_bytes(self, img, img_format='jpeg'): 31 | img = trans_fn.resize(img, (self.size, self.size), interpolation=Image.BICUBIC) 32 | buf = BytesIO() 33 | img.save(buf, format=img_format) 34 | img_bytes = buf.getvalue() 35 | return img_bytes 36 | 37 | def prepare(self, filename): 38 | frames = {'img':[], 'kp':None, 'coeff_3dmm':None} 39 | cap = cv2.VideoCapture(filename) 40 | while cap.isOpened(): 41 | ret, frame = cap.read() 42 | if ret: 43 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 44 | img_pil = Image.fromarray(frame) 45 | img_bytes = self.get_resized_bytes(img_pil, self.img_format) 46 | frames['img'].append(img_bytes) 47 | else: 48 | break 49 | cap.release() 50 | video_name = os.path.splitext(os.path.basename(filename))[0] 51 | keypoint_byte = get_others(self.kp_root, video_name, 'keypoint') 52 | coeff_3dmm_byte = get_others(self.coeff_3dmm_root, video_name, 'coeff_3dmm') 53 | frames['kp'] = keypoint_byte 54 | frames['coeff_3dmm'] = coeff_3dmm_byte 55 | return frames 56 | 57 | def __call__(self, index_filename): 58 | index, filename = index_filename 59 | result = self.prepare(filename) 60 | return index, result, filename 61 | 62 | def get_others(root, video_name, data_type): 63 | if root is None: 64 | return 65 | else: 66 | assert data_type in ('keypoint', 'coeff_3dmm') 67 | if os.path.isfile(os.path.join(root, 'train', video_name+'.mat')): 68 | file_path = os.path.join(root, 'train', video_name+'.mat') 69 | else: 70 | file_path = os.path.join(root, 'test', video_name+'.mat') 71 | 72 | if data_type == 'keypoint': 73 | return_byte = convert_kp(file_path) 74 | else: 75 | return_byte = convert_3dmm(file_path) 76 | return return_byte 77 | 78 | def convert_kp(file_path): 79 | file_mat = loadmat(file_path) 80 | kp_byte = file_mat['landmark'].tobytes() 81 | return kp_byte 82 | 83 | def convert_3dmm(file_path): 84 | file_mat = loadmat(file_path) 85 | coeff_3dmm = file_mat['coeff'] 86 | crop_param = file_mat['transform_params'] 87 | _, _, ratio, t0, t1 = np.hsplit(crop_param.astype(np.float32), 5) 88 | crop_param = np.concatenate([ratio, t0, t1], 1) 89 | coeff_3dmm_cat = np.concatenate([coeff_3dmm, crop_param], 1) 90 | coeff_3dmm_byte = coeff_3dmm_cat.tobytes() 91 | return coeff_3dmm_byte 92 | 93 | 94 | def prepare_data(path, keypoint_path, coeff_3dmm_path, out, n_worker, sizes, chunksize, img_format): 95 | filenames = list() 96 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} 97 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) 98 | extensions = VIDEO_EXTENSIONS 99 | for ext in extensions: 100 | filenames += glob(f'{path}/**/*.{ext}', recursive=True) 101 | train_video, test_video = [], [] 102 | for item in filenames: 103 | if "/train/" in item: 104 | train_video.append(item) 105 | else: 106 | test_video.append(item) 107 | print(len(train_video), len(test_video)) 108 | with open(os.path.join(out, 'train_list.txt'),'w') as f: 109 | for item in train_video: 110 | item = os.path.splitext(os.path.basename(item))[0] 111 | f.write(item + '\n') 112 | 113 | with open(os.path.join(out, 'test_list.txt'),'w') as f: 114 | for item in test_video: 115 | item = os.path.splitext(os.path.basename(item))[0] 116 | f.write(item + '\n') 117 | 118 | 119 | filenames = sorted(filenames) 120 | total = len(filenames) 121 | os.makedirs(out, exist_ok=True) 122 | for size in sizes: 123 | lmdb_path = os.path.join(out, str(size)) 124 | with lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) as env: 125 | with env.begin(write=True) as txn: 126 | txn.put(format_for_lmdb('length'), format_for_lmdb(total)) 127 | resizer = Resizer(size, keypoint_path, coeff_3dmm_path, img_format) 128 | with multiprocessing.Pool(n_worker) as pool: 129 | for idx, result, filename in tqdm( 130 | pool.imap_unordered(resizer, enumerate(filenames), chunksize=chunksize), 131 | total=total): 132 | filename = os.path.basename(filename) 133 | video_name = os.path.splitext(filename)[0] 134 | txn.put(format_for_lmdb(video_name, 'length'), format_for_lmdb(len(result['img']))) 135 | 136 | for frame_idx, frame in enumerate(result['img']): 137 | txn.put(format_for_lmdb(video_name, frame_idx), frame) 138 | 139 | if result['kp']: 140 | txn.put(format_for_lmdb(video_name, 'keypoint'), result['kp']) 141 | if result['coeff_3dmm']: 142 | txn.put(format_for_lmdb(video_name, 'coeff_3dmm'), result['coeff_3dmm']) 143 | 144 | 145 | if __name__ == '__main__': 146 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 147 | parser.add_argument('--path', type=str, help='a path to input directiory') 148 | parser.add_argument('--keypoint_path', type=str, help='a path to output directory', default=None) 149 | parser.add_argument('--coeff_3dmm_path', type=str, help='a path to output directory', default=None) 150 | parser.add_argument('--out', type=str, help='a path to output directory') 151 | parser.add_argument('--sizes', type=int, nargs='+', default=(256,)) 152 | parser.add_argument('--n_worker', type=int, help='number of worker processes', default=8) 153 | parser.add_argument('--chunksize', type=int, help='approximate chunksize for each worker', default=10) 154 | parser.add_argument('--img_format', type=str, default='jpeg') 155 | args = parser.parse_args() 156 | prepare_data(**vars(args)) -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = models.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = models.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = models.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = models.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = models.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = models.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous utils.""" 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from scipy.stats import truncnorm 8 | from torch._six import container_abcs, string_classes 9 | 10 | 11 | def split_labels(labels, label_lengths): 12 | r"""Split concatenated labels into their parts. 13 | 14 | Args: 15 | labels (torch.Tensor): Labels obtained through concatenation. 16 | label_lengths (OrderedDict): Containing order of labels & their lengths. 17 | 18 | Returns: 19 | 20 | """ 21 | assert isinstance(label_lengths, OrderedDict) 22 | start = 0 23 | outputs = {} 24 | for data_type, length in label_lengths.items(): 25 | end = start + length 26 | if labels.dim() == 5: 27 | outputs[data_type] = labels[:, :, start:end] 28 | elif labels.dim() == 4: 29 | outputs[data_type] = labels[:, start:end] 30 | elif labels.dim() == 3: 31 | outputs[data_type] = labels[start:end] 32 | start = end 33 | return outputs 34 | 35 | 36 | def requires_grad(model, require=True): 37 | r""" Set a model to require gradient or not. 38 | 39 | Args: 40 | model (nn.Module): Neural network model. 41 | require (bool): Whether the network requires gradient or not. 42 | 43 | Returns: 44 | 45 | """ 46 | for p in model.parameters(): 47 | p.requires_grad = require 48 | 49 | 50 | def to_device(data, device): 51 | r"""Move all tensors inside data to device. 52 | 53 | Args: 54 | data (dict, list, or tensor): Input data. 55 | device (str): 'cpu' or 'cuda'. 56 | """ 57 | assert device in ['cpu', 'cuda'] 58 | if isinstance(data, torch.Tensor): 59 | data = data.to(torch.device(device)) 60 | return data 61 | elif isinstance(data, container_abcs.Mapping): 62 | return {key: to_device(data[key], device) for key in data} 63 | elif isinstance(data, container_abcs.Sequence) and \ 64 | not isinstance(data, string_classes): 65 | return [to_device(d, device) for d in data] 66 | else: 67 | return data 68 | 69 | 70 | def to_cuda(data): 71 | r"""Move all tensors inside data to gpu. 72 | 73 | Args: 74 | data (dict, list, or tensor): Input data. 75 | """ 76 | return to_device(data, 'cuda') 77 | 78 | 79 | def to_cpu(data): 80 | r"""Move all tensors inside data to cpu. 81 | 82 | Args: 83 | data (dict, list, or tensor): Input data. 84 | """ 85 | return to_device(data, 'cpu') 86 | 87 | 88 | def to_half(data): 89 | r"""Move all floats to half. 90 | 91 | Args: 92 | data (dict, list or tensor): Input data. 93 | """ 94 | if isinstance(data, torch.Tensor) and torch.is_floating_point(data): 95 | data = data.half() 96 | return data 97 | elif isinstance(data, container_abcs.Mapping): 98 | return {key: to_half(data[key]) for key in data} 99 | elif isinstance(data, container_abcs.Sequence) and \ 100 | not isinstance(data, string_classes): 101 | return [to_half(d) for d in data] 102 | else: 103 | return data 104 | 105 | 106 | def to_float(data): 107 | r"""Move all halfs to float. 108 | 109 | Args: 110 | data (dict, list or tensor): Input data. 111 | """ 112 | if isinstance(data, torch.Tensor) and torch.is_floating_point(data): 113 | data = data.float() 114 | return data 115 | elif isinstance(data, container_abcs.Mapping): 116 | return {key: to_float(data[key]) for key in data} 117 | elif isinstance(data, container_abcs.Sequence) and \ 118 | not isinstance(data, string_classes): 119 | return [to_float(d) for d in data] 120 | else: 121 | return data 122 | 123 | 124 | def get_and_setattr(cfg, name, default): 125 | r"""Get attribute with default choice. If attribute does not exist, set it 126 | using the default value. 127 | 128 | Args: 129 | cfg (obj) : Config options. 130 | name (str) : Attribute name. 131 | default (obj) : Default attribute. 132 | 133 | Returns: 134 | (obj) : Desired attribute. 135 | """ 136 | if not hasattr(cfg, name) or name not in cfg.__dict__: 137 | setattr(cfg, name, default) 138 | return getattr(cfg, name) 139 | 140 | 141 | def get_nested_attr(cfg, attr_name, default): 142 | r"""Iteratively try to get the attribute from cfg. If not found, return 143 | default. 144 | 145 | Args: 146 | cfg (obj): Config file. 147 | attr_name (str): Attribute name (e.g. XXX.YYY.ZZZ). 148 | default (obj): Default return value for the attribute. 149 | 150 | Returns: 151 | (obj): Attribute value. 152 | """ 153 | names = attr_name.split('.') 154 | atr = cfg 155 | for name in names: 156 | if not hasattr(atr, name): 157 | return default 158 | atr = getattr(atr, name) 159 | return atr 160 | 161 | 162 | def gradient_norm(model): 163 | r"""Return the gradient norm of model. 164 | 165 | Args: 166 | model (PyTorch module): Your network. 167 | 168 | """ 169 | total_norm = 0 170 | for p in model.parameters(): 171 | if p.grad is not None: 172 | param_norm = p.grad.norm(2) 173 | total_norm += param_norm.item() ** 2 174 | return total_norm ** (1. / 2) 175 | 176 | 177 | def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'): 178 | r"""Randomly shift the input tensor. 179 | 180 | Args: 181 | x (4D tensor): The input batch of images. 182 | offset (int): The maximum offset ratio that is between [0, 1]. 183 | The maximum shift is offset * image_size for each direction. 184 | mode (str): The resample mode for 'F.grid_sample'. 185 | padding_mode (str): The padding mode for 'F.grid_sample'. 186 | 187 | Returns: 188 | x (4D tensor) : The randomly shifted image. 189 | """ 190 | assert x.dim() == 4, "Input must be a 4D tensor." 191 | batch_size = x.size(0) 192 | theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat( 193 | batch_size, 1, 1) 194 | theta[:, :, 2] = 2 * offset * torch.rand(batch_size, 2) - offset 195 | grid = F.affine_grid(theta, x.size()) 196 | x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode) 197 | return x 198 | 199 | 200 | def truncated_gaussian(threshold, size, seed=None, device=None): 201 | r"""Apply the truncated gaussian trick to trade diversity for quality 202 | 203 | Args: 204 | threshold (float): Truncation threshold. 205 | size (list of integer): Tensor size. 206 | seed (int): Random seed. 207 | device: 208 | """ 209 | state = None if seed is None else np.random.RandomState(seed) 210 | values = truncnorm.rvs(-threshold, threshold, 211 | size=size, random_state=state) 212 | return torch.tensor(values, device=device).float() 213 | 214 | 215 | def apply_imagenet_normalization(input): 216 | r"""Normalize using ImageNet mean and std. 217 | 218 | Args: 219 | input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1]. 220 | 221 | Returns: 222 | Normalized inputs using the ImageNet normalization. 223 | """ 224 | # normalize the input back to [0, 1] 225 | normalized_input = (input + 1) / 2 226 | # normalize the input using the ImageNet mean and std 227 | mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 228 | std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 229 | output = (normalized_input - mean) / std 230 | return output 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | Website 4 | | 5 | ArXiv 6 | | 7 | Get Start 8 | | 9 | Video 10 | 11 |

12 | 13 | 14 | # PIRenderer 15 | 16 | The source code of the ICCV2021 paper "[PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering](https://arxiv.org/abs/2109.08379)" (ICCV2021) 17 | 18 | The proposed **PIRenderer** can synthesis portrait images by intuitively controlling the face motions with fully disentangled 3DMM parameters. This model can be applied to tasks such as: 19 | 20 | * **Intuitive Portrait Image Editing** 21 | 22 |

23 | 24 |

25 |

26 | Intuitive Portrait Image Control 27 |

28 |

29 | 30 |

31 |

32 | Pose & Expression Alignment 33 |

34 | 35 | 36 | * **Motion Imitation** 37 |

38 | 39 |

40 |

41 | Same & Corss-identity Reenactment 42 |

43 | 44 | * **Audio-Driven Facial Reenactment** 45 | 46 |

47 | 48 |

49 |

50 | Audio-Driven Reenactment 51 |

52 | 53 | ## News 54 | 55 | * 2021.9.20 Code for PyTorch is available! 56 | 57 | 58 | 59 | ## Colab Demo 60 | 61 | Coming soon 62 | 63 | 64 | ## Get Start 65 | 66 | ### 1). Installation 67 | 68 | #### Requirements 69 | 70 | * Python 3 71 | * PyTorch 1.7.1 72 | * CUDA 10.2 73 | 74 | #### Conda Installation 75 | 76 | ```bash 77 | # 1. Create a conda virtual environment. 78 | conda create -n PIRenderer python=3.6 79 | conda activate PIRenderer 80 | conda install -c pytorch pytorch=1.7.1 torchvision cudatoolkit=10.2 81 | 82 | # 2. Install other dependencies 83 | pip install -r requirements.txt 84 | ``` 85 | 86 | ### 2). Dataset 87 | 88 | We train our model using the [VoxCeleb](https://arxiv.org/abs/1706.08612). You can download the demo dataset for inference or prepare the dataset for training and testing. 89 | 90 | #### Download the demo dataset 91 | 92 | The demo dataset contains all 514 test videos. You can download the dataset with the following code: 93 | 94 | ```bash 95 | ./scripts/download_demo_dataset.sh 96 | ``` 97 | 98 | Or you can choose to download the resources with these links: 99 | 100 | ​ [Google Driven](https://drive.google.com/drive/folders/16Yn2r46b4cV6ZozOH6a8SdFz_iG7BQk1?usp=sharing) & [BaiDu Driven](https://pan.baidu.com/s/1e615bBHvM4Wz-2snk-86Xw) with extraction passwords ”p9ab“ 101 | 102 | Then unzip and save the files to `./dataset` 103 | 104 | #### Prepare the dataset 105 | 106 | 1. The dataset is preprocessed follow the method used in [First-Order](https://github.com/AliaksandrSiarohin/video-preprocessing). You can follow the instructions in their repo to download and crop videos for training and testing. 107 | 108 | 2. After obtaining the VoxCeleb videos, we extract 3DMM parameters using [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction). 109 | 110 | The folder are with format as: 111 | 112 | ``` 113 | ${DATASET_ROOT_FOLDER} 114 | └───path_to_videos 115 | └───train 116 | └───xxx.mp4 117 | └───xxx.mp4 118 | ... 119 | └───test 120 | └───xxx.mp4 121 | └───xxx.mp4 122 | ... 123 | └───path_to_3dmm_coeff 124 | └───train 125 | └───xxx.mat 126 | └───xxx.mat 127 | ... 128 | └───test 129 | └───xxx.mat 130 | └───xxx.mat 131 | ... 132 | ``` 133 | 134 | **News**: We provide Scripts for extracting 3dmm coeffs from videos. Please check the [DatasetHelper](./DatasetHelper.md) for more details. 135 | 136 | 3. We save the video and 3DMM parameters in a lmdb file. Please run the following code to do this 137 | 138 | ```bash 139 | python scripts/prepare_vox_lmdb.py \ 140 | --path path_to_videos \ 141 | --coeff_3dmm_path path_to_3dmm_coeff \ 142 | --out path_to_output_dir 143 | ``` 144 | 145 | ### 3). Training and Inference 146 | 147 | #### Inference 148 | 149 | The trained weights can be downloaded by running the following code: 150 | 151 | ```bash 152 | ./scripts/download_weights.sh 153 | ``` 154 | 155 | Or you can choose to download the resources with these links: 156 | 157 | [Google Driven](https://drive.google.com/file/d/1-0xOf6g58OmtKtEWJlU3VlnfRqPN9Uq7/view?usp=sharing) & [Baidu Driven](https://pan.baidu.com/s/18B3xfKMXnm4tOqlFSB8ntg) with extraction passwards "4sy1". 158 | 159 | Then unzip and save the files to `./result/face`. 160 | 161 | **Reenactment** 162 | 163 | Run the demo for face reenactment: 164 | 165 | ```bash 166 | # same identity 167 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 inference.py \ 168 | --config ./config/face_demo.yaml \ 169 | --name face \ 170 | --no_resume \ 171 | --output_dir ./vox_result/face_reenactment 172 | 173 | # cross identity 174 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 inference.py \ 175 | --config ./config/face_demo.yaml \ 176 | --name face \ 177 | --no_resume \ 178 | --output_dir ./vox_result/face_reenactment_cross \ 179 | --cross_id 180 | ``` 181 | 182 | The output results are saved at `./vox_result/face_reenactment` and `./vox_result/face_reenactment_cross` 183 | 184 | **Intuitive Control** 185 | 186 | Our model can generate results by providing intuitive controlling coefficients. 187 | We provide the following code for this task. Please note that you need to build the environment of [DeepFaceRecon](https://github.com/sicxu/Deep3DFaceRecon_pytorch/tree/73d491102af6731bded9ae6b3cc7466c3b2e9e48) first. 188 | 189 | ```bash 190 | # 1. Copy the provided scrips to the folder `Deep3DFaceRecon_pytorch`. 191 | cp scripts/face_recon_videos.py ./Deep3DFaceRecon_pytorch 192 | cp scripts/extract_kp_videos.py ./Deep3DFaceRecon_pytorch 193 | cp scripts/coeff_detector.py ./Deep3DFaceRecon_pytorch 194 | cp scripts/inference_options.py ./Deep3DFaceRecon_pytorch/options 195 | 196 | cd Deep3DFaceRecon_pytorch 197 | 198 | # 2. Extracte the 3dmm coefficients of the demo images. 199 | python coeff_detector.py \ 200 | --input_dir ../demo_images \ 201 | --keypoint_dir ../demo_images \ 202 | --output_dir ../demo_images \ 203 | --name=model_name \ 204 | --epoch=20 \ 205 | --model facerecon 206 | 207 | # 3. control the source image with our model 208 | cd .. 209 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 intuitive_control.py \ 210 | --config ./config/face_demo.yaml \ 211 | --name face \ 212 | --no_resume \ 213 | --output_dir ./vox_result/face_intuitive \ 214 | --input_name ./demo_images 215 | ``` 216 | 217 | 218 | #### Train 219 | 220 | Our model can be trained with the following code 221 | 222 | ```bash 223 | python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 train.py \ 224 | --config ./config/face.yaml \ 225 | --name face 226 | ``` 227 | 228 | 229 | ## Citation 230 | 231 | If you find this code is helpful, please cite our paper 232 | 233 | ```tex 234 | @misc{ren2021pirenderer, 235 | title={PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering}, 236 | author={Yurui Ren and Ge Li and Yuanqi Chen and Thomas H. Li and Shan Liu}, 237 | year={2021}, 238 | eprint={2109.08379}, 239 | archivePrefix={arXiv}, 240 | primaryClass={cs.CV} 241 | } 242 | ``` 243 | 244 | ## Acknowledgement 245 | 246 | We build our project base on [imaginaire](https://github.com/NVlabs/imaginaire). Some dataset preprocessing methods are derived from [video-preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). 247 | 248 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import os 4 | import re 5 | 6 | import yaml 7 | from util.distributed import master_only_print as print 8 | 9 | 10 | class AttrDict(dict): 11 | """Dict as attribute trick.""" 12 | 13 | def __init__(self, *args, **kwargs): 14 | super(AttrDict, self).__init__(*args, **kwargs) 15 | self.__dict__ = self 16 | for key, value in self.__dict__.items(): 17 | if isinstance(value, dict): 18 | self.__dict__[key] = AttrDict(value) 19 | elif isinstance(value, (list, tuple)): 20 | if isinstance(value[0], dict): 21 | self.__dict__[key] = [AttrDict(item) for item in value] 22 | else: 23 | self.__dict__[key] = value 24 | 25 | def yaml(self): 26 | """Convert object to yaml dict and return.""" 27 | yaml_dict = {} 28 | for key, value in self.__dict__.items(): 29 | if isinstance(value, AttrDict): 30 | yaml_dict[key] = value.yaml() 31 | elif isinstance(value, list): 32 | if isinstance(value[0], AttrDict): 33 | new_l = [] 34 | for item in value: 35 | new_l.append(item.yaml()) 36 | yaml_dict[key] = new_l 37 | else: 38 | yaml_dict[key] = value 39 | else: 40 | yaml_dict[key] = value 41 | return yaml_dict 42 | 43 | def __repr__(self): 44 | """Print all variables.""" 45 | ret_str = [] 46 | for key, value in self.__dict__.items(): 47 | if isinstance(value, AttrDict): 48 | ret_str.append('{}:'.format(key)) 49 | child_ret_str = value.__repr__().split('\n') 50 | for item in child_ret_str: 51 | ret_str.append(' ' + item) 52 | elif isinstance(value, list): 53 | if isinstance(value[0], AttrDict): 54 | ret_str.append('{}:'.format(key)) 55 | for item in value: 56 | # Treat as AttrDict above. 57 | child_ret_str = item.__repr__().split('\n') 58 | for item in child_ret_str: 59 | ret_str.append(' ' + item) 60 | else: 61 | ret_str.append('{}: {}'.format(key, value)) 62 | else: 63 | ret_str.append('{}: {}'.format(key, value)) 64 | return '\n'.join(ret_str) 65 | 66 | 67 | class Config(AttrDict): 68 | r"""Configuration class. This should include every human specifiable 69 | hyperparameter values for your training.""" 70 | 71 | def __init__(self, filename=None, args=None, verbose=False, is_train=True): 72 | super(Config, self).__init__() 73 | # Set default parameters. 74 | # Logging. 75 | 76 | large_number = 1000000000 77 | self.snapshot_save_iter = large_number 78 | self.snapshot_save_epoch = large_number 79 | self.snapshot_save_start_iter = 0 80 | self.snapshot_save_start_epoch = 0 81 | self.image_save_iter = large_number 82 | self.eval_epoch = large_number 83 | self.start_eval_epoch = large_number 84 | self.eval_epoch = large_number 85 | self.max_epoch = large_number 86 | self.max_iter = large_number 87 | self.logging_iter = 100 88 | self.image_to_tensorboard=False 89 | self.which_iter = args.which_iter 90 | self.resume = not args.no_resume 91 | 92 | 93 | self.checkpoints_dir = args.checkpoints_dir 94 | self.name = args.name 95 | self.phase = 'train' if is_train else 'test' 96 | 97 | # Networks. 98 | self.gen = AttrDict(type='generators.dummy') 99 | self.dis = AttrDict(type='discriminators.dummy') 100 | 101 | # Optimizers. 102 | self.gen_optimizer = AttrDict(type='adam', 103 | lr=0.0001, 104 | adam_beta1=0.0, 105 | adam_beta2=0.999, 106 | eps=1e-8, 107 | lr_policy=AttrDict(iteration_mode=False, 108 | type='step', 109 | step_size=large_number, 110 | gamma=1)) 111 | self.dis_optimizer = AttrDict(type='adam', 112 | lr=0.0001, 113 | adam_beta1=0.0, 114 | adam_beta2=0.999, 115 | eps=1e-8, 116 | lr_policy=AttrDict(iteration_mode=False, 117 | type='step', 118 | step_size=large_number, 119 | gamma=1)) 120 | # Data. 121 | self.data = AttrDict(name='dummy', 122 | type='datasets.images', 123 | num_workers=0) 124 | self.test_data = AttrDict(name='dummy', 125 | type='datasets.images', 126 | num_workers=0, 127 | test=AttrDict(is_lmdb=False, 128 | roots='', 129 | batch_size=1)) 130 | self.trainer = AttrDict( 131 | model_average=False, 132 | model_average_beta=0.9999, 133 | model_average_start_iteration=1000, 134 | model_average_batch_norm_estimation_iteration=30, 135 | model_average_remove_sn=True, 136 | image_to_tensorboard=False, 137 | hparam_to_tensorboard=False, 138 | distributed_data_parallel='pytorch', 139 | delay_allreduce=True, 140 | gan_relativistic=False, 141 | gen_step=1, 142 | dis_step=1) 143 | 144 | # # Cudnn. 145 | self.cudnn = AttrDict(deterministic=False, 146 | benchmark=True) 147 | 148 | # Others. 149 | self.pretrained_weight = '' 150 | self.inference_args = AttrDict() 151 | 152 | 153 | # Update with given configurations. 154 | assert os.path.exists(filename), 'File {} not exist.'.format(filename) 155 | loader = yaml.SafeLoader 156 | loader.add_implicit_resolver( 157 | u'tag:yaml.org,2002:float', 158 | re.compile(u'''^(?: 159 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 160 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 161 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 162 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 163 | |[-+]?\\.(?:inf|Inf|INF) 164 | |\\.(?:nan|NaN|NAN))$''', re.X), 165 | list(u'-+0123456789.')) 166 | try: 167 | with open(filename, 'r') as f: 168 | cfg_dict = yaml.load(f, Loader=loader) 169 | except EnvironmentError: 170 | print('Please check the file with name of "%s"', filename) 171 | recursive_update(self, cfg_dict) 172 | 173 | # Put common opts in both gen and dis. 174 | if 'common' in cfg_dict: 175 | self.common = AttrDict(**cfg_dict['common']) 176 | self.gen.common = self.common 177 | self.dis.common = self.common 178 | 179 | 180 | if verbose: 181 | print(' config '.center(80, '-')) 182 | print(self.__repr__()) 183 | print(''.center(80, '-')) 184 | 185 | 186 | def rsetattr(obj, attr, val): 187 | """Recursively find object and set value""" 188 | pre, _, post = attr.rpartition('.') 189 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 190 | 191 | 192 | def rgetattr(obj, attr, *args): 193 | """Recursively find object and return value""" 194 | 195 | def _getattr(obj, attr): 196 | r"""Get attribute.""" 197 | return getattr(obj, attr, *args) 198 | 199 | return functools.reduce(_getattr, [obj] + attr.split('.')) 200 | 201 | 202 | def recursive_update(d, u): 203 | """Recursively update AttrDict d with AttrDict u""" 204 | for key, value in u.items(): 205 | if isinstance(value, collections.abc.Mapping): 206 | d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) 207 | elif isinstance(value, (list, tuple)): 208 | if isinstance(value[0], dict): 209 | d.__dict__[key] = [AttrDict(item) for item in value] 210 | else: 211 | d.__dict__[key] = value 212 | else: 213 | d.__dict__[key] = value 214 | return d 215 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | from . import util 5 | from . import html 6 | # from pdb import set_trace as st 7 | import matplotlib.pyplot as plt 8 | import math 9 | # from IPython import embed 10 | 11 | def zoom_to_res(img,res=256,order=0,axis=0): 12 | # img 3xXxX 13 | from scipy.ndimage import zoom 14 | zoom_factor = res/img.shape[1] 15 | if(axis==0): 16 | return zoom(img,[1,zoom_factor,zoom_factor],order=order) 17 | elif(axis==2): 18 | return zoom(img,[zoom_factor,zoom_factor,1],order=order) 19 | 20 | class Visualizer(): 21 | def __init__(self, opt): 22 | # self.opt = opt 23 | self.display_id = opt.display_id 24 | # self.use_html = opt.is_train and not opt.no_html 25 | self.win_size = opt.display_winsize 26 | self.name = opt.name 27 | self.display_cnt = 0 # display_current_results counter 28 | self.display_cnt_high = 0 29 | self.use_html = opt.use_html 30 | 31 | if self.display_id > 0: 32 | import visdom 33 | self.vis = visdom.Visdom(port = opt.display_port) 34 | 35 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 36 | util.mkdirs([self.web_dir,]) 37 | if self.use_html: 38 | self.img_dir = os.path.join(self.web_dir, 'images') 39 | print('create web directory %s...' % self.web_dir) 40 | util.mkdirs([self.img_dir,]) 41 | 42 | # |visuals|: dictionary of images to display or save 43 | def display_current_results(self, visuals, epoch, nrows=None, res=256): 44 | if self.display_id > 0: # show images in the browser 45 | title = self.name 46 | if(nrows is None): 47 | nrows = int(math.ceil(len(visuals.items()) / 2.0)) 48 | images = [] 49 | idx = 0 50 | for label, image_numpy in visuals.items(): 51 | title += " | " if idx % nrows == 0 else ", " 52 | title += label 53 | img = image_numpy.transpose([2, 0, 1]) 54 | img = zoom_to_res(img,res=res,order=0) 55 | images.append(img) 56 | idx += 1 57 | if len(visuals.items()) % 2 != 0: 58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 59 | white_image = zoom_to_res(white_image,res=res,order=0) 60 | images.append(white_image) 61 | self.vis.images(images, nrow=nrows, win=self.display_id + 1, 62 | opts=dict(title=title)) 63 | 64 | if self.use_html: # save images to a html file 65 | for label, image_numpy in visuals.items(): 66 | img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label)) 67 | util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path) 68 | 69 | self.display_cnt += 1 70 | self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt) 71 | 72 | # update website 73 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 74 | for n in range(epoch, 0, -1): 75 | webpage.add_header('epoch [%d]' % n) 76 | if(n==epoch): 77 | high = self.display_cnt 78 | else: 79 | high = self.display_cnt_high 80 | for c in range(high-1,-1,-1): 81 | ims = [] 82 | txts = [] 83 | links = [] 84 | 85 | for label, image_numpy in visuals.items(): 86 | img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label) 87 | ims.append(os.path.join('images',img_path)) 88 | txts.append(label) 89 | links.append(os.path.join('images',img_path)) 90 | webpage.add_images(ims, txts, links, width=self.win_size) 91 | webpage.save() 92 | 93 | # save errors into a directory 94 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False): 95 | if not hasattr(self, 'plot_data'): 96 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 97 | self.plot_data['X'].append(epoch + counter_ratio) 98 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 99 | 100 | # embed() 101 | if(keys=='+ALL'): 102 | plot_keys = self.plot_data['legend'] 103 | else: 104 | plot_keys = keys 105 | 106 | if(to_plot): 107 | (f,ax) = plt.subplots(1,1) 108 | for (k,kname) in enumerate(plot_keys): 109 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0] 110 | x = self.plot_data['X'] 111 | y = np.array(self.plot_data['Y'])[:,kk] 112 | if(to_plot): 113 | ax.plot(x, y, 'o-', label=kname) 114 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x) 115 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y) 116 | 117 | if(to_plot): 118 | plt.legend(loc=0,fontsize='small') 119 | plt.xlabel('epoch') 120 | plt.ylabel('Value') 121 | f.savefig(os.path.join(self.web_dir,'%s.png'%name)) 122 | f.clf() 123 | plt.close() 124 | 125 | # errors: dictionary of error labels and values 126 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 127 | if not hasattr(self, 'plot_data'): 128 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 129 | self.plot_data['X'].append(epoch + counter_ratio) 130 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 131 | self.vis.line( 132 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 133 | Y=np.array(self.plot_data['Y']), 134 | opts={ 135 | 'title': self.name + ' loss over time', 136 | 'legend': self.plot_data['legend'], 137 | 'xlabel': 'epoch', 138 | 'ylabel': 'loss'}, 139 | win=self.display_id) 140 | 141 | # errors: same format as |errors| of plotCurrentErrors 142 | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None): 143 | message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2) 144 | message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()]) 145 | 146 | print(message) 147 | if(fid is not None): 148 | fid.write('%s\n'%message) 149 | 150 | 151 | # save image to the disk 152 | def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256): 153 | image_dir = webpage.get_image_dir() 154 | ims = [] 155 | txts = [] 156 | links = [] 157 | 158 | for name, image_numpy, txt in zip(names, images, in_txts): 159 | image_name = '%s_%s.png' % (prefix, name) 160 | save_path = os.path.join(image_dir, image_name) 161 | if(res is not None): 162 | util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path) 163 | else: 164 | util.save_image(image_numpy, save_path) 165 | 166 | ims.append(os.path.join(webpage.img_subdir,image_name)) 167 | # txts.append(name) 168 | txts.append(txt) 169 | links.append(os.path.join(webpage.img_subdir,image_name)) 170 | # embed() 171 | webpage.add_images(ims, txts, links, width=self.win_size) 172 | 173 | # save image to the disk 174 | def save_images(self, webpage, images, names, image_path, title=''): 175 | image_dir = webpage.get_image_dir() 176 | # short_path = ntpath.basename(image_path) 177 | # name = os.path.splitext(short_path)[0] 178 | # name = short_path 179 | # webpage.add_header('%s, %s' % (name, title)) 180 | ims = [] 181 | txts = [] 182 | links = [] 183 | 184 | for label, image_numpy in zip(names, images): 185 | image_name = '%s.jpg' % (label,) 186 | save_path = os.path.join(image_dir, image_name) 187 | util.save_image(image_numpy, save_path) 188 | 189 | ims.append(image_name) 190 | txts.append(label) 191 | links.append(image_name) 192 | webpage.add_images(ims, txts, links, width=self.win_size) 193 | 194 | # save image to the disk 195 | # def save_images(self, webpage, visuals, image_path, short=False): 196 | # image_dir = webpage.get_image_dir() 197 | # if short: 198 | # short_path = ntpath.basename(image_path) 199 | # name = os.path.splitext(short_path)[0] 200 | # else: 201 | # name = image_path 202 | 203 | # webpage.add_header(name) 204 | # ims = [] 205 | # txts = [] 206 | # links = [] 207 | 208 | # for label, image_numpy in visuals.items(): 209 | # image_name = '%s_%s.png' % (name, label) 210 | # save_path = os.path.join(image_dir, image_name) 211 | # util.save_image(image_numpy, save_path) 212 | 213 | # ims.append(image_name) 214 | # txts.append(label) 215 | # links.append(image_name) 216 | # webpage.add_images(ims, txts, links, width=self.win_size) 217 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/models/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | sys.path.append('..') 6 | sys.path.append('.') 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | from torch.autograd import Variable 11 | import numpy as np 12 | from pdb import set_trace as st 13 | from skimage import color 14 | from IPython import embed 15 | from . import pretrained_networks as pn 16 | 17 | # from .PerceptualSimilarity.util import util 18 | from ..util import util 19 | 20 | # Off-the-shelf deep network 21 | class PNet(nn.Module): 22 | '''Pre-trained network with all channels equally weighted by default''' 23 | def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True): 24 | super(PNet, self).__init__() 25 | 26 | self.use_gpu = use_gpu 27 | 28 | self.pnet_type = pnet_type 29 | self.pnet_rand = pnet_rand 30 | 31 | self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1)) 32 | self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1)) 33 | 34 | if(self.pnet_type in ['vgg','vgg16']): 35 | self.net = pn.vgg16(pretrained=not self.pnet_rand,requires_grad=False) 36 | elif(self.pnet_type=='alex'): 37 | self.net = pn.alexnet(pretrained=not self.pnet_rand,requires_grad=False) 38 | elif(self.pnet_type[:-2]=='resnet'): 39 | self.net = pn.resnet(pretrained=not self.pnet_rand,requires_grad=False, num=int(self.pnet_type[-2:])) 40 | elif(self.pnet_type=='squeeze'): 41 | self.net = pn.squeezenet(pretrained=not self.pnet_rand,requires_grad=False) 42 | 43 | self.L = self.net.N_slices 44 | 45 | if(use_gpu): 46 | self.net.cuda() 47 | self.shift = self.shift.cuda() 48 | self.scale = self.scale.cuda() 49 | 50 | def forward(self, in0, in1, retPerLayer=False): 51 | in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 52 | in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 53 | 54 | outs0 = self.net.forward(in0_sc) 55 | outs1 = self.net.forward(in1_sc) 56 | 57 | if(retPerLayer): 58 | all_scores = [] 59 | for (kk,out0) in enumerate(outs0): 60 | cur_score = (1.-util.cos_sim(outs0[kk],outs1[kk])) 61 | if(kk==0): 62 | val = 1.*cur_score 63 | else: 64 | # val = val + self.lambda_feat_layers[kk]*cur_score 65 | val = val + cur_score 66 | if(retPerLayer): 67 | all_scores+=[cur_score] 68 | 69 | if(retPerLayer): 70 | return (val, all_scores) 71 | else: 72 | return val 73 | 74 | # Learned perceptual metric 75 | class PNetLin(nn.Module): 76 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, use_gpu=True, spatial=False, version='0.1'): 77 | super(PNetLin, self).__init__() 78 | 79 | self.use_gpu = use_gpu 80 | self.pnet_type = pnet_type 81 | self.pnet_tune = pnet_tune 82 | self.pnet_rand = pnet_rand 83 | self.spatial = spatial 84 | self.version = version 85 | 86 | if(self.pnet_type in ['vgg','vgg16']): 87 | net_type = pn.vgg16 88 | self.chns = [64,128,256,512,512] 89 | elif(self.pnet_type=='alex'): 90 | net_type = pn.alexnet 91 | self.chns = [64,192,384,256,256] 92 | elif(self.pnet_type=='squeeze'): 93 | net_type = pn.squeezenet 94 | self.chns = [64,128,256,384,384,512,512] 95 | 96 | if(self.pnet_tune): 97 | self.net = net_type(pretrained=not self.pnet_rand,requires_grad=True) 98 | else: 99 | self.net = [net_type(pretrained=not self.pnet_rand,requires_grad=False),] 100 | 101 | self.lin0 = NetLinLayer(self.chns[0],use_dropout=use_dropout) 102 | self.lin1 = NetLinLayer(self.chns[1],use_dropout=use_dropout) 103 | self.lin2 = NetLinLayer(self.chns[2],use_dropout=use_dropout) 104 | self.lin3 = NetLinLayer(self.chns[3],use_dropout=use_dropout) 105 | self.lin4 = NetLinLayer(self.chns[4],use_dropout=use_dropout) 106 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 107 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 108 | self.lin5 = NetLinLayer(self.chns[5],use_dropout=use_dropout) 109 | self.lin6 = NetLinLayer(self.chns[6],use_dropout=use_dropout) 110 | self.lins+=[self.lin5,self.lin6] 111 | 112 | self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1)) 113 | self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1)) 114 | 115 | if(use_gpu): 116 | if(self.pnet_tune): 117 | self.net.cuda() 118 | else: 119 | self.net[0].cuda() 120 | self.shift = self.shift.cuda() 121 | self.scale = self.scale.cuda() 122 | self.lin0.cuda() 123 | self.lin1.cuda() 124 | self.lin2.cuda() 125 | self.lin3.cuda() 126 | self.lin4.cuda() 127 | if(self.pnet_type=='squeeze'): 128 | self.lin5.cuda() 129 | self.lin6.cuda() 130 | 131 | def forward(self, in0, in1): 132 | in0_sc = (in0 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 133 | in1_sc = (in1 - self.shift.expand_as(in0))/self.scale.expand_as(in0) 134 | 135 | if(self.version=='0.0'): 136 | # v0.0 - original release had a bug, where input was not scaled 137 | in0_input = in0 138 | in1_input = in1 139 | else: 140 | # v0.1 141 | in0_input = in0_sc 142 | in1_input = in1_sc 143 | 144 | if(self.pnet_tune): 145 | outs0 = self.net.forward(in0_input) 146 | outs1 = self.net.forward(in1_input) 147 | else: 148 | outs0 = self.net[0].forward(in0_input) 149 | outs1 = self.net[0].forward(in1_input) 150 | 151 | feats0 = {} 152 | feats1 = {} 153 | diffs = [0]*len(outs0) 154 | 155 | for (kk,out0) in enumerate(outs0): 156 | feats0[kk] = util.normalize_tensor(outs0[kk]) 157 | feats1[kk] = util.normalize_tensor(outs1[kk]) 158 | diffs[kk] = (feats0[kk]-feats1[kk])**2 159 | 160 | if self.spatial: 161 | lin_models = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 162 | if(self.pnet_type=='squeeze'): 163 | lin_models.extend([self.lin5, self.lin6]) 164 | res = [lin_models[kk].model(diffs[kk]) for kk in range(len(diffs))] 165 | return res 166 | 167 | val = torch.mean(torch.mean(self.lin0.model(diffs[0]),dim=3),dim=2) 168 | val = val + torch.mean(torch.mean(self.lin1.model(diffs[1]),dim=3),dim=2) 169 | val = val + torch.mean(torch.mean(self.lin2.model(diffs[2]),dim=3),dim=2) 170 | val = val + torch.mean(torch.mean(self.lin3.model(diffs[3]),dim=3),dim=2) 171 | val = val + torch.mean(torch.mean(self.lin4.model(diffs[4]),dim=3),dim=2) 172 | if(self.pnet_type=='squeeze'): 173 | val = val + torch.mean(torch.mean(self.lin5.model(diffs[5]),dim=3),dim=2) 174 | val = val + torch.mean(torch.mean(self.lin6.model(diffs[6]),dim=3),dim=2) 175 | 176 | val = val.view(val.size()[0],val.size()[1],1,1) 177 | 178 | return val 179 | 180 | class Dist2LogitLayer(nn.Module): 181 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 182 | def __init__(self, chn_mid=32,use_sigmoid=True): 183 | super(Dist2LogitLayer, self).__init__() 184 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 185 | layers += [nn.LeakyReLU(0.2,True),] 186 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 187 | layers += [nn.LeakyReLU(0.2,True),] 188 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 189 | if(use_sigmoid): 190 | layers += [nn.Sigmoid(),] 191 | self.model = nn.Sequential(*layers) 192 | 193 | def forward(self,d0,d1,eps=0.1): 194 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 195 | 196 | class BCERankingLoss(nn.Module): 197 | def __init__(self, use_gpu=True, chn_mid=32): 198 | super(BCERankingLoss, self).__init__() 199 | self.use_gpu = use_gpu 200 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 201 | self.parameters = list(self.net.parameters()) 202 | self.loss = torch.nn.BCELoss() 203 | self.model = nn.Sequential(*[self.net]) 204 | 205 | if(self.use_gpu): 206 | self.net.cuda() 207 | 208 | def forward(self, d0, d1, judge): 209 | per = (judge+1.)/2. 210 | if(self.use_gpu): 211 | per = per.cuda() 212 | self.logit = self.net.forward(d0,d1) 213 | return self.loss(self.logit, per) 214 | 215 | class NetLinLayer(nn.Module): 216 | ''' A single linear layer which does a 1x1 conv ''' 217 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 218 | super(NetLinLayer, self).__init__() 219 | 220 | layers = [nn.Dropout(),] if(use_dropout) else [] 221 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 222 | self.model = nn.Sequential(*layers) 223 | 224 | 225 | # L2, DSSIM metrics 226 | class FakeNet(nn.Module): 227 | def __init__(self, use_gpu=True, colorspace='Lab'): 228 | super(FakeNet, self).__init__() 229 | self.use_gpu = use_gpu 230 | self.colorspace=colorspace 231 | 232 | class L2(FakeNet): 233 | 234 | def forward(self, in0, in1): 235 | assert(in0.size()[0]==1) # currently only supports batchSize 1 236 | 237 | if(self.colorspace=='RGB'): 238 | (N,C,X,Y) = in0.size() 239 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 240 | return value 241 | elif(self.colorspace=='Lab'): 242 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 243 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 244 | ret_var = Variable( torch.Tensor((value,) ) ) 245 | if(self.use_gpu): 246 | ret_var = ret_var.cuda() 247 | return ret_var 248 | 249 | class DSSIM(FakeNet): 250 | 251 | def forward(self, in0, in1): 252 | assert(in0.size()[0]==1) # currently only supports batchSize 1 253 | 254 | if(self.colorspace=='RGB'): 255 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 256 | elif(self.colorspace=='Lab'): 257 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 258 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 259 | ret_var = Variable( torch.Tensor((value,) ) ) 260 | if(self.use_gpu): 261 | ret_var = ret_var.cuda() 262 | return ret_var 263 | 264 | def print_network(net): 265 | num_params = 0 266 | for param in net.parameters(): 267 | num_params += param.numel() 268 | print('Network',net) 269 | print('Total number of parameters: %d' % num_params) 270 | -------------------------------------------------------------------------------- /third_part/PerceptualSimilarity/models/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | sys.path.append('..') 6 | sys.path.append('.') 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | import os 11 | from collections import OrderedDict 12 | from torch.autograd import Variable 13 | import itertools 14 | from .base_model import BaseModel 15 | from scipy.ndimage import zoom 16 | import fractions 17 | import functools 18 | import skimage.transform 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | from third_part.PerceptualSimilarity.util import util 23 | # from util import util 24 | 25 | class DistModel(BaseModel): 26 | def name(self): 27 | return self.model_name 28 | 29 | def initialize(self, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'): 30 | ''' 31 | INPUTS 32 | model - ['net-lin'] for linearly calibrated network 33 | ['net'] for off-the-shelf network 34 | ['L2'] for L2 distance in Lab colorspace 35 | ['SSIM'] for ssim in RGB colorspace 36 | net - ['squeeze','alex','vgg'] 37 | model_path - if None, will look in weights/[NET_NAME].pth 38 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 39 | use_gpu - bool - whether or not to use a GPU 40 | printNet - bool - whether or not to print network architecture out 41 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 42 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 43 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 44 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 45 | is_train - bool - [True] for training mode 46 | lr - float - initial learning rate 47 | beta1 - float - initial momentum term for adam 48 | version - 0.1 for latest, 0.0 was original 49 | ''' 50 | BaseModel.initialize(self, use_gpu=use_gpu) 51 | 52 | self.model = model 53 | self.net = net 54 | self.use_gpu = use_gpu 55 | self.is_train = is_train 56 | self.spatial = spatial 57 | self.spatial_shape = spatial_shape 58 | self.spatial_order = spatial_order 59 | self.spatial_factor = spatial_factor 60 | 61 | self.model_name = '%s [%s]'%(model,net) 62 | if(self.model == 'net-lin'): # pretrained net + linear layer 63 | self.net = networks.PNetLin(use_gpu=use_gpu,pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | # model_path = './PerceptualSimilarity/weights/v%s/%s.pth'%(version,net) 70 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net))) 71 | 72 | if(not is_train): 73 | print('Loading model from: %s'%model_path) 74 | self.net.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 75 | 76 | elif(self.model=='net'): # pretrained network 77 | assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks' 78 | self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net) 79 | self.is_fake_net = True 80 | elif(self.model in ['L2','l2']): 81 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 82 | self.model_name = 'L2' 83 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 84 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 85 | self.model_name = 'SSIM' 86 | else: 87 | raise ValueError("Model [%s] not recognized." % self.model) 88 | 89 | self.parameters = list(self.net.parameters()) 90 | 91 | if self.is_train: # training mode 92 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 93 | self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu) 94 | self.parameters+=self.rankLoss.parameters 95 | self.lr = lr 96 | self.old_lr = lr 97 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 98 | else: # test mode 99 | self.net.eval() 100 | 101 | if(printNet): 102 | print('---------- Networks initialized -------------') 103 | networks.print_network(self.net) 104 | print('-----------------------------------------------') 105 | 106 | def forward_pair(self,in1,in2,retPerLayer=False): 107 | if(retPerLayer): 108 | return self.net.forward(in1,in2, retPerLayer=True) 109 | else: 110 | return self.net.forward(in1,in2) 111 | 112 | def forward(self, in0, in1, retNumpy=True): 113 | ''' Function computes the distance between image patches in0 and in1 114 | INPUTS 115 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 116 | retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array 117 | OUTPUT 118 | computed distances between in0 and in1 119 | ''' 120 | 121 | self.input_ref = in0 122 | self.input_p0 = in1 123 | 124 | if(self.use_gpu): 125 | self.input_ref = self.input_ref.cuda() 126 | self.input_p0 = self.input_p0.cuda() 127 | 128 | self.var_ref = Variable(self.input_ref,requires_grad=True) 129 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 130 | 131 | self.d0 = self.forward_pair(self.var_ref, self.var_p0) 132 | self.loss_total = self.d0 133 | 134 | def convert_output(d0): 135 | if(retNumpy): 136 | ans = d0.cpu().data.numpy() 137 | if not self.spatial: 138 | ans = ans.flatten() 139 | else: 140 | assert(ans.shape[0] == 1 and len(ans.shape) == 4) 141 | return ans[0,...].transpose([1, 2, 0]) # Reshape to usual numpy image format: (height, width, channels) 142 | return ans 143 | else: 144 | return d0 145 | 146 | if self.spatial: 147 | L = [convert_output(x) for x in self.d0] 148 | spatial_shape = self.spatial_shape 149 | if spatial_shape is None: 150 | if(self.spatial_factor is None): 151 | spatial_shape = (in0.size()[2],in0.size()[3]) 152 | else: 153 | spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor) 154 | 155 | L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L] 156 | 157 | L = np.mean(np.concatenate(L, 2) * len(L), 2) 158 | return L 159 | else: 160 | return convert_output(self.d0) 161 | 162 | # ***** TRAINING FUNCTIONS ***** 163 | def optimize_parameters(self): 164 | self.forward_train() 165 | self.optimizer_net.zero_grad() 166 | self.backward_train() 167 | self.optimizer_net.step() 168 | self.clamp_weights() 169 | 170 | def clamp_weights(self): 171 | for module in self.net.modules(): 172 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 173 | module.weight.data = torch.clamp(module.weight.data,min=0) 174 | 175 | def set_input(self, data): 176 | self.input_ref = data['ref'] 177 | self.input_p0 = data['p0'] 178 | self.input_p1 = data['p1'] 179 | self.input_judge = data['judge'] 180 | 181 | if(self.use_gpu): 182 | self.input_ref = self.input_ref.cuda() 183 | self.input_p0 = self.input_p0.cuda() 184 | self.input_p1 = self.input_p1.cuda() 185 | self.input_judge = self.input_judge.cuda() 186 | 187 | self.var_ref = Variable(self.input_ref,requires_grad=True) 188 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 189 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 190 | 191 | def forward_train(self): # run forward pass 192 | self.d0 = self.forward_pair(self.var_ref, self.var_p0) 193 | self.d1 = self.forward_pair(self.var_ref, self.var_p1) 194 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 195 | 196 | # var_judge 197 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 198 | 199 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 200 | return self.loss_total 201 | 202 | def backward_train(self): 203 | torch.mean(self.loss_total).backward() 204 | 205 | def compute_accuracy(self,d0,d1,judge): 206 | ''' d0, d1 are Variables, judge is a Tensor ''' 207 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 247 | self.old_lr = lr 248 | 249 | 250 | 251 | def score_2afc_dataset(data_loader,func): 252 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 253 | distance function 'func' in dataset 'data_loader' 254 | INPUTS 255 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 256 | func - callable distance function - calling d=func(in0,in1) should take 2 257 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 258 | OUTPUTS 259 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 260 | [1] - dictionary with following elements 261 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 262 | gts - N array in [0,1], preferred patch selected by human evaluators 263 | (closer to "0" for left patch p0, "1" for right patch p1, 264 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 265 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 266 | CONSTS 267 | N - number of test triplets in data_loader 268 | ''' 269 | 270 | d0s = [] 271 | d1s = [] 272 | gts = [] 273 | 274 | # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__()) 275 | for (i,data) in enumerate(data_loader.load_data()): 276 | d0s+=func(data['ref'],data['p0']).tolist() 277 | d1s+=func(data['ref'],data['p1']).tolist() 278 | gts+=data['judge'].cpu().numpy().flatten().tolist() 279 | # bar.update(i) 280 | 281 | d0s = np.array(d0s) 282 | d1s = np.array(d1s) 283 | gts = np.array(gts) 284 | scores = (d0s 0: 158 | mean = mean / count 159 | print(name) 160 | print(mean) 161 | 162 | def grab_patch(img_in, P, yy, xx): 163 | return img_in[yy:yy+P,xx:xx+P,:] 164 | 165 | def load_image(path): 166 | if(path[-3:] == 'dng'): 167 | import rawpy 168 | with rawpy.imread(path) as raw: 169 | img = raw.postprocess() 170 | # img = plt.imread(path) 171 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 172 | import cv2 173 | return cv2.imread(path)[:,:,::-1] 174 | else: 175 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 176 | 177 | return img 178 | 179 | 180 | def resize_image(img, max_size=256): 181 | [Y, X] = img.shape[:2] 182 | 183 | # resize 184 | max_dim = max([Y, X]) 185 | zoom_factor = 1. * max_size / max_dim 186 | img = zoom(img, [zoom_factor, zoom_factor, 1]) 187 | 188 | return img 189 | 190 | def resize_image_zoom(img, zoom_factor=1., order=3): 191 | if(zoom_factor==1): 192 | return img 193 | else: 194 | return zoom(img, [zoom_factor, zoom_factor, 1], order=order) 195 | 196 | def save_image(image_numpy, image_path, ): 197 | image_pil = Image.fromarray(image_numpy) 198 | image_pil.save(image_path) 199 | 200 | 201 | def prep_display_image(img, dtype='uint8'): 202 | if(dtype == 'uint8'): 203 | return np.clip(img, 0, 255).astype('uint8') 204 | else: 205 | return np.clip(img, 0, 1.) 206 | 207 | 208 | def info(object, spacing=10, collapse=1): 209 | """Print methods and doc strings. 210 | Takes module, class, list, dictionary, or string.""" 211 | methodList = [ 212 | e for e in dir(object) if isinstance( 213 | getattr( 214 | object, 215 | e), 216 | collections.Callable)] 217 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 218 | print("\n".join(["%s %s" % 219 | (method.ljust(spacing), 220 | processFunc(str(getattr(object, method).__doc__))) 221 | for method in methodList])) 222 | 223 | 224 | def varname(p): 225 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 226 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 227 | if m: 228 | return m.group(1) 229 | 230 | 231 | def print_numpy(x, val=True, shp=False): 232 | x = x.astype(np.float64) 233 | if shp: 234 | print('shape,', x.shape) 235 | if val: 236 | x = x.flatten() 237 | print( 238 | 'mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % 239 | (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 240 | 241 | 242 | def mkdirs(paths): 243 | if isinstance(paths, list) and not isinstance(paths, str): 244 | for path in paths: 245 | mkdir(path) 246 | else: 247 | mkdir(paths) 248 | 249 | 250 | def mkdir(path): 251 | if not os.path.exists(path): 252 | os.makedirs(path) 253 | 254 | 255 | def rgb2lab(input): 256 | from skimage import color 257 | return color.rgb2lab(input / 255.) 258 | 259 | 260 | def montage( 261 | imgs, 262 | PAD=5, 263 | RATIO=16 / 9., 264 | EXTRA_PAD=( 265 | False, 266 | False), 267 | MM=-1, 268 | NN=-1, 269 | primeDir=0, 270 | verbose=False, 271 | returnGridPos=False, 272 | backClr=np.array( 273 | (0, 274 | 0, 275 | 0))): 276 | # INPUTS 277 | # imgs YxXxMxN or YxXxN 278 | # PAD scalar number of pixels in between 279 | # RATIO scalar target ratio of cols/rows 280 | # MM scalar # rows, if specified, overrides RATIO 281 | # NN scalar # columns, if specified, overrides RATIO 282 | # primeDir scalar 0 for top-to-bottom, 1 for left-to-right 283 | # OUTPUTS 284 | # mont_imgs MM*Y x NN*X x M big image with everything montaged 285 | # def montage(imgs, PAD=5, RATIO=16/9., MM=-1, NN=-1, primeDir=0, 286 | # verbose=False, forceFloat=False): 287 | if(imgs.ndim == 3): 288 | toExp = True 289 | imgs = imgs[:, :, np.newaxis, :] 290 | else: 291 | toExp = False 292 | 293 | Y = imgs.shape[0] 294 | X = imgs.shape[1] 295 | M = imgs.shape[2] 296 | N = imgs.shape[3] 297 | 298 | PADS = np.array((PAD)) 299 | if(PADS.flatten().size == 1): 300 | PADY = PADS 301 | PADX = PADS 302 | else: 303 | PADY = PADS[0] 304 | PADX = PADS[1] 305 | 306 | if(MM == -1 and NN == -1): 307 | NN = np.ceil(np.sqrt(1.0 * N * RATIO)) 308 | MM = np.ceil(1.0 * N / NN) 309 | NN = np.ceil(1.0 * N / MM) 310 | elif(MM == -1): 311 | MM = np.ceil(1.0 * N / NN) 312 | elif(NN == -1): 313 | NN = np.ceil(1.0 * N / MM) 314 | 315 | if(primeDir == 0): # write top-to-bottom 316 | [grid_mm, grid_nn] = np.meshgrid( 317 | np.arange(MM, dtype='uint'), np.arange(NN, dtype='uint')) 318 | elif(primeDir == 1): # write left-to-right 319 | [grid_nn, grid_mm] = np.meshgrid( 320 | np.arange(NN, dtype='uint'), np.arange(MM, dtype='uint')) 321 | 322 | grid_mm = np.uint(grid_mm.flatten()[0:N]) 323 | grid_nn = np.uint(grid_nn.flatten()[0:N]) 324 | 325 | EXTRA_PADY = EXTRA_PAD[0] * PADY 326 | EXTRA_PADX = EXTRA_PAD[0] * PADX 327 | 328 | # mont_imgs = np.zeros(((Y+PAD)*MM-PAD, (X+PAD)*NN-PAD, M), dtype=use_dtype) 329 | mont_imgs = np.zeros( 330 | (np.uint( 331 | (Y + PADY) * MM - PADY + EXTRA_PADY), 332 | np.uint( 333 | (X + PADX) * NN - PADX + EXTRA_PADX), 334 | M), 335 | dtype=imgs.dtype) 336 | mont_imgs = mont_imgs + \ 337 | backClr.flatten()[np.newaxis, np.newaxis, :].astype(mont_imgs.dtype) 338 | 339 | for ii in np.random.permutation(N): 340 | # print imgs[:,:,:,ii].shape 341 | # mont_imgs[grid_mm[ii]*(Y+PAD):(grid_mm[ii]*(Y+PAD)+Y), grid_nn[ii]*(X+PAD):(grid_nn[ii]*(X+PAD)+X),:] 342 | mont_imgs[np.uint(grid_mm[ii] * 343 | (Y + 344 | PADY)):np.uint((grid_mm[ii] * 345 | (Y + 346 | PADY) + 347 | Y)), np.uint(grid_nn[ii] * 348 | (X + 349 | PADX)):np.uint((grid_nn[ii] * 350 | (X + 351 | PADX) + 352 | X)), :] = imgs[:, :, :, ii] 353 | 354 | if(M == 1): 355 | imgs = imgs.reshape(imgs.shape[0], imgs.shape[1], imgs.shape[3]) 356 | 357 | if(toExp): 358 | mont_imgs = mont_imgs[:, :, 0] 359 | 360 | if(returnGridPos): 361 | # return (mont_imgs,np.concatenate((grid_mm[:,:,np.newaxis]*(Y+PAD), 362 | # grid_nn[:,:,np.newaxis]*(X+PAD)),axis=2)) 363 | return (mont_imgs, np.concatenate( 364 | (grid_mm[:, np.newaxis] * (Y + PADY), grid_nn[:, np.newaxis] * (X + PADX)), axis=1)) 365 | # return (mont_imgs, (grid_mm,grid_nn)) 366 | else: 367 | return mont_imgs 368 | 369 | class zeroClipper(object): 370 | def __init__(self, frequency=1): 371 | self.frequency = frequency 372 | 373 | def __call__(self, module): 374 | embed() 375 | if hasattr(module, 'weight'): 376 | # module.weight.data = torch.max(module.weight.data, 0) 377 | module.weight.data = torch.max(module.weight.data, 0) + 100 378 | 379 | def flatten_nested_list(nested_list): 380 | # only works for list of list 381 | accum = [] 382 | for sublist in nested_list: 383 | for item in sublist: 384 | accum.append(item) 385 | return accum 386 | 387 | def read_file(in_path,list_lines=False): 388 | agg_str = '' 389 | f = open(in_path,'r') 390 | cur_line = f.readline() 391 | while(cur_line!=''): 392 | agg_str+=cur_line 393 | cur_line = f.readline() 394 | f.close() 395 | if(list_lines==False): 396 | return agg_str.replace('\n','') 397 | else: 398 | line_list = agg_str.split('\n') 399 | ret_list = [] 400 | for item in line_list: 401 | if(item!=''): 402 | ret_list.append(item) 403 | return ret_list 404 | 405 | def read_csv_file_as_text(in_path): 406 | agg_str = [] 407 | f = open(in_path,'r') 408 | cur_line = f.readline() 409 | while(cur_line!=''): 410 | agg_str.append(cur_line) 411 | cur_line = f.readline() 412 | f.close() 413 | return agg_str 414 | 415 | def random_swap(obj0,obj1): 416 | if(np.random.rand() < .5): 417 | return (obj0,obj1,0) 418 | else: 419 | return (obj1,obj0,1) 420 | 421 | def voc_ap(rec, prec, use_07_metric=False): 422 | """ ap = voc_ap(rec, prec, [use_07_metric]) 423 | Compute VOC AP given precision and recall. 424 | If use_07_metric is true, uses the 425 | VOC 07 11 point method (default:False). 426 | """ 427 | if use_07_metric: 428 | # 11 point metric 429 | ap = 0. 430 | for t in np.arange(0., 1.1, 0.1): 431 | if np.sum(rec >= t) == 0: 432 | p = 0 433 | else: 434 | p = np.max(prec[rec >= t]) 435 | ap = ap + p / 11. 436 | else: 437 | # correct AP calculation 438 | # first append sentinel values at the end 439 | mrec = np.concatenate(([0.], rec, [1.])) 440 | mpre = np.concatenate(([0.], prec, [0.])) 441 | 442 | # compute the precision envelope 443 | for i in range(mpre.size - 1, 0, -1): 444 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 445 | 446 | # to calculate area under PR curve, look for points 447 | # where X axis (recall) changes value 448 | i = np.where(mrec[1:] != mrec[:-1])[0] 449 | 450 | # and sum (\Delta recall) * prec 451 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 452 | return ap 453 | -------------------------------------------------------------------------------- /generators/base_function.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.autograd import Function 8 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 9 | 10 | 11 | class LayerNorm2d(nn.Module): 12 | def __init__(self, n_out, affine=True): 13 | super(LayerNorm2d, self).__init__() 14 | self.n_out = n_out 15 | self.affine = affine 16 | 17 | if self.affine: 18 | self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) 19 | self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) 20 | 21 | def forward(self, x): 22 | normalized_shape = x.size()[1:] 23 | if self.affine: 24 | return F.layer_norm(x, normalized_shape, \ 25 | self.weight.expand(normalized_shape), 26 | self.bias.expand(normalized_shape)) 27 | 28 | else: 29 | return F.layer_norm(x, normalized_shape) 30 | 31 | class ADAINHourglass(nn.Module): 32 | def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect): 33 | super(ADAINHourglass, self).__init__() 34 | self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect) 35 | self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect) 36 | self.output_nc = self.decoder.output_nc 37 | 38 | def forward(self, x, z): 39 | return self.decoder(self.encoder(x, z), z) 40 | 41 | 42 | 43 | class ADAINEncoder(nn.Module): 44 | def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False): 45 | super(ADAINEncoder, self).__init__() 46 | self.layers = layers 47 | self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3) 48 | for i in range(layers): 49 | in_channels = min(ngf * (2**i), img_f) 50 | out_channels = min(ngf *(2**(i+1)), img_f) 51 | model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect) 52 | setattr(self, 'encoder' + str(i), model) 53 | self.output_nc = out_channels 54 | 55 | def forward(self, x, z): 56 | out = self.input_layer(x) 57 | out_list = [out] 58 | for i in range(self.layers): 59 | model = getattr(self, 'encoder' + str(i)) 60 | out = model(out, z) 61 | out_list.append(out) 62 | return out_list 63 | 64 | class ADAINDecoder(nn.Module): 65 | """docstring for ADAINDecoder""" 66 | def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True, 67 | nonlinearity=nn.LeakyReLU(), use_spect=False): 68 | 69 | super(ADAINDecoder, self).__init__() 70 | self.encoder_layers = encoder_layers 71 | self.decoder_layers = decoder_layers 72 | self.skip_connect = skip_connect 73 | use_transpose = True 74 | 75 | for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]: 76 | in_channels = min(ngf * (2**(i+1)), img_f) 77 | in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels 78 | out_channels = min(ngf * (2**i), img_f) 79 | model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect) 80 | setattr(self, 'decoder' + str(i), model) 81 | 82 | self.output_nc = out_channels*2 if self.skip_connect else out_channels 83 | 84 | def forward(self, x, z): 85 | out = x.pop() if self.skip_connect else x 86 | for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]: 87 | model = getattr(self, 'decoder' + str(i)) 88 | out = model(out, z) 89 | out = torch.cat([out, x.pop()], 1) if self.skip_connect else out 90 | return out 91 | 92 | class ADAINEncoderBlock(nn.Module): 93 | def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False): 94 | super(ADAINEncoderBlock, self).__init__() 95 | kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1} 96 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} 97 | 98 | self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect) 99 | self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect) 100 | 101 | 102 | self.norm_0 = ADAIN(input_nc, feature_nc) 103 | self.norm_1 = ADAIN(output_nc, feature_nc) 104 | self.actvn = nonlinearity 105 | 106 | def forward(self, x, z): 107 | x = self.conv_0(self.actvn(self.norm_0(x, z))) 108 | x = self.conv_1(self.actvn(self.norm_1(x, z))) 109 | return x 110 | 111 | class ADAINDecoderBlock(nn.Module): 112 | def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False): 113 | super(ADAINDecoderBlock, self).__init__() 114 | # Attributes 115 | self.actvn = nonlinearity 116 | hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc 117 | 118 | kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1} 119 | if use_transpose: 120 | kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1} 121 | else: 122 | kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1} 123 | 124 | # create conv layers 125 | self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect) 126 | if use_transpose: 127 | self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect) 128 | self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect) 129 | else: 130 | self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect), 131 | nn.Upsample(scale_factor=2)) 132 | self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect), 133 | nn.Upsample(scale_factor=2)) 134 | # define normalization layers 135 | self.norm_0 = ADAIN(input_nc, feature_nc) 136 | self.norm_1 = ADAIN(hidden_nc, feature_nc) 137 | self.norm_s = ADAIN(input_nc, feature_nc) 138 | 139 | def forward(self, x, z): 140 | x_s = self.shortcut(x, z) 141 | dx = self.conv_0(self.actvn(self.norm_0(x, z))) 142 | dx = self.conv_1(self.actvn(self.norm_1(dx, z))) 143 | out = x_s + dx 144 | return out 145 | 146 | def shortcut(self, x, z): 147 | x_s = self.conv_s(self.actvn(self.norm_s(x, z))) 148 | return x_s 149 | 150 | 151 | def spectral_norm(module, use_spect=True): 152 | """use spectral normal layer to stable the training process""" 153 | if use_spect: 154 | return SpectralNorm(module) 155 | else: 156 | return module 157 | 158 | 159 | class ADAIN(nn.Module): 160 | def __init__(self, norm_nc, feature_nc): 161 | super().__init__() 162 | 163 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 164 | 165 | nhidden = 128 166 | use_bias=True 167 | 168 | self.mlp_shared = nn.Sequential( 169 | nn.Linear(feature_nc, nhidden, bias=use_bias), 170 | nn.ReLU() 171 | ) 172 | self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias) 173 | self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias) 174 | 175 | def forward(self, x, feature): 176 | 177 | # Part 1. generate parameter-free normalized activations 178 | normalized = self.param_free_norm(x) 179 | 180 | # Part 2. produce scaling and bias conditioned on feature 181 | feature = feature.view(feature.size(0), -1) 182 | actv = self.mlp_shared(feature) 183 | gamma = self.mlp_gamma(actv) 184 | beta = self.mlp_beta(actv) 185 | 186 | # apply scale and bias 187 | gamma = gamma.view(*gamma.size()[:2], 1,1) 188 | beta = beta.view(*beta.size()[:2], 1,1) 189 | out = normalized * (1 + gamma) + beta 190 | return out 191 | 192 | 193 | class FineEncoder(nn.Module): 194 | """docstring for Encoder""" 195 | def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 196 | super(FineEncoder, self).__init__() 197 | self.layers = layers 198 | self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect) 199 | for i in range(layers): 200 | in_channels = min(ngf*(2**i), img_f) 201 | out_channels = min(ngf*(2**(i+1)), img_f) 202 | model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) 203 | setattr(self, 'down' + str(i), model) 204 | self.output_nc = out_channels 205 | 206 | def forward(self, x): 207 | x = self.first(x) 208 | out=[x] 209 | for i in range(self.layers): 210 | model = getattr(self, 'down'+str(i)) 211 | x = model(x) 212 | out.append(x) 213 | return out 214 | 215 | class FineDecoder(nn.Module): 216 | """docstring for FineDecoder""" 217 | def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 218 | super(FineDecoder, self).__init__() 219 | self.layers = layers 220 | for i in range(layers)[::-1]: 221 | in_channels = min(ngf*(2**(i+1)), img_f) 222 | out_channels = min(ngf*(2**i), img_f) 223 | up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) 224 | res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect) 225 | jump = Jump(out_channels, norm_layer, nonlinearity, use_spect) 226 | 227 | setattr(self, 'up' + str(i), up) 228 | setattr(self, 'res' + str(i), res) 229 | setattr(self, 'jump' + str(i), jump) 230 | 231 | self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh') 232 | 233 | self.output_nc = out_channels 234 | 235 | def forward(self, x, z): 236 | out = x.pop() 237 | for i in range(self.layers)[::-1]: 238 | res_model = getattr(self, 'res' + str(i)) 239 | up_model = getattr(self, 'up' + str(i)) 240 | jump_model = getattr(self, 'jump' + str(i)) 241 | out = res_model(out, z) 242 | out = up_model(out) 243 | out = jump_model(x.pop()) + out 244 | out_image = self.final(out) 245 | return out_image 246 | 247 | class FirstBlock2d(nn.Module): 248 | """ 249 | Downsampling block for use in encoder. 250 | """ 251 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 252 | super(FirstBlock2d, self).__init__() 253 | kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3} 254 | conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 255 | 256 | if type(norm_layer) == type(None): 257 | self.model = nn.Sequential(conv, nonlinearity) 258 | else: 259 | self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) 260 | 261 | 262 | def forward(self, x): 263 | out = self.model(x) 264 | return out 265 | 266 | class DownBlock2d(nn.Module): 267 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 268 | super(DownBlock2d, self).__init__() 269 | 270 | 271 | kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} 272 | conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 273 | pool = nn.AvgPool2d(kernel_size=(2, 2)) 274 | 275 | if type(norm_layer) == type(None): 276 | self.model = nn.Sequential(conv, nonlinearity, pool) 277 | else: 278 | self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool) 279 | 280 | def forward(self, x): 281 | out = self.model(x) 282 | return out 283 | 284 | class UpBlock2d(nn.Module): 285 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 286 | super(UpBlock2d, self).__init__() 287 | kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} 288 | conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 289 | if type(norm_layer) == type(None): 290 | self.model = nn.Sequential(conv, nonlinearity) 291 | else: 292 | self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) 293 | 294 | def forward(self, x): 295 | out = self.model(F.interpolate(x, scale_factor=2)) 296 | return out 297 | 298 | class FineADAINResBlocks(nn.Module): 299 | def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 300 | super(FineADAINResBlocks, self).__init__() 301 | self.num_block = num_block 302 | for i in range(num_block): 303 | model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect) 304 | setattr(self, 'res'+str(i), model) 305 | 306 | def forward(self, x, z): 307 | for i in range(self.num_block): 308 | model = getattr(self, 'res'+str(i)) 309 | x = model(x, z) 310 | return x 311 | 312 | class Jump(nn.Module): 313 | def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 314 | super(Jump, self).__init__() 315 | kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} 316 | conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) 317 | 318 | if type(norm_layer) == type(None): 319 | self.model = nn.Sequential(conv, nonlinearity) 320 | else: 321 | self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity) 322 | 323 | def forward(self, x): 324 | out = self.model(x) 325 | return out 326 | 327 | class FineADAINResBlock2d(nn.Module): 328 | """ 329 | Define an Residual block for different types 330 | """ 331 | def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): 332 | super(FineADAINResBlock2d, self).__init__() 333 | 334 | kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} 335 | 336 | self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) 337 | self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) 338 | self.norm1 = ADAIN(input_nc, feature_nc) 339 | self.norm2 = ADAIN(input_nc, feature_nc) 340 | 341 | self.actvn = nonlinearity 342 | 343 | 344 | def forward(self, x, z): 345 | dx = self.actvn(self.norm1(self.conv1(x), z)) 346 | dx = self.norm2(self.conv2(x), z) 347 | out = dx + x 348 | return out 349 | 350 | class FinalBlock2d(nn.Module): 351 | """ 352 | Define the output layer 353 | """ 354 | def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'): 355 | super(FinalBlock2d, self).__init__() 356 | 357 | kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3} 358 | conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 359 | 360 | if tanh_or_sigmoid == 'sigmoid': 361 | out_nonlinearity = nn.Sigmoid() 362 | else: 363 | out_nonlinearity = nn.Tanh() 364 | 365 | self.model = nn.Sequential(conv, out_nonlinearity) 366 | def forward(self, x): 367 | out = self.model(x) 368 | return out --------------------------------------------------------------------------------