├── .gitignore ├── .gitmodules ├── README.md ├── arguments └── __init__.py ├── assets └── main.png ├── auerror.py ├── data └── .gitkeep ├── data_utils ├── deepspeech_features │ ├── README.md │ ├── deepspeech_features.py │ ├── deepspeech_store.py │ ├── extract_ds_features.py │ ├── extract_wav.py │ └── fea_win.py ├── easyportrait │ ├── create_teeth_mask.py │ ├── local_configs │ │ ├── __base__ │ │ │ ├── datasets │ │ │ │ ├── easyportrait_1024x1024.py │ │ │ │ ├── easyportrait_384x384.py │ │ │ │ └── easyportrait_512x512.py │ │ │ ├── default_runtime.py │ │ │ ├── models │ │ │ │ ├── bisenetv2.py │ │ │ │ ├── fcn_resnet50.py │ │ │ │ ├── fpn_resnet50.py │ │ │ │ ├── lraspp.py │ │ │ │ └── segformer.py │ │ │ └── schedules │ │ │ │ ├── schedule_10k_adamw.py │ │ │ │ ├── schedule_160k_adamw.py │ │ │ │ ├── schedule_20k_adamw.py │ │ │ │ ├── schedule_40k_adamw.py │ │ │ │ └── schedule_80k_adamw.py │ │ └── easyportrait_experiments_v2 │ │ │ ├── bisenet-fp │ │ │ └── bisenetv2-fp.py │ │ │ ├── bisenet-ps │ │ │ └── bisenetv2-ps.py │ │ │ ├── danet-fp │ │ │ └── danet-fp.py │ │ │ ├── danet-ps │ │ │ └── danet-ps.py │ │ │ ├── deeplab-fp │ │ │ └── deeplabv3-fp.py │ │ │ ├── deeplab-ps │ │ │ └── deeplabv3-ps.py │ │ │ ├── fastscnn-fp │ │ │ └── fastscnn-fp.py │ │ │ ├── fastscnn-ps │ │ │ └── fastscnn-ps.py │ │ │ ├── fcn-fp │ │ │ └── fcn-fp.py │ │ │ ├── fcn-ps │ │ │ └── fcn-ps.py │ │ │ ├── fpn-fp │ │ │ └── fpn-fp.py │ │ │ ├── fpn-ps │ │ │ └── fpn-ps.py │ │ │ ├── segformer-fp │ │ │ └── segformer-fp.py │ │ │ └── segformer-ps │ │ │ └── segformer-ps.py │ └── mmseg │ │ ├── .mim │ │ ├── configs │ │ └── tools │ │ ├── __init__.py │ │ ├── apis │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── test.py │ │ └── train.py │ │ ├── core │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── class_names.py │ │ │ ├── eval_hooks.py │ │ │ └── metrics.py │ │ ├── hook │ │ │ ├── __init__.py │ │ │ └── wandblogger_hook.py │ │ ├── optimizers │ │ │ ├── __init__.py │ │ │ └── layer_decay_optimizer_constructor.py │ │ ├── seg │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ └── sampler │ │ │ │ ├── __init__.py │ │ │ │ ├── base_pixel_sampler.py │ │ │ │ └── ohem_pixel_sampler.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── dist_util.py │ │ │ └── misc.py │ │ ├── datasets │ │ ├── __init__.py │ │ ├── ade.py │ │ ├── builder.py │ │ ├── chase_db1.py │ │ ├── cityscapes.py │ │ ├── coco_stuff.py │ │ ├── custom.py │ │ ├── dark_zurich.py │ │ ├── dataset_wrappers.py │ │ ├── drive.py │ │ ├── easy_portrait.py │ │ ├── easy_portrait_face_parsing.py │ │ ├── easy_portrait_portrait_segmentation.py │ │ ├── face.py │ │ ├── hrf.py │ │ ├── imagenets.py │ │ ├── isaid.py │ │ ├── isprs.py │ │ ├── lapa.py │ │ ├── loveda.py │ │ ├── night_driving.py │ │ ├── pascal_context.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ ├── compose.py │ │ │ ├── formating.py │ │ │ ├── formatting.py │ │ │ ├── loading.py │ │ │ ├── test_time_aug.py │ │ │ └── transforms.py │ │ ├── potsdam.py │ │ ├── samplers │ │ │ ├── __init__.py │ │ │ └── distributed_sampler.py │ │ ├── stare.py │ │ └── voc.py │ │ ├── models │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── beit.py │ │ │ ├── bisenetv1.py │ │ │ ├── bisenetv2.py │ │ │ ├── cgnet.py │ │ │ ├── erfnet.py │ │ │ ├── fast_scnn.py │ │ │ ├── hrnet.py │ │ │ ├── icnet.py │ │ │ ├── mae.py │ │ │ ├── mit.py │ │ │ ├── mobilenet_v2.py │ │ │ ├── mobilenet_v3.py │ │ │ ├── mscan.py │ │ │ ├── resnest.py │ │ │ ├── resnet.py │ │ │ ├── resnext.py │ │ │ ├── stdc.py │ │ │ ├── swin.py │ │ │ ├── timm_backbone.py │ │ │ ├── twins.py │ │ │ ├── unet.py │ │ │ └── vit.py │ │ ├── builder.py │ │ ├── decode_heads │ │ │ ├── __init__.py │ │ │ ├── ann_head.py │ │ │ ├── apc_head.py │ │ │ ├── aspp_head.py │ │ │ ├── cascade_decode_head.py │ │ │ ├── cc_head.py │ │ │ ├── da_head.py │ │ │ ├── decode_head.py │ │ │ ├── dm_head.py │ │ │ ├── dnl_head.py │ │ │ ├── dpt_head.py │ │ │ ├── ema_head.py │ │ │ ├── enc_head.py │ │ │ ├── fcn_head.py │ │ │ ├── fpn_head.py │ │ │ ├── gc_head.py │ │ │ ├── ham_head.py │ │ │ ├── isa_head.py │ │ │ ├── knet_head.py │ │ │ ├── lraspp_head.py │ │ │ ├── nl_head.py │ │ │ ├── ocr_head.py │ │ │ ├── point_head.py │ │ │ ├── psa_head.py │ │ │ ├── psp_head.py │ │ │ ├── segformer_head.py │ │ │ ├── segmenter_mask_head.py │ │ │ ├── sep_aspp_head.py │ │ │ ├── sep_fcn_head.py │ │ │ ├── setr_mla_head.py │ │ │ ├── setr_up_head.py │ │ │ ├── stdc_head.py │ │ │ └── uper_head.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ ├── cross_entropy_loss.py │ │ │ ├── dice_loss.py │ │ │ ├── focal_loss.py │ │ │ ├── lovasz_loss.py │ │ │ ├── tversky_loss.py │ │ │ └── utils.py │ │ ├── necks │ │ │ ├── __init__.py │ │ │ ├── featurepyramid.py │ │ │ ├── fpn.py │ │ │ ├── ic_neck.py │ │ │ ├── jpu.py │ │ │ ├── mla_neck.py │ │ │ └── multilevel_neck.py │ │ ├── segmentors │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── cascade_encoder_decoder.py │ │ │ └── encoder_decoder.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── embed.py │ │ │ ├── inverted_residual.py │ │ │ ├── make_divisible.py │ │ │ ├── res_layer.py │ │ │ ├── se_layer.py │ │ │ ├── self_attention_block.py │ │ │ ├── shape_convert.py │ │ │ └── up_conv_block.py │ │ ├── ops │ │ ├── __init__.py │ │ ├── encoding.py │ │ └── wrappers.py │ │ ├── utils │ │ ├── __init__.py │ │ ├── collect_env.py │ │ ├── logger.py │ │ ├── misc.py │ │ ├── set_env.py │ │ └── util_distribution.py │ │ └── version.py ├── face_parsing │ ├── logger.py │ ├── model.py │ ├── resnet.py │ └── test.py ├── face_tracking │ ├── 3DMM │ │ └── .gitkeep │ ├── __init__.py │ ├── convert_BFM.py │ ├── data_loader.py │ ├── face_tracker.py │ ├── facemodel.py │ ├── geo_transform.py │ ├── render_3dmm.py │ ├── render_land.py │ └── util.py ├── hubert.py ├── process.py ├── wav2mel.py ├── wav2mel_hparams.py └── wav2vec.py ├── encoding.py ├── environment.yml ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── metrics.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py └── motion_net.py ├── scripts ├── prepare.sh └── train_xx.sh ├── synthesize_fuse.py ├── train_face.py ├── train_fuse.py ├── train_mouth.py └── utils ├── audio_utils.py ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── loss_utils.py ├── sh_utils.py └── system_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | *.mp4 6 | *.pth 7 | 8 | data_utils/face_tracking/3DMM/* 9 | data_utils/face_parsing/79999_iter.pth 10 | 11 | *.pyc 12 | .vscode 13 | output* 14 | build 15 | gridencoder/gridencoder.egg-info 16 | diff_rasterization/diff_rast.egg-info 17 | diff_rasterization/dist 18 | tensorboard_3d 19 | screenshots 20 | 21 | data/* 22 | !*.gitkeep -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/ashawkey/diff-gaussian-rasterization.git 7 | -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/TalkingGaussian/9bd34f732e30cd94c0a9a087a125d2dda228e258/assets/main.png -------------------------------------------------------------------------------- /auerror.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The AUE evaluation is based on OpenFace (https://github.com/TadasBaltrusaitis/OpenFace). 3 | First, use OpenFace's FeatureExtraction to process the reconstructed video ("A_generated.mp4" for example) 4 | and the corresponding GT ("A_GT.mp4" for example) respectively. 5 | Then, run "python auerror.py A_generated A_GT" 6 | 7 | Default directory structure is: 8 | 9 | |--- auerror.py 10 | |--- OpenFace_2.2.0_win_x64 11 | |--- proceessed 12 | |--- ... 13 | 14 | ''' 15 | 16 | 17 | import pandas as pd 18 | import os 19 | import sys 20 | import numpy as np 21 | 22 | AUitems = [' AU01_r',' AU02_r', ' AU04_r', ' AU05_r', ' AU06_r', ' AU07_r', ' AU09_r', ' AU10_r', ' AU12_r', ' AU14_r', ' AU15_r', ' AU17_r', ' AU20_r', ' AU23_r', ' AU25_r', ' AU26_r', ' AU45_r'] 23 | 24 | df_1 = pd.read_csv(os.path.join('./OpenFace_2.2.0_win_x64/processed', sys.argv[1]+'.csv'))[AUitems] 25 | df_2 = pd.read_csv(os.path.join('./OpenFace_2.2.0_win_x64/processed', sys.argv[2]+'.csv'))[AUitems] 26 | 27 | error = (df_1-df_2)**2 28 | print(error.mean().sum()) 29 | 30 | 31 | AUitems_lower = [' AU10_r', ' AU12_r', ' AU14_r', ' AU15_r', ' AU17_r', ' AU20_r', ' AU23_r', ' AU25_r', ' AU26_r'] 32 | AUitems_upper = [' AU01_r',' AU02_r', ' AU04_r', ' AU05_r', ' AU06_r', ' AU07_r', ' AU09_r', ' AU45_r'] 33 | 34 | df_1 = pd.read_csv(os.path.join('./OpenFace_2.2.0_win_x64/processed', sys.argv[1]+'.csv'))[AUitems] 35 | df_2 = pd.read_csv(os.path.join('./OpenFace_2.2.0_win_x64/processed', sys.argv[2]+'.csv'))[AUitems] 36 | 37 | error_l = (df_1[AUitems_lower]-df_2[AUitems_lower])**2 38 | error_u = (df_1[AUitems_upper]-df_2[AUitems_upper])**2 39 | 40 | print('l:', error_l.mean().sum(), 'u', error_u.mean().sum()) -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/TalkingGaussian/9bd34f732e30cd94c0a9a087a125d2dda228e258/data/.gitkeep -------------------------------------------------------------------------------- /data_utils/deepspeech_features/README.md: -------------------------------------------------------------------------------- 1 | # Routines for DeepSpeech features processing 2 | Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model. 3 | 4 | ## Installation 5 | 6 | ``` 7 | pip3 install -r requirements.txt 8 | ``` 9 | 10 | ## Usage 11 | 12 | Generate wav files: 13 | ``` 14 | python3 extract_wav.py --in-video= 15 | ``` 16 | 17 | Generate files with DeepSpeech features: 18 | ``` 19 | python3 extract_ds_features.py --input= 20 | ``` 21 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/extract_wav.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for extracting audio (16-bit, mono, 22000 Hz) from video file. 3 | """ 4 | 5 | import os 6 | import argparse 7 | import subprocess 8 | 9 | 10 | def parse_args(): 11 | """ 12 | Create python script parameters. 13 | 14 | Returns 15 | ------- 16 | ArgumentParser 17 | Resulted args. 18 | """ 19 | parser = argparse.ArgumentParser( 20 | description="Extract audio from video file", 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument( 23 | "--in-video", 24 | type=str, 25 | required=True, 26 | help="path to input video file or directory") 27 | parser.add_argument( 28 | "--out-audio", 29 | type=str, 30 | help="path to output audio file") 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def extract_audio(in_video, 37 | out_audio): 38 | """ 39 | Real extract audio from video file. 40 | 41 | Parameters 42 | ---------- 43 | in_video : str 44 | Path to input video file. 45 | out_audio : str 46 | Path to output audio file. 47 | """ 48 | if not out_audio: 49 | file_stem, _ = os.path.splitext(in_video) 50 | out_audio = file_stem + ".wav" 51 | # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" 52 | # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" 53 | # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" 54 | command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" 55 | subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) 56 | 57 | 58 | def main(): 59 | """ 60 | Main body of script. 61 | """ 62 | args = parse_args() 63 | in_video = os.path.expanduser(args.in_video) 64 | if not os.path.exists(in_video): 65 | raise Exception("Input file/directory doesn't exist: {}".format(in_video)) 66 | if os.path.isfile(in_video): 67 | extract_audio( 68 | in_video=in_video, 69 | out_audio=args.out_audio) 70 | else: 71 | video_file_paths = [] 72 | for file_name in os.listdir(in_video): 73 | if not os.path.isfile(os.path.join(in_video, file_name)): 74 | continue 75 | _, file_ext = os.path.splitext(file_name) 76 | if file_ext.lower() in (".mp4", ".mkv", ".avi"): 77 | video_file_path = os.path.join(in_video, file_name) 78 | video_file_paths.append(video_file_path) 79 | video_file_paths = sorted(video_file_paths) 80 | for video_file_path in video_file_paths: 81 | extract_audio( 82 | in_video=video_file_path, 83 | out_audio="") 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /data_utils/deepspeech_features/fea_win.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | net_output = np.load('french.ds.npy').reshape(-1, 29) 4 | win_size = 16 5 | zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) 6 | net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) 7 | windows = [] 8 | for window_index in range(0, net_output.shape[0] - win_size, 2): 9 | windows.append(net_output[window_index:window_index + win_size]) 10 | print(np.array(windows).shape) 11 | np.save('aud_french.npy', np.array(windows)) 12 | -------------------------------------------------------------------------------- /data_utils/easyportrait/create_teeth_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | 4 | from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot 5 | 6 | import os 7 | import glob 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | def main(): 12 | parser = ArgumentParser() 13 | parser.add_argument('datset', help='Image file') 14 | parser.add_argument('--config', default="./data_utils/easyportrait/local_configs/easyportrait_experiments_v2/fpn-fp/fpn-fp.py", help='Config file') 15 | parser.add_argument('--checkpoint', default="./data_utils/easyportrait/fpn-fp-512.pth", help='Checkpoint file') 16 | 17 | args = parser.parse_args() 18 | 19 | # build the model from a config file and a checkpoint file 20 | model = init_segmentor(args.config, args.checkpoint, device='cuda:0') 21 | 22 | # test a single image 23 | dataset_path = os.path.join(args.datset, 'ori_imgs') 24 | out_path = os.path.join(args.datset, 'teeth_mask') 25 | os.makedirs(out_path, exist_ok=True) 26 | 27 | for file in tqdm(glob.glob(os.path.join(dataset_path, '*.jpg'))): 28 | result = inference_segmentor(model, file) 29 | result[0][result[0]!=7] = 0 30 | np.save(file.replace('jpg', 'npy').replace('ori_imgs', 'teeth_mask'), result[0].astype(np.bool_)) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_1024x1024.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'EasyPortraitDataset' 3 | data_root = 'path/to/data/EasyPortrait' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations'), 10 | dict(type='Pad', size=(1920, 1920), pad_val=0, seg_pad_val=255), 11 | dict(type='Resize', img_scale=(1024, 1024)), 12 | 13 | # We don't use RandomFlip, but need it in the code to fix error: https://github.com/open-mmlab/mmsegmentation/issues/231 14 | dict(type='RandomFlip', prob=0.0), 15 | dict(type='PhotoMetricDistortion', 16 | brightness_delta=16, 17 | contrast_range=(0.5, 1.0), 18 | saturation_range=(0.5, 1.0), 19 | hue_delta=9), 20 | dict(type='Normalize', **img_norm_cfg), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | 25 | test_pipeline = [ 26 | dict(type='LoadImageFromFile'), 27 | dict( 28 | type='MultiScaleFlipAug', 29 | img_scale=(1024, 1024), 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='Normalize', **img_norm_cfg), 34 | dict(type='ImageToTensor', keys=['img']), 35 | dict(type='Collect', keys=['img']), 36 | ]) 37 | ] 38 | 39 | data = dict( 40 | samples_per_gpu=4, 41 | workers_per_gpu=4, 42 | train=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | img_dir='images/train', 46 | ann_dir='annotations/train', 47 | pipeline=train_pipeline), 48 | val=dict( 49 | type=dataset_type, 50 | data_root=data_root, 51 | img_dir='images/val', 52 | ann_dir='annotations/val', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='images/test', 58 | ann_dir='annotations/test', 59 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_384x384.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'EasyPortraitDataset' 3 | data_root = 'path/to/data/EasyPortrait' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations'), 10 | dict(type='Pad', size=(1920, 1920), pad_val=0, seg_pad_val=255), 11 | dict(type='Resize', img_scale=(384, 384)), 12 | 13 | # We don't use RandomFlip, but need it in the code to fix error: https://github.com/open-mmlab/mmsegmentation/issues/231 14 | dict(type='RandomFlip', prob=0.0), 15 | dict(type='PhotoMetricDistortion', 16 | brightness_delta=16, 17 | contrast_range=(0.5, 1.0), 18 | saturation_range=(0.5, 1.0), 19 | hue_delta=9), 20 | dict(type='Normalize', **img_norm_cfg), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | 25 | test_pipeline = [ 26 | dict(type='LoadImageFromFile'), 27 | dict( 28 | type='MultiScaleFlipAug', 29 | img_scale=(384, 384), 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='Normalize', **img_norm_cfg), 34 | dict(type='ImageToTensor', keys=['img']), 35 | dict(type='Collect', keys=['img']), 36 | ]) 37 | ] 38 | 39 | data = dict( 40 | samples_per_gpu=4, 41 | workers_per_gpu=4, 42 | train=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | img_dir='images/train', 46 | ann_dir='annotations/train', 47 | pipeline=train_pipeline), 48 | val=dict( 49 | type=dataset_type, 50 | data_root=data_root, 51 | img_dir='images/val', 52 | ann_dir='annotations/val', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='images/test', 58 | ann_dir='annotations/test', 59 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/datasets/easyportrait_512x512.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'EasyPortraitDataset' 3 | data_root = 'path/to/data/EasyPortrait' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations'), 10 | dict(type='Pad', size=(1920, 1920), pad_val=0, seg_pad_val=255), 11 | dict(type='Resize', img_scale=(512, 512)), 12 | 13 | # We don't use RandomFlip, but need it in the code to fix error: https://github.com/open-mmlab/mmsegmentation/issues/231 14 | dict(type='RandomFlip', prob=0.0), 15 | dict(type='PhotoMetricDistortion', 16 | brightness_delta=16, 17 | contrast_range=(0.5, 1.0), 18 | saturation_range=(0.5, 1.0), 19 | hue_delta=9), 20 | dict(type='Normalize', **img_norm_cfg), 21 | dict(type='DefaultFormatBundle'), 22 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 23 | ] 24 | 25 | test_pipeline = [ 26 | dict(type='LoadImageFromFile'), 27 | dict( 28 | type='MultiScaleFlipAug', 29 | img_scale=(512, 512), 30 | flip=False, 31 | transforms=[ 32 | dict(type='Resize', keep_ratio=True), 33 | dict(type='Normalize', **img_norm_cfg), 34 | dict(type='ImageToTensor', keys=['img']), 35 | dict(type='Collect', keys=['img']), 36 | ]) 37 | ] 38 | 39 | data = dict( 40 | samples_per_gpu=4, 41 | workers_per_gpu=4, 42 | train=dict( 43 | type=dataset_type, 44 | data_root=data_root, 45 | img_dir='images/train', 46 | ann_dir='annotations/train', 47 | pipeline=train_pipeline), 48 | val=dict( 49 | type=dataset_type, 50 | data_root=data_root, 51 | img_dir='images/val', 52 | ann_dir='annotations/val', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='images/test', 58 | ann_dir='annotations/test', 59 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/models/bisenetv2.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='BiSeNetV2', 8 | detail_channels=(64, 64, 128), 9 | semantic_channels=(16, 32, 64, 128), 10 | semantic_expansion_ratio=6, 11 | bga_channels=128, 12 | out_indices=(0, 1, 2, 3, 4), 13 | init_cfg=None, 14 | align_corners=False), 15 | decode_head=dict( 16 | type='FCNHead', 17 | in_channels=128, 18 | in_index=0, 19 | channels=1024, 20 | num_convs=1, 21 | concat_input=False, 22 | dropout_ratio=0.1, 23 | num_classes=19, 24 | norm_cfg=norm_cfg, 25 | align_corners=False, 26 | loss_decode=dict( 27 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | auxiliary_head=[ 29 | dict( 30 | type='FCNHead', 31 | in_channels=16, 32 | channels=16, 33 | num_convs=2, 34 | num_classes=19, 35 | in_index=1, 36 | norm_cfg=norm_cfg, 37 | concat_input=False, 38 | align_corners=False, 39 | loss_decode=dict( 40 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 41 | dict( 42 | type='FCNHead', 43 | in_channels=32, 44 | channels=64, 45 | num_convs=2, 46 | num_classes=19, 47 | in_index=2, 48 | norm_cfg=norm_cfg, 49 | concat_input=False, 50 | align_corners=False, 51 | loss_decode=dict( 52 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 53 | dict( 54 | type='FCNHead', 55 | in_channels=64, 56 | channels=256, 57 | num_convs=2, 58 | num_classes=19, 59 | in_index=3, 60 | norm_cfg=norm_cfg, 61 | concat_input=False, 62 | align_corners=False, 63 | loss_decode=dict( 64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 65 | dict( 66 | type='FCNHead', 67 | in_channels=128, 68 | channels=1024, 69 | num_convs=2, 70 | num_classes=19, 71 | in_index=4, 72 | norm_cfg=norm_cfg, 73 | concat_input=False, 74 | align_corners=False, 75 | loss_decode=dict( 76 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 77 | ], 78 | # model training and testing settings 79 | train_cfg=dict(), 80 | test_cfg=dict(mode='whole')) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/models/fcn_resnet50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 2, 4), 12 | strides=(1, 2, 1, 1), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | decode_head=dict( 18 | type='FCNHead', 19 | in_channels=2048, 20 | in_index=3, 21 | channels=512, 22 | num_convs=2, 23 | concat_input=True, 24 | dropout_ratio=0.1, 25 | num_classes=19, 26 | norm_cfg=norm_cfg, 27 | align_corners=False, 28 | loss_decode=dict( 29 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 30 | auxiliary_head=dict( 31 | type='FCNHead', 32 | in_channels=1024, 33 | in_index=2, 34 | channels=256, 35 | num_convs=1, 36 | concat_input=False, 37 | dropout_ratio=0.1, 38 | num_classes=19, 39 | norm_cfg=norm_cfg, 40 | align_corners=False, 41 | loss_decode=dict( 42 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 43 | # model training and testing settings 44 | train_cfg=dict(), 45 | test_cfg=dict(mode='whole')) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/models/fpn_resnet50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 1, 1), 12 | strides=(1, 2, 2, 2), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[256, 512, 1024, 2048], 20 | out_channels=256, 21 | num_outs=4), 22 | decode_head=dict( 23 | type='FPNHead', 24 | in_channels=[256, 256, 256, 256], 25 | in_index=[0, 1, 2, 3], 26 | feature_strides=[4, 8, 16, 32], 27 | channels=128, 28 | dropout_ratio=0.1, 29 | num_classes=19, 30 | norm_cfg=norm_cfg, 31 | align_corners=False, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | # model training and testing settings 35 | train_cfg=dict(), 36 | test_cfg=dict(mode='whole')) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/models/lraspp.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | backbone=dict( 6 | type='MobileNetV3', 7 | arch='large', 8 | out_indices=(1, 3, 16), 9 | norm_cfg=norm_cfg), 10 | decode_head=dict( 11 | type='LRASPPHead', 12 | in_channels=(16, 24, 960), 13 | in_index=(0, 1, 2), 14 | channels=128, 15 | input_transform='multiple_select', 16 | dropout_ratio=0.1, 17 | num_classes=19, 18 | norm_cfg=norm_cfg, 19 | act_cfg=dict(type='ReLU'), 20 | align_corners=False, 21 | loss_decode=dict( 22 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 23 | # model training and testing settings 24 | train_cfg=dict(), 25 | test_cfg=dict(mode='whole')) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/models/segformer.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='MixVisionTransformer', 8 | in_channels=3, 9 | embed_dims=32, 10 | num_stages=4, 11 | num_layers=[2, 2, 2, 2], 12 | num_heads=[1, 2, 5, 8], 13 | patch_sizes=[7, 3, 3, 3], 14 | sr_ratios=[8, 4, 2, 1], 15 | out_indices=(0, 1, 2, 3), 16 | mlp_ratio=4, 17 | qkv_bias=True, 18 | drop_rate=0.0, 19 | attn_drop_rate=0.0, 20 | drop_path_rate=0.1), 21 | decode_head=dict( 22 | type='SegformerHead', 23 | in_channels=[32, 64, 160, 256], 24 | in_index=[0, 1, 2, 3], 25 | channels=256, 26 | dropout_ratio=0.1, 27 | num_classes=19, 28 | norm_cfg=norm_cfg, 29 | align_corners=False, 30 | loss_decode=dict( 31 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 32 | # model training and testing settings 33 | train_cfg=dict(), 34 | test_cfg=dict(mode='whole')) -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/schedules/schedule_10k_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | 5 | # learning policy 6 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 7 | 8 | # runtime settings 9 | runner = dict(type='IterBasedRunner', max_iters=10000) 10 | checkpoint_config = dict(by_epoch=False, interval=2000) 11 | evaluation = dict(interval=2000, metric='mIoU') -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/schedules/schedule_160k_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/schedules/schedule_20k_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | 5 | # learning policy 6 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 7 | 8 | # runtime settings 9 | runner = dict(type='IterBasedRunner', max_iters=20000) 10 | checkpoint_config = dict(by_epoch=False, interval=2000) 11 | evaluation = dict(interval=2000, metric='mIoU') -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/schedules/schedule_40k_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') -------------------------------------------------------------------------------- /data_utils/easyportrait/local_configs/__base__/schedules/schedule_80k_adamw.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/.mim/configs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/TalkingGaussian/9bd34f732e30cd94c0a9a087a125d2dda228e258/data_utils/easyportrait/mmseg/.mim/configs -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/.mim/tools: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/TalkingGaussian/9bd34f732e30cd94c0a9a087a125d2dda228e258/data_utils/easyportrait/mmseg/.mim/tools -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from packaging.version import parse 6 | 7 | from .version import __version__, version_info 8 | 9 | MMCV_MIN = '1.3.13' 10 | MMCV_MAX = '1.8.0' 11 | 12 | 13 | def digit_version(version_str: str, length: int = 4): 14 | """Convert a version string into a tuple of integers. 15 | 16 | This method is usually used for comparing two versions. For pre-release 17 | versions: alpha < beta < rc. 18 | 19 | Args: 20 | version_str (str): The version string. 21 | length (int): The maximum number of version levels. Default: 4. 22 | 23 | Returns: 24 | tuple[int]: The version info in digits (integers). 25 | """ 26 | version = parse(version_str) 27 | assert version.release, f'failed to parse version {version_str}' 28 | release = list(version.release) 29 | release = release[:length] 30 | if len(release) < length: 31 | release = release + [0] * (length - len(release)) 32 | if version.is_prerelease: 33 | mapping = {'a': -3, 'b': -2, 'rc': -1} 34 | val = -4 35 | # version.pre can be None 36 | if version.pre: 37 | if version.pre[0] not in mapping: 38 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 39 | 'version checking may go wrong') 40 | else: 41 | val = mapping[version.pre[0]] 42 | release.extend([val, version.pre[-1]]) 43 | else: 44 | release.extend([val, 0]) 45 | 46 | elif version.is_postrelease: 47 | release.extend([1, version.post]) 48 | else: 49 | release.extend([0, 0]) 50 | return tuple(release) 51 | 52 | 53 | mmcv_min_version = digit_version(MMCV_MIN) 54 | mmcv_max_version = digit_version(MMCV_MAX) 55 | mmcv_version = digit_version(mmcv.__version__) 56 | 57 | 58 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ 59 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 60 | f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.' 61 | 62 | __all__ = ['__version__', 'version_info', 'digit_version'] 63 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import (get_root_logger, init_random_seed, set_random_seed, 5 | train_segmentor) 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 9 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 10 | 'show_result_pyplot', 'init_random_seed' 11 | ] 12 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import (OPTIMIZER_BUILDERS, build_optimizer, 3 | build_optimizer_constructor) 4 | from .evaluation import * # noqa: F401, F403 5 | from .hook import * # noqa: F401, F403 6 | from .optimizers import * # noqa: F401, F403 7 | from .seg import * # noqa: F401, F403 8 | from .utils import * # noqa: F401, F403 9 | 10 | __all__ = [ 11 | 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' 12 | ] 13 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS 5 | from mmcv.utils import Registry, build_from_cfg 6 | 7 | OPTIMIZER_BUILDERS = Registry( 8 | 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) 9 | 10 | 11 | def build_optimizer_constructor(cfg): 12 | constructor_type = cfg.get('type') 13 | if constructor_type in OPTIMIZER_BUILDERS: 14 | return build_from_cfg(cfg, OPTIMIZER_BUILDERS) 15 | elif constructor_type in MMCV_OPTIMIZER_BUILDERS: 16 | return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) 17 | else: 18 | raise KeyError(f'{constructor_type} is not registered ' 19 | 'in the optimizer builder registry.') 20 | 21 | 22 | def build_optimizer(model, cfg): 23 | optimizer_cfg = copy.deepcopy(cfg) 24 | constructor_type = optimizer_cfg.pop('constructor', 25 | 'DefaultOptimizerConstructor') 26 | paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) 27 | optim_constructor = build_optimizer_constructor( 28 | dict( 29 | type=constructor_type, 30 | optimizer_cfg=optimizer_cfg, 31 | paramwise_cfg=paramwise_cfg)) 32 | optimizer = optim_constructor(model) 33 | return optimizer 34 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .class_names import get_classes, get_palette 3 | from .eval_hooks import DistEvalHook, EvalHook 4 | from .metrics import (eval_metrics, intersect_and_union, mean_dice, 5 | mean_fscore, mean_iou, pre_eval_to_metrics) 6 | 7 | __all__ = [ 8 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 9 | 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', 10 | 'intersect_and_union' 11 | ] 12 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/hook/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .wandblogger_hook import MMSegWandbHook 3 | 4 | __all__ = ['MMSegWandbHook'] 5 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .layer_decay_optimizer_constructor import ( 3 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) 4 | 5 | __all__ = [ 6 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' 7 | ] 8 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_pixel_sampler 3 | from .sampler import BasePixelSampler, OHEMPixelSampler 4 | 5 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | PIXEL_SAMPLERS = Registry('pixel sampler') 5 | 6 | 7 | def build_pixel_sampler(cfg, **default_args): 8 | """Build pixel sampler for segmentation map.""" 9 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 10 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_pixel_sampler import BasePixelSampler 3 | from .ohem_pixel_sampler import OHEMPixelSampler 4 | 5 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..builder import PIXEL_SAMPLERS 7 | from .base_pixel_sampler import BasePixelSampler 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super(OHEMPixelSampler, self).__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | if not isinstance(self.context.loss_decode, nn.ModuleList): 67 | losses_decode = [self.context.loss_decode] 68 | else: 69 | losses_decode = self.context.loss_decode 70 | losses = 0.0 71 | for loss_module in losses_decode: 72 | losses += loss_module( 73 | seg_logit, 74 | seg_label, 75 | weight=None, 76 | ignore_index=self.context.ignore_index, 77 | reduction_override='none') 78 | 79 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 80 | _, sort_indices = losses[valid_mask].sort(descending=True) 81 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 82 | 83 | seg_weight[valid_mask] = valid_seg_weight 84 | 85 | return seg_weight 86 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_util import check_dist_init, sync_random_seed 3 | from .misc import add_prefix 4 | 5 | __all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed'] 6 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def check_dist_init(): 9 | return dist.is_available() and dist.is_initialized() 10 | 11 | 12 | def sync_random_seed(seed=None, device='cuda'): 13 | """Make sure different ranks share the same seed. All workers must call 14 | this function, otherwise it will deadlock. This method is generally used in 15 | `DistributedSampler`, because the seed should be identical across all 16 | processes in the distributed group. 17 | 18 | In distributed sampling, different ranks should sample non-overlapped 19 | data in the dataset. Therefore, this function is used to make sure that 20 | each rank shuffles the data indices in the same order based 21 | on the same seed. Then different ranks could use different indices 22 | to select non-overlapped data from the same data list. 23 | 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | 32 | if seed is None: 33 | seed = np.random.randint(2**31) 34 | assert isinstance(seed, int) 35 | 36 | rank, world_size = get_dist_info() 37 | 38 | if world_size == 1: 39 | return seed 40 | 41 | if rank == 0: 42 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 43 | else: 44 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 45 | dist.broadcast(random_num, src=0) 46 | return random_num.item() 47 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def add_prefix(inputs, prefix): 3 | """Add prefix for dict. 4 | 5 | Args: 6 | inputs (dict): The input dict with str keys. 7 | prefix (str): The prefix to add. 8 | 9 | Returns: 10 | 11 | dict: The dict with keys updated with ``prefix``. 12 | """ 13 | 14 | outputs = dict() 15 | for name, value in inputs.items(): 16 | outputs[f'{prefix}.{name}'] = value 17 | 18 | return outputs 19 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ade import ADE20KDataset 3 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 4 | from .chase_db1 import ChaseDB1Dataset 5 | from .cityscapes import CityscapesDataset 6 | from .coco_stuff import COCOStuffDataset 7 | from .custom import CustomDataset 8 | from .dark_zurich import DarkZurichDataset 9 | from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, 10 | RepeatDataset) 11 | from .drive import DRIVEDataset 12 | from .face import FaceOccludedDataset 13 | from .hrf import HRFDataset 14 | from .imagenets import (ImageNetSDataset, LoadImageNetSAnnotations, 15 | LoadImageNetSImageFromFile) 16 | from .isaid import iSAIDDataset 17 | from .isprs import ISPRSDataset 18 | from .loveda import LoveDADataset 19 | from .night_driving import NightDrivingDataset 20 | from .pascal_context import PascalContextDataset, PascalContextDataset59 21 | from .potsdam import PotsdamDataset 22 | from .stare import STAREDataset 23 | from .voc import PascalVOCDataset 24 | from .easy_portrait import EasyPortraitDataset 25 | from .lapa import LaPaDataset 26 | from .easy_portrait_face_parsing import EasyPortraitFPDataset, EasyPortraitFPDatasetCross 27 | from .easy_portrait_portrait_segmentation import EasyPortraitPSDataset 28 | 29 | __all__ = [ 30 | 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 31 | 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 32 | 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 33 | 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 34 | 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', 35 | 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', 36 | 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'FaceOccludedDataset', 37 | 'ImageNetSDataset', 'LoadImageNetSAnnotations', 38 | 'LoadImageNetSImageFromFile', 'EasyPortraitDataset', 'LaPaDataset', 39 | 'EasyPortraitFPDataset', 'EasyPortraitPSDataset', 'EasyPortraitFPDatasetCross', 40 | ] 41 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class ChaseDB1Dataset(CustomDataset): 9 | """Chase_db1 dataset. 10 | 11 | In segmentation map annotation for Chase_db1, 0 stands for background, 12 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False. 13 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_1stHO.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(ChaseDB1Dataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_1stHO.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/dark_zurich.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class DarkZurichDataset(CityscapesDataset): 8 | """DarkZurichDataset dataset.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__( 12 | img_suffix='_rgb_anon.png', 13 | seg_map_suffix='_gt_labelTrainIds.png', 14 | **kwargs) 15 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class DRIVEDataset(CustomDataset): 9 | """DRIVE dataset. 10 | 11 | In segmentation map annotation for DRIVE, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_manual1.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(DRIVEDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_manual1.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/easy_portrait.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class EasyPortraitDataset(CustomDataset): 13 | """EasyPortrait dataset. 14 | 15 | In segmentation map annotation for EasyPortrait, 0 stands for background, 16 | which is included in 9 categories. ``reduce_zero_label`` is fixed to False. 17 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 18 | '.png'. 19 | """ 20 | 21 | CLASSES = ('background', 'person', 'skin', 22 | 'left brow', 'right brow', 'left eye', 23 | 'right eye', 'lips', 'teeth') 24 | 25 | PALETTE = [[0, 0, 0], [223, 87, 188], [160, 221, 255], 26 | [130, 106, 237], [200, 121, 255], [255, 183, 255], 27 | [0, 144, 193], [113, 137, 255], [230, 232, 230]] 28 | 29 | def __init__(self, **kwargs): 30 | super(EasyPortraitDataset, self).__init__( 31 | img_suffix='.jpg', 32 | seg_map_suffix='.png', 33 | reduce_zero_label=False, 34 | **kwargs) 35 | assert self.file_client.exists(self.img_dir) -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/easy_portrait_face_parsing.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class EasyPortraitFPDataset(CustomDataset): 13 | """EasyPortraitFPDataset dataset. 14 | 15 | In segmentation map annotation for EasyPortrait, 0 stands for background, 16 | which is included in 9 categories. ``reduce_zero_label`` is fixed to False. 17 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 18 | '.png'. 19 | """ 20 | 21 | CLASSES = ('background', 'skin', 22 | 'left brow', 'right brow', 'left eye', 23 | 'right eye', 'lips', 'teeth') 24 | 25 | PALETTE = [[0, 0, 0], [160, 221, 255], 26 | [130, 106, 237], [200, 121, 255], [255, 183, 255], 27 | [0, 144, 193], [113, 137, 255], [230, 232, 230]] 28 | 29 | def __init__(self, **kwargs): 30 | super(EasyPortraitFPDataset, self).__init__( 31 | img_suffix='.jpg', 32 | seg_map_suffix='.png', 33 | reduce_zero_label=False, 34 | **kwargs) 35 | assert self.file_client.exists(self.img_dir) 36 | 37 | @DATASETS.register_module() 38 | class EasyPortraitFPDatasetCross(CustomDataset): 39 | """EasyPortraitFPDatasetCross dataset. 40 | 41 | In segmentation map annotation for EasyPortrait, 0 stands for background, 42 | which is included in 9 categories. ``reduce_zero_label`` is fixed to False. 43 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 44 | '.png'. 45 | """ 46 | 47 | CLASSES = ('background', 'left brow', 'right brow', 'left eye', 'right eye', 'lips') 48 | PALETTE = [[0, 0, 0], [160, 221, 255], 49 | [130, 106, 237], [200, 121, 255], [255, 183, 255], 50 | [0, 144, 193]] 51 | 52 | def __init__(self, **kwargs): 53 | super(EasyPortraitFPDatasetCross, self).__init__( 54 | img_suffix='.jpg', 55 | seg_map_suffix='.png', 56 | reduce_zero_label=False, 57 | **kwargs) 58 | assert self.file_client.exists(self.img_dir) -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/easy_portrait_portrait_segmentation.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class EasyPortraitPSDataset(CustomDataset): 13 | """EasyPortrait dataset. 14 | 15 | In segmentation map annotation for EasyPortrait, 0 stands for background, 16 | which is included in 9 categories. ``reduce_zero_label`` is fixed to False. 17 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 18 | '.png'. 19 | """ 20 | 21 | CLASSES = ('background', 'person') 22 | 23 | PALETTE = [[0, 0, 0], [160, 221, 255]] 24 | 25 | def __init__(self, **kwargs): 26 | super(EasyPortraitPSDataset, self).__init__( 27 | img_suffix='.jpg', 28 | seg_map_suffix='.png', 29 | reduce_zero_label=False, 30 | **kwargs) 31 | assert self.file_client.exists(self.img_dir) -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/face.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class FaceOccludedDataset(CustomDataset): 10 | """Face Occluded dataset. 11 | 12 | Args: 13 | split (str): Split txt file for Pascal VOC. 14 | """ 15 | 16 | CLASSES = ('background', 'face') 17 | 18 | PALETTE = [[0, 0, 0], [128, 0, 0]] 19 | 20 | def __init__(self, split, **kwargs): 21 | super(FaceOccludedDataset, self).__init__( 22 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 23 | assert osp.exists(self.img_dir) and self.split is not None 24 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class HRFDataset(CustomDataset): 9 | """HRF dataset. 10 | 11 | In segmentation map annotation for HRF, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(HRFDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/isaid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | import mmcv 4 | from mmcv.utils import print_log 5 | 6 | from ..utils import get_root_logger 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class iSAIDDataset(CustomDataset): 13 | """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images 14 | In segmentation map annotation for iSAID dataset, which is included 15 | in 16 categories. ``reduce_zero_label`` is fixed to False. The 16 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 17 | '_manual1.png'. 18 | """ 19 | 20 | CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond', 21 | 'tennis_court', 'basketball_court', 'Ground_Track_Field', 22 | 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', 23 | 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', 24 | 'Harbor') 25 | 26 | PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], 27 | [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], 28 | [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], 29 | [0, 127, 191], [0, 127, 255], [0, 100, 155]] 30 | 31 | def __init__(self, **kwargs): 32 | super(iSAIDDataset, self).__init__( 33 | img_suffix='.png', 34 | seg_map_suffix='.png', 35 | ignore_index=255, 36 | **kwargs) 37 | assert self.file_client.exists(self.img_dir) 38 | 39 | def load_annotations(self, 40 | img_dir, 41 | img_suffix, 42 | ann_dir, 43 | seg_map_suffix=None, 44 | split=None): 45 | """Load annotation from directory. 46 | 47 | Args: 48 | img_dir (str): Path to image directory 49 | img_suffix (str): Suffix of images. 50 | ann_dir (str|None): Path to annotation directory. 51 | seg_map_suffix (str|None): Suffix of segmentation maps. 52 | split (str|None): Split txt file. If split is specified, only file 53 | with suffix in the splits will be loaded. Otherwise, all images 54 | in img_dir/ann_dir will be loaded. Default: None 55 | 56 | Returns: 57 | list[dict]: All image info of dataset. 58 | """ 59 | 60 | img_infos = [] 61 | if split is not None: 62 | with open(split) as f: 63 | for line in f: 64 | name = line.strip() 65 | img_info = dict(filename=name + img_suffix) 66 | if ann_dir is not None: 67 | ann_name = name + '_instance_color_RGB' 68 | seg_map = ann_name + seg_map_suffix 69 | img_info['ann'] = dict(seg_map=seg_map) 70 | img_infos.append(img_info) 71 | else: 72 | for img in mmcv.scandir(img_dir, img_suffix, recursive=True): 73 | img_info = dict(filename=img) 74 | if ann_dir is not None: 75 | seg_img = img 76 | seg_map = seg_img.replace( 77 | img_suffix, '_instance_color_RGB' + seg_map_suffix) 78 | img_info['ann'] = dict(seg_map=seg_map) 79 | img_infos.append(img_info) 80 | 81 | print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) 82 | return img_infos 83 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/isprs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .custom import CustomDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class ISPRSDataset(CustomDataset): 8 | """ISPRS dataset. 9 | 10 | In segmentation map annotation for LoveDA, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', 15 | 'car', 'clutter') 16 | 17 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]] 19 | 20 | def __init__(self, **kwargs): 21 | super(ISPRSDataset, self).__init__( 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=True, 25 | **kwargs) 26 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/lapa.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class LaPaDataset(CustomDataset): 13 | """EasyPortrait dataset. 14 | 15 | In segmentation map annotation for LaPa, 0 stands for background, 16 | which is included in 11 categories. ``reduce_zero_label`` is fixed to False. 17 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 18 | '.png'. 19 | """ 20 | 21 | CLASSES = ('background', 'skin', 'left eyebrow', 22 | 'right eyebrow', 'left eye', 'right eye', 23 | 'nose', 'upper lip', 'inner mouth', 'lower lip', 'hair') 24 | 25 | PALETTE = [[0, 0, 0], [0, 153, 255], [102, 255, 153], 26 | [0, 204, 153], [255, 255, 102], [255, 255, 204], 27 | [255, 153, 0], [255, 102, 255], [102, 0, 51], 28 | [255, 204, 255], [255, 0, 102]] 29 | 30 | def __init__(self, **kwargs): 31 | super(EasyPortraitDataset, self).__init__( 32 | img_suffix='.jpg', 33 | seg_map_suffix='.png', 34 | reduce_zero_label=False, 35 | **kwargs) 36 | assert self.file_client.exists(self.img_dir) -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/loveda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from .builder import DATASETS 9 | from .custom import CustomDataset 10 | 11 | 12 | @DATASETS.register_module() 13 | class LoveDADataset(CustomDataset): 14 | """LoveDA dataset. 15 | 16 | In segmentation map annotation for LoveDA, 0 is the ignore index. 17 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 18 | ``seg_map_suffix`` are both fixed to '.png'. 19 | """ 20 | CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', 21 | 'agricultural') 22 | 23 | PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], 24 | [159, 129, 183], [0, 255, 0], [255, 195, 128]] 25 | 26 | def __init__(self, **kwargs): 27 | super(LoveDADataset, self).__init__( 28 | img_suffix='.png', 29 | seg_map_suffix='.png', 30 | reduce_zero_label=True, 31 | **kwargs) 32 | 33 | def results2img(self, results, imgfile_prefix, indices=None): 34 | """Write the segmentation results to images. 35 | 36 | Args: 37 | results (list[ndarray]): Testing results of the 38 | dataset. 39 | imgfile_prefix (str): The filename prefix of the png files. 40 | If the prefix is "somepath/xxx", 41 | the png files will be named "somepath/xxx.png". 42 | indices (list[int], optional): Indices of input results, if not 43 | set, all the indices of the dataset will be used. 44 | Default: None. 45 | 46 | Returns: 47 | list[str: str]: result txt files which contains corresponding 48 | semantic segmentation images. 49 | """ 50 | 51 | mmcv.mkdir_or_exist(imgfile_prefix) 52 | result_files = [] 53 | for result, idx in zip(results, indices): 54 | 55 | filename = self.img_infos[idx]['filename'] 56 | basename = osp.splitext(osp.basename(filename))[0] 57 | 58 | png_filename = osp.join(imgfile_prefix, f'{basename}.png') 59 | 60 | # The index range of official requirement is from 0 to 6. 61 | output = Image.fromarray(result.astype(np.uint8)) 62 | output.save(png_filename) 63 | result_files.append(png_filename) 64 | 65 | return result_files 66 | 67 | def format_results(self, results, imgfile_prefix, indices=None): 68 | """Format the results into dir (standard format for LoveDA evaluation). 69 | 70 | Args: 71 | results (list): Testing results of the dataset. 72 | imgfile_prefix (str): The prefix of images files. It 73 | includes the file path and the prefix of filename, e.g., 74 | "a/b/prefix". 75 | indices (list[int], optional): Indices of input results, 76 | if not set, all the indices of the dataset will be used. 77 | Default: None. 78 | 79 | Returns: 80 | tuple: (result_files, tmp_dir), result_files is a list containing 81 | the image paths, tmp_dir is the temporal directory created 82 | for saving json/png files when img_prefix is not specified. 83 | """ 84 | if indices is None: 85 | indices = list(range(len(self))) 86 | 87 | assert isinstance(results, list), 'results must be a list.' 88 | assert isinstance(indices, list), 'indices must be a list.' 89 | 90 | result_files = self.results2img(results, imgfile_prefix, indices) 91 | 92 | return result_files 93 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/night_driving.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class NightDrivingDataset(CityscapesDataset): 8 | """NightDrivingDataset dataset.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__( 12 | img_suffix='_leftImg8bit.png', 13 | seg_map_suffix='_gtCoarse_labelTrainIds.png', 14 | **kwargs) 15 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .compose import Compose 3 | from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, 4 | Transpose, to_tensor) 5 | from .loading import LoadAnnotations, LoadImageFromFile 6 | from .test_time_aug import MultiScaleFlipAug 7 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 8 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 9 | RandomFlip, RandomMosaic, RandomRotate, Rerange, 10 | Resize, RGB2Gray, SegRescale) 11 | 12 | __all__ = [ 13 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 14 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 15 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 16 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 17 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 18 | 'RandomMosaic' 19 | ] 20 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..builder import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose(object): 11 | """Compose multiple transforms sequentially. 12 | 13 | Args: 14 | transforms (Sequence[dict | callable]): Sequence of transform object or 15 | config dict to be composed. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, collections.abc.Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError('transform must be callable or a dict') 29 | 30 | def __call__(self, data): 31 | """Call function to apply transforms sequentially. 32 | 33 | Args: 34 | data (dict): A result dict contains the data to transform. 35 | 36 | Returns: 37 | dict: Transformed data. 38 | """ 39 | 40 | for t in self.transforms: 41 | data = t(data) 42 | if data is None: 43 | return None 44 | return data 45 | 46 | def __repr__(self): 47 | format_string = self.__class__.__name__ + '(' 48 | for t in self.transforms: 49 | format_string += '\n' 50 | format_string += f' {t}' 51 | format_string += '\n)' 52 | return format_string 53 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # flake8: noqa 3 | import warnings 4 | 5 | from .formatting import * 6 | 7 | warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be ' 8 | 'deprecated in 2021, please replace it with ' 9 | 'mmseg.datasets.pipelines.formatting.') 10 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/potsdam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .custom import CustomDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class PotsdamDataset(CustomDataset): 8 | """ISPRS Potsdam dataset. 9 | 10 | In segmentation map annotation for Potsdam dataset, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', 15 | 'car', 'clutter') 16 | 17 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]] 19 | 20 | def __init__(self, **kwargs): 21 | super(PotsdamDataset, self).__init__( 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=True, 25 | **kwargs) 26 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from __future__ import division 3 | from typing import Iterator, Optional 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import DistributedSampler as _DistributedSampler 8 | 9 | from mmseg.core.utils import sync_random_seed 10 | from mmseg.utils import get_device 11 | 12 | 13 | class DistributedSampler(_DistributedSampler): 14 | """DistributedSampler inheriting from 15 | `torch.utils.data.DistributedSampler`. 16 | 17 | Args: 18 | datasets (Dataset): the dataset will be loaded. 19 | num_replicas (int, optional): Number of processes participating in 20 | distributed training. By default, world_size is retrieved from the 21 | current distributed group. 22 | rank (int, optional): Rank of the current process within num_replicas. 23 | By default, rank is retrieved from the current distributed group. 24 | shuffle (bool): If True (default), sampler will shuffle the indices. 25 | seed (int): random seed used to shuffle the sampler if 26 | :attr:`shuffle=True`. This number should be identical across all 27 | processes in the distributed group. Default: ``0``. 28 | """ 29 | 30 | def __init__(self, 31 | dataset: Dataset, 32 | num_replicas: Optional[int] = None, 33 | rank: Optional[int] = None, 34 | shuffle: bool = True, 35 | seed=0) -> None: 36 | super().__init__( 37 | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 38 | 39 | # In distributed sampling, different ranks should sample 40 | # non-overlapped data in the dataset. Therefore, this function 41 | # is used to make sure that each rank shuffles the data indices 42 | # in the same order based on the same seed. Then different ranks 43 | # could use different indices to select non-overlapped data from the 44 | # same data list. 45 | device = get_device() 46 | self.seed = sync_random_seed(seed, device) 47 | 48 | def __iter__(self) -> Iterator: 49 | """ 50 | Yields: 51 | Iterator: iterator of indices for rank. 52 | """ 53 | # deterministically shuffle based on epoch 54 | if self.shuffle: 55 | g = torch.Generator() 56 | # When :attr:`shuffle=True`, this ensures all replicas 57 | # use a different random ordering for each epoch. 58 | # Otherwise, the next iteration of this sampler will 59 | # yield the same ordering. 60 | g.manual_seed(self.epoch + self.seed) 61 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 62 | else: 63 | indices = torch.arange(len(self.dataset)).tolist() 64 | 65 | # add extra samples to make it evenly divisible 66 | indices += indices[:(self.total_size - len(indices))] 67 | assert len(indices) == self.total_size 68 | 69 | # subsample 70 | indices = indices[self.rank:self.total_size:self.num_replicas] 71 | assert len(indices) == self.num_samples 72 | 73 | return iter(indices) 74 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/stare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class STAREDataset(CustomDataset): 10 | """STARE dataset. 11 | 12 | In segmentation map annotation for STARE, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '.ah.png'. 16 | """ 17 | 18 | CLASSES = ('background', 'vessel') 19 | 20 | PALETTE = [[120, 120, 120], [6, 230, 230]] 21 | 22 | def __init__(self, **kwargs): 23 | super(STAREDataset, self).__init__( 24 | img_suffix='.png', 25 | seg_map_suffix='.ah.png', 26 | reduce_zero_label=False, 27 | **kwargs) 28 | assert osp.exists(self.img_dir) 29 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class PascalVOCDataset(CustomDataset): 10 | """Pascal VOC dataset. 11 | 12 | Args: 13 | split (str): Split txt file for Pascal VOC. 14 | """ 15 | 16 | CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 17 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 18 | 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 19 | 'train', 'tvmonitor') 20 | 21 | PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], 22 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], 23 | [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], 24 | [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], 25 | [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] 26 | 27 | def __init__(self, split, **kwargs): 28 | super(PascalVOCDataset, self).__init__( 29 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 30 | assert osp.exists(self.img_dir) and self.split is not None 31 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, 4 | build_head, build_loss, build_segmentor) 5 | from .decode_heads import * # noqa: F401,F403 6 | from .losses import * # noqa: F401,F403 7 | from .necks import * # noqa: F401,F403 8 | from .segmentors import * # noqa: F401,F403 9 | 10 | __all__ = [ 11 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', 12 | 'build_head', 'build_loss', 'build_segmentor' 13 | ] 14 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .beit import BEiT 3 | from .bisenetv1 import BiSeNetV1 4 | from .bisenetv2 import BiSeNetV2 5 | from .cgnet import CGNet 6 | from .erfnet import ERFNet 7 | from .fast_scnn import FastSCNN 8 | from .hrnet import HRNet 9 | from .icnet import ICNet 10 | from .mae import MAE 11 | from .mit import MixVisionTransformer 12 | from .mobilenet_v2 import MobileNetV2 13 | from .mobilenet_v3 import MobileNetV3 14 | from .mscan import MSCAN 15 | from .resnest import ResNeSt 16 | from .resnet import ResNet, ResNetV1c, ResNetV1d 17 | from .resnext import ResNeXt 18 | from .stdc import STDCContextPathNet, STDCNet 19 | from .swin import SwinTransformer 20 | from .timm_backbone import TIMMBackbone 21 | from .twins import PCPVT, SVT 22 | from .unet import UNet 23 | from .vit import VisionTransformer 24 | 25 | __all__ = [ 26 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 27 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 28 | 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 29 | 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 30 | 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'MSCAN' 31 | ] -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/backbones/timm_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | try: 3 | import timm 4 | except ImportError: 5 | timm = None 6 | 7 | from mmcv.cnn.bricks.registry import NORM_LAYERS 8 | from mmcv.runner import BaseModule 9 | 10 | from ..builder import BACKBONES 11 | 12 | 13 | @BACKBONES.register_module() 14 | class TIMMBackbone(BaseModule): 15 | """Wrapper to use backbones from timm library. More details can be found in 16 | `timm `_ . 17 | 18 | Args: 19 | model_name (str): Name of timm model to instantiate. 20 | pretrained (bool): Load pretrained weights if True. 21 | checkpoint_path (str): Path of checkpoint to load after 22 | model is initialized. 23 | in_channels (int): Number of input image channels. Default: 3. 24 | init_cfg (dict, optional): Initialization config dict 25 | **kwargs: Other timm & model specific arguments. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model_name, 31 | features_only=True, 32 | pretrained=True, 33 | checkpoint_path='', 34 | in_channels=3, 35 | init_cfg=None, 36 | **kwargs, 37 | ): 38 | if timm is None: 39 | raise RuntimeError('timm is not installed') 40 | super(TIMMBackbone, self).__init__(init_cfg) 41 | if 'norm_layer' in kwargs: 42 | kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) 43 | self.timm_model = timm.create_model( 44 | model_name=model_name, 45 | features_only=features_only, 46 | pretrained=pretrained, 47 | in_chans=in_channels, 48 | checkpoint_path=checkpoint_path, 49 | **kwargs, 50 | ) 51 | 52 | # Make unused parameters None 53 | self.timm_model.global_pool = None 54 | self.timm_model.fc = None 55 | self.timm_model.classifier = None 56 | 57 | # Hack to use pretrained weights from timm 58 | if pretrained or checkpoint_path: 59 | self._is_init = True 60 | 61 | def forward(self, x): 62 | features = self.timm_model(x) 63 | return features 64 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmcv.cnn import MODELS as MMCV_MODELS 5 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 6 | from mmcv.utils import Registry 7 | 8 | MODELS = Registry('models', parent=MMCV_MODELS) 9 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 10 | 11 | BACKBONES = MODELS 12 | NECKS = MODELS 13 | HEADS = MODELS 14 | LOSSES = MODELS 15 | SEGMENTORS = MODELS 16 | 17 | 18 | def build_backbone(cfg): 19 | """Build backbone.""" 20 | return BACKBONES.build(cfg) 21 | 22 | 23 | def build_neck(cfg): 24 | """Build neck.""" 25 | return NECKS.build(cfg) 26 | 27 | 28 | def build_head(cfg): 29 | """Build head.""" 30 | return HEADS.build(cfg) 31 | 32 | 33 | def build_loss(cfg): 34 | """Build loss.""" 35 | return LOSSES.build(cfg) 36 | 37 | 38 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 39 | """Build segmentor.""" 40 | if train_cfg is not None or test_cfg is not None: 41 | warnings.warn( 42 | 'train_cfg and test_cfg is deprecated, ' 43 | 'please specify them in model', UserWarning) 44 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 45 | 'train_cfg specified in both outer field and model field ' 46 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 47 | 'test_cfg specified in both outer field and model field ' 48 | return SEGMENTORS.build( 49 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 50 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ann_head import ANNHead 3 | from .apc_head import APCHead 4 | from .aspp_head import ASPPHead 5 | from .cc_head import CCHead 6 | from .da_head import DAHead 7 | from .dm_head import DMHead 8 | from .dnl_head import DNLHead 9 | from .dpt_head import DPTHead 10 | from .ema_head import EMAHead 11 | from .enc_head import EncHead 12 | from .fcn_head import FCNHead 13 | from .fpn_head import FPNHead 14 | from .gc_head import GCHead 15 | from .ham_head import LightHamHead 16 | from .isa_head import ISAHead 17 | from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator 18 | from .lraspp_head import LRASPPHead 19 | from .nl_head import NLHead 20 | from .ocr_head import OCRHead 21 | from .point_head import PointHead 22 | from .psa_head import PSAHead 23 | from .psp_head import PSPHead 24 | from .segformer_head import SegformerHead 25 | from .segmenter_mask_head import SegmenterMaskTransformerHead 26 | from .sep_aspp_head import DepthwiseSeparableASPPHead 27 | from .sep_fcn_head import DepthwiseSeparableFCNHead 28 | from .setr_mla_head import SETRMLAHead 29 | from .setr_up_head import SETRUPHead 30 | from .stdc_head import STDCHead 31 | from .uper_head import UPerHead 32 | 33 | __all__ = [ 34 | 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 35 | 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 36 | 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 37 | 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 38 | 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 39 | 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 40 | 'KernelUpdateHead', 'KernelUpdator', 'LightHamHead' 41 | ] -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/cascade_decode_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from .decode_head import BaseDecodeHead 5 | 6 | 7 | class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): 8 | """Base class for cascade decode head used in 9 | :class:`CascadeEncoderDecoder.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) 13 | 14 | @abstractmethod 15 | def forward(self, inputs, prev_output): 16 | """Placeholder of forward function.""" 17 | pass 18 | 19 | def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, 20 | train_cfg): 21 | """Forward function for training. 22 | Args: 23 | inputs (list[Tensor]): List of multi-level img features. 24 | prev_output (Tensor): The output of previous decode head. 25 | img_metas (list[dict]): List of image info dict where each dict 26 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 27 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 28 | For details on the values of these keys see 29 | `mmseg/datasets/pipelines/formatting.py:Collect`. 30 | gt_semantic_seg (Tensor): Semantic segmentation masks 31 | used if the architecture supports semantic segmentation task. 32 | train_cfg (dict): The training config. 33 | 34 | Returns: 35 | dict[str, Tensor]: a dictionary of loss components 36 | """ 37 | seg_logits = self.forward(inputs, prev_output) 38 | losses = self.losses(seg_logits, gt_semantic_seg) 39 | 40 | return losses 41 | 42 | def forward_test(self, inputs, prev_output, img_metas, test_cfg): 43 | """Forward function for testing. 44 | 45 | Args: 46 | inputs (list[Tensor]): List of multi-level img features. 47 | prev_output (Tensor): The output of previous decode head. 48 | img_metas (list[dict]): List of image info dict where each dict 49 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 50 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 51 | For details on the values of these keys see 52 | `mmseg/datasets/pipelines/formatting.py:Collect`. 53 | test_cfg (dict): The testing config. 54 | 55 | Returns: 56 | Tensor: Output segmentation map. 57 | """ 58 | return self.forward(inputs, prev_output) 59 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/cc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | try: 8 | from mmcv.ops import CrissCrossAttention 9 | except ModuleNotFoundError: 10 | CrissCrossAttention = None 11 | 12 | 13 | @HEADS.register_module() 14 | class CCHead(FCNHead): 15 | """CCNet: Criss-Cross Attention for Semantic Segmentation. 16 | 17 | This head is the implementation of `CCNet 18 | `_. 19 | 20 | Args: 21 | recurrence (int): Number of recurrence of Criss Cross Attention 22 | module. Default: 2. 23 | """ 24 | 25 | def __init__(self, recurrence=2, **kwargs): 26 | if CrissCrossAttention is None: 27 | raise RuntimeError('Please install mmcv-full for ' 28 | 'CrissCrossAttention ops') 29 | super(CCHead, self).__init__(num_convs=2, **kwargs) 30 | self.recurrence = recurrence 31 | self.cca = CrissCrossAttention(self.channels) 32 | 33 | def forward(self, inputs): 34 | """Forward function.""" 35 | x = self._transform_inputs(inputs) 36 | output = self.convs[0](x) 37 | for _ in range(self.recurrence): 38 | output = self.cca(output) 39 | output = self.convs[1](output) 40 | if self.concat_input: 41 | output = self.conv_cat(torch.cat([x, output], dim=1)) 42 | output = self.cls_seg(output) 43 | return output 44 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @HEADS.register_module() 11 | class FCNHead(BaseDecodeHead): 12 | """Fully Convolution Networks for Semantic Segmentation. 13 | 14 | This head is implemented of `FCNNet `_. 15 | 16 | Args: 17 | num_convs (int): Number of convs in the head. Default: 2. 18 | kernel_size (int): The kernel size for convs in the head. Default: 3. 19 | concat_input (bool): Whether concat the input and output of convs 20 | before classification layer. 21 | dilation (int): The dilation rate for convs in the head. Default: 1. 22 | """ 23 | 24 | def __init__(self, 25 | num_convs=2, 26 | kernel_size=3, 27 | concat_input=True, 28 | dilation=1, 29 | **kwargs): 30 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 31 | self.num_convs = num_convs 32 | self.concat_input = concat_input 33 | self.kernel_size = kernel_size 34 | super(FCNHead, self).__init__(**kwargs) 35 | if num_convs == 0: 36 | assert self.in_channels == self.channels 37 | 38 | conv_padding = (kernel_size // 2) * dilation 39 | convs = [] 40 | for i in range(num_convs): 41 | _in_channels = self.in_channels if i == 0 else self.channels 42 | convs.append( 43 | ConvModule( 44 | _in_channels, 45 | self.channels, 46 | kernel_size=kernel_size, 47 | padding=conv_padding, 48 | dilation=dilation, 49 | conv_cfg=self.conv_cfg, 50 | norm_cfg=self.norm_cfg, 51 | act_cfg=self.act_cfg)) 52 | 53 | if len(convs) == 0: 54 | self.convs = nn.Identity() 55 | else: 56 | self.convs = nn.Sequential(*convs) 57 | if self.concat_input: 58 | self.conv_cat = ConvModule( 59 | self.in_channels + self.channels, 60 | self.channels, 61 | kernel_size=kernel_size, 62 | padding=kernel_size // 2, 63 | conv_cfg=self.conv_cfg, 64 | norm_cfg=self.norm_cfg, 65 | act_cfg=self.act_cfg) 66 | 67 | def _forward_feature(self, inputs): 68 | """Forward function for feature maps before classifying each pixel with 69 | ``self.cls_seg`` fc. 70 | 71 | Args: 72 | inputs (list[Tensor]): List of multi-level img features. 73 | 74 | Returns: 75 | feats (Tensor): A tensor of shape (batch_size, self.channels, 76 | H, W) which is feature map for last layer of decoder head. 77 | """ 78 | x = self._transform_inputs(inputs) 79 | feats = self.convs(x) 80 | if self.concat_input: 81 | feats = self.conv_cat(torch.cat([x, feats], dim=1)) 82 | return feats 83 | 84 | def forward(self, inputs): 85 | """Forward function.""" 86 | output = self._forward_feature(inputs) 87 | output = self.cls_seg(output) 88 | return output 89 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/fpn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import Upsample, resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FPNHead(BaseDecodeHead): 13 | """Panoptic Feature Pyramid Networks. 14 | 15 | This head is the implementation of `Semantic FPN 16 | `_. 17 | 18 | Args: 19 | feature_strides (tuple[int]): The strides for input feature maps. 20 | stack_lateral. All strides suppose to be power of 2. The first 21 | one is of largest resolution. 22 | """ 23 | 24 | def __init__(self, feature_strides, **kwargs): 25 | super(FPNHead, self).__init__( 26 | input_transform='multiple_select', **kwargs) 27 | assert len(feature_strides) == len(self.in_channels) 28 | assert min(feature_strides) == feature_strides[0] 29 | self.feature_strides = feature_strides 30 | 31 | self.scale_heads = nn.ModuleList() 32 | for i in range(len(feature_strides)): 33 | head_length = max( 34 | 1, 35 | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) 36 | scale_head = [] 37 | for k in range(head_length): 38 | scale_head.append( 39 | ConvModule( 40 | self.in_channels[i] if k == 0 else self.channels, 41 | self.channels, 42 | 3, 43 | padding=1, 44 | conv_cfg=self.conv_cfg, 45 | norm_cfg=self.norm_cfg, 46 | act_cfg=self.act_cfg)) 47 | if feature_strides[i] != feature_strides[0]: 48 | scale_head.append( 49 | Upsample( 50 | scale_factor=2, 51 | mode='bilinear', 52 | align_corners=self.align_corners)) 53 | self.scale_heads.append(nn.Sequential(*scale_head)) 54 | 55 | def forward(self, inputs): 56 | 57 | x = self._transform_inputs(inputs) 58 | 59 | output = self.scale_heads[0](x[0]) 60 | for i in range(1, len(self.feature_strides)): 61 | # non inplace 62 | output = output + resize( 63 | self.scale_heads[i](x[i]), 64 | size=output.shape[2:], 65 | mode='bilinear', 66 | align_corners=self.align_corners) 67 | 68 | output = self.cls_seg(output) 69 | return output 70 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/gc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ContextBlock 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class GCHead(FCNHead): 11 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. 12 | 13 | This head is the implementation of `GCNet 14 | `_. 15 | 16 | Args: 17 | ratio (float): Multiplier of channels ratio. Default: 1/4. 18 | pooling_type (str): The pooling type of context aggregation. 19 | Options are 'att', 'avg'. Default: 'avg'. 20 | fusion_types (tuple[str]): The fusion type for feature fusion. 21 | Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) 22 | """ 23 | 24 | def __init__(self, 25 | ratio=1 / 4., 26 | pooling_type='att', 27 | fusion_types=('channel_add', ), 28 | **kwargs): 29 | super(GCHead, self).__init__(num_convs=2, **kwargs) 30 | self.ratio = ratio 31 | self.pooling_type = pooling_type 32 | self.fusion_types = fusion_types 33 | self.gc_block = ContextBlock( 34 | in_channels=self.channels, 35 | ratio=self.ratio, 36 | pooling_type=self.pooling_type, 37 | fusion_types=self.fusion_types) 38 | 39 | def forward(self, inputs): 40 | """Forward function.""" 41 | x = self._transform_inputs(inputs) 42 | output = self.convs[0](x) 43 | output = self.gc_block(output) 44 | output = self.convs[1](output) 45 | if self.concat_input: 46 | output = self.conv_cat(torch.cat([x, output], dim=1)) 47 | output = self.cls_seg(output) 48 | return output 49 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/lraspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv import is_tuple_of 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | @HEADS.register_module() 13 | class LRASPPHead(BaseDecodeHead): 14 | """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. 15 | 16 | This head is the improved implementation of `Searching for MobileNetV3 17 | `_. 18 | 19 | Args: 20 | branch_channels (tuple[int]): The number of output channels in every 21 | each branch. Default: (32, 64). 22 | """ 23 | 24 | def __init__(self, branch_channels=(32, 64), **kwargs): 25 | super(LRASPPHead, self).__init__(**kwargs) 26 | if self.input_transform != 'multiple_select': 27 | raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' 28 | f'must be \'multiple_select\'. But received ' 29 | f'\'{self.input_transform}\'') 30 | assert is_tuple_of(branch_channels, int) 31 | assert len(branch_channels) == len(self.in_channels) - 1 32 | self.branch_channels = branch_channels 33 | 34 | self.convs = nn.Sequential() 35 | self.conv_ups = nn.Sequential() 36 | for i in range(len(branch_channels)): 37 | self.convs.add_module( 38 | f'conv{i}', 39 | nn.Conv2d( 40 | self.in_channels[i], branch_channels[i], 1, bias=False)) 41 | self.conv_ups.add_module( 42 | f'conv_up{i}', 43 | ConvModule( 44 | self.channels + branch_channels[i], 45 | self.channels, 46 | 1, 47 | norm_cfg=self.norm_cfg, 48 | act_cfg=self.act_cfg, 49 | bias=False)) 50 | 51 | self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) 52 | 53 | self.aspp_conv = ConvModule( 54 | self.in_channels[-1], 55 | self.channels, 56 | 1, 57 | norm_cfg=self.norm_cfg, 58 | act_cfg=self.act_cfg, 59 | bias=False) 60 | self.image_pool = nn.Sequential( 61 | nn.AvgPool2d(kernel_size=49, stride=(16, 20)), 62 | ConvModule( 63 | self.in_channels[2], 64 | self.channels, 65 | 1, 66 | act_cfg=dict(type='Sigmoid'), 67 | bias=False)) 68 | 69 | def forward(self, inputs): 70 | """Forward function.""" 71 | inputs = self._transform_inputs(inputs) 72 | 73 | x = inputs[-1] 74 | 75 | x = self.aspp_conv(x) * resize( 76 | self.image_pool(x), 77 | size=x.size()[2:], 78 | mode='bilinear', 79 | align_corners=self.align_corners) 80 | x = self.conv_up_input(x) 81 | 82 | for i in range(len(self.branch_channels) - 1, -1, -1): 83 | x = resize( 84 | x, 85 | size=inputs[i].size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | x = torch.cat([x, self.convs[i](inputs[i])], 1) 89 | x = self.conv_ups[i](x) 90 | 91 | return self.cls_seg(x) 92 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/nl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import NonLocal2d 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class NLHead(FCNHead): 11 | """Non-local Neural Networks. 12 | 13 | This head is the implementation of `NLNet 14 | `_. 15 | 16 | Args: 17 | reduction (int): Reduction factor of projection transform. Default: 2. 18 | use_scale (bool): Whether to scale pairwise_weight by 19 | sqrt(1/inter_channels). Default: True. 20 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 21 | 'dot_product'. Default: 'embedded_gaussian.'. 22 | """ 23 | 24 | def __init__(self, 25 | reduction=2, 26 | use_scale=True, 27 | mode='embedded_gaussian', 28 | **kwargs): 29 | super(NLHead, self).__init__(num_convs=2, **kwargs) 30 | self.reduction = reduction 31 | self.use_scale = use_scale 32 | self.mode = mode 33 | self.nl_block = NonLocal2d( 34 | in_channels=self.channels, 35 | reduction=self.reduction, 36 | use_scale=self.use_scale, 37 | conv_cfg=self.conv_cfg, 38 | norm_cfg=self.norm_cfg, 39 | mode=self.mode) 40 | 41 | def forward(self, inputs): 42 | """Forward function.""" 43 | x = self._transform_inputs(inputs) 44 | output = self.convs[0](x) 45 | output = self.nl_block(output) 46 | output = self.convs[1](output) 47 | if self.concat_input: 48 | output = self.conv_cat(torch.cat([x, output], dim=1)) 49 | output = self.cls_seg(output) 50 | return output 51 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/sep_aspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | 6 | from mmseg.ops import resize 7 | from ..builder import HEADS 8 | from .aspp_head import ASPPHead, ASPPModule 9 | 10 | 11 | class DepthwiseSeparableASPPModule(ASPPModule): 12 | """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable 13 | conv.""" 14 | 15 | def __init__(self, **kwargs): 16 | super(DepthwiseSeparableASPPModule, self).__init__(**kwargs) 17 | for i, dilation in enumerate(self.dilations): 18 | if dilation > 1: 19 | self[i] = DepthwiseSeparableConvModule( 20 | self.in_channels, 21 | self.channels, 22 | 3, 23 | dilation=dilation, 24 | padding=dilation, 25 | norm_cfg=self.norm_cfg, 26 | act_cfg=self.act_cfg) 27 | 28 | 29 | @HEADS.register_module() 30 | class DepthwiseSeparableASPPHead(ASPPHead): 31 | """Encoder-Decoder with Atrous Separable Convolution for Semantic Image 32 | Segmentation. 33 | 34 | This head is the implementation of `DeepLabV3+ 35 | `_. 36 | 37 | Args: 38 | c1_in_channels (int): The input channels of c1 decoder. If is 0, 39 | the no decoder will be used. 40 | c1_channels (int): The intermediate channels of c1 decoder. 41 | """ 42 | 43 | def __init__(self, c1_in_channels, c1_channels, **kwargs): 44 | super(DepthwiseSeparableASPPHead, self).__init__(**kwargs) 45 | assert c1_in_channels >= 0 46 | self.aspp_modules = DepthwiseSeparableASPPModule( 47 | dilations=self.dilations, 48 | in_channels=self.in_channels, 49 | channels=self.channels, 50 | conv_cfg=self.conv_cfg, 51 | norm_cfg=self.norm_cfg, 52 | act_cfg=self.act_cfg) 53 | if c1_in_channels > 0: 54 | self.c1_bottleneck = ConvModule( 55 | c1_in_channels, 56 | c1_channels, 57 | 1, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg) 61 | else: 62 | self.c1_bottleneck = None 63 | self.sep_bottleneck = nn.Sequential( 64 | DepthwiseSeparableConvModule( 65 | self.channels + c1_channels, 66 | self.channels, 67 | 3, 68 | padding=1, 69 | norm_cfg=self.norm_cfg, 70 | act_cfg=self.act_cfg), 71 | DepthwiseSeparableConvModule( 72 | self.channels, 73 | self.channels, 74 | 3, 75 | padding=1, 76 | norm_cfg=self.norm_cfg, 77 | act_cfg=self.act_cfg)) 78 | 79 | def forward(self, inputs): 80 | """Forward function.""" 81 | x = self._transform_inputs(inputs) 82 | aspp_outs = [ 83 | resize( 84 | self.image_pool(x), 85 | size=x.size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | ] 89 | aspp_outs.extend(self.aspp_modules(x)) 90 | aspp_outs = torch.cat(aspp_outs, dim=1) 91 | output = self.bottleneck(aspp_outs) 92 | if self.c1_bottleneck is not None: 93 | c1_output = self.c1_bottleneck(inputs[0]) 94 | output = resize( 95 | input=output, 96 | size=c1_output.shape[2:], 97 | mode='bilinear', 98 | align_corners=self.align_corners) 99 | output = torch.cat([output, c1_output], dim=1) 100 | output = self.sep_bottleneck(output) 101 | output = self.cls_seg(output) 102 | return output 103 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/sep_fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import DepthwiseSeparableConvModule 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | 8 | @HEADS.register_module() 9 | class DepthwiseSeparableFCNHead(FCNHead): 10 | """Depthwise-Separable Fully Convolutional Network for Semantic 11 | Segmentation. 12 | 13 | This head is implemented according to `Fast-SCNN: Fast Semantic 14 | Segmentation Network `_. 15 | 16 | Args: 17 | in_channels(int): Number of output channels of FFM. 18 | channels(int): Number of middle-stage channels in the decode head. 19 | concat_input(bool): Whether to concatenate original decode input into 20 | the result of several consecutive convolution layers. 21 | Default: True. 22 | num_classes(int): Used to determine the dimension of 23 | final prediction tensor. 24 | in_index(int): Correspond with 'out_indices' in FastSCNN backbone. 25 | norm_cfg (dict | None): Config of norm layers. 26 | align_corners (bool): align_corners argument of F.interpolate. 27 | Default: False. 28 | loss_decode(dict): Config of loss type and some 29 | relevant additional options. 30 | dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is 31 | 'default', it will be the same as `act_cfg`. Default: None. 32 | """ 33 | 34 | def __init__(self, dw_act_cfg=None, **kwargs): 35 | super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) 36 | self.convs[0] = DepthwiseSeparableConvModule( 37 | self.in_channels, 38 | self.channels, 39 | kernel_size=self.kernel_size, 40 | padding=self.kernel_size // 2, 41 | norm_cfg=self.norm_cfg, 42 | dw_act_cfg=dw_act_cfg) 43 | 44 | for i in range(1, self.num_convs): 45 | self.convs[i] = DepthwiseSeparableConvModule( 46 | self.channels, 47 | self.channels, 48 | kernel_size=self.kernel_size, 49 | padding=self.kernel_size // 2, 50 | norm_cfg=self.norm_cfg, 51 | dw_act_cfg=dw_act_cfg) 52 | 53 | if self.concat_input: 54 | self.conv_cat = DepthwiseSeparableConvModule( 55 | self.in_channels + self.channels, 56 | self.channels, 57 | kernel_size=self.kernel_size, 58 | padding=self.kernel_size // 2, 59 | norm_cfg=self.norm_cfg, 60 | dw_act_cfg=dw_act_cfg) 61 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/setr_mla_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import Upsample 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class SETRMLAHead(BaseDecodeHead): 13 | """Multi level feature aggretation head of SETR. 14 | 15 | MLA head of `SETR `_. 16 | 17 | Args: 18 | mlahead_channels (int): Channels of conv-conv-4x of multi-level feature 19 | aggregation. Default: 128. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | """ 22 | 23 | def __init__(self, mla_channels=128, up_scale=4, **kwargs): 24 | super(SETRMLAHead, self).__init__( 25 | input_transform='multiple_select', **kwargs) 26 | self.mla_channels = mla_channels 27 | 28 | num_inputs = len(self.in_channels) 29 | 30 | # Refer to self.cls_seg settings of BaseDecodeHead 31 | assert self.channels == num_inputs * mla_channels 32 | 33 | self.up_convs = nn.ModuleList() 34 | for i in range(num_inputs): 35 | self.up_convs.append( 36 | nn.Sequential( 37 | ConvModule( 38 | in_channels=self.in_channels[i], 39 | out_channels=mla_channels, 40 | kernel_size=3, 41 | padding=1, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg), 44 | ConvModule( 45 | in_channels=mla_channels, 46 | out_channels=mla_channels, 47 | kernel_size=3, 48 | padding=1, 49 | norm_cfg=self.norm_cfg, 50 | act_cfg=self.act_cfg), 51 | Upsample( 52 | scale_factor=up_scale, 53 | mode='bilinear', 54 | align_corners=self.align_corners))) 55 | 56 | def forward(self, inputs): 57 | inputs = self._transform_inputs(inputs) 58 | outs = [] 59 | for x, up_conv in zip(inputs, self.up_convs): 60 | outs.append(up_conv(x)) 61 | out = torch.cat(outs, dim=1) 62 | out = self.cls_seg(out) 63 | return out 64 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/setr_up_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from mmseg.ops import Upsample 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @HEADS.register_module() 11 | class SETRUPHead(BaseDecodeHead): 12 | """Naive upsampling head and Progressive upsampling head of SETR. 13 | 14 | Naive or PUP head of `SETR `_. 15 | 16 | Args: 17 | norm_layer (dict): Config dict for input normalization. 18 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 19 | num_convs (int): Number of decoder convolutions. Default: 1. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | kernel_size (int): The kernel size of convolution when decoding 22 | feature information from backbone. Default: 3. 23 | init_cfg (dict | list[dict] | None): Initialization config dict. 24 | Default: dict( 25 | type='Constant', val=1.0, bias=0, layer='LayerNorm'). 26 | """ 27 | 28 | def __init__(self, 29 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 30 | num_convs=1, 31 | up_scale=4, 32 | kernel_size=3, 33 | init_cfg=[ 34 | dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), 35 | dict( 36 | type='Normal', 37 | std=0.01, 38 | override=dict(name='conv_seg')) 39 | ], 40 | **kwargs): 41 | 42 | assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' 43 | 44 | super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) 45 | 46 | assert isinstance(self.in_channels, int) 47 | 48 | _, self.norm = build_norm_layer(norm_layer, self.in_channels) 49 | 50 | self.up_convs = nn.ModuleList() 51 | in_channels = self.in_channels 52 | out_channels = self.channels 53 | for _ in range(num_convs): 54 | self.up_convs.append( 55 | nn.Sequential( 56 | ConvModule( 57 | in_channels=in_channels, 58 | out_channels=out_channels, 59 | kernel_size=kernel_size, 60 | stride=1, 61 | padding=int(kernel_size - 1) // 2, 62 | norm_cfg=self.norm_cfg, 63 | act_cfg=self.act_cfg), 64 | Upsample( 65 | scale_factor=up_scale, 66 | mode='bilinear', 67 | align_corners=self.align_corners))) 68 | in_channels = out_channels 69 | 70 | def forward(self, x): 71 | x = self._transform_inputs(x) 72 | 73 | n, c, h, w = x.shape 74 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 75 | x = self.norm(x) 76 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 77 | 78 | for up_conv in self.up_convs: 79 | x = up_conv(x) 80 | out = self.cls_seg(x) 81 | return out 82 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/decode_heads/stdc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class STDCHead(FCNHead): 11 | """This head is the implementation of `Rethinking BiSeNet For Real-time 12 | Semantic Segmentation `_. 13 | 14 | Args: 15 | boundary_threshold (float): The threshold of calculating boundary. 16 | Default: 0.1. 17 | """ 18 | 19 | def __init__(self, boundary_threshold=0.1, **kwargs): 20 | super(STDCHead, self).__init__(**kwargs) 21 | self.boundary_threshold = boundary_threshold 22 | # Using register buffer to make laplacian kernel on the same 23 | # device of `seg_label`. 24 | self.register_buffer( 25 | 'laplacian_kernel', 26 | torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], 27 | dtype=torch.float32, 28 | requires_grad=False).reshape((1, 1, 3, 3))) 29 | self.fusion_kernel = torch.nn.Parameter( 30 | torch.tensor([[6. / 10], [3. / 10], [1. / 10]], 31 | dtype=torch.float32).reshape(1, 3, 1, 1), 32 | requires_grad=False) 33 | 34 | def losses(self, seg_logit, seg_label): 35 | """Compute Detail Aggregation Loss.""" 36 | # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv 37 | # parameters. However, it is a constant in original repo and other 38 | # codebase because it would not be added into computation graph 39 | # after threshold operation. 40 | seg_label = seg_label.to(self.laplacian_kernel) 41 | boundary_targets = F.conv2d( 42 | seg_label, self.laplacian_kernel, padding=1) 43 | boundary_targets = boundary_targets.clamp(min=0) 44 | boundary_targets[boundary_targets > self.boundary_threshold] = 1 45 | boundary_targets[boundary_targets <= self.boundary_threshold] = 0 46 | 47 | boundary_targets_x2 = F.conv2d( 48 | seg_label, self.laplacian_kernel, stride=2, padding=1) 49 | boundary_targets_x2 = boundary_targets_x2.clamp(min=0) 50 | 51 | boundary_targets_x4 = F.conv2d( 52 | seg_label, self.laplacian_kernel, stride=4, padding=1) 53 | boundary_targets_x4 = boundary_targets_x4.clamp(min=0) 54 | 55 | boundary_targets_x4_up = F.interpolate( 56 | boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') 57 | boundary_targets_x2_up = F.interpolate( 58 | boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') 59 | 60 | boundary_targets_x2_up[ 61 | boundary_targets_x2_up > self.boundary_threshold] = 1 62 | boundary_targets_x2_up[ 63 | boundary_targets_x2_up <= self.boundary_threshold] = 0 64 | 65 | boundary_targets_x4_up[ 66 | boundary_targets_x4_up > self.boundary_threshold] = 1 67 | boundary_targets_x4_up[ 68 | boundary_targets_x4_up <= self.boundary_threshold] = 0 69 | 70 | boundary_targets_pyramids = torch.stack( 71 | (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), 72 | dim=1) 73 | 74 | boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) 75 | boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids, 76 | self.fusion_kernel) 77 | 78 | boundary_targets_pyramid[ 79 | boundary_targets_pyramid > self.boundary_threshold] = 1 80 | boundary_targets_pyramid[ 81 | boundary_targets_pyramid <= self.boundary_threshold] = 0 82 | 83 | loss = super(STDCHead, self).losses(seg_logit, 84 | boundary_targets_pyramid.long()) 85 | return loss 86 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 4 | cross_entropy, mask_cross_entropy) 5 | from .dice_loss import DiceLoss 6 | from .focal_loss import FocalLoss 7 | from .lovasz_loss import LovaszLoss 8 | from .tversky_loss import TverskyLoss 9 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 10 | 11 | __all__ = [ 12 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 13 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 14 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 15 | 'FocalLoss', 'TverskyLoss' 16 | ] 17 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 11 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 12 | ignore_index (int | None): The label index to be ignored. Default: None 13 | topk (int | tuple[int], optional): If the predictions in ``topk`` 14 | matches the target, the predictions will be regarded as 15 | correct ones. Defaults to 1. 16 | thresh (float, optional): If not None, predictions with scores under 17 | this threshold are considered incorrect. Default to None. 18 | 19 | Returns: 20 | float | tuple[float]: If the input ``topk`` is a single integer, 21 | the function will return a single float as accuracy. If 22 | ``topk`` is a tuple containing multiple integers, the 23 | function will return a tuple containing accuracies of 24 | each ``topk`` number. 25 | """ 26 | assert isinstance(topk, (int, tuple)) 27 | if isinstance(topk, int): 28 | topk = (topk, ) 29 | return_single = True 30 | else: 31 | return_single = False 32 | 33 | maxk = max(topk) 34 | if pred.size(0) == 0: 35 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 36 | return accu[0] if return_single else accu 37 | assert pred.ndim == target.ndim + 1 38 | assert pred.size(0) == target.size(0) 39 | assert maxk <= pred.size(1), \ 40 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 41 | pred_value, pred_label = pred.topk(maxk, dim=1) 42 | # transpose to shape (maxk, N, ...) 43 | pred_label = pred_label.transpose(0, 1) 44 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 45 | if thresh is not None: 46 | # Only prediction values larger than thresh are counted as correct 47 | correct = correct & (pred_value > thresh).t() 48 | if ignore_index is not None: 49 | correct = correct[:, target != ignore_index] 50 | res = [] 51 | eps = torch.finfo(torch.float32).eps 52 | for k in topk: 53 | # Avoid causing ZeroDivisionError when all pixels 54 | # of an image are ignored 55 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps 56 | if ignore_index is not None: 57 | total_num = target[target != ignore_index].numel() + eps 58 | else: 59 | total_num = target.numel() + eps 60 | res.append(correct_k.mul_(100.0 / total_num)) 61 | return res[0] if return_single else res 62 | 63 | 64 | class Accuracy(nn.Module): 65 | """Accuracy calculation module.""" 66 | 67 | def __init__(self, topk=(1, ), thresh=None, ignore_index=None): 68 | """Module to calculate the accuracy. 69 | 70 | Args: 71 | topk (tuple, optional): The criterion used to calculate the 72 | accuracy. Defaults to (1,). 73 | thresh (float, optional): If not None, predictions with scores 74 | under this threshold are considered incorrect. Default to None. 75 | """ 76 | super().__init__() 77 | self.topk = topk 78 | self.thresh = thresh 79 | self.ignore_index = ignore_index 80 | 81 | def forward(self, pred, target): 82 | """Forward function to calculate accuracy. 83 | 84 | Args: 85 | pred (torch.Tensor): Prediction of models. 86 | target (torch.Tensor): Target for each prediction. 87 | 88 | Returns: 89 | tuple[float]: The accuracies under different topk criterions. 90 | """ 91 | return accuracy(pred, target, self.topk, self.thresh, 92 | self.ignore_index) 93 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .featurepyramid import Feature2Pyramid 3 | from .fpn import FPN 4 | from .ic_neck import ICNeck 5 | from .jpu import JPU 6 | from .mla_neck import MLANeck 7 | from .multilevel_neck import MultiLevelNeck 8 | 9 | __all__ = [ 10 | 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' 11 | ] 12 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/necks/featurepyramid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import build_norm_layer 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | @NECKS.register_module() 9 | class Feature2Pyramid(nn.Module): 10 | """Feature2Pyramid. 11 | 12 | A neck structure connect ViT backbone and decoder_heads. 13 | 14 | Args: 15 | embed_dims (int): Embedding dimension. 16 | rescales (list[float]): Different sampling multiples were 17 | used to obtain pyramid features. Default: [4, 2, 1, 0.5]. 18 | norm_cfg (dict): Config dict for normalization layer. 19 | Default: dict(type='SyncBN', requires_grad=True). 20 | """ 21 | 22 | def __init__(self, 23 | embed_dim, 24 | rescales=[4, 2, 1, 0.5], 25 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 26 | super(Feature2Pyramid, self).__init__() 27 | self.rescales = rescales 28 | self.upsample_4x = None 29 | for k in self.rescales: 30 | if k == 4: 31 | self.upsample_4x = nn.Sequential( 32 | nn.ConvTranspose2d( 33 | embed_dim, embed_dim, kernel_size=2, stride=2), 34 | build_norm_layer(norm_cfg, embed_dim)[1], 35 | nn.GELU(), 36 | nn.ConvTranspose2d( 37 | embed_dim, embed_dim, kernel_size=2, stride=2), 38 | ) 39 | elif k == 2: 40 | self.upsample_2x = nn.Sequential( 41 | nn.ConvTranspose2d( 42 | embed_dim, embed_dim, kernel_size=2, stride=2)) 43 | elif k == 1: 44 | self.identity = nn.Identity() 45 | elif k == 0.5: 46 | self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) 47 | elif k == 0.25: 48 | self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) 49 | else: 50 | raise KeyError(f'invalid {k} for feature2pyramid') 51 | 52 | def forward(self, inputs): 53 | assert len(inputs) == len(self.rescales) 54 | outputs = [] 55 | if self.upsample_4x is not None: 56 | ops = [ 57 | self.upsample_4x, self.upsample_2x, self.identity, 58 | self.downsample_2x 59 | ] 60 | else: 61 | ops = [ 62 | self.upsample_2x, self.identity, self.downsample_2x, 63 | self.downsample_4x 64 | ] 65 | for i in range(len(inputs)): 66 | outputs.append(ops[i](inputs[i])) 67 | return tuple(outputs) 68 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/necks/multilevel_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, xavier_init 4 | 5 | from mmseg.ops import resize 6 | from ..builder import NECKS 7 | 8 | 9 | @NECKS.register_module() 10 | class MultiLevelNeck(nn.Module): 11 | """MultiLevelNeck. 12 | 13 | A neck structure connect vit backbone and decoder_heads. 14 | 15 | Args: 16 | in_channels (List[int]): Number of input channels per scale. 17 | out_channels (int): Number of output channels (used at each scale). 18 | scales (List[float]): Scale factors for each input feature map. 19 | Default: [0.5, 1, 2, 4] 20 | norm_cfg (dict): Config dict for normalization layer. Default: None. 21 | act_cfg (dict): Config dict for activation layer in ConvModule. 22 | Default: None. 23 | """ 24 | 25 | def __init__(self, 26 | in_channels, 27 | out_channels, 28 | scales=[0.5, 1, 2, 4], 29 | norm_cfg=None, 30 | act_cfg=None): 31 | super(MultiLevelNeck, self).__init__() 32 | assert isinstance(in_channels, list) 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.scales = scales 36 | self.num_outs = len(scales) 37 | self.lateral_convs = nn.ModuleList() 38 | self.convs = nn.ModuleList() 39 | for in_channel in in_channels: 40 | self.lateral_convs.append( 41 | ConvModule( 42 | in_channel, 43 | out_channels, 44 | kernel_size=1, 45 | norm_cfg=norm_cfg, 46 | act_cfg=act_cfg)) 47 | for _ in range(self.num_outs): 48 | self.convs.append( 49 | ConvModule( 50 | out_channels, 51 | out_channels, 52 | kernel_size=3, 53 | padding=1, 54 | stride=1, 55 | norm_cfg=norm_cfg, 56 | act_cfg=act_cfg)) 57 | 58 | # default init_weights for conv(msra) and norm in ConvModule 59 | def init_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | xavier_init(m, distribution='uniform') 63 | 64 | def forward(self, inputs): 65 | assert len(inputs) == len(self.in_channels) 66 | inputs = [ 67 | lateral_conv(inputs[i]) 68 | for i, lateral_conv in enumerate(self.lateral_convs) 69 | ] 70 | # for len(inputs) not equal to self.num_outs 71 | if len(inputs) == 1: 72 | inputs = [inputs[0] for _ in range(self.num_outs)] 73 | outs = [] 74 | for i in range(self.num_outs): 75 | x_resize = resize( 76 | inputs[i], scale_factor=self.scales[i], mode='bilinear') 77 | outs.append(self.convs[i](x_resize)) 78 | return tuple(outs) 79 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseSegmentor 3 | from .cascade_encoder_decoder import CascadeEncoderDecoder 4 | from .encoder_decoder import EncoderDecoder 5 | 6 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] 7 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/segmentors/cascade_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from torch import nn 3 | 4 | from mmseg.core import add_prefix 5 | from mmseg.ops import resize 6 | from .. import builder 7 | from ..builder import SEGMENTORS 8 | from .encoder_decoder import EncoderDecoder 9 | 10 | 11 | @SEGMENTORS.register_module() 12 | class CascadeEncoderDecoder(EncoderDecoder): 13 | """Cascade Encoder Decoder segmentors. 14 | 15 | CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of 16 | CascadeEncoderDecoder are cascaded. The output of previous decoder_head 17 | will be the input of next decoder_head. 18 | """ 19 | 20 | def __init__(self, 21 | num_stages, 22 | backbone, 23 | decode_head, 24 | neck=None, 25 | auxiliary_head=None, 26 | train_cfg=None, 27 | test_cfg=None, 28 | pretrained=None, 29 | init_cfg=None): 30 | self.num_stages = num_stages 31 | super(CascadeEncoderDecoder, self).__init__( 32 | backbone=backbone, 33 | decode_head=decode_head, 34 | neck=neck, 35 | auxiliary_head=auxiliary_head, 36 | train_cfg=train_cfg, 37 | test_cfg=test_cfg, 38 | pretrained=pretrained, 39 | init_cfg=init_cfg) 40 | 41 | def _init_decode_head(self, decode_head): 42 | """Initialize ``decode_head``""" 43 | assert isinstance(decode_head, list) 44 | assert len(decode_head) == self.num_stages 45 | self.decode_head = nn.ModuleList() 46 | for i in range(self.num_stages): 47 | self.decode_head.append(builder.build_head(decode_head[i])) 48 | self.align_corners = self.decode_head[-1].align_corners 49 | self.num_classes = self.decode_head[-1].num_classes 50 | self.out_channels = self.decode_head[-1].out_channels 51 | 52 | def encode_decode(self, img, img_metas): 53 | """Encode images with backbone and decode into a semantic segmentation 54 | map of the same size as input.""" 55 | x = self.extract_feat(img) 56 | out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) 57 | for i in range(1, self.num_stages): 58 | out = self.decode_head[i].forward_test(x, out, img_metas, 59 | self.test_cfg) 60 | out = resize( 61 | input=out, 62 | size=img.shape[2:], 63 | mode='bilinear', 64 | align_corners=self.align_corners) 65 | return out 66 | 67 | def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): 68 | """Run forward function and calculate loss for decode head in 69 | training.""" 70 | losses = dict() 71 | 72 | loss_decode = self.decode_head[0].forward_train( 73 | x, img_metas, gt_semantic_seg, self.train_cfg) 74 | 75 | losses.update(add_prefix(loss_decode, 'decode_0')) 76 | 77 | for i in range(1, self.num_stages): 78 | # forward test again, maybe unnecessary for most methods. 79 | if i == 1: 80 | prev_outputs = self.decode_head[0].forward_test( 81 | x, img_metas, self.test_cfg) 82 | else: 83 | prev_outputs = self.decode_head[i - 1].forward_test( 84 | x, prev_outputs, img_metas, self.test_cfg) 85 | loss_decode = self.decode_head[i].forward_train( 86 | x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) 87 | losses.update(add_prefix(loss_decode, f'decode_{i}')) 88 | 89 | return losses 90 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .embed import PatchEmbed 3 | from .inverted_residual import InvertedResidual, InvertedResidualV3 4 | from .make_divisible import make_divisible 5 | from .res_layer import ResLayer 6 | from .se_layer import SELayer 7 | from .self_attention_block import SelfAttentionBlock 8 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, 9 | nlc_to_nchw) 10 | from .up_conv_block import UpConvBlock 11 | 12 | __all__ = [ 13 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 14 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 15 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' 16 | ] 17 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import build_conv_layer, build_norm_layer 3 | from mmcv.runner import Sequential 4 | from torch import nn as nn 5 | 6 | 7 | class ResLayer(Sequential): 8 | """ResLayer to build ResNet style backbone. 9 | 10 | Args: 11 | block (nn.Module): block used to build ResLayer. 12 | inplanes (int): inplanes of block. 13 | planes (int): planes of block. 14 | num_blocks (int): number of blocks. 15 | stride (int): stride of the first block. Default: 1 16 | avg_down (bool): Use AvgPool instead of stride conv when 17 | downsampling in the bottleneck. Default: False 18 | conv_cfg (dict): dictionary to construct and config conv layer. 19 | Default: None 20 | norm_cfg (dict): dictionary to construct and config norm layer. 21 | Default: dict(type='BN') 22 | multi_grid (int | None): Multi grid dilation rates of last 23 | stage. Default: None 24 | contract_dilation (bool): Whether contract first dilation of each layer 25 | Default: False 26 | """ 27 | 28 | def __init__(self, 29 | block, 30 | inplanes, 31 | planes, 32 | num_blocks, 33 | stride=1, 34 | dilation=1, 35 | avg_down=False, 36 | conv_cfg=None, 37 | norm_cfg=dict(type='BN'), 38 | multi_grid=None, 39 | contract_dilation=False, 40 | **kwargs): 41 | self.block = block 42 | 43 | downsample = None 44 | if stride != 1 or inplanes != planes * block.expansion: 45 | downsample = [] 46 | conv_stride = stride 47 | if avg_down: 48 | conv_stride = 1 49 | downsample.append( 50 | nn.AvgPool2d( 51 | kernel_size=stride, 52 | stride=stride, 53 | ceil_mode=True, 54 | count_include_pad=False)) 55 | downsample.extend([ 56 | build_conv_layer( 57 | conv_cfg, 58 | inplanes, 59 | planes * block.expansion, 60 | kernel_size=1, 61 | stride=conv_stride, 62 | bias=False), 63 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 64 | ]) 65 | downsample = nn.Sequential(*downsample) 66 | 67 | layers = [] 68 | if multi_grid is None: 69 | if dilation > 1 and contract_dilation: 70 | first_dilation = dilation // 2 71 | else: 72 | first_dilation = dilation 73 | else: 74 | first_dilation = multi_grid[0] 75 | layers.append( 76 | block( 77 | inplanes=inplanes, 78 | planes=planes, 79 | stride=stride, 80 | dilation=first_dilation, 81 | downsample=downsample, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | **kwargs)) 85 | inplanes = planes * block.expansion 86 | for i in range(1, num_blocks): 87 | layers.append( 88 | block( 89 | inplanes=inplanes, 90 | planes=planes, 91 | stride=1, 92 | dilation=dilation if multi_grid is None else multi_grid[i], 93 | conv_cfg=conv_cfg, 94 | norm_cfg=norm_cfg, 95 | **kwargs)) 96 | super(ResLayer, self).__init__(*layers) 97 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/utils/se_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from .make_divisible import make_divisible 7 | 8 | 9 | class SELayer(nn.Module): 10 | """Squeeze-and-Excitation Module. 11 | 12 | Args: 13 | channels (int): The input (and output) channels of the SE layer. 14 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will be 15 | ``int(channels/ratio)``. Default: 16. 16 | conv_cfg (None or dict): Config dict for convolution layer. 17 | Default: None, which means using conv2d. 18 | act_cfg (dict or Sequence[dict]): Config dict for activation layer. 19 | If act_cfg is a dict, two activation layers will be configured 20 | by this dict. If act_cfg is a sequence of dicts, the first 21 | activation layer will be configured by the first dict and the 22 | second activation layer will be configured by the second dict. 23 | Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, 24 | divisor=6.0)). 25 | """ 26 | 27 | def __init__(self, 28 | channels, 29 | ratio=16, 30 | conv_cfg=None, 31 | act_cfg=(dict(type='ReLU'), 32 | dict(type='HSigmoid', bias=3.0, divisor=6.0))): 33 | super(SELayer, self).__init__() 34 | if isinstance(act_cfg, dict): 35 | act_cfg = (act_cfg, act_cfg) 36 | assert len(act_cfg) == 2 37 | assert mmcv.is_tuple_of(act_cfg, dict) 38 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 39 | self.conv1 = ConvModule( 40 | in_channels=channels, 41 | out_channels=make_divisible(channels // ratio, 8), 42 | kernel_size=1, 43 | stride=1, 44 | conv_cfg=conv_cfg, 45 | act_cfg=act_cfg[0]) 46 | self.conv2 = ConvModule( 47 | in_channels=make_divisible(channels // ratio, 8), 48 | out_channels=channels, 49 | kernel_size=1, 50 | stride=1, 51 | conv_cfg=conv_cfg, 52 | act_cfg=act_cfg[1]) 53 | 54 | def forward(self, x): 55 | out = self.global_avgpool(x) 56 | out = self.conv1(out) 57 | out = self.conv2(out) 58 | return x * out 59 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/models/utils/shape_convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def nlc_to_nchw(x, hw_shape): 3 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 4 | 5 | Args: 6 | x (Tensor): The input tensor of shape [N, L, C] before conversion. 7 | hw_shape (Sequence[int]): The height and width of output feature map. 8 | 9 | Returns: 10 | Tensor: The output tensor of shape [N, C, H, W] after conversion. 11 | """ 12 | H, W = hw_shape 13 | assert len(x.shape) == 3 14 | B, L, C = x.shape 15 | assert L == H * W, 'The seq_len doesn\'t match H, W' 16 | return x.transpose(1, 2).reshape(B, C, H, W) 17 | 18 | 19 | def nchw_to_nlc(x): 20 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 21 | 22 | Args: 23 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion. 24 | 25 | Returns: 26 | Tensor: The output tensor of shape [N, L, C] after conversion. 27 | """ 28 | assert len(x.shape) == 4 29 | return x.flatten(2).transpose(1, 2).contiguous() 30 | 31 | 32 | def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): 33 | """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the 34 | reshaped tensor as the input of `module`, and the convert the output of 35 | `module`, whose shape is. 36 | 37 | [N, L, C], to [N, C, H, W]. 38 | 39 | Args: 40 | module (Callable): A callable object the takes a tensor 41 | with shape [N, L, C] as input. 42 | x (Tensor): The input tensor of shape [N, C, H, W]. 43 | contiguous: 44 | contiguous (Bool): Whether to make the tensor contiguous 45 | after each shape transform. 46 | 47 | Returns: 48 | Tensor: The output tensor of shape [N, C, H, W]. 49 | 50 | Example: 51 | >>> import torch 52 | >>> import torch.nn as nn 53 | >>> norm = nn.LayerNorm(4) 54 | >>> feature_map = torch.rand(4, 4, 5, 5) 55 | >>> output = nchw2nlc2nchw(norm, feature_map) 56 | """ 57 | B, C, H, W = x.shape 58 | if not contiguous: 59 | x = x.flatten(2).transpose(1, 2) 60 | x = module(x, **kwargs) 61 | x = x.transpose(1, 2).reshape(B, C, H, W) 62 | else: 63 | x = x.flatten(2).transpose(1, 2).contiguous() 64 | x = module(x, **kwargs) 65 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 66 | return x 67 | 68 | 69 | def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): 70 | """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the 71 | reshaped tensor as the input of `module`, and convert the output of 72 | `module`, whose shape is. 73 | 74 | [N, C, H, W], to [N, L, C]. 75 | 76 | Args: 77 | module (Callable): A callable object the takes a tensor 78 | with shape [N, C, H, W] as input. 79 | x (Tensor): The input tensor of shape [N, L, C]. 80 | hw_shape: (Sequence[int]): The height and width of the 81 | feature map with shape [N, C, H, W]. 82 | contiguous (Bool): Whether to make the tensor contiguous 83 | after each shape transform. 84 | 85 | Returns: 86 | Tensor: The output tensor of shape [N, L, C]. 87 | 88 | Example: 89 | >>> import torch 90 | >>> import torch.nn as nn 91 | >>> conv = nn.Conv2d(16, 16, 3, 1, 1) 92 | >>> feature_map = torch.rand(4, 25, 16) 93 | >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) 94 | """ 95 | H, W = hw_shape 96 | assert len(x.shape) == 3 97 | B, L, C = x.shape 98 | assert L == H * W, 'The seq_len doesn\'t match H, W' 99 | if not contiguous: 100 | x = x.transpose(1, 2).reshape(B, C, H, W) 101 | x = module(x, **kwargs) 102 | x = x.flatten(2).transpose(1, 2) 103 | else: 104 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 105 | x = module(x, **kwargs) 106 | x = x.flatten(2).transpose(1, 2).contiguous() 107 | return x 108 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .encoding import Encoding 3 | from .wrappers import Upsample, resize 4 | 5 | __all__ = ['Upsample', 'resize', 'Encoding'] 6 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Encoding(nn.Module): 8 | """Encoding Layer: a learnable residual encoder. 9 | 10 | Input is of shape (batch_size, channels, height, width). 11 | Output is of shape (batch_size, num_codes, channels). 12 | 13 | Args: 14 | channels: dimension of the features or feature channels 15 | num_codes: number of code words 16 | """ 17 | 18 | def __init__(self, channels, num_codes): 19 | super(Encoding, self).__init__() 20 | # init codewords and smoothing factor 21 | self.channels, self.num_codes = channels, num_codes 22 | std = 1. / ((num_codes * channels)**0.5) 23 | # [num_codes, channels] 24 | self.codewords = nn.Parameter( 25 | torch.empty(num_codes, channels, 26 | dtype=torch.float).uniform_(-std, std), 27 | requires_grad=True) 28 | # [num_codes] 29 | self.scale = nn.Parameter( 30 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 31 | requires_grad=True) 32 | 33 | @staticmethod 34 | def scaled_l2(x, codewords, scale): 35 | num_codes, channels = codewords.size() 36 | batch_size = x.size(0) 37 | reshaped_scale = scale.view((1, 1, num_codes)) 38 | expanded_x = x.unsqueeze(2).expand( 39 | (batch_size, x.size(1), num_codes, channels)) 40 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 41 | 42 | scaled_l2_norm = reshaped_scale * ( 43 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 44 | return scaled_l2_norm 45 | 46 | @staticmethod 47 | def aggregate(assignment_weights, x, codewords): 48 | num_codes, channels = codewords.size() 49 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 50 | batch_size = x.size(0) 51 | 52 | expanded_x = x.unsqueeze(2).expand( 53 | (batch_size, x.size(1), num_codes, channels)) 54 | encoded_feat = (assignment_weights.unsqueeze(3) * 55 | (expanded_x - reshaped_codewords)).sum(dim=1) 56 | return encoded_feat 57 | 58 | def forward(self, x): 59 | assert x.dim() == 4 and x.size(1) == self.channels 60 | # [batch_size, channels, height, width] 61 | batch_size = x.size(0) 62 | # [batch_size, height x width, channels] 63 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 64 | # assignment_weights: [batch_size, channels, num_codes] 65 | assignment_weights = F.softmax( 66 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 67 | # aggregate 68 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 69 | return encoded_feat 70 | 71 | def __repr__(self): 72 | repr_str = self.__class__.__name__ 73 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 74 | f'x{self.channels})' 75 | return repr_str 76 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > input_w: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super(Upsample, self).__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .misc import find_latest_checkpoint 5 | from .set_env import setup_multi_processes 6 | from .util_distribution import build_ddp, build_dp, get_device 7 | 8 | __all__ = [ 9 | 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 10 | 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device' 11 | ] 12 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmseg 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmseg". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | 26 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 27 | 28 | return logger 29 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | import warnings 5 | 6 | 7 | def find_latest_checkpoint(path, suffix='pth'): 8 | """This function is for finding the latest checkpoint. 9 | 10 | It will be used when automatically resume, modified from 11 | https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py 12 | 13 | Args: 14 | path (str): The path to find checkpoints. 15 | suffix (str): File extension for the checkpoint. Defaults to pth. 16 | 17 | Returns: 18 | latest_path(str | None): File path of the latest checkpoint. 19 | """ 20 | if not osp.exists(path): 21 | warnings.warn("The path of the checkpoints doesn't exist.") 22 | return None 23 | if osp.exists(osp.join(path, f'latest.{suffix}')): 24 | return osp.join(path, f'latest.{suffix}') 25 | 26 | checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) 27 | if len(checkpoints) == 0: 28 | warnings.warn('The are no checkpoints in the path') 29 | return None 30 | latest = -1 31 | latest_path = '' 32 | for checkpoint in checkpoints: 33 | if len(checkpoint) < len(latest_path): 34 | continue 35 | # `count` is iteration number, as checkpoints are saved as 36 | # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. 37 | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) 38 | if count > latest: 39 | latest = count 40 | latest_path = checkpoint 41 | return latest_path 42 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/utils/set_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | 5 | import cv2 6 | import torch.multiprocessing as mp 7 | 8 | from ..utils import get_root_logger 9 | 10 | 11 | def setup_multi_processes(cfg): 12 | """Setup multi-processing environment variables.""" 13 | logger = get_root_logger() 14 | 15 | # set multi-process start method 16 | if platform.system() != 'Windows': 17 | mp_start_method = cfg.get('mp_start_method', None) 18 | current_method = mp.get_start_method(allow_none=True) 19 | if mp_start_method in ('fork', 'spawn', 'forkserver'): 20 | logger.info( 21 | f'Multi-processing start method `{mp_start_method}` is ' 22 | f'different from the previous setting `{current_method}`.' 23 | f'It will be force set to `{mp_start_method}`.') 24 | mp.set_start_method(mp_start_method, force=True) 25 | else: 26 | logger.info( 27 | f'Multi-processing start method is `{mp_start_method}`') 28 | 29 | # disable opencv multithreading to avoid system being overloaded 30 | opencv_num_threads = cfg.get('opencv_num_threads', None) 31 | if isinstance(opencv_num_threads, int): 32 | logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') 33 | cv2.setNumThreads(opencv_num_threads) 34 | else: 35 | logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') 36 | 37 | if cfg.data.workers_per_gpu > 1: 38 | # setup OMP threads 39 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 40 | omp_num_threads = cfg.get('omp_num_threads', None) 41 | if 'OMP_NUM_THREADS' not in os.environ: 42 | if isinstance(omp_num_threads, int): 43 | logger.info(f'OMP num threads is {omp_num_threads}') 44 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 45 | else: 46 | logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') 47 | 48 | # setup MKL threads 49 | if 'MKL_NUM_THREADS' not in os.environ: 50 | mkl_num_threads = cfg.get('mkl_num_threads', None) 51 | if isinstance(mkl_num_threads, int): 52 | logger.info(f'MKL num threads is {mkl_num_threads}') 53 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 54 | else: 55 | logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') 56 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/utils/util_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 5 | 6 | from mmseg import digit_version 7 | 8 | dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} 9 | 10 | ddp_factory = {'cuda': MMDistributedDataParallel} 11 | 12 | 13 | def build_dp(model, device='cuda', dim=0, *args, **kwargs): 14 | """build DataParallel module by device type. 15 | 16 | if device is cuda, return a MMDataParallel module; if device is mlu, 17 | return a MLUDataParallel module. 18 | 19 | Args: 20 | model (:class:`nn.Module`): module to be parallelized. 21 | device (str): device type, cuda, cpu or mlu. Defaults to cuda. 22 | dim (int): Dimension used to scatter the data. Defaults to 0. 23 | 24 | Returns: 25 | :class:`nn.Module`: parallelized module. 26 | """ 27 | if device == 'cuda': 28 | model = model.cuda() 29 | elif device == 'mlu': 30 | assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ 31 | 'Please use MMCV >= 1.5.0 for MLU training!' 32 | from mmcv.device.mlu import MLUDataParallel 33 | dp_factory['mlu'] = MLUDataParallel 34 | model = model.mlu() 35 | 36 | return dp_factory[device](model, dim=dim, *args, **kwargs) 37 | 38 | 39 | def build_ddp(model, device='cuda', *args, **kwargs): 40 | """Build DistributedDataParallel module by device type. 41 | 42 | If device is cuda, return a MMDistributedDataParallel module; 43 | if device is mlu, return a MLUDistributedDataParallel module. 44 | 45 | Args: 46 | model (:class:`nn.Module`): module to be parallelized. 47 | device (str): device type, mlu or cuda. 48 | 49 | Returns: 50 | :class:`nn.Module`: parallelized module. 51 | 52 | References: 53 | .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. 54 | DistributedDataParallel.html 55 | """ 56 | assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' 57 | if device == 'cuda': 58 | model = model.cuda() 59 | elif device == 'mlu': 60 | assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ 61 | 'Please use MMCV >= 1.5.0 for MLU training!' 62 | from mmcv.device.mlu import MLUDistributedDataParallel 63 | ddp_factory['mlu'] = MLUDistributedDataParallel 64 | model = model.mlu() 65 | 66 | return ddp_factory[device](model, *args, **kwargs) 67 | 68 | 69 | def is_mlu_available(): 70 | """Returns a bool indicating if MLU is currently available.""" 71 | return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() 72 | 73 | 74 | def get_device(): 75 | """Returns an available device, cpu, cuda or mlu.""" 76 | is_device_available = { 77 | 'cuda': torch.cuda.is_available(), 78 | 'mlu': is_mlu_available() 79 | } 80 | device_list = [k for k, v in is_device_available.items() if v] 81 | return device_list[0] if len(device_list) == 1 else 'cpu' 82 | -------------------------------------------------------------------------------- /data_utils/easyportrait/mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.30.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /data_utils/face_parsing/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import sys 8 | import logging 9 | 10 | import torch.distributed as dist 11 | 12 | 13 | def setup_logger(logpth): 14 | logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 15 | logfile = osp.join(logpth, logfile) 16 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 17 | log_level = logging.INFO 18 | if dist.is_initialized() and not dist.get_rank()==0: 19 | log_level = logging.ERROR 20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 21 | logging.root.addHandler(logging.StreamHandler()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /data_utils/face_parsing/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /data_utils/face_tracking/3DMM/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/TalkingGaussian/9bd34f732e30cd94c0a9a087a125d2dda228e258/data_utils/face_tracking/3DMM/.gitkeep -------------------------------------------------------------------------------- /data_utils/face_tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fictionarry/TalkingGaussian/9bd34f732e30cd94c0a9a087a125d2dda228e258/data_utils/face_tracking/__init__.py -------------------------------------------------------------------------------- /data_utils/face_tracking/convert_BFM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | original_BFM = loadmat("3DMM/01_MorphableModel.mat") 5 | sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"] 6 | 7 | shapePC = original_BFM["shapePC"] 8 | shapeEV = original_BFM["shapeEV"] 9 | shapeMU = original_BFM["shapeMU"] 10 | texPC = original_BFM["texPC"] 11 | texEV = original_BFM["texEV"] 12 | texMU = original_BFM["texMU"] 13 | 14 | b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) 15 | mu_shape = shapeMU.reshape(-1, 3) 16 | 17 | b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) 18 | mu_tex = texMU.reshape(-1, 3) 19 | 20 | b_shape = b_shape[:, sub_inds, :].reshape(199, -1) 21 | mu_shape = mu_shape[sub_inds, :].reshape(-1) 22 | b_tex = b_tex[:, sub_inds, :].reshape(199, -1) 23 | mu_tex = mu_tex[sub_inds, :].reshape(-1) 24 | 25 | exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item() 26 | np.save( 27 | "3DMM/3DMM_info.npy", 28 | { 29 | "mu_shape": mu_shape, 30 | "b_shape": b_shape, 31 | "sig_shape": shapeEV.reshape(-1), 32 | "mu_exp": exp_info["mu_exp"], 33 | "b_exp": exp_info["base_exp"], 34 | "sig_exp": exp_info["sig_exp"], 35 | "mu_tex": mu_tex, 36 | "b_tex": b_tex, 37 | "sig_tex": texEV.reshape(-1), 38 | }, 39 | ) 40 | -------------------------------------------------------------------------------- /data_utils/face_tracking/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def load_dir(path, start, end): 7 | lmss = [] 8 | imgs_paths = [] 9 | for i in range(start, end): 10 | if os.path.isfile(os.path.join(path, str(i) + ".lms")): 11 | lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32) 12 | lmss.append(lms) 13 | imgs_paths.append(os.path.join(path, str(i) + ".jpg")) 14 | lmss = np.stack(lmss) 15 | lmss = torch.as_tensor(lmss).cuda() 16 | return lmss, imgs_paths 17 | -------------------------------------------------------------------------------- /data_utils/face_tracking/geo_transform.py: -------------------------------------------------------------------------------- 1 | """This module contains functions for geometry transform and camera projection""" 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def euler2rot(euler_angle): 8 | batch_size = euler_angle.shape[0] 9 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 10 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 11 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 12 | one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) 13 | zero = torch.zeros( 14 | (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device 15 | ) 16 | rot_x = torch.cat( 17 | ( 18 | torch.cat((one, zero, zero), 1), 19 | torch.cat((zero, theta.cos(), theta.sin()), 1), 20 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 21 | ), 22 | 2, 23 | ) 24 | rot_y = torch.cat( 25 | ( 26 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 27 | torch.cat((zero, one, zero), 1), 28 | torch.cat((phi.sin(), zero, phi.cos()), 1), 29 | ), 30 | 2, 31 | ) 32 | rot_z = torch.cat( 33 | ( 34 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 35 | torch.cat((psi.sin(), psi.cos(), zero), 1), 36 | torch.cat((zero, zero, one), 1), 37 | ), 38 | 2, 39 | ) 40 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 41 | 42 | 43 | def rot_trans_geo(geometry, rot, trans): 44 | rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) 45 | return rott_geo.permute(0, 2, 1) 46 | 47 | 48 | def euler_trans_geo(geometry, euler, trans): 49 | rot = euler2rot(euler) 50 | return rot_trans_geo(geometry, rot, trans) 51 | 52 | 53 | def proj_geo(rott_geo, camera_para): 54 | fx = camera_para[:, 0] 55 | fy = camera_para[:, 0] 56 | cx = camera_para[:, 1] 57 | cy = camera_para[:, 2] 58 | 59 | X = rott_geo[:, :, 0] 60 | Y = rott_geo[:, :, 1] 61 | Z = rott_geo[:, :, 2] 62 | 63 | fxX = fx[:, None] * X 64 | fyY = fy[:, None] * Y 65 | 66 | proj_x = -fxX / Z + cx[:, None] 67 | proj_y = fyY / Z + cy[:, None] 68 | 69 | return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) 70 | -------------------------------------------------------------------------------- /data_utils/face_tracking/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def compute_tri_normal(geometry, tris): 7 | tri_1 = tris[:, 0] 8 | tri_2 = tris[:, 1] 9 | tri_3 = tris[:, 2] 10 | vert_1 = torch.index_select(geometry, 1, tri_1) 11 | vert_2 = torch.index_select(geometry, 1, tri_2) 12 | vert_3 = torch.index_select(geometry, 1, tri_3) 13 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) 14 | normal = nn.functional.normalize(nnorm) 15 | return normal 16 | 17 | 18 | def euler2rot(euler_angle): 19 | batch_size = euler_angle.shape[0] 20 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 21 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 22 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 23 | one = torch.ones(batch_size, 1, 1).to(euler_angle.device) 24 | zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) 25 | rot_x = torch.cat( 26 | ( 27 | torch.cat((one, zero, zero), 1), 28 | torch.cat((zero, theta.cos(), theta.sin()), 1), 29 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 30 | ), 31 | 2, 32 | ) 33 | rot_y = torch.cat( 34 | ( 35 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 36 | torch.cat((zero, one, zero), 1), 37 | torch.cat((phi.sin(), zero, phi.cos()), 1), 38 | ), 39 | 2, 40 | ) 41 | rot_z = torch.cat( 42 | ( 43 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 44 | torch.cat((psi.sin(), psi.cos(), zero), 1), 45 | torch.cat((zero, zero, one), 1), 46 | ), 47 | 2, 48 | ) 49 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 50 | 51 | 52 | def rot_trans_pts(geometry, rot, trans): 53 | rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] 54 | return rott_geo.permute(0, 2, 1) 55 | 56 | 57 | def cal_lap_loss(tensor_list, weight_list): 58 | lap_kernel = ( 59 | torch.Tensor((-0.5, 1.0, -0.5)) 60 | .unsqueeze(0) 61 | .unsqueeze(0) 62 | .float() 63 | .to(tensor_list[0].device) 64 | ) 65 | loss_lap = 0 66 | for i in range(len(tensor_list)): 67 | in_tensor = tensor_list[i] 68 | in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) 69 | out_tensor = F.conv1d(in_tensor, lap_kernel) 70 | loss_lap += torch.mean(out_tensor ** 2) * weight_list[i] 71 | return loss_lap 72 | 73 | 74 | def proj_pts(rott_geo, focal_length, cxy): 75 | cx, cy = cxy[0], cxy[1] 76 | X = rott_geo[:, :, 0] 77 | Y = rott_geo[:, :, 1] 78 | Z = rott_geo[:, :, 2] 79 | fxX = focal_length * X 80 | fyY = focal_length * Y 81 | proj_x = -fxX / Z + cx 82 | proj_y = fyY / Z + cy 83 | return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) 84 | 85 | 86 | def forward_rott(geometry, euler_angle, trans): 87 | rot = euler2rot(euler_angle) 88 | rott_geo = rot_trans_pts(geometry, rot, trans) 89 | return rott_geo 90 | 91 | 92 | def forward_transform(geometry, euler_angle, trans, focal_length, cxy): 93 | rot = euler2rot(euler_angle) 94 | rott_geo = rot_trans_pts(geometry, rot, trans) 95 | proj_geo = proj_pts(rott_geo, focal_length, cxy) 96 | return proj_geo 97 | 98 | 99 | def cal_lan_loss(proj_lan, gt_lan): 100 | return torch.mean((proj_lan - gt_lan) ** 2) 101 | 102 | 103 | def cal_col_loss(pred_img, gt_img, img_mask): 104 | pred_img = pred_img.float() 105 | # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 106 | loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255 107 | loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) 108 | loss = torch.mean(loss) 109 | return loss 110 | -------------------------------------------------------------------------------- /data_utils/hubert.py: -------------------------------------------------------------------------------- 1 | from transformers import Wav2Vec2Processor, HubertModel 2 | import soundfile as sf 3 | import numpy as np 4 | import torch 5 | 6 | print("Loading the Wav2Vec2 Processor...") 7 | wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") 8 | print("Loading the HuBERT Model...") 9 | hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") 10 | 11 | 12 | def get_hubert_from_16k_wav(wav_16k_name): 13 | speech_16k, _ = sf.read(wav_16k_name) 14 | hubert = get_hubert_from_16k_speech(speech_16k) 15 | return hubert 16 | 17 | @torch.no_grad() 18 | def get_hubert_from_16k_speech(speech, device="cuda:0"): 19 | global hubert_model 20 | hubert_model = hubert_model.to(device) 21 | if speech.ndim ==2: 22 | speech = speech[:, 0] # [T, 2] ==> [T,] 23 | input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] 24 | input_values_all = input_values_all.to(device) 25 | # For long audio sequence, due to the memory limitation, we cannot process them in one run 26 | # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 27 | # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. 28 | # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 29 | # We have the equation to calculate out time step: T = floor((t-k)/s) 30 | # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip 31 | # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N 32 | kernel = 400 33 | stride = 320 34 | clip_length = stride * 1000 35 | num_iter = input_values_all.shape[1] // clip_length 36 | expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride 37 | res_lst = [] 38 | for i in range(num_iter): 39 | if i == 0: 40 | start_idx = 0 41 | end_idx = clip_length - stride + kernel 42 | else: 43 | start_idx = clip_length * i 44 | end_idx = start_idx + (clip_length - stride + kernel) 45 | input_values = input_values_all[:, start_idx: end_idx] 46 | hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] 47 | res_lst.append(hidden_states[0]) 48 | if num_iter > 0: 49 | input_values = input_values_all[:, clip_length * num_iter:] 50 | else: 51 | input_values = input_values_all 52 | # if input_values.shape[1] != 0: 53 | if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it 54 | hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] 55 | res_lst.append(hidden_states[0]) 56 | else: 57 | print("skip the latest ", input_values.shape[1]) 58 | ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] 59 | # assert ret.shape[0] == expected_T 60 | assert abs(ret.shape[0] - expected_T) <= 1 61 | if ret.shape[0] < expected_T: 62 | ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) 63 | else: 64 | ret = ret[:expected_T] 65 | return ret 66 | 67 | def make_even_first_dim(tensor): 68 | size = list(tensor.size()) 69 | if size[0] % 2 == 1: 70 | size[0] -= 1 71 | return tensor[:size[0]] 72 | return tensor 73 | 74 | import soundfile as sf 75 | import numpy as np 76 | import torch 77 | from argparse import ArgumentParser 78 | import librosa 79 | 80 | parser = ArgumentParser() 81 | parser.add_argument('--wav', type=str, help='') 82 | args = parser.parse_args() 83 | 84 | wav_name = args.wav 85 | 86 | speech_16k, sr = librosa.load(wav_name, sr=16000) 87 | # speech_16k = librosa.resample(speech, orig_sr=sr, target_sr=16000) 88 | # print("SR: {} to {}".format(sr, 16000)) 89 | # print(speech.shape, speech_16k.shape) 90 | 91 | hubert_hidden = get_hubert_from_16k_speech(speech_16k) 92 | hubert_hidden = make_even_first_dim(hubert_hidden).reshape(-1, 2, 1024) 93 | np.save(wav_name.replace('.wav', '_hu.npy'), hubert_hidden.detach().numpy()) 94 | print(hubert_hidden.detach().numpy().shape) -------------------------------------------------------------------------------- /data_utils/wav2mel_hparams.py: -------------------------------------------------------------------------------- 1 | class HParams: 2 | def __init__(self, **kwargs): 3 | self.data = {} 4 | 5 | for key, value in kwargs.items(): 6 | self.data[key] = value 7 | 8 | def __getattr__(self, key): 9 | if key not in self.data: 10 | raise AttributeError("'HParams' object has no attribute %s" % key) 11 | return self.data[key] 12 | 13 | def set_hparam(self, key, value): 14 | self.data[key] = value 15 | 16 | # Default hyperparameters 17 | hparams = HParams( 18 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 19 | # network 20 | rescale=True, # Whether to rescale audio prior to preprocessing 21 | rescaling_max=0.9, # Rescaling value 22 | 23 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 24 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 25 | # Does not work if n_ffit is not multiple of hop_size!! 26 | use_lws=False, 27 | 28 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 29 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 30 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 31 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 32 | 33 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 34 | 35 | # Mel and Linear spectrograms normalization/scaling and clipping 36 | signal_normalization=True, 37 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 38 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 39 | symmetric_mels=True, 40 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 41 | # faster and cleaner convergence) 42 | max_abs_value=4., 43 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 44 | # be too big to avoid gradient explosion, 45 | # not too small for fast convergence) 46 | # Contribution by @begeekmyfriend 47 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 48 | # levels. Also allows for better G&L phase reconstruction) 49 | preemphasize=True, # whether to apply filter 50 | preemphasis=0.97, # filter coefficient. 51 | 52 | # Limits 53 | min_level_db=-100, 54 | ref_level_db=20, 55 | fmin=65, 56 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 57 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 58 | fmax=6000, # To be increased/reduced depending on data. 59 | 60 | ###################### Our training parameters ################################# 61 | img_size=96, 62 | fps=25, 63 | 64 | batch_size=16, 65 | initial_learning_rate=1e-4, 66 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 67 | num_workers=16, 68 | checkpoint_interval=3000, 69 | eval_interval=3000, 70 | save_optimizer_state=True, 71 | 72 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 73 | syncnet_batch_size=64, 74 | syncnet_lr=1e-4, 75 | syncnet_eval_interval=10000, 76 | syncnet_checkpoint_interval=10000, 77 | 78 | disc_wt=0.07, 79 | disc_initial_learning_rate=1e-4, 80 | ) 81 | -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | 16 | self.output_dim = 0 17 | if self.include_input: 18 | self.output_dim += self.input_dim 19 | 20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 21 | 22 | if log_sampling: 23 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 24 | else: 25 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 26 | 27 | self.freq_bands = self.freq_bands.numpy().tolist() 28 | 29 | def forward(self, input, **kwargs): 30 | 31 | out = [] 32 | if self.include_input: 33 | out.append(input) 34 | 35 | for i in range(len(self.freq_bands)): 36 | freq = self.freq_bands[i] 37 | for p_fn in self.periodic_fns: 38 | out.append(p_fn(input * freq)) 39 | 40 | out = torch.cat(out, dim=-1) 41 | 42 | 43 | return out 44 | 45 | def get_encoder(encoding, input_dim=3, 46 | multires=6, 47 | degree=4, 48 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 49 | **kwargs): 50 | 51 | if encoding == 'None': 52 | return lambda x, **kwargs: x, input_dim 53 | 54 | elif encoding == 'frequency': 55 | #encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 56 | from freqencoder import FreqEncoder 57 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 58 | 59 | elif encoding == 'sphere_harmonics': 60 | from shencoder import SHEncoder 61 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 62 | 63 | elif encoding == 'hashgrid': 64 | from gridencoder import GridEncoder 65 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 66 | 67 | elif encoding == 'tiledgrid': 68 | from gridencoder import GridEncoder 69 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 70 | 71 | elif encoding == 'ash': 72 | from ashencoder import AshEncoder 73 | encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) 74 | 75 | else: 76 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 77 | 78 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: talking_gaussian 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.3 8 | - plyfile=0.8.1 9 | - python=3.7.13 10 | - pip=22.3.1 11 | - pytorch=1.12.1 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | - tqdm 15 | - ffmpeg 16 | - openh264 17 | - pip: 18 | - ./submodules/diff-gaussian-rasterization 19 | - ./submodules/simple-knn 20 | - ./gridencoder 21 | - numpy 22 | - pillow 23 | - scipy 24 | - tensorboard 25 | - opencv-python 26 | - tensorboardX 27 | 28 | - pandas 29 | - tqdm 30 | - matplotlib 31 | - PyMCubes==0.1.4 32 | - rich 33 | - packaging 34 | - scikit-learn 35 | 36 | - face_alignment 37 | - python_speech_features 38 | - numba 39 | - resampy 40 | - pyaudio 41 | - soundfile 42 | - configargparse 43 | 44 | - lpips 45 | - imageio-ffmpeg 46 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | 17 | #endif -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import gc 13 | import os 14 | import random 15 | import json 16 | from utils.system_utils import searchForMaxIteration 17 | from scene.dataset_readers import sceneLoadTypeCallbacks 18 | from scene.gaussian_model import GaussianModel 19 | from scene.motion_net import MotionNetwork, MouthMotionNetwork 20 | from arguments import ModelParams 21 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 22 | 23 | class Scene: 24 | 25 | gaussians : GaussianModel 26 | 27 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 28 | """b 29 | :param path: Path to colmap scene main folder. 30 | """ 31 | self.model_path = args.model_path 32 | self.loaded_iter = None 33 | self.gaussians = gaussians 34 | 35 | if load_iteration: 36 | if load_iteration == -1: 37 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 38 | else: 39 | self.loaded_iter = load_iteration 40 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 41 | 42 | self.train_cameras = {} 43 | self.test_cameras = {} 44 | 45 | if os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 46 | print("Found transforms_train.json file, assuming Blender data set!") 47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, args=args) 48 | else: 49 | assert False, "Could not recognize scene type!" 50 | 51 | if not self.loaded_iter: 52 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 53 | dest_file.write(src_file.read()) 54 | json_cams = [] 55 | camlist = [] 56 | if scene_info.test_cameras: 57 | camlist.extend(scene_info.test_cameras) 58 | if scene_info.train_cameras: 59 | camlist.extend(scene_info.train_cameras) 60 | for id, cam in enumerate(camlist): 61 | json_cams.append(camera_to_JSON(id, cam)) 62 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 63 | json.dump(json_cams, file) 64 | 65 | if shuffle: 66 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 67 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 68 | 69 | self.cameras_extent = scene_info.nerf_normalization["radius"] 70 | 71 | for resolution_scale in resolution_scales: 72 | print("Loading Training Cameras") 73 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 74 | print("Loading Test Cameras") 75 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 76 | 77 | if self.loaded_iter: 78 | self.gaussians.load_ply(os.path.join(self.model_path, 79 | "point_cloud", 80 | "iteration_" + str(self.loaded_iter), 81 | "point_cloud.ply")) 82 | else: 83 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 84 | 85 | gc.collect() 86 | 87 | 88 | def save(self, iteration): 89 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 90 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 91 | 92 | def getTrainCameras(self, scale=1.0): 93 | return self.train_cameras[scale] 94 | 95 | def getTestCameras(self, scale=1.0): 96 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, background, talking_dict, 19 | image_name, image_path, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | self.image_path = image_path 32 | self.talking_dict = talking_dict 33 | 34 | try: 35 | self.data_device = torch.device(data_device) 36 | except Exception as e: 37 | print(e) 38 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 39 | self.data_device = torch.device("cuda") 40 | 41 | if image is not None: 42 | self.original_image = image.clamp(0, 255).to(self.data_device) 43 | self.image_width = self.original_image.shape[2] 44 | self.image_height = self.original_image.shape[1] 45 | else: 46 | self.original_image = None 47 | 48 | if background is not None: 49 | self.background = background.clamp(0, 255).to(self.data_device) 50 | else: 51 | self.background = None 52 | 53 | # for key in self.mask.keys(): 54 | # self.mask[key] = torch.as_tensor(self.mask[key], device=self.data_device) 55 | 56 | if gt_alpha_mask is not None: 57 | self.original_image *= gt_alpha_mask.to(self.data_device) 58 | # else: 59 | # self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 60 | 61 | self.zfar = 100.0 62 | self.znear = 0.01 63 | 64 | self.trans = trans 65 | self.scale = scale 66 | 67 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 68 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 69 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 70 | self.camera_center = self.world_view_transform.inverse()[3, :3] 71 | 72 | class MiniCam: 73 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 74 | self.image_width = width 75 | self.image_height = height 76 | self.FoVy = fovy 77 | self.FoVx = fovx 78 | self.znear = znear 79 | self.zfar = zfar 80 | self.world_view_transform = world_view_transform 81 | self.full_proj_transform = full_proj_transform 82 | view_inv = torch.inverse(self.world_view_transform) 83 | self.camera_center = view_inv[3][:3] 84 | 85 | -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_parsing/79999_iter.pth?raw=true -O data_utils/face_parsing/79999_iter.pth 2 | 3 | mkdir data_utils/face_tracking/3DMM 4 | 5 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/exp_info.npy?raw=true -O data_utils/face_tracking/3DMM/exp_info.npy 6 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/keys_info.npy?raw=true -O data_utils/face_tracking/3DMM/keys_info.npy 7 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/sub_mesh.obj?raw=true -O data_utils/face_tracking/3DMM/sub_mesh.obj 8 | wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/topology_info.npy?raw=true -O data_utils/face_tracking/3DMM/topology_info.npy -------------------------------------------------------------------------------- /scripts/train_xx.sh: -------------------------------------------------------------------------------- 1 | dataset=$1 2 | workspace=$2 3 | gpu_id=$3 4 | audio_extractor='deepspeech' # deepspeech, esperanto, hubert 5 | 6 | export CUDA_VISIBLE_DEVICES=$gpu_id 7 | 8 | python train_mouth.py -s $dataset -m $workspace --audio_extractor $audio_extractor 9 | python train_face.py -s $dataset -m $workspace --init_num 2000 --densify_grad_threshold 0.0005 --audio_extractor $audio_extractor 10 | python train_fuse.py -s $dataset -m $workspace --opacity_lr 0.001 --audio_extractor $audio_extractor 11 | 12 | # # Parallel. Ensure that you have aleast 2 GPUs, and over N x 64GB memory for about N x 5k frames (IMPORTANT! Otherwise the computer will crash). 13 | # CUDA_VISIBLE_DEVICES=$gpu_id python train_mouth.py -s $dataset -m $workspace --audio_extractor $audio_extractor & 14 | # CUDA_VISIBLE_DEVICES=$((gpu_id+1)) python train_face.py -s $dataset -m $workspace --init_num 2000 --densify_grad_threshold 0.0005 --audio_extractor $audio_extractor 15 | # CUDA_VISIBLE_DEVICES=$gpu_id python train_fuse.py -s $dataset -m $workspace --opacity_lr 0.001 --audio_extractor $audio_extractor 16 | 17 | python synthesize_fuse.py -s $dataset -m $workspace --eval --audio_extractor $audio_extractor 18 | python metrics.py $workspace/test/ours_None/renders/out.mp4 $workspace/test/ours_None/gt/out.mp4 -------------------------------------------------------------------------------- /utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_audio_features(features, att_mode, index): 4 | if att_mode == 0: 5 | return features[[index]] 6 | elif att_mode == 1: 7 | left = index - 8 8 | pad_left = 0 9 | if left < 0: 10 | pad_left = -left 11 | left = 0 12 | auds = features[left:index] 13 | if pad_left > 0: 14 | # pad may be longer than auds, so do not use zeros_like 15 | auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) 16 | return auds 17 | elif att_mode == 2: 18 | left = index - 4 19 | right = index + 4 20 | pad_left = 0 21 | pad_right = 0 22 | if left < 0: 23 | pad_left = -left 24 | left = 0 25 | if right > features.shape[0]: 26 | pad_right = right - features.shape[0] 27 | right = features.shape[0] 28 | auds = features[left:right] 29 | if pad_left > 0: 30 | auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) 31 | if pad_right > 0: 32 | auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] 33 | return auds 34 | else: 35 | raise NotImplementedError(f'wrong att_mode: {att_mode}') -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import torch 14 | import numpy as np 15 | import os 16 | from PIL import Image 17 | from utils.general_utils import PILtoTorch 18 | from utils.graphics_utils import fov2focal 19 | 20 | WARNED = False 21 | 22 | def loadCam(args, id, cam_info, resolution_scale): 23 | 24 | if cam_info.image is not None: 25 | image_rgb = PILtoTorch(cam_info.image).type("torch.ByteTensor") 26 | gt_image = image_rgb[:3, ...] 27 | else: 28 | gt_image = None 29 | 30 | if cam_info.background is not None: 31 | background = PILtoTorch(cam_info.background)[:3, ...].type("torch.ByteTensor") 32 | else: 33 | background = None 34 | 35 | loaded_mask = None 36 | 37 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 38 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 39 | image=gt_image, gt_alpha_mask=loaded_mask, background=background, talking_dict=cam_info.talking_dict, 40 | image_name=cam_info.image_name, image_path=cam_info.image_path, uid=id, data_device=args.data_device) 41 | 42 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 43 | camera_list = [] 44 | 45 | for id, c in enumerate(cam_infos): 46 | camera_list.append(loadCam(args, id, c, resolution_scale)) 47 | 48 | return camera_list 49 | 50 | def camera_to_JSON(id, camera : Camera): 51 | Rt = np.zeros((4, 4)) 52 | Rt[:3, :3] = camera.R.transpose() 53 | Rt[:3, 3] = camera.T 54 | Rt[3, 3] = 1.0 55 | 56 | W2C = np.linalg.inv(Rt) 57 | pos = W2C[:3, 3] 58 | rot = W2C[:3, :3] 59 | serializable_array_2d = [x.tolist() for x in rot] 60 | camera_entry = { 61 | 'id' : id, 62 | 'img_name' : camera.image_name, 63 | 'width' : camera.width, 64 | 'height' : camera.height, 65 | 'position': pos.tolist(), 66 | 'rotation': serializable_array_2d, 67 | 'fy' : fov2focal(camera.FovY, camera.height), 68 | 'fx' : fov2focal(camera.FovX, camera.width) 69 | } 70 | return camera_entry 71 | 72 | 73 | def loadCamOnTheFly(camera): 74 | image_path = camera.image_path 75 | image = Image.open(image_path) 76 | image = np.array(image.convert("RGB")) 77 | 78 | bg_img = PILtoTorch(np.array(Image.open(os.path.join("/".join(image_path.split("/")[:-2]), 'bc.jpg')).convert("RGB"))).to(camera.data_device) 79 | torso_img_path = image_path.replace("gt_imgs", "torso_imgs").replace("jpg", "png") 80 | torso_img = PILtoTorch(np.array(Image.open(torso_img_path).convert("RGBA")) * 1.0).to(camera.data_device) 81 | bg = torso_img[:3] * torso_img[3:] / 255 + bg_img * (1.0 - torso_img[3:] / 255) 82 | 83 | teeth_mask_path = image_path.replace("gt_imgs", "teeth_mask").replace("jpg", "npy") 84 | teeth_mask = torch.as_tensor(np.load(teeth_mask_path)).to(camera.data_device) 85 | 86 | mask_path = image_path.replace("gt_imgs", "parsing").replace("jpg", "png") 87 | mask = PILtoTorch(np.array(Image.open(mask_path).convert("RGB")) * 1.0).to(camera.data_device) 88 | camera.talking_dict['face_mask'] = (mask[2] > 254) * (mask[0] == 0) * (mask[1] == 0) ^ teeth_mask 89 | camera.talking_dict['hair_mask'] = (mask[0] < 1) * (mask[1] < 1) * (mask[2] < 1) 90 | camera.talking_dict['mouth_mask'] = (mask[0] == 100) * (mask[1] == 100) * (mask[2] == 100) + teeth_mask 91 | 92 | camera.original_image = PILtoTorch(image).type("torch.ByteTensor").clamp(0, 255).to(camera.data_device) 93 | camera.background = bg.type("torch.ByteTensor").clamp(0, 255).to(camera.data_device) 94 | camera.image_width = camera.original_image.shape[2] 95 | camera.image_height = camera.original_image.shape[1] 96 | 97 | return camera 98 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def patchify(input, patch_size): 18 | patches = F.unfold(input, kernel_size=patch_size, stride=patch_size).permute(0,2,1).view(-1, 3, patch_size, patch_size) 19 | return patches 20 | 21 | def l1_loss(network_output, gt): 22 | return torch.abs((network_output - gt)).mean() 23 | 24 | def l2_loss(network_output, gt): 25 | return ((network_output - gt) ** 2).mean() 26 | 27 | def gaussian(window_size, sigma): 28 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 29 | return gauss / gauss.sum() 30 | 31 | def create_window(window_size, channel): 32 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 33 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 34 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 35 | return window 36 | 37 | def ssim(img1, img2, window_size=11, size_average=True): 38 | channel = img1.size(-3) 39 | window = create_window(window_size, channel) 40 | 41 | if img1.is_cuda: 42 | window = window.cuda(img1.get_device()) 43 | window = window.type_as(img1) 44 | 45 | return _ssim(img1, img2, window, window_size, channel, size_average) 46 | 47 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 48 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 49 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 50 | 51 | mu1_sq = mu1.pow(2) 52 | mu2_sq = mu2.pow(2) 53 | mu1_mu2 = mu1 * mu2 54 | 55 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 56 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 57 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 58 | 59 | C1 = 0.01 ** 2 60 | C2 = 0.03 ** 2 61 | 62 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 63 | 64 | if size_average: 65 | return ssim_map.mean() 66 | else: 67 | return ssim_map.mean(1).mean(1).mean(1) 68 | 69 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | --------------------------------------------------------------------------------