├── requirements
├── mminstall.txt
├── readthedocs.txt
├── tests.txt
├── docs.txt
└── runtime.txt
├── requirements.txt
├── metrics
├── utils_image.py
└── cal-iqa.py
├── mmedit
├── core
│ ├── evaluation
│ │ ├── niqe_pris_params.npz
│ │ ├── __init__.py
│ │ ├── metric_utils.py
│ │ ├── eval_hooks.py
│ │ └── inceptions.py
│ ├── optimizer
│ │ ├── __init__.py
│ │ └── builder.py
│ ├── utils
│ │ ├── __init__.py
│ │ └── dist_utils.py
│ ├── export
│ │ ├── __init__.py
│ │ └── wrappers.py
│ ├── scheduler
│ │ └── __init__.py
│ ├── registry.py
│ ├── hooks
│ │ ├── __init__.py
│ │ ├── visualization.py
│ │ └── ema.py
│ ├── __init__.py
│ ├── misc.py
│ └── distributed_wrapper.py
├── models
│ ├── backbones
│ │ ├── __init__.py
│ │ └── derain_backbones
│ │ │ ├── __init__.py
│ │ │ └── derain_net.py
│ ├── derainers
│ │ ├── __init__.py
│ │ └── derainer.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── pixelwise_loss.py
│ ├── registry.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── flow_warp.py
│ │ └── sr_backbone_utils.py
│ ├── __init__.py
│ ├── builder.py
│ └── base.py
├── datasets
│ ├── samplers
│ │ ├── __init__.py
│ │ └── distributed_sampler.py
│ ├── registry.py
│ ├── __init__.py
│ ├── dataset_wrappers.py
│ ├── pipelines
│ │ ├── compose.py
│ │ ├── __init__.py
│ │ ├── normalization.py
│ │ ├── random_down_sampling.py
│ │ ├── utils.py
│ │ ├── generate_assistant.py
│ │ └── formating.py
│ ├── sr_folder_gt_dataset.py
│ ├── base_dataset.py
│ ├── base_sr_dataset.py
│ ├── sr_folder_multiple_gt_dataset.py
│ └── builder.py
├── utils
│ ├── __init__.py
│ ├── collect_env.py
│ ├── cli.py
│ ├── logger.py
│ ├── setup_env.py
│ └── misc.py
├── version.py
├── apis
│ ├── __init__.py
│ ├── restoration_inference.py
│ ├── restoration_face_inference.py
│ ├── restoration_video_inference.py
│ └── test.py
└── __init__.py
├── tools
├── dist_train.sh
├── dist_test.sh
├── slurm_test.sh
├── slurm_train.sh
├── deployment
│ ├── test_torchserver.py
│ ├── mmedit_handler.py
│ └── mmedit2torchserve.py
├── test.py
└── train.py
├── setup.cfg
├── modules
└── DeformableAlignment.py
├── README.md
└── configs
└── derainers
└── ViMPNet
└── ViMPNet.py
/requirements/mminstall.txt:
--------------------------------------------------------------------------------
1 | mmcv-full>=1.3.17
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/runtime.txt
2 | -r requirements/tests.txt
3 |
--------------------------------------------------------------------------------
/metrics/utils_image.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TonyHongtaoWu/ViMP-Net/HEAD/metrics/utils_image.py
--------------------------------------------------------------------------------
/requirements/readthedocs.txt:
--------------------------------------------------------------------------------
1 | lmdb
2 | mmcv
3 | regex
4 | scikit-image
5 | titlecase
6 | torch
7 | torchvision
8 |
--------------------------------------------------------------------------------
/requirements/tests.txt:
--------------------------------------------------------------------------------
1 | codecov
2 | flake8
3 | interrogate
4 | isort==5.10.1
5 | onnxruntime
6 | pytest
7 | pytest-runner
8 | yapf
9 |
--------------------------------------------------------------------------------
/mmedit/core/evaluation/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TonyHongtaoWu/ViMP-Net/HEAD/mmedit/core/evaluation/niqe_pris_params.npz
--------------------------------------------------------------------------------
/mmedit/core/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .builder import build_optimizers
3 |
4 | __all__ = ['build_optimizers']
5 |
--------------------------------------------------------------------------------
/mmedit/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dist_utils import sync_random_seed
3 |
4 | __all__ = ['sync_random_seed']
5 |
--------------------------------------------------------------------------------
/mmedit/core/export/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .wrappers import ONNXRuntimeEditing
3 |
4 | __all__ = ['ONNXRuntimeEditing']
5 |
--------------------------------------------------------------------------------
/mmedit/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 | from .derain_backbones import (ViMPNet)
4 |
5 | __all__ = [
6 | 'ViMPNet'
7 | ]
8 |
--------------------------------------------------------------------------------
/mmedit/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .distributed_sampler import DistributedSampler
3 |
4 | __all__ = ['DistributedSampler']
5 |
--------------------------------------------------------------------------------
/mmedit/models/backbones/derain_backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 | from .ViMPNet import ViMPNet
4 |
5 | __all__ = [
6 | 'ViMPNet'
7 | ]
8 |
--------------------------------------------------------------------------------
/mmedit/datasets/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import Registry
3 |
4 | DATASETS = Registry('dataset')
5 | PIPELINES = Registry('pipeline')
6 |
--------------------------------------------------------------------------------
/mmedit/core/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .lr_updater import LinearLrUpdaterHook, ReduceLrUpdaterHook
3 |
4 | __all__ = ['LinearLrUpdaterHook', 'ReduceLrUpdaterHook']
5 |
--------------------------------------------------------------------------------
/mmedit/models/derainers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .derainer import Derainer
3 | from .DrainNet import DrainNet
4 |
5 |
6 | __all__ = [
7 | 'Derainer', 'DrainNet'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmedit/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss
3 |
4 | __all__ = [
5 | 'L1Loss', 'MSELoss', 'CharbonnierLoss'
6 | ]
7 |
--------------------------------------------------------------------------------
/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | docutils==0.16.0
2 | mmcls==0.10.0
3 | myst_parser
4 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
5 | sphinx==4.0.2
6 | sphinx-copybutton
7 | sphinx_markdown_tables
8 |
--------------------------------------------------------------------------------
/mmedit/models/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.cnn import MODELS as MMCV_MODELS
3 | from mmcv.utils import Registry
4 |
5 | MODELS = Registry('model', parent=MMCV_MODELS)
6 | BACKBONES = MODELS
7 | COMPONENTS = MODELS
8 | LOSSES = MODELS
9 |
--------------------------------------------------------------------------------
/mmedit/core/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import Registry, build_from_cfg
3 |
4 | METRICS = Registry('metric')
5 |
6 |
7 | def build_metric(cfg):
8 | """Build a metric calculator."""
9 | return build_from_cfg(cfg, METRICS)
10 |
--------------------------------------------------------------------------------
/mmedit/core/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .ema import ExponentialMovingAverageHook
3 | from .visualization import MMEditVisualizationHook, VisualizationHook
4 |
5 | __all__ = [
6 | 'VisualizationHook', 'MMEditVisualizationHook',
7 | 'ExponentialMovingAverageHook'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmedit/models/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 | from .flow_warp import flow_warp
4 | from .sr_backbone_utils import (ResidualBlockNoBN,
5 | make_layer)
6 |
7 |
8 |
9 |
10 | __all__ = [
11 | 'flow_warp', 'ResidualBlockNoBN', 'make_layer'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmedit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .cli import modify_args
3 | from .logger import get_root_logger
4 | from .misc import deprecated_function
5 | from .setup_env import setup_multi_processes
6 |
7 | __all__ = [
8 | 'get_root_logger', 'setup_multi_processes', 'modify_args',
9 | 'deprecated_function'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmedit/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_dataset import BaseDataset
3 | from .builder import build_dataloader, build_dataset
4 | from .registry import DATASETS, PIPELINES
5 | from .sr_folder_multiple_gt_dataset import SRFolderMultipleGTDataset
6 |
7 | __all__ = [
8 | 'DATASETS', 'PIPELINES', 'build_dataset', 'build_dataloader',
9 | 'BaseDataset', 'SRFolderMultipleGTDataset'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmedit/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .eval_hooks import DistEvalIterHook, EvalIterHook
3 | from .inceptions import FID, KID, InceptionV3
4 | from .metrics import (L1Evaluation, connectivity, gradient_error, mae, mse,
5 | niqe, psnr, reorder_image, sad, ssim)
6 |
7 | __all__ = [
8 | 'mse', 'sad', 'psnr', 'reorder_image', 'ssim', 'EvalIterHook',
9 | 'DistEvalIterHook', 'L1Evaluation', 'gradient_error', 'connectivity',
10 | 'niqe', 'mae', 'FID', 'KID', 'InceptionV3'
11 | ]
12 |
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-18600}
8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9 |
10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
11 | python -m torch.distributed.launch \
12 | --nnodes=$NNODES \
13 | --node_rank=$NODE_RANK \
14 | --master_addr=$MASTER_ADDR \
15 | --nproc_per_node=$GPUS \
16 | --master_port=$PORT \
17 | $(dirname "$0")/train.py \
18 | $CONFIG \
19 | --seed 0 \
20 | --launcher pytorch ${@:3}
21 |
--------------------------------------------------------------------------------
/requirements/runtime.txt:
--------------------------------------------------------------------------------
1 | av
2 | av==8.0.3; python_version < '3.7'
3 | facexlib
4 | lmdb
5 | mmcv-full>=1.3.13 # To support DCN on CPU
6 | numpy
7 | opencv-python!=4.5.5.62,!=4.5.5.64
8 | # MMCV depends opencv-python instead of headless, thus we install opencv-python
9 | # Due to a bug from upstream, we skip this two version
10 | # https://github.com/opencv/opencv-python/issues/602
11 | # https://github.com/opencv/opencv/issues/21366
12 | # It seems to be fixed in https://github.com/opencv/opencv/pull/21382
13 | Pillow
14 | tensorboard
15 | torch
16 | torchvision
17 |
--------------------------------------------------------------------------------
/mmedit/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 |
3 | __version__ = '0.16.1'
4 |
5 |
6 | def parse_version_info(version_str):
7 | ver_info = []
8 | for x in version_str.split('.'):
9 | if x.isdigit():
10 | ver_info.append(int(x))
11 | elif x.find('rc') != -1:
12 | patch_version = x.split('rc')
13 | ver_info.append(int(patch_version[0]))
14 | ver_info.append(f'rc{patch_version[1]}')
15 | return tuple(ver_info)
16 |
17 |
18 | version_info = parse_version_info(__version__)
19 |
--------------------------------------------------------------------------------
/mmedit/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import collect_env as collect_base_env
3 | from mmcv.utils import get_git_hash
4 |
5 | import mmedit
6 |
7 |
8 | def collect_env():
9 | """Collect the information of the running environments."""
10 | env_info = collect_base_env()
11 | env_info['MMEditing'] = f'{mmedit.__version__}+{get_git_hash()[:7]}'
12 |
13 | return env_info
14 |
15 |
16 | if __name__ == '__main__':
17 | for name, val in collect_env().items():
18 | print('{}: {}'.format(name, val))
19 |
--------------------------------------------------------------------------------
/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | NNODES=${NNODES:-1}
7 | NODE_RANK=${NODE_RANK:-0}
8 | PORT=${PORT:-29500}
9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
10 |
11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
12 | python -m torch.distributed.launch \
13 | --nnodes=$NNODES \
14 | --node_rank=$NODE_RANK \
15 | --master_addr=$MASTER_ADDR \
16 | --nproc_per_node=$GPUS \
17 | --master_port=$PORT \
18 | $(dirname "$0")/test.py \
19 | $CONFIG \
20 | $CHECKPOINT \
21 | --launcher pytorch \
22 | ${@:4}
23 |
--------------------------------------------------------------------------------
/mmedit/utils/cli.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import re
3 | import sys
4 | import warnings
5 |
6 |
7 | def modify_args():
8 | for i, v in enumerate(sys.argv):
9 | if i == 0:
10 | assert v.endswith('.py')
11 | elif re.match(r'--\w+_.*', v):
12 | new_arg = v.replace('_', '-')
13 | warnings.warn(
14 | f'command line argument {v} is deprecated, '
15 | f'please use {new_arg} instead.',
16 | category=DeprecationWarning,
17 | )
18 | sys.argv[i] = new_arg
19 |
--------------------------------------------------------------------------------
/mmedit/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .backbones import * # noqa: F401, F403
3 | from .base import BaseModel
4 | from .builder import (build, build_backbone, build_component, build_loss,
5 | build_model)
6 | from .common import * # noqa: F401, F403
7 | from .losses import * # noqa: F401, F403
8 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS
9 | from .derainers import Derainer
10 |
11 |
12 | __all__ = [
13 | 'BaseModel', 'Derainer', 'build',
14 | 'build_backbone', 'build_component', 'build_loss', 'build_model',
15 | 'BACKBONES', 'COMPONENTS', 'LOSSES', 'MODELS',
16 | ]
17 |
--------------------------------------------------------------------------------
/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CHECKPOINT=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal=1
3 |
4 | [aliases]
5 | test=pytest
6 |
7 | [tool:pytest]
8 | addopts=tests/
9 |
10 | [yapf]
11 | based_on_style = pep8
12 | blank_line_before_nested_class_or_def = true
13 | split_before_expression_after_opening_paren = true
14 | split_penalty_import_names=0
15 | SPLIT_PENALTY_AFTER_OPENING_BRACKET=888
16 |
17 | [isort]
18 | line_length = 79
19 | multi_line_output = 0
20 | extra_standard_library = setuptools
21 | known_first_party = mmedit
22 | known_third_party = PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,packaging,pymatting,pytest,pytorch_sphinx_theme,requests,scipy,titlecase,torch,torchvision,ts
23 | no_lines_before = STDLIB,LOCALFOLDER
24 | default_section = THIRDPARTY
25 |
--------------------------------------------------------------------------------
/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
3 |
4 | set -x
5 |
6 | PARTITION=$1
7 | JOB_NAME=$2
8 | CONFIG=$3
9 | WORK_DIR=$4
10 | GPUS=${GPUS:-8}
11 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
12 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
13 | PY_ARGS=${@:5}
14 | SRUN_ARGS=${SRUN_ARGS:-""}
15 |
16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
17 | srun -p ${PARTITION} \
18 | --job-name=${JOB_NAME} \
19 | --gres=gpu:${GPUS_PER_NODE} \
20 | --ntasks=${GPUS} \
21 | --ntasks-per-node=${GPUS_PER_NODE} \
22 | --cpus-per-task=${CPUS_PER_TASK} \
23 | --kill-on-bad-exit=1 \
24 | ${SRUN_ARGS} \
25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
26 |
--------------------------------------------------------------------------------
/mmedit/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .evaluation import (FID, KID, DistEvalIterHook, EvalIterHook, InceptionV3,
3 | L1Evaluation, mae, mse, psnr, reorder_image, sad,
4 | ssim)
5 | from .hooks import MMEditVisualizationHook, VisualizationHook
6 | from .misc import tensor2img
7 | from .optimizer import build_optimizers
8 | from .registry import build_metric
9 | from .scheduler import LinearLrUpdaterHook, ReduceLrUpdaterHook
10 |
11 | __all__ = [
12 | 'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook',
13 | 'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook',
14 | 'VisualizationHook', 'MMEditVisualizationHook', 'L1Evaluation', 'FID',
15 | 'KID', 'InceptionV3', 'build_metric', 'ReduceLrUpdaterHook', 'mae'
16 | ]
17 |
--------------------------------------------------------------------------------
/mmedit/apis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # from .generation_inference import generation_inference
3 | # from .inpainting_inference import inpainting_inference
4 | # from .matting_inference import init_model, matting_inference
5 | from .restoration_face_inference import restoration_face_inference
6 | from .restoration_inference import restoration_inference
7 | from .restoration_video_inference import restoration_video_inference
8 | from .test import multi_gpu_test, single_gpu_test
9 | from .train import init_random_seed, set_random_seed, train_model
10 | # from .video_interpolation_inference import video_interpolation_inference
11 |
12 | __all__ = [
13 | 'train_model', 'set_random_seed', 'init_model',
14 | 'restoration_inference',
15 | 'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference',
16 | 'restoration_face_inference', 'init_random_seed'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmedit/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import logging
3 |
4 | from mmcv.utils import get_logger
5 |
6 |
7 | def get_root_logger(log_file=None, log_level=logging.INFO):
8 | """Get the root logger.
9 |
10 | The logger will be initialized if it has not been initialized. By default a
11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
12 | also be added. The name of the root logger is the top-level package name,
13 | e.g., "mmedit".
14 |
15 | Args:
16 | log_file (str | None): The log filename. If specified, a FileHandler
17 | will be added to the root logger.
18 | log_level (int): The root logger level. Note that only the process of
19 | rank 0 is affected, while other processes will set the level to
20 | "Error" and be silent most of the time.
21 |
22 | Returns:
23 | logging.Logger: The root logger.
24 | """
25 | # root logger name: mmedit
26 | logger = get_logger(__name__.split('.')[0], log_file, log_level)
27 | return logger
28 |
--------------------------------------------------------------------------------
/mmedit/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 |
4 | from .version import __version__, version_info
5 |
6 | try:
7 | from mmcv.utils import digit_version
8 | except ImportError:
9 |
10 | def digit_version(version_str):
11 | digit_ver = []
12 | for x in version_str.split('.'):
13 | if x.isdigit():
14 | digit_ver.append(int(x))
15 | elif x.find('rc') != -1:
16 | patch_version = x.split('rc')
17 | digit_ver.append(int(patch_version[0]) - 1)
18 | digit_ver.append(int(patch_version[1]))
19 | return digit_ver
20 |
21 |
22 | MMCV_MIN = '1.3.13'
23 | MMCV_MAX = '1.8'
24 |
25 | mmcv_min_version = digit_version(MMCV_MIN)
26 | mmcv_max_version = digit_version(MMCV_MAX)
27 | mmcv_version = digit_version(mmcv.__version__)
28 |
29 |
30 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
31 | f'mmcv=={mmcv.__version__} is used but incompatible. ' \
32 | f'Please install mmcv-full>={mmcv_min_version}, <={mmcv_max_version}.'
33 |
34 | __all__ = ['__version__', 'version_info']
35 |
--------------------------------------------------------------------------------
/mmedit/datasets/dataset_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .registry import DATASETS
3 |
4 |
5 | @DATASETS.register_module()
6 | class RepeatDataset:
7 | """A wrapper of repeated dataset.
8 |
9 | The length of repeated dataset will be `times` larger than the original
10 | dataset. This is useful when the data loading time is long but the dataset
11 | is small. Using RepeatDataset can reduce the data loading time between
12 | epochs.
13 |
14 | Args:
15 | dataset (:obj:`Dataset`): The dataset to be repeated.
16 | times (int): Repeat times.
17 | """
18 |
19 | def __init__(self, dataset, times):
20 | self.dataset = dataset
21 | self.times = times
22 |
23 | self._ori_len = len(self.dataset)
24 |
25 | def __getitem__(self, idx):
26 | """Get item at each call.
27 |
28 | Args:
29 | idx (int): Index for getting each item.
30 | """
31 | return self.dataset[idx % self._ori_len]
32 |
33 | def __len__(self):
34 | """Length of the dataset.
35 |
36 | Returns:
37 | int: Length of the dataset.
38 | """
39 | return self.times * self._ori_len
40 |
--------------------------------------------------------------------------------
/mmedit/core/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 | import torch.distributed as dist
5 | from mmcv.runner import get_dist_info
6 |
7 |
8 | def sync_random_seed(seed=None, device='cuda'):
9 | """Make sure different ranks share the same seed.
10 |
11 | All workers must call this function, otherwise it will deadlock.
12 | This method is generally used in `DistributedSampler`,
13 | because the seed should be identical across all processes
14 | in the distributed group.
15 | Args:
16 | seed (int, Optional): The seed. Default to None.
17 | device (str): The device where the seed will be put on.
18 | Default to 'cuda'.
19 | Returns:
20 | int: Seed to be used.
21 | """
22 | if seed is None:
23 | seed = np.random.randint(2**31)
24 | assert isinstance(seed, int)
25 |
26 | rank, world_size = get_dist_info()
27 |
28 | if world_size == 1:
29 | return seed
30 |
31 | if rank == 0:
32 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
33 | else:
34 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
35 | dist.broadcast(random_num, src=0)
36 | return random_num.item()
37 |
--------------------------------------------------------------------------------
/tools/deployment/test_torchserver.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from argparse import ArgumentParser
3 |
4 | import cv2
5 | import numpy as np
6 | import requests
7 | from PIL import Image
8 |
9 |
10 | def parse_args():
11 | parser = ArgumentParser()
12 | parser.add_argument('model_name', help='The model name in the server')
13 | parser.add_argument(
14 | '--inference-addr',
15 | default='127.0.0.1:8080',
16 | help='Address and port of the inference server')
17 | parser.add_argument('--img-path', type=str, help='The input LQ image.')
18 | parser.add_argument(
19 | '--save-path', type=str, help='Path to save the generated GT image.')
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 | def save_results(content, save_path, ori_shape):
25 | ori_len = np.prod(ori_shape)
26 | scale = int(np.sqrt(len(content) / ori_len))
27 | target_size = [int(size * scale) for size in ori_shape[:2][::-1]]
28 | # Convert to RGB and save image
29 | img = Image.frombytes('RGB', target_size, content, 'raw', 'BGR', 0, 0)
30 | img.save(save_path)
31 |
32 |
33 | def main(args):
34 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
35 | ori_shape = cv2.imread(args.img_path).shape
36 | with open(args.img_path, 'rb') as image:
37 | response = requests.post(url, image)
38 | save_results(response.content, args.save_path, ori_shape)
39 |
40 |
41 | if __name__ == '__main__':
42 | parsed_args = parse_args()
43 | main(parsed_args)
44 |
--------------------------------------------------------------------------------
/mmedit/models/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv import build_from_cfg
4 |
5 | from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS
6 |
7 |
8 | def build(cfg, registry, default_args=None):
9 | """Build module function.
10 |
11 | Args:
12 | cfg (dict): Configuration for building modules.
13 | registry (obj): ``registry`` object.
14 | default_args (dict, optional): Default arguments. Defaults to None.
15 | """
16 | if isinstance(cfg, list):
17 | modules = [
18 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
19 | ]
20 | return nn.Sequential(*modules)
21 |
22 | return build_from_cfg(cfg, registry, default_args)
23 |
24 |
25 | def build_backbone(cfg):
26 | """Build backbone.
27 |
28 | Args:
29 | cfg (dict): Configuration for building backbone.
30 | """
31 | return build(cfg, BACKBONES)
32 |
33 |
34 | def build_component(cfg):
35 | """Build component.
36 |
37 | Args:
38 | cfg (dict): Configuration for building component.
39 | """
40 | return build(cfg, COMPONENTS)
41 |
42 |
43 | def build_loss(cfg):
44 | """Build loss.
45 |
46 | Args:
47 | cfg (dict): Configuration for building loss.
48 | """
49 | return build(cfg, LOSSES)
50 |
51 |
52 | def build_model(cfg, train_cfg=None, test_cfg=None):
53 | """Build model.
54 |
55 | Args:
56 | cfg (dict): Configuration for building model.
57 | train_cfg (dict): Training configuration. Default: None.
58 | test_cfg (dict): Testing configuration. Default: None.
59 | """
60 | return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
61 |
--------------------------------------------------------------------------------
/metrics/cal-iqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import os
5 | import torch
6 | import requests
7 | import numpy as np
8 | from os import path as osp
9 | from collections import OrderedDict
10 | from torch.utils.data import DataLoader
11 | import utils_image as util
12 | from rich.progress import track
13 | from natsort import natsorted
14 | import lpips
15 |
16 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
17 | return torch.Tensor((image / factor - cent)
18 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
19 |
20 |
21 | path ="./result/"
22 | gt_path = "./testgt/"
23 | folders = os.listdir(path)
24 | print(path)
25 |
26 |
27 |
28 | psnr = []
29 | ssim = []
30 | lpips_ = []
31 | loss_fn_alex = lpips.LPIPS(net='alex').cuda()
32 |
33 | for folder in folders:
34 | print(folder)
35 | imgs = natsorted(glob.glob(osp.join(path, folder, '*.png')))
36 | imgs_gt = natsorted(glob.glob(osp.join(gt_path, folder, '*.png')))
37 | psnr_folder = []
38 | ssim_folder = []
39 | lpips_folder = []
40 |
41 | for i in track(range(len(imgs))):
42 | output = cv2.imread(imgs[i])
43 | gt = cv2.imread(imgs_gt[i])
44 | if output.shape != gt.shape:
45 | print(output.shape, gt.shape)
46 | psnr_folder.append(util.calculate_psnr(output, gt))
47 | ssim_folder.append(util.calculate_ssim(output, gt))
48 | lpips_folder.append(loss_fn_alex(im2tensor(output).cuda(), im2tensor(gt).cuda()).item())
49 |
50 | psnr.append(np.mean(psnr_folder))
51 | ssim.append(np.mean(ssim_folder))
52 | lpips_.append(np.mean(lpips_folder))
53 |
54 | print('psnr: ', np.mean(psnr))
55 | print('ssim: ', np.mean(ssim))
56 | print('lpips: ', np.mean(lpips_))
57 |
58 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/compose.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from collections.abc import Sequence
3 |
4 | from mmcv.utils import build_from_cfg
5 |
6 | from ..registry import PIPELINES
7 |
8 |
9 | @PIPELINES.register_module()
10 | class Compose:
11 | """Compose a data pipeline with a sequence of transforms.
12 |
13 | Args:
14 | transforms (list[dict | callable]):
15 | Either config dicts of transforms or transform objects.
16 | """
17 |
18 | def __init__(self, transforms):
19 | assert isinstance(transforms, Sequence)
20 | self.transforms = []
21 | for transform in transforms:
22 | if isinstance(transform, dict):
23 | transform = build_from_cfg(transform, PIPELINES)
24 | self.transforms.append(transform)
25 | elif callable(transform):
26 | self.transforms.append(transform)
27 | else:
28 | raise TypeError(f'transform must be callable or a dict, '
29 | f'but got {type(transform)}')
30 |
31 | def __call__(self, data):
32 | """Call function.
33 |
34 | Args:
35 | data (dict): A dict containing the necessary information and
36 | data for augmentation.
37 |
38 | Returns:
39 | dict: A dict containing the processed data and information.
40 | """
41 | for t in self.transforms:
42 | data = t(data)
43 | if data is None:
44 | return None
45 | return data
46 |
47 | def __repr__(self):
48 | format_string = self.__class__.__name__ + '('
49 | for t in self.transforms:
50 | format_string += '\n'
51 | format_string += f' {t}'
52 | format_string += '\n)'
53 | return format_string
54 |
--------------------------------------------------------------------------------
/mmedit/apis/restoration_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from mmcv.parallel import collate, scatter
4 |
5 | from mmedit.datasets.pipelines import Compose
6 |
7 |
8 | def restoration_inference(model, img, ref=None):
9 | """Inference image with the model.
10 |
11 | Args:
12 | model (nn.Module): The loaded model.
13 | img (str): File path of input image.
14 | ref (str | None): File path of reference image. Default: None.
15 |
16 | Returns:
17 | Tensor: The predicted restoration result.
18 | """
19 | cfg = model.cfg
20 | device = next(model.parameters()).device # model device
21 | # remove gt from test_pipeline
22 | keys_to_remove = ['gt', 'gt_path']
23 | for key in keys_to_remove:
24 | for pipeline in list(cfg.test_pipeline):
25 | if 'key' in pipeline and key == pipeline['key']:
26 | cfg.test_pipeline.remove(pipeline)
27 | if 'keys' in pipeline and key in pipeline['keys']:
28 | pipeline['keys'].remove(key)
29 | if len(pipeline['keys']) == 0:
30 | cfg.test_pipeline.remove(pipeline)
31 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
32 | pipeline['meta_keys'].remove(key)
33 | # build the data pipeline
34 | test_pipeline = Compose(cfg.test_pipeline)
35 | # prepare data
36 | if ref: # Ref-SR
37 | data = dict(lq_path=img, ref_path=ref)
38 | else: # SISR
39 | data = dict(lq_path=img)
40 | data = test_pipeline(data)
41 | data = collate([data], samples_per_gpu=1)
42 | if 'cuda' in str(device):
43 | data = scatter(data, [device])[0]
44 | # forward the model
45 | with torch.no_grad():
46 | result = model(test_mode=True, **data)
47 |
48 | return result['output']
49 |
--------------------------------------------------------------------------------
/mmedit/core/optimizer/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.runner import build_optimizer
3 |
4 |
5 | def build_optimizers(model, cfgs):
6 | """Build multiple optimizers from configs.
7 |
8 | If `cfgs` contains several dicts for optimizers, then a dict for each
9 | constructed optimizers will be returned.
10 | If `cfgs` only contains one optimizer config, the constructed optimizer
11 | itself will be returned.
12 |
13 | For example,
14 |
15 | 1) Multiple optimizer configs:
16 |
17 | .. code-block:: python
18 |
19 | optimizer_cfg = dict(
20 | model1=dict(type='SGD', lr=lr),
21 | model2=dict(type='SGD', lr=lr))
22 |
23 | The return dict is
24 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
25 |
26 | 2) Single optimizer config:
27 |
28 | .. code-block:: python
29 |
30 | optimizer_cfg = dict(type='SGD', lr=lr)
31 |
32 | The return is ``torch.optim.Optimizer``.
33 |
34 | Args:
35 | model (:obj:`nn.Module`): The model with parameters to be optimized.
36 | cfgs (dict): The config dict of the optimizer.
37 |
38 | Returns:
39 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
40 | The initialized optimizers.
41 | """
42 | optimizers = {}
43 | if hasattr(model, 'module'):
44 | model = model.module
45 | # determine whether 'cfgs' has several dicts for optimizers
46 | is_dict_of_dict = True
47 | for key, cfg in cfgs.items():
48 | if not isinstance(cfg, dict):
49 | is_dict_of_dict = False
50 |
51 | if is_dict_of_dict:
52 | for key, cfg in cfgs.items():
53 | cfg_ = cfg.copy()
54 | module = getattr(model, key)
55 | optimizers[key] = build_optimizer(module, cfg_)
56 | return optimizers
57 |
58 | return build_optimizer(model, cfgs)
59 |
--------------------------------------------------------------------------------
/mmedit/models/common/flow_warp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def flow_warp(x,
7 | flow,
8 | interpolation='bilinear',
9 | padding_mode='zeros',
10 | align_corners=True):
11 | """Warp an image or a feature map with optical flow.
12 |
13 | Args:
14 | x (Tensor): Tensor with size (n, c, h, w).
15 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
16 | a two-channel, denoting the width and height relative offsets.
17 | Note that the values are not normalized to [-1, 1].
18 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
19 | Default: 'bilinear'.
20 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
21 | Default: 'zeros'.
22 | align_corners (bool): Whether align corners. Default: True.
23 |
24 | Returns:
25 | Tensor: Warped image or feature map.
26 | """
27 | if x.size()[-2:] != flow.size()[1:3]:
28 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
29 | f'flow ({flow.size()[1:3]}) are not the same.')
30 | _, _, h, w = x.size()
31 | # create mesh grid
32 | device = flow.device
33 | grid_y, grid_x = torch.meshgrid(
34 | torch.arange(0, h, device=device, dtype=x.dtype),
35 | torch.arange(0, w, device=device, dtype=x.dtype))
36 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2
37 | grid.requires_grad = False
38 |
39 | grid_flow = grid + flow
40 | # scale grid_flow to [-1,1]
41 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
42 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
43 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
44 | output = F.grid_sample(
45 | x,
46 | grid_flow,
47 | mode=interpolation,
48 | padding_mode=padding_mode,
49 | align_corners=align_corners)
50 | return output
51 |
--------------------------------------------------------------------------------
/mmedit/datasets/sr_folder_gt_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_sr_dataset import BaseSRDataset
3 | from .registry import DATASETS
4 |
5 |
6 | @DATASETS.register_module()
7 | class SRFolderGTDataset(BaseSRDataset):
8 | """General ground-truth image folder dataset for image restoration.
9 |
10 | The dataset loads gt (Ground-Truth) image only,
11 | applies specified transforms and finally returns a dict containing paired
12 | data and other information.
13 |
14 | This is the "gt folder mode", which needs to specify the gt
15 | folder path, each folder containing the corresponding images.
16 | Image lists will be generated automatically.
17 |
18 | For example, we have a folder with the following structure:
19 |
20 | ::
21 |
22 | data_root
23 | ├── gt
24 | │ ├── 0001.png
25 | │ ├── 0002.png
26 |
27 | then, you need to set:
28 |
29 | .. code-block:: python
30 |
31 | gt_folder = data_root/gt
32 |
33 | Args:
34 | gt_folder (str | :obj:`Path`): Path to a gt folder.
35 | pipeline (List[dict | callable]): A sequence of data transformations.
36 | scale (int | tuple): Upsampling scale or upsampling scale range.
37 | test_mode (bool): Store `True` when building test dataset.
38 | Default: `False`.
39 | """
40 |
41 | def __init__(self,
42 | gt_folder,
43 | pipeline,
44 | scale,
45 | test_mode=False,
46 | filename_tmpl='{}'):
47 | super().__init__(pipeline, scale, test_mode)
48 | self.gt_folder = str(gt_folder)
49 | self.filename_tmpl = filename_tmpl
50 | self.data_infos = self.load_annotations()
51 |
52 | def load_annotations(self):
53 | """Load annotations for SR dataset.
54 |
55 | It loads the GT image path from folder.
56 |
57 | Returns:
58 | list[dict]: A list of dicts for path of GT.
59 | """
60 | data_infos = []
61 | gt_paths = self.scan_folder(self.gt_folder)
62 | for gt_path in gt_paths:
63 | data_infos.append(dict(gt_path=gt_path))
64 | return data_infos
65 |
--------------------------------------------------------------------------------
/mmedit/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | from abc import ABCMeta, abstractmethod
4 |
5 | from torch.utils.data import Dataset
6 |
7 | from .pipelines import Compose
8 |
9 |
10 | class BaseDataset(Dataset, metaclass=ABCMeta):
11 | """Base class for datasets.
12 |
13 | All datasets should subclass it.
14 | All subclasses should overwrite:
15 |
16 | ``load_annotations``, supporting to load information and generate
17 | image lists.
18 |
19 | Args:
20 | pipeline (list[dict | callable]): A sequence of data transforms.
21 | test_mode (bool): If True, the dataset will work in test mode.
22 | Otherwise, in train mode.
23 | """
24 |
25 | def __init__(self, pipeline, test_mode=False):
26 | super().__init__()
27 | self.test_mode = test_mode
28 | self.pipeline = Compose(pipeline)
29 |
30 | @abstractmethod
31 | def load_annotations(self):
32 | """Abstract function for loading annotation.
33 |
34 | All subclasses should overwrite this function
35 | """
36 |
37 | def prepare_train_data(self, idx):
38 | """Prepare training data.
39 |
40 | Args:
41 | idx (int): Index of the training batch data.
42 |
43 | Returns:
44 | dict: Returned training batch.
45 | """
46 | results = copy.deepcopy(self.data_infos[idx])
47 | return self.pipeline(results)
48 |
49 | def prepare_test_data(self, idx):
50 | """Prepare testing data.
51 |
52 | Args:
53 | idx (int): Index for getting each testing batch.
54 |
55 | Returns:
56 | Tensor: Returned testing batch.
57 | """
58 | results = copy.deepcopy(self.data_infos[idx])
59 | return self.pipeline(results)
60 |
61 | def __len__(self):
62 | """Length of the dataset.
63 |
64 | Returns:
65 | int: Length of the dataset.
66 | """
67 | return len(self.data_infos)
68 |
69 | def __getitem__(self, idx):
70 | """Get item at each call.
71 |
72 | Args:
73 | idx (int): Index for getting each item.
74 | """
75 | if self.test_mode:
76 | return self.prepare_test_data(idx)
77 |
78 | return self.prepare_train_data(idx)
79 |
--------------------------------------------------------------------------------
/tools/deployment/mmedit_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | import random
4 | import string
5 | from io import BytesIO
6 |
7 | import PIL.Image as Image
8 | import torch
9 | from ts.torch_handler.base_handler import BaseHandler
10 |
11 | from mmedit.apis import init_model, restoration_inference
12 | from mmedit.core import tensor2img
13 |
14 |
15 | class MMEditHandler(BaseHandler):
16 |
17 | def initialize(self, context):
18 | print('MMEditHandler.initialize is called')
19 | properties = context.system_properties
20 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
21 | self.device = torch.device(self.map_location + ':' +
22 | str(properties.get('gpu_id')) if torch.cuda.
23 | is_available() else self.map_location)
24 | self.manifest = context.manifest
25 |
26 | model_dir = properties.get('model_dir')
27 | serialized_file = self.manifest['model']['serializedFile']
28 | checkpoint = os.path.join(model_dir, serialized_file)
29 | self.config_file = os.path.join(model_dir, 'config.py')
30 |
31 | self.model = init_model(self.config_file, checkpoint, self.device)
32 | self.initialized = True
33 |
34 | def preprocess(self, data, *args, **kwargs):
35 | body = data[0].get('data') or data[0].get('body')
36 | result = Image.open(BytesIO(body))
37 | # data preprocess is in inference.
38 | return result
39 |
40 | def inference(self, data, *args, **kwargs):
41 | # generate temp image path for restoration_inference
42 | temp_name = ''.join(
43 | random.sample(string.ascii_letters + string.digits, 18))
44 | temp_path = f'./{temp_name}.png'
45 | data.save(temp_path)
46 | results = restoration_inference(self.model, temp_path)
47 | # delete the temp image path
48 | os.remove(temp_path)
49 | return results
50 |
51 | def postprocess(self, data):
52 | # convert torch tensor to numpy and then convert to bytes
53 | output_list = []
54 | for data_ in data:
55 | data_np = tensor2img(data_)
56 | data_byte = data_np.tobytes()
57 | output_list.append(data_byte)
58 |
59 | return output_list
60 |
--------------------------------------------------------------------------------
/mmedit/utils/setup_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | import platform
4 | import warnings
5 |
6 | import cv2
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def setup_multi_processes(cfg):
11 | """Setup multi-processing environment variables."""
12 | # set multi-process start method as `fork` to speed up the training
13 | if platform.system() != 'Windows':
14 | mp_start_method = cfg.get('mp_start_method', 'fork')
15 | current_method = mp.get_start_method(allow_none=True)
16 | if current_method is not None and current_method != mp_start_method:
17 | warnings.warn(
18 | f'Multi-processing start method `{mp_start_method}` is '
19 | f'different from the previous setting `{current_method}`.'
20 | f'It will be force set to `{mp_start_method}`. You can change '
21 | f'this behavior by changing `mp_start_method` in your config.')
22 | mp.set_start_method(mp_start_method, force=True)
23 |
24 | # disable opencv multithreading to avoid system being overloaded
25 | opencv_num_threads = cfg.get('opencv_num_threads', 0)
26 | cv2.setNumThreads(opencv_num_threads)
27 |
28 | # setup OMP threads
29 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
30 | if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
31 | omp_num_threads = 1
32 | warnings.warn(
33 | f'Setting OMP_NUM_THREADS environment variable for each process '
34 | f'to be {omp_num_threads} in default, to avoid your system being '
35 | f'overloaded, please further tune the variable for optimal '
36 | f'performance in your application as needed.')
37 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
38 |
39 | # setup MKL threads
40 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
41 | mkl_num_threads = 1
42 | warnings.warn(
43 | f'Setting MKL_NUM_THREADS environment variable for each process '
44 | f'to be {mkl_num_threads} in default, to avoid your system being '
45 | f'overloaded, please further tune the variable for optimal '
46 | f'performance in your application as needed.')
47 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
48 |
--------------------------------------------------------------------------------
/mmedit/core/evaluation/metric_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | def gaussian(x, sigma):
7 | """Gaussian function.
8 |
9 | Args:
10 | x (array_like): The independent variable.
11 | sigma (float): Standard deviation of the gaussian function.
12 |
13 | Return:
14 | ndarray or scalar: Gaussian value of `x`.
15 | """
16 | return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
17 |
18 |
19 | def dgaussian(x, sigma):
20 | """Gradient of gaussian.
21 |
22 | Args:
23 | x (array_like): The independent variable.
24 | sigma (float): Standard deviation of the gaussian function.
25 |
26 | Return:
27 | ndarray or scalar: Gradient of gaussian of `x`.
28 | """
29 | return -x * gaussian(x, sigma) / sigma**2
30 |
31 |
32 | def gauss_filter(sigma, epsilon=1e-2):
33 | """Gradient of gaussian.
34 |
35 | Args:
36 | sigma (float): Standard deviation of the gaussian kernel.
37 | epsilon (float): Small value used when calculating kernel size.
38 | Default: 1e-2.
39 |
40 | Return:
41 | tuple[ndarray]: Gaussian filter along x and y axis.
42 | """
43 | half_size = np.ceil(
44 | sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon)))
45 | size = int(2 * half_size + 1)
46 |
47 | # create filter in x axis
48 | filter_x = np.zeros((size, size))
49 | for i in range(size):
50 | for j in range(size):
51 | filter_x[i, j] = gaussian(i - half_size, sigma) * dgaussian(
52 | j - half_size, sigma)
53 |
54 | # normalize filter
55 | norm = np.sqrt((filter_x**2).sum())
56 | filter_x = filter_x / norm
57 | filter_y = np.transpose(filter_x)
58 |
59 | return filter_x, filter_y
60 |
61 |
62 | def gauss_gradient(img, sigma):
63 | """Gaussian gradient.
64 |
65 | From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/
66 | submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/
67 | index.html
68 |
69 | Args:
70 | img (ndarray): Input image.
71 | sigma (float): Standard deviation of the gaussian kernel.
72 |
73 | Return:
74 | ndarray: Gaussian gradient of input `img`.
75 | """
76 | filter_x, filter_y = gauss_filter(sigma)
77 | img_filtered_x = cv2.filter2D(
78 | img, -1, filter_x, borderType=cv2.BORDER_REPLICATE)
79 | img_filtered_y = cv2.filter2D(
80 | img, -1, filter_y, borderType=cv2.BORDER_REPLICATE)
81 | return np.sqrt(img_filtered_x**2 + img_filtered_y**2)
82 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .augmentation import (BinarizeImage, ColorJitter, CopyValues, Flip,
3 | GenerateFrameIndices,
4 | GenerateFrameIndiceswithPadding,
5 | GenerateSegmentIndices, MirrorSequence, Pad,
6 | Quantize, RandomAffine, RandomJitter,
7 | RandomMaskDilation, RandomTransposeHW, Resize,
8 | TemporalReverse, UnsharpMasking)
9 | from .compose import Compose
10 | from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
11 | CropLike, FixedCrop, ModCrop, PairedRandomCrop,
12 | RandomResizedCrop)
13 | from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
14 | ToTensor)
15 | from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap
16 | from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
17 | LoadImageFromFileList, LoadMask, LoadPairedImageFromFile,
18 | RandomLoadResizeBg)
19 | from .matlab_like_resize import MATLABLikeResize
20 | from .matting_aug import (CompositeFg, GenerateSeg, GenerateSoftSeg,
21 | GenerateTrimap, GenerateTrimapWithDistTransform,
22 | MergeFgAndBg, PerturbBg, TransformTrimap)
23 | from .normalization import Normalize, RescaleToZeroOne
24 | from .random_degradations import (DegradationsWithShuffle, RandomBlur,
25 | RandomJPEGCompression, RandomNoise,
26 | RandomResize, RandomVideoCompression)
27 | from .random_down_sampling import RandomDownSampling
28 |
29 | __all__ = [
30 | 'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask',
31 | 'RandomLoadResizeBg', 'Compose', 'ImageToTensor', 'ToTensor',
32 | 'GetMaskedImage', 'BinarizeImage', 'Flip', 'Pad', 'RandomAffine',
33 | 'RandomJitter', 'ColorJitter', 'RandomMaskDilation', 'RandomTransposeHW',
34 | 'Resize', 'RandomResizedCrop', 'Crop', 'CropAroundCenter',
35 | 'CropAroundUnknown', 'ModCrop', 'PairedRandomCrop', 'Normalize',
36 | 'RescaleToZeroOne', 'GenerateTrimap', 'MergeFgAndBg', 'CompositeFg',
37 | 'TemporalReverse', 'LoadImageFromFileList', 'GenerateFrameIndices',
38 | 'GenerateFrameIndiceswithPadding', 'FixedCrop', 'LoadPairedImageFromFile',
39 | 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg',
40 | 'GetSpatialDiscountMask', 'RandomDownSampling',
41 | 'GenerateTrimapWithDistTransform', 'TransformTrimap',
42 | 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
43 | 'CropLike', 'GenerateHeatmap', 'MATLABLikeResize', 'CopyValues',
44 | 'Quantize', 'RandomBlur', 'RandomJPEGCompression', 'RandomNoise',
45 | 'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking',
46 | 'RandomVideoCompression'
47 | ]
48 |
--------------------------------------------------------------------------------
/mmedit/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import functools
3 | import logging
4 | import re
5 | import textwrap
6 | from typing import Callable
7 |
8 | from mmcv import print_log
9 |
10 |
11 | def deprecated_function(since: str, removed_in: str,
12 | instructions: str) -> Callable:
13 | """Marks functions as deprecated.
14 |
15 | Throw a warning when a deprecated function is called, and add a note in the
16 | docstring. Modified from https://github.com/pytorch/pytorch/blob/master/torch/onnx/_deprecation.py
17 | Args:
18 | since (str): The version when the function was first deprecated.
19 | removed_in (str): The version when the function will be removed.
20 | instructions (str): The action users should take.
21 | Returns:
22 | Callable: A new function, which will be deprecated soon.
23 | """ # noqa: E501
24 |
25 | def decorator(function):
26 |
27 | @functools.wraps(function)
28 | def wrapper(*args, **kwargs):
29 | print_log(
30 | f"'{function.__module__}.{function.__name__}' "
31 | f'is deprecated in version {since} and will be '
32 | f'removed in version {removed_in}. Please {instructions}.',
33 | # logger='current',
34 | level=logging.WARNING,
35 | )
36 | return function(*args, **kwargs)
37 |
38 | indent = ' '
39 | # Add a deprecation note to the docstring.
40 | docstring = function.__doc__ or ''
41 | # Add a note to the docstring.
42 | deprecation_note = textwrap.dedent(f"""\
43 | .. deprecated:: {since}
44 | Deprecated and will be removed in version {removed_in}.
45 | Please {instructions}.
46 | """)
47 | # Split docstring at first occurrence of newline
48 | pattern = '\n\n'
49 | summary_and_body = re.split(pattern, docstring, 1)
50 |
51 | if len(summary_and_body) > 1:
52 | summary, body = summary_and_body
53 | body = textwrap.indent(textwrap.dedent(body), indent)
54 | summary = '\n'.join(
55 | [textwrap.dedent(string) for string in summary.split('\n')])
56 | summary = textwrap.indent(summary, prefix=indent)
57 | # Dedent the body. We cannot do this with the presence of the
58 | # summary because the body contains leading whitespaces when the
59 | # summary does not.
60 | new_docstring_parts = [
61 | deprecation_note, '\n\n', summary, '\n\n', body
62 | ]
63 | else:
64 | summary = summary_and_body[0]
65 | summary = '\n'.join(
66 | [textwrap.dedent(string) for string in summary.split('\n')])
67 | summary = textwrap.indent(summary, prefix=indent)
68 | new_docstring_parts = [deprecation_note, '\n\n', summary]
69 |
70 | wrapper.__doc__ = ''.join(new_docstring_parts)
71 |
72 | return wrapper
73 |
74 | return decorator
75 |
--------------------------------------------------------------------------------
/mmedit/datasets/samplers/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from __future__ import division
3 | import math
4 |
5 | import torch
6 | from torch.utils.data import DistributedSampler as _DistributedSampler
7 |
8 | from mmedit.core.utils import sync_random_seed
9 |
10 |
11 | class DistributedSampler(_DistributedSampler):
12 | """DistributedSampler inheriting from
13 | `torch.utils.data.DistributedSampler`.
14 |
15 | In pytorch of lower versions, there is no `shuffle` argument. This child
16 | class will port one to DistributedSampler.
17 | """
18 |
19 | def __init__(self,
20 | dataset,
21 | num_replicas=None,
22 | rank=None,
23 | shuffle=True,
24 | samples_per_gpu=1,
25 | seed=0):
26 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
27 | self.shuffle = shuffle
28 | self.samples_per_gpu = samples_per_gpu
29 | # fix the bug of the official implementation
30 | self.num_samples_per_replica = int(
31 | math.ceil(
32 | len(self.dataset) * 1.0 / self.num_replicas / samples_per_gpu))
33 | self.num_samples = self.num_samples_per_replica * self.samples_per_gpu
34 | self.total_size = self.num_samples * self.num_replicas
35 |
36 | # In distributed sampling, different ranks should sample
37 | # non-overlapped data in the dataset. Therefore, this function
38 | # is used to make sure that each rank shuffles the data indices
39 | # in the same order based on the same seed. Then different ranks
40 | # could use different indices to select non-overlapped data from the
41 | # same data list.
42 | self.seed = sync_random_seed(seed)
43 |
44 | # to avoid padding bug when meeting too small dataset
45 | if len(dataset) < self.num_replicas * samples_per_gpu:
46 | raise ValueError(
47 | 'You may use too small dataset and our distributed '
48 | 'sampler cannot pad your dataset correctly. We highly '
49 | 'recommend you to use fewer GPUs to finish your work')
50 |
51 | def __iter__(self):
52 | # deterministically shuffle based on epoch
53 | if self.shuffle:
54 | g = torch.Generator()
55 | # When :attr:`shuffle=True`, this ensures all replicas
56 | # use a different random ordering for each epoch.
57 | # Otherwise, the next iteration of this sampler will
58 | # yield the same ordering.
59 | g.manual_seed(self.epoch + self.seed)
60 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
61 | else:
62 | indices = torch.arange(len(self.dataset)).tolist()
63 |
64 | # add extra samples to make it evenly divisible
65 | indices += indices[:(self.total_size - len(indices))]
66 | assert len(indices) == self.total_size
67 |
68 | # subsample
69 | indices = indices[self.rank:self.total_size:self.num_replicas]
70 | assert len(indices) == self.num_samples
71 |
72 | return iter(indices)
73 |
--------------------------------------------------------------------------------
/mmedit/core/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import math
3 |
4 | import numpy as np
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
10 | """Convert torch Tensors into image numpy arrays.
11 |
12 | After clamping to (min, max), image values will be normalized to [0, 1].
13 |
14 | For different tensor shapes, this function will have different behaviors:
15 |
16 | 1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
17 | Use `make_grid` to stitch images in the batch dimension, and then
18 | convert it to numpy array.
19 | 2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
20 | Directly change to numpy array.
21 |
22 | Note that the image channel in input tensors should be RGB order. This
23 | function will convert it to cv2 convention, i.e., (H x W x C) with BGR
24 | order.
25 |
26 | Args:
27 | tensor (Tensor | list[Tensor]): Input tensors.
28 | out_type (numpy type): Output types. If ``np.uint8``, transform outputs
29 | to uint8 type with range [0, 255]; otherwise, float type with
30 | range [0, 1]. Default: ``np.uint8``.
31 | min_max (tuple): min and max values for clamp.
32 |
33 | Returns:
34 | (Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
35 | of shape (H x W).
36 | """
37 | if not (torch.is_tensor(tensor) or
38 | (isinstance(tensor, list)
39 | and all(torch.is_tensor(t) for t in tensor))):
40 | raise TypeError(
41 | f'tensor or list of tensors expected, got {type(tensor)}')
42 |
43 | if torch.is_tensor(tensor):
44 | tensor = [tensor]
45 | result = []
46 | for _tensor in tensor:
47 | # Squeeze two times so that:
48 | # 1. (1, 1, h, w) -> (h, w) or
49 | # 3. (1, 3, h, w) -> (3, h, w) or
50 | # 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
51 | _tensor = _tensor.squeeze(0).squeeze(0)
52 |
53 | # _tensor = (_tensor
54 | _tensor = _tensor.float().detach().cpu().clamp_(*min_max)
55 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
56 | n_dim = _tensor.dim()
57 | if n_dim == 4:
58 | img_np = make_grid(
59 | _tensor, nrow=int(math.sqrt(_tensor.size(0))),
60 | normalize=False).numpy()
61 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
62 | elif n_dim == 3:
63 | img_np = _tensor.numpy()
64 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
65 | elif n_dim == 2:
66 | img_np = _tensor.numpy()
67 | else:
68 | raise ValueError('Only support 4D, 3D or 2D tensor. '
69 | f'But received with dimension: {n_dim}')
70 | if out_type == np.uint8:
71 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
72 | img_np = (img_np * 255.0).round()
73 | img_np = img_np.astype(out_type)
74 | result.append(img_np)
75 | result = result[0] if len(result) == 1 else result
76 | return result
77 |
--------------------------------------------------------------------------------
/mmedit/models/common/sr_backbone_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv.cnn import constant_init, kaiming_init
4 | from mmcv.utils.parrots_wrapper import _BatchNorm
5 |
6 |
7 | def default_init_weights(module, scale=1):
8 | """Initialize network weights.
9 |
10 | Args:
11 | modules (nn.Module): Modules to be initialized.
12 | scale (float): Scale initialized weights, especially for residual
13 | blocks.
14 | """
15 | for m in module.modules():
16 | if isinstance(m, nn.Conv2d):
17 | kaiming_init(m, a=0, mode='fan_in', bias=0)
18 | m.weight.data *= scale
19 | elif isinstance(m, nn.Linear):
20 | kaiming_init(m, a=0, mode='fan_in', bias=0)
21 | m.weight.data *= scale
22 | elif isinstance(m, _BatchNorm):
23 | constant_init(m.weight, val=1, bias=0)
24 |
25 |
26 | def make_layer(block, num_blocks, **kwarg):
27 | """Make layers by stacking the same blocks.
28 |
29 | Args:
30 | block (nn.module): nn.module class for basic block.
31 | num_blocks (int): number of blocks.
32 |
33 | Returns:
34 | nn.Sequential: Stacked blocks in nn.Sequential.
35 | """
36 | layers = []
37 | for _ in range(num_blocks):
38 | layers.append(block(**kwarg))
39 | return nn.Sequential(*layers)
40 |
41 |
42 | class ResidualBlockNoBN(nn.Module):
43 | """Residual block without BN.
44 |
45 | It has a style of:
46 |
47 | ::
48 |
49 | ---Conv-ReLU-Conv-+-
50 | |________________|
51 |
52 | Args:
53 | mid_channels (int): Channel number of intermediate features.
54 | Default: 64.
55 | res_scale (float): Used to scale the residual before addition.
56 | Default: 1.0.
57 | """
58 |
59 | def __init__(self, mid_channels=64, res_scale=1.0, groups=1):
60 | super().__init__()
61 | self.res_scale = res_scale
62 | self.conv1 = nn.Conv2d(
63 | mid_channels, mid_channels, 3, 1, 1, bias=True, groups=groups)
64 | self.conv2 = nn.Conv2d(
65 | mid_channels, mid_channels, 3, 1, 1, bias=True, groups=groups)
66 |
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | # if res_scale < 1.0, use the default initialization, as in EDSR.
70 | # if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet.
71 | if res_scale == 1.0:
72 | self.init_weights()
73 |
74 | def init_weights(self):
75 | """Initialize weights for ResidualBlockNoBN.
76 |
77 | Initialization methods like `kaiming_init` are for VGG-style modules.
78 | For modules with residual paths, using smaller std is better for
79 | stability and performance. We empirically use 0.1. See more details in
80 | "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks"
81 | """
82 |
83 | for m in [self.conv1, self.conv2]:
84 | default_init_weights(m, 0.1)
85 |
86 | def forward(self, x):
87 | """Forward function.
88 |
89 | Args:
90 | x (Tensor): Input tensor with shape (n, c, h, w).
91 |
92 | Returns:
93 | Tensor: Forward results.
94 | """
95 |
96 | identity = x
97 | out = self.conv2(self.relu(self.conv1(x)))
98 | return identity + out * self.res_scale
99 |
--------------------------------------------------------------------------------
/mmedit/models/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class BaseModel(nn.Module, metaclass=ABCMeta):
10 | """Base model.
11 |
12 | All models should subclass it.
13 | All subclass should overwrite:
14 |
15 | ``init_weights``, supporting to initialize models.
16 |
17 | ``forward_train``, supporting to forward when training.
18 |
19 | ``forward_test``, supporting to forward when testing.
20 |
21 | ``train_step``, supporting to train one step when training.
22 | """
23 |
24 | @abstractmethod
25 | def init_weights(self):
26 | """Abstract method for initializing weight.
27 |
28 | All subclass should overwrite it.
29 | """
30 |
31 | @abstractmethod
32 | def forward_train(self, imgs, labels):
33 | """Abstract method for training forward.
34 |
35 | All subclass should overwrite it.
36 | """
37 |
38 | @abstractmethod
39 | def forward_test(self, imgs):
40 | """Abstract method for testing forward.
41 |
42 | All subclass should overwrite it.
43 | """
44 |
45 | def forward(self, imgs, labels, test_mode, **kwargs):
46 | """Forward function for base model.
47 |
48 | Args:
49 | imgs (Tensor): Input image(s).
50 | labels (Tensor): Ground-truth label(s).
51 | test_mode (bool): Whether in test mode.
52 | kwargs (dict): Other arguments.
53 |
54 | Returns:
55 | Tensor: Forward results.
56 | """
57 |
58 | if test_mode:
59 | return self.forward_test(imgs, **kwargs)
60 |
61 | return self.forward_train(imgs, labels, **kwargs)
62 |
63 | @abstractmethod
64 | def train_step(self, data_batch, optimizer):
65 | """Abstract method for one training step.
66 |
67 | All subclass should overwrite it.
68 | """
69 |
70 | def val_step(self, data_batch, **kwargs):
71 | """Abstract method for one validation step.
72 |
73 | All subclass should overwrite it.
74 | """
75 | output = self.forward_test(**data_batch, **kwargs)
76 | return output
77 |
78 | def parse_losses(self, losses):
79 | """Parse losses dict for different loss variants.
80 |
81 | Args:
82 | losses (dict): Loss dict.
83 |
84 | Returns:
85 | loss (float): Sum of the total loss.
86 | log_vars (dict): loss dict for different variants.
87 | """
88 | log_vars = OrderedDict()
89 | for loss_name, loss_value in losses.items():
90 | if isinstance(loss_value, torch.Tensor):
91 | log_vars[loss_name] = loss_value.mean()
92 | elif isinstance(loss_value, list):
93 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
94 | else:
95 | raise TypeError(
96 | f'{loss_name} is not a tensor or list of tensors')
97 |
98 | loss = sum(_value for _key, _value in log_vars.items()
99 | if 'loss' in _key)
100 |
101 | log_vars['loss'] = loss
102 | for name in log_vars:
103 | log_vars[name] = log_vars[name].item()
104 |
105 | return loss, log_vars
106 |
--------------------------------------------------------------------------------
/modules/DeformableAlignment.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is based on:
3 | BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from mmcv.cnn import constant_init
9 | from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
10 |
11 |
12 |
13 |
14 |
15 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
16 | """Second-order deformable alignment module.
17 |
18 | Args:
19 | in_channels (int): Same as nn.Conv2d.
20 | out_channels (int): Same as nn.Conv2d.
21 | kernel_size (int or tuple[int]): Same as nn.Conv2d.
22 | stride (int or tuple[int]): Same as nn.Conv2d.
23 | padding (int or tuple[int]): Same as nn.Conv2d.
24 | dilation (int or tuple[int]): Same as nn.Conv2d.
25 | groups (int): Same as nn.Conv2d.
26 | bias (bool or str): If specified as `auto`, it will be decided by the
27 | norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
28 | False.
29 | max_residue_magnitude (int): The maximum magnitude of the offset
30 | residue (Eq. 6 in paper). Default: 10.
31 | """
32 |
33 | def __init__(self, *args, **kwargs):
34 | self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
35 |
36 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
37 |
38 | self.conv_offset = nn.Sequential(
39 | nn.Conv2d(3 * self.out_channels + 5, self.out_channels, 3, 1, 1),
40 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
41 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
42 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
43 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
44 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
45 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
46 | )
47 |
48 | self.init_offset()
49 |
50 | def init_offset(self):
51 | constant_init(self.conv_offset[-1], val=0, bias=0)
52 |
53 | def forward(self, x, extra_feat, flow_1, flow_2, streakmask_n1):
54 | # print(streakmask_n1.shape)
55 | extra_feat = torch.cat([extra_feat, flow_1, flow_2, streakmask_n1], dim=1)
56 | out = self.conv_offset(extra_feat)
57 | o1, o2, mask = torch.chunk(out, 3, dim=1)
58 |
59 | # offset
60 | offset = self.max_residue_magnitude * torch.tanh(
61 | torch.cat((o1, o2), dim=1))
62 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
63 | offset_1 = offset_1 + flow_1.flip(1).repeat(1,
64 | offset_1.size(1) // 2, 1,
65 | 1)
66 | offset_2 = offset_2 + flow_2.flip(1).repeat(1,
67 | offset_2.size(1) // 2, 1,
68 | 1)
69 | offset = torch.cat([offset_1, offset_2], dim=1)
70 |
71 | # mask
72 | mask = torch.sigmoid(mask)
73 |
74 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
75 | self.stride, self.padding,
76 | self.dilation, self.groups,
77 | self.deform_groups)
78 |
--------------------------------------------------------------------------------
/mmedit/apis/restoration_face_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 | from mmcv.parallel import collate, scatter
5 |
6 | from mmedit.datasets.pipelines import Compose
7 |
8 | try:
9 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
10 | has_facexlib = True
11 | except ImportError:
12 | has_facexlib = False
13 |
14 |
15 | def restoration_face_inference(model, img, upscale_factor=1, face_size=1024):
16 | """Inference image with the model.
17 |
18 | Args:
19 | model (nn.Module): The loaded model.
20 | img (str): File path of input image.
21 | upscale_factor (int, optional): The number of times the input image
22 | is upsampled. Default: 1.
23 | face_size (int, optional): The size of the cropped and aligned faces.
24 | Default: 1024.
25 |
26 | Returns:
27 | Tensor: The predicted restoration result.
28 | """
29 | device = next(model.parameters()).device # model device
30 |
31 | # build the data pipeline
32 | if model.cfg.get('demo_pipeline', None):
33 | test_pipeline = model.cfg.demo_pipeline
34 | elif model.cfg.get('test_pipeline', None):
35 | test_pipeline = model.cfg.test_pipeline
36 | else:
37 | test_pipeline = model.cfg.val_pipeline
38 |
39 | # remove gt from test_pipeline
40 | keys_to_remove = ['gt', 'gt_path']
41 | for key in keys_to_remove:
42 | for pipeline in list(test_pipeline):
43 | if 'key' in pipeline and key == pipeline['key']:
44 | test_pipeline.remove(pipeline)
45 | if 'keys' in pipeline and key in pipeline['keys']:
46 | pipeline['keys'].remove(key)
47 | if len(pipeline['keys']) == 0:
48 | test_pipeline.remove(pipeline)
49 | if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
50 | pipeline['meta_keys'].remove(key)
51 | # build the data pipeline
52 | test_pipeline = Compose(test_pipeline)
53 |
54 | # face helper for detecting and aligning faces
55 | assert has_facexlib, 'Please install FaceXLib to use the demo.'
56 | face_helper = FaceRestoreHelper(
57 | upscale_factor,
58 | face_size=face_size,
59 | crop_ratio=(1, 1),
60 | det_model='retinaface_resnet50',
61 | template_3points=True,
62 | save_ext='png',
63 | device=device)
64 |
65 | face_helper.read_image(img)
66 | # get face landmarks for each face
67 | face_helper.get_face_landmarks_5(
68 | only_center_face=False, eye_dist_threshold=None)
69 | # align and warp each face
70 | face_helper.align_warp_face()
71 |
72 | for i, img in enumerate(face_helper.cropped_faces):
73 | # prepare data
74 | data = dict(lq=img.astype(np.float32))
75 | data = test_pipeline(data)
76 | data = collate([data], samples_per_gpu=1)
77 | if 'cuda' in str(device):
78 | data = scatter(data, [device])[0]
79 |
80 | with torch.no_grad():
81 | output = model(test_mode=True, **data)['output']
82 | output = torch.clamp(output, min=0, max=1)
83 |
84 | output = output.squeeze(0).permute(1, 2, 0)[:, :, [2, 1, 0]]
85 | output = output.cpu().numpy() * 255 # (0, 255)
86 | face_helper.add_restored_face(output)
87 |
88 | face_helper.get_inverse_affine(None)
89 | restored_img = face_helper.paste_faces_to_input_image(upsample_img=None)
90 |
91 | return restored_img
92 |
--------------------------------------------------------------------------------
/mmedit/core/hooks/visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | import mmcv
5 | import torch
6 | from mmcv.runner import HOOKS, Hook
7 | from mmcv.runner.dist_utils import master_only
8 | from torchvision.utils import save_image
9 |
10 | from mmedit.utils import deprecated_function
11 |
12 |
13 | @HOOKS.register_module()
14 | class MMEditVisualizationHook(Hook):
15 | """Visualization hook.
16 |
17 | In this hook, we use the official api `save_image` in torchvision to save
18 | the visualization results.
19 |
20 | Args:
21 | output_dir (str): The file path to store visualizations.
22 | res_name_list (str): The list contains the name of results in outputs
23 | dict. The results in outputs dict must be a torch.Tensor with shape
24 | (n, c, h, w).
25 | interval (int): The interval of calling this hook. If set to -1,
26 | the visualization hook will not be called. Default: -1.
27 | filename_tmpl (str): Format string used to save images. The output file
28 | name will be formatted as this args. Default: 'iter_{}.png'.
29 | rerange (bool): Whether to rerange the output value from [-1, 1] to
30 | [0, 1]. We highly recommend users should preprocess the
31 | visualization results on their own. Here, we just provide a simple
32 | interface. Default: True.
33 | bgr2rgb (bool): Whether to reformat the channel dimension from BGR to
34 | RGB. The final image we will save is following RGB style.
35 | Default: True.
36 | nrow (int): The number of samples in a row. Default: 1.
37 | padding (int): The number of padding pixels between each samples.
38 | Default: 4.
39 | """
40 |
41 | def __init__(self,
42 | output_dir,
43 | res_name_list,
44 | interval=-1,
45 | filename_tmpl='iter_{}.png',
46 | rerange=True,
47 | bgr2rgb=True,
48 | nrow=1,
49 | padding=4):
50 | assert mmcv.is_list_of(res_name_list, str)
51 | self.output_dir = output_dir
52 | self.res_name_list = res_name_list
53 | self.interval = interval
54 | self.filename_tmpl = filename_tmpl
55 | self.bgr2rgb = bgr2rgb
56 | self.rerange = rerange
57 | self.nrow = nrow
58 | self.padding = padding
59 |
60 | mmcv.mkdir_or_exist(self.output_dir)
61 |
62 | @master_only
63 | def after_train_iter(self, runner):
64 | """The behavior after each train iteration.
65 |
66 | Args:
67 | runner (object): The runner.
68 | """
69 | if not self.every_n_iters(runner, self.interval):
70 | return
71 | results = runner.outputs['results']
72 |
73 | filename = self.filename_tmpl.format(runner.iter + 1)
74 |
75 | img_list = [results[k] for k in self.res_name_list if k in results]
76 | img_cat = torch.cat(img_list, dim=3).detach()
77 | if self.rerange:
78 | img_cat = ((img_cat + 1) / 2)
79 | if self.bgr2rgb:
80 | img_cat = img_cat[:, [2, 1, 0], ...]
81 | img_cat = img_cat.clamp_(0, 1)
82 | save_image(
83 | img_cat,
84 | osp.join(self.output_dir, filename),
85 | nrow=self.nrow,
86 | padding=self.padding)
87 |
88 |
89 | @HOOKS.register_module()
90 | class VisualizationHook(MMEditVisualizationHook):
91 |
92 | @deprecated_function('0.16.0', '0.20.0', 'use \'MMEditVisualizationHook\'')
93 | def __init__(self, *args, **kwargs):
94 | super().__init__(*args, **kwargs)
95 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/normalization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 | import numpy as np
4 |
5 | from ..registry import PIPELINES
6 |
7 |
8 | @PIPELINES.register_module()
9 | class Normalize:
10 | """Normalize images with the given mean and std value.
11 |
12 | Required keys are the keys in attribute "keys", added or modified keys are
13 | the keys in attribute "keys" and these keys with postfix '_norm_cfg'.
14 | It also supports normalizing a list of images.
15 |
16 | Args:
17 | keys (Sequence[str]): The images to be normalized.
18 | mean (np.ndarray): Mean values of different channels.
19 | std (np.ndarray): Std values of different channels.
20 | to_rgb (bool): Whether to convert channels from BGR to RGB.
21 | """
22 |
23 | def __init__(self, keys, mean, std, to_rgb=False, save_original=False):
24 | self.keys = keys
25 | self.mean = np.array(mean, dtype=np.float32)
26 | self.std = np.array(std, dtype=np.float32)
27 | self.to_rgb = to_rgb
28 | self.save_original = save_original
29 |
30 | def __call__(self, results):
31 | """Call function.
32 |
33 | Args:
34 | results (dict): A dict containing the necessary information and
35 | data for augmentation.
36 |
37 | Returns:
38 | dict: A dict containing the processed data and information.
39 | """
40 | for key in self.keys:
41 | if isinstance(results[key], list):
42 | if self.save_original:
43 | results[key + '_unnormalised'] = [
44 | v.copy() for v in results[key]
45 | ]
46 | results[key] = [
47 | mmcv.imnormalize(v, self.mean, self.std, self.to_rgb)
48 | for v in results[key]
49 | ]
50 | else:
51 | if self.save_original:
52 | results[key + '_unnormalised'] = results[key].copy()
53 | results[key] = mmcv.imnormalize(results[key], self.mean,
54 | self.std, self.to_rgb)
55 |
56 | results['img_norm_cfg'] = dict(
57 | mean=self.mean, std=self.std, to_rgb=self.to_rgb)
58 | return results
59 |
60 | def __repr__(self):
61 | repr_str = self.__class__.__name__
62 | repr_str += (f'(keys={self.keys}, mean={self.mean}, std={self.std}, '
63 | f'to_rgb={self.to_rgb})')
64 |
65 | return repr_str
66 |
67 |
68 | @PIPELINES.register_module()
69 | class RescaleToZeroOne:
70 | """Transform the images into a range between 0 and 1.
71 |
72 | Required keys are the keys in attribute "keys", added or modified keys are
73 | the keys in attribute "keys".
74 | It also supports rescaling a list of images.
75 |
76 | Args:
77 | keys (Sequence[str]): The images to be transformed.
78 | """
79 |
80 | def __init__(self, keys):
81 | self.keys = keys
82 |
83 | def __call__(self, results):
84 | """Call function.
85 |
86 | Args:
87 | results (dict): A dict containing the necessary information and
88 | data for augmentation.
89 |
90 | Returns:
91 | dict: A dict containing the processed data and information.
92 | """
93 | for key in self.keys:
94 | if isinstance(results[key], list):
95 | results[key] = [
96 | v.astype(np.float32) / 255. for v in results[key]
97 | ]
98 | else:
99 | results[key] = results[key].astype(np.float32) / 255.
100 | return results
101 |
102 | def __repr__(self):
103 | return self.__class__.__name__ + f'(keys={self.keys})'
104 |
--------------------------------------------------------------------------------
/mmedit/datasets/base_sr_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | import os.path as osp
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | from mmcv import scandir
9 |
10 | from mmedit.core.registry import build_metric
11 | from .base_dataset import BaseDataset
12 |
13 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
14 | '.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')
15 | FEATURE_BASED_METRICS = ['FID', 'KID']
16 |
17 |
18 | class BaseSRDataset(BaseDataset):
19 | """Base class for super resolution datasets."""
20 |
21 | def __init__(self, pipeline, scale, test_mode=False):
22 | super().__init__(pipeline, test_mode)
23 | self.scale = scale
24 |
25 | @staticmethod
26 | def scan_folder(path):
27 | """Obtain image path list (including sub-folders) from a given folder.
28 |
29 | Args:
30 | path (str | :obj:`Path`): Folder path.
31 |
32 | Returns:
33 | list[str]: image list obtained form given folder.
34 | """
35 |
36 | if isinstance(path, (str, Path)):
37 | path = str(path)
38 | else:
39 | raise TypeError("'path' must be a str or a Path object, "
40 | f'but received {type(path)}.')
41 |
42 | images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True))
43 | images = [osp.join(path, v) for v in images]
44 | assert images, f'{path} has no valid image file.'
45 | return images
46 |
47 | def __getitem__(self, idx):
48 | """Get item at each call.
49 |
50 | Args:
51 | idx (int): Index for getting each item.
52 | """
53 | results = copy.deepcopy(self.data_infos[idx])
54 | results['scale'] = self.scale
55 | return self.pipeline(results)
56 |
57 | def evaluate(self, results, logger=None):
58 | """Evaluate with different metrics.
59 |
60 | Args:
61 | results (list[tuple]): The output of forward_test() of the model.
62 |
63 | Return:
64 | dict: Evaluation results dict.
65 | """
66 | if not isinstance(results, list):
67 | raise TypeError(f'results must be a list, but got {type(results)}')
68 | assert len(results) == len(self), (
69 | 'The length of results is not equal to the dataset len: '
70 | f'{len(results)} != {len(self)}')
71 |
72 | results = [res['eval_result'] for res in results] # a list of dict
73 | eval_result = defaultdict(list) # a dict of list
74 |
75 | for res in results:
76 | for metric, val in res.items():
77 | eval_result[metric].append(val)
78 | for metric, val_list in eval_result.items():
79 | assert len(val_list) == len(self), (
80 | f'Length of evaluation result of {metric} is {len(val_list)}, '
81 | f'should be {len(self)}')
82 |
83 | # average the results
84 | eval_result.update({
85 | metric: sum(values) / len(self)
86 | for metric, values in eval_result.items()
87 | if metric not in ['_inception_feat'] + FEATURE_BASED_METRICS
88 | })
89 |
90 | # evaluate feature-based metrics
91 | if '_inception_feat' in eval_result:
92 | feat1, feat2 = [], []
93 | for f1, f2 in eval_result['_inception_feat']:
94 | feat1.append(f1)
95 | feat2.append(f2)
96 | feat1 = np.concatenate(feat1, 0)
97 | feat2 = np.concatenate(feat2, 0)
98 |
99 | for metric in FEATURE_BASED_METRICS:
100 | if metric in eval_result:
101 | metric_func = build_metric(eval_result[metric].pop())
102 | eval_result[metric] = metric_func(feat1, feat2)
103 |
104 | # delete a redundant key for clean logging
105 | del eval_result['_inception_feat']
106 |
107 | return eval_result
108 |
--------------------------------------------------------------------------------
/tools/deployment/mmedit2torchserve.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from argparse import ArgumentParser, Namespace
3 | from pathlib import Path
4 | from tempfile import TemporaryDirectory
5 |
6 | import mmcv
7 |
8 | try:
9 | from model_archiver.model_packaging import package_model
10 | from model_archiver.model_packaging_utils import ModelExportUtils
11 | except ImportError:
12 | package_model = None
13 |
14 |
15 | def mmedit2torchserve(
16 | config_file: str,
17 | checkpoint_file: str,
18 | output_folder: str,
19 | model_name: str,
20 | model_version: str = '1.0',
21 | force: bool = False,
22 | ):
23 | """Converts MMEditing model (config + checkpoint) to TorchServe `.mar`.
24 |
25 | Args:
26 | config_file:
27 | In MMEditing config format.
28 | The contents vary for each task repository.
29 | checkpoint_file:
30 | In MMEditing checkpoint format.
31 | The contents vary for each task repository.
32 | output_folder:
33 | Folder where `{model_name}.mar` will be created.
34 | The file created will be in TorchServe archive format.
35 | model_name:
36 | If not None, used for naming the `{model_name}.mar` file
37 | that will be created under `output_folder`.
38 | If None, `{Path(checkpoint_file).stem}` will be used.
39 | model_version:
40 | Model's version.
41 | force:
42 | If True, if there is an existing `{model_name}.mar`
43 | file under `output_folder` it will be overwritten.
44 | """
45 | mmcv.mkdir_or_exist(output_folder)
46 |
47 | config = mmcv.Config.fromfile(config_file)
48 |
49 | with TemporaryDirectory() as tmpdir:
50 | config.dump(f'{tmpdir}/config.py')
51 |
52 | args_ = Namespace(
53 | **{
54 | 'model_file': f'{tmpdir}/config.py',
55 | 'serialized_file': checkpoint_file,
56 | 'handler': f'{Path(__file__).parent}/mmedit_handler.py',
57 | 'model_name': model_name or Path(checkpoint_file).stem,
58 | 'version': model_version,
59 | 'export_path': output_folder,
60 | 'force': force,
61 | 'requirements_file': None,
62 | 'extra_files': None,
63 | 'runtime': 'python',
64 | 'archive_format': 'default'
65 | })
66 | print(args_.model_name)
67 | manifest = ModelExportUtils.generate_manifest_json(args_)
68 | package_model(args_, manifest)
69 |
70 |
71 | def parse_args():
72 | parser = ArgumentParser(
73 | description='Convert MMEditing models to TorchServe `.mar` format.')
74 | parser.add_argument('config', type=str, help='config file path')
75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path')
76 | parser.add_argument(
77 | '--output-folder',
78 | type=str,
79 | required=True,
80 | help='Folder where `{model_name}.mar` will be created.')
81 | parser.add_argument(
82 | '--model-name',
83 | type=str,
84 | default=None,
85 | help='If not None, used for naming the `{model_name}.mar`'
86 | 'file that will be created under `output_folder`.'
87 | 'If None, `{Path(checkpoint_file).stem}` will be used.')
88 | parser.add_argument(
89 | '--model-version',
90 | type=str,
91 | default='1.0',
92 | help='Number used for versioning.')
93 | parser.add_argument(
94 | '-f',
95 | '--force',
96 | action='store_true',
97 | help='overwrite the existing `{model_name}.mar`')
98 | args_ = parser.parse_args()
99 |
100 | return args_
101 |
102 |
103 | if __name__ == '__main__':
104 | args = parse_args()
105 |
106 | if package_model is None:
107 | raise ImportError('`torch-model-archiver` is required.'
108 | 'Try: pip install torch-model-archiver')
109 |
110 | mmedit2torchserve(args.config, args.checkpoint, args.output_folder,
111 | args.model_name, args.model_version, args.force)
112 |
--------------------------------------------------------------------------------
/mmedit/models/losses/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import functools
3 |
4 | import torch.nn.functional as F
5 |
6 |
7 | def reduce_loss(loss, reduction):
8 | """Reduce loss as specified.
9 |
10 | Args:
11 | loss (Tensor): Elementwise loss tensor.
12 | reduction (str): Options are "none", "mean" and "sum".
13 |
14 | Returns:
15 | Tensor: Reduced loss tensor.
16 | """
17 | reduction_enum = F._Reduction.get_enum(reduction)
18 | # none: 0, elementwise_mean:1, sum: 2
19 | if reduction_enum == 0:
20 | return loss
21 | if reduction_enum == 1:
22 | return loss.mean()
23 |
24 | return loss.sum()
25 |
26 |
27 | def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False):
28 | """Apply element-wise weight and reduce loss.
29 |
30 | Args:
31 | loss (Tensor): Element-wise loss.
32 | weight (Tensor): Element-wise weights. Default: None.
33 | reduction (str): Same as built-in losses of PyTorch. Options are
34 | "none", "mean" and "sum". Default: 'mean'.
35 | sample_wise (bool): Whether calculate the loss sample-wise. This
36 | argument only takes effect when `reduction` is 'mean' and `weight`
37 | (argument of `forward()`) is not None. It will first reduces loss
38 | with 'mean' per-sample, and then it means over all the samples.
39 | Default: False.
40 |
41 | Returns:
42 | Tensor: Processed loss values.
43 | """
44 | # if weight is specified, apply element-wise weight
45 | if weight is not None:
46 | assert weight.dim() == loss.dim()
47 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
48 | loss = loss * weight
49 |
50 | # if weight is not specified or reduction is sum, just reduce the loss
51 | if weight is None or reduction == 'sum':
52 | loss = reduce_loss(loss, reduction)
53 | # if reduction is mean, then compute mean over masked region
54 | elif reduction == 'mean':
55 | # expand weight from N1HW to NCHW
56 | if weight.size(1) == 1:
57 | weight = weight.expand_as(loss)
58 | # small value to prevent division by zero
59 | eps = 1e-12
60 |
61 | # perform sample-wise mean
62 | if sample_wise:
63 | weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111
64 | loss = (loss / (weight + eps)).sum() / weight.size(0)
65 | # perform pixel-wise mean
66 | else:
67 | loss = loss.sum() / (weight.sum() + eps)
68 |
69 | return loss
70 |
71 |
72 | def masked_loss(loss_func):
73 | """Create a masked version of a given loss function.
74 |
75 | To use this decorator, the loss function must have the signature like
76 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
77 | element-wise loss without any reduction. This decorator will add weight
78 | and reduction arguments to the function. The decorated function will have
79 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
80 | avg_factor=None, **kwargs)`.
81 |
82 | :Example:
83 |
84 | >>> import torch
85 | >>> @masked_loss
86 | >>> def l1_loss(pred, target):
87 | >>> return (pred - target).abs()
88 |
89 | >>> pred = torch.Tensor([0, 2, 3])
90 | >>> target = torch.Tensor([1, 1, 1])
91 | >>> weight = torch.Tensor([1, 0, 1])
92 |
93 | >>> l1_loss(pred, target)
94 | tensor(1.3333)
95 | >>> l1_loss(pred, target, weight)
96 | tensor(1.5000)
97 | >>> l1_loss(pred, target, reduction='none')
98 | tensor([1., 1., 2.])
99 | >>> l1_loss(pred, target, weight, reduction='sum')
100 | tensor(3.)
101 | """
102 |
103 | @functools.wraps(loss_func)
104 | def wrapper(pred,
105 | target,
106 | weight=None,
107 | reduction='mean',
108 | sample_wise=False,
109 | **kwargs):
110 | # get element-wise loss
111 | loss = loss_func(pred, target, **kwargs)
112 | loss = mask_reduce_loss(loss, weight, reduction, sample_wise)
113 | return loss
114 |
115 | return wrapper
116 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mask-Guided Progressive Network for Joint Raindrop and Rain Streak Removal in Videos (ACM MM 2023)
2 | Hongtao Wu, Yijun Yang, Haoyu Chen, Jingjing Ren, Lei Zhu
3 |
4 | This repo is the official Pytorch implementation of [Mask-Guided Progressive Network for Joint Raindrop and Rain Streak Removal in Videos](https://dl.acm.org/doi/abs/10.1145/3581783.3612001).
5 | The first dataset and approach for video rain streaks and raindrops removal.
6 |
7 |
8 |
9 |
10 | > **Abstract:** *Videos captured in rainy weather are unavoidably corrupted by both rain streaks and raindrops in driving scenarios, and it is desirable and challenging to recover background details obscured by rain streaks and raindrops. However, existing video rain removal methods often address either video rain streak removal or video raindrop removal, thereby suffer from degraded performance when deal with both simultaneously. The bottleneck is a lack of a video dataset, where each video frame contains both rain streaks and raindrops. To address this issue, we in this work generate a synthesized dataset, namely VRDS, with 102 rainy videos from diverse scenarios, and each video frame has the corresponding rain streak map, raindrop mask, and the underlying rain-free clean image (ground truth). Moreover, we devise a mask-guided progressive video deraining network (ViMP-Net) to remove both rain streaks and raindrops of each video frame. Specifically, we develop an intensity-guided alignment block to predict the rain streak intensity map and remove the rain streaks of the input rainy video at the first stage. Then, we predict a raindrop mask and pass it into a devised mask-guided dual transformer block to learn inter-frame and intra-frame transformer features, which are then fed into a decoder for further eliminating raindrops. Experimental results demonstrate that our ViMP-Net outperforms state-of-the-art methods on our synthetic dataset and real-world rainy videos.*
11 |
12 |
13 |
14 | ## Our Dataset
15 | Our VRDS dataset can be downloaded [here](https://hkustgz-my.sharepoint.com/:f:/g/personal/hwu375_connect_hkust-gz_edu_cn/EmI_nfrnMyNAohEwNtnq50MB22RWxp-x_mtp264aVzOxlA?e=CjP3kO).
16 |
17 | ## Installation
18 |
19 | This implementation is based on [MMEditing](https://github.com/open-mmlab/mmediting),
20 | which is an open-source image and video editing toolbox.
21 |
22 |
23 | Below are quick steps for installation.
24 |
25 | **Step 1.**
26 | Install PyTorch following [official instructions](https://pytorch.org/get-started/locally/).
27 |
28 | **Step 2.**
29 | Install MMCV with [MIM](https://github.com/open-mmlab/mim).
30 |
31 | ```shell
32 | pip3 install openmim
33 | mim install mmcv-full
34 | ```
35 |
36 | **Step 3.**
37 | Install ViMP-Net from source.
38 |
39 | ```shell
40 | git clone https://github.com/TonyHongtaoWu/ViMP-Net.git
41 | cd ViMP-Net
42 | pip3 install -e .
43 | ```
44 |
45 | Please refer to [MMEditing Installation](https://github.com/open-mmlab/mmediting/blob/master/docs/en/install.md) for more detailed instruction.
46 |
47 |
48 | ## Training and Testing
49 | You may need to adjust the dataset path and dataloader before starting.
50 |
51 | You can train ViMP-Net on VRDS dataset using the below command with 4 GPUs:
52 |
53 | ```shell
54 | CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/derainers/ViMPNet/ViMPNet.py 4
55 | ```
56 |
57 | You can use the following command with 1 GPU to test your trained model `model.pth`:
58 |
59 | ```shell
60 | CUDA_VISIBLE_DEVICES=0 ./tools/dist_test.sh configs/derainers/ViMPNet/ViMPNet.py "model.pth" 1 --save-path './save_path/'
61 | ```
62 |
63 | You can find one model checkpoint trained on VRDS dataset [here](https://drive.google.com/drive/folders/1Iu_sxlN3nUpi99QUxWAnRP1a0mNNm2JU?usp=sharing).
64 |
65 |
66 |
67 | ## Our Results
68 | The visual results of ViMP-Net can be downloaded in [Google Drive](https://drive.google.com/file/d/1yEFbQbh45hWOu2g4HR9-SUvZZpyJJd7l/view?usp=sharing) and [Outlook](https://hkustgz-my.sharepoint.com/:u:/g/personal/hwu375_connect_hkust-gz_edu_cn/EVM_XI3KcE9DgQaE9hbXvLQBjhnMP0rvQnSVcnOFcsMyTA?e=7tE2Kk).
69 |
70 |
71 | ## Contact
72 | Should you have any question or suggestion, please contact hwu375@connect.hkust-gz.edu.cn.
73 |
74 | ## Acknowledgement
75 | This code is based on [MMEditing](https://github.com/open-mmlab/mmagic) and [FuseFormer](https://github.com/ruiliu-ai/FuseFormer).
76 |
77 | ## Citation
78 | If you find this repository helpful to your research, please consider citing the following:
79 | ```
80 | @inproceedings{wu2023mask,
81 | title={Mask-Guided Progressive Network for Joint Raindrop and Rain Streak Removal in Videos},
82 | author={Wu, Hongtao and Yang, Yijun and Chen, Haoyu and Ren, Jingjing and Zhu, Lei},
83 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
84 | pages={7216--7225},
85 | year={2023}
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/mmedit/core/evaluation/eval_hooks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | from mmcv.runner import Hook
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class EvalIterHook(Hook):
9 | """Non-Distributed evaluation hook for iteration-based runner.
10 |
11 | This hook will regularly perform evaluation in a given interval when
12 | performing in non-distributed environment.
13 |
14 | Args:
15 | dataloader (DataLoader): A PyTorch dataloader.
16 | interval (int): Evaluation interval. Default: 1.
17 | eval_kwargs (dict): Other eval kwargs. It contains:
18 | save_image (bool): Whether to save image.
19 | save_path (str): The path to save image.
20 | """
21 |
22 | def __init__(self, dataloader, interval=1, **eval_kwargs):
23 | if not isinstance(dataloader, DataLoader):
24 | raise TypeError('dataloader must be a pytorch DataLoader, '
25 | f'but got { type(dataloader)}')
26 | self.dataloader = dataloader
27 | self.interval = interval
28 | self.eval_kwargs = eval_kwargs
29 | self.save_image = self.eval_kwargs.pop('save_image', False)
30 | self.save_path = self.eval_kwargs.pop('save_path', None)
31 |
32 | def after_train_iter(self, runner):
33 | """The behavior after each train iteration.
34 |
35 | Args:
36 | runner (``mmcv.runner.BaseRunner``): The runner.
37 | """
38 | if not self.every_n_iters(runner, self.interval):
39 | return
40 | runner.log_buffer.clear()
41 | from mmedit.apis import single_gpu_test
42 | results = single_gpu_test(
43 | runner.model,
44 | self.dataloader,
45 | save_image=self.save_image,
46 | save_path=self.save_path,
47 | iteration=runner.iter)
48 | self.evaluate(runner, results)
49 |
50 | def evaluate(self, runner, results):
51 | """Evaluation function.
52 |
53 | Args:
54 | runner (``mmcv.runner.BaseRunner``): The runner.
55 | results (dict): Model forward results.
56 | """
57 | eval_res = self.dataloader.dataset.evaluate(
58 | results, logger=runner.logger, **self.eval_kwargs)
59 | for name, val in eval_res.items():
60 | if isinstance(val, dict):
61 | runner.log_buffer.output.update(val)
62 | continue
63 | runner.log_buffer.output[name] = val
64 | runner.log_buffer.ready = True
65 | # call `after_val_epoch` after evaluation.
66 | # This is a hack.
67 | # Because epoch does not naturally exist In IterBasedRunner,
68 | # thus we consider the end of an evluation as the end of an epoch.
69 | # With this hack , we can support epoch based hooks.
70 | if 'iter' in runner.__class__.__name__.lower():
71 | runner.call_hook('after_val_epoch')
72 |
73 |
74 | class DistEvalIterHook(EvalIterHook):
75 | """Distributed evaluation hook.
76 |
77 | Args:
78 | dataloader (DataLoader): A PyTorch dataloader.
79 | interval (int): Evaluation interval. Default: 1.
80 | tmpdir (str | None): Temporary directory to save the results of all
81 | processes. Default: None.
82 | gpu_collect (bool): Whether to use gpu or cpu to collect results.
83 | Default: False.
84 | eval_kwargs (dict): Other eval kwargs. It may contain:
85 | save_image (bool): Whether save image.
86 | save_path (str): The path to save image.
87 | """
88 |
89 | def __init__(self,
90 | dataloader,
91 | interval=1,
92 | gpu_collect=False,
93 | **eval_kwargs):
94 | super().__init__(dataloader, interval, **eval_kwargs)
95 | self.gpu_collect = gpu_collect
96 |
97 | def after_train_iter(self, runner):
98 | """The behavior after each train iteration.
99 |
100 | Args:
101 | runner (``mmcv.runner.BaseRunner``): The runner.
102 | """
103 | if not self.every_n_iters(runner, self.interval):
104 | return
105 | runner.log_buffer.clear()
106 | from mmedit.apis import multi_gpu_test
107 | results = multi_gpu_test(
108 | runner.model,
109 | self.dataloader,
110 | tmpdir=osp.join(runner.work_dir, '.eval_hook'),
111 | gpu_collect=self.gpu_collect,
112 | save_image=self.save_image,
113 | save_path=self.save_path,
114 | iteration=runner.iter)
115 | if runner.rank == 0:
116 | print('\n')
117 | self.evaluate(runner, results)
118 |
--------------------------------------------------------------------------------
/configs/derainers/ViMPNet/ViMPNet.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | model = dict(
3 | type='DrainNet',
4 | generator=dict(
5 | type='ViMPNet',
6 | mid_channels=64,
7 | num_blocks=9,
8 | spynet_pretrained='https://download.openmmlab.com/mmediting/restorers/'
9 | 'basicvsr/spynet_20210409-c6c1bd09.pth'),
10 | pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'),
11 | )
12 | # model training and testing settings
13 | train_cfg = dict(fix_iter=5000)
14 | test_cfg = dict(metrics=['PSNR'], crop_border=0)
15 |
16 | train_dataset_type = 'SRFolderMultipleGTDataset'
17 | val_dataset_type = 'SRFolderMultipleGTDataset'
18 |
19 | train_pipeline = [
20 | dict(
21 | type='GenerateSegmentIndices',
22 | interval_list=[1],
23 | filename_tmpl='{08d}.png',
24 | start_idx=0),
25 | dict(
26 | type='LoadImageFromFileList',
27 | io_backend='disk',
28 | key='lq',
29 | channel_order='rgb'),
30 | dict(
31 | type='LoadImageFromFileList',
32 | io_backend='disk',
33 | key='gt',
34 | channel_order='rgb'),
35 | dict(
36 | type='LoadImageFromFileList',
37 | io_backend='disk',
38 | key='mask',
39 | flag='grayscale',
40 | channel_order='rgb'),
41 | dict(
42 | type='LoadImageFromFileList',
43 | io_backend='disk',
44 | key='drop',
45 | channel_order='rgb'),
46 | dict(
47 | type='LoadImageFromFileList',
48 | io_backend='disk',
49 | key='streak',
50 | flag='grayscale',
51 | channel_order='rgb'),
52 |
53 | dict(type='RescaleToZeroOne', keys=['lq', 'gt', 'mask', 'drop', 'streak']),
54 | dict(type='PairedRandomCrop', gt_patch_size=256),
55 | dict(
56 | type='Flip', keys=['lq', 'gt', 'mask', 'drop', 'streak'], flip_ratio=0.5,
57 | direction='horizontal'),
58 | dict(type='Flip', keys=['lq', 'gt', 'mask', 'drop', 'streak'], flip_ratio=0.5, direction='vertical'),
59 | dict(type='RandomTransposeHW', keys=['lq', 'gt', 'mask', 'drop', 'streak'], transpose_ratio=0.5),
60 | dict(type='FramesToTensor', keys=['lq', 'gt', 'mask', 'drop', 'streak']),
61 | dict(type='Collect', keys=['lq', 'gt', 'mask', 'drop', 'streak'], meta_keys=['lq_path', 'gt_path', 'mask_path', 'drop_path', 'streak_path'])
62 | ]
63 |
64 | test_pipeline = [
65 | dict(type='GenerateSegmentIndices', interval_list=[1], filename_tmpl='{:08d}.png'),
66 | dict(
67 | type='LoadImageFromFileList',
68 | io_backend='disk',
69 | key='lq',
70 | channel_order='rgb'),
71 | dict(
72 | type='LoadImageFromFileList',
73 | io_backend='disk',
74 | key='gt',
75 | channel_order='rgb'),
76 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
77 | dict(type='FramesToTensor', keys=['lq', 'gt']),
78 | dict(
79 | type='Collect',
80 | keys=['lq', 'gt'],
81 | meta_keys=['lq_path', 'gt_path', 'key'])
82 | ]
83 |
84 | data = dict(
85 | workers_per_gpu=5,
86 | train_dataloader=dict(samples_per_gpu=1, drop_last=True),
87 | val_dataloader=dict(samples_per_gpu=1),
88 | test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=1),
89 |
90 | # train
91 | train=dict(
92 | type='RepeatDataset',
93 | times=1000,
94 | dataset=dict(
95 | type=train_dataset_type,
96 | lq_folder="../data/lq",
97 | gt_folder="../data/gt",
98 | mask_folder="../data/01drop",
99 | drop_folder="../data/drop",
100 | streak_folder="../data/mstreak",
101 | pipeline=train_pipeline,
102 | scale=1,
103 | num_input_frames=5,
104 | test_mode=False)),
105 |
106 | # # test
107 | test=dict(
108 | type=val_dataset_type,
109 | lq_folder='../data/testlq',
110 | gt_folder='../data/testgt',
111 | pipeline=test_pipeline,
112 | scale=1,
113 | test_mode=True),
114 | )
115 |
116 | # optimizer
117 | optimizers = dict(
118 | generator=dict(
119 | type='Adam',
120 | lr=1e-4,
121 | betas=(0.9, 0.99),
122 | paramwise_cfg=dict(custom_keys={'spynet': dict(lr_mult=0.25)})))
123 |
124 | # learning policy
125 | total_iters = 500000
126 | lr_config = dict(
127 | policy='CosineRestart',
128 | by_epoch=False,
129 | periods=[500000],
130 | restart_weights=[1],
131 | min_lr=1e-9)
132 |
133 | checkpoint_config = dict(interval=5000, save_optimizer=True, by_epoch=False)
134 |
135 | log_config = dict(
136 | interval=1000,
137 | hooks=[
138 | dict(type='TextLoggerHook', by_epoch=False),
139 | ])
140 | visual_config = None
141 |
142 | # runtime settings
143 | dist_params = dict(backend='nccl')
144 | log_level = 'INFO'
145 |
146 | work_dir = f'../ViMPNettrain/'
147 | load_from = None
148 | resume_from = None
149 | workflow = [('train', 1)]
150 | find_unused_parameters = True
151 |
--------------------------------------------------------------------------------
/mmedit/apis/restoration_video_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import glob
3 | import os.path as osp
4 |
5 | import mmcv
6 | import numpy as np
7 | import torch
8 |
9 | from mmedit.datasets.pipelines import Compose
10 |
11 | VIDEO_EXTENSIONS = ('.mp4', '.mov')
12 |
13 |
14 | def pad_sequence(data, window_size):
15 | padding = window_size // 2
16 |
17 | data = torch.cat([
18 | data[:, 1 + padding:1 + 2 * padding].flip(1), data,
19 | data[:, -1 - 2 * padding:-1 - padding].flip(1)
20 | ],
21 | dim=1)
22 |
23 | return data
24 |
25 |
26 | def restoration_video_inference(model,
27 | img_dir,
28 | window_size,
29 | start_idx,
30 | filename_tmpl,
31 | max_seq_len=None):
32 | """Inference image with the model.
33 |
34 | Args:
35 | model (nn.Module): The loaded model.
36 | img_dir (str): Directory of the input video.
37 | window_size (int): The window size used in sliding-window framework.
38 | This value should be set according to the settings of the network.
39 | A value smaller than 0 means using recurrent framework.
40 | start_idx (int): The index corresponds to the first frame in the
41 | sequence.
42 | filename_tmpl (str): Template for file name.
43 | max_seq_len (int | None): The maximum sequence length that the model
44 | processes. If the sequence length is larger than this number,
45 | the sequence is split into multiple segments. If it is None,
46 | the entire sequence is processed at once.
47 |
48 | Returns:
49 | Tensor: The predicted restoration result.
50 | """
51 |
52 | device = next(model.parameters()).device # model device
53 |
54 | # build the data pipeline
55 | if model.cfg.get('demo_pipeline', None):
56 | test_pipeline = model.cfg.demo_pipeline
57 | elif model.cfg.get('test_pipeline', None):
58 | test_pipeline = model.cfg.test_pipeline
59 | else:
60 | test_pipeline = model.cfg.val_pipeline
61 |
62 | # check if the input is a video
63 | file_extension = osp.splitext(img_dir)[1]
64 | if file_extension in VIDEO_EXTENSIONS:
65 | video_reader = mmcv.VideoReader(img_dir)
66 | # load the images
67 | data = dict(lq=[], lq_path=None, key=img_dir)
68 | for frame in video_reader:
69 | data['lq'].append(np.flip(frame, axis=2))
70 |
71 | # remove the data loading pipeline
72 | tmp_pipeline = []
73 | for pipeline in test_pipeline:
74 | if pipeline['type'] not in [
75 | 'GenerateSegmentIndices', 'LoadImageFromFileList'
76 | ]:
77 | tmp_pipeline.append(pipeline)
78 | test_pipeline = tmp_pipeline
79 | else:
80 | # the first element in the pipeline must be 'GenerateSegmentIndices'
81 | if test_pipeline[0]['type'] != 'GenerateSegmentIndices':
82 | raise TypeError('The first element in the pipeline must be '
83 | f'"GenerateSegmentIndices", but got '
84 | f'"{test_pipeline[0]["type"]}".')
85 |
86 | # specify start_idx and filename_tmpl
87 | test_pipeline[0]['start_idx'] = start_idx
88 | test_pipeline[0]['filename_tmpl'] = filename_tmpl
89 |
90 | # prepare data
91 | sequence_length = len(glob.glob(osp.join(img_dir, '*')))
92 | lq_folder = osp.dirname(img_dir)
93 | key = osp.basename(img_dir)
94 | data = dict(
95 | lq_path=lq_folder,
96 | gt_path='',
97 | key=key,
98 | sequence_length=sequence_length)
99 |
100 | # compose the pipeline
101 | test_pipeline = Compose(test_pipeline)
102 | data = test_pipeline(data)
103 | data = data['lq'].unsqueeze(0) # in cpu
104 |
105 | # forward the model
106 | with torch.no_grad():
107 | if window_size > 0: # sliding window framework
108 | data = pad_sequence(data, window_size)
109 | result = []
110 | for i in range(0, data.size(1) - 2 * (window_size // 2)):
111 | data_i = data[:, i:i + window_size].to(device)
112 | result.append(model(lq=data_i, test_mode=True)['output'].cpu())
113 | result = torch.stack(result, dim=1)
114 | else: # recurrent framework
115 | if max_seq_len is None:
116 | result = model(
117 | lq=data.to(device), test_mode=True)['output'].cpu()
118 | else:
119 | result = []
120 | for i in range(0, data.size(1), max_seq_len):
121 | result.append(
122 | model(
123 | lq=data[:, i:i + max_seq_len].to(device),
124 | test_mode=True)['output'].cpu())
125 | result = torch.cat(result, dim=1)
126 | return result
127 |
--------------------------------------------------------------------------------
/mmedit/core/hooks/ema.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 | from copy import deepcopy
4 | from functools import partial
5 |
6 | import mmcv
7 | import torch
8 | from mmcv.parallel import is_module_wrapper
9 | from mmcv.runner import HOOKS, Hook
10 |
11 |
12 | @HOOKS.register_module()
13 | class ExponentialMovingAverageHook(Hook):
14 | """Exponential Moving Average Hook.
15 |
16 | Exponential moving average is a trick that widely used in current GAN
17 | literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is
18 | maintaining a model with the same architecture, but its parameters are
19 | updated as a moving average of the trained weights in the original model.
20 | In general, the model with moving averaged weights achieves better
21 | performance.
22 |
23 | Args:
24 | module_keys (str | tuple[str]): The name of the ema model. Note that we
25 | require these keys are followed by '_ema' so that we can easily
26 | find the original model by discarding the last four characters.
27 | interp_mode (str, optional): Mode of the interpolation method.
28 | Defaults to 'lerp'.
29 | interp_cfg (dict | None, optional): Set arguments of the interpolation
30 | function. Defaults to None.
31 | interval (int, optional): Evaluation interval (by iterations).
32 | Default: -1.
33 | start_iter (int, optional): Start iteration for ema. If the start
34 | iteration is not reached, the weights of ema model will maintain
35 | the same as the original one. Otherwise, its parameters are updated
36 | as a moving average of the trained weights in the original model.
37 | Default: 0.
38 | """
39 |
40 | def __init__(self,
41 | module_keys,
42 | interp_mode='lerp',
43 | interp_cfg=None,
44 | interval=-1,
45 | start_iter=0):
46 | super().__init__()
47 | assert isinstance(module_keys, str) or mmcv.is_tuple_of(
48 | module_keys, str)
49 | self.module_keys = (module_keys, ) if isinstance(module_keys,
50 | str) else module_keys
51 | # sanity check for the format of module keys
52 | for k in self.module_keys:
53 | assert k.endswith(
54 | '_ema'), 'You should give keys that end with "_ema".'
55 | self.interp_mode = interp_mode
56 | self.interp_cfg = dict() if interp_cfg is None else deepcopy(
57 | interp_cfg)
58 | self.interval = interval
59 | self.start_iter = start_iter
60 |
61 | assert hasattr(
62 | self, interp_mode
63 | ), f'Currently, we do not support {self.interp_mode} for EMA.'
64 | self.interp_func = partial(
65 | getattr(self, interp_mode), **self.interp_cfg)
66 |
67 | @staticmethod
68 | def lerp(a, b, momentum=0.999, momentum_nontrainable=0., trainable=True):
69 | m = momentum if trainable else momentum_nontrainable
70 | return a + (b - a) * m
71 |
72 | def every_n_iters(self, runner, n):
73 | if runner.iter < self.start_iter:
74 | return True
75 | return (runner.iter + 1 - self.start_iter) % n == 0 if n > 0 else False
76 |
77 | @torch.no_grad()
78 | def after_train_iter(self, runner):
79 | if not self.every_n_iters(runner, self.interval):
80 | return
81 |
82 | model = runner.model.module if is_module_wrapper(
83 | runner.model) else runner.model
84 |
85 | for key in self.module_keys:
86 | # get current ema states
87 | ema_net = getattr(model, key)
88 | states_ema = ema_net.state_dict(keep_vars=False)
89 | # get currently original states
90 | net = getattr(model, key[:-4])
91 | states_orig = net.state_dict(keep_vars=True)
92 |
93 | for k, v in states_orig.items():
94 | if runner.iter < self.start_iter:
95 | states_ema[k].data.copy_(v.data)
96 | else:
97 | states_ema[k] = self.interp_func(
98 | v, states_ema[k], trainable=v.requires_grad).detach()
99 | ema_net.load_state_dict(states_ema, strict=True)
100 |
101 | def before_run(self, runner):
102 | model = runner.model.module if is_module_wrapper(
103 | runner.model) else runner.model
104 | # sanity check for ema model
105 | for k in self.module_keys:
106 | if not hasattr(model, k) and not hasattr(model, k[:-4]):
107 | raise RuntimeError(
108 | f'Cannot find both {k[:-4]} and {k} network for EMA hook.')
109 | if not hasattr(model, k) and hasattr(model, k[:-4]):
110 | setattr(model, k, deepcopy(getattr(model, k[:-4])))
111 | warnings.warn(
112 | f'We do not suggest construct and initialize EMA model {k}'
113 | ' in hook. You may explicitly define it by yourself.')
114 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/random_down_sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import math
3 |
4 | import numpy as np
5 | import torch
6 | from mmcv import imresize
7 |
8 | from ..registry import PIPELINES
9 |
10 |
11 | @PIPELINES.register_module()
12 | class RandomDownSampling:
13 | """Generate LQ image from GT (and crop), which will randomly pick a scale.
14 |
15 | Args:
16 | scale_min (float): The minimum of upsampling scale, inclusive.
17 | Default: 1.0.
18 | scale_max (float): The maximum of upsampling scale, exclusive.
19 | Default: 4.0.
20 | patch_size (int): The cropped lr patch size.
21 | Default: None, means no crop.
22 | interpolation (str): Interpolation method, accepted values are
23 | "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
24 | backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
25 | "hamming" for 'pillow' backend.
26 | Default: "bicubic".
27 | backend (str | None): The image resize backend type. Options are `cv2`,
28 | `pillow`, `None`. If backend is None, the global imread_backend
29 | specified by ``mmcv.use_backend()`` will be used.
30 | Default: "pillow".
31 |
32 | Scale will be picked in the range of [scale_min, scale_max).
33 | """
34 |
35 | def __init__(self,
36 | scale_min=1.0,
37 | scale_max=4.0,
38 | patch_size=None,
39 | interpolation='bicubic',
40 | backend='pillow'):
41 | assert scale_max >= scale_min
42 | self.scale_min = scale_min
43 | self.scale_max = scale_max
44 | self.patch_size = patch_size
45 | self.interpolation = interpolation
46 | self.backend = backend
47 |
48 | def __call__(self, results):
49 | """Call function.
50 |
51 | Args:
52 | results (dict): A dict containing the necessary information and
53 | data for augmentation. 'gt' is required.
54 |
55 | Returns:
56 | dict: A dict containing the processed data and information.
57 | modified 'gt', supplement 'lq' and 'scale' to keys.
58 | """
59 | img = results['gt']
60 | scale = np.random.uniform(self.scale_min, self.scale_max)
61 |
62 | if self.patch_size is None:
63 | h_lr = math.floor(img.shape[-3] / scale + 1e-9)
64 | w_lr = math.floor(img.shape[-2] / scale + 1e-9)
65 | img = img[:round(h_lr * scale), :round(w_lr * scale), :]
66 | img_down = resize_fn(img, (w_lr, h_lr), self.interpolation,
67 | self.backend)
68 | crop_lr, crop_hr = img_down, img
69 | else:
70 | w_lr = self.patch_size
71 | w_hr = round(w_lr * scale)
72 | x0 = np.random.randint(0, img.shape[-3] - w_hr)
73 | y0 = np.random.randint(0, img.shape[-2] - w_hr)
74 | crop_hr = img[x0:x0 + w_hr, y0:y0 + w_hr, :]
75 | crop_lr = resize_fn(crop_hr, w_lr, self.interpolation,
76 | self.backend)
77 | results['gt'] = crop_hr
78 | results['lq'] = crop_lr
79 | results['scale'] = scale
80 |
81 | return results
82 |
83 | def __repr__(self):
84 | repr_str = self.__class__.__name__
85 | repr_str += (f' scale_min={self.scale_min}, '
86 | f'scale_max={self.scale_max}, '
87 | f'patch_size={self.patch_size}, '
88 | f'interpolation={self.interpolation}, '
89 | f'backend={self.backend}')
90 |
91 | return repr_str
92 |
93 |
94 | def resize_fn(img, size, interpolation='bicubic', backend='pillow'):
95 | """Resize the given image to a given size.
96 |
97 | Args:
98 | img (ndarray | torch.Tensor): The input image.
99 | size (int | tuple[int]): Target size w or (w, h).
100 | interpolation (str): Interpolation method, accepted values are
101 | "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
102 | backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
103 | "hamming" for 'pillow' backend.
104 | Default: "bicubic".
105 | backend (str | None): The image resize backend type. Options are `cv2`,
106 | `pillow`, `None`. If backend is None, the global imread_backend
107 | specified by ``mmcv.use_backend()`` will be used.
108 | Default: "pillow".
109 |
110 | Returns:
111 | ndarray | torch.Tensor: `resized_img`, whose type is same as `img`.
112 | """
113 | if isinstance(size, int):
114 | size = (size, size)
115 | if isinstance(img, np.ndarray):
116 | return imresize(
117 | img, size, interpolation=interpolation, backend=backend)
118 | elif isinstance(img, torch.Tensor):
119 | image = imresize(
120 | img.numpy(), size, interpolation=interpolation, backend=backend)
121 | return torch.from_numpy(image)
122 |
123 | else:
124 | raise TypeError('img should got np.ndarray or torch.Tensor,'
125 | f'but got {type(img)}')
126 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | from mmcv.utils import print_log
7 |
8 | _integer_types = (
9 | np.byte,
10 | np.ubyte, # 8 bits
11 | np.short,
12 | np.ushort, # 16 bits
13 | np.intc,
14 | np.uintc, # 16 or 32 or 64 bits
15 | np.int_,
16 | np.uint, # 32 or 64 bits
17 | np.longlong,
18 | np.ulonglong) # 64 bits
19 |
20 | _integer_ranges = {
21 | t: (np.iinfo(t).min, np.iinfo(t).max)
22 | for t in _integer_types
23 | }
24 |
25 | dtype_range = {
26 | np.bool_: (False, True),
27 | np.bool8: (False, True),
28 | np.float16: (-1, 1),
29 | np.float32: (-1, 1),
30 | np.float64: (-1, 1)
31 | }
32 | dtype_range.update(_integer_ranges)
33 |
34 |
35 | def dtype_limits(image, clip_negative=False):
36 | """Return intensity limits, i.e. (min, max) tuple, of the image's dtype.
37 |
38 | This function is adopted from skimage:
39 | https://github.com/scikit-image/scikit-image/blob/
40 | 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/util/dtype.py#L35
41 |
42 | Args:
43 | image (ndarray): Input image.
44 | clip_negative (bool, optional): If True, clip the negative range
45 | (i.e. return 0 for min intensity) even if the image dtype allows
46 | negative values.
47 |
48 | Returns
49 | tuple: Lower and upper intensity limits.
50 | """
51 | imin, imax = dtype_range[image.dtype.type]
52 | if clip_negative:
53 | imin = 0
54 | return imin, imax
55 |
56 |
57 | def adjust_gamma(image, gamma=1, gain=1):
58 | """Performs Gamma Correction on the input image.
59 |
60 | This function is adopted from skimage:
61 | https://github.com/scikit-image/scikit-image/blob/
62 | 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/exposure/
63 | exposure.py#L439-L494
64 |
65 | Also known as Power Law Transform.
66 | This function transforms the input image pixelwise according to the
67 | equation ``O = I**gamma`` after scaling each pixel to the range 0 to 1.
68 |
69 | Args:
70 | image (ndarray): Input image.
71 | gamma (float, optional): Non negative real number. Defaults to 1.
72 | gain (float, optional): The constant multiplier. Defaults to 1.
73 |
74 | Returns:
75 | ndarray: Gamma corrected output image.
76 | """
77 | if np.any(image < 0):
78 | raise ValueError('Image Correction methods work correctly only on '
79 | 'images with non-negative values. Use '
80 | 'skimage.exposure.rescale_intensity.')
81 |
82 | dtype = image.dtype.type
83 |
84 | if gamma < 0:
85 | raise ValueError('Gamma should be a non-negative real number.')
86 |
87 | scale = float(dtype_limits(image, True)[1] - dtype_limits(image, True)[0])
88 |
89 | out = ((image / scale)**gamma) * scale * gain
90 | return out.astype(dtype)
91 |
92 |
93 | def random_choose_unknown(unknown, crop_size):
94 | """Randomly choose an unknown start (top-left) point for a given crop_size.
95 |
96 | Args:
97 | unknown (np.ndarray): The binary unknown mask.
98 | crop_size (tuple[int]): The given crop size.
99 |
100 | Returns:
101 | tuple[int]: The top-left point of the chosen bbox.
102 | """
103 | h, w = unknown.shape
104 | crop_h, crop_w = crop_size
105 | delta_h = center_h = crop_h // 2
106 | delta_w = center_w = crop_w // 2
107 |
108 | # mask out the validate area for selecting the cropping center
109 | mask = np.zeros_like(unknown)
110 | mask[delta_h:h - delta_h, delta_w:w - delta_w] = 1
111 | if np.any(unknown & mask):
112 | center_h_list, center_w_list = np.where(unknown & mask)
113 | elif np.any(unknown):
114 | center_h_list, center_w_list = np.where(unknown)
115 | else:
116 | print_log('No unknown pixels found!', level=logging.WARNING)
117 | center_h_list = [center_h]
118 | center_w_list = [center_w]
119 | num_unknowns = len(center_h_list)
120 | rand_ind = np.random.randint(num_unknowns)
121 | center_h = center_h_list[rand_ind]
122 | center_w = center_w_list[rand_ind]
123 |
124 | # make sure the top-left point is valid
125 | top = np.clip(center_h - delta_h, 0, h - crop_h)
126 | left = np.clip(center_w - delta_w, 0, w - crop_w)
127 |
128 | return top, left
129 |
130 |
131 | def make_coord(shape, ranges=None, flatten=True):
132 | """Make coordinates at grid centers.
133 |
134 | Args:
135 | shape (tuple): shape of image.
136 | ranges (tuple): range of coordinate value. Default: None.
137 | flatten (bool): flatten to (n, 2) or Not. Default: True.
138 |
139 | return:
140 | coord (Tensor): coordinates.
141 | """
142 | coord_seqs = []
143 | for i, n in enumerate(shape):
144 | if ranges is None:
145 | v0, v1 = -1, 1
146 | else:
147 | v0, v1 = ranges[i]
148 | r = (v1 - v0) / (2 * n)
149 | seq = v0 + r + (2 * r) * torch.arange(n).float()
150 | coord_seqs.append(seq)
151 | coord = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
152 | if flatten:
153 | coord = coord.view(-1, coord.shape[-1])
154 | return coord
155 |
--------------------------------------------------------------------------------
/mmedit/datasets/sr_folder_multiple_gt_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import glob
3 | import os
4 | import os.path as osp
5 |
6 | import mmcv
7 |
8 | from .base_sr_dataset import BaseSRDataset
9 | from .registry import DATASETS
10 |
11 |
12 | @DATASETS.register_module()
13 | class SRFolderMultipleGTDataset(BaseSRDataset):
14 | """General dataset for video super resolution, used for recurrent networks.
15 |
16 | The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
17 | frames. Then it applies specified transforms and finally returns a dict
18 | containing paired data and other information.
19 |
20 | This dataset takes an annotation file specifying the sequences used in
21 | training or test. If no annotation file is provided, it assumes all video
22 | sequences under the root directory is used for training or test.
23 |
24 | In the annotation file (.txt), each line contains:
25 |
26 | 1. folder name;
27 | 2. number of frames in this sequence (in the same folder)
28 |
29 | Examples:
30 |
31 | ::
32 |
33 | calendar 41
34 | city 34
35 | foliage 49
36 | walk 47
37 |
38 | Args:
39 | lq_folder (str | :obj:`Path`): Path to a lq folder.
40 | gt_folder (str | :obj:`Path`): Path to a gt folder.
41 | pipeline (list[dict | callable]): A sequence of data transformations.
42 | scale (int): Upsampling scale ratio.
43 | ann_file (str): The path to the annotation file. If None, we assume
44 | that all sequences in the folder is used. Default: None
45 | num_input_frames (None | int): The number of frames per iteration.
46 | If None, the whole clip is extracted. If it is a positive integer,
47 | a sequence of 'num_input_frames' frames is extracted from the clip.
48 | Note that non-positive integers are not accepted. Default: None.
49 | test_mode (bool): Store `True` when building test dataset.
50 | Default: `True`.
51 | """
52 |
53 | def __init__(self,
54 | lq_folder,
55 | gt_folder,
56 | mask_folder,
57 | drop_folder,
58 | streak_folder,
59 | pipeline,
60 | scale,
61 | ann_file=None,
62 | num_input_frames=None,
63 | test_mode=True):
64 | super().__init__(pipeline, scale, test_mode)
65 |
66 | self.lq_folder = str(lq_folder)
67 | self.gt_folder = str(gt_folder)
68 | self.mask_folder = str(mask_folder)
69 | self.drop_folder = str(drop_folder)
70 | self.streak_folder = str(streak_folder)
71 |
72 | self.ann_file = ann_file
73 | if num_input_frames is not None and num_input_frames <= 0:
74 | raise ValueError('"num_input_frames" must be None or positive, '
75 | f'but got {num_input_frames}.')
76 | self.num_input_frames = num_input_frames
77 |
78 | self.data_infos = self.load_annotations()
79 |
80 | def _load_annotations_from_file(self):
81 | data_infos = []
82 |
83 | ann_list = mmcv.list_from_file(self.ann_file)
84 | for ann in ann_list:
85 | key, sequence_length = ann.strip().split(' ')
86 | if self.num_input_frames is None:
87 | num_input_frames = sequence_length
88 | else:
89 | num_input_frames = self.num_input_frames
90 | data_infos.append(
91 | dict(
92 | lq_path=self.lq_folder,
93 | gt_path=self.gt_folder,
94 | mask_path=self.mask_folder,
95 | drop_path=self.drop_folder,
96 | streak_path=self.streak_folder,
97 | key=key,
98 | num_input_frames=int(num_input_frames),
99 | sequence_length=int(sequence_length)))
100 |
101 | return data_infos
102 |
103 | def load_annotations(self):
104 | """Load annotations for the dataset.
105 |
106 | Returns:
107 | list[dict]: Returned list of dicts for paired paths of LQ and GT.
108 | """
109 |
110 | if self.ann_file:
111 | return self._load_annotations_from_file()
112 |
113 | sequences = sorted(glob.glob(osp.join(self.lq_folder, '*')))
114 | data_infos = []
115 | for sequence in sequences:
116 | sequence_length = len(glob.glob(osp.join(sequence, '*.png')))
117 | if self.num_input_frames is None:
118 | num_input_frames = sequence_length
119 | else:
120 | num_input_frames = self.num_input_frames
121 | data_infos.append(
122 | dict(
123 | lq_path=self.lq_folder,
124 | gt_path=self.gt_folder,
125 | mask_path=self.mask_folder,
126 | drop_path=self.drop_folder,
127 | streak_path=self.streak_folder,
128 | key=sequence.replace(f'{self.lq_folder}{os.sep}', ''),
129 | num_input_frames=num_input_frames,
130 | sequence_length=sequence_length))
131 |
132 | return data_infos
133 |
--------------------------------------------------------------------------------
/mmedit/core/export/wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 | import warnings
4 |
5 | import numpy as np
6 | import onnxruntime as ort
7 | import torch
8 | from torch import nn
9 |
10 | from mmedit.models import BaseMattor, BasicRestorer, build_model
11 |
12 |
13 | def inference_with_session(sess, io_binding, output_names, input_tensor):
14 | device_type = input_tensor.device.type
15 | device_id = input_tensor.device.index
16 | device_id = 0 if device_id is None else device_id
17 | io_binding.bind_input(
18 | name='input',
19 | device_type=device_type,
20 | device_id=device_id,
21 | element_type=np.float32,
22 | shape=input_tensor.shape,
23 | buffer_ptr=input_tensor.data_ptr())
24 | for name in output_names:
25 | io_binding.bind_output(name)
26 | sess.run_with_iobinding(io_binding)
27 | pred = io_binding.copy_outputs_to_cpu()
28 | return pred
29 |
30 |
31 | class ONNXRuntimeMattor(nn.Module):
32 |
33 | def __init__(self, sess, io_binding, output_names, base_model):
34 | super(ONNXRuntimeMattor, self).__init__()
35 | self.sess = sess
36 | self.io_binding = io_binding
37 | self.output_names = output_names
38 | self.base_model = base_model
39 |
40 | def forward(self,
41 | merged,
42 | trimap,
43 | meta,
44 | test_mode=False,
45 | save_image=False,
46 | save_path=None,
47 | iteration=None):
48 | input_tensor = torch.cat((merged, trimap), 1).contiguous()
49 | pred_alpha = inference_with_session(self.sess, self.io_binding,
50 | self.output_names, input_tensor)[0]
51 |
52 | pred_alpha = pred_alpha.squeeze()
53 | pred_alpha = self.base_model.restore_shape(pred_alpha, meta)
54 | eval_result = self.base_model.evaluate(pred_alpha, meta)
55 |
56 | if save_image:
57 | self.base_model.save_image(pred_alpha, meta, save_path, iteration)
58 |
59 | return {'pred_alpha': pred_alpha, 'eval_result': eval_result}
60 |
61 |
62 | class RestorerGenerator(nn.Module):
63 |
64 | def __init__(self, sess, io_binding, output_names):
65 | super(RestorerGenerator, self).__init__()
66 | self.sess = sess
67 | self.io_binding = io_binding
68 | self.output_names = output_names
69 |
70 | def forward(self, x):
71 | pred = inference_with_session(self.sess, self.io_binding,
72 | self.output_names, x)[0]
73 | pred = torch.from_numpy(pred)
74 | return pred
75 |
76 |
77 | class ONNXRuntimeRestorer(nn.Module):
78 |
79 | def __init__(self, sess, io_binding, output_names, base_model):
80 | super(ONNXRuntimeRestorer, self).__init__()
81 | self.sess = sess
82 | self.io_binding = io_binding
83 | self.output_names = output_names
84 | self.base_model = base_model
85 | restorer_generator = RestorerGenerator(self.sess, self.io_binding,
86 | self.output_names)
87 | base_model.generator = restorer_generator
88 |
89 | def forward(self, lq, gt=None, test_mode=False, **kwargs):
90 | return self.base_model(lq, gt=gt, test_mode=test_mode, **kwargs)
91 |
92 |
93 | class ONNXRuntimeEditing(nn.Module):
94 |
95 | def __init__(self, onnx_file, cfg, device_id):
96 | super(ONNXRuntimeEditing, self).__init__()
97 | ort_custom_op_path = ''
98 | try:
99 | from mmcv.ops import get_onnxruntime_op_path
100 | ort_custom_op_path = get_onnxruntime_op_path()
101 | except (ImportError, ModuleNotFoundError):
102 | warnings.warn('If input model has custom op from mmcv, \
103 | you may have to build mmcv with ONNXRuntime from source.')
104 | session_options = ort.SessionOptions()
105 | # register custom op for onnxruntime
106 | if osp.exists(ort_custom_op_path):
107 | session_options.register_custom_ops_library(ort_custom_op_path)
108 | sess = ort.InferenceSession(onnx_file, session_options)
109 | providers = ['CPUExecutionProvider']
110 | options = [{}]
111 | is_cuda_available = ort.get_device() == 'GPU'
112 | if is_cuda_available:
113 | providers.insert(0, 'CUDAExecutionProvider')
114 | options.insert(0, {'device_id': device_id})
115 |
116 | sess.set_providers(providers, options)
117 |
118 | self.sess = sess
119 | self.device_id = device_id
120 | self.io_binding = sess.io_binding()
121 | self.output_names = [_.name for _ in sess.get_outputs()]
122 |
123 | base_model = build_model(
124 | cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
125 |
126 | if isinstance(base_model, BaseMattor):
127 | WrapperClass = ONNXRuntimeMattor
128 | elif isinstance(base_model, BasicRestorer):
129 | WrapperClass = ONNXRuntimeRestorer
130 | self.wrapper = WrapperClass(self.sess, self.io_binding,
131 | self.output_names, base_model)
132 |
133 | def forward(self, **kwargs):
134 | return self.wrapper(**kwargs)
135 |
--------------------------------------------------------------------------------
/mmedit/core/evaluation/inceptions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 | from scipy import linalg
5 |
6 | from ..registry import METRICS
7 | from .inception_utils import load_inception
8 |
9 |
10 | class InceptionV3:
11 | """Feature extractor features using InceptionV3 model.
12 |
13 | Args:
14 | style (str): The model style to run Inception model. it must be either
15 | 'StyleGAN' or 'pytorch'.
16 | device (torch.device): device to extract feature.
17 | inception_kwargs (**kwargs): kwargs for InceptionV3.
18 | """
19 |
20 | def __init__(self, style='StyleGAN', device='cpu', **inception_kwargs):
21 | self.inception = load_inception(
22 | style=style, **inception_kwargs).eval().to(device)
23 | self.style = style
24 | self.device = device
25 |
26 | def __call__(self, img1, img2, crop_border=0):
27 | """Extract features of real and fake images.
28 |
29 | Args:
30 | img1, img2 (np.ndarray): Images with range [0, 255]
31 | and shape (H, W, C).
32 |
33 | Returns:
34 | (tuple): Pair of features extracted from InceptionV3 model.
35 | """
36 | return (
37 | self.forward_inception(self.img2tensor(img1)).numpy(),
38 | self.forward_inception(self.img2tensor(img2)).numpy(),
39 | )
40 |
41 | def img2tensor(self, img):
42 | img = np.expand_dims(img.transpose((2, 0, 1)), axis=0)
43 | if self.style == 'StyleGAN':
44 | return torch.tensor(img).to(device=self.device, dtype=torch.uint8)
45 |
46 | return torch.from_numpy(img / 255.).to(
47 | device=self.device, dtype=torch.float32)
48 |
49 | def forward_inception(self, x):
50 | if self.style == 'StyleGAN':
51 | return self.inception(x).cpu()
52 |
53 | return self.inception(x)[-1].view(x.shape[0], -1).cpu()
54 |
55 |
56 | def frechet_distance(X, Y):
57 | """Compute the frechet distance."""
58 |
59 | muX, covX = np.mean(X, axis=0), np.cov(X, rowvar=False)
60 | muY, covY = np.mean(Y, axis=0), np.cov(Y, rowvar=False)
61 |
62 | cov_sqrt = linalg.sqrtm(covX.dot(covY))
63 | frechet_distance = np.square(muX - muY).sum() + np.trace(covX) + np.trace(
64 | covY) - 2 * np.trace(cov_sqrt)
65 | return np.real(frechet_distance)
66 |
67 |
68 | @METRICS.register_module()
69 | class FID:
70 | """FID metric."""
71 |
72 | def __call__(self, X, Y):
73 | """Calculate FID.
74 |
75 | Args:
76 | X (np.ndarray): Input feature X with shape (n_samples, dims).
77 | Y (np.ndarray): Input feature Y with shape (n_samples, dims).
78 |
79 | Returns:
80 | (float): fid value.
81 | """
82 | return frechet_distance(X, Y)
83 |
84 |
85 | def polynomial_kernel(X, Y=None, degree=3, gamma=None, coef=1):
86 | """Create a polynomial kernel."""
87 | Y = X if Y is None else Y
88 | if gamma is None:
89 | gamma = 1.0 / X.shape[1]
90 | K = ((X @ Y.T) * gamma + coef)**degree
91 | return K
92 |
93 |
94 | def mmd2(X, Y, unbiased=True):
95 | """Compute the Maximum Mean Discrepancy."""
96 | XX = polynomial_kernel(X, X)
97 | YY = polynomial_kernel(Y, Y)
98 | XY = polynomial_kernel(X, Y)
99 |
100 | m = X.shape[0]
101 | if not unbiased:
102 | return (np.sum(XX) + np.sum(YY) - 2 * np.sum(XY)) / m**2
103 |
104 | trX = np.trace(XX)
105 | trY = np.trace(YY)
106 | return (np.sum(XX) - trX + np.sum(YY) -
107 | trY) / (m * (m - 1)) - 2 * np.sum(XY) / m**2
108 |
109 |
110 | @METRICS.register_module()
111 | class KID:
112 | """Implementation of `KID `.
113 |
114 | Args:
115 | num_repeats (int): The number of repetitions. Default: 100.
116 | sample_size (int): Size to sample. Default: 1000.
117 | use_unbiased_estimator (bool): Whether to use KID as an unbiased
118 | estimator. Using an unbiased estimator is desirable in the case of
119 | finite sample size, especially when the number of samples are
120 | small. Using an unbiased estimator is recommended in most cases.
121 | Default: True
122 | """
123 |
124 | def __init__(self,
125 | num_repeats=100,
126 | sample_size=1000,
127 | use_unbiased_estimator=True):
128 | self.num_repeats = num_repeats
129 | self.sample_size = sample_size
130 | self.unbiased = use_unbiased_estimator
131 |
132 | def __call__(self, X, Y):
133 | """Calculate KID.
134 |
135 | Args:
136 | X (np.ndarray): Input feature X with shape (n_samples, dims).
137 | Y (np.ndarray): Input feature Y with shape (n_samples, dims).
138 |
139 | Returns:
140 | (dict): dict containing mean and std of KID values.
141 | """
142 | num_samples = X.shape[0]
143 | kid = list()
144 | for i in range(self.num_repeats):
145 | X_ = X[np.random.choice(
146 | num_samples, self.sample_size, replace=False)]
147 | Y_ = Y[np.random.choice(
148 | num_samples, self.sample_size, replace=False)]
149 | kid.append(mmd2(X_, Y_, unbiased=self.unbiased))
150 | kid = np.array(kid)
151 | return dict(KID_MEAN=kid.mean(), KID_STD=kid.std())
152 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 |
5 | import mmcv
6 | import torch
7 | from mmcv import Config, DictAction
8 | from mmcv.parallel import MMDataParallel
9 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint
10 |
11 | from mmedit.apis import multi_gpu_test, set_random_seed, single_gpu_test
12 | from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper
13 | from mmedit.datasets import build_dataloader, build_dataset
14 | from mmedit.models import build_model
15 | from mmedit.utils import setup_multi_processes
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='mmediting tester')
20 | parser.add_argument('config', help='test config file path')
21 | parser.add_argument('checkpoint', help='checkpoint file')
22 | parser.add_argument('--seed', type=int, default=None, help='random seed')
23 | parser.add_argument(
24 | '--deterministic',
25 | action='store_true',
26 | help='whether to set deterministic options for CUDNN backend.')
27 | parser.add_argument('--out', help='output result pickle file')
28 | parser.add_argument(
29 | '--gpu-collect',
30 | action='store_true',
31 | help='whether to use gpu to collect results')
32 | parser.add_argument(
33 | '--save-path',
34 | default=None,
35 | type=str,
36 | help='path to store images and if not given, will not save image')
37 | parser.add_argument('--tmpdir', help='tmp dir for writing some results')
38 | parser.add_argument(
39 | '--cfg-options',
40 | nargs='+',
41 | action=DictAction,
42 | help='override some settings in the used config, the key-value pair '
43 | 'in xxx=yyy format will be merged into config file. If the value to '
44 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
45 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
46 | 'Note that the quotation marks are necessary and that no white space '
47 | 'is allowed.')
48 | parser.add_argument(
49 | '--launcher',
50 | choices=['none', 'pytorch', 'slurm', 'mpi'],
51 | default='none',
52 | help='job launcher')
53 | parser.add_argument('--local_rank', type=int, default=0)
54 | args = parser.parse_args()
55 | if 'LOCAL_RANK' not in os.environ:
56 | os.environ['LOCAL_RANK'] = str(args.local_rank)
57 | return args
58 |
59 |
60 | def main():
61 | args = parse_args()
62 |
63 | cfg = Config.fromfile(args.config)
64 |
65 | if args.cfg_options is not None:
66 | cfg.merge_from_dict(args.cfg_options)
67 |
68 | # set multi-process settings
69 | setup_multi_processes(cfg)
70 |
71 | # set cudnn_benchmark
72 | if cfg.get('cudnn_benchmark', False):
73 | torch.backends.cudnn.benchmark = True
74 | cfg.model.pretrained = None
75 |
76 | # init distributed env first, since logger depends on the dist info.
77 | if args.launcher == 'none':
78 | distributed = False
79 | else:
80 | distributed = True
81 | init_dist(args.launcher, **cfg.dist_params)
82 |
83 | rank, _ = get_dist_info()
84 |
85 | # set random seeds
86 | if args.seed is not None:
87 | if rank == 0:
88 | print('set random seed to', args.seed)
89 | set_random_seed(args.seed, deterministic=args.deterministic)
90 |
91 | # build the dataloader
92 | # TODO: support multiple images per gpu (only minor changes are needed)
93 | dataset = build_dataset(cfg.data.test)
94 |
95 | loader_cfg = {
96 | **dict((k, cfg.data[k]) for k in ['workers_per_gpu'] if k in cfg.data),
97 | **dict(
98 | samples_per_gpu=1,
99 | drop_last=False,
100 | shuffle=False,
101 | dist=distributed),
102 | **cfg.data.get('test_dataloader', {})
103 | }
104 |
105 | data_loader = build_dataloader(dataset, **loader_cfg)
106 |
107 | # build the model and load checkpoint
108 | model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
109 |
110 | args.save_image = args.save_path is not None
111 | empty_cache = cfg.get('empty_cache', False)
112 | if not distributed:
113 | _ = load_checkpoint(model, args.checkpoint, map_location='cpu')
114 | model = MMDataParallel(model, device_ids=[0])
115 | outputs = single_gpu_test(
116 | model,
117 | data_loader,
118 | save_path=args.save_path,
119 | save_image=args.save_image)
120 | else:
121 | find_unused_parameters = cfg.get('find_unused_parameters', False)
122 | model = DistributedDataParallelWrapper(
123 | model,
124 | device_ids=[torch.cuda.current_device()],
125 | broadcast_buffers=False,
126 | find_unused_parameters=find_unused_parameters)
127 |
128 | device_id = torch.cuda.current_device()
129 | _ = load_checkpoint(
130 | model,
131 | args.checkpoint,
132 | map_location=lambda storage, loc: storage.cuda(device_id))
133 | outputs = multi_gpu_test(
134 | model,
135 | data_loader,
136 | args.tmpdir,
137 | args.gpu_collect,
138 | save_path=args.save_path,
139 | save_image=args.save_image,
140 | empty_cache=empty_cache)
141 |
142 | if rank == 0 and 'eval_result' in outputs[0]:
143 | print('')
144 | # print metrics
145 | stats = dataset.evaluate(outputs)
146 | for stat in stats:
147 | print('Eval-{}: {}'.format(stat, stats[stat]))
148 |
149 | # save result pickle
150 | if args.out:
151 | print('writing results to {}'.format(args.out))
152 | mmcv.dump(outputs, args.out)
153 |
154 |
155 | if __name__ == '__main__':
156 | main()
157 |
--------------------------------------------------------------------------------
/mmedit/core/distributed_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
5 | from mmcv.parallel.scatter_gather import scatter_kwargs
6 | from torch.cuda._utils import _get_device_index
7 |
8 |
9 | @MODULE_WRAPPERS.register_module()
10 | class DistributedDataParallelWrapper(nn.Module):
11 | """A DistributedDataParallel wrapper for models in MMediting.
12 |
13 | In MMedting, there is a need to wrap different modules in the models
14 | with separate DistributedDataParallel. Otherwise, it will cause
15 | errors for GAN training.
16 | More specific, the GAN model, usually has two sub-modules:
17 | generator and discriminator. If we wrap both of them in one
18 | standard DistributedDataParallel, it will cause errors during training,
19 | because when we update the parameters of the generator (or discriminator),
20 | the parameters of the discriminator (or generator) is not updated, which is
21 | not allowed for DistributedDataParallel.
22 | So we design this wrapper to separately wrap DistributedDataParallel
23 | for generator and discriminator.
24 |
25 | In this wrapper, we perform two operations:
26 | 1. Wrap the modules in the models with separate MMDistributedDataParallel.
27 | Note that only modules with parameters will be wrapped.
28 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
29 |
30 | Note that the arguments of this wrapper is the same as those in
31 | `torch.nn.parallel.distributed.DistributedDataParallel`.
32 |
33 | Args:
34 | module (nn.Module): Module that needs to be wrapped.
35 | device_ids (list[int | `torch.device`]): Same as that in
36 | `torch.nn.parallel.distributed.DistributedDataParallel`.
37 | dim (int, optional): Same as that in the official scatter function in
38 | pytorch. Defaults to 0.
39 | broadcast_buffers (bool): Same as that in
40 | `torch.nn.parallel.distributed.DistributedDataParallel`.
41 | Defaults to False.
42 | find_unused_parameters (bool, optional): Same as that in
43 | `torch.nn.parallel.distributed.DistributedDataParallel`.
44 | Traverse the autograd graph of all tensors contained in returned
45 | value of the wrapped module’s forward function. Defaults to False.
46 | kwargs (dict): Other arguments used in
47 | `torch.nn.parallel.distributed.DistributedDataParallel`.
48 | """
49 |
50 | def __init__(self,
51 | module,
52 | device_ids,
53 | dim=0,
54 | broadcast_buffers=False,
55 | find_unused_parameters=False,
56 | **kwargs):
57 | super().__init__()
58 | assert len(device_ids) == 1, (
59 | 'Currently, DistributedDataParallelWrapper only supports one'
60 | 'single CUDA device for each process.'
61 | f'The length of device_ids must be 1, but got {len(device_ids)}.')
62 | self.module = module
63 | self.dim = dim
64 | self.to_ddp(
65 | device_ids=device_ids,
66 | dim=dim,
67 | broadcast_buffers=broadcast_buffers,
68 | find_unused_parameters=find_unused_parameters,
69 | **kwargs)
70 | self.output_device = _get_device_index(device_ids[0], True)
71 |
72 | def to_ddp(self, device_ids, dim, broadcast_buffers,
73 | find_unused_parameters, **kwargs):
74 | """Wrap models with separate MMDistributedDataParallel.
75 |
76 | It only wraps the modules with parameters.
77 | """
78 | for name, module in self.module._modules.items():
79 | if next(module.parameters(), None) is None:
80 | module = module.cuda()
81 | elif all(not p.requires_grad for p in module.parameters()):
82 | module = module.cuda()
83 | else:
84 | module = MMDistributedDataParallel(
85 | module.cuda(),
86 | device_ids=device_ids,
87 | dim=dim,
88 | broadcast_buffers=broadcast_buffers,
89 | find_unused_parameters=find_unused_parameters,
90 | **kwargs)
91 | self.module._modules[name] = module
92 |
93 | def scatter(self, inputs, kwargs, device_ids):
94 | """Scatter function.
95 |
96 | Args:
97 | inputs (Tensor): Input Tensor.
98 | kwargs (dict): Args for
99 | ``mmcv.parallel.scatter_gather.scatter_kwargs``.
100 | device_ids (int): Device id.
101 | """
102 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
103 |
104 | def forward(self, *inputs, **kwargs):
105 | """Forward function.
106 |
107 | Args:
108 | inputs (tuple): Input data.
109 | kwargs (dict): Args for
110 | ``mmcv.parallel.scatter_gather.scatter_kwargs``.
111 | """
112 | inputs, kwargs = self.scatter(inputs, kwargs,
113 | [torch.cuda.current_device()])
114 | return self.module(*inputs[0], **kwargs[0])
115 |
116 | def train_step(self, *inputs, **kwargs):
117 | """Train step function.
118 |
119 | Args:
120 | inputs (Tensor): Input Tensor.
121 | kwargs (dict): Args for
122 | ``mmcv.parallel.scatter_gather.scatter_kwargs``.
123 | """
124 | inputs, kwargs = self.scatter(inputs, kwargs,
125 | [torch.cuda.current_device()])
126 | output = self.module.train_step(*inputs[0], **kwargs[0])
127 | return output
128 |
129 | def val_step(self, *inputs, **kwargs):
130 | """Validation step function.
131 |
132 | Args:
133 | inputs (tuple): Input data.
134 | kwargs (dict): Args for ``scatter_kwargs``.
135 | """
136 | inputs, kwargs = self.scatter(inputs, kwargs,
137 | [torch.cuda.current_device()])
138 | output = self.module.val_step(*inputs[0], **kwargs[0])
139 | return output
140 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import copy
4 | import os
5 | import os.path as osp
6 | import time
7 |
8 | import mmcv
9 | import torch
10 | import torch.distributed as dist
11 | from mmcv import Config, DictAction
12 | from mmcv.runner import init_dist
13 |
14 | from mmedit import __version__
15 | from mmedit.apis import init_random_seed, set_random_seed, train_model
16 | from mmedit.datasets import build_dataset
17 | from mmedit.models import build_model
18 | from mmedit.utils import collect_env, get_root_logger, setup_multi_processes
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='Train an editor')
23 | parser.add_argument('config', help='train config file path')
24 | parser.add_argument('--work-dir', help='the dir to save logs and models')
25 | parser.add_argument(
26 | '--resume-from', help='the checkpoint file to resume from')
27 | parser.add_argument(
28 | '--no-validate',
29 | action='store_true',
30 | help='whether not to evaluate the checkpoint during training')
31 | parser.add_argument(
32 | '--gpus',
33 | type=int,
34 | default=1,
35 | help='number of gpus to use '
36 | '(only applicable to non-distributed training)')
37 | parser.add_argument('--seed', type=int, default=None, help='random seed')
38 | parser.add_argument(
39 | '--diff_seed',
40 | action='store_true',
41 | help='Whether or not set different seeds for different ranks')
42 | parser.add_argument(
43 | '--deterministic',
44 | action='store_true',
45 | help='whether to set deterministic options for CUDNN backend.')
46 | parser.add_argument(
47 | '--cfg-options',
48 | nargs='+',
49 | action=DictAction,
50 | help='override some settings in the used config, the key-value pair '
51 | 'in xxx=yyy format will be merged into config file. If the value to '
52 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
53 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
54 | 'Note that the quotation marks are necessary and that no white space '
55 | 'is allowed.')
56 | parser.add_argument(
57 | '--launcher',
58 | choices=['none', 'pytorch', 'slurm', 'mpi'],
59 | default='none',
60 | help='job launcher')
61 | parser.add_argument('--local_rank', type=int, default=0)
62 | parser.add_argument(
63 | '--autoscale-lr',
64 | action='store_true',
65 | help='automatically scale lr with the number of gpus')
66 | args = parser.parse_args()
67 | if 'LOCAL_RANK' not in os.environ:
68 | os.environ['LOCAL_RANK'] = str(args.local_rank)
69 |
70 | return args
71 |
72 |
73 | def main():
74 | args = parse_args()
75 |
76 | cfg = Config.fromfile(args.config)
77 |
78 | if args.cfg_options is not None:
79 | cfg.merge_from_dict(args.cfg_options)
80 |
81 | # set multi-process settings
82 | setup_multi_processes(cfg)
83 |
84 | # set cudnn_benchmark
85 | if cfg.get('cudnn_benchmark', False):
86 | torch.backends.cudnn.benchmark = True
87 | # update configs according to CLI args
88 | if args.work_dir is not None:
89 | cfg.work_dir = args.work_dir
90 | if args.resume_from is not None:
91 | cfg.resume_from = args.resume_from
92 | cfg.gpus = args.gpus
93 |
94 | if args.autoscale_lr:
95 | # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
96 | cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8
97 |
98 | # init distributed env first, since logger depends on the dist info.
99 | if args.launcher == 'none':
100 | distributed = False
101 | else:
102 | distributed = True
103 | init_dist(args.launcher, **cfg.dist_params)
104 |
105 | # create work_dir
106 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
107 | # dump config
108 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
109 | # init the logger before other steps
110 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
111 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
112 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
113 |
114 | # log env info
115 | env_info_dict = collect_env.collect_env()
116 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
117 | dash_line = '-' * 60 + '\n'
118 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
119 | dash_line)
120 |
121 | # log some basic info
122 | logger.info('Distributed training: {}'.format(distributed))
123 | logger.info('mmedit Version: {}'.format(__version__))
124 | logger.info('Config:\n{}'.format(cfg.text))
125 |
126 | # set random seeds
127 | seed = init_random_seed(args.seed)
128 | seed = seed + dist.get_rank() if args.diff_seed else seed
129 | logger.info('Set random seed to {}, deterministic: {}'.format(
130 | seed, args.deterministic))
131 | set_random_seed(seed, deterministic=args.deterministic)
132 | cfg.seed = seed
133 |
134 | model = build_model(
135 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
136 |
137 | datasets = [build_dataset(cfg.data.train)]
138 | if len(cfg.workflow) == 2:
139 | val_dataset = copy.deepcopy(cfg.data.val)
140 | val_dataset.pipeline = cfg.data.train.pipeline
141 | datasets.append(build_dataset(val_dataset))
142 | if cfg.checkpoint_config is not None:
143 | # save version, config file content and class names in
144 | # checkpoints as meta data
145 | cfg.checkpoint_config.meta = dict(
146 | mmedit_version=__version__,
147 | config=cfg.text,
148 | )
149 |
150 | # meta information
151 | meta = dict()
152 | if cfg.get('exp_name', None) is None:
153 | cfg['exp_name'] = osp.splitext(osp.basename(cfg.work_dir))[0]
154 | meta['exp_name'] = cfg.exp_name
155 | meta['mmedit Version'] = __version__
156 | meta['seed'] = seed
157 | meta['env_info'] = env_info
158 |
159 | # add an attribute for visualization convenience
160 | train_model(
161 | model,
162 | datasets,
163 | cfg,
164 | distributed=distributed,
165 | validate=(not args.no_validate),
166 | timestamp=timestamp,
167 | meta=meta)
168 |
169 |
170 |
171 | if __name__ == '__main__':
172 | main()
173 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/generate_assistant.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 |
5 | from ..registry import PIPELINES
6 | from .utils import make_coord
7 |
8 |
9 | @PIPELINES.register_module()
10 | class GenerateHeatmap:
11 | """Generate heatmap from keypoint.
12 |
13 | Args:
14 | keypoint (str): Key of keypoint in dict.
15 | ori_size (int | Tuple[int]): Original image size of keypoint.
16 | target_size (int | Tuple[int]): Target size of heatmap.
17 | sigma (float): Sigma parameter of heatmap. Default: 1.0
18 | """
19 |
20 | def __init__(self, keypoint, ori_size, target_size, sigma=1.0):
21 | if isinstance(ori_size, int):
22 | ori_size = (ori_size, ori_size)
23 | else:
24 | ori_size = ori_size[:2]
25 | if isinstance(target_size, int):
26 | target_size = (target_size, target_size)
27 | else:
28 | target_size = target_size[:2]
29 | self.size_ratio = (target_size[0] / ori_size[0],
30 | target_size[1] / ori_size[1])
31 | self.keypoint = keypoint
32 | self.sigma = sigma
33 | self.target_size = target_size
34 | self.ori_size = ori_size
35 |
36 | def __call__(self, results):
37 | """Call function.
38 |
39 | Args:
40 | results (dict): A dict containing the necessary information and
41 | data for augmentation. Require keypoint.
42 |
43 | Returns:
44 | dict: A dict containing the processed data and information.
45 | Add 'heatmap'.
46 | """
47 | keypoint_list = [(keypoint[0] * self.size_ratio[0],
48 | keypoint[1] * self.size_ratio[1])
49 | for keypoint in results[self.keypoint]]
50 | heatmap_list = [
51 | self._generate_one_heatmap(keypoint) for keypoint in keypoint_list
52 | ]
53 | results['heatmap'] = np.stack(heatmap_list, axis=2)
54 | return results
55 |
56 | def _generate_one_heatmap(self, keypoint):
57 | """Generate One Heatmap.
58 |
59 | Args:
60 | landmark (Tuple[float]): Location of a landmark.
61 |
62 | results:
63 | heatmap (np.ndarray): A heatmap of landmark.
64 | """
65 | w, h = self.target_size
66 |
67 | x_range = np.arange(start=0, stop=w, dtype=int)
68 | y_range = np.arange(start=0, stop=h, dtype=int)
69 | grid_x, grid_y = np.meshgrid(x_range, y_range)
70 | dist2 = (grid_x - keypoint[0])**2 + (grid_y - keypoint[1])**2
71 | exponent = dist2 / 2.0 / self.sigma / self.sigma
72 | heatmap = np.exp(-exponent)
73 | return heatmap
74 |
75 | def __repr__(self):
76 | return (f'{self.__class__.__name__}, '
77 | f'keypoint={self.keypoint}, '
78 | f'ori_size={self.ori_size}, '
79 | f'target_size={self.target_size}, '
80 | f'sigma={self.sigma}')
81 |
82 |
83 | @PIPELINES.register_module()
84 | class GenerateCoordinateAndCell:
85 | """Generate coordinate and cell.
86 |
87 | Generate coordinate from the desired size of SR image.
88 | Train or val:
89 | 1. Generate coordinate from GT.
90 | 2. Reshape GT image to (HgWg, 3) and transpose to (3, HgWg).
91 | where `Hg` and `Wg` represent the height and width of GT.
92 | Test:
93 | Generate coordinate from LQ and scale or target_size.
94 | Then generate cell from coordinate.
95 |
96 | Args:
97 | sample_quantity (int): The quantity of samples in coordinates.
98 | To ensure that the GT tensors in a batch have the same dimensions.
99 | Default: None.
100 | scale (float): Scale of upsampling.
101 | Default: None.
102 | target_size (tuple[int]): Size of target image.
103 | Default: None.
104 |
105 | The priority of getting 'size of target image' is:
106 | 1, results['gt'].shape[-2:]
107 | 2, results['lq'].shape[-2:] * scale
108 | 3, target_size
109 | """
110 |
111 | def __init__(self, sample_quantity=None, scale=None, target_size=None):
112 | self.sample_quantity = sample_quantity
113 | self.scale = scale
114 | self.target_size = target_size
115 |
116 | def __call__(self, results):
117 | """Call function.
118 |
119 | Args:
120 | results (dict): A dict containing the necessary information and
121 | data for augmentation.
122 | Require either in results:
123 | 1. 'lq' (tensor), whose shape is similar as (3, H, W).
124 | 2. 'gt' (tensor), whose shape is similar as (3, H, W).
125 | 3. None, the premise is
126 | self.target_size and len(self.target_size) >= 2.
127 |
128 | Returns:
129 | dict: A dict containing the processed data and information.
130 | Reshape 'gt' to (-1, 3) and transpose to (3, -1) if 'gt'
131 | in results.
132 | Add 'coord' and 'cell'.
133 | """
134 | # generate hr_coord (and hr_rgb)
135 | if 'gt' in results:
136 | crop_hr = results['gt']
137 | self.target_size = crop_hr.shape
138 | hr_rgb = crop_hr.contiguous().view(3, -1).permute(1, 0)
139 | results['gt'] = hr_rgb
140 | elif self.scale is not None and 'lq' in results:
141 | _, h_lr, w_lr = results['lq'].shape
142 | self.target_size = (round(h_lr * self.scale),
143 | round(w_lr * self.scale))
144 | else:
145 | assert self.target_size is not None
146 | assert len(self.target_size) >= 2
147 | hr_coord = make_coord(self.target_size[-2:])
148 |
149 | if self.sample_quantity is not None and 'gt' in results:
150 | sample_lst = np.random.choice(
151 | len(hr_coord), self.sample_quantity, replace=False)
152 | hr_coord = hr_coord[sample_lst]
153 | results['gt'] = results['gt'][sample_lst]
154 |
155 | # Preparations for cell decoding
156 | cell = torch.ones_like(hr_coord)
157 | cell[:, 0] *= 2 / self.target_size[-2]
158 | cell[:, 1] *= 2 / self.target_size[-1]
159 |
160 | results['coord'] = hr_coord
161 | results['cell'] = cell
162 |
163 | return results
164 |
165 | def __repr__(self):
166 | repr_str = self.__class__.__name__
167 | repr_str += (f'sample_quantity={self.sample_quantity}, '
168 | f'scale={self.scale}, target_size={self.target_size}')
169 | return repr_str
170 |
--------------------------------------------------------------------------------
/mmedit/datasets/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | import platform
4 | import random
5 | from functools import partial
6 |
7 | import numpy as np
8 | import torch
9 | from mmcv.parallel import collate
10 | from mmcv.runner import get_dist_info
11 | from mmcv.utils import build_from_cfg
12 | from packaging import version
13 | from torch.utils.data import ConcatDataset, DataLoader
14 |
15 | from .dataset_wrappers import RepeatDataset
16 | from .registry import DATASETS
17 | from .samplers import DistributedSampler
18 |
19 | if platform.system() != 'Windows':
20 | # https://github.com/pytorch/pytorch/issues/973
21 | import resource
22 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
23 | base_soft_limit = rlimit[0]
24 | hard_limit = rlimit[1]
25 | soft_limit = min(max(4096, base_soft_limit), hard_limit)
26 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
27 |
28 |
29 | def _concat_dataset(cfg, default_args=None):
30 | """Concat datasets with different ann_file but the same type.
31 |
32 | Args:
33 | cfg (dict): The config of dataset.
34 | default_args (dict, optional): Default initialization arguments.
35 | Default: None.
36 |
37 | Returns:
38 | Dataset: The concatenated dataset.
39 | """
40 | ann_files = cfg['ann_file']
41 |
42 | datasets = []
43 | num_dset = len(ann_files)
44 | for i in range(num_dset):
45 | data_cfg = copy.deepcopy(cfg)
46 | data_cfg['ann_file'] = ann_files[i]
47 | datasets.append(build_dataset(data_cfg, default_args))
48 |
49 | return ConcatDataset(datasets)
50 |
51 |
52 | def build_dataset(cfg, default_args=None):
53 | """Build a dataset from config dict.
54 |
55 | It supports a variety of dataset config. If ``cfg`` is a Sequential (list
56 | or dict), it will be a concatenated dataset of the datasets specified by
57 | the Sequential. If it is a ``RepeatDataset``, then it will repeat the
58 | dataset ``cfg['dataset']`` for ``cfg['times']`` times. If the ``ann_file``
59 | of the dataset is a Sequential, then it will build a concatenated dataset
60 | with the same dataset type but different ``ann_file``.
61 |
62 | Args:
63 | cfg (dict): Config dict. It should at least contain the key "type".
64 | default_args (dict, optional): Default initialization arguments.
65 | Default: None.
66 |
67 | Returns:
68 | Dataset: The constructed dataset.
69 | """
70 | if isinstance(cfg, (list, tuple)):
71 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
72 | elif cfg['type'] == 'RepeatDataset':
73 | dataset = RepeatDataset(
74 | build_dataset(cfg['dataset'], default_args), cfg['times'])
75 | elif isinstance(cfg.get('ann_file'), (list, tuple)):
76 | dataset = _concat_dataset(cfg, default_args)
77 | else:
78 | dataset = build_from_cfg(cfg, DATASETS, default_args)
79 |
80 | return dataset
81 |
82 |
83 | def build_dataloader(dataset,
84 | samples_per_gpu,
85 | workers_per_gpu,
86 | num_gpus=1,
87 | dist=True,
88 | shuffle=True,
89 | seed=None,
90 | drop_last=False,
91 | pin_memory=True,
92 | persistent_workers=True,
93 | **kwargs):
94 | """Build PyTorch DataLoader.
95 |
96 | In distributed training, each GPU/process has a dataloader.
97 | In non-distributed training, there is only one dataloader for all GPUs.
98 |
99 | Args:
100 | dataset (:obj:`Dataset`): A PyTorch dataset.
101 | samples_per_gpu (int): Number of samples on each GPU, i.e.,
102 | batch size of each GPU.
103 | workers_per_gpu (int): How many subprocesses to use for data
104 | loading for each GPU.
105 | num_gpus (int): Number of GPUs. Only used in non-distributed
106 | training. Default: 1.
107 | dist (bool): Distributed training/test or not. Default: True.
108 | shuffle (bool): Whether to shuffle the data at every epoch.
109 | Default: True.
110 | seed (int | None): Seed to be used. Default: None.
111 | drop_last (bool): Whether to drop the last incomplete batch in epoch.
112 | Default: False
113 | pin_memory (bool): Whether to use pin_memory in DataLoader.
114 | Default: True
115 | persistent_workers (bool): If True, the data loader will not shutdown
116 | the worker processes after a dataset has been consumed once.
117 | This allows to maintain the workers Dataset instances alive.
118 | The argument also has effect in PyTorch>=1.7.0.
119 | Default: True
120 | kwargs (dict, optional): Any keyword argument to be used to initialize
121 | DataLoader.
122 |
123 | Returns:
124 | DataLoader: A PyTorch dataloader.
125 | """
126 | rank, world_size = get_dist_info()
127 | if dist:
128 | sampler = DistributedSampler(
129 | dataset,
130 | world_size,
131 | rank,
132 | shuffle=shuffle,
133 | samples_per_gpu=samples_per_gpu,
134 | seed=seed)
135 | shuffle = False
136 | batch_size = samples_per_gpu
137 | num_workers = workers_per_gpu
138 | else:
139 | sampler = None
140 | batch_size = num_gpus * samples_per_gpu
141 | num_workers = num_gpus * workers_per_gpu
142 |
143 | init_fn = partial(
144 | worker_init_fn, num_workers=num_workers, rank=rank,
145 | seed=seed) if seed is not None else None
146 |
147 | if version.parse(torch.__version__) >= version.parse('1.7.0'):
148 | kwargs['persistent_workers'] = persistent_workers
149 |
150 | data_loader = DataLoader(
151 | dataset,
152 | batch_size=batch_size,
153 | sampler=sampler,
154 | num_workers=num_workers,
155 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
156 | pin_memory=pin_memory,
157 | shuffle=shuffle,
158 | worker_init_fn=init_fn,
159 | drop_last=drop_last,
160 | **kwargs)
161 |
162 | return data_loader
163 |
164 |
165 | def worker_init_fn(worker_id, num_workers, rank, seed):
166 | """Function to initialize each worker.
167 |
168 | The seed of each worker equals to
169 | ``num_worker * rank + worker_id + user_seed``.
170 |
171 | Args:
172 | worker_id (int): Id for each worker.
173 | num_workers (int): Number of workers.
174 | rank (int): Rank in distributed training.
175 | seed (int): Random seed.
176 | """
177 |
178 | worker_seed = num_workers * rank + worker_id + seed
179 | np.random.seed(worker_seed)
180 | random.seed(worker_seed)
181 | torch.manual_seed(worker_seed)
182 |
--------------------------------------------------------------------------------
/mmedit/models/derainers/derainer.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import os.path as osp
3 | import torch
4 | import mmcv
5 | from mmcv.runner import auto_fp16
6 | import torch.nn as nn
7 | from mmedit.core import psnr, ssim, tensor2img
8 | from ..base import BaseModel
9 | from ..builder import build_backbone, build_loss
10 | from ..registry import MODELS
11 |
12 |
13 |
14 | @MODELS.register_module()
15 | class Derainer(BaseModel):
16 |
17 | allowed_metrics = {'PSNR': psnr, 'SSIM': ssim}
18 |
19 | def __init__(self,
20 | generator,
21 | pixel_loss,
22 |
23 | train_cfg=None,
24 | test_cfg=None,
25 | pretrained=None):
26 | super().__init__()
27 |
28 | self.train_cfg = train_cfg
29 | self.test_cfg = test_cfg
30 |
31 | self.fp16_enabled = False
32 |
33 | # generator
34 | self.generator = build_backbone(generator)
35 |
36 | self.init_weights(pretrained)
37 |
38 | # loss
39 | self.pixel_loss = build_loss(pixel_loss)
40 | self.mask_loss = nn.BCELoss()
41 | self.streak_loss = nn.MSELoss()
42 | self.l1_loss = nn.L1Loss()
43 |
44 |
45 | def init_weights(self, pretrained=None):
46 | """Init weights for models.
47 |
48 | Args:
49 | pretrained (str, optional): Path for pretrained weights. If given
50 | None, pretrained weights will not be loaded. Defaults to None.
51 | """
52 | self.generator.init_weights(pretrained)
53 |
54 | @auto_fp16(apply_to=('lq', ))
55 | def forward(self, lq, gt=None, mask=None, drop=None, streak=None, test_mode=False, **kwargs):
56 | """Forward function.
57 |
58 | Args:
59 | lq (Tensor): Input lq images.
60 | gt (Tensor): Ground-truth image. Default: None.
61 | test_mode (bool): Whether in test mode or not. Default: False.
62 | kwargs (dict): Other arguments.
63 | """
64 |
65 | if test_mode:
66 | return self.forward_test(lq, gt, **kwargs)
67 |
68 | return self.forward_train(lq, gt, mask, drop, streak)
69 |
70 | def forward_train(self, lq, gt, mask, drop, streak):
71 | """Training forward function.
72 |
73 | Args:
74 | lq (Tensor): LQ Tensor with shape (n, c, h, w).
75 | gt (Tensor): GT Tensor with shape (n, c, h, w).
76 |
77 | Returns:
78 | Tensor: Output tensor.
79 | """
80 | losses = dict()
81 |
82 | output, outputmask, streakmask, finalresult = self.generator(lq)
83 |
84 | loss_pix = self.pixel_loss(output, drop)
85 | loss_mask = 0.8 * self.mask_loss(outputmask, mask)
86 | loss_pix2 = 1.2 * self.pixel_loss(finalresult, gt)
87 | hole_loss = self.l1_loss(finalresult * mask, gt * mask)
88 | hole_loss = 0.5 * hole_loss / (torch.mean(mask) * 1 + 0.00000001)
89 | loss_streak = 0.5*self.streak_loss(streakmask, streak)
90 |
91 |
92 | losses['loss_pix'] = loss_pix
93 | losses['loss_mask'] = loss_mask
94 | losses['loss_reconstruct'] = loss_pix2
95 | losses['loss_hole'] = hole_loss
96 | losses['loss_streak'] = loss_streak
97 |
98 |
99 | outputs = dict(
100 | losses=losses,
101 | num_samples=len(gt.data),
102 | results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu()))
103 | return outputs
104 |
105 | def evaluate(self, output, gt):
106 | """Evaluation function.
107 |
108 | Args:
109 | output (Tensor): Model output with shape (n, c, h, w).
110 | gt (Tensor): GT Tensor with shape (n, c, h, w).
111 |
112 | Returns:
113 | dict: Evaluation results.
114 | """
115 | crop_border = self.test_cfg.crop_border
116 |
117 | output = tensor2img(output)
118 | gt = tensor2img(gt)
119 |
120 | eval_result = dict()
121 | for metric in self.test_cfg.metrics:
122 | eval_result[metric] = self.allowed_metrics[metric](output, gt,
123 | crop_border)
124 | return eval_result
125 |
126 | def forward_test(self,
127 | lq,
128 | gt=None,
129 | meta=None,
130 | save_image=False,
131 | save_path=None,
132 | iteration=None):
133 | """Testing forward function.
134 |
135 | Args:
136 | lq (Tensor): LQ Tensor with shape (n, c, h, w).
137 | gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None.
138 | save_image (bool): Whether to save image. Default: False.
139 | save_path (str): Path to save image. Default: None.
140 | iteration (int): Iteration for the saving image name.
141 | Default: None.
142 |
143 | Returns:
144 | dict: Output results.
145 | """
146 | output = self.generator(lq)
147 | if self.test_cfg is not None and self.test_cfg.get('metrics', None):
148 | assert gt is not None, (
149 | 'evaluation with metrics must have gt images.')
150 | results = dict(eval_result=self.evaluate(output, gt))
151 | else:
152 | results = dict(lq=lq.cpu(), output=output.cpu())
153 | if gt is not None:
154 | results['gt'] = gt.cpu()
155 |
156 | # save image
157 | if save_image:
158 | lq_path = meta[0]['lq_path']
159 | folder_name = osp.splitext(osp.basename(lq_path))[0]
160 | if isinstance(iteration, numbers.Number):
161 | save_path = osp.join(save_path, folder_name,
162 | f'{folder_name}-{iteration + 1:06d}.png')
163 | elif iteration is None:
164 | save_path = osp.join(save_path, f'{folder_name}.png')
165 | else:
166 | raise ValueError('iteration should be number or None, '
167 | f'but got {type(iteration)}')
168 | mmcv.imwrite(tensor2img(output), save_path)
169 |
170 | return results
171 |
172 | def forward_dummy(self, img):
173 | """Used for computing network FLOPs.
174 |
175 | Args:
176 | img (Tensor): Input image.
177 |
178 | Returns:
179 | Tensor: Output image.
180 | """
181 | out = self.generator(img)
182 | return out
183 |
184 | def train_step(self, data_batch, optimizer):
185 | """Train step.
186 |
187 | Args:
188 | data_batch (dict): A batch of data.
189 | optimizer (obj): Optimizer.
190 |
191 | Returns:
192 | dict: Returned output.
193 | """
194 | outputs = self(**data_batch, test_mode=False)
195 | loss, log_vars = self.parse_losses(outputs.pop('losses'))
196 |
197 | # optimize
198 | optimizer['generator'].zero_grad()
199 | loss.backward()
200 | optimizer['generator'].step()
201 |
202 | outputs.update({'log_vars': log_vars})
203 | return outputs
204 |
205 | def val_step(self, data_batch, **kwargs):
206 | """Validation step.
207 |
208 | Args:
209 | data_batch (dict): A batch of data.
210 | kwargs (dict): Other arguments for ``val_step``.
211 |
212 | Returns:
213 | dict: Returned output.
214 | """
215 | output = self.forward_test(**data_batch, **kwargs)
216 | return output
217 |
218 |
--------------------------------------------------------------------------------
/mmedit/models/losses/pixelwise_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from ..registry import LOSSES
7 | from .utils import masked_loss
8 |
9 | _reduction_modes = ['none', 'mean', 'sum']
10 |
11 |
12 | @masked_loss
13 | def l1_loss(pred, target):
14 | """L1 loss.
15 |
16 | Args:
17 | pred (Tensor): Prediction Tensor with shape (n, c, h, w).
18 | target ([type]): Target Tensor with shape (n, c, h, w).
19 |
20 | Returns:
21 | Tensor: Calculated L1 loss.
22 | """
23 | return F.l1_loss(pred, target, reduction='none')
24 |
25 |
26 | @masked_loss
27 | def mse_loss(pred, target):
28 | """MSE loss.
29 |
30 | Args:
31 | pred (Tensor): Prediction Tensor with shape (n, c, h, w).
32 | target ([type]): Target Tensor with shape (n, c, h, w).
33 |
34 | Returns:
35 | Tensor: Calculated MSE loss.
36 | """
37 | return F.mse_loss(pred, target, reduction='none')
38 |
39 |
40 | @masked_loss
41 | def charbonnier_loss(pred, target, eps=1e-12):
42 | """Charbonnier loss.
43 |
44 | Args:
45 | pred (Tensor): Prediction Tensor with shape (n, c, h, w).
46 | target ([type]): Target Tensor with shape (n, c, h, w).
47 |
48 | Returns:
49 | Tensor: Calculated Charbonnier loss.
50 | """
51 | return torch.sqrt((pred - target)**2 + eps)
52 |
53 |
54 | @LOSSES.register_module()
55 | class L1Loss(nn.Module):
56 | """L1 (mean absolute error, MAE) loss.
57 |
58 | Args:
59 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
60 | reduction (str): Specifies the reduction to apply to the output.
61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
62 | sample_wise (bool): Whether calculate the loss sample-wise. This
63 | argument only takes effect when `reduction` is 'mean' and `weight`
64 | (argument of `forward()`) is not None. It will first reduce loss
65 | with 'mean' per-sample, and then it means over all the samples.
66 | Default: False.
67 | """
68 |
69 | def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False):
70 | super().__init__()
71 | if reduction not in ['none', 'mean', 'sum']:
72 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
73 | f'Supported ones are: {_reduction_modes}')
74 |
75 | self.loss_weight = loss_weight
76 | self.reduction = reduction
77 | self.sample_wise = sample_wise
78 |
79 | def forward(self, pred, target, weight=None, **kwargs):
80 | """Forward Function.
81 |
82 | Args:
83 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
84 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
85 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
86 | weights. Default: None.
87 | """
88 | return self.loss_weight * l1_loss(
89 | pred,
90 | target,
91 | weight,
92 | reduction=self.reduction,
93 | sample_wise=self.sample_wise)
94 |
95 |
96 | @LOSSES.register_module()
97 | class MSELoss(nn.Module):
98 | """MSE (L2) loss.
99 |
100 | Args:
101 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
102 | reduction (str): Specifies the reduction to apply to the output.
103 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
104 | sample_wise (bool): Whether calculate the loss sample-wise. This
105 | argument only takes effect when `reduction` is 'mean' and `weight`
106 | (argument of `forward()`) is not None. It will first reduces loss
107 | with 'mean' per-sample, and then it means over all the samples.
108 | Default: False.
109 | """
110 |
111 | def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False):
112 | super().__init__()
113 | if reduction not in ['none', 'mean', 'sum']:
114 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
115 | f'Supported ones are: {_reduction_modes}')
116 |
117 | self.loss_weight = loss_weight
118 | self.reduction = reduction
119 | self.sample_wise = sample_wise
120 |
121 | def forward(self, pred, target, weight=None, **kwargs):
122 | """Forward Function.
123 |
124 | Args:
125 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
126 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
127 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
128 | weights. Default: None.
129 | """
130 | return self.loss_weight * mse_loss(
131 | pred,
132 | target,
133 | weight,
134 | reduction=self.reduction,
135 | sample_wise=self.sample_wise)
136 |
137 |
138 | @LOSSES.register_module()
139 | class CharbonnierLoss(nn.Module):
140 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable variant
141 | of L1Loss).
142 |
143 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
144 | Super-Resolution".
145 |
146 | Args:
147 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
148 | reduction (str): Specifies the reduction to apply to the output.
149 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
150 | sample_wise (bool): Whether calculate the loss sample-wise. This
151 | argument only takes effect when `reduction` is 'mean' and `weight`
152 | (argument of `forward()`) is not None. It will first reduces loss
153 | with 'mean' per-sample, and then it means over all the samples.
154 | Default: False.
155 | eps (float): A value used to control the curvature near zero.
156 | Default: 1e-12.
157 | """
158 |
159 | def __init__(self,
160 | loss_weight=1.0,
161 | reduction='mean',
162 | sample_wise=False,
163 | eps=1e-12):
164 | super().__init__()
165 | if reduction not in ['none', 'mean', 'sum']:
166 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
167 | f'Supported ones are: {_reduction_modes}')
168 |
169 | self.loss_weight = loss_weight
170 | self.reduction = reduction
171 | self.sample_wise = sample_wise
172 | self.eps = eps
173 |
174 | def forward(self, pred, target, weight=None, **kwargs):
175 | """Forward Function.
176 |
177 | Args:
178 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
179 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
180 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
181 | weights. Default: None.
182 | """
183 | return self.loss_weight * charbonnier_loss(
184 | pred,
185 | target,
186 | weight,
187 | eps=self.eps,
188 | reduction=self.reduction,
189 | sample_wise=self.sample_wise)
190 |
191 |
192 | @LOSSES.register_module()
193 | class MaskedTVLoss(L1Loss):
194 | """Masked TV loss.
195 |
196 | Args:
197 | loss_weight (float, optional): Loss weight. Defaults to 1.0.
198 | """
199 |
200 | def __init__(self, loss_weight=1.0):
201 | super().__init__(loss_weight=loss_weight)
202 |
203 | def forward(self, pred, mask=None):
204 | """Forward function.
205 |
206 | Args:
207 | pred (torch.Tensor): Tensor with shape of (n, c, h, w).
208 | mask (torch.Tensor, optional): Tensor with shape of (n, 1, h, w).
209 | Defaults to None.
210 |
211 | Returns:
212 | [type]: [description]
213 | """
214 | y_diff = super().forward(
215 | pred[:, :, :-1, :], pred[:, :, 1:, :], weight=mask[:, :, :-1, :])
216 | x_diff = super().forward(
217 | pred[:, :, :, :-1], pred[:, :, :, 1:], weight=mask[:, :, :, :-1])
218 |
219 | loss = x_diff + y_diff
220 |
221 | return loss
222 |
--------------------------------------------------------------------------------
/mmedit/models/backbones/derain_backbones/derain_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from mmcv.cnn import ConvModule
6 | from mmcv.runner import load_checkpoint
7 |
8 | from mmedit.models.common import ( ResidualBlockNoBN,
9 | flow_warp, make_layer)
10 | from mmedit.models.registry import BACKBONES
11 | from mmedit.utils import get_root_logger
12 |
13 |
14 |
15 |
16 | class ResidualBlocksWithInputConv(nn.Module):
17 | """Residual blocks with a convolution in front.
18 |
19 | Args:
20 | in_channels (int): Number of input channels of the first conv.
21 | out_channels (int): Number of channels of the residual blocks.
22 | Default: 64.
23 | num_blocks (int): Number of residual blocks. Default: 30.
24 | """
25 |
26 | def __init__(self, in_channels, out_channels=64, num_blocks=30):
27 | super().__init__()
28 |
29 | main = []
30 |
31 | # a convolution used to match the channels of the residual blocks
32 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True))
33 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
34 |
35 | # residual blocks
36 | main.append(
37 | make_layer(
38 | ResidualBlockNoBN, num_blocks, mid_channels=out_channels))
39 |
40 | self.main = nn.Sequential(*main)
41 |
42 | def forward(self, feat):
43 | """Forward function for ResidualBlocksWithInputConv.
44 |
45 | Args:
46 | feat (Tensor): Input feature with shape (n, in_channels, h, w)
47 |
48 | Returns:
49 | Tensor: Output feature with shape (n, out_channels, h, w)
50 | """
51 | return self.main(feat)
52 |
53 |
54 | class SPyNet(nn.Module):
55 | """SPyNet network structure.
56 |
57 | The difference to the SPyNet in [tof.py] is that
58 | 1. more SPyNetBasicModule is used in this version, and
59 | 2. no batch normalization is used in this version.
60 |
61 | Paper:
62 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
63 |
64 | Args:
65 | pretrained (str): path for pre-trained SPyNet. Default: None.
66 | """
67 |
68 | def __init__(self, pretrained):
69 | super().__init__()
70 |
71 | self.basic_module = nn.ModuleList(
72 | [SPyNetBasicModule() for _ in range(6)])
73 |
74 | if isinstance(pretrained, str):
75 | logger = get_root_logger()
76 | load_checkpoint(self, pretrained, strict=True, logger=logger)
77 | elif pretrained is not None:
78 | raise TypeError('[pretrained] should be str or None, '
79 | f'but got {type(pretrained)}.')
80 |
81 | self.register_buffer(
82 | 'mean',
83 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
84 | self.register_buffer(
85 | 'std',
86 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
87 |
88 | def compute_flow(self, ref, supp):
89 | """Compute flow from ref to supp.
90 |
91 | Note that in this function, the images are already resized to a
92 | multiple of 32.
93 |
94 | Args:
95 | ref (Tensor): Reference image with shape of (n, 3, h, w).
96 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
97 |
98 | Returns:
99 | Tensor: Estimated optical flow: (n, 2, h, w).
100 | """
101 | n, _, h, w = ref.size()
102 |
103 | # normalize the input images
104 | ref = [(ref - self.mean) / self.std]
105 | supp = [(supp - self.mean) / self.std]
106 |
107 | # generate downsampled frames
108 | for level in range(5):
109 | ref.append(
110 | F.avg_pool2d(
111 | input=ref[-1],
112 | kernel_size=2,
113 | stride=2,
114 | count_include_pad=False))
115 | supp.append(
116 | F.avg_pool2d(
117 | input=supp[-1],
118 | kernel_size=2,
119 | stride=2,
120 | count_include_pad=False))
121 | ref = ref[::-1]
122 | supp = supp[::-1]
123 |
124 | # flow computation
125 | flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
126 | for level in range(len(ref)):
127 | if level == 0:
128 | flow_up = flow
129 | else:
130 | flow_up = F.interpolate(
131 | input=flow,
132 | scale_factor=2,
133 | mode='bilinear',
134 | align_corners=True) * 2.0
135 |
136 | # add the residue to the upsampled flow
137 | flow = flow_up + self.basic_module[level](
138 | torch.cat([
139 | ref[level],
140 | flow_warp(
141 | supp[level],
142 | flow_up.permute(0, 2, 3, 1),
143 | padding_mode='border'), flow_up
144 | ], 1))
145 |
146 | return flow
147 |
148 | def forward(self, ref, supp):
149 | """Forward function of SPyNet.
150 |
151 | This function computes the optical flow from ref to supp.
152 |
153 | Args:
154 | ref (Tensor): Reference image with shape of (n, 3, h, w).
155 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
156 |
157 | Returns:
158 | Tensor: Estimated optical flow: (n, 2, h, w).
159 | """
160 |
161 | # upsize to a multiple of 32
162 | h, w = ref.shape[2:4]
163 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
164 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
165 | ref = F.interpolate(
166 | input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
167 | supp = F.interpolate(
168 | input=supp,
169 | size=(h_up, w_up),
170 | mode='bilinear',
171 | align_corners=False)
172 |
173 | # compute flow, and resize back to the original resolution
174 | flow = F.interpolate(
175 | input=self.compute_flow(ref, supp),
176 | size=(h, w),
177 | mode='bilinear',
178 | align_corners=False)
179 |
180 | # adjust the flow values
181 | flow[:, 0, :, :] *= float(w) / float(w_up)
182 | flow[:, 1, :, :] *= float(h) / float(h_up)
183 |
184 | return flow
185 |
186 |
187 | class SPyNetBasicModule(nn.Module):
188 | """Basic Module for SPyNet.
189 |
190 | Paper:
191 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
192 | """
193 |
194 | def __init__(self):
195 | super().__init__()
196 |
197 | self.basic_module = nn.Sequential(
198 | ConvModule(
199 | in_channels=8,
200 | out_channels=32,
201 | kernel_size=7,
202 | stride=1,
203 | padding=3,
204 | norm_cfg=None,
205 | act_cfg=dict(type='ReLU')),
206 | ConvModule(
207 | in_channels=32,
208 | out_channels=64,
209 | kernel_size=7,
210 | stride=1,
211 | padding=3,
212 | norm_cfg=None,
213 | act_cfg=dict(type='ReLU')),
214 | ConvModule(
215 | in_channels=64,
216 | out_channels=32,
217 | kernel_size=7,
218 | stride=1,
219 | padding=3,
220 | norm_cfg=None,
221 | act_cfg=dict(type='ReLU')),
222 | ConvModule(
223 | in_channels=32,
224 | out_channels=16,
225 | kernel_size=7,
226 | stride=1,
227 | padding=3,
228 | norm_cfg=None,
229 | act_cfg=dict(type='ReLU')),
230 | ConvModule(
231 | in_channels=16,
232 | out_channels=2,
233 | kernel_size=7,
234 | stride=1,
235 | padding=3,
236 | norm_cfg=None,
237 | act_cfg=None))
238 |
239 | def forward(self, tensor_input):
240 | """
241 | Args:
242 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
243 | 8 channels contain:
244 | [reference image (3), neighbor image (3), initial flow (2)].
245 |
246 | Returns:
247 | Tensor: Refined flow with shape (b, 2, h, w)
248 | """
249 | return self.basic_module(tensor_input)
250 |
--------------------------------------------------------------------------------
/mmedit/apis/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 | import pickle
4 | import shutil
5 | import tempfile
6 |
7 | import mmcv
8 | import torch
9 | import torch.distributed as dist
10 | from mmcv.runner import get_dist_info
11 |
12 |
13 | def single_gpu_test(model,
14 | data_loader,
15 | save_image=False,
16 | save_path=None,
17 | iteration=None):
18 | """Test model with a single gpu.
19 |
20 | This method tests model with a single gpu and displays test progress bar.
21 |
22 | Args:
23 | model (nn.Module): Model to be tested.
24 | data_loader (nn.Dataloader): Pytorch data loader.
25 | save_image (bool): Whether save image. Default: False.
26 | save_path (str): The path to save image. Default: None.
27 | iteration (int): Iteration number. It is used for the save image name.
28 | Default: None.
29 |
30 | Returns:
31 | list: The prediction results.
32 | """
33 | if save_image and save_path is None:
34 | raise ValueError(
35 | "When 'save_image' is True, you should also set 'save_path'.")
36 |
37 | model.eval()
38 | results = []
39 | dataset = data_loader.dataset
40 | prog_bar = mmcv.ProgressBar(len(dataset))
41 | for data in data_loader:
42 | with torch.no_grad():
43 | result = model(
44 | test_mode=True,
45 | save_image=save_image,
46 | save_path=save_path,
47 | iteration=iteration,
48 | **data)
49 | results.append(result)
50 |
51 | # get batch size
52 | for _, v in data.items():
53 | if isinstance(v, torch.Tensor):
54 | batch_size = v.size(0)
55 | break
56 | for _ in range(batch_size):
57 | prog_bar.update()
58 | return results
59 |
60 |
61 | def multi_gpu_test(model,
62 | data_loader,
63 | tmpdir=None,
64 | gpu_collect=False,
65 | save_image=False,
66 | save_path=None,
67 | iteration=None,
68 | empty_cache=False):
69 | """Test model with multiple gpus.
70 |
71 | This method tests model with multiple gpus and collects the results
72 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
73 | it encodes results to gpu tensors and use gpu communication for results
74 | collection. On cpu mode it saves the results on different gpus to 'tmpdir'
75 | and collects them by the rank 0 worker.
76 |
77 | Args:
78 | model (nn.Module): Model to be tested.
79 | data_loader (nn.Dataloader): Pytorch data loader.
80 | tmpdir (str): Path of directory to save the temporary results from
81 | different gpus under cpu mode.
82 | gpu_collect (bool): Option to use either gpu or cpu to collect results.
83 | save_image (bool): Whether save image. Default: False.
84 | save_path (str): The path to save image. Default: None.
85 | iteration (int): Iteration number. It is used for the save image name.
86 | Default: None.
87 | empty_cache (bool): empty cache in every iteration. Default: False.
88 |
89 | Returns:
90 | list: The prediction results.
91 | """
92 |
93 | if save_image and save_path is None:
94 | raise ValueError(
95 | "When 'save_image' is True, you should also set 'save_path'.")
96 | model.eval()
97 | results = []
98 | dataset = data_loader.dataset
99 | rank, world_size = get_dist_info()
100 | if rank == 0:
101 | prog_bar = mmcv.ProgressBar(len(dataset))
102 | for data in data_loader:
103 | with torch.no_grad():
104 | result = model(
105 | test_mode=True,
106 | save_image=save_image,
107 | save_path=save_path,
108 | iteration=iteration,
109 | **data)
110 | results.append(result)
111 | if empty_cache:
112 | torch.cuda.empty_cache()
113 | if rank == 0:
114 | # get batch size
115 | for _, v in data.items():
116 | if isinstance(v, torch.Tensor):
117 | batch_size = v.size(0)
118 | break
119 | for _ in range(batch_size * world_size):
120 | prog_bar.update()
121 | # collect results from all ranks
122 | if gpu_collect:
123 | results = collect_results_gpu(results, len(dataset))
124 | else:
125 | results = collect_results_cpu(results, len(dataset), tmpdir)
126 | return results
127 |
128 |
129 | def collect_results_cpu(result_part, size, tmpdir=None):
130 | """Collect results in cpu mode.
131 |
132 | It saves the results on different gpus to 'tmpdir' and collects
133 | them by the rank 0 worker.
134 |
135 | Args:
136 | result_part (list): Results to be collected
137 | size (int): Result size.
138 | tmpdir (str): Path of directory to save the temporary results from
139 | different gpus under cpu mode. Default: None
140 |
141 | Returns:
142 | list: Ordered results.
143 | """
144 |
145 | rank, world_size = get_dist_info()
146 | # create a tmp dir if it is not specified
147 | if tmpdir is None:
148 | MAX_LEN = 512
149 | # 32 is whitespace
150 | dir_tensor = torch.full((MAX_LEN, ),
151 | 32,
152 | dtype=torch.uint8,
153 | device='cuda')
154 | if rank == 0:
155 | mmcv.mkdir_or_exist('.dist_test')
156 | tmpdir = tempfile.mkdtemp(dir='.dist_test')
157 | tmpdir = torch.tensor(
158 | bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
159 | dir_tensor[:len(tmpdir)] = tmpdir
160 | dist.broadcast(dir_tensor, 0)
161 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
162 | else:
163 | mmcv.mkdir_or_exist(tmpdir)
164 | # synchronizes all processes to make sure tmpdir exist
165 | dist.barrier()
166 | # dump the part result to the dir
167 | mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
168 | # synchronizes all processes for loading pickle file
169 | dist.barrier()
170 | # collect all parts
171 | if rank != 0:
172 | return None
173 |
174 | # load results of all parts from tmp dir
175 | part_list = []
176 | for i in range(world_size):
177 | part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
178 | part_list.append(mmcv.load(part_file))
179 | # sort the results
180 | ordered_results = []
181 | for res in zip(*part_list):
182 | ordered_results.extend(list(res))
183 | # the dataloader may pad some samples
184 | ordered_results = ordered_results[:size]
185 | # remove tmp dir
186 | shutil.rmtree(tmpdir)
187 | return ordered_results
188 |
189 |
190 | def collect_results_gpu(result_part, size):
191 | """Collect results in gpu mode.
192 |
193 | It encodes results to gpu tensors and use gpu communication for results
194 | collection.
195 |
196 | Args:
197 | result_part (list): Results to be collected
198 | size (int): Result size.
199 |
200 | Returns:
201 | list: Ordered results.
202 | """
203 |
204 | rank, world_size = get_dist_info()
205 | # dump result part to tensor with pickle
206 | part_tensor = torch.tensor(
207 | bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
208 | # gather all result part tensor shape
209 | shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
210 | shape_list = [shape_tensor.clone() for _ in range(world_size)]
211 | dist.all_gather(shape_list, shape_tensor)
212 | # padding result part tensor to max length
213 | shape_max = torch.tensor(shape_list).max()
214 | part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
215 | part_send[:shape_tensor[0]] = part_tensor
216 | part_recv_list = [
217 | part_tensor.new_zeros(shape_max) for _ in range(world_size)
218 | ]
219 | # gather all result part
220 | dist.all_gather(part_recv_list, part_send)
221 |
222 | if rank != 0:
223 | return None
224 |
225 | part_list = []
226 | for recv, shape in zip(part_recv_list, shape_list):
227 | part_list.append(pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
228 | # sort the results
229 | ordered_results = []
230 | for res in zip(*part_list):
231 | ordered_results.extend(list(res))
232 | # the dataloader may pad some samples
233 | ordered_results = ordered_results[:size]
234 | return ordered_results
235 |
--------------------------------------------------------------------------------
/mmedit/datasets/pipelines/formating.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from collections.abc import Sequence
3 |
4 | import mmcv
5 | import numpy as np
6 | import torch
7 | from mmcv.parallel import DataContainer as DC
8 | from torch.nn import functional as F
9 |
10 | from ..registry import PIPELINES
11 |
12 |
13 | def to_tensor(data):
14 | """Convert objects of various python types to :obj:`torch.Tensor`.
15 |
16 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
17 | :class:`Sequence`, :class:`int` and :class:`float`.
18 | """
19 | if isinstance(data, torch.Tensor):
20 | return data
21 | if isinstance(data, np.ndarray):
22 | return torch.from_numpy(data)
23 | if isinstance(data, Sequence) and not mmcv.is_str(data):
24 | return torch.tensor(data)
25 | if isinstance(data, int):
26 | return torch.LongTensor([data])
27 | if isinstance(data, float):
28 | return torch.FloatTensor([data])
29 |
30 | raise TypeError(f'type {type(data)} cannot be converted to tensor.')
31 |
32 |
33 | @PIPELINES.register_module()
34 | class ToTensor:
35 | """Convert some values in results dict to `torch.Tensor` type in data
36 | loader pipeline.
37 |
38 | Args:
39 | keys (Sequence[str]): Required keys to be converted.
40 | """
41 |
42 | def __init__(self, keys):
43 | self.keys = keys
44 |
45 | def __call__(self, results):
46 | """Call function.
47 |
48 | Args:
49 | results (dict): A dict containing the necessary information and
50 | data for augmentation.
51 |
52 | Returns:
53 | dict: A dict containing the processed data and information.
54 | """
55 | for key in self.keys:
56 | results[key] = to_tensor(results[key])
57 | return results
58 |
59 | def __repr__(self):
60 | return self.__class__.__name__ + f'(keys={self.keys})'
61 |
62 |
63 | @PIPELINES.register_module()
64 | class ImageToTensor:
65 | """Convert image type to `torch.Tensor` type.
66 |
67 | Args:
68 | keys (Sequence[str]): Required keys to be converted.
69 | to_float32 (bool): Whether convert numpy image array to np.float32
70 | before converted to tensor. Default: True.
71 | """
72 |
73 | def __init__(self, keys, to_float32=True):
74 | self.keys = keys
75 | self.to_float32 = to_float32
76 |
77 | def __call__(self, results):
78 | """Call function.
79 |
80 | Args:
81 | results (dict): A dict containing the necessary information and
82 | data for augmentation.
83 |
84 | Returns:
85 | dict: A dict containing the processed data and information.
86 | """
87 | for key in self.keys:
88 | # deal with gray scale img: expand a color channel
89 | if len(results[key].shape) == 2:
90 | results[key] = results[key][..., None]
91 | if self.to_float32 and not isinstance(results[key], np.float32):
92 | results[key] = results[key].astype(np.float32)
93 | results[key] = to_tensor(results[key].transpose(2, 0, 1))
94 | return results
95 |
96 | def __repr__(self):
97 | return self.__class__.__name__ + (
98 | f'(keys={self.keys}, to_float32={self.to_float32})')
99 |
100 |
101 | @PIPELINES.register_module()
102 | class FramesToTensor(ImageToTensor):
103 | """Convert frames type to `torch.Tensor` type.
104 |
105 | It accepts a list of frames, converts each to `torch.Tensor` type and then
106 | concatenates in a new dimension (dim=0).
107 |
108 | Args:
109 | keys (Sequence[str]): Required keys to be converted.
110 | to_float32 (bool): Whether convert numpy image array to np.float32
111 | before converted to tensor. Default: True.
112 | """
113 |
114 | def __call__(self, results):
115 | """Call function.
116 |
117 | Args:
118 | results (dict): A dict containing the necessary information and
119 | data for augmentation.
120 |
121 | Returns:
122 | dict: A dict containing the processed data and information.
123 | """
124 | for key in self.keys:
125 | if not isinstance(results[key], list):
126 | raise TypeError(f'results["{key}"] should be a list, '
127 | f'but got {type(results[key])}')
128 | for idx, v in enumerate(results[key]):
129 | # deal with gray scale img: expand a color channel
130 | if len(v.shape) == 2:
131 | v = v[..., None]
132 | if self.to_float32 and not isinstance(v, np.float32):
133 | v = v.astype(np.float32)
134 | results[key][idx] = to_tensor(v.transpose(2, 0, 1))
135 | results[key] = torch.stack(results[key], dim=0)
136 | if results[key].size(0) == 1:
137 | results[key].squeeze_()
138 | return results
139 |
140 |
141 | @PIPELINES.register_module()
142 | class GetMaskedImage:
143 | """Get masked image.
144 |
145 | Args:
146 | img_name (str): Key for clean image.
147 | mask_name (str): Key for mask image. The mask shape should be
148 | (h, w, 1) while '1' indicate holes and '0' indicate valid
149 | regions.
150 | """
151 |
152 | def __init__(self, img_name='gt_img', mask_name='mask'):
153 | self.img_name = img_name
154 | self.mask_name = mask_name
155 |
156 | def __call__(self, results):
157 | """Call function.
158 |
159 | Args:
160 | results (dict): A dict containing the necessary information and
161 | data for augmentation.
162 |
163 | Returns:
164 | dict: A dict containing the processed data and information.
165 | """
166 | clean_img = results[self.img_name]
167 | mask = results[self.mask_name]
168 |
169 | masked_img = clean_img * (1. - mask)
170 | results['masked_img'] = masked_img
171 |
172 | return results
173 |
174 | def __repr__(self):
175 | return self.__class__.__name__ + (
176 | f"(img_name='{self.img_name}', mask_name='{self.mask_name}')")
177 |
178 |
179 | @PIPELINES.register_module()
180 | class FormatTrimap:
181 | """Convert trimap (tensor) to one-hot representation.
182 |
183 | It transforms the trimap label from (0, 128, 255) to (0, 1, 2). If
184 | ``to_onehot`` is set to True, the trimap will convert to one-hot tensor of
185 | shape (3, H, W). Required key is "trimap", added or modified key are
186 | "trimap" and "to_onehot".
187 |
188 | Args:
189 | to_onehot (bool): whether convert trimap to one-hot tensor. Default:
190 | ``False``.
191 | """
192 |
193 | def __init__(self, to_onehot=False):
194 | self.to_onehot = to_onehot
195 |
196 | def __call__(self, results):
197 | """Call function.
198 |
199 | Args:
200 | results (dict): A dict containing the necessary information and
201 | data for augmentation.
202 |
203 | Returns:
204 | dict: A dict containing the processed data and information.
205 | """
206 | trimap = results['trimap'].squeeze()
207 | trimap[trimap == 128] = 1
208 | trimap[trimap == 255] = 2
209 | if self.to_onehot:
210 | trimap = F.one_hot(trimap.to(torch.long), num_classes=3)
211 | trimap = trimap.permute(2, 0, 1)
212 | else:
213 | trimap = trimap[None, ...] # expand the channels dimension
214 | results['trimap'] = trimap.float()
215 | results['meta'].data['to_onehot'] = self.to_onehot
216 | return results
217 |
218 | def __repr__(self):
219 | return self.__class__.__name__ + f'(to_onehot={self.to_onehot})'
220 |
221 |
222 | @PIPELINES.register_module()
223 | class Collect:
224 | """Collect data from the loader relevant to the specific task.
225 |
226 | This is usually the last stage of the data loader pipeline. Typically keys
227 | is set to some subset of "img", "gt_labels".
228 |
229 | The "img_meta" item is always populated. The contents of the "meta"
230 | dictionary depends on "meta_keys".
231 |
232 | Args:
233 | keys (Sequence[str]): Required keys to be collected.
234 | meta_keys (Sequence[str]): Required keys to be collected to "meta".
235 | Default: None.
236 | """
237 |
238 | def __init__(self, keys, meta_keys=None):
239 | self.keys = keys
240 | self.meta_keys = meta_keys
241 |
242 | def __call__(self, results):
243 | """Call function.
244 |
245 | Args:
246 | results (dict): A dict containing the necessary information and
247 | data for augmentation.
248 |
249 | Returns:
250 | dict: A dict containing the processed data and information.
251 | """
252 | data = {}
253 | img_meta = {}
254 | for key in self.meta_keys:
255 | img_meta[key] = results[key]
256 | data['meta'] = DC(img_meta, cpu_only=True)
257 | for key in self.keys:
258 | data[key] = results[key]
259 | return data
260 |
261 | def __repr__(self):
262 | return self.__class__.__name__ + (
263 | f'(keys={self.keys}, meta_keys={self.meta_keys})')
264 |
--------------------------------------------------------------------------------