├── requirements ├── mminstall.txt ├── readthedocs.txt ├── tests.txt ├── docs.txt └── runtime.txt ├── requirements.txt ├── docs ├── image │ ├── map_net.jpg │ └── haze_world.jpg └── dataset_prepare.md ├── mmedit ├── models │ ├── dehazers │ │ └── __init__.py │ ├── backbones │ │ ├── map_backbones │ │ │ ├── __init__.py │ │ │ ├── map_utils.py │ │ │ └── map_modules.py │ │ └── __init__.py │ ├── common │ │ ├── conv.py │ │ ├── downsample.py │ │ ├── img_normalize.py │ │ ├── upsample.py │ │ ├── __init__.py │ │ ├── flow_warp.py │ │ ├── gated_conv_module.py │ │ ├── sr_backbone_utils.py │ │ ├── linear_module.py │ │ ├── mask_conv_module.py │ │ ├── ensemble.py │ │ ├── partial_conv.py │ │ ├── separable_conv_module.py │ │ ├── aspp.py │ │ └── model_utils.py │ ├── registry.py │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── gradient_loss.py │ │ └── utils.py │ ├── builder.py │ └── base.py ├── core │ ├── evaluation │ │ ├── niqe_pris_params.npz │ │ ├── __init__.py │ │ ├── metric_utils.py │ │ ├── eval_hooks.py │ │ └── inceptions.py │ ├── optimizer │ │ ├── __init__.py │ │ └── builder.py │ ├── utils │ │ ├── __init__.py │ │ └── dist_utils.py │ ├── export │ │ ├── __init__.py │ │ └── wrappers.py │ ├── scheduler │ │ └── __init__.py │ ├── registry.py │ ├── hooks │ │ ├── __init__.py │ │ ├── visualization.py │ │ └── ema.py │ ├── __init__.py │ └── misc.py ├── datasets │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ ├── registry.py │ ├── __init__.py │ ├── dataset_wrappers.py │ ├── pipelines │ │ ├── compose.py │ │ ├── __init__.py │ │ ├── crop_hazeworld.py │ │ ├── normalization.py │ │ ├── random_down_sampling.py │ │ └── utils.py │ ├── base_dataset.py │ ├── base_dh_dataset.py │ ├── base_sr_dataset.py │ └── sr_folder_multiple_gt_dataset.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── cli.py │ ├── logger.py │ ├── setup_env.py │ └── misc.py ├── version.py ├── apis │ ├── __init__.py │ ├── inpainting_inference.py │ ├── restoration_inference.py │ ├── generation_inference.py │ ├── matting_inference.py │ ├── restoration_face_inference.py │ └── restoration_video_inference.py └── __init__.py ├── .readthedocs.yml ├── MANIFEST.in ├── tools ├── dist_train.sh ├── dist_test.sh ├── slurm_test.sh ├── slurm_train.sh ├── publish_model.py ├── deployment │ ├── test_torchserver.py │ ├── mmedit_handler.py │ └── mmedit2torchserve.py └── get_flops.py ├── configs └── dehazers │ ├── mapnet │ ├── mapnet_runtime.py │ └── mapnet_hazeworld.py │ └── _base_ │ ├── default_runtime.py │ ├── schedules │ ├── schedule_40k.py │ └── schedule_80k_eval.py │ └── datasets │ ├── revide.py │ └── hazeworld.py ├── setup.cfg ├── LICENSE ├── .gitignore └── README.md /requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv-full>=1.3.17 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/runtime.txt 2 | -r requirements/tests.txt 3 | -------------------------------------------------------------------------------- /docs/image/map_net.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqixuac/MAP-Net/HEAD/docs/image/map_net.jpg -------------------------------------------------------------------------------- /docs/image/haze_world.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqixuac/MAP-Net/HEAD/docs/image/haze_world.jpg -------------------------------------------------------------------------------- /mmedit/models/dehazers/__init__.py: -------------------------------------------------------------------------------- 1 | from .map import MAP 2 | 3 | __all__ = [ 4 | 'MAP' 5 | ] 6 | -------------------------------------------------------------------------------- /requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | lmdb 2 | mmcv 3 | regex 4 | scikit-image 5 | titlecase 6 | torch 7 | torchvision 8 | -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | flake8 3 | interrogate 4 | isort==5.10.1 5 | onnxruntime 6 | pytest 7 | pytest-runner 8 | yapf 9 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiaqixuac/MAP-Net/HEAD/mmedit/core/evaluation/niqe_pris_params.npz -------------------------------------------------------------------------------- /mmedit/models/backbones/map_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .convnext import ConvNeXt 2 | from .mapnet_net import MAPNet 3 | 4 | __all__ = [ 5 | 'MAPNet' 6 | ] 7 | -------------------------------------------------------------------------------- /mmedit/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_optimizers 3 | 4 | __all__ = ['build_optimizers'] 5 | -------------------------------------------------------------------------------- /mmedit/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import sync_random_seed 3 | 4 | __all__ = ['sync_random_seed'] 5 | -------------------------------------------------------------------------------- /mmedit/core/export/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .wrappers import ONNXRuntimeEditing 3 | 4 | __all__ = ['ONNXRuntimeEditing'] 5 | -------------------------------------------------------------------------------- /mmedit/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .map_backbones import MAPNet 3 | 4 | __all__ = [ 5 | 'MAPNet' 6 | ] 7 | -------------------------------------------------------------------------------- /mmedit/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /mmedit/datasets/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry 3 | 4 | DATASETS = Registry('dataset') 5 | PIPELINES = Registry('pipeline') 6 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: all 4 | 5 | python: 6 | version: 3.7 7 | install: 8 | - requirements: requirements/docs.txt 9 | - requirements: requirements/readthedocs.txt 10 | -------------------------------------------------------------------------------- /mmedit/core/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .lr_updater import LinearLrUpdaterHook, ReduceLrUpdaterHook 3 | 4 | __all__ = ['LinearLrUpdaterHook', 'ReduceLrUpdaterHook'] 5 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | mmcls==0.10.0 3 | myst_parser 4 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | sphinx==4.0.2 6 | sphinx-copybutton 7 | sphinx_markdown_tables 8 | -------------------------------------------------------------------------------- /mmedit/models/common/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import CONV_LAYERS 3 | from torch import nn 4 | 5 | CONV_LAYERS.register_module('Deconv', module=nn.ConvTranspose2d) 6 | # TODO: octave conv 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements/*.txt 2 | include mmedit/.mim/VERSION 3 | include mmedit/.mim/model-index.yml 4 | recursive-include mmedit/.mim/configs *.py *.yml 5 | recursive-include mmedit/.mim/tools *.sh *.py 6 | recursive-include mmedit/.mim/demo *.py 7 | -------------------------------------------------------------------------------- /mmedit/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import MODELS as MMCV_MODELS 3 | from mmcv.utils import Registry 4 | 5 | MODELS = Registry('model', parent=MMCV_MODELS) 6 | BACKBONES = MODELS 7 | COMPONENTS = MODELS 8 | LOSSES = MODELS 9 | -------------------------------------------------------------------------------- /mmedit/core/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | METRICS = Registry('metric') 5 | 6 | 7 | def build_metric(cfg): 8 | """Build a metric calculator.""" 9 | return build_from_cfg(cfg, METRICS) 10 | -------------------------------------------------------------------------------- /mmedit/core/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ema import ExponentialMovingAverageHook 3 | from .visualization import MMEditVisualizationHook, VisualizationHook 4 | 5 | __all__ = [ 6 | 'VisualizationHook', 'MMEditVisualizationHook', 7 | 'ExponentialMovingAverageHook' 8 | ] 9 | -------------------------------------------------------------------------------- /mmedit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cli import modify_args 3 | from .logger import get_root_logger 4 | from .misc import deprecated_function 5 | from .setup_env import setup_multi_processes 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'setup_multi_processes', 'modify_args', 9 | 'deprecated_function' 10 | ] 11 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import DistEvalIterHook, EvalIterHook 3 | from .inceptions import FID, KID, InceptionV3 4 | from .metrics import (L1Evaluation, connectivity, gradient_error, mae, mse, 5 | niqe, psnr, reorder_image, sad, ssim) 6 | 7 | __all__ = [ 8 | 'mse', 'sad', 'psnr', 'reorder_image', 'ssim', 'EvalIterHook', 9 | 'DistEvalIterHook', 'L1Evaluation', 'gradient_error', 'connectivity', 10 | 'niqe', 'mae', 'FID', 'KID', 'InceptionV3' 11 | ] 12 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/train.py \ 18 | $CONFIG \ 19 | --seed 0 \ 20 | --launcher pytorch ${@:3} 21 | -------------------------------------------------------------------------------- /mmedit/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.16.1' 4 | 5 | 6 | def parse_version_info(version_str): 7 | ver_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | ver_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | ver_info.append(int(patch_version[0])) 14 | ver_info.append(f'rc{patch_version[1]}') 15 | return tuple(ver_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | av 2 | av==8.0.3; python_version < '3.7' 3 | einops 4 | facexlib 5 | lmdb 6 | mmcv-full>=1.3.13 # To support DCN on CPU 7 | numpy 8 | opencv-python!=4.5.5.62,!=4.5.5.64 9 | # MMCV depends opencv-python instead of headless, thus we install opencv-python 10 | # Due to a bug from upstream, we skip this two version 11 | # https://github.com/opencv/opencv-python/issues/602 12 | # https://github.com/opencv/opencv/issues/21366 13 | # It seems to be fixed in https://github.com/opencv/opencv/pull/21382 14 | Pillow 15 | tensorboard 16 | torch 17 | torchvision 18 | -------------------------------------------------------------------------------- /mmedit/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmedit 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMEditing'] = f'{mmedit.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /configs/dehazers/mapnet/mapnet_runtime.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizers = dict( 3 | generator=dict( 4 | type='AdamW', 5 | lr=0.0002, 6 | betas=(0.9, 0.999), 7 | ) 8 | ) 9 | 10 | # learning policy 11 | lr_config = dict( 12 | policy='poly', 13 | warmup='linear', 14 | warmup_iters=1500, 15 | warmup_ratio=1e-6, 16 | power=1.0, 17 | min_lr=1e-7, 18 | by_epoch=False) 19 | 20 | # model training and testing settings 21 | train_cfg = None 22 | test_cfg = dict(metrics=['L1', 'PSNR', 'SSIM'], crop_border=0) 23 | 24 | visual_config = None 25 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/test.py \ 19 | $CONFIG \ 20 | $CHECKPOINT \ 21 | --launcher pytorch \ 22 | ${@:4} 23 | -------------------------------------------------------------------------------- /mmedit/utils/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import re 3 | import sys 4 | import warnings 5 | 6 | 7 | def modify_args(): 8 | for i, v in enumerate(sys.argv): 9 | if i == 0: 10 | assert v.endswith('.py') 11 | elif re.match(r'--\w+_.*', v): 12 | new_arg = v.replace('_', '-') 13 | warnings.warn( 14 | f'command line argument {v} is deprecated, ' 15 | f'please use {new_arg} instead.', 16 | category=DeprecationWarning, 17 | ) 18 | sys.argv[i] = new_arg 19 | -------------------------------------------------------------------------------- /configs/dehazers/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/default_runtime.py 2 | # yapf:disable 3 | log_config = dict( 4 | interval=100, 5 | hooks=[ 6 | dict(type='TextLoggerHook', by_epoch=False), 7 | # dict(type='TensorboardLoggerHook') 8 | # dict(type='PaviLoggerHook') # for internal services 9 | ]) 10 | # yapf:enable 11 | dist_params = dict(backend='nccl') 12 | log_level = 'INFO' 13 | load_from = None 14 | resume_from = None 15 | workflow = [('train', 1)] 16 | cudnn_benchmark = True 17 | # find_unused_parameters = True 18 | -------------------------------------------------------------------------------- /mmedit/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401, F403 3 | from .base import BaseModel 4 | from .builder import (build, build_backbone, build_component, build_loss, 5 | build_model) 6 | from .common import * # noqa: F401, F403 7 | from .dehazers import * 8 | from .losses import * # noqa: F401, F403 9 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS 10 | 11 | __all__ = [ 12 | 'BaseModel','build', 13 | 'build_backbone', 'build_component', 'build_loss', 'build_model', 14 | 'BACKBONES', 'COMPONENTS', 'LOSSES', 'MODELS', 15 | ] 16 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | addopts=tests/ 9 | 10 | [yapf] 11 | based_on_style = pep8 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | split_penalty_import_names=0 15 | SPLIT_PENALTY_AFTER_OPENING_BRACKET=888 16 | 17 | [isort] 18 | line_length = 79 19 | multi_line_output = 0 20 | extra_standard_library = setuptools 21 | known_first_party = mmedit 22 | known_third_party = PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,pytorch_sphinx_theme,requests,scipy,titlecase,torch,torchvision,ts 23 | no_lines_before = STDLIB,LOCALFOLDER 24 | default_section = THIRDPARTY 25 | -------------------------------------------------------------------------------- /mmedit/models/common/downsample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def pixel_unshuffle(x, scale): 3 | """Down-sample by pixel unshuffle. 4 | 5 | Args: 6 | x (Tensor): Input tensor. 7 | scale (int): Scale factor. 8 | 9 | Returns: 10 | Tensor: Output tensor. 11 | """ 12 | 13 | b, c, h, w = x.shape 14 | if h % scale != 0 or w % scale != 0: 15 | raise AssertionError( 16 | f'Invalid scale ({scale}) of pixel unshuffle for tensor ' 17 | f'with shape: {x.shape}') 18 | h = h // scale 19 | w = w // scale 20 | x = x.view(b, c, h, scale, w, scale) 21 | x = x.permute(0, 1, 3, 5, 2, 4) 22 | return x.reshape(b, -1, h, w) 23 | -------------------------------------------------------------------------------- /mmedit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_dataset import BaseDataset 3 | from .base_dh_dataset import BaseDHDataset 4 | from .base_sr_dataset import BaseSRDataset 5 | from .builder import build_dataloader, build_dataset 6 | from .dataset_wrappers import RepeatDataset 7 | from .hw_folder_multiple_gt_dataset import HWFolderMultipleGTDataset 8 | from .registry import DATASETS, PIPELINES 9 | from .sr_folder_multiple_gt_dataset import SRFolderMultipleGTDataset 10 | 11 | __all__ = [ 12 | 'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader', 13 | 'BaseDataset', 'BaseDHDataset', 'HWFolderMultipleGTDataset', 14 | 'BaseSRDataset', 'RepeatDataset', 'SRFolderMultipleGTDataset' 15 | ] 16 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | GPUS=${GPUS:-8} 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 13 | PY_ARGS=${@:5} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 26 | -------------------------------------------------------------------------------- /configs/dehazers/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/schedules/schedule_40k.py 2 | # https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/basicvsr/basicvsr_reds4.py 3 | # # optimizer 4 | # optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 5 | # optimizer_config = dict() 6 | # # learning policy 7 | # lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 8 | # runtime settings 9 | total_iters = 40000 10 | runner = dict(type='IterBasedRunner', max_iters=40000) 11 | checkpoint_config = dict(by_epoch=False, interval=40000) 12 | # remove gpu_collect=True in non distributed training 13 | evaluation = dict(interval=40000, save_image=False, gpu_collect=True) 14 | -------------------------------------------------------------------------------- /configs/dehazers/_base_/schedules/schedule_80k_eval.py: -------------------------------------------------------------------------------- 1 | # https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/schedules/schedule_80k.py 2 | # https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/basicvsr/basicvsr_reds4.py 3 | # # optimizer 4 | # optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 5 | # optimizer_config = dict() 6 | # # learning policy 7 | # lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 8 | # runtime settings 9 | total_iters = 80000 10 | runner = dict(type='IterBasedRunner', max_iters=80000) 11 | checkpoint_config = dict(by_epoch=False, interval=8000) 12 | # remove gpu_collect=True in non distributed training 13 | evaluation = dict(interval=8000, save_image=False, gpu_collect=True) 14 | -------------------------------------------------------------------------------- /mmedit/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .composition_loss import (CharbonnierCompLoss, L1CompositionLoss, 3 | MSECompositionLoss) 4 | from .gradient_loss import GradientLoss 5 | from .perceptual_loss import (PerceptualLoss, PerceptualVGG, 6 | TransferalPerceptualLoss) 7 | from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss 8 | from .utils import mask_reduce_loss, reduce_loss 9 | 10 | __all__ = [ 11 | 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss', 12 | 'MSECompositionLoss', 'CharbonnierCompLoss', 13 | 'PerceptualLoss', 'PerceptualVGG', 'reduce_loss', 14 | 'mask_reduce_loss', 'MaskedTVLoss', 'GradientLoss', 15 | 'TransferalPerceptualLoss', 16 | ] 17 | -------------------------------------------------------------------------------- /mmedit/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import (FID, KID, DistEvalIterHook, EvalIterHook, InceptionV3, 3 | L1Evaluation, mae, mse, psnr, reorder_image, sad, 4 | ssim) 5 | from .hooks import MMEditVisualizationHook, VisualizationHook 6 | from .misc import tensor2img 7 | from .optimizer import build_optimizers 8 | from .registry import build_metric 9 | from .scheduler import LinearLrUpdaterHook, ReduceLrUpdaterHook 10 | 11 | __all__ = [ 12 | 'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook', 13 | 'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook', 14 | 'VisualizationHook', 'MMEditVisualizationHook', 'L1Evaluation', 'FID', 15 | 'KID', 'InceptionV3', 'build_metric', 'ReduceLrUpdaterHook', 'mae' 16 | ] 17 | -------------------------------------------------------------------------------- /mmedit/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generation_inference import generation_inference 3 | from .inpainting_inference import inpainting_inference 4 | from .matting_inference import init_model, matting_inference 5 | from .restoration_face_inference import restoration_face_inference 6 | from .restoration_inference import restoration_inference 7 | from .restoration_video_inference import restoration_video_inference 8 | from .test import multi_gpu_test, single_gpu_test 9 | from .train import init_random_seed, set_random_seed, train_model 10 | from .video_interpolation_inference import video_interpolation_inference 11 | 12 | __all__ = [ 13 | 'train_model', 'set_random_seed', 'init_model', 'matting_inference', 14 | 'inpainting_inference', 'restoration_inference', 'generation_inference', 15 | 'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference', 16 | 'restoration_face_inference', 'video_interpolation_inference', 17 | 'init_random_seed' 18 | ] 19 | -------------------------------------------------------------------------------- /mmedit/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmedit". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | # root logger name: mmedit 26 | logger = get_logger(__name__.split('.')[0], log_file, log_level) 27 | return logger 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jiaqi Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mmedit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | from .version import __version__, version_info 5 | 6 | try: 7 | from mmcv.utils import digit_version 8 | except ImportError: 9 | 10 | def digit_version(version_str): 11 | digit_ver = [] 12 | for x in version_str.split('.'): 13 | if x.isdigit(): 14 | digit_ver.append(int(x)) 15 | elif x.find('rc') != -1: 16 | patch_version = x.split('rc') 17 | digit_ver.append(int(patch_version[0]) - 1) 18 | digit_ver.append(int(patch_version[1])) 19 | return digit_ver 20 | 21 | 22 | MMCV_MIN = '1.3.13' 23 | MMCV_MAX = '1.8' 24 | 25 | mmcv_min_version = digit_version(MMCV_MIN) 26 | mmcv_max_version = digit_version(MMCV_MAX) 27 | mmcv_version = digit_version(mmcv.__version__) 28 | 29 | 30 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ 31 | f'mmcv=={mmcv.__version__} is used but incompatible. ' \ 32 | f'Please install mmcv-full>={mmcv_min_version}, <={mmcv_max_version}.' 33 | 34 | __all__ = ['__version__', 'version_info'] 35 | -------------------------------------------------------------------------------- /mmedit/models/common/img_normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ImgNormalize(nn.Conv2d): 7 | """Normalize images with the given mean and std value. 8 | 9 | Based on Conv2d layer, can work in GPU. 10 | 11 | Args: 12 | pixel_range (float): Pixel range of feature. 13 | img_mean (Tuple[float]): Image mean of each channel. 14 | img_std (Tuple[float]): Image std of each channel. 15 | sign (int): Sign of bias. Default -1. 16 | """ 17 | 18 | def __init__(self, pixel_range, img_mean, img_std, sign=-1): 19 | 20 | assert len(img_mean) == len(img_std) 21 | num_channels = len(img_mean) 22 | super().__init__(num_channels, num_channels, kernel_size=1) 23 | 24 | std = torch.Tensor(img_std) 25 | self.weight.data = torch.eye(num_channels).view( 26 | num_channels, num_channels, 1, 1) 27 | self.weight.data.div_(std.view(num_channels, 1, 1, 1)) 28 | self.bias.data = sign * pixel_range * torch.Tensor(img_mean) 29 | self.bias.data.div_(std) 30 | 31 | self.weight.requires_grad = False 32 | self.bias.requires_grad = False 33 | -------------------------------------------------------------------------------- /mmedit/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .registry import DATASETS 3 | 4 | 5 | @DATASETS.register_module() 6 | class RepeatDataset: 7 | """A wrapper of repeated dataset. 8 | 9 | The length of repeated dataset will be `times` larger than the original 10 | dataset. This is useful when the data loading time is long but the dataset 11 | is small. Using RepeatDataset can reduce the data loading time between 12 | epochs. 13 | 14 | Args: 15 | dataset (:obj:`Dataset`): The dataset to be repeated. 16 | times (int): Repeat times. 17 | """ 18 | 19 | def __init__(self, dataset, times): 20 | self.dataset = dataset 21 | self.times = times 22 | 23 | self._ori_len = len(self.dataset) 24 | 25 | def __getitem__(self, idx): 26 | """Get item at each call. 27 | 28 | Args: 29 | idx (int): Index for getting each item. 30 | """ 31 | return self.dataset[idx % self._ori_len] 32 | 33 | def __len__(self): 34 | """Length of the dataset. 35 | 36 | Returns: 37 | int: Length of the dataset. 38 | """ 39 | return self.times * self._ori_len 40 | -------------------------------------------------------------------------------- /mmedit/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def sync_random_seed(seed=None, device='cuda'): 9 | """Make sure different ranks share the same seed. 10 | 11 | All workers must call this function, otherwise it will deadlock. 12 | This method is generally used in `DistributedSampler`, 13 | because the seed should be identical across all processes 14 | in the distributed group. 15 | Args: 16 | seed (int, Optional): The seed. Default to None. 17 | device (str): The device where the seed will be put on. 18 | Default to 'cuda'. 19 | Returns: 20 | int: Seed to be used. 21 | """ 22 | if seed is None: 23 | seed = np.random.randint(2**31) 24 | assert isinstance(seed, int) 25 | 26 | rank, world_size = get_dist_info() 27 | 28 | if world_size == 1: 29 | return seed 30 | 31 | if rank == 0: 32 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 33 | else: 34 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 35 | dist.broadcast(random_num, src=0) 36 | return random_num.item() 37 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | from packaging import version 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Process a checkpoint to be published') 12 | parser.add_argument('in_file', help='input checkpoint filename') 13 | parser.add_argument('out_file', help='output checkpoint filename') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def process_checkpoint(in_file, out_file): 19 | checkpoint = torch.load(in_file, map_location='cpu') 20 | # remove optimizer for smaller file size 21 | if 'optimizer' in checkpoint: 22 | del checkpoint['optimizer'] 23 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 24 | # add the code here. 25 | if version.parse(torch.__version__) >= version.parse('1.6'): 26 | torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) 27 | else: 28 | torch.save(checkpoint, out_file) 29 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 30 | final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' 31 | subprocess.Popen(['mv', out_file, final_file]) 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | process_checkpoint(args.in_file, args.out_file) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /tools/deployment/test_torchserver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | 4 | import cv2 5 | import numpy as np 6 | import requests 7 | from PIL import Image 8 | 9 | 10 | def parse_args(): 11 | parser = ArgumentParser() 12 | parser.add_argument('model_name', help='The model name in the server') 13 | parser.add_argument( 14 | '--inference-addr', 15 | default='127.0.0.1:8080', 16 | help='Address and port of the inference server') 17 | parser.add_argument('--img-path', type=str, help='The input LQ image.') 18 | parser.add_argument( 19 | '--save-path', type=str, help='Path to save the generated GT image.') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def save_results(content, save_path, ori_shape): 25 | ori_len = np.prod(ori_shape) 26 | scale = int(np.sqrt(len(content) / ori_len)) 27 | target_size = [int(size * scale) for size in ori_shape[:2][::-1]] 28 | # Convert to RGB and save image 29 | img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0) 30 | img.save(save_path) 31 | 32 | 33 | def main(args): 34 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 35 | ori_shape = cv2.imread(args.img_path).shape 36 | with open(args.img_path, 'rb') as image: 37 | response = requests.post(url, image) 38 | save_results(response.content, args.save_path, ori_shape) 39 | 40 | 41 | if __name__ == '__main__': 42 | parsed_args = parse_args() 43 | main(parsed_args) 44 | -------------------------------------------------------------------------------- /configs/dehazers/mapnet/mapnet_hazeworld.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/hazeworld.py', 3 | '../_base_/default_runtime.py', './mapnet_runtime.py', 4 | '../_base_/schedules/schedule_40k.py' 5 | ] 6 | 7 | exp_name = 'mapnet_hazeworld_40k' 8 | 9 | checkpoint = 'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-tiny_3rdparty_32xb128-noema_in1k_20220301-795e9634.pth' # noqa 10 | 11 | # model settings 12 | model = dict( 13 | type='MAP', 14 | generator=dict( 15 | type='MAPNet', 16 | backbone=dict( 17 | type='ConvNeXt', 18 | arch='tiny', 19 | out_indices=[0, 1, 2, 3], 20 | drop_path_rate=0.0, 21 | layer_scale_init_value=1.0, 22 | gap_before_final_norm=False, 23 | init_cfg=dict(type='Pretrained', checkpoint=checkpoint, prefix='backbone.'), 24 | ), 25 | neck=dict( 26 | type='ProjectionHead', 27 | in_channels=[96, 192, 384, 768], 28 | out_channels=64, 29 | num_outs=4 30 | ), 31 | upsampler=dict( 32 | type='MAPUpsampler', 33 | embed_dim=32, 34 | num_feat=32, 35 | ), 36 | channels=32, 37 | num_trans_bins=32, 38 | align_depths=(1, 1, 1, 1), 39 | num_kv_frames=[1, 2, 3], 40 | ), 41 | 42 | pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'), 43 | ) 44 | 45 | data = dict( 46 | train_dataloader=dict(samples_per_gpu=2, drop_last=True), 47 | ) 48 | 49 | # runtime settings 50 | work_dir = f'./work_dirs/{exp_name}' 51 | -------------------------------------------------------------------------------- /mmedit/models/common/upsample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .sr_backbone_utils import default_init_weights 6 | 7 | 8 | class PixelShufflePack(nn.Module): 9 | """Pixel Shuffle upsample layer. 10 | 11 | Args: 12 | in_channels (int): Number of input channels. 13 | out_channels (int): Number of output channels. 14 | scale_factor (int): Upsample ratio. 15 | upsample_kernel (int): Kernel size of Conv layer to expand channels. 16 | 17 | Returns: 18 | Upsampled feature map. 19 | """ 20 | 21 | def __init__(self, in_channels, out_channels, scale_factor, 22 | upsample_kernel): 23 | super().__init__() 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.scale_factor = scale_factor 27 | self.upsample_kernel = upsample_kernel 28 | self.upsample_conv = nn.Conv2d( 29 | self.in_channels, 30 | self.out_channels * scale_factor * scale_factor, 31 | self.upsample_kernel, 32 | padding=(self.upsample_kernel - 1) // 2) 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | """Initialize weights for PixelShufflePack.""" 37 | default_init_weights(self, 1) 38 | 39 | def forward(self, x): 40 | """Forward function for PixelShufflePack. 41 | 42 | Args: 43 | x (Tensor): Input tensor with shape (n, c, h, w). 44 | 45 | Returns: 46 | Tensor: Forward results. 47 | """ 48 | x = self.upsample_conv(x) 49 | x = F.pixel_shuffle(x, self.scale_factor) 50 | return x 51 | -------------------------------------------------------------------------------- /mmedit/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv import build_from_cfg 4 | 5 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS 6 | 7 | 8 | def build(cfg, registry, default_args=None): 9 | """Build module function. 10 | 11 | Args: 12 | cfg (dict): Configuration for building modules. 13 | registry (obj): ``registry`` object. 14 | default_args (dict, optional): Default arguments. Defaults to None. 15 | """ 16 | if isinstance(cfg, list): 17 | modules = [ 18 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 19 | ] 20 | return nn.Sequential(*modules) 21 | 22 | return build_from_cfg(cfg, registry, default_args) 23 | 24 | 25 | def build_backbone(cfg): 26 | """Build backbone. 27 | 28 | Args: 29 | cfg (dict): Configuration for building backbone. 30 | """ 31 | return build(cfg, BACKBONES) 32 | 33 | 34 | def build_component(cfg): 35 | """Build component. 36 | 37 | Args: 38 | cfg (dict): Configuration for building component. 39 | """ 40 | return build(cfg, COMPONENTS) 41 | 42 | 43 | def build_loss(cfg): 44 | """Build loss. 45 | 46 | Args: 47 | cfg (dict): Configuration for building loss. 48 | """ 49 | return build(cfg, LOSSES) 50 | 51 | 52 | def build_model(cfg, train_cfg=None, test_cfg=None): 53 | """Build model. 54 | 55 | Args: 56 | cfg (dict): Configuration for building model. 57 | train_cfg (dict): Training configuration. Default: None. 58 | test_cfg (dict): Testing configuration. Default: None. 59 | """ 60 | return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 61 | -------------------------------------------------------------------------------- /mmedit/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aspp import ASPP 3 | from .contextual_attention import ContextualAttentionModule 4 | from .conv import * # noqa: F401, F403 5 | from .downsample import pixel_unshuffle 6 | from .ensemble import SpatialTemporalEnsemble 7 | from .flow_warp import flow_warp 8 | from .gated_conv_module import SimpleGatedConvModule 9 | from .gca_module import GCAModule 10 | from .generation_model_utils import (GANImageBuffer, ResidualBlockWithDropout, 11 | UnetSkipConnectionBlock, 12 | generation_init_weights) 13 | from .img_normalize import ImgNormalize 14 | from .linear_module import LinearModule 15 | from .mask_conv_module import MaskConvModule 16 | from .model_utils import (extract_around_bbox, extract_bbox_patch, scale_bbox, 17 | set_requires_grad) 18 | from .partial_conv import PartialConv2d 19 | from .separable_conv_module import DepthwiseSeparableConvModule 20 | from .sr_backbone_utils import (ResidualBlockNoBN, default_init_weights, 21 | make_layer) 22 | from .upsample import PixelShufflePack 23 | 24 | __all__ = [ 25 | 'ASPP', 'PartialConv2d', 'PixelShufflePack', 'default_init_weights', 26 | 'ResidualBlockNoBN', 'make_layer', 'MaskConvModule', 'extract_bbox_patch', 27 | 'extract_around_bbox', 'set_requires_grad', 'scale_bbox', 28 | 'DepthwiseSeparableConvModule', 'ContextualAttentionModule', 'GCAModule', 29 | 'SimpleGatedConvModule', 'LinearModule', 'flow_warp', 'ImgNormalize', 30 | 'generation_init_weights', 'GANImageBuffer', 'UnetSkipConnectionBlock', 31 | 'ResidualBlockWithDropout', 'pixel_unshuffle', 'SpatialTemporalEnsemble' 32 | ] 33 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections.abc import Sequence 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..registry import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose: 11 | """Compose a data pipeline with a sequence of transforms. 12 | 13 | Args: 14 | transforms (list[dict | callable]): 15 | Either config dicts of transforms or transform objects. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError(f'transform must be callable or a dict, ' 29 | f'but got {type(transform)}') 30 | 31 | def __call__(self, data): 32 | """Call function. 33 | 34 | Args: 35 | data (dict): A dict containing the necessary information and 36 | data for augmentation. 37 | 38 | Returns: 39 | dict: A dict containing the processed data and information. 40 | """ 41 | for t in self.transforms: 42 | data = t(data) 43 | if data is None: 44 | return None 45 | return data 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.transforms: 50 | format_string += '\n' 51 | format_string += f' {t}' 52 | format_string += '\n)' 53 | return format_string 54 | -------------------------------------------------------------------------------- /mmedit/apis/inpainting_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.parallel import collate, scatter 4 | 5 | from mmedit.datasets.pipelines import Compose 6 | 7 | 8 | def inpainting_inference(model, masked_img, mask): 9 | """Inference image with the model. 10 | 11 | Args: 12 | model (nn.Module): The loaded model. 13 | masked_img (str): File path of image with mask. 14 | mask (str): Mask file path. 15 | 16 | Returns: 17 | Tensor: The predicted inpainting result. 18 | """ 19 | device = next(model.parameters()).device # model device 20 | 21 | infer_pipeline = [ 22 | dict(type='LoadImageFromFile', key='masked_img'), 23 | dict(type='LoadMask', mask_mode='file', mask_config=dict()), 24 | dict(type='Pad', keys=['masked_img', 'mask'], mode='reflect'), 25 | dict( 26 | type='Normalize', 27 | keys=['masked_img'], 28 | mean=[127.5] * 3, 29 | std=[127.5] * 3, 30 | to_rgb=False), 31 | dict(type='GetMaskedImage', img_name='masked_img'), 32 | dict( 33 | type='Collect', 34 | keys=['masked_img', 'mask'], 35 | meta_keys=['masked_img_path']), 36 | dict(type='ImageToTensor', keys=['masked_img', 'mask']) 37 | ] 38 | 39 | # build the data pipeline 40 | test_pipeline = Compose(infer_pipeline) 41 | # prepare data 42 | data = dict(masked_img_path=masked_img, mask_path=mask) 43 | data = test_pipeline(data) 44 | data = collate([data], samples_per_gpu=1) 45 | if 'cuda' in str(device): 46 | data = scatter(data, [device])[0] 47 | else: 48 | data.pop('meta') 49 | # forward the model 50 | with torch.no_grad(): 51 | result = model(test_mode=True, **data) 52 | 53 | return result['fake_img'] 54 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.parallel import collate, scatter 4 | 5 | from mmedit.datasets.pipelines import Compose 6 | 7 | 8 | def restoration_inference(model, img, ref=None): 9 | """Inference image with the model. 10 | 11 | Args: 12 | model (nn.Module): The loaded model. 13 | img (str): File path of input image. 14 | ref (str | None): File path of reference image. Default: None. 15 | 16 | Returns: 17 | Tensor: The predicted restoration result. 18 | """ 19 | cfg = model.cfg 20 | device = next(model.parameters()).device # model device 21 | # remove gt from test_pipeline 22 | keys_to_remove = ['gt', 'gt_path'] 23 | for key in keys_to_remove: 24 | for pipeline in list(cfg.test_pipeline): 25 | if 'key' in pipeline and key == pipeline['key']: 26 | cfg.test_pipeline.remove(pipeline) 27 | if 'keys' in pipeline and key in pipeline['keys']: 28 | pipeline['keys'].remove(key) 29 | if len(pipeline['keys']) == 0: 30 | cfg.test_pipeline.remove(pipeline) 31 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 32 | pipeline['meta_keys'].remove(key) 33 | # build the data pipeline 34 | test_pipeline = Compose(cfg.test_pipeline) 35 | # prepare data 36 | if ref: # Ref-SR 37 | data = dict(lq_path=img, ref_path=ref) 38 | else: # SISR 39 | data = dict(lq_path=img) 40 | data = test_pipeline(data) 41 | data = collate([data], samples_per_gpu=1) 42 | if 'cuda' in str(device): 43 | data = scatter(data, [device])[0] 44 | # forward the model 45 | with torch.no_grad(): 46 | result = model(test_mode=True, **data) 47 | 48 | return result['output'] 49 | -------------------------------------------------------------------------------- /mmedit/core/optimizer/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import build_optimizer 3 | 4 | 5 | def build_optimizers(model, cfgs): 6 | """Build multiple optimizers from configs. 7 | 8 | If `cfgs` contains several dicts for optimizers, then a dict for each 9 | constructed optimizers will be returned. 10 | If `cfgs` only contains one optimizer config, the constructed optimizer 11 | itself will be returned. 12 | 13 | For example, 14 | 15 | 1) Multiple optimizer configs: 16 | 17 | .. code-block:: python 18 | 19 | optimizer_cfg = dict( 20 | model1=dict(type='SGD', lr=lr), 21 | model2=dict(type='SGD', lr=lr)) 22 | 23 | The return dict is 24 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` 25 | 26 | 2) Single optimizer config: 27 | 28 | .. code-block:: python 29 | 30 | optimizer_cfg = dict(type='SGD', lr=lr) 31 | 32 | The return is ``torch.optim.Optimizer``. 33 | 34 | Args: 35 | model (:obj:`nn.Module`): The model with parameters to be optimized. 36 | cfgs (dict): The config dict of the optimizer. 37 | 38 | Returns: 39 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: 40 | The initialized optimizers. 41 | """ 42 | optimizers = {} 43 | if hasattr(model, 'module'): 44 | model = model.module 45 | # determine whether 'cfgs' has several dicts for optimizers 46 | is_dict_of_dict = True 47 | for key, cfg in cfgs.items(): 48 | if not isinstance(cfg, dict): 49 | is_dict_of_dict = False 50 | 51 | if is_dict_of_dict: 52 | for key, cfg in cfgs.items(): 53 | cfg_ = cfg.copy() 54 | module = getattr(model, key) 55 | optimizers[key] = build_optimizer(module, cfg_) 56 | return optimizers 57 | 58 | return build_optimizer(model, cfgs) 59 | -------------------------------------------------------------------------------- /mmedit/models/common/flow_warp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def flow_warp(x, 7 | flow, 8 | interpolation='bilinear', 9 | padding_mode='zeros', 10 | align_corners=True): 11 | """Warp an image or a feature map with optical flow. 12 | 13 | Args: 14 | x (Tensor): Tensor with size (n, c, h, w). 15 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 16 | a two-channel, denoting the width and height relative offsets. 17 | Note that the values are not normalized to [-1, 1]. 18 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 19 | Default: 'bilinear'. 20 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 21 | Default: 'zeros'. 22 | align_corners (bool): Whether align corners. Default: True. 23 | 24 | Returns: 25 | Tensor: Warped image or feature map. 26 | """ 27 | if x.size()[-2:] != flow.size()[1:3]: 28 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 29 | f'flow ({flow.size()[1:3]}) are not the same.') 30 | _, _, h, w = x.size() 31 | # create mesh grid 32 | device = flow.device 33 | grid_y, grid_x = torch.meshgrid( 34 | torch.arange(0, h, device=device, dtype=x.dtype), 35 | torch.arange(0, w, device=device, dtype=x.dtype)) 36 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2 37 | grid.requires_grad = False 38 | 39 | grid_flow = grid + flow 40 | # scale grid_flow to [-1,1] 41 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 42 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 43 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 44 | output = F.grid_sample( 45 | x, 46 | grid_flow, 47 | mode=interpolation, 48 | padding_mode=padding_mode, 49 | align_corners=align_corners) 50 | return output 51 | -------------------------------------------------------------------------------- /mmedit/models/losses/gradient_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..registry import LOSSES 7 | from .pixelwise_loss import l1_loss 8 | 9 | _reduction_modes = ['none', 'mean', 'sum'] 10 | 11 | 12 | @LOSSES.register_module() 13 | class GradientLoss(nn.Module): 14 | """Gradient loss. 15 | 16 | Args: 17 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 18 | reduction (str): Specifies the reduction to apply to the output. 19 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 20 | """ 21 | 22 | def __init__(self, loss_weight=1.0, reduction='mean'): 23 | super().__init__() 24 | self.loss_weight = loss_weight 25 | self.reduction = reduction 26 | if self.reduction not in ['none', 'mean', 'sum']: 27 | raise ValueError(f'Unsupported reduction mode: {self.reduction}. ' 28 | f'Supported ones are: {_reduction_modes}') 29 | 30 | def forward(self, pred, target, weight=None): 31 | """ 32 | Args: 33 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 34 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 35 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 36 | weights. Default: None. 37 | """ 38 | kx = torch.Tensor([[1, 0, -1], [2, 0, -2], 39 | [1, 0, -1]]).view(1, 1, 3, 3).to(target) 40 | ky = torch.Tensor([[1, 2, 1], [0, 0, 0], 41 | [-1, -2, -1]]).view(1, 1, 3, 3).to(target) 42 | 43 | pred_grad_x = F.conv2d(pred, kx, padding=1) 44 | pred_grad_y = F.conv2d(pred, ky, padding=1) 45 | target_grad_x = F.conv2d(target, kx, padding=1) 46 | target_grad_y = F.conv2d(target, ky, padding=1) 47 | 48 | loss = ( 49 | l1_loss( 50 | pred_grad_x, target_grad_x, weight, reduction=self.reduction) + 51 | l1_loss( 52 | pred_grad_y, target_grad_y, weight, reduction=self.reduction)) 53 | return loss * self.loss_weight 54 | -------------------------------------------------------------------------------- /tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn.utils import get_model_complexity_info 6 | 7 | from mmedit.models import build_model 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Train a editor') 12 | parser.add_argument('config', help='train config file path') 13 | parser.add_argument( 14 | '--shape', 15 | type=int, 16 | nargs='+', 17 | default=[250, 250], 18 | help='input image size') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | 25 | args = parse_args() 26 | 27 | if len(args.shape) == 1: 28 | input_shape = (3, args.shape[0], args.shape[0]) 29 | elif len(args.shape) == 2: 30 | input_shape = (3, ) + tuple(args.shape) 31 | elif len(args.shape) in [3, 4]: # 4 for video inputs (t, c, h, w) 32 | input_shape = tuple(args.shape) 33 | else: 34 | raise ValueError('invalid input shape') 35 | 36 | cfg = Config.fromfile(args.config) 37 | model = build_model( 38 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda() 39 | model.eval() 40 | 41 | if hasattr(model, 'forward_dummy'): 42 | model.forward = model.forward_dummy 43 | else: 44 | raise NotImplementedError( 45 | 'FLOPs counter is currently not currently supported ' 46 | f'with {model.__class__.__name__}') 47 | 48 | flops, params = get_model_complexity_info(model, input_shape) 49 | 50 | split_line = '=' * 30 51 | print(f'{split_line}\nInput shape: {input_shape}\n' 52 | f'Flops: {flops}\nParams: {params}\n{split_line}') 53 | if len(input_shape) == 4: 54 | print('!!!If your network computes N frames in one forward pass, you ' 55 | 'may want to divide the FLOPs by N to get the average FLOPs ' 56 | 'for each frame.') 57 | print('!!!Please be cautious if you use the results in papers. ' 58 | 'You may need to check if all ops are supported and verify that the ' 59 | 'flops computation is correct.') 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/*.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/en/_build/ 69 | docs/en/_tmp/ 70 | docs/zh_cn/_build/ 71 | docs/zh_cn/_tmp/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | # custom 111 | .vscode 112 | .idea 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | /data/ 118 | /data 119 | mmedit/.mim 120 | 121 | # Pytorch 122 | *.pth 123 | 124 | # onnx and tensorrt 125 | *.onnx 126 | *.trt 127 | 128 | # local history 129 | .history/** 130 | 131 | # Pytorch Server 132 | *.mar 133 | 134 | # MacOS 135 | .DS_Store 136 | 137 | # by jqxu 138 | scripts/ 139 | -------------------------------------------------------------------------------- /mmedit/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from .pipelines import Compose 8 | 9 | 10 | class BaseDataset(Dataset, metaclass=ABCMeta): 11 | """Base class for datasets. 12 | 13 | All datasets should subclass it. 14 | All subclasses should overwrite: 15 | 16 | ``load_annotations``, supporting to load information and generate 17 | image lists. 18 | 19 | Args: 20 | pipeline (list[dict | callable]): A sequence of data transforms. 21 | test_mode (bool): If True, the dataset will work in test mode. 22 | Otherwise, in train mode. 23 | """ 24 | 25 | def __init__(self, pipeline, test_mode=False): 26 | super().__init__() 27 | self.test_mode = test_mode 28 | self.pipeline = Compose(pipeline) 29 | 30 | @abstractmethod 31 | def load_annotations(self): 32 | """Abstract function for loading annotation. 33 | 34 | All subclasses should overwrite this function 35 | """ 36 | 37 | def prepare_train_data(self, idx): 38 | """Prepare training data. 39 | 40 | Args: 41 | idx (int): Index of the training batch data. 42 | 43 | Returns: 44 | dict: Returned training batch. 45 | """ 46 | results = copy.deepcopy(self.data_infos[idx]) 47 | return self.pipeline(results) 48 | 49 | def prepare_test_data(self, idx): 50 | """Prepare testing data. 51 | 52 | Args: 53 | idx (int): Index for getting each testing batch. 54 | 55 | Returns: 56 | Tensor: Returned testing batch. 57 | """ 58 | results = copy.deepcopy(self.data_infos[idx]) 59 | return self.pipeline(results) 60 | 61 | def __len__(self): 62 | """Length of the dataset. 63 | 64 | Returns: 65 | int: Length of the dataset. 66 | """ 67 | return len(self.data_infos) 68 | 69 | def __getitem__(self, idx): 70 | """Get item at each call. 71 | 72 | Args: 73 | idx (int): Index for getting each item. 74 | """ 75 | if self.test_mode: 76 | return self.prepare_test_data(idx) 77 | 78 | return self.prepare_train_data(idx) 79 | -------------------------------------------------------------------------------- /tools/deployment/mmedit_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import random 4 | import string 5 | from io import BytesIO 6 | 7 | import PIL.Image as Image 8 | import torch 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmedit.apis import init_model, restoration_inference 12 | from mmedit.core import tensor2img 13 | 14 | 15 | class MMEditHandler(BaseHandler): 16 | 17 | def initialize(self, context): 18 | print('MMEditHandler.initialize is called') 19 | properties = context.system_properties 20 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | self.device = torch.device(self.map_location + ':' + 22 | str(properties.get('gpu_id')) if torch.cuda. 23 | is_available() else self.map_location) 24 | self.manifest = context.manifest 25 | 26 | model_dir = properties.get('model_dir') 27 | serialized_file = self.manifest['model']['serializedFile'] 28 | checkpoint = os.path.join(model_dir, serialized_file) 29 | self.config_file = os.path.join(model_dir, 'config.py') 30 | 31 | self.model = init_model(self.config_file, checkpoint, self.device) 32 | self.initialized = True 33 | 34 | def preprocess(self, data, *args, **kwargs): 35 | body = data[0].get('data') or data[0].get('body') 36 | result = Image.open(BytesIO(body)) 37 | # data preprocess is in inference. 38 | return result 39 | 40 | def inference(self, data, *args, **kwargs): 41 | # generate temp image path for restoration_inference 42 | temp_name = ''.join( 43 | random.sample(string.ascii_letters + string.digits, 18)) 44 | temp_path = f'./{temp_name}.png' 45 | data.save(temp_path) 46 | results = restoration_inference(self.model, temp_path) 47 | # delete the temp image path 48 | os.remove(temp_path) 49 | return results 50 | 51 | def postprocess(self, data): 52 | # convert torch tensor to numpy and then convert to bytes 53 | output_list = [] 54 | for data_ in data: 55 | data_np = tensor2img(data_) 56 | data_byte = data_np.tobytes() 57 | output_list.append(data_byte) 58 | 59 | return output_list 60 | -------------------------------------------------------------------------------- /mmedit/utils/setup_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | import warnings 5 | 6 | import cv2 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def setup_multi_processes(cfg): 11 | """Setup multi-processing environment variables.""" 12 | # set multi-process start method as `fork` to speed up the training 13 | if platform.system() != 'Windows': 14 | mp_start_method = cfg.get('mp_start_method', 'fork') 15 | current_method = mp.get_start_method(allow_none=True) 16 | if current_method is not None and current_method != mp_start_method: 17 | warnings.warn( 18 | f'Multi-processing start method `{mp_start_method}` is ' 19 | f'different from the previous setting `{current_method}`.' 20 | f'It will be force set to `{mp_start_method}`. You can change ' 21 | f'this behavior by changing `mp_start_method` in your config.') 22 | mp.set_start_method(mp_start_method, force=True) 23 | 24 | # disable opencv multithreading to avoid system being overloaded 25 | opencv_num_threads = cfg.get('opencv_num_threads', 0) 26 | cv2.setNumThreads(opencv_num_threads) 27 | 28 | # setup OMP threads 29 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 30 | if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 31 | omp_num_threads = 1 32 | warnings.warn( 33 | f'Setting OMP_NUM_THREADS environment variable for each process ' 34 | f'to be {omp_num_threads} in default, to avoid your system being ' 35 | f'overloaded, please further tune the variable for optimal ' 36 | f'performance in your application as needed.') 37 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 38 | 39 | # setup MKL threads 40 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 41 | mkl_num_threads = 1 42 | warnings.warn( 43 | f'Setting MKL_NUM_THREADS environment variable for each process ' 44 | f'to be {mkl_num_threads} in default, to avoid your system being ' 45 | f'overloaded, please further tune the variable for optimal ' 46 | f'performance in your application as needed.') 47 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 48 | -------------------------------------------------------------------------------- /mmedit/apis/generation_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | 6 | from mmedit.core import tensor2img 7 | from mmedit.datasets.pipelines import Compose 8 | 9 | 10 | def generation_inference(model, img, img_unpaired=None): 11 | """Inference image with the model. 12 | 13 | Args: 14 | model (nn.Module): The loaded model. 15 | img (str): File path of input image. 16 | img_unpaired (str, optional): File path of the unpaired image. 17 | If not None, perform unpaired image generation. Default: None. 18 | 19 | Returns: 20 | np.ndarray: The predicted generation result. 21 | """ 22 | cfg = model.cfg 23 | device = next(model.parameters()).device # model device 24 | # build the data pipeline 25 | test_pipeline = Compose(cfg.test_pipeline) 26 | # prepare data 27 | if img_unpaired is None: 28 | data = dict(pair_path=img) 29 | else: 30 | data = dict(img_a_path=img, img_b_path=img_unpaired) 31 | data = test_pipeline(data) 32 | data = collate([data], samples_per_gpu=1) 33 | if 'cuda' in str(device): 34 | data = scatter(data, [device])[0] 35 | # forward the model 36 | with torch.no_grad(): 37 | results = model(test_mode=True, **data) 38 | # process generation shown mode 39 | if img_unpaired is None: 40 | if model.show_input: 41 | output = np.concatenate([ 42 | tensor2img(results['real_a'], min_max=(-1, 1)), 43 | tensor2img(results['fake_b'], min_max=(-1, 1)), 44 | tensor2img(results['real_b'], min_max=(-1, 1)) 45 | ], 46 | axis=1) 47 | else: 48 | output = tensor2img(results['fake_b'], min_max=(-1, 1)) 49 | else: 50 | if model.show_input: 51 | output = np.concatenate([ 52 | tensor2img(results['real_a'], min_max=(-1, 1)), 53 | tensor2img(results['fake_b'], min_max=(-1, 1)), 54 | tensor2img(results['real_b'], min_max=(-1, 1)), 55 | tensor2img(results['fake_a'], min_max=(-1, 1)) 56 | ], 57 | axis=1) 58 | else: 59 | if model.test_direction == 'a2b': 60 | output = tensor2img(results['fake_b'], min_max=(-1, 1)) 61 | else: 62 | output = tensor2img(results['fake_a'], min_max=(-1, 1)) 63 | return output 64 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def gaussian(x, sigma): 7 | """Gaussian function. 8 | 9 | Args: 10 | x (array_like): The independent variable. 11 | sigma (float): Standard deviation of the gaussian function. 12 | 13 | Return: 14 | ndarray or scalar: Gaussian value of `x`. 15 | """ 16 | return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) 17 | 18 | 19 | def dgaussian(x, sigma): 20 | """Gradient of gaussian. 21 | 22 | Args: 23 | x (array_like): The independent variable. 24 | sigma (float): Standard deviation of the gaussian function. 25 | 26 | Return: 27 | ndarray or scalar: Gradient of gaussian of `x`. 28 | """ 29 | return -x * gaussian(x, sigma) / sigma**2 30 | 31 | 32 | def gauss_filter(sigma, epsilon=1e-2): 33 | """Gradient of gaussian. 34 | 35 | Args: 36 | sigma (float): Standard deviation of the gaussian kernel. 37 | epsilon (float): Small value used when calculating kernel size. 38 | Default: 1e-2. 39 | 40 | Return: 41 | tuple[ndarray]: Gaussian filter along x and y axis. 42 | """ 43 | half_size = np.ceil( 44 | sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) 45 | size = int(2 * half_size + 1) 46 | 47 | # create filter in x axis 48 | filter_x = np.zeros((size, size)) 49 | for i in range(size): 50 | for j in range(size): 51 | filter_x[i, j] = gaussian(i - half_size, sigma) * dgaussian( 52 | j - half_size, sigma) 53 | 54 | # normalize filter 55 | norm = np.sqrt((filter_x**2).sum()) 56 | filter_x = filter_x / norm 57 | filter_y = np.transpose(filter_x) 58 | 59 | return filter_x, filter_y 60 | 61 | 62 | def gauss_gradient(img, sigma): 63 | """Gaussian gradient. 64 | 65 | From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ 66 | submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ 67 | index.html 68 | 69 | Args: 70 | img (ndarray): Input image. 71 | sigma (float): Standard deviation of the gaussian kernel. 72 | 73 | Return: 74 | ndarray: Gaussian gradient of input `img`. 75 | """ 76 | filter_x, filter_y = gauss_filter(sigma) 77 | img_filtered_x = cv2.filter2D( 78 | img, -1, filter_x, borderType=cv2.BORDER_REPLICATE) 79 | img_filtered_y = cv2.filter2D( 80 | img, -1, filter_y, borderType=cv2.BORDER_REPLICATE) 81 | return np.sqrt(img_filtered_x**2 + img_filtered_y**2) 82 | -------------------------------------------------------------------------------- /mmedit/models/common/gated_conv_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | from mmcv.cnn import ConvModule, build_activation_layer 7 | 8 | 9 | class SimpleGatedConvModule(nn.Module): 10 | """Simple Gated Convolutional Module. 11 | 12 | This module is a simple gated convolutional module. The detailed formula 13 | is: 14 | 15 | .. math:: 16 | y = \\phi(conv1(x)) * \\sigma(conv2(x)), 17 | 18 | where `phi` is the feature activation function and `sigma` is the gate 19 | activation function. In default, the gate activation function is sigmoid. 20 | 21 | Args: 22 | in_channels (int): Same as nn.Conv2d. 23 | out_channels (int): The number of channels of the output feature. Note 24 | that `out_channels` in the conv module is doubled since this module 25 | contains two convolutions for feature and gate separately. 26 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 27 | feat_act_cfg (dict): Config dict for feature activation layer. 28 | gate_act_cfg (dict): Config dict for gate activation layer. 29 | kwargs (keyword arguments): Same as `ConvModule`. 30 | """ 31 | 32 | def __init__(self, 33 | in_channels, 34 | out_channels, 35 | kernel_size, 36 | feat_act_cfg=dict(type='ELU'), 37 | gate_act_cfg=dict(type='Sigmoid'), 38 | **kwargs): 39 | super().__init__() 40 | # the activation function should specified outside conv module 41 | kwargs_ = copy.deepcopy(kwargs) 42 | kwargs_['act_cfg'] = None 43 | self.with_feat_act = feat_act_cfg is not None 44 | self.with_gate_act = gate_act_cfg is not None 45 | 46 | self.conv = ConvModule(in_channels, out_channels * 2, kernel_size, 47 | **kwargs_) 48 | 49 | if self.with_feat_act: 50 | self.feat_act = build_activation_layer(feat_act_cfg) 51 | 52 | if self.with_gate_act: 53 | self.gate_act = build_activation_layer(gate_act_cfg) 54 | 55 | def forward(self, x): 56 | """Forward Function. 57 | 58 | Args: 59 | x (torch.Tensor): Input tensor with shape of (n, c, h, w). 60 | 61 | Returns: 62 | torch.Tensor: Output tensor with shape of (n, c, h', w'). 63 | """ 64 | x = self.conv(x) 65 | x, gate = torch.split(x, x.size(1) // 2, dim=1) 66 | if self.with_feat_act: 67 | x = self.feat_act(x) 68 | if self.with_gate_act: 69 | gate = self.gate_act(gate) 70 | x = x * gate 71 | 72 | return x 73 | -------------------------------------------------------------------------------- /mmedit/apis/matting_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | from mmcv.runner import load_checkpoint 6 | 7 | from mmedit.datasets.pipelines import Compose 8 | from mmedit.models import build_model 9 | 10 | 11 | def init_model(config, checkpoint=None, device='cuda:0'): 12 | """Initialize a model from config file. 13 | 14 | Args: 15 | config (str or :obj:`mmcv.Config`): Config file path or the config 16 | object. 17 | checkpoint (str, optional): Checkpoint path. If left as None, the model 18 | will not load any weights. 19 | device (str): Which device the model will deploy. Default: 'cuda:0'. 20 | 21 | Returns: 22 | nn.Module: The constructed model. 23 | """ 24 | if isinstance(config, str): 25 | config = mmcv.Config.fromfile(config) 26 | elif not isinstance(config, mmcv.Config): 27 | raise TypeError('config must be a filename or Config object, ' 28 | f'but got {type(config)}') 29 | config.model.pretrained = None 30 | config.test_cfg.metrics = None 31 | model = build_model(config.model, test_cfg=config.test_cfg) 32 | if checkpoint is not None: 33 | checkpoint = load_checkpoint(model, checkpoint) 34 | 35 | model.cfg = config # save the config in the model for convenience 36 | model.to(device) 37 | model.eval() 38 | return model 39 | 40 | 41 | def matting_inference(model, img, trimap): 42 | """Inference image(s) with the model. 43 | 44 | Args: 45 | model (nn.Module): The loaded model. 46 | img (str): Image file path. 47 | trimap (str): Trimap file path. 48 | 49 | Returns: 50 | np.ndarray: The predicted alpha matte. 51 | """ 52 | cfg = model.cfg 53 | device = next(model.parameters()).device # model device 54 | # remove alpha from test_pipeline 55 | keys_to_remove = ['alpha', 'ori_alpha'] 56 | for key in keys_to_remove: 57 | for pipeline in list(cfg.test_pipeline): 58 | if 'key' in pipeline and key == pipeline['key']: 59 | cfg.test_pipeline.remove(pipeline) 60 | if 'keys' in pipeline and key in pipeline['keys']: 61 | pipeline['keys'].remove(key) 62 | if len(pipeline['keys']) == 0: 63 | cfg.test_pipeline.remove(pipeline) 64 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 65 | pipeline['meta_keys'].remove(key) 66 | # build the data pipeline 67 | test_pipeline = Compose(cfg.test_pipeline) 68 | # prepare data 69 | data = dict(merged_path=img, trimap_path=trimap) 70 | data = test_pipeline(data) 71 | data = collate([data], samples_per_gpu=1) 72 | if 'cuda' in str(device): 73 | data = scatter(data, [device])[0] 74 | # forward the model 75 | with torch.no_grad(): 76 | result = model(test_mode=True, **data) 77 | 78 | return result['pred_alpha'] 79 | -------------------------------------------------------------------------------- /mmedit/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | import logging 4 | import re 5 | import textwrap 6 | from typing import Callable 7 | 8 | from mmcv import print_log 9 | 10 | 11 | def deprecated_function(since: str, removed_in: str, 12 | instructions: str) -> Callable: 13 | """Marks functions as deprecated. 14 | 15 | Throw a warning when a deprecated function is called, and add a note in the 16 | docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py 17 | Args: 18 | since (str): The version when the function was first deprecated. 19 | removed_in (str): The version when the function will be removed. 20 | instructions (str): The action users should take. 21 | Returns: 22 | Callable: A new function, which will be deprecated soon. 23 | """ # noqa: E501 24 | 25 | def decorator(function): 26 | 27 | @functools.wraps(function) 28 | def wrapper(*args, **kwargs): 29 | print_log( 30 | f"'{function.__module__}.{function.__name__}' " 31 | f'is deprecated in version {since} and will be ' 32 | f'removed in version {removed_in}. Please {instructions}.', 33 | # logger='current', 34 | level=logging.WARNING, 35 | ) 36 | return function(*args, **kwargs) 37 | 38 | indent = ' ' 39 | # Add a deprecation note to the docstring. 40 | docstring = function.__doc__ or '' 41 | # Add a note to the docstring. 42 | deprecation_note = textwrap.dedent(f"""\ 43 | .. deprecated:: {since} 44 | Deprecated and will be removed in version {removed_in}. 45 | Please {instructions}. 46 | """) 47 | # Split docstring at first occurrence of newline 48 | pattern = '\n\n' 49 | summary_and_body = re.split(pattern, docstring, 1) 50 | 51 | if len(summary_and_body) > 1: 52 | summary, body = summary_and_body 53 | body = textwrap.indent(textwrap.dedent(body), indent) 54 | summary = '\n'.join( 55 | [textwrap.dedent(string) for string in summary.split('\n')]) 56 | summary = textwrap.indent(summary, prefix=indent) 57 | # Dedent the body. We cannot do this with the presence of the 58 | # summary because the body contains leading whitespaces when the 59 | # summary does not. 60 | new_docstring_parts = [ 61 | deprecation_note, '\n\n', summary, '\n\n', body 62 | ] 63 | else: 64 | summary = summary_and_body[0] 65 | summary = '\n'.join( 66 | [textwrap.dedent(string) for string in summary.split('\n')]) 67 | summary = textwrap.indent(summary, prefix=indent) 68 | new_docstring_parts = [deprecation_note, '\n\n', summary] 69 | 70 | wrapper.__doc__ = ''.join(new_docstring_parts) 71 | 72 | return wrapper 73 | 74 | return decorator 75 | -------------------------------------------------------------------------------- /mmedit/datasets/base_dh_dataset.py: -------------------------------------------------------------------------------- 1 | # https://github.com/open-mmlab/mmediting/blob/master/mmedit/datasets/base_sr_dataset.py 2 | import copy 3 | import os.path as osp 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | from mmcv import scandir 8 | 9 | from .base_dataset import BaseDataset 10 | 11 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 12 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF') 13 | 14 | 15 | class BaseDHDataset(BaseDataset): 16 | """Base class for dehazing datasets.""" 17 | 18 | def __init__(self, pipeline, test_mode=False): 19 | super().__init__(pipeline, test_mode) 20 | 21 | @staticmethod 22 | def scan_folder(path): 23 | """Obtain image path list (including sub-folders) from a given folder. 24 | 25 | Args: 26 | path (str | :obj:`Path`): Folder path. 27 | 28 | Returns: 29 | list[str]: image list obtained form given folder. 30 | """ 31 | 32 | if isinstance(path, (str, Path)): 33 | path = str(path) 34 | else: 35 | raise TypeError("'path' must be a str or a Path object, " 36 | f'but received {type(path)}.') 37 | 38 | images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=False)) 39 | images = [osp.join(path, v) for v in images] 40 | images.sort() 41 | assert images, f'{path} has no valid image file.' 42 | return images 43 | 44 | def __getitem__(self, idx): 45 | """Get item at each call. 46 | 47 | Args: 48 | idx (int): Index for getting each item. 49 | """ 50 | results = copy.deepcopy(self.data_infos[idx]) 51 | results['scale'] = 1 52 | return self.pipeline(results) 53 | 54 | def evaluate(self, results, logger=None): 55 | """Evaluate with different metrics. 56 | 57 | Args: 58 | results (list[tuple]): The output of forward_test() of the model. 59 | 60 | Return: 61 | dict: Evaluation results dict. 62 | """ 63 | if not isinstance(results, list): 64 | raise TypeError(f'results must be a list, but got {type(results)}') 65 | assert len(results) == len(self), ( 66 | 'The length of results is not equal to the dataset len: ' 67 | f'{len(results)} != {len(self)}') 68 | 69 | results = [res['eval_result'] for res in results] # a list of dict 70 | eval_result = defaultdict(list) # a dict of list 71 | 72 | for res in results: 73 | for metric, val in res.items(): 74 | eval_result[metric].append(val) 75 | for metric, val_list in eval_result.items(): 76 | assert len(val_list) == len(self), ( 77 | f'Length of evaluation result of {metric} is {len(val_list)}, ' 78 | f'should be {len(self)}') 79 | 80 | # average the results 81 | eval_result = { 82 | metric: sum(values) / len(self) 83 | for metric, values in eval_result.items() 84 | } 85 | 86 | return eval_result 87 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augmentation import (BinarizeImage, ColorJitter, CopyValues, Flip, 3 | GenerateFrameIndices, 4 | GenerateFrameIndiceswithPadding, 5 | GenerateSegmentIndices, MirrorSequence, Pad, 6 | Quantize, RandomAffine, RandomJitter, 7 | RandomMaskDilation, RandomTransposeHW, Resize, 8 | TemporalReverse, UnsharpMasking) 9 | from .augmentation_hazeworld import GenerateFileIndices 10 | from .compose import Compose 11 | from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown, 12 | CropLike, FixedCrop, ModCrop, PairedRandomCrop, 13 | RandomResizedCrop) 14 | from .crop_hazeworld import PairedRandomCropWithTransmission 15 | from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor, 16 | ToTensor) 17 | from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap 18 | from .loading import (GetSpatialDiscountMask, LoadImageFromFile, 19 | LoadImageFromFileList, LoadMask, LoadPairedImageFromFile, 20 | RandomLoadResizeBg) 21 | from .matlab_like_resize import MATLABLikeResize 22 | from .matting_aug import (CompositeFg, GenerateSeg, GenerateSoftSeg, 23 | GenerateTrimap, GenerateTrimapWithDistTransform, 24 | MergeFgAndBg, PerturbBg, TransformTrimap) 25 | from .normalization import Normalize, RescaleToZeroOne 26 | from .random_degradations import (DegradationsWithShuffle, RandomBlur, 27 | RandomJPEGCompression, RandomNoise, 28 | RandomResize, RandomVideoCompression) 29 | from .random_down_sampling import RandomDownSampling 30 | 31 | __all__ = [ 32 | 'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask', 33 | 'RandomLoadResizeBg', 'Compose', 'ImageToTensor', 'ToTensor', 34 | 'GetMaskedImage', 'BinarizeImage', 'Flip', 'Pad', 'RandomAffine', 35 | 'RandomJitter', 'ColorJitter', 'RandomMaskDilation', 'RandomTransposeHW', 36 | 'Resize', 'RandomResizedCrop', 'Crop', 'CropAroundCenter', 37 | 'CropAroundUnknown', 'ModCrop', 'PairedRandomCrop', 'Normalize', 38 | 'RescaleToZeroOne', 'GenerateTrimap', 'MergeFgAndBg', 'CompositeFg', 39 | 'TemporalReverse', 'LoadImageFromFileList', 'GenerateFrameIndices', 40 | 'GenerateFrameIndiceswithPadding', 'FixedCrop', 'LoadPairedImageFromFile', 41 | 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 42 | 'GetSpatialDiscountMask', 'RandomDownSampling', 43 | 'GenerateTrimapWithDistTransform', 'TransformTrimap', 44 | 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', 45 | 'CropLike', 'GenerateHeatmap', 'MATLABLikeResize', 'CopyValues', 46 | 'Quantize', 'RandomBlur', 'RandomJPEGCompression', 'RandomNoise', 47 | 'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking', 48 | 'RandomVideoCompression', 49 | 'GenerateFileIndices', 'PairedRandomCropWithTransmission' 50 | ] 51 | -------------------------------------------------------------------------------- /mmedit/core/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 10 | """Convert torch Tensors into image numpy arrays. 11 | 12 | After clamping to (min, max), image values will be normalized to [0, 1]. 13 | 14 | For different tensor shapes, this function will have different behaviors: 15 | 16 | 1. 4D mini-batch Tensor of shape (N x 3/1 x H x W): 17 | Use `make_grid` to stitch images in the batch dimension, and then 18 | convert it to numpy array. 19 | 2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W): 20 | Directly change to numpy array. 21 | 22 | Note that the image channel in input tensors should be RGB order. This 23 | function will convert it to cv2 convention, i.e., (H x W x C) with BGR 24 | order. 25 | 26 | Args: 27 | tensor (Tensor | list[Tensor]): Input tensors. 28 | out_type (numpy type): Output types. If ``np.uint8``, transform outputs 29 | to uint8 type with range [0, 255]; otherwise, float type with 30 | range [0, 1]. Default: ``np.uint8``. 31 | min_max (tuple): min and max values for clamp. 32 | 33 | Returns: 34 | (Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray 35 | of shape (H x W). 36 | """ 37 | if not (torch.is_tensor(tensor) or 38 | (isinstance(tensor, list) 39 | and all(torch.is_tensor(t) for t in tensor))): 40 | raise TypeError( 41 | f'tensor or list of tensors expected, got {type(tensor)}') 42 | 43 | if torch.is_tensor(tensor): 44 | tensor = [tensor] 45 | result = [] 46 | for _tensor in tensor: 47 | # Squeeze two times so that: 48 | # 1. (1, 1, h, w) -> (h, w) or 49 | # 3. (1, 3, h, w) -> (3, h, w) or 50 | # 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w) 51 | _tensor = _tensor.squeeze(0).squeeze(0) 52 | _tensor = _tensor.float().detach().cpu().clamp_(*min_max) 53 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 54 | n_dim = _tensor.dim() 55 | if n_dim == 4: 56 | img_np = make_grid( 57 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 58 | normalize=False).numpy() 59 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) 60 | elif n_dim == 3: 61 | img_np = _tensor.numpy() 62 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) 63 | elif n_dim == 2: 64 | img_np = _tensor.numpy() 65 | else: 66 | raise ValueError('Only support 4D, 3D or 2D tensor. ' 67 | f'But received with dimension: {n_dim}') 68 | if out_type == np.uint8: 69 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 70 | img_np = (img_np * 255.0).round() 71 | img_np = img_np.astype(out_type) 72 | result.append(img_np) 73 | result = result[0] if len(result) == 1 else result 74 | return result 75 | -------------------------------------------------------------------------------- /mmedit/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from __future__ import division 3 | import math 4 | 5 | import torch 6 | from torch.utils.data import DistributedSampler as _DistributedSampler 7 | 8 | from mmedit.core.utils import sync_random_seed 9 | 10 | 11 | class DistributedSampler(_DistributedSampler): 12 | """DistributedSampler inheriting from 13 | `torch.utils.data.DistributedSampler`. 14 | 15 | In pytorch of lower versions, there is no `shuffle` argument. This child 16 | class will port one to DistributedSampler. 17 | """ 18 | 19 | def __init__(self, 20 | dataset, 21 | num_replicas=None, 22 | rank=None, 23 | shuffle=True, 24 | samples_per_gpu=1, 25 | seed=0): 26 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 27 | self.shuffle = shuffle 28 | self.samples_per_gpu = samples_per_gpu 29 | # fix the bug of the official implementation 30 | self.num_samples_per_replica = int( 31 | math.ceil( 32 | len(self.dataset) * 1.0 / self.num_replicas / samples_per_gpu)) 33 | self.num_samples = self.num_samples_per_replica * self.samples_per_gpu 34 | self.total_size = self.num_samples * self.num_replicas 35 | 36 | # In distributed sampling, different ranks should sample 37 | # non-overlapped data in the dataset. Therefore, this function 38 | # is used to make sure that each rank shuffles the data indices 39 | # in the same order based on the same seed. Then different ranks 40 | # could use different indices to select non-overlapped data from the 41 | # same data list. 42 | self.seed = sync_random_seed(seed) 43 | 44 | # to avoid padding bug when meeting too small dataset 45 | if len(dataset) < self.num_replicas * samples_per_gpu: 46 | raise ValueError( 47 | 'You may use too small dataset and our distributed ' 48 | 'sampler cannot pad your dataset correctly. We highly ' 49 | 'recommend you to use fewer GPUs to finish your work') 50 | 51 | def __iter__(self): 52 | # deterministically shuffle based on epoch 53 | if self.shuffle: 54 | g = torch.Generator() 55 | # When :attr:`shuffle=True`, this ensures all replicas 56 | # use a different random ordering for each epoch. 57 | # Otherwise, the next iteration of this sampler will 58 | # yield the same ordering. 59 | g.manual_seed(self.epoch + self.seed) 60 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 61 | else: 62 | indices = torch.arange(len(self.dataset)).tolist() 63 | 64 | # add extra samples to make it evenly divisible 65 | indices += indices[:(self.total_size - len(indices))] 66 | assert len(indices) == self.total_size 67 | 68 | # subsample 69 | indices = indices[self.rank:self.total_size:self.num_replicas] 70 | assert len(indices) == self.num_samples 71 | 72 | return iter(indices) 73 | -------------------------------------------------------------------------------- /mmedit/models/common/sr_backbone_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import constant_init, kaiming_init 4 | from mmcv.utils.parrots_wrapper import _BatchNorm 5 | 6 | 7 | def default_init_weights(module, scale=1): 8 | """Initialize network weights. 9 | 10 | Args: 11 | modules (nn.Module): Modules to be initialized. 12 | scale (float): Scale initialized weights, especially for residual 13 | blocks. 14 | """ 15 | for m in module.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | kaiming_init(m, a=0, mode='fan_in', bias=0) 18 | m.weight.data *= scale 19 | elif isinstance(m, nn.Linear): 20 | kaiming_init(m, a=0, mode='fan_in', bias=0) 21 | m.weight.data *= scale 22 | elif isinstance(m, _BatchNorm): 23 | constant_init(m.weight, val=1, bias=0) 24 | 25 | 26 | def make_layer(block, num_blocks, **kwarg): 27 | """Make layers by stacking the same blocks. 28 | 29 | Args: 30 | block (nn.module): nn.module class for basic block. 31 | num_blocks (int): number of blocks. 32 | 33 | Returns: 34 | nn.Sequential: Stacked blocks in nn.Sequential. 35 | """ 36 | layers = [] 37 | for _ in range(num_blocks): 38 | layers.append(block(**kwarg)) 39 | return nn.Sequential(*layers) 40 | 41 | 42 | class ResidualBlockNoBN(nn.Module): 43 | """Residual block without BN. 44 | 45 | It has a style of: 46 | 47 | :: 48 | 49 | ---Conv-ReLU-Conv-+- 50 | |________________| 51 | 52 | Args: 53 | mid_channels (int): Channel number of intermediate features. 54 | Default: 64. 55 | res_scale (float): Used to scale the residual before addition. 56 | Default: 1.0. 57 | """ 58 | 59 | def __init__(self, mid_channels=64, res_scale=1.0, groups=1): 60 | super().__init__() 61 | self.res_scale = res_scale 62 | self.conv1 = nn.Conv2d( 63 | mid_channels, mid_channels, 3, 1, 1, bias=True, groups=groups) 64 | self.conv2 = nn.Conv2d( 65 | mid_channels, mid_channels, 3, 1, 1, bias=True, groups=groups) 66 | 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | # if res_scale < 1.0, use the default initialization, as in EDSR. 70 | # if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet. 71 | if res_scale == 1.0: 72 | self.init_weights() 73 | 74 | def init_weights(self): 75 | """Initialize weights for ResidualBlockNoBN. 76 | 77 | Initialization methods like `kaiming_init` are for VGG-style modules. 78 | For modules with residual paths, using smaller std is better for 79 | stability and performance. We empirically use 0.1. See more details in 80 | "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" 81 | """ 82 | 83 | for m in [self.conv1, self.conv2]: 84 | default_init_weights(m, 0.1) 85 | 86 | def forward(self, x): 87 | """Forward function. 88 | 89 | Args: 90 | x (Tensor): Input tensor with shape (n, c, h, w). 91 | 92 | Returns: 93 | Tensor: Forward results. 94 | """ 95 | 96 | identity = x 97 | out = self.conv2(self.relu(self.conv1(x))) 98 | return identity + out * self.res_scale 99 | -------------------------------------------------------------------------------- /mmedit/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class BaseModel(nn.Module, metaclass=ABCMeta): 10 | """Base model. 11 | 12 | All models should subclass it. 13 | All subclass should overwrite: 14 | 15 | ``init_weights``, supporting to initialize models. 16 | 17 | ``forward_train``, supporting to forward when training. 18 | 19 | ``forward_test``, supporting to forward when testing. 20 | 21 | ``train_step``, supporting to train one step when training. 22 | """ 23 | 24 | @abstractmethod 25 | def init_weights(self): 26 | """Abstract method for initializing weight. 27 | 28 | All subclass should overwrite it. 29 | """ 30 | 31 | @abstractmethod 32 | def forward_train(self, imgs, labels): 33 | """Abstract method for training forward. 34 | 35 | All subclass should overwrite it. 36 | """ 37 | 38 | @abstractmethod 39 | def forward_test(self, imgs): 40 | """Abstract method for testing forward. 41 | 42 | All subclass should overwrite it. 43 | """ 44 | 45 | def forward(self, imgs, labels, test_mode, **kwargs): 46 | """Forward function for base model. 47 | 48 | Args: 49 | imgs (Tensor): Input image(s). 50 | labels (Tensor): Ground-truth label(s). 51 | test_mode (bool): Whether in test mode. 52 | kwargs (dict): Other arguments. 53 | 54 | Returns: 55 | Tensor: Forward results. 56 | """ 57 | 58 | if test_mode: 59 | return self.forward_test(imgs, **kwargs) 60 | 61 | return self.forward_train(imgs, labels, **kwargs) 62 | 63 | @abstractmethod 64 | def train_step(self, data_batch, optimizer): 65 | """Abstract method for one training step. 66 | 67 | All subclass should overwrite it. 68 | """ 69 | 70 | def val_step(self, data_batch, **kwargs): 71 | """Abstract method for one validation step. 72 | 73 | All subclass should overwrite it. 74 | """ 75 | output = self.forward_test(**data_batch, **kwargs) 76 | return output 77 | 78 | def parse_losses(self, losses): 79 | """Parse losses dict for different loss variants. 80 | 81 | Args: 82 | losses (dict): Loss dict. 83 | 84 | Returns: 85 | loss (float): Sum of the total loss. 86 | log_vars (dict): loss dict for different variants. 87 | """ 88 | log_vars = OrderedDict() 89 | for loss_name, loss_value in losses.items(): 90 | if isinstance(loss_value, torch.Tensor): 91 | log_vars[loss_name] = loss_value.mean() 92 | elif isinstance(loss_value, list): 93 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 94 | else: 95 | raise TypeError( 96 | f'{loss_name} is not a tensor or list of tensors') 97 | 98 | loss = sum(_value for _key, _value in log_vars.items() 99 | if 'loss' in _key) 100 | 101 | log_vars['loss'] = loss 102 | for name in log_vars: 103 | log_vars[name] = log_vars[name].item() 104 | 105 | return loss, log_vars 106 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_face_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | 6 | from mmedit.datasets.pipelines import Compose 7 | 8 | try: 9 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 10 | has_facexlib = True 11 | except ImportError: 12 | has_facexlib = False 13 | 14 | 15 | def restoration_face_inference(model, img, upscale_factor=1, face_size=1024): 16 | """Inference image with the model. 17 | 18 | Args: 19 | model (nn.Module): The loaded model. 20 | img (str): File path of input image. 21 | upscale_factor (int, optional): The number of times the input image 22 | is upsampled. Default: 1. 23 | face_size (int, optional): The size of the cropped and aligned faces. 24 | Default: 1024. 25 | 26 | Returns: 27 | Tensor: The predicted restoration result. 28 | """ 29 | device = next(model.parameters()).device # model device 30 | 31 | # build the data pipeline 32 | if model.cfg.get('demo_pipeline', None): 33 | test_pipeline = model.cfg.demo_pipeline 34 | elif model.cfg.get('test_pipeline', None): 35 | test_pipeline = model.cfg.test_pipeline 36 | else: 37 | test_pipeline = model.cfg.val_pipeline 38 | 39 | # remove gt from test_pipeline 40 | keys_to_remove = ['gt', 'gt_path'] 41 | for key in keys_to_remove: 42 | for pipeline in list(test_pipeline): 43 | if 'key' in pipeline and key == pipeline['key']: 44 | test_pipeline.remove(pipeline) 45 | if 'keys' in pipeline and key in pipeline['keys']: 46 | pipeline['keys'].remove(key) 47 | if len(pipeline['keys']) == 0: 48 | test_pipeline.remove(pipeline) 49 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 50 | pipeline['meta_keys'].remove(key) 51 | # build the data pipeline 52 | test_pipeline = Compose(test_pipeline) 53 | 54 | # face helper for detecting and aligning faces 55 | assert has_facexlib, 'Please install FaceXLib to use the demo.' 56 | face_helper = FaceRestoreHelper( 57 | upscale_factor, 58 | face_size=face_size, 59 | crop_ratio=(1, 1), 60 | det_model='retinaface_resnet50', 61 | template_3points=True, 62 | save_ext='png', 63 | device=device) 64 | 65 | face_helper.read_image(img) 66 | # get face landmarks for each face 67 | face_helper.get_face_landmarks_5( 68 | only_center_face=False, eye_dist_threshold=None) 69 | # align and warp each face 70 | face_helper.align_warp_face() 71 | 72 | for i, img in enumerate(face_helper.cropped_faces): 73 | # prepare data 74 | data = dict(lq=img.astype(np.float32)) 75 | data = test_pipeline(data) 76 | data = collate([data], samples_per_gpu=1) 77 | if 'cuda' in str(device): 78 | data = scatter(data, [device])[0] 79 | 80 | with torch.no_grad(): 81 | output = model(test_mode=True, **data)['output'] 82 | output = torch.clamp(output, min=0, max=1) 83 | 84 | output = output.squeeze(0).permute(1, 2, 0)[:, :, [2, 1, 0]] 85 | output = output.cpu().numpy() * 255 # (0, 255) 86 | face_helper.add_restored_face(output) 87 | 88 | face_helper.get_inverse_affine(None) 89 | restored_img = face_helper.paste_faces_to_input_image(upsample_img=None) 90 | 91 | return restored_img 92 | -------------------------------------------------------------------------------- /mmedit/models/common/linear_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import build_activation_layer, kaiming_init 4 | 5 | 6 | class LinearModule(nn.Module): 7 | """A linear block that contains linear/norm/activation layers. 8 | 9 | For low level vision, we add spectral norm and padding layer. 10 | 11 | Args: 12 | in_features (int): Same as nn.Linear. 13 | out_features (int): Same as nn.Linear. 14 | bias (bool): Same as nn.Linear. 15 | act_cfg (dict): Config dict for activation layer, "relu" by default. 16 | inplace (bool): Whether to use inplace mode for activation. 17 | with_spectral_norm (bool): Whether use spectral norm in linear module. 18 | order (tuple[str]): The order of linear/activation layers. It is a 19 | sequence of "linear", "norm" and "act". Examples are 20 | ("linear", "act") and ("act", "linear"). 21 | """ 22 | 23 | def __init__(self, 24 | in_features, 25 | out_features, 26 | bias=True, 27 | act_cfg=dict(type='ReLU'), 28 | inplace=True, 29 | with_spectral_norm=False, 30 | order=('linear', 'act')): 31 | super().__init__() 32 | assert act_cfg is None or isinstance(act_cfg, dict) 33 | self.act_cfg = act_cfg 34 | self.inplace = inplace 35 | self.with_spectral_norm = with_spectral_norm 36 | self.order = order 37 | assert isinstance(self.order, tuple) and len(self.order) == 2 38 | assert set(order) == set(['linear', 'act']) 39 | 40 | self.with_activation = act_cfg is not None 41 | self.with_bias = bias 42 | 43 | # build linear layer 44 | self.linear = nn.Linear(in_features, out_features, bias=bias) 45 | # export the attributes of self.linear to a higher level for 46 | # convenience 47 | self.in_features = self.linear.in_features 48 | self.out_features = self.linear.out_features 49 | 50 | if self.with_spectral_norm: 51 | self.linear = nn.utils.spectral_norm(self.linear) 52 | 53 | # build activation layer 54 | if self.with_activation: 55 | act_cfg_ = act_cfg.copy() 56 | act_cfg_.setdefault('inplace', inplace) 57 | self.activate = build_activation_layer(act_cfg_) 58 | 59 | # Use msra init by default 60 | self.init_weights() 61 | 62 | def init_weights(self): 63 | if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': 64 | nonlinearity = 'leaky_relu' 65 | a = self.act_cfg.get('negative_slope', 0.01) 66 | else: 67 | nonlinearity = 'relu' 68 | a = 0 69 | 70 | kaiming_init(self.linear, a=a, nonlinearity=nonlinearity) 71 | 72 | def forward(self, x, activate=True): 73 | """Forward Function. 74 | 75 | Args: 76 | x (torch.Tensor): Input tensor with shape of :math:`(n, *, c)`. 77 | Same as ``torch.nn.Linear``. 78 | activate (bool, optional): Whether to use activation layer. 79 | Defaults to True. 80 | 81 | Returns: 82 | torch.Tensor: Same as ``torch.nn.Linear``. 83 | """ 84 | for layer in self.order: 85 | if layer == 'linear': 86 | x = self.linear(x) 87 | elif layer == 'act' and activate and self.with_activation: 88 | x = self.activate(x) 89 | return x 90 | -------------------------------------------------------------------------------- /mmedit/core/hooks/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | import torch 6 | from mmcv.runner import HOOKS, Hook 7 | from mmcv.runner.dist_utils import master_only 8 | from torchvision.utils import save_image 9 | 10 | from mmedit.utils import deprecated_function 11 | 12 | 13 | @HOOKS.register_module() 14 | class MMEditVisualizationHook(Hook): 15 | """Visualization hook. 16 | 17 | In this hook, we use the official api `save_image` in torchvision to save 18 | the visualization results. 19 | 20 | Args: 21 | output_dir (str): The file path to store visualizations. 22 | res_name_list (str): The list contains the name of results in outputs 23 | dict. The results in outputs dict must be a torch.Tensor with shape 24 | (n, c, h, w). 25 | interval (int): The interval of calling this hook. If set to -1, 26 | the visualization hook will not be called. Default: -1. 27 | filename_tmpl (str): Format string used to save images. The output file 28 | name will be formatted as this args. Default: 'iter_{}.png'. 29 | rerange (bool): Whether to rerange the output value from [-1, 1] to 30 | [0, 1]. We highly recommend users should preprocess the 31 | visualization results on their own. Here, we just provide a simple 32 | interface. Default: True. 33 | bgr2rgb (bool): Whether to reformat the channel dimension from BGR to 34 | RGB. The final image we will save is following RGB style. 35 | Default: True. 36 | nrow (int): The number of samples in a row. Default: 1. 37 | padding (int): The number of padding pixels between each samples. 38 | Default: 4. 39 | """ 40 | 41 | def __init__(self, 42 | output_dir, 43 | res_name_list, 44 | interval=-1, 45 | filename_tmpl='iter_{}.png', 46 | rerange=True, 47 | bgr2rgb=True, 48 | nrow=1, 49 | padding=4): 50 | assert mmcv.is_list_of(res_name_list, str) 51 | self.output_dir = output_dir 52 | self.res_name_list = res_name_list 53 | self.interval = interval 54 | self.filename_tmpl = filename_tmpl 55 | self.bgr2rgb = bgr2rgb 56 | self.rerange = rerange 57 | self.nrow = nrow 58 | self.padding = padding 59 | 60 | mmcv.mkdir_or_exist(self.output_dir) 61 | 62 | @master_only 63 | def after_train_iter(self, runner): 64 | """The behavior after each train iteration. 65 | 66 | Args: 67 | runner (object): The runner. 68 | """ 69 | if not self.every_n_iters(runner, self.interval): 70 | return 71 | results = runner.outputs['results'] 72 | 73 | filename = self.filename_tmpl.format(runner.iter + 1) 74 | 75 | img_list = [results[k] for k in self.res_name_list if k in results] 76 | img_cat = torch.cat(img_list, dim=3).detach() 77 | if self.rerange: 78 | img_cat = ((img_cat + 1) / 2) 79 | if self.bgr2rgb: 80 | img_cat = img_cat[:, [2, 1, 0], ...] 81 | img_cat = img_cat.clamp_(0, 1) 82 | save_image( 83 | img_cat, 84 | osp.join(self.output_dir, filename), 85 | nrow=self.nrow, 86 | padding=self.padding) 87 | 88 | 89 | @HOOKS.register_module() 90 | class VisualizationHook(MMEditVisualizationHook): 91 | 92 | @deprecated_function('0.16.0', '0.20.0', 'use \'MMEditVisualizationHook\'') 93 | def __init__(self, *args, **kwargs): 94 | super().__init__(*args, **kwargs) 95 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/crop_hazeworld.py: -------------------------------------------------------------------------------- 1 | # https://github.com/open-mmlab/mmediting/blob/master/mmedit/datasets/pipelines/crop.py 2 | import numpy as np 3 | 4 | from ..registry import PIPELINES 5 | 6 | 7 | @PIPELINES.register_module() 8 | class PairedRandomCropWithTransmission: 9 | """Paried random crop. 10 | 11 | It crops a pair of lq, gt, and transmission map images with corresponding locations. 12 | It also supports accepting lq list and gt list. 13 | Required keys are "scale", "lq", "gt", and "trans", 14 | added or modified keys are "lq", "gt", and "trans". 15 | 16 | Args: 17 | gt_patch_size (int): cropped gt patch size. 18 | """ 19 | 20 | def __init__(self, gt_patch_size): 21 | self.gt_patch_size = gt_patch_size 22 | 23 | def __call__(self, results): 24 | """Call function. 25 | 26 | Args: 27 | results (dict): A dict containing the necessary information and 28 | data for augmentation. 29 | 30 | Returns: 31 | dict: A dict containing the processed data and information. 32 | """ 33 | scale = results['scale'] 34 | assert scale == 1 # for dehazing setting 35 | lq_patch_size = self.gt_patch_size // scale 36 | 37 | lq_is_list = isinstance(results['lq'], list) 38 | if not lq_is_list: 39 | results['lq'] = [results['lq']] 40 | gt_is_list = isinstance(results['gt'], list) 41 | if not gt_is_list: 42 | results['gt'] = [results['gt']] 43 | tm_is_list = isinstance(results['trans'], list) 44 | if not tm_is_list: 45 | results['trans'] = [results['trans']] 46 | 47 | h_lq, w_lq, _ = results['lq'][0].shape 48 | h_gt, w_gt, _ = results['gt'][0].shape 49 | h_tm, w_tm, _ = results['trans'][0].shape 50 | 51 | if h_gt != h_lq * scale or w_gt != w_lq * scale or h_gt != h_tm or w_gt != w_tm: 52 | raise ValueError( 53 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 54 | f'multiplication of LQ ({h_lq}, {w_lq}), TM ({w_tm}, {w_tm}).') 55 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 56 | raise ValueError( 57 | f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 58 | f'({lq_patch_size}, {lq_patch_size}). Please check ' 59 | f'{results["lq_path"][0]} and {results["gt_path"][0]}.') 60 | 61 | # randomly choose top and left coordinates for lq patch 62 | top = np.random.randint(h_lq - lq_patch_size + 1) 63 | left = np.random.randint(w_lq - lq_patch_size + 1) 64 | # crop lq patch 65 | results['lq'] = [ 66 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 67 | for v in results['lq'] 68 | ] 69 | # crop corresponding gt patch 70 | top_gt, left_gt = int(top * scale), int(left * scale) 71 | results['gt'] = [ 72 | v[top_gt:top_gt + self.gt_patch_size, 73 | left_gt:left_gt + self.gt_patch_size, ...] for v in results['gt'] 74 | ] 75 | 76 | if not lq_is_list: 77 | results['lq'] = results['lq'][0] 78 | if not gt_is_list: 79 | results['gt'] = results['gt'][0] 80 | 81 | # crop corresponding transmission map patch 82 | results['trans'] = [ 83 | v[top_gt:top_gt + self.gt_patch_size, 84 | left_gt:left_gt + self.gt_patch_size, ...] for v in results['trans'] 85 | ] 86 | if not tm_is_list: 87 | results['trans'] = results['trans'][0] 88 | return results 89 | 90 | def __repr__(self): 91 | repr_str = self.__class__.__name__ 92 | repr_str += f'(gt_patch_size={self.gt_patch_size})' 93 | return repr_str 94 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import numpy as np 4 | 5 | from ..registry import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Normalize: 10 | """Normalize images with the given mean and std value. 11 | 12 | Required keys are the keys in attribute "keys", added or modified keys are 13 | the keys in attribute "keys" and these keys with postfix '_norm_cfg'. 14 | It also supports normalizing a list of images. 15 | 16 | Args: 17 | keys (Sequence[str]): The images to be normalized. 18 | mean (np.ndarray): Mean values of different channels. 19 | std (np.ndarray): Std values of different channels. 20 | to_rgb (bool): Whether to convert channels from BGR to RGB. 21 | """ 22 | 23 | def __init__(self, keys, mean, std, to_rgb=False, save_original=False): 24 | self.keys = keys 25 | self.mean = np.array(mean, dtype=np.float32) 26 | self.std = np.array(std, dtype=np.float32) 27 | self.to_rgb = to_rgb 28 | self.save_original = save_original 29 | 30 | def __call__(self, results): 31 | """Call function. 32 | 33 | Args: 34 | results (dict): A dict containing the necessary information and 35 | data for augmentation. 36 | 37 | Returns: 38 | dict: A dict containing the processed data and information. 39 | """ 40 | for key in self.keys: 41 | if isinstance(results[key], list): 42 | if self.save_original: 43 | results[key + '_unnormalised'] = [ 44 | v.copy() for v in results[key] 45 | ] 46 | results[key] = [ 47 | mmcv.imnormalize(v, self.mean, self.std, self.to_rgb) 48 | for v in results[key] 49 | ] 50 | else: 51 | if self.save_original: 52 | results[key + '_unnormalised'] = results[key].copy() 53 | results[key] = mmcv.imnormalize(results[key], self.mean, 54 | self.std, self.to_rgb) 55 | 56 | results['img_norm_cfg'] = dict( 57 | mean=self.mean, std=self.std, to_rgb=self.to_rgb) 58 | return results 59 | 60 | def __repr__(self): 61 | repr_str = self.__class__.__name__ 62 | repr_str += (f'(keys={self.keys}, mean={self.mean}, std={self.std}, ' 63 | f'to_rgb={self.to_rgb})') 64 | 65 | return repr_str 66 | 67 | 68 | @PIPELINES.register_module() 69 | class RescaleToZeroOne: 70 | """Transform the images into a range between 0 and 1. 71 | 72 | Required keys are the keys in attribute "keys", added or modified keys are 73 | the keys in attribute "keys". 74 | It also supports rescaling a list of images. 75 | 76 | Args: 77 | keys (Sequence[str]): The images to be transformed. 78 | """ 79 | 80 | def __init__(self, keys): 81 | self.keys = keys 82 | 83 | def __call__(self, results): 84 | """Call function. 85 | 86 | Args: 87 | results (dict): A dict containing the necessary information and 88 | data for augmentation. 89 | 90 | Returns: 91 | dict: A dict containing the processed data and information. 92 | """ 93 | for key in self.keys: 94 | if isinstance(results[key], list): 95 | results[key] = [ 96 | v.astype(np.float32) / 255. for v in results[key] 97 | ] 98 | else: 99 | results[key] = results[key].astype(np.float32) / 255. 100 | return results 101 | 102 | def __repr__(self): 103 | return self.__class__.__name__ + f'(keys={self.keys})' 104 | -------------------------------------------------------------------------------- /mmedit/models/common/mask_conv_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import ConvModule 3 | 4 | 5 | class MaskConvModule(ConvModule): 6 | """Mask convolution module. 7 | 8 | This is a simple wrapper for mask convolution like: 'partial conv'. 9 | Convolutions in this module always need a mask as extra input. 10 | 11 | Args: 12 | in_channels (int): Same as nn.Conv2d. 13 | out_channels (int): Same as nn.Conv2d. 14 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 15 | stride (int or tuple[int]): Same as nn.Conv2d. 16 | padding (int or tuple[int]): Same as nn.Conv2d. 17 | dilation (int or tuple[int]): Same as nn.Conv2d. 18 | groups (int): Same as nn.Conv2d. 19 | bias (bool or str): If specified as `auto`, it will be decided by the 20 | norm_cfg. Bias will be set as True if norm_cfg is None, otherwise 21 | False. 22 | conv_cfg (dict): Config dict for convolution layer. 23 | norm_cfg (dict): Config dict for normalization layer. 24 | act_cfg (dict): Config dict for activation layer, "relu" by default. 25 | inplace (bool): Whether to use inplace mode for activation. 26 | with_spectral_norm (bool): Whether use spectral norm in conv module. 27 | padding_mode (str): If the `padding_mode` has not been supported by 28 | current `Conv2d` in Pytorch, we will use our own padding layer 29 | instead. Currently, we support ['zeros', 'circular'] with official 30 | implementation and ['reflect'] with our own implementation. 31 | Default: 'zeros'. 32 | order (tuple[str]): The order of conv/norm/activation layers. It is a 33 | sequence of "conv", "norm" and "act". Examples are 34 | ("conv", "norm", "act") and ("act", "conv", "norm"). 35 | """ 36 | supported_conv_list = ['PConv'] 37 | 38 | def __init__(self, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | assert self.conv_cfg['type'] in self.supported_conv_list 41 | 42 | self.init_weights() 43 | 44 | def forward(self, 45 | x, 46 | mask=None, 47 | activate=True, 48 | norm=True, 49 | return_mask=True): 50 | """Forward function for partial conv2d. 51 | 52 | Args: 53 | input (torch.Tensor): Tensor with shape of (n, c, h, w). 54 | mask (torch.Tensor): Tensor with shape of (n, c, h, w) or 55 | (n, 1, h, w). If mask is not given, the function will 56 | work as standard conv2d. Default: None. 57 | activate (bool): Whether use activation layer. 58 | norm (bool): Whether use norm layer. 59 | return_mask (bool): If True and mask is not None, the updated 60 | mask will be returned. Default: True. 61 | 62 | Returns: 63 | Tensor or tuple: Result Tensor or 2-tuple of 64 | 65 | ``Tensor``: Results after partial conv. 66 | 67 | ``Tensor``: Updated mask will be returned if mask is given \ 68 | and `return_mask` is True. 69 | """ 70 | for layer in self.order: 71 | if layer == 'conv': 72 | if self.with_explicit_padding: 73 | x = self.padding_layer(x) 74 | mask = self.padding_layer(mask) 75 | if return_mask: 76 | x, updated_mask = self.conv( 77 | x, mask, return_mask=return_mask) 78 | else: 79 | x = self.conv(x, mask, return_mask=False) 80 | elif layer == 'norm' and norm and self.with_norm: 81 | x = self.norm(x) 82 | elif layer == 'act' and activate and self.with_activation: 83 | x = self.activate(x) 84 | 85 | if return_mask: 86 | return x, updated_mask 87 | 88 | return x 89 | -------------------------------------------------------------------------------- /mmedit/models/common/ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SpatialTemporalEnsemble(nn.Module): 7 | """Apply spatial and temporal ensemble and compute outputs. 8 | 9 | Args: 10 | is_temporal_ensemble (bool, optional): Whether to apply ensemble 11 | temporally. If True, the sequence will also be flipped temporally. 12 | If the input is an image, this argument must be set to False. 13 | Default: False. 14 | """ 15 | 16 | def __init__(self, is_temporal_ensemble=False): 17 | 18 | super().__init__() 19 | 20 | self.is_temporal_ensemble = is_temporal_ensemble 21 | 22 | def _transform(self, imgs, mode): 23 | """Apply spatial transform (flip, rotate) to the images. 24 | 25 | Args: 26 | imgs (torch.Tensor): The images to be transformed/ 27 | mode (str): The mode of transform. Supported values are 'vertical', 28 | 'horizontal', and 'transpose', corresponding to vertical flip, 29 | horizontal flip, and rotation, respectively. 30 | 31 | Returns: 32 | torch.Tensor: Output of the model with spatial ensemble applied. 33 | """ 34 | 35 | is_single_image = False 36 | if imgs.ndim == 4: 37 | if self.is_temporal_ensemble: 38 | raise ValueError('"is_temporal_ensemble" must be False if ' 39 | 'the input is an image.') 40 | is_single_image = True 41 | imgs = imgs.unsqueeze(1) 42 | 43 | if mode == 'vertical': 44 | imgs = imgs.flip(4).clone() 45 | elif mode == 'horizontal': 46 | imgs = imgs.flip(3).clone() 47 | elif mode == 'transpose': 48 | imgs = imgs.permute(0, 1, 2, 4, 3).clone() 49 | 50 | if is_single_image: 51 | imgs = imgs.squeeze(1) 52 | 53 | return imgs 54 | 55 | def spatial_ensemble(self, imgs, model): 56 | """Apply spatial ensemble. 57 | 58 | Args: 59 | imgs (torch.Tensor): The images to be processed by the model. Its 60 | size should be either (n, t, c, h, w) or (n, c, h, w). 61 | model (nn.Module): The model to process the images. 62 | 63 | Returns: 64 | torch.Tensor: Output of the model with spatial ensemble applied. 65 | """ 66 | 67 | img_list = [imgs.cpu()] 68 | for mode in ['vertical', 'horizontal', 'transpose']: 69 | img_list.extend([self._transform(t, mode) for t in img_list]) 70 | 71 | output_list = [model(t.to(imgs.device)).cpu() for t in img_list] 72 | for i in range(len(output_list)): 73 | if i > 3: 74 | output_list[i] = self._transform(output_list[i], 'transpose') 75 | if i % 4 > 1: 76 | output_list[i] = self._transform(output_list[i], 'horizontal') 77 | if (i % 4) % 2 == 1: 78 | output_list[i] = self._transform(output_list[i], 'vertical') 79 | 80 | outputs = torch.stack(output_list, dim=0) 81 | outputs = outputs.mean(dim=0, keepdim=False) 82 | 83 | return outputs.to(imgs.device) 84 | 85 | def forward(self, imgs, model): 86 | """Apply spatial and temporal ensemble. 87 | 88 | Args: 89 | imgs (torch.Tensor): The images to be processed by the model. Its 90 | size should be either (n, t, c, h, w) or (n, c, h, w). 91 | model (nn.Module): The model to process the images. 92 | 93 | Returns: 94 | torch.Tensor: Output of the model with spatial ensemble applied. 95 | """ 96 | outputs = self.spatial_ensemble(imgs, model) 97 | if self.is_temporal_ensemble: 98 | outputs += self.spatial_ensemble(imgs.flip(1), model).flip(1) 99 | outputs *= 0.5 100 | 101 | return outputs 102 | -------------------------------------------------------------------------------- /mmedit/models/common/partial_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from mmcv.cnn import CONV_LAYERS 7 | 8 | 9 | @CONV_LAYERS.register_module(name='PConv') 10 | class PartialConv2d(nn.Conv2d): 11 | """Implementation for partial convolution. 12 | 13 | Image Inpainting for Irregular Holes Using Partial Convolutions 14 | [https://arxiv.org/abs/1804.07723] 15 | 16 | Args: 17 | multi_channel (bool): If True, the mask is multi-channel. Otherwise, 18 | the mask is single-channel. 19 | eps (float): Need to be changed for mixed precision training. 20 | For mixed precision training, you need change 1e-8 to 1e-6. 21 | """ 22 | 23 | def __init__(self, *args, multi_channel=False, eps=1e-8, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | 26 | # whether the mask is multi-channel or not 27 | self.multi_channel = multi_channel 28 | self.eps = eps 29 | 30 | if self.multi_channel: 31 | out_channels, in_channels = self.out_channels, self.in_channels 32 | else: 33 | out_channels, in_channels = 1, 1 34 | 35 | self.register_buffer( 36 | 'weight_mask_updater', 37 | torch.ones(out_channels, in_channels, self.kernel_size[0], 38 | self.kernel_size[1])) 39 | 40 | self.mask_kernel_numel = np.prod(self.weight_mask_updater.shape[1:4]) 41 | self.mask_kernel_numel = (self.mask_kernel_numel).item() 42 | 43 | def forward(self, input, mask=None, return_mask=True): 44 | """Forward function for partial conv2d. 45 | 46 | Args: 47 | input (torch.Tensor): Tensor with shape of (n, c, h, w). 48 | mask (torch.Tensor): Tensor with shape of (n, c, h, w) or 49 | (n, 1, h, w). If mask is not given, the function will 50 | work as standard conv2d. Default: None. 51 | return_mask (bool): If True and mask is not None, the updated 52 | mask will be returned. Default: True. 53 | 54 | Returns: 55 | torch.Tensor : Results after partial conv.\ 56 | torch.Tensor : Updated mask will be returned if mask is given and \ 57 | ``return_mask`` is True. 58 | """ 59 | assert input.dim() == 4 60 | if mask is not None: 61 | assert mask.dim() == 4 62 | if self.multi_channel: 63 | assert mask.shape[1] == input.shape[1] 64 | else: 65 | assert mask.shape[1] == 1 66 | 67 | # update mask and compute mask ratio 68 | if mask is not None: 69 | with torch.no_grad(): 70 | 71 | updated_mask = F.conv2d( 72 | mask, 73 | self.weight_mask_updater, 74 | bias=None, 75 | stride=self.stride, 76 | padding=self.padding, 77 | dilation=self.dilation) 78 | mask_ratio = self.mask_kernel_numel / (updated_mask + self.eps) 79 | 80 | updated_mask = torch.clamp(updated_mask, 0, 1) 81 | mask_ratio = mask_ratio * updated_mask 82 | 83 | # standard conv2d 84 | if mask is not None: 85 | input = input * mask 86 | raw_out = super().forward(input) 87 | 88 | if mask is not None: 89 | if self.bias is None: 90 | output = raw_out * mask_ratio 91 | else: 92 | # compute new bias when mask is given 93 | bias_view = self.bias.view(1, self.out_channels, 1, 1) 94 | output = (raw_out - bias_view) * mask_ratio + bias_view 95 | output = output * updated_mask 96 | else: 97 | output = raw_out 98 | 99 | if return_mask and mask is not None: 100 | return output, updated_mask 101 | 102 | return output 103 | -------------------------------------------------------------------------------- /mmedit/datasets/base_sr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import os.path as osp 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from mmcv import scandir 9 | 10 | from mmedit.core.registry import build_metric 11 | from .base_dataset import BaseDataset 12 | 13 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 14 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF') 15 | FEATURE_BASED_METRICS = ['FID', 'KID'] 16 | 17 | 18 | class BaseSRDataset(BaseDataset): 19 | """Base class for super resolution datasets.""" 20 | 21 | def __init__(self, pipeline, scale, test_mode=False): 22 | super().__init__(pipeline, test_mode) 23 | self.scale = scale 24 | 25 | @staticmethod 26 | def scan_folder(path): 27 | """Obtain image path list (including sub-folders) from a given folder. 28 | 29 | Args: 30 | path (str | :obj:`Path`): Folder path. 31 | 32 | Returns: 33 | list[str]: image list obtained form given folder. 34 | """ 35 | 36 | if isinstance(path, (str, Path)): 37 | path = str(path) 38 | else: 39 | raise TypeError("'path' must be a str or a Path object, " 40 | f'but received {type(path)}.') 41 | 42 | images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) 43 | images = [osp.join(path, v) for v in images] 44 | assert images, f'{path} has no valid image file.' 45 | return images 46 | 47 | def __getitem__(self, idx): 48 | """Get item at each call. 49 | 50 | Args: 51 | idx (int): Index for getting each item. 52 | """ 53 | results = copy.deepcopy(self.data_infos[idx]) 54 | results['scale'] = self.scale 55 | return self.pipeline(results) 56 | 57 | def evaluate(self, results, logger=None): 58 | """Evaluate with different metrics. 59 | 60 | Args: 61 | results (list[tuple]): The output of forward_test() of the model. 62 | 63 | Return: 64 | dict: Evaluation results dict. 65 | """ 66 | if not isinstance(results, list): 67 | raise TypeError(f'results must be a list, but got {type(results)}') 68 | assert len(results) == len(self), ( 69 | 'The length of results is not equal to the dataset len: ' 70 | f'{len(results)} != {len(self)}') 71 | 72 | results = [res['eval_result'] for res in results] # a list of dict 73 | eval_result = defaultdict(list) # a dict of list 74 | 75 | for res in results: 76 | for metric, val in res.items(): 77 | eval_result[metric].append(val) 78 | for metric, val_list in eval_result.items(): 79 | assert len(val_list) == len(self), ( 80 | f'Length of evaluation result of {metric} is {len(val_list)}, ' 81 | f'should be {len(self)}') 82 | 83 | # average the results 84 | eval_result.update({ 85 | metric: sum(values) / len(self) 86 | for metric, values in eval_result.items() 87 | if metric not in ['_inception_feat'] + FEATURE_BASED_METRICS 88 | }) 89 | 90 | # evaluate feature-based metrics 91 | if '_inception_feat' in eval_result: 92 | feat1, feat2 = [], [] 93 | for f1, f2 in eval_result['_inception_feat']: 94 | feat1.append(f1) 95 | feat2.append(f2) 96 | feat1 = np.concatenate(feat1, 0) 97 | feat2 = np.concatenate(feat2, 0) 98 | 99 | for metric in FEATURE_BASED_METRICS: 100 | if metric in eval_result: 101 | metric_func = build_metric(eval_result[metric].pop()) 102 | eval_result[metric] = metric_func(feat1, feat2) 103 | 104 | # delete a redundant key for clean logging 105 | del eval_result['_inception_feat'] 106 | 107 | return eval_result 108 | -------------------------------------------------------------------------------- /tools/deployment/mmedit2torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | 6 | import mmcv 7 | 8 | try: 9 | from model_archiver.model_packaging import package_model 10 | from model_archiver.model_packaging_utils import ModelExportUtils 11 | except ImportError: 12 | package_model = None 13 | 14 | 15 | def mmedit2torchserve( 16 | config_file: str, 17 | checkpoint_file: str, 18 | output_folder: str, 19 | model_name: str, 20 | model_version: str = '1.0', 21 | force: bool = False, 22 | ): 23 | """Converts MMEditing model (config + checkpoint) to TorchServe `.mar`. 24 | 25 | Args: 26 | config_file: 27 | In MMEditing config format. 28 | The contents vary for each task repository. 29 | checkpoint_file: 30 | In MMEditing checkpoint format. 31 | The contents vary for each task repository. 32 | output_folder: 33 | Folder where `{model_name}.mar` will be created. 34 | The file created will be in TorchServe archive format. 35 | model_name: 36 | If not None, used for naming the `{model_name}.mar` file 37 | that will be created under `output_folder`. 38 | If None, `{Path(checkpoint_file).stem}` will be used. 39 | model_version: 40 | Model's version. 41 | force: 42 | If True, if there is an existing `{model_name}.mar` 43 | file under `output_folder` it will be overwritten. 44 | """ 45 | mmcv.mkdir_or_exist(output_folder) 46 | 47 | config = mmcv.Config.fromfile(config_file) 48 | 49 | with TemporaryDirectory() as tmpdir: 50 | config.dump(f'{tmpdir}/config.py') 51 | 52 | args_ = Namespace( 53 | **{ 54 | 'model_file': f'{tmpdir}/config.py', 55 | 'serialized_file': checkpoint_file, 56 | 'handler': f'{Path(__file__).parent}/mmedit_handler.py', 57 | 'model_name': model_name or Path(checkpoint_file).stem, 58 | 'version': model_version, 59 | 'export_path': output_folder, 60 | 'force': force, 61 | 'requirements_file': None, 62 | 'extra_files': None, 63 | 'runtime': 'python', 64 | 'archive_format': 'default' 65 | }) 66 | print(args_.model_name) 67 | manifest = ModelExportUtils.generate_manifest_json(args_) 68 | package_model(args_, manifest) 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser( 73 | description='Convert MMEditing models to TorchServe `.mar` format.') 74 | parser.add_argument('config', type=str, help='config file path') 75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 76 | parser.add_argument( 77 | '--output-folder', 78 | type=str, 79 | required=True, 80 | help='Folder where `{model_name}.mar` will be created.') 81 | parser.add_argument( 82 | '--model-name', 83 | type=str, 84 | default=None, 85 | help='If not None, used for naming the `{model_name}.mar`' 86 | 'file that will be created under `output_folder`.' 87 | 'If None, `{Path(checkpoint_file).stem}` will be used.') 88 | parser.add_argument( 89 | '--model-version', 90 | type=str, 91 | default='1.0', 92 | help='Number used for versioning.') 93 | parser.add_argument( 94 | '-f', 95 | '--force', 96 | action='store_true', 97 | help='overwrite the existing `{model_name}.mar`') 98 | args_ = parser.parse_args() 99 | 100 | return args_ 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | 106 | if package_model is None: 107 | raise ImportError('`torch-model-archiver` is required.' 108 | 'Try: pip install torch-model-archiver') 109 | 110 | mmedit2torchserve(args.config, args.checkpoint, args.output_folder, 111 | args.model_name, args.model_version, args.force) 112 | -------------------------------------------------------------------------------- /mmedit/models/common/separable_conv_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule 4 | 5 | 6 | class DepthwiseSeparableConvModule(nn.Module): 7 | """Depthwise separable convolution module. 8 | 9 | See https://arxiv.org/pdf/1704.04861.pdf for details. 10 | 11 | This module can replace a ConvModule with the conv block replaced by two 12 | conv block: depthwise conv block and pointwise conv block. The depthwise 13 | conv block contains depthwise-conv/norm/activation layers. The pointwise 14 | conv block contains pointwise-conv/norm/activation layers. It should be 15 | noted that there will be norm/activation layer in the depthwise conv block 16 | if ``norm_cfg`` and ``act_cfg`` are specified. 17 | 18 | Args: 19 | in_channels (int): Same as nn.Conv2d. 20 | out_channels (int): Same as nn.Conv2d. 21 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 22 | stride (int or tuple[int]): Same as nn.Conv2d. Default: 1. 23 | padding (int or tuple[int]): Same as nn.Conv2d. Default: 0. 24 | dilation (int or tuple[int]): Same as nn.Conv2d. Default: 1. 25 | norm_cfg (dict): Default norm config for both depthwise ConvModule and 26 | pointwise ConvModule. Default: None. 27 | act_cfg (dict): Default activation config for both depthwise ConvModule 28 | and pointwise ConvModule. Default: dict(type='ReLU'). 29 | dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is 30 | 'default', it will be the same as ``norm_cfg``. Default: 'default'. 31 | dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is 32 | 'default', it will be the same as ``act_cfg``. Default: 'default'. 33 | pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is 34 | 'default', it will be the same as `norm_cfg`. Default: 'default'. 35 | pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is 36 | 'default', it will be the same as ``act_cfg``. Default: 'default'. 37 | kwargs (optional): Other shared arguments for depthwise and pointwise 38 | ConvModule. See ConvModule for ref. 39 | """ 40 | 41 | def __init__(self, 42 | in_channels, 43 | out_channels, 44 | kernel_size, 45 | stride=1, 46 | padding=0, 47 | dilation=1, 48 | norm_cfg=None, 49 | act_cfg=dict(type='ReLU'), 50 | dw_norm_cfg='default', 51 | dw_act_cfg='default', 52 | pw_norm_cfg='default', 53 | pw_act_cfg='default', 54 | **kwargs): 55 | super().__init__() 56 | assert 'groups' not in kwargs, 'groups should not be specified' 57 | 58 | # if norm/activation config of depthwise/pointwise ConvModule is not 59 | # specified, use default config. 60 | dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg 61 | dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg 62 | pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg 63 | pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg 64 | 65 | # depthwise convolution 66 | self.depthwise_conv = ConvModule( 67 | in_channels, 68 | in_channels, 69 | kernel_size, 70 | stride=stride, 71 | padding=padding, 72 | dilation=dilation, 73 | groups=in_channels, 74 | norm_cfg=dw_norm_cfg, 75 | act_cfg=dw_act_cfg, 76 | **kwargs) 77 | 78 | self.pointwise_conv = ConvModule( 79 | in_channels, 80 | out_channels, 81 | 1, 82 | norm_cfg=pw_norm_cfg, 83 | act_cfg=pw_act_cfg, 84 | **kwargs) 85 | 86 | def forward(self, x): 87 | """Forward function. 88 | 89 | Args: 90 | x (Tensor): Input tensor with shape (N, C, H, W). 91 | 92 | Returns: 93 | Tensor: Output tensor. 94 | """ 95 | x = self.depthwise_conv(x) 96 | x = self.pointwise_conv(x) 97 | return x 98 | -------------------------------------------------------------------------------- /mmedit/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | def reduce_loss(loss, reduction): 8 | """Reduce loss as specified. 9 | 10 | Args: 11 | loss (Tensor): Elementwise loss tensor. 12 | reduction (str): Options are "none", "mean" and "sum". 13 | 14 | Returns: 15 | Tensor: Reduced loss tensor. 16 | """ 17 | reduction_enum = F._Reduction.get_enum(reduction) 18 | # none: 0, elementwise_mean:1, sum: 2 19 | if reduction_enum == 0: 20 | return loss 21 | if reduction_enum == 1: 22 | return loss.mean() 23 | 24 | return loss.sum() 25 | 26 | 27 | def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False): 28 | """Apply element-wise weight and reduce loss. 29 | 30 | Args: 31 | loss (Tensor): Element-wise loss. 32 | weight (Tensor): Element-wise weights. Default: None. 33 | reduction (str): Same as built-in losses of PyTorch. Options are 34 | "none", "mean" and "sum". Default: 'mean'. 35 | sample_wise (bool): Whether calculate the loss sample-wise. This 36 | argument only takes effect when `reduction` is 'mean' and `weight` 37 | (argument of `forward()`) is not None. It will first reduces loss 38 | with 'mean' per-sample, and then it means over all the samples. 39 | Default: False. 40 | 41 | Returns: 42 | Tensor: Processed loss values. 43 | """ 44 | # if weight is specified, apply element-wise weight 45 | if weight is not None: 46 | assert weight.dim() == loss.dim() 47 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 48 | loss = loss * weight 49 | 50 | # if weight is not specified or reduction is sum, just reduce the loss 51 | if weight is None or reduction == 'sum': 52 | loss = reduce_loss(loss, reduction) 53 | # if reduction is mean, then compute mean over masked region 54 | elif reduction == 'mean': 55 | # expand weight from N1HW to NCHW 56 | if weight.size(1) == 1: 57 | weight = weight.expand_as(loss) 58 | # small value to prevent division by zero 59 | eps = 1e-12 60 | 61 | # perform sample-wise mean 62 | if sample_wise: 63 | weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111 64 | loss = (loss / (weight + eps)).sum() / weight.size(0) 65 | # perform pixel-wise mean 66 | else: 67 | loss = loss.sum() / (weight.sum() + eps) 68 | 69 | return loss 70 | 71 | 72 | def masked_loss(loss_func): 73 | """Create a masked version of a given loss function. 74 | 75 | To use this decorator, the loss function must have the signature like 76 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 77 | element-wise loss without any reduction. This decorator will add weight 78 | and reduction arguments to the function. The decorated function will have 79 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 80 | avg_factor=None, **kwargs)`. 81 | 82 | :Example: 83 | 84 | >>> import torch 85 | >>> @masked_loss 86 | >>> def l1_loss(pred, target): 87 | >>> return (pred - target).abs() 88 | 89 | >>> pred = torch.Tensor([0, 2, 3]) 90 | >>> target = torch.Tensor([1, 1, 1]) 91 | >>> weight = torch.Tensor([1, 0, 1]) 92 | 93 | >>> l1_loss(pred, target) 94 | tensor(1.3333) 95 | >>> l1_loss(pred, target, weight) 96 | tensor(1.5000) 97 | >>> l1_loss(pred, target, reduction='none') 98 | tensor([1., 1., 2.]) 99 | >>> l1_loss(pred, target, weight, reduction='sum') 100 | tensor(3.) 101 | """ 102 | 103 | @functools.wraps(loss_func) 104 | def wrapper(pred, 105 | target, 106 | weight=None, 107 | reduction='mean', 108 | sample_wise=False, 109 | **kwargs): 110 | # get element-wise loss 111 | loss = loss_func(pred, target, **kwargs) 112 | loss = mask_reduce_loss(loss, weight, reduction, sample_wise) 113 | return loss 114 | 115 | return wrapper 116 | -------------------------------------------------------------------------------- /mmedit/models/common/aspp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ConvModule 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .separable_conv_module import DepthwiseSeparableConvModule 8 | 9 | 10 | class ASPPPooling(nn.Sequential): 11 | 12 | def __init__(self, in_channels, out_channels, conv_cfg, norm_cfg, act_cfg): 13 | super().__init__( 14 | nn.AdaptiveAvgPool2d(1), 15 | ConvModule( 16 | in_channels, 17 | out_channels, 18 | 1, 19 | conv_cfg=conv_cfg, 20 | norm_cfg=norm_cfg, 21 | act_cfg=act_cfg)) 22 | 23 | def forward(self, x): 24 | size = x.shape[-2:] 25 | for mod in self: 26 | x = mod(x) 27 | return F.interpolate( 28 | x, size=size, mode='bilinear', align_corners=False) 29 | 30 | 31 | class ASPP(nn.Module): 32 | """ASPP module from DeepLabV3. 33 | 34 | The code is adopted from 35 | https://github.com/pytorch/vision/blob/master/torchvision/models/ 36 | segmentation/deeplabv3.py 37 | 38 | For more information about the module: 39 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 40 | `_. 41 | 42 | Args: 43 | in_channels (int): Input channels of the module. 44 | out_channels (int): Output channels of the module. 45 | mid_channels (int): Output channels of the intermediate ASPP conv 46 | modules. 47 | dilations (Sequence[int]): Dilation rate of three ASPP conv module. 48 | Default: [12, 24, 36]. 49 | conv_cfg (dict): Config dict for convolution layer. If "None", 50 | nn.Conv2d will be applied. Default: None. 51 | norm_cfg (dict): Config dict for normalization layer. 52 | Default: dict(type='BN'). 53 | act_cfg (dict): Config dict for activation layer. 54 | Default: dict(type='ReLU'). 55 | separable_conv (bool): Whether replace normal conv with depthwise 56 | separable conv which is faster. Default: False. 57 | """ 58 | 59 | def __init__(self, 60 | in_channels, 61 | out_channels=256, 62 | mid_channels=256, 63 | dilations=(12, 24, 36), 64 | conv_cfg=None, 65 | norm_cfg=dict(type='BN'), 66 | act_cfg=dict(type='ReLU'), 67 | separable_conv=False): 68 | super().__init__() 69 | 70 | if separable_conv: 71 | conv_module = DepthwiseSeparableConvModule 72 | else: 73 | conv_module = ConvModule 74 | 75 | modules = [] 76 | modules.append( 77 | ConvModule( 78 | in_channels, 79 | mid_channels, 80 | 1, 81 | conv_cfg=conv_cfg, 82 | norm_cfg=norm_cfg, 83 | act_cfg=act_cfg)) 84 | 85 | for dilation in dilations: 86 | modules.append( 87 | conv_module( 88 | in_channels, 89 | mid_channels, 90 | 3, 91 | padding=dilation, 92 | dilation=dilation, 93 | conv_cfg=conv_cfg, 94 | norm_cfg=norm_cfg, 95 | act_cfg=act_cfg)) 96 | 97 | modules.append( 98 | ASPPPooling(in_channels, mid_channels, conv_cfg, norm_cfg, 99 | act_cfg)) 100 | 101 | self.convs = nn.ModuleList(modules) 102 | 103 | self.project = nn.Sequential( 104 | ConvModule( 105 | 5 * mid_channels, 106 | out_channels, 107 | 1, 108 | conv_cfg=conv_cfg, 109 | norm_cfg=norm_cfg, 110 | act_cfg=act_cfg), nn.Dropout(0.5)) 111 | 112 | def forward(self, x): 113 | """Forward function for ASPP module. 114 | 115 | Args: 116 | x (Tensor): Input tensor with shape (N, C, H, W). 117 | 118 | Returns: 119 | Tensor: Output tensor. 120 | """ 121 | res = [] 122 | for conv in self.convs: 123 | res.append(conv(x)) 124 | res = torch.cat(res, dim=1) 125 | return self.project(res) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [MAP-Net](https://arxiv.org/abs/2303.09757) 2 | 3 | PyTorch implementation of **MAP-Net**, from the following paper: 4 | 5 | [Video Dehazing via a Multi-Range Temporal Alignment Network with Physical Prior](https://arxiv.org/abs/2303.09757). CVPR 2023.\ 6 | Jiaqi Xu, Xiaowei Hu, Lei Zhu, Qi Dou, Jifeng Dai, Yu Qiao, and Pheng-Ann Heng 7 | 8 |

9 | 11 |

12 | 13 | We propose **MAP-Net**, a novel video dehazing framework that effectively explores the physical haze priors and aggregates temporal information. 14 | 15 | 16 | ## Dataset 17 | 18 |

19 | 21 |

22 | 23 | We construct a large-scale outdoor video dehazing benchmark dataset, **HazeWorld**, which contains video frames in various real-world scenarios. 24 | 25 | To prepare the HazeWorld dataset for experiments, please follow the [instructions](./docs/dataset_prepare.md). 26 | 27 | 28 | ## Installation 29 | 30 | This implementation is based on [MMEditing](https://github.com/open-mmlab/mmediting), 31 | which is an open-source image and video editing toolbox. 32 | 33 | ``` 34 | python 3.10.9 35 | pytorch 1.12.1 36 | torchvision 0.13.1 37 | cuda 11.3 38 | ``` 39 | 40 | Below are quick steps for installation. 41 | 42 | **Step 1.** 43 | Install PyTorch following [official instructions](https://pytorch.org/get-started/locally/). 44 | 45 | **Step 2.** 46 | Install MMCV with [MIM](https://github.com/open-mmlab/mim). 47 | 48 | ```shell 49 | pip3 install openmim 50 | mim install mmcv-full 51 | ``` 52 | 53 | **Step 3.** 54 | Install MAP-Net from source. 55 | 56 | ```shell 57 | git clone https://github.com/jiaqixuac/MAP-Net.git 58 | cd MAP-Net 59 | pip3 install -e . 60 | ``` 61 | 62 | Please refer to [MMEditing Installation](https://github.com/open-mmlab/mmediting/blob/master/docs/en/install.md) for more detailed instruction. 63 | 64 | 65 | ## Getting Started 66 | 67 | You can train MAP-Net on HazeWorld using the below command with 4 GPUs: 68 | 69 | ```shell 70 | bash tools/dist_train.sh configs/dehazers/mapnet/mapnet_hazeworld.py 4 71 | ``` 72 | 73 | 74 | ## Evaluation 75 | 76 | We mainly use [psnr and ssim](./mmedit/core/evaluation/metrics.py) to measure the model performance. 77 | For HazeWorld, we compute the dataset-averaged video-level metrics; 78 | see the [*evaluate*](./mmedit/datasets/hw_folder_multiple_gt_dataset.py) function. 79 | 80 | You can use the following command with 1 GPU to test your trained model `xxx.pth`: 81 | 82 | ```shell 83 | bash tools/dist_test.sh configs/dehazers/mapnet/mapnet_hazeworld.py xxx.pth 1 84 | ``` 85 | 86 | You can find one model checkpoint trained on HazeWorld [here](https://appsrv.cse.cuhk.edu.hk/~jqxu/models/MAP-Net/mapnet_hazeworld_40k.pth) or [here](https://huggingface.co/jiaqixuac/MAP-Net/blob/main/mapnet_hazeworld_40k.pth). 87 | 88 | 89 | ## Results 90 | 91 | Demo for the real-world hazy videos. 92 | 93 | https://user-images.githubusercontent.com/33066765/224627919-cdc91886-9ab3-4b51-873b-3596c4aea085.mp4 94 | 95 | For the [REVIDE](https://github.com/BookerDeWitt/REVIDE_Dataset) dataset, 96 | the visual results of MAP-Net can be downloaded [here](https://appsrv.cse.cuhk.edu.hk/~jqxu/data/MAP-Net/visual_results_MAP-Net_REVIDE.zip) or [here](https://huggingface.co/datasets/jiaqixuac/HazeWorld/blob/main/visual_results_MAP-Net_REVIDE.zip). 97 | 98 | 99 | ## Acknowledgement 100 | 101 | This repository is built using the [mmedit](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc6) 102 | and [mmseg](https://github.com/open-mmlab/mmsegmentation) toolboxes, 103 | [DAT](https://github.com/LeapLabTHU/DAT) 104 | and [STM](https://github.com/seoungwugoh/STM) repositories. 105 | 106 | 107 | ## Citation 108 | 109 | If you find this repository helpful to your research, please consider citing the following: 110 | 111 | ```bibtex 112 | @inproceedings{xu2023map, 113 | title = {Video Dehazing via a Multi-Range Temporal Alignment Network with Physical Prior}, 114 | author = {Jiaqi Xu and Xiaowei Hu and Lei Zhu and Qi Dou and Jifeng Dai and Yu Qiao and Pheng-Ann Heng}, 115 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 116 | year = {2023}, 117 | } 118 | ``` 119 | 120 | 121 | ## License 122 | 123 | This project is released under the [MIT license](./LICENSE). 124 | Please refer to the acknowledged repositories for their licenses. 125 | -------------------------------------------------------------------------------- /configs/dehazers/_base_/datasets/revide.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/datasets/ade20k.py 2 | # https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/basicvsr/basicvsr_reds4.py 3 | 4 | train_dataset_type = 'SRFolderMultipleGTDataset' 5 | test_dataset_type = 'SRFolderMultipleGTDataset' 6 | 7 | img_norm_cfg_lq = dict( 8 | # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True, 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], to_rgb=True, 10 | ) 11 | img_norm_cfg_gt = dict( 12 | # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True, 13 | mean=[0., 0., 0.], std=[1., 1., 1.], to_rgb=True, 14 | ) 15 | crop_size = 384 16 | num_input_frames = 3 17 | 18 | io_backend = 'disk' 19 | load_kwargs = dict() 20 | 21 | train_pipeline = [ 22 | dict(type='GenerateSegmentIndices', interval_list=[1], filename_tmpl='{:05d}.JPG'), 23 | dict(type='LoadImageFromFileList', 24 | io_backend=io_backend, 25 | key='lq', 26 | flag='unchanged', 27 | # channel_order='rgb', 28 | **load_kwargs), 29 | dict(type='LoadImageFromFileList', 30 | io_backend=io_backend, 31 | key='gt', 32 | flag='unchanged', 33 | **load_kwargs), 34 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']), 35 | dict(type='ResizeVideo', keys=['lq', 'gt'], scales=[0.25, 0.375, 0.5, 0.625, 0.75], sample=False), 36 | dict(type='Normalize', 37 | keys=['lq'], 38 | **img_norm_cfg_lq), 39 | dict(type='Normalize', 40 | keys=['gt'], 41 | **img_norm_cfg_gt), 42 | dict(type='PairedRandomCrop', gt_patch_size=crop_size), 43 | dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, 44 | direction='horizontal'), 45 | # by jqxu: do not rotation 46 | # dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'), 47 | # dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5), 48 | dict(type='FramesToTensor', keys=['lq', 'gt']), 49 | dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path', 'key']), 50 | ] 51 | test_pipeline = [ 52 | # folder-based 53 | dict(type='GenerateSegmentIndices', interval_list=[1], filename_tmpl='{:05d}.JPG'), 54 | dict(type='LoadImageFromFileList', 55 | io_backend=io_backend, 56 | key='lq', 57 | flag='unchanged', 58 | # channel_order='rgb' 59 | **load_kwargs), 60 | dict(type='LoadImageFromFileList', 61 | io_backend=io_backend, 62 | key='gt', 63 | flag='unchanged', 64 | **load_kwargs), 65 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']), 66 | dict(type='ResizeVideo', keys=['lq'], scales=0.5), 67 | dict(type='Normalize', 68 | keys=['lq'], 69 | **img_norm_cfg_lq), 70 | dict(type='Normalize', 71 | keys=['gt'], 72 | **img_norm_cfg_gt), 73 | dict(type='FramesToTensor', keys=['lq', 'gt']), 74 | dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'lq_path', 'key']), 75 | ] 76 | 77 | data = dict( 78 | workers_per_gpu=6, 79 | train_dataloader=dict(samples_per_gpu=4, drop_last=True), 80 | val_dataloader=dict(samples_per_gpu=1, workers_per_gpu=2), 81 | test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=2), 82 | 83 | train=dict( 84 | type='RepeatDataset', 85 | times=10000, 86 | dataset=dict( 87 | type=train_dataset_type, 88 | lq_folder='data/REVIDE/REVIDE_indoor/Train/hazy', 89 | gt_folder='data/REVIDE/REVIDE_indoor/Train/gt', 90 | ann_file='data/REVIDE/REVIDE_indoor/Train/meta_info_GT.txt', 91 | num_input_frames=num_input_frames, 92 | pipeline=train_pipeline, 93 | img_extension='.JPG', 94 | scale=1, 95 | test_mode=False)), 96 | val=dict( 97 | type=test_dataset_type, 98 | lq_folder='data/REVIDE/REVIDE_indoor/Test/hazy', 99 | gt_folder='data/REVIDE/REVIDE_indoor/Test/gt', 100 | ann_file='data/REVIDE/REVIDE_indoor/Test/meta_info_GT.txt', 101 | pipeline=test_pipeline, 102 | img_extension='.JPG', 103 | scale=1, 104 | test_mode=True), 105 | test=dict( 106 | type=test_dataset_type, 107 | lq_folder='data/REVIDE/REVIDE_indoor/Test/hazy', 108 | gt_folder='data/REVIDE/REVIDE_indoor/Test/gt', 109 | ann_file='data/REVIDE/REVIDE_indoor/Test/meta_info_GT.txt', 110 | pipeline=test_pipeline, 111 | img_extension='.JPG', 112 | scale=1, 113 | test_mode=True) 114 | ) 115 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from mmcv.runner import Hook 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class EvalIterHook(Hook): 9 | """Non-Distributed evaluation hook for iteration-based runner. 10 | 11 | This hook will regularly perform evaluation in a given interval when 12 | performing in non-distributed environment. 13 | 14 | Args: 15 | dataloader (DataLoader): A PyTorch dataloader. 16 | interval (int): Evaluation interval. Default: 1. 17 | eval_kwargs (dict): Other eval kwargs. It contains: 18 | save_image (bool): Whether to save image. 19 | save_path (str): The path to save image. 20 | """ 21 | 22 | def __init__(self, dataloader, interval=1, **eval_kwargs): 23 | if not isinstance(dataloader, DataLoader): 24 | raise TypeError('dataloader must be a pytorch DataLoader, ' 25 | f'but got { type(dataloader)}') 26 | self.dataloader = dataloader 27 | self.interval = interval 28 | self.eval_kwargs = eval_kwargs 29 | self.save_image = self.eval_kwargs.pop('save_image', False) 30 | self.save_path = self.eval_kwargs.pop('save_path', None) 31 | 32 | def after_train_iter(self, runner): 33 | """The behavior after each train iteration. 34 | 35 | Args: 36 | runner (``mmcv.runner.BaseRunner``): The runner. 37 | """ 38 | if not self.every_n_iters(runner, self.interval): 39 | return 40 | runner.log_buffer.clear() 41 | from mmedit.apis import single_gpu_test 42 | results = single_gpu_test( 43 | runner.model, 44 | self.dataloader, 45 | save_image=self.save_image, 46 | save_path=self.save_path, 47 | iteration=runner.iter) 48 | self.evaluate(runner, results) 49 | 50 | def evaluate(self, runner, results): 51 | """Evaluation function. 52 | 53 | Args: 54 | runner (``mmcv.runner.BaseRunner``): The runner. 55 | results (dict): Model forward results. 56 | """ 57 | eval_res = self.dataloader.dataset.evaluate( 58 | results, logger=runner.logger, **self.eval_kwargs) 59 | for name, val in eval_res.items(): 60 | if isinstance(val, dict): 61 | runner.log_buffer.output.update(val) 62 | continue 63 | runner.log_buffer.output[name] = val 64 | runner.log_buffer.ready = True 65 | # call `after_val_epoch` after evaluation. 66 | # This is a hack. 67 | # Because epoch does not naturally exist In IterBasedRunner, 68 | # thus we consider the end of an evluation as the end of an epoch. 69 | # With this hack , we can support epoch based hooks. 70 | if 'iter' in runner.__class__.__name__.lower(): 71 | runner.call_hook('after_val_epoch') 72 | 73 | 74 | class DistEvalIterHook(EvalIterHook): 75 | """Distributed evaluation hook. 76 | 77 | Args: 78 | dataloader (DataLoader): A PyTorch dataloader. 79 | interval (int): Evaluation interval. Default: 1. 80 | tmpdir (str | None): Temporary directory to save the results of all 81 | processes. Default: None. 82 | gpu_collect (bool): Whether to use gpu or cpu to collect results. 83 | Default: False. 84 | eval_kwargs (dict): Other eval kwargs. It may contain: 85 | save_image (bool): Whether save image. 86 | save_path (str): The path to save image. 87 | """ 88 | 89 | def __init__(self, 90 | dataloader, 91 | interval=1, 92 | gpu_collect=False, 93 | **eval_kwargs): 94 | super().__init__(dataloader, interval, **eval_kwargs) 95 | self.gpu_collect = gpu_collect 96 | 97 | def after_train_iter(self, runner): 98 | """The behavior after each train iteration. 99 | 100 | Args: 101 | runner (``mmcv.runner.BaseRunner``): The runner. 102 | """ 103 | if not self.every_n_iters(runner, self.interval): 104 | return 105 | runner.log_buffer.clear() 106 | from mmedit.apis import multi_gpu_test 107 | results = multi_gpu_test( 108 | runner.model, 109 | self.dataloader, 110 | tmpdir=osp.join(runner.work_dir, '.eval_hook'), 111 | gpu_collect=self.gpu_collect, 112 | save_image=self.save_image, 113 | save_path=self.save_path, 114 | iteration=runner.iter) 115 | if runner.rank == 0: 116 | print('\n') 117 | self.evaluate(runner, results) 118 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_folder_multiple_gt_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os 4 | import os.path as osp 5 | 6 | import mmcv 7 | 8 | from .base_sr_dataset import BaseSRDataset 9 | from .registry import DATASETS 10 | 11 | 12 | @DATASETS.register_module() 13 | class SRFolderMultipleGTDataset(BaseSRDataset): 14 | """General dataset for video super resolution, used for recurrent networks. 15 | 16 | The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) 17 | frames. Then it applies specified transforms and finally returns a dict 18 | containing paired data and other information. 19 | 20 | This dataset takes an annotation file specifying the sequences used in 21 | training or test. If no annotation file is provided, it assumes all video 22 | sequences under the root directory is used for training or test. 23 | 24 | In the annotation file (.txt), each line contains: 25 | 26 | 1. folder name; 27 | 2. number of frames in this sequence (in the same folder) 28 | 29 | Examples: 30 | 31 | :: 32 | 33 | calendar 41 34 | city 34 35 | foliage 49 36 | walk 47 37 | 38 | Args: 39 | lq_folder (str | :obj:`Path`): Path to a lq folder. 40 | gt_folder (str | :obj:`Path`): Path to a gt folder. 41 | pipeline (list[dict | callable]): A sequence of data transformations. 42 | scale (int): Upsampling scale ratio. 43 | ann_file (str): The path to the annotation file. If None, we assume 44 | that all sequences in the folder is used. Default: None 45 | num_input_frames (None | int): The number of frames per iteration. 46 | If None, the whole clip is extracted. If it is a positive integer, 47 | a sequence of 'num_input_frames' frames is extracted from the clip. 48 | Note that non-positive integers are not accepted. Default: None. 49 | test_mode (bool): Store `True` when building test dataset. 50 | Default: `True`. 51 | """ 52 | 53 | def __init__(self, 54 | lq_folder, 55 | gt_folder, 56 | pipeline, 57 | scale, 58 | ann_file=None, 59 | num_input_frames=None, 60 | img_extension='.png', 61 | test_mode=True): 62 | super().__init__(pipeline, scale, test_mode) 63 | 64 | self.lq_folder = str(lq_folder) 65 | self.gt_folder = str(gt_folder) 66 | self.ann_file = ann_file 67 | 68 | if num_input_frames is not None and num_input_frames <= 0: 69 | raise ValueError('"num_input_frames" must be None or positive, ' 70 | f'but got {num_input_frames}.') 71 | self.num_input_frames = num_input_frames 72 | self.img_extension = img_extension 73 | 74 | self.data_infos = self.load_annotations() 75 | 76 | def _load_annotations_from_file(self): 77 | data_infos = [] 78 | 79 | ann_list = mmcv.list_from_file(self.ann_file) 80 | for ann in ann_list: 81 | key, sequence_length = ann.strip().split(' ') 82 | if self.num_input_frames is None: 83 | num_input_frames = sequence_length 84 | else: 85 | num_input_frames = self.num_input_frames 86 | data_infos.append( 87 | dict( 88 | lq_path=self.lq_folder, 89 | gt_path=self.gt_folder, 90 | key=key, 91 | num_input_frames=int(num_input_frames), 92 | sequence_length=int(sequence_length))) 93 | 94 | return data_infos 95 | 96 | def load_annotations(self): 97 | """Load annotations for the dataset. 98 | 99 | Returns: 100 | list[dict]: Returned list of dicts for paired paths of LQ and GT. 101 | """ 102 | 103 | if self.ann_file: 104 | return self._load_annotations_from_file() 105 | 106 | sequences = sorted(glob.glob(osp.join(self.lq_folder, '*'))) 107 | data_infos = [] 108 | for sequence in sequences: 109 | sequence_length = len(glob.glob(osp.join(sequence, f'*{self.img_extension}'))) 110 | if self.num_input_frames is None: 111 | num_input_frames = sequence_length 112 | else: 113 | num_input_frames = self.num_input_frames 114 | data_infos.append( 115 | dict( 116 | lq_path=self.lq_folder, 117 | gt_path=self.gt_folder, 118 | key=sequence.replace(f'{self.lq_folder}{os.sep}', ''), 119 | num_input_frames=num_input_frames, 120 | sequence_length=sequence_length)) 121 | 122 | return data_infos 123 | -------------------------------------------------------------------------------- /mmedit/models/common/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_requires_grad(nets, requires_grad=False): 7 | """Set requires_grad for all the networks. 8 | 9 | Args: 10 | nets (nn.Module | list[nn.Module]): A list of networks or a single 11 | network. 12 | requires_grad (bool): Whether the networks require gradients or not 13 | """ 14 | if not isinstance(nets, list): 15 | nets = [nets] 16 | for net in nets: 17 | if net is not None: 18 | for param in net.parameters(): 19 | param.requires_grad = requires_grad 20 | 21 | 22 | def extract_bbox_patch(bbox, img, channel_first=True): 23 | """Extract patch from a given bbox. 24 | 25 | Args: 26 | bbox (torch.Tensor | numpy.array): Bbox with (top, left, h, w). If 27 | `img` has batch dimension, the `bbox` must be stacked at first 28 | dimension. The shape should be (4,) or (n, 4). 29 | img (torch.Tensor | numpy.array): Image data to be extracted. If 30 | organized in batch dimension, the batch dimension must be the first 31 | order like (n, h, w, c) or (n, c, h, w). 32 | channel_first (bool): If True, the channel dimension of img is before 33 | height and width, e.g. (c, h, w). Otherwise, the img shape (samples 34 | in the batch) is like (h, w, c). 35 | 36 | Returns: 37 | (torch.Tensor | numpy.array): Extracted patches. The dimension of the \ 38 | output should be the same as `img`. 39 | """ 40 | 41 | def _extract(bbox, img): 42 | assert len(bbox) == 4 43 | t, l, h, w = bbox 44 | if channel_first: 45 | img_patch = img[..., t:t + h, l:l + w] 46 | else: 47 | img_patch = img[t:t + h, l:l + w, ...] 48 | 49 | return img_patch 50 | 51 | input_size = img.shape 52 | assert len(input_size) == 3 or len(input_size) == 4 53 | bbox_size = bbox.shape 54 | assert bbox_size == (4, ) or (len(bbox_size) == 2 55 | and bbox_size[0] == input_size[0]) 56 | 57 | # images with batch dimension 58 | if len(input_size) == 4: 59 | output_list = [] 60 | for i in range(input_size[0]): 61 | img_patch_ = _extract(bbox[i], img[i:i + 1, ...]) 62 | output_list.append(img_patch_) 63 | if isinstance(img, torch.Tensor): 64 | img_patch = torch.cat(output_list, dim=0) 65 | else: 66 | img_patch = np.concatenate(output_list, axis=0) 67 | # standardize image 68 | else: 69 | img_patch = _extract(bbox, img) 70 | 71 | return img_patch 72 | 73 | 74 | def scale_bbox(bbox, target_size): 75 | """Modify bbox to target size. 76 | 77 | The original bbox will be enlarged to the target size with the original 78 | bbox in the center of the new bbox. 79 | 80 | Args: 81 | bbox (np.ndarray | torch.Tensor): Bboxes to be modified. Bbox can 82 | be in batch or not. The shape should be (4,) or (n, 4). 83 | target_size (tuple[int]): Target size of final bbox. 84 | 85 | Returns: 86 | (np.ndarray | torch.Tensor): Modified bboxes. 87 | """ 88 | 89 | def _mod(bbox, target_size): 90 | top_ori, left_ori, h_ori, w_ori = bbox 91 | h, w = target_size 92 | assert h >= h_ori and w >= w_ori 93 | top = int(max(0, top_ori - (h - h_ori) // 2)) 94 | left = int(max(0, left_ori - (w - w_ori) // 2)) 95 | 96 | if isinstance(bbox, torch.Tensor): 97 | bbox_new = torch.Tensor([top, left, h, w]).type_as(bbox) 98 | else: 99 | bbox_new = np.asarray([top, left, h, w]) 100 | 101 | return bbox_new 102 | 103 | if isinstance(bbox, torch.Tensor): 104 | bbox_new = torch.zeros_like(bbox) 105 | elif isinstance(bbox, np.ndarray): 106 | bbox_new = np.zeros_like(bbox) 107 | else: 108 | raise TypeError('bbox mush be torch.Tensor or numpy.ndarray' 109 | f'but got type {type(bbox)}') 110 | bbox_shape = list(bbox.shape) 111 | 112 | if len(bbox_shape) == 2: 113 | for i in range(bbox_shape[0]): 114 | bbox_new[i, :] = _mod(bbox[i], target_size) 115 | else: 116 | bbox_new = _mod(bbox, target_size) 117 | 118 | return bbox_new 119 | 120 | 121 | def extract_around_bbox(img, bbox, target_size, channel_first=True): 122 | """Extract patches around the given bbox. 123 | 124 | Args: 125 | bbox (np.ndarray | torch.Tensor): Bboxes to be modified. Bbox can 126 | be in batch or not. 127 | target_size (List(int)): Target size of final bbox. 128 | 129 | Returns: 130 | (torch.Tensor | numpy.array): Extracted patches. The dimension of the \ 131 | output should be the same as `img`. 132 | """ 133 | bbox_new = scale_bbox(bbox, target_size) 134 | img_patch = extract_bbox_patch(bbox_new, img, channel_first=channel_first) 135 | 136 | return img_patch, bbox_new 137 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_video_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | 9 | from mmedit.datasets.pipelines import Compose 10 | 11 | VIDEO_EXTENSIONS = ('.mp4', '.mov') 12 | 13 | 14 | def pad_sequence(data, window_size): 15 | padding = window_size // 2 16 | 17 | data = torch.cat([ 18 | data[:, 1 + padding:1 + 2 * padding].flip(1), data, 19 | data[:, -1 - 2 * padding:-1 - padding].flip(1) 20 | ], 21 | dim=1) 22 | 23 | return data 24 | 25 | 26 | def restoration_video_inference(model, 27 | img_dir, 28 | window_size, 29 | start_idx, 30 | filename_tmpl, 31 | max_seq_len=None): 32 | """Inference image with the model. 33 | 34 | Args: 35 | model (nn.Module): The loaded model. 36 | img_dir (str): Directory of the input video. 37 | window_size (int): The window size used in sliding-window framework. 38 | This value should be set according to the settings of the network. 39 | A value smaller than 0 means using recurrent framework. 40 | start_idx (int): The index corresponds to the first frame in the 41 | sequence. 42 | filename_tmpl (str): Template for file name. 43 | max_seq_len (int | None): The maximum sequence length that the model 44 | processes. If the sequence length is larger than this number, 45 | the sequence is split into multiple segments. If it is None, 46 | the entire sequence is processed at once. 47 | 48 | Returns: 49 | Tensor: The predicted restoration result. 50 | """ 51 | 52 | device = next(model.parameters()).device # model device 53 | 54 | # build the data pipeline 55 | if model.cfg.get('demo_pipeline', None): 56 | test_pipeline = model.cfg.demo_pipeline 57 | elif model.cfg.get('test_pipeline', None): 58 | test_pipeline = model.cfg.test_pipeline 59 | else: 60 | test_pipeline = model.cfg.val_pipeline 61 | 62 | # check if the input is a video 63 | file_extension = osp.splitext(img_dir)[1] 64 | if file_extension in VIDEO_EXTENSIONS: 65 | video_reader = mmcv.VideoReader(img_dir) 66 | # load the images 67 | data = dict(lq=[], lq_path=None, key=img_dir) 68 | for frame in video_reader: 69 | data['lq'].append(np.flip(frame, axis=2)) 70 | 71 | # remove the data loading pipeline 72 | tmp_pipeline = [] 73 | for pipeline in test_pipeline: 74 | if pipeline['type'] not in [ 75 | 'GenerateSegmentIndices', 'LoadImageFromFileList' 76 | ]: 77 | tmp_pipeline.append(pipeline) 78 | test_pipeline = tmp_pipeline 79 | else: 80 | # the first element in the pipeline must be 'GenerateSegmentIndices' 81 | if test_pipeline[0]['type'] != 'GenerateSegmentIndices': 82 | raise TypeError('The first element in the pipeline must be ' 83 | f'"GenerateSegmentIndices", but got ' 84 | f'"{test_pipeline[0]["type"]}".') 85 | 86 | # specify start_idx and filename_tmpl 87 | test_pipeline[0]['start_idx'] = start_idx 88 | test_pipeline[0]['filename_tmpl'] = filename_tmpl 89 | 90 | # prepare data 91 | sequence_length = len(glob.glob(osp.join(img_dir, '*'))) 92 | lq_folder = osp.dirname(img_dir) 93 | key = osp.basename(img_dir) 94 | data = dict( 95 | lq_path=lq_folder, 96 | gt_path='', 97 | key=key, 98 | sequence_length=sequence_length) 99 | 100 | # compose the pipeline 101 | test_pipeline = Compose(test_pipeline) 102 | data = test_pipeline(data) 103 | data = data['lq'].unsqueeze(0) # in cpu 104 | 105 | # forward the model 106 | with torch.no_grad(): 107 | if window_size > 0: # sliding window framework 108 | data = pad_sequence(data, window_size) 109 | result = [] 110 | for i in range(0, data.size(1) - 2 * (window_size // 2)): 111 | data_i = data[:, i:i + window_size].to(device) 112 | result.append(model(lq=data_i, test_mode=True)['output'].cpu()) 113 | result = torch.stack(result, dim=1) 114 | else: # recurrent framework 115 | if max_seq_len is None: 116 | result = model( 117 | lq=data.to(device), test_mode=True)['output'].cpu() 118 | else: 119 | result = [] 120 | for i in range(0, data.size(1), max_seq_len): 121 | result.append( 122 | model( 123 | lq=data[:, i:i + max_seq_len].to(device), 124 | test_mode=True)['output'].cpu()) 125 | result = torch.cat(result, dim=1) 126 | return result 127 | -------------------------------------------------------------------------------- /configs/dehazers/_base_/datasets/hazeworld.py: -------------------------------------------------------------------------------- 1 | # https://github.com/open-mmlab/mmsegmentation/blob/master/configs/_base_/datasets/ade20k.py 2 | # https://github.com/open-mmlab/mmediting/blob/master/configs/restorers/basicvsr/basicvsr_reds4.py 3 | 4 | train_dataset_type = 'HWFolderMultipleGTDataset' 5 | test_dataset_type = 'HWFolderMultipleGTDataset' # dataset & video-level evaluation 6 | 7 | img_norm_cfg_lq = dict( 8 | # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True, 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], to_rgb=True, 10 | ) 11 | img_norm_cfg_gt = dict( 12 | # mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True, 13 | mean=[0., 0., 0.], std=[1., 1., 1.], to_rgb=True, 14 | ) 15 | crop_size = 256 16 | num_input_frames = 5 17 | 18 | io_backend = 'disk' 19 | load_kwargs = dict() 20 | 21 | train_pipeline = [ 22 | dict(type='GenerateFileIndices', 23 | interval_list=[1], 24 | annotation_tree_json='data/HazeWorld/train/meta_info_tree_GT_train.json'), 25 | dict(type='LoadImageFromFileList', 26 | io_backend=io_backend, 27 | key='lq', 28 | flag='unchanged', 29 | # channel_order='rgb', 30 | **load_kwargs), 31 | dict(type='LoadImageFromFileList', 32 | io_backend=io_backend, 33 | key='gt', 34 | flag='unchanged', 35 | **load_kwargs), 36 | # dict(type='LoadImageFromFileList', 37 | # io_backend=io_backend, 38 | # key='trans', 39 | # flag='unchanged', 40 | # **load_kwargs), 41 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']), 42 | # dict(type='RescaleToZeroOne', keys=['lq', 'gt', 'trans']), 43 | dict(type='Normalize', 44 | keys=['lq'], 45 | **img_norm_cfg_lq), 46 | dict(type='Normalize', 47 | keys=['gt'], 48 | **img_norm_cfg_gt), 49 | dict(type='PairedRandomCrop', gt_patch_size=crop_size), 50 | # dict(type='PairedRandomCropWithTransmission', gt_patch_size=crop_size), 51 | dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, 52 | direction='horizontal'), 53 | # dict(type='Flip', keys=['lq', 'gt', 'trans'], flip_ratio=0.5, 54 | # direction='horizontal'), 55 | # dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'), 56 | # dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5), 57 | dict(type='FramesToTensor', keys=['lq', 'gt']), 58 | # dict(type='FramesToTensor', keys=['lq', 'gt', 'trans']), 59 | # dict(type='ToTensor', keys=['haze_light']) 60 | dict(type='Collect', 61 | keys=['lq', 'gt'], 62 | # keys=['lq', 'gt', 'trans', 'haze_light'], 63 | meta_keys=['lq_path', 'gt_path', 'dataset', 'folder', 'haze_beta', 'haze_light']), 64 | ] 65 | test_pipeline = [ 66 | dict(type='GenerateFileIndices', 67 | interval_list=[1], 68 | annotation_tree_json='data/HazeWorld/test/meta_info_tree_GT_test.json'), 69 | dict(type='LoadImageFromFileList', 70 | io_backend=io_backend, 71 | key='lq', 72 | flag='unchanged', 73 | # channel_order='rgb' 74 | **load_kwargs), 75 | dict(type='LoadImageFromFileList', 76 | io_backend=io_backend, 77 | key='gt', 78 | flag='unchanged', 79 | **load_kwargs), 80 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']), 81 | dict(type='Normalize', 82 | keys=['lq'], 83 | **img_norm_cfg_lq), 84 | dict(type='Normalize', 85 | keys=['gt'], 86 | **img_norm_cfg_gt), 87 | dict(type='FramesToTensor', keys=['lq', 'gt']), 88 | dict(type='Collect', 89 | keys=['lq', 'gt'], 90 | meta_keys=['lq_path', 'gt_path', 'dataset', 'folder', 'haze_beta', 'haze_light']), 91 | ] 92 | 93 | data = dict( 94 | workers_per_gpu=6, 95 | train_dataloader=dict(samples_per_gpu=4, drop_last=True), 96 | val_dataloader=dict(samples_per_gpu=1, workers_per_gpu=2), 97 | test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=2), 98 | 99 | train=dict( 100 | type='RepeatDataset', 101 | times=10000, 102 | dataset=dict( 103 | type=train_dataset_type, 104 | lq_folder='data/HazeWorld/train/hazy', 105 | gt_folder='data/HazeWorld/train/gt', 106 | # trans_folder='data/HazeWorld/train/transmission', 107 | ann_file='data/HazeWorld/train/meta_info_GT_train.txt', 108 | num_input_frames=num_input_frames, 109 | pipeline=train_pipeline, 110 | test_mode=False)), 111 | val=dict( 112 | type=test_dataset_type, 113 | lq_folder='data/HazeWorld/test/hazy', 114 | gt_folder='data/HazeWorld/test/gt', 115 | ann_file='data/HazeWorld/test/meta_info_GT_test.txt', 116 | pipeline=test_pipeline, 117 | test_mode=False), 118 | test=dict( 119 | type=test_dataset_type, 120 | lq_folder='data/HazeWorld/test/hazy', 121 | gt_folder='data/HazeWorld/test/gt', 122 | ann_file='data/HazeWorld/test/meta_info_GT_test.txt', 123 | pipeline=test_pipeline, 124 | test_mode=True) 125 | ) 126 | -------------------------------------------------------------------------------- /mmedit/core/hooks/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | from copy import deepcopy 4 | from functools import partial 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.parallel import is_module_wrapper 9 | from mmcv.runner import HOOKS, Hook 10 | 11 | 12 | @HOOKS.register_module() 13 | class ExponentialMovingAverageHook(Hook): 14 | """Exponential Moving Average Hook. 15 | 16 | Exponential moving average is a trick that widely used in current GAN 17 | literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is 18 | maintaining a model with the same architecture, but its parameters are 19 | updated as a moving average of the trained weights in the original model. 20 | In general, the model with moving averaged weights achieves better 21 | performance. 22 | 23 | Args: 24 | module_keys (str | tuple[str]): The name of the ema model. Note that we 25 | require these keys are followed by '_ema' so that we can easily 26 | find the original model by discarding the last four characters. 27 | interp_mode (str, optional): Mode of the interpolation method. 28 | Defaults to 'lerp'. 29 | interp_cfg (dict | None, optional): Set arguments of the interpolation 30 | function. Defaults to None. 31 | interval (int, optional): Evaluation interval (by iterations). 32 | Default: -1. 33 | start_iter (int, optional): Start iteration for ema. If the start 34 | iteration is not reached, the weights of ema model will maintain 35 | the same as the original one. Otherwise, its parameters are updated 36 | as a moving average of the trained weights in the original model. 37 | Default: 0. 38 | """ 39 | 40 | def __init__(self, 41 | module_keys, 42 | interp_mode='lerp', 43 | interp_cfg=None, 44 | interval=-1, 45 | start_iter=0): 46 | super().__init__() 47 | assert isinstance(module_keys, str) or mmcv.is_tuple_of( 48 | module_keys, str) 49 | self.module_keys = (module_keys, ) if isinstance(module_keys, 50 | str) else module_keys 51 | # sanity check for the format of module keys 52 | for k in self.module_keys: 53 | assert k.endswith( 54 | '_ema'), 'You should give keys that end with "_ema".' 55 | self.interp_mode = interp_mode 56 | self.interp_cfg = dict() if interp_cfg is None else deepcopy( 57 | interp_cfg) 58 | self.interval = interval 59 | self.start_iter = start_iter 60 | 61 | assert hasattr( 62 | self, interp_mode 63 | ), f'Currently, we do not support {self.interp_mode} for EMA.' 64 | self.interp_func = partial( 65 | getattr(self, interp_mode), **self.interp_cfg) 66 | 67 | @staticmethod 68 | def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True): 69 | m = momentum if trainable else momentum_nontrainable 70 | return a + (b - a) * m 71 | 72 | def every_n_iters(self, runner, n): 73 | if runner.iter < self.start_iter: 74 | return True 75 | return (runner.iter + 1 - self.start_iter) % n == 0 if n > 0 else False 76 | 77 | @torch.no_grad() 78 | def after_train_iter(self, runner): 79 | if not self.every_n_iters(runner, self.interval): 80 | return 81 | 82 | model = runner.model.module if is_module_wrapper( 83 | runner.model) else runner.model 84 | 85 | for key in self.module_keys: 86 | # get current ema states 87 | ema_net = getattr(model, key) 88 | states_ema = ema_net.state_dict(keep_vars=False) 89 | # get currently original states 90 | net = getattr(model, key[:-4]) 91 | states_orig = net.state_dict(keep_vars=True) 92 | 93 | for k, v in states_orig.items(): 94 | if runner.iter < self.start_iter: 95 | states_ema[k].data.copy_(v.data) 96 | else: 97 | states_ema[k] = self.interp_func( 98 | v, states_ema[k], trainable=v.requires_grad).detach() 99 | ema_net.load_state_dict(states_ema, strict=True) 100 | 101 | def before_run(self, runner): 102 | model = runner.model.module if is_module_wrapper( 103 | runner.model) else runner.model 104 | # sanity check for ema model 105 | for k in self.module_keys: 106 | if not hasattr(model, k) and not hasattr(model, k[:-4]): 107 | raise RuntimeError( 108 | f'Cannot find both {k[:-4]} and {k} network for EMA hook.') 109 | if not hasattr(model, k) and hasattr(model, k[:-4]): 110 | setattr(model, k, deepcopy(getattr(model, k[:-4]))) 111 | warnings.warn( 112 | f'We do not suggest construct and initialize EMA model {k}' 113 | ' in hook. You may explicitly define it by yourself.') 114 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/random_down_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv import imresize 7 | 8 | from ..registry import PIPELINES 9 | 10 | 11 | @PIPELINES.register_module() 12 | class RandomDownSampling: 13 | """Generate LQ image from GT (and crop), which will randomly pick a scale. 14 | 15 | Args: 16 | scale_min (float): The minimum of upsampling scale, inclusive. 17 | Default: 1.0. 18 | scale_max (float): The maximum of upsampling scale, exclusive. 19 | Default: 4.0. 20 | patch_size (int): The cropped lr patch size. 21 | Default: None, means no crop. 22 | interpolation (str): Interpolation method, accepted values are 23 | "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' 24 | backend, "nearest", "bilinear", "bicubic", "box", "lanczos", 25 | "hamming" for 'pillow' backend. 26 | Default: "bicubic". 27 | backend (str | None): The image resize backend type. Options are `cv2`, 28 | `pillow`, `None`. If backend is None, the global imread_backend 29 | specified by ``mmcv.use_backend()`` will be used. 30 | Default: "pillow". 31 | 32 | Scale will be picked in the range of [scale_min, scale_max). 33 | """ 34 | 35 | def __init__(self, 36 | scale_min=1.0, 37 | scale_max=4.0, 38 | patch_size=None, 39 | interpolation='bicubic', 40 | backend='pillow'): 41 | assert scale_max >= scale_min 42 | self.scale_min = scale_min 43 | self.scale_max = scale_max 44 | self.patch_size = patch_size 45 | self.interpolation = interpolation 46 | self.backend = backend 47 | 48 | def __call__(self, results): 49 | """Call function. 50 | 51 | Args: 52 | results (dict): A dict containing the necessary information and 53 | data for augmentation. 'gt' is required. 54 | 55 | Returns: 56 | dict: A dict containing the processed data and information. 57 | modified 'gt', supplement 'lq' and 'scale' to keys. 58 | """ 59 | img = results['gt'] 60 | scale = np.random.uniform(self.scale_min, self.scale_max) 61 | 62 | if self.patch_size is None: 63 | h_lr = math.floor(img.shape[-3] / scale + 1e-9) 64 | w_lr = math.floor(img.shape[-2] / scale + 1e-9) 65 | img = img[:round(h_lr * scale), :round(w_lr * scale), :] 66 | img_down = resize_fn(img, (w_lr, h_lr), self.interpolation, 67 | self.backend) 68 | crop_lr, crop_hr = img_down, img 69 | else: 70 | w_lr = self.patch_size 71 | w_hr = round(w_lr * scale) 72 | x0 = np.random.randint(0, img.shape[-3] - w_hr) 73 | y0 = np.random.randint(0, img.shape[-2] - w_hr) 74 | crop_hr = img[x0:x0 + w_hr, y0:y0 + w_hr, :] 75 | crop_lr = resize_fn(crop_hr, w_lr, self.interpolation, 76 | self.backend) 77 | results['gt'] = crop_hr 78 | results['lq'] = crop_lr 79 | results['scale'] = scale 80 | 81 | return results 82 | 83 | def __repr__(self): 84 | repr_str = self.__class__.__name__ 85 | repr_str += (f' scale_min={self.scale_min}, ' 86 | f'scale_max={self.scale_max}, ' 87 | f'patch_size={self.patch_size}, ' 88 | f'interpolation={self.interpolation}, ' 89 | f'backend={self.backend}') 90 | 91 | return repr_str 92 | 93 | 94 | def resize_fn(img, size, interpolation='bicubic', backend='pillow'): 95 | """Resize the given image to a given size. 96 | 97 | Args: 98 | img (ndarray | torch.Tensor): The input image. 99 | size (int | tuple[int]): Target size w or (w, h). 100 | interpolation (str): Interpolation method, accepted values are 101 | "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' 102 | backend, "nearest", "bilinear", "bicubic", "box", "lanczos", 103 | "hamming" for 'pillow' backend. 104 | Default: "bicubic". 105 | backend (str | None): The image resize backend type. Options are `cv2`, 106 | `pillow`, `None`. If backend is None, the global imread_backend 107 | specified by ``mmcv.use_backend()`` will be used. 108 | Default: "pillow". 109 | 110 | Returns: 111 | ndarray | torch.Tensor: `resized_img`, whose type is same as `img`. 112 | """ 113 | if isinstance(size, int): 114 | size = (size, size) 115 | if isinstance(img, np.ndarray): 116 | return imresize( 117 | img, size, interpolation=interpolation, backend=backend) 118 | elif isinstance(img, torch.Tensor): 119 | image = imresize( 120 | img.numpy(), size, interpolation=interpolation, backend=backend) 121 | return torch.from_numpy(image) 122 | 123 | else: 124 | raise TypeError('img should got np.ndarray or torch.Tensor,' 125 | f'but got {type(img)}') 126 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.utils import print_log 7 | 8 | _integer_types = ( 9 | np.byte, 10 | np.ubyte, # 8 bits 11 | np.short, 12 | np.ushort, # 16 bits 13 | np.intc, 14 | np.uintc, # 16 or 32 or 64 bits 15 | np.int_, 16 | np.uint, # 32 or 64 bits 17 | np.longlong, 18 | np.ulonglong) # 64 bits 19 | 20 | _integer_ranges = { 21 | t: (np.iinfo(t).min, np.iinfo(t).max) 22 | for t in _integer_types 23 | } 24 | 25 | dtype_range = { 26 | np.bool_: (False, True), 27 | np.bool8: (False, True), 28 | np.float16: (-1, 1), 29 | np.float32: (-1, 1), 30 | np.float64: (-1, 1) 31 | } 32 | dtype_range.update(_integer_ranges) 33 | 34 | 35 | def dtype_limits(image, clip_negative=False): 36 | """Return intensity limits, i.e. (min, max) tuple, of the image's dtype. 37 | 38 | This function is adopted from skimage: 39 | https://github.com/scikit-image/scikit-image/blob/ 40 | 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/util/dtype.py#L35 41 | 42 | Args: 43 | image (ndarray): Input image. 44 | clip_negative (bool, optional): If True, clip the negative range 45 | (i.e. return 0 for min intensity) even if the image dtype allows 46 | negative values. 47 | 48 | Returns 49 | tuple: Lower and upper intensity limits. 50 | """ 51 | imin, imax = dtype_range[image.dtype.type] 52 | if clip_negative: 53 | imin = 0 54 | return imin, imax 55 | 56 | 57 | def adjust_gamma(image, gamma=1, gain=1): 58 | """Performs Gamma Correction on the input image. 59 | 60 | This function is adopted from skimage: 61 | https://github.com/scikit-image/scikit-image/blob/ 62 | 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/exposure/ 63 | exposure.py#L439-L494 64 | 65 | Also known as Power Law Transform. 66 | This function transforms the input image pixelwise according to the 67 | equation ``O = I**gamma`` after scaling each pixel to the range 0 to 1. 68 | 69 | Args: 70 | image (ndarray): Input image. 71 | gamma (float, optional): Non negative real number. Defaults to 1. 72 | gain (float, optional): The constant multiplier. Defaults to 1. 73 | 74 | Returns: 75 | ndarray: Gamma corrected output image. 76 | """ 77 | if np.any(image < 0): 78 | raise ValueError('Image Correction methods work correctly only on ' 79 | 'images with non-negative values. Use ' 80 | 'skimage.exposure.rescale_intensity.') 81 | 82 | dtype = image.dtype.type 83 | 84 | if gamma < 0: 85 | raise ValueError('Gamma should be a non-negative real number.') 86 | 87 | scale = float(dtype_limits(image, True)[1] - dtype_limits(image, True)[0]) 88 | 89 | out = ((image / scale)**gamma) * scale * gain 90 | return out.astype(dtype) 91 | 92 | 93 | def random_choose_unknown(unknown, crop_size): 94 | """Randomly choose an unknown start (top-left) point for a given crop_size. 95 | 96 | Args: 97 | unknown (np.ndarray): The binary unknown mask. 98 | crop_size (tuple[int]): The given crop size. 99 | 100 | Returns: 101 | tuple[int]: The top-left point of the chosen bbox. 102 | """ 103 | h, w = unknown.shape 104 | crop_h, crop_w = crop_size 105 | delta_h = center_h = crop_h // 2 106 | delta_w = center_w = crop_w // 2 107 | 108 | # mask out the validate area for selecting the cropping center 109 | mask = np.zeros_like(unknown) 110 | mask[delta_h:h - delta_h, delta_w:w - delta_w] = 1 111 | if np.any(unknown & mask): 112 | center_h_list, center_w_list = np.where(unknown & mask) 113 | elif np.any(unknown): 114 | center_h_list, center_w_list = np.where(unknown) 115 | else: 116 | print_log('No unknown pixels found!', level=logging.WARNING) 117 | center_h_list = [center_h] 118 | center_w_list = [center_w] 119 | num_unknowns = len(center_h_list) 120 | rand_ind = np.random.randint(num_unknowns) 121 | center_h = center_h_list[rand_ind] 122 | center_w = center_w_list[rand_ind] 123 | 124 | # make sure the top-left point is valid 125 | top = np.clip(center_h - delta_h, 0, h - crop_h) 126 | left = np.clip(center_w - delta_w, 0, w - crop_w) 127 | 128 | return top, left 129 | 130 | 131 | def make_coord(shape, ranges=None, flatten=True): 132 | """Make coordinates at grid centers. 133 | 134 | Args: 135 | shape (tuple): shape of image. 136 | ranges (tuple): range of coordinate value. Default: None. 137 | flatten (bool): flatten to (n, 2) or Not. Default: True. 138 | 139 | return: 140 | coord (Tensor): coordinates. 141 | """ 142 | coord_seqs = [] 143 | for i, n in enumerate(shape): 144 | if ranges is None: 145 | v0, v1 = -1, 1 146 | else: 147 | v0, v1 = ranges[i] 148 | r = (v1 - v0) / (2 * n) 149 | seq = v0 + r + (2 * r) * torch.arange(n).float() 150 | coord_seqs.append(seq) 151 | coord = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 152 | if flatten: 153 | coord = coord.view(-1, coord.shape[-1]) 154 | return coord 155 | -------------------------------------------------------------------------------- /mmedit/core/export/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import warnings 4 | 5 | import numpy as np 6 | import onnxruntime as ort 7 | import torch 8 | from torch import nn 9 | 10 | from mmedit.models import BaseMattor, BasicRestorer, build_model 11 | 12 | 13 | def inference_with_session(sess, io_binding, output_names, input_tensor): 14 | device_type = input_tensor.device.type 15 | device_id = input_tensor.device.index 16 | device_id = 0 if device_id is None else device_id 17 | io_binding.bind_input( 18 | name='input', 19 | device_type=device_type, 20 | device_id=device_id, 21 | element_type=np.float32, 22 | shape=input_tensor.shape, 23 | buffer_ptr=input_tensor.data_ptr()) 24 | for name in output_names: 25 | io_binding.bind_output(name) 26 | sess.run_with_iobinding(io_binding) 27 | pred = io_binding.copy_outputs_to_cpu() 28 | return pred 29 | 30 | 31 | class ONNXRuntimeMattor(nn.Module): 32 | 33 | def __init__(self, sess, io_binding, output_names, base_model): 34 | super(ONNXRuntimeMattor, self).__init__() 35 | self.sess = sess 36 | self.io_binding = io_binding 37 | self.output_names = output_names 38 | self.base_model = base_model 39 | 40 | def forward(self, 41 | merged, 42 | trimap, 43 | meta, 44 | test_mode=False, 45 | save_image=False, 46 | save_path=None, 47 | iteration=None): 48 | input_tensor = torch.cat((merged, trimap), 1).contiguous() 49 | pred_alpha = inference_with_session(self.sess, self.io_binding, 50 | self.output_names, input_tensor)[0] 51 | 52 | pred_alpha = pred_alpha.squeeze() 53 | pred_alpha = self.base_model.restore_shape(pred_alpha, meta) 54 | eval_result = self.base_model.evaluate(pred_alpha, meta) 55 | 56 | if save_image: 57 | self.base_model.save_image(pred_alpha, meta, save_path, iteration) 58 | 59 | return {'pred_alpha': pred_alpha, 'eval_result': eval_result} 60 | 61 | 62 | class RestorerGenerator(nn.Module): 63 | 64 | def __init__(self, sess, io_binding, output_names): 65 | super(RestorerGenerator, self).__init__() 66 | self.sess = sess 67 | self.io_binding = io_binding 68 | self.output_names = output_names 69 | 70 | def forward(self, x): 71 | pred = inference_with_session(self.sess, self.io_binding, 72 | self.output_names, x)[0] 73 | pred = torch.from_numpy(pred) 74 | return pred 75 | 76 | 77 | class ONNXRuntimeRestorer(nn.Module): 78 | 79 | def __init__(self, sess, io_binding, output_names, base_model): 80 | super(ONNXRuntimeRestorer, self).__init__() 81 | self.sess = sess 82 | self.io_binding = io_binding 83 | self.output_names = output_names 84 | self.base_model = base_model 85 | restorer_generator = RestorerGenerator(self.sess, self.io_binding, 86 | self.output_names) 87 | base_model.generator = restorer_generator 88 | 89 | def forward(self, lq, gt=None, test_mode=False, **kwargs): 90 | return self.base_model(lq, gt=gt, test_mode=test_mode, **kwargs) 91 | 92 | 93 | class ONNXRuntimeEditing(nn.Module): 94 | 95 | def __init__(self, onnx_file, cfg, device_id): 96 | super(ONNXRuntimeEditing, self).__init__() 97 | ort_custom_op_path = '' 98 | try: 99 | from mmcv.ops import get_onnxruntime_op_path 100 | ort_custom_op_path = get_onnxruntime_op_path() 101 | except (ImportError, ModuleNotFoundError): 102 | warnings.warn('If input model has custom op from mmcv, \ 103 | you may have to build mmcv with ONNXRuntime from source.') 104 | session_options = ort.SessionOptions() 105 | # register custom op for onnxruntime 106 | if osp.exists(ort_custom_op_path): 107 | session_options.register_custom_ops_library(ort_custom_op_path) 108 | sess = ort.InferenceSession(onnx_file, session_options) 109 | providers = ['CPUExecutionProvider'] 110 | options = [{}] 111 | is_cuda_available = ort.get_device() == 'GPU' 112 | if is_cuda_available: 113 | providers.insert(0, 'CUDAExecutionProvider') 114 | options.insert(0, {'device_id': device_id}) 115 | 116 | sess.set_providers(providers, options) 117 | 118 | self.sess = sess 119 | self.device_id = device_id 120 | self.io_binding = sess.io_binding() 121 | self.output_names = [_.name for _ in sess.get_outputs()] 122 | 123 | base_model = build_model( 124 | cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) 125 | 126 | if isinstance(base_model, BaseMattor): 127 | WrapperClass = ONNXRuntimeMattor 128 | elif isinstance(base_model, BasicRestorer): 129 | WrapperClass = ONNXRuntimeRestorer 130 | self.wrapper = WrapperClass(self.sess, self.io_binding, 131 | self.output_names, base_model) 132 | 133 | def forward(self, **kwargs): 134 | return self.wrapper(**kwargs) 135 | -------------------------------------------------------------------------------- /mmedit/models/backbones/map_backbones/map_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | # from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/utils/shape_convert.py 9 | def nlc_to_nchw(x, hw_shape): 10 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 11 | 12 | Args: 13 | x (Tensor): The input tensor of shape [N, L, C] before conversion. 14 | hw_shape (Sequence[int]): The height and width of output feature map. 15 | 16 | Returns: 17 | Tensor: The output tensor of shape [N, C, H, W] after conversion. 18 | """ 19 | H, W = hw_shape 20 | assert len(x.shape) == 3 21 | B, L, C = x.shape 22 | assert L == H * W, 'The seq_len doesn\'t match H, W' 23 | return x.transpose(1, 2).reshape(B, C, H, W) 24 | 25 | 26 | def nchw_to_nlc(x): 27 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 28 | 29 | Args: 30 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion. 31 | 32 | Returns: 33 | Tensor: The output tensor of shape [N, L, C] after conversion. 34 | """ 35 | assert len(x.shape) == 4 36 | return x.flatten(2).transpose(1, 2).contiguous() 37 | 38 | 39 | def flow_warp_5d(x, 40 | flow, 41 | interpolation='bilinear', 42 | padding_mode='zeros', 43 | align_corners=True): 44 | """Modified from mmedit.models.utils.flow_warp 45 | Warp a stack of image or a feature map with flow. 46 | 47 | Args: 48 | x (Tensor): Tensor with size (n, c, d, h, w). 49 | flow (Tensor): Tensor with size (n, d, h, w, 3). The last dimension is 50 | a three-channel, denoting the width, height and z relative offsets. 51 | Note that the w, h values are not normalized to [-1, 1], 52 | and z values are normalized to [0, d]. 53 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 54 | Default: 'bilinear'. 55 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 56 | Default: 'zeros'. 57 | align_corners (bool): Whether align corners. Default: True. 58 | 59 | Returns: 60 | Tensor: Warped image or feature map. 61 | """ 62 | if x.size()[-2:] != flow.size()[2:4]: 63 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 64 | f'flow ({flow.size()[2:4]}) are not the same.') 65 | _, _, d, h, w = x.size() 66 | # create mesh grid 67 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) 68 | # TODO: assume reference point is 0.5, ... 69 | # TODO: make it consistent with stda layer 70 | grid_z = d * 0.5 * torch.ones((h, w)) 71 | grid = torch.stack((grid_x, grid_y, grid_z), 2).type_as(x) # (h, w, 3) 72 | grid.requires_grad = False 73 | 74 | grid_flow = grid + flow 75 | # scale grid_flow to [-1,1] 76 | grid_flow_x = 2.0 * grid_flow[:, :, :, :, 0] / max(w - 1, 1) - 1.0 77 | grid_flow_y = 2.0 * grid_flow[:, :, :, :, 1] / max(h - 1, 1) - 1.0 78 | grid_flow_z = 2.0 * grid_flow[:, :, :, :, 2] / d - 1.0 79 | grid_flow = torch.stack((grid_flow_x, grid_flow_y, grid_flow_z), dim=4) 80 | output = F.grid_sample( 81 | x, 82 | grid_flow, 83 | mode=interpolation, 84 | padding_mode=padding_mode, 85 | align_corners=align_corners) 86 | return output 87 | 88 | 89 | def get_flow_from_grid(grid, ref, d=None): 90 | """ 91 | convert sampling grids [0, 1] to pixels 92 | """ 93 | _, h, w, _ = grid.shape 94 | flow = grid - ref # denormalize to flow in pixel 95 | flow[:, :, :, 0] *= h 96 | flow[:, :, :, 1] *= w 97 | if grid.shape[-1] == 3: 98 | flow[:, :, :, 2] *= d 99 | 100 | return flow 101 | 102 | 103 | def get_discrete_values(num_bins, start=0, end=1): 104 | """ 105 | get discrete values given predefined num_bins and range (start, end) 106 | """ 107 | values = torch.linspace( 108 | start, end, num_bins, requires_grad=False) 109 | return values 110 | 111 | 112 | # https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/ops/wrappers.py 113 | def resize(input, 114 | size=None, 115 | scale_factor=None, 116 | mode='nearest', 117 | align_corners=None, 118 | warning=True): 119 | if warning: 120 | if size is not None and align_corners: 121 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 122 | output_h, output_w = tuple(int(x) for x in size) 123 | if output_h > input_h or output_w > input_w: 124 | if ((output_h > 1 and output_w > 1 and input_h > 1 125 | and input_w > 1) and (output_h - 1) % (input_h - 1) 126 | and (output_w - 1) % (input_w - 1)): 127 | warnings.warn( 128 | f'When align_corners={align_corners}, ' 129 | 'the output would more aligned if ' 130 | f'input size {(input_h, input_w)} is `x+1` and ' 131 | f'out size {(output_h, output_w)} is `nx+1`') 132 | return F.interpolate(input, size, scale_factor, mode, align_corners) 133 | -------------------------------------------------------------------------------- /mmedit/models/backbones/map_backbones/map_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import ConvModule 3 | from mmcv.runner import BaseModule 4 | from mmedit.models.common import PixelShufflePack, ResidualBlockNoBN, make_layer 5 | from mmedit.models.registry import COMPONENTS 6 | 7 | 8 | @COMPONENTS.register_module() 9 | class ProjectionHead(BaseModule): 10 | """Projection head. 11 | """ 12 | 13 | def __init__(self, 14 | in_channels, 15 | out_channels, 16 | num_outs, 17 | start_level=0, 18 | end_level=-1, 19 | no_norm_on_lateral=False, 20 | conv_cfg=None, 21 | norm_cfg=None, 22 | act_cfg=None): 23 | super().__init__() 24 | assert isinstance(in_channels, list) 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.num_ins = len(in_channels) 28 | self.num_outs = num_outs 29 | self.no_norm_on_lateral = no_norm_on_lateral 30 | 31 | if end_level == -1: 32 | self.backbone_end_level = self.num_ins 33 | assert num_outs >= self.num_ins - start_level 34 | else: 35 | # if end_level < inputs, no extra level is allowed 36 | self.backbone_end_level = end_level 37 | assert end_level <= len(in_channels) 38 | assert num_outs == end_level - start_level 39 | self.start_level = start_level 40 | self.end_level = end_level 41 | 42 | self.lateral_convs = nn.ModuleList() 43 | 44 | for i in range(self.start_level, self.backbone_end_level): 45 | l_conv = ConvModule( 46 | in_channels[i], 47 | out_channels, 48 | 1, 49 | conv_cfg=conv_cfg, 50 | norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, 51 | act_cfg=act_cfg, 52 | inplace=False) 53 | 54 | self.lateral_convs.append(l_conv) 55 | 56 | def forward(self, inputs): 57 | assert len(inputs) == len(self.in_channels) 58 | 59 | # build laterals 60 | outs = [ 61 | lateral_conv(inputs[i + self.start_level]) 62 | for i, lateral_conv in enumerate(self.lateral_convs) 63 | ] 64 | 65 | return tuple(outs) 66 | 67 | 68 | @COMPONENTS.register_module() 69 | class MAPUpsampler(BaseModule): 70 | """The upsampler for MAP. 71 | https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py 72 | https://github.com/open-mmlab/mmediting/blob/master/mmedit/models/backbones/sr_backbones/basicvsr_net.py 73 | 74 | Args: 75 | embed_dim: feature dim from decoder 76 | num_feat: intermediate feature dim 77 | """ 78 | 79 | def __init__(self, 80 | upscale=4, 81 | embed_dim=64, 82 | num_feat=64, 83 | num_out_ch=3): 84 | super().__init__() 85 | assert upscale == 4 86 | 87 | self.conv_before_upsample = nn.Conv2d(embed_dim, num_feat, 3, 1, 1) 88 | self.upsample1 = PixelShufflePack( 89 | num_feat, num_feat, 2, upsample_kernel=3) 90 | self.upsample2 = PixelShufflePack( 91 | num_feat, num_feat, 2, upsample_kernel=3) 92 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 93 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 94 | 95 | nn.init.constant_(self.conv_last.weight, 0) 96 | nn.init.constant_(self.conv_last.bias, 0) 97 | 98 | # activation function 99 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 100 | 101 | def forward(self, x): 102 | hr = self.lrelu(self.conv_before_upsample(x)) 103 | hr = self.lrelu(self.upsample1(hr)) 104 | hr = self.lrelu(self.upsample2(hr)) 105 | hr = self.lrelu(self.conv_hr(hr)) 106 | hr = self.conv_last(hr) 107 | 108 | return hr 109 | 110 | 111 | class ResidualBlocksWithInputConv(nn.Module): 112 | """Residual blocks with a convolution in front. 113 | 114 | Args: 115 | in_channels (int): Number of input channels of the first conv. 116 | out_channels (int): Number of channels of the residual blocks. 117 | Default: 64. 118 | num_blocks (int): Number of residual blocks. Default: 30. 119 | """ 120 | 121 | def __init__(self, in_channels, out_channels=64, num_blocks=30): 122 | super().__init__() 123 | 124 | main = [] 125 | 126 | # a convolution used to match the channels of the residual blocks 127 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) 128 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) 129 | 130 | # residual blocks 131 | main.append( 132 | make_layer( 133 | ResidualBlockNoBN, num_blocks, mid_channels=out_channels)) 134 | 135 | self.main = nn.Sequential(*main) 136 | 137 | def forward(self, feat): 138 | """Forward function for ResidualBlocksWithInputConv. 139 | 140 | Args: 141 | feat (Tensor): Input feature with shape (n, in_channels, h, w) 142 | 143 | Returns: 144 | Tensor: Output feature with shape (n, out_channels, h, w) 145 | """ 146 | return self.main(feat) 147 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/inceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | from scipy import linalg 5 | 6 | from ..registry import METRICS 7 | from .inception_utils import load_inception 8 | 9 | 10 | class InceptionV3: 11 | """Feature extractor features using InceptionV3 model. 12 | 13 | Args: 14 | style (str): The model style to run Inception model. it must be either 15 | 'StyleGAN' or 'pytorch'. 16 | device (torch.device): device to extract feature. 17 | inception_kwargs (**kwargs): kwargs for InceptionV3. 18 | """ 19 | 20 | def __init__(self, style='StyleGAN', device='cpu', **inception_kwargs): 21 | self.inception = load_inception( 22 | style=style, **inception_kwargs).eval().to(device) 23 | self.style = style 24 | self.device = device 25 | 26 | def __call__(self, img1, img2, crop_border=0): 27 | """Extract features of real and fake images. 28 | 29 | Args: 30 | img1, img2 (np.ndarray): Images with range [0, 255] 31 | and shape (H, W, C). 32 | 33 | Returns: 34 | (tuple): Pair of features extracted from InceptionV3 model. 35 | """ 36 | return ( 37 | self.forward_inception(self.img2tensor(img1)).numpy(), 38 | self.forward_inception(self.img2tensor(img2)).numpy(), 39 | ) 40 | 41 | def img2tensor(self, img): 42 | img = np.expand_dims(img.transpose((2, 0, 1)), axis=0) 43 | if self.style == 'StyleGAN': 44 | return torch.tensor(img).to(device=self.device, dtype=torch.uint8) 45 | 46 | return torch.from_numpy(img / 255.).to( 47 | device=self.device, dtype=torch.float32) 48 | 49 | def forward_inception(self, x): 50 | if self.style == 'StyleGAN': 51 | return self.inception(x).cpu() 52 | 53 | return self.inception(x)[-1].view(x.shape[0], -1).cpu() 54 | 55 | 56 | def frechet_distance(X, Y): 57 | """Compute the frechet distance.""" 58 | 59 | muX, covX = np.mean(X, axis=0), np.cov(X, rowvar=False) 60 | muY, covY = np.mean(Y, axis=0), np.cov(Y, rowvar=False) 61 | 62 | cov_sqrt = linalg.sqrtm(covX.dot(covY)) 63 | frechet_distance = np.square(muX - muY).sum() + np.trace(covX) + np.trace( 64 | covY) - 2 * np.trace(cov_sqrt) 65 | return np.real(frechet_distance) 66 | 67 | 68 | @METRICS.register_module() 69 | class FID: 70 | """FID metric.""" 71 | 72 | def __call__(self, X, Y): 73 | """Calculate FID. 74 | 75 | Args: 76 | X (np.ndarray): Input feature X with shape (n_samples, dims). 77 | Y (np.ndarray): Input feature Y with shape (n_samples, dims). 78 | 79 | Returns: 80 | (float): fid value. 81 | """ 82 | return frechet_distance(X, Y) 83 | 84 | 85 | def polynomial_kernel(X, Y=None, degree=3, gamma=None, coef=1): 86 | """Create a polynomial kernel.""" 87 | Y = X if Y is None else Y 88 | if gamma is None: 89 | gamma = 1.0 / X.shape[1] 90 | K = ((X @ Y.T) * gamma + coef)**degree 91 | return K 92 | 93 | 94 | def mmd2(X, Y, unbiased=True): 95 | """Compute the Maximum Mean Discrepancy.""" 96 | XX = polynomial_kernel(X, X) 97 | YY = polynomial_kernel(Y, Y) 98 | XY = polynomial_kernel(X, Y) 99 | 100 | m = X.shape[0] 101 | if not unbiased: 102 | return (np.sum(XX) + np.sum(YY) - 2 * np.sum(XY)) / m**2 103 | 104 | trX = np.trace(XX) 105 | trY = np.trace(YY) 106 | return (np.sum(XX) - trX + np.sum(YY) - 107 | trY) / (m * (m - 1)) - 2 * np.sum(XY) / m**2 108 | 109 | 110 | @METRICS.register_module() 111 | class KID: 112 | """Implementation of `KID `. 113 | 114 | Args: 115 | num_repeats (int): The number of repetitions. Default: 100. 116 | sample_size (int): Size to sample. Default: 1000. 117 | use_unbiased_estimator (bool): Whether to use KID as an unbiased 118 | estimator. Using an unbiased estimator is desirable in the case of 119 | finite sample size, especially when the number of samples are 120 | small. Using an unbiased estimator is recommended in most cases. 121 | Default: True 122 | """ 123 | 124 | def __init__(self, 125 | num_repeats=100, 126 | sample_size=1000, 127 | use_unbiased_estimator=True): 128 | self.num_repeats = num_repeats 129 | self.sample_size = sample_size 130 | self.unbiased = use_unbiased_estimator 131 | 132 | def __call__(self, X, Y): 133 | """Calculate KID. 134 | 135 | Args: 136 | X (np.ndarray): Input feature X with shape (n_samples, dims). 137 | Y (np.ndarray): Input feature Y with shape (n_samples, dims). 138 | 139 | Returns: 140 | (dict): dict containing mean and std of KID values. 141 | """ 142 | num_samples = X.shape[0] 143 | kid = list() 144 | for i in range(self.num_repeats): 145 | X_ = X[np.random.choice( 146 | num_samples, self.sample_size, replace=False)] 147 | Y_ = Y[np.random.choice( 148 | num_samples, self.sample_size, replace=False)] 149 | kid.append(mmd2(X_, Y_, unbiased=self.unbiased)) 150 | kid = np.array(kid) 151 | return dict(KID_MEAN=kid.mean(), KID_STD=kid.std()) 152 | -------------------------------------------------------------------------------- /docs/dataset_prepare.md: -------------------------------------------------------------------------------- 1 | ## HazeWorld 2 | 3 | **HazeWorld** is a large-scale synthetic outdoor video dehazing dataset, 4 | which is built upon [Cityscapes](https://www.cityscapes-dataset.com/), 5 | [DDAD](https://github.com/TRI-ML/DDAD), 6 | [UA-DETRAC](https://detrac-db.rit.albany.edu/), 7 | [VisDrone](https://github.com/VisDrone/VisDrone-Dataset), 8 | [DAVIS](https://davischallenge.org/), 9 | and [REDS](https://seungjunnah.github.io/Datasets/reds.html). 10 | Please refer to these official dataset websites for rights of use. 11 | 12 | We use [RCVD](https://robust-cvd.github.io/) to estimate the temporally consistent video depths, which are used to synthesize the hazy videos. 13 | The fog synthesis pipeline is built on [SeeingThroughFog](https://github.com/princeton-computational-imaging/SeeingThroughFog/tree/master/tools/DatasetFoggification). 14 | 15 | ## Prepare datasets 16 | 17 | It is recommended to symlink the dataset root to `$MAP-NET/data`. 18 | If your folder structure is different, you may need to change the corresponding paths in config files. 19 | 20 | ```none 21 | MAP-Net 22 | ├── ... 23 | ├── data 24 | │ ├── Cityscapes 25 | │ │ ├── leftImg8bit_sequence_trainvaltest 26 | │ │ │ ├── leftImg8bit_sequence 27 | │ │ │ │ ├── train 28 | │ │ │ │ ├── val 29 | │ │ │ │ ├── test 30 | │ ├── HazeWorld 31 | │ │ ├── mapping_hazeworld_cityscapes.txt 32 | │ │ ├── gt 33 | │ │ │ ├── Cityscapes 34 | │ │ │ │ ├── train 35 | │ │ │ │ ├── val 36 | │ │ │ │ ├── test 37 | │ │ │ │ ├── mapping_info_GT_train.txt 38 | │ │ │ │ ├── mapping_info_GT_test.txt 39 | │ │ │ ├── DDAD 40 | │ │ │ │ ├── train 41 | │ │ │ │ ├── val 42 | │ │ │ │ ├── ... 43 | │ │ │ ├── UA-DETRAC 44 | │ │ │ │ ├── train 45 | │ │ │ │ ├── test 46 | │ │ │ │ ├── ... 47 | │ │ │ ├── VisDrone 48 | │ │ │ │ ├── train 49 | │ │ │ │ ├── val 50 | │ │ │ │ ├── test-dev 51 | │ │ │ │ ├── ... 52 | │ │ │ ├── DAVIS 53 | │ │ │ │ ├── train 54 | │ │ │ │ ├── val 55 | │ │ │ │ ├── test-dev 56 | │ │ │ │ ├── test-challenge 57 | │ │ │ │ ├── ... 58 | │ │ │ ├── REDS 59 | │ │ │ │ ├── train 60 | │ │ │ │ ├── val 61 | │ │ │ │ ├── ... 62 | │ │ ├── hazy 63 | │ │ │ ├── ... 64 | │ │ ├── transmission 65 | │ │ │ ├── ... 66 | (symlink) 67 | │ │ ├── train 68 | │ │ │ ├── meta_info_tree_GT_train.json 69 | │ │ │ ├── meta_info_GT_train.txt 70 | │ │ │ ├── meta_info_GT_... 71 | │ │ │ ├── gt (symlink) 72 | │ │ │ │ ├── Cityscapes 73 | │ │ │ │ ├── DDAD 74 | │ │ │ │ ├── UA-DETRAC 75 | │ │ │ │ ├── VisDrone 76 | │ │ │ │ ├── DAVIS 77 | │ │ │ │ ├── REDS 78 | │ │ │ ├── hazy (symlink) 79 | │ │ │ │ ├── ... 80 | │ │ ├── test 81 | │ │ │ ├── meta_info_tree_GT_test.json 82 | │ │ │ ├── meta_info_GT_test.txt 83 | │ │ │ ├── meta_info_GT_... 84 | │ │ │ ├── gt (symlink) 85 | │ │ │ │ ├── ... 86 | │ │ │ ├── hazy (symlink) 87 | │ │ │ │ ├── ... 88 | ``` 89 | 90 | **Step 1.** 91 | Download the data from the links at the bottom. 92 | Since many hazy videos may correspond to one ground-truth video, we adopt the file structure above to save storage. 93 | 94 | **Step 2.** 95 | ~~Download the [meta files](https://appsrv.cse.cuhk.edu.hk/~jqxu/data/MAP-Net/HazeWorld_meta-files.zip) and put them into the corresponding locations (see above).~~ 96 | The meta files are provided in the HazeWorld [download link](https://appsrv.cse.cuhk.edu.hk/~jqxu/data/MAP-Net/HazeWorld.zip). 97 | 98 | **Step 3.** 99 | Symlink the **train** and **test** split using the [script](../tools/data/dehazing/hazeworld/create_symlink_hazeworld.py) and the following command: 100 | 101 | ```shell 102 | python tools/data/dehazing/hazeworld/create_symlink_hazeworld.py 103 | ``` 104 | 105 | ### Cityscapes 106 | 107 | The data could be found [here](https://www.cityscapes-dataset.com/downloads/) after registration. 108 | Download [*leftImg8bit_sequence_trainvaltest.zip (324GB)*](https://www.cityscapes-dataset.com/file-handling/?packageID=14). 109 | The used videos can be found in [mapping_hazeworld_cityscapes.txt](https://drive.google.com/file/d/13IZPyeB64lu3szOJsihSPGyUx9cK6yb8/view?usp=share_link). 110 | 111 | ```shell 112 | python tools/data/dehazing/hazeworld/preprocess_hazeworld_cityscapes.py \ 113 | --meta-file data/HazeWorld/mapping_hazeworld_cityscapes.txt \ 114 | --input-dir data/Cityscapes --work-dir data/HazeWorld 115 | ``` 116 | 117 | ### Others 118 | 119 | For others, we provide the [processed data (~100GB)](https://appsrv.cse.cuhk.edu.hk/~jqxu/data/MAP-Net/HazeWorld.zip). 120 | You can also refer to their official websites for the original data. 121 | 122 | ### Notes 123 | 124 | * We do some data processing on the original data, 125 | so the numbers and videos may not correspond to the original ones. 126 | Here are some examples, 127 | and more details can be found [here](../tools/data/dehazing/hazeworld/preprocess_hazeworld_cityscapes.py). 128 | 129 | > We sample the frames to keep each video clip of similar length (mostly no more than 100 images per video), 130 | > using different sampling strategies for each dataset. 131 | > 132 | > We use **cv2** to resize (short border to 720 pixels if the original is larger than 720, keeping aspect ratios, 133 | > and using the default interpolation method) and save (*jpg*, default quality, lossy compression, to save storage) images. 134 | 135 | * Also, we manually check the data and remove some improper videos (*e.g.*, indoor or nighttime scenes). 136 | --------------------------------------------------------------------------------