├── .gitignore ├── README.md ├── assets ├── .DS_Store ├── CGA.png ├── SDA.png ├── archi.png └── teaser.png ├── benchmark └── README.md ├── configs ├── base │ ├── base_rgb_weather.py │ ├── base_rgb_weather_320.py │ └── base_rgb_weather_384.py ├── noise │ ├── all_in_one_L.py │ └── all_in_one_S.py └── weather │ ├── all_in_one_L.py │ └── all_in_one_S.py ├── data ├── all-in-one-test │ └── all_in_one.json └── all-in-one-train │ ├── all_in_one.json │ └── description_v2.json ├── get_flops.py ├── mmedit ├── __init__.py ├── apis │ ├── __init__.py │ ├── dmci_runner.py │ ├── generation_inference.py │ ├── inpainting_inference.py │ ├── matting_inference.py │ ├── restoration_face_inference.py │ ├── restoration_inference.py │ ├── restoration_video_inference.py │ ├── test.py │ ├── test_stage1.py │ ├── train_epoch.py │ ├── train_iter.py │ ├── train_s1.py │ └── video_interpolation_inference.py ├── core │ ├── __init__.py │ ├── distributed_wrapper.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── eval_hooks.py │ │ ├── eval_hooks_s1.py │ │ ├── metric_utils.py │ │ └── metrics.py │ ├── export │ │ ├── __init__.py │ │ └── wrappers.py │ ├── hooks │ │ ├── __init__.py │ │ ├── ema.py │ │ └── visualization.py │ ├── mask.py │ ├── misc.py │ ├── optimizer │ │ ├── __init__.py │ │ └── builder.py │ ├── scheduler │ │ ├── __init__.py │ │ └── lr_updater.py │ └── utils │ │ ├── __init__.py │ │ └── dist_utils.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── base_generation_dataset.py │ ├── base_matting_dataset.py │ ├── base_sr_dataset.py │ ├── base_vfi_dataset.py │ ├── builder.py │ ├── comp1k_dataset.py │ ├── dataset_wrappers.py │ ├── generation_paired_dataset.py │ ├── generation_unpaired_dataset.py │ ├── img_inpainting_dataset.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── augmentation.py │ │ ├── blur_kernels.py │ │ ├── compose.py │ │ ├── crop.py │ │ ├── formating.py │ │ ├── generate_assistant.py │ │ ├── loading.py │ │ ├── matlab_like_resize.py │ │ ├── matting_aug.py │ │ ├── normalization.py │ │ ├── random_degradations.py │ │ ├── random_down_sampling.py │ │ └── utils.py │ ├── registry.py │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ ├── sr_annotation_dataset.py │ ├── sr_facial_landmark_dataset.py │ ├── sr_folder_dataset.py │ ├── sr_folder_gt_dataset.py │ ├── sr_folder_multiple_gt_dataset.py │ ├── sr_folder_ref_dataset.py │ ├── sr_folder_video_dataset.py │ ├── sr_lmdb_dataset.py │ ├── sr_reds_dataset.py │ ├── sr_reds_multiple_gt_dataset.py │ ├── sr_reds_online_gt_dataset.py │ ├── sr_test_multiple_gt_dataset.py │ ├── sr_vid4_dataset.py │ ├── sr_vimeo90k_dataset.py │ ├── sr_vimeo90k_multiple_gt_dataset.py │ ├── vfi_vimeo90k_7frames_dataset.py │ └── vfi_vimeo90k_dataset.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── encoder_decoders │ │ │ ├── __init__.py │ │ │ ├── aot_encoder_decoder.py │ │ │ ├── decoders │ │ │ │ ├── __init__.py │ │ │ │ ├── aot_decoder.py │ │ │ │ ├── deepfill_decoder.py │ │ │ │ ├── fba_decoder.py │ │ │ │ ├── gl_decoder.py │ │ │ │ ├── indexnet_decoder.py │ │ │ │ ├── pconv_decoder.py │ │ │ │ ├── plain_decoder.py │ │ │ │ └── resnet_dec.py │ │ │ ├── encoders │ │ │ │ ├── __init__.py │ │ │ │ ├── aot_encoder.py │ │ │ │ ├── deepfill_encoder.py │ │ │ │ ├── fba_encoder.py │ │ │ │ ├── gl_encoder.py │ │ │ │ ├── indexnet_encoder.py │ │ │ │ ├── pconv_encoder.py │ │ │ │ ├── resnet.py │ │ │ │ ├── resnet_enc.py │ │ │ │ └── vgg.py │ │ │ ├── gl_encoder_decoder.py │ │ │ ├── necks │ │ │ │ ├── __init__.py │ │ │ │ ├── aot_neck.py │ │ │ │ ├── contextual_attention_neck.py │ │ │ │ └── gl_dilation.py │ │ │ ├── pconv_encoder_decoder.py │ │ │ ├── simple_encoder_decoder.py │ │ │ └── two_stage_encoder_decoder.py │ │ ├── generation_backbones │ │ │ ├── __init__.py │ │ │ ├── resnet_generator.py │ │ │ └── unet_generator.py │ │ ├── sr_backbones │ │ │ ├── __init__.py │ │ │ ├── airnet.py │ │ │ ├── attention_block.py │ │ │ ├── basicvsr_net.py │ │ │ ├── basicvsr_pp.py │ │ │ ├── convmod.py │ │ │ ├── deform_conv.py │ │ │ ├── dic_net.py │ │ │ ├── dmci_former.py │ │ │ ├── dmci_restormer.py │ │ │ ├── duf.py │ │ │ ├── edsr.py │ │ │ ├── edvr_net.py │ │ │ ├── glean_styleganv2.py │ │ │ ├── iconvsr.py │ │ │ ├── liif_net.py │ │ │ ├── mapping_net.py │ │ │ ├── moco.py │ │ │ ├── network_art.py │ │ │ ├── ops.py │ │ │ ├── rdn.py │ │ │ ├── real_basicvsr_net.py │ │ │ ├── restormer_net.py │ │ │ ├── rrdb_net.py │ │ │ ├── sr_resnet.py │ │ │ ├── srcnn.py │ │ │ ├── tdan_net.py │ │ │ ├── tof.py │ │ │ ├── ttsr_net.py │ │ │ └── uformer.py │ │ └── vfi_backbones │ │ │ ├── __init__.py │ │ │ ├── cain_net.py │ │ │ ├── flavr_net.py │ │ │ └── tof_vfi_net.py │ ├── base.py │ ├── builder.py │ ├── common │ │ ├── __init__.py │ │ ├── aspp.py │ │ ├── common_model_dual.py │ │ ├── contextual_attention.py │ │ ├── conv.py │ │ ├── downsample.py │ │ ├── ensemble.py │ │ ├── entropy_models.py │ │ ├── flow_warp.py │ │ ├── gated_conv_module.py │ │ ├── gca_module.py │ │ ├── generation_model_utils.py │ │ ├── img_normalize.py │ │ ├── layers.py │ │ ├── linear_module.py │ │ ├── mask_conv_module.py │ │ ├── model_utils.py │ │ ├── partial_conv.py │ │ ├── separable_conv_module.py │ │ ├── sr_backbone_utils.py │ │ ├── upsample.py │ │ └── video_net.py │ ├── components │ │ ├── __init__.py │ │ ├── discriminators │ │ │ ├── __init__.py │ │ │ ├── deepfill_disc.py │ │ │ ├── discriminator_arch.py │ │ │ ├── gl_disc.py │ │ │ ├── light_cnn.py │ │ │ ├── modified_vgg.py │ │ │ ├── multi_layer_disc.py │ │ │ ├── patch_disc.py │ │ │ ├── smpatch_disc.py │ │ │ ├── ttsr_disc.py │ │ │ └── unet_disc.py │ │ ├── refiners │ │ │ ├── __init__.py │ │ │ ├── deepfill_refiner.py │ │ │ ├── mlp_refiner.py │ │ │ └── plain_refiner.py │ │ └── stylegan2 │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── generator_discriminator.py │ │ │ ├── generator_discriminatorxl.py │ │ │ └── modules.py │ ├── extractors │ │ ├── __init__.py │ │ ├── feedback_hour_glass.py │ │ └── lte.py │ ├── inpaintors │ │ ├── __init__.py │ │ ├── aot_inpaintor.py │ │ ├── deepfillv1.py │ │ ├── gl_inpaintor.py │ │ ├── one_stage.py │ │ ├── pconv_inpaintor.py │ │ └── two_stage.py │ ├── losses │ │ ├── RDPix_loss.py │ │ ├── RD_loss.py │ │ ├── __init__.py │ │ ├── composition_loss.py │ │ ├── feature_loss.py │ │ ├── gan_loss.py │ │ ├── gradient_loss.py │ │ ├── lambdaD_loss.py │ │ ├── moco_loss.py │ │ ├── perceptual_loss.py │ │ ├── pixelwise_loss.py │ │ ├── utils.py │ │ └── vq_loss.py │ ├── mattors │ │ ├── __init__.py │ │ ├── base_mattor.py │ │ ├── dim.py │ │ ├── gca.py │ │ ├── indexnet.py │ │ └── utils.py │ ├── registry.py │ ├── restorers │ │ ├── __init__.py │ │ ├── basic_restorer.py │ │ ├── basicvsr.py │ │ ├── dic.py │ │ ├── dmci.py │ │ ├── edvr.py │ │ ├── esrgan.py │ │ ├── glean.py │ │ ├── liif.py │ │ ├── real_basicvsr.py │ │ ├── real_esrgan.py │ │ ├── simple_baseline.py │ │ ├── srgan.py │ │ ├── tdan.py │ │ └── ttsr.py │ ├── synthesizers │ │ ├── __init__.py │ │ ├── cycle_gan.py │ │ └── pix2pix.py │ ├── transformers │ │ ├── __init__.py │ │ └── search_transformer.py │ ├── video_interpolators │ │ ├── __init__.py │ │ ├── basic_interpolator.py │ │ └── cain.py │ └── vqgan │ │ ├── __init__.py │ │ └── femasr_model.py ├── utils │ ├── __init__.py │ ├── cli.py │ ├── collect_env.py │ ├── common.py │ ├── functional.py │ ├── logger.py │ ├── setup_env.py │ └── stream_helper.py └── version.py ├── scripts ├── test_script_clean.sh ├── test_script_noise.sh ├── test_script_weather.sh └── train.sh └── tools ├── data ├── generation │ ├── README.md │ ├── README_zh-CN.md │ ├── paired-pix2pix │ │ ├── README.md │ │ └── README_zh-CN.md │ └── unpaired-cyclegan │ │ ├── README.md │ │ └── README_zh-CN.md ├── inpainting │ ├── README.md │ ├── README_zh-CN.md │ ├── celeba-hq │ │ ├── README.md │ │ └── README_zh-CN.md │ ├── paris-street-view │ │ ├── README.md │ │ └── README_zh-CN.md │ └── places365 │ │ ├── README.md │ │ └── README_zh-CN.md ├── matting │ ├── README.md │ ├── README_zh-CN.md │ ├── bgm │ │ └── preprocess_bgm_dataset.py │ └── comp1k │ │ ├── README.md │ │ ├── README_zh-CN.md │ │ ├── check_extended_fg.py │ │ ├── extend_fg.py │ │ ├── filter_comp1k_anno.py │ │ └── preprocess_comp1k_dataset.py ├── super-resolution │ ├── README.md │ ├── README_zh-CN.md │ ├── df2k_ost │ │ ├── README.md │ │ ├── README_zh-CN.md │ │ └── preprocess_df2k_ost_dataset.py │ ├── div2k │ │ ├── README.md │ │ ├── README_zh-CN.md │ │ └── preprocess_div2k_dataset.py │ ├── reds │ │ ├── README.md │ │ ├── README_zh-CN.md │ │ ├── crop_sub_images.py │ │ └── preprocess_reds_dataset.py │ ├── vid4 │ │ ├── README.md │ │ └── README_zh-CN.md │ └── vimeo90k │ │ ├── README.md │ │ ├── README_zh-CN.md │ │ └── preprocess_vimeo90k_dataset.py └── video-interpolation │ ├── README.md │ ├── README_zh-CN.md │ └── vimeo90k-triplet │ ├── README.md │ └── README_zh-CN.md ├── deploy_test.py ├── deployment ├── mmedit2torchserve.py ├── mmedit_handler.py └── test_torchserver.py ├── dist_test.sh ├── dist_test_s1.sh ├── dist_train.sh ├── dist_train_iter.sh ├── dist_train_s1.sh ├── evaluate_comp1k.py ├── onnx2tensorrt.py ├── publish_model.py ├── pytorch2onnx.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── test_stage1.py ├── train.py ├── train_iter.py └── train_s1.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.pyd 4 | .vscode/ 5 | *.bin 6 | *.png 7 | *.so 8 | *.csv 9 | *.onnx 10 | *.npz 11 | *.yuv 12 | __pycache__ 13 | pretrain_model 14 | workdirs/ 15 | *.pth 16 | *.zip 17 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeldaM1/All-in-one/8a79eecea4ffb694d976f66aa254229dc4ba18e2/assets/.DS_Store -------------------------------------------------------------------------------- /assets/CGA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeldaM1/All-in-one/8a79eecea4ffb694d976f66aa254229dc4ba18e2/assets/CGA.png -------------------------------------------------------------------------------- /assets/SDA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeldaM1/All-in-one/8a79eecea4ffb694d976f66aa254229dc4ba18e2/assets/SDA.png -------------------------------------------------------------------------------- /assets/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeldaM1/All-in-one/8a79eecea4ffb694d976f66aa254229dc4ba18e2/assets/archi.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeldaM1/All-in-one/8a79eecea4ffb694d976f66aa254229dc4ba18e2/assets/teaser.png -------------------------------------------------------------------------------- /configs/weather/all_in_one_L.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../base/base_rgb_weather_384.py' 3 | ] 4 | 5 | 6 | lambda_cfg = [128,2048] 7 | training_scheduling = 'normal' 8 | 9 | 10 | 11 | exp_name = 'all_in_one_L_weather' 12 | dim=48 13 | model = dict( 14 | type='DMCI', 15 | generator=dict(type='DMCI_ED_restormer_dual', dim=dim, num_blocks=[2,3,3,4], heads=[4,4,4,4], ffn_expansion_factor=2.66, 16 | bias=False, LayerNorm_type='BiasFree', N_ch=288, z_ch=4*dim,N_ch_new=True,sigmoid=False, 17 | intraEncoder="Restor_Encoderv2_pos3", version_enc="v4", 18 | intraDecoder="Restor_Decoderv2_pos3", version_dec="v4",), 19 | RD_loss=dict(type='RDLoss', loss_weight=1.0), 20 | ) 21 | resume_from = f'./workdirs/{exp_name}/iter_150000.pth' 22 | 23 | 24 | total_iters = 450000 25 | lr_config = dict( 26 | policy='CosineRestart', 27 | by_epoch=False, 28 | periods=[total_iters], 29 | restart_weights=[1], 30 | min_lr=1e-7) 31 | 32 | 33 | checkpoint_config = dict(interval=5000, save_optimizer=True, by_epoch=False) 34 | evaluation = dict(interval=10000, save_image=False, gpu_collect=True) 35 | dist_params = dict(backend='nccl') 36 | log_level = 'INFO' 37 | work_dir = f'./workdirs/{exp_name}' 38 | load_from = None 39 | workflow = [('train', 1)] 40 | -------------------------------------------------------------------------------- /configs/weather/all_in_one_S.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../base/base_rgb_weather_384.py' 3 | ] 4 | 5 | 6 | lambda_cfg = [128,2048] 7 | training_scheduling = 'normal' 8 | 9 | 10 | 11 | exp_name = 'all_in_one_S_weather' 12 | dim=32 13 | model = dict( 14 | type='DMCI', 15 | generator=dict(type='DMCI_ED_restormer_dual', dim=32, num_blocks=[2,3,3,4], heads=[4,4,4,4], ffn_expansion_factor=2.66, 16 | bias=False, LayerNorm_type='BiasFree', N_ch=256, z_ch=4*dim,N_ch_new=True,sigmoid=False, 17 | intraEncoder="Restor_Encoderv2_pos3", version_enc="v4", 18 | intraDecoder="Restor_Decoderv2_pos3", version_dec="v4",), 19 | RD_loss=dict(type='RDLoss', loss_weight=1.0), 20 | ) 21 | resume_from = f'./workdirs/{exp_name}/iter_150000.pth' 22 | 23 | 24 | total_iters = 450000 25 | lr_config = dict( 26 | policy='CosineRestart', 27 | by_epoch=False, 28 | periods=[total_iters], 29 | restart_weights=[1], 30 | min_lr=1e-7) 31 | 32 | 33 | checkpoint_config = dict(interval=5000, save_optimizer=True, by_epoch=False) 34 | evaluation = dict(interval=10000, save_image=False, gpu_collect=True) 35 | dist_params = dict(backend='nccl') 36 | log_level = 'INFO' 37 | work_dir = f'./workdirs/{exp_name}' 38 | load_from = None 39 | workflow = [('train', 1)] 40 | -------------------------------------------------------------------------------- /get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn.utils import get_model_complexity_info 6 | 7 | from mmedit.models import build_model 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Train a editor') 12 | parser.add_argument('config', help='train config file path') 13 | parser.add_argument( 14 | '--shape', 15 | type=int, 16 | nargs='+', 17 | default=[256,256], 18 | help='input image size') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | 25 | args = parse_args() 26 | 27 | if len(args.shape) == 1: 28 | input_shape = (3, args.shape[0], args.shape[0]) 29 | elif len(args.shape) == 2: 30 | input_shape = (3, ) + tuple(args.shape) 31 | elif len(args.shape) in [3, 4]: # 4 for video inputs (t, c, h, w) 32 | input_shape = tuple(args.shape) 33 | else: 34 | raise ValueError('invalid input shape') 35 | 36 | cfg = Config.fromfile(args.config) 37 | model = build_model( 38 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda() 39 | model.eval() 40 | model=model.generator 41 | 42 | # from ptflops import get_model_complexity_info 43 | 44 | # macs, params = get_model_complexity_info(model, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=True) 45 | # print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 46 | # print('{:<30} {:<8}'.format('Number of parameters: ', params)) 47 | 48 | 49 | # if hasattr(model, 'forward_dummy'): 50 | # model.forward = model.forward_dummy 51 | # else: 52 | # raise NotImplementedError( 53 | # 'FLOPs counter is currently not currently supported ' 54 | # f'with {model.__class__.__name__}') 55 | # breakpoint() 56 | flops, params = get_model_complexity_info(model, input_shape) 57 | 58 | split_line = '=' * 30 59 | print(f'{split_line}\nInput shape: {input_shape}\nFlops: {flops}\nParams: {params}\n{split_line}') 60 | if len(input_shape) == 4: 61 | print('!!!If your network computes N frames in one forward pass, you ' 62 | 'may want to divide the FLOPs by N to get the average FLOPs ' 63 | 'for each frame.') 64 | print('!!!Please be cautious if you use the results in papers. ' 65 | 'You may need to check if all ops are supported and verify that the ' 66 | 'flops computation is correct.') 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /mmedit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | from .version import __version__, version_info 5 | 6 | try: 7 | from mmcv.utils import digit_version 8 | except ImportError: 9 | 10 | def digit_version(version_str): 11 | digit_ver = [] 12 | for x in version_str.split('.'): 13 | if x.isdigit(): 14 | digit_ver.append(int(x)) 15 | elif x.find('rc') != -1: 16 | patch_version = x.split('rc') 17 | digit_ver.append(int(patch_version[0]) - 1) 18 | digit_ver.append(int(patch_version[1])) 19 | return digit_ver 20 | 21 | 22 | MMCV_MIN = '1.3.13' 23 | MMCV_MAX = '1.6' 24 | 25 | mmcv_min_version = digit_version(MMCV_MIN) 26 | mmcv_max_version = digit_version(MMCV_MAX) 27 | mmcv_version = digit_version(mmcv.__version__) 28 | 29 | 30 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 31 | f'mmcv=={mmcv.__version__} is used but incompatible. ' \ 32 | f'Please install mmcv-full>={mmcv_min_version}, <={mmcv_max_version}.' 33 | 34 | __all__ = ['__version__', 'version_info'] 35 | -------------------------------------------------------------------------------- /mmedit/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generation_inference import generation_inference 3 | from .inpainting_inference import inpainting_inference 4 | from .matting_inference import init_model, matting_inference 5 | from .restoration_face_inference import restoration_face_inference 6 | from .restoration_inference import restoration_inference 7 | from .restoration_video_inference import restoration_video_inference 8 | from .test import multi_gpu_test, single_gpu_test 9 | from .test_stage1 import multi_gpu_test_s1, single_gpu_test_s1 10 | from .train_epoch import init_random_seed, set_random_seed, train_model 11 | from .train_iter import train_model_iter 12 | from .train_s1 import train_model_s1 13 | from .video_interpolation_inference import video_interpolation_inference 14 | 15 | __all__ = [ 16 | 'train_model', 'set_random_seed', 'init_model', 'matting_inference', 17 | 'inpainting_inference', 'restoration_inference', 'generation_inference', 18 | 'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference', 19 | 'restoration_face_inference', 'video_interpolation_inference', 20 | 'init_random_seed','train_model_iter','multi_gpu_test_s1', 'single_gpu_test_s1', 21 | 'train_model_s1' 22 | ] 23 | -------------------------------------------------------------------------------- /mmedit/apis/generation_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | 6 | from mmedit.core import tensor2img 7 | from mmedit.datasets.pipelines import Compose 8 | 9 | 10 | def generation_inference(model, img, img_unpaired=None): 11 | """Inference image with the model. 12 | 13 | Args: 14 | model (nn.Module): The loaded model. 15 | img (str): File path of input image. 16 | img_unpaired (str, optional): File path of the unpaired image. 17 | If not None, perform unpaired image generation. Default: None. 18 | 19 | Returns: 20 | np.ndarray: The predicted generation result. 21 | """ 22 | cfg = model.cfg 23 | device = next(model.parameters()).device # model device 24 | # build the data pipeline 25 | test_pipeline = Compose(cfg.test_pipeline) 26 | # prepare data 27 | if img_unpaired is None: 28 | data = dict(pair_path=img) 29 | else: 30 | data = dict(img_a_path=img, img_b_path=img_unpaired) 31 | data = test_pipeline(data) 32 | data = collate([data], samples_per_gpu=1) 33 | if 'cuda' in str(device): 34 | data = scatter(data, [device])[0] 35 | # forward the model 36 | with torch.no_grad(): 37 | results = model(test_mode=True, **data) 38 | # process generation shown mode 39 | if img_unpaired is None: 40 | if model.show_input: 41 | output = np.concatenate([ 42 | tensor2img(results['real_a'], min_max=(-1, 1)), 43 | tensor2img(results['fake_b'], min_max=(-1, 1)), 44 | tensor2img(results['real_b'], min_max=(-1, 1)) 45 | ], 46 | axis=1) 47 | else: 48 | output = tensor2img(results['fake_b'], min_max=(-1, 1)) 49 | else: 50 | if model.show_input: 51 | output = np.concatenate([ 52 | tensor2img(results['real_a'], min_max=(-1, 1)), 53 | tensor2img(results['fake_b'], min_max=(-1, 1)), 54 | tensor2img(results['real_b'], min_max=(-1, 1)), 55 | tensor2img(results['fake_a'], min_max=(-1, 1)) 56 | ], 57 | axis=1) 58 | else: 59 | if model.test_direction == 'a2b': 60 | output = tensor2img(results['fake_b'], min_max=(-1, 1)) 61 | else: 62 | output = tensor2img(results['fake_a'], min_max=(-1, 1)) 63 | return output 64 | -------------------------------------------------------------------------------- /mmedit/apis/inpainting_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.parallel import collate, scatter 4 | 5 | from mmedit.datasets.pipelines import Compose 6 | 7 | 8 | def inpainting_inference(model, masked_img, mask): 9 | """Inference image with the model. 10 | 11 | Args: 12 | model (nn.Module): The loaded model. 13 | masked_img (str): File path of image with mask. 14 | mask (str): Mask file path. 15 | 16 | Returns: 17 | Tensor: The predicted inpainting result. 18 | """ 19 | device = next(model.parameters()).device # model device 20 | 21 | infer_pipeline = [ 22 | dict(type='LoadImageFromFile', key='masked_img'), 23 | dict(type='LoadMask', mask_mode='file', mask_config=dict()), 24 | dict(type='Pad', keys=['masked_img', 'mask'], mode='reflect'), 25 | dict( 26 | type='Normalize', 27 | keys=['masked_img'], 28 | mean=[127.5] * 3, 29 | std=[127.5] * 3, 30 | to_rgb=False), 31 | dict(type='GetMaskedImage', img_name='masked_img'), 32 | dict( 33 | type='Collect', 34 | keys=['masked_img', 'mask'], 35 | meta_keys=['masked_img_path']), 36 | dict(type='ImageToTensor', keys=['masked_img', 'mask']) 37 | ] 38 | 39 | # build the data pipeline 40 | test_pipeline = Compose(infer_pipeline) 41 | # prepare data 42 | data = dict(masked_img_path=masked_img, mask_path=mask) 43 | data = test_pipeline(data) 44 | data = collate([data], samples_per_gpu=1) 45 | if 'cuda' in str(device): 46 | data = scatter(data, [device])[0] 47 | else: 48 | data.pop('meta') 49 | # forward the model 50 | with torch.no_grad(): 51 | result = model(test_mode=True, **data) 52 | 53 | return result['fake_img'] 54 | -------------------------------------------------------------------------------- /mmedit/apis/matting_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | from mmcv.runner import load_checkpoint 6 | 7 | from mmedit.datasets.pipelines import Compose 8 | from mmedit.models import build_model 9 | 10 | 11 | def init_model(config, checkpoint=None, device='cuda:0'): 12 | """Initialize a model from config file. 13 | 14 | Args: 15 | config (str or :obj:`mmcv.Config`): Config file path or the config 16 | object. 17 | checkpoint (str, optional): Checkpoint path. If left as None, the model 18 | will not load any weights. 19 | device (str): Which device the model will deploy. Default: 'cuda:0'. 20 | 21 | Returns: 22 | nn.Module: The constructed model. 23 | """ 24 | if isinstance(config, str): 25 | config = mmcv.Config.fromfile(config) 26 | elif not isinstance(config, mmcv.Config): 27 | raise TypeError('config must be a filename or Config object, ' 28 | f'but got {type(config)}') 29 | config.model.pretrained = None 30 | config.test_cfg.metrics = None 31 | model = build_model(config.model, test_cfg=config.test_cfg) 32 | if checkpoint is not None: 33 | checkpoint = load_checkpoint(model, checkpoint) 34 | 35 | model.cfg = config # save the config in the model for convenience 36 | model.to(device) 37 | model.eval() 38 | return model 39 | 40 | 41 | def matting_inference(model, img, trimap): 42 | """Inference image(s) with the model. 43 | 44 | Args: 45 | model (nn.Module): The loaded model. 46 | img (str): Image file path. 47 | trimap (str): Trimap file path. 48 | 49 | Returns: 50 | np.ndarray: The predicted alpha matte. 51 | """ 52 | cfg = model.cfg 53 | device = next(model.parameters()).device # model device 54 | # remove alpha from test_pipeline 55 | keys_to_remove = ['alpha', 'ori_alpha'] 56 | for key in keys_to_remove: 57 | for pipeline in list(cfg.test_pipeline): 58 | if 'key' in pipeline and key == pipeline['key']: 59 | cfg.test_pipeline.remove(pipeline) 60 | if 'keys' in pipeline and key in pipeline['keys']: 61 | pipeline['keys'].remove(key) 62 | if len(pipeline['keys']) == 0: 63 | cfg.test_pipeline.remove(pipeline) 64 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 65 | pipeline['meta_keys'].remove(key) 66 | # build the data pipeline 67 | test_pipeline = Compose(cfg.test_pipeline) 68 | # prepare data 69 | data = dict(merged_path=img, trimap_path=trimap) 70 | data = test_pipeline(data) 71 | data = collate([data], samples_per_gpu=1) 72 | if 'cuda' in str(device): 73 | data = scatter(data, [device])[0] 74 | # forward the model 75 | with torch.no_grad(): 76 | result = model(test_mode=True, **data) 77 | 78 | return result['pred_alpha'] 79 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.parallel import collate, scatter 4 | 5 | from mmedit.datasets.pipelines import Compose 6 | 7 | 8 | def restoration_inference(model, img, ref=None): 9 | """Inference image with the model. 10 | 11 | Args: 12 | model (nn.Module): The loaded model. 13 | img (str): File path of input image. 14 | ref (str | None): File path of reference image. Default: None. 15 | 16 | Returns: 17 | Tensor: The predicted restoration result. 18 | """ 19 | cfg = model.cfg 20 | device = next(model.parameters()).device # model device 21 | # remove gt from test_pipeline 22 | keys_to_remove = ['gt', 'gt_path'] 23 | for key in keys_to_remove: 24 | for pipeline in list(cfg.test_pipeline): 25 | if 'key' in pipeline and key == pipeline['key']: 26 | cfg.test_pipeline.remove(pipeline) 27 | if 'keys' in pipeline and key in pipeline['keys']: 28 | pipeline['keys'].remove(key) 29 | if len(pipeline['keys']) == 0: 30 | cfg.test_pipeline.remove(pipeline) 31 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 32 | pipeline['meta_keys'].remove(key) 33 | # build the data pipeline 34 | test_pipeline = Compose(cfg.test_pipeline) 35 | # prepare data 36 | if ref: # Ref-SR 37 | data = dict(lq_path=img, ref_path=ref) 38 | else: # SISR 39 | data = dict(lq_path=img) 40 | data = test_pipeline(data) 41 | data = collate([data], samples_per_gpu=1) 42 | if 'cuda' in str(device): 43 | data = scatter(data, [device])[0] 44 | # forward the model 45 | with torch.no_grad(): 46 | result = model(test_mode=True, **data) 47 | 48 | return result['output'] 49 | -------------------------------------------------------------------------------- /mmedit/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import (DistEvalIterHook, EvalIterHook, L1Evaluation, DistEvalIterHook_iter, DistEvalIterHook_s1, EvalIterHook_s1, mae, 3 | mse, psnr, reorder_image, sad, ssim) 4 | from .hooks import VisualizationHook 5 | from .misc import tensor2img 6 | from .optimizer import build_optimizers 7 | from .scheduler import LinearLrUpdaterHook, ReduceLrUpdaterHook 8 | 9 | __all__ = [ 10 | 'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook', 11 | 'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook', 12 | 'VisualizationHook', 'L1Evaluation', 'ReduceLrUpdaterHook', 'mae','DistEvalIterHook_iter', 13 | 'DistEvalIterHook_s1', 'EvalIterHook_s1' 14 | ] 15 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import DistEvalIterHook, EvalIterHook,DistEvalIterHook_iter 3 | from .metrics import (L1Evaluation, connectivity, gradient_error, mae, mse, 4 | niqe, psnr, reorder_image, sad, ssim) 5 | from .eval_hooks_s1 import DistEvalIterHook_s1, EvalIterHook_s1 6 | __all__ = [ 7 | 'mse', 'sad', 'psnr', 'reorder_image', 'ssim', 'EvalIterHook', 8 | 'DistEvalIterHook', 'L1Evaluation', 'gradient_error', 'connectivity', 9 | 'niqe', 'mae','DistEvalIterHook_iter', 'DistEvalIterHook_s1', 'EvalIterHook_s1' 10 | ] 11 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def gaussian(x, sigma): 7 | """Gaussian function. 8 | 9 | Args: 10 | x (array_like): The independent variable. 11 | sigma (float): Standard deviation of the gaussian function. 12 | 13 | Return: 14 | ndarray or scalar: Gaussian value of `x`. 15 | """ 16 | return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) 17 | 18 | 19 | def dgaussian(x, sigma): 20 | """Gradient of gaussian. 21 | 22 | Args: 23 | x (array_like): The independent variable. 24 | sigma (float): Standard deviation of the gaussian function. 25 | 26 | Return: 27 | ndarray or scalar: Gradient of gaussian of `x`. 28 | """ 29 | return -x * gaussian(x, sigma) / sigma**2 30 | 31 | 32 | def gauss_filter(sigma, epsilon=1e-2): 33 | """Gradient of gaussian. 34 | 35 | Args: 36 | sigma (float): Standard deviation of the gaussian kernel. 37 | epsilon (float): Small value used when calculating kernel size. 38 | Default: 1e-2. 39 | 40 | Return: 41 | tuple[ndarray]: Gaussian filter along x and y axis. 42 | """ 43 | half_size = np.ceil( 44 | sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) 45 | size = int(2 * half_size + 1) 46 | 47 | # create filter in x axis 48 | filter_x = np.zeros((size, size)) 49 | for i in range(size): 50 | for j in range(size): 51 | filter_x[i, j] = gaussian(i - half_size, sigma) * dgaussian( 52 | j - half_size, sigma) 53 | 54 | # normalize filter 55 | norm = np.sqrt((filter_x**2).sum()) 56 | filter_x = filter_x / norm 57 | filter_y = np.transpose(filter_x) 58 | 59 | return filter_x, filter_y 60 | 61 | 62 | def gauss_gradient(img, sigma): 63 | """Gaussian gradient. 64 | 65 | From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ 66 | submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ 67 | index.html 68 | 69 | Args: 70 | img (ndarray): Input image. 71 | sigma (float): Standard deviation of the gaussian kernel. 72 | 73 | Return: 74 | ndarray: Gaussian gradient of input `img`. 75 | """ 76 | filter_x, filter_y = gauss_filter(sigma) 77 | img_filtered_x = cv2.filter2D( 78 | img, -1, filter_x, borderType=cv2.BORDER_REPLICATE) 79 | img_filtered_y = cv2.filter2D( 80 | img, -1, filter_y, borderType=cv2.BORDER_REPLICATE) 81 | return np.sqrt(img_filtered_x**2 + img_filtered_y**2) 82 | -------------------------------------------------------------------------------- /mmedit/core/export/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .wrappers import ONNXRuntimeEditing 3 | 4 | __all__ = ['ONNXRuntimeEditing'] 5 | -------------------------------------------------------------------------------- /mmedit/core/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ema import ExponentialMovingAverageHook 3 | from .visualization import VisualizationHook 4 | 5 | __all__ = ['VisualizationHook', 'ExponentialMovingAverageHook'] 6 | -------------------------------------------------------------------------------- /mmedit/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_optimizers 3 | 4 | __all__ = ['build_optimizers'] 5 | -------------------------------------------------------------------------------- /mmedit/core/optimizer/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import build_optimizer 3 | 4 | 5 | def build_optimizers(model, cfgs): 6 | """Build multiple optimizers from configs. 7 | 8 | If `cfgs` contains several dicts for optimizers, then a dict for each 9 | constructed optimizers will be returned. 10 | If `cfgs` only contains one optimizer config, the constructed optimizer 11 | itself will be returned. 12 | 13 | For example, 14 | 15 | 1) Multiple optimizer configs: 16 | 17 | .. code-block:: python 18 | 19 | optimizer_cfg = dict( 20 | model1=dict(type='SGD', lr=lr), 21 | model2=dict(type='SGD', lr=lr)) 22 | 23 | The return dict is 24 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` 25 | 26 | 2) Single optimizer config: 27 | 28 | .. code-block:: python 29 | 30 | optimizer_cfg = dict(type='SGD', lr=lr) 31 | 32 | The return is ``torch.optim.Optimizer``. 33 | 34 | Args: 35 | model (:obj:`nn.Module`): The model with parameters to be optimized. 36 | cfgs (dict): The config dict of the optimizer. 37 | 38 | Returns: 39 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: 40 | The initialized optimizers. 41 | """ 42 | optimizers = {} 43 | if hasattr(model, 'module'): 44 | model = model.module 45 | # determine whether 'cfgs' has several dicts for optimizers 46 | is_dict_of_dict = True 47 | for key, cfg in cfgs.items(): 48 | if not isinstance(cfg, dict): 49 | is_dict_of_dict = False 50 | 51 | if is_dict_of_dict: 52 | for key, cfg in cfgs.items(): 53 | cfg_ = cfg.copy() 54 | module = getattr(model, key) 55 | optimizers[key] = build_optimizer(module, cfg_) 56 | return optimizers 57 | 58 | return build_optimizer(model, cfgs) 59 | -------------------------------------------------------------------------------- /mmedit/core/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .lr_updater import LinearLrUpdaterHook, ReduceLrUpdaterHook 3 | 4 | __all__ = ['LinearLrUpdaterHook', 'ReduceLrUpdaterHook'] 5 | -------------------------------------------------------------------------------- /mmedit/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import sync_random_seed 3 | 4 | __all__ = ['sync_random_seed'] 5 | -------------------------------------------------------------------------------- /mmedit/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def sync_random_seed(seed=None, device='cuda'): 9 | """Make sure different ranks share the same seed. 10 | All workers must call this function, otherwise it will deadlock. 11 | This method is generally used in `DistributedSampler`, 12 | because the seed should be identical across all processes 13 | in the distributed group. 14 | Args: 15 | seed (int, Optional): The seed. Default to None. 16 | device (str): The device where the seed will be put on. 17 | Default to 'cuda'. 18 | Returns: 19 | int: Seed to be used. 20 | """ 21 | if seed is None: 22 | seed = np.random.randint(2**31) 23 | assert isinstance(seed, int) 24 | 25 | rank, world_size = get_dist_info() 26 | 27 | if world_size == 1: 28 | return seed 29 | 30 | if rank == 0: 31 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 32 | else: 33 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 34 | dist.broadcast(random_num, src=0) 35 | return random_num.item() 36 | -------------------------------------------------------------------------------- /mmedit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_dataset import BaseDataset 3 | from .base_generation_dataset import BaseGenerationDataset 4 | from .base_matting_dataset import BaseMattingDataset 5 | from .base_sr_dataset import BaseSRDataset 6 | from .base_vfi_dataset import BaseVFIDataset 7 | from .builder import build_dataloader, build_dataset 8 | from .comp1k_dataset import AdobeComp1kDataset 9 | from .dataset_wrappers import RepeatDataset 10 | from .generation_paired_dataset import GenerationPairedDataset 11 | from .generation_unpaired_dataset import GenerationUnpairedDataset 12 | from .img_inpainting_dataset import ImgInpaintingDataset 13 | from .registry import DATASETS, PIPELINES 14 | from .sr_annotation_dataset import SRAnnotationDataset 15 | from .sr_facial_landmark_dataset import SRFacialLandmarkDataset 16 | from .sr_folder_dataset import SRFolderDataset 17 | from .sr_folder_gt_dataset import SRFolderGTDataset,GTListDataset 18 | from .sr_folder_multiple_gt_dataset import SRFolderMultipleGTDataset 19 | from .sr_folder_ref_dataset import SRFolderRefDataset 20 | from .sr_folder_video_dataset import SRFolderVideoDataset 21 | from .sr_lmdb_dataset import SRLmdbDataset 22 | from .sr_reds_dataset import SRREDSDataset 23 | from .sr_reds_multiple_gt_dataset import SRREDSMultipleGTDataset 24 | from .sr_test_multiple_gt_dataset import SRTestMultipleGTDataset 25 | from .sr_vid4_dataset import SRVid4Dataset 26 | from .sr_vimeo90k_dataset import SRVimeo90KDataset 27 | from .sr_vimeo90k_multiple_gt_dataset import SRVimeo90KMultipleGTDataset 28 | from .vfi_vimeo90k_7frames_dataset import VFIVimeo90K7FramesDataset 29 | from .vfi_vimeo90k_dataset import VFIVimeo90KDataset 30 | 31 | __all__ = [ 32 | 'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader', 33 | 'BaseDataset', 'BaseMattingDataset', 'ImgInpaintingDataset', 34 | 'AdobeComp1kDataset', 'SRLmdbDataset', 'SRFolderDataset', 35 | 'SRAnnotationDataset', 'BaseSRDataset', 'RepeatDataset', 'SRREDSDataset', 36 | 'SRVimeo90KDataset', 'BaseGenerationDataset', 'GenerationPairedDataset', 37 | 'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset', 38 | 'SRREDSMultipleGTDataset', 'SRVimeo90KMultipleGTDataset', 39 | 'SRTestMultipleGTDataset', 'SRFolderRefDataset', 'SRFacialLandmarkDataset', 40 | 'SRFolderMultipleGTDataset', 'SRFolderVideoDataset', 'BaseVFIDataset', 41 | 'VFIVimeo90KDataset', 'VFIVimeo90K7FramesDataset','GTListDataset' 42 | ] 43 | -------------------------------------------------------------------------------- /mmedit/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from .pipelines import Compose 8 | 9 | 10 | class BaseDataset(Dataset, metaclass=ABCMeta): 11 | """Base class for datasets. 12 | 13 | All datasets should subclass it. 14 | All subclasses should overwrite: 15 | 16 | ``load_annotations``, supporting to load information and generate 17 | image lists. 18 | 19 | Args: 20 | pipeline (list[dict | callable]): A sequence of data transforms. 21 | test_mode (bool): If True, the dataset will work in test mode. 22 | Otherwise, in train mode. 23 | """ 24 | 25 | def __init__(self, pipeline, test_mode=False): 26 | super().__init__() 27 | self.test_mode = test_mode 28 | self.pipeline = Compose(pipeline) 29 | 30 | @abstractmethod 31 | def load_annotations(self): 32 | """Abstract function for loading annotation. 33 | 34 | All subclasses should overwrite this function 35 | """ 36 | 37 | def prepare_train_data(self, idx): 38 | """Prepare training data. 39 | 40 | Args: 41 | idx (int): Index of the training batch data. 42 | 43 | Returns: 44 | dict: Returned training batch. 45 | """ 46 | results = copy.deepcopy(self.data_infos[idx]) 47 | return self.pipeline(results) 48 | 49 | def prepare_test_data(self, idx): 50 | """Prepare testing data. 51 | 52 | Args: 53 | idx (int): Index for getting each testing batch. 54 | 55 | Returns: 56 | Tensor: Returned testing batch. 57 | """ 58 | results = copy.deepcopy(self.data_infos[idx]) 59 | return self.pipeline(results) 60 | 61 | def __len__(self): 62 | """Length of the dataset. 63 | 64 | Returns: 65 | int: Length of the dataset. 66 | """ 67 | return len(self.data_infos) 68 | 69 | def __getitem__(self, idx): 70 | """Get item at each call. 71 | 72 | Args: 73 | idx (int): Index for getting each item. 74 | """ 75 | if self.test_mode: 76 | return self.prepare_test_data(idx) 77 | 78 | return self.prepare_train_data(idx) 79 | -------------------------------------------------------------------------------- /mmedit/datasets/base_generation_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | from pathlib import Path 4 | 5 | from mmcv import scandir 6 | 7 | from .base_dataset import BaseDataset 8 | 9 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 10 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF') 11 | 12 | 13 | class BaseGenerationDataset(BaseDataset): 14 | """Base class for generation datasets.""" 15 | 16 | @staticmethod 17 | def scan_folder(path): 18 | """Obtain image path list (including sub-folders) from a given folder. 19 | 20 | Args: 21 | path (str | :obj:`Path`): Folder path. 22 | 23 | Returns: 24 | list[str]: Image list obtained from the given folder. 25 | """ 26 | 27 | if isinstance(path, (str, Path)): 28 | path = str(path) 29 | else: 30 | raise TypeError("'path' must be a str or a Path object, " 31 | f'but received {type(path)}.') 32 | 33 | images = scandir(path, suffix=IMG_EXTENSIONS, recursive=True) 34 | images = [osp.join(path, v) for v in images] 35 | assert images, f'{path} has no valid image file.' 36 | return images 37 | 38 | def evaluate(self, results, logger=None): 39 | """Evaluating with saving generated images. (needs no metrics) 40 | 41 | Args: 42 | results (list[tuple]): The output of forward_test() of the model. 43 | 44 | Return: 45 | dict: Evaluation results dict. 46 | """ 47 | if not isinstance(results, list): 48 | raise TypeError(f'results must be a list, but got {type(results)}') 49 | assert len(results) == len(self), ( 50 | 'The length of results is not equal to the dataset len: ' 51 | f'{len(results)} != {len(self)}') 52 | 53 | results = [res['saved_flag'] for res in results] 54 | saved_num = 0 55 | for flag in results: 56 | if flag: 57 | saved_num += 1 58 | 59 | # make a dict to show 60 | eval_result = {'val_saved_number': saved_num} 61 | 62 | return eval_result 63 | -------------------------------------------------------------------------------- /mmedit/datasets/base_matting_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import defaultdict 3 | 4 | from .base_dataset import BaseDataset 5 | from .registry import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class BaseMattingDataset(BaseDataset): 10 | """Base image matting dataset. 11 | """ 12 | 13 | def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False): 14 | super().__init__(pipeline, test_mode) 15 | self.ann_file = str(ann_file) 16 | self.data_prefix = str(data_prefix) 17 | self.data_infos = self.load_annotations() 18 | 19 | def evaluate(self, results, logger=None): 20 | """Evaluating with different metrics. 21 | 22 | Args: 23 | results (list[tuple]): The output of forward_test() of the model. 24 | 25 | Return: 26 | dict: Evaluation results dict. 27 | """ 28 | if not isinstance(results, list): 29 | raise TypeError(f'results must be a list, but got {type(results)}') 30 | assert len(results) == len(self), ( 31 | 'The length of results is not equal to the ' 32 | f'dataset len: {len(results)} != {len(self)}') 33 | 34 | results = [res['eval_result'] for res in results] # a list of dict 35 | 36 | eval_result = defaultdict(list) # a dict of list 37 | for res in results: 38 | for metric, val in res.items(): 39 | eval_result[metric].append(val) 40 | for metric, val_list in eval_result.items(): 41 | assert len(val_list) == len(self), ( 42 | f'Length of evaluation result of {metric} is {len(val_list)}, ' 43 | f'should be {len(self)}') 44 | 45 | # average the results 46 | eval_result = { 47 | metric: sum(values) / len(self) 48 | for metric, values in eval_result.items() 49 | } 50 | 51 | return eval_result 52 | -------------------------------------------------------------------------------- /mmedit/datasets/base_sr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import os.path as osp 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | from mmcv import scandir 8 | 9 | from .base_dataset import BaseDataset 10 | 11 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 12 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF') 13 | 14 | 15 | class BaseSRDataset(BaseDataset): 16 | """Base class for super resolution datasets. 17 | """ 18 | 19 | def __init__(self, pipeline, scale, test_mode=False): 20 | super().__init__(pipeline, test_mode) 21 | self.scale = scale 22 | 23 | @staticmethod 24 | def scan_folder(path): 25 | """Obtain image path list (including sub-folders) from a given folder. 26 | 27 | Args: 28 | path (str | :obj:`Path`): Folder path. 29 | 30 | Returns: 31 | list[str]: image list obtained form given folder. 32 | """ 33 | 34 | if isinstance(path, (str, Path)): 35 | path = str(path) 36 | else: 37 | raise TypeError("'path' must be a str or a Path object, " 38 | f'but received {type(path)}.') 39 | 40 | images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) 41 | images = [osp.join(path, v) for v in images] 42 | assert images, f'{path} has no valid image file.' 43 | return images 44 | 45 | def __getitem__(self, idx): 46 | """Get item at each call. 47 | 48 | Args: 49 | idx (int): Index for getting each item. 50 | """ 51 | results = copy.deepcopy(self.data_infos[idx]) 52 | results['scale'] = self.scale 53 | return self.pipeline(results) 54 | 55 | def evaluate(self, results, logger=None): 56 | """Evaluate with different metrics. 57 | 58 | Args: 59 | results (list[tuple]): The output of forward_test() of the model. 60 | 61 | Return: 62 | dict: Evaluation results dict. 63 | """ 64 | if not isinstance(results, list): 65 | raise TypeError(f'results must be a list, but got {type(results)}') 66 | assert len(results) == len(self), ( 67 | 'The length of results is not equal to the dataset len: ' 68 | f'{len(results)} != {len(self)}') 69 | 70 | results = [res['eval_result'] for res in results] # a list of dict 71 | eval_result = defaultdict(list) # a dict of list 72 | 73 | for res in results: 74 | for metric, val in res.items(): 75 | eval_result[metric].append(val) 76 | for metric, val_list in eval_result.items(): 77 | assert len(val_list) == len(self), ( 78 | f'Length of evaluation result of {metric} is {len(val_list)}, ' 79 | f'should be {len(self)}') 80 | 81 | # average the results 82 | eval_result = { 83 | metric: sum(values) / len(self) 84 | for metric, values in eval_result.items() 85 | } 86 | 87 | return eval_result 88 | -------------------------------------------------------------------------------- /mmedit/datasets/base_vfi_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from collections import defaultdict 4 | 5 | from .base_dataset import BaseDataset 6 | 7 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 8 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF') 9 | 10 | 11 | class BaseVFIDataset(BaseDataset): 12 | """Base class for video frame interpolation datasets. 13 | """ 14 | 15 | def __init__(self, pipeline, folder, ann_file, test_mode=False): 16 | super().__init__(pipeline, test_mode) 17 | self.folder = str(folder) 18 | self.ann_file = str(ann_file) 19 | 20 | def __getitem__(self, idx): 21 | """Get item at each call. 22 | 23 | Args: 24 | idx (int): Index for getting each item. 25 | """ 26 | results = copy.deepcopy(self.data_infos[idx]) 27 | results['folder'] = self.folder 28 | results['ann_file'] = self.ann_file 29 | return self.pipeline(results) 30 | 31 | def evaluate(self, results, logger=None): 32 | """Evaluate with different metrics. 33 | 34 | Args: 35 | results (list[tuple]): The output of forward_test() of the model. 36 | 37 | Return: 38 | dict: Evaluation results dict. 39 | """ 40 | if not isinstance(results, list): 41 | raise TypeError(f'results must be a list, but got {type(results)}') 42 | assert len(results) == len(self), ( 43 | 'The length of results is not equal to the dataset len: ' 44 | f'{len(results)} != {len(self)}') 45 | 46 | results = [res['eval_result'] for res in results] # a list of dict 47 | eval_result = defaultdict(list) # a dict of list 48 | 49 | for res in results: 50 | for metric, val in res.items(): 51 | eval_result[metric].append(val) 52 | for metric, val_list in eval_result.items(): 53 | assert len(val_list) == len(self), ( 54 | f'Length of evaluation result of {metric} is {len(val_list)}, ' 55 | f'should be {len(self)}') 56 | 57 | # average the results 58 | eval_result = { 59 | metric: sum(values) / len(self) 60 | for metric, values in eval_result.items() 61 | } 62 | 63 | return eval_result 64 | 65 | def load_annotations(self): 66 | """Abstract function for loading annotation. 67 | 68 | All subclasses should overwrite this function 69 | """ 70 | pass 71 | -------------------------------------------------------------------------------- /mmedit/datasets/comp1k_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | 6 | from .base_matting_dataset import BaseMattingDataset 7 | from .registry import DATASETS 8 | 9 | 10 | @DATASETS.register_module() 11 | class AdobeComp1kDataset(BaseMattingDataset): 12 | """Adobe composition-1k dataset. 13 | 14 | The dataset loads (alpha, fg, bg) data and apply specified transforms to 15 | the data. You could specify whether composite merged image online or load 16 | composited merged image in pipeline. 17 | 18 | Example for online comp-1k dataset: 19 | 20 | :: 21 | 22 | [ 23 | { 24 | "alpha": 'alpha/000.png', 25 | "fg": 'fg/000.png', 26 | "bg": 'bg/000.png' 27 | }, 28 | { 29 | "alpha": 'alpha/001.png', 30 | "fg": 'fg/001.png', 31 | "bg": 'bg/001.png' 32 | }, 33 | ] 34 | 35 | Example for offline comp-1k dataset: 36 | 37 | :: 38 | 39 | [ 40 | { 41 | "alpha": 'alpha/000.png', 42 | "merged": 'merged/000.png', 43 | "fg": 'fg/000.png', 44 | "bg": 'bg/000.png' 45 | }, 46 | { 47 | "alpha": 'alpha/001.png', 48 | "merged": 'merged/001.png', 49 | "fg": 'fg/001.png', 50 | "bg": 'bg/001.png' 51 | }, 52 | ] 53 | 54 | """ 55 | 56 | def load_annotations(self): 57 | """Load annotations for Adobe Composition-1k dataset. 58 | 59 | It loads image paths from json file. 60 | 61 | Returns: 62 | dict: Loaded dict. 63 | """ 64 | data_infos = mmcv.load(self.ann_file) 65 | 66 | for data_info in data_infos: 67 | for key in data_info: 68 | data_info[key] = osp.join(self.data_prefix, data_info[key]) 69 | 70 | return data_infos 71 | -------------------------------------------------------------------------------- /mmedit/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .registry import DATASETS 3 | 4 | 5 | @DATASETS.register_module() 6 | class RepeatDataset: 7 | """A wrapper of repeated dataset. 8 | 9 | The length of repeated dataset will be `times` larger than the original 10 | dataset. This is useful when the data loading time is long but the dataset 11 | is small. Using RepeatDataset can reduce the data loading time between 12 | epochs. 13 | 14 | Args: 15 | dataset (:obj:`Dataset`): The dataset to be repeated. 16 | times (int): Repeat times. 17 | """ 18 | 19 | def __init__(self, dataset, times): 20 | self.dataset = dataset 21 | self.times = times 22 | 23 | self._ori_len = len(self.dataset) 24 | 25 | def __getitem__(self, idx): 26 | """Get item at each call. 27 | 28 | Args: 29 | idx (int): Index for getting each item. 30 | """ 31 | return self.dataset[idx % self._ori_len] 32 | 33 | def __len__(self): 34 | """Length of the dataset. 35 | 36 | Returns: 37 | int: Length of the dataset. 38 | """ 39 | return self.times * self._ori_len 40 | -------------------------------------------------------------------------------- /mmedit/datasets/generation_paired_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .base_generation_dataset import BaseGenerationDataset 5 | from .registry import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class GenerationPairedDataset(BaseGenerationDataset): 10 | """General paired image folder dataset for image generation. 11 | 12 | It assumes that the training directory is '/path/to/data/train'. 13 | During test time, the directory is '/path/to/data/test'. '/path/to/data' 14 | can be initialized by args 'dataroot'. Each sample contains a pair of 15 | images concatenated in the w dimension (A|B). 16 | 17 | Args: 18 | dataroot (str | :obj:`Path`): Path to the folder root of paired images. 19 | pipeline (List[dict | callable]): A sequence of data transformations. 20 | test_mode (bool): Store `True` when building test dataset. 21 | Default: `False`. 22 | """ 23 | 24 | def __init__(self, dataroot, pipeline, test_mode=False): 25 | super().__init__(pipeline, test_mode) 26 | phase = 'test' if test_mode else 'train' 27 | self.dataroot = osp.join(str(dataroot), phase) 28 | self.data_infos = self.load_annotations() 29 | 30 | def load_annotations(self): 31 | """Load paired image paths. 32 | 33 | Returns: 34 | list[dict]: List that contains paired image paths. 35 | """ 36 | data_infos = [] 37 | pair_paths = sorted(self.scan_folder(self.dataroot)) 38 | for pair_path in pair_paths: 39 | data_infos.append(dict(pair_path=pair_path)) 40 | 41 | return data_infos 42 | -------------------------------------------------------------------------------- /mmedit/datasets/img_inpainting_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from pathlib import Path 3 | 4 | from .base_dataset import BaseDataset 5 | from .registry import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class ImgInpaintingDataset(BaseDataset): 10 | """Image dataset for inpainting. 11 | """ 12 | 13 | def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False): 14 | super().__init__(pipeline, test_mode) 15 | self.ann_file = str(ann_file) 16 | self.data_prefix = str(data_prefix) 17 | self.data_infos = self.load_annotations() 18 | 19 | def load_annotations(self): 20 | """Load annotations for dataset. 21 | 22 | Returns: 23 | list[dict]: Contain dataset annotations. 24 | """ 25 | with open(self.ann_file, 'r') as f: 26 | img_infos = [] 27 | for idx, line in enumerate(f): 28 | line = line.strip() 29 | _info = dict() 30 | line_split = line.split(' ') 31 | _info = dict( 32 | gt_img_path=Path(self.data_prefix).joinpath( 33 | line_split[0]).as_posix(), 34 | gt_img_idx=idx) 35 | img_infos.append(_info) 36 | 37 | return img_infos 38 | 39 | def evaluate(self, outputs, logger=None, **kwargs): 40 | metric_keys = outputs[0]['eval_result'].keys() 41 | stats = {} 42 | for key in metric_keys: 43 | val = sum([x['eval_result'][key] for x in outputs]) 44 | val /= self.__len__() 45 | stats[key] = val 46 | 47 | return stats 48 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augmentation import (BinarizeImage, ColorJitter, CopyValues, Flip, 3 | GenerateFrameIndices, 4 | GenerateFrameIndiceswithPadding, 5 | GenerateSegmentIndices, MirrorSequence, Pad, 6 | Quantize, RandomAffine, RandomJitter, 7 | RandomMaskDilation, RandomTransposeHW, Resize,MyTransposeHW, 8 | TemporalReverse, UnsharpMasking) 9 | from .compose import Compose 10 | from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown, 11 | CropLike, FixedCrop, ModCrop, PairedRandomCrop, 12 | RandomResizedCrop) 13 | from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,ImageToTensor_trans, 14 | ToTensor) 15 | from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap 16 | from .loading import (GetSpatialDiscountMask, LoadImageFromFile, 17 | LoadImageFromFileList, LoadMask, LoadPairedImageFromFile, 18 | RandomLoadResizeBg) 19 | from .matlab_like_resize import MATLABLikeResize 20 | from .matting_aug import (CompositeFg, GenerateSeg, GenerateSoftSeg, 21 | GenerateTrimap, GenerateTrimapWithDistTransform, 22 | MergeFgAndBg, PerturbBg, TransformTrimap) 23 | from .normalization import Normalize, RescaleToZeroOne 24 | from .random_degradations import (DegradationsWithShuffle, RandomBlur, 25 | RandomJPEGCompression, RandomNoise, 26 | RandomResize, RandomVideoCompression) 27 | from .random_down_sampling import RandomDownSampling 28 | 29 | __all__ = [ 30 | 'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask','MyTransposeHW', 31 | 'RandomLoadResizeBg', 'Compose', 'ImageToTensor', 'ToTensor', 32 | 'GetMaskedImage', 'BinarizeImage', 'Flip', 'Pad', 'RandomAffine', 33 | 'RandomJitter', 'ColorJitter', 'RandomMaskDilation', 'RandomTransposeHW', 34 | 'Resize', 'RandomResizedCrop', 'Crop', 'CropAroundCenter', 35 | 'CropAroundUnknown', 'ModCrop', 'PairedRandomCrop', 'Normalize', 36 | 'RescaleToZeroOne', 'GenerateTrimap', 'MergeFgAndBg', 'CompositeFg', 37 | 'TemporalReverse', 'LoadImageFromFileList', 'GenerateFrameIndices', 38 | 'GenerateFrameIndiceswithPadding', 'FixedCrop', 'LoadPairedImageFromFile', 39 | 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 40 | 'GetSpatialDiscountMask', 'RandomDownSampling', 41 | 'GenerateTrimapWithDistTransform', 'TransformTrimap', 42 | 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', 43 | 'CropLike', 'GenerateHeatmap', 'MATLABLikeResize', 'CopyValues', 44 | 'Quantize', 'RandomBlur', 'RandomJPEGCompression', 'RandomNoise', 45 | 'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking', 46 | 'RandomVideoCompression','ImageToTensor_trans' 47 | ] 48 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections.abc import Sequence 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..registry import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose: 11 | """Compose a data pipeline with a sequence of transforms. 12 | 13 | Args: 14 | transforms (list[dict | callable]): 15 | Either config dicts of transforms or transform objects. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError(f'transform must be callable or a dict, ' 29 | f'but got {type(transform)}') 30 | 31 | def __call__(self, data): 32 | """Call function. 33 | 34 | Args: 35 | data (dict): A dict containing the necessary information and 36 | data for augmentation. 37 | 38 | Returns: 39 | dict: A dict containing the processed data and information. 40 | """ 41 | for t in self.transforms: 42 | data = t(data) 43 | if data is None: 44 | return None 45 | return data 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.transforms: 50 | format_string += '\n' 51 | format_string += f' {t}' 52 | format_string += '\n)' 53 | return format_string 54 | -------------------------------------------------------------------------------- /mmedit/datasets/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry 3 | 4 | DATASETS = Registry('dataset') 5 | PIPELINES = Registry('pipeline') 6 | -------------------------------------------------------------------------------- /mmedit/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_annotation_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .base_sr_dataset import BaseSRDataset 5 | from .registry import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class SRAnnotationDataset(BaseSRDataset): 10 | """General paired image dataset with an annotation file for image 11 | restoration. 12 | 13 | The dataset loads lq (Low Quality) and gt (Ground-Truth) image pairs, 14 | applies specified transforms and finally returns a dict containing paired 15 | data and other information. 16 | 17 | This is the "annotation file mode": 18 | Each line in the annotation file contains the image names and 19 | image shape (usually for gt), separated by a white space. 20 | 21 | Example of an annotation file: 22 | 23 | :: 24 | 25 | 0001_s001.png (480,480,3) 26 | 0001_s002.png (480,480,3) 27 | 28 | Args: 29 | lq_folder (str | :obj:`Path`): Path to a lq folder. 30 | gt_folder (str | :obj:`Path`): Path to a gt folder. 31 | ann_file (str | :obj:`Path`): Path to the annotation file. 32 | pipeline (list[dict | callable]): A sequence of data transformations. 33 | scale (int): Upsampling scale ratio. 34 | test_mode (bool): Store `True` when building test dataset. 35 | Default: `False`. 36 | filename_tmpl (str): Template for each filename. Note that the 37 | template excludes the file extension. Default: '{}'. 38 | """ 39 | 40 | def __init__(self, 41 | lq_folder, 42 | gt_folder, 43 | ann_file, 44 | pipeline, 45 | scale, 46 | test_mode=False, 47 | filename_tmpl='{}'): 48 | super().__init__(pipeline, scale, test_mode) 49 | self.lq_folder = str(lq_folder) 50 | self.gt_folder = str(gt_folder) 51 | self.ann_file = str(ann_file) 52 | self.filename_tmpl = filename_tmpl 53 | self.data_infos = self.load_annotations() 54 | 55 | def load_annotations(self): 56 | """Load annotations for SR dataset. 57 | 58 | It loads the LQ and GT image path from the annotation file. 59 | Each line in the annotation file contains the image names and 60 | image shape (usually for gt), separated by a white space. 61 | 62 | Returns: 63 | list[dict]: A list of dicts for paired paths of LQ and GT. 64 | """ 65 | data_infos = [] 66 | with open(self.ann_file, 'r') as fin: 67 | for line in fin: 68 | gt_name = line.split(' ')[0] 69 | basename, ext = osp.splitext(osp.basename(gt_name)) 70 | lq_name = f'{self.filename_tmpl.format(basename)}{ext}' 71 | data_infos.append( 72 | dict( 73 | lq_path=osp.join(self.lq_folder, lq_name), 74 | gt_path=osp.join(self.gt_folder, gt_name))) 75 | return data_infos 76 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_facial_landmark_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from .base_sr_dataset import BaseSRDataset 7 | from .registry import DATASETS 8 | 9 | 10 | @DATASETS.register_module() 11 | class SRFacialLandmarkDataset(BaseSRDataset): 12 | """Facial image and landmark dataset with an annotation file for image 13 | restoration. 14 | 15 | The dataset loads gt (Ground-Truth) image, shape of image, face box, and 16 | landmark. Applies specified transforms and finally returns a dict 17 | containing paired data and other information. 18 | 19 | This is the "annotation file mode": 20 | Each dict in the annotation list contains the image names, image shape, 21 | face box, and landmark. 22 | 23 | Annotation file is a `npy` file, which contains a list of dict. 24 | Example of an annotation file: 25 | 26 | :: 27 | 28 | dict1(file=*, bbox=*, shape=*, landmark=*) 29 | dict2(file=*, bbox=*, shape=*, landmark=*) 30 | 31 | Args: 32 | gt_folder (str | :obj:`Path`): Path to a gt folder. 33 | ann_file (str | :obj:`Path`): Path to the annotation file. 34 | pipeline (list[dict | callable]): A sequence of data transformations. 35 | scale (int): Upsampling scale ratio. 36 | test_mode (bool): Store `True` when building test dataset. 37 | Default: `False`. 38 | """ 39 | 40 | def __init__(self, gt_folder, ann_file, pipeline, scale, test_mode=False): 41 | super().__init__(pipeline, scale, test_mode) 42 | self.gt_folder = str(gt_folder) 43 | self.ann_file = str(ann_file) 44 | self.data_infos = self.load_annotations() 45 | 46 | def load_annotations(self): 47 | """Load annotations for SR dataset. 48 | 49 | Annotation file is a `npy` file, which contains a list of dict. 50 | 51 | It loads the GT image path and landmark from the annotation file. 52 | Each dict in the annotation file contains the image names, image 53 | shape (usually for gt), bbox and landmark. 54 | 55 | Returns: 56 | list[dict]: A list of dicts for GT path and landmark. 57 | Contains: gt_path, bbox, shape, landmark. 58 | """ 59 | data_infos = np.load(self.ann_file, allow_pickle=True) 60 | for data_info in data_infos: 61 | data_info['gt_path'] = osp.join(self.gt_folder, 62 | data_info['gt_path']) 63 | 64 | return data_infos 65 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_test_multiple_gt_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os 4 | import os.path as osp 5 | import warnings 6 | 7 | from .base_sr_dataset import BaseSRDataset 8 | from .registry import DATASETS 9 | 10 | 11 | @DATASETS.register_module() 12 | class SRTestMultipleGTDataset(BaseSRDataset): 13 | """Test dataset for video super resolution for recurrent networks. 14 | 15 | It assumes all video sequences under the root directory is used for test. 16 | 17 | The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) 18 | frames. Then it applies specified transforms and finally returns a dict 19 | containing paired data and other information. 20 | 21 | Args: 22 | lq_folder (str | :obj:`Path`): Path to a lq folder. 23 | gt_folder (str | :obj:`Path`): Path to a gt folder. 24 | pipeline (list[dict | callable]): A sequence of data transformations. 25 | scale (int): Upsampling scale ratio. 26 | test_mode (bool): Store `True` when building test dataset. 27 | Default: `True`. 28 | """ 29 | 30 | def __init__(self, lq_folder, gt_folder, pipeline, scale, test_mode=True): 31 | super().__init__(pipeline, scale, test_mode) 32 | 33 | warnings.warn('"SRTestMultipleGTDataset" have been deprecated and ' 34 | 'will be removed in future release. Please use ' 35 | '"SRFolderMultipleGTDataset" instead. Details see ' 36 | 'https://github.com/open-mmlab/mmediting/pull/355') 37 | 38 | self.lq_folder = str(lq_folder) 39 | self.gt_folder = str(gt_folder) 40 | self.data_infos = self.load_annotations() 41 | 42 | def load_annotations(self): 43 | """Load annoations for the test dataset. 44 | 45 | Returns: 46 | list[dict]: A list of dicts for paired paths and other information. 47 | """ 48 | 49 | sequences = sorted(glob.glob(osp.join(self.lq_folder, '*'))) 50 | 51 | data_infos = [] 52 | for sequence in sequences: 53 | sequence_length = len(glob.glob(osp.join(sequence, '*.png'))) 54 | data_infos.append( 55 | dict( 56 | lq_path=self.lq_folder, 57 | gt_path=self.gt_folder, 58 | key=sequence.replace(f'{self.lq_folder}{os.sep}', ''), 59 | sequence_length=int(sequence_length))) 60 | 61 | return data_infos 62 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_vimeo90k_multiple_gt_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import os.path as osp 4 | 5 | from .base_sr_dataset import BaseSRDataset 6 | from .registry import DATASETS 7 | 8 | 9 | @DATASETS.register_module() 10 | class SRVimeo90KMultipleGTDataset(BaseSRDataset): 11 | """Vimeo90K dataset for video super resolution for recurrent networks. 12 | 13 | The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) 14 | frames. Then it applies specified transforms and finally returns a dict 15 | containing paired data and other information. 16 | 17 | It reads Vimeo90K keys from the txt file. Each line contains: 18 | 19 | 1. video frame folder 20 | 2. image shape 21 | 22 | Examples: 23 | 24 | :: 25 | 26 | 00001/0266 (256,448,3) 27 | 00001/0268 (256,448,3) 28 | 29 | Args: 30 | lq_folder (str | :obj:`Path`): Path to a lq folder. 31 | gt_folder (str | :obj:`Path`): Path to a gt folder. 32 | ann_file (str | :obj:`Path`): Path to the annotation file. 33 | pipeline (list[dict | callable]): A sequence of data transformations. 34 | scale (int): Upsampling scale ratio. 35 | num_input_frames (int): Number of frames in each training sequence. 36 | Default: 7. 37 | test_mode (bool): Store `True` when building test dataset. 38 | Default: `False`. 39 | """ 40 | 41 | def __init__(self, 42 | lq_folder, 43 | gt_folder, 44 | ann_file, 45 | pipeline, 46 | scale, 47 | num_input_frames=7, 48 | test_mode=False): 49 | super().__init__(pipeline, scale, test_mode) 50 | 51 | self.lq_folder = str(lq_folder) 52 | self.gt_folder = str(gt_folder) 53 | self.ann_file = str(ann_file) 54 | self.num_input_frames = num_input_frames 55 | 56 | self.data_infos = self.load_annotations() 57 | 58 | def load_annotations(self): 59 | """Load annotations for Vimeo-90K dataset. 60 | 61 | Returns: 62 | list[dict]: A list of dicts for paired paths and other information. 63 | """ 64 | # get keys 65 | with open(self.ann_file, 'r') as fin: 66 | keys = [line.strip().split(' ')[0] for line in fin] 67 | 68 | data_infos = [] 69 | for key in keys: 70 | key = key.replace('/', os.sep) 71 | lq_paths = [ 72 | osp.join(self.lq_folder, key, f'im{i}.png') 73 | for i in range(1, self.num_input_frames + 1) 74 | ] 75 | gt_paths = [ 76 | osp.join(self.gt_folder, key, f'im{i}.png') 77 | for i in range(1, self.num_input_frames + 1) 78 | ] 79 | 80 | data_infos.append( 81 | dict(lq_path=lq_paths, gt_path=gt_paths, key=key)) 82 | 83 | return data_infos 84 | -------------------------------------------------------------------------------- /mmedit/datasets/vfi_vimeo90k_7frames_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import os.path as osp 4 | 5 | from .base_vfi_dataset import BaseVFIDataset 6 | from .registry import DATASETS 7 | 8 | 9 | @DATASETS.register_module() 10 | class VFIVimeo90K7FramesDataset(BaseVFIDataset): 11 | """Utilize Vimeo90K dataset (7 frames) for video frame interpolation. 12 | 13 | Load 7 GT (Ground-Truth) frames from the dataset, predict several frame(s) 14 | from other frames. 15 | Then it applies specified transforms and finally returns a dict 16 | containing paired data and other information. 17 | 18 | It reads Vimeo90K keys from the txt file. Each line contains: 19 | 20 | 1. video frame folder 21 | 2. number of frames 22 | 3. image shape 23 | 24 | Examples: 25 | 26 | :: 27 | 28 | 00001/0266 7 (256,448,3) 29 | 00001/0268 7 (256,448,3) 30 | 31 | Note: Only `video frame folder` is required information. 32 | 33 | Args: 34 | folder (str | :obj:`Path`): Path to image folder. 35 | ann_file (str | :obj:`Path`): Path to the annotation file. 36 | pipeline (list[dict | callable]): A sequence of data transformations. 37 | input_frames (list[int]): Index of input frames. 38 | target_frames (list[int]): Index of target frames. 39 | test_mode (bool): Store `True` when building test dataset. 40 | Default: `False`. 41 | """ 42 | 43 | def __init__(self, 44 | folder, 45 | ann_file, 46 | pipeline, 47 | input_frames, 48 | target_frames, 49 | test_mode=False): 50 | super().__init__( 51 | pipeline=pipeline, 52 | folder=folder, 53 | ann_file=ann_file, 54 | test_mode=test_mode) 55 | 56 | self.input_frames = input_frames 57 | self.target_frames = target_frames 58 | 59 | self.data_infos = self.load_annotations() 60 | 61 | def load_annotations(self): 62 | """Load annoations for Vimeo-90K dataset. 63 | 64 | Returns: 65 | list[dict]: A list of dicts for paired paths and other information. 66 | """ 67 | # get keys 68 | with open(self.ann_file, 'r') as fin: 69 | keys = [line.strip().split(' ')[0] for line in fin] 70 | 71 | data_infos = [] 72 | for key in keys: 73 | key = key.replace('/', os.sep) 74 | inputs_path = [ 75 | osp.join(self.folder, key, f'im{i}.png') 76 | for i in self.input_frames 77 | ] 78 | target_path = [ 79 | osp.join(self.folder, key, f'im{i}.png') 80 | for i in self.target_frames 81 | ] 82 | 83 | data_infos.append( 84 | dict( 85 | inputs_path=inputs_path, target_path=target_path, key=key)) 86 | 87 | return data_infos 88 | -------------------------------------------------------------------------------- /mmedit/datasets/vfi_vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import os.path as osp 4 | 5 | from .base_vfi_dataset import BaseVFIDataset 6 | from .registry import DATASETS 7 | 8 | 9 | @DATASETS.register_module() 10 | class VFIVimeo90KDataset(BaseVFIDataset): 11 | """Vimeo90K dataset for video frame interpolation. 12 | 13 | The dataset loads two input frames and a center GT (Ground-Truth) frame. 14 | Then it applies specified transforms and finally returns a dict containing 15 | paired data and other information. 16 | 17 | It reads Vimeo90K keys from the txt file. 18 | Each line contains: 19 | 20 | Examples: 21 | 22 | :: 23 | 24 | 00001/0389 25 | 00001/0402 26 | 27 | Args: 28 | pipeline (list[dict | callable]): A sequence of data transformations. 29 | folder (str | :obj:`Path`): Path to the folder. 30 | ann_file (str | :obj:`Path`): Path to the annotation file. 31 | test_mode (bool): Store `True` when building test dataset. 32 | Default: `False`. 33 | """ 34 | 35 | def __init__(self, pipeline, folder, ann_file, test_mode=False): 36 | super().__init__(pipeline, folder, ann_file, test_mode) 37 | self.data_infos = self.load_annotations() 38 | 39 | def load_annotations(self): 40 | """Load annotations for VimeoK dataset. 41 | 42 | Returns: 43 | list[dict]: A list of dicts for paired paths and other information. 44 | """ 45 | # get keys 46 | with open(self.ann_file, 'r') as f: 47 | keys = f.read().split('\n') 48 | keys = [ 49 | k.strip() for k in keys if (k.strip() is not None and k != '') 50 | ] 51 | 52 | data_infos = [] 53 | for key in keys: 54 | key = key.replace('/', os.sep) 55 | key_folder = osp.join(self.folder, key) 56 | inputs_path = [ 57 | osp.join(key_folder, 'im1.png'), 58 | osp.join(key_folder, 'im3.png') 59 | ] 60 | target_path = osp.join(key_folder, 'im2.png') 61 | data_infos.append( 62 | dict( 63 | inputs_path=inputs_path, target_path=target_path, key=key)) 64 | 65 | return data_infos 66 | -------------------------------------------------------------------------------- /mmedit/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401, F403 3 | from .base import BaseModel 4 | from .builder import (build, build_backbone, build_component, build_loss, 5 | build_model) 6 | from .common import * # noqa: F401, F403 7 | from .components import * # noqa: F401, F403 8 | from .extractors import LTE, FeedbackHourglass 9 | from .inpaintors import (AOTInpaintor, DeepFillv1Inpaintor, GLInpaintor, 10 | OneStageInpaintor, PConvInpaintor, TwoStageInpaintor) 11 | from .losses import * # noqa: F401, F403 12 | from .mattors import DIM, GCA, BaseMattor, IndexNet 13 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS 14 | from .restorers import ESRGAN, SRGAN, BasicRestorer 15 | from .synthesizers import CycleGAN, Pix2Pix 16 | from .transformers import SearchTransformer 17 | from .video_interpolators import CAIN, BasicInterpolator 18 | from .vqgan import FeMaSRModel_HR 19 | __all__ = [ 20 | 'AOTInpaintor', 'BaseModel', 'BasicRestorer', 'OneStageInpaintor', 'build', 21 | 'build_backbone', 'build_component', 'build_loss', 'build_model', 22 | 'BACKBONES', 'COMPONENTS', 'LOSSES', 'BaseMattor', 'DIM', 'MODELS', 23 | 'GLInpaintor', 'PConvInpaintor', 'SRGAN', 'ESRGAN', 'GCA', 24 | 'TwoStageInpaintor', 'IndexNet', 'DeepFillv1Inpaintor', 'Pix2Pix', 25 | 'CycleGAN', 'SearchTransformer', 'LTE', 'FeedbackHourglass', 26 | 'BasicInterpolator', 'CAIN','FeMaSRModel_HR' 27 | ] 28 | -------------------------------------------------------------------------------- /mmedit/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .encoder_decoders import (VGG16, ContextualAttentionNeck, DeepFillDecoder, 3 | DeepFillEncoder, DeepFillEncoderDecoder, 4 | DepthwiseIndexBlock, FBADecoder, 5 | FBAResnetDilated, GLDecoder, GLDilationNeck, 6 | GLEncoder, GLEncoderDecoder, HolisticIndexBlock, 7 | IndexedUpsample, IndexNetDecoder, 8 | IndexNetEncoder, PConvDecoder, PConvEncoder, 9 | PConvEncoderDecoder, PlainDecoder, 10 | ResGCADecoder, ResGCAEncoder, ResNetDec, 11 | ResNetEnc, ResShortcutDec, ResShortcutEnc, 12 | SimpleEncoderDecoder) 13 | from .generation_backbones import ResnetGenerator, UnetGenerator 14 | from .sr_backbones import (EDSR, LIIFEDSR, LIIFRDN, RDN, SRCNN, BasicVSRNet, 15 | BasicVSRPlusPlus, DICNet, EDVRNet, GLEANStyleGANv2, 16 | IconVSR, MSRResNet, RealBasicVSRNet, RRDBNet, 17 | TDANNet, TOFlow, TTSRNet) 18 | from .vfi_backbones import CAINNet, FLAVRNet, TOFlowVFINet 19 | __all__ = [ 20 | 'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder', 21 | 'GLEncoderDecoder', 'GLEncoder', 'GLDecoder', 'GLDilationNeck', 22 | 'PConvEncoderDecoder', 'PConvEncoder', 'PConvDecoder', 'ResNetEnc', 23 | 'ResNetDec', 'ResShortcutEnc', 'ResShortcutDec', 'RRDBNet', 24 | 'DeepFillEncoder', 'HolisticIndexBlock', 'DepthwiseIndexBlock', 25 | 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'RDN', 'DICNet', 26 | 'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder', 27 | 'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN', 28 | 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder', 29 | 'BasicVSRNet', 'IconVSR', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet', 30 | 'LIIFEDSR', 'LIIFRDN', 'BasicVSRPlusPlus', 'RealBasicVSRNet', 'CAINNet', 31 | 'TOFlowVFINet', 'FLAVRNet' 32 | ] 33 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aot_encoder_decoder import AOTEncoderDecoder 3 | from .decoders import (DeepFillDecoder, FBADecoder, GLDecoder, IndexedUpsample, 4 | IndexNetDecoder, PConvDecoder, PlainDecoder, 5 | ResGCADecoder, ResNetDec, ResShortcutDec) 6 | from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock, 7 | FBAResnetDilated, GLEncoder, HolisticIndexBlock, 8 | IndexNetEncoder, PConvEncoder, ResGCAEncoder, ResNetEnc, 9 | ResShortcutEnc) 10 | from .gl_encoder_decoder import GLEncoderDecoder 11 | from .necks import ContextualAttentionNeck, GLDilationNeck 12 | from .pconv_encoder_decoder import PConvEncoderDecoder 13 | from .simple_encoder_decoder import SimpleEncoderDecoder 14 | from .two_stage_encoder_decoder import DeepFillEncoderDecoder 15 | 16 | __all__ = [ 17 | 'GLEncoderDecoder', 'SimpleEncoderDecoder', 'VGG16', 'GLEncoder', 18 | 'PlainDecoder', 'GLDecoder', 'GLDilationNeck', 'PConvEncoderDecoder', 19 | 'PConvEncoder', 'PConvDecoder', 'ResNetEnc', 'ResNetDec', 'ResShortcutEnc', 20 | 'ResShortcutDec', 'HolisticIndexBlock', 'DepthwiseIndexBlock', 21 | 'DeepFillEncoder', 'DeepFillEncoderDecoder', 'DeepFillDecoder', 22 | 'ContextualAttentionNeck', 'IndexedUpsample', 'IndexNetEncoder', 23 | 'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder', 'FBAResnetDilated', 24 | 'FBADecoder', 'AOTEncoderDecoder' 25 | ] 26 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/aot_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmedit.models.builder import build_component 3 | from mmedit.models.registry import BACKBONES 4 | from .gl_encoder_decoder import GLEncoderDecoder 5 | 6 | 7 | @BACKBONES.register_module() 8 | class AOTEncoderDecoder(GLEncoderDecoder): 9 | """Encoder-Decoder used in AOT-GAN model. 10 | 11 | This implementation follows: 12 | Aggregated Contextual Transformations for High-Resolution Image Inpainting 13 | The architecture of the encoder-decoder is: 14 | (conv2d x 3) --> (dilated conv2d x 8) --> (conv2d or deconv2d x 3) 15 | 16 | Args: 17 | encoder (dict): Config dict to encoder. 18 | decoder (dict): Config dict to build decoder. 19 | dilation_neck (dict): Config dict to build dilation neck. 20 | """ 21 | 22 | def __init__(self, 23 | encoder=dict(type='AOTEncoder'), 24 | decoder=dict(type='AOTDecoder'), 25 | dilation_neck=dict(type='AOTBlockNeck')): 26 | super().__init__() 27 | self.encoder = build_component(encoder) 28 | self.decoder = build_component(decoder) 29 | self.dilation_neck = build_component(dilation_neck) 30 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aot_decoder import AOTDecoder 3 | from .deepfill_decoder import DeepFillDecoder 4 | from .fba_decoder import FBADecoder 5 | from .gl_decoder import GLDecoder 6 | from .indexnet_decoder import IndexedUpsample, IndexNetDecoder 7 | from .pconv_decoder import PConvDecoder 8 | from .plain_decoder import PlainDecoder 9 | from .resnet_dec import ResGCADecoder, ResNetDec, ResShortcutDec 10 | 11 | __all__ = [ 12 | 'GLDecoder', 'PlainDecoder', 'PConvDecoder', 'ResNetDec', 'ResShortcutDec', 13 | 'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder', 14 | 'FBADecoder', 'AOTDecoder' 15 | ] 16 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/decoders/aot_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmedit.models.registry import COMPONENTS 7 | 8 | 9 | @COMPONENTS.register_module() 10 | class AOTDecoder(nn.Module): 11 | """Decoder used in AOT-GAN model. 12 | 13 | This implementation follows: 14 | Aggregated Contextual Transformations for High-Resolution Image Inpainting 15 | 16 | Args: 17 | in_channels (int, optional): Channel number of input feature. 18 | Default: 256. 19 | mid_channels (int, optional): Channel number of middle feature. 20 | Default: 128. 21 | out_channels (int, optional): Channel number of output feature. 22 | Default 3. 23 | act_cfg (dict, optional): Config dict for activation layer, 24 | "relu" by default. 25 | """ 26 | 27 | def __init__(self, 28 | in_channels=256, 29 | mid_channels=128, 30 | out_channels=3, 31 | act_cfg=dict(type='ReLU')): 32 | super().__init__() 33 | 34 | self.decoder = nn.ModuleList([ 35 | ConvModule( 36 | in_channels, 37 | mid_channels, 38 | kernel_size=3, 39 | stride=1, 40 | padding=1, 41 | act_cfg=act_cfg), 42 | ConvModule( 43 | mid_channels, 44 | mid_channels // 2, 45 | kernel_size=3, 46 | stride=1, 47 | padding=1, 48 | act_cfg=act_cfg), 49 | ConvModule( 50 | mid_channels // 2, 51 | out_channels, 52 | kernel_size=3, 53 | stride=1, 54 | padding=1, 55 | act_cfg=None) 56 | ]) 57 | self.output_act = nn.Tanh() 58 | 59 | def forward(self, x): 60 | """Forward Function. 61 | 62 | Args: 63 | x (Tensor): Input tensor with shape of (n, c, h, w). 64 | 65 | Returns: 66 | Tensor: Output tensor with shape of (n, c, h', w'). 67 | """ 68 | for i in range(0, len(self.decoder)): 69 | if i <= 1: 70 | x = F.interpolate( 71 | x, scale_factor=2, mode='bilinear', align_corners=True) 72 | x = self.decoder[i](x) 73 | 74 | return self.output_act(x) 75 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aot_encoder import AOTEncoder 3 | from .deepfill_encoder import DeepFillEncoder 4 | from .fba_encoder import FBAResnetDilated 5 | from .gl_encoder import GLEncoder 6 | from .indexnet_encoder import (DepthwiseIndexBlock, HolisticIndexBlock, 7 | IndexNetEncoder) 8 | from .pconv_encoder import PConvEncoder 9 | from .resnet_enc import ResGCAEncoder, ResNetEnc, ResShortcutEnc 10 | from .vgg import VGG16 11 | 12 | __all__ = [ 13 | 'GLEncoder', 'VGG16', 'ResNetEnc', 'HolisticIndexBlock', 14 | 'DepthwiseIndexBlock', 'ResShortcutEnc', 'PConvEncoder', 'DeepFillEncoder', 15 | 'IndexNetEncoder', 'ResGCAEncoder', 'FBAResnetDilated', 'AOTEncoder' 16 | ] 17 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/encoders/aot_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmedit.models.registry import COMPONENTS 6 | 7 | 8 | @COMPONENTS.register_module() 9 | class AOTEncoder(nn.Module): 10 | """Encoder used in AOT-GAN model. 11 | 12 | This implementation follows: 13 | Aggregated Contextual Transformations for High-Resolution Image Inpainting 14 | 15 | Args: 16 | in_channels (int, optional): Channel number of input feature. 17 | Default: 4. 18 | mid_channels (int, optional): Channel number of middle feature. 19 | Default: 64. 20 | out_channels (int, optional): Channel number of output feature. 21 | Default: 256. 22 | act_cfg (dict, optional): Config dict for activation layer, 23 | "relu" by default. 24 | """ 25 | 26 | def __init__(self, 27 | in_channels=4, 28 | mid_channels=64, 29 | out_channels=256, 30 | act_cfg=dict(type='ReLU')): 31 | super().__init__() 32 | self.encoder = nn.Sequential( 33 | nn.ReflectionPad2d(3), 34 | ConvModule( 35 | in_channels, 36 | mid_channels, 37 | kernel_size=7, 38 | stride=1, 39 | act_cfg=act_cfg), 40 | ConvModule( 41 | mid_channels, 42 | mid_channels * 2, 43 | kernel_size=4, 44 | stride=2, 45 | padding=1, 46 | act_cfg=act_cfg), 47 | ConvModule( 48 | mid_channels * 2, 49 | out_channels, 50 | kernel_size=4, 51 | stride=2, 52 | padding=1, 53 | act_cfg=act_cfg)) 54 | 55 | def forward(self, x): 56 | """Forward Function. 57 | 58 | Args: 59 | x (Tensor): Input tensor with shape of (n, c, h, w). 60 | 61 | Returns: 62 | Tensor: Output tensor with shape of (n, c, h', w'). 63 | """ 64 | return self.encoder(x) 65 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/encoders/deepfill_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmedit.models.common import SimpleGatedConvModule 6 | from mmedit.models.registry import COMPONENTS 7 | 8 | 9 | @COMPONENTS.register_module() 10 | class DeepFillEncoder(nn.Module): 11 | """Encoder used in DeepFill model. 12 | 13 | This implementation follows: 14 | Generative Image Inpainting with Contextual Attention 15 | 16 | Args: 17 | in_channels (int): The number of input channels. Default: 5. 18 | conv_type (str): The type of conv module. In DeepFillv1 model, the 19 | `conv_type` should be 'conv'. In DeepFillv2 model, the `conv_type` 20 | should be 'gated_conv'. 21 | norm_cfg (dict): Config dict to build norm layer. Default: None. 22 | act_cfg (dict): Config dict for activation layer, "elu" by default. 23 | encoder_type (str): Type of the encoder. Should be one of ['stage1', 24 | 'stage2_conv', 'stage2_attention']. Default: 'stage1'. 25 | channel_factor (float): The scale factor for channel size. 26 | Default: 1. 27 | kwargs (keyword arguments). 28 | """ 29 | _conv_type = dict(conv=ConvModule, gated_conv=SimpleGatedConvModule) 30 | 31 | def __init__(self, 32 | in_channels=5, 33 | conv_type='conv', 34 | norm_cfg=None, 35 | act_cfg=dict(type='ELU'), 36 | encoder_type='stage1', 37 | channel_factor=1., 38 | **kwargs): 39 | super().__init__() 40 | conv_module = self._conv_type[conv_type] 41 | channel_list_dict = dict( 42 | stage1=[32, 64, 64, 128, 128, 128], 43 | stage2_conv=[32, 32, 64, 64, 128, 128], 44 | stage2_attention=[32, 32, 64, 128, 128, 128]) 45 | channel_list = channel_list_dict[encoder_type] 46 | channel_list = [int(x * channel_factor) for x in channel_list] 47 | kernel_size_list = [5, 3, 3, 3, 3, 3] 48 | stride_list = [1, 2, 1, 2, 1, 1] 49 | for i in range(6): 50 | ks = kernel_size_list[i] 51 | padding = (ks - 1) // 2 52 | self.add_module( 53 | f'enc{i + 1}', 54 | conv_module( 55 | in_channels, 56 | channel_list[i], 57 | kernel_size=ks, 58 | stride=stride_list[i], 59 | padding=padding, 60 | norm_cfg=norm_cfg, 61 | act_cfg=act_cfg, 62 | **kwargs)) 63 | in_channels = channel_list[i] 64 | 65 | def forward(self, x): 66 | """Forward Function. 67 | 68 | Args: 69 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 70 | 71 | Returns: 72 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 73 | """ 74 | for i in range(6): 75 | x = getattr(self, f'enc{i + 1}')(x) 76 | outputs = dict(out=x) 77 | return outputs 78 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmedit.models.registry import COMPONENTS 3 | from .resnet import ResNet 4 | 5 | 6 | @COMPONENTS.register_module() 7 | class FBAResnetDilated(ResNet): 8 | """ResNet-based encoder for FBA image matting.""" 9 | 10 | def forward(self, x): 11 | """Forward function. 12 | 13 | Args: 14 | x (Tensor): Input tensor with shape (N, C, H, W). 15 | 16 | Returns: 17 | Tensor: Output tensor. 18 | """ 19 | # x: (merged_t, trimap_t, two_channel_trimap,merged) 20 | # t refers to transformed. 21 | two_channel_trimap = x[:, 9:11] 22 | merged = x[:, 11:14] 23 | x = x[:, 0:11, ...] 24 | conv_out = [x] 25 | if self.deep_stem: 26 | x = self.stem(x) 27 | else: 28 | x = self.conv1(x) 29 | x = self.norm1(x) 30 | x = self.activate(x) 31 | conv_out.append(x) 32 | x = self.maxpool(x) 33 | x = self.layer1(x) 34 | conv_out.append(x) 35 | x = self.layer2(x) 36 | conv_out.append(x) 37 | x = self.layer3(x) 38 | conv_out.append(x) 39 | x = self.layer4(x) 40 | conv_out.append(x) 41 | return { 42 | 'conv_out': conv_out, 43 | 'merged': merged, 44 | 'two_channel_trimap': two_channel_trimap 45 | } 46 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/encoders/gl_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmedit.models.registry import COMPONENTS 6 | 7 | 8 | @COMPONENTS.register_module() 9 | class GLEncoder(nn.Module): 10 | """Encoder used in Global&Local model. 11 | 12 | This implementation follows: 13 | Globally and locally Consistent Image Completion 14 | 15 | Args: 16 | norm_cfg (dict): Config dict to build norm layer. 17 | act_cfg (dict): Config dict for activation layer, "relu" by default. 18 | """ 19 | 20 | def __init__(self, norm_cfg=None, act_cfg=dict(type='ReLU')): 21 | super().__init__() 22 | 23 | channel_list = [64, 128, 128, 256, 256, 256] 24 | kernel_size_list = [5, 3, 3, 3, 3, 3] 25 | stride_list = [1, 2, 1, 2, 1, 1] 26 | in_channels = 4 27 | for i in range(6): 28 | ks = kernel_size_list[i] 29 | padding = (ks - 1) // 2 30 | self.add_module( 31 | f'enc{i + 1}', 32 | ConvModule( 33 | in_channels, 34 | channel_list[i], 35 | kernel_size=ks, 36 | stride=stride_list[i], 37 | padding=padding, 38 | norm_cfg=norm_cfg, 39 | act_cfg=act_cfg)) 40 | in_channels = channel_list[i] 41 | 42 | def forward(self, x): 43 | """Forward Function. 44 | 45 | Args: 46 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 47 | 48 | Returns: 49 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 50 | """ 51 | for i in range(6): 52 | x = getattr(self, f'enc{i + 1}')(x) 53 | return x 54 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/gl_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.runner import auto_fp16, load_checkpoint 4 | 5 | from mmedit.models.builder import build_component 6 | from mmedit.models.registry import BACKBONES 7 | from mmedit.utils import get_root_logger 8 | 9 | 10 | @BACKBONES.register_module() 11 | class GLEncoderDecoder(nn.Module): 12 | """Encoder-Decoder used in Global&Local model. 13 | 14 | This implementation follows: 15 | Globally and locally Consistent Image Completion 16 | 17 | The architecture of the encoder-decoder is:\ 18 | (conv2d x 6) --> (dilated conv2d x 4) --> (conv2d or deconv2d x 7) 19 | 20 | Args: 21 | encoder (dict): Config dict to encoder. 22 | decoder (dict): Config dict to build decoder. 23 | dilation_neck (dict): Config dict to build dilation neck. 24 | """ 25 | 26 | def __init__(self, 27 | encoder=dict(type='GLEncoder'), 28 | decoder=dict(type='GLDecoder'), 29 | dilation_neck=dict(type='GLDilationNeck')): 30 | super().__init__() 31 | self.encoder = build_component(encoder) 32 | self.decoder = build_component(decoder) 33 | self.dilation_neck = build_component(dilation_neck) 34 | 35 | # support fp16 36 | self.fp16_enabled = False 37 | 38 | @auto_fp16() 39 | def forward(self, x): 40 | """Forward Function. 41 | 42 | Args: 43 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 44 | 45 | Returns: 46 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 47 | """ 48 | x = self.encoder(x) 49 | if isinstance(x, dict): 50 | x = x['out'] 51 | x = self.dilation_neck(x) 52 | x = self.decoder(x) 53 | 54 | return x 55 | 56 | def init_weights(self, pretrained=None): 57 | """Init weights for models. 58 | 59 | Args: 60 | pretrained (str, optional): Path for pretrained weights. If given 61 | None, pretrained weights will not be loaded. Defaults to None. 62 | """ 63 | if isinstance(pretrained, str): 64 | logger = get_root_logger() 65 | load_checkpoint(self, pretrained, strict=False, logger=logger) 66 | elif pretrained is None: 67 | # Here, we just use the default initialization in `ConvModule`. 68 | pass 69 | else: 70 | raise TypeError('pretrained must be a str or None') 71 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aot_neck import AOTBlock, AOTBlockNeck 3 | from .contextual_attention_neck import ContextualAttentionNeck 4 | from .gl_dilation import GLDilationNeck 5 | 6 | __all__ = [ 7 | 'GLDilationNeck', 'ContextualAttentionNeck', 'AOTBlockNeck', 'AOTBlock' 8 | ] 9 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/necks/contextual_attention_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmedit.models.common import SimpleGatedConvModule 6 | from mmedit.models.common.contextual_attention import ContextualAttentionModule 7 | from mmedit.models.registry import COMPONENTS 8 | 9 | 10 | @COMPONENTS.register_module() 11 | class ContextualAttentionNeck(nn.Module): 12 | """Neck with contextual attention module. 13 | 14 | Args: 15 | in_channels (int): The number of input channels. 16 | conv_type (str): The type of conv module. In DeepFillv1 model, the 17 | `conv_type` should be 'conv'. In DeepFillv2 model, the `conv_type` 18 | should be 'gated_conv'. 19 | conv_cfg (dict | None): Config of conv module. Default: None. 20 | norm_cfg (dict | None): Config of norm module. Default: None. 21 | act_cfg (dict | None): Config of activation layer. Default: 22 | dict(type='ELU'). 23 | contextual_attention_args (dict): Config of contextual attention 24 | module. Default: dict(softmax_scale=10.). 25 | kwargs (keyword arguments). 26 | """ 27 | _conv_type = dict(conv=ConvModule, gated_conv=SimpleGatedConvModule) 28 | 29 | def __init__(self, 30 | in_channels, 31 | conv_type='conv', 32 | conv_cfg=None, 33 | norm_cfg=None, 34 | act_cfg=dict(type='ELU'), 35 | contextual_attention_args=dict(softmax_scale=10.), 36 | **kwargs): 37 | super().__init__() 38 | self.contextual_attention = ContextualAttentionModule( 39 | **contextual_attention_args) 40 | conv_module = self._conv_type[conv_type] 41 | self.conv1 = conv_module( 42 | in_channels, 43 | in_channels, 44 | 3, 45 | padding=1, 46 | conv_cfg=conv_cfg, 47 | norm_cfg=norm_cfg, 48 | act_cfg=act_cfg, 49 | **kwargs) 50 | self.conv2 = conv_module( 51 | in_channels, 52 | in_channels, 53 | 3, 54 | padding=1, 55 | conv_cfg=conv_cfg, 56 | norm_cfg=norm_cfg, 57 | act_cfg=act_cfg, 58 | **kwargs) 59 | 60 | def forward(self, x, mask): 61 | """Forward Function. 62 | 63 | Args: 64 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 65 | mask (torch.Tensor): Input tensor with shape of (n, 1, h, w). 66 | 67 | Returns: 68 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 69 | """ 70 | x, offset = self.contextual_attention(x, x, mask) 71 | x = self.conv1(x) 72 | x = self.conv2(x) 73 | 74 | return x, offset 75 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/necks/gl_dilation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | from mmedit.models.common import SimpleGatedConvModule 6 | from mmedit.models.registry import COMPONENTS 7 | 8 | 9 | @COMPONENTS.register_module() 10 | class GLDilationNeck(nn.Module): 11 | """Dilation Backbone used in Global&Local model. 12 | 13 | This implementation follows: 14 | Globally and locally Consistent Image Completion 15 | 16 | Args: 17 | in_channels (int): Channel number of input feature. 18 | conv_type (str): The type of conv module. In DeepFillv1 model, the 19 | `conv_type` should be 'conv'. In DeepFillv2 model, the `conv_type` 20 | should be 'gated_conv'. 21 | norm_cfg (dict): Config dict to build norm layer. 22 | act_cfg (dict): Config dict for activation layer, "relu" by default. 23 | kwargs (keyword arguments). 24 | """ 25 | _conv_type = dict(conv=ConvModule, gated_conv=SimpleGatedConvModule) 26 | 27 | def __init__(self, 28 | in_channels=256, 29 | conv_type='conv', 30 | norm_cfg=None, 31 | act_cfg=dict(type='ReLU'), 32 | **kwargs): 33 | super().__init__() 34 | conv_module = self._conv_type[conv_type] 35 | dilation_convs_ = [] 36 | for i in range(4): 37 | dilation_ = int(2**(i + 1)) 38 | dilation_convs_.append( 39 | conv_module( 40 | in_channels, 41 | in_channels, 42 | kernel_size=3, 43 | padding=dilation_, 44 | dilation=dilation_, 45 | stride=1, 46 | norm_cfg=norm_cfg, 47 | act_cfg=act_cfg, 48 | **kwargs)) 49 | self.dilation_convs = nn.Sequential(*dilation_convs_) 50 | 51 | def forward(self, x): 52 | """Forward Function. 53 | 54 | Args: 55 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 56 | 57 | Returns: 58 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 59 | """ 60 | x = self.dilation_convs(x) 61 | return x 62 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/pconv_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.runner import auto_fp16, load_checkpoint 4 | 5 | from mmedit.models.builder import build_component 6 | from mmedit.models.registry import BACKBONES 7 | from mmedit.utils import get_root_logger 8 | 9 | 10 | @BACKBONES.register_module() 11 | class PConvEncoderDecoder(nn.Module): 12 | """Encoder-Decoder with partial conv module. 13 | 14 | Args: 15 | encoder (dict): Config of the encoder. 16 | decoder (dict): Config of the decoder. 17 | """ 18 | 19 | def __init__(self, encoder, decoder): 20 | super().__init__() 21 | self.encoder = build_component(encoder) 22 | self.decoder = build_component(decoder) 23 | 24 | # support fp16 25 | self.fp16_enabled = False 26 | 27 | @auto_fp16() 28 | def forward(self, x, mask_in): 29 | """Forward Function. 30 | 31 | Args: 32 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 33 | mask_in (torch.Tensor): Input tensor with shape of (n, c, h, w). 34 | 35 | Returns: 36 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 37 | """ 38 | enc_outputs = self.encoder(x, mask_in) 39 | x, final_mask = self.decoder(enc_outputs) 40 | 41 | return x, final_mask 42 | 43 | def init_weights(self, pretrained=None): 44 | """Init weights for models. 45 | 46 | Args: 47 | pretrained (str, optional): Path for pretrained weights. If given 48 | None, pretrained weights will not be loaded. Defaults to None. 49 | """ 50 | if isinstance(pretrained, str): 51 | logger = get_root_logger() 52 | load_checkpoint(self, pretrained, strict=False, logger=logger) 53 | elif pretrained is None: 54 | # Here, we just use the default initialization in `ConvModule`. 55 | pass 56 | else: 57 | raise TypeError('pretrained must be a str or None') 58 | -------------------------------------------------------------------------------- /mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | from mmedit.models.builder import build_component 5 | from mmedit.models.registry import BACKBONES 6 | 7 | 8 | @BACKBONES.register_module() 9 | class SimpleEncoderDecoder(nn.Module): 10 | """Simple encoder-decoder model from matting. 11 | 12 | Args: 13 | encoder (dict): Config of the encoder. 14 | decoder (dict): Config of the decoder. 15 | """ 16 | 17 | def __init__(self, encoder, decoder): 18 | super().__init__() 19 | 20 | self.encoder = build_component(encoder) 21 | if hasattr(self.encoder, 'out_channels'): 22 | decoder['in_channels'] = self.encoder.out_channels 23 | self.decoder = build_component(decoder) 24 | 25 | def init_weights(self, pretrained=None): 26 | self.encoder.init_weights(pretrained) 27 | self.decoder.init_weights() 28 | 29 | def forward(self, *args, **kwargs): 30 | """Forward function. 31 | 32 | Returns: 33 | Tensor: The output tensor of the decoder. 34 | """ 35 | out = self.encoder(*args, **kwargs) 36 | out = self.decoder(out) 37 | return out 38 | -------------------------------------------------------------------------------- /mmedit/models/backbones/generation_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .resnet_generator import ResnetGenerator 3 | from .unet_generator import UnetGenerator 4 | 5 | __all__ = ['UnetGenerator', 'ResnetGenerator'] 6 | -------------------------------------------------------------------------------- /mmedit/models/backbones/sr_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .basicvsr_net import BasicVSRNet 3 | from .basicvsr_pp import BasicVSRPlusPlus 4 | from .dic_net import DICNet 5 | from .edsr import EDSR 6 | from .edvr_net import EDVRNet 7 | from .glean_styleganv2 import GLEANStyleGANv2 8 | from .iconvsr import IconVSR 9 | from .liif_net import LIIFEDSR, LIIFRDN 10 | from .rdn import RDN 11 | from .real_basicvsr_net import RealBasicVSRNet 12 | from .rrdb_net import RRDBNet 13 | from .sr_resnet import MSRResNet 14 | from .srcnn import SRCNN 15 | from .tdan_net import TDANNet 16 | from .tof import TOFlow 17 | from .ttsr_net import TTSRNet 18 | from .attention_block import CBAMBlock,CBAMBlock_noBN 19 | from .restormer_net import Restormer 20 | from .airnet import AirNet 21 | from .dmci_former import Restor_Encoderv2_pos3 22 | from .dmci_restormer import DMCI_ED_restormer_dual 23 | __all__ = [ 24 | 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'DICNet', 25 | 'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet', 26 | 'LIIFEDSR', 'LIIFRDN', 'BasicVSRPlusPlus', 'RealBasicVSRNet', 27 | 'DMCI_Net', 'Restor_Encoderv2_pos3','DMCI_ED_restormer_dual', 28 | 'CBAMBlock','CBAMBlock_noBN','Restormer','AirNet', 'MoCo', 29 | 30 | ] 31 | -------------------------------------------------------------------------------- /mmedit/models/backbones/sr_backbones/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.utils import _pair 6 | 7 | from mmcv.ops import modulated_deform_conv2d 8 | 9 | 10 | class DCN_layer(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 12 | groups=1, deformable_groups=1, bias=True, extra_offset_mask=True): 13 | super(DCN_layer, self).__init__() 14 | self.in_channels = in_channels 15 | self.out_channels = out_channels 16 | self.kernel_size = _pair(kernel_size) 17 | self.stride = stride 18 | self.padding = padding 19 | self.dilation = dilation 20 | self.groups = groups 21 | self.deformable_groups = deformable_groups 22 | self.with_bias = bias 23 | 24 | self.weight = nn.Parameter( 25 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 26 | 27 | self.extra_offset_mask = extra_offset_mask 28 | self.conv_offset_mask = nn.Conv2d( 29 | self.in_channels * 2, 30 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 31 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 32 | bias=True 33 | ) 34 | 35 | if bias: 36 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 37 | else: 38 | self.register_parameter('bias', None) 39 | 40 | self.init_offset() 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | n = self.in_channels 45 | for k in self.kernel_size: 46 | n *= k 47 | stdv = 1. / math.sqrt(n) 48 | self.weight.data.uniform_(-stdv, stdv) 49 | if self.bias is not None: 50 | self.bias.data.zero_() 51 | 52 | def init_offset(self): 53 | self.conv_offset_mask.weight.data.zero_() 54 | self.conv_offset_mask.bias.data.zero_() 55 | 56 | def forward(self, input_feat, inter): 57 | feat_degradation = torch.cat([input_feat, inter], dim=1) 58 | 59 | out = self.conv_offset_mask(feat_degradation) 60 | o1, o2, mask = torch.chunk(out, 3, dim=1) 61 | offset = torch.cat((o1, o2), dim=1) 62 | mask = torch.sigmoid(mask) 63 | 64 | return modulated_deform_conv2d(input_feat.contiguous(), offset, mask, self.weight, self.bias, self.stride, 65 | self.padding, self.dilation, self.groups, self.deformable_groups) 66 | -------------------------------------------------------------------------------- /mmedit/models/backbones/sr_backbones/duf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DynamicUpsamplingFilter(nn.Module): 9 | """Dynamic upsampling filter used in DUF. 10 | 11 | Ref: https://github.com/yhjo09/VSR-DUF. 12 | It only supports input with 3 channels. And it applies the same filters 13 | to 3 channels. 14 | 15 | Args: 16 | filter_size (tuple): Filter size of generated filters. 17 | The shape is (kh, kw). Default: (5, 5). 18 | """ 19 | 20 | def __init__(self, filter_size=(5, 5)): 21 | super().__init__() 22 | if not isinstance(filter_size, tuple): 23 | raise TypeError('The type of filter_size must be tuple, ' 24 | f'but got type{filter_size}') 25 | if len(filter_size) != 2: 26 | raise ValueError('The length of filter size must be 2, ' 27 | f'but got {len(filter_size)}.') 28 | # generate a local expansion filter, similar to im2col 29 | self.filter_size = filter_size 30 | filter_prod = np.prod(filter_size) 31 | expansion_filter = torch.eye(int(filter_prod)).view( 32 | filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw) 33 | self.expansion_filter = expansion_filter.repeat( 34 | 3, 1, 1, 1) # repeat for all the 3 channels 35 | 36 | def forward(self, x, filters): 37 | """Forward function for DynamicUpsamplingFilter. 38 | 39 | Args: 40 | x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w). 41 | filters (Tensor): Generated dynamic filters. 42 | The shape is (n, filter_prod, upsampling_square, h, w). 43 | filter_prod: prod of filter kernel size, e.g., 1*5*5=25. 44 | upsampling_square: similar to pixel shuffle, 45 | upsampling_square = upsampling * upsampling 46 | e.g., for x 4 upsampling, upsampling_square= 4*4 = 16 47 | 48 | Returns: 49 | Tensor: Filtered image with shape (n, 3*upsampling, h, w) 50 | """ 51 | n, filter_prod, upsampling_square, h, w = filters.size() 52 | kh, kw = self.filter_size 53 | expanded_input = F.conv2d( 54 | x, 55 | self.expansion_filter.to(x), 56 | padding=(kh // 2, kw // 2), 57 | groups=3) # (n, 3*filter_prod, h, w) 58 | expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute( 59 | 0, 3, 4, 1, 2) # (n, h, w, 3, filter_prod) 60 | filters = filters.permute( 61 | 0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square] 62 | out = torch.matmul(expanded_input, 63 | filters) # (n, h, w, 3, upsampling_square) 64 | return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w) 65 | -------------------------------------------------------------------------------- /mmedit/models/backbones/sr_backbones/mapping_net.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | # from torch_utils import misc 7 | # import dnnlib 8 | # import legacy 9 | 10 | from .dmci_styleganxl_cascade_v2 import Mapping_control 11 | # from mmedit.models.backbones.sr_backbones.dmci_styleganxl_cascade_v2 import Mapping_control 12 | 13 | 14 | class Mapping_control_z(nn.Module): 15 | def __init__(self, out_res=8, out_dim=512, inplace=True, **kwargs): 16 | # # re_list=[8 16 32 64 128 256] 17 | super().__init__() 18 | self.out_res=out_res 19 | self.down_feas=Mapping_control(out_res=out_res, out_dim=out_dim, inplace=inplace, **kwargs) 20 | 21 | self.to_z0 = nn.Sequential( # 256 8 22 | nn.Conv2d(256, 256, 3, 2, 1, bias=True),nn.LeakyReLU(inplace=inplace), # 256 4 23 | nn.Conv2d(256, 256, 3, 2, 1, bias=True),nn.LeakyReLU(inplace=inplace), # 256 2 24 | nn.AdaptiveAvgPool2d((1,1))# 256 x 1 x 1 25 | ) # 64 x 2 x 2 26 | 27 | self.to_z1 = nn.Linear(256, 128) # 1x 256 -> 1 x 64 28 | 29 | def forward(self, latent): 30 | latent,control_list=self.down_feas(latent) 31 | latent=latent.view(latent.shape[0],-1,self.out_res,self.out_res) 32 | z0= self.to_z0(latent) 33 | z_noise=self.to_z1(z0.view(-1,256)) 34 | 35 | return z_noise,control_list 36 | 37 | 38 | # class Mapping_control_vec(nn.Module): 39 | # def __init__(self, input_res=256, tar_dim=128, base_dim=16, max_dim=512, inplace=True, **kwargs): 40 | # # # re_list=[8 16 32 64 128 256] 41 | # super().__init__() 42 | # self.input_layer=nn.Sequential(nn.Conv2d(3, base_dim, 3, 1, 1, bias=True),nn.LeakyReLU(inplace=inplace)) 43 | # self.log_size = int(np.log2(input_res)) 44 | # pre_ch=base_dim 45 | # for i in range(1,self.log_size+1): 46 | # input_res=input_res//2 47 | # cur_ch=min(base_dim*i,max_dim) 48 | # layer=nn.Sequential(nn.Conv2d(pre_ch, cur_ch, 3, 2, 1, bias=True),nn.AdaptiveAvgPool2d(input_res), nn.LeakyReLU(inplace=inplace)) 49 | # setattr(self,f"down_layer{i}",layer) 50 | # pre_ch=cur_ch 51 | 52 | # self.to_z0 = nn.AdaptiveAvgPool2d(1) 53 | # # flatten input.view(input.size(0), -1) 54 | # self.to_z1 = nn.Linear(max_dim, tar_dim) 55 | 56 | # def forward(self, img): 57 | # latent = self.input_layer(img) 58 | # for i in range(1,self.log_size+1): 59 | # latent=getattr(self,f"down_layer{i}")(latent) 60 | # z0= self.to_z0(latent) 61 | # z0_flatten=z0.view(z0.size(0), -1) 62 | # vec=self.to_z1(z0_flatten) 63 | # return vec 64 | 65 | -------------------------------------------------------------------------------- /mmedit/models/backbones/vfi_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cain_net import CAINNet 3 | from .flavr_net import FLAVRNet 4 | from .tof_vfi_net import TOFlowVFINet 5 | 6 | __all__ = ['CAINNet', 'TOFlowVFINet', 'FLAVRNet'] 7 | -------------------------------------------------------------------------------- /mmedit/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv import build_from_cfg 4 | 5 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS 6 | 7 | 8 | def build(cfg, registry, default_args=None): 9 | """Build module function. 10 | 11 | Args: 12 | cfg (dict): Configuration for building modules. 13 | registry (obj): ``registry`` object. 14 | default_args (dict, optional): Default arguments. Defaults to None. 15 | """ 16 | if isinstance(cfg, list): 17 | modules = [ 18 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 19 | ] 20 | return nn.Sequential(*modules) 21 | 22 | return build_from_cfg(cfg, registry, default_args) 23 | 24 | 25 | def build_backbone(cfg): 26 | """Build backbone. 27 | 28 | Args: 29 | cfg (dict): Configuration for building backbone. 30 | """ 31 | return build(cfg, BACKBONES) 32 | 33 | 34 | def build_component(cfg): 35 | """Build component. 36 | 37 | Args: 38 | cfg (dict): Configuration for building component. 39 | """ 40 | return build(cfg, COMPONENTS) 41 | 42 | 43 | def build_loss(cfg): 44 | """Build loss. 45 | 46 | Args: 47 | cfg (dict): Configuration for building loss. 48 | """ 49 | return build(cfg, LOSSES) 50 | 51 | 52 | def build_model(cfg, train_cfg=None, test_cfg=None): 53 | """Build model. 54 | 55 | Args: 56 | cfg (dict): Configuration for building model. 57 | train_cfg (dict): Training configuration. Default: None. 58 | test_cfg (dict): Testing configuration. Default: None. 59 | """ 60 | return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 61 | -------------------------------------------------------------------------------- /mmedit/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aspp import ASPP 3 | from .contextual_attention import ContextualAttentionModule 4 | from .conv import * # noqa: F401, F403 5 | from .downsample import pixel_unshuffle 6 | from .ensemble import SpatialTemporalEnsemble 7 | from .flow_warp import flow_warp 8 | from .gated_conv_module import SimpleGatedConvModule 9 | from .gca_module import GCAModule 10 | from .generation_model_utils import (GANImageBuffer, ResidualBlockWithDropout, 11 | UnetSkipConnectionBlock, 12 | generation_init_weights) 13 | from .img_normalize import ImgNormalize 14 | from .linear_module import LinearModule 15 | from .mask_conv_module import MaskConvModule 16 | from .model_utils import (extract_around_bbox, extract_bbox_patch, scale_bbox, 17 | set_requires_grad) 18 | from .partial_conv import PartialConv2d 19 | from .separable_conv_module import DepthwiseSeparableConvModule 20 | from .sr_backbone_utils import (ResidualBlockNoBN, default_init_weights, 21 | make_layer) 22 | from .upsample import PixelShufflePack 23 | from .common_model_dual import CompressionModel_dual 24 | from .layers import conv3x3, conv1x1, DepthConvBlock2, DepthConvBlock3, DepthConvBlock4,ResidualBlockUpsample, ResidualBlockWithStride2,ResidualBlock,ResidualBlockWithStride2_dynamic,Dynamic_conv2d,ConvFFN2 25 | from .video_net import UNet 26 | __all__ = [ 27 | 'ASPP', 'PartialConv2d', 'PixelShufflePack', 'default_init_weights', 28 | 'ResidualBlockNoBN', 'make_layer', 'MaskConvModule', 'extract_bbox_patch', 29 | 'extract_around_bbox', 'set_requires_grad', 'scale_bbox', 30 | 'DepthwiseSeparableConvModule', 'ContextualAttentionModule', 'GCAModule', 31 | 'SimpleGatedConvModule', 'LinearModule', 'flow_warp', 'ImgNormalize', 32 | 'generation_init_weights', 'GANImageBuffer', 'UnetSkipConnectionBlock', 33 | 'ResidualBlockWithDropout', 'pixel_unshuffle', 'SpatialTemporalEnsemble','ResidualBlock','conv3x3', 'conv1x1', 'DepthConvBlock2', 'DepthConvBlock3', 'DepthConvBlock4','ResidualBlockUpsample', 'ResidualBlockWithStride2','UNet', 34 | 'Dynamic_conv2d','ConvFFN2', 35 | 'ResidualBlockWithStride2_dynamic','CompressionModel_dual' 36 | ] 37 | -------------------------------------------------------------------------------- /mmedit/models/common/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import CONV_LAYERS 3 | from torch import nn 4 | 5 | CONV_LAYERS.register_module('Deconv', module=nn.ConvTranspose2d) 6 | # TODO: octave conv 7 | -------------------------------------------------------------------------------- /mmedit/models/common/downsample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def pixel_unshuffle(x, scale): 3 | """Down-sample by pixel unshuffle. 4 | 5 | Args: 6 | x (Tensor): Input tensor. 7 | scale (int): Scale factor. 8 | 9 | Returns: 10 | Tensor: Output tensor. 11 | """ 12 | 13 | b, c, h, w = x.shape 14 | if h % scale != 0 or w % scale != 0: 15 | raise AssertionError( 16 | f'Invalid scale ({scale}) of pixel unshuffle for tensor ' 17 | f'with shape: {x.shape}') 18 | h = int(h / scale) 19 | w = int(w / scale) 20 | x = x.view(b, c, h, scale, w, scale) 21 | x = x.permute(0, 1, 3, 5, 2, 4) 22 | return x.reshape(b, -1, h, w) 23 | -------------------------------------------------------------------------------- /mmedit/models/common/flow_warp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def flow_warp(x, 7 | flow, 8 | interpolation='bilinear', 9 | padding_mode='zeros', 10 | align_corners=True): 11 | """Warp an image or a feature map with optical flow. 12 | 13 | Args: 14 | x (Tensor): Tensor with size (n, c, h, w). 15 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 16 | a two-channel, denoting the width and height relative offsets. 17 | Note that the values are not normalized to [-1, 1]. 18 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 19 | Default: 'bilinear'. 20 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 21 | Default: 'zeros'. 22 | align_corners (bool): Whether align corners. Default: True. 23 | 24 | Returns: 25 | Tensor: Warped image or feature map. 26 | """ 27 | if x.size()[-2:] != flow.size()[1:3]: 28 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 29 | f'flow ({flow.size()[1:3]}) are not the same.') 30 | _, _, h, w = x.size() 31 | # create mesh grid 32 | device = flow.device 33 | grid_y, grid_x = torch.meshgrid( 34 | torch.arange(0, h, device=device, dtype=x.dtype), 35 | torch.arange(0, w, device=device, dtype=x.dtype)) 36 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2 37 | grid.requires_grad = False 38 | 39 | grid_flow = grid + flow 40 | # scale grid_flow to [-1,1] 41 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 42 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 43 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 44 | output = F.grid_sample( 45 | x, 46 | grid_flow, 47 | mode=interpolation, 48 | padding_mode=padding_mode, 49 | align_corners=align_corners) 50 | return output 51 | -------------------------------------------------------------------------------- /mmedit/models/common/gated_conv_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | from mmcv.cnn import ConvModule, build_activation_layer 7 | 8 | 9 | class SimpleGatedConvModule(nn.Module): 10 | """Simple Gated Convolutional Module. 11 | 12 | This module is a simple gated convolutional module. The detailed formula 13 | is: 14 | 15 | .. math:: 16 | y = \\phi(conv1(x)) * \\sigma(conv2(x)), 17 | 18 | where `phi` is the feature activation function and `sigma` is the gate 19 | activation function. In default, the gate activation function is sigmoid. 20 | 21 | Args: 22 | in_channels (int): Same as nn.Conv2d. 23 | out_channels (int): The number of channels of the output feature. Note 24 | that `out_channels` in the conv module is doubled since this module 25 | contains two convolutions for feature and gate separately. 26 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 27 | feat_act_cfg (dict): Config dict for feature activation layer. 28 | gate_act_cfg (dict): Config dict for gate activation layer. 29 | kwargs (keyword arguments): Same as `ConvModule`. 30 | """ 31 | 32 | def __init__(self, 33 | in_channels, 34 | out_channels, 35 | kernel_size, 36 | feat_act_cfg=dict(type='ELU'), 37 | gate_act_cfg=dict(type='Sigmoid'), 38 | **kwargs): 39 | super().__init__() 40 | # the activation function should specified outside conv module 41 | kwargs_ = copy.deepcopy(kwargs) 42 | kwargs_['act_cfg'] = None 43 | self.with_feat_act = feat_act_cfg is not None 44 | self.with_gate_act = gate_act_cfg is not None 45 | 46 | self.conv = ConvModule(in_channels, out_channels * 2, kernel_size, 47 | **kwargs_) 48 | 49 | if self.with_feat_act: 50 | self.feat_act = build_activation_layer(feat_act_cfg) 51 | 52 | if self.with_gate_act: 53 | self.gate_act = build_activation_layer(gate_act_cfg) 54 | 55 | def forward(self, x): 56 | """Forward Function. 57 | 58 | Args: 59 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 60 | 61 | Returns: 62 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 63 | """ 64 | x = self.conv(x) 65 | x, gate = torch.split(x, x.size(1) // 2, dim=1) 66 | if self.with_feat_act: 67 | x = self.feat_act(x) 68 | if self.with_gate_act: 69 | gate = self.gate_act(gate) 70 | x = x * gate 71 | 72 | return x 73 | -------------------------------------------------------------------------------- /mmedit/models/common/img_normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ImgNormalize(nn.Conv2d): 7 | """Normalize images with the given mean and std value. 8 | 9 | Based on Conv2d layer, can work in GPU. 10 | 11 | Args: 12 | pixel_range (float): Pixel range of feature. 13 | img_mean (Tuple[float]): Image mean of each channel. 14 | img_std (Tuple[float]): Image std of each channel. 15 | sign (int): Sign of bias. Default -1. 16 | """ 17 | 18 | def __init__(self, pixel_range, img_mean, img_std, sign=-1): 19 | 20 | assert len(img_mean) == len(img_std) 21 | num_channels = len(img_mean) 22 | super().__init__(num_channels, num_channels, kernel_size=1) 23 | 24 | std = torch.Tensor(img_std) 25 | self.weight.data = torch.eye(num_channels).view( 26 | num_channels, num_channels, 1, 1) 27 | self.weight.data.div_(std.view(num_channels, 1, 1, 1)) 28 | self.bias.data = sign * pixel_range * torch.Tensor(img_mean) 29 | self.bias.data.div_(std) 30 | 31 | self.weight.requires_grad = False 32 | self.bias.requires_grad = False 33 | -------------------------------------------------------------------------------- /mmedit/models/common/upsample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .sr_backbone_utils import default_init_weights 6 | 7 | 8 | class PixelShufflePack(nn.Module): 9 | """ Pixel Shuffle upsample layer. 10 | 11 | Args: 12 | in_channels (int): Number of input channels. 13 | out_channels (int): Number of output channels. 14 | scale_factor (int): Upsample ratio. 15 | upsample_kernel (int): Kernel size of Conv layer to expand channels. 16 | 17 | Returns: 18 | Upsampled feature map. 19 | """ 20 | 21 | def __init__(self, in_channels, out_channels, scale_factor, 22 | upsample_kernel): 23 | super().__init__() 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.scale_factor = scale_factor 27 | self.upsample_kernel = upsample_kernel 28 | self.upsample_conv = nn.Conv2d( 29 | self.in_channels, 30 | self.out_channels * scale_factor * scale_factor, 31 | self.upsample_kernel, 32 | padding=(self.upsample_kernel - 1) // 2) 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | """Initialize weights for PixelShufflePack. 37 | """ 38 | default_init_weights(self, 1) 39 | 40 | def forward(self, x): 41 | """Forward function for PixelShufflePack. 42 | 43 | Args: 44 | x (Tensor): Input tensor with shape (n, c, h, w). 45 | 46 | Returns: 47 | Tensor: Forward results. 48 | """ 49 | x = self.upsample_conv(x) 50 | x = F.pixel_shuffle(x, self.scale_factor) 51 | return x 52 | -------------------------------------------------------------------------------- /mmedit/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .discriminators import (DeepFillv1Discriminators, GLDiscs, ModifiedVGG, 3 | MultiLayerDiscriminator, PatchDiscriminator, 4 | UNetDiscriminatorWithSpectralNorm) 5 | from .refiners import DeepFillRefiner, PlainRefiner 6 | from .stylegan2 import StyleGAN2Discriminator, StyleGANv2Generator,SynthesisNetworkv2 7 | 8 | __all__ = [ 9 | 'PlainRefiner', 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', 10 | 'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator', 11 | 'StyleGAN2Discriminator', 'StyleGANv2Generator','SynthesisNetworkv2', 12 | 'UNetDiscriminatorWithSpectralNorm' 13 | ] 14 | -------------------------------------------------------------------------------- /mmedit/models/components/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .deepfill_disc import DeepFillv1Discriminators 3 | from .gl_disc import GLDiscs 4 | from .light_cnn import LightCNN 5 | from .modified_vgg import ModifiedVGG,DiscriminatorBlock 6 | from .multi_layer_disc import MultiLayerDiscriminator 7 | from .patch_disc import PatchDiscriminator 8 | from .smpatch_disc import SoftMaskPatchDiscriminator 9 | from .ttsr_disc import TTSRDiscriminator 10 | from .unet_disc import UNetDiscriminatorWithSpectralNorm 11 | from .discriminator_arch import UNetDiscriminatorSN 12 | __all__ = [ 13 | 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', 'TTSRDiscriminator', 14 | 'DeepFillv1Discriminators', 'PatchDiscriminator', 'LightCNN', 15 | 'UNetDiscriminatorWithSpectralNorm', 'SoftMaskPatchDiscriminator','DiscriminatorBlock', 16 | 'UNetDiscriminatorSN' 17 | ] 18 | -------------------------------------------------------------------------------- /mmedit/models/components/discriminators/deepfill_disc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import normal_init 4 | from mmcv.runner import load_checkpoint 5 | 6 | from mmedit.models.builder import build_component 7 | from mmedit.models.registry import COMPONENTS 8 | from mmedit.utils import get_root_logger 9 | 10 | 11 | @COMPONENTS.register_module() 12 | class DeepFillv1Discriminators(nn.Module): 13 | """Discriminators used in DeepFillv1 model. 14 | 15 | In DeepFillv1 model, the discriminators are independent without any 16 | concatenation like Global&Local model. Thus, we call this model 17 | `DeepFillv1Discriminators`. There exist a global discriminator and a local 18 | discriminator with global and local input respectively. 19 | 20 | The details can be found in: 21 | Generative Image Inpainting with Contextual Attention. 22 | 23 | Args: 24 | global_disc_cfg (dict): Config dict for global discriminator. 25 | local_disc_cfg (dict): Config dict for local discriminator. 26 | """ 27 | 28 | def __init__(self, global_disc_cfg, local_disc_cfg): 29 | super().__init__() 30 | self.global_disc = build_component(global_disc_cfg) 31 | self.local_disc = build_component(local_disc_cfg) 32 | 33 | def forward(self, x): 34 | """Forward function. 35 | 36 | Args: 37 | x (tuple[torch.Tensor]): Contains global image and the local image 38 | patch. 39 | 40 | Returns: 41 | tuple[torch.Tensor]: Contains the prediction from discriminators \ 42 | in global image and local image patch. 43 | """ 44 | global_img, local_img = x 45 | 46 | global_pred = self.global_disc(global_img) 47 | local_pred = self.local_disc(local_img) 48 | 49 | return global_pred, local_pred 50 | 51 | def init_weights(self, pretrained=None): 52 | """Init weights for models. 53 | 54 | Args: 55 | pretrained (str, optional): Path for pretrained weights. If given 56 | None, pretrained weights will not be loaded. Defaults to None. 57 | """ 58 | if isinstance(pretrained, str): 59 | logger = get_root_logger() 60 | load_checkpoint(self, pretrained, strict=False, logger=logger) 61 | elif pretrained is None: 62 | for m in self.modules(): 63 | if isinstance(m, nn.Linear): 64 | normal_init(m, 0, std=0.02) 65 | elif isinstance(m, nn.Conv2d): 66 | normal_init(m, 0.0, std=0.02) 67 | else: 68 | raise TypeError('pretrained must be a str or None but got' 69 | f'{type(pretrained)} instead.') 70 | -------------------------------------------------------------------------------- /mmedit/models/components/discriminators/gl_disc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.runner import load_checkpoint 5 | 6 | from mmedit.models.registry import COMPONENTS 7 | from mmedit.utils import get_root_logger 8 | from .multi_layer_disc import MultiLayerDiscriminator 9 | 10 | 11 | @COMPONENTS.register_module() 12 | class GLDiscs(nn.Module): 13 | """Discriminators in Global&Local 14 | 15 | This discriminator contains a local discriminator and a global 16 | discriminator as described in the original paper: 17 | Globally and locally Consistent Image Completion 18 | 19 | Args: 20 | global_disc_cfg (dict): Config dict to build global discriminator. 21 | local_disc_cfg (dict): Config dict to build local discriminator. 22 | """ 23 | 24 | def __init__(self, global_disc_cfg, local_disc_cfg): 25 | super().__init__() 26 | self.global_disc = MultiLayerDiscriminator(**global_disc_cfg) 27 | self.local_disc = MultiLayerDiscriminator(**local_disc_cfg) 28 | 29 | self.fc = nn.Linear(2048, 1, bias=True) 30 | 31 | def forward(self, x): 32 | """Forward function. 33 | 34 | Args: 35 | x (tuple[torch.Tensor]): Contains global image and the local image 36 | patch. 37 | 38 | Returns: 39 | tuple[torch.Tensor]: Contains the prediction from discriminators \ 40 | in global image and local image patch. 41 | """ 42 | g_img, l_img = x 43 | g_pred = self.global_disc(g_img) 44 | l_pred = self.local_disc(l_img) 45 | 46 | pred = self.fc(torch.cat([g_pred, l_pred], dim=1)) 47 | 48 | return pred 49 | 50 | def init_weights(self, pretrained=None): 51 | """Init weights for models. 52 | 53 | Args: 54 | pretrained (str, optional): Path for pretrained weights. If given 55 | None, pretrained weights will not be loaded. Defaults to None. 56 | """ 57 | if isinstance(pretrained, str): 58 | logger = get_root_logger() 59 | load_checkpoint(self, pretrained, strict=False, logger=logger) 60 | elif pretrained is None: 61 | for m in self.modules(): 62 | # Here, we only initialize the module with fc layer since the 63 | # conv and norm layers has been initialized in `ConvModule`. 64 | if isinstance(m, nn.Linear): 65 | nn.init.normal_(m.weight.data, 0.0, 0.02) 66 | nn.init.constant_(m.bias.data, 0.0) 67 | else: 68 | raise TypeError('pretrained must be a str or None') 69 | -------------------------------------------------------------------------------- /mmedit/models/components/discriminators/ttsr_disc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.runner import load_checkpoint 4 | 5 | from mmedit.models.registry import COMPONENTS 6 | from mmedit.utils import get_root_logger 7 | 8 | 9 | @COMPONENTS.register_module() 10 | class TTSRDiscriminator(nn.Module): 11 | """A discriminator for TTSR. 12 | 13 | Args: 14 | in_channels (int): Channel number of inputs. Default: 3. 15 | in_size (int): Size of input image. Default: 160. 16 | """ 17 | 18 | def __init__(self, in_channels=3, in_size=160): 19 | super().__init__() 20 | 21 | self.body = nn.Sequential( 22 | nn.Conv2d(in_channels, 32, 3, 1, 1), nn.LeakyReLU(0.2), 23 | nn.Conv2d(32, 32, 3, 2, 1), nn.LeakyReLU(0.2), 24 | nn.Conv2d(32, 64, 3, 1, 1), nn.LeakyReLU(0.2), 25 | nn.Conv2d(64, 64, 3, 2, 1), nn.LeakyReLU(0.2), 26 | nn.Conv2d(64, 128, 3, 1, 1), nn.LeakyReLU(0.2), 27 | nn.Conv2d(128, 128, 3, 2, 1), nn.LeakyReLU(0.2), 28 | nn.Conv2d(128, 256, 3, 1, 1), nn.LeakyReLU(0.2), 29 | nn.Conv2d(256, 256, 3, 2, 1), nn.LeakyReLU(0.2), 30 | nn.Conv2d(256, 512, 3, 1, 1), nn.LeakyReLU(0.2), 31 | nn.Conv2d(512, 512, 3, 2, 1), nn.LeakyReLU(0.2)) 32 | 33 | self.last = nn.Sequential( 34 | nn.Linear(in_size // 32 * in_size // 32 * 512, 1024), 35 | nn.LeakyReLU(0.2), nn.Linear(1024, 1)) 36 | 37 | def forward(self, x): 38 | """Forward function. 39 | 40 | Args: 41 | x (Tensor): Input tensor with shape (n, c, h, w). 42 | 43 | Returns: 44 | Tensor: Forward results. 45 | """ 46 | 47 | x = self.body(x) 48 | x = x.view(x.size(0), -1) 49 | x = self.last(x) 50 | 51 | return x 52 | 53 | def init_weights(self, pretrained=None, strict=True): 54 | """Init weights for models. 55 | 56 | Args: 57 | pretrained (str, optional): Path for pretrained weights. If given 58 | None, pretrained weights will not be loaded. Defaults to None. 59 | strict (boo, optional): Whether strictly load the pretrained model. 60 | Defaults to True. 61 | """ 62 | if isinstance(pretrained, str): 63 | logger = get_root_logger() 64 | load_checkpoint(self, pretrained, strict=strict, logger=logger) 65 | elif pretrained is not None: 66 | raise TypeError(f'"pretrained" must be a str or None. ' 67 | f'But received {type(pretrained)}.') 68 | -------------------------------------------------------------------------------- /mmedit/models/components/refiners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .deepfill_refiner import DeepFillRefiner 3 | from .mlp_refiner import MLPRefiner 4 | from .plain_refiner import PlainRefiner 5 | 6 | __all__ = ['PlainRefiner', 'DeepFillRefiner', 'MLPRefiner'] 7 | -------------------------------------------------------------------------------- /mmedit/models/components/refiners/deepfill_refiner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mmedit.models.builder import build_component 7 | from mmedit.models.registry import COMPONENTS 8 | 9 | 10 | @COMPONENTS.register_module() 11 | class DeepFillRefiner(nn.Module): 12 | """Refiner used in DeepFill model. 13 | 14 | This implementation follows: 15 | Generative Image Inpainting with Contextual Attention. 16 | 17 | Args: 18 | encoder_attention (dict): Config dict for encoder used in branch 19 | with contextual attention module. 20 | encoder_conv (dict): Config dict for encoder used in branch with 21 | just convolutional operation. 22 | dilation_neck (dict): Config dict for dilation neck in branch with 23 | just convolutional operation. 24 | contextual_attention (dict): Config dict for contextual attention 25 | neck. 26 | decoder (dict): Config dict for decoder used to fuse and decode 27 | features. 28 | """ 29 | 30 | def __init__(self, 31 | encoder_attention=dict( 32 | type='DeepFillEncoder', encoder_type='stage2_attention'), 33 | encoder_conv=dict( 34 | type='DeepFillEncoder', encoder_type='stage2_conv'), 35 | dilation_neck=dict( 36 | type='GLDilationNeck', 37 | in_channels=128, 38 | act_cfg=dict(type='ELU')), 39 | contextual_attention=dict( 40 | type='ContextualAttentionNeck', in_channels=128), 41 | decoder=dict(type='DeepFillDecoder', in_channels=256)): 42 | super().__init__() 43 | self.encoder_attention = build_component(encoder_attention) 44 | self.encoder_conv = build_component(encoder_conv) 45 | self.contextual_attention_neck = build_component(contextual_attention) 46 | self.dilation_neck = build_component(dilation_neck) 47 | self.decoder = build_component(decoder) 48 | 49 | def forward(self, x, mask): 50 | """Forward Function. 51 | 52 | Args: 53 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 54 | mask (torch.Tensor): Input tensor with shape of (n, 1, h, w). 55 | 56 | Returns: 57 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 58 | """ 59 | # conv branch 60 | encoder_dict = self.encoder_conv(x) 61 | conv_x = self.dilation_neck(encoder_dict['out']) 62 | 63 | # contextual attention branch 64 | attention_x = self.encoder_attention(x)['out'] 65 | h_x, w_x = attention_x.shape[-2:] 66 | # resale mask to a smaller size 67 | resized_mask = F.interpolate(mask, size=(h_x, w_x)) 68 | attention_x, offset = self.contextual_attention_neck( 69 | attention_x, resized_mask) 70 | 71 | # concat two branches 72 | x = torch.cat([conv_x, attention_x], dim=1) 73 | x = self.decoder(dict(out=x)) 74 | 75 | return x, offset 76 | -------------------------------------------------------------------------------- /mmedit/models/components/refiners/mlp_refiner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.runner import load_checkpoint 4 | 5 | from mmedit.models.registry import COMPONENTS 6 | from mmedit.utils import get_root_logger 7 | 8 | 9 | @COMPONENTS.register_module() 10 | class MLPRefiner(nn.Module): 11 | """Multilayer perceptrons (MLPs), refiner used in LIIF. 12 | 13 | Args: 14 | in_dim (int): Input dimension. 15 | out_dim (int): Output dimension. 16 | hidden_list (list[int]): List of hidden dimensions. 17 | """ 18 | 19 | def __init__(self, in_dim, out_dim, hidden_list): 20 | super().__init__() 21 | layers = [] 22 | lastv = in_dim 23 | for hidden in hidden_list: 24 | layers.append(nn.Linear(lastv, hidden)) 25 | layers.append(nn.ReLU()) 26 | lastv = hidden 27 | layers.append(nn.Linear(lastv, out_dim)) 28 | self.layers = nn.Sequential(*layers) 29 | 30 | def forward(self, x): 31 | """Forward function. 32 | 33 | Args: 34 | x (Tensor): The input of MLP. 35 | 36 | Returns: 37 | Tensor: The output of MLP. 38 | """ 39 | shape = x.shape[:-1] 40 | x = self.layers(x.view(-1, x.shape[-1])) 41 | return x.view(*shape, -1) 42 | 43 | def init_weights(self, pretrained=None, strict=True): 44 | """Init weights for models. 45 | 46 | Args: 47 | pretrained (str, optional): Path for pretrained weights. If given 48 | None, pretrained weights will not be loaded. Defaults to None. 49 | strict (boo, optional): Whether strictly load the pretrained model. 50 | Defaults to True. 51 | """ 52 | if isinstance(pretrained, str): 53 | logger = get_root_logger() 54 | load_checkpoint(self, pretrained, strict=strict, logger=logger) 55 | elif pretrained is None: 56 | pass 57 | else: 58 | raise TypeError(f'"pretrained" must be a str or None. ' 59 | f'But received {type(pretrained)}.') 60 | -------------------------------------------------------------------------------- /mmedit/models/components/refiners/plain_refiner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn.utils.weight_init import xavier_init 5 | 6 | from mmedit.models.registry import COMPONENTS 7 | 8 | 9 | @COMPONENTS.register_module() 10 | class PlainRefiner(nn.Module): 11 | """Simple refiner from Deep Image Matting. 12 | 13 | Args: 14 | conv_channels (int): Number of channels produced by the three main 15 | convolutional layer. 16 | loss_refine (dict): Config of the loss of the refiner. Default: None. 17 | pretrained (str): Name of pretrained model. Default: None. 18 | """ 19 | 20 | def __init__(self, conv_channels=64, pretrained=None): 21 | super().__init__() 22 | 23 | assert pretrained is None, 'pretrained not supported yet' 24 | 25 | self.refine_conv1 = nn.Conv2d( 26 | 4, conv_channels, kernel_size=3, padding=1) 27 | self.refine_conv2 = nn.Conv2d( 28 | conv_channels, conv_channels, kernel_size=3, padding=1) 29 | self.refine_conv3 = nn.Conv2d( 30 | conv_channels, conv_channels, kernel_size=3, padding=1) 31 | self.refine_pred = nn.Conv2d( 32 | conv_channels, 1, kernel_size=3, padding=1) 33 | 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | def init_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | xavier_init(m) 40 | 41 | def forward(self, x, raw_alpha): 42 | """Forward function. 43 | 44 | Args: 45 | x (Tensor): The input feature map of refiner. 46 | raw_alpha (Tensor): The raw predicted alpha matte. 47 | 48 | Returns: 49 | Tensor: The refined alpha matte. 50 | """ 51 | out = self.relu(self.refine_conv1(x)) 52 | out = self.relu(self.refine_conv2(out)) 53 | out = self.relu(self.refine_conv3(out)) 54 | raw_refine = self.refine_pred(out) 55 | pred_refine = torch.sigmoid(raw_alpha + raw_refine) 56 | return pred_refine 57 | -------------------------------------------------------------------------------- /mmedit/models/components/stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import (StyleGAN2Discriminator, 3 | StyleGANv2Generator) 4 | from .generator_discriminatorxl import SynthesisNetworkv2 5 | __all__ = ['StyleGANv2Generator', 'StyleGAN2Discriminator','SynthesisNetworkv2'] 6 | -------------------------------------------------------------------------------- /mmedit/models/extractors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .feedback_hour_glass import (FeedbackHourglass, Hourglass, 3 | reduce_to_five_heatmaps) 4 | from .lte import LTE 5 | 6 | __all__ = ['LTE', 'Hourglass', 'FeedbackHourglass', 'reduce_to_five_heatmaps'] 7 | -------------------------------------------------------------------------------- /mmedit/models/inpaintors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aot_inpaintor import AOTInpaintor 3 | from .deepfillv1 import DeepFillv1Inpaintor 4 | from .gl_inpaintor import GLInpaintor 5 | from .one_stage import OneStageInpaintor 6 | from .pconv_inpaintor import PConvInpaintor 7 | from .two_stage import TwoStageInpaintor 8 | 9 | __all__ = [ 10 | 'OneStageInpaintor', 'GLInpaintor', 'PConvInpaintor', 'TwoStageInpaintor', 11 | 'DeepFillv1Inpaintor', 'AOTInpaintor' 12 | ] 13 | -------------------------------------------------------------------------------- /mmedit/models/losses/RDPix_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ..registry import LOSSES 9 | from mmedit.utils import yuv_444_to_420 10 | 11 | def get_mse(self, x, x_hat, yuv420): 12 | _, _, H, W = x.size() 13 | pixel_num = H * W 14 | if yuv420: 15 | org_y, org_u, org_v = yuv_444_to_420(x) 16 | rec_y, rec_u, rec_v = yuv_444_to_420(x_hat) 17 | mse_y = self.mse(reduction='none')(org_y, rec_y) 18 | mse_u = self.mse(reduction='none')(org_u, rec_u) 19 | mse_v = self.mse(reduction='none')(org_v, rec_v) 20 | mse_y = torch.sum(mse_y, dim=(1, 2, 3)) / pixel_num 21 | mse_u = torch.sum(mse_u, dim=(1, 2, 3)) / pixel_num * 4 22 | mse_v = torch.sum(mse_v, dim=(1, 2, 3)) / pixel_num * 4 23 | mse = (4 * mse_y + mse_u + mse_v) / 6 * 3 # rgb is sum, not average MSE 24 | else: 25 | mse = self.mse(x, x_hat) 26 | mse = torch.sum(mse, dim=(1, 2, 3)) / pixel_num 27 | return mse 28 | 29 | 30 | @LOSSES.register_module() 31 | class Rate_PixLoss(nn.Module): 32 | def __init__(self,loss_weight: float = 1.0,yuv420=True): 33 | super().__init__() 34 | self.loss_weight = loss_weight 35 | self.mse=nn.MSELoss(reduction='none') 36 | self.yuv420=yuv420 37 | 38 | def forward(self, pred, target, lambda_w=None, **kwargs): 39 | _, _, H, W = pred.size() 40 | pixel_num = H * W 41 | if self.yuv420: 42 | org_y, org_u, org_v = yuv_444_to_420(pred) 43 | rec_y, rec_u, rec_v = yuv_444_to_420(target) 44 | mse_y = self.mse(org_y, rec_y) 45 | mse_u = self.mse(org_u, rec_u) 46 | mse_v = self.mse(org_v, rec_v) 47 | mse_y = torch.sum(mse_y, dim=(1, 2, 3)) / pixel_num 48 | mse_u = torch.sum(mse_u, dim=(1, 2, 3)) / pixel_num * 4 49 | mse_v = torch.sum(mse_v, dim=(1, 2, 3)) / pixel_num * 4 50 | mse = (4 * mse_y + mse_u + mse_v) / 6 * 3 # rgb is sum, not average MSE 51 | else: 52 | mse = self.mse(pred, target) 53 | mse = torch.sum(mse, dim=(1, 2, 3)) / pixel_num 54 | mse_loss_lambda = torch.sum(mse*lambda_w) / len(lambda_w) 55 | return self.loss_weight*mse_loss_lambda 56 | 57 | -------------------------------------------------------------------------------- /mmedit/models/losses/RD_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ..registry import LOSSES 9 | from .utils import masked_loss 10 | 11 | 12 | def get_loss_func(loss_func): 13 | # pylint: disable=possibly-unused-variable 14 | 15 | def get_final_loss(costs, rate_number): 16 | losses = costs 17 | loss = torch.sum(losses) / rate_number 18 | return { 19 | 'costs': costs, 20 | 'losses': losses, 21 | 'loss': loss, 22 | } 23 | 24 | def loss_me_mse(rd, lmbdas): 25 | costs = lmbdas * rd['me_mse'] 26 | return get_final_loss(costs, len(lmbdas)) 27 | 28 | def loss_me_rdc_mse(rd, lmbdas): 29 | costs = lmbdas * rd['me_mse'] + rd['bpp_mv_y'] + rd['bpp_mv_z'] 30 | return get_final_loss(costs, len(lmbdas)) 31 | 32 | def loss_recon_mse(rd, lmbdas): 33 | costs = lmbdas * rd['mse'] 34 | return get_final_loss(costs, len(lmbdas)) 35 | 36 | def loss_recon_rdc_mse(rd, lmbdas): 37 | costs = lmbdas * rd['mse'] + rd['bpp_y'] + rd['bpp_z'] 38 | return get_final_loss(costs, len(lmbdas)) 39 | 40 | def loss_total_rdc_mse(rd, lmbdas): 41 | costs = lmbdas * rd['mse'] + rd['bpp'] 42 | return get_final_loss(costs, len(lmbdas)) 43 | 44 | def loss_total_rdc_ms_ssim(rd, lmbdas): 45 | costs = lmbdas / 17 * (rd['1-ssim']) + rd['bpp'] 46 | return get_final_loss(costs, len(lmbdas)) 47 | 48 | loss_func_name = f'loss_{loss_func}' 49 | assert loss_func_name in locals() 50 | return locals()[loss_func_name] 51 | 52 | 53 | @LOSSES.register_module() 54 | class RDLoss(nn.Module): 55 | def __init__(self,loss_weight: float = 1.0): 56 | super().__init__() 57 | self.loss_weight = loss_weight 58 | 59 | def forward(self, rd, lmbdas, loss_type): 60 | loss_func = get_loss_func(loss_type) 61 | loss_dict=loss_func(rd, lmbdas) 62 | return self.loss_weight * loss_dict["loss"], loss_dict 63 | -------------------------------------------------------------------------------- /mmedit/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .composition_loss import (CharbonnierCompLoss, L1CompositionLoss, 3 | MSECompositionLoss) 4 | from .feature_loss import LightCNNFeatureLoss 5 | from .gan_loss import DiscShiftLoss, GANLoss, GaussianBlur, GradientPenaltyLoss 6 | from .gradient_loss import GradientLoss 7 | from .perceptual_loss import (PerceptualLoss, PerceptualVGG, 8 | TransferalPerceptualLoss) 9 | from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss 10 | from .utils import mask_reduce_loss, reduce_loss 11 | from .RD_loss import RDLoss 12 | from .RDPix_loss import Rate_PixLoss 13 | from .moco_loss import MocoLoss 14 | from .vq_loss import LPIPSLoss,Codebook_SemanticLoss 15 | from .lambdaD_loss import LambdaDLoss 16 | __all__ = [ 17 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss', 18 | 'MSECompositionLoss', 'CharbonnierCompLoss', 'GANLoss', 'GaussianBlur', 19 | 'GradientPenaltyLoss', 'PerceptualLoss', 'PerceptualVGG', 'reduce_loss', 20 | 'mask_reduce_loss', 'DiscShiftLoss', 'MaskedTVLoss', 'GradientLoss', 21 | 'TransferalPerceptualLoss', 'LightCNNFeatureLoss','RDLoss','Rate_PixLoss', 22 | 'MocoLoss','LPIPSLoss','Codebook_SemanticLoss','LambdaDLoss' 23 | ] 24 | -------------------------------------------------------------------------------- /mmedit/models/losses/lambdaD_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ..registry import LOSSES 9 | from .utils import masked_loss 10 | 11 | 12 | def get_loss_func(loss_func): 13 | # pylint: disable=possibly-unused-variable 14 | 15 | def loss_me_mse(rd): 16 | return rd['me_mse'],0 17 | 18 | def loss_me_rdc_mse(rd): 19 | return rd['me_mse'],(rd['bpp_mv_y'] + rd['bpp_mv_z']) 20 | 21 | 22 | def loss_recon_mse(rd): 23 | return rd['mse'],0 24 | 25 | def loss_recon_rdc_mse(rd): 26 | return rd['mse'],(rd['bpp_y'] + rd['bpp_z']) 27 | 28 | 29 | def loss_total_rdc_mse(rd): 30 | return rd['mse'], rd['bpp'] 31 | 32 | 33 | def loss_total_rdc_ms_ssim(rd): 34 | return 17 * (rd['1-ssim']), rd['bpp'] 35 | 36 | loss_func_name = f'loss_{loss_func}' 37 | assert loss_func_name in locals() 38 | return locals()[loss_func_name] 39 | 40 | 41 | @LOSSES.register_module() 42 | class LambdaDLoss(nn.Module): 43 | def __init__(self,loss_weight: float = 1.0): 44 | super().__init__() 45 | self.loss_weight = loss_weight 46 | 47 | def forward(self, rd, loss_type): 48 | loss_func = get_loss_func(loss_type) 49 | distortion_loss, bpp_loss =loss_func(rd) 50 | return self.loss_weight * distortion_loss, bpp_loss 51 | -------------------------------------------------------------------------------- /mmedit/models/losses/vq_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # import pyiqa 6 | from ..registry import LOSSES 7 | from .utils import masked_loss 8 | from mmedit.utils import ycbcr2rgb 9 | _reduction_modes = ['none', 'mean', 'sum'] 10 | 11 | @LOSSES.register_module() 12 | class LPIPSLoss(nn.Module): 13 | def __init__(self, loss_weight = 1.0, yuv_input=False): 14 | super(LPIPSLoss, self).__init__() 15 | self.model = pyiqa.create_metric('lpips-vgg', as_loss=True) 16 | self.loss_weight = loss_weight 17 | self.yuv_input=yuv_input 18 | 19 | 20 | def forward(self, x, gt): 21 | if self.yuv_input: 22 | x=ycbcr2rgb(x) 23 | gt=ycbcr2rgb(gt) 24 | 25 | return self.model(x, gt) * self.loss_weight, None 26 | 27 | 28 | 29 | @LOSSES.register_module() 30 | class Codebook_SemanticLoss(nn.Module): 31 | def __init__(self, loss_weight = 1.0): 32 | super(Codebook_SemanticLoss, self).__init__() 33 | self.loss_weight = loss_weight 34 | 35 | def forward(self, l_codebook): 36 | return l_codebook * self.loss_weight 37 | 38 | -------------------------------------------------------------------------------- /mmedit/models/mattors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_mattor import BaseMattor 3 | from .dim import DIM 4 | from .gca import GCA 5 | from .indexnet import IndexNet 6 | from .utils import get_unknown_tensor 7 | 8 | __all__ = ['BaseMattor', 'DIM', 'IndexNet', 'GCA', 'get_unknown_tensor'] 9 | -------------------------------------------------------------------------------- /mmedit/models/mattors/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def get_unknown_tensor(trimap, meta): 6 | """Get 1-channel unknown area tensor from the 3 or 1-channel trimap tensor. 7 | 8 | Args: 9 | trimap (Tensor): Tensor with shape (N, 3, H, W) or (N, 1, H, W). 10 | 11 | Returns: 12 | Tensor: Unknown area mask of shape (N, 1, H, W). 13 | """ 14 | if trimap.shape[1] == 3: 15 | # The three channels correspond to (bg mask, unknown mask, fg mask) 16 | # respectively. 17 | weight = trimap[:, 1:2, :, :].float() 18 | elif 'to_onehot' in meta[0]: 19 | # key 'to_onehot' is added by pipeline `FormatTrimap` 20 | # 0 for bg, 1 for unknown, 2 for fg 21 | weight = trimap.eq(1).float() 22 | else: 23 | # trimap is simply processed by pipeline `RescaleToZeroOne` 24 | # 0 for bg, 128/255 for unknown, 1 for fg 25 | weight = trimap.eq(128 / 255).float() 26 | return weight 27 | 28 | 29 | def fba_fusion(alpha, img, F, B): 30 | """Postprocess the predicted. 31 | 32 | This class is adopted from 33 | https://github.com/MarcoForte/FBA_Matting. 34 | 35 | Args: 36 | alpha (Tensor): Tensor with shape (N, 1, H, W). 37 | img (Tensor): Tensor with shape (N, 3, H, W). 38 | F (Tensor): Tensor with shape (N, 3, H, W). 39 | B (Tensor): Tensor with shape (N, 3, H, W). 40 | 41 | Returns: 42 | alpha (Tensor): Tensor with shape (N, 1, H, W). 43 | F (Tensor): Tensor with shape (N, 3, H, W). 44 | B (Tensor): Tensor with shape (N, 3, H, W). 45 | """ 46 | F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B)) 47 | B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * 48 | (1 - alpha) * F) 49 | 50 | F = torch.clamp(F, 0, 1) 51 | B = torch.clamp(B, 0, 1) 52 | la = 0.1 53 | alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / ( 54 | torch.sum((F - B) * (F - B), 1, keepdim=True) + la) 55 | alpha = torch.clamp(alpha, 0, 1) 56 | return alpha, F, B 57 | -------------------------------------------------------------------------------- /mmedit/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import MODELS as MMCV_MODELS 3 | from mmcv.utils import Registry 4 | 5 | MODELS = Registry('model', parent=MMCV_MODELS) 6 | BACKBONES = MODELS 7 | COMPONENTS = MODELS 8 | LOSSES = MODELS 9 | -------------------------------------------------------------------------------- /mmedit/models/restorers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .basic_restorer import BasicRestorer 3 | from .basicvsr import BasicVSR 4 | from .dic import DIC 5 | from .edvr import EDVR 6 | from .esrgan import ESRGAN 7 | from .glean import GLEAN 8 | from .liif import LIIF 9 | from .real_basicvsr import RealBasicVSR 10 | from .real_esrgan import RealESRGAN 11 | from .srgan import SRGAN 12 | from .tdan import TDAN 13 | from .ttsr import TTSR 14 | from .dmci import DMCI 15 | 16 | __all__ = [ 17 | 'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR', 18 | 'GLEAN', 'TDAN', 'DIC', 'RealESRGAN', 'RealBasicVSR','DMCI' 19 | ] 20 | -------------------------------------------------------------------------------- /mmedit/models/restorers/glean.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numbers 3 | import os.path as osp 4 | 5 | import mmcv 6 | 7 | from mmedit.core import tensor2img 8 | from ..registry import MODELS 9 | from .srgan import SRGAN 10 | 11 | 12 | @MODELS.register_module() 13 | class GLEAN(SRGAN): 14 | """GLEAN model for single image super-resolution. 15 | 16 | This model is identical to SRGAN except that the output images are 17 | transformed from [-1, 1] to [0, 1]. 18 | 19 | Paper: 20 | GLEAN: Generative Latent Bank for Large-Factor Image Super-Resolution. 21 | CVPR, 2021. 22 | 23 | """ 24 | 25 | def init_weights(self, pretrained=None): 26 | """Init weights for models. 27 | 28 | Args: 29 | pretrained (str, optional): Path for pretrained weights. If given 30 | None, pretrained weights will not be loaded. Defaults to None. 31 | """ 32 | self.generator.init_weights(pretrained=pretrained) 33 | 34 | def forward_test(self, 35 | lq, 36 | gt=None, 37 | meta=None, 38 | save_image=False, 39 | save_path=None, 40 | iteration=None): 41 | """Testing forward function. 42 | 43 | Args: 44 | lq (Tensor): LQ Tensor with shape (n, c, h, w). 45 | gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. 46 | save_image (bool): Whether to save image. Default: False. 47 | save_path (str): Path to save image. Default: None. 48 | iteration (int): Iteration for the saving image name. 49 | Default: None. 50 | 51 | Returns: 52 | dict: Output results. 53 | """ 54 | output = self.generator(lq) 55 | 56 | # normalize from [-1, 1] to [0, 1] 57 | output = (output + 1) / 2.0 58 | 59 | if self.test_cfg is not None and self.test_cfg.get('metrics', None): 60 | assert gt is not None, ( 61 | 'evaluation with metrics must have gt images.') 62 | gt = (gt + 1) / 2.0 # normalize from [-1, 1] to [0, 1] 63 | results = dict(eval_result=self.evaluate(output, gt)) 64 | else: 65 | results = dict(lq=lq.cpu(), output=output.cpu()) 66 | if gt is not None: 67 | results['gt'] = gt.cpu() 68 | 69 | # save image 70 | if save_image: 71 | lq_path = meta[0]['lq_path'] 72 | folder_name = osp.splitext(osp.basename(lq_path))[0] 73 | if isinstance(iteration, numbers.Number): 74 | save_path = osp.join(save_path, folder_name, 75 | f'{folder_name}-{iteration + 1:06d}.png') 76 | elif iteration is None: 77 | save_path = osp.join(save_path, f'{folder_name}.png') 78 | else: 79 | raise ValueError('iteration should be number or None, ' 80 | f'but got {type(iteration)}') 81 | mmcv.imwrite(tensor2img(output), save_path) 82 | 83 | return results 84 | -------------------------------------------------------------------------------- /mmedit/models/synthesizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cycle_gan import CycleGAN 3 | from .pix2pix import Pix2Pix 4 | 5 | __all__ = ['Pix2Pix', 'CycleGAN'] 6 | -------------------------------------------------------------------------------- /mmedit/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .search_transformer import SearchTransformer 3 | 4 | __all__ = ['SearchTransformer'] 5 | -------------------------------------------------------------------------------- /mmedit/models/video_interpolators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .basic_interpolator import BasicInterpolator 3 | from .cain import CAIN 4 | 5 | __all__ = ['BasicInterpolator', 'CAIN'] 6 | -------------------------------------------------------------------------------- /mmedit/models/vqgan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .femasr_model import FeMaSRModel_HR 3 | 4 | __all__ = ['FeMaSRModel_HR'] 5 | -------------------------------------------------------------------------------- /mmedit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cli import modify_args 3 | from .logger import get_root_logger 4 | from .setup_env import setup_multi_processes 5 | from .stream_helper import encode_i, decode_i, get_downsampled_shape, filesize, \ 6 | get_state_dict, get_padding_size 7 | from .common import avg_per_rate, generate_str, str2bool,get_training_lambdas,get_clip_grad_norm_func 8 | from .functional import rgb2ycbcr,yuv_444_to_420,ycbcr2rgb 9 | 10 | __all__ = ['get_root_logger', 'setup_multi_processes', 'modify_args','str2bool', 11 | 'encode_i', 'decode_i', 'get_downsampled_shape', 'filesize', 'get_state_dict', 'get_clip_grad_norm_func', 12 | 'avg_per_rate', 'generate_str', 'rgb2ycbcr','get_padding_size','yuv_444_to_420','ycbcr2rgb','get_training_lambdas' 13 | ] 14 | -------------------------------------------------------------------------------- /mmedit/utils/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import re 3 | import sys 4 | import warnings 5 | 6 | 7 | def modify_args(): 8 | for i, v in enumerate(sys.argv): 9 | if i == 0: 10 | assert v.endswith('.py') 11 | elif re.match(r'--\w+_.*', v): 12 | new_arg = v.replace('_', '-') 13 | warnings.warn( 14 | f'command line argument {v} is deprecated, ' 15 | f'please use {new_arg} instead.', 16 | category=DeprecationWarning, 17 | ) 18 | sys.argv[i] = new_arg 19 | -------------------------------------------------------------------------------- /mmedit/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmedit 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMEditing'] = f'{mmedit.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /mmedit/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmedit". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | # root logger name: mmedit 26 | logger = get_logger(__name__.split('.')[0], log_file, log_level) 27 | return logger 28 | -------------------------------------------------------------------------------- /mmedit/utils/setup_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | import warnings 5 | 6 | import cv2 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def setup_multi_processes(cfg): 11 | """Setup multi-processing environment variables.""" 12 | # set multi-process start method as `fork` to speed up the training 13 | if platform.system() != 'Windows': 14 | mp_start_method = cfg.get('mp_start_method', 'fork') 15 | current_method = mp.get_start_method(allow_none=True) 16 | if current_method is not None and current_method != mp_start_method: 17 | warnings.warn( 18 | f'Multi-processing start method `{mp_start_method}` is ' 19 | f'different from the previous setting `{current_method}`.' 20 | f'It will be force set to `{mp_start_method}`. You can change ' 21 | f'this behavior by changing `mp_start_method` in your config.') 22 | mp.set_start_method(mp_start_method, force=True) 23 | 24 | # disable opencv multithreading to avoid system being overloaded 25 | opencv_num_threads = cfg.get('opencv_num_threads', 0) 26 | cv2.setNumThreads(opencv_num_threads) 27 | 28 | # setup OMP threads 29 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 30 | if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 31 | omp_num_threads = 1 32 | warnings.warn( 33 | f'Setting OMP_NUM_THREADS environment variable for each process ' 34 | f'to be {omp_num_threads} in default, to avoid your system being ' 35 | f'overloaded, please further tune the variable for optimal ' 36 | f'performance in your application as needed.') 37 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 38 | 39 | # setup MKL threads 40 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 41 | mkl_num_threads = 1 42 | warnings.warn( 43 | f'Setting MKL_NUM_THREADS environment variable for each process ' 44 | f'to be {mkl_num_threads} in default, to avoid your system being ' 45 | f'overloaded, please further tune the variable for optimal ' 46 | f'performance in your application as needed.') 47 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 48 | -------------------------------------------------------------------------------- /mmedit/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.14.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | ver_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | ver_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | ver_info.append(int(patch_version[0])) 14 | ver_info.append(f'rc{patch_version[1]}') 15 | return tuple(ver_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /scripts/test_script_clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | 5 | 6 | 7 | ./tools/dist_test.sh configs/noise/all_in_one_L.py ckpt/all_in_one_L_noise.pth 1 \ 8 | " --lq_folder data/all-in-one-test --gt_folder data/all-in-one-test \ 9 | --rate_num 8 --cal_msssim True --noise_type 0 \ 10 | --save-path ./results/all_in_one_L/Kodak_24_sgm0 \ 11 | --json_path ./results/all_in_one_L/Kodak_24_sgm0.json " 12 | 13 | 14 | 15 | 16 | ./tools/dist_test.sh configs/noise/all_in_one_S.py ckpt/all_in_one_S_noise.pth 1 \ 17 | " --lq_folder data/all-in-one-test --gt_folder data/all-in-one-test \ 18 | --rate_num 8 --cal_msssim True --noise_type 0 \ 19 | --save-path ./results/all_in_one_S/Kodak_24_sgm0 \ 20 | --json_path ./results/all_in_one_S/Kodak_24_sgm0.json " 21 | 22 | -------------------------------------------------------------------------------- /scripts/test_script_noise.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # ./tools/dist_test.sh configs/noise/all_in_one_L.py ckpt/all_in_one_L_noise.pth 1 \ 5 | # " --lq_folder data/all-in-one-test --gt_folder data/all-in-one-test \ 6 | # --rate_num 8 --cal_msssim True --noise_type 15 \ 7 | # --save-path ./results/all_in_one_L/Kodak_24_sgm15 \ 8 | # --json_path ./results/all_in_one_L/Kodak_24_sgm15.json " 9 | 10 | 11 | 12 | 13 | # ./tools/dist_test.sh configs/noise/all_in_one_S.py ckpt/all_in_one_S_noise.pth 1 \ 14 | # " --lq_folder data/all-in-one-test --gt_folder data/all-in-one-test \ 15 | # --rate_num 8 --cal_msssim True --noise_type 15 \ 16 | # --save-path ./results/all_in_one_S/Kodak_24_sgm15 \ 17 | # --json_path ./results/all_in_one_S/Kodak_24_sgm15.json " 18 | 19 | 20 | pip install pytorch_msssim scipy mmengine timm==0.6.7 21 | cd /data/zenghuimin/code/codec-dev-release/ 22 | 23 | ./tools/dist_test.sh configs/noise/all_in_one_S.py all_in_one_S_noise.pth 1 \ 24 | " --lq_folder data/all-in-one-test --gt_folder data/all-in-one-test \ 25 | --rate_num 8 --cal_msssim True --noise_type 15 \ 26 | --save-path ./results/all_in_one_S/Kodak_24_sgm15 \ 27 | --json_path ./results/all_in_one_S/Kodak_24_sgm15.json " 28 | 29 | -------------------------------------------------------------------------------- /scripts/test_script_weather.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ./tools/dist_test.sh configs/weather/all_in_one_L.py \ 4 | ./ckpt/all_in_one_L_weather.pth 1 \ 5 | " --specified_key SOTS_outdoor --gt_folder data/all-in-one-test \ 6 | --save-path ./results/all_in_one_L/SOTS_outdoor \ 7 | --json_path ./results/all_in_one_L/SOTS_outdoor.json \ 8 | --rate_num 8 --cal_msssim True " 9 | 10 | 11 | 12 | 13 | ./tools/dist_test.sh configs/weather/all_in_one_S.py \ 14 | ./ckpt/all_in_one_S_weather.pth 1 \ 15 | " --specified_key SOTS_outdoor --gt_folder data/all-in-one-test \ 16 | --save-path ./results/all_in_one_S/SOTS_outdoor \ 17 | --json_path ./results/all_in_one_S/SOTS_outdoor.json \ 18 | --rate_num 8 --cal_msssim True " 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ./tools/dist_train_iter.sh configs/weather/all_in_one_L.py 4 '--train_dataset data/all-in-one-train --test_dataset data/all-in-one-test --exp_name all_in_one_L_weather ' 4 | 5 | ./tools/dist_train_iter.sh configs/weather/all_in_one_S.py 4 '--train_dataset data/all-in-one-train --test_dataset data/all-in-one-test --exp_name all_in_one_S_weather ' 6 | 7 | 8 | ./tools/dist_train_iter.sh configs/noise/all_in_one_L.py 4 '--train_dataset data/all-in-one-train --test_dataset data/all-in-one-test --exp_name all_in_one_L_noise ' 9 | 10 | ./tools/dist_train_iter.sh configs/noise/all_in_one_S.py 4 '--train_dataset data/all-in-one-train --test_dataset data/all-in-one-test --exp_name all_in_one_S_noise ' 11 | 12 | -------------------------------------------------------------------------------- /tools/data/generation/README.md: -------------------------------------------------------------------------------- 1 | # Generation Datasets 2 | 3 | It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files. 4 | 5 | MMEditing supported generation datasets: 6 | 7 | - [Paired Dataset for Pix2pix](paired-pix2pix/README.md) \[ [Homepage](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/) \] 8 | - [Unpaired Dataset for CycleGAN](unpaired-cyclegan/README.md) \[ [Homepage](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/) \] 9 | -------------------------------------------------------------------------------- /tools/data/generation/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 图像生成数据集 2 | 3 | 建议将数据集软链接到 `$MMEDITING/data` 。如果您的文件夹结构不同,您可能需要更改配置文件中的相应路径。 4 | 5 | MMEditing 支持的生成数据集: 6 | 7 | * [Pix2Pix 的配对数据集](paired-pix2pix/README.md) \[ [主页](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/) \] 8 | * [CycleGAN 的未配对数据集](unpaired-cyclegan/README.md) \[ [主页](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/) \] 9 | -------------------------------------------------------------------------------- /tools/data/generation/paired-pix2pix/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Paired Dataset for Pix2pix 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{isola2017image, 7 | title={Image-to-image translation with conditional adversarial networks}, 8 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, 9 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 10 | pages={1125--1134}, 11 | year={2017} 12 | } 13 | ``` 14 | 15 | You can download paired datasets from [here](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/). 16 | Then, you need to unzip and move corresponding datasets to follow the folder structure shown above. The datasets have been well-prepared by the original authors. 17 | 18 | ```text 19 | mmediting 20 | ├── mmedit 21 | ├── tools 22 | ├── configs 23 | ├── data 24 | │ ├── paired 25 | │ │ ├── facades 26 | │ │ ├── maps 27 | | | ├── edges2shoes 28 | | | | ├── train 29 | | | | ├── test 30 | ``` 31 | -------------------------------------------------------------------------------- /tools/data/generation/paired-pix2pix/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 为 Pix2pix 准备配对数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{isola2017image, 7 | title={Image-to-image translation with conditional adversarial networks}, 8 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, 9 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 10 | pages={1125--1134}, 11 | year={2017} 12 | } 13 | ``` 14 | 15 | 您可以从[此处](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/)下载配对数据集。然后,您需要解压缩并移动相应的数据集以遵循如下所示的文件夹结构。数据集已经由原作者准备好了。 16 | 17 | ```text 18 | mmediting 19 | ├── mmedit 20 | ├── tools 21 | ├── configs 22 | ├── data 23 | │ ├── paired 24 | │ │ ├── facades 25 | │ │ ├── maps 26 | | | ├── edges2shoes 27 | | | | ├── train 28 | | | | ├── test 29 | ``` 30 | -------------------------------------------------------------------------------- /tools/data/generation/unpaired-cyclegan/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Unpaired Dataset for CycleGAN 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{zhu2017unpaired, 7 | title={Unpaired image-to-image translation using cycle-consistent adversarial networks}, 8 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, 9 | booktitle={Proceedings of the IEEE international conference on computer vision}, 10 | pages={2223--2232}, 11 | year={2017} 12 | } 13 | ``` 14 | 15 | You can download unpaired datasets from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/). 16 | Then, you need to unzip and move corresponding datasets to follow the folder structure shown above. The datasets have been well-prepared by the original authors. 17 | 18 | ```text 19 | mmediting 20 | ├── mmedit 21 | ├── tools 22 | ├── configs 23 | ├── data 24 | │ ├── unpaired 25 | │ │ ├── facades 26 | | | ├── horse2zebra 27 | | | ├── summer2winter_yosemite 28 | | | | ├── trainA 29 | | | | ├── trainB 30 | | | | ├── testA 31 | | | | ├── testB 32 | ``` 33 | -------------------------------------------------------------------------------- /tools/data/generation/unpaired-cyclegan/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 为 CycleGAN 准备未配对数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{zhu2017unpaired, 7 | title={Unpaired image-to-image translation using cycle-consistent adversarial networks}, 8 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, 9 | booktitle={Proceedings of the IEEE international conference on computer vision}, 10 | pages={2223--2232}, 11 | year={2017} 12 | } 13 | ``` 14 | 15 | 您可以从[此处](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)下载未配对的数据集。然后,您需要解压缩并移动相应的数据集以遵循如上所示的文件夹结构。数据集已经由原作者准备好了。 16 | 17 | ```text 18 | mmediting 19 | ├── mmedit 20 | ├── tools 21 | ├── configs 22 | ├── data 23 | │ ├── unpaired 24 | │ │ ├── facades 25 | | | ├── horse2zebra 26 | | | ├── summer2winter_yosemite 27 | | | | ├── trainA 28 | | | | ├── trainB 29 | | | | ├── testA 30 | | | | ├── testB 31 | ``` 32 | -------------------------------------------------------------------------------- /tools/data/inpainting/README.md: -------------------------------------------------------------------------------- 1 | # Inpainting Datasets 2 | 3 | It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files. 4 | 5 | MMEditing supported inpainting datasets: 6 | 7 | - [Paris Street View](paris-street-view/README.md) \[ [Homepage](https://github.com/pathak22/context-encoder/issues/24) \] 8 | - [CelebA-HQ](celeba-hq/README.md) \[ [Homepage](https://github.com/tkarras/progressive_growing_of_gans#preparing-datasets-for-training) \] 9 | - [Places365](places365/README.md) \[ [Homepage](http://places2.csail.mit.edu/) \] 10 | 11 | As we only need images for inpainting task, further preparation is not necessary and the folder structure can be different from the example. You can utilize the information provided by the original dataset like `Place365` (e.g. `meta`). Also, you can easily scan the data set and list all of the images to a specific `txt` file. Here is an example for the `Places365_val.txt` from Places365 and we will only use the image name information in inpainting. 12 | 13 | ``` 14 | Places365_val_00000001.jpg 165 15 | Places365_val_00000002.jpg 358 16 | Places365_val_00000003.jpg 93 17 | Places365_val_00000004.jpg 164 18 | Places365_val_00000005.jpg 289 19 | Places365_val_00000006.jpg 106 20 | Places365_val_00000007.jpg 81 21 | Places365_val_00000008.jpg 121 22 | Places365_val_00000009.jpg 150 23 | Places365_val_00000010.jpg 302 24 | Places365_val_00000011.jpg 42 25 | ``` 26 | -------------------------------------------------------------------------------- /tools/data/inpainting/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 图像补全数据集 2 | 3 | 建议将数据集软链接到 `$MMEDITING/data` 。如果您的文件夹结构不同,您可能需要更改配置文件中的相应路径。 4 | 5 | MMEditing 支持的补全数据集: 6 | 7 | * [Paris Street View](paris-street-view/README.md) \[ [主页](https://github.com/pathak22/context-encoder/issues/24) \] 8 | * [CelebA-HQ](celeba-hq/README.md) \[ [主页](https://github.com/tkarras/progressive_growing_of_gans#preparing-datasets-for-training) \] 9 | * [Places365](places365/README.md) \[ [主页](http://places2.csail.mit.edu/) \] 10 | 11 | 由于在图像补全任务中,我们只需要使用图像,因此我们不需要对数据集进行额外的预处理操作,文件目录的结构也可以和本例有所不同。您可以利用原始数据集提供的信息,如 `Place365` (例如 `meta`)。或者,您可以直接遍历数据集文件夹,并将所有图像文件的路径罗列在一个文本文件中。下面的例子节选自 Places365 数据集中的 `Places365_val.txt`,针对图像补全任务,我们只需要使用其中的文件名信息。 12 | 13 | ``` 14 | Places365_val_00000001.jpg 165 15 | Places365_val_00000002.jpg 358 16 | Places365_val_00000003.jpg 93 17 | Places365_val_00000004.jpg 164 18 | Places365_val_00000005.jpg 289 19 | Places365_val_00000006.jpg 106 20 | Places365_val_00000007.jpg 81 21 | Places365_val_00000008.jpg 121 22 | Places365_val_00000009.jpg 150 23 | Places365_val_00000010.jpg 302 24 | Places365_val_00000011.jpg 42 25 | ``` 26 | -------------------------------------------------------------------------------- /tools/data/inpainting/celeba-hq/README.md: -------------------------------------------------------------------------------- 1 | # Preparing CelebA-HQ Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @article{karras2017progressive, 7 | title={Progressive growing of gans for improved quality, stability, and variation}, 8 | author={Karras, Tero and Aila, Timo and Laine, Samuli and Lehtinen, Jaakko}, 9 | journal={arXiv preprint arXiv:1710.10196}, 10 | year={2017} 11 | } 12 | ``` 13 | 14 | Follow the instructions [here](https://github.com/tkarras/progressive_growing_of_gans#preparing-datasets-for-training) to prepare the dataset. 15 | 16 | ```text 17 | mmediting 18 | ├── mmedit 19 | ├── tools 20 | ├── configs 21 | ├── data 22 | │ ├── celeba-hq 23 | │ │ ├── train 24 | | | ├── val 25 | 26 | ``` 27 | -------------------------------------------------------------------------------- /tools/data/inpainting/celeba-hq/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 CelebA-HQ 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @article{karras2017progressive, 7 | title={Progressive growing of gans for improved quality, stability, and variation}, 8 | author={Karras, Tero and Aila, Timo and Laine, Samuli and Lehtinen, Jaakko}, 9 | journal={arXiv preprint arXiv:1710.10196}, 10 | year={2017} 11 | } 12 | ``` 13 | 14 | 请按照[此处](https://github.com/tkarras/progressive_growing_of_gans#preparing-datasets-for-training)准备数据集。 15 | 16 | ```text 17 | mmediting 18 | ├── mmedit 19 | ├── tools 20 | ├── configs 21 | ├── data 22 | │ ├── celeba-hq 23 | │ │ ├── train 24 | | | ├── val 25 | 26 | ``` 27 | -------------------------------------------------------------------------------- /tools/data/inpainting/paris-street-view/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Paris Street View Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{pathak2016context, 7 | title={Context encoders: Feature learning by inpainting}, 8 | author={Pathak, Deepak and Krahenbuhl, Philipp and Donahue, Jeff and Darrell, Trevor and Efros, Alexei A}, 9 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 10 | pages={2536--2544}, 11 | year={2016} 12 | } 13 | ``` 14 | 15 | Obtain the dataset [here](https://github.com/pathak22/context-encoder/issues/24). 16 | 17 | ```text 18 | mmediting 19 | ├── mmedit 20 | ├── tools 21 | ├── configs 22 | ├── data 23 | │ ├── paris_street_view 24 | │ │ ├── train 25 | | | ├── val 26 | 27 | ``` 28 | -------------------------------------------------------------------------------- /tools/data/inpainting/paris-street-view/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 Paris Street View 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{pathak2016context, 7 | title={Context encoders: Feature learning by inpainting}, 8 | author={Pathak, Deepak and Krahenbuhl, Philipp and Donahue, Jeff and Darrell, Trevor and Efros, Alexei A}, 9 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 10 | pages={2536--2544}, 11 | year={2016} 12 | } 13 | ``` 14 | 15 | 请从[此处](https://github.com/pathak22/context-encoder/issues/24)获取数据集。 16 | 17 | ```text 18 | mmediting 19 | ├── mmedit 20 | ├── tools 21 | ├── configs 22 | ├── data 23 | │ ├── paris_street_view 24 | │ │ ├── train 25 | | | ├── val 26 | 27 | ``` 28 | -------------------------------------------------------------------------------- /tools/data/inpainting/places365/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Places365 Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @article{zhou2017places, 7 | title={Places: A 10 million Image Database for Scene Recognition}, 8 | author={Zhou, Bolei and Lapedriza, Agata and Khosla, Aditya and Oliva, Aude and Torralba, Antonio}, 9 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 10 | year={2017}, 11 | publisher={IEEE} 12 | } 13 | 14 | ``` 15 | 16 | Prepare the data from [Places365](http://places2.csail.mit.edu/download.html). 17 | 18 | ```text 19 | mmediting 20 | ├── mmedit 21 | ├── tools 22 | ├── configs 23 | ├── data 24 | │ ├── places 25 | │ │ ├── test_set 26 | │ │ ├── train_set 27 | | | ├── meta 28 | | | | ├── Places365_train.txt 29 | | | | ├── Places365_val.txt 30 | ``` 31 | -------------------------------------------------------------------------------- /tools/data/inpainting/places365/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 Places365 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @article{zhou2017places, 7 | title={Places: A 10 million Image Database for Scene Recognition}, 8 | author={Zhou, Bolei and Lapedriza, Agata and Khosla, Aditya and Oliva, Aude and Torralba, Antonio}, 9 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 10 | year={2017}, 11 | publisher={IEEE} 12 | } 13 | 14 | ``` 15 | 16 | 请从 [Places365](http://places2.csail.mit.edu/download.html) 下载并准备数据。 17 | 18 | ```text 19 | mmediting 20 | ├── mmedit 21 | ├── tools 22 | ├── configs 23 | ├── data 24 | │ ├── places 25 | │ │ ├── test_set 26 | │ │ ├── train_set 27 | | | ├── meta 28 | | | | ├── Places365_train.txt 29 | | | | ├── Places365_val.txt 30 | ``` 31 | -------------------------------------------------------------------------------- /tools/data/matting/README.md: -------------------------------------------------------------------------------- 1 | # Matting Datasets 2 | 3 | It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files. 4 | 5 | MMEditing supported matting datasets: 6 | 7 | - [Composition-1k](comp1k/README.md) \[ [Homepage](https://sites.google.com/view/deepimagematting) \] 8 | -------------------------------------------------------------------------------- /tools/data/matting/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 抠图数据集 2 | 3 | 建议将数据集软链接到 `$MMEDITING/data` 。如果您的文件夹结构不同,您可能需要更改配置文件中的相应路径。 4 | 5 | MMEditing 支持的抠图数据集: 6 | 7 | * [Composition-1k](comp1k/README.md) \[ [Homepage](https://sites.google.com/view/deepimagematting) \] 8 | -------------------------------------------------------------------------------- /tools/data/matting/comp1k/check_extended_fg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # This script checks the alpha-foreground difference between 4 | # the extended fg and the original fg 5 | 6 | import glob 7 | import os 8 | import os.path as osp 9 | 10 | import cv2 11 | import numpy as np 12 | 13 | folder = 'data/adobe_composition-1k/Training_set/Adobe-licensed images' 14 | folder = osp.join(folder.split('/')) 15 | imgs = [ 16 | os.path.splitext(os.path.basename(x))[0] 17 | for x in glob.glob(osp.join(folder, 'fg', '*.jpg')) 18 | ] 19 | 20 | print('max,avg,img') 21 | for name in imgs: 22 | alpha = cv2.imread( 23 | osp.join(folder, 'alpha', '*.jpg'), cv2.IMREAD_GRAYSCALE).astype( 24 | np.float32)[..., None] / 255 25 | fg = cv2.imread(osp.join(folder, 'fg', f'{name}.jpg')).astype(np.float32) 26 | xt = cv2.imread(osp.join(folder, 'fg_extended', 27 | f'{name}.jpg')).astype(np.float32) 28 | diff = np.abs((fg - xt) * alpha) 29 | print(f'{diff.max()},{diff.mean()},"{name}"', flush=True) 30 | -------------------------------------------------------------------------------- /tools/data/matting/comp1k/filter_comp1k_anno.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | 7 | 8 | def generate_json(comp1k_json_path, target_list_path, save_json_path): 9 | data_infos = mmcv.load(comp1k_json_path) 10 | targets = mmcv.list_from_file(target_list_path) 11 | new_data_infos = [] 12 | for data_info in data_infos: 13 | for target in targets: 14 | if data_info['alpha_path'].endswith(target): 15 | new_data_infos.append(data_info) 16 | break 17 | 18 | mmcv.dump(new_data_infos, save_json_path) 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser( 23 | description='Filter composition-1k annotation file') 24 | parser.add_argument( 25 | 'comp1k_json_path', 26 | help='Path to the composition-1k dataset annotation file') 27 | parser.add_argument( 28 | 'target_list_path', 29 | help='Path to the file name list that need to filter out') 30 | parser.add_argument( 31 | 'save_json_path', help='Path to save the result json file') 32 | return parser.parse_args() 33 | 34 | 35 | def main(): 36 | args = parse_args() 37 | 38 | if not osp.exists(args.comp1k_json_path): 39 | raise FileNotFoundError(f'{args.comp1k_json_path} does not exist!') 40 | if not osp.exists(args.target_list_path): 41 | raise FileNotFoundError(f'{args.target_list_path} does not exist!') 42 | 43 | generate_json(args.comp1k_json_path, args.target_list_path, 44 | args.save_json_path) 45 | 46 | print('Done!') 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /tools/data/super-resolution/README.md: -------------------------------------------------------------------------------- 1 | # Super-Resolution Datasets 2 | 3 | It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files. 4 | 5 | MMEditing supported super-resolution datasets: 6 | 7 | - Image Super-Resolution 8 | - [DIV2K](div2k/README.md) \[ [Homepage](https://data.vision.ee.ethz.ch/cvl/DIV2K/) \] 9 | - Video Super-Resolution 10 | - [REDS](reds/README.md) \[ [Homepage](https://seungjunnah.github.io/Datasets/reds.html) \] 11 | - [Vimeo90K](vimeo90k/README.md) \[ [Homepage](http://toflow.csail.mit.edu) \] 12 | -------------------------------------------------------------------------------- /tools/data/super-resolution/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 超分辨率数据集 2 | 3 | 建议将数据集的根目录链接到 `$MMEDITING/data` 下,如果您的文件目录结构不一致,那么可能需要在配置文件中修改对应的文件路径。 4 | 5 | MMEditing 支持下列超分辨率数据集: 6 | 7 | - 图像超分辨率 8 | - [DIV2K](div2k/README.md) \[ [Homepage](https://data.vision.ee.ethz.ch/cvl/DIV2K/) \] 9 | - 视频超分辨率 10 | - [REDS](reds/README.md) \[ [Homepage](https://seungjunnah.github.io/Datasets/reds.html) \] 11 | - [Vimeo90K](vimeo90k/README.md) \[ [Homepage](http://toflow.csail.mit.edu) \] 12 | -------------------------------------------------------------------------------- /tools/data/super-resolution/df2k_ost/README.md: -------------------------------------------------------------------------------- 1 | # Preparing DF2K_OST Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{wang2021real, 7 | title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data}, 8 | author={Wang, Xintao and Xie, Liangbin and Dong, Chao and Shan, Ying}, 9 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 10 | pages={1905--1914}, 11 | year={2021} 12 | } 13 | ``` 14 | 15 | - The DIV2K dataset can be downloaded from [here](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (We use the training set only). 16 | - The Flickr2K dataset can be downloaded [here](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (We use the training set only). 17 | - The OST dataset can be downloaded [here](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip) (We use the training set only). 18 | 19 | Please first put all the images into the `GT` folder (naming does not need to be in order): 20 | ```text 21 | mmediting 22 | ├── mmedit 23 | ├── tools 24 | ├── configs 25 | ├── data 26 | │ ├── df2k_ost 27 | │ │ ├── GT 28 | │ │ │ ├── 0001.png 29 | │ │ │ ├── 0002.png 30 | │ │ │ ├── ... 31 | ... 32 | ``` 33 | 34 | ## Crop sub-images 35 | 36 | For faster IO, we recommend to crop the images to sub-images. We provide such a script: 37 | 38 | ```shell 39 | python tools/data/super-resolution/df2k_ost/preprocess_df2k_ost_dataset.py --data-root ./data/df2k_ost 40 | ``` 41 | 42 | The generated data is stored under `df2k_ost` and the data structure is as follows, where `_sub` indicates the sub-images. 43 | 44 | ```text 45 | mmediting 46 | ├── mmedit 47 | ├── tools 48 | ├── configs 49 | ├── data 50 | │ ├── df2k_ost 51 | │ │ ├── GT 52 | │ │ ├── GT_sub 53 | ... 54 | ``` 55 | 56 | ## Prepare LMDB dataset for DF2K_OST 57 | 58 | If you want to use LMDB datasets for faster IO speed, you can make LMDB files by: 59 | 60 | ```shell 61 | python tools/data/super-resolution/df2k_ost/preprocess_df2k_ost_dataset.py --data-root ./data/df2k_ost --make-lmdb 62 | ``` 63 | -------------------------------------------------------------------------------- /tools/data/super-resolution/df2k_ost/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 DF2K_OST 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @inproceedings{wang2021real, 7 | title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data}, 8 | author={Wang, Xintao and Xie, Liangbin and Dong, Chao and Shan, Ying}, 9 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 10 | pages={1905--1914}, 11 | year={2021} 12 | } 13 | ``` 14 | 15 | - DIV2K 数据集可以在 [这里](https://data.vision.ee.ethz.ch/cvl/DIV2K/) 下载 (我们只使用训练集)。 16 | - Flickr2K 数据集可以在 [这里](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) 下载 (我们只使用训练集)。 17 | - OST 数据集可以在 [这里](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip) 下载 (我们只使用训练集)。 18 | 19 | 请先将所有图片放入 `GT` 文件夹(命名不需要按顺序): 20 | ```text 21 | mmediting 22 | ├── mmedit 23 | ├── tools 24 | ├── configs 25 | ├── data 26 | │ ├── df2k_ost 27 | │ │ ├── GT 28 | │ │ │ ├── 0001.png 29 | │ │ │ ├── 0002.png 30 | │ │ │ ├── ... 31 | ... 32 | ``` 33 | 34 | ## 裁剪子图像 35 | 36 | 为了更快的 IO,我们建议将图像裁剪为子图像。 我们提供了这样一个脚本: 37 | 38 | ```shell 39 | python tools/data/super-resolution/df2k_ost/preprocess_df2k_ost_dataset.py --data-root ./data/df2k_ost 40 | ``` 41 | 42 | 生成的数据存放在 `df2k_ost` 下,数据结构如下,其中 `_sub` 表示子图像。 43 | 44 | ```text 45 | mmediting 46 | ├── mmedit 47 | ├── tools 48 | ├── configs 49 | ├── data 50 | │ ├── df2k_ost 51 | │ │ ├── GT 52 | │ │ ├── GT_sub 53 | ... 54 | ``` 55 | 56 | ## Prepare LMDB dataset for DF2K_OST 57 | 58 | 如果你想使用 LMDB 数据集来获得更快的 IO 速度,你可以通过以下方式制作 LMDB 文件: 59 | 60 | ```shell 61 | python tools/data/super-resolution/df2k_ost/preprocess_df2k_ost_dataset.py --data-root ./data/df2k_ost --make-lmdb 62 | ``` 63 | -------------------------------------------------------------------------------- /tools/data/super-resolution/div2k/README.md: -------------------------------------------------------------------------------- 1 | # Preparing DIV2K Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @InProceedings{Agustsson_2017_CVPR_Workshops, 7 | author = {Agustsson, Eirikur and Timofte, Radu}, 8 | title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study}, 9 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 10 | month = {July}, 11 | year = {2017} 12 | } 13 | ``` 14 | 15 | - Training dataset: [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/). 16 | - Validation dataset: Set5 and Set14. 17 | 18 | ```text 19 | mmediting 20 | ├── mmedit 21 | ├── tools 22 | ├── configs 23 | ├── data 24 | │ ├── DIV2K 25 | │ │ ├── DIV2K_train_HR 26 | │ │ ├── DIV2K_train_LR_bicubic 27 | │ │ │ ├── X2 28 | │ │ │ ├── X3 29 | │ │ │ ├── X4 30 | │ │ ├── DIV2K_valid_HR 31 | │ │ ├── DIV2K_valid_LR_bicubic 32 | │ │ │ ├── X2 33 | │ │ │ ├── X3 34 | │ │ │ ├── X4 35 | │ ├── val_set5 36 | │ │ ├── Set5_bicLRx2 37 | │ │ ├── Set5_bicLRx3 38 | │ │ ├── Set5_bicLRx4 39 | │ ├── val_set14 40 | │ │ ├── Set14_bicLRx2 41 | │ │ ├── Set14_bicLRx3 42 | │ │ ├── Set14_bicLRx4 43 | ``` 44 | 45 | ## Crop sub-images 46 | 47 | For faster IO, we recommend to crop the DIV2K images to sub-images. We provide such a script: 48 | 49 | ```shell 50 | python tools/data/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K 51 | ``` 52 | 53 | The generated data is stored under `DIV2K` and the data structure is as follows, where `_sub` indicates the sub-images. 54 | 55 | ```text 56 | mmediting 57 | ├── mmedit 58 | ├── tools 59 | ├── configs 60 | ├── data 61 | │ ├── DIV2K 62 | │ │ ├── DIV2K_train_HR 63 | │ │ ├── DIV2K_train_HR_sub 64 | │ │ ├── DIV2K_train_LR_bicubic 65 | │ │ │ ├── X2 66 | │ │ │ ├── X3 67 | │ │ │ ├── X4 68 | │ │ │ ├── X2_sub 69 | │ │ │ ├── X3_sub 70 | │ │ │ ├── X4_sub 71 | │ │ ├── DIV2K_valid_HR 72 | │ │ ├── ... 73 | ... 74 | ``` 75 | 76 | ## Prepare annotation list 77 | 78 | If you use the annotation mode for the dataset, you first need to prepare a specific `txt` file. 79 | 80 | Each line in the annotation file contains the image names and image shape (usually for the ground-truth images), separated by a white space. 81 | 82 | Example of an annotation file: 83 | 84 | ```text 85 | 0001_s001.png (480,480,3) 86 | 0001_s002.png (480,480,3) 87 | ``` 88 | 89 | ## Prepare LMDB dataset for DIV2K 90 | 91 | If you want to use LMDB datasets for faster IO speed, you can make LMDB files by: 92 | 93 | ```shell 94 | python tools/data/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K --make-lmdb 95 | ``` 96 | -------------------------------------------------------------------------------- /tools/data/super-resolution/div2k/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 DIV2K 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @InProceedings{Agustsson_2017_CVPR_Workshops, 7 | author = {Agustsson, Eirikur and Timofte, Radu}, 8 | title = {NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study}, 9 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 10 | month = {July}, 11 | year = {2017} 12 | } 13 | ``` 14 | 15 | - 训练集: [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/). 16 | - 验证集: Set5 and Set14. 17 | 18 | ```text 19 | mmediting 20 | ├── mmedit 21 | ├── tools 22 | ├── configs 23 | ├── data 24 | │ ├── DIV2K 25 | │ │ ├── DIV2K_train_HR 26 | │ │ ├── DIV2K_train_LR_bicubic 27 | │ │ │ ├── X2 28 | │ │ │ ├── X3 29 | │ │ │ ├── X4 30 | │ │ ├── DIV2K_valid_HR 31 | │ │ ├── DIV2K_valid_LR_bicubic 32 | │ │ │ ├── X2 33 | │ │ │ ├── X3 34 | │ │ │ ├── X4 35 | │ ├── val_set5 36 | │ │ ├── Set5_bicLRx2 37 | │ │ ├── Set5_bicLRx3 38 | │ │ ├── Set5_bicLRx4 39 | │ ├── val_set14 40 | │ │ ├── Set14_bicLRx2 41 | │ │ ├── Set14_bicLRx3 42 | │ │ ├── Set14_bicLRx4 43 | ``` 44 | 45 | ## 裁剪子图 46 | 47 | 为了加快 IO,建议将 DIV2K 中的图片裁剪为一系列子图,为此,我们提供了一个脚本: 48 | 49 | ```shell 50 | python tools/data/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K 51 | ``` 52 | 53 | 生成的数据保存在 `DIV2K` 目录下,其文件结构如下所示,其中 `_sub` 表示子图: 54 | 55 | ```text 56 | mmediting 57 | ├── mmedit 58 | ├── tools 59 | ├── configs 60 | ├── data 61 | │ ├── DIV2K 62 | │ │ ├── DIV2K_train_HR 63 | │ │ ├── DIV2K_train_HR_sub 64 | │ │ ├── DIV2K_train_LR_bicubic 65 | │ │ │ ├── X2 66 | │ │ │ ├── X3 67 | │ │ │ ├── X4 68 | │ │ │ ├── X2_sub 69 | │ │ │ ├── X3_sub 70 | │ │ │ ├── X4_sub 71 | │ │ ├── DIV2K_valid_HR 72 | │ │ ├── ... 73 | ... 74 | ``` 75 | 76 | ## 准备标注列表文件 77 | 78 | 如果您想使用`标注模式`来处理数据集,需要先准备一个 `txt` 格式的标注文件。 79 | 80 | 标注文件中的每一行包含了图片名以及图片尺寸(这些通常是 ground-truth 图片),这两个字段用空格间隔开。 81 | 82 | 标注文件示例: 83 | 84 | ```text 85 | 0001_s001.png (480,480,3) 86 | 0001_s002.png (480,480,3) 87 | ``` 88 | 89 | ## 准备 LMDB 格式的 DIV2K 数据集 90 | 91 | 如果您想使用 `LMDB` 以获得更快的 IO 速度,可以通过以下脚本来构建 LMDB 文件 92 | 93 | ```shell 94 | python tools/data/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K --make-lmdb 95 | ``` 96 | -------------------------------------------------------------------------------- /tools/data/super-resolution/reds/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 REDS 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @InProceedings{Nah_2019_CVPR_Workshops_REDS, 7 | author = {Nah, Seungjun and Baik, Sungyong and Hong, Seokil and Moon, Gyeongsik and Son, Sanghyun and Timofte, Radu and Lee, Kyoung Mu}, 8 | title = {NTIRE 2019 Challenge on Video Deblurring and Super-Resolution: Dataset and Study}, 9 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 10 | month = {June}, 11 | year = {2019} 12 | } 13 | ``` 14 | 15 | - 训练集: [REDS 数据集](https://seungjunnah.github.io/Datasets/reds.html). 16 | - 验证集: [REDS 数据集](https://seungjunnah.github.io/Datasets/reds.html) 和 Vid4. 17 | 18 | 请注意,我们合并了 REDS 的训练集和验证集,以便在 REDS4 划分(在 `EDVR` 中会使用到)和官方验证集划分之间切换。 19 | 20 | 原始验证集的名称被修改了(clip 000 到 029),以避免与训练集发生冲突(总共 240 个 clip)。具体而言,验证集中的 clips 被改名为 240、241、... 269。 21 | 22 | 可通过运行以下命令来准备 REDS 数据集: 23 | 24 | ```shell 25 | python tools/data/super-resolution/reds/preprocess_reds_dataset.py ./data/REDS 26 | ``` 27 | 28 | ```text 29 | mmediting 30 | ├── mmedit 31 | ├── tools 32 | ├── configs 33 | ├── data 34 | │ ├── REDS 35 | │ │ ├── train_sharp 36 | │ │ │ ├── 000 37 | │ │ │ ├── 001 38 | │ │ │ ├── ... 39 | │ │ ├── train_sharp_bicubic 40 | │ │ │ ├── 000 41 | │ │ │ ├── 001 42 | │ │ │ ├── ... 43 | │ ├── REDS4 44 | │ │ ├── GT 45 | │ │ ├── sharp_bicubic 46 | ``` 47 | 48 | ## 准备 LMDB 格式的 REDS 数据集 49 | 50 | 如果您想使用 `LMDB` 以获得更快的 IO 速度,可以通过以下脚本来构建 LMDB 文件: 51 | 52 | ```shell 53 | python tools/data/super-resolution/reds/preprocess_reds_dataset.py --root-path ./data/REDS --make-lmdb 54 | ``` 55 | 56 | ## 裁剪为子图 57 | 58 | MMEditing 支持将 REDS 图像裁剪为子图像以加快 IO。我们提供了这样一个脚本: 59 | 60 | ```shell 61 | python tools/data/super-resolution/reds/crop_sub_images.py --data-root ./data/REDS -scales 4 62 | ``` 63 | 64 | 生成的数据存储在 `REDS` 下,数据结构如下,其中`_sub`表示子图像。 65 | 66 | ```text 67 | mmediting 68 | ├── mmedit 69 | ├── tools 70 | ├── configs 71 | ├── data 72 | │ ├── REDS 73 | │ │ ├── train_sharp 74 | │ │ │ ├── 000 75 | │ │ │ ├── 001 76 | │ │ │ ├── ... 77 | │ │ ├── train_sharp_sub 78 | │ │ │ ├── 000_s001 79 | │ │ │ ├── 000_s002 80 | │ │ │ ├── ... 81 | │ │ │ ├── 001_s001 82 | │ │ │ ├── ... 83 | │ │ ├── train_sharp_bicubic 84 | │ │ │ ├── X4 85 | │ │ │ │ ├── 000 86 | │ │ │ │ ├── 001 87 | │ │ │ │ ├── ... 88 | │ │ │ ├── X4_sub 89 | │ │ │ ├── 000_s001 90 | │ │ │ ├── 000_s002 91 | │ │ │ ├── ... 92 | │ │ │ ├── 001_s001 93 | │ │ │ ├── ... 94 | ``` 95 | 96 | 请注意,默认情况下,`preprocess_reds_dataset.py` 不会为裁剪后的数据集制作 lmdb 和注释文件。您可能需要为此类操作稍微修改脚本。 97 | -------------------------------------------------------------------------------- /tools/data/super-resolution/vid4/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Vid4 Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @article{xue2019video, 7 | title={On Bayesian adaptive video super resolution}, 8 | author={Liu, Ce and Sun, Deqing}, 9 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 10 | volume={36}, 11 | number={2}, 12 | pages={346--360}, 13 | year={2013}, 14 | publisher={IEEE} 15 | } 16 | ``` 17 | 18 | The Vid4 dataset can be downloaded from [here](https://drive.google.com/file/d/1ZuvNNLgR85TV_whJoHM7uVb-XW1y70DW/view?usp=sharing). There are two degradations in the dataset. 19 | 20 | 1. BIx4 contains images downsampled by bicubic interpolation 21 | 2. BDx4 contains images blurred by Gaussian kernel with σ=1.6, followed by a subsampling every four pixels. 22 | -------------------------------------------------------------------------------- /tools/data/super-resolution/vid4/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 Vid4 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @article{xue2019video, 7 | title={On Bayesian adaptive video super resolution}, 8 | author={Liu, Ce and Sun, Deqing}, 9 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 10 | volume={36}, 11 | number={2}, 12 | pages={346--360}, 13 | year={2013}, 14 | publisher={IEEE} 15 | } 16 | ``` 17 | 18 | 可以从 [此处](https://drive.google.com/file/d/1ZuvNNLgR85TV_whJoHM7uVb-XW1y70DW/view?usp=sharing) 下载 Vid4 数据集, 19 | 其中包含了由两种下采样方法得到的图片: 20 | 21 | 1. BIx4 包含了由双线性插值下采样得到的图片 22 | 2. BDx4 包含了由 `σ=1.6` 的高斯核模糊,然后每4个像素进行一次采样得到的图片 23 | -------------------------------------------------------------------------------- /tools/data/super-resolution/vimeo90k/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Vimeo90K Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @article{xue2019video, 7 | title={Video Enhancement with Task-Oriented Flow}, 8 | author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and Freeman, William T}, 9 | journal={International Journal of Computer Vision (IJCV)}, 10 | volume={127}, 11 | number={8}, 12 | pages={1106--1125}, 13 | year={2019}, 14 | publisher={Springer} 15 | } 16 | ``` 17 | 18 | The training and test datasets can be download from [here](http://toflow.csail.mit.edu/). 19 | 20 | The Vimeo90K dataset has a `clip/sequence/img` folder structure: 21 | ```text 22 | ├── GT/LQ 23 | │ ├── 00001 24 | │ │ ├── 0001 25 | │ │ │ ├── im1.png 26 | │ │ │ ├── im2.png 27 | │ │ │ ├── ... 28 | │ │ ├── 0002 29 | │ │ ├── 0003 30 | │ │ ├── ... 31 | │ ├── 00002 32 | │ ├── ... 33 | ``` 34 | 35 | 36 | 37 | 38 | ## Prepare the annotation files for Vimeo90K dataset 39 | 40 | To prepare the annotation file for training, you need to download the official training list path for Vimeo90K from the official website, and run the following command: 41 | 42 | ```shell 43 | python tools/data/super-resolution/vimeo90k/preprocess_vimeo90k_dataset.py ./data/Vimeo90K/official_train_list.txt 44 | ``` 45 | 46 | The annotation file for test is generated similarly. 47 | 48 | 49 | 50 | ## Prepare LMDB dataset for Vimeo90K 51 | 52 | If you want to use LMDB datasets for faster IO speed, you can make LMDB files by: 53 | 54 | ```shell 55 | python tools/data/super-resolution/vimeo90k/preprocess_vimeo90k_dataset.py ./data/Vimeo90K/official_train_list.txt --gt-path ./data/Vimeo90K/GT --lq-path ./data/Vimeo90K/LQ --make-lmdb 56 | ``` 57 | -------------------------------------------------------------------------------- /tools/data/super-resolution/vimeo90k/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 Vimeo90K 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @article{xue2019video, 7 | title={Video Enhancement with Task-Oriented Flow}, 8 | author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and Freeman, William T}, 9 | journal={International Journal of Computer Vision (IJCV)}, 10 | volume={127}, 11 | number={8}, 12 | pages={1106--1125}, 13 | year={2019}, 14 | publisher={Springer} 15 | } 16 | ``` 17 | 18 | 训练集和测试集可以从 [此处](http://toflow.csail.mit.edu/) 下载。 19 | 20 | Vimeo90K 数据集包含了如下所示的 `clip/sequence/img` 目录结构: 21 | 22 | ```text 23 | ├── GT/LQ 24 | │ ├── 00001 25 | │ │ ├── 0001 26 | │ │ │ ├── im1.png 27 | │ │ │ ├── im2.png 28 | │ │ │ ├── ... 29 | │ │ ├── 0002 30 | │ │ ├── 0003 31 | │ │ ├── ... 32 | │ ├── 00002 33 | │ ├── ... 34 | ``` 35 | 36 | ## 准备 Vimeo90K 数据集的标注文件 37 | 38 | 为了准备好训练所需的标注文件,请先从 Vimeo90K 数据集官网下载训练路径列表,随后执行如下命令: 39 | 40 | ```shell 41 | python tools/data/super-resolution/vimeo90k/preprocess_vimeo90k_dataset.py ./data/Vimeo90K/official_train_list.txt 42 | ``` 43 | 44 | 测试集的标注文件可通过类似方式生成. 45 | 46 | ## 准备 LMDB 格式的 Vimeo90K 数据集 47 | 48 | 如果您想使用 `LMDB` 以获得更快的 IO 速度,可以通过以下脚本来构建 LMDB 文件 49 | 50 | ```shell 51 | python tools/data/super-resolution/vimeo90k/preprocess_vimeo90k_dataset.py ./data/Vimeo90K/official_train_list.txt --gt-path ./data/Vimeo90K/GT --lq-path ./data/Vimeo90K/LQ --make-lmdb 52 | ``` 53 | -------------------------------------------------------------------------------- /tools/data/video-interpolation/README.md: -------------------------------------------------------------------------------- 1 | # Video Frame Interpolation Datasets 2 | 3 | It is recommended to symlink the dataset root to `$MMEDITING/data`. If your folder structure is different, you may need to change the corresponding paths in config files. 4 | 5 | MMEditing supported video frame interpolation datasets: 6 | 7 | - [Vimeo90K-triplet](vimeo90k-triplet/README.md) \[ [Homepage](http://toflow.csail.mit.edu) \] 8 | -------------------------------------------------------------------------------- /tools/data/video-interpolation/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 视频插帧数据集 2 | 3 | 建议将数据集的根目录链接到 `$MMEDITING/data` 下,如果您的文件目录结构不一致,那么可能需要在配置文件中修改对应的文件路径。 4 | 5 | MMEditing 支持下列视频插帧数据集: 6 | 7 | - [Vimeo90K-triplet](vimeo90k-triplet/README.md) \[ [Homepage](http://toflow.csail.mit.edu) \] 8 | -------------------------------------------------------------------------------- /tools/data/video-interpolation/vimeo90k-triplet/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Vimeo90K-triplet Dataset 2 | 3 | 4 | 5 | ```bibtex 6 | @article{xue2019video, 7 | title={Video Enhancement with Task-Oriented Flow}, 8 | author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and Freeman, William T}, 9 | journal={International Journal of Computer Vision (IJCV)}, 10 | volume={127}, 11 | number={8}, 12 | pages={1106--1125}, 13 | year={2019}, 14 | publisher={Springer} 15 | } 16 | ``` 17 | 18 | The training and test datasets can be download from [here](http://toflow.csail.mit.edu/). 19 | 20 | The Vimeo90K-triplet dataset has a `clip/sequence/img` folder structure: 21 | 22 | ```text 23 | ├── tri_testlist.txt 24 | ├── tri_trainlist.txt 25 | ├── sequences 26 | │ ├── 00001 27 | │ │ ├── 0001 28 | │ │ │ ├── im1.png 29 | │ │ │ ├── im2.png 30 | │ │ │ └── im3.png 31 | │ │ ├── 0002 32 | │ │ ├── 0003 33 | │ │ ├── ... 34 | │ ├── 00002 35 | │ ├── ... 36 | ``` 37 | -------------------------------------------------------------------------------- /tools/data/video-interpolation/vimeo90k-triplet/README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 准备 Vimeo90K-triplet 数据集 2 | 3 | 4 | 5 | ```bibtex 6 | @article{xue2019video, 7 | title={Video Enhancement with Task-Oriented Flow}, 8 | author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and Freeman, William T}, 9 | journal={International Journal of Computer Vision (IJCV)}, 10 | volume={127}, 11 | number={8}, 12 | pages={1106--1125}, 13 | year={2019}, 14 | publisher={Springer} 15 | } 16 | ``` 17 | 18 | 训练集和测试集可以从 [此处](http://toflow.csail.mit.edu/) 下载。 19 | 20 | Vimeo90K-triplet 数据集包含了如下所示的 `clip/sequence/img` 目录结构: 21 | 22 | ```text 23 | ├── tri_testlist.txt 24 | ├── tri_trainlist.txt 25 | ├── sequences 26 | │ ├── 00001 27 | │ │ ├── 0001 28 | │ │ │ ├── im1.png 29 | │ │ │ ├── im2.png 30 | │ │ │ └── im3.png 31 | │ │ ├── 0002 32 | │ │ ├── 0003 33 | │ │ ├── ... 34 | │ ├── 00002 35 | │ ├── ... 36 | ``` 37 | -------------------------------------------------------------------------------- /tools/deployment/mmedit_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import random 4 | import string 5 | from io import BytesIO 6 | 7 | import PIL.Image as Image 8 | import torch 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmedit.apis import init_model, restoration_inference 12 | from mmedit.core import tensor2img 13 | 14 | 15 | class MMEditHandler(BaseHandler): 16 | 17 | def initialize(self, context): 18 | print('MMEditHandler.initialize is called') 19 | properties = context.system_properties 20 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | self.device = torch.device(self.map_location + ':' + 22 | str(properties.get('gpu_id')) if torch.cuda. 23 | is_available() else self.map_location) 24 | self.manifest = context.manifest 25 | 26 | model_dir = properties.get('model_dir') 27 | serialized_file = self.manifest['model']['serializedFile'] 28 | checkpoint = os.path.join(model_dir, serialized_file) 29 | self.config_file = os.path.join(model_dir, 'config.py') 30 | 31 | self.model = init_model(self.config_file, checkpoint, self.device) 32 | self.initialized = True 33 | 34 | def preprocess(self, data, *args, **kwargs): 35 | body = data[0].get('data') or data[0].get('body') 36 | result = Image.open(BytesIO(body)) 37 | # data preprocess is in inference. 38 | return result 39 | 40 | def inference(self, data, *args, **kwargs): 41 | # generate temp image path for restoration_inference 42 | temp_name = ''.join( 43 | random.sample(string.ascii_letters + string.digits, 18)) 44 | temp_path = f'./{temp_name}.png' 45 | data.save(temp_path) 46 | results = restoration_inference(self.model, temp_path) 47 | # delete the temp image path 48 | os.remove(temp_path) 49 | return results 50 | 51 | def postprocess(self, data): 52 | # convert torch tensor to numpy and then convert to bytes 53 | output_list = [] 54 | for data_ in data: 55 | data_np = tensor2img(data_) 56 | data_byte = data_np.tobytes() 57 | output_list.append(data_byte) 58 | 59 | return output_list 60 | -------------------------------------------------------------------------------- /tools/deployment/test_torchserver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | 4 | import cv2 5 | import numpy as np 6 | import requests 7 | from PIL import Image 8 | 9 | 10 | def parse_args(): 11 | parser = ArgumentParser() 12 | parser.add_argument('model_name', help='The model name in the server') 13 | parser.add_argument( 14 | '--inference-addr', 15 | default='127.0.0.1:8080', 16 | help='Address and port of the inference server') 17 | parser.add_argument('--img-path', type=str, help='The input LQ image.') 18 | parser.add_argument( 19 | '--save-path', type=str, help='Path to save the generated GT image.') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def save_results(content, save_path, ori_shape): 25 | ori_len = np.prod(ori_shape) 26 | scale = int(np.sqrt(len(content) / ori_len)) 27 | target_size = [int(size * scale) for size in ori_shape[:2][::-1]] 28 | # Convert to RGB and save image 29 | img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0) 30 | img.save(save_path) 31 | 32 | 33 | def main(args): 34 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 35 | ori_shape = cv2.imread(args.img_path).shape 36 | with open(args.img_path, 'rb') as image: 37 | response = requests.post(url, image) 38 | save_results(response.content, args.save_path, ori_shape) 39 | 40 | 41 | if __name__ == '__main__': 42 | parsed_args = parse_args() 43 | main(parsed_args) 44 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | OTHER_CONFIG=$4 7 | NNODES=${NNODES:-1} 8 | NODE_RANK=${NODE_RANK:-0} 9 | PORT=${PORT:-29001} 10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | 12 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 13 | python -m torch.distributed.launch \ 14 | --nnodes=$NNODES \ 15 | --node_rank=$NODE_RANK \ 16 | --master_addr=$MASTER_ADDR \ 17 | --nproc_per_node=$GPUS \ 18 | --master_port=$PORT \ 19 | $(dirname "$0")/test.py \ 20 | $CONFIG \ 21 | $CHECKPOINT \ 22 | $OTHER_CONFIG \ 23 | --launcher pytorch \ 24 | ${@:4} 25 | -------------------------------------------------------------------------------- /tools/dist_test_s1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # CONFIG=$1 4 | # CHECKPOINT=$2 5 | # GPUS=$3 6 | # OTHER_CONFIG=$4 7 | # NNODES=${NNODES:-1} 8 | # NODE_RANK=${NODE_RANK:-0} 9 | # PORT=${PORT:-29001} 10 | # MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | 12 | # PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 13 | # python -m torch.distributed.launch \ 14 | # --nnodes=$NNODES \ 15 | # --node_rank=$NODE_RANK \ 16 | # --master_addr=$MASTER_ADDR \ 17 | # --nproc_per_node=$GPUS \ 18 | # --master_port=$PORT \ 19 | # $(dirname "$0")/test_stage1.py \ 20 | # $CONFIG \ 21 | # $CHECKPOINT \ 22 | # $OTHER_CONFIG \ 23 | # --launcher pytorch \ 24 | # ${@:4} 25 | CONFIG=$1 26 | CHECKPOINT=$2 27 | GPUS=$3 28 | NNODES=${NNODES:-1} 29 | NODE_RANK=${NODE_RANK:-0} 30 | PORT=${PORT:-29500} 31 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 32 | 33 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 34 | python -m torch.distributed.launch \ 35 | --nnodes=$NNODES \ 36 | --node_rank=$NODE_RANK \ 37 | --master_addr=$MASTER_ADDR \ 38 | --nproc_per_node=$GPUS \ 39 | --master_port=$PORT \ 40 | $(dirname "$0")/test_stage1.py \ 41 | $CONFIG \ 42 | $CHECKPOINT \ 43 | --launcher pytorch \ 44 | ${@:4} -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | OTHER_CONFIG=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29501} 9 | # PORT=${PORT:-145622} 10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | 12 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 13 | python -m torch.distributed.launch \ 14 | --nnodes=$NNODES \ 15 | --node_rank=$NODE_RANK \ 16 | --master_addr=$MASTER_ADDR \ 17 | --nproc_per_node=$GPUS \ 18 | --master_port=$PORT \ 19 | $(dirname "$0")/train.py \ 20 | $CONFIG \ 21 | $OTHER_CONFIG \ 22 | --seed 0 \ 23 | --launcher pytorch ${@:3} 24 | -------------------------------------------------------------------------------- /tools/dist_train_iter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | OTHER_CONFIG=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/train_iter.py \ 19 | $CONFIG \ 20 | $OTHER_CONFIG \ 21 | --seed 0 \ 22 | --launcher pytorch ${@:3} 23 | -------------------------------------------------------------------------------- /tools/dist_train_s1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | OTHER_CONFIG=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/train_s1.py \ 19 | $CONFIG \ 20 | $OTHER_CONFIG \ 21 | --seed 0 \ 22 | --launcher pytorch ${@:3} 23 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | from packaging import version 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Process a checkpoint to be published') 12 | parser.add_argument('in_file', help='input checkpoint filename') 13 | parser.add_argument('out_file', help='output checkpoint filename') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def process_checkpoint(in_file, out_file): 19 | checkpoint = torch.load(in_file, map_location='cpu') 20 | # remove optimizer for smaller file size 21 | if 'optimizer' in checkpoint: 22 | del checkpoint['optimizer'] 23 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 24 | # add the code here. 25 | if version.parse(torch.__version__) >= version.parse('1.6'): 26 | torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) 27 | else: 28 | torch.save(checkpoint, out_file) 29 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 30 | final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' 31 | subprocess.Popen(['mv', out_file, final_file]) 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | process_checkpoint(args.in_file, args.out_file) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | GPUS=${GPUS:-8} 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 13 | PY_ARGS=${@:5} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 26 | --------------------------------------------------------------------------------