├── Hallo_node.py ├── LICENSE ├── README.md ├── __init__.py ├── accelerate_config.yaml ├── assets ├── framework_1.jpg ├── framework_2.jpg └── wechat.jpeg ├── configs ├── inference │ └── long.yaml ├── train │ ├── stage1.yaml │ ├── stage2_long.yaml │ └── video_sr.yaml ├── unet │ ├── config.json │ └── unet.yaml └── vae │ └── config.json ├── example.png ├── hallo ├── __init__.py ├── animate │ ├── __init__.py │ ├── face_animate.py │ └── face_animate_static.py ├── basicsr │ ├── VERSION │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── data_util.py │ │ ├── gaussian_kernels.py │ │ ├── prefetch_dataloader.py │ │ ├── transforms.py │ │ └── vfhq_dataset.py │ ├── hallo_archs │ │ ├── __init__.py │ │ ├── arcface_arch.py │ │ ├── arch_util.py │ │ ├── codeformer_arch.py │ │ ├── rrdbnet_arch.py │ │ ├── vgg_arch.py │ │ └── vqgan_arch.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ ├── metrics │ │ ├── __init__.py │ │ ├── metric_util.py │ │ └── psnr_ssim.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── codeformer_temporal_model.py │ │ ├── lr_scheduler.py │ │ ├── sr_model.py │ │ └── vqgan_model.py │ ├── ops │ │ ├── __init__.py │ │ ├── dcn │ │ │ ├── __init__.py │ │ │ ├── deform_conv.py │ │ │ └── src │ │ │ │ ├── deform_conv_cuda.cpp │ │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ │ └── deform_conv_ext.cpp │ │ ├── fused_act │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ └── src │ │ │ │ ├── fused_bias_act.cpp │ │ │ │ └── fused_bias_act_kernel.cu │ │ └── upfirdn2d │ │ │ ├── __init__.py │ │ │ ├── src │ │ │ ├── upfirdn2d.cpp │ │ │ └── upfirdn2d_kernel.cu │ │ │ └── upfirdn2d.py │ ├── setup.py │ ├── train.py │ ├── utils │ │ ├── __init__.py │ │ ├── dist_util.py │ │ ├── download_util.py │ │ ├── file_client.py │ │ ├── img_util.py │ │ ├── lmdb_util.py │ │ ├── logger.py │ │ ├── matlab_functions.py │ │ ├── misc.py │ │ ├── options.py │ │ ├── realesrgan_utils.py │ │ ├── registry.py │ │ └── video_util.py │ └── version.py ├── datasets │ ├── __init__.py │ ├── audio_processor.py │ ├── image_processor.py │ ├── mask_image.py │ └── talk_video.py ├── facelib │ ├── detection │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── matlab_cp2tform.py │ │ ├── retinaface │ │ │ ├── retinaface.py │ │ │ ├── retinaface_net.py │ │ │ └── retinaface_utils.py │ │ └── yolov5face │ │ │ ├── __init__.py │ │ │ ├── face_detector.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── experimental.py │ │ │ ├── yolo.py │ │ │ ├── yolov5l.yaml │ │ │ └── yolov5n.yaml │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── autoanchor.py │ │ │ ├── datasets.py │ │ │ ├── extract_ckpt.py │ │ │ ├── general.py │ │ │ └── torch_utils.py │ ├── parsing │ │ ├── __init__.py │ │ ├── bisenet.py │ │ ├── parsenet.py │ │ └── resnet.py │ └── utils │ │ ├── __init__.py │ │ ├── face_restoration_helper.py │ │ ├── face_utils.py │ │ └── misc.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── audio_proj.py │ ├── face_locator.py │ ├── image_proj.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ ├── unet_3d_blocks.py │ └── wav2vec.py ├── utils │ ├── __init__.py │ ├── config.py │ └── util.py └── video_sr.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── app.py ├── data_preprocess.py ├── extract_meta_info_stage1.py ├── extract_meta_info_stage2.py ├── inference_long.py ├── train_stage1.py └── train_stage2_long.py ├── utils.py └── workflow.json /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 smthemex 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_Hallo2 2 | [Hallo2](https://github.com/fudan-generative-vision/hallo2): Long-Duration and High-Resolution Audio-driven Portrait Image Animation, 3 | 4 | ## Updates: 5 | **2024/10/22** 6 | * 修复task地址绝对引用可能出现的问题。(fix bug) 7 | * Currently, only square 512 images and 2x magnification are supported(目前仅支持方形512图像和2倍放大,官方模型和方法所限) 8 | * input audio must be *.wav (输入的音频格式只能是wav,采样用的是16000,你要用高保真的自己合成就是了,别矫情.) 9 | 10 | 1.Installation 11 | ----- 12 | In the ./ComfyUI /custom_node directory, run the following: 13 | ``` python 14 | git clone https://github.com/smthemex/ComfyUI_Hallo2 15 | 16 | ``` 17 | 2.requirements 18 | ---- 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | if using embeded comfyUI,in your "X:\ComfyUI_windows\python_embeded "(便携包的comfyUI用户在python_embeded目录下用以下命令安装) 23 | ``` 24 | python -m pip install -r requirements.txt 25 | ``` 26 | Possible installation difficulties that may be encountered(可能会遇到的安装难题): 27 | * 2.1 audio-separator 28 | * 2.1.1 If' pip install audio-separator' building wheel fail(diffq),makesure has install [visual-cpp-build-tools](https://visualstudio.microsoft.com/zh-hans/visual-cpp-build-tools ) in window 29 | 安装audio-separator可能会出现vs的报错,确认你安装了[visual-cpp-build-tools](https://visualstudio.microsoft.com/zh-hans/visual-cpp-build-tools ) 30 | * 2.1.2 Although there are ‘visual-cpp-build-tools’, it still fails(diffq). If you are using the ComfyUI portable package or Akiba package, please add the interpreter address to the Windows system variable. 31 | 虽然有‘visual-cpp-build-tools,但是还是失败(diffq),如果使用的是comfyUI便携包,或者秋叶包,请将解释器地址加入windows的系统变量里,Linux用户,你都用Linux了,就不用我教了吧,window的做法是,将X:\ComfyUI_windows\python_embeded 和F:\ComfyUI_windows\python_embeded\Scripts 2个地址加入Path系统变量里。 32 | * 3.2 ffmpeg 33 | * 3.3 If the module is missing, Remove the requirements' # symbol,please pip install 34 | * 3.4 onnx 错误 35 | 少了啥,就去掉#号,重新安装 36 | 37 | 3 checkpoints 38 | ---- 39 | 所有模型下载地址(all checkpoints):[huggingface](https://huggingface.co/fudan-generative-ai/hallo2/tree/main) 40 | 41 | ``` 42 | ├── ComfyUI/models/Hallo/ 43 | |-- audio_separator/ 44 | | |-- download_checks.json 45 | | |-- mdx_model_data.json 46 | | |-- vr_model_data.json 47 | | `-- Kim_Vocal_2.onnx 48 | |-- face_analysis/ 49 | | `-- models/ 50 | | |-- face_landmarker_v2_with_blendshapes.task # face landmarker model from mediapipe 51 | | |-- 1k3d68.onnx 52 | | |-- 2d106det.onnx 53 | | |-- genderage.onnx 54 | | |-- glintr100.onnx 55 | | `-- scrfd_10g_bnkps.onnx 56 | |-- facelib 57 | | |-- detection_mobilenet0.25_Final.pth 58 | | |-- detection_Resnet50_Final.pth 59 | | |-- parsing_parsenet.pth 60 | | |-- yolov5l-face.pth 61 | | `-- yolov5n-face.pth 62 | |-- hallo2 63 | | |-- net_g.pth 64 | | `-- net.pth 65 | |-- motion_module/ 66 | | `-- mm_sd_v15_v2.ckpt 67 | `-- wav2vec/ 68 | `-- wav2vec2-base-960h/ 69 | |-- config.json 70 | |-- feature_extractor_config.json 71 | |-- model.safetensors 72 | |-- preprocessor_config.json 73 | |-- special_tokens_map.json 74 | |-- tokenizer_config.json 75 | `-- vocab.json 76 | ``` 77 | Normal checkpoints 78 | ``` 79 | ├── ComfyUI/models/ 80 | |-- upscale_models/ 81 | | `-- RealESRGAN_x2plus.pth 82 | |-- vae/ 83 | | `-- vae-ft-mse-840000-ema-pruned.safetensors 84 | |-- checkpoints/ 85 | | `-- v1-5-pruned-emaonly.safetensors # any sd1.5 if load fail, changge another 86 | 87 | ``` 88 | 5 Example 89 | ---- 90 | ![](https://github.com/smthemex/ComfyUI_Hallo2/blob/main/example.png) 91 | 92 | 6 Citation 93 | ------ 94 | hallo2 95 | ``` 96 | @misc{cui2024hallo2, 97 | title={Hallo2: Long-Duration and High-Resolution Audio-driven Portrait Image Animation}, 98 | author={Jiahao Cui and Hui Li and Yao Yao and Hao Zhu and Hanlin Shang and Kaihui Cheng and Hang Zhou and Siyu Zhu and️ Jingdong Wang}, 99 | year={2024}, 100 | eprint={2410.07718}, 101 | archivePrefix={arXiv}, 102 | primaryClass={cs.CV} 103 | } 104 | ``` 105 | 106 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .Hallo_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 5 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: "no" 12 | main_training_function: main 13 | mixed_precision: "fp16" 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /assets/framework_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/assets/framework_1.jpg -------------------------------------------------------------------------------- /assets/framework_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/assets/framework_2.jpg -------------------------------------------------------------------------------- /assets/wechat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/assets/wechat.jpeg -------------------------------------------------------------------------------- /configs/inference/long.yaml: -------------------------------------------------------------------------------- 1 | source_image: ./examples/reference_images/1.jpg 2 | driving_audio: ./examples/driving_audios/1.wav 3 | 4 | weight_dtype: fp16 5 | 6 | data: 7 | n_motion_frames: 2 8 | n_sample_frames: 16 9 | source_image: 10 | width: 512 11 | height: 512 12 | driving_audio: 13 | sample_rate: 16000 14 | export_video: 15 | fps: 25 16 | 17 | inference_steps: 40 18 | cfg_scale: 3.5 19 | 20 | use_mask: true 21 | mask_rate: 0.25 22 | use_cut: true 23 | 24 | audio_ckpt_dir: pretrained_models/hallo2 25 | 26 | 27 | save_path: ./output_long/debug/ 28 | cache_path: ./.cache 29 | 30 | base_model_path: ./pretrained_models/stable-diffusion-v1-5 31 | 32 | motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt 33 | 34 | face_analysis: 35 | model_path: ./pretrained_models/face_analysis 36 | 37 | wav2vec: 38 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h 39 | features: all 40 | 41 | audio_separator: 42 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx 43 | 44 | vae: 45 | model_path: ./pretrained_models/sd-vae-ft-mse 46 | 47 | face_expand_ratio: 1.2 48 | pose_weight: 1.0 49 | face_weight: 1.0 50 | lip_weight: 1.0 51 | 52 | unet_additional_kwargs: 53 | use_inflated_groupnorm: true 54 | unet_use_cross_frame_attention: false 55 | unet_use_temporal_attention: false 56 | use_motion_module: true 57 | use_audio_module: true 58 | motion_module_resolutions: 59 | - 1 60 | - 2 61 | - 4 62 | - 8 63 | motion_module_mid_block: true 64 | motion_module_decoder_only: false 65 | motion_module_type: Vanilla 66 | motion_module_kwargs: 67 | num_attention_heads: 8 68 | num_transformer_block: 1 69 | attention_block_types: 70 | - Temporal_Self 71 | - Temporal_Self 72 | temporal_position_encoding: true 73 | temporal_position_encoding_max_len: 32 74 | temporal_attention_dim_div: 1 75 | audio_attention_dim: 768 76 | stack_enable_blocks_name: 77 | - "up" 78 | - "down" 79 | - "mid" 80 | stack_enable_blocks_depth: [0,1,2,3] 81 | 82 | 83 | enable_zero_snr: true 84 | 85 | noise_scheduler_kwargs: 86 | beta_start: 0.00085 87 | beta_end: 0.012 88 | beta_schedule: "linear" 89 | clip_sample: false 90 | steps_offset: 1 91 | ### Zero-SNR params 92 | prediction_type: "v_prediction" 93 | rescale_betas_zero_snr: True 94 | timestep_spacing: "trailing" 95 | 96 | sampler: DDIM 97 | -------------------------------------------------------------------------------- /configs/train/stage1.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 8 3 | train_width: 512 4 | train_height: 512 5 | meta_paths: 6 | - "./data/HDTF_meta.json" 7 | # Margin of frame indexes between ref and tgt images 8 | sample_margin: 30 9 | 10 | solver: 11 | gradient_accumulation_steps: 1 12 | mixed_precision: "no" 13 | enable_xformers_memory_efficient_attention: True 14 | gradient_checkpointing: False 15 | max_train_steps: 30000 16 | max_grad_norm: 1.0 17 | # lr 18 | learning_rate: 1.0e-5 19 | scale_lr: False 20 | lr_warmup_steps: 1 21 | lr_scheduler: "constant" 22 | 23 | # optimizer 24 | use_8bit_adam: False 25 | adam_beta1: 0.9 26 | adam_beta2: 0.999 27 | adam_weight_decay: 1.0e-2 28 | adam_epsilon: 1.0e-8 29 | 30 | val: 31 | validation_steps: 500 32 | 33 | noise_scheduler_kwargs: 34 | num_train_timesteps: 1000 35 | beta_start: 0.00085 36 | beta_end: 0.012 37 | beta_schedule: "scaled_linear" 38 | steps_offset: 1 39 | clip_sample: false 40 | 41 | base_model_path: "./pretrained_models/stable-diffusion-v1-5/" 42 | vae_model_path: "./pretrained_models/sd-vae-ft-mse" 43 | face_analysis_model_path: "./pretrained_models/face_analysis" 44 | 45 | weight_dtype: "fp16" # [fp16, fp32] 46 | uncond_ratio: 0.1 47 | noise_offset: 0.05 48 | snr_gamma: 5.0 49 | enable_zero_snr: True 50 | face_locator_pretrained: False 51 | 52 | seed: 42 53 | resume_from_checkpoint: "latest" 54 | checkpointing_steps: 500 55 | exp_name: "stage1" 56 | output_dir: "./exp_output" 57 | 58 | ref_image_paths: 59 | - "examples/reference_images/1.jpg" 60 | 61 | mask_image_paths: 62 | - "examples/masks/1.png" 63 | 64 | -------------------------------------------------------------------------------- /configs/train/stage2_long.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 4 3 | val_bs: 1 4 | train_width: 512 5 | train_height: 512 6 | fps: 25 7 | sample_rate: 16000 8 | n_motion_frames: 2 9 | n_sample_frames: 14 10 | audio_margin: 2 11 | train_meta_paths: 12 | - "./data/hdtf_split_stage2.json" 13 | 14 | wav2vec_config: 15 | audio_type: "vocals" # audio vocals 16 | model_scale: "base" # base large 17 | features: "all" # last avg all 18 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h 19 | audio_separator: 20 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx 21 | face_expand_ratio: 1.2 22 | 23 | solver: 24 | gradient_accumulation_steps: 1 25 | mixed_precision: "no" 26 | enable_xformers_memory_efficient_attention: True 27 | gradient_checkpointing: True 28 | max_train_steps: 30000 29 | max_grad_norm: 1.0 30 | # lr 31 | learning_rate: 1e-5 32 | scale_lr: False 33 | lr_warmup_steps: 1 34 | lr_scheduler: "constant" 35 | 36 | # optimizer 37 | use_8bit_adam: True 38 | adam_beta1: 0.9 39 | adam_beta2: 0.999 40 | adam_weight_decay: 1.0e-2 41 | adam_epsilon: 1.0e-8 42 | 43 | val: 44 | validation_steps: 1000 45 | 46 | noise_scheduler_kwargs: 47 | num_train_timesteps: 1000 48 | beta_start: 0.00085 49 | beta_end: 0.012 50 | beta_schedule: "linear" 51 | steps_offset: 1 52 | clip_sample: false 53 | 54 | unet_additional_kwargs: 55 | use_inflated_groupnorm: true 56 | unet_use_cross_frame_attention: false 57 | unet_use_temporal_attention: false 58 | use_motion_module: true 59 | use_audio_module: true 60 | motion_module_resolutions: 61 | - 1 62 | - 2 63 | - 4 64 | - 8 65 | motion_module_mid_block: true 66 | motion_module_decoder_only: false 67 | motion_module_type: Vanilla 68 | motion_module_kwargs: 69 | num_attention_heads: 8 70 | num_transformer_block: 1 71 | attention_block_types: 72 | - Temporal_Self 73 | - Temporal_Self 74 | temporal_position_encoding: true 75 | temporal_position_encoding_max_len: 32 76 | temporal_attention_dim_div: 1 77 | audio_attention_dim: 768 78 | stack_enable_blocks_name: 79 | - "up" 80 | - "down" 81 | - "mid" 82 | stack_enable_blocks_depth: [0,1,2,3] 83 | 84 | 85 | trainable_para: 86 | # - audio_modules 87 | - motion_modules 88 | 89 | base_model_path: "./pretrained_models/stable-diffusion-v1-5/" 90 | vae_model_path: "./pretrained_models/sd-vae-ft-mse" 91 | face_analysis_model_path: "./pretrained_models/face_analysis" 92 | mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt" 93 | 94 | weight_dtype: "fp16" # [fp16, fp32] 95 | uncond_img_ratio: 0.05 96 | uncond_audio_ratio: 0.05 97 | uncond_ia_ratio: 0.05 98 | start_ratio: 0.05 99 | noise_offset: 0.05 100 | snr_gamma: 5.0 101 | enable_zero_snr: True 102 | 103 | audio_ckpt_dir: ./pretrained_models/hallo 104 | 105 | 106 | single_inference_times: 10 107 | inference_steps: 40 108 | cfg_scale: 3.5 109 | use_mask: true 110 | mask_rate: 0.25 111 | 112 | 113 | seed: 42 114 | resume_from_checkpoint: "latest" 115 | checkpointing_steps: 500 116 | 117 | exp_name: "stage2_long" 118 | output_dir: "./exp_output" 119 | 120 | ref_img_path: 121 | - "./examples/reference_images/1.jpg" 122 | audio_path: 123 | - "./examples/driving_audios/1.wav" 124 | 125 | 126 | -------------------------------------------------------------------------------- /configs/train/video_sr.yaml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CodeFormer_temp 3 | model_type: CodeFormerTempModel 4 | num_gpu: 8 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: VFHQ 11 | type: VFHQBlindDataset 12 | dataroot_gt: ./VFHQ/image 13 | filename_tmpl: '{}' 14 | io_backend: 15 | type: disk 16 | 17 | in_size: 512 18 | gt_size: 512 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | use_hflip: true 22 | use_corrupt: true 23 | video_length: 16 24 | 25 | # large degradation in stageII 26 | blur_kernel_size: 41 27 | use_motion_kernel: false 28 | motion_kernel_prob: 0.001 29 | kernel_list: ['iso', 'aniso'] 30 | kernel_prob: [0.5, 0.5] 31 | blur_sigma: [1, 15] 32 | downsample_range: [4, 30] 33 | noise_range: [0, 20] 34 | jpeg_range: [30, 80] 35 | 36 | latent_gt_path: ~ # without pre-calculated latent code 37 | 38 | # data loader 39 | num_worker_per_gpu: 8 40 | batch_size_per_gpu: 4 41 | dataset_enlarge_ratio: 1 42 | prefetch_mode: ~ 43 | 44 | # val: 45 | # name: CelebA-HQ-512 46 | # type: PairedImageDataset 47 | # dataroot_lq: datasets/faces/validation/lq 48 | # dataroot_gt: datasets/faces/validation/gt 49 | # io_backend: 50 | # type: disk 51 | # mean: [0.5, 0.5, 0.5] 52 | # std: [0.5, 0.5, 0.5] 53 | # scale: 1 54 | 55 | # network structures 56 | network_g: 57 | type: CodeFormer 58 | dim_embd: 512 59 | n_head: 8 60 | n_layers: 9 61 | codebook_size: 1024 62 | connect_list: ['32', '64', '128', '256'] 63 | fix_modules: ['quantize','generator'] 64 | vqgan_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' # pretrained VQGAN 65 | 66 | network_vqgan: # this config is needed if no pre-calculated latent 67 | type: VQAutoEncoder 68 | img_size: 512 69 | nf: 64 70 | ch_mult: [1, 2, 2, 4, 4, 8] 71 | quantizer: 'nearest' 72 | codebook_size: 1024 73 | model_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' 74 | 75 | # path 76 | path: 77 | pretrain_network_g: './pretrained_models/CodeFormer/codeformer.pth' 78 | param_key_g: params_ema 79 | strict_load_g: false 80 | pretrain_network_d: ~ 81 | strict_load_d: true 82 | resume_state: ~ 83 | 84 | # base_lr(4.5e-6)*bach_size(4) 85 | train: 86 | use_hq_feat_loss: true 87 | feat_loss_weight: 1.0 88 | cross_entropy_loss: true 89 | entropy_loss_weight: 0.5 90 | fidelity_weight: 0 91 | 92 | trainable_para: temp 93 | 94 | optim_g: 95 | type: Adam 96 | lr: !!float 1e-4 97 | weight_decay: 0 98 | betas: [0.9, 0.99] 99 | 100 | scheduler: 101 | type: MultiStepLR 102 | milestones: [400000, 450000] 103 | gamma: 0.5 104 | 105 | # scheduler: 106 | # type: CosineAnnealingRestartLR 107 | # periods: [500000] 108 | # restart_weights: [1] 109 | # eta_min: !!float 2e-5 # no lr reduce in official vqgan code 110 | 111 | total_iter: 500000 112 | 113 | warmup_iter: -1 # no warm up 114 | ema_decay: 0.995 115 | 116 | use_adaptive_weight: true 117 | 118 | net_g_start_iter: 0 119 | net_d_iters: 1 120 | net_d_start_iter: 0 121 | manual_seed: 0 122 | 123 | # validation settings 124 | val: 125 | val_freq: 1000 126 | save_img: true 127 | 128 | metrics: 129 | psnr: # metric name, can be arbitrary 130 | type: calculate_psnr 131 | crop_border: 4 132 | test_y_channel: false 133 | 134 | # logging settings 135 | logger: 136 | print_freq: 1 137 | save_checkpoint_freq: 1000 138 | use_tb_logger: true 139 | wandb: 140 | project: ~ 141 | resume_id: ~ 142 | 143 | # dist training settings 144 | dist_params: 145 | backend: nccl 146 | port: 29412 147 | 148 | find_unused_parameters: true 149 | -------------------------------------------------------------------------------- /configs/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "attention_head_dim": 8, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "center_input_sample": false, 13 | "cross_attention_dim": 768, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock2D", 16 | "CrossAttnDownBlock2D", 17 | "CrossAttnDownBlock2D", 18 | "DownBlock2D" 19 | ], 20 | "downsample_padding": 1, 21 | "flip_sin_to_cos": true, 22 | "freq_shift": 0, 23 | "in_channels": 4, 24 | "layers_per_block": 2, 25 | "mid_block_scale_factor": 1, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 64, 30 | "up_block_types": [ 31 | "UpBlock2D", 32 | "CrossAttnUpBlock2D", 33 | "CrossAttnUpBlock2D", 34 | "CrossAttnUpBlock2D" 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /configs/unet/unet.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | use_audio_module: true 7 | motion_module_resolutions: 8 | - 1 9 | - 2 10 | - 4 11 | - 8 12 | motion_module_mid_block: true 13 | motion_module_decoder_only: false 14 | motion_module_type: Vanilla 15 | motion_module_kwargs: 16 | num_attention_heads: 8 17 | num_transformer_block: 1 18 | attention_block_types: 19 | - Temporal_Self 20 | - Temporal_Self 21 | temporal_position_encoding: true 22 | temporal_position_encoding_max_len: 32 23 | temporal_attention_dim_div: 1 24 | audio_attention_dim: 768 25 | stack_enable_blocks_name: 26 | - "up" 27 | - "down" 28 | - "mid" 29 | stack_enable_blocks_depth: [0,1,2,3] 30 | 31 | enable_zero_snr: true 32 | 33 | noise_scheduler_kwargs: 34 | beta_start: 0.00085 35 | beta_end: 0.012 36 | beta_schedule: "linear" 37 | clip_sample: false 38 | steps_offset: 1 39 | ### Zero-SNR params 40 | prediction_type: "v_prediction" 41 | rescale_betas_zero_snr: True 42 | timestep_spacing: "trailing" 43 | 44 | sampler: DDIM 45 | -------------------------------------------------------------------------------- /configs/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.21.0.dev0", 4 | "_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f481832419503d1fa18d4a4379597f04b/vae", 5 | "act_fn": "silu", 6 | "block_out_channels": [ 7 | 128, 8 | 256, 9 | 512, 10 | 512 11 | ], 12 | "down_block_types": [ 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D", 16 | "DownEncoderBlock2D" 17 | ], 18 | "force_upcast": true, 19 | "in_channels": 3, 20 | "latent_channels": 4, 21 | "layers_per_block": 2, 22 | "norm_num_groups": 32, 23 | "out_channels": 3, 24 | "sample_size": 512, 25 | "scaling_factor": 0.18215, 26 | "up_block_types": [ 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D", 30 | "UpDecoderBlock2D" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/example.png -------------------------------------------------------------------------------- /hallo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/__init__.py -------------------------------------------------------------------------------- /hallo/animate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/animate/__init__.py -------------------------------------------------------------------------------- /hallo/basicsr/VERSION: -------------------------------------------------------------------------------- 1 | 1.3.2 2 | -------------------------------------------------------------------------------- /hallo/basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .hallo_archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .train import * 10 | from .utils import * 11 | from .version import __gitsha__, __version__ 12 | -------------------------------------------------------------------------------- /hallo/basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from .prefetch_dataloader import PrefetchDataLoader 11 | from ..utils import get_root_logger, scandir 12 | from ..utils.dist_util import get_dist_info 13 | from ..utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | #_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must constain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | 84 | prefetch_mode = dataset_opt.get('prefetch_mode') 85 | if prefetch_mode == 'cpu': # CPUPrefetcher 86 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 87 | logger = get_root_logger() 88 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') 89 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 90 | else: 91 | # prefetch_mode=None: Normal dataloader 92 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 93 | return torch.utils.data.DataLoader(**dataloader_args) 94 | 95 | 96 | def worker_init_fn(worker_id, num_workers, rank, seed): 97 | # Set the worker seed to num_workers * rank + worker_id + seed 98 | worker_seed = num_workers * rank + worker_id + seed 99 | np.random.seed(worker_seed) 100 | random.seed(worker_seed) 101 | -------------------------------------------------------------------------------- /hallo/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /hallo/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /hallo/basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | 4 | 5 | def mod_crop(img, scale): 6 | """Mod crop images, used during testing. 7 | 8 | Args: 9 | img (ndarray): Input image. 10 | scale (int): Scale factor. 11 | 12 | Returns: 13 | ndarray: Result image. 14 | """ 15 | img = img.copy() 16 | if img.ndim in (2, 3): 17 | h, w = img.shape[0], img.shape[1] 18 | h_remainder, w_remainder = h % scale, w % scale 19 | img = img[:h - h_remainder, :w - w_remainder, ...] 20 | else: 21 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 22 | return img 23 | 24 | 25 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): 26 | """Paired random crop. 27 | 28 | It crops lists of lq and gt images with corresponding locations. 29 | 30 | Args: 31 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 32 | should have the same shape. If the input is an ndarray, it will 33 | be transformed to a list containing itself. 34 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 35 | should have the same shape. If the input is an ndarray, it will 36 | be transformed to a list containing itself. 37 | gt_patch_size (int): GT patch size. 38 | scale (int): Scale factor. 39 | gt_path (str): Path to ground-truth. 40 | 41 | Returns: 42 | list[ndarray] | ndarray: GT images and LQ images. If returned results 43 | only have one element, just return ndarray. 44 | """ 45 | 46 | if not isinstance(img_gts, list): 47 | img_gts = [img_gts] 48 | if not isinstance(img_lqs, list): 49 | img_lqs = [img_lqs] 50 | 51 | h_lq, w_lq, _ = img_lqs[0].shape 52 | h_gt, w_gt, _ = img_gts[0].shape 53 | lq_patch_size = gt_patch_size // scale 54 | 55 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 56 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 57 | f'multiplication of LQ ({h_lq}, {w_lq}).') 58 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 59 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 60 | f'({lq_patch_size}, {lq_patch_size}). ' 61 | f'Please remove {gt_path}.') 62 | 63 | # randomly choose top and left coordinates for lq patch 64 | top = random.randint(0, h_lq - lq_patch_size) 65 | left = random.randint(0, w_lq - lq_patch_size) 66 | 67 | # crop lq patch 68 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 69 | 70 | # crop corresponding gt patch 71 | top_gt, left_gt = int(top * scale), int(left * scale) 72 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 73 | if len(img_gts) == 1: 74 | img_gts = img_gts[0] 75 | if len(img_lqs) == 1: 76 | img_lqs = img_lqs[0] 77 | return img_gts, img_lqs 78 | 79 | 80 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 81 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 82 | 83 | We use vertical flip and transpose for rotation implementation. 84 | All the images in the list use the same augmentation. 85 | 86 | Args: 87 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 88 | is an ndarray, it will be transformed to a list. 89 | hflip (bool): Horizontal flip. Default: True. 90 | rotation (bool): Ratotation. Default: True. 91 | flows (list[ndarray]: Flows to be augmented. If the input is an 92 | ndarray, it will be transformed to a list. 93 | Dimension is (h, w, 2). Default: None. 94 | return_status (bool): Return the status of flip and rotation. 95 | Default: False. 96 | 97 | Returns: 98 | list[ndarray] | ndarray: Augmented images and flows. If returned 99 | results only have one element, just return ndarray. 100 | 101 | """ 102 | hflip = hflip and random.random() < 0.5 103 | vflip = rotation and random.random() < 0.5 104 | rot90 = rotation and random.random() < 0.5 105 | 106 | def _augment(img): 107 | if hflip: # horizontal 108 | cv2.flip(img, 1, img) 109 | if vflip: # vertical 110 | cv2.flip(img, 0, img) 111 | if rot90: 112 | img = img.transpose(1, 0, 2) 113 | return img 114 | 115 | def _augment_flow(flow): 116 | if hflip: # horizontal 117 | cv2.flip(flow, 1, flow) 118 | flow[:, :, 0] *= -1 119 | if vflip: # vertical 120 | cv2.flip(flow, 0, flow) 121 | flow[:, :, 1] *= -1 122 | if rot90: 123 | flow = flow.transpose(1, 0, 2) 124 | flow = flow[:, :, [1, 0]] 125 | return flow 126 | 127 | if not isinstance(imgs, list): 128 | imgs = [imgs] 129 | imgs = [_augment(img) for img in imgs] 130 | if len(imgs) == 1: 131 | imgs = imgs[0] 132 | 133 | if flows is not None: 134 | if not isinstance(flows, list): 135 | flows = [flows] 136 | flows = [_augment_flow(flow) for flow in flows] 137 | if len(flows) == 1: 138 | flows = flows[0] 139 | return imgs, flows 140 | else: 141 | if return_status: 142 | return imgs, (hflip, vflip, rot90) 143 | else: 144 | return imgs 145 | 146 | 147 | def img_rotate(img, angle, center=None, scale=1.0): 148 | """Rotate image. 149 | 150 | Args: 151 | img (ndarray): Image to be rotated. 152 | angle (float): Rotation angle in degrees. Positive values mean 153 | counter-clockwise rotation. 154 | center (tuple[int]): Rotation center. If the center is None, 155 | initialize it as the center of the image. Default: None. 156 | scale (float): Isotropic scale factor. Default: 1.0. 157 | """ 158 | (h, w) = img.shape[:2] 159 | 160 | if center is None: 161 | center = (w // 2, h // 2) 162 | 163 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 164 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 165 | return rotated_img 166 | -------------------------------------------------------------------------------- /hallo/basicsr/hallo_archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from ..utils import get_root_logger, scandir 6 | from ..utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | #_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /hallo/basicsr/hallo_archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from ..utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Emperically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Emperically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | # @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out -------------------------------------------------------------------------------- /hallo/basicsr/hallo_archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from ..utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | # @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | output = {} 155 | 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /hallo/basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from ..utils import get_root_logger 4 | from ..utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty) 7 | 8 | __all__ = [ 9 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must constain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /hallo/basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /hallo/basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from ..utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must constain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /hallo/basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /hallo/basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from .metric_util import reorder_image, to_y_channel 5 | from ..utils.registry import METRIC_REGISTRY 6 | 7 | 8 | # @METRIC_REGISTRY.register() 9 | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img1 (ndarray): Images with range [0, 255]. 16 | img2 (ndarray): Images with range [0, 255]. 17 | crop_border (int): Cropped pixels in each edge of an image. These 18 | pixels are not involved in the PSNR calculation. 19 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 20 | Default: 'HWC'. 21 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 22 | 23 | Returns: 24 | float: psnr result. 25 | """ 26 | 27 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 28 | if input_order not in ['HWC', 'CHW']: 29 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 30 | img1 = reorder_image(img1, input_order=input_order) 31 | img2 = reorder_image(img2, input_order=input_order) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | mse = np.mean((img1 - img2)**2) 44 | if mse == 0: 45 | return float('inf') 46 | return 20. * np.log10(255. / np.sqrt(mse)) 47 | 48 | 49 | def _ssim(img1, img2): 50 | """Calculate SSIM (structural similarity) for one channel images. 51 | 52 | It is called by func:`calculate_ssim`. 53 | 54 | Args: 55 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 56 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | 58 | Returns: 59 | float: ssim result. 60 | """ 61 | 62 | C1 = (0.01 * 255)**2 63 | C2 = (0.03 * 255)**2 64 | 65 | img1 = img1.astype(np.float64) 66 | img2 = img2.astype(np.float64) 67 | kernel = cv2.getGaussianKernel(11, 1.5) 68 | window = np.outer(kernel, kernel.transpose()) 69 | 70 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 71 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 72 | mu1_sq = mu1**2 73 | mu2_sq = mu2**2 74 | mu1_mu2 = mu1 * mu2 75 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 76 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 77 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 78 | 79 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 80 | return ssim_map.mean() 81 | 82 | 83 | # @METRIC_REGISTRY.register() 84 | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 85 | """Calculate SSIM (structural similarity). 86 | 87 | Ref: 88 | Image quality assessment: From error visibility to structural similarity 89 | 90 | The results are the same as that of the official released MATLAB code in 91 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 92 | 93 | For three-channel images, SSIM is calculated for each channel and then 94 | averaged. 95 | 96 | Args: 97 | img1 (ndarray): Images with range [0, 255]. 98 | img2 (ndarray): Images with range [0, 255]. 99 | crop_border (int): Cropped pixels in each edge of an image. These 100 | pixels are not involved in the SSIM calculation. 101 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 102 | Default: 'HWC'. 103 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 104 | 105 | Returns: 106 | float: ssim result. 107 | """ 108 | 109 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 110 | if input_order not in ['HWC', 'CHW']: 111 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 112 | img1 = reorder_image(img1, input_order=input_order) 113 | img2 = reorder_image(img2, input_order=input_order) 114 | img1 = img1.astype(np.float64) 115 | img2 = img2.astype(np.float64) 116 | 117 | if crop_border != 0: 118 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 119 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 120 | 121 | if test_y_channel: 122 | img1 = to_y_channel(img1) 123 | img2 = to_y_channel(img2) 124 | 125 | ssims = [] 126 | for i in range(img1.shape[2]): 127 | ssims.append(_ssim(img1[..., i], img2[..., i])) 128 | return np.array(ssims).mean() 129 | -------------------------------------------------------------------------------- /hallo/basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from ..utils import get_root_logger, scandir 6 | from ..utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | #_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must constain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /hallo/basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The mimimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /hallo/basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | try: 8 | from . import fused_act_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | fused_act_ext = load( 16 | 'fused', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 19 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class FusedLeakyReLUFunctionBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, out, negative_slope, scale): 28 | ctx.save_for_backward(out) 29 | ctx.negative_slope = negative_slope 30 | ctx.scale = scale 31 | 32 | empty = grad_output.new_empty(0) 33 | 34 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 35 | 36 | dim = [0] 37 | 38 | if grad_input.ndim > 2: 39 | dim += list(range(2, grad_input.ndim)) 40 | 41 | grad_bias = grad_input.sum(dim).detach() 42 | 43 | return grad_input, grad_bias 44 | 45 | @staticmethod 46 | def backward(ctx, gradgrad_input, gradgrad_bias): 47 | out, = ctx.saved_tensors 48 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 49 | ctx.scale) 50 | 51 | return gradgrad_out, None, None, None 52 | 53 | 54 | class FusedLeakyReLUFunction(Function): 55 | 56 | @staticmethod 57 | def forward(ctx, input, bias, negative_slope, scale): 58 | empty = input.new_empty(0) 59 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 60 | ctx.save_for_backward(out) 61 | ctx.negative_slope = negative_slope 62 | ctx.scale = scale 63 | 64 | return out 65 | 66 | @staticmethod 67 | def backward(ctx, grad_output): 68 | out, = ctx.saved_tensors 69 | 70 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | 77 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 78 | super().__init__() 79 | 80 | self.bias = nn.Parameter(torch.zeros(channel)) 81 | self.negative_slope = negative_slope 82 | self.scale = scale 83 | 84 | def forward(self, input): 85 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 86 | 87 | 88 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 89 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 90 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /hallo/basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.nn import functional as F 6 | 7 | try: 8 | from . import upfirdn2d_ext 9 | except ImportError: 10 | import os 11 | BASICSR_JIT = os.getenv('BASICSR_JIT') 12 | if BASICSR_JIT == 'True': 13 | from torch.utils.cpp_extension import load 14 | module_path = os.path.dirname(__file__) 15 | upfirdn2d_ext = load( 16 | 'upfirdn2d', 17 | sources=[ 18 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 19 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 20 | ], 21 | ) 22 | 23 | 24 | class UpFirDn2dBackward(Function): 25 | 26 | @staticmethod 27 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 28 | 29 | up_x, up_y = up 30 | down_x, down_y = down 31 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 32 | 33 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 34 | 35 | grad_input = upfirdn2d_ext.upfirdn2d( 36 | grad_output, 37 | grad_kernel, 38 | down_x, 39 | down_y, 40 | up_x, 41 | up_y, 42 | g_pad_x0, 43 | g_pad_x1, 44 | g_pad_y0, 45 | g_pad_y1, 46 | ) 47 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 48 | 49 | ctx.save_for_backward(kernel) 50 | 51 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 52 | 53 | ctx.up_x = up_x 54 | ctx.up_y = up_y 55 | ctx.down_x = down_x 56 | ctx.down_y = down_y 57 | ctx.pad_x0 = pad_x0 58 | ctx.pad_x1 = pad_x1 59 | ctx.pad_y0 = pad_y0 60 | ctx.pad_y1 = pad_y1 61 | ctx.in_size = in_size 62 | ctx.out_size = out_size 63 | 64 | return grad_input 65 | 66 | @staticmethod 67 | def backward(ctx, gradgrad_input): 68 | kernel, = ctx.saved_tensors 69 | 70 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 71 | 72 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 73 | gradgrad_input, 74 | kernel, 75 | ctx.up_x, 76 | ctx.up_y, 77 | ctx.down_x, 78 | ctx.down_y, 79 | ctx.pad_x0, 80 | ctx.pad_x1, 81 | ctx.pad_y0, 82 | ctx.pad_y1, 83 | ) 84 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 85 | # ctx.out_size[1], ctx.in_size[3]) 86 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 87 | 88 | return gradgrad_out, None, None, None, None, None, None, None, None 89 | 90 | 91 | class UpFirDn2d(Function): 92 | 93 | @staticmethod 94 | def forward(ctx, input, kernel, up, down, pad): 95 | up_x, up_y = up 96 | down_x, down_y = down 97 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 98 | 99 | kernel_h, kernel_w = kernel.shape 100 | batch, channel, in_h, in_w = input.shape 101 | ctx.in_size = input.shape 102 | 103 | input = input.reshape(-1, in_h, in_w, 1) 104 | 105 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 106 | 107 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 108 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 109 | ctx.out_size = (out_h, out_w) 110 | 111 | ctx.up = (up_x, up_y) 112 | ctx.down = (down_x, down_y) 113 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 114 | 115 | g_pad_x0 = kernel_w - pad_x0 - 1 116 | g_pad_y0 = kernel_h - pad_y0 - 1 117 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 118 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 119 | 120 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 121 | 122 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | if input.device.type == 'cpu': 149 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 150 | else: 151 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 152 | 153 | return out 154 | 155 | 156 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 157 | _, channel, in_h, in_w = input.shape 158 | input = input.reshape(-1, in_h, in_w, 1) 159 | 160 | _, in_h, in_w, minor = input.shape 161 | kernel_h, kernel_w = kernel.shape 162 | 163 | out = input.view(-1, in_h, 1, in_w, 1, minor) 164 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 165 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 166 | 167 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 168 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 172 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 173 | out = F.conv2d(out, w) 174 | out = out.reshape( 175 | -1, 176 | minor, 177 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 178 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 179 | ) 180 | out = out.permute(0, 2, 3, 1) 181 | out = out[:, ::down_y, ::down_x, :] 182 | 183 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 184 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 185 | 186 | return out.view(-1, channel, out_h, out_w) 187 | -------------------------------------------------------------------------------- /hallo/basicsr/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | from utils.misc import gpu_is_available 11 | 12 | version_file = './basicsr/version.py' 13 | 14 | 15 | def readme(): 16 | with open('README.md', encoding='utf-8') as f: 17 | content = f.read() 18 | return content 19 | 20 | 21 | def get_git_hash(): 22 | 23 | def _minimal_ext_cmd(cmd): 24 | # construct minimal environment 25 | env = {} 26 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 27 | v = os.environ.get(k) 28 | if v is not None: 29 | env[k] = v 30 | # LANGUAGE is used on win32 31 | env['LANGUAGE'] = 'C' 32 | env['LANG'] = 'C' 33 | env['LC_ALL'] = 'C' 34 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 35 | return out 36 | 37 | try: 38 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 39 | sha = out.strip().decode('ascii') 40 | except OSError: 41 | sha = 'unknown' 42 | 43 | return sha 44 | 45 | 46 | def get_hash(): 47 | if os.path.exists('.git'): 48 | sha = get_git_hash()[:7] 49 | elif os.path.exists(version_file): 50 | try: 51 | from version import __version__ 52 | sha = __version__.split('+')[-1] 53 | except ImportError: 54 | raise ImportError('Unable to get git version') 55 | else: 56 | sha = 'unknown' 57 | 58 | return sha 59 | 60 | 61 | def write_version_py(): 62 | content = """# GENERATED VERSION FILE 63 | # TIME: {} 64 | __version__ = '{}' 65 | __gitsha__ = '{}' 66 | version_info = ({}) 67 | """ 68 | sha = get_hash() 69 | with open('./basicsr/VERSION', 'r') as f: 70 | SHORT_VERSION = f.read().strip() 71 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 72 | 73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 74 | with open(version_file, 'w') as f: 75 | f.write(version_file_str) 76 | 77 | 78 | def get_version(): 79 | with open(version_file, 'r') as f: 80 | exec(compile(f.read(), version_file, 'exec')) 81 | return locals()['__version__'] 82 | 83 | 84 | def make_cuda_ext(name, module, sources, sources_cuda=None): 85 | if sources_cuda is None: 86 | sources_cuda = [] 87 | define_macros = [] 88 | extra_compile_args = {'cxx': []} 89 | 90 | # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 91 | if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1': 92 | define_macros += [('WITH_CUDA', None)] 93 | extension = CUDAExtension 94 | extra_compile_args['nvcc'] = [ 95 | '-D__CUDA_NO_HALF_OPERATORS__', 96 | '-D__CUDA_NO_HALF_CONVERSIONS__', 97 | '-D__CUDA_NO_HALF2_OPERATORS__', 98 | ] 99 | sources += sources_cuda 100 | else: 101 | print(f'Compiling {name} without CUDA') 102 | extension = CppExtension 103 | 104 | return extension( 105 | name=f'{module}.{name}', 106 | sources=[os.path.join(*module.split('.'), p) for p in sources], 107 | define_macros=define_macros, 108 | extra_compile_args=extra_compile_args) 109 | 110 | 111 | def get_requirements(filename='requirements.txt'): 112 | with open(os.path.join('.', filename), 'r') as f: 113 | requires = [line.replace('\n', '') for line in f.readlines()] 114 | return requires 115 | 116 | 117 | if __name__ == '__main__': 118 | if '--cuda_ext' in sys.argv: 119 | ext_modules = [ 120 | make_cuda_ext( 121 | name='deform_conv_ext', 122 | module='ops.dcn', 123 | sources=['src/deform_conv_ext.cpp'], 124 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 125 | make_cuda_ext( 126 | name='fused_act_ext', 127 | module='ops.fused_act', 128 | sources=['src/fused_bias_act.cpp'], 129 | sources_cuda=['src/fused_bias_act_kernel.cu']), 130 | make_cuda_ext( 131 | name='upfirdn2d_ext', 132 | module='ops.upfirdn2d', 133 | sources=['src/upfirdn2d.cpp'], 134 | sources_cuda=['src/upfirdn2d_kernel.cu']), 135 | ] 136 | sys.argv.remove('--cuda_ext') 137 | else: 138 | ext_modules = [] 139 | 140 | write_version_py() 141 | setup( 142 | name='basicsr', 143 | version=get_version(), 144 | description='Open Source Image and Video Super-Resolution Toolbox', 145 | long_description=readme(), 146 | long_description_content_type='text/markdown', 147 | author='Xintao Wang', 148 | author_email='xintao.wang@outlook.com', 149 | keywords='computer vision, restoration, super resolution', 150 | url='https://github.com/xinntao/BasicSR', 151 | include_package_data=True, 152 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 153 | classifiers=[ 154 | 'Development Status :: 4 - Beta', 155 | 'License :: OSI Approved :: Apache Software License', 156 | 'Operating System :: OS Independent', 157 | 'Programming Language :: Python :: 3', 158 | 'Programming Language :: Python :: 3.7', 159 | 'Programming Language :: Python :: 3.8', 160 | ], 161 | license='Apache License 2.0', 162 | setup_requires=['cython', 'numpy'], 163 | install_requires=get_requirements(), 164 | ext_modules=ext_modules, 165 | cmdclass={'build_ext': BuildExtension}, 166 | zip_safe=False) 167 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 3 | from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 4 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 5 | 6 | __all__ = [ 7 | # file_client.py 8 | 'FileClient', 9 | # img_util.py 10 | 'img2tensor', 11 | 'tensor2img', 12 | 'imfrombytes', 13 | 'imwrite', 14 | 'crop_border', 15 | # logger.py 16 | 'MessageLogger', 17 | 'init_tb_logger', 18 | 'init_wandb_logger', 19 | 'get_root_logger', 20 | 'get_env_info', 21 | # misc.py 22 | 'set_random_seed', 23 | 'get_time_str', 24 | 'mkdir_and_rename', 25 | 'make_exp_dirs', 26 | 'scandir', 27 | 'check_resume', 28 | 'sizeof_fmt' 29 | ] 30 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | Ref: 14 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 15 | Args: 16 | file_id (str): File id. 17 | save_path (str): Save path. 18 | """ 19 | 20 | session = requests.Session() 21 | URL = 'https://docs.google.com/uc?export=download' 22 | params = {'id': file_id} 23 | 24 | response = session.get(URL, params=params, stream=True) 25 | token = get_confirm_token(response) 26 | if token: 27 | params['confirm'] = token 28 | response = session.get(URL, params=params, stream=True) 29 | 30 | # get file size 31 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | print(response_file_size) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 72 | Args: 73 | url (str): URL to be downloaded. 74 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 75 | Default: None. 76 | progress (bool): Whether to show the download progress. Default: True. 77 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 78 | Returns: 79 | str: The path to the downloaded file. 80 | """ 81 | if model_dir is None: # use the pytorch hub_dir 82 | hub_dir = get_dir() 83 | model_dir = os.path.join(hub_dir, 'checkpoints') 84 | 85 | os.makedirs(model_dir, exist_ok=True) 86 | 87 | parts = urlparse(url) 88 | filename = os.path.basename(parts.path) 89 | if file_name is not None: 90 | filename = file_name 91 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 92 | if not os.path.exists(cached_file): 93 | print(f'Downloading: "{url}" to {cached_file}\n') 94 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 95 | return cached_file -------------------------------------------------------------------------------- /hallo/basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing differnet lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | return cv2.imwrite(file_path, img, params) 152 | 153 | 154 | def crop_border(imgs, crop_border): 155 | """Crop borders of images. 156 | 157 | Args: 158 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 159 | crop_border (int): Crop border for each end of height and weight. 160 | 161 | Returns: 162 | list[ndarray]: Cropped images. 163 | """ 164 | if crop_border == 0: 165 | return imgs 166 | else: 167 | if isinstance(imgs, list): 168 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 169 | else: 170 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 171 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | Args: 13 | opt (dict): Config. It contains the following keys: 14 | name (str): Exp name. 15 | logger (dict): Contains 'print_freq' (str) for logger interval. 16 | train (dict): Contains 'total_iter' (int) for total iters. 17 | use_tb_logger (bool): Use tensorboard logger. 18 | start_iter (int): Start iter. Default: 1. 19 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 20 | """ 21 | 22 | def __init__(self, opt, start_iter=1, tb_logger=None): 23 | self.exp_name = opt['name'] 24 | self.interval = opt['logger']['print_freq'] 25 | self.start_iter = start_iter 26 | self.max_iters = opt['train']['total_iter'] 27 | self.use_tb_logger = opt['logger']['use_tb_logger'] 28 | self.tb_logger = tb_logger 29 | self.start_time = time.time() 30 | self.logger = get_root_logger() 31 | 32 | @master_only 33 | def __call__(self, log_vars): 34 | """Format logging message. 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | time (float): Iter time. 41 | data_time (float): Data time for each iter. 42 | """ 43 | # epoch, iter, learning rates 44 | epoch = log_vars.pop('epoch') 45 | current_iter = log_vars.pop('iter') 46 | lrs = log_vars.pop('lrs') 47 | 48 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 49 | for v in lrs: 50 | message += f'{v:.3e},' 51 | message += ')] ' 52 | 53 | # time and estimated time 54 | if 'time' in log_vars.keys(): 55 | iter_time = log_vars.pop('time') 56 | data_time = log_vars.pop('data_time') 57 | 58 | total_time = time.time() - self.start_time 59 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 60 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 61 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 62 | message += f'[eta: {eta_str}, ' 63 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 64 | 65 | # other items, especially losses 66 | for k, v in log_vars.items(): 67 | message += f'{k}: {v:.4e} ' 68 | # tensorboard logger 69 | if self.use_tb_logger: 70 | # if k.startswith('l_'): 71 | # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 72 | # else: 73 | self.tb_logger.add_scalar(k, v, current_iter) 74 | self.logger.info(message) 75 | 76 | 77 | @master_only 78 | def init_tb_logger(log_dir): 79 | from torch.utils.tensorboard import SummaryWriter 80 | tb_logger = SummaryWriter(log_dir=log_dir) 81 | return tb_logger 82 | 83 | 84 | @master_only 85 | def init_wandb_logger(opt): 86 | """We now only use wandb to sync tensorboard log.""" 87 | import wandb 88 | logger = logging.getLogger('basicsr') 89 | 90 | project = opt['logger']['wandb']['project'] 91 | resume_id = opt['logger']['wandb'].get('resume_id') 92 | if resume_id: 93 | wandb_id = resume_id 94 | resume = 'allow' 95 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 96 | else: 97 | wandb_id = wandb.util.generate_id() 98 | resume = 'never' 99 | 100 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 101 | 102 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 103 | 104 | 105 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 106 | """Get the root logger. 107 | The logger will be initialized if it has not been initialized. By default a 108 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 109 | also be added. 110 | Args: 111 | logger_name (str): root logger name. Default: 'basicsr'. 112 | log_file (str | None): The log filename. If specified, a FileHandler 113 | will be added to the root logger. 114 | log_level (int): The root logger level. Note that only the process of 115 | rank 0 is affected, while other processes will set the level to 116 | "Error" and be silent most of the time. 117 | Returns: 118 | logging.Logger: The root logger. 119 | """ 120 | logger = logging.getLogger(logger_name) 121 | # if the logger has been initialized, just return it 122 | if logger_name in initialized_logger: 123 | return logger 124 | 125 | format_str = '%(asctime)s %(levelname)s: %(message)s' 126 | stream_handler = logging.StreamHandler() 127 | stream_handler.setFormatter(logging.Formatter(format_str)) 128 | logger.addHandler(stream_handler) 129 | logger.propagate = False 130 | rank, _ = get_dist_info() 131 | if rank != 0: 132 | logger.setLevel('ERROR') 133 | elif log_file is not None: 134 | logger.setLevel(log_level) 135 | # add file handler 136 | # file_handler = logging.FileHandler(log_file, 'w') 137 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 138 | file_handler.setFormatter(logging.Formatter(format_str)) 139 | file_handler.setLevel(log_level) 140 | logger.addHandler(file_handler) 141 | initialized_logger[logger_name] = True 142 | return logger 143 | 144 | 145 | def get_env_info(): 146 | """Get environment information. 147 | Currently, only log the software version. 148 | """ 149 | import torch 150 | import torchvision 151 | 152 | from basicsr.version import __version__ 153 | msg = r""" 154 | ____ _ _____ ____ 155 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 156 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 157 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 158 | /_____/ \__,_//____//_/ \___//____//_/ |_| 159 | ______ __ __ __ __ 160 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 161 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 162 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 163 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 164 | """ 165 | msg += ('\nVersion Information: ' 166 | f'\n\tBasicSR: {__version__}' 167 | f'\n\tPyTorch: {torch.__version__}' 168 | f'\n\tTorchVision: {torchvision.__version__}') 169 | return msg -------------------------------------------------------------------------------- /hallo/basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import numpy as np 7 | from os import path as osp 8 | 9 | from .dist_util import master_only 10 | from .logger import get_root_logger 11 | 12 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",torch.__version__)[0][:3])] >= [1, 12, 0] 13 | 14 | def gpu_is_available(): 15 | if IS_HIGH_VERSION: 16 | if torch.backends.mps.is_available(): 17 | return True 18 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 19 | 20 | def get_device(gpu_id=None): 21 | if gpu_id is None: 22 | gpu_str = '' 23 | elif isinstance(gpu_id, int): 24 | gpu_str = f':{gpu_id}' 25 | else: 26 | raise TypeError('Input should be int value.') 27 | 28 | if IS_HIGH_VERSION: 29 | if torch.backends.mps.is_available(): 30 | return torch.device('mps'+gpu_str) 31 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 32 | 33 | 34 | def set_random_seed(seed): 35 | """Set random seeds.""" 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | 42 | 43 | def get_time_str(): 44 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 45 | 46 | 47 | def mkdir_and_rename(path): 48 | """mkdirs. If path exists, rename it with timestamp and create a new one. 49 | 50 | Args: 51 | path (str): Folder path. 52 | """ 53 | if osp.exists(path): 54 | new_name = path + '_archived_' + get_time_str() 55 | print(f'Path already exists. Rename it to {new_name}', flush=True) 56 | os.rename(path, new_name) 57 | os.makedirs(path, exist_ok=True) 58 | 59 | 60 | @master_only 61 | def make_exp_dirs(opt): 62 | """Make dirs for experiments.""" 63 | path_opt = opt['path'].copy() 64 | if opt['is_train']: 65 | mkdir_and_rename(path_opt.pop('experiments_root')) 66 | else: 67 | mkdir_and_rename(path_opt.pop('results_root')) 68 | for key, path in path_opt.items(): 69 | if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): 70 | os.makedirs(path, exist_ok=True) 71 | 72 | 73 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 74 | """Scan a directory to find the interested files. 75 | 76 | Args: 77 | dir_path (str): Path of the directory. 78 | suffix (str | tuple(str), optional): File suffix that we are 79 | interested in. Default: None. 80 | recursive (bool, optional): If set to True, recursively scan the 81 | directory. Default: False. 82 | full_path (bool, optional): If set to True, include the dir_path. 83 | Default: False. 84 | 85 | Returns: 86 | A generator for all the interested files with relative pathes. 87 | """ 88 | 89 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 90 | raise TypeError('"suffix" must be a string or tuple of strings') 91 | 92 | root = dir_path 93 | 94 | def _scandir(dir_path, suffix, recursive): 95 | for entry in os.scandir(dir_path): 96 | if not entry.name.startswith('.') and entry.is_file(): 97 | if full_path: 98 | return_path = entry.path 99 | else: 100 | return_path = osp.relpath(entry.path, root) 101 | 102 | if suffix is None: 103 | yield return_path 104 | elif return_path.endswith(suffix): 105 | yield return_path 106 | else: 107 | if recursive: 108 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 109 | else: 110 | continue 111 | 112 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 113 | 114 | 115 | def check_resume(opt, resume_iter): 116 | """Check resume states and pretrain_network paths. 117 | 118 | Args: 119 | opt (dict): Options. 120 | resume_iter (int): Resume iteration. 121 | """ 122 | logger = get_root_logger() 123 | if opt['path']['resume_state']: 124 | # get all the networks 125 | networks = [key for key in opt.keys() if key.startswith('network_')] 126 | flag_pretrain = False 127 | for network in networks: 128 | if opt['path'].get(f'pretrain_{network}') is not None: 129 | flag_pretrain = True 130 | if flag_pretrain: 131 | logger.warning('pretrain_network path will be ignored during resuming.') 132 | # set pretrained model paths 133 | for network in networks: 134 | name = f'pretrain_{network}' 135 | basename = network.replace('network_', '') 136 | if opt['path'].get('ignore_resume_networks') is None or (basename 137 | not in opt['path']['ignore_resume_networks']): 138 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 139 | logger.info(f"Set {name} to {opt['path'][name]}") 140 | 141 | 142 | def sizeof_fmt(size, suffix='B'): 143 | """Get human readable file size. 144 | 145 | Args: 146 | size (int): File size. 147 | suffix (str): Suffix. Default: 'B'. 148 | 149 | Return: 150 | str: Formated file siz. 151 | """ 152 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 153 | if abs(size) < 1024.0: 154 | return f'{size:3.1f} {unit}{suffix}' 155 | size /= 1024.0 156 | return f'{size:3.1f} Y{suffix}' 157 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import time 3 | from collections import OrderedDict 4 | from os import path as osp 5 | from .misc import get_time_str 6 | 7 | def ordered_yaml(): 8 | """Support OrderedDict for yaml. 9 | 10 | Returns: 11 | yaml Loader and Dumper. 12 | """ 13 | try: 14 | from yaml import CDumper as Dumper 15 | from yaml import CLoader as Loader 16 | except ImportError: 17 | from yaml import Dumper, Loader 18 | 19 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 20 | 21 | def dict_representer(dumper, data): 22 | return dumper.represent_dict(data.items()) 23 | 24 | def dict_constructor(loader, node): 25 | return OrderedDict(loader.construct_pairs(node)) 26 | 27 | Dumper.add_representer(OrderedDict, dict_representer) 28 | Loader.add_constructor(_mapping_tag, dict_constructor) 29 | return Loader, Dumper 30 | 31 | 32 | def parse(opt_path, root_path, is_train=True): 33 | """Parse option file. 34 | 35 | Args: 36 | opt_path (str): Option file path. 37 | is_train (str): Indicate whether in training or not. Default: True. 38 | 39 | Returns: 40 | (dict): Options. 41 | """ 42 | with open(opt_path, mode='r') as f: 43 | Loader, _ = ordered_yaml() 44 | opt = yaml.load(f, Loader=Loader) 45 | 46 | opt['is_train'] = is_train 47 | 48 | # opt['name'] = f"{get_time_str()}_{opt['name']}" 49 | if opt['path'].get('resume_state', None): # Shangchen added 50 | resume_state_path = opt['path'].get('resume_state') 51 | opt['name'] = resume_state_path.split("/")[-3] 52 | else: 53 | opt['name'] = f"{get_time_str()}_{opt['name']}" 54 | 55 | 56 | # datasets 57 | for phase, dataset in opt['datasets'].items(): 58 | # for several datasets, e.g., test_1, test_2 59 | phase = phase.split('_')[0] 60 | dataset['phase'] = phase 61 | if 'scale' in opt: 62 | dataset['scale'] = opt['scale'] 63 | if dataset.get('dataroot_gt') is not None: 64 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 65 | if dataset.get('dataroot_lq') is not None: 66 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 67 | 68 | # paths 69 | for key, val in opt['path'].items(): 70 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 71 | opt['path'][key] = osp.expanduser(val) 72 | 73 | if is_train: 74 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 75 | opt['path']['experiments_root'] = experiments_root 76 | opt['path']['models'] = osp.join(experiments_root, 'models') 77 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 78 | opt['path']['log'] = experiments_root 79 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 80 | 81 | else: # test 82 | results_root = osp.join(root_path, 'results', opt['name']) 83 | opt['path']['results_root'] = results_root 84 | opt['path']['log'] = results_root 85 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 86 | 87 | return opt 88 | 89 | 90 | def dict2str(opt, indent_level=1): 91 | """dict to string for printing options. 92 | 93 | Args: 94 | opt (dict): Option dict. 95 | indent_level (int): Indent level. Default: 1. 96 | 97 | Return: 98 | (str): Option string for printing. 99 | """ 100 | msg = '\n' 101 | for k, v in opt.items(): 102 | if isinstance(v, dict): 103 | msg += ' ' * (indent_level * 2) + k + ':[' 104 | msg += dict2str(v, indent_level + 1) 105 | msg += ' ' * (indent_level * 2) + ']\n' 106 | else: 107 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 108 | return msg 109 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /hallo/basicsr/utils/video_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code is modified from the Real-ESRGAN: 3 | https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py 4 | 5 | ''' 6 | import cv2 7 | import sys 8 | import numpy as np 9 | 10 | try: 11 | import ffmpeg 12 | except ImportError: 13 | import pip 14 | pip.main(['install', '--user', 'ffmpeg-python']) 15 | import ffmpeg 16 | 17 | def get_video_meta_info(video_path): 18 | ret = {} 19 | probe = ffmpeg.probe(video_path) 20 | video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] 21 | has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) 22 | ret['width'] = video_streams[0]['width'] 23 | ret['height'] = video_streams[0]['height'] 24 | ret['fps'] = eval(video_streams[0]['avg_frame_rate']) 25 | ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None 26 | ret['nb_frames'] = int(video_streams[0]['nb_frames']) 27 | return ret 28 | 29 | class VideoReader: 30 | def __init__(self, video_path): 31 | self.paths = [] # for image&folder type 32 | self.audio = None 33 | try: 34 | self.stream_reader = ( 35 | ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', 36 | loglevel='error').run_async( 37 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 38 | except FileNotFoundError: 39 | print('Please install ffmpeg (not ffmpeg-python) by running\n', 40 | '\t$ conda install -c conda-forge ffmpeg') 41 | sys.exit(0) 42 | 43 | meta = get_video_meta_info(video_path) 44 | self.width = meta['width'] 45 | self.height = meta['height'] 46 | self.input_fps = meta['fps'] 47 | self.audio = meta['audio'] 48 | self.nb_frames = meta['nb_frames'] 49 | 50 | self.idx = 0 51 | 52 | def get_resolution(self): 53 | return self.height, self.width 54 | 55 | def get_fps(self): 56 | if self.input_fps is not None: 57 | return self.input_fps 58 | return 24 59 | 60 | def get_audio(self): 61 | return self.audio 62 | 63 | def __len__(self): 64 | return self.nb_frames 65 | 66 | def get_frame_from_stream(self): 67 | img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel 68 | if not img_bytes: 69 | return None 70 | img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) 71 | return img 72 | 73 | def get_frame_from_list(self): 74 | if self.idx >= self.nb_frames: 75 | return None 76 | img = cv2.imread(self.paths[self.idx]) 77 | self.idx += 1 78 | return img 79 | 80 | def get_frame(self): 81 | return self.get_frame_from_stream() 82 | 83 | 84 | def close(self): 85 | self.stream_reader.stdin.close() 86 | self.stream_reader.wait() 87 | 88 | 89 | class VideoWriter: 90 | def __init__(self, video_save_path, height, width, fps, audio): 91 | if height > 2160: 92 | print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', 93 | 'We highly recommend to decrease the outscale(aka, -s).') 94 | if audio is not None: 95 | self.stream_writer = ( 96 | ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', 97 | framerate=fps).output( 98 | audio, 99 | video_save_path, 100 | pix_fmt='yuv420p', 101 | vcodec='libx264', 102 | loglevel='error', 103 | acodec='copy').overwrite_output().run_async( 104 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 105 | else: 106 | self.stream_writer = ( 107 | ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', 108 | framerate=fps).output( 109 | video_save_path, pix_fmt='yuv420p', vcodec='libx264', 110 | loglevel='error').overwrite_output().run_async( 111 | pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) 112 | 113 | def write_frame(self, frame): 114 | try: 115 | frame = frame.astype(np.uint8).tobytes() 116 | self.stream_writer.stdin.write(frame) 117 | except BrokenPipeError: 118 | print('Please re-install ffmpeg and libx264 by running\n', 119 | '\t$ conda install -c conda-forge ffmpeg\n', 120 | '\t$ conda install -c conda-forge x264') 121 | sys.exit(0) 122 | 123 | def close(self): 124 | self.stream_writer.stdin.close() 125 | self.stream_writer.wait() -------------------------------------------------------------------------------- /hallo/basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Sat Sep 14 04:53:53 2024 3 | __version__ = '1.3.2' 4 | __gitsha__ = '' 5 | version_info = (1, 3, 2) 6 | -------------------------------------------------------------------------------- /hallo/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/datasets/__init__.py -------------------------------------------------------------------------------- /hallo/datasets/mask_image.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module contains the code for a dataset class called FaceMaskDataset, which is used to process and 4 | load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and 5 | provides methods for data augmentation, getting items from the dataset, and determining the length of the 6 | dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch, 7 | PIL, and transformers. 8 | """ 9 | 10 | import json 11 | import random 12 | from pathlib import Path 13 | 14 | import torch 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | from torchvision import transforms 18 | from transformers import CLIPImageProcessor 19 | 20 | 21 | class FaceMaskDataset(Dataset): 22 | """ 23 | FaceMaskDataset is a custom dataset for face mask images. 24 | 25 | Args: 26 | img_size (int): The size of the input images. 27 | drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1. 28 | data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"]. 29 | sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30. 30 | 31 | Attributes: 32 | img_size (int): The size of the input images. 33 | drop_ratio (float): The ratio of dropped pixels during data augmentation. 34 | data_meta_paths (list): The paths to the metadata files containing image paths and labels. 35 | sample_margin (int): The margin for sampling regions in the image. 36 | processor (CLIPImageProcessor): The image processor for preprocessing images. 37 | transform (transforms.Compose): The image augmentation transform. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | img_size, 43 | drop_ratio=0.1, 44 | data_meta_paths=None, 45 | sample_margin=30, 46 | ): 47 | super().__init__() 48 | 49 | self.img_size = img_size 50 | self.sample_margin = sample_margin 51 | 52 | vid_meta = [] 53 | for data_meta_path in data_meta_paths: 54 | with open(data_meta_path, "r", encoding="utf-8") as f: 55 | vid_meta.extend(json.load(f)) 56 | self.vid_meta = vid_meta 57 | self.length = len(self.vid_meta) 58 | 59 | self.clip_image_processor = CLIPImageProcessor() 60 | 61 | self.transform = transforms.Compose( 62 | [ 63 | transforms.Resize(self.img_size), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.5], [0.5]), 66 | ] 67 | ) 68 | 69 | self.cond_transform = transforms.Compose( 70 | [ 71 | transforms.Resize(self.img_size), 72 | transforms.ToTensor(), 73 | ] 74 | ) 75 | 76 | self.drop_ratio = drop_ratio 77 | 78 | def augmentation(self, image, transform, state=None): 79 | """ 80 | Apply data augmentation to the input image. 81 | 82 | Args: 83 | image (PIL.Image): The input image. 84 | transform (torchvision.transforms.Compose): The data augmentation transforms. 85 | state (dict, optional): The random state for reproducibility. Defaults to None. 86 | 87 | Returns: 88 | PIL.Image: The augmented image. 89 | """ 90 | if state is not None: 91 | torch.set_rng_state(state) 92 | return transform(image) 93 | 94 | def __getitem__(self, index): 95 | video_meta = self.vid_meta[index] 96 | video_path = video_meta["image_path"] 97 | mask_path = video_meta["mask_path"] 98 | face_emb_path = video_meta["face_emb"] 99 | 100 | video_frames = sorted(Path(video_path).iterdir()) 101 | video_length = len(video_frames) 102 | 103 | margin = min(self.sample_margin, video_length) 104 | 105 | ref_img_idx = random.randint(0, video_length - 1) 106 | if ref_img_idx + margin < video_length: 107 | tgt_img_idx = random.randint( 108 | ref_img_idx + margin, video_length - 1) 109 | elif ref_img_idx - margin > 0: 110 | tgt_img_idx = random.randint(0, ref_img_idx - margin) 111 | else: 112 | tgt_img_idx = random.randint(0, video_length - 1) 113 | 114 | ref_img_pil = Image.open(video_frames[ref_img_idx]) 115 | tgt_img_pil = Image.open(video_frames[tgt_img_idx]) 116 | 117 | tgt_mask_pil = Image.open(mask_path) 118 | 119 | assert ref_img_pil is not None, "Fail to load reference image." 120 | assert tgt_img_pil is not None, "Fail to load target image." 121 | assert tgt_mask_pil is not None, "Fail to load target mask." 122 | 123 | state = torch.get_rng_state() 124 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state) 125 | tgt_mask_img = self.augmentation( 126 | tgt_mask_pil, self.cond_transform, state) 127 | tgt_mask_img = tgt_mask_img.repeat(3, 1, 1) 128 | ref_img_vae = self.augmentation( 129 | ref_img_pil, self.transform, state) 130 | face_emb = torch.load(face_emb_path) 131 | 132 | 133 | sample = { 134 | "video_dir": video_path, 135 | "img": tgt_img, 136 | "tgt_mask": tgt_mask_img, 137 | "ref_img": ref_img_vae, 138 | "face_emb": face_emb, 139 | } 140 | 141 | return sample 142 | 143 | def __len__(self): 144 | return len(self.vid_meta) 145 | 146 | 147 | if __name__ == "__main__": 148 | data = FaceMaskDataset(img_size=(512, 512)) 149 | train_dataloader = torch.utils.data.DataLoader( 150 | data, batch_size=4, shuffle=True, num_workers=1 151 | ) 152 | for step, batch in enumerate(train_dataloader): 153 | print(batch["tgt_mask"].shape) 154 | break 155 | -------------------------------------------------------------------------------- /hallo/facelib/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from copy import deepcopy 5 | 6 | from ..utils import load_file_from_url 7 | from ..utils import download_pretrained_models 8 | from ..detection.yolov5face.models.common import Conv 9 | 10 | from .retinaface.retinaface import RetinaFace 11 | from .yolov5face.face_detector import YoloDetector 12 | 13 | 14 | def init_detection_model(model_name, half=False, device='cuda'): 15 | if 'net' in model_name.lower() : 16 | model = init_retinaface_model(model_name, half, device) 17 | elif 'yolov5' in model_name.lower(): 18 | model = init_yolov5face_model(model_name, device) 19 | else: 20 | raise NotImplementedError(f'{model_name} is not implemented.') 21 | 22 | return model 23 | 24 | 25 | def init_retinaface_model(model_name, half=False, device='cuda'): 26 | 27 | if 'resnet50' in model_name.lower(): 28 | model = RetinaFace(network_name='resnet50', half=half) 29 | #model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' 30 | elif 'net0.25' in model_name.lower(): 31 | model = RetinaFace(network_name='mobile0.25', half=half) 32 | #model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' 33 | else: 34 | raise NotImplementedError(f'{model_name} is not implemented.') 35 | 36 | #model_path = load_file_from_url(url=model_url, model_dir='pretrained_models/facelib', progress=True, file_name=None) 37 | load_net = torch.load(model_name, map_location=lambda storage, loc: storage) 38 | # remove unnecessary 'module.' 39 | for k, v in deepcopy(load_net).items(): 40 | if k.startswith('module.'): 41 | load_net[k[7:]] = v 42 | load_net.pop(k) 43 | model.load_state_dict(load_net, strict=True) 44 | model.eval() 45 | model = model.to(device) 46 | 47 | return model 48 | 49 | 50 | def init_yolov5face_model(model_name, device='cuda'): 51 | if "yolov5l" in model_name.lower(): 52 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 53 | #model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' 54 | elif "yolov5n" in model_name.lower(): 55 | model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 56 | #model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' 57 | else: 58 | raise NotImplementedError(f'{model_name} is not implemented.') 59 | 60 | #model_path = load_file_from_url(url=model_url, model_dir='pretrained_models/facelib', progress=True, file_name=None) 61 | load_net = torch.load(model_name, map_location=lambda storage, loc: storage) 62 | model.detector.load_state_dict(load_net, strict=True) 63 | model.detector.eval() 64 | model.detector = model.detector.to(device).float() 65 | 66 | for m in model.detector.modules(): 67 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 68 | m.inplace = True # pytorch 1.7.0 compatibility 69 | elif isinstance(m, Conv): 70 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 71 | 72 | return model 73 | 74 | 75 | # Download from Google Drive 76 | # def init_yolov5face_model(model_name, device='cuda'): 77 | # if model_name == 'YOLOv5l': 78 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) 79 | # f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} 80 | # elif model_name == 'YOLOv5n': 81 | # model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) 82 | # f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} 83 | # else: 84 | # raise NotImplementedError(f'{model_name} is not implemented.') 85 | 86 | # model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) 87 | # if not os.path.exists(model_path): 88 | # download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') 89 | 90 | # load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 91 | # model.detector.load_state_dict(load_net, strict=True) 92 | # model.detector.eval() 93 | # model.detector = model.detector.to(device).float() 94 | 95 | # for m in model.detector.modules(): 96 | # if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 97 | # m.inplace = True # pytorch 1.7.0 compatibility 98 | # elif isinstance(m, Conv): 99 | # m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 100 | 101 | # return model -------------------------------------------------------------------------------- /hallo/facelib/detection/retinaface/retinaface_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(inp, oup, stride=1, leaky=0): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), 9 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 10 | 11 | 12 | def conv_bn_no_relu(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | ) 17 | 18 | 19 | def conv_bn1X1(inp, oup, stride, leaky=0): 20 | return nn.Sequential( 21 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), 22 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 23 | 24 | 25 | def conv_dw(inp, oup, stride, leaky=0.1): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 28 | nn.BatchNorm2d(inp), 29 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 31 | nn.BatchNorm2d(oup), 32 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 33 | ) 34 | 35 | 36 | class SSH(nn.Module): 37 | 38 | def __init__(self, in_channel, out_channel): 39 | super(SSH, self).__init__() 40 | assert out_channel % 4 == 0 41 | leaky = 0 42 | if (out_channel <= 64): 43 | leaky = 0.1 44 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) 45 | 46 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) 47 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 48 | 49 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) 50 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 51 | 52 | def forward(self, input): 53 | conv3X3 = self.conv3X3(input) 54 | 55 | conv5X5_1 = self.conv5X5_1(input) 56 | conv5X5 = self.conv5X5_2(conv5X5_1) 57 | 58 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 59 | conv7X7 = self.conv7x7_3(conv7X7_2) 60 | 61 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class FPN(nn.Module): 67 | 68 | def __init__(self, in_channels_list, out_channels): 69 | super(FPN, self).__init__() 70 | leaky = 0 71 | if (out_channels <= 64): 72 | leaky = 0.1 73 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) 74 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) 75 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) 76 | 77 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) 78 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) 79 | 80 | def forward(self, input): 81 | # names = list(input.keys()) 82 | # input = list(input.values()) 83 | 84 | output1 = self.output1(input[0]) 85 | output2 = self.output2(input[1]) 86 | output3 = self.output3(input[2]) 87 | 88 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') 89 | output2 = output2 + up3 90 | output2 = self.merge2(output2) 91 | 92 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') 93 | output1 = output1 + up2 94 | output1 = self.merge1(output1) 95 | 96 | out = [output1, output2, output3] 97 | return out 98 | 99 | 100 | class MobileNetV1(nn.Module): 101 | 102 | def __init__(self): 103 | super(MobileNetV1, self).__init__() 104 | self.stage1 = nn.Sequential( 105 | conv_bn(3, 8, 2, leaky=0.1), # 3 106 | conv_dw(8, 16, 1), # 7 107 | conv_dw(16, 32, 2), # 11 108 | conv_dw(32, 32, 1), # 19 109 | conv_dw(32, 64, 2), # 27 110 | conv_dw(64, 64, 1), # 43 111 | ) 112 | self.stage2 = nn.Sequential( 113 | conv_dw(64, 128, 2), # 43 + 16 = 59 114 | conv_dw(128, 128, 1), # 59 + 32 = 91 115 | conv_dw(128, 128, 1), # 91 + 32 = 123 116 | conv_dw(128, 128, 1), # 123 + 32 = 155 117 | conv_dw(128, 128, 1), # 155 + 32 = 187 118 | conv_dw(128, 128, 1), # 187 + 32 = 219 119 | ) 120 | self.stage3 = nn.Sequential( 121 | conv_dw(128, 256, 2), # 219 +3 2 = 241 122 | conv_dw(256, 256, 1), # 241 + 64 = 301 123 | ) 124 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 125 | self.fc = nn.Linear(256, 1000) 126 | 127 | def forward(self, x): 128 | x = self.stage1(x) 129 | x = self.stage2(x) 130 | x = self.stage3(x) 131 | x = self.avg(x) 132 | # x = self.model(x) 133 | x = x.view(-1, 256) 134 | x = self.fc(x) 135 | return x 136 | 137 | 138 | class ClassHead(nn.Module): 139 | 140 | def __init__(self, inchannels=512, num_anchors=3): 141 | super(ClassHead, self).__init__() 142 | self.num_anchors = num_anchors 143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 144 | 145 | def forward(self, x): 146 | out = self.conv1x1(x) 147 | out = out.permute(0, 2, 3, 1).contiguous() 148 | 149 | return out.view(out.shape[0], -1, 2) 150 | 151 | 152 | class BboxHead(nn.Module): 153 | 154 | def __init__(self, inchannels=512, num_anchors=3): 155 | super(BboxHead, self).__init__() 156 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 157 | 158 | def forward(self, x): 159 | out = self.conv1x1(x) 160 | out = out.permute(0, 2, 3, 1).contiguous() 161 | 162 | return out.view(out.shape[0], -1, 4) 163 | 164 | 165 | class LandmarkHead(nn.Module): 166 | 167 | def __init__(self, inchannels=512, num_anchors=3): 168 | super(LandmarkHead, self).__init__() 169 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 170 | 171 | def forward(self, x): 172 | out = self.conv1x1(x) 173 | out = out.permute(0, 2, 3, 1).contiguous() 174 | 175 | return out.view(out.shape[0], -1, 10) 176 | 177 | 178 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): 179 | classhead = nn.ModuleList() 180 | for i in range(fpn_num): 181 | classhead.append(ClassHead(inchannels, anchor_num)) 182 | return classhead 183 | 184 | 185 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): 186 | bboxhead = nn.ModuleList() 187 | for i in range(fpn_num): 188 | bboxhead.append(BboxHead(inchannels, anchor_num)) 189 | return bboxhead 190 | 191 | 192 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): 193 | landmarkhead = nn.ModuleList() 194 | for i in range(fpn_num): 195 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 196 | return landmarkhead 197 | -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/facelib/detection/yolov5face/__init__.py -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/face_detector.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import copy 3 | import re 4 | import torch 5 | import numpy as np 6 | 7 | from pathlib import Path 8 | from .models.yolo import Model 9 | from .utils.datasets import letterbox 10 | from .utils.general import ( 11 | check_img_size, 12 | non_max_suppression_face, 13 | scale_coords, 14 | scale_coords_landmarks, 15 | ) 16 | 17 | # IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9) 18 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 19 | torch.__version__)[0][:3])] >= [1, 9, 0] 20 | 21 | 22 | def isListempty(inList): 23 | if isinstance(inList, list): # Is a list 24 | return all(map(isListempty, inList)) 25 | return False # Not a list 26 | 27 | class YoloDetector: 28 | def __init__( 29 | self, 30 | config_name, 31 | min_face=10, 32 | target_size=None, 33 | device='cuda', 34 | ): 35 | """ 36 | config_name: name of .yaml config with network configuration from models/ folder. 37 | min_face : minimal face size in pixels. 38 | target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. 39 | None for original resolution. 40 | """ 41 | self._class_path = Path(__file__).parent.absolute() 42 | self.target_size = target_size 43 | self.min_face = min_face 44 | self.detector = Model(cfg=config_name) 45 | self.device = device 46 | 47 | 48 | def _preprocess(self, imgs): 49 | """ 50 | Preprocessing image before passing through the network. Resize and conversion to torch tensor. 51 | """ 52 | pp_imgs = [] 53 | for img in imgs: 54 | h0, w0 = img.shape[:2] # orig hw 55 | if self.target_size: 56 | r = self.target_size / min(h0, w0) # resize image to img_size 57 | if r < 1: 58 | img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) 59 | 60 | imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size 61 | img = letterbox(img, new_shape=imgsz)[0] 62 | pp_imgs.append(img) 63 | pp_imgs = np.array(pp_imgs) 64 | pp_imgs = pp_imgs.transpose(0, 3, 1, 2) 65 | pp_imgs = torch.from_numpy(pp_imgs).to(self.device) 66 | pp_imgs = pp_imgs.float() # uint8 to fp16/32 67 | return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 68 | 69 | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): 70 | """ 71 | Postprocessing of raw pytorch model output. 72 | Returns: 73 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 74 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 75 | """ 76 | bboxes = [[] for _ in range(len(origimgs))] 77 | landmarks = [[] for _ in range(len(origimgs))] 78 | 79 | pred = non_max_suppression_face(pred, conf_thres, iou_thres) 80 | 81 | for image_id, origimg in enumerate(origimgs): 82 | img_shape = origimg.shape 83 | image_height, image_width = img_shape[:2] 84 | gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh 85 | gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks 86 | det = pred[image_id].cpu() 87 | scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() 88 | scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() 89 | 90 | for j in range(det.size()[0]): 91 | box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() 92 | box = list( 93 | map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) 94 | ) 95 | if box[3] - box[1] < self.min_face: 96 | continue 97 | lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() 98 | lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) 99 | lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] 100 | bboxes[image_id].append(box) 101 | landmarks[image_id].append(lm) 102 | return bboxes, landmarks 103 | 104 | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): 105 | """ 106 | Get bbox coordinates and keypoints of faces on original image. 107 | Params: 108 | imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) 109 | conf_thres: confidence threshold for each prediction 110 | iou_thres: threshold for NMS (filter of intersecting bboxes) 111 | Returns: 112 | bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. 113 | points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). 114 | """ 115 | # Pass input images through face detector 116 | images = imgs if isinstance(imgs, list) else [imgs] 117 | images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] 118 | origimgs = copy.deepcopy(images) 119 | 120 | images = self._preprocess(images) 121 | 122 | if IS_HIGH_VERSION: 123 | with torch.inference_mode(): # for pytorch>=1.9 124 | pred = self.detector(images)[0] 125 | else: 126 | with torch.no_grad(): # for pytorch<1.9 127 | pred = self.detector(images)[0] 128 | 129 | bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) 130 | 131 | # return bboxes, points 132 | if not isListempty(points): 133 | bboxes = np.array(bboxes).reshape(-1,4) 134 | points = np.array(points).reshape(-1,10) 135 | padding = bboxes[:,0].reshape(-1,1) 136 | return np.concatenate((bboxes, padding, points), axis=1) 137 | else: 138 | return None 139 | 140 | def __call__(self, *args): 141 | return self.predict(*args) 142 | -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/facelib/detection/yolov5face/models/__init__.py -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/models/experimental.py: -------------------------------------------------------------------------------- 1 | # # This file contains experimental modules 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from .common import Conv 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super().__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class MixConv2d(nn.Module): 25 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 26 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 27 | super().__init__() 28 | groups = len(k) 29 | if equal_ch: # equal c_ per group 30 | i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices 31 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 32 | else: # equal weight.numel() per group 33 | b = [c2] + [0] * groups 34 | a = np.eye(groups + 1, groups, k=-1) 35 | a -= np.roll(a, 1, axis=1) 36 | a *= np.array(k) ** 2 37 | a[0] = 1 38 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 39 | 40 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 41 | self.bn = nn.BatchNorm2d(c2) 42 | self.act = nn.LeakyReLU(0.1, inplace=True) 43 | 44 | def forward(self, x): 45 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 46 | -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/models/yolov5l.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 16 | [-1, 3, C3, [128]], 17 | [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 18 | [-1, 9, C3, [256]], 19 | [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 20 | [-1, 9, C3, [512]], 21 | [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 22 | [-1, 1, SPP, [1024, [3,5,7]]], 23 | [-1, 3, C3, [1024, False]], # 8 24 | ] 25 | 26 | # YOLOv5 head 27 | head: 28 | [[-1, 1, Conv, [512, 1, 1]], 29 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 30 | [[-1, 5], 1, Concat, [1]], # cat backbone P4 31 | [-1, 3, C3, [512, False]], # 12 32 | 33 | [-1, 1, Conv, [256, 1, 1]], 34 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 35 | [[-1, 3], 1, Concat, [1]], # cat backbone P3 36 | [-1, 3, C3, [256, False]], # 16 (P3/8-small) 37 | 38 | [-1, 1, Conv, [256, 3, 2]], 39 | [[-1, 13], 1, Concat, [1]], # cat head P4 40 | [-1, 3, C3, [512, False]], # 19 (P4/16-medium) 41 | 42 | [-1, 1, Conv, [512, 3, 2]], 43 | [[-1, 9], 1, Concat, [1]], # cat head P5 44 | [-1, 3, C3, [1024, False]], # 22 (P5/32-large) 45 | 46 | [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 47 | ] -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/models/yolov5n.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 1 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [4,5, 8,10, 13,16] # P3/8 9 | - [23,29, 43,55, 73,105] # P4/16 10 | - [146,217, 231,300, 335,433] # P5/32 11 | 12 | # YOLOv5 backbone 13 | backbone: 14 | # [from, number, module, args] 15 | [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 16 | [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 17 | [-1, 3, ShuffleV2Block, [128, 1]], # 2 18 | [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 19 | [-1, 7, ShuffleV2Block, [256, 1]], # 4 20 | [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 21 | [-1, 3, ShuffleV2Block, [512, 1]], # 6 22 | ] 23 | 24 | # YOLOv5 head 25 | head: 26 | [[-1, 1, Conv, [128, 1, 1]], 27 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 28 | [[-1, 4], 1, Concat, [1]], # cat backbone P4 29 | [-1, 1, C3, [128, False]], # 10 30 | 31 | [-1, 1, Conv, [128, 1, 1]], 32 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 33 | [[-1, 2], 1, Concat, [1]], # cat backbone P3 34 | [-1, 1, C3, [128, False]], # 14 (P3/8-small) 35 | 36 | [-1, 1, Conv, [128, 3, 2]], 37 | [[-1, 11], 1, Concat, [1]], # cat head P4 38 | [-1, 1, C3, [128, False]], # 17 (P4/16-medium) 39 | 40 | [-1, 1, Conv, [128, 3, 2]], 41 | [[-1, 7], 1, Concat, [1]], # cat head P5 42 | [-1, 1, C3, [128, False]], # 20 (P5/32-large) 43 | 44 | [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) 45 | ] 46 | -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/facelib/detection/yolov5face/utils/__init__.py -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | 4 | def check_anchor_order(m): 5 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 6 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 7 | da = a[-1] - a[0] # delta a 8 | ds = m.stride[-1] - m.stride[0] # delta s 9 | if da.sign() != ds.sign(): # same order 10 | print("Reversing anchor order") 11 | m.anchors[:] = m.anchors.flip(0) 12 | m.anchor_grid[:] = m.anchor_grid.flip(0) 13 | -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): 6 | # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 7 | shape = img.shape[:2] # current shape [height, width] 8 | if isinstance(new_shape, int): 9 | new_shape = (new_shape, new_shape) 10 | 11 | # Scale ratio (new / old) 12 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 13 | if not scaleup: # only scale down, do not scale up (for better test mAP) 14 | r = min(r, 1.0) 15 | 16 | # Compute padding 17 | ratio = r, r # width, height ratios 18 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 19 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 20 | if auto: # minimum rectangle 21 | dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding 22 | elif scale_fill: # stretch 23 | dw, dh = 0.0, 0.0 24 | new_unpad = (new_shape[1], new_shape[0]) 25 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 26 | 27 | dw /= 2 # divide padding into 2 sides 28 | dh /= 2 29 | 30 | if shape[::-1] != new_unpad: # resize 31 | img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) 32 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 33 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 34 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 35 | return img, ratio, (dw, dh) 36 | -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/utils/extract_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.insert(0,'./facelib/detection/yolov5face') 4 | model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] 5 | torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') -------------------------------------------------------------------------------- /hallo/facelib/detection/yolov5face/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def fuse_conv_and_bn(conv, bn): 6 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 7 | fusedconv = ( 8 | nn.Conv2d( 9 | conv.in_channels, 10 | conv.out_channels, 11 | kernel_size=conv.kernel_size, 12 | stride=conv.stride, 13 | padding=conv.padding, 14 | groups=conv.groups, 15 | bias=True, 16 | ) 17 | .requires_grad_(False) 18 | .to(conv.weight.device) 19 | ) 20 | 21 | # prepare filters 22 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 23 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 24 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) 25 | 26 | # prepare spatial bias 27 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 28 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 29 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 30 | 31 | return fusedconv 32 | 33 | 34 | def copy_attr(a, b, include=(), exclude=()): 35 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 36 | for k, v in b.__dict__.items(): 37 | if (include and k not in include) or k.startswith("_") or k in exclude: 38 | continue 39 | 40 | setattr(a, k, v) 41 | -------------------------------------------------------------------------------- /hallo/facelib/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import torch 4 | 5 | from ..utils import load_file_from_url 6 | from .bisenet import BiSeNet 7 | from .parsenet import ParseNet 8 | 9 | 10 | def init_parsing_model(model_name='bisenet', pars_model_path="",half=False, device='cuda'): 11 | if model_name == 'bisenet': 12 | model = BiSeNet(num_class=19) 13 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' 14 | elif model_name == 'parsenet': 15 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19) 16 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' 17 | else: 18 | raise NotImplementedError(f'{model_name} is not implemented.') 19 | if not os.path.exists(pars_model_path): 20 | model_path = load_file_from_url(url=model_url, model_dir=f'{os.path.join(folder_paths.models_dir,"Hallo")}/facelib', progress=True, file_name=None) 21 | else: 22 | model_path=pars_model_path 23 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 24 | model.load_state_dict(load_net, strict=True) 25 | model.eval() 26 | model = model.to(device) 27 | return model 28 | -------------------------------------------------------------------------------- /hallo/facelib/parsing/bisenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .resnet import ResNet18 6 | 7 | 8 | class ConvBNReLU(nn.Module): 9 | 10 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): 11 | super(ConvBNReLU, self).__init__() 12 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) 13 | self.bn = nn.BatchNorm2d(out_chan) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = F.relu(self.bn(x)) 18 | return x 19 | 20 | 21 | class BiSeNetOutput(nn.Module): 22 | 23 | def __init__(self, in_chan, mid_chan, num_class): 24 | super(BiSeNetOutput, self).__init__() 25 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 26 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) 27 | 28 | def forward(self, x): 29 | feat = self.conv(x) 30 | out = self.conv_out(feat) 31 | return out, feat 32 | 33 | 34 | class AttentionRefinementModule(nn.Module): 35 | 36 | def __init__(self, in_chan, out_chan): 37 | super(AttentionRefinementModule, self).__init__() 38 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 39 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 40 | self.bn_atten = nn.BatchNorm2d(out_chan) 41 | self.sigmoid_atten = nn.Sigmoid() 42 | 43 | def forward(self, x): 44 | feat = self.conv(x) 45 | atten = F.avg_pool2d(feat, feat.size()[2:]) 46 | atten = self.conv_atten(atten) 47 | atten = self.bn_atten(atten) 48 | atten = self.sigmoid_atten(atten) 49 | out = torch.mul(feat, atten) 50 | return out 51 | 52 | 53 | class ContextPath(nn.Module): 54 | 55 | def __init__(self): 56 | super(ContextPath, self).__init__() 57 | self.resnet = ResNet18() 58 | self.arm16 = AttentionRefinementModule(256, 128) 59 | self.arm32 = AttentionRefinementModule(512, 128) 60 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 61 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 62 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 63 | 64 | def forward(self, x): 65 | feat8, feat16, feat32 = self.resnet(x) 66 | h8, w8 = feat8.size()[2:] 67 | h16, w16 = feat16.size()[2:] 68 | h32, w32 = feat32.size()[2:] 69 | 70 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 71 | avg = self.conv_avg(avg) 72 | avg_up = F.interpolate(avg, (h32, w32), mode='nearest') 73 | 74 | feat32_arm = self.arm32(feat32) 75 | feat32_sum = feat32_arm + avg_up 76 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') 77 | feat32_up = self.conv_head32(feat32_up) 78 | 79 | feat16_arm = self.arm16(feat16) 80 | feat16_sum = feat16_arm + feat32_up 81 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') 82 | feat16_up = self.conv_head16(feat16_up) 83 | 84 | return feat8, feat16_up, feat32_up # x8, x8, x16 85 | 86 | 87 | class FeatureFusionModule(nn.Module): 88 | 89 | def __init__(self, in_chan, out_chan): 90 | super(FeatureFusionModule, self).__init__() 91 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 92 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) 93 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.sigmoid = nn.Sigmoid() 96 | 97 | def forward(self, fsp, fcp): 98 | fcat = torch.cat([fsp, fcp], dim=1) 99 | feat = self.convblk(fcat) 100 | atten = F.avg_pool2d(feat, feat.size()[2:]) 101 | atten = self.conv1(atten) 102 | atten = self.relu(atten) 103 | atten = self.conv2(atten) 104 | atten = self.sigmoid(atten) 105 | feat_atten = torch.mul(feat, atten) 106 | feat_out = feat_atten + feat 107 | return feat_out 108 | 109 | 110 | class BiSeNet(nn.Module): 111 | 112 | def __init__(self, num_class): 113 | super(BiSeNet, self).__init__() 114 | self.cp = ContextPath() 115 | self.ffm = FeatureFusionModule(256, 256) 116 | self.conv_out = BiSeNetOutput(256, 256, num_class) 117 | self.conv_out16 = BiSeNetOutput(128, 64, num_class) 118 | self.conv_out32 = BiSeNetOutput(128, 64, num_class) 119 | 120 | def forward(self, x, return_feat=False): 121 | h, w = x.size()[2:] 122 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature 123 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature 124 | feat_fuse = self.ffm(feat_sp, feat_cp8) 125 | 126 | out, feat = self.conv_out(feat_fuse) 127 | out16, feat16 = self.conv_out16(feat_cp8) 128 | out32, feat32 = self.conv_out32(feat_cp16) 129 | 130 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 131 | out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) 132 | out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) 133 | 134 | if return_feat: 135 | feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) 136 | feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) 137 | feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) 138 | return out, out16, out32, feat, feat16, feat32 139 | else: 140 | return out, out16, out32 141 | -------------------------------------------------------------------------------- /hallo/facelib/parsing/parsenet.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/chaofengc/PSFRGAN 2 | """ 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class NormLayer(nn.Module): 9 | """Normalization Layers. 10 | 11 | Args: 12 | channels: input channels, for batch norm and instance norm. 13 | input_size: input shape without batch size, for layer norm. 14 | """ 15 | 16 | def __init__(self, channels, normalize_shape=None, norm_type='bn'): 17 | super(NormLayer, self).__init__() 18 | norm_type = norm_type.lower() 19 | self.norm_type = norm_type 20 | if norm_type == 'bn': 21 | self.norm = nn.BatchNorm2d(channels, affine=True) 22 | elif norm_type == 'in': 23 | self.norm = nn.InstanceNorm2d(channels, affine=False) 24 | elif norm_type == 'gn': 25 | self.norm = nn.GroupNorm(32, channels, affine=True) 26 | elif norm_type == 'pixel': 27 | self.norm = lambda x: F.normalize(x, p=2, dim=1) 28 | elif norm_type == 'layer': 29 | self.norm = nn.LayerNorm(normalize_shape) 30 | elif norm_type == 'none': 31 | self.norm = lambda x: x * 1.0 32 | else: 33 | assert 1 == 0, f'Norm type {norm_type} not support.' 34 | 35 | def forward(self, x, ref=None): 36 | if self.norm_type == 'spade': 37 | return self.norm(x, ref) 38 | else: 39 | return self.norm(x) 40 | 41 | 42 | class ReluLayer(nn.Module): 43 | """Relu Layer. 44 | 45 | Args: 46 | relu type: type of relu layer, candidates are 47 | - ReLU 48 | - LeakyReLU: default relu slope 0.2 49 | - PRelu 50 | - SELU 51 | - none: direct pass 52 | """ 53 | 54 | def __init__(self, channels, relu_type='relu'): 55 | super(ReluLayer, self).__init__() 56 | relu_type = relu_type.lower() 57 | if relu_type == 'relu': 58 | self.func = nn.ReLU(True) 59 | elif relu_type == 'leakyrelu': 60 | self.func = nn.LeakyReLU(0.2, inplace=True) 61 | elif relu_type == 'prelu': 62 | self.func = nn.PReLU(channels) 63 | elif relu_type == 'selu': 64 | self.func = nn.SELU(True) 65 | elif relu_type == 'none': 66 | self.func = lambda x: x * 1.0 67 | else: 68 | assert 1 == 0, f'Relu type {relu_type} not support.' 69 | 70 | def forward(self, x): 71 | return self.func(x) 72 | 73 | 74 | class ConvLayer(nn.Module): 75 | 76 | def __init__(self, 77 | in_channels, 78 | out_channels, 79 | kernel_size=3, 80 | scale='none', 81 | norm_type='none', 82 | relu_type='none', 83 | use_pad=True, 84 | bias=True): 85 | super(ConvLayer, self).__init__() 86 | self.use_pad = use_pad 87 | self.norm_type = norm_type 88 | if norm_type in ['bn']: 89 | bias = False 90 | 91 | stride = 2 if scale == 'down' else 1 92 | 93 | self.scale_func = lambda x: x 94 | if scale == 'up': 95 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') 96 | 97 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) 98 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 99 | 100 | self.relu = ReluLayer(out_channels, relu_type) 101 | self.norm = NormLayer(out_channels, norm_type=norm_type) 102 | 103 | def forward(self, x): 104 | out = self.scale_func(x) 105 | if self.use_pad: 106 | out = self.reflection_pad(out) 107 | out = self.conv2d(out) 108 | out = self.norm(out) 109 | out = self.relu(out) 110 | return out 111 | 112 | 113 | class ResidualBlock(nn.Module): 114 | """ 115 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html 116 | """ 117 | 118 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): 119 | super(ResidualBlock, self).__init__() 120 | 121 | if scale == 'none' and c_in == c_out: 122 | self.shortcut_func = lambda x: x 123 | else: 124 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) 125 | 126 | scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} 127 | scale_conf = scale_config_dict[scale] 128 | 129 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) 130 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') 131 | 132 | def forward(self, x): 133 | identity = self.shortcut_func(x) 134 | 135 | res = self.conv1(x) 136 | res = self.conv2(res) 137 | return identity + res 138 | 139 | 140 | class ParseNet(nn.Module): 141 | 142 | def __init__(self, 143 | in_size=128, 144 | out_size=128, 145 | min_feat_size=32, 146 | base_ch=64, 147 | parsing_ch=19, 148 | res_depth=10, 149 | relu_type='LeakyReLU', 150 | norm_type='bn', 151 | ch_range=[32, 256]): 152 | super().__init__() 153 | self.res_depth = res_depth 154 | act_args = {'norm_type': norm_type, 'relu_type': relu_type} 155 | min_ch, max_ch = ch_range 156 | 157 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 158 | min_feat_size = min(in_size, min_feat_size) 159 | 160 | down_steps = int(np.log2(in_size // min_feat_size)) 161 | up_steps = int(np.log2(out_size // min_feat_size)) 162 | 163 | # =============== define encoder-body-decoder ==================== 164 | self.encoder = [] 165 | self.encoder.append(ConvLayer(3, base_ch, 3, 1)) 166 | head_ch = base_ch 167 | for i in range(down_steps): 168 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) 169 | self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) 170 | head_ch = head_ch * 2 171 | 172 | self.body = [] 173 | for i in range(res_depth): 174 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) 175 | 176 | self.decoder = [] 177 | for i in range(up_steps): 178 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) 179 | self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) 180 | head_ch = head_ch // 2 181 | 182 | self.encoder = nn.Sequential(*self.encoder) 183 | self.body = nn.Sequential(*self.body) 184 | self.decoder = nn.Sequential(*self.decoder) 185 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) 186 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) 187 | 188 | def forward(self, x): 189 | feat = self.encoder(x) 190 | x = feat + self.body(feat) 191 | x = self.decoder(x) 192 | out_img = self.out_img_conv(x) 193 | out_mask = self.out_mask_conv(x) 194 | return out_mask, out_img 195 | -------------------------------------------------------------------------------- /hallo/facelib/parsing/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_chan, out_chan, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_chan, out_chan, stride) 15 | self.bn1 = nn.BatchNorm2d(out_chan) 16 | self.conv2 = conv3x3(out_chan, out_chan) 17 | self.bn2 = nn.BatchNorm2d(out_chan) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = None 20 | if in_chan != out_chan or stride != 1: 21 | self.downsample = nn.Sequential( 22 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(out_chan), 24 | ) 25 | 26 | def forward(self, x): 27 | residual = self.conv1(x) 28 | residual = F.relu(self.bn1(residual)) 29 | residual = self.conv2(residual) 30 | residual = self.bn2(residual) 31 | 32 | shortcut = x 33 | if self.downsample is not None: 34 | shortcut = self.downsample(x) 35 | 36 | out = shortcut + residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 42 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 43 | for i in range(bnum - 1): 44 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class ResNet18(nn.Module): 49 | 50 | def __init__(self): 51 | super(ResNet18, self).__init__() 52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.bn1 = nn.BatchNorm2d(64) 54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 55 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 56 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 57 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 58 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.relu(self.bn1(x)) 63 | x = self.maxpool(x) 64 | 65 | x = self.layer1(x) 66 | feat8 = self.layer2(x) # 1/8 67 | feat16 = self.layer3(feat8) # 1/16 68 | feat32 = self.layer4(feat16) # 1/32 69 | return feat8, feat16, feat32 70 | -------------------------------------------------------------------------------- /hallo/facelib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back 2 | from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir 3 | 4 | __all__ = [ 5 | 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 6 | 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' 7 | ] 8 | -------------------------------------------------------------------------------- /hallo/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/models/__init__.py -------------------------------------------------------------------------------- /hallo/models/audio_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the implementation of an Audio Projection Model, which is designed for 3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens 4 | that can be used for various downstream applications, such as audio analysis or synthesis. 5 | 6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which 7 | provides a foundation for building custom models. This implementation includes multiple linear 8 | layers with ReLU activation functions and a LayerNorm for normalization. 9 | 10 | Key Features: 11 | - Audio embedding input with flexible sequence length and block structure. 12 | - Multiple linear layers for feature transformation. 13 | - ReLU activation for non-linear transformation. 14 | - LayerNorm for stabilizing and speeding up training. 15 | - Rearrangement of input embeddings to match the model's expected input shape. 16 | - Customizable number of blocks, channels, and context tokens for adaptability. 17 | 18 | The module is structured to be easily integrated into larger systems or used as a standalone 19 | component for audio feature extraction and processing. 20 | 21 | Classes: 22 | - AudioProjModel: A class representing the audio projection model with configurable parameters. 23 | 24 | Functions: 25 | - (none) 26 | 27 | Dependencies: 28 | - torch: For tensor operations and neural network components. 29 | - diffusers: For the ModelMixin base class. 30 | - einops: For tensor rearrangement operations. 31 | 32 | """ 33 | 34 | import torch 35 | from diffusers import ModelMixin 36 | from einops import rearrange 37 | from torch import nn 38 | 39 | 40 | class AudioProjModel(ModelMixin): 41 | """Audio Projection Model 42 | 43 | This class defines an audio projection model that takes audio embeddings as input 44 | and produces context tokens as output. The model is based on the ModelMixin class 45 | and consists of multiple linear layers and activation functions. It can be used 46 | for various audio processing tasks. 47 | 48 | Attributes: 49 | seq_len (int): The length of the audio sequence. 50 | blocks (int): The number of blocks in the audio projection model. 51 | channels (int): The number of channels in the audio projection model. 52 | intermediate_dim (int): The intermediate dimension of the model. 53 | context_tokens (int): The number of context tokens in the output. 54 | output_dim (int): The output dimension of the context tokens. 55 | 56 | Methods: 57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): 58 | Initializes the AudioProjModel with the given parameters. 59 | forward(self, audio_embeds): 60 | Defines the forward pass for the AudioProjModel. 61 | Parameters: 62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 63 | Returns: 64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 65 | 66 | """ 67 | 68 | def __init__( 69 | self, 70 | seq_len=5, 71 | blocks=12, # add a new parameter blocks 72 | channels=768, # add a new parameter channels 73 | intermediate_dim=512, 74 | output_dim=768, 75 | context_tokens=32, 76 | ): 77 | super().__init__() 78 | 79 | self.seq_len = seq_len 80 | self.blocks = blocks 81 | self.channels = channels 82 | self.input_dim = ( 83 | seq_len * blocks * channels 84 | ) # update input_dim to be the product of blocks and channels. 85 | self.intermediate_dim = intermediate_dim 86 | self.context_tokens = context_tokens 87 | self.output_dim = output_dim 88 | 89 | # define multiple linear layers 90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) 91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) 92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) 93 | 94 | self.norm = nn.LayerNorm(output_dim) 95 | 96 | def forward(self, audio_embeds): 97 | """ 98 | Defines the forward pass for the AudioProjModel. 99 | 100 | Parameters: 101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 102 | 103 | Returns: 104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 105 | """ 106 | # merge 107 | video_length = audio_embeds.shape[1] 108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") 109 | batch_size, window_size, blocks, channels = audio_embeds.shape 110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) 111 | 112 | audio_embeds = torch.relu(self.proj1(audio_embeds)) 113 | audio_embeds = torch.relu(self.proj2(audio_embeds)) 114 | 115 | context_tokens = self.proj3(audio_embeds).reshape( 116 | batch_size, self.context_tokens, self.output_dim 117 | ) 118 | 119 | context_tokens = self.norm(context_tokens) 120 | context_tokens = rearrange( 121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length 122 | ) 123 | 124 | return context_tokens 125 | -------------------------------------------------------------------------------- /hallo/models/face_locator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the FaceLocator class, which is a neural network model designed to 3 | locate and extract facial features from input images or tensors. It uses a series of 4 | convolutional layers to progressively downsample and refine the facial feature map. 5 | 6 | The FaceLocator class is part of a larger system that may involve facial recognition or 7 | similar tasks where precise location and extraction of facial features are required. 8 | 9 | Attributes: 10 | conditioning_embedding_channels (int): The number of channels in the output embedding. 11 | conditioning_channels (int): The number of input channels for the conditioning tensor. 12 | block_out_channels (Tuple[int]): A tuple of integers representing the output channels 13 | for each block in the model. 14 | 15 | The model uses the following components: 16 | - InflatedConv3d: A convolutional layer that inflates the input to increase the depth. 17 | - zero_module: A utility function that may set certain parameters to zero for regularization 18 | or other purposes. 19 | 20 | The forward method of the FaceLocator class takes a conditioning tensor as input and 21 | produces an embedding tensor as output, which can be used for further processing or analysis. 22 | """ 23 | 24 | from typing import Tuple 25 | 26 | import torch.nn.functional as F 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from torch import nn 29 | 30 | from .motion_module import zero_module 31 | from .resnet import InflatedConv3d 32 | 33 | 34 | class FaceLocator(ModelMixin): 35 | """ 36 | The FaceLocator class is a neural network model designed to process and extract facial 37 | features from an input tensor. It consists of a series of convolutional layers that 38 | progressively downsample the input while increasing the depth of the feature map. 39 | 40 | The model is built using InflatedConv3d layers, which are designed to inflate the 41 | feature channels, allowing for more complex feature extraction. The final output is a 42 | conditioning embedding that can be used for various tasks such as facial recognition or 43 | feature-based image manipulation. 44 | 45 | Parameters: 46 | conditioning_embedding_channels (int): The number of channels in the output embedding. 47 | conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3. 48 | block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels 49 | for each block in the model. The default is (16, 32, 64, 128), which defines the 50 | progression of the network's depth. 51 | 52 | Attributes: 53 | conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process. 54 | blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model. 55 | conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding. 56 | 57 | The forward method applies the convolutional layers to the input conditioning tensor and 58 | returns the resulting embedding tensor. 59 | """ 60 | def __init__( 61 | self, 62 | conditioning_embedding_channels: int, 63 | conditioning_channels: int = 3, 64 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 65 | ): 66 | super().__init__() 67 | self.conv_in = InflatedConv3d( 68 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 69 | ) 70 | 71 | self.blocks = nn.ModuleList([]) 72 | 73 | for i in range(len(block_out_channels) - 1): 74 | channel_in = block_out_channels[i] 75 | channel_out = block_out_channels[i + 1] 76 | self.blocks.append( 77 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 78 | ) 79 | self.blocks.append( 80 | InflatedConv3d( 81 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 82 | ) 83 | ) 84 | 85 | self.conv_out = zero_module( 86 | InflatedConv3d( 87 | block_out_channels[-1], 88 | conditioning_embedding_channels, 89 | kernel_size=3, 90 | padding=1, 91 | ) 92 | ) 93 | 94 | def forward(self, conditioning): 95 | """ 96 | Forward pass of the FaceLocator model. 97 | 98 | Args: 99 | conditioning (Tensor): The input conditioning tensor. 100 | 101 | Returns: 102 | Tensor: The output embedding tensor. 103 | """ 104 | embedding = self.conv_in(conditioning) 105 | embedding = F.silu(embedding) 106 | 107 | for block in self.blocks: 108 | embedding = block(embedding) 109 | embedding = F.silu(embedding) 110 | 111 | embedding = self.conv_out(embedding) 112 | 113 | return embedding 114 | -------------------------------------------------------------------------------- /hallo/models/image_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | image_proj_model.py 3 | 4 | This module defines the ImageProjModel class, which is responsible for 5 | projecting image embeddings into a different dimensional space. The model 6 | leverages a linear transformation followed by a layer normalization to 7 | reshape and normalize the input image embeddings for further processing in 8 | cross-attention mechanisms or other downstream tasks. 9 | 10 | Classes: 11 | ImageProjModel 12 | 13 | Dependencies: 14 | torch 15 | diffusers.ModelMixin 16 | 17 | """ 18 | 19 | import torch 20 | from diffusers import ModelMixin 21 | 22 | 23 | class ImageProjModel(ModelMixin): 24 | """ 25 | ImageProjModel is a class that projects image embeddings into a different 26 | dimensional space. It inherits from ModelMixin, providing additional functionalities 27 | specific to image projection. 28 | 29 | Attributes: 30 | cross_attention_dim (int): The dimension of the cross attention. 31 | clip_embeddings_dim (int): The dimension of the CLIP embeddings. 32 | clip_extra_context_tokens (int): The number of extra context tokens in CLIP. 33 | 34 | Methods: 35 | forward(image_embeds): Forward pass of the ImageProjModel, which takes in image 36 | embeddings and returns the projected tokens. 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | cross_attention_dim=1024, 43 | clip_embeddings_dim=1024, 44 | clip_extra_context_tokens=4, 45 | ): 46 | super().__init__() 47 | 48 | self.generator = None 49 | self.cross_attention_dim = cross_attention_dim 50 | self.clip_extra_context_tokens = clip_extra_context_tokens 51 | self.proj = torch.nn.Linear( 52 | clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim 53 | ) 54 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 55 | 56 | def forward(self, image_embeds): 57 | """ 58 | Forward pass of the ImageProjModel, which takes in image embeddings and returns the 59 | projected tokens after reshaping and normalization. 60 | 61 | Args: 62 | image_embeds (torch.Tensor): The input image embeddings, with shape 63 | batch_size x num_image_tokens x clip_embeddings_dim. 64 | 65 | Returns: 66 | clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping 67 | and normalization, with shape batch_size x (clip_extra_context_tokens * 68 | cross_attention_dim). 69 | 70 | """ 71 | embeds = image_embeds 72 | clip_extra_context_tokens = self.proj(embeds).reshape( 73 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 74 | ) 75 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 76 | return clip_extra_context_tokens 77 | -------------------------------------------------------------------------------- /hallo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_Hallo2/ef42c86c6bbcdb7b0def67fe977e0e7b3a36edea/hallo/utils/__init__.py -------------------------------------------------------------------------------- /hallo/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides utility functions for configuration manipulation. 3 | """ 4 | 5 | from typing import Dict 6 | 7 | 8 | def filter_non_none(dict_obj: Dict): 9 | """ 10 | Filters out key-value pairs from the given dictionary where the value is None. 11 | 12 | Args: 13 | dict_obj (Dict): The dictionary to be filtered. 14 | 15 | Returns: 16 | Dict: The dictionary with key-value pairs removed where the value was None. 17 | 18 | This function creates a new dictionary containing only the key-value pairs from 19 | the original dictionary where the value is not None. It then clears the original 20 | dictionary and updates it with the filtered key-value pairs. 21 | """ 22 | non_none_filter = { k: v for k, v in dict_obj.items() if v is not None } 23 | dict_obj.clear() 24 | dict_obj.update(non_none_filter) 25 | return dict_obj 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_hallo2" 3 | description = " Long-Duration and High-Resolution Audio-driven Portrait Image Animation," 4 | version = "1.0.0" 5 | license = { file = "LICENSE" } 6 | dependencies = ["audio-separator", "ffmpeg-python", "icecream", "#av==12.1.0", "#bitsandbytes==0.43.1", "#decord==0.6.0", "#diffusers==0.27.2", "#einops==0.8.0", "#accelerate==0.28.0", "#insightface==0.7.3", "#librosa==0.10.2.post1", "#lpips==0.1.4", "#mediapipe[vision]==0.10.14", "#mlflow==2.13.1", "#moviepy==1.0.3", "#numpy==1.26.4", "#omegaconf==2.3.0", "#onnx2torch==1.5.14", "#onnx==1.16.1", "#onnxruntime-gpu==1.18.0", "#opencv-contrib-python==4.9.0.80", "#opencv-python-headless==4.9.0.80", "#opencv-python==4.9.0.80", "#pillow==10.3.0", "#setuptools==70.0.0", "#tqdm==4.66.4", "#transformers==4.39.2", "#xformers==0.0.25.post1", "#isort==5.13.2", "#pylint==3.2.2", "#pre-commit==3.7.1", "#gradio==4.36.1", "#lpips"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/smthemex/ComfyUI_Hallo2" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "smthemex" 14 | DisplayName = "ComfyUI_Hallo2" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audio-separator 2 | ffmpeg-python 3 | icecream 4 | #av==12.1.0 5 | #bitsandbytes==0.43.1 6 | #decord==0.6.0 7 | #diffusers==0.27.2 8 | #einops==0.8.0 9 | #accelerate==0.28.0 10 | #insightface==0.7.3 11 | #librosa==0.10.2.post1 12 | #lpips==0.1.4 13 | #mediapipe[vision]==0.10.14 14 | #mlflow==2.13.1 15 | #moviepy==1.0.3 16 | #numpy==1.26.4 17 | #omegaconf==2.3.0 18 | #onnx2torch==1.5.14 19 | #onnx==1.16.1 20 | #onnxruntime-gpu==1.18.0 21 | #opencv-contrib-python==4.9.0.80 22 | #opencv-python-headless==4.9.0.80 23 | #opencv-python==4.9.0.80 24 | #pillow==10.3.0 25 | #setuptools==70.0.0 26 | #tqdm==4.66.4 27 | #transformers==4.39.2 28 | #xformers==0.0.25.post1 29 | #isort==5.13.2 30 | #pylint==3.2.2 31 | #pre-commit==3.7.1 32 | #gradio==4.36.1 33 | #lpips 34 | -------------------------------------------------------------------------------- /scripts/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is a gradio web ui. 3 | 4 | The script takes an image and an audio clip, and lets you configure all the 5 | variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc. 6 | 7 | Usage: 8 | This script can be run from the command line with the following command: 9 | 10 | python scripts/app.py 11 | """ 12 | import argparse 13 | 14 | import gradio as gr 15 | from .inference_long import inference_process 16 | 17 | 18 | def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)): 19 | """ 20 | Create a gradio interface with the configs. 21 | """ 22 | _ = progress 23 | config = { 24 | 'source_image': image, 25 | 'driving_audio': audio, 26 | 'pose_weight': pose_weight, 27 | 'face_weight': face_weight, 28 | 'lip_weight': lip_weight, 29 | 'face_expand_ratio': face_expand_ratio, 30 | 'config': 'configs/inference/default.yaml', 31 | 'checkpoint': None, 32 | 'output': ".cache/output.mp4" 33 | } 34 | args = argparse.Namespace() 35 | for key, value in config.items(): 36 | setattr(args, key, value) 37 | return inference_process(args) 38 | 39 | app = gr.Interface( 40 | fn=predict, 41 | inputs=[ 42 | gr.Image(label="source image (no webp)", type="filepath", format="jpeg"), 43 | gr.Audio(label="source audio", type="filepath"), 44 | gr.Number(label="pose weight", value=1.0), 45 | gr.Number(label="face weight", value=1.0), 46 | gr.Number(label="lip weight", value=1.0), 47 | gr.Number(label="face expand ratio", value=1.2), 48 | ], 49 | outputs=[gr.Video()], 50 | ) 51 | app.launch() 52 | -------------------------------------------------------------------------------- /scripts/extract_meta_info_stage1.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is used to extract meta information from video directories. 4 | 5 | It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path` 6 | specifies the path to the video directory, while the `dataset_name` specifies the name 7 | of the dataset. The module then collects all the video folder paths, and for each video 8 | folder, it checks if a mask path and a face embedding path exist. If they do, it appends 9 | a dictionary containing the image path, mask path, and face embedding path to a list. 10 | 11 | Finally, the module writes the list of dictionaries to a JSON file with the filename 12 | constructed using the `dataset_name`. 13 | 14 | Usage: 15 | python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf 16 | 17 | """ 18 | 19 | import argparse 20 | import json 21 | import os 22 | from pathlib import Path 23 | 24 | import torch 25 | 26 | 27 | def collect_video_folder_paths(root_path: Path) -> list: 28 | """ 29 | Collect all video folder paths from the root path. 30 | 31 | Args: 32 | root_path (Path): The root directory containing video folders. 33 | 34 | Returns: 35 | list: List of video folder paths. 36 | """ 37 | return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()] 38 | 39 | 40 | def construct_meta_info(frames_dir_path: Path) -> dict: 41 | """ 42 | Construct meta information for a given frames directory. 43 | 44 | Args: 45 | frames_dir_path (Path): The path to the frames directory. 46 | 47 | Returns: 48 | dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist. 49 | """ 50 | mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png" 51 | face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt" 52 | 53 | if not os.path.exists(mask_path): 54 | print(f"Mask path not found: {mask_path}") 55 | return None 56 | 57 | if torch.load(face_emb_path) is None: 58 | print(f"Face emb is None: {face_emb_path}") 59 | return None 60 | 61 | return { 62 | "image_path": str(frames_dir_path), 63 | "mask_path": mask_path, 64 | "face_emb": face_emb_path, 65 | } 66 | 67 | 68 | def main(): 69 | """ 70 | Main function to extract meta info for training. 71 | """ 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("-r", "--root_path", type=str, 74 | required=True, help="Root path of the video directories") 75 | parser.add_argument("-n", "--dataset_name", type=str, 76 | required=True, help="Name of the dataset") 77 | parser.add_argument("--meta_info_name", type=str, 78 | help="Name of the meta information file") 79 | 80 | args = parser.parse_args() 81 | 82 | if args.meta_info_name is None: 83 | args.meta_info_name = args.dataset_name 84 | 85 | image_dir = Path(args.root_path) / "images" 86 | output_dir = Path("./data") 87 | output_dir.mkdir(exist_ok=True) 88 | 89 | # Collect all video folder paths 90 | frames_dir_paths = collect_video_folder_paths(image_dir) 91 | 92 | meta_infos = [] 93 | for frames_dir_path in frames_dir_paths: 94 | meta_info = construct_meta_info(frames_dir_path) 95 | if meta_info: 96 | meta_infos.append(meta_info) 97 | 98 | output_file = output_dir / f"{args.meta_info_name}_stage1.json" 99 | with output_file.open("w", encoding="utf-8") as f: 100 | json.dump(meta_infos, f, indent=4) 101 | 102 | print(f"Final data count: {len(meta_infos)}") 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /scripts/extract_meta_info_stage2.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is used to extract meta information from video files and store them in a JSON file. 4 | 5 | The script takes in command line arguments to specify the root path of the video files, 6 | the dataset name, and the name of the meta information file. It then generates a list of 7 | dictionaries containing the meta information for each video file and writes it to a JSON 8 | file with the specified name. 9 | 10 | The meta information includes the path to the video file, the mask path, the face mask 11 | path, the face mask union path, the face mask gaussian path, the lip mask path, the lip 12 | mask union path, the lip mask gaussian path, the separate mask border, the separate mask 13 | face, the separate mask lip, the face embedding path, the audio path, the vocals embedding 14 | base last path, the vocals embedding base all path, the vocals embedding base average 15 | path, the vocals embedding large last path, the vocals embedding large all path, and the 16 | vocals embedding large average path. 17 | 18 | The script checks if the mask path exists before adding the information to the list. 19 | 20 | Usage: 21 | python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name 22 | 23 | Example: 24 | python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info 25 | """ 26 | 27 | import argparse 28 | import json 29 | import os 30 | from pathlib import Path 31 | 32 | import torch 33 | from decord import VideoReader, cpu 34 | from tqdm import tqdm 35 | 36 | 37 | def get_video_paths(root_path: Path, extensions: list) -> list: 38 | """ 39 | Get a list of video paths from the root path with the specified extensions. 40 | 41 | Args: 42 | root_path (Path): The root directory containing video files. 43 | extensions (list): List of file extensions to include. 44 | 45 | Returns: 46 | list: List of video file paths. 47 | """ 48 | return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions] 49 | 50 | 51 | def file_exists(file_path: str) -> bool: 52 | """ 53 | Check if a file exists. 54 | 55 | Args: 56 | file_path (str): The path to the file. 57 | 58 | Returns: 59 | bool: True if the file exists, False otherwise. 60 | """ 61 | return os.path.exists(file_path) 62 | 63 | 64 | def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str: 65 | """ 66 | Construct a new path by replacing the base directory and extension in the original path. 67 | 68 | Args: 69 | video_path (str): The original video path. 70 | base_dir (str): The base directory to be replaced. 71 | new_dir (str): The new directory to replace the base directory. 72 | new_ext (str): The new file extension. 73 | 74 | Returns: 75 | str: The constructed path. 76 | """ 77 | return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext) 78 | 79 | 80 | def extract_meta_info(video_path: str) -> dict: 81 | """ 82 | Extract meta information for a given video file. 83 | 84 | Args: 85 | video_path (str): The path to the video file. 86 | 87 | Returns: 88 | dict: A dictionary containing the meta information for the video. 89 | """ 90 | mask_path = construct_paths( 91 | video_path, "videos", "face_mask", ".png") 92 | sep_mask_border = construct_paths( 93 | video_path, "videos", "sep_pose_mask", ".png") 94 | sep_mask_face = construct_paths( 95 | video_path, "videos", "sep_face_mask", ".png") 96 | sep_mask_lip = construct_paths( 97 | video_path, "videos", "sep_lip_mask", ".png") 98 | face_emb_path = construct_paths( 99 | video_path, "videos", "face_emb", ".pt") 100 | audio_path = construct_paths(video_path, "videos", "audios", ".wav") 101 | vocal_emb_base_all = construct_paths( 102 | video_path, "videos", "audio_emb", ".pt") 103 | 104 | assert_flag = True 105 | 106 | if not file_exists(mask_path): 107 | print(f"Mask path not found: {mask_path}") 108 | assert_flag = False 109 | if not file_exists(sep_mask_border): 110 | print(f"Separate mask border not found: {sep_mask_border}") 111 | assert_flag = False 112 | if not file_exists(sep_mask_face): 113 | print(f"Separate mask face not found: {sep_mask_face}") 114 | assert_flag = False 115 | if not file_exists(sep_mask_lip): 116 | print(f"Separate mask lip not found: {sep_mask_lip}") 117 | assert_flag = False 118 | if not file_exists(face_emb_path): 119 | print(f"Face embedding path not found: {face_emb_path}") 120 | assert_flag = False 121 | if not file_exists(audio_path): 122 | print(f"Audio path not found: {audio_path}") 123 | assert_flag = False 124 | if not file_exists(vocal_emb_base_all): 125 | print(f"Vocal embedding base all not found: {vocal_emb_base_all}") 126 | assert_flag = False 127 | 128 | video_frames = VideoReader(video_path, ctx=cpu(0)) 129 | audio_emb = torch.load(vocal_emb_base_all) 130 | if abs(len(video_frames) - audio_emb.shape[0]) > 3: 131 | print(f"Frame count mismatch for video: {video_path}") 132 | assert_flag = False 133 | 134 | face_emb = torch.load(face_emb_path) 135 | if face_emb is None: 136 | print(f"Face embedding is None for video: {video_path}") 137 | assert_flag = False 138 | 139 | del video_frames, audio_emb 140 | 141 | if assert_flag: 142 | return { 143 | "video_path": str(video_path), 144 | "mask_path": mask_path, 145 | "sep_mask_border": sep_mask_border, 146 | "sep_mask_face": sep_mask_face, 147 | "sep_mask_lip": sep_mask_lip, 148 | "face_emb_path": face_emb_path, 149 | "audio_path": audio_path, 150 | "vocals_emb_base_all": vocal_emb_base_all, 151 | } 152 | return None 153 | 154 | 155 | def main(): 156 | """ 157 | Main function to extract meta info for training. 158 | """ 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("-r", "--root_path", type=str, 161 | required=True, help="Root path of the video files") 162 | parser.add_argument("-n", "--dataset_name", type=str, 163 | required=True, help="Name of the dataset") 164 | parser.add_argument("--meta_info_name", type=str, 165 | help="Name of the meta information file") 166 | 167 | args = parser.parse_args() 168 | 169 | if args.meta_info_name is None: 170 | args.meta_info_name = args.dataset_name 171 | 172 | video_dir = Path(args.root_path) / "videos" 173 | video_paths = get_video_paths(video_dir, [".mp4"]) 174 | 175 | meta_infos = [] 176 | 177 | for video_path in tqdm(video_paths, desc="Extracting meta info"): 178 | meta_info = extract_meta_info(video_path) 179 | if meta_info: 180 | meta_infos.append(meta_info) 181 | 182 | print(f"Final data count: {len(meta_infos)}") 183 | 184 | output_file = Path(f"./data/{args.meta_info_name}_stage2.json") 185 | output_file.parent.mkdir(parents=True, exist_ok=True) 186 | 187 | with output_file.open("w", encoding="utf-8") as f: 188 | json.dump(meta_infos, f, indent=4) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | from comfy.utils import common_upscale,ProgressBar 9 | import folder_paths 10 | import logging 11 | 12 | cur_path = os.path.dirname(os.path.abspath(__file__)) 13 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 14 | 15 | 16 | def tensor2cv(tensor_image): 17 | if len(tensor_image.shape)==4:#bhwc to hwc 18 | tensor_image=tensor_image.squeeze(0) 19 | if tensor_image.is_cuda: 20 | tensor_image = tensor_image.cpu().detach() 21 | tensor_image=tensor_image.numpy() 22 | #反归一化 23 | maxValue=tensor_image.max() 24 | tensor_image=tensor_image*255/maxValue 25 | img_cv2=np.uint8(tensor_image)#32 to uint8 26 | img_cv2=cv2.cvtColor(img_cv2,cv2.COLOR_RGB2BGR) 27 | return img_cv2 28 | 29 | def cvargb2tensor(img): 30 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 31 | img = torch.from_numpy(img.transpose((2, 0, 1))) 32 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 33 | 34 | def cv2tensor(img): 35 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 36 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 37 | img = torch.from_numpy(img.transpose((2, 0, 1))) 38 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 39 | 40 | def images_generator(img_list: list,): 41 | #get img size 42 | sizes = {} 43 | for image_ in img_list: 44 | if isinstance(image_,Image.Image): 45 | count = sizes.get(image_.size, 0) 46 | sizes[image_.size] = count + 1 47 | elif isinstance(image_,np.ndarray): 48 | count = sizes.get(image_.shape[:2][::-1], 0) 49 | sizes[image_.shape[:2][::-1]] = count + 1 50 | else: 51 | raise "unsupport image list,must be pil or cv2!!!" 52 | size = max(sizes.items(), key=lambda x: x[1])[0] 53 | yield size[0], size[1] 54 | 55 | # any to tensor 56 | def load_image(img_in): 57 | if isinstance(img_in, Image.Image): 58 | img_in=img_in.convert("RGB") 59 | i = np.array(img_in, dtype=np.float32) 60 | i = torch.from_numpy(i).div_(255) 61 | if i.shape[0] != size[1] or i.shape[1] != size[0]: 62 | i = torch.from_numpy(i).movedim(-1, 0).unsqueeze(0) 63 | i = common_upscale(i, size[0], size[1], "lanczos", "center") 64 | i = i.squeeze(0).movedim(0, -1).numpy() 65 | return i 66 | elif isinstance(img_in,np.ndarray): 67 | i=cv2.cvtColor(img_in,cv2.COLOR_BGR2RGB).astype(np.float32) 68 | i = torch.from_numpy(i).div_(255) 69 | #print(i.shape) 70 | return i 71 | else: 72 | raise "unsupport image list,must be pil,cv2 or tensor!!!" 73 | 74 | total_images = len(img_list) 75 | processed_images = 0 76 | pbar = ProgressBar(total_images) 77 | images = map(load_image, img_list) 78 | try: 79 | prev_image = next(images) 80 | while True: 81 | next_image = next(images) 82 | yield prev_image 83 | processed_images += 1 84 | pbar.update_absolute(processed_images, total_images) 85 | prev_image = next_image 86 | except StopIteration: 87 | pass 88 | if prev_image is not None: 89 | yield prev_image 90 | 91 | def load_images(img_list: list,): 92 | gen = images_generator(img_list) 93 | (width, height) = next(gen) 94 | images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3))))) 95 | if len(images) == 0: 96 | raise FileNotFoundError(f"No images could be loaded .") 97 | return images 98 | 99 | def tensor2pil(tensor): 100 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 101 | image = Image.fromarray(image_np, mode='RGB') 102 | return image 103 | 104 | def pil2narry(img): 105 | narry = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).unsqueeze(0) 106 | return narry --------------------------------------------------------------------------------