├── requirements ├── mminstall.txt ├── readthedocs.txt ├── tests.txt ├── docs.txt └── runtime.txt ├── requirements.txt ├── metrics ├── utils_image.py └── cal-iqa.py ├── mmedit ├── core │ ├── evaluation │ │ ├── niqe_pris_params.npz │ │ ├── __init__.py │ │ ├── metric_utils.py │ │ ├── eval_hooks.py │ │ └── inceptions.py │ ├── optimizer │ │ ├── __init__.py │ │ └── builder.py │ ├── utils │ │ ├── __init__.py │ │ └── dist_utils.py │ ├── export │ │ ├── __init__.py │ │ └── wrappers.py │ ├── scheduler │ │ └── __init__.py │ ├── registry.py │ ├── hooks │ │ ├── __init__.py │ │ ├── visualization.py │ │ └── ema.py │ ├── __init__.py │ ├── misc.py │ └── distributed_wrapper.py ├── models │ ├── backbones │ │ ├── __init__.py │ │ └── derain_backbones │ │ │ ├── __init__.py │ │ │ └── derain_net.py │ ├── derainers │ │ ├── __init__.py │ │ └── derainer.py │ ├── losses │ │ ├── __init__.py │ │ ├── utils.py │ │ └── pixelwise_loss.py │ ├── registry.py │ ├── common │ │ ├── __init__.py │ │ ├── flow_warp.py │ │ └── sr_backbone_utils.py │ ├── __init__.py │ ├── builder.py │ └── base.py ├── datasets │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ ├── registry.py │ ├── __init__.py │ ├── dataset_wrappers.py │ ├── pipelines │ │ ├── compose.py │ │ ├── __init__.py │ │ ├── normalization.py │ │ ├── random_down_sampling.py │ │ ├── utils.py │ │ ├── generate_assistant.py │ │ └── formating.py │ ├── sr_folder_gt_dataset.py │ ├── base_dataset.py │ ├── base_sr_dataset.py │ ├── sr_folder_multiple_gt_dataset.py │ └── builder.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── cli.py │ ├── logger.py │ ├── setup_env.py │ └── misc.py ├── version.py ├── apis │ ├── __init__.py │ ├── restoration_inference.py │ ├── restoration_face_inference.py │ ├── restoration_video_inference.py │ └── test.py └── __init__.py ├── tools ├── dist_train.sh ├── dist_test.sh ├── slurm_test.sh ├── slurm_train.sh ├── deployment │ ├── test_torchserver.py │ ├── mmedit_handler.py │ └── mmedit2torchserve.py ├── test.py └── train.py ├── setup.cfg ├── modules └── DeformableAlignment.py ├── README.md └── configs └── derainers └── ViMPNet └── ViMPNet.py /requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv-full>=1.3.17 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/runtime.txt 2 | -r requirements/tests.txt 3 | -------------------------------------------------------------------------------- /metrics/utils_image.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyHongtaoWu/ViMP-Net/HEAD/metrics/utils_image.py -------------------------------------------------------------------------------- /requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | lmdb 2 | mmcv 3 | regex 4 | scikit-image 5 | titlecase 6 | torch 7 | torchvision 8 | -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | codecov 2 | flake8 3 | interrogate 4 | isort==5.10.1 5 | onnxruntime 6 | pytest 7 | pytest-runner 8 | yapf 9 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyHongtaoWu/ViMP-Net/HEAD/mmedit/core/evaluation/niqe_pris_params.npz -------------------------------------------------------------------------------- /mmedit/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_optimizers 3 | 4 | __all__ = ['build_optimizers'] 5 | -------------------------------------------------------------------------------- /mmedit/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import sync_random_seed 3 | 4 | __all__ = ['sync_random_seed'] 5 | -------------------------------------------------------------------------------- /mmedit/core/export/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .wrappers import ONNXRuntimeEditing 3 | 4 | __all__ = ['ONNXRuntimeEditing'] 5 | -------------------------------------------------------------------------------- /mmedit/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .derain_backbones import (ViMPNet) 4 | 5 | __all__ = [ 6 | 'ViMPNet' 7 | ] 8 | -------------------------------------------------------------------------------- /mmedit/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /mmedit/models/backbones/derain_backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .ViMPNet import ViMPNet 4 | 5 | __all__ = [ 6 | 'ViMPNet' 7 | ] 8 | -------------------------------------------------------------------------------- /mmedit/datasets/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry 3 | 4 | DATASETS = Registry('dataset') 5 | PIPELINES = Registry('pipeline') 6 | -------------------------------------------------------------------------------- /mmedit/core/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .lr_updater import LinearLrUpdaterHook, ReduceLrUpdaterHook 3 | 4 | __all__ = ['LinearLrUpdaterHook', 'ReduceLrUpdaterHook'] 5 | -------------------------------------------------------------------------------- /mmedit/models/derainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .derainer import Derainer 3 | from .DrainNet import DrainNet 4 | 5 | 6 | __all__ = [ 7 | 'Derainer', 'DrainNet' 8 | ] 9 | -------------------------------------------------------------------------------- /mmedit/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss 3 | 4 | __all__ = [ 5 | 'L1Loss', 'MSELoss', 'CharbonnierLoss' 6 | ] 7 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | mmcls==0.10.0 3 | myst_parser 4 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | sphinx==4.0.2 6 | sphinx-copybutton 7 | sphinx_markdown_tables 8 | -------------------------------------------------------------------------------- /mmedit/models/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import MODELS as MMCV_MODELS 3 | from mmcv.utils import Registry 4 | 5 | MODELS = Registry('model', parent=MMCV_MODELS) 6 | BACKBONES = MODELS 7 | COMPONENTS = MODELS 8 | LOSSES = MODELS 9 | -------------------------------------------------------------------------------- /mmedit/core/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | METRICS = Registry('metric') 5 | 6 | 7 | def build_metric(cfg): 8 | """Build a metric calculator.""" 9 | return build_from_cfg(cfg, METRICS) 10 | -------------------------------------------------------------------------------- /mmedit/core/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ema import ExponentialMovingAverageHook 3 | from .visualization import MMEditVisualizationHook, VisualizationHook 4 | 5 | __all__ = [ 6 | 'VisualizationHook', 'MMEditVisualizationHook', 7 | 'ExponentialMovingAverageHook' 8 | ] 9 | -------------------------------------------------------------------------------- /mmedit/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .flow_warp import flow_warp 4 | from .sr_backbone_utils import (ResidualBlockNoBN, 5 | make_layer) 6 | 7 | 8 | 9 | 10 | __all__ = [ 11 | 'flow_warp', 'ResidualBlockNoBN', 'make_layer' 12 | ] 13 | -------------------------------------------------------------------------------- /mmedit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cli import modify_args 3 | from .logger import get_root_logger 4 | from .misc import deprecated_function 5 | from .setup_env import setup_multi_processes 6 | 7 | __all__ = [ 8 | 'get_root_logger', 'setup_multi_processes', 'modify_args', 9 | 'deprecated_function' 10 | ] 11 | -------------------------------------------------------------------------------- /mmedit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_dataset import BaseDataset 3 | from .builder import build_dataloader, build_dataset 4 | from .registry import DATASETS, PIPELINES 5 | from .sr_folder_multiple_gt_dataset import SRFolderMultipleGTDataset 6 | 7 | __all__ = [ 8 | 'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader', 9 | 'BaseDataset', 'SRFolderMultipleGTDataset' 10 | ] 11 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import DistEvalIterHook, EvalIterHook 3 | from .inceptions import FID, KID, InceptionV3 4 | from .metrics import (L1Evaluation, connectivity, gradient_error, mae, mse, 5 | niqe, psnr, reorder_image, sad, ssim) 6 | 7 | __all__ = [ 8 | 'mse', 'sad', 'psnr', 'reorder_image', 'ssim', 'EvalIterHook', 9 | 'DistEvalIterHook', 'L1Evaluation', 'gradient_error', 'connectivity', 10 | 'niqe', 'mae', 'FID', 'KID', 'InceptionV3' 11 | ] 12 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-18600} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/train.py \ 18 | $CONFIG \ 19 | --seed 0 \ 20 | --launcher pytorch ${@:3} 21 | -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | av 2 | av==8.0.3; python_version < '3.7' 3 | facexlib 4 | lmdb 5 | mmcv-full>=1.3.13 # To support DCN on CPU 6 | numpy 7 | opencv-python!=4.5.5.62,!=4.5.5.64 8 | # MMCV depends opencv-python instead of headless, thus we install opencv-python 9 | # Due to a bug from upstream, we skip this two version 10 | # https://github.com/opencv/opencv-python/issues/602 11 | # https://github.com/opencv/opencv/issues/21366 12 | # It seems to be fixed in https://github.com/opencv/opencv/pull/21382 13 | Pillow 14 | tensorboard 15 | torch 16 | torchvision 17 | -------------------------------------------------------------------------------- /mmedit/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.16.1' 4 | 5 | 6 | def parse_version_info(version_str): 7 | ver_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | ver_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | ver_info.append(int(patch_version[0])) 14 | ver_info.append(f'rc{patch_version[1]}') 15 | return tuple(ver_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /mmedit/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmedit 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMEditing'] = f'{mmedit.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/test.py \ 19 | $CONFIG \ 20 | $CHECKPOINT \ 21 | --launcher pytorch \ 22 | ${@:4} 23 | -------------------------------------------------------------------------------- /mmedit/utils/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import re 3 | import sys 4 | import warnings 5 | 6 | 7 | def modify_args(): 8 | for i, v in enumerate(sys.argv): 9 | if i == 0: 10 | assert v.endswith('.py') 11 | elif re.match(r'--\w+_.*', v): 12 | new_arg = v.replace('_', '-') 13 | warnings.warn( 14 | f'command line argument {v} is deprecated, ' 15 | f'please use {new_arg} instead.', 16 | category=DeprecationWarning, 17 | ) 18 | sys.argv[i] = new_arg 19 | -------------------------------------------------------------------------------- /mmedit/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401, F403 3 | from .base import BaseModel 4 | from .builder import (build, build_backbone, build_component, build_loss, 5 | build_model) 6 | from .common import * # noqa: F401, F403 7 | from .losses import * # noqa: F401, F403 8 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS 9 | from .derainers import Derainer 10 | 11 | 12 | __all__ = [ 13 | 'BaseModel', 'Derainer', 'build', 14 | 'build_backbone', 'build_component', 'build_loss', 'build_model', 15 | 'BACKBONES', 'COMPONENTS', 'LOSSES', 'MODELS', 16 | ] 17 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | addopts=tests/ 9 | 10 | [yapf] 11 | based_on_style = pep8 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | split_penalty_import_names=0 15 | SPLIT_PENALTY_AFTER_OPENING_BRACKET=888 16 | 17 | [isort] 18 | line_length = 79 19 | multi_line_output = 0 20 | extra_standard_library = setuptools 21 | known_first_party = mmedit 22 | known_third_party = PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,pytorch_sphinx_theme,requests,scipy,titlecase,torch,torchvision,ts 23 | no_lines_before = STDLIB,LOCALFOLDER 24 | default_section = THIRDPARTY 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | GPUS=${GPUS:-8} 11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 13 | PY_ARGS=${@:5} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 26 | -------------------------------------------------------------------------------- /mmedit/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import (FID, KID, DistEvalIterHook, EvalIterHook, InceptionV3, 3 | L1Evaluation, mae, mse, psnr, reorder_image, sad, 4 | ssim) 5 | from .hooks import MMEditVisualizationHook, VisualizationHook 6 | from .misc import tensor2img 7 | from .optimizer import build_optimizers 8 | from .registry import build_metric 9 | from .scheduler import LinearLrUpdaterHook, ReduceLrUpdaterHook 10 | 11 | __all__ = [ 12 | 'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook', 13 | 'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook', 14 | 'VisualizationHook', 'MMEditVisualizationHook', 'L1Evaluation', 'FID', 15 | 'KID', 'InceptionV3', 'build_metric', 'ReduceLrUpdaterHook', 'mae' 16 | ] 17 | -------------------------------------------------------------------------------- /mmedit/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # from .generation_inference import generation_inference 3 | # from .inpainting_inference import inpainting_inference 4 | # from .matting_inference import init_model, matting_inference 5 | from .restoration_face_inference import restoration_face_inference 6 | from .restoration_inference import restoration_inference 7 | from .restoration_video_inference import restoration_video_inference 8 | from .test import multi_gpu_test, single_gpu_test 9 | from .train import init_random_seed, set_random_seed, train_model 10 | # from .video_interpolation_inference import video_interpolation_inference 11 | 12 | __all__ = [ 13 | 'train_model', 'set_random_seed', 'init_model', 14 | 'restoration_inference', 15 | 'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference', 16 | 'restoration_face_inference', 'init_random_seed' 17 | ] 18 | -------------------------------------------------------------------------------- /mmedit/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmedit". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | # root logger name: mmedit 26 | logger = get_logger(__name__.split('.')[0], log_file, log_level) 27 | return logger 28 | -------------------------------------------------------------------------------- /mmedit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | from .version import __version__, version_info 5 | 6 | try: 7 | from mmcv.utils import digit_version 8 | except ImportError: 9 | 10 | def digit_version(version_str): 11 | digit_ver = [] 12 | for x in version_str.split('.'): 13 | if x.isdigit(): 14 | digit_ver.append(int(x)) 15 | elif x.find('rc') != -1: 16 | patch_version = x.split('rc') 17 | digit_ver.append(int(patch_version[0]) - 1) 18 | digit_ver.append(int(patch_version[1])) 19 | return digit_ver 20 | 21 | 22 | MMCV_MIN = '1.3.13' 23 | MMCV_MAX = '1.8' 24 | 25 | mmcv_min_version = digit_version(MMCV_MIN) 26 | mmcv_max_version = digit_version(MMCV_MAX) 27 | mmcv_version = digit_version(mmcv.__version__) 28 | 29 | 30 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ 31 | f'mmcv=={mmcv.__version__} is used but incompatible. ' \ 32 | f'Please install mmcv-full>={mmcv_min_version}, <={mmcv_max_version}.' 33 | 34 | __all__ = ['__version__', 'version_info'] 35 | -------------------------------------------------------------------------------- /mmedit/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .registry import DATASETS 3 | 4 | 5 | @DATASETS.register_module() 6 | class RepeatDataset: 7 | """A wrapper of repeated dataset. 8 | 9 | The length of repeated dataset will be `times` larger than the original 10 | dataset. This is useful when the data loading time is long but the dataset 11 | is small. Using RepeatDataset can reduce the data loading time between 12 | epochs. 13 | 14 | Args: 15 | dataset (:obj:`Dataset`): The dataset to be repeated. 16 | times (int): Repeat times. 17 | """ 18 | 19 | def __init__(self, dataset, times): 20 | self.dataset = dataset 21 | self.times = times 22 | 23 | self._ori_len = len(self.dataset) 24 | 25 | def __getitem__(self, idx): 26 | """Get item at each call. 27 | 28 | Args: 29 | idx (int): Index for getting each item. 30 | """ 31 | return self.dataset[idx % self._ori_len] 32 | 33 | def __len__(self): 34 | """Length of the dataset. 35 | 36 | Returns: 37 | int: Length of the dataset. 38 | """ 39 | return self.times * self._ori_len 40 | -------------------------------------------------------------------------------- /mmedit/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def sync_random_seed(seed=None, device='cuda'): 9 | """Make sure different ranks share the same seed. 10 | 11 | All workers must call this function, otherwise it will deadlock. 12 | This method is generally used in `DistributedSampler`, 13 | because the seed should be identical across all processes 14 | in the distributed group. 15 | Args: 16 | seed (int, Optional): The seed. Default to None. 17 | device (str): The device where the seed will be put on. 18 | Default to 'cuda'. 19 | Returns: 20 | int: Seed to be used. 21 | """ 22 | if seed is None: 23 | seed = np.random.randint(2**31) 24 | assert isinstance(seed, int) 25 | 26 | rank, world_size = get_dist_info() 27 | 28 | if world_size == 1: 29 | return seed 30 | 31 | if rank == 0: 32 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 33 | else: 34 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 35 | dist.broadcast(random_num, src=0) 36 | return random_num.item() 37 | -------------------------------------------------------------------------------- /tools/deployment/test_torchserver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | 4 | import cv2 5 | import numpy as np 6 | import requests 7 | from PIL import Image 8 | 9 | 10 | def parse_args(): 11 | parser = ArgumentParser() 12 | parser.add_argument('model_name', help='The model name in the server') 13 | parser.add_argument( 14 | '--inference-addr', 15 | default='127.0.0.1:8080', 16 | help='Address and port of the inference server') 17 | parser.add_argument('--img-path', type=str, help='The input LQ image.') 18 | parser.add_argument( 19 | '--save-path', type=str, help='Path to save the generated GT image.') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def save_results(content, save_path, ori_shape): 25 | ori_len = np.prod(ori_shape) 26 | scale = int(np.sqrt(len(content) / ori_len)) 27 | target_size = [int(size * scale) for size in ori_shape[:2][::-1]] 28 | # Convert to RGB and save image 29 | img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0) 30 | img.save(save_path) 31 | 32 | 33 | def main(args): 34 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 35 | ori_shape = cv2.imread(args.img_path).shape 36 | with open(args.img_path, 'rb') as image: 37 | response = requests.post(url, image) 38 | save_results(response.content, args.save_path, ori_shape) 39 | 40 | 41 | if __name__ == '__main__': 42 | parsed_args = parse_args() 43 | main(parsed_args) 44 | -------------------------------------------------------------------------------- /mmedit/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv import build_from_cfg 4 | 5 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS 6 | 7 | 8 | def build(cfg, registry, default_args=None): 9 | """Build module function. 10 | 11 | Args: 12 | cfg (dict): Configuration for building modules. 13 | registry (obj): ``registry`` object. 14 | default_args (dict, optional): Default arguments. Defaults to None. 15 | """ 16 | if isinstance(cfg, list): 17 | modules = [ 18 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 19 | ] 20 | return nn.Sequential(*modules) 21 | 22 | return build_from_cfg(cfg, registry, default_args) 23 | 24 | 25 | def build_backbone(cfg): 26 | """Build backbone. 27 | 28 | Args: 29 | cfg (dict): Configuration for building backbone. 30 | """ 31 | return build(cfg, BACKBONES) 32 | 33 | 34 | def build_component(cfg): 35 | """Build component. 36 | 37 | Args: 38 | cfg (dict): Configuration for building component. 39 | """ 40 | return build(cfg, COMPONENTS) 41 | 42 | 43 | def build_loss(cfg): 44 | """Build loss. 45 | 46 | Args: 47 | cfg (dict): Configuration for building loss. 48 | """ 49 | return build(cfg, LOSSES) 50 | 51 | 52 | def build_model(cfg, train_cfg=None, test_cfg=None): 53 | """Build model. 54 | 55 | Args: 56 | cfg (dict): Configuration for building model. 57 | train_cfg (dict): Training configuration. Default: None. 58 | test_cfg (dict): Testing configuration. Default: None. 59 | """ 60 | return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 61 | -------------------------------------------------------------------------------- /metrics/cal-iqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import os 5 | import torch 6 | import requests 7 | import numpy as np 8 | from os import path as osp 9 | from collections import OrderedDict 10 | from torch.utils.data import DataLoader 11 | import utils_image as util 12 | from rich.progress import track 13 | from natsort import natsorted 14 | import lpips 15 | 16 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 17 | return torch.Tensor((image / factor - cent) 18 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 19 | 20 | 21 | path ="./result/" 22 | gt_path = "./testgt/" 23 | folders = os.listdir(path) 24 | print(path) 25 | 26 | 27 | 28 | psnr = [] 29 | ssim = [] 30 | lpips_ = [] 31 | loss_fn_alex = lpips.LPIPS(net='alex').cuda() 32 | 33 | for folder in folders: 34 | print(folder) 35 | imgs = natsorted(glob.glob(osp.join(path, folder, '*.png'))) 36 | imgs_gt = natsorted(glob.glob(osp.join(gt_path, folder, '*.png'))) 37 | psnr_folder = [] 38 | ssim_folder = [] 39 | lpips_folder = [] 40 | 41 | for i in track(range(len(imgs))): 42 | output = cv2.imread(imgs[i]) 43 | gt = cv2.imread(imgs_gt[i]) 44 | if output.shape != gt.shape: 45 | print(output.shape, gt.shape) 46 | psnr_folder.append(util.calculate_psnr(output, gt)) 47 | ssim_folder.append(util.calculate_ssim(output, gt)) 48 | lpips_folder.append(loss_fn_alex(im2tensor(output).cuda(), im2tensor(gt).cuda()).item()) 49 | 50 | psnr.append(np.mean(psnr_folder)) 51 | ssim.append(np.mean(ssim_folder)) 52 | lpips_.append(np.mean(lpips_folder)) 53 | 54 | print('psnr: ', np.mean(psnr)) 55 | print('ssim: ', np.mean(ssim)) 56 | print('lpips: ', np.mean(lpips_)) 57 | 58 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections.abc import Sequence 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..registry import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose: 11 | """Compose a data pipeline with a sequence of transforms. 12 | 13 | Args: 14 | transforms (list[dict | callable]): 15 | Either config dicts of transforms or transform objects. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError(f'transform must be callable or a dict, ' 29 | f'but got {type(transform)}') 30 | 31 | def __call__(self, data): 32 | """Call function. 33 | 34 | Args: 35 | data (dict): A dict containing the necessary information and 36 | data for augmentation. 37 | 38 | Returns: 39 | dict: A dict containing the processed data and information. 40 | """ 41 | for t in self.transforms: 42 | data = t(data) 43 | if data is None: 44 | return None 45 | return data 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.transforms: 50 | format_string += '\n' 51 | format_string += f' {t}' 52 | format_string += '\n)' 53 | return format_string 54 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.parallel import collate, scatter 4 | 5 | from mmedit.datasets.pipelines import Compose 6 | 7 | 8 | def restoration_inference(model, img, ref=None): 9 | """Inference image with the model. 10 | 11 | Args: 12 | model (nn.Module): The loaded model. 13 | img (str): File path of input image. 14 | ref (str | None): File path of reference image. Default: None. 15 | 16 | Returns: 17 | Tensor: The predicted restoration result. 18 | """ 19 | cfg = model.cfg 20 | device = next(model.parameters()).device # model device 21 | # remove gt from test_pipeline 22 | keys_to_remove = ['gt', 'gt_path'] 23 | for key in keys_to_remove: 24 | for pipeline in list(cfg.test_pipeline): 25 | if 'key' in pipeline and key == pipeline['key']: 26 | cfg.test_pipeline.remove(pipeline) 27 | if 'keys' in pipeline and key in pipeline['keys']: 28 | pipeline['keys'].remove(key) 29 | if len(pipeline['keys']) == 0: 30 | cfg.test_pipeline.remove(pipeline) 31 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 32 | pipeline['meta_keys'].remove(key) 33 | # build the data pipeline 34 | test_pipeline = Compose(cfg.test_pipeline) 35 | # prepare data 36 | if ref: # Ref-SR 37 | data = dict(lq_path=img, ref_path=ref) 38 | else: # SISR 39 | data = dict(lq_path=img) 40 | data = test_pipeline(data) 41 | data = collate([data], samples_per_gpu=1) 42 | if 'cuda' in str(device): 43 | data = scatter(data, [device])[0] 44 | # forward the model 45 | with torch.no_grad(): 46 | result = model(test_mode=True, **data) 47 | 48 | return result['output'] 49 | -------------------------------------------------------------------------------- /mmedit/core/optimizer/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import build_optimizer 3 | 4 | 5 | def build_optimizers(model, cfgs): 6 | """Build multiple optimizers from configs. 7 | 8 | If `cfgs` contains several dicts for optimizers, then a dict for each 9 | constructed optimizers will be returned. 10 | If `cfgs` only contains one optimizer config, the constructed optimizer 11 | itself will be returned. 12 | 13 | For example, 14 | 15 | 1) Multiple optimizer configs: 16 | 17 | .. code-block:: python 18 | 19 | optimizer_cfg = dict( 20 | model1=dict(type='SGD', lr=lr), 21 | model2=dict(type='SGD', lr=lr)) 22 | 23 | The return dict is 24 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` 25 | 26 | 2) Single optimizer config: 27 | 28 | .. code-block:: python 29 | 30 | optimizer_cfg = dict(type='SGD', lr=lr) 31 | 32 | The return is ``torch.optim.Optimizer``. 33 | 34 | Args: 35 | model (:obj:`nn.Module`): The model with parameters to be optimized. 36 | cfgs (dict): The config dict of the optimizer. 37 | 38 | Returns: 39 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: 40 | The initialized optimizers. 41 | """ 42 | optimizers = {} 43 | if hasattr(model, 'module'): 44 | model = model.module 45 | # determine whether 'cfgs' has several dicts for optimizers 46 | is_dict_of_dict = True 47 | for key, cfg in cfgs.items(): 48 | if not isinstance(cfg, dict): 49 | is_dict_of_dict = False 50 | 51 | if is_dict_of_dict: 52 | for key, cfg in cfgs.items(): 53 | cfg_ = cfg.copy() 54 | module = getattr(model, key) 55 | optimizers[key] = build_optimizer(module, cfg_) 56 | return optimizers 57 | 58 | return build_optimizer(model, cfgs) 59 | -------------------------------------------------------------------------------- /mmedit/models/common/flow_warp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def flow_warp(x, 7 | flow, 8 | interpolation='bilinear', 9 | padding_mode='zeros', 10 | align_corners=True): 11 | """Warp an image or a feature map with optical flow. 12 | 13 | Args: 14 | x (Tensor): Tensor with size (n, c, h, w). 15 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 16 | a two-channel, denoting the width and height relative offsets. 17 | Note that the values are not normalized to [-1, 1]. 18 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 19 | Default: 'bilinear'. 20 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 21 | Default: 'zeros'. 22 | align_corners (bool): Whether align corners. Default: True. 23 | 24 | Returns: 25 | Tensor: Warped image or feature map. 26 | """ 27 | if x.size()[-2:] != flow.size()[1:3]: 28 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 29 | f'flow ({flow.size()[1:3]}) are not the same.') 30 | _, _, h, w = x.size() 31 | # create mesh grid 32 | device = flow.device 33 | grid_y, grid_x = torch.meshgrid( 34 | torch.arange(0, h, device=device, dtype=x.dtype), 35 | torch.arange(0, w, device=device, dtype=x.dtype)) 36 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2 37 | grid.requires_grad = False 38 | 39 | grid_flow = grid + flow 40 | # scale grid_flow to [-1,1] 41 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 42 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 43 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 44 | output = F.grid_sample( 45 | x, 46 | grid_flow, 47 | mode=interpolation, 48 | padding_mode=padding_mode, 49 | align_corners=align_corners) 50 | return output 51 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_folder_gt_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_sr_dataset import BaseSRDataset 3 | from .registry import DATASETS 4 | 5 | 6 | @DATASETS.register_module() 7 | class SRFolderGTDataset(BaseSRDataset): 8 | """General ground-truth image folder dataset for image restoration. 9 | 10 | The dataset loads gt (Ground-Truth) image only, 11 | applies specified transforms and finally returns a dict containing paired 12 | data and other information. 13 | 14 | This is the "gt folder mode", which needs to specify the gt 15 | folder path, each folder containing the corresponding images. 16 | Image lists will be generated automatically. 17 | 18 | For example, we have a folder with the following structure: 19 | 20 | :: 21 | 22 | data_root 23 | ├── gt 24 | │ ├── 0001.png 25 | │ ├── 0002.png 26 | 27 | then, you need to set: 28 | 29 | .. code-block:: python 30 | 31 | gt_folder = data_root/gt 32 | 33 | Args: 34 | gt_folder (str | :obj:`Path`): Path to a gt folder. 35 | pipeline (List[dict | callable]): A sequence of data transformations. 36 | scale (int | tuple): Upsampling scale or upsampling scale range. 37 | test_mode (bool): Store `True` when building test dataset. 38 | Default: `False`. 39 | """ 40 | 41 | def __init__(self, 42 | gt_folder, 43 | pipeline, 44 | scale, 45 | test_mode=False, 46 | filename_tmpl='{}'): 47 | super().__init__(pipeline, scale, test_mode) 48 | self.gt_folder = str(gt_folder) 49 | self.filename_tmpl = filename_tmpl 50 | self.data_infos = self.load_annotations() 51 | 52 | def load_annotations(self): 53 | """Load annotations for SR dataset. 54 | 55 | It loads the GT image path from folder. 56 | 57 | Returns: 58 | list[dict]: A list of dicts for path of GT. 59 | """ 60 | data_infos = [] 61 | gt_paths = self.scan_folder(self.gt_folder) 62 | for gt_path in gt_paths: 63 | data_infos.append(dict(gt_path=gt_path)) 64 | return data_infos 65 | -------------------------------------------------------------------------------- /mmedit/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from .pipelines import Compose 8 | 9 | 10 | class BaseDataset(Dataset, metaclass=ABCMeta): 11 | """Base class for datasets. 12 | 13 | All datasets should subclass it. 14 | All subclasses should overwrite: 15 | 16 | ``load_annotations``, supporting to load information and generate 17 | image lists. 18 | 19 | Args: 20 | pipeline (list[dict | callable]): A sequence of data transforms. 21 | test_mode (bool): If True, the dataset will work in test mode. 22 | Otherwise, in train mode. 23 | """ 24 | 25 | def __init__(self, pipeline, test_mode=False): 26 | super().__init__() 27 | self.test_mode = test_mode 28 | self.pipeline = Compose(pipeline) 29 | 30 | @abstractmethod 31 | def load_annotations(self): 32 | """Abstract function for loading annotation. 33 | 34 | All subclasses should overwrite this function 35 | """ 36 | 37 | def prepare_train_data(self, idx): 38 | """Prepare training data. 39 | 40 | Args: 41 | idx (int): Index of the training batch data. 42 | 43 | Returns: 44 | dict: Returned training batch. 45 | """ 46 | results = copy.deepcopy(self.data_infos[idx]) 47 | return self.pipeline(results) 48 | 49 | def prepare_test_data(self, idx): 50 | """Prepare testing data. 51 | 52 | Args: 53 | idx (int): Index for getting each testing batch. 54 | 55 | Returns: 56 | Tensor: Returned testing batch. 57 | """ 58 | results = copy.deepcopy(self.data_infos[idx]) 59 | return self.pipeline(results) 60 | 61 | def __len__(self): 62 | """Length of the dataset. 63 | 64 | Returns: 65 | int: Length of the dataset. 66 | """ 67 | return len(self.data_infos) 68 | 69 | def __getitem__(self, idx): 70 | """Get item at each call. 71 | 72 | Args: 73 | idx (int): Index for getting each item. 74 | """ 75 | if self.test_mode: 76 | return self.prepare_test_data(idx) 77 | 78 | return self.prepare_train_data(idx) 79 | -------------------------------------------------------------------------------- /tools/deployment/mmedit_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import random 4 | import string 5 | from io import BytesIO 6 | 7 | import PIL.Image as Image 8 | import torch 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmedit.apis import init_model, restoration_inference 12 | from mmedit.core import tensor2img 13 | 14 | 15 | class MMEditHandler(BaseHandler): 16 | 17 | def initialize(self, context): 18 | print('MMEditHandler.initialize is called') 19 | properties = context.system_properties 20 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | self.device = torch.device(self.map_location + ':' + 22 | str(properties.get('gpu_id')) if torch.cuda. 23 | is_available() else self.map_location) 24 | self.manifest = context.manifest 25 | 26 | model_dir = properties.get('model_dir') 27 | serialized_file = self.manifest['model']['serializedFile'] 28 | checkpoint = os.path.join(model_dir, serialized_file) 29 | self.config_file = os.path.join(model_dir, 'config.py') 30 | 31 | self.model = init_model(self.config_file, checkpoint, self.device) 32 | self.initialized = True 33 | 34 | def preprocess(self, data, *args, **kwargs): 35 | body = data[0].get('data') or data[0].get('body') 36 | result = Image.open(BytesIO(body)) 37 | # data preprocess is in inference. 38 | return result 39 | 40 | def inference(self, data, *args, **kwargs): 41 | # generate temp image path for restoration_inference 42 | temp_name = ''.join( 43 | random.sample(string.ascii_letters + string.digits, 18)) 44 | temp_path = f'./{temp_name}.png' 45 | data.save(temp_path) 46 | results = restoration_inference(self.model, temp_path) 47 | # delete the temp image path 48 | os.remove(temp_path) 49 | return results 50 | 51 | def postprocess(self, data): 52 | # convert torch tensor to numpy and then convert to bytes 53 | output_list = [] 54 | for data_ in data: 55 | data_np = tensor2img(data_) 56 | data_byte = data_np.tobytes() 57 | output_list.append(data_byte) 58 | 59 | return output_list 60 | -------------------------------------------------------------------------------- /mmedit/utils/setup_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | import warnings 5 | 6 | import cv2 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def setup_multi_processes(cfg): 11 | """Setup multi-processing environment variables.""" 12 | # set multi-process start method as `fork` to speed up the training 13 | if platform.system() != 'Windows': 14 | mp_start_method = cfg.get('mp_start_method', 'fork') 15 | current_method = mp.get_start_method(allow_none=True) 16 | if current_method is not None and current_method != mp_start_method: 17 | warnings.warn( 18 | f'Multi-processing start method `{mp_start_method}` is ' 19 | f'different from the previous setting `{current_method}`.' 20 | f'It will be force set to `{mp_start_method}`. You can change ' 21 | f'this behavior by changing `mp_start_method` in your config.') 22 | mp.set_start_method(mp_start_method, force=True) 23 | 24 | # disable opencv multithreading to avoid system being overloaded 25 | opencv_num_threads = cfg.get('opencv_num_threads', 0) 26 | cv2.setNumThreads(opencv_num_threads) 27 | 28 | # setup OMP threads 29 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 30 | if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 31 | omp_num_threads = 1 32 | warnings.warn( 33 | f'Setting OMP_NUM_THREADS environment variable for each process ' 34 | f'to be {omp_num_threads} in default, to avoid your system being ' 35 | f'overloaded, please further tune the variable for optimal ' 36 | f'performance in your application as needed.') 37 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 38 | 39 | # setup MKL threads 40 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 41 | mkl_num_threads = 1 42 | warnings.warn( 43 | f'Setting MKL_NUM_THREADS environment variable for each process ' 44 | f'to be {mkl_num_threads} in default, to avoid your system being ' 45 | f'overloaded, please further tune the variable for optimal ' 46 | f'performance in your application as needed.') 47 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 48 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def gaussian(x, sigma): 7 | """Gaussian function. 8 | 9 | Args: 10 | x (array_like): The independent variable. 11 | sigma (float): Standard deviation of the gaussian function. 12 | 13 | Return: 14 | ndarray or scalar: Gaussian value of `x`. 15 | """ 16 | return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) 17 | 18 | 19 | def dgaussian(x, sigma): 20 | """Gradient of gaussian. 21 | 22 | Args: 23 | x (array_like): The independent variable. 24 | sigma (float): Standard deviation of the gaussian function. 25 | 26 | Return: 27 | ndarray or scalar: Gradient of gaussian of `x`. 28 | """ 29 | return -x * gaussian(x, sigma) / sigma**2 30 | 31 | 32 | def gauss_filter(sigma, epsilon=1e-2): 33 | """Gradient of gaussian. 34 | 35 | Args: 36 | sigma (float): Standard deviation of the gaussian kernel. 37 | epsilon (float): Small value used when calculating kernel size. 38 | Default: 1e-2. 39 | 40 | Return: 41 | tuple[ndarray]: Gaussian filter along x and y axis. 42 | """ 43 | half_size = np.ceil( 44 | sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) 45 | size = int(2 * half_size + 1) 46 | 47 | # create filter in x axis 48 | filter_x = np.zeros((size, size)) 49 | for i in range(size): 50 | for j in range(size): 51 | filter_x[i, j] = gaussian(i - half_size, sigma) * dgaussian( 52 | j - half_size, sigma) 53 | 54 | # normalize filter 55 | norm = np.sqrt((filter_x**2).sum()) 56 | filter_x = filter_x / norm 57 | filter_y = np.transpose(filter_x) 58 | 59 | return filter_x, filter_y 60 | 61 | 62 | def gauss_gradient(img, sigma): 63 | """Gaussian gradient. 64 | 65 | From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ 66 | submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ 67 | index.html 68 | 69 | Args: 70 | img (ndarray): Input image. 71 | sigma (float): Standard deviation of the gaussian kernel. 72 | 73 | Return: 74 | ndarray: Gaussian gradient of input `img`. 75 | """ 76 | filter_x, filter_y = gauss_filter(sigma) 77 | img_filtered_x = cv2.filter2D( 78 | img, -1, filter_x, borderType=cv2.BORDER_REPLICATE) 79 | img_filtered_y = cv2.filter2D( 80 | img, -1, filter_y, borderType=cv2.BORDER_REPLICATE) 81 | return np.sqrt(img_filtered_x**2 + img_filtered_y**2) 82 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augmentation import (BinarizeImage, ColorJitter, CopyValues, Flip, 3 | GenerateFrameIndices, 4 | GenerateFrameIndiceswithPadding, 5 | GenerateSegmentIndices, MirrorSequence, Pad, 6 | Quantize, RandomAffine, RandomJitter, 7 | RandomMaskDilation, RandomTransposeHW, Resize, 8 | TemporalReverse, UnsharpMasking) 9 | from .compose import Compose 10 | from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown, 11 | CropLike, FixedCrop, ModCrop, PairedRandomCrop, 12 | RandomResizedCrop) 13 | from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor, 14 | ToTensor) 15 | from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap 16 | from .loading import (GetSpatialDiscountMask, LoadImageFromFile, 17 | LoadImageFromFileList, LoadMask, LoadPairedImageFromFile, 18 | RandomLoadResizeBg) 19 | from .matlab_like_resize import MATLABLikeResize 20 | from .matting_aug import (CompositeFg, GenerateSeg, GenerateSoftSeg, 21 | GenerateTrimap, GenerateTrimapWithDistTransform, 22 | MergeFgAndBg, PerturbBg, TransformTrimap) 23 | from .normalization import Normalize, RescaleToZeroOne 24 | from .random_degradations import (DegradationsWithShuffle, RandomBlur, 25 | RandomJPEGCompression, RandomNoise, 26 | RandomResize, RandomVideoCompression) 27 | from .random_down_sampling import RandomDownSampling 28 | 29 | __all__ = [ 30 | 'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask', 31 | 'RandomLoadResizeBg', 'Compose', 'ImageToTensor', 'ToTensor', 32 | 'GetMaskedImage', 'BinarizeImage', 'Flip', 'Pad', 'RandomAffine', 33 | 'RandomJitter', 'ColorJitter', 'RandomMaskDilation', 'RandomTransposeHW', 34 | 'Resize', 'RandomResizedCrop', 'Crop', 'CropAroundCenter', 35 | 'CropAroundUnknown', 'ModCrop', 'PairedRandomCrop', 'Normalize', 36 | 'RescaleToZeroOne', 'GenerateTrimap', 'MergeFgAndBg', 'CompositeFg', 37 | 'TemporalReverse', 'LoadImageFromFileList', 'GenerateFrameIndices', 38 | 'GenerateFrameIndiceswithPadding', 'FixedCrop', 'LoadPairedImageFromFile', 39 | 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 40 | 'GetSpatialDiscountMask', 'RandomDownSampling', 41 | 'GenerateTrimapWithDistTransform', 'TransformTrimap', 42 | 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', 43 | 'CropLike', 'GenerateHeatmap', 'MATLABLikeResize', 'CopyValues', 44 | 'Quantize', 'RandomBlur', 'RandomJPEGCompression', 'RandomNoise', 45 | 'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking', 46 | 'RandomVideoCompression' 47 | ] 48 | -------------------------------------------------------------------------------- /mmedit/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | import logging 4 | import re 5 | import textwrap 6 | from typing import Callable 7 | 8 | from mmcv import print_log 9 | 10 | 11 | def deprecated_function(since: str, removed_in: str, 12 | instructions: str) -> Callable: 13 | """Marks functions as deprecated. 14 | 15 | Throw a warning when a deprecated function is called, and add a note in the 16 | docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py 17 | Args: 18 | since (str): The version when the function was first deprecated. 19 | removed_in (str): The version when the function will be removed. 20 | instructions (str): The action users should take. 21 | Returns: 22 | Callable: A new function, which will be deprecated soon. 23 | """ # noqa: E501 24 | 25 | def decorator(function): 26 | 27 | @functools.wraps(function) 28 | def wrapper(*args, **kwargs): 29 | print_log( 30 | f"'{function.__module__}.{function.__name__}' " 31 | f'is deprecated in version {since} and will be ' 32 | f'removed in version {removed_in}. Please {instructions}.', 33 | # logger='current', 34 | level=logging.WARNING, 35 | ) 36 | return function(*args, **kwargs) 37 | 38 | indent = ' ' 39 | # Add a deprecation note to the docstring. 40 | docstring = function.__doc__ or '' 41 | # Add a note to the docstring. 42 | deprecation_note = textwrap.dedent(f"""\ 43 | .. deprecated:: {since} 44 | Deprecated and will be removed in version {removed_in}. 45 | Please {instructions}. 46 | """) 47 | # Split docstring at first occurrence of newline 48 | pattern = '\n\n' 49 | summary_and_body = re.split(pattern, docstring, 1) 50 | 51 | if len(summary_and_body) > 1: 52 | summary, body = summary_and_body 53 | body = textwrap.indent(textwrap.dedent(body), indent) 54 | summary = '\n'.join( 55 | [textwrap.dedent(string) for string in summary.split('\n')]) 56 | summary = textwrap.indent(summary, prefix=indent) 57 | # Dedent the body. We cannot do this with the presence of the 58 | # summary because the body contains leading whitespaces when the 59 | # summary does not. 60 | new_docstring_parts = [ 61 | deprecation_note, '\n\n', summary, '\n\n', body 62 | ] 63 | else: 64 | summary = summary_and_body[0] 65 | summary = '\n'.join( 66 | [textwrap.dedent(string) for string in summary.split('\n')]) 67 | summary = textwrap.indent(summary, prefix=indent) 68 | new_docstring_parts = [deprecation_note, '\n\n', summary] 69 | 70 | wrapper.__doc__ = ''.join(new_docstring_parts) 71 | 72 | return wrapper 73 | 74 | return decorator 75 | -------------------------------------------------------------------------------- /mmedit/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from __future__ import division 3 | import math 4 | 5 | import torch 6 | from torch.utils.data import DistributedSampler as _DistributedSampler 7 | 8 | from mmedit.core.utils import sync_random_seed 9 | 10 | 11 | class DistributedSampler(_DistributedSampler): 12 | """DistributedSampler inheriting from 13 | `torch.utils.data.DistributedSampler`. 14 | 15 | In pytorch of lower versions, there is no `shuffle` argument. This child 16 | class will port one to DistributedSampler. 17 | """ 18 | 19 | def __init__(self, 20 | dataset, 21 | num_replicas=None, 22 | rank=None, 23 | shuffle=True, 24 | samples_per_gpu=1, 25 | seed=0): 26 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 27 | self.shuffle = shuffle 28 | self.samples_per_gpu = samples_per_gpu 29 | # fix the bug of the official implementation 30 | self.num_samples_per_replica = int( 31 | math.ceil( 32 | len(self.dataset) * 1.0 / self.num_replicas / samples_per_gpu)) 33 | self.num_samples = self.num_samples_per_replica * self.samples_per_gpu 34 | self.total_size = self.num_samples * self.num_replicas 35 | 36 | # In distributed sampling, different ranks should sample 37 | # non-overlapped data in the dataset. Therefore, this function 38 | # is used to make sure that each rank shuffles the data indices 39 | # in the same order based on the same seed. Then different ranks 40 | # could use different indices to select non-overlapped data from the 41 | # same data list. 42 | self.seed = sync_random_seed(seed) 43 | 44 | # to avoid padding bug when meeting too small dataset 45 | if len(dataset) < self.num_replicas * samples_per_gpu: 46 | raise ValueError( 47 | 'You may use too small dataset and our distributed ' 48 | 'sampler cannot pad your dataset correctly. We highly ' 49 | 'recommend you to use fewer GPUs to finish your work') 50 | 51 | def __iter__(self): 52 | # deterministically shuffle based on epoch 53 | if self.shuffle: 54 | g = torch.Generator() 55 | # When :attr:`shuffle=True`, this ensures all replicas 56 | # use a different random ordering for each epoch. 57 | # Otherwise, the next iteration of this sampler will 58 | # yield the same ordering. 59 | g.manual_seed(self.epoch + self.seed) 60 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 61 | else: 62 | indices = torch.arange(len(self.dataset)).tolist() 63 | 64 | # add extra samples to make it evenly divisible 65 | indices += indices[:(self.total_size - len(indices))] 66 | assert len(indices) == self.total_size 67 | 68 | # subsample 69 | indices = indices[self.rank:self.total_size:self.num_replicas] 70 | assert len(indices) == self.num_samples 71 | 72 | return iter(indices) 73 | -------------------------------------------------------------------------------- /mmedit/core/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 10 | """Convert torch Tensors into image numpy arrays. 11 | 12 | After clamping to (min, max), image values will be normalized to [0, 1]. 13 | 14 | For different tensor shapes, this function will have different behaviors: 15 | 16 | 1. 4D mini-batch Tensor of shape (N x 3/1 x H x W): 17 | Use `make_grid` to stitch images in the batch dimension, and then 18 | convert it to numpy array. 19 | 2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W): 20 | Directly change to numpy array. 21 | 22 | Note that the image channel in input tensors should be RGB order. This 23 | function will convert it to cv2 convention, i.e., (H x W x C) with BGR 24 | order. 25 | 26 | Args: 27 | tensor (Tensor | list[Tensor]): Input tensors. 28 | out_type (numpy type): Output types. If ``np.uint8``, transform outputs 29 | to uint8 type with range [0, 255]; otherwise, float type with 30 | range [0, 1]. Default: ``np.uint8``. 31 | min_max (tuple): min and max values for clamp. 32 | 33 | Returns: 34 | (Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray 35 | of shape (H x W). 36 | """ 37 | if not (torch.is_tensor(tensor) or 38 | (isinstance(tensor, list) 39 | and all(torch.is_tensor(t) for t in tensor))): 40 | raise TypeError( 41 | f'tensor or list of tensors expected, got {type(tensor)}') 42 | 43 | if torch.is_tensor(tensor): 44 | tensor = [tensor] 45 | result = [] 46 | for _tensor in tensor: 47 | # Squeeze two times so that: 48 | # 1. (1, 1, h, w) -> (h, w) or 49 | # 3. (1, 3, h, w) -> (3, h, w) or 50 | # 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w) 51 | _tensor = _tensor.squeeze(0).squeeze(0) 52 | 53 | # _tensor = (_tensor 54 | _tensor = _tensor.float().detach().cpu().clamp_(*min_max) 55 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 56 | n_dim = _tensor.dim() 57 | if n_dim == 4: 58 | img_np = make_grid( 59 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 60 | normalize=False).numpy() 61 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) 62 | elif n_dim == 3: 63 | img_np = _tensor.numpy() 64 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) 65 | elif n_dim == 2: 66 | img_np = _tensor.numpy() 67 | else: 68 | raise ValueError('Only support 4D, 3D or 2D tensor. ' 69 | f'But received with dimension: {n_dim}') 70 | if out_type == np.uint8: 71 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 72 | img_np = (img_np * 255.0).round() 73 | img_np = img_np.astype(out_type) 74 | result.append(img_np) 75 | result = result[0] if len(result) == 1 else result 76 | return result 77 | -------------------------------------------------------------------------------- /mmedit/models/common/sr_backbone_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import constant_init, kaiming_init 4 | from mmcv.utils.parrots_wrapper import _BatchNorm 5 | 6 | 7 | def default_init_weights(module, scale=1): 8 | """Initialize network weights. 9 | 10 | Args: 11 | modules (nn.Module): Modules to be initialized. 12 | scale (float): Scale initialized weights, especially for residual 13 | blocks. 14 | """ 15 | for m in module.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | kaiming_init(m, a=0, mode='fan_in', bias=0) 18 | m.weight.data *= scale 19 | elif isinstance(m, nn.Linear): 20 | kaiming_init(m, a=0, mode='fan_in', bias=0) 21 | m.weight.data *= scale 22 | elif isinstance(m, _BatchNorm): 23 | constant_init(m.weight, val=1, bias=0) 24 | 25 | 26 | def make_layer(block, num_blocks, **kwarg): 27 | """Make layers by stacking the same blocks. 28 | 29 | Args: 30 | block (nn.module): nn.module class for basic block. 31 | num_blocks (int): number of blocks. 32 | 33 | Returns: 34 | nn.Sequential: Stacked blocks in nn.Sequential. 35 | """ 36 | layers = [] 37 | for _ in range(num_blocks): 38 | layers.append(block(**kwarg)) 39 | return nn.Sequential(*layers) 40 | 41 | 42 | class ResidualBlockNoBN(nn.Module): 43 | """Residual block without BN. 44 | 45 | It has a style of: 46 | 47 | :: 48 | 49 | ---Conv-ReLU-Conv-+- 50 | |________________| 51 | 52 | Args: 53 | mid_channels (int): Channel number of intermediate features. 54 | Default: 64. 55 | res_scale (float): Used to scale the residual before addition. 56 | Default: 1.0. 57 | """ 58 | 59 | def __init__(self, mid_channels=64, res_scale=1.0, groups=1): 60 | super().__init__() 61 | self.res_scale = res_scale 62 | self.conv1 = nn.Conv2d( 63 | mid_channels, mid_channels, 3, 1, 1, bias=True, groups=groups) 64 | self.conv2 = nn.Conv2d( 65 | mid_channels, mid_channels, 3, 1, 1, bias=True, groups=groups) 66 | 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | # if res_scale < 1.0, use the default initialization, as in EDSR. 70 | # if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet. 71 | if res_scale == 1.0: 72 | self.init_weights() 73 | 74 | def init_weights(self): 75 | """Initialize weights for ResidualBlockNoBN. 76 | 77 | Initialization methods like `kaiming_init` are for VGG-style modules. 78 | For modules with residual paths, using smaller std is better for 79 | stability and performance. We empirically use 0.1. See more details in 80 | "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" 81 | """ 82 | 83 | for m in [self.conv1, self.conv2]: 84 | default_init_weights(m, 0.1) 85 | 86 | def forward(self, x): 87 | """Forward function. 88 | 89 | Args: 90 | x (Tensor): Input tensor with shape (n, c, h, w). 91 | 92 | Returns: 93 | Tensor: Forward results. 94 | """ 95 | 96 | identity = x 97 | out = self.conv2(self.relu(self.conv1(x))) 98 | return identity + out * self.res_scale 99 | -------------------------------------------------------------------------------- /mmedit/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class BaseModel(nn.Module, metaclass=ABCMeta): 10 | """Base model. 11 | 12 | All models should subclass it. 13 | All subclass should overwrite: 14 | 15 | ``init_weights``, supporting to initialize models. 16 | 17 | ``forward_train``, supporting to forward when training. 18 | 19 | ``forward_test``, supporting to forward when testing. 20 | 21 | ``train_step``, supporting to train one step when training. 22 | """ 23 | 24 | @abstractmethod 25 | def init_weights(self): 26 | """Abstract method for initializing weight. 27 | 28 | All subclass should overwrite it. 29 | """ 30 | 31 | @abstractmethod 32 | def forward_train(self, imgs, labels): 33 | """Abstract method for training forward. 34 | 35 | All subclass should overwrite it. 36 | """ 37 | 38 | @abstractmethod 39 | def forward_test(self, imgs): 40 | """Abstract method for testing forward. 41 | 42 | All subclass should overwrite it. 43 | """ 44 | 45 | def forward(self, imgs, labels, test_mode, **kwargs): 46 | """Forward function for base model. 47 | 48 | Args: 49 | imgs (Tensor): Input image(s). 50 | labels (Tensor): Ground-truth label(s). 51 | test_mode (bool): Whether in test mode. 52 | kwargs (dict): Other arguments. 53 | 54 | Returns: 55 | Tensor: Forward results. 56 | """ 57 | 58 | if test_mode: 59 | return self.forward_test(imgs, **kwargs) 60 | 61 | return self.forward_train(imgs, labels, **kwargs) 62 | 63 | @abstractmethod 64 | def train_step(self, data_batch, optimizer): 65 | """Abstract method for one training step. 66 | 67 | All subclass should overwrite it. 68 | """ 69 | 70 | def val_step(self, data_batch, **kwargs): 71 | """Abstract method for one validation step. 72 | 73 | All subclass should overwrite it. 74 | """ 75 | output = self.forward_test(**data_batch, **kwargs) 76 | return output 77 | 78 | def parse_losses(self, losses): 79 | """Parse losses dict for different loss variants. 80 | 81 | Args: 82 | losses (dict): Loss dict. 83 | 84 | Returns: 85 | loss (float): Sum of the total loss. 86 | log_vars (dict): loss dict for different variants. 87 | """ 88 | log_vars = OrderedDict() 89 | for loss_name, loss_value in losses.items(): 90 | if isinstance(loss_value, torch.Tensor): 91 | log_vars[loss_name] = loss_value.mean() 92 | elif isinstance(loss_value, list): 93 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 94 | else: 95 | raise TypeError( 96 | f'{loss_name} is not a tensor or list of tensors') 97 | 98 | loss = sum(_value for _key, _value in log_vars.items() 99 | if 'loss' in _key) 100 | 101 | log_vars['loss'] = loss 102 | for name in log_vars: 103 | log_vars[name] = log_vars[name].item() 104 | 105 | return loss, log_vars 106 | -------------------------------------------------------------------------------- /modules/DeformableAlignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on: 3 | BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from mmcv.cnn import constant_init 9 | from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d 10 | 11 | 12 | 13 | 14 | 15 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d): 16 | """Second-order deformable alignment module. 17 | 18 | Args: 19 | in_channels (int): Same as nn.Conv2d. 20 | out_channels (int): Same as nn.Conv2d. 21 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 22 | stride (int or tuple[int]): Same as nn.Conv2d. 23 | padding (int or tuple[int]): Same as nn.Conv2d. 24 | dilation (int or tuple[int]): Same as nn.Conv2d. 25 | groups (int): Same as nn.Conv2d. 26 | bias (bool or str): If specified as `auto`, it will be decided by the 27 | norm_cfg. Bias will be set as True if norm_cfg is None, otherwise 28 | False. 29 | max_residue_magnitude (int): The maximum magnitude of the offset 30 | residue (Eq. 6 in paper). Default: 10. 31 | """ 32 | 33 | def __init__(self, *args, **kwargs): 34 | self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) 35 | 36 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) 37 | 38 | self.conv_offset = nn.Sequential( 39 | nn.Conv2d(3 * self.out_channels + 5, self.out_channels, 3, 1, 1), 40 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 41 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 42 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 43 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 44 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 45 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), 46 | ) 47 | 48 | self.init_offset() 49 | 50 | def init_offset(self): 51 | constant_init(self.conv_offset[-1], val=0, bias=0) 52 | 53 | def forward(self, x, extra_feat, flow_1, flow_2, streakmask_n1): 54 | # print(streakmask_n1.shape) 55 | extra_feat = torch.cat([extra_feat, flow_1, flow_2, streakmask_n1], dim=1) 56 | out = self.conv_offset(extra_feat) 57 | o1, o2, mask = torch.chunk(out, 3, dim=1) 58 | 59 | # offset 60 | offset = self.max_residue_magnitude * torch.tanh( 61 | torch.cat((o1, o2), dim=1)) 62 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1) 63 | offset_1 = offset_1 + flow_1.flip(1).repeat(1, 64 | offset_1.size(1) // 2, 1, 65 | 1) 66 | offset_2 = offset_2 + flow_2.flip(1).repeat(1, 67 | offset_2.size(1) // 2, 1, 68 | 1) 69 | offset = torch.cat([offset_1, offset_2], dim=1) 70 | 71 | # mask 72 | mask = torch.sigmoid(mask) 73 | 74 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, 75 | self.stride, self.padding, 76 | self.dilation, self.groups, 77 | self.deform_groups) 78 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_face_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | from mmcv.parallel import collate, scatter 5 | 6 | from mmedit.datasets.pipelines import Compose 7 | 8 | try: 9 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 10 | has_facexlib = True 11 | except ImportError: 12 | has_facexlib = False 13 | 14 | 15 | def restoration_face_inference(model, img, upscale_factor=1, face_size=1024): 16 | """Inference image with the model. 17 | 18 | Args: 19 | model (nn.Module): The loaded model. 20 | img (str): File path of input image. 21 | upscale_factor (int, optional): The number of times the input image 22 | is upsampled. Default: 1. 23 | face_size (int, optional): The size of the cropped and aligned faces. 24 | Default: 1024. 25 | 26 | Returns: 27 | Tensor: The predicted restoration result. 28 | """ 29 | device = next(model.parameters()).device # model device 30 | 31 | # build the data pipeline 32 | if model.cfg.get('demo_pipeline', None): 33 | test_pipeline = model.cfg.demo_pipeline 34 | elif model.cfg.get('test_pipeline', None): 35 | test_pipeline = model.cfg.test_pipeline 36 | else: 37 | test_pipeline = model.cfg.val_pipeline 38 | 39 | # remove gt from test_pipeline 40 | keys_to_remove = ['gt', 'gt_path'] 41 | for key in keys_to_remove: 42 | for pipeline in list(test_pipeline): 43 | if 'key' in pipeline and key == pipeline['key']: 44 | test_pipeline.remove(pipeline) 45 | if 'keys' in pipeline and key in pipeline['keys']: 46 | pipeline['keys'].remove(key) 47 | if len(pipeline['keys']) == 0: 48 | test_pipeline.remove(pipeline) 49 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']: 50 | pipeline['meta_keys'].remove(key) 51 | # build the data pipeline 52 | test_pipeline = Compose(test_pipeline) 53 | 54 | # face helper for detecting and aligning faces 55 | assert has_facexlib, 'Please install FaceXLib to use the demo.' 56 | face_helper = FaceRestoreHelper( 57 | upscale_factor, 58 | face_size=face_size, 59 | crop_ratio=(1, 1), 60 | det_model='retinaface_resnet50', 61 | template_3points=True, 62 | save_ext='png', 63 | device=device) 64 | 65 | face_helper.read_image(img) 66 | # get face landmarks for each face 67 | face_helper.get_face_landmarks_5( 68 | only_center_face=False, eye_dist_threshold=None) 69 | # align and warp each face 70 | face_helper.align_warp_face() 71 | 72 | for i, img in enumerate(face_helper.cropped_faces): 73 | # prepare data 74 | data = dict(lq=img.astype(np.float32)) 75 | data = test_pipeline(data) 76 | data = collate([data], samples_per_gpu=1) 77 | if 'cuda' in str(device): 78 | data = scatter(data, [device])[0] 79 | 80 | with torch.no_grad(): 81 | output = model(test_mode=True, **data)['output'] 82 | output = torch.clamp(output, min=0, max=1) 83 | 84 | output = output.squeeze(0).permute(1, 2, 0)[:, :, [2, 1, 0]] 85 | output = output.cpu().numpy() * 255 # (0, 255) 86 | face_helper.add_restored_face(output) 87 | 88 | face_helper.get_inverse_affine(None) 89 | restored_img = face_helper.paste_faces_to_input_image(upsample_img=None) 90 | 91 | return restored_img 92 | -------------------------------------------------------------------------------- /mmedit/core/hooks/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | import torch 6 | from mmcv.runner import HOOKS, Hook 7 | from mmcv.runner.dist_utils import master_only 8 | from torchvision.utils import save_image 9 | 10 | from mmedit.utils import deprecated_function 11 | 12 | 13 | @HOOKS.register_module() 14 | class MMEditVisualizationHook(Hook): 15 | """Visualization hook. 16 | 17 | In this hook, we use the official api `save_image` in torchvision to save 18 | the visualization results. 19 | 20 | Args: 21 | output_dir (str): The file path to store visualizations. 22 | res_name_list (str): The list contains the name of results in outputs 23 | dict. The results in outputs dict must be a torch.Tensor with shape 24 | (n, c, h, w). 25 | interval (int): The interval of calling this hook. If set to -1, 26 | the visualization hook will not be called. Default: -1. 27 | filename_tmpl (str): Format string used to save images. The output file 28 | name will be formatted as this args. Default: 'iter_{}.png'. 29 | rerange (bool): Whether to rerange the output value from [-1, 1] to 30 | [0, 1]. We highly recommend users should preprocess the 31 | visualization results on their own. Here, we just provide a simple 32 | interface. Default: True. 33 | bgr2rgb (bool): Whether to reformat the channel dimension from BGR to 34 | RGB. The final image we will save is following RGB style. 35 | Default: True. 36 | nrow (int): The number of samples in a row. Default: 1. 37 | padding (int): The number of padding pixels between each samples. 38 | Default: 4. 39 | """ 40 | 41 | def __init__(self, 42 | output_dir, 43 | res_name_list, 44 | interval=-1, 45 | filename_tmpl='iter_{}.png', 46 | rerange=True, 47 | bgr2rgb=True, 48 | nrow=1, 49 | padding=4): 50 | assert mmcv.is_list_of(res_name_list, str) 51 | self.output_dir = output_dir 52 | self.res_name_list = res_name_list 53 | self.interval = interval 54 | self.filename_tmpl = filename_tmpl 55 | self.bgr2rgb = bgr2rgb 56 | self.rerange = rerange 57 | self.nrow = nrow 58 | self.padding = padding 59 | 60 | mmcv.mkdir_or_exist(self.output_dir) 61 | 62 | @master_only 63 | def after_train_iter(self, runner): 64 | """The behavior after each train iteration. 65 | 66 | Args: 67 | runner (object): The runner. 68 | """ 69 | if not self.every_n_iters(runner, self.interval): 70 | return 71 | results = runner.outputs['results'] 72 | 73 | filename = self.filename_tmpl.format(runner.iter + 1) 74 | 75 | img_list = [results[k] for k in self.res_name_list if k in results] 76 | img_cat = torch.cat(img_list, dim=3).detach() 77 | if self.rerange: 78 | img_cat = ((img_cat + 1) / 2) 79 | if self.bgr2rgb: 80 | img_cat = img_cat[:, [2, 1, 0], ...] 81 | img_cat = img_cat.clamp_(0, 1) 82 | save_image( 83 | img_cat, 84 | osp.join(self.output_dir, filename), 85 | nrow=self.nrow, 86 | padding=self.padding) 87 | 88 | 89 | @HOOKS.register_module() 90 | class VisualizationHook(MMEditVisualizationHook): 91 | 92 | @deprecated_function('0.16.0', '0.20.0', 'use \'MMEditVisualizationHook\'') 93 | def __init__(self, *args, **kwargs): 94 | super().__init__(*args, **kwargs) 95 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import numpy as np 4 | 5 | from ..registry import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Normalize: 10 | """Normalize images with the given mean and std value. 11 | 12 | Required keys are the keys in attribute "keys", added or modified keys are 13 | the keys in attribute "keys" and these keys with postfix '_norm_cfg'. 14 | It also supports normalizing a list of images. 15 | 16 | Args: 17 | keys (Sequence[str]): The images to be normalized. 18 | mean (np.ndarray): Mean values of different channels. 19 | std (np.ndarray): Std values of different channels. 20 | to_rgb (bool): Whether to convert channels from BGR to RGB. 21 | """ 22 | 23 | def __init__(self, keys, mean, std, to_rgb=False, save_original=False): 24 | self.keys = keys 25 | self.mean = np.array(mean, dtype=np.float32) 26 | self.std = np.array(std, dtype=np.float32) 27 | self.to_rgb = to_rgb 28 | self.save_original = save_original 29 | 30 | def __call__(self, results): 31 | """Call function. 32 | 33 | Args: 34 | results (dict): A dict containing the necessary information and 35 | data for augmentation. 36 | 37 | Returns: 38 | dict: A dict containing the processed data and information. 39 | """ 40 | for key in self.keys: 41 | if isinstance(results[key], list): 42 | if self.save_original: 43 | results[key + '_unnormalised'] = [ 44 | v.copy() for v in results[key] 45 | ] 46 | results[key] = [ 47 | mmcv.imnormalize(v, self.mean, self.std, self.to_rgb) 48 | for v in results[key] 49 | ] 50 | else: 51 | if self.save_original: 52 | results[key + '_unnormalised'] = results[key].copy() 53 | results[key] = mmcv.imnormalize(results[key], self.mean, 54 | self.std, self.to_rgb) 55 | 56 | results['img_norm_cfg'] = dict( 57 | mean=self.mean, std=self.std, to_rgb=self.to_rgb) 58 | return results 59 | 60 | def __repr__(self): 61 | repr_str = self.__class__.__name__ 62 | repr_str += (f'(keys={self.keys}, mean={self.mean}, std={self.std}, ' 63 | f'to_rgb={self.to_rgb})') 64 | 65 | return repr_str 66 | 67 | 68 | @PIPELINES.register_module() 69 | class RescaleToZeroOne: 70 | """Transform the images into a range between 0 and 1. 71 | 72 | Required keys are the keys in attribute "keys", added or modified keys are 73 | the keys in attribute "keys". 74 | It also supports rescaling a list of images. 75 | 76 | Args: 77 | keys (Sequence[str]): The images to be transformed. 78 | """ 79 | 80 | def __init__(self, keys): 81 | self.keys = keys 82 | 83 | def __call__(self, results): 84 | """Call function. 85 | 86 | Args: 87 | results (dict): A dict containing the necessary information and 88 | data for augmentation. 89 | 90 | Returns: 91 | dict: A dict containing the processed data and information. 92 | """ 93 | for key in self.keys: 94 | if isinstance(results[key], list): 95 | results[key] = [ 96 | v.astype(np.float32) / 255. for v in results[key] 97 | ] 98 | else: 99 | results[key] = results[key].astype(np.float32) / 255. 100 | return results 101 | 102 | def __repr__(self): 103 | return self.__class__.__name__ + f'(keys={self.keys})' 104 | -------------------------------------------------------------------------------- /mmedit/datasets/base_sr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import os.path as osp 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from mmcv import scandir 9 | 10 | from mmedit.core.registry import build_metric 11 | from .base_dataset import BaseDataset 12 | 13 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 14 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF') 15 | FEATURE_BASED_METRICS = ['FID', 'KID'] 16 | 17 | 18 | class BaseSRDataset(BaseDataset): 19 | """Base class for super resolution datasets.""" 20 | 21 | def __init__(self, pipeline, scale, test_mode=False): 22 | super().__init__(pipeline, test_mode) 23 | self.scale = scale 24 | 25 | @staticmethod 26 | def scan_folder(path): 27 | """Obtain image path list (including sub-folders) from a given folder. 28 | 29 | Args: 30 | path (str | :obj:`Path`): Folder path. 31 | 32 | Returns: 33 | list[str]: image list obtained form given folder. 34 | """ 35 | 36 | if isinstance(path, (str, Path)): 37 | path = str(path) 38 | else: 39 | raise TypeError("'path' must be a str or a Path object, " 40 | f'but received {type(path)}.') 41 | 42 | images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) 43 | images = [osp.join(path, v) for v in images] 44 | assert images, f'{path} has no valid image file.' 45 | return images 46 | 47 | def __getitem__(self, idx): 48 | """Get item at each call. 49 | 50 | Args: 51 | idx (int): Index for getting each item. 52 | """ 53 | results = copy.deepcopy(self.data_infos[idx]) 54 | results['scale'] = self.scale 55 | return self.pipeline(results) 56 | 57 | def evaluate(self, results, logger=None): 58 | """Evaluate with different metrics. 59 | 60 | Args: 61 | results (list[tuple]): The output of forward_test() of the model. 62 | 63 | Return: 64 | dict: Evaluation results dict. 65 | """ 66 | if not isinstance(results, list): 67 | raise TypeError(f'results must be a list, but got {type(results)}') 68 | assert len(results) == len(self), ( 69 | 'The length of results is not equal to the dataset len: ' 70 | f'{len(results)} != {len(self)}') 71 | 72 | results = [res['eval_result'] for res in results] # a list of dict 73 | eval_result = defaultdict(list) # a dict of list 74 | 75 | for res in results: 76 | for metric, val in res.items(): 77 | eval_result[metric].append(val) 78 | for metric, val_list in eval_result.items(): 79 | assert len(val_list) == len(self), ( 80 | f'Length of evaluation result of {metric} is {len(val_list)}, ' 81 | f'should be {len(self)}') 82 | 83 | # average the results 84 | eval_result.update({ 85 | metric: sum(values) / len(self) 86 | for metric, values in eval_result.items() 87 | if metric not in ['_inception_feat'] + FEATURE_BASED_METRICS 88 | }) 89 | 90 | # evaluate feature-based metrics 91 | if '_inception_feat' in eval_result: 92 | feat1, feat2 = [], [] 93 | for f1, f2 in eval_result['_inception_feat']: 94 | feat1.append(f1) 95 | feat2.append(f2) 96 | feat1 = np.concatenate(feat1, 0) 97 | feat2 = np.concatenate(feat2, 0) 98 | 99 | for metric in FEATURE_BASED_METRICS: 100 | if metric in eval_result: 101 | metric_func = build_metric(eval_result[metric].pop()) 102 | eval_result[metric] = metric_func(feat1, feat2) 103 | 104 | # delete a redundant key for clean logging 105 | del eval_result['_inception_feat'] 106 | 107 | return eval_result 108 | -------------------------------------------------------------------------------- /tools/deployment/mmedit2torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | 6 | import mmcv 7 | 8 | try: 9 | from model_archiver.model_packaging import package_model 10 | from model_archiver.model_packaging_utils import ModelExportUtils 11 | except ImportError: 12 | package_model = None 13 | 14 | 15 | def mmedit2torchserve( 16 | config_file: str, 17 | checkpoint_file: str, 18 | output_folder: str, 19 | model_name: str, 20 | model_version: str = '1.0', 21 | force: bool = False, 22 | ): 23 | """Converts MMEditing model (config + checkpoint) to TorchServe `.mar`. 24 | 25 | Args: 26 | config_file: 27 | In MMEditing config format. 28 | The contents vary for each task repository. 29 | checkpoint_file: 30 | In MMEditing checkpoint format. 31 | The contents vary for each task repository. 32 | output_folder: 33 | Folder where `{model_name}.mar` will be created. 34 | The file created will be in TorchServe archive format. 35 | model_name: 36 | If not None, used for naming the `{model_name}.mar` file 37 | that will be created under `output_folder`. 38 | If None, `{Path(checkpoint_file).stem}` will be used. 39 | model_version: 40 | Model's version. 41 | force: 42 | If True, if there is an existing `{model_name}.mar` 43 | file under `output_folder` it will be overwritten. 44 | """ 45 | mmcv.mkdir_or_exist(output_folder) 46 | 47 | config = mmcv.Config.fromfile(config_file) 48 | 49 | with TemporaryDirectory() as tmpdir: 50 | config.dump(f'{tmpdir}/config.py') 51 | 52 | args_ = Namespace( 53 | **{ 54 | 'model_file': f'{tmpdir}/config.py', 55 | 'serialized_file': checkpoint_file, 56 | 'handler': f'{Path(__file__).parent}/mmedit_handler.py', 57 | 'model_name': model_name or Path(checkpoint_file).stem, 58 | 'version': model_version, 59 | 'export_path': output_folder, 60 | 'force': force, 61 | 'requirements_file': None, 62 | 'extra_files': None, 63 | 'runtime': 'python', 64 | 'archive_format': 'default' 65 | }) 66 | print(args_.model_name) 67 | manifest = ModelExportUtils.generate_manifest_json(args_) 68 | package_model(args_, manifest) 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser( 73 | description='Convert MMEditing models to TorchServe `.mar` format.') 74 | parser.add_argument('config', type=str, help='config file path') 75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 76 | parser.add_argument( 77 | '--output-folder', 78 | type=str, 79 | required=True, 80 | help='Folder where `{model_name}.mar` will be created.') 81 | parser.add_argument( 82 | '--model-name', 83 | type=str, 84 | default=None, 85 | help='If not None, used for naming the `{model_name}.mar`' 86 | 'file that will be created under `output_folder`.' 87 | 'If None, `{Path(checkpoint_file).stem}` will be used.') 88 | parser.add_argument( 89 | '--model-version', 90 | type=str, 91 | default='1.0', 92 | help='Number used for versioning.') 93 | parser.add_argument( 94 | '-f', 95 | '--force', 96 | action='store_true', 97 | help='overwrite the existing `{model_name}.mar`') 98 | args_ = parser.parse_args() 99 | 100 | return args_ 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | 106 | if package_model is None: 107 | raise ImportError('`torch-model-archiver` is required.' 108 | 'Try: pip install torch-model-archiver') 109 | 110 | mmedit2torchserve(args.config, args.checkpoint, args.output_folder, 111 | args.model_name, args.model_version, args.force) 112 | -------------------------------------------------------------------------------- /mmedit/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import functools 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | def reduce_loss(loss, reduction): 8 | """Reduce loss as specified. 9 | 10 | Args: 11 | loss (Tensor): Elementwise loss tensor. 12 | reduction (str): Options are "none", "mean" and "sum". 13 | 14 | Returns: 15 | Tensor: Reduced loss tensor. 16 | """ 17 | reduction_enum = F._Reduction.get_enum(reduction) 18 | # none: 0, elementwise_mean:1, sum: 2 19 | if reduction_enum == 0: 20 | return loss 21 | if reduction_enum == 1: 22 | return loss.mean() 23 | 24 | return loss.sum() 25 | 26 | 27 | def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False): 28 | """Apply element-wise weight and reduce loss. 29 | 30 | Args: 31 | loss (Tensor): Element-wise loss. 32 | weight (Tensor): Element-wise weights. Default: None. 33 | reduction (str): Same as built-in losses of PyTorch. Options are 34 | "none", "mean" and "sum". Default: 'mean'. 35 | sample_wise (bool): Whether calculate the loss sample-wise. This 36 | argument only takes effect when `reduction` is 'mean' and `weight` 37 | (argument of `forward()`) is not None. It will first reduces loss 38 | with 'mean' per-sample, and then it means over all the samples. 39 | Default: False. 40 | 41 | Returns: 42 | Tensor: Processed loss values. 43 | """ 44 | # if weight is specified, apply element-wise weight 45 | if weight is not None: 46 | assert weight.dim() == loss.dim() 47 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 48 | loss = loss * weight 49 | 50 | # if weight is not specified or reduction is sum, just reduce the loss 51 | if weight is None or reduction == 'sum': 52 | loss = reduce_loss(loss, reduction) 53 | # if reduction is mean, then compute mean over masked region 54 | elif reduction == 'mean': 55 | # expand weight from N1HW to NCHW 56 | if weight.size(1) == 1: 57 | weight = weight.expand_as(loss) 58 | # small value to prevent division by zero 59 | eps = 1e-12 60 | 61 | # perform sample-wise mean 62 | if sample_wise: 63 | weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111 64 | loss = (loss / (weight + eps)).sum() / weight.size(0) 65 | # perform pixel-wise mean 66 | else: 67 | loss = loss.sum() / (weight.sum() + eps) 68 | 69 | return loss 70 | 71 | 72 | def masked_loss(loss_func): 73 | """Create a masked version of a given loss function. 74 | 75 | To use this decorator, the loss function must have the signature like 76 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 77 | element-wise loss without any reduction. This decorator will add weight 78 | and reduction arguments to the function. The decorated function will have 79 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 80 | avg_factor=None, **kwargs)`. 81 | 82 | :Example: 83 | 84 | >>> import torch 85 | >>> @masked_loss 86 | >>> def l1_loss(pred, target): 87 | >>> return (pred - target).abs() 88 | 89 | >>> pred = torch.Tensor([0, 2, 3]) 90 | >>> target = torch.Tensor([1, 1, 1]) 91 | >>> weight = torch.Tensor([1, 0, 1]) 92 | 93 | >>> l1_loss(pred, target) 94 | tensor(1.3333) 95 | >>> l1_loss(pred, target, weight) 96 | tensor(1.5000) 97 | >>> l1_loss(pred, target, reduction='none') 98 | tensor([1., 1., 2.]) 99 | >>> l1_loss(pred, target, weight, reduction='sum') 100 | tensor(3.) 101 | """ 102 | 103 | @functools.wraps(loss_func) 104 | def wrapper(pred, 105 | target, 106 | weight=None, 107 | reduction='mean', 108 | sample_wise=False, 109 | **kwargs): 110 | # get element-wise loss 111 | loss = loss_func(pred, target, **kwargs) 112 | loss = mask_reduce_loss(loss, weight, reduction, sample_wise) 113 | return loss 114 | 115 | return wrapper 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mask-Guided Progressive Network for Joint Raindrop and Rain Streak Removal in Videos (ACM MM 2023) 2 | Hongtao Wu, Yijun Yang, Haoyu Chen, Jingjing Ren, Lei Zhu 3 | 4 | This repo is the official Pytorch implementation of [Mask-Guided Progressive Network for Joint Raindrop and Rain Streak Removal in Videos](https://dl.acm.org/doi/abs/10.1145/3581783.3612001). 5 | The first dataset and approach for video rain streaks and raindrops removal. 6 | 7 | 8 |
9 | 10 | > **Abstract:** *Videos captured in rainy weather are unavoidably corrupted by both rain streaks and raindrops in driving scenarios, and it is desirable and challenging to recover background details obscured by rain streaks and raindrops. However, existing video rain removal methods often address either video rain streak removal or video raindrop removal, thereby suffer from degraded performance when deal with both simultaneously. The bottleneck is a lack of a video dataset, where each video frame contains both rain streaks and raindrops. To address this issue, we in this work generate a synthesized dataset, namely VRDS, with 102 rainy videos from diverse scenarios, and each video frame has the corresponding rain streak map, raindrop mask, and the underlying rain-free clean image (ground truth). Moreover, we devise a mask-guided progressive video deraining network (ViMP-Net) to remove both rain streaks and raindrops of each video frame. Specifically, we develop an intensity-guided alignment block to predict the rain streak intensity map and remove the rain streaks of the input rainy video at the first stage. Then, we predict a raindrop mask and pass it into a devised mask-guided dual transformer block to learn inter-frame and intra-frame transformer features, which are then fed into a decoder for further eliminating raindrops. Experimental results demonstrate that our ViMP-Net outperforms state-of-the-art methods on our synthetic dataset and real-world rainy videos.* 11 |
12 | 13 | 14 | ## Our Dataset 15 | Our VRDS dataset can be downloaded [here](https://hkustgz-my.sharepoint.com/:f:/g/personal/hwu375_connect_hkust-gz_edu_cn/EmI_nfrnMyNAohEwNtnq50MB22RWxp-x_mtp264aVzOxlA?e=CjP3kO). 16 | 17 | ## Installation 18 | 19 | This implementation is based on [MMEditing](https://github.com/open-mmlab/mmediting), 20 | which is an open-source image and video editing toolbox. 21 | 22 | 23 | Below are quick steps for installation. 24 | 25 | **Step 1.** 26 | Install PyTorch following [official instructions](https://pytorch.org/get-started/locally/). 27 | 28 | **Step 2.** 29 | Install MMCV with [MIM](https://github.com/open-mmlab/mim). 30 | 31 | ```shell 32 | pip3 install openmim 33 | mim install mmcv-full 34 | ``` 35 | 36 | **Step 3.** 37 | Install ViMP-Net from source. 38 | 39 | ```shell 40 | git clone https://github.com/TonyHongtaoWu/ViMP-Net.git 41 | cd ViMP-Net 42 | pip3 install -e . 43 | ``` 44 | 45 | Please refer to [MMEditing Installation](https://github.com/open-mmlab/mmediting/blob/master/docs/en/install.md) for more detailed instruction. 46 | 47 | 48 | ## Training and Testing 49 | You may need to adjust the dataset path and dataloader before starting. 50 | 51 | You can train ViMP-Net on VRDS dataset using the below command with 4 GPUs: 52 | 53 | ```shell 54 | CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/derainers/ViMPNet/ViMPNet.py 4 55 | ``` 56 | 57 | You can use the following command with 1 GPU to test your trained model `model.pth`: 58 | 59 | ```shell 60 | CUDA_VISIBLE_DEVICES=0 ./tools/dist_test.sh configs/derainers/ViMPNet/ViMPNet.py "model.pth" 1 --save-path './save_path/' 61 | ``` 62 | 63 | You can find one model checkpoint trained on VRDS dataset [here](https://drive.google.com/drive/folders/1Iu_sxlN3nUpi99QUxWAnRP1a0mNNm2JU?usp=sharing). 64 | 65 | 66 | 67 | ## Our Results 68 | The visual results of ViMP-Net can be downloaded in [Google Drive](https://drive.google.com/file/d/1yEFbQbh45hWOu2g4HR9-SUvZZpyJJd7l/view?usp=sharing) and [Outlook](https://hkustgz-my.sharepoint.com/:u:/g/personal/hwu375_connect_hkust-gz_edu_cn/EVM_XI3KcE9DgQaE9hbXvLQBjhnMP0rvQnSVcnOFcsMyTA?e=7tE2Kk). 69 | 70 | 71 | ## Contact 72 | Should you have any question or suggestion, please contact hwu375@connect.hkust-gz.edu.cn. 73 | 74 | ## Acknowledgement 75 | This code is based on [MMEditing](https://github.com/open-mmlab/mmagic) and [FuseFormer](https://github.com/ruiliu-ai/FuseFormer). 76 | 77 | ## Citation 78 | If you find this repository helpful to your research, please consider citing the following: 79 | ``` 80 | @inproceedings{wu2023mask, 81 | title={Mask-Guided Progressive Network for Joint Raindrop and Rain Streak Removal in Videos}, 82 | author={Wu, Hongtao and Yang, Yijun and Chen, Haoyu and Ren, Jingjing and Zhu, Lei}, 83 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 84 | pages={7216--7225}, 85 | year={2023} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from mmcv.runner import Hook 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class EvalIterHook(Hook): 9 | """Non-Distributed evaluation hook for iteration-based runner. 10 | 11 | This hook will regularly perform evaluation in a given interval when 12 | performing in non-distributed environment. 13 | 14 | Args: 15 | dataloader (DataLoader): A PyTorch dataloader. 16 | interval (int): Evaluation interval. Default: 1. 17 | eval_kwargs (dict): Other eval kwargs. It contains: 18 | save_image (bool): Whether to save image. 19 | save_path (str): The path to save image. 20 | """ 21 | 22 | def __init__(self, dataloader, interval=1, **eval_kwargs): 23 | if not isinstance(dataloader, DataLoader): 24 | raise TypeError('dataloader must be a pytorch DataLoader, ' 25 | f'but got { type(dataloader)}') 26 | self.dataloader = dataloader 27 | self.interval = interval 28 | self.eval_kwargs = eval_kwargs 29 | self.save_image = self.eval_kwargs.pop('save_image', False) 30 | self.save_path = self.eval_kwargs.pop('save_path', None) 31 | 32 | def after_train_iter(self, runner): 33 | """The behavior after each train iteration. 34 | 35 | Args: 36 | runner (``mmcv.runner.BaseRunner``): The runner. 37 | """ 38 | if not self.every_n_iters(runner, self.interval): 39 | return 40 | runner.log_buffer.clear() 41 | from mmedit.apis import single_gpu_test 42 | results = single_gpu_test( 43 | runner.model, 44 | self.dataloader, 45 | save_image=self.save_image, 46 | save_path=self.save_path, 47 | iteration=runner.iter) 48 | self.evaluate(runner, results) 49 | 50 | def evaluate(self, runner, results): 51 | """Evaluation function. 52 | 53 | Args: 54 | runner (``mmcv.runner.BaseRunner``): The runner. 55 | results (dict): Model forward results. 56 | """ 57 | eval_res = self.dataloader.dataset.evaluate( 58 | results, logger=runner.logger, **self.eval_kwargs) 59 | for name, val in eval_res.items(): 60 | if isinstance(val, dict): 61 | runner.log_buffer.output.update(val) 62 | continue 63 | runner.log_buffer.output[name] = val 64 | runner.log_buffer.ready = True 65 | # call `after_val_epoch` after evaluation. 66 | # This is a hack. 67 | # Because epoch does not naturally exist In IterBasedRunner, 68 | # thus we consider the end of an evluation as the end of an epoch. 69 | # With this hack , we can support epoch based hooks. 70 | if 'iter' in runner.__class__.__name__.lower(): 71 | runner.call_hook('after_val_epoch') 72 | 73 | 74 | class DistEvalIterHook(EvalIterHook): 75 | """Distributed evaluation hook. 76 | 77 | Args: 78 | dataloader (DataLoader): A PyTorch dataloader. 79 | interval (int): Evaluation interval. Default: 1. 80 | tmpdir (str | None): Temporary directory to save the results of all 81 | processes. Default: None. 82 | gpu_collect (bool): Whether to use gpu or cpu to collect results. 83 | Default: False. 84 | eval_kwargs (dict): Other eval kwargs. It may contain: 85 | save_image (bool): Whether save image. 86 | save_path (str): The path to save image. 87 | """ 88 | 89 | def __init__(self, 90 | dataloader, 91 | interval=1, 92 | gpu_collect=False, 93 | **eval_kwargs): 94 | super().__init__(dataloader, interval, **eval_kwargs) 95 | self.gpu_collect = gpu_collect 96 | 97 | def after_train_iter(self, runner): 98 | """The behavior after each train iteration. 99 | 100 | Args: 101 | runner (``mmcv.runner.BaseRunner``): The runner. 102 | """ 103 | if not self.every_n_iters(runner, self.interval): 104 | return 105 | runner.log_buffer.clear() 106 | from mmedit.apis import multi_gpu_test 107 | results = multi_gpu_test( 108 | runner.model, 109 | self.dataloader, 110 | tmpdir=osp.join(runner.work_dir, '.eval_hook'), 111 | gpu_collect=self.gpu_collect, 112 | save_image=self.save_image, 113 | save_path=self.save_path, 114 | iteration=runner.iter) 115 | if runner.rank == 0: 116 | print('\n') 117 | self.evaluate(runner, results) 118 | -------------------------------------------------------------------------------- /configs/derainers/ViMPNet/ViMPNet.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='DrainNet', 4 | generator=dict( 5 | type='ViMPNet', 6 | mid_channels=64, 7 | num_blocks=9, 8 | spynet_pretrained='https://download.openmmlab.com/mmediting/restorers/' 9 | 'basicvsr/spynet_20210409-c6c1bd09.pth'), 10 | pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'), 11 | ) 12 | # model training and testing settings 13 | train_cfg = dict(fix_iter=5000) 14 | test_cfg = dict(metrics=['PSNR'], crop_border=0) 15 | 16 | train_dataset_type = 'SRFolderMultipleGTDataset' 17 | val_dataset_type = 'SRFolderMultipleGTDataset' 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='GenerateSegmentIndices', 22 | interval_list=[1], 23 | filename_tmpl='{08d}.png', 24 | start_idx=0), 25 | dict( 26 | type='LoadImageFromFileList', 27 | io_backend='disk', 28 | key='lq', 29 | channel_order='rgb'), 30 | dict( 31 | type='LoadImageFromFileList', 32 | io_backend='disk', 33 | key='gt', 34 | channel_order='rgb'), 35 | dict( 36 | type='LoadImageFromFileList', 37 | io_backend='disk', 38 | key='mask', 39 | flag='grayscale', 40 | channel_order='rgb'), 41 | dict( 42 | type='LoadImageFromFileList', 43 | io_backend='disk', 44 | key='drop', 45 | channel_order='rgb'), 46 | dict( 47 | type='LoadImageFromFileList', 48 | io_backend='disk', 49 | key='streak', 50 | flag='grayscale', 51 | channel_order='rgb'), 52 | 53 | dict(type='RescaleToZeroOne', keys=['lq', 'gt', 'mask', 'drop', 'streak']), 54 | dict(type='PairedRandomCrop', gt_patch_size=256), 55 | dict( 56 | type='Flip', keys=['lq', 'gt', 'mask', 'drop', 'streak'], flip_ratio=0.5, 57 | direction='horizontal'), 58 | dict(type='Flip', keys=['lq', 'gt', 'mask', 'drop', 'streak'], flip_ratio=0.5, direction='vertical'), 59 | dict(type='RandomTransposeHW', keys=['lq', 'gt', 'mask', 'drop', 'streak'], transpose_ratio=0.5), 60 | dict(type='FramesToTensor', keys=['lq', 'gt', 'mask', 'drop', 'streak']), 61 | dict(type='Collect', keys=['lq', 'gt', 'mask', 'drop', 'streak'], meta_keys=['lq_path', 'gt_path', 'mask_path', 'drop_path', 'streak_path']) 62 | ] 63 | 64 | test_pipeline = [ 65 | dict(type='GenerateSegmentIndices', interval_list=[1], filename_tmpl='{:08d}.png'), 66 | dict( 67 | type='LoadImageFromFileList', 68 | io_backend='disk', 69 | key='lq', 70 | channel_order='rgb'), 71 | dict( 72 | type='LoadImageFromFileList', 73 | io_backend='disk', 74 | key='gt', 75 | channel_order='rgb'), 76 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']), 77 | dict(type='FramesToTensor', keys=['lq', 'gt']), 78 | dict( 79 | type='Collect', 80 | keys=['lq', 'gt'], 81 | meta_keys=['lq_path', 'gt_path', 'key']) 82 | ] 83 | 84 | data = dict( 85 | workers_per_gpu=5, 86 | train_dataloader=dict(samples_per_gpu=1, drop_last=True), 87 | val_dataloader=dict(samples_per_gpu=1), 88 | test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=1), 89 | 90 | # train 91 | train=dict( 92 | type='RepeatDataset', 93 | times=1000, 94 | dataset=dict( 95 | type=train_dataset_type, 96 | lq_folder="../data/lq", 97 | gt_folder="../data/gt", 98 | mask_folder="../data/01drop", 99 | drop_folder="../data/drop", 100 | streak_folder="../data/mstreak", 101 | pipeline=train_pipeline, 102 | scale=1, 103 | num_input_frames=5, 104 | test_mode=False)), 105 | 106 | # # test 107 | test=dict( 108 | type=val_dataset_type, 109 | lq_folder='../data/testlq', 110 | gt_folder='../data/testgt', 111 | pipeline=test_pipeline, 112 | scale=1, 113 | test_mode=True), 114 | ) 115 | 116 | # optimizer 117 | optimizers = dict( 118 | generator=dict( 119 | type='Adam', 120 | lr=1e-4, 121 | betas=(0.9, 0.99), 122 | paramwise_cfg=dict(custom_keys={'spynet': dict(lr_mult=0.25)}))) 123 | 124 | # learning policy 125 | total_iters = 500000 126 | lr_config = dict( 127 | policy='CosineRestart', 128 | by_epoch=False, 129 | periods=[500000], 130 | restart_weights=[1], 131 | min_lr=1e-9) 132 | 133 | checkpoint_config = dict(interval=5000, save_optimizer=True, by_epoch=False) 134 | 135 | log_config = dict( 136 | interval=1000, 137 | hooks=[ 138 | dict(type='TextLoggerHook', by_epoch=False), 139 | ]) 140 | visual_config = None 141 | 142 | # runtime settings 143 | dist_params = dict(backend='nccl') 144 | log_level = 'INFO' 145 | 146 | work_dir = f'../ViMPNettrain/' 147 | load_from = None 148 | resume_from = None 149 | workflow = [('train', 1)] 150 | find_unused_parameters = True 151 | -------------------------------------------------------------------------------- /mmedit/apis/restoration_video_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | 9 | from mmedit.datasets.pipelines import Compose 10 | 11 | VIDEO_EXTENSIONS = ('.mp4', '.mov') 12 | 13 | 14 | def pad_sequence(data, window_size): 15 | padding = window_size // 2 16 | 17 | data = torch.cat([ 18 | data[:, 1 + padding:1 + 2 * padding].flip(1), data, 19 | data[:, -1 - 2 * padding:-1 - padding].flip(1) 20 | ], 21 | dim=1) 22 | 23 | return data 24 | 25 | 26 | def restoration_video_inference(model, 27 | img_dir, 28 | window_size, 29 | start_idx, 30 | filename_tmpl, 31 | max_seq_len=None): 32 | """Inference image with the model. 33 | 34 | Args: 35 | model (nn.Module): The loaded model. 36 | img_dir (str): Directory of the input video. 37 | window_size (int): The window size used in sliding-window framework. 38 | This value should be set according to the settings of the network. 39 | A value smaller than 0 means using recurrent framework. 40 | start_idx (int): The index corresponds to the first frame in the 41 | sequence. 42 | filename_tmpl (str): Template for file name. 43 | max_seq_len (int | None): The maximum sequence length that the model 44 | processes. If the sequence length is larger than this number, 45 | the sequence is split into multiple segments. If it is None, 46 | the entire sequence is processed at once. 47 | 48 | Returns: 49 | Tensor: The predicted restoration result. 50 | """ 51 | 52 | device = next(model.parameters()).device # model device 53 | 54 | # build the data pipeline 55 | if model.cfg.get('demo_pipeline', None): 56 | test_pipeline = model.cfg.demo_pipeline 57 | elif model.cfg.get('test_pipeline', None): 58 | test_pipeline = model.cfg.test_pipeline 59 | else: 60 | test_pipeline = model.cfg.val_pipeline 61 | 62 | # check if the input is a video 63 | file_extension = osp.splitext(img_dir)[1] 64 | if file_extension in VIDEO_EXTENSIONS: 65 | video_reader = mmcv.VideoReader(img_dir) 66 | # load the images 67 | data = dict(lq=[], lq_path=None, key=img_dir) 68 | for frame in video_reader: 69 | data['lq'].append(np.flip(frame, axis=2)) 70 | 71 | # remove the data loading pipeline 72 | tmp_pipeline = [] 73 | for pipeline in test_pipeline: 74 | if pipeline['type'] not in [ 75 | 'GenerateSegmentIndices', 'LoadImageFromFileList' 76 | ]: 77 | tmp_pipeline.append(pipeline) 78 | test_pipeline = tmp_pipeline 79 | else: 80 | # the first element in the pipeline must be 'GenerateSegmentIndices' 81 | if test_pipeline[0]['type'] != 'GenerateSegmentIndices': 82 | raise TypeError('The first element in the pipeline must be ' 83 | f'"GenerateSegmentIndices", but got ' 84 | f'"{test_pipeline[0]["type"]}".') 85 | 86 | # specify start_idx and filename_tmpl 87 | test_pipeline[0]['start_idx'] = start_idx 88 | test_pipeline[0]['filename_tmpl'] = filename_tmpl 89 | 90 | # prepare data 91 | sequence_length = len(glob.glob(osp.join(img_dir, '*'))) 92 | lq_folder = osp.dirname(img_dir) 93 | key = osp.basename(img_dir) 94 | data = dict( 95 | lq_path=lq_folder, 96 | gt_path='', 97 | key=key, 98 | sequence_length=sequence_length) 99 | 100 | # compose the pipeline 101 | test_pipeline = Compose(test_pipeline) 102 | data = test_pipeline(data) 103 | data = data['lq'].unsqueeze(0) # in cpu 104 | 105 | # forward the model 106 | with torch.no_grad(): 107 | if window_size > 0: # sliding window framework 108 | data = pad_sequence(data, window_size) 109 | result = [] 110 | for i in range(0, data.size(1) - 2 * (window_size // 2)): 111 | data_i = data[:, i:i + window_size].to(device) 112 | result.append(model(lq=data_i, test_mode=True)['output'].cpu()) 113 | result = torch.stack(result, dim=1) 114 | else: # recurrent framework 115 | if max_seq_len is None: 116 | result = model( 117 | lq=data.to(device), test_mode=True)['output'].cpu() 118 | else: 119 | result = [] 120 | for i in range(0, data.size(1), max_seq_len): 121 | result.append( 122 | model( 123 | lq=data[:, i:i + max_seq_len].to(device), 124 | test_mode=True)['output'].cpu()) 125 | result = torch.cat(result, dim=1) 126 | return result 127 | -------------------------------------------------------------------------------- /mmedit/core/hooks/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | from copy import deepcopy 4 | from functools import partial 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.parallel import is_module_wrapper 9 | from mmcv.runner import HOOKS, Hook 10 | 11 | 12 | @HOOKS.register_module() 13 | class ExponentialMovingAverageHook(Hook): 14 | """Exponential Moving Average Hook. 15 | 16 | Exponential moving average is a trick that widely used in current GAN 17 | literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is 18 | maintaining a model with the same architecture, but its parameters are 19 | updated as a moving average of the trained weights in the original model. 20 | In general, the model with moving averaged weights achieves better 21 | performance. 22 | 23 | Args: 24 | module_keys (str | tuple[str]): The name of the ema model. Note that we 25 | require these keys are followed by '_ema' so that we can easily 26 | find the original model by discarding the last four characters. 27 | interp_mode (str, optional): Mode of the interpolation method. 28 | Defaults to 'lerp'. 29 | interp_cfg (dict | None, optional): Set arguments of the interpolation 30 | function. Defaults to None. 31 | interval (int, optional): Evaluation interval (by iterations). 32 | Default: -1. 33 | start_iter (int, optional): Start iteration for ema. If the start 34 | iteration is not reached, the weights of ema model will maintain 35 | the same as the original one. Otherwise, its parameters are updated 36 | as a moving average of the trained weights in the original model. 37 | Default: 0. 38 | """ 39 | 40 | def __init__(self, 41 | module_keys, 42 | interp_mode='lerp', 43 | interp_cfg=None, 44 | interval=-1, 45 | start_iter=0): 46 | super().__init__() 47 | assert isinstance(module_keys, str) or mmcv.is_tuple_of( 48 | module_keys, str) 49 | self.module_keys = (module_keys, ) if isinstance(module_keys, 50 | str) else module_keys 51 | # sanity check for the format of module keys 52 | for k in self.module_keys: 53 | assert k.endswith( 54 | '_ema'), 'You should give keys that end with "_ema".' 55 | self.interp_mode = interp_mode 56 | self.interp_cfg = dict() if interp_cfg is None else deepcopy( 57 | interp_cfg) 58 | self.interval = interval 59 | self.start_iter = start_iter 60 | 61 | assert hasattr( 62 | self, interp_mode 63 | ), f'Currently, we do not support {self.interp_mode} for EMA.' 64 | self.interp_func = partial( 65 | getattr(self, interp_mode), **self.interp_cfg) 66 | 67 | @staticmethod 68 | def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True): 69 | m = momentum if trainable else momentum_nontrainable 70 | return a + (b - a) * m 71 | 72 | def every_n_iters(self, runner, n): 73 | if runner.iter < self.start_iter: 74 | return True 75 | return (runner.iter + 1 - self.start_iter) % n == 0 if n > 0 else False 76 | 77 | @torch.no_grad() 78 | def after_train_iter(self, runner): 79 | if not self.every_n_iters(runner, self.interval): 80 | return 81 | 82 | model = runner.model.module if is_module_wrapper( 83 | runner.model) else runner.model 84 | 85 | for key in self.module_keys: 86 | # get current ema states 87 | ema_net = getattr(model, key) 88 | states_ema = ema_net.state_dict(keep_vars=False) 89 | # get currently original states 90 | net = getattr(model, key[:-4]) 91 | states_orig = net.state_dict(keep_vars=True) 92 | 93 | for k, v in states_orig.items(): 94 | if runner.iter < self.start_iter: 95 | states_ema[k].data.copy_(v.data) 96 | else: 97 | states_ema[k] = self.interp_func( 98 | v, states_ema[k], trainable=v.requires_grad).detach() 99 | ema_net.load_state_dict(states_ema, strict=True) 100 | 101 | def before_run(self, runner): 102 | model = runner.model.module if is_module_wrapper( 103 | runner.model) else runner.model 104 | # sanity check for ema model 105 | for k in self.module_keys: 106 | if not hasattr(model, k) and not hasattr(model, k[:-4]): 107 | raise RuntimeError( 108 | f'Cannot find both {k[:-4]} and {k} network for EMA hook.') 109 | if not hasattr(model, k) and hasattr(model, k[:-4]): 110 | setattr(model, k, deepcopy(getattr(model, k[:-4]))) 111 | warnings.warn( 112 | f'We do not suggest construct and initialize EMA model {k}' 113 | ' in hook. You may explicitly define it by yourself.') 114 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/random_down_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv import imresize 7 | 8 | from ..registry import PIPELINES 9 | 10 | 11 | @PIPELINES.register_module() 12 | class RandomDownSampling: 13 | """Generate LQ image from GT (and crop), which will randomly pick a scale. 14 | 15 | Args: 16 | scale_min (float): The minimum of upsampling scale, inclusive. 17 | Default: 1.0. 18 | scale_max (float): The maximum of upsampling scale, exclusive. 19 | Default: 4.0. 20 | patch_size (int): The cropped lr patch size. 21 | Default: None, means no crop. 22 | interpolation (str): Interpolation method, accepted values are 23 | "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' 24 | backend, "nearest", "bilinear", "bicubic", "box", "lanczos", 25 | "hamming" for 'pillow' backend. 26 | Default: "bicubic". 27 | backend (str | None): The image resize backend type. Options are `cv2`, 28 | `pillow`, `None`. If backend is None, the global imread_backend 29 | specified by ``mmcv.use_backend()`` will be used. 30 | Default: "pillow". 31 | 32 | Scale will be picked in the range of [scale_min, scale_max). 33 | """ 34 | 35 | def __init__(self, 36 | scale_min=1.0, 37 | scale_max=4.0, 38 | patch_size=None, 39 | interpolation='bicubic', 40 | backend='pillow'): 41 | assert scale_max >= scale_min 42 | self.scale_min = scale_min 43 | self.scale_max = scale_max 44 | self.patch_size = patch_size 45 | self.interpolation = interpolation 46 | self.backend = backend 47 | 48 | def __call__(self, results): 49 | """Call function. 50 | 51 | Args: 52 | results (dict): A dict containing the necessary information and 53 | data for augmentation. 'gt' is required. 54 | 55 | Returns: 56 | dict: A dict containing the processed data and information. 57 | modified 'gt', supplement 'lq' and 'scale' to keys. 58 | """ 59 | img = results['gt'] 60 | scale = np.random.uniform(self.scale_min, self.scale_max) 61 | 62 | if self.patch_size is None: 63 | h_lr = math.floor(img.shape[-3] / scale + 1e-9) 64 | w_lr = math.floor(img.shape[-2] / scale + 1e-9) 65 | img = img[:round(h_lr * scale), :round(w_lr * scale), :] 66 | img_down = resize_fn(img, (w_lr, h_lr), self.interpolation, 67 | self.backend) 68 | crop_lr, crop_hr = img_down, img 69 | else: 70 | w_lr = self.patch_size 71 | w_hr = round(w_lr * scale) 72 | x0 = np.random.randint(0, img.shape[-3] - w_hr) 73 | y0 = np.random.randint(0, img.shape[-2] - w_hr) 74 | crop_hr = img[x0:x0 + w_hr, y0:y0 + w_hr, :] 75 | crop_lr = resize_fn(crop_hr, w_lr, self.interpolation, 76 | self.backend) 77 | results['gt'] = crop_hr 78 | results['lq'] = crop_lr 79 | results['scale'] = scale 80 | 81 | return results 82 | 83 | def __repr__(self): 84 | repr_str = self.__class__.__name__ 85 | repr_str += (f' scale_min={self.scale_min}, ' 86 | f'scale_max={self.scale_max}, ' 87 | f'patch_size={self.patch_size}, ' 88 | f'interpolation={self.interpolation}, ' 89 | f'backend={self.backend}') 90 | 91 | return repr_str 92 | 93 | 94 | def resize_fn(img, size, interpolation='bicubic', backend='pillow'): 95 | """Resize the given image to a given size. 96 | 97 | Args: 98 | img (ndarray | torch.Tensor): The input image. 99 | size (int | tuple[int]): Target size w or (w, h). 100 | interpolation (str): Interpolation method, accepted values are 101 | "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' 102 | backend, "nearest", "bilinear", "bicubic", "box", "lanczos", 103 | "hamming" for 'pillow' backend. 104 | Default: "bicubic". 105 | backend (str | None): The image resize backend type. Options are `cv2`, 106 | `pillow`, `None`. If backend is None, the global imread_backend 107 | specified by ``mmcv.use_backend()`` will be used. 108 | Default: "pillow". 109 | 110 | Returns: 111 | ndarray | torch.Tensor: `resized_img`, whose type is same as `img`. 112 | """ 113 | if isinstance(size, int): 114 | size = (size, size) 115 | if isinstance(img, np.ndarray): 116 | return imresize( 117 | img, size, interpolation=interpolation, backend=backend) 118 | elif isinstance(img, torch.Tensor): 119 | image = imresize( 120 | img.numpy(), size, interpolation=interpolation, backend=backend) 121 | return torch.from_numpy(image) 122 | 123 | else: 124 | raise TypeError('img should got np.ndarray or torch.Tensor,' 125 | f'but got {type(img)}') 126 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.utils import print_log 7 | 8 | _integer_types = ( 9 | np.byte, 10 | np.ubyte, # 8 bits 11 | np.short, 12 | np.ushort, # 16 bits 13 | np.intc, 14 | np.uintc, # 16 or 32 or 64 bits 15 | np.int_, 16 | np.uint, # 32 or 64 bits 17 | np.longlong, 18 | np.ulonglong) # 64 bits 19 | 20 | _integer_ranges = { 21 | t: (np.iinfo(t).min, np.iinfo(t).max) 22 | for t in _integer_types 23 | } 24 | 25 | dtype_range = { 26 | np.bool_: (False, True), 27 | np.bool8: (False, True), 28 | np.float16: (-1, 1), 29 | np.float32: (-1, 1), 30 | np.float64: (-1, 1) 31 | } 32 | dtype_range.update(_integer_ranges) 33 | 34 | 35 | def dtype_limits(image, clip_negative=False): 36 | """Return intensity limits, i.e. (min, max) tuple, of the image's dtype. 37 | 38 | This function is adopted from skimage: 39 | https://github.com/scikit-image/scikit-image/blob/ 40 | 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/util/dtype.py#L35 41 | 42 | Args: 43 | image (ndarray): Input image. 44 | clip_negative (bool, optional): If True, clip the negative range 45 | (i.e. return 0 for min intensity) even if the image dtype allows 46 | negative values. 47 | 48 | Returns 49 | tuple: Lower and upper intensity limits. 50 | """ 51 | imin, imax = dtype_range[image.dtype.type] 52 | if clip_negative: 53 | imin = 0 54 | return imin, imax 55 | 56 | 57 | def adjust_gamma(image, gamma=1, gain=1): 58 | """Performs Gamma Correction on the input image. 59 | 60 | This function is adopted from skimage: 61 | https://github.com/scikit-image/scikit-image/blob/ 62 | 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/exposure/ 63 | exposure.py#L439-L494 64 | 65 | Also known as Power Law Transform. 66 | This function transforms the input image pixelwise according to the 67 | equation ``O = I**gamma`` after scaling each pixel to the range 0 to 1. 68 | 69 | Args: 70 | image (ndarray): Input image. 71 | gamma (float, optional): Non negative real number. Defaults to 1. 72 | gain (float, optional): The constant multiplier. Defaults to 1. 73 | 74 | Returns: 75 | ndarray: Gamma corrected output image. 76 | """ 77 | if np.any(image < 0): 78 | raise ValueError('Image Correction methods work correctly only on ' 79 | 'images with non-negative values. Use ' 80 | 'skimage.exposure.rescale_intensity.') 81 | 82 | dtype = image.dtype.type 83 | 84 | if gamma < 0: 85 | raise ValueError('Gamma should be a non-negative real number.') 86 | 87 | scale = float(dtype_limits(image, True)[1] - dtype_limits(image, True)[0]) 88 | 89 | out = ((image / scale)**gamma) * scale * gain 90 | return out.astype(dtype) 91 | 92 | 93 | def random_choose_unknown(unknown, crop_size): 94 | """Randomly choose an unknown start (top-left) point for a given crop_size. 95 | 96 | Args: 97 | unknown (np.ndarray): The binary unknown mask. 98 | crop_size (tuple[int]): The given crop size. 99 | 100 | Returns: 101 | tuple[int]: The top-left point of the chosen bbox. 102 | """ 103 | h, w = unknown.shape 104 | crop_h, crop_w = crop_size 105 | delta_h = center_h = crop_h // 2 106 | delta_w = center_w = crop_w // 2 107 | 108 | # mask out the validate area for selecting the cropping center 109 | mask = np.zeros_like(unknown) 110 | mask[delta_h:h - delta_h, delta_w:w - delta_w] = 1 111 | if np.any(unknown & mask): 112 | center_h_list, center_w_list = np.where(unknown & mask) 113 | elif np.any(unknown): 114 | center_h_list, center_w_list = np.where(unknown) 115 | else: 116 | print_log('No unknown pixels found!', level=logging.WARNING) 117 | center_h_list = [center_h] 118 | center_w_list = [center_w] 119 | num_unknowns = len(center_h_list) 120 | rand_ind = np.random.randint(num_unknowns) 121 | center_h = center_h_list[rand_ind] 122 | center_w = center_w_list[rand_ind] 123 | 124 | # make sure the top-left point is valid 125 | top = np.clip(center_h - delta_h, 0, h - crop_h) 126 | left = np.clip(center_w - delta_w, 0, w - crop_w) 127 | 128 | return top, left 129 | 130 | 131 | def make_coord(shape, ranges=None, flatten=True): 132 | """Make coordinates at grid centers. 133 | 134 | Args: 135 | shape (tuple): shape of image. 136 | ranges (tuple): range of coordinate value. Default: None. 137 | flatten (bool): flatten to (n, 2) or Not. Default: True. 138 | 139 | return: 140 | coord (Tensor): coordinates. 141 | """ 142 | coord_seqs = [] 143 | for i, n in enumerate(shape): 144 | if ranges is None: 145 | v0, v1 = -1, 1 146 | else: 147 | v0, v1 = ranges[i] 148 | r = (v1 - v0) / (2 * n) 149 | seq = v0 + r + (2 * r) * torch.arange(n).float() 150 | coord_seqs.append(seq) 151 | coord = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 152 | if flatten: 153 | coord = coord.view(-1, coord.shape[-1]) 154 | return coord 155 | -------------------------------------------------------------------------------- /mmedit/datasets/sr_folder_multiple_gt_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os 4 | import os.path as osp 5 | 6 | import mmcv 7 | 8 | from .base_sr_dataset import BaseSRDataset 9 | from .registry import DATASETS 10 | 11 | 12 | @DATASETS.register_module() 13 | class SRFolderMultipleGTDataset(BaseSRDataset): 14 | """General dataset for video super resolution, used for recurrent networks. 15 | 16 | The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth) 17 | frames. Then it applies specified transforms and finally returns a dict 18 | containing paired data and other information. 19 | 20 | This dataset takes an annotation file specifying the sequences used in 21 | training or test. If no annotation file is provided, it assumes all video 22 | sequences under the root directory is used for training or test. 23 | 24 | In the annotation file (.txt), each line contains: 25 | 26 | 1. folder name; 27 | 2. number of frames in this sequence (in the same folder) 28 | 29 | Examples: 30 | 31 | :: 32 | 33 | calendar 41 34 | city 34 35 | foliage 49 36 | walk 47 37 | 38 | Args: 39 | lq_folder (str | :obj:`Path`): Path to a lq folder. 40 | gt_folder (str | :obj:`Path`): Path to a gt folder. 41 | pipeline (list[dict | callable]): A sequence of data transformations. 42 | scale (int): Upsampling scale ratio. 43 | ann_file (str): The path to the annotation file. If None, we assume 44 | that all sequences in the folder is used. Default: None 45 | num_input_frames (None | int): The number of frames per iteration. 46 | If None, the whole clip is extracted. If it is a positive integer, 47 | a sequence of 'num_input_frames' frames is extracted from the clip. 48 | Note that non-positive integers are not accepted. Default: None. 49 | test_mode (bool): Store `True` when building test dataset. 50 | Default: `True`. 51 | """ 52 | 53 | def __init__(self, 54 | lq_folder, 55 | gt_folder, 56 | mask_folder, 57 | drop_folder, 58 | streak_folder, 59 | pipeline, 60 | scale, 61 | ann_file=None, 62 | num_input_frames=None, 63 | test_mode=True): 64 | super().__init__(pipeline, scale, test_mode) 65 | 66 | self.lq_folder = str(lq_folder) 67 | self.gt_folder = str(gt_folder) 68 | self.mask_folder = str(mask_folder) 69 | self.drop_folder = str(drop_folder) 70 | self.streak_folder = str(streak_folder) 71 | 72 | self.ann_file = ann_file 73 | if num_input_frames is not None and num_input_frames <= 0: 74 | raise ValueError('"num_input_frames" must be None or positive, ' 75 | f'but got {num_input_frames}.') 76 | self.num_input_frames = num_input_frames 77 | 78 | self.data_infos = self.load_annotations() 79 | 80 | def _load_annotations_from_file(self): 81 | data_infos = [] 82 | 83 | ann_list = mmcv.list_from_file(self.ann_file) 84 | for ann in ann_list: 85 | key, sequence_length = ann.strip().split(' ') 86 | if self.num_input_frames is None: 87 | num_input_frames = sequence_length 88 | else: 89 | num_input_frames = self.num_input_frames 90 | data_infos.append( 91 | dict( 92 | lq_path=self.lq_folder, 93 | gt_path=self.gt_folder, 94 | mask_path=self.mask_folder, 95 | drop_path=self.drop_folder, 96 | streak_path=self.streak_folder, 97 | key=key, 98 | num_input_frames=int(num_input_frames), 99 | sequence_length=int(sequence_length))) 100 | 101 | return data_infos 102 | 103 | def load_annotations(self): 104 | """Load annotations for the dataset. 105 | 106 | Returns: 107 | list[dict]: Returned list of dicts for paired paths of LQ and GT. 108 | """ 109 | 110 | if self.ann_file: 111 | return self._load_annotations_from_file() 112 | 113 | sequences = sorted(glob.glob(osp.join(self.lq_folder, '*'))) 114 | data_infos = [] 115 | for sequence in sequences: 116 | sequence_length = len(glob.glob(osp.join(sequence, '*.png'))) 117 | if self.num_input_frames is None: 118 | num_input_frames = sequence_length 119 | else: 120 | num_input_frames = self.num_input_frames 121 | data_infos.append( 122 | dict( 123 | lq_path=self.lq_folder, 124 | gt_path=self.gt_folder, 125 | mask_path=self.mask_folder, 126 | drop_path=self.drop_folder, 127 | streak_path=self.streak_folder, 128 | key=sequence.replace(f'{self.lq_folder}{os.sep}', ''), 129 | num_input_frames=num_input_frames, 130 | sequence_length=sequence_length)) 131 | 132 | return data_infos 133 | -------------------------------------------------------------------------------- /mmedit/core/export/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import warnings 4 | 5 | import numpy as np 6 | import onnxruntime as ort 7 | import torch 8 | from torch import nn 9 | 10 | from mmedit.models import BaseMattor, BasicRestorer, build_model 11 | 12 | 13 | def inference_with_session(sess, io_binding, output_names, input_tensor): 14 | device_type = input_tensor.device.type 15 | device_id = input_tensor.device.index 16 | device_id = 0 if device_id is None else device_id 17 | io_binding.bind_input( 18 | name='input', 19 | device_type=device_type, 20 | device_id=device_id, 21 | element_type=np.float32, 22 | shape=input_tensor.shape, 23 | buffer_ptr=input_tensor.data_ptr()) 24 | for name in output_names: 25 | io_binding.bind_output(name) 26 | sess.run_with_iobinding(io_binding) 27 | pred = io_binding.copy_outputs_to_cpu() 28 | return pred 29 | 30 | 31 | class ONNXRuntimeMattor(nn.Module): 32 | 33 | def __init__(self, sess, io_binding, output_names, base_model): 34 | super(ONNXRuntimeMattor, self).__init__() 35 | self.sess = sess 36 | self.io_binding = io_binding 37 | self.output_names = output_names 38 | self.base_model = base_model 39 | 40 | def forward(self, 41 | merged, 42 | trimap, 43 | meta, 44 | test_mode=False, 45 | save_image=False, 46 | save_path=None, 47 | iteration=None): 48 | input_tensor = torch.cat((merged, trimap), 1).contiguous() 49 | pred_alpha = inference_with_session(self.sess, self.io_binding, 50 | self.output_names, input_tensor)[0] 51 | 52 | pred_alpha = pred_alpha.squeeze() 53 | pred_alpha = self.base_model.restore_shape(pred_alpha, meta) 54 | eval_result = self.base_model.evaluate(pred_alpha, meta) 55 | 56 | if save_image: 57 | self.base_model.save_image(pred_alpha, meta, save_path, iteration) 58 | 59 | return {'pred_alpha': pred_alpha, 'eval_result': eval_result} 60 | 61 | 62 | class RestorerGenerator(nn.Module): 63 | 64 | def __init__(self, sess, io_binding, output_names): 65 | super(RestorerGenerator, self).__init__() 66 | self.sess = sess 67 | self.io_binding = io_binding 68 | self.output_names = output_names 69 | 70 | def forward(self, x): 71 | pred = inference_with_session(self.sess, self.io_binding, 72 | self.output_names, x)[0] 73 | pred = torch.from_numpy(pred) 74 | return pred 75 | 76 | 77 | class ONNXRuntimeRestorer(nn.Module): 78 | 79 | def __init__(self, sess, io_binding, output_names, base_model): 80 | super(ONNXRuntimeRestorer, self).__init__() 81 | self.sess = sess 82 | self.io_binding = io_binding 83 | self.output_names = output_names 84 | self.base_model = base_model 85 | restorer_generator = RestorerGenerator(self.sess, self.io_binding, 86 | self.output_names) 87 | base_model.generator = restorer_generator 88 | 89 | def forward(self, lq, gt=None, test_mode=False, **kwargs): 90 | return self.base_model(lq, gt=gt, test_mode=test_mode, **kwargs) 91 | 92 | 93 | class ONNXRuntimeEditing(nn.Module): 94 | 95 | def __init__(self, onnx_file, cfg, device_id): 96 | super(ONNXRuntimeEditing, self).__init__() 97 | ort_custom_op_path = '' 98 | try: 99 | from mmcv.ops import get_onnxruntime_op_path 100 | ort_custom_op_path = get_onnxruntime_op_path() 101 | except (ImportError, ModuleNotFoundError): 102 | warnings.warn('If input model has custom op from mmcv, \ 103 | you may have to build mmcv with ONNXRuntime from source.') 104 | session_options = ort.SessionOptions() 105 | # register custom op for onnxruntime 106 | if osp.exists(ort_custom_op_path): 107 | session_options.register_custom_ops_library(ort_custom_op_path) 108 | sess = ort.InferenceSession(onnx_file, session_options) 109 | providers = ['CPUExecutionProvider'] 110 | options = [{}] 111 | is_cuda_available = ort.get_device() == 'GPU' 112 | if is_cuda_available: 113 | providers.insert(0, 'CUDAExecutionProvider') 114 | options.insert(0, {'device_id': device_id}) 115 | 116 | sess.set_providers(providers, options) 117 | 118 | self.sess = sess 119 | self.device_id = device_id 120 | self.io_binding = sess.io_binding() 121 | self.output_names = [_.name for _ in sess.get_outputs()] 122 | 123 | base_model = build_model( 124 | cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) 125 | 126 | if isinstance(base_model, BaseMattor): 127 | WrapperClass = ONNXRuntimeMattor 128 | elif isinstance(base_model, BasicRestorer): 129 | WrapperClass = ONNXRuntimeRestorer 130 | self.wrapper = WrapperClass(self.sess, self.io_binding, 131 | self.output_names, base_model) 132 | 133 | def forward(self, **kwargs): 134 | return self.wrapper(**kwargs) 135 | -------------------------------------------------------------------------------- /mmedit/core/evaluation/inceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | from scipy import linalg 5 | 6 | from ..registry import METRICS 7 | from .inception_utils import load_inception 8 | 9 | 10 | class InceptionV3: 11 | """Feature extractor features using InceptionV3 model. 12 | 13 | Args: 14 | style (str): The model style to run Inception model. it must be either 15 | 'StyleGAN' or 'pytorch'. 16 | device (torch.device): device to extract feature. 17 | inception_kwargs (**kwargs): kwargs for InceptionV3. 18 | """ 19 | 20 | def __init__(self, style='StyleGAN', device='cpu', **inception_kwargs): 21 | self.inception = load_inception( 22 | style=style, **inception_kwargs).eval().to(device) 23 | self.style = style 24 | self.device = device 25 | 26 | def __call__(self, img1, img2, crop_border=0): 27 | """Extract features of real and fake images. 28 | 29 | Args: 30 | img1, img2 (np.ndarray): Images with range [0, 255] 31 | and shape (H, W, C). 32 | 33 | Returns: 34 | (tuple): Pair of features extracted from InceptionV3 model. 35 | """ 36 | return ( 37 | self.forward_inception(self.img2tensor(img1)).numpy(), 38 | self.forward_inception(self.img2tensor(img2)).numpy(), 39 | ) 40 | 41 | def img2tensor(self, img): 42 | img = np.expand_dims(img.transpose((2, 0, 1)), axis=0) 43 | if self.style == 'StyleGAN': 44 | return torch.tensor(img).to(device=self.device, dtype=torch.uint8) 45 | 46 | return torch.from_numpy(img / 255.).to( 47 | device=self.device, dtype=torch.float32) 48 | 49 | def forward_inception(self, x): 50 | if self.style == 'StyleGAN': 51 | return self.inception(x).cpu() 52 | 53 | return self.inception(x)[-1].view(x.shape[0], -1).cpu() 54 | 55 | 56 | def frechet_distance(X, Y): 57 | """Compute the frechet distance.""" 58 | 59 | muX, covX = np.mean(X, axis=0), np.cov(X, rowvar=False) 60 | muY, covY = np.mean(Y, axis=0), np.cov(Y, rowvar=False) 61 | 62 | cov_sqrt = linalg.sqrtm(covX.dot(covY)) 63 | frechet_distance = np.square(muX - muY).sum() + np.trace(covX) + np.trace( 64 | covY) - 2 * np.trace(cov_sqrt) 65 | return np.real(frechet_distance) 66 | 67 | 68 | @METRICS.register_module() 69 | class FID: 70 | """FID metric.""" 71 | 72 | def __call__(self, X, Y): 73 | """Calculate FID. 74 | 75 | Args: 76 | X (np.ndarray): Input feature X with shape (n_samples, dims). 77 | Y (np.ndarray): Input feature Y with shape (n_samples, dims). 78 | 79 | Returns: 80 | (float): fid value. 81 | """ 82 | return frechet_distance(X, Y) 83 | 84 | 85 | def polynomial_kernel(X, Y=None, degree=3, gamma=None, coef=1): 86 | """Create a polynomial kernel.""" 87 | Y = X if Y is None else Y 88 | if gamma is None: 89 | gamma = 1.0 / X.shape[1] 90 | K = ((X @ Y.T) * gamma + coef)**degree 91 | return K 92 | 93 | 94 | def mmd2(X, Y, unbiased=True): 95 | """Compute the Maximum Mean Discrepancy.""" 96 | XX = polynomial_kernel(X, X) 97 | YY = polynomial_kernel(Y, Y) 98 | XY = polynomial_kernel(X, Y) 99 | 100 | m = X.shape[0] 101 | if not unbiased: 102 | return (np.sum(XX) + np.sum(YY) - 2 * np.sum(XY)) / m**2 103 | 104 | trX = np.trace(XX) 105 | trY = np.trace(YY) 106 | return (np.sum(XX) - trX + np.sum(YY) - 107 | trY) / (m * (m - 1)) - 2 * np.sum(XY) / m**2 108 | 109 | 110 | @METRICS.register_module() 111 | class KID: 112 | """Implementation of `KID `. 113 | 114 | Args: 115 | num_repeats (int): The number of repetitions. Default: 100. 116 | sample_size (int): Size to sample. Default: 1000. 117 | use_unbiased_estimator (bool): Whether to use KID as an unbiased 118 | estimator. Using an unbiased estimator is desirable in the case of 119 | finite sample size, especially when the number of samples are 120 | small. Using an unbiased estimator is recommended in most cases. 121 | Default: True 122 | """ 123 | 124 | def __init__(self, 125 | num_repeats=100, 126 | sample_size=1000, 127 | use_unbiased_estimator=True): 128 | self.num_repeats = num_repeats 129 | self.sample_size = sample_size 130 | self.unbiased = use_unbiased_estimator 131 | 132 | def __call__(self, X, Y): 133 | """Calculate KID. 134 | 135 | Args: 136 | X (np.ndarray): Input feature X with shape (n_samples, dims). 137 | Y (np.ndarray): Input feature Y with shape (n_samples, dims). 138 | 139 | Returns: 140 | (dict): dict containing mean and std of KID values. 141 | """ 142 | num_samples = X.shape[0] 143 | kid = list() 144 | for i in range(self.num_repeats): 145 | X_ = X[np.random.choice( 146 | num_samples, self.sample_size, replace=False)] 147 | Y_ = Y[np.random.choice( 148 | num_samples, self.sample_size, replace=False)] 149 | kid.append(mmd2(X_, Y_, unbiased=self.unbiased)) 150 | kid = np.array(kid) 151 | return dict(KID_MEAN=kid.mean(), KID_STD=kid.std()) 152 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | 5 | import mmcv 6 | import torch 7 | from mmcv import Config, DictAction 8 | from mmcv.parallel import MMDataParallel 9 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint 10 | 11 | from mmedit.apis import multi_gpu_test, set_random_seed, single_gpu_test 12 | from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper 13 | from mmedit.datasets import build_dataloader, build_dataset 14 | from mmedit.models import build_model 15 | from mmedit.utils import setup_multi_processes 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='mmediting tester') 20 | parser.add_argument('config', help='test config file path') 21 | parser.add_argument('checkpoint', help='checkpoint file') 22 | parser.add_argument('--seed', type=int, default=None, help='random seed') 23 | parser.add_argument( 24 | '--deterministic', 25 | action='store_true', 26 | help='whether to set deterministic options for CUDNN backend.') 27 | parser.add_argument('--out', help='output result pickle file') 28 | parser.add_argument( 29 | '--gpu-collect', 30 | action='store_true', 31 | help='whether to use gpu to collect results') 32 | parser.add_argument( 33 | '--save-path', 34 | default=None, 35 | type=str, 36 | help='path to store images and if not given, will not save image') 37 | parser.add_argument('--tmpdir', help='tmp dir for writing some results') 38 | parser.add_argument( 39 | '--cfg-options', 40 | nargs='+', 41 | action=DictAction, 42 | help='override some settings in the used config, the key-value pair ' 43 | 'in xxx=yyy format will be merged into config file. If the value to ' 44 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 45 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 46 | 'Note that the quotation marks are necessary and that no white space ' 47 | 'is allowed.') 48 | parser.add_argument( 49 | '--launcher', 50 | choices=['none', 'pytorch', 'slurm', 'mpi'], 51 | default='none', 52 | help='job launcher') 53 | parser.add_argument('--local_rank', type=int, default=0) 54 | args = parser.parse_args() 55 | if 'LOCAL_RANK' not in os.environ: 56 | os.environ['LOCAL_RANK'] = str(args.local_rank) 57 | return args 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | 63 | cfg = Config.fromfile(args.config) 64 | 65 | if args.cfg_options is not None: 66 | cfg.merge_from_dict(args.cfg_options) 67 | 68 | # set multi-process settings 69 | setup_multi_processes(cfg) 70 | 71 | # set cudnn_benchmark 72 | if cfg.get('cudnn_benchmark', False): 73 | torch.backends.cudnn.benchmark = True 74 | cfg.model.pretrained = None 75 | 76 | # init distributed env first, since logger depends on the dist info. 77 | if args.launcher == 'none': 78 | distributed = False 79 | else: 80 | distributed = True 81 | init_dist(args.launcher, **cfg.dist_params) 82 | 83 | rank, _ = get_dist_info() 84 | 85 | # set random seeds 86 | if args.seed is not None: 87 | if rank == 0: 88 | print('set random seed to', args.seed) 89 | set_random_seed(args.seed, deterministic=args.deterministic) 90 | 91 | # build the dataloader 92 | # TODO: support multiple images per gpu (only minor changes are needed) 93 | dataset = build_dataset(cfg.data.test) 94 | 95 | loader_cfg = { 96 | **dict((k, cfg.data[k]) for k in ['workers_per_gpu'] if k in cfg.data), 97 | **dict( 98 | samples_per_gpu=1, 99 | drop_last=False, 100 | shuffle=False, 101 | dist=distributed), 102 | **cfg.data.get('test_dataloader', {}) 103 | } 104 | 105 | data_loader = build_dataloader(dataset, **loader_cfg) 106 | 107 | # build the model and load checkpoint 108 | model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) 109 | 110 | args.save_image = args.save_path is not None 111 | empty_cache = cfg.get('empty_cache', False) 112 | if not distributed: 113 | _ = load_checkpoint(model, args.checkpoint, map_location='cpu') 114 | model = MMDataParallel(model, device_ids=[0]) 115 | outputs = single_gpu_test( 116 | model, 117 | data_loader, 118 | save_path=args.save_path, 119 | save_image=args.save_image) 120 | else: 121 | find_unused_parameters = cfg.get('find_unused_parameters', False) 122 | model = DistributedDataParallelWrapper( 123 | model, 124 | device_ids=[torch.cuda.current_device()], 125 | broadcast_buffers=False, 126 | find_unused_parameters=find_unused_parameters) 127 | 128 | device_id = torch.cuda.current_device() 129 | _ = load_checkpoint( 130 | model, 131 | args.checkpoint, 132 | map_location=lambda storage, loc: storage.cuda(device_id)) 133 | outputs = multi_gpu_test( 134 | model, 135 | data_loader, 136 | args.tmpdir, 137 | args.gpu_collect, 138 | save_path=args.save_path, 139 | save_image=args.save_image, 140 | empty_cache=empty_cache) 141 | 142 | if rank == 0 and 'eval_result' in outputs[0]: 143 | print('') 144 | # print metrics 145 | stats = dataset.evaluate(outputs) 146 | for stat in stats: 147 | print('Eval-{}: {}'.format(stat, stats[stat])) 148 | 149 | # save result pickle 150 | if args.out: 151 | print('writing results to {}'.format(args.out)) 152 | mmcv.dump(outputs, args.out) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /mmedit/core/distributed_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel 5 | from mmcv.parallel.scatter_gather import scatter_kwargs 6 | from torch.cuda._utils import _get_device_index 7 | 8 | 9 | @MODULE_WRAPPERS.register_module() 10 | class DistributedDataParallelWrapper(nn.Module): 11 | """A DistributedDataParallel wrapper for models in MMediting. 12 | 13 | In MMedting, there is a need to wrap different modules in the models 14 | with separate DistributedDataParallel. Otherwise, it will cause 15 | errors for GAN training. 16 | More specific, the GAN model, usually has two sub-modules: 17 | generator and discriminator. If we wrap both of them in one 18 | standard DistributedDataParallel, it will cause errors during training, 19 | because when we update the parameters of the generator (or discriminator), 20 | the parameters of the discriminator (or generator) is not updated, which is 21 | not allowed for DistributedDataParallel. 22 | So we design this wrapper to separately wrap DistributedDataParallel 23 | for generator and discriminator. 24 | 25 | In this wrapper, we perform two operations: 26 | 1. Wrap the modules in the models with separate MMDistributedDataParallel. 27 | Note that only modules with parameters will be wrapped. 28 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. 29 | 30 | Note that the arguments of this wrapper is the same as those in 31 | `torch.nn.parallel.distributed.DistributedDataParallel`. 32 | 33 | Args: 34 | module (nn.Module): Module that needs to be wrapped. 35 | device_ids (list[int | `torch.device`]): Same as that in 36 | `torch.nn.parallel.distributed.DistributedDataParallel`. 37 | dim (int, optional): Same as that in the official scatter function in 38 | pytorch. Defaults to 0. 39 | broadcast_buffers (bool): Same as that in 40 | `torch.nn.parallel.distributed.DistributedDataParallel`. 41 | Defaults to False. 42 | find_unused_parameters (bool, optional): Same as that in 43 | `torch.nn.parallel.distributed.DistributedDataParallel`. 44 | Traverse the autograd graph of all tensors contained in returned 45 | value of the wrapped module’s forward function. Defaults to False. 46 | kwargs (dict): Other arguments used in 47 | `torch.nn.parallel.distributed.DistributedDataParallel`. 48 | """ 49 | 50 | def __init__(self, 51 | module, 52 | device_ids, 53 | dim=0, 54 | broadcast_buffers=False, 55 | find_unused_parameters=False, 56 | **kwargs): 57 | super().__init__() 58 | assert len(device_ids) == 1, ( 59 | 'Currently, DistributedDataParallelWrapper only supports one' 60 | 'single CUDA device for each process.' 61 | f'The length of device_ids must be 1, but got {len(device_ids)}.') 62 | self.module = module 63 | self.dim = dim 64 | self.to_ddp( 65 | device_ids=device_ids, 66 | dim=dim, 67 | broadcast_buffers=broadcast_buffers, 68 | find_unused_parameters=find_unused_parameters, 69 | **kwargs) 70 | self.output_device = _get_device_index(device_ids[0], True) 71 | 72 | def to_ddp(self, device_ids, dim, broadcast_buffers, 73 | find_unused_parameters, **kwargs): 74 | """Wrap models with separate MMDistributedDataParallel. 75 | 76 | It only wraps the modules with parameters. 77 | """ 78 | for name, module in self.module._modules.items(): 79 | if next(module.parameters(), None) is None: 80 | module = module.cuda() 81 | elif all(not p.requires_grad for p in module.parameters()): 82 | module = module.cuda() 83 | else: 84 | module = MMDistributedDataParallel( 85 | module.cuda(), 86 | device_ids=device_ids, 87 | dim=dim, 88 | broadcast_buffers=broadcast_buffers, 89 | find_unused_parameters=find_unused_parameters, 90 | **kwargs) 91 | self.module._modules[name] = module 92 | 93 | def scatter(self, inputs, kwargs, device_ids): 94 | """Scatter function. 95 | 96 | Args: 97 | inputs (Tensor): Input Tensor. 98 | kwargs (dict): Args for 99 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 100 | device_ids (int): Device id. 101 | """ 102 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 103 | 104 | def forward(self, *inputs, **kwargs): 105 | """Forward function. 106 | 107 | Args: 108 | inputs (tuple): Input data. 109 | kwargs (dict): Args for 110 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 111 | """ 112 | inputs, kwargs = self.scatter(inputs, kwargs, 113 | [torch.cuda.current_device()]) 114 | return self.module(*inputs[0], **kwargs[0]) 115 | 116 | def train_step(self, *inputs, **kwargs): 117 | """Train step function. 118 | 119 | Args: 120 | inputs (Tensor): Input Tensor. 121 | kwargs (dict): Args for 122 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 123 | """ 124 | inputs, kwargs = self.scatter(inputs, kwargs, 125 | [torch.cuda.current_device()]) 126 | output = self.module.train_step(*inputs[0], **kwargs[0]) 127 | return output 128 | 129 | def val_step(self, *inputs, **kwargs): 130 | """Validation step function. 131 | 132 | Args: 133 | inputs (tuple): Input data. 134 | kwargs (dict): Args for ``scatter_kwargs``. 135 | """ 136 | inputs, kwargs = self.scatter(inputs, kwargs, 137 | [torch.cuda.current_device()]) 138 | output = self.module.val_step(*inputs[0], **kwargs[0]) 139 | return output 140 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | 8 | import mmcv 9 | import torch 10 | import torch.distributed as dist 11 | from mmcv import Config, DictAction 12 | from mmcv.runner import init_dist 13 | 14 | from mmedit import __version__ 15 | from mmedit.apis import init_random_seed, set_random_seed, train_model 16 | from mmedit.datasets import build_dataset 17 | from mmedit.models import build_model 18 | from mmedit.utils import collect_env, get_root_logger, setup_multi_processes 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train an editor') 23 | parser.add_argument('config', help='train config file path') 24 | parser.add_argument('--work-dir', help='the dir to save logs and models') 25 | parser.add_argument( 26 | '--resume-from', help='the checkpoint file to resume from') 27 | parser.add_argument( 28 | '--no-validate', 29 | action='store_true', 30 | help='whether not to evaluate the checkpoint during training') 31 | parser.add_argument( 32 | '--gpus', 33 | type=int, 34 | default=1, 35 | help='number of gpus to use ' 36 | '(only applicable to non-distributed training)') 37 | parser.add_argument('--seed', type=int, default=None, help='random seed') 38 | parser.add_argument( 39 | '--diff_seed', 40 | action='store_true', 41 | help='Whether or not set different seeds for different ranks') 42 | parser.add_argument( 43 | '--deterministic', 44 | action='store_true', 45 | help='whether to set deterministic options for CUDNN backend.') 46 | parser.add_argument( 47 | '--cfg-options', 48 | nargs='+', 49 | action=DictAction, 50 | help='override some settings in the used config, the key-value pair ' 51 | 'in xxx=yyy format will be merged into config file. If the value to ' 52 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 53 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 54 | 'Note that the quotation marks are necessary and that no white space ' 55 | 'is allowed.') 56 | parser.add_argument( 57 | '--launcher', 58 | choices=['none', 'pytorch', 'slurm', 'mpi'], 59 | default='none', 60 | help='job launcher') 61 | parser.add_argument('--local_rank', type=int, default=0) 62 | parser.add_argument( 63 | '--autoscale-lr', 64 | action='store_true', 65 | help='automatically scale lr with the number of gpus') 66 | args = parser.parse_args() 67 | if 'LOCAL_RANK' not in os.environ: 68 | os.environ['LOCAL_RANK'] = str(args.local_rank) 69 | 70 | return args 71 | 72 | 73 | def main(): 74 | args = parse_args() 75 | 76 | cfg = Config.fromfile(args.config) 77 | 78 | if args.cfg_options is not None: 79 | cfg.merge_from_dict(args.cfg_options) 80 | 81 | # set multi-process settings 82 | setup_multi_processes(cfg) 83 | 84 | # set cudnn_benchmark 85 | if cfg.get('cudnn_benchmark', False): 86 | torch.backends.cudnn.benchmark = True 87 | # update configs according to CLI args 88 | if args.work_dir is not None: 89 | cfg.work_dir = args.work_dir 90 | if args.resume_from is not None: 91 | cfg.resume_from = args.resume_from 92 | cfg.gpus = args.gpus 93 | 94 | if args.autoscale_lr: 95 | # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) 96 | cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8 97 | 98 | # init distributed env first, since logger depends on the dist info. 99 | if args.launcher == 'none': 100 | distributed = False 101 | else: 102 | distributed = True 103 | init_dist(args.launcher, **cfg.dist_params) 104 | 105 | # create work_dir 106 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 107 | # dump config 108 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 109 | # init the logger before other steps 110 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 111 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 112 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 113 | 114 | # log env info 115 | env_info_dict = collect_env.collect_env() 116 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 117 | dash_line = '-' * 60 + '\n' 118 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 119 | dash_line) 120 | 121 | # log some basic info 122 | logger.info('Distributed training: {}'.format(distributed)) 123 | logger.info('mmedit Version: {}'.format(__version__)) 124 | logger.info('Config:\n{}'.format(cfg.text)) 125 | 126 | # set random seeds 127 | seed = init_random_seed(args.seed) 128 | seed = seed + dist.get_rank() if args.diff_seed else seed 129 | logger.info('Set random seed to {}, deterministic: {}'.format( 130 | seed, args.deterministic)) 131 | set_random_seed(seed, deterministic=args.deterministic) 132 | cfg.seed = seed 133 | 134 | model = build_model( 135 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 136 | 137 | datasets = [build_dataset(cfg.data.train)] 138 | if len(cfg.workflow) == 2: 139 | val_dataset = copy.deepcopy(cfg.data.val) 140 | val_dataset.pipeline = cfg.data.train.pipeline 141 | datasets.append(build_dataset(val_dataset)) 142 | if cfg.checkpoint_config is not None: 143 | # save version, config file content and class names in 144 | # checkpoints as meta data 145 | cfg.checkpoint_config.meta = dict( 146 | mmedit_version=__version__, 147 | config=cfg.text, 148 | ) 149 | 150 | # meta information 151 | meta = dict() 152 | if cfg.get('exp_name', None) is None: 153 | cfg['exp_name'] = osp.splitext(osp.basename(cfg.work_dir))[0] 154 | meta['exp_name'] = cfg.exp_name 155 | meta['mmedit Version'] = __version__ 156 | meta['seed'] = seed 157 | meta['env_info'] = env_info 158 | 159 | # add an attribute for visualization convenience 160 | train_model( 161 | model, 162 | datasets, 163 | cfg, 164 | distributed=distributed, 165 | validate=(not args.no_validate), 166 | timestamp=timestamp, 167 | meta=meta) 168 | 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/generate_assistant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | from ..registry import PIPELINES 6 | from .utils import make_coord 7 | 8 | 9 | @PIPELINES.register_module() 10 | class GenerateHeatmap: 11 | """Generate heatmap from keypoint. 12 | 13 | Args: 14 | keypoint (str): Key of keypoint in dict. 15 | ori_size (int | Tuple[int]): Original image size of keypoint. 16 | target_size (int | Tuple[int]): Target size of heatmap. 17 | sigma (float): Sigma parameter of heatmap. Default: 1.0 18 | """ 19 | 20 | def __init__(self, keypoint, ori_size, target_size, sigma=1.0): 21 | if isinstance(ori_size, int): 22 | ori_size = (ori_size, ori_size) 23 | else: 24 | ori_size = ori_size[:2] 25 | if isinstance(target_size, int): 26 | target_size = (target_size, target_size) 27 | else: 28 | target_size = target_size[:2] 29 | self.size_ratio = (target_size[0] / ori_size[0], 30 | target_size[1] / ori_size[1]) 31 | self.keypoint = keypoint 32 | self.sigma = sigma 33 | self.target_size = target_size 34 | self.ori_size = ori_size 35 | 36 | def __call__(self, results): 37 | """Call function. 38 | 39 | Args: 40 | results (dict): A dict containing the necessary information and 41 | data for augmentation. Require keypoint. 42 | 43 | Returns: 44 | dict: A dict containing the processed data and information. 45 | Add 'heatmap'. 46 | """ 47 | keypoint_list = [(keypoint[0] * self.size_ratio[0], 48 | keypoint[1] * self.size_ratio[1]) 49 | for keypoint in results[self.keypoint]] 50 | heatmap_list = [ 51 | self._generate_one_heatmap(keypoint) for keypoint in keypoint_list 52 | ] 53 | results['heatmap'] = np.stack(heatmap_list, axis=2) 54 | return results 55 | 56 | def _generate_one_heatmap(self, keypoint): 57 | """Generate One Heatmap. 58 | 59 | Args: 60 | landmark (Tuple[float]): Location of a landmark. 61 | 62 | results: 63 | heatmap (np.ndarray): A heatmap of landmark. 64 | """ 65 | w, h = self.target_size 66 | 67 | x_range = np.arange(start=0, stop=w, dtype=int) 68 | y_range = np.arange(start=0, stop=h, dtype=int) 69 | grid_x, grid_y = np.meshgrid(x_range, y_range) 70 | dist2 = (grid_x - keypoint[0])**2 + (grid_y - keypoint[1])**2 71 | exponent = dist2 / 2.0 / self.sigma / self.sigma 72 | heatmap = np.exp(-exponent) 73 | return heatmap 74 | 75 | def __repr__(self): 76 | return (f'{self.__class__.__name__}, ' 77 | f'keypoint={self.keypoint}, ' 78 | f'ori_size={self.ori_size}, ' 79 | f'target_size={self.target_size}, ' 80 | f'sigma={self.sigma}') 81 | 82 | 83 | @PIPELINES.register_module() 84 | class GenerateCoordinateAndCell: 85 | """Generate coordinate and cell. 86 | 87 | Generate coordinate from the desired size of SR image. 88 | Train or val: 89 | 1. Generate coordinate from GT. 90 | 2. Reshape GT image to (HgWg, 3) and transpose to (3, HgWg). 91 | where `Hg` and `Wg` represent the height and width of GT. 92 | Test: 93 | Generate coordinate from LQ and scale or target_size. 94 | Then generate cell from coordinate. 95 | 96 | Args: 97 | sample_quantity (int): The quantity of samples in coordinates. 98 | To ensure that the GT tensors in a batch have the same dimensions. 99 | Default: None. 100 | scale (float): Scale of upsampling. 101 | Default: None. 102 | target_size (tuple[int]): Size of target image. 103 | Default: None. 104 | 105 | The priority of getting 'size of target image' is: 106 | 1, results['gt'].shape[-2:] 107 | 2, results['lq'].shape[-2:] * scale 108 | 3, target_size 109 | """ 110 | 111 | def __init__(self, sample_quantity=None, scale=None, target_size=None): 112 | self.sample_quantity = sample_quantity 113 | self.scale = scale 114 | self.target_size = target_size 115 | 116 | def __call__(self, results): 117 | """Call function. 118 | 119 | Args: 120 | results (dict): A dict containing the necessary information and 121 | data for augmentation. 122 | Require either in results: 123 | 1. 'lq' (tensor), whose shape is similar as (3, H, W). 124 | 2. 'gt' (tensor), whose shape is similar as (3, H, W). 125 | 3. None, the premise is 126 | self.target_size and len(self.target_size) >= 2. 127 | 128 | Returns: 129 | dict: A dict containing the processed data and information. 130 | Reshape 'gt' to (-1, 3) and transpose to (3, -1) if 'gt' 131 | in results. 132 | Add 'coord' and 'cell'. 133 | """ 134 | # generate hr_coord (and hr_rgb) 135 | if 'gt' in results: 136 | crop_hr = results['gt'] 137 | self.target_size = crop_hr.shape 138 | hr_rgb = crop_hr.contiguous().view(3, -1).permute(1, 0) 139 | results['gt'] = hr_rgb 140 | elif self.scale is not None and 'lq' in results: 141 | _, h_lr, w_lr = results['lq'].shape 142 | self.target_size = (round(h_lr * self.scale), 143 | round(w_lr * self.scale)) 144 | else: 145 | assert self.target_size is not None 146 | assert len(self.target_size) >= 2 147 | hr_coord = make_coord(self.target_size[-2:]) 148 | 149 | if self.sample_quantity is not None and 'gt' in results: 150 | sample_lst = np.random.choice( 151 | len(hr_coord), self.sample_quantity, replace=False) 152 | hr_coord = hr_coord[sample_lst] 153 | results['gt'] = results['gt'][sample_lst] 154 | 155 | # Preparations for cell decoding 156 | cell = torch.ones_like(hr_coord) 157 | cell[:, 0] *= 2 / self.target_size[-2] 158 | cell[:, 1] *= 2 / self.target_size[-1] 159 | 160 | results['coord'] = hr_coord 161 | results['cell'] = cell 162 | 163 | return results 164 | 165 | def __repr__(self): 166 | repr_str = self.__class__.__name__ 167 | repr_str += (f'sample_quantity={self.sample_quantity}, ' 168 | f'scale={self.scale}, target_size={self.target_size}') 169 | return repr_str 170 | -------------------------------------------------------------------------------- /mmedit/datasets/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import platform 4 | import random 5 | from functools import partial 6 | 7 | import numpy as np 8 | import torch 9 | from mmcv.parallel import collate 10 | from mmcv.runner import get_dist_info 11 | from mmcv.utils import build_from_cfg 12 | from packaging import version 13 | from torch.utils.data import ConcatDataset, DataLoader 14 | 15 | from .dataset_wrappers import RepeatDataset 16 | from .registry import DATASETS 17 | from .samplers import DistributedSampler 18 | 19 | if platform.system() != 'Windows': 20 | # https://github.com/pytorch/pytorch/issues/973 21 | import resource 22 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 23 | base_soft_limit = rlimit[0] 24 | hard_limit = rlimit[1] 25 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 26 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 27 | 28 | 29 | def _concat_dataset(cfg, default_args=None): 30 | """Concat datasets with different ann_file but the same type. 31 | 32 | Args: 33 | cfg (dict): The config of dataset. 34 | default_args (dict, optional): Default initialization arguments. 35 | Default: None. 36 | 37 | Returns: 38 | Dataset: The concatenated dataset. 39 | """ 40 | ann_files = cfg['ann_file'] 41 | 42 | datasets = [] 43 | num_dset = len(ann_files) 44 | for i in range(num_dset): 45 | data_cfg = copy.deepcopy(cfg) 46 | data_cfg['ann_file'] = ann_files[i] 47 | datasets.append(build_dataset(data_cfg, default_args)) 48 | 49 | return ConcatDataset(datasets) 50 | 51 | 52 | def build_dataset(cfg, default_args=None): 53 | """Build a dataset from config dict. 54 | 55 | It supports a variety of dataset config. If ``cfg`` is a Sequential (list 56 | or dict), it will be a concatenated dataset of the datasets specified by 57 | the Sequential. If it is a ``RepeatDataset``, then it will repeat the 58 | dataset ``cfg['dataset']`` for ``cfg['times']`` times. If the ``ann_file`` 59 | of the dataset is a Sequential, then it will build a concatenated dataset 60 | with the same dataset type but different ``ann_file``. 61 | 62 | Args: 63 | cfg (dict): Config dict. It should at least contain the key "type". 64 | default_args (dict, optional): Default initialization arguments. 65 | Default: None. 66 | 67 | Returns: 68 | Dataset: The constructed dataset. 69 | """ 70 | if isinstance(cfg, (list, tuple)): 71 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 72 | elif cfg['type'] == 'RepeatDataset': 73 | dataset = RepeatDataset( 74 | build_dataset(cfg['dataset'], default_args), cfg['times']) 75 | elif isinstance(cfg.get('ann_file'), (list, tuple)): 76 | dataset = _concat_dataset(cfg, default_args) 77 | else: 78 | dataset = build_from_cfg(cfg, DATASETS, default_args) 79 | 80 | return dataset 81 | 82 | 83 | def build_dataloader(dataset, 84 | samples_per_gpu, 85 | workers_per_gpu, 86 | num_gpus=1, 87 | dist=True, 88 | shuffle=True, 89 | seed=None, 90 | drop_last=False, 91 | pin_memory=True, 92 | persistent_workers=True, 93 | **kwargs): 94 | """Build PyTorch DataLoader. 95 | 96 | In distributed training, each GPU/process has a dataloader. 97 | In non-distributed training, there is only one dataloader for all GPUs. 98 | 99 | Args: 100 | dataset (:obj:`Dataset`): A PyTorch dataset. 101 | samples_per_gpu (int): Number of samples on each GPU, i.e., 102 | batch size of each GPU. 103 | workers_per_gpu (int): How many subprocesses to use for data 104 | loading for each GPU. 105 | num_gpus (int): Number of GPUs. Only used in non-distributed 106 | training. Default: 1. 107 | dist (bool): Distributed training/test or not. Default: True. 108 | shuffle (bool): Whether to shuffle the data at every epoch. 109 | Default: True. 110 | seed (int | None): Seed to be used. Default: None. 111 | drop_last (bool): Whether to drop the last incomplete batch in epoch. 112 | Default: False 113 | pin_memory (bool): Whether to use pin_memory in DataLoader. 114 | Default: True 115 | persistent_workers (bool): If True, the data loader will not shutdown 116 | the worker processes after a dataset has been consumed once. 117 | This allows to maintain the workers Dataset instances alive. 118 | The argument also has effect in PyTorch>=1.7.0. 119 | Default: True 120 | kwargs (dict, optional): Any keyword argument to be used to initialize 121 | DataLoader. 122 | 123 | Returns: 124 | DataLoader: A PyTorch dataloader. 125 | """ 126 | rank, world_size = get_dist_info() 127 | if dist: 128 | sampler = DistributedSampler( 129 | dataset, 130 | world_size, 131 | rank, 132 | shuffle=shuffle, 133 | samples_per_gpu=samples_per_gpu, 134 | seed=seed) 135 | shuffle = False 136 | batch_size = samples_per_gpu 137 | num_workers = workers_per_gpu 138 | else: 139 | sampler = None 140 | batch_size = num_gpus * samples_per_gpu 141 | num_workers = num_gpus * workers_per_gpu 142 | 143 | init_fn = partial( 144 | worker_init_fn, num_workers=num_workers, rank=rank, 145 | seed=seed) if seed is not None else None 146 | 147 | if version.parse(torch.__version__) >= version.parse('1.7.0'): 148 | kwargs['persistent_workers'] = persistent_workers 149 | 150 | data_loader = DataLoader( 151 | dataset, 152 | batch_size=batch_size, 153 | sampler=sampler, 154 | num_workers=num_workers, 155 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 156 | pin_memory=pin_memory, 157 | shuffle=shuffle, 158 | worker_init_fn=init_fn, 159 | drop_last=drop_last, 160 | **kwargs) 161 | 162 | return data_loader 163 | 164 | 165 | def worker_init_fn(worker_id, num_workers, rank, seed): 166 | """Function to initialize each worker. 167 | 168 | The seed of each worker equals to 169 | ``num_worker * rank + worker_id + user_seed``. 170 | 171 | Args: 172 | worker_id (int): Id for each worker. 173 | num_workers (int): Number of workers. 174 | rank (int): Rank in distributed training. 175 | seed (int): Random seed. 176 | """ 177 | 178 | worker_seed = num_workers * rank + worker_id + seed 179 | np.random.seed(worker_seed) 180 | random.seed(worker_seed) 181 | torch.manual_seed(worker_seed) 182 | -------------------------------------------------------------------------------- /mmedit/models/derainers/derainer.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os.path as osp 3 | import torch 4 | import mmcv 5 | from mmcv.runner import auto_fp16 6 | import torch.nn as nn 7 | from mmedit.core import psnr, ssim, tensor2img 8 | from ..base import BaseModel 9 | from ..builder import build_backbone, build_loss 10 | from ..registry import MODELS 11 | 12 | 13 | 14 | @MODELS.register_module() 15 | class Derainer(BaseModel): 16 | 17 | allowed_metrics = {'PSNR': psnr, 'SSIM': ssim} 18 | 19 | def __init__(self, 20 | generator, 21 | pixel_loss, 22 | 23 | train_cfg=None, 24 | test_cfg=None, 25 | pretrained=None): 26 | super().__init__() 27 | 28 | self.train_cfg = train_cfg 29 | self.test_cfg = test_cfg 30 | 31 | self.fp16_enabled = False 32 | 33 | # generator 34 | self.generator = build_backbone(generator) 35 | 36 | self.init_weights(pretrained) 37 | 38 | # loss 39 | self.pixel_loss = build_loss(pixel_loss) 40 | self.mask_loss = nn.BCELoss() 41 | self.streak_loss = nn.MSELoss() 42 | self.l1_loss = nn.L1Loss() 43 | 44 | 45 | def init_weights(self, pretrained=None): 46 | """Init weights for models. 47 | 48 | Args: 49 | pretrained (str, optional): Path for pretrained weights. If given 50 | None, pretrained weights will not be loaded. Defaults to None. 51 | """ 52 | self.generator.init_weights(pretrained) 53 | 54 | @auto_fp16(apply_to=('lq', )) 55 | def forward(self, lq, gt=None, mask=None, drop=None, streak=None, test_mode=False, **kwargs): 56 | """Forward function. 57 | 58 | Args: 59 | lq (Tensor): Input lq images. 60 | gt (Tensor): Ground-truth image. Default: None. 61 | test_mode (bool): Whether in test mode or not. Default: False. 62 | kwargs (dict): Other arguments. 63 | """ 64 | 65 | if test_mode: 66 | return self.forward_test(lq, gt, **kwargs) 67 | 68 | return self.forward_train(lq, gt, mask, drop, streak) 69 | 70 | def forward_train(self, lq, gt, mask, drop, streak): 71 | """Training forward function. 72 | 73 | Args: 74 | lq (Tensor): LQ Tensor with shape (n, c, h, w). 75 | gt (Tensor): GT Tensor with shape (n, c, h, w). 76 | 77 | Returns: 78 | Tensor: Output tensor. 79 | """ 80 | losses = dict() 81 | 82 | output, outputmask, streakmask, finalresult = self.generator(lq) 83 | 84 | loss_pix = self.pixel_loss(output, drop) 85 | loss_mask = 0.8 * self.mask_loss(outputmask, mask) 86 | loss_pix2 = 1.2 * self.pixel_loss(finalresult, gt) 87 | hole_loss = self.l1_loss(finalresult * mask, gt * mask) 88 | hole_loss = 0.5 * hole_loss / (torch.mean(mask) * 1 + 0.00000001) 89 | loss_streak = 0.5*self.streak_loss(streakmask, streak) 90 | 91 | 92 | losses['loss_pix'] = loss_pix 93 | losses['loss_mask'] = loss_mask 94 | losses['loss_reconstruct'] = loss_pix2 95 | losses['loss_hole'] = hole_loss 96 | losses['loss_streak'] = loss_streak 97 | 98 | 99 | outputs = dict( 100 | losses=losses, 101 | num_samples=len(gt.data), 102 | results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu())) 103 | return outputs 104 | 105 | def evaluate(self, output, gt): 106 | """Evaluation function. 107 | 108 | Args: 109 | output (Tensor): Model output with shape (n, c, h, w). 110 | gt (Tensor): GT Tensor with shape (n, c, h, w). 111 | 112 | Returns: 113 | dict: Evaluation results. 114 | """ 115 | crop_border = self.test_cfg.crop_border 116 | 117 | output = tensor2img(output) 118 | gt = tensor2img(gt) 119 | 120 | eval_result = dict() 121 | for metric in self.test_cfg.metrics: 122 | eval_result[metric] = self.allowed_metrics[metric](output, gt, 123 | crop_border) 124 | return eval_result 125 | 126 | def forward_test(self, 127 | lq, 128 | gt=None, 129 | meta=None, 130 | save_image=False, 131 | save_path=None, 132 | iteration=None): 133 | """Testing forward function. 134 | 135 | Args: 136 | lq (Tensor): LQ Tensor with shape (n, c, h, w). 137 | gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. 138 | save_image (bool): Whether to save image. Default: False. 139 | save_path (str): Path to save image. Default: None. 140 | iteration (int): Iteration for the saving image name. 141 | Default: None. 142 | 143 | Returns: 144 | dict: Output results. 145 | """ 146 | output = self.generator(lq) 147 | if self.test_cfg is not None and self.test_cfg.get('metrics', None): 148 | assert gt is not None, ( 149 | 'evaluation with metrics must have gt images.') 150 | results = dict(eval_result=self.evaluate(output, gt)) 151 | else: 152 | results = dict(lq=lq.cpu(), output=output.cpu()) 153 | if gt is not None: 154 | results['gt'] = gt.cpu() 155 | 156 | # save image 157 | if save_image: 158 | lq_path = meta[0]['lq_path'] 159 | folder_name = osp.splitext(osp.basename(lq_path))[0] 160 | if isinstance(iteration, numbers.Number): 161 | save_path = osp.join(save_path, folder_name, 162 | f'{folder_name}-{iteration + 1:06d}.png') 163 | elif iteration is None: 164 | save_path = osp.join(save_path, f'{folder_name}.png') 165 | else: 166 | raise ValueError('iteration should be number or None, ' 167 | f'but got {type(iteration)}') 168 | mmcv.imwrite(tensor2img(output), save_path) 169 | 170 | return results 171 | 172 | def forward_dummy(self, img): 173 | """Used for computing network FLOPs. 174 | 175 | Args: 176 | img (Tensor): Input image. 177 | 178 | Returns: 179 | Tensor: Output image. 180 | """ 181 | out = self.generator(img) 182 | return out 183 | 184 | def train_step(self, data_batch, optimizer): 185 | """Train step. 186 | 187 | Args: 188 | data_batch (dict): A batch of data. 189 | optimizer (obj): Optimizer. 190 | 191 | Returns: 192 | dict: Returned output. 193 | """ 194 | outputs = self(**data_batch, test_mode=False) 195 | loss, log_vars = self.parse_losses(outputs.pop('losses')) 196 | 197 | # optimize 198 | optimizer['generator'].zero_grad() 199 | loss.backward() 200 | optimizer['generator'].step() 201 | 202 | outputs.update({'log_vars': log_vars}) 203 | return outputs 204 | 205 | def val_step(self, data_batch, **kwargs): 206 | """Validation step. 207 | 208 | Args: 209 | data_batch (dict): A batch of data. 210 | kwargs (dict): Other arguments for ``val_step``. 211 | 212 | Returns: 213 | dict: Returned output. 214 | """ 215 | output = self.forward_test(**data_batch, **kwargs) 216 | return output 217 | 218 | -------------------------------------------------------------------------------- /mmedit/models/losses/pixelwise_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..registry import LOSSES 7 | from .utils import masked_loss 8 | 9 | _reduction_modes = ['none', 'mean', 'sum'] 10 | 11 | 12 | @masked_loss 13 | def l1_loss(pred, target): 14 | """L1 loss. 15 | 16 | Args: 17 | pred (Tensor): Prediction Tensor with shape (n, c, h, w). 18 | target ([type]): Target Tensor with shape (n, c, h, w). 19 | 20 | Returns: 21 | Tensor: Calculated L1 loss. 22 | """ 23 | return F.l1_loss(pred, target, reduction='none') 24 | 25 | 26 | @masked_loss 27 | def mse_loss(pred, target): 28 | """MSE loss. 29 | 30 | Args: 31 | pred (Tensor): Prediction Tensor with shape (n, c, h, w). 32 | target ([type]): Target Tensor with shape (n, c, h, w). 33 | 34 | Returns: 35 | Tensor: Calculated MSE loss. 36 | """ 37 | return F.mse_loss(pred, target, reduction='none') 38 | 39 | 40 | @masked_loss 41 | def charbonnier_loss(pred, target, eps=1e-12): 42 | """Charbonnier loss. 43 | 44 | Args: 45 | pred (Tensor): Prediction Tensor with shape (n, c, h, w). 46 | target ([type]): Target Tensor with shape (n, c, h, w). 47 | 48 | Returns: 49 | Tensor: Calculated Charbonnier loss. 50 | """ 51 | return torch.sqrt((pred - target)**2 + eps) 52 | 53 | 54 | @LOSSES.register_module() 55 | class L1Loss(nn.Module): 56 | """L1 (mean absolute error, MAE) loss. 57 | 58 | Args: 59 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 60 | reduction (str): Specifies the reduction to apply to the output. 61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 62 | sample_wise (bool): Whether calculate the loss sample-wise. This 63 | argument only takes effect when `reduction` is 'mean' and `weight` 64 | (argument of `forward()`) is not None. It will first reduce loss 65 | with 'mean' per-sample, and then it means over all the samples. 66 | Default: False. 67 | """ 68 | 69 | def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): 70 | super().__init__() 71 | if reduction not in ['none', 'mean', 'sum']: 72 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 73 | f'Supported ones are: {_reduction_modes}') 74 | 75 | self.loss_weight = loss_weight 76 | self.reduction = reduction 77 | self.sample_wise = sample_wise 78 | 79 | def forward(self, pred, target, weight=None, **kwargs): 80 | """Forward Function. 81 | 82 | Args: 83 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 84 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 85 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 86 | weights. Default: None. 87 | """ 88 | return self.loss_weight * l1_loss( 89 | pred, 90 | target, 91 | weight, 92 | reduction=self.reduction, 93 | sample_wise=self.sample_wise) 94 | 95 | 96 | @LOSSES.register_module() 97 | class MSELoss(nn.Module): 98 | """MSE (L2) loss. 99 | 100 | Args: 101 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 102 | reduction (str): Specifies the reduction to apply to the output. 103 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 104 | sample_wise (bool): Whether calculate the loss sample-wise. This 105 | argument only takes effect when `reduction` is 'mean' and `weight` 106 | (argument of `forward()`) is not None. It will first reduces loss 107 | with 'mean' per-sample, and then it means over all the samples. 108 | Default: False. 109 | """ 110 | 111 | def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): 112 | super().__init__() 113 | if reduction not in ['none', 'mean', 'sum']: 114 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 115 | f'Supported ones are: {_reduction_modes}') 116 | 117 | self.loss_weight = loss_weight 118 | self.reduction = reduction 119 | self.sample_wise = sample_wise 120 | 121 | def forward(self, pred, target, weight=None, **kwargs): 122 | """Forward Function. 123 | 124 | Args: 125 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 126 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 127 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 128 | weights. Default: None. 129 | """ 130 | return self.loss_weight * mse_loss( 131 | pred, 132 | target, 133 | weight, 134 | reduction=self.reduction, 135 | sample_wise=self.sample_wise) 136 | 137 | 138 | @LOSSES.register_module() 139 | class CharbonnierLoss(nn.Module): 140 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable variant 141 | of L1Loss). 142 | 143 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate 144 | Super-Resolution". 145 | 146 | Args: 147 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 148 | reduction (str): Specifies the reduction to apply to the output. 149 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 150 | sample_wise (bool): Whether calculate the loss sample-wise. This 151 | argument only takes effect when `reduction` is 'mean' and `weight` 152 | (argument of `forward()`) is not None. It will first reduces loss 153 | with 'mean' per-sample, and then it means over all the samples. 154 | Default: False. 155 | eps (float): A value used to control the curvature near zero. 156 | Default: 1e-12. 157 | """ 158 | 159 | def __init__(self, 160 | loss_weight=1.0, 161 | reduction='mean', 162 | sample_wise=False, 163 | eps=1e-12): 164 | super().__init__() 165 | if reduction not in ['none', 'mean', 'sum']: 166 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 167 | f'Supported ones are: {_reduction_modes}') 168 | 169 | self.loss_weight = loss_weight 170 | self.reduction = reduction 171 | self.sample_wise = sample_wise 172 | self.eps = eps 173 | 174 | def forward(self, pred, target, weight=None, **kwargs): 175 | """Forward Function. 176 | 177 | Args: 178 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 179 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 180 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 181 | weights. Default: None. 182 | """ 183 | return self.loss_weight * charbonnier_loss( 184 | pred, 185 | target, 186 | weight, 187 | eps=self.eps, 188 | reduction=self.reduction, 189 | sample_wise=self.sample_wise) 190 | 191 | 192 | @LOSSES.register_module() 193 | class MaskedTVLoss(L1Loss): 194 | """Masked TV loss. 195 | 196 | Args: 197 | loss_weight (float, optional): Loss weight. Defaults to 1.0. 198 | """ 199 | 200 | def __init__(self, loss_weight=1.0): 201 | super().__init__(loss_weight=loss_weight) 202 | 203 | def forward(self, pred, mask=None): 204 | """Forward function. 205 | 206 | Args: 207 | pred (torch.Tensor): Tensor with shape of (n, c, h, w). 208 | mask (torch.Tensor, optional): Tensor with shape of (n, 1, h, w). 209 | Defaults to None. 210 | 211 | Returns: 212 | [type]: [description] 213 | """ 214 | y_diff = super().forward( 215 | pred[:, :, :-1, :], pred[:, :, 1:, :], weight=mask[:, :, :-1, :]) 216 | x_diff = super().forward( 217 | pred[:, :, :, :-1], pred[:, :, :, 1:], weight=mask[:, :, :, :-1]) 218 | 219 | loss = x_diff + y_diff 220 | 221 | return loss 222 | -------------------------------------------------------------------------------- /mmedit/models/backbones/derain_backbones/derain_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import ConvModule 6 | from mmcv.runner import load_checkpoint 7 | 8 | from mmedit.models.common import ( ResidualBlockNoBN, 9 | flow_warp, make_layer) 10 | from mmedit.models.registry import BACKBONES 11 | from mmedit.utils import get_root_logger 12 | 13 | 14 | 15 | 16 | class ResidualBlocksWithInputConv(nn.Module): 17 | """Residual blocks with a convolution in front. 18 | 19 | Args: 20 | in_channels (int): Number of input channels of the first conv. 21 | out_channels (int): Number of channels of the residual blocks. 22 | Default: 64. 23 | num_blocks (int): Number of residual blocks. Default: 30. 24 | """ 25 | 26 | def __init__(self, in_channels, out_channels=64, num_blocks=30): 27 | super().__init__() 28 | 29 | main = [] 30 | 31 | # a convolution used to match the channels of the residual blocks 32 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) 33 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) 34 | 35 | # residual blocks 36 | main.append( 37 | make_layer( 38 | ResidualBlockNoBN, num_blocks, mid_channels=out_channels)) 39 | 40 | self.main = nn.Sequential(*main) 41 | 42 | def forward(self, feat): 43 | """Forward function for ResidualBlocksWithInputConv. 44 | 45 | Args: 46 | feat (Tensor): Input feature with shape (n, in_channels, h, w) 47 | 48 | Returns: 49 | Tensor: Output feature with shape (n, out_channels, h, w) 50 | """ 51 | return self.main(feat) 52 | 53 | 54 | class SPyNet(nn.Module): 55 | """SPyNet network structure. 56 | 57 | The difference to the SPyNet in [tof.py] is that 58 | 1. more SPyNetBasicModule is used in this version, and 59 | 2. no batch normalization is used in this version. 60 | 61 | Paper: 62 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 63 | 64 | Args: 65 | pretrained (str): path for pre-trained SPyNet. Default: None. 66 | """ 67 | 68 | def __init__(self, pretrained): 69 | super().__init__() 70 | 71 | self.basic_module = nn.ModuleList( 72 | [SPyNetBasicModule() for _ in range(6)]) 73 | 74 | if isinstance(pretrained, str): 75 | logger = get_root_logger() 76 | load_checkpoint(self, pretrained, strict=True, logger=logger) 77 | elif pretrained is not None: 78 | raise TypeError('[pretrained] should be str or None, ' 79 | f'but got {type(pretrained)}.') 80 | 81 | self.register_buffer( 82 | 'mean', 83 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 84 | self.register_buffer( 85 | 'std', 86 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 87 | 88 | def compute_flow(self, ref, supp): 89 | """Compute flow from ref to supp. 90 | 91 | Note that in this function, the images are already resized to a 92 | multiple of 32. 93 | 94 | Args: 95 | ref (Tensor): Reference image with shape of (n, 3, h, w). 96 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 97 | 98 | Returns: 99 | Tensor: Estimated optical flow: (n, 2, h, w). 100 | """ 101 | n, _, h, w = ref.size() 102 | 103 | # normalize the input images 104 | ref = [(ref - self.mean) / self.std] 105 | supp = [(supp - self.mean) / self.std] 106 | 107 | # generate downsampled frames 108 | for level in range(5): 109 | ref.append( 110 | F.avg_pool2d( 111 | input=ref[-1], 112 | kernel_size=2, 113 | stride=2, 114 | count_include_pad=False)) 115 | supp.append( 116 | F.avg_pool2d( 117 | input=supp[-1], 118 | kernel_size=2, 119 | stride=2, 120 | count_include_pad=False)) 121 | ref = ref[::-1] 122 | supp = supp[::-1] 123 | 124 | # flow computation 125 | flow = ref[0].new_zeros(n, 2, h // 32, w // 32) 126 | for level in range(len(ref)): 127 | if level == 0: 128 | flow_up = flow 129 | else: 130 | flow_up = F.interpolate( 131 | input=flow, 132 | scale_factor=2, 133 | mode='bilinear', 134 | align_corners=True) * 2.0 135 | 136 | # add the residue to the upsampled flow 137 | flow = flow_up + self.basic_module[level]( 138 | torch.cat([ 139 | ref[level], 140 | flow_warp( 141 | supp[level], 142 | flow_up.permute(0, 2, 3, 1), 143 | padding_mode='border'), flow_up 144 | ], 1)) 145 | 146 | return flow 147 | 148 | def forward(self, ref, supp): 149 | """Forward function of SPyNet. 150 | 151 | This function computes the optical flow from ref to supp. 152 | 153 | Args: 154 | ref (Tensor): Reference image with shape of (n, 3, h, w). 155 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 156 | 157 | Returns: 158 | Tensor: Estimated optical flow: (n, 2, h, w). 159 | """ 160 | 161 | # upsize to a multiple of 32 162 | h, w = ref.shape[2:4] 163 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) 164 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) 165 | ref = F.interpolate( 166 | input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False) 167 | supp = F.interpolate( 168 | input=supp, 169 | size=(h_up, w_up), 170 | mode='bilinear', 171 | align_corners=False) 172 | 173 | # compute flow, and resize back to the original resolution 174 | flow = F.interpolate( 175 | input=self.compute_flow(ref, supp), 176 | size=(h, w), 177 | mode='bilinear', 178 | align_corners=False) 179 | 180 | # adjust the flow values 181 | flow[:, 0, :, :] *= float(w) / float(w_up) 182 | flow[:, 1, :, :] *= float(h) / float(h_up) 183 | 184 | return flow 185 | 186 | 187 | class SPyNetBasicModule(nn.Module): 188 | """Basic Module for SPyNet. 189 | 190 | Paper: 191 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 192 | """ 193 | 194 | def __init__(self): 195 | super().__init__() 196 | 197 | self.basic_module = nn.Sequential( 198 | ConvModule( 199 | in_channels=8, 200 | out_channels=32, 201 | kernel_size=7, 202 | stride=1, 203 | padding=3, 204 | norm_cfg=None, 205 | act_cfg=dict(type='ReLU')), 206 | ConvModule( 207 | in_channels=32, 208 | out_channels=64, 209 | kernel_size=7, 210 | stride=1, 211 | padding=3, 212 | norm_cfg=None, 213 | act_cfg=dict(type='ReLU')), 214 | ConvModule( 215 | in_channels=64, 216 | out_channels=32, 217 | kernel_size=7, 218 | stride=1, 219 | padding=3, 220 | norm_cfg=None, 221 | act_cfg=dict(type='ReLU')), 222 | ConvModule( 223 | in_channels=32, 224 | out_channels=16, 225 | kernel_size=7, 226 | stride=1, 227 | padding=3, 228 | norm_cfg=None, 229 | act_cfg=dict(type='ReLU')), 230 | ConvModule( 231 | in_channels=16, 232 | out_channels=2, 233 | kernel_size=7, 234 | stride=1, 235 | padding=3, 236 | norm_cfg=None, 237 | act_cfg=None)) 238 | 239 | def forward(self, tensor_input): 240 | """ 241 | Args: 242 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w). 243 | 8 channels contain: 244 | [reference image (3), neighbor image (3), initial flow (2)]. 245 | 246 | Returns: 247 | Tensor: Refined flow with shape (b, 2, h, w) 248 | """ 249 | return self.basic_module(tensor_input) 250 | -------------------------------------------------------------------------------- /mmedit/apis/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import pickle 4 | import shutil 5 | import tempfile 6 | 7 | import mmcv 8 | import torch 9 | import torch.distributed as dist 10 | from mmcv.runner import get_dist_info 11 | 12 | 13 | def single_gpu_test(model, 14 | data_loader, 15 | save_image=False, 16 | save_path=None, 17 | iteration=None): 18 | """Test model with a single gpu. 19 | 20 | This method tests model with a single gpu and displays test progress bar. 21 | 22 | Args: 23 | model (nn.Module): Model to be tested. 24 | data_loader (nn.Dataloader): Pytorch data loader. 25 | save_image (bool): Whether save image. Default: False. 26 | save_path (str): The path to save image. Default: None. 27 | iteration (int): Iteration number. It is used for the save image name. 28 | Default: None. 29 | 30 | Returns: 31 | list: The prediction results. 32 | """ 33 | if save_image and save_path is None: 34 | raise ValueError( 35 | "When 'save_image' is True, you should also set 'save_path'.") 36 | 37 | model.eval() 38 | results = [] 39 | dataset = data_loader.dataset 40 | prog_bar = mmcv.ProgressBar(len(dataset)) 41 | for data in data_loader: 42 | with torch.no_grad(): 43 | result = model( 44 | test_mode=True, 45 | save_image=save_image, 46 | save_path=save_path, 47 | iteration=iteration, 48 | **data) 49 | results.append(result) 50 | 51 | # get batch size 52 | for _, v in data.items(): 53 | if isinstance(v, torch.Tensor): 54 | batch_size = v.size(0) 55 | break 56 | for _ in range(batch_size): 57 | prog_bar.update() 58 | return results 59 | 60 | 61 | def multi_gpu_test(model, 62 | data_loader, 63 | tmpdir=None, 64 | gpu_collect=False, 65 | save_image=False, 66 | save_path=None, 67 | iteration=None, 68 | empty_cache=False): 69 | """Test model with multiple gpus. 70 | 71 | This method tests model with multiple gpus and collects the results 72 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 73 | it encodes results to gpu tensors and use gpu communication for results 74 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 75 | and collects them by the rank 0 worker. 76 | 77 | Args: 78 | model (nn.Module): Model to be tested. 79 | data_loader (nn.Dataloader): Pytorch data loader. 80 | tmpdir (str): Path of directory to save the temporary results from 81 | different gpus under cpu mode. 82 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 83 | save_image (bool): Whether save image. Default: False. 84 | save_path (str): The path to save image. Default: None. 85 | iteration (int): Iteration number. It is used for the save image name. 86 | Default: None. 87 | empty_cache (bool): empty cache in every iteration. Default: False. 88 | 89 | Returns: 90 | list: The prediction results. 91 | """ 92 | 93 | if save_image and save_path is None: 94 | raise ValueError( 95 | "When 'save_image' is True, you should also set 'save_path'.") 96 | model.eval() 97 | results = [] 98 | dataset = data_loader.dataset 99 | rank, world_size = get_dist_info() 100 | if rank == 0: 101 | prog_bar = mmcv.ProgressBar(len(dataset)) 102 | for data in data_loader: 103 | with torch.no_grad(): 104 | result = model( 105 | test_mode=True, 106 | save_image=save_image, 107 | save_path=save_path, 108 | iteration=iteration, 109 | **data) 110 | results.append(result) 111 | if empty_cache: 112 | torch.cuda.empty_cache() 113 | if rank == 0: 114 | # get batch size 115 | for _, v in data.items(): 116 | if isinstance(v, torch.Tensor): 117 | batch_size = v.size(0) 118 | break 119 | for _ in range(batch_size * world_size): 120 | prog_bar.update() 121 | # collect results from all ranks 122 | if gpu_collect: 123 | results = collect_results_gpu(results, len(dataset)) 124 | else: 125 | results = collect_results_cpu(results, len(dataset), tmpdir) 126 | return results 127 | 128 | 129 | def collect_results_cpu(result_part, size, tmpdir=None): 130 | """Collect results in cpu mode. 131 | 132 | It saves the results on different gpus to 'tmpdir' and collects 133 | them by the rank 0 worker. 134 | 135 | Args: 136 | result_part (list): Results to be collected 137 | size (int): Result size. 138 | tmpdir (str): Path of directory to save the temporary results from 139 | different gpus under cpu mode. Default: None 140 | 141 | Returns: 142 | list: Ordered results. 143 | """ 144 | 145 | rank, world_size = get_dist_info() 146 | # create a tmp dir if it is not specified 147 | if tmpdir is None: 148 | MAX_LEN = 512 149 | # 32 is whitespace 150 | dir_tensor = torch.full((MAX_LEN, ), 151 | 32, 152 | dtype=torch.uint8, 153 | device='cuda') 154 | if rank == 0: 155 | mmcv.mkdir_or_exist('.dist_test') 156 | tmpdir = tempfile.mkdtemp(dir='.dist_test') 157 | tmpdir = torch.tensor( 158 | bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') 159 | dir_tensor[:len(tmpdir)] = tmpdir 160 | dist.broadcast(dir_tensor, 0) 161 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 162 | else: 163 | mmcv.mkdir_or_exist(tmpdir) 164 | # synchronizes all processes to make sure tmpdir exist 165 | dist.barrier() 166 | # dump the part result to the dir 167 | mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank))) 168 | # synchronizes all processes for loading pickle file 169 | dist.barrier() 170 | # collect all parts 171 | if rank != 0: 172 | return None 173 | 174 | # load results of all parts from tmp dir 175 | part_list = [] 176 | for i in range(world_size): 177 | part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) 178 | part_list.append(mmcv.load(part_file)) 179 | # sort the results 180 | ordered_results = [] 181 | for res in zip(*part_list): 182 | ordered_results.extend(list(res)) 183 | # the dataloader may pad some samples 184 | ordered_results = ordered_results[:size] 185 | # remove tmp dir 186 | shutil.rmtree(tmpdir) 187 | return ordered_results 188 | 189 | 190 | def collect_results_gpu(result_part, size): 191 | """Collect results in gpu mode. 192 | 193 | It encodes results to gpu tensors and use gpu communication for results 194 | collection. 195 | 196 | Args: 197 | result_part (list): Results to be collected 198 | size (int): Result size. 199 | 200 | Returns: 201 | list: Ordered results. 202 | """ 203 | 204 | rank, world_size = get_dist_info() 205 | # dump result part to tensor with pickle 206 | part_tensor = torch.tensor( 207 | bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') 208 | # gather all result part tensor shape 209 | shape_tensor = torch.tensor(part_tensor.shape, device='cuda') 210 | shape_list = [shape_tensor.clone() for _ in range(world_size)] 211 | dist.all_gather(shape_list, shape_tensor) 212 | # padding result part tensor to max length 213 | shape_max = torch.tensor(shape_list).max() 214 | part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') 215 | part_send[:shape_tensor[0]] = part_tensor 216 | part_recv_list = [ 217 | part_tensor.new_zeros(shape_max) for _ in range(world_size) 218 | ] 219 | # gather all result part 220 | dist.all_gather(part_recv_list, part_send) 221 | 222 | if rank != 0: 223 | return None 224 | 225 | part_list = [] 226 | for recv, shape in zip(part_recv_list, shape_list): 227 | part_list.append(pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) 228 | # sort the results 229 | ordered_results = [] 230 | for res in zip(*part_list): 231 | ordered_results.extend(list(res)) 232 | # the dataloader may pad some samples 233 | ordered_results = ordered_results[:size] 234 | return ordered_results 235 | -------------------------------------------------------------------------------- /mmedit/datasets/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections.abc import Sequence 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | from mmcv.parallel import DataContainer as DC 8 | from torch.nn import functional as F 9 | 10 | from ..registry import PIPELINES 11 | 12 | 13 | def to_tensor(data): 14 | """Convert objects of various python types to :obj:`torch.Tensor`. 15 | 16 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 17 | :class:`Sequence`, :class:`int` and :class:`float`. 18 | """ 19 | if isinstance(data, torch.Tensor): 20 | return data 21 | if isinstance(data, np.ndarray): 22 | return torch.from_numpy(data) 23 | if isinstance(data, Sequence) and not mmcv.is_str(data): 24 | return torch.tensor(data) 25 | if isinstance(data, int): 26 | return torch.LongTensor([data]) 27 | if isinstance(data, float): 28 | return torch.FloatTensor([data]) 29 | 30 | raise TypeError(f'type {type(data)} cannot be converted to tensor.') 31 | 32 | 33 | @PIPELINES.register_module() 34 | class ToTensor: 35 | """Convert some values in results dict to `torch.Tensor` type in data 36 | loader pipeline. 37 | 38 | Args: 39 | keys (Sequence[str]): Required keys to be converted. 40 | """ 41 | 42 | def __init__(self, keys): 43 | self.keys = keys 44 | 45 | def __call__(self, results): 46 | """Call function. 47 | 48 | Args: 49 | results (dict): A dict containing the necessary information and 50 | data for augmentation. 51 | 52 | Returns: 53 | dict: A dict containing the processed data and information. 54 | """ 55 | for key in self.keys: 56 | results[key] = to_tensor(results[key]) 57 | return results 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ + f'(keys={self.keys})' 61 | 62 | 63 | @PIPELINES.register_module() 64 | class ImageToTensor: 65 | """Convert image type to `torch.Tensor` type. 66 | 67 | Args: 68 | keys (Sequence[str]): Required keys to be converted. 69 | to_float32 (bool): Whether convert numpy image array to np.float32 70 | before converted to tensor. Default: True. 71 | """ 72 | 73 | def __init__(self, keys, to_float32=True): 74 | self.keys = keys 75 | self.to_float32 = to_float32 76 | 77 | def __call__(self, results): 78 | """Call function. 79 | 80 | Args: 81 | results (dict): A dict containing the necessary information and 82 | data for augmentation. 83 | 84 | Returns: 85 | dict: A dict containing the processed data and information. 86 | """ 87 | for key in self.keys: 88 | # deal with gray scale img: expand a color channel 89 | if len(results[key].shape) == 2: 90 | results[key] = results[key][..., None] 91 | if self.to_float32 and not isinstance(results[key], np.float32): 92 | results[key] = results[key].astype(np.float32) 93 | results[key] = to_tensor(results[key].transpose(2, 0, 1)) 94 | return results 95 | 96 | def __repr__(self): 97 | return self.__class__.__name__ + ( 98 | f'(keys={self.keys}, to_float32={self.to_float32})') 99 | 100 | 101 | @PIPELINES.register_module() 102 | class FramesToTensor(ImageToTensor): 103 | """Convert frames type to `torch.Tensor` type. 104 | 105 | It accepts a list of frames, converts each to `torch.Tensor` type and then 106 | concatenates in a new dimension (dim=0). 107 | 108 | Args: 109 | keys (Sequence[str]): Required keys to be converted. 110 | to_float32 (bool): Whether convert numpy image array to np.float32 111 | before converted to tensor. Default: True. 112 | """ 113 | 114 | def __call__(self, results): 115 | """Call function. 116 | 117 | Args: 118 | results (dict): A dict containing the necessary information and 119 | data for augmentation. 120 | 121 | Returns: 122 | dict: A dict containing the processed data and information. 123 | """ 124 | for key in self.keys: 125 | if not isinstance(results[key], list): 126 | raise TypeError(f'results["{key}"] should be a list, ' 127 | f'but got {type(results[key])}') 128 | for idx, v in enumerate(results[key]): 129 | # deal with gray scale img: expand a color channel 130 | if len(v.shape) == 2: 131 | v = v[..., None] 132 | if self.to_float32 and not isinstance(v, np.float32): 133 | v = v.astype(np.float32) 134 | results[key][idx] = to_tensor(v.transpose(2, 0, 1)) 135 | results[key] = torch.stack(results[key], dim=0) 136 | if results[key].size(0) == 1: 137 | results[key].squeeze_() 138 | return results 139 | 140 | 141 | @PIPELINES.register_module() 142 | class GetMaskedImage: 143 | """Get masked image. 144 | 145 | Args: 146 | img_name (str): Key for clean image. 147 | mask_name (str): Key for mask image. The mask shape should be 148 | (h, w, 1) while '1' indicate holes and '0' indicate valid 149 | regions. 150 | """ 151 | 152 | def __init__(self, img_name='gt_img', mask_name='mask'): 153 | self.img_name = img_name 154 | self.mask_name = mask_name 155 | 156 | def __call__(self, results): 157 | """Call function. 158 | 159 | Args: 160 | results (dict): A dict containing the necessary information and 161 | data for augmentation. 162 | 163 | Returns: 164 | dict: A dict containing the processed data and information. 165 | """ 166 | clean_img = results[self.img_name] 167 | mask = results[self.mask_name] 168 | 169 | masked_img = clean_img * (1. - mask) 170 | results['masked_img'] = masked_img 171 | 172 | return results 173 | 174 | def __repr__(self): 175 | return self.__class__.__name__ + ( 176 | f"(img_name='{self.img_name}', mask_name='{self.mask_name}')") 177 | 178 | 179 | @PIPELINES.register_module() 180 | class FormatTrimap: 181 | """Convert trimap (tensor) to one-hot representation. 182 | 183 | It transforms the trimap label from (0, 128, 255) to (0, 1, 2). If 184 | ``to_onehot`` is set to True, the trimap will convert to one-hot tensor of 185 | shape (3, H, W). Required key is "trimap", added or modified key are 186 | "trimap" and "to_onehot". 187 | 188 | Args: 189 | to_onehot (bool): whether convert trimap to one-hot tensor. Default: 190 | ``False``. 191 | """ 192 | 193 | def __init__(self, to_onehot=False): 194 | self.to_onehot = to_onehot 195 | 196 | def __call__(self, results): 197 | """Call function. 198 | 199 | Args: 200 | results (dict): A dict containing the necessary information and 201 | data for augmentation. 202 | 203 | Returns: 204 | dict: A dict containing the processed data and information. 205 | """ 206 | trimap = results['trimap'].squeeze() 207 | trimap[trimap == 128] = 1 208 | trimap[trimap == 255] = 2 209 | if self.to_onehot: 210 | trimap = F.one_hot(trimap.to(torch.long), num_classes=3) 211 | trimap = trimap.permute(2, 0, 1) 212 | else: 213 | trimap = trimap[None, ...] # expand the channels dimension 214 | results['trimap'] = trimap.float() 215 | results['meta'].data['to_onehot'] = self.to_onehot 216 | return results 217 | 218 | def __repr__(self): 219 | return self.__class__.__name__ + f'(to_onehot={self.to_onehot})' 220 | 221 | 222 | @PIPELINES.register_module() 223 | class Collect: 224 | """Collect data from the loader relevant to the specific task. 225 | 226 | This is usually the last stage of the data loader pipeline. Typically keys 227 | is set to some subset of "img", "gt_labels". 228 | 229 | The "img_meta" item is always populated. The contents of the "meta" 230 | dictionary depends on "meta_keys". 231 | 232 | Args: 233 | keys (Sequence[str]): Required keys to be collected. 234 | meta_keys (Sequence[str]): Required keys to be collected to "meta". 235 | Default: None. 236 | """ 237 | 238 | def __init__(self, keys, meta_keys=None): 239 | self.keys = keys 240 | self.meta_keys = meta_keys 241 | 242 | def __call__(self, results): 243 | """Call function. 244 | 245 | Args: 246 | results (dict): A dict containing the necessary information and 247 | data for augmentation. 248 | 249 | Returns: 250 | dict: A dict containing the processed data and information. 251 | """ 252 | data = {} 253 | img_meta = {} 254 | for key in self.meta_keys: 255 | img_meta[key] = results[key] 256 | data['meta'] = DC(img_meta, cpu_only=True) 257 | for key in self.keys: 258 | data[key] = results[key] 259 | return data 260 | 261 | def __repr__(self): 262 | return self.__class__.__name__ + ( 263 | f'(keys={self.keys}, meta_keys={self.meta_keys})') 264 | --------------------------------------------------------------------------------