├── model ├── __init__.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── CPB.cpython-39.pyc │ │ ├── KAN.cpython-39.pyc │ │ ├── LAN.cpython-39.pyc │ │ ├── SSF.cpython-39.pyc │ │ ├── drop.cpython-39.pyc │ │ ├── KANConv.cpython-39.pyc │ │ ├── cswin.cpython-39.pyc │ │ ├── urwkv.cpython-39.pyc │ │ ├── vrwkv.cpython-39.pyc │ │ ├── ConvLSTM.cpython-39.pyc │ │ ├── KANLinear.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── adaptLDA.cpython-39.pyc │ │ ├── lca_block.cpython-39.pyc │ │ ├── conv_block.cpython-39.pyc │ │ ├── convolution.cpython-39.pyc │ │ ├── enhanceBlock.cpython-39.pyc │ │ ├── fusionLayer.cpython-39.pyc │ │ ├── lca_block2.cpython-39.pyc │ │ ├── ParallelCPB_P1.cpython-39.pyc │ │ ├── ParallelCPB_P2.cpython-39.pyc │ │ ├── ParallelCPB_P3.cpython-39.pyc │ │ ├── gaussianFilter.cpython-39.pyc │ │ ├── BilateralCPB_B1.cpython-39.pyc │ │ ├── BilateralCPB_B2.cpython-39.pyc │ │ ├── SequentialCPB_S1.cpython-39.pyc │ │ ├── SequentialCPB_S2.cpython-39.pyc │ │ ├── adaptTransformer.cpython-39.pyc │ │ ├── baseTransformer.cpython-39.pyc │ │ ├── transformer_block.cpython-39.pyc │ │ ├── MixedChannelFusion.cpython-39.pyc │ │ └── adaptDualAttention.cpython-39.pyc │ ├── cuda │ │ └── wkv_op.cpp │ ├── drop.py │ ├── SSF.py │ └── LAN.py ├── __pycache__ │ ├── loss.cpython-39.pyc │ ├── decoder.cpython-39.pyc │ ├── encoder.cpython-39.pyc │ ├── LLFormer.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── model_builder.cpython-39.pyc │ ├── model_builder_wGRT.cpython-39.pyc │ ├── model_builder_wLCA.cpython-39.pyc │ ├── model_builder_ParallelCPB.cpython-39.pyc │ ├── model_builder_SequentialCPB.cpython-39.pyc │ └── model_builder_GlobalLocalCPB.cpython-39.pyc ├── model_builder.py ├── encoder.py ├── decoder.py └── loss.py ├── custom_utils ├── metrics │ ├── __init__.py │ ├── svr_brisque.joblib │ ├── niqe_image_params.mat │ ├── __pycache__ │ │ ├── piqe.cpython-39.pyc │ │ ├── brisque.cpython-39.pyc │ │ └── __init__.cpython-39.pyc │ ├── brisque.py │ ├── niqe.py │ └── piqe.py ├── data_loaders │ ├── __init__.py │ ├── __pycache__ │ │ ├── lol.cpython-37.pyc │ │ ├── lol.cpython-39.pyc │ │ ├── lol_v1.cpython-37.pyc │ │ ├── lol_v2.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── data_RGB.cpython-39.pyc │ │ ├── exposure.cpython-37.pyc │ │ ├── exposure.cpython-39.pyc │ │ ├── lol_v1_new.cpython-39.pyc │ │ ├── dataset_RGB.cpython-39.pyc │ │ ├── load_derain.cpython-39.pyc │ │ └── lol_v1_whole.cpython-39.pyc │ ├── exposure.py │ ├── mit5k.py │ └── lol.py ├── warmup_scheduler │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── scheduler.cpython-37.pyc │ │ └── scheduler.cpython-39.pyc │ ├── run.py │ └── scheduler.py ├── WT │ ├── __init__.py │ └── transform.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── losses.cpython-39.pyc │ ├── plotting.cpython-39.pyc │ ├── visdom.cpython-39.pyc │ ├── dir_utils.cpython-39.pyc │ ├── img_resize.cpython-39.pyc │ ├── validation.cpython-39.pyc │ ├── dataset_utils.cpython-39.pyc │ ├── image_utils.cpython-39.pyc │ ├── lr_scheduler.cpython-39.pyc │ ├── model_utils.cpython-39.pyc │ └── preprocessing.cpython-39.pyc ├── UHD_load │ ├── __pycache__ │ │ ├── val_data_train.cpython-39.pyc │ │ └── train_data_aug_local.cpython-39.pyc │ ├── val_data_train.py │ ├── val_data.py │ └── train_data_aug_local.py ├── dir_utils.py ├── dataset_utils.py ├── img_resize.py ├── GaussianBlur.py ├── model_utils.py ├── validation.py ├── model_load.py ├── preprocessing.py ├── plotting.py ├── image_utils.py ├── losses.py └── lr_scheduler.py ├── README_md_files ├── 6cf966f0-5190-11f0-847b-8bd8db6e5334.jpeg ├── 9e45a430-5190-11f0-847b-8bd8db6e5334.jpeg └── e4f9c500-5190-11f0-847b-8bd8db6e5334.jpeg ├── configs ├── FiveK.yaml ├── SMID.yaml ├── LOL_blur.yaml ├── SDSD_indoor.yaml ├── LOL_v2_real.yaml ├── SDSD_outdoor.yaml ├── LOL_v1.yaml ├── SID.yaml └── LOL_v2_synthetic.yaml ├── tools ├── test.py ├── measure.py └── train.py ├── test.sh ├── README.md ├── train.sh └── LICENSE /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /custom_utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /custom_utils/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /custom_utils/WT/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import * -------------------------------------------------------------------------------- /model/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/metrics/svr_brisque.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/metrics/svr_brisque.joblib -------------------------------------------------------------------------------- /model/__pycache__/decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/decoder.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/encoder.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/metrics/niqe_image_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/metrics/niqe_image_params.mat -------------------------------------------------------------------------------- /model/__pycache__/LLFormer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/LLFormer.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/CPB.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/CPB.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/KAN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/KAN.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/LAN.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/LAN.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/SSF.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/SSF.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/drop.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/drop.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * -------------------------------------------------------------------------------- /custom_utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/losses.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/losses.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/plotting.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/plotting.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/visdom.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/visdom.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/model_builder.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/KANConv.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/KANConv.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/cswin.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/cswin.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/urwkv.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/urwkv.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/vrwkv.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/vrwkv.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/dir_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/dir_utils.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/img_resize.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/img_resize.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/validation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/validation.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/ConvLSTM.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/ConvLSTM.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/KANLinear.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/KANLinear.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/adaptLDA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/adaptLDA.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/lca_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/lca_block.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/dataset_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/dataset_utils.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/image_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/image_utils.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/lr_scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/lr_scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/model_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/model_utils.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/__pycache__/preprocessing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/__pycache__/preprocessing.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/metrics/__pycache__/piqe.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/metrics/__pycache__/piqe.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_builder_wGRT.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/model_builder_wGRT.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_builder_wLCA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/model_builder_wLCA.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/conv_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/conv_block.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/convolution.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/convolution.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/enhanceBlock.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/enhanceBlock.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/fusionLayer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/fusionLayer.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/lca_block2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/lca_block2.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/metrics/__pycache__/brisque.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/metrics/__pycache__/brisque.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/ParallelCPB_P1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/ParallelCPB_P1.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/ParallelCPB_P2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/ParallelCPB_P2.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/ParallelCPB_P3.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/ParallelCPB_P3.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/gaussianFilter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/gaussianFilter.cpython-39.pyc -------------------------------------------------------------------------------- /README_md_files/6cf966f0-5190-11f0-847b-8bd8db6e5334.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/README_md_files/6cf966f0-5190-11f0-847b-8bd8db6e5334.jpeg -------------------------------------------------------------------------------- /README_md_files/9e45a430-5190-11f0-847b-8bd8db6e5334.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/README_md_files/9e45a430-5190-11f0-847b-8bd8db6e5334.jpeg -------------------------------------------------------------------------------- /README_md_files/e4f9c500-5190-11f0-847b-8bd8db6e5334.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/README_md_files/e4f9c500-5190-11f0-847b-8bd8db6e5334.jpeg -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/lol.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/lol.cpython-37.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/lol.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/lol.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/metrics/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/metrics/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_builder_ParallelCPB.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/model_builder_ParallelCPB.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/BilateralCPB_B1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/BilateralCPB_B1.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/BilateralCPB_B2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/BilateralCPB_B2.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/SequentialCPB_S1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/SequentialCPB_S1.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/SequentialCPB_S2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/SequentialCPB_S2.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/adaptTransformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/adaptTransformer.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/baseTransformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/baseTransformer.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/transformer_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/transformer_block.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/lol_v1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/lol_v1.cpython-37.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/lol_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/lol_v2.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_builder_SequentialCPB.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/model_builder_SequentialCPB.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/MixedChannelFusion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/MixedChannelFusion.cpython-39.pyc -------------------------------------------------------------------------------- /model/modules/__pycache__/adaptDualAttention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/modules/__pycache__/adaptDualAttention.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/UHD_load/__pycache__/val_data_train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/UHD_load/__pycache__/val_data_train.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/data_RGB.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/data_RGB.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/exposure.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/exposure.cpython-37.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/exposure.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/exposure.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/lol_v1_new.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/lol_v1_new.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_builder_GlobalLocalCPB.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/model/__pycache__/model_builder_GlobalLocalCPB.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/dataset_RGB.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/dataset_RGB.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/load_derain.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/load_derain.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/data_loaders/__pycache__/lol_v1_whole.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/data_loaders/__pycache__/lol_v1_whole.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/warmup_scheduler/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/warmup_scheduler/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/warmup_scheduler/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/__pycache__/scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/warmup_scheduler/__pycache__/scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/UHD_load/__pycache__/train_data_aug_local.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FZU-N/URWKV/HEAD/custom_utils/UHD_load/__pycache__/train_data_aug_local.cpython-39.pyc -------------------------------------------------------------------------------- /custom_utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | 6 | def mkdirs(paths): 7 | if isinstance(paths, list) and not isinstance(paths, str): 8 | for path in paths: 9 | mkdir(path) 10 | else: 11 | mkdir(paths) 12 | 13 | 14 | def mkdir(path): 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | 18 | 19 | def get_last_path(path, session): 20 | x = natsorted(glob(os.path.join(path, '*%s' % session)))[-1] 21 | return x 22 | -------------------------------------------------------------------------------- /configs/FiveK.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [0] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'MIT_5K' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | BATCH: 8 12 | EPOCHS: 1000 13 | # EPOCH_DECAY: [10] 14 | LR_INITIAL: 1e-4 15 | LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 10 20 | RESUME: False 21 | PATCH_SIZES: [128, 256, 384] 22 | BATCH_SIZES: [4, 2, 2] 23 | EPOCHS_PER_SIZE: [400, 400, 200] 24 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/MIT-Adobe-5K-512/train' # path to training data 25 | VAL_DIR: '/data/xr/Dataset/light_dataset/MIT-Adobe-5K-512/test' # path to validation data 26 | SAVE_DIR: './checkpoints/MIT_5K/' # path to save models and images 27 | -------------------------------------------------------------------------------- /configs/SMID.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'SMID' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 10 # 1 20 | RESUME: False # False 21 | PATCH_SIZES: [256, 384, 512] 22 | BATCH_SIZES: [2, 2, 1] 23 | EPOCHS_PER_SIZE: [400, 400, 200] 24 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/SMID_png/train/' # path to training data 25 | VAL_DIR: '/data/xr/Dataset/light_dataset/SMID_png/eval/' # path to validation data 26 | SAVE_DIR: './checkpoints/SMID/' # path to save models and images 27 | -------------------------------------------------------------------------------- /configs/LOL_blur.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'LOL_blur' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 10 20 | RESUME: False # False 21 | PATCH_SIZES: [256, 384, 512] 22 | BATCH_SIZES: [2, 2, 1] 23 | EPOCHS_PER_SIZE: [400, 400, 200] 24 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/LOL_blur/train' # path to training data 25 | VAL_DIR: '/data/xr/Dataset/light_dataset/LOL_blur/eval/' # path to validation data 26 | SAVE_DIR: './checkpoints/LOL_blur/' # path to save models and images 27 | -------------------------------------------------------------------------------- /configs/SDSD_indoor.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'SDSD_indoor' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 1 # 1 20 | RESUME: False # False 21 | PATCH_SIZES: [256, 384, 512] 22 | BATCH_SIZES: [2, 2, 1] 23 | EPOCHS_PER_SIZE: [400, 400, 200] 24 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/SDSD_indoor_png/train/' # path to training data 25 | VAL_DIR: '/data/xr/Dataset/light_dataset/SDSD_indoor_png/eval/' # path to validation data 26 | SAVE_DIR: './checkpoints/SDSD_indoor/' # path to save models and images 27 | -------------------------------------------------------------------------------- /configs/LOL_v2_real.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'LOL_v2_real' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 1 20 | RESUME: False # False 21 | PATCH_SIZES: [256, 384, 400] 22 | BATCH_SIZES: [2, 2, 1] 23 | EPOCHS_PER_SIZE: [400, 400, 200] 24 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/LOL_v2/Real_captured/Train' # path to training data 25 | VAL_DIR: '/data/xr/Dataset/light_dataset/LOL_v2/Real_captured/Test' # path to validation data 26 | SAVE_DIR: './checkpoints/LOL_v2_real/' # path to save models and images 27 | -------------------------------------------------------------------------------- /configs/SDSD_outdoor.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'SDSD_outdoor' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 1 # 1 20 | RESUME: False # False 21 | PATCH_SIZES: [128, 256, 384, 512] 22 | BATCH_SIZES: [4, 2, 2, 1] 23 | EPOCHS_PER_SIZE: [400, 400, 200, 200] 24 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/SDSD_outdoor_png/train/' # path to training data 25 | VAL_DIR: '/data/xr/Dataset/light_dataset/SDSD_outdoor_png/eval/' # path to validation data 26 | SAVE_DIR: './checkpoints/SDSD_outdoor/' # path to save models and images 27 | -------------------------------------------------------------------------------- /custom_utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from prefetch_generator import BackgroundGenerator 4 | 5 | class MixUp_AUG: 6 | def __init__(self): 7 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 8 | 9 | def aug(self, rgb_gt, rgb_noisy): 10 | bs = rgb_gt.size(0) 11 | indices = torch.randperm(bs) 12 | rgb_gt2 = rgb_gt[indices] 13 | rgb_noisy2 = rgb_noisy[indices] 14 | 15 | lam = self.dist.rsample((bs, 1)).view(-1, 1, 1, 1).cuda() 16 | 17 | rgb_gt = lam * rgb_gt + (1 - lam) * rgb_gt2 18 | rgb_noisy = lam * rgb_noisy + (1 - lam) * rgb_noisy2 19 | 20 | return rgb_gt, rgb_noisy 21 | 22 | class DataLoaderX(DataLoader): 23 | def __iter__(self): 24 | return BackgroundGenerator(super().__iter__()) 25 | -------------------------------------------------------------------------------- /configs/LOL_v1.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'LOL_v1' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 1 20 | RESUME: False # False 21 | PATCH_SIZES: [256, 384, 400] 22 | BATCH_SIZES: [2, 2, 1] 23 | EPOCHS_PER_SIZE: [400, 400, 200] 24 | # PATCH_SIZES: [128, 192, 256, 320, 384, 400] 25 | # BATCH_SIZES: [16, 16, 8, 8, 4, 2] 26 | # EPOCHS_PER_SIZE: [300, 200, 200, 100, 100 ,100] 27 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/LOL_v1/our485' # path to training data 28 | VAL_DIR: '/data/xr/Dataset/light_dataset/LOL_v1/eval15' # path to validation data 29 | SAVE_DIR: './checkpoints/LOL_v1/' # path to save models and images 30 | -------------------------------------------------------------------------------- /configs/SID.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'SID' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 10 # 1 20 | RESUME: False # False 21 | PATCH_SIZES: [128, 256, 384, 600] 22 | BATCH_SIZES: [4, 2, 2, 1] 23 | EPOCHS_PER_SIZE: [250, 250, 250, 250] 24 | # PATCH_SIZES: [128, 192, 256, 320, 384, 400] 25 | # BATCH_SIZES: [16, 16, 8, 8, 4, 2] 26 | # EPOCHS_PER_SIZE: [300, 200, 200, 100, 100 ,100] 27 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/SID_png/train/' # path to training data 28 | VAL_DIR: '/data/xr/Dataset/light_dataset/SID_png/eval/' # path to validation data 29 | SAVE_DIR: './checkpoints/SID/' # path to save models and images 30 | -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from .scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /configs/LOL_v2_synthetic.yaml: -------------------------------------------------------------------------------- 1 | # Training configuration 2 | GPU: [1] 3 | 4 | VERBOSE: False 5 | 6 | MODEL: 7 | MODE: 'LOL_v2_synthetic' 8 | 9 | # Optimization arguments. 10 | OPTIM: 11 | # BATCH: 4 # 8 12 | # EPOCHS: 1000 #200 13 | # # EPOCH_DECAY: [10] 14 | # LR_INITIAL: 2e-4 # 1e-4 15 | # LR_MIN: 1e-6 16 | # # BETA1: 0.9 17 | 18 | TRAINING: 19 | VAL_AFTER_EVERY: 1 20 | RESUME: False # False 21 | PATCH_SIZES: [256, 384] 22 | BATCH_SIZES: [2, 2] 23 | EPOCHS_PER_SIZE: [600, 400] 24 | # PATCH_SIZES: [128, 192, 256, 320, 384, 400] 25 | # BATCH_SIZES: [16, 16, 8, 8, 4, 2] 26 | # EPOCHS_PER_SIZE: [300, 200, 200, 100, 100 ,100] 27 | TRAIN_DIR: '/data/xr/Dataset/light_dataset/LOL_v2/Synthetic/Train' # path to training data 28 | VAL_DIR: '/data/xr/Dataset/light_dataset/LOL_v2/Synthetic/Test' # path to validation data 29 | SAVE_DIR: './checkpoints/LOL_v2_sync/' # path to save models and images 30 | -------------------------------------------------------------------------------- /model/modules/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /model/model_builder.py: -------------------------------------------------------------------------------- 1 | # troch imports 2 | import torch 3 | from torch import nn 4 | from torchvision import datasets, transforms 5 | from timm.models.layers import trunc_normal_ 6 | import torch.nn.functional as F 7 | 8 | # other imports 9 | import numpy as np 10 | import os 11 | import math 12 | 13 | # own files import 14 | from .encoder import Encoder 15 | from .decoder import Decoder 16 | from .modules.urwkv import URWKV 17 | 18 | # recursive network based on residual units 19 | class LLENet(nn.Module): 20 | def __init__(self, dim): 21 | super().__init__() 22 | self.dim = dim 23 | 24 | self.encoder = Encoder(dim=self.dim) # 3 -> 32 -> 64 -> 128 25 | self.decoder = Decoder(dim=self.dim) # 128 -> 64 -> 32 -> 3 26 | 27 | self.apply(self._init_weights) # Correctly apply init_weights to all submodules 28 | 29 | def _init_weights(self, m): 30 | if isinstance(m, nn.Linear): 31 | trunc_normal_(m.weight, std=.02) 32 | if isinstance(m, nn.Linear) and m.bias is not None: 33 | nn.init.constant_(m.bias, 0) 34 | elif isinstance(m, nn.LayerNorm): 35 | nn.init.constant_(m.bias, 0) 36 | nn.init.constant_(m.weight, 1.0) 37 | 38 | def forward(self, x): 39 | outer_shortcut = x 40 | inter_feat = [] 41 | encode_list, inter_feat = self.encoder(x, inter_feat) 42 | 43 | x = encode_list[-1] 44 | x = self.decoder(x, encode_list, inter_feat) 45 | x=torch.add(x, outer_shortcut) 46 | 47 | return x 48 | 49 | 50 | -------------------------------------------------------------------------------- /model/modules/drop.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 5 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 6 | 7 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 8 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 9 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 10 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 11 | 'survival rate' as the argument. 12 | 13 | """ 14 | if drop_prob == 0. or not training: 15 | return x 16 | keep_prob = 1 - drop_prob 17 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 18 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 19 | if keep_prob > 0.0 and scale_by_keep: 20 | random_tensor.div_(keep_prob) 21 | return x * random_tensor 22 | 23 | 24 | class DropPath(nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 26 | """ 27 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 28 | super(DropPath, self).__init__() 29 | self.drop_prob = drop_prob 30 | self.scale_by_keep = scale_by_keep 31 | 32 | def forward(self, x): 33 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 34 | 35 | def extra_repr(self): 36 | return f'drop_prob={round(self.drop_prob,3):0.3f}' -------------------------------------------------------------------------------- /custom_utils/UHD_load/val_data_train.py: -------------------------------------------------------------------------------- 1 | # --- Imports --- # 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from torchvision.transforms import Compose, ToTensor, Normalize, Resize 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # --- Validation/test dataset --- # 11 | class ValData_train(data.Dataset): 12 | def __init__(self, val_data_dir): 13 | super().__init__() 14 | val_list = val_data_dir + 'data_list.txt'#'final_test_datalist.txt'eval15/datalist.txt 15 | with open(val_list) as f: 16 | contents = f.readlines() 17 | lowlight_names = [i.strip() for i in contents] 18 | gt_names = lowlight_names#[i.split('_')[0] + '.png' for i in lowlight_names] 19 | 20 | self.lowlight_names = lowlight_names 21 | self.gt_names = gt_names 22 | self.val_data_dir = val_data_dir 23 | self.data_list=val_list 24 | def get_images(self, index): 25 | lowlight_name = self.lowlight_names[index] 26 | gt_name = self.gt_names[index] 27 | lowlight_img = Image.open(self.val_data_dir + 'low/' + lowlight_name)#eval15/low/ 28 | gt_img = Image.open(self.val_data_dir + 'high/' + gt_name) #eval15/high/ 29 | transform_lowlight = Compose([ToTensor()]) 30 | transform_gt = Compose([ToTensor()]) 31 | lowlight = transform_lowlight(lowlight_img) 32 | gt = transform_gt(gt_img) 33 | return lowlight, gt,lowlight_name # 34 | 35 | def __getitem__(self, index): 36 | res = self.get_images(index) 37 | return res 38 | 39 | def __len__(self): 40 | return len(self.lowlight_names) 41 | -------------------------------------------------------------------------------- /custom_utils/img_resize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | import cv2 9 | import os 10 | import math 11 | 12 | 13 | def pad_input(net_input, local_window_dim=2): 14 | _, _, h_old, w_old = net_input.size() 15 | h_original = h_old 16 | w_original = w_old 17 | multiplier = max(h_old // local_window_dim + 1, w_old // local_window_dim + 1) 18 | h_pad = (multiplier) * local_window_dim - h_old 19 | w_pad = (multiplier) * local_window_dim - w_old 20 | net_input = torch.cat([net_input, torch.flip(net_input, [2])], 2)[:, :, :h_old + h_pad, :] 21 | net_input = torch.cat([net_input, torch.flip(net_input, [3])], 3)[:, :, :, :w_old + w_pad] 22 | 23 | if h_pad > h_old or w_pad > w_old: 24 | _, _, h_old, w_old = net_input.size() 25 | multiplier = max(h_old // local_window_dim + 1, w_old // local_window_dim + 1) 26 | h_pad = (multiplier) * local_window_dim - h_old 27 | w_pad = (multiplier) * local_window_dim - w_old 28 | net_input = torch.cat([net_input, torch.flip(net_input, [2])], 2)[:, :, :h_old + h_pad, :] 29 | net_input = torch.cat([net_input, torch.flip(net_input, [3])], 3)[:, :, :, :w_old + w_pad] 30 | 31 | return net_input 32 | 33 | 34 | def crop_output(net_input, net_output): 35 | _, _, h_old, w_old = net_input.size() 36 | h_original = h_old 37 | w_original = w_old 38 | net_output = net_output[:,:,:h_original, :w_original] 39 | # output_data = net_output.cpu().detach().numpy() #B C H W 40 | # output_data = np.transpose(output_data, (0,2,3,1)) #B H W C 41 | output_data = net_output 42 | 43 | return output_data -------------------------------------------------------------------------------- /custom_utils/GaussianBlur.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | import math 3 | 4 | def get_gaussian_kernel(kernel_size=21, sigma=5, channels=3): 5 | #if not kernel_size: kernel_size = int(2*np.ceil(2*sigma)+1) 6 | #print("Kernel is: ",kernel_size) 7 | #print("Sigma is: ",sigma) 8 | padding = kernel_size//2 9 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 10 | x_coord = torch.arange(kernel_size) 11 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) 12 | y_grid = x_grid.t() 13 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 14 | 15 | mean = (kernel_size - 1)/2. 16 | variance = sigma**2. 17 | 18 | # Calculate the 2-dimensional gaussian kernel which is 19 | # the product of two gaussian distributions for two different 20 | # variables (in this case called x and y) 21 | gaussian_kernel = (1./(2.*math.pi*variance)) *\ 22 | torch.exp( 23 | -torch.sum((xy_grid - mean)**2., dim=-1) /\ 24 | (2*variance) 25 | ) 26 | 27 | # Make sure sum of values in gaussian kernel equals 1. 28 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 29 | 30 | # Reshape to 2d depthwise convolutional weight 31 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 32 | gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) 33 | 34 | gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, 35 | kernel_size=kernel_size, groups=channels, bias=False) 36 | 37 | gaussian_filter.weight.data = gaussian_kernel 38 | gaussian_filter.weight.requires_grad = False 39 | 40 | return gaussian_filter, padding 41 | -------------------------------------------------------------------------------- /custom_utils/WT/transform.py: -------------------------------------------------------------------------------- 1 | ## Ultra-High-Definition Low-Light Image Enhancement: A Benchmark and Transformer-Based Method 2 | ## Tao Wang, Kaihao Zhang, Tianrun Shen, Wenhan Luo, Bjorn Stenger, Tong Lu 3 | ## https://arxiv.org/pdf/2212.11548.pdf 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | def dwt_init(x): 9 | x01 = x[:, :, 0::2, :] / 2 10 | x02 = x[:, :, 1::2, :] / 2 11 | x1 = x01[:, :, :, 0::2] 12 | x2 = x02[:, :, :, 0::2] 13 | x3 = x01[:, :, :, 1::2] 14 | x4 = x02[:, :, :, 1::2] 15 | x_LL = x1 + x2 + x3 + x4 16 | x_HL = -x1 - x2 + x3 + x4 17 | x_LH = -x1 + x2 - x3 + x4 18 | x_HH = x1 - x2 - x3 + x4 19 | # print(x_HH[:, 0, :, :]) 20 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 21 | 22 | def iwt_init(x): 23 | r = 2 24 | in_batch, in_channel, in_height, in_width = x.size() 25 | out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r ** 2)), r * in_height, r * in_width 26 | x1 = x[:, 0:out_channel, :, :] / 2 27 | x2 = x[:, out_channel:out_channel * 2, :, :] / 2 28 | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 29 | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 30 | h = torch.zeros([out_batch, out_channel, out_height, out_width]).cuda() # 31 | 32 | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 33 | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 34 | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 35 | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 36 | 37 | return h 38 | 39 | 40 | class DWT(nn.Module): 41 | def __init__(self): 42 | super(DWT, self).__init__() 43 | self.requires_grad = True 44 | 45 | def forward(self, x): 46 | return dwt_init(x) 47 | 48 | 49 | class IWT(nn.Module): 50 | def __init__(self): 51 | super(IWT, self).__init__() 52 | self.requires_grad = True 53 | 54 | def forward(self, x): 55 | return iwt_init(x) 56 | 57 | 58 | -------------------------------------------------------------------------------- /custom_utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | 6 | def freeze(model): 7 | for p in model.parameters(): 8 | p.requires_grad = False 9 | 10 | 11 | def unfreeze(model): 12 | for p in model.parameters(): 13 | p.requires_grad = True 14 | 15 | 16 | def is_frozen(model): 17 | x = [p.requires_grad for p in model.parameters()] 18 | return not all(x) 19 | 20 | 21 | def save_checkpoint(model_dir, state, session): 22 | epoch = state['epoch'] 23 | model_out_path = os.path.join(model_dir, "model_epoch_{}_{}.pth".format(epoch, session)) 24 | torch.save(state, model_out_path) 25 | 26 | 27 | def load_checkpoint(model, weights): 28 | checkpoint = torch.load(weights) 29 | try: 30 | model.load_state_dict(checkpoint["state_dict"]) 31 | except: 32 | state_dict = checkpoint["state_dict"] 33 | new_state_dict = OrderedDict() 34 | for k, v in state_dict.items(): 35 | name = k[7:] # remove `module.` 36 | new_state_dict[name] = v 37 | model.load_state_dict(new_state_dict) 38 | 39 | 40 | def load_checkpoint_multigpu(model, weights): 41 | checkpoint = torch.load(weights) 42 | state_dict = checkpoint["state_dict"] 43 | new_state_dict = OrderedDict() 44 | for k, v in state_dict.items(): 45 | name = k[7:] # remove `module.` 46 | new_state_dict[name] = v 47 | model.load_state_dict(new_state_dict) 48 | 49 | 50 | def load_start_epoch(weights): 51 | checkpoint = torch.load(weights) 52 | epoch = checkpoint["epoch"] 53 | return epoch 54 | 55 | 56 | def load_optim(optimizer, weights): 57 | checkpoint = torch.load(weights) 58 | optimizer.load_state_dict(checkpoint['optimizer']) 59 | # for p in optimizer.param_groups: lr = p['lr'] 60 | # return lr 61 | 62 | 63 | def network_parameters(nets): 64 | num_params = sum(param.numel() for param in nets.parameters()) 65 | return num_params 66 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .modules.urwkv import URWKV 5 | 6 | class Encoder(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | self.dim = dim 10 | self.patch_size = 4 11 | self.head = nn.Conv2d(3, self.dim, kernel_size=3, stride=1, padding=1, bias=False) 12 | 13 | # stage1 14 | self.enhanceBlock1 = URWKV(patch_size=3, in_channels=self.dim, embed_dims=self.dim, depth=3) 15 | self.proj1 = nn.Sequential( 16 | nn.Conv2d(self.dim, self.dim, kernel_size=3, stride=1, padding=1, bias=False), 17 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 18 | ) 19 | 20 | # stage2 21 | self.enhanceBlock2 = URWKV(patch_size=3, in_channels=self.dim, embed_dims=self.dim, depth=3) 22 | self.proj2 = nn.Sequential( 23 | nn.Conv2d(self.dim, self.dim*2, kernel_size=3, stride=1, padding=1, bias=False), 24 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 25 | ) 26 | 27 | # stage3 28 | self.enhanceBlock3 = URWKV(patch_size=3, in_channels=self.dim*2, embed_dims=self.dim*2, depth=3) 29 | self.proj3 = nn.Sequential( 30 | nn.Conv2d(self.dim*2, self.dim*4, kernel_size=3, stride=1, padding=1, bias=False), 31 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 32 | ) 33 | 34 | def forward(self, x, inter_feat): 35 | B, C, H, W = x.shape 36 | x = self.head(x) 37 | inter_feat.append(x) # [1, 32, 256, 256] 38 | 39 | # stage1 40 | x1, inter_feat = self.enhanceBlock1(x, inter_feat) 41 | inter_feat.append(x1) # [1, 32, 256, 256], [1, 32, 256, 256] 42 | x1 = self.proj1(x1) # C, H, W 43 | x1_out = F.interpolate(x1, scale_factor=0.5, mode='bilinear') # down 1/2 44 | 45 | # stage2 46 | x1_out, inter_feat = self.enhanceBlock2(x1_out, inter_feat) 47 | inter_feat.append(x1_out) # [1, 32, 256, 256], [1, 32, 256, 256], [1, 32, 128, 128] 48 | x2 = self.proj2(x1_out) # 2C, H/2, W/2 49 | x2_out = F.interpolate(x2, scale_factor=0.5, mode='bilinear') # down 1/4 50 | 51 | # stage3 52 | x2_out, inter_feat = self.enhanceBlock3(x2_out, inter_feat) 53 | inter_feat.append(x2_out) # [1, 32, 256, 256], [1, 32, 256, 256], [1, 32, 128, 128], [1, 64, 64, 64] 54 | x3 = self.proj3(x2_out) # 4C, H/4, W/4 55 | x3_out = F.interpolate(x3, scale_factor=0.5, mode='bilinear') # down 1/8, 4C, H/8, W/8 56 | 57 | feat_list = [x1, x2, x3, x3_out] 58 | return feat_list, inter_feat -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | # troch imports 2 | import torch 3 | from torch import nn 4 | from torchvision import datasets, transforms 5 | from timm.models.layers import trunc_normal_ 6 | import torch.nn.functional as F 7 | 8 | # other imports 9 | import numpy as np 10 | import os 11 | import math 12 | from .modules.SSF import SSF 13 | from .modules.urwkv import URWKV 14 | 15 | class Decoder(nn.Module): 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.dim = dim 19 | self.patch_size = 4 20 | self.residual_depth = [1, 1, 1] 21 | self.recursive_depth = [1, 1, 1] 22 | self.enhanceBlock1 = URWKV(patch_size=3, in_channels=self.dim*4, embed_dims=self.dim*4, depth=2) 23 | self.proj1 = nn.Sequential( 24 | nn.Conv2d(self.dim*4, self.dim*2, kernel_size=3, stride=1, padding=1, bias=False), 25 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 26 | ) 27 | 28 | 29 | self.enhanceBlock2 = URWKV(patch_size=3, in_channels=self.dim*2, embed_dims=self.dim*2, depth=2) 30 | self.proj2 = nn.Sequential( 31 | nn.Conv2d(self.dim*2, self.dim, kernel_size=3, stride=1, padding=1, bias=False), 32 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 33 | ) 34 | 35 | 36 | self.enhanceBlock3 = URWKV(patch_size=3, in_channels=self.dim, embed_dims=self.dim, depth=2) 37 | self.proj3 = nn.Sequential( 38 | nn.Conv2d(self.dim, self.dim, kernel_size=3, stride=1, padding=1, bias=False), 39 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 40 | ) 41 | 42 | self.tail = nn.Conv2d(self.dim, 3, kernel_size=3, stride=1, padding=1, bias=False) 43 | 44 | 45 | self.multiscale_fuse1 = SSF(num_feats=3, encode_channels=[self.dim*4, self.dim*2, self.dim], target_channels=self.dim) 46 | self.multiscale_fuse2 = SSF(num_feats=3, encode_channels=[self.dim*4, self.dim*2, self.dim], target_channels=self.dim*2) 47 | self.multiscale_fuse3 = SSF(num_feats=3, encode_channels=[self.dim*4, self.dim*2, self.dim], target_channels=self.dim*4) 48 | 49 | self.upSample = nn.Upsample(scale_factor=2, mode="bilinear") 50 | 51 | def forward(self, x, encode_list, inter_feat): 52 | feat_1s, feat_2s, feat_4s = encode_list[0], encode_list[1], encode_list[2] 53 | 54 | 55 | x1 = self.multiscale_fuse3(self.upSample(x),encode_list[:3]) 56 | x1, inter_feat = self.enhanceBlock1(x1, inter_feat) 57 | inter_feat.append(x1) 58 | x1 = self.proj1(x1) 59 | 60 | x2 = self.multiscale_fuse2(self.upSample(x1),encode_list[:3]) 61 | x2, inter_feat = self.enhanceBlock2(x2, inter_feat) 62 | inter_feat.append(x2) 63 | x2 = self.proj2(x2) 64 | 65 | x3 = self.multiscale_fuse1(self.upSample(x2),encode_list[:3]) 66 | x3, inter_feat = self.enhanceBlock3(x3, inter_feat) 67 | inter_feat.append(x3) 68 | x3 = self.proj3(x3) 69 | out = self.tail(x3) 70 | return out 71 | -------------------------------------------------------------------------------- /custom_utils/validation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | import cv2 9 | import os 10 | import math 11 | from IQA_pytorch import SSIM, MS_SSIM 12 | from .img_resize import pad_input, crop_output 13 | 14 | EPS = 1e-3 15 | PI = 22.0 / 7.0 16 | # calculate PSNR 17 | class PSNR(nn.Module): 18 | def __init__(self, max_val=0): 19 | super().__init__() 20 | 21 | base10 = torch.log(torch.tensor(10.0)) 22 | max_val = torch.tensor(max_val).float() 23 | 24 | self.register_buffer('base10', base10) 25 | self.register_buffer('max_val', 20 * torch.log(max_val) / base10) 26 | 27 | def __call__(self, a, b): 28 | mse = torch.mean((a.float() - b.float()) ** 2) 29 | 30 | if mse == 0: 31 | return 0 32 | 33 | return 10 * torch.log10((1.0 / mse)) 34 | 35 | 36 | ssim = SSIM() 37 | psnr = PSNR() 38 | 39 | def validation(model, val_loader): 40 | 41 | ssim = SSIM() 42 | psnr = PSNR() 43 | ssim_list = [] 44 | psnr_list = [] 45 | for i, imgs in enumerate(val_loader): 46 | with torch.no_grad(): 47 | low_img, high_img = imgs[0].cuda(), imgs[1].cuda() 48 | if low_img.shape[2] == low_img.shape[3]: 49 | enhanced_img = model(low_img) 50 | else: 51 | # low_img, high_img = pad_input(low_img), pad_input(high_img) ### 52 | enhanced_img = model(low_img) 53 | # enhanced_img, high_img = crop_output(imgs[0].cuda(), enhanced_img), crop_output(imgs[1].cuda(), high_img) ### 54 | ssim_value = ssim(enhanced_img, high_img, as_loss=False).item() 55 | #ssim_value = ssim(enhanced_img, high_img).item() 56 | psnr_value = psnr(enhanced_img, high_img).item() 57 | # print('The %d image SSIM value is %d:' %(i, ssim_value)) 58 | ssim_list.append(ssim_value) 59 | psnr_list.append(psnr_value) 60 | 61 | SSIM_mean = np.mean(ssim_list) 62 | PSNR_mean = np.mean(psnr_list) 63 | print('The SSIM Value is:', SSIM_mean) 64 | print('The PSNR Value is:', PSNR_mean) 65 | return SSIM_mean, PSNR_mean 66 | 67 | def validation_shadow(model, val_loader): 68 | 69 | ssim = SSIM() 70 | psnr = PSNR() 71 | ssim_list = [] 72 | psnr_list = [] 73 | for i, imgs in enumerate(val_loader): 74 | with torch.no_grad(): 75 | low_img, high_img, mask = imgs[0].cuda(), imgs[1].cuda(), imgs[2].cuda() 76 | _, _, enhanced_img = model(low_img, mask) 77 | # print(enhanced_img.shape) 78 | ssim_value = ssim(enhanced_img, high_img, as_loss=False).item() 79 | #ssim_value = ssim(enhanced_img, high_img).item() 80 | psnr_value = psnr(enhanced_img, high_img).item() 81 | # print('The %d image SSIM value is %d:' %(i, ssim_value)) 82 | ssim_list.append(ssim_value) 83 | psnr_list.append(psnr_value) 84 | 85 | SSIM_mean = np.mean(ssim_list) 86 | PSNR_mean = np.mean(psnr_list) 87 | print('The SSIM Value is:', SSIM_mean) 88 | print('The PSNR Value is:', PSNR_mean) 89 | return SSIM_mean, PSNR_mean 90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /custom_utils/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /custom_utils/model_load.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import logging 9 | 10 | import torch 11 | 12 | 13 | logger = logging.getLogger('global') 14 | 15 | 16 | def check_keys(model, pretrained_state_dict): 17 | ckpt_keys = set(pretrained_state_dict.keys()) 18 | model_keys = set(model.state_dict().keys()) 19 | used_pretrained_keys = model_keys & ckpt_keys 20 | unused_pretrained_keys = ckpt_keys - model_keys 21 | missing_keys = model_keys - ckpt_keys 22 | # filter 'num_batches_tracked' 23 | missing_keys = [x for x in missing_keys 24 | if not x.endswith('num_batches_tracked')] 25 | if len(missing_keys) > 0: 26 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 27 | logger.info('missing keys:{}'.format(len(missing_keys))) 28 | if len(unused_pretrained_keys) > 0: 29 | logger.info('[Warning] unused_pretrained_keys: {}'.format( 30 | unused_pretrained_keys)) 31 | logger.info('unused checkpoint keys:{}'.format( 32 | len(unused_pretrained_keys))) 33 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 34 | assert len(used_pretrained_keys) > 0, \ 35 | 'load NONE from pretrained checkpoint' 36 | return True 37 | 38 | 39 | def remove_prefix(state_dict, prefix): 40 | ''' Old style model is stored with all names of parameters 41 | share common prefix 'module.' ''' 42 | logger.info('remove prefix \'{}\''.format(prefix)) 43 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 44 | return {f(key): value for key, value in state_dict.items()} 45 | 46 | 47 | def load_pretrain(model, pretrained_path): 48 | logger.info('load pretrained model from {}'.format(pretrained_path)) 49 | device = torch.cuda.current_device() 50 | pretrained_dict = torch.load(pretrained_path, 51 | map_location=lambda storage, loc: storage.cuda(device)) 52 | if "state_dict" in pretrained_dict.keys(): 53 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 54 | 'module.') 55 | else: 56 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 57 | 58 | try: 59 | check_keys(model, pretrained_dict) 60 | except: 61 | logger.info('[Warning]: using pretrain as features.\ 62 | Adding "features." as prefix') 63 | new_dict = {} 64 | for k, v in pretrained_dict.items(): 65 | k = 'features.' + k 66 | new_dict[k] = v 67 | pretrained_dict = new_dict 68 | check_keys(model, pretrained_dict) 69 | model.load_state_dict(pretrained_dict, strict=False) 70 | return model 71 | 72 | 73 | def restore_from(model, optimizer, ckpt_path): 74 | device = torch.cuda.current_device() 75 | ckpt = torch.load(ckpt_path, 76 | map_location=lambda storage, loc: storage.cuda(device)) 77 | epoch = ckpt['epoch'] 78 | 79 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') 80 | check_keys(model, ckpt_model_dict) 81 | model.load_state_dict(ckpt_model_dict, strict=False) 82 | 83 | check_keys(optimizer, ckpt['optimizer']) 84 | optimizer.load_state_dict(ckpt['optimizer']) 85 | return model, optimizer, epoch 86 | -------------------------------------------------------------------------------- /custom_utils/UHD_load/val_data.py: -------------------------------------------------------------------------------- 1 | # --- Imports --- # 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from torchvision.transforms import Compose, ToTensor, Normalize, Resize 5 | import numpy as np 6 | import torch 7 | import os 8 | 9 | # --- Validation/test dataset --- # 10 | class ValData(data.Dataset): 11 | def __init__(self, dataset_name,val_data_dir): 12 | super().__init__() 13 | self.dataset_name = dataset_name 14 | val_list = os.path.join(val_data_dir, 'data_list.txt') 15 | with open(val_list) as f: 16 | contents = f.readlines() 17 | lowlight_names = [i.strip() for i in contents] 18 | if self.dataset_name=='UHD' or self.dataset_name=='LOLv1' or self.dataset_name=='LOLv2': 19 | gt_names = lowlight_names # 20 | else: 21 | gt_names = None 22 | print('The dataset is not included in this work.') 23 | self.lowlight_names = lowlight_names 24 | self.gt_names = gt_names 25 | self.val_data_dir = val_data_dir 26 | self.data_list=val_list 27 | def get_images(self, index): 28 | lowlight_name = self.lowlight_names[index] 29 | padding = 8 30 | # build the folder of validation/test data in our way 31 | if os.path.exists(os.path.join(self.val_data_dir, 'low')): 32 | lowlight_img = Image.open(os.path.join(self.val_data_dir, 'low', lowlight_name)) 33 | if os.path.exists(os.path.join(self.val_data_dir, 'high')) : 34 | gt_name = self.gt_names[index] 35 | gt_img = Image.open(os.path.join(self.val_data_dir, 'high', gt_name)) ## 36 | a = lowlight_img.size 37 | 38 | a_0 =a[1] - np.mod(a[1],padding) 39 | a_1 =a[0] - np.mod(a[0],padding) 40 | lowlight_crop_img = lowlight_img.crop((0, 0, 0 + a_1, 0+a_0)) 41 | gt_crop_img = gt_img.crop((0, 0, 0 + a_1, 0+a_0)) 42 | transform_lowlight = Compose([ToTensor()]) 43 | transform_gt = Compose([ToTensor()]) 44 | lowlight_img = transform_lowlight(lowlight_crop_img) 45 | gt_img = transform_gt(gt_crop_img) 46 | else: 47 | # the inputs is used to calculate PSNR. 48 | a = lowlight_img.size 49 | a_0 =a[1] - np.mod(a[1],padding) 50 | a_1 =a[0] - np.mod(a[0],padding) 51 | lowlight_crop_img = lowlight_img.crop((0, 0, 0 + a_1, 0+a_0)) 52 | gt_crop_img = lowlight_crop_img 53 | transform_lowlight = Compose([ToTensor() ]) 54 | transform_gt = Compose([ToTensor()]) 55 | lowlight_img = transform_lowlight(lowlight_crop_img) 56 | gt_img = transform_gt(gt_crop_img) 57 | # Any folder containing validation/test images 58 | else: 59 | lowlight_img = Image.open(os.path.join(self.val_data_dir, lowlight_name)) 60 | a = lowlight_img.size 61 | a_0 =a[1] - np.mod(a[1],padding) 62 | a_1 =a[0] - np.mod(a[0],padding) 63 | lowlight_crop_img = lowlight_img.crop((0, 0, 0 + a_1, 0+a_0)) 64 | gt_crop_img = lowlight_crop_img 65 | transform_lowlight = Compose([ToTensor()]) 66 | transform_gt = Compose([ToTensor()]) 67 | lowlight_img = transform_lowlight(lowlight_crop_img) 68 | gt_img = transform_gt(gt_crop_img) 69 | return lowlight_img, gt_img, lowlight_name 70 | 71 | 72 | def __getitem__(self, index): 73 | res = self.get_images(index) 74 | return res 75 | 76 | def __len__(self): 77 | return len(self.lowlight_names) 78 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = "1" #str(args.gpu_id) 4 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 5 | # torch imports 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.nn.functional as F 12 | from collections import OrderedDict 13 | 14 | # other imports 15 | 16 | import sys 17 | sys.path.append('.') 18 | import argparse 19 | from thop import profile 20 | import numpy as np 21 | from IQA_pytorch import SSIM, MS_SSIM 22 | from tqdm import tqdm 23 | import pyiqa 24 | import cv2 25 | import matplotlib as plt 26 | import PIL.ImageDraw as ImageDraw 27 | import PIL.ImageFont as ImageFont 28 | from PIL import Image 29 | 30 | # ours imports 31 | import custom_utils 32 | from custom_utils.data_loaders.lol import wholeDataLoader 33 | from model.model_builder import LLENet 34 | from custom_utils.img_resize import pad_input, crop_output 35 | from custom_utils.validation import PSNR, validation 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--gpu_id', type=str, default=0) 39 | parser.add_argument('--testSet_path', type=str, default='/data/xr/Dataset/light_dataset/LOL_v1/eval15') 40 | parser.add_argument('--save', type=bool, default=True) 41 | parser.add_argument('--channel', type=int, default=32) 42 | parser.add_argument('--model_name', type=str, default='testpath') 43 | parser.add_argument('--weight_path', type=str, default='./checkpoints/LOL_v1/testpath/models/model_latest.pth') 44 | parser.add_argument('--save_path', type=str, default='./results/LOL_v1') 45 | args = parser.parse_args() 46 | 47 | print(args) 48 | 49 | 50 | def load_checkpoint(model, weights): 51 | checkpoint = torch.load(weights) 52 | # model.load_state_dict(checkpoint["state_dict"]) 53 | try: 54 | model.load_state_dict(checkpoint["state_dict"]) 55 | except: 56 | state_dict = checkpoint["state_dict"] 57 | new_state_dict = OrderedDict() 58 | for k, v in state_dict.items(): 59 | name = k[7:] # remove `module.` 60 | new_state_dict[name] = v 61 | model.load_state_dict(new_state_dict) 62 | 63 | def eval(): 64 | val_dataset = wholeDataLoader(images_path=args.testSet_path, mode='test') 65 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) 66 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 67 | model = LLENet(dim=args.channel).cuda() 68 | 69 | test_tensor = torch.randn(1, 3, 256, 256) 70 | flops, params = profile(model.cuda(), ((test_tensor.cuda()),)) 71 | print('flops: %.2f G, params: %.2f M' % (flops / 1000000000.0, params / 1000000.0)) 72 | 73 | params_number = params / 1000000.0 74 | flops_number = flops / 1000000000.0 75 | 76 | # model.load_state_dict(torch.load(args.weight_path)) 77 | load_checkpoint(model, args.weight_path) 78 | model.eval() 79 | 80 | 81 | ssim = SSIM() 82 | psnr = PSNR() 83 | ssim_list = [] 84 | psnr_list = [] 85 | lpips_list = [] 86 | niqe_list = [] 87 | 88 | if args.save: 89 | result_path = os.path.join(args.save_path, args.model_name) 90 | if not os.path.exists(result_path): 91 | custom_utils.mkdir(result_path) 92 | 93 | if os.path.exists(result_path + '/result.txt'): 94 | os.remove(result_path + '/result.txt') 95 | 96 | with torch.no_grad(): 97 | for i, imgs in enumerate(tqdm(val_loader), 0): 98 | low_img, high_img, name = imgs[0].cuda(), imgs[1].cuda(), str(imgs[2][0]) 99 | enhanced_img = model(low_img) 100 | 101 | 102 | # 图像上不输出相关信息 103 | if args.save: 104 | save_path = os.path.join(args.save_path, args.model_name) 105 | save_name = str(name) + '.png' 106 | save_file = os.path.join(save_path, save_name) 107 | torchvision.utils.save_image(enhanced_img, save_file) 108 | 109 | if __name__ == "__main__": 110 | 111 | eval() -------------------------------------------------------------------------------- /custom_utils/UHD_load/train_data_aug_local.py: -------------------------------------------------------------------------------- 1 | # --- Imports --- # 2 | import torch.utils.data as data 3 | from PIL import Image,ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | 6 | import torchvision.transforms as tfs 7 | from torchvision.transforms import functional as FF 8 | import imghdr 9 | import random 10 | import torch 11 | import numpy as np 12 | # from basicsr.utils import DiffJPEG, USMSharp 13 | from skimage import io, color 14 | import PIL 15 | import torchvision 16 | 17 | # --- Training dataset --- # 18 | class TrainData(data.Dataset): 19 | def __init__(self, crop_size, train_data_dir): 20 | super().__init__() 21 | # self.usm_sharpener = USMSharp().cuda() # do usm sharpening 22 | torch.multiprocessing.set_start_method('spawn', force=True) 23 | 24 | train_list = train_data_dir +'data_list.txt' #'datalist.txt''train_list_recap.txt' 'fitered_trainingdata.txt' 25 | with open(train_list) as f: 26 | contents = f.readlines() 27 | lowlight_names = [i.strip() for i in contents] 28 | gt_names = lowlight_names#[i.split('_')[0] for i in lowlight_names] 29 | 30 | self.lowlight_names = lowlight_names 31 | self.gt_names = gt_names 32 | self.crop_size = crop_size 33 | self.size_w = crop_size[0] 34 | self.size_h = crop_size[1] 35 | self.train_data_dir = train_data_dir 36 | def get_images(self, index): 37 | lowlight_name = self.lowlight_names[index] 38 | gt_name = self.gt_names[index] 39 | 40 | lowlight = Image.open(self.train_data_dir + 'low/' + lowlight_name).convert('RGB') #'input_unprocess_aligned/' v 41 | clear = Image.open(self.train_data_dir + 'high/' + gt_name ).convert('RGB') #'gt_unprocess_aligned/''high/' 42 | 43 | if not isinstance(self.crop_size,str): 44 | i,j,h,w=tfs.RandomCrop.get_params(lowlight,output_size=(self.size_w,self.size_h)) 45 | # i,j,h,w=tfs.RandomCrop.get_params(lowlight,output_size=(2160,3840)) 46 | lowlight=FF.crop(lowlight,i,j,h,w) 47 | clear=FF.crop(clear,i,j,h,w) 48 | 49 | 50 | data, target=self.augData(lowlight.convert("RGB") ,clear.convert("RGB") ) 51 | 52 | return data, target #, lowlight.resize((width/8, height/8)),gt.resize((width/8, height/8))#,factor 53 | def augData(self,data,target): 54 | #if self.train: 55 | if 1: 56 | rand_hor=random.randint(0,1) 57 | rand_rot=random.randint(0,3) 58 | data=tfs.RandomHorizontalFlip(rand_hor)(data) 59 | target=tfs.RandomHorizontalFlip(rand_hor)(target) 60 | if rand_rot: 61 | data=FF.rotate(data,90*rand_rot) 62 | target=FF.rotate(target,90*rand_rot) 63 | 64 | data=tfs.ToTensor()(data) 65 | target=tfs.ToTensor()(target) 66 | 67 | return data, target 68 | def __getitem__(self, index): 69 | res = self.get_images(index) 70 | return res 71 | 72 | def __len__(self): 73 | return len(self.lowlight_names) 74 | 75 | 76 | def cutblur(self, im1, im2, prob=1.0, alpha=1.0): 77 | if im1.size() != im2.size(): 78 | raise ValueError("im1 and im2 have to be the same resolution.") 79 | 80 | if alpha <= 0 or np.random.rand(1) >= prob: 81 | return im1, im2 82 | 83 | cut_ratio = np.random.randn()* 0.1+ alpha 84 | 85 | h, w = im2.size(0), im2.size(1) 86 | ch, cw = np.int(h*cut_ratio), np.int(w*cut_ratio) 87 | cy = np.random.randint(0, h-ch+1) 88 | cx = np.random.randint(0, w-cw+1) 89 | 90 | # apply CutBlur to inside or outside 91 | if np.random.random() > 0.3: #0.5 92 | im2[cy:cy+ch, cx:cx+cw,:] = im1[cy:cy+ch, cx:cx+cw,:] 93 | 94 | return im1, im2 95 | 96 | def tensor_to_image(self,tensor): 97 | tensor = tensor*255 98 | tensor = np.array(tensor, dtype=np.uint8) 99 | if np.ndim(tensor)>3: 100 | assert tensor.shape[0] == 1 101 | tensor = tensor[0] 102 | return PIL.Image.fromarray(tensor) -------------------------------------------------------------------------------- /custom_utils/metrics/brisque.py: -------------------------------------------------------------------------------- 1 | import math 2 | import scipy.special 3 | import numpy as np 4 | import cv2 5 | import scipy as sp 6 | from joblib import load 7 | 8 | gamma_range = np.arange(0.2, 10, 0.001) 9 | a = scipy.special.gamma(2.0/gamma_range) 10 | a *= a 11 | b = scipy.special.gamma(1.0/gamma_range) 12 | c = scipy.special.gamma(3.0/gamma_range) 13 | prec_gammas = a/(b*c) 14 | 15 | 16 | def aggd_features(imdata): 17 | # flatten imdata 18 | imdata.shape = (len(imdata.flat),) 19 | imdata2 = imdata*imdata 20 | left_data = imdata2[imdata < 0] 21 | right_data = imdata2[imdata >= 0] 22 | left_mean_sqrt = 0 23 | right_mean_sqrt = 0 24 | if len(left_data) > 0: 25 | left_mean_sqrt = np.sqrt(np.average(left_data)) 26 | if len(right_data) > 0: 27 | right_mean_sqrt = np.sqrt(np.average(right_data)) 28 | 29 | if right_mean_sqrt != 0: 30 | gamma_hat = left_mean_sqrt/right_mean_sqrt 31 | else: 32 | gamma_hat = np.inf 33 | # solve r-hat norm 34 | 35 | imdata2_mean = np.mean(imdata2) 36 | if imdata2_mean != 0: 37 | r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2)) 38 | else: 39 | r_hat = np.inf 40 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * 41 | (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) 42 | 43 | # solve alpha by guessing values that minimize ro 44 | pos = np.argmin((prec_gammas - rhat_norm)**2) 45 | alpha = gamma_range[pos] 46 | 47 | gam1 = scipy.special.gamma(1.0/alpha) 48 | gam2 = scipy.special.gamma(2.0/alpha) 49 | gam3 = scipy.special.gamma(3.0/alpha) 50 | 51 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3) 52 | bl = aggdratio * left_mean_sqrt 53 | br = aggdratio * right_mean_sqrt 54 | 55 | # mean parameter 56 | N = (br - bl)*(gam2 / gam1) # *aggdratio 57 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) 58 | 59 | 60 | def ggd_features(imdata): 61 | nr_gam = 1/prec_gammas 62 | sigma_sq = np.var(imdata) 63 | E = np.mean(np.abs(imdata)) 64 | rho = sigma_sq/E**2 65 | pos = np.argmin(np.abs(nr_gam - rho)) 66 | return gamma_range[pos], sigma_sq 67 | 68 | 69 | def paired_product(new_im): 70 | shift1 = np.roll(new_im.copy(), 1, axis=1) 71 | shift2 = np.roll(new_im.copy(), 1, axis=0) 72 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) 73 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) 74 | 75 | H_img = shift1 * new_im 76 | V_img = shift2 * new_im 77 | D1_img = shift3 * new_im 78 | D2_img = shift4 * new_im 79 | 80 | return (H_img, V_img, D1_img, D2_img) 81 | 82 | 83 | def calculate_mscn(dis_image): 84 | dis_image = dis_image.astype(np.float32) # 类型转换十分重要 85 | ux = cv2.GaussianBlur(dis_image, (7, 7), 7/6) 86 | ux_sq = ux*ux 87 | sigma = np.sqrt(np.abs(cv2.GaussianBlur(dis_image**2, (7, 7), 7/6)-ux_sq)) 88 | 89 | mscn = (dis_image-ux)/(1+sigma) 90 | 91 | return mscn 92 | 93 | 94 | def ggd_features(imdata): 95 | nr_gam = 1/prec_gammas 96 | sigma_sq = np.var(imdata) 97 | E = np.mean(np.abs(imdata)) 98 | rho = sigma_sq/E**2 99 | pos = np.argmin(np.abs(nr_gam - rho)) 100 | return gamma_range[pos], sigma_sq 101 | 102 | 103 | def extract_brisque_feats(mscncoefs): 104 | alpha_m, sigma_sq = ggd_features(mscncoefs.copy()) 105 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs) 106 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) 107 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) 108 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) 109 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) 110 | # print(alpha_m, alpha1) 111 | return [ 112 | alpha_m, sigma_sq, 113 | alpha1, N1, lsq1**2, rsq1**2, # (V) 114 | alpha2, N2, lsq2**2, rsq2**2, # (H) 115 | alpha3, N3, lsq3**2, rsq3**2, # (D1) 116 | alpha4, N4, lsq4**2, rsq4**2, # (D2) 117 | ] 118 | 119 | 120 | def brisque(im): 121 | mscncoefs = calculate_mscn(im) 122 | features1 = extract_brisque_feats(mscncoefs) 123 | lowResolution = cv2.resize(im, (0, 0), fx=0.5, fy=0.5) 124 | features2 = extract_brisque_feats(lowResolution) 125 | 126 | return np.array(features1+features2) 127 | 128 | def brisque_val(im): 129 | im = np.transpose(im.cpu().numpy(), (2,3,1,0)) 130 | im = im.squeeze() 131 | feature = brisque(im) 132 | feature = feature.reshape(1, -1) 133 | clf = load('toolkit/metrics/svr_brisque.joblib') 134 | score = clf.predict(feature)[0] 135 | return score 136 | -------------------------------------------------------------------------------- /model/modules/SSF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class InceptionModule(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | super(InceptionModule, self).__init__() 8 | self.branch1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 9 | self.branch3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 10 | self.branch5x5 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2) 11 | self.branch_pool = nn.Conv2d(in_channels, out_channels, kernel_size=1) 12 | self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 13 | 14 | self.conv = nn.Conv2d(in_channels * 4, 1, kernel_size=1) 15 | 16 | self._initialize_weights() 17 | 18 | def _initialize_weights(self): 19 | # Custom initialization for each convolution layer 20 | nn.init.kaiming_normal_(self.branch1x1.weight, mode='fan_out', nonlinearity='relu') 21 | nn.init.kaiming_normal_(self.branch3x3.weight, mode='fan_out', nonlinearity='relu') 22 | nn.init.kaiming_normal_(self.branch5x5.weight, mode='fan_out', nonlinearity='relu') 23 | nn.init.kaiming_normal_(self.branch_pool.weight, mode='fan_out', nonlinearity='relu') 24 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 25 | 26 | def forward(self, x): 27 | branch1x1 = self.branch1x1(x) 28 | branch3x3 = self.branch3x3(x) 29 | branch5x5 = self.branch5x5(x) 30 | branch_pool = self.branch_pool(self.pool(x)) 31 | 32 | outputs = torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1) 33 | outputs = self.conv(outputs) 34 | return outputs 35 | 36 | # State-aware Selective Fusion (SSF) module 37 | class SSF(nn.Module): 38 | def __init__(self, num_feats, encode_channels, target_channels): 39 | super(SSF, self).__init__() 40 | self.num_feats = num_feats 41 | self.target_channels = target_channels 42 | 43 | # Alignment convolution layers for each encoder feature map 44 | self.align_convs = nn.ModuleList([ 45 | nn.Conv2d(in_channels, target_channels, kernel_size=1) 46 | for in_channels in encode_channels 47 | ]) 48 | 49 | self.conv_fusion = nn.Conv2d(num_feats, num_feats, kernel_size=3, padding=1) 50 | self.inception = InceptionModule(num_feats, num_feats) 51 | self.final_conv = nn.Conv2d(target_channels * 2, target_channels, kernel_size=3, padding=1) 52 | 53 | self._initialize_weights() 54 | 55 | def _initialize_weights(self): 56 | for conv in self.align_convs: 57 | nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu') 58 | nn.init.kaiming_normal_(self.conv_fusion.weight, mode='fan_out', nonlinearity='relu') 59 | nn.init.kaiming_normal_(self.final_conv.weight, mode='fan_out', nonlinearity='relu') 60 | 61 | 62 | def forward(self, x, feat_list): 63 | target_height, target_width = x.size(2), x.size(3) 64 | feat_list.reverse() 65 | # Process each feature map in feat_list 66 | aligned_feats = [] 67 | encoder_feat = torch.zeros_like(x) 68 | 69 | for i, feat in enumerate(feat_list): 70 | if feat.size(2) == target_height and feat.size(3) == target_width: 71 | encoder_feat = feat 72 | 73 | feat = torch.mean(feat, dim=1, keepdim=True) 74 | feat = F.interpolate(feat, size=(target_height, target_width), mode='bilinear', align_corners=False) 75 | aligned_feats.append(feat) 76 | 77 | # Stack the aligned feature maps along the channel dimension 78 | stacked_feats = torch.cat(aligned_feats, dim=1) 79 | 80 | # Fuse features along the N dimension 81 | fused_feat = self.conv_fusion(stacked_feats) 82 | 83 | # Apply the Inception module for multi-scale feature extraction 84 | inception_feat = self.inception(fused_feat) 85 | inception_feat = torch.sigmoid(inception_feat) 86 | 87 | guided_feat = inception_feat * encoder_feat 88 | 89 | output = self.final_conv(torch.cat([guided_feat, x], dim=1)) 90 | 91 | return output 92 | 93 | class NormLayer(nn.Module): 94 | def __init__(self, num_channels): 95 | super(NormLayer, self).__init__() 96 | 97 | # Learnable scaling and bias parameters 98 | self.scale = nn.Parameter(torch.ones(1, num_channels, 1, 1)) 99 | self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 100 | 101 | def forward(self, x): 102 | # Apply normalization with learned scaling and bias 103 | return x * self.scale + self.bias 104 | -------------------------------------------------------------------------------- /custom_utils/data_loaders/exposure.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import torch 5 | import torch.utils.data as data 6 | 7 | import numpy as np 8 | import glob 9 | import random 10 | import cv2 11 | from glob import glob 12 | 13 | # By Ziteng Cui, cui@mi.t.u-tokyo.ac.jp 14 | random.seed(1143) 15 | 16 | 17 | # input: low light image path 18 | # return: train image ids, test image ids 19 | 20 | def populate_train_list(images_path, mode='train'): 21 | train_list = [os.path.basename(f) for f in glob(os.path.join(images_path, '*.JPG'))] 22 | train_list.sort() 23 | 24 | if mode == 'train': 25 | random.shuffle(train_list) 26 | 27 | return train_list 28 | 29 | 30 | class exposure_loader(data.Dataset): 31 | 32 | def __init__(self, images_path, mode='train', expert='c', normalize=False): 33 | self.train_list = populate_train_list(images_path, mode) 34 | # self.h, self.w = int(img_size[0]), int(img_size[1]) 35 | self.mode = mode # train or test 36 | self.data_list = self.train_list 37 | self.low_path = images_path 38 | if self.mode == 'train' or self.mode == 'val': 39 | self.high_path = images_path.replace('INPUT_IMAGES', 'GT_IMAGES') 40 | elif self.mode == 'test': 41 | self.high_path = images_path.replace('INPUT_IMAGES', 'expert_'+expert+'_testing_set') 42 | self.normalize = normalize 43 | self.resize = True 44 | self.image_size = 1200 45 | self.image_size_w = 900 46 | #self.image_size = 512 47 | #self.image_size_w = 512 48 | # self.test_resize = True 49 | print("Total examples:", len(self.data_list)) 50 | # print("Total testing examples:", len(self.test_list)) 51 | # self.transform_train = transforms.Compose() 52 | 53 | def FLIP_aug(self, low, high): 54 | if random.random() > 0.5: 55 | low = cv2.flip(low, 0) 56 | high = cv2.flip(high, 0) 57 | 58 | if random.random() > 0.5: 59 | low = cv2.flip(low, 1) 60 | high = cv2.flip(high, 1) 61 | 62 | return low, high 63 | 64 | def get_params(self, low): 65 | self.h, self.w = low.shape[0], low.shape[1] # 900, 1200 66 | # print(self.h, self.w) 67 | # self.crop_height = random.randint(self.h / 2, self.h) # random.randint(self.MinCropHeight, self.MaxCropHeight) 68 | # self.crop_width = random.randint(self.w / 2, self.w) # random.randint(self.MinCropWidth,self.MaxCropWidth) 69 | self.crop_height = self.h / 2 # random.randint(self.MinCropHeight, self.MaxCropHeight) 70 | self.crop_width = self.w / 2 # random.randint(self.MinCropWidth,self.MaxCropWidth) 71 | 72 | i = random.randint(0, self.h - self.crop_height) 73 | j = random.randint(0, self.w - self.crop_width) 74 | return i, j 75 | 76 | def Random_Crop(self, low, high): 77 | self.i, self.j = self.get_params(low) 78 | self.i, self.j = int(self.i), int(self.j) 79 | # if random.random() > 0.5: 80 | low = low[self.i: self.i + int(self.crop_height), self.j: self.j + int(self.crop_width)] 81 | high = high[self.i: self.i + int(self.crop_height), self.j: self.j + int(self.crop_width)] 82 | return low, high 83 | 84 | def __getitem__(self, index): 85 | img_id = self.data_list[index] 86 | a = img_id.rfind('_') 87 | img_id_gt = img_id[:a] 88 | 89 | data_lowlight = cv2.imread(osp.join(self.low_path, img_id), cv2.IMREAD_UNCHANGED) 90 | data_highlight = cv2.imread(osp.join(self.high_path, img_id_gt+'.jpg'), cv2.IMREAD_UNCHANGED) 91 | 92 | if data_lowlight.shape[0] >= data_lowlight.shape[1]: 93 | data_lowlight = cv2.transpose(data_lowlight) 94 | data_highlight = cv2.transpose(data_highlight) 95 | 96 | if self.resize: 97 | data_lowlight = cv2.resize(data_lowlight, (self.image_size, self.image_size_w)) 98 | data_highlight = cv2.resize(data_highlight, (self.image_size, self.image_size_w)) 99 | # print(data_lowlight.shape) 100 | if self.mode == 'train': # data augmentation 101 | data_lowlight, data_highlight = self.FLIP_aug(data_lowlight, data_highlight) 102 | # data_lowlight, data_highlight = self.Random_Crop(data_lowlight, data_highlight) 103 | # print(data_lowlight.shape) 104 | data_lowlight = (np.asarray(data_lowlight[..., ::-1]) / 255.0) 105 | data_highlight = (np.asarray(data_highlight[..., ::-1]) / 255.0) 106 | 107 | data_lowlight = torch.from_numpy(data_lowlight).float() # float32 108 | data_highlight = torch.from_numpy(data_highlight).float() # float32 109 | 110 | return data_lowlight.permute(2, 0, 1), data_highlight.permute(2, 0, 1) 111 | 112 | def __len__(self): 113 | return len(self.data_list) 114 | 115 | 116 | if __name__ == "__main__": 117 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 118 | train_path = '/data/unagi0/cui_data/light_dataset/Exposure_CVPR21/train/INPUT_IMAGES' 119 | train_dataset = exposure_loader(train_path, mode='train') 120 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1, 121 | pin_memory=True) 122 | for iteration, imgs in enumerate(train_loader): 123 | print(iteration) 124 | print(imgs[0].shape) 125 | print(imgs[1].shape) 126 | low_img = imgs[0] 127 | high_img = imgs[1] -------------------------------------------------------------------------------- /custom_utils/data_loaders/mit5k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | 5 | import torch 6 | import torch.utils.data as data 7 | 8 | import numpy as np 9 | from PIL import Image, ImageOps 10 | import glob 11 | import random 12 | import cv2 13 | import torchvision 14 | from torchvision import transforms 15 | from torchvision.transforms import Compose, ToTensor, Normalize, ConvertImageDtype 16 | from glob import glob 17 | # By Ziteng Cui, cui@mi.t.u-tokyo.ac.jp 18 | random.seed(1143) 19 | 20 | 21 | # input: low light image path 22 | # return: train image ids, test image ids 23 | 24 | def populate_train_list(images_path, mode='train'): 25 | 26 | train_list = [os.path.basename(f) for f in glob(os.path.join(images_path, '*.jpg'))] 27 | train_list.sort() 28 | 29 | if mode == 'train': 30 | random.shuffle(train_list) 31 | 32 | return train_list 33 | 34 | 35 | class adobe5k_loader(data.Dataset): 36 | 37 | def __init__(self, images_path, mode='train', normalize=False): 38 | self.train_list = populate_train_list(images_path, mode) 39 | # self.h, self.w = int(img_size[0]), int(img_size[1]) 40 | self.mode = mode # train or test 41 | self.data_list = self.train_list 42 | self.low_path = images_path 43 | self.high_path = images_path.replace('Inputs_jpg', 'Experts_C') 44 | self.normalize = normalize 45 | self.resize = True 46 | #self.image_size = 1200 47 | #self.image_size_w = 900 48 | self.image_size = 600 49 | self.image_size_w = 450 50 | #self.test_resize = True 51 | print("Total examples:", len(self.data_list)) 52 | #print("Total testing examples:", len(self.test_list)) 53 | # self.transform_train = transforms.Compose() 54 | 55 | def FLIP_aug(self, low, high): 56 | if random.random() > 0.5: 57 | low = cv2.flip(low, 0) 58 | high = cv2.flip(high, 0) 59 | 60 | if random.random() > 0.5: 61 | low = cv2.flip(low, 1) 62 | high = cv2.flip(high, 1) 63 | 64 | return low, high 65 | 66 | 67 | def get_params(self, low): 68 | self.h, self.w = low.shape[0], low.shape[1] # 900, 1200 69 | #print(self.h, self.w) 70 | #self.crop_height = random.randint(self.h / 2, self.h) # random.randint(self.MinCropHeight, self.MaxCropHeight) 71 | #self.crop_width = random.randint(self.w / 2, self.w) # random.randint(self.MinCropWidth,self.MaxCropWidth) 72 | self.crop_height = self.h / 2 #random.randint(self.MinCropHeight, self.MaxCropHeight) 73 | self.crop_width = self.w / 2 #random.randint(self.MinCropWidth,self.MaxCropWidth) 74 | 75 | i = random.randint(0, self.h - self.crop_height) 76 | j = random.randint(0, self.w - self.crop_width) 77 | return i, j 78 | 79 | def Random_Crop(self, low, high): 80 | self.i, self.j = self.get_params(low) 81 | self.i, self.j = int(self.i), int(self.j) 82 | #if random.random() > 0.5: 83 | low = low[self.i: self.i + int(self.crop_height), self.j: self.j + int(self.crop_width)] 84 | high = high[self.i: self.i + int(self.crop_height), self.j: self.j + int(self.crop_width)] 85 | return low, high 86 | 87 | def __getitem__(self, index): 88 | img_id = self.data_list[index] 89 | 90 | #data_lowlight = Image.open(osp.join(self.low_path, img_id)) 91 | data_lowlight = cv2.imread(osp.join(self.low_path, img_id), cv2.IMREAD_UNCHANGED) 92 | data_highlight = cv2.imread(osp.join(self.high_path, img_id), cv2.IMREAD_UNCHANGED) 93 | 94 | if data_lowlight.shape[0] >= data_lowlight.shape[1]: 95 | data_lowlight = cv2.transpose(data_lowlight) 96 | data_highlight = cv2.transpose(data_highlight) 97 | 98 | if self.resize: 99 | data_lowlight = cv2.resize(data_lowlight, (self.image_size, self.image_size_w)) 100 | data_highlight = cv2.resize(data_highlight, (self.image_size, self.image_size_w)) 101 | #print(data_lowlight.shape) 102 | if self.mode == 'train': #data augmentation 103 | data_lowlight, data_highlight = self.FLIP_aug(data_lowlight, data_highlight) 104 | #data_lowlight, data_highlight = self.Random_Crop(data_lowlight, data_highlight) 105 | #print(data_lowlight.shape) 106 | data_lowlight = (np.asarray(data_lowlight[..., ::-1]) / 255.0) 107 | data_highlight = (np.asarray(data_highlight[..., ::-1]) / 255.0) 108 | 109 | data_lowlight = torch.from_numpy(data_lowlight).float() # float32 110 | data_highlight = torch.from_numpy(data_highlight).float() # float32 111 | 112 | return data_lowlight.permute(2, 0, 1), data_highlight.permute(2, 0, 1) 113 | 114 | def __len__(self): 115 | return len(self.data_list) 116 | 117 | 118 | if __name__ == "__main__": 119 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 120 | train_path = '/home/czt/DataSets/five5k_dataset/Inputs_jpg' 121 | test_path = '/home/czt/DataSets/five5k_dataset/UPE_testset/Inputs_jpg' 122 | test_dataset = adobe5k_loader(train_path, mode='train') 123 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1, 124 | pin_memory=True) 125 | for iteration, imgs in enumerate(test_loader): 126 | print(iteration) 127 | print(imgs[0].shape) 128 | print(imgs[1].shape) 129 | low_img = imgs[0] 130 | high_img = imgs[1] 131 | # visualization(low_img, 'show/low', iteration) 132 | # visualization(high_img, 'show/high', iteration) -------------------------------------------------------------------------------- /custom_utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def numpy_to_torch(a: np.ndarray): 7 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) 8 | 9 | 10 | def torch_to_numpy(a: torch.Tensor): 11 | return a.squeeze(0).permute(1,2,0).numpy() 12 | 13 | 14 | def sample_patch_transformed(im, pos, scale, image_sz, transforms, is_mask=False): 15 | """Extract transformed image samples. 16 | args: 17 | im: Image. 18 | pos: Center position for extraction. 19 | scale: Image scale to extract features from. 20 | image_sz: Size to resize the image samples to before extraction. 21 | transforms: A set of image transforms to apply. 22 | """ 23 | 24 | # Get image patche 25 | im_patch, _ = sample_patch(im, pos, scale*image_sz, image_sz, is_mask=is_mask) 26 | 27 | # Apply transforms 28 | im_patches = torch.cat([T(im_patch, is_mask=is_mask) for T in transforms]) 29 | 30 | return im_patches 31 | 32 | 33 | def sample_patch_multiscale(im, pos, scales, image_sz, mode: str='replicate', max_scale_change=None): 34 | """Extract image patches at multiple scales. 35 | args: 36 | im: Image. 37 | pos: Center position for extraction. 38 | scales: Image scales to extract image patches from. 39 | image_sz: Size to resize the image samples to 40 | mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' 41 | max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode 42 | """ 43 | if isinstance(scales, (int, float)): 44 | scales = [scales] 45 | 46 | # Get image patches 47 | patch_iter, coord_iter = zip(*(sample_patch(im, pos, s*image_sz, image_sz, mode=mode, 48 | max_scale_change=max_scale_change) for s in scales)) 49 | im_patches = torch.cat(list(patch_iter)) 50 | patch_coords = torch.cat(list(coord_iter)) 51 | 52 | return im_patches, patch_coords 53 | 54 | 55 | def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None, 56 | mode: str = 'replicate', max_scale_change=None, is_mask=False): 57 | """Sample an image patch. 58 | 59 | args: 60 | im: Image 61 | pos: center position of crop 62 | sample_sz: size to crop 63 | output_sz: size to resize to 64 | mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' 65 | max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode 66 | """ 67 | 68 | # if mode not in ['replicate', 'inside']: 69 | # raise ValueError('Unknown border mode \'{}\'.'.format(mode)) 70 | 71 | # copy and convert 72 | posl = pos.long().clone() 73 | 74 | pad_mode = mode 75 | 76 | # Get new sample size if forced inside the image 77 | if mode == 'inside' or mode == 'inside_major': 78 | pad_mode = 'replicate' 79 | im_sz = torch.Tensor([im.shape[2], im.shape[3]]) 80 | shrink_factor = (sample_sz.float() / im_sz) 81 | if mode == 'inside': 82 | shrink_factor = shrink_factor.max() 83 | elif mode == 'inside_major': 84 | shrink_factor = shrink_factor.min() 85 | shrink_factor.clamp_(min=1, max=max_scale_change) 86 | sample_sz = (sample_sz.float() / shrink_factor).long() 87 | 88 | # Compute pre-downsampling factor 89 | if output_sz is not None: 90 | resize_factor = torch.min(sample_sz.float() / output_sz.float()).item() 91 | df = int(max(int(resize_factor - 0.1), 1)) 92 | else: 93 | df = int(1) 94 | 95 | sz = sample_sz.float() / df # new size 96 | 97 | # Do downsampling 98 | if df > 1: 99 | os = posl % df # offset 100 | posl = (posl - os) / df # new position 101 | im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample 102 | else: 103 | im2 = im 104 | 105 | # compute size to crop 106 | szl = torch.max(sz.round(), torch.Tensor([2])).long() 107 | 108 | # Extract top and bottom coordinates 109 | tl = posl - (szl - 1)/2 110 | br = posl + szl/2 + 1 111 | 112 | # Shift the crop to inside 113 | if mode == 'inside' or mode == 'inside_major': 114 | im2_sz = torch.LongTensor([im2.shape[2], im2.shape[3]]) 115 | shift = (-tl).clamp(0) - (br - im2_sz).clamp(0) 116 | tl += shift 117 | br += shift 118 | 119 | outside = ((-tl).clamp(0) + (br - im2_sz).clamp(0)) // 2 120 | shift = (-tl - outside) * (outside > 0).long() 121 | tl += shift 122 | br += shift 123 | 124 | # Get image patch 125 | # im_patch = im2[...,tl[0].item():br[0].item(),tl[1].item():br[1].item()] 126 | 127 | # Get image patch 128 | if not is_mask: 129 | im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3], -tl[0].item(), br[0].item() - im2.shape[2]), pad_mode) 130 | else: 131 | im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3], -tl[0].item(), br[0].item() - im2.shape[2])) 132 | 133 | # Get image coordinates 134 | patch_coord = df * torch.cat((tl, br)).view(1,4) 135 | 136 | if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]): 137 | return im_patch.clone(), patch_coord 138 | 139 | # Resample 140 | if not is_mask: 141 | im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear') 142 | else: 143 | im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='nearest') 144 | 145 | return im_patch, patch_coord 146 | -------------------------------------------------------------------------------- /model/modules/LAN.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import math, os 3 | 4 | import logging 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | import torch.utils.checkpoint as cp 10 | import numbers 11 | from .drop import DropPath 12 | 13 | 14 | ''' 15 | # Global Average Pooling (GAP) as a Luminance Estimation 16 | - The use of Global Average Pooling (GAP) on features, both on intermediate feature maps (inter_feat) and the input feature map (x), 17 | simulates a form of global luminance perception. The GAP extracts a summary of the overall brightness of an image, 18 | which is akin to how the human eye averages out brightness over large areas, 19 | especially in low-light conditions where local contrast is reduced. 20 | 21 | - The GAP operation takes an average over spatial dimensions to create a single luminance descriptor for each channel, 22 | which can be considered as a global "luminance estimation." 23 | ''' 24 | 25 | 26 | class LuminanceAdaptiveNorm(nn.Module): 27 | def __init__(self, dim, channel_first=True, seed=42): 28 | super().__init__() 29 | 30 | torch.manual_seed(seed) 31 | 32 | self.alpha = nn.Parameter(torch.ones([1, 1, dim])) 33 | self.beta = nn.Parameter(torch.zeros([1, 1, dim])) 34 | self.color = nn.Parameter(torch.eye(dim)) 35 | self.channel_first = channel_first 36 | 37 | self.gap = nn.AdaptiveAvgPool2d(1) # Global Average Pooling to [B, C', 1, 1] 38 | 39 | self.mlp = nn.Sequential( 40 | nn.Linear(3 * dim, dim), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(dim, dim) 43 | ) 44 | 45 | self._initialize_weights() 46 | 47 | def _initialize_weights(self): 48 | nn.init.kaiming_uniform_(self.mlp[0].weight, a=0, mode='fan_in', nonlinearity='relu') 49 | nn.init.xavier_uniform_(self.mlp[2].weight) 50 | 51 | def _gap_and_pad_features(self, inter_feat, x_): 52 | """Applies GAP to input features and pads channels to match the maximum.""" 53 | gap_feats = [self.gap(feat) for feat in inter_feat] # [B, C', 1, 1] 54 | gap_x = self.gap(x_) 55 | gap_feats.append(gap_x) 56 | 57 | # Find maximum channels in inter_feat and apply GAP with padding 58 | max_channels = max([feat.shape[1] for feat in gap_feats]) 59 | 60 | # Zero-pad features to match the maximum number of channels 61 | padded_feats = [] 62 | for feat in gap_feats: 63 | B, C_gap, _, _ = feat.shape 64 | if C_gap < max_channels: 65 | padding = torch.zeros(B, max_channels - C_gap, 1, 1, device=feat.device) 66 | feat = torch.cat([feat, padding], dim=1) # Concatenate along the channel dimension 67 | padded_feats.append(feat) 68 | 69 | return torch.stack(padded_feats, dim=0).squeeze(-1).squeeze(-1) # [num_feats, B, max_C] 70 | 71 | def _apply_convolutions(self, stacked_feats, num_feats, x_device): 72 | """Applies convolution operations with different kernel sizes.""" 73 | conv_results = [] 74 | 75 | for kernel_size in [1, 3, 5]: 76 | conv_layer = nn.Conv2d(1, 1, kernel_size=(num_feats, kernel_size), padding=0, bias=False).to(x_device) 77 | conv_out = conv_layer(stacked_feats).squeeze(2).squeeze(1) # Squeeze unnecessary dimensions 78 | conv_results.append(conv_out) 79 | 80 | return conv_results 81 | 82 | def _apply_linear_layers(self, conv_results, C, x_device): 83 | """Apply Linear layers to match the feature channels with C.""" 84 | linear_results = [] 85 | 86 | for conv_out in conv_results: 87 | linear_layer = nn.Linear(conv_out.shape[-1], C).to(x_device) 88 | linear_results.append(linear_layer(conv_out)) 89 | 90 | return torch.cat(linear_results, dim=-1) 91 | 92 | def forward(self, x, inter_feat, patch_resolution): 93 | if x.dim() == 4: 94 | B, _, N, C = x.shape 95 | x = x.view(B, N, C) 96 | else: 97 | B, N, C = x.shape 98 | H_x, W_x = patch_resolution 99 | 100 | # Reshape and permute x to match the input format for GAP 101 | x_reshaped = x.view(B, H_x, W_x, C).permute(0, 3, 1, 2) 102 | 103 | stacked_feats = self._gap_and_pad_features(inter_feat, x_reshaped) 104 | stacked_feats = stacked_feats.permute(1, 0, 2).unsqueeze(1) # [B, 1, num_feats, max_C] 105 | 106 | # Apply convolutions with different kernel sizes (1, 3, 5) 107 | conv_results = self._apply_convolutions(stacked_feats, len(inter_feat) + 1, x.device) 108 | 109 | # Apply linear layers to match the feature dimension C 110 | concatenated_features = self._apply_linear_layers(conv_results, C, x.device) 111 | 112 | conv_out = torch.tanh(self.mlp(concatenated_features)) # [B, C] 113 | 114 | # Adjust alpha using the learned conv_out features 115 | adjusted_alpha = self.alpha + conv_out.view(B, 1, C) 116 | 117 | # Normalize x and apply the adjustments 118 | mu = x.mean(dim=-1, keepdim=True) 119 | sigma = x.std(dim=-1, keepdim=True) 120 | x_normalized = (x - mu) / (sigma + 1e-3) 121 | 122 | # Apply color transform and adjustments based on channel_first flag 123 | if self.channel_first: 124 | x_transformed = torch.tensordot(x_normalized, self.color, dims=[[-1], [-1]]) 125 | x_out = x_transformed * adjusted_alpha + self.beta 126 | else: 127 | x_out = x_normalized * adjusted_alpha + self.beta 128 | x_out = torch.tensordot(x_out, self.color, dims=[[-1], [-1]]) 129 | 130 | return x_out.view(B, N, C) 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | # torch import 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import torchvision.models as models 7 | 8 | # other import 9 | import os 10 | import math 11 | import cv2 12 | import numpy as np 13 | from math import exp 14 | import pytorch_msssim 15 | 16 | 17 | class MeanShift(nn.Conv2d): 18 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 19 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 20 | std = torch.Tensor(rgb_std) 21 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 22 | self.weight.data.div_(std.view(3, 1, 1, 1)) 23 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 24 | self.bias.data.div_(std) 25 | self.requires_grad = False 26 | 27 | 28 | class VGGLoss(nn.Module): 29 | def __init__(self, conv_index='54', rgb_range=1): 30 | super(VGGLoss, self).__init__() 31 | vgg_features = models.vgg19(pretrained=True).features 32 | modules = [m for m in vgg_features] 33 | if conv_index == '22': 34 | self.vgg = nn.Sequential(*modules[:8]) 35 | self.vgg.cuda() 36 | elif conv_index == '54': 37 | self.vgg = nn.Sequential(*modules[:35]) 38 | self.vgg.cuda() 39 | 40 | vgg_mean = (0.485, 0.456, 0.406) 41 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 42 | self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std).cuda() 43 | self.vgg.requires_grad = False 44 | 45 | def forward(self, sr, hr): 46 | def _forward(x): 47 | x = self.sub_mean(x) 48 | x = self.vgg(x) 49 | return x 50 | 51 | vgg_sr = _forward(sr) 52 | with torch.no_grad(): 53 | vgg_hr = _forward(hr.detach()) 54 | 55 | loss = F.mse_loss(vgg_sr, vgg_hr) 56 | 57 | return loss 58 | 59 | def gaussian(window_size, sigma): 60 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 61 | return gauss / gauss.sum() 62 | 63 | 64 | def create_window(window_size, channel): 65 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 66 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 67 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 68 | return window 69 | 70 | 71 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 72 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 73 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 74 | 75 | mu1_sq = mu1.pow(2) 76 | mu2_sq = mu2.pow(2) 77 | mu1_mu2 = mu1 * mu2 78 | 79 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 80 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 81 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 82 | 83 | C1 = 0.01 ** 2 84 | C2 = 0.03 ** 2 85 | 86 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 87 | 88 | if size_average: 89 | return ssim_map.mean() 90 | else: 91 | return ssim_map.mean(1).mean(1).mean(1) 92 | 93 | 94 | 95 | class SSIM_loss(torch.nn.Module): 96 | def __init__(self, window_size=11, size_average=True): 97 | super(SSIM_loss, self).__init__() 98 | self.window_size = window_size 99 | self.size_average = size_average 100 | self.channel = 1 101 | self.window = create_window(window_size, self.channel) 102 | 103 | def forward(self, img1, img2): 104 | (_, channel, _, _) = img1.size() 105 | 106 | if channel == self.channel and self.window.data.type() == img1.data.type(): 107 | window = self.window 108 | else: 109 | window = create_window(self.window_size, channel) 110 | 111 | if img1.is_cuda: 112 | window = window.cuda(img1.get_device()) 113 | window = window.type_as(img1) 114 | 115 | self.window = window 116 | self.channel = channel 117 | 118 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 119 | 120 | 121 | 122 | 123 | class L1_Charbonnier_loss(torch.nn.Module): 124 | """L1 Charbonnierloss.""" 125 | def __init__(self): 126 | super(L1_Charbonnier_loss, self).__init__() 127 | self.eps = 1e-6 128 | 129 | def forward(self, X, Y): 130 | diff = torch.add(X, -Y) 131 | error = torch.sqrt(diff * diff + self.eps) 132 | loss = torch.mean(error) 133 | return loss 134 | 135 | 136 | # Perpectual Loss 137 | class LossNetwork(torch.nn.Module): 138 | def __init__(self, vgg_model): 139 | super(LossNetwork, self).__init__() 140 | self.vgg_layers = vgg_model 141 | self.layer_name_mapping = { 142 | '3': "relu1_2", 143 | '8': "relu2_2", 144 | '15': "relu3_3" 145 | } 146 | 147 | def output_features(self, x): 148 | output = {} 149 | for name, module in self.vgg_layers._modules.items(): 150 | x = module(x) 151 | if name in self.layer_name_mapping: 152 | output[self.layer_name_mapping[name]] = x 153 | return list(output.values()) 154 | 155 | def forward(self, pred_im, gt): 156 | loss = [] 157 | pred_im_features = self.output_features(pred_im) 158 | gt_features = self.output_features(gt) 159 | for pred_im_feature, gt_feature in zip(pred_im_features, gt_features): 160 | loss.append(F.mse_loss(pred_im_feature, gt_feature)) 161 | 162 | return sum(loss)/len(loss) 163 | 164 | -------------------------------------------------------------------------------- /custom_utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | 7 | def draw_figure(fig): 8 | fig.canvas.draw() 9 | fig.canvas.flush_events() 10 | plt.pause(0.001) 11 | 12 | 13 | def show_tensor(a: torch.Tensor, fig_num = None, title = None, range=(None, None), ax=None): 14 | """Display a 2D tensor. 15 | args: 16 | fig_num: Figure number. 17 | title: Title of figure. 18 | """ 19 | a_np = a.squeeze().cpu().clone().detach().numpy() 20 | 21 | if a_np.ndim == 3: 22 | a_np = np.transpose(a_np, (1, 2, 0)) 23 | 24 | 25 | if ax is None: 26 | fig = plt.figure(fig_num) 27 | plt.tight_layout() 28 | plt.cla() 29 | plt.imshow(a_np, vmin=range[0], vmax=range[1]) 30 | plt.axis('off') 31 | plt.axis('equal') 32 | if title is not None: 33 | plt.title(title) 34 | draw_figure(fig) 35 | else: 36 | ax.cla() 37 | ax.imshow(a_np, vmin=range[0], vmax=range[1]) 38 | ax.set_axis_off() 39 | ax.axis('equal') 40 | if title is not None: 41 | ax.set_title(title) 42 | draw_figure(plt.gcf()) 43 | 44 | 45 | def plot_graph(a: torch.Tensor, fig_num = None, title = None): 46 | """Plot graph. Data is a 1D tensor. 47 | args: 48 | fig_num: Figure number. 49 | title: Title of figure. 50 | """ 51 | a_np = a.squeeze().cpu().clone().detach().numpy() 52 | if a_np.ndim > 1: 53 | raise ValueError 54 | fig = plt.figure(fig_num) 55 | # plt.tight_layout() 56 | plt.cla() 57 | plt.plot(a_np) 58 | if title is not None: 59 | plt.title(title) 60 | draw_figure(fig) 61 | 62 | 63 | def show_image_with_boxes(im, boxes, iou_pred=None, disp_ids=None): 64 | im_np = im.clone().cpu().squeeze().numpy() 65 | im_np = np.ascontiguousarray(im_np.transpose(1, 2, 0).astype(np.uint8)) 66 | 67 | boxes = boxes.view(-1, 4).cpu().numpy().round().astype(int) 68 | 69 | # Draw proposals 70 | for i_ in range(boxes.shape[0]): 71 | if disp_ids is None or disp_ids[i_]: 72 | bb = boxes[i_, :] 73 | disp_color = (i_*38 % 256, (255 - i_*97) % 256, (123 + i_*66) % 256) 74 | cv2.rectangle(im_np, (bb[0], bb[1]), (bb[0] + bb[2], bb[1] + bb[3]), 75 | disp_color, 1) 76 | 77 | if iou_pred is not None: 78 | text_pos = (bb[0], bb[1] - 5) 79 | cv2.putText(im_np, 'ID={} IOU = {:3.2f}'.format(i_, iou_pred[i_]), text_pos, 80 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, bottomLeftOrigin=False) 81 | 82 | im_tensor = torch.from_numpy(im_np.transpose(2, 0, 1)).float() 83 | 84 | return im_tensor 85 | 86 | 87 | 88 | def _pascal_color_map(N=256, normalized=False): 89 | """ 90 | Python implementation of the color map function for the PASCAL VOC data set. 91 | Official Matlab version can be found in the PASCAL VOC devkit 92 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 93 | """ 94 | 95 | def bitget(byteval, idx): 96 | return (byteval & (1 << idx)) != 0 97 | 98 | dtype = 'float32' if normalized else 'uint8' 99 | cmap = np.zeros((N, 3), dtype=dtype) 100 | for i in range(N): 101 | r = g = b = 0 102 | c = i 103 | for j in range(8): 104 | r = r | (bitget(c, 0) << 7 - j) 105 | g = g | (bitget(c, 1) << 7 - j) 106 | b = b | (bitget(c, 2) << 7 - j) 107 | c = c >> 3 108 | 109 | cmap[i] = np.array([r, g, b]) 110 | 111 | cmap = cmap / 255 if normalized else cmap 112 | return cmap 113 | 114 | 115 | def overlay_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 116 | """ Overlay mask over image. 117 | Source: https://github.com/albertomontesg/davis-interactive/blob/master/davisinteractive/utils/visualization.py 118 | This function allows you to overlay a mask over an image with some 119 | transparency. 120 | # Arguments 121 | im: Numpy Array. Array with the image. The shape must be (H, W, 3) and 122 | the pixels must be represented as `np.uint8` data type. 123 | ann: Numpy Array. Array with the mask. The shape must be (H, W) and the 124 | values must be intergers 125 | alpha: Float. Proportion of alpha to apply at the overlaid mask. 126 | colors: Numpy Array. Optional custom colormap. It must have shape (N, 3) 127 | being N the maximum number of colors to represent. 128 | contour_thickness: Integer. Thickness of each object index contour draw 129 | over the overlay. This function requires to have installed the 130 | package `opencv-python`. 131 | # Returns 132 | Numpy Array: Image of the overlay with shape (H, W, 3) and data type 133 | `np.uint8`. 134 | """ 135 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 136 | if im.shape[:-1] != ann.shape: 137 | raise ValueError('First two dimensions of `im` and `ann` must match') 138 | if im.shape[-1] != 3: 139 | raise ValueError('im must have three channels at the 3 dimension') 140 | 141 | colors = colors or _pascal_color_map() 142 | colors = np.asarray(colors, dtype=np.uint8) 143 | 144 | mask = colors[ann] 145 | fg = im * alpha + (1 - alpha) * mask 146 | 147 | img = im.copy() 148 | img[ann > 0] = fg[ann > 0] 149 | 150 | if contour_thickness: # pragma: no cover 151 | import cv2 152 | for obj_id in np.unique(ann[ann > 0]): 153 | contours = cv2.findContours((ann == obj_id).astype( 154 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 155 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 156 | contour_thickness) 157 | return img 158 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Print usage instructions -> set parameter 0 4 | print_usage() { 5 | echo "Usage: $0 " 6 | echo " dataset_number:" 7 | echo " 1 -> LOL_v1" 8 | echo " 2 -> LOL_v2_real" 9 | echo " 3 -> LOL_v2_sync" 10 | echo " 4 -> MIT_5K" 11 | echo " 5 -> SID" 12 | echo " 6 -> SMID" 13 | echo " 7 -> SDSD_indoor" 14 | echo " 8 -> SDSD_outdoor" 15 | echo " 9 -> LOL_blur" 16 | } 17 | 18 | # Define function for testing LOL-v1 19 | test_LOL_v1() { 20 | python ./tools/test.py \ 21 | --gpu_id 0 \ 22 | --model_name URWKV \ 23 | --testSet_path '/data/xr/Dataset/light_dataset/LOL_v1/eval15/' \ 24 | --weight_path './checkpoints/LOL_v1/URWKV/models/model_bestPSNR.pth' \ 25 | --save_path './results/LOL_v1' 26 | 27 | python ./tools/measure.py \ 28 | --gpu_id 0 \ 29 | --model_name URWKV \ 30 | --dataset_name LOL_v1 31 | } 32 | 33 | # Define function for testing LOL_v2_real 34 | test_LOL_v2_real() { 35 | python ./tools/test.py \ 36 | --gpu_id 0 \ 37 | --model_name URWKV \ 38 | --testSet_path '/data/xr/Dataset/light_dataset/LOL_v2/Real_captured/Test/' \ 39 | --weight_path './checkpoints/LOL_v2_real/URWKV/models/model_bestPSNR.pth' \ 40 | --save_path './results/LOL_v2_real' 41 | 42 | python ./tools/measure.py \ 43 | --gpu_id 0 \ 44 | --model_name URWKV \ 45 | --dataset_name LOL_v2_real 46 | } 47 | 48 | # Define function for testing LOL_v2_sync 49 | test_LOL_v2_sync() { 50 | python ./tools/test.py \ 51 | --gpu_id 0 \ 52 | --model_name URWKV \ 53 | --testSet_path '/data/xr/Dataset/light_dataset/LOL_v2/Synthetic/Test/' \ 54 | --weight_path './checkpoints/LOL_v2_sync/URWKV/models/model_bestPSNR.pth' \ 55 | --save_path './results/LOL_v2_sync' 56 | 57 | python ./tools/measure.py \ 58 | --gpu_id 0 \ 59 | --model_name URWKV \ 60 | --dataset_name LOL_v2_sync 61 | } 62 | 63 | # Define function for testing MIT_5K 64 | test_MIT_5K() { 65 | python ./tools/test_MIT5K.py \ 66 | --gpu_id 0 \ 67 | --model_name URWKV \ 68 | --testSet_path '/data/xr/Dataset/light_dataset/MIT-Adobe-5K-512/test/' \ 69 | --weight_path './checkpoints/MIT_5K/URWKV/models/model_bestPSNR.pth' \ 70 | --save_path './results/MIT_5K' 71 | 72 | python ./tools/measure.py \ 73 | --gpu_id 0 \ 74 | --model_name URWKV \ 75 | --dataset_name MIT_5K 76 | } 77 | 78 | # Define function for testing SID 79 | test_SID() { 80 | python ./tools/test.py \ 81 | --gpu_id 0 \ 82 | --model_name URWKV \ 83 | --testSet_path '/data/xr/Dataset/light_dataset/SID_png/eval/' \ 84 | --weight_path './checkpoints/SID/URWKV/models/model_bestPSNR.pth' \ 85 | --save_path './results/SID' 86 | 87 | python ./tools/measure.py \ 88 | --gpu_id 0 \ 89 | --model_name URWKV \ 90 | --dataset_name SID 91 | } 92 | 93 | # Define function for testing SMID 94 | test_SMID() { 95 | python ./tools/test.py \ 96 | --gpu_id 0 \ 97 | --model_name URWKV \ 98 | --testSet_path '/data/xr/Dataset/light_dataset/SMID_png/eval/' \ 99 | --weight_path './checkpoints/SMID/URWKV/models/model_bestPSNR.pth' \ 100 | --save_path './results/SMID' 101 | 102 | python ./tools/measure.py \ 103 | --gpu_id 0 \ 104 | --model_name URWKV \ 105 | --dataset_name SMID 106 | } 107 | 108 | # Define function for testing SDSD_indoor 109 | test_SDSD_indoor() { 110 | python ./tools/test.py \ 111 | --gpu_id 0 \ 112 | --model_name URWKV \ 113 | --testSet_path '/data/xr/Dataset/light_dataset/SDSD_indoor_png/eval/' \ 114 | --weight_path './checkpoints/SDSD_indoor/URWKV/models/model_bestPSNR.pth' \ 115 | --save_path './results/SDSD_indoor' 116 | 117 | python ./tools/measure.py \ 118 | --gpu_id 0 \ 119 | --model_name URWKV \ 120 | --dataset_name SDSD_indoor 121 | } 122 | 123 | # Define function for testing SDSD_outdoor 124 | test_SDSD_outdoor() { 125 | python ./tools/test.py \ 126 | --gpu_id 0 \ 127 | --model_name URWKV \ 128 | --testSet_path '/data/xr/Dataset/light_dataset/SDSD_outdoor_png/eval/' \ 129 | --weight_path './checkpoints/SDSD_outdoor/URWKV/models/model_bestPSNR.pth' \ 130 | --save_path './results/SDSD_outdoor' 131 | 132 | python ./tools/measure.py \ 133 | --gpu_id 0 \ 134 | --model_name URWKV \ 135 | --dataset_name SDSD_outdoor 136 | } 137 | 138 | # Define function for testing LOL-blur 139 | test_LOL_blur() { 140 | python ./tools/test.py \ 141 | --gpu_id 0 \ 142 | --model_name URWKV \ 143 | --testSet_path '/data/xr/Dataset/light_dataset/LOL_blur/eval' \ 144 | --weight_path './checkpoints/LOL_blur/URWKV/models/model_bestSSIM.pth' \ 145 | --save_path './results/LOL_blur' 146 | 147 | python ./tools/measure.py \ 148 | --gpu_id 0 \ 149 | --model_name URWKV \ 150 | --dataset_name LOL_blur 151 | } 152 | 153 | # Check if argument is provided 154 | if [ $# -ne 1 ]; then 155 | print_usage 156 | exit 1 157 | fi 158 | 159 | # Parse command line argument 160 | case $1 in 161 | 1) 162 | test_LOL_v1 163 | ;; 164 | 2) 165 | test_LOL_v2_real 166 | ;; 167 | 3) 168 | test_LOL_v2_sync 169 | ;; 170 | 4) 171 | test_MIT_5K 172 | ;; 173 | 5) 174 | test_SID 175 | ;; 176 | 6) 177 | test_SMID 178 | ;; 179 | 7) 180 | test_SDSD_indoor 181 | ;; 182 | 8) 183 | test_SDSD_outdoor 184 | ;; 185 | 9) 186 | test_LOL_blur 187 | ;; 188 | *) 189 | echo "Invalid dataset number. Please provide a number between 1 to 4." 190 | print_usage 191 | exit 1 192 | ;; 193 | esac -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # URWKV: Unified RWKV Model with Multi-state Perspective for Low-light Image Restoration 2 | 3 | 📢 This paper has been accepted to CVPR 2025! 🎉 4 | 5 | [main paper](https://openaccess.thecvf.com/content/CVPR2025/papers/Xu_URWKV_Unified_RWKV_Model_with_Multi-state_Perspective_for_Low-light_Image_CVPR_2025_paper.pdf) | [supplementary materials](https://openaccess.thecvf.com/content/CVPR2025/supplemental/Xu_URWKV_Unified_RWKV_CVPR_2025_supplemental.pdf) | [poster](https://pan.baidu.com/s/18Z84hr2_HlXGzy1XXcZMIw?pwd=56u9) 6 | 7 | **TODO:** 8 | 9 | * [x] Release the official implementation of URWKV, including training and inference scripts. This is a relatively rough version, so you may need some time to configure the environment and paths. 10 | 11 | * [x] Release pre-trained weights for reproducibility. 12 | 13 | * [ ] Add visual comparisons with SOTA methods across various benchmark datasets. 14 | 15 | * [ ] Refactor and document code for clarity and reproducibility. 16 | 17 | **Notes and Links:** 18 | 19 | * **Results:** Visual results of URWKV can be downloaded from [here](https://pan.baidu.com/s/1EiuCvuj_Ycw0YEDpzhFLJg?pwd=kn23). 20 | 21 | * **Pre-trained weights:** The weights for SMID and MIT-5K may have been overwritten. You can either train them yourself or wait for us to re-train and upload them later. Pre-trained weights for other datasets can be downloaded from [here](https://pan.baidu.com/s/1UuKmG6WcaCWdwkj3_jsPPg?pwd=5ady). 22 | 23 | * **Datasets:** All datasets used in this work can be downloaded from [here](https://pan.baidu.com/s/1R0L4QEXw0uOyWyVp1x6Zig?pwd=2x5i). 24 | 25 | * **Hyperparameter tuning:** Since we haven't done much hyperparameter tuning, you are encouraged to explore better configurations to potentially improve the model's performance. 26 | 27 | ## Abstract 28 | 29 | Existing low-light image enhancement (LLIE) and joint LLIE and deblurring (LLIE-deblur) models have made strides in addressing predefined degradations, yet they are often constrained by dynamically coupled degradations. To address these challenges, we introduce a Unified Receptance Weighted Key Value (URWKV) model with multi-state perspective, enabling flexible and effective degradation restoration for low-light images. Specifically, we customize the core URWKV block to perceive and analyze complex degradations by leveraging multiple intra- and inter-stage states. First, inspired by the pupil mechanism in the human visual system, we propose Luminance-adaptive Normalization (LAN) that adjusts normalization parameters based on rich inter-stage states, allowing for adaptive, scene-aware luminance modulation. Second, we aggregate multiple intra-stage states through exponential moving average approach, effectively capturing subtle variations while mitigating information loss inherent in the single-state mechanism. To reduce the degradation effects commonly associated with conventional skip connections, we propose the State-aware Selective Fusion (SSF) module, which dynamically aligns and integrates multi-state features across encoder stages, selectively fusing contextual information. In comparison to state-of-the-art models, our URWKV model achieves superior performance on various benchmarks, while requiring significantly fewer parameters and computational resources. 30 | 31 | ## Overview 32 | 33 | ![](README_md_files/6cf966f0-5190-11f0-847b-8bd8db6e5334.jpeg?v=1&type=image) 34 | 35 | ## Main Results 36 | 37 | Consistent with [BiFormer](https://github.com/FZU-N/BiFormer), results are measured using `measure_pair.py`. It should be noted that all metrics in our method are computed in the sRGB space, and no GT Mean-related techniques are applied. 38 | 39 | ![](README_md_files/9e45a430-5190-11f0-847b-8bd8db6e5334.jpeg?v=1&type=image) 40 | 41 | ![](README_md_files/e4f9c500-5190-11f0-847b-8bd8db6e5334.jpeg?v=1&type=image) 42 | 43 | To ensure fairness, if a comparison method does not provide pretrained weights, we retrain it using the recommended settings provided by the authors. Otherwise, we use the officially released pretrained weights for evaluation. All results are evaluated using a unified script, `measure_pair.py`. In this paper, the following methods were retrained: SNR-Net, FourLLIE, UHDFour, LLFormer, Retinexformer, BiFormer, RetinexMamba, LEDNet, PDHAT, MIRNet, Restormer, and MambaIR. The corresponding visual comparison results will be released later. 44 | 45 | ## Environment Setup 46 | 47 | ### 1. Create and Activate a Conda Environment 48 | 49 | ```markup 50 | conda create --name URWKV python=3.9 51 | conda activate URWKV 52 | ``` 53 | 54 | ### 2. Install PyTorch and Dependencies 55 | 56 | Install PyTorch with CUDA 11.3 support: 57 | 58 | ```markup 59 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch 60 | ``` 61 | 62 | Install cuDNN: 63 | 64 | ```markup 65 | conda install cudnn 66 | ``` 67 | 68 | ### 3. Install Python Dependencies 69 | 70 | ```markup 71 | pip install pyyaml yacs tqdm colorama pandas natsort 72 | pip install matplotlib tensorboardX 73 | pip install cython thop prefetch_generator 74 | pip install opencv-python scikit-image 75 | pip install timm einops mmcls 76 | pip install pytorch_msssim IQA_pytorch pyiqa lpips 77 | pip install numpy==1.26.4 78 | ``` 79 | 80 | ### 4. Install MMCV 81 | 82 | ```markup 83 | pip install -U openmim 84 | mim install mmcv==1.7.1 85 | ``` 86 | 87 | > ⚠️ Note: If you encounter an error related to Ninja while compiling C++ extensions (e.g., Ninja is required to load C++ extensions), install Ninja with: `sudo apt-get install ninja-build` 88 | 89 | ## Citation 90 | 91 | If you find this work useful for your research, please cite: 92 | 93 | ```markup 94 | @inproceedings{xu2025urwkv, 95 | title={URWKV: Unified RWKV Model with Multi-state Perspective for Low-light Image Restoration}, 96 | author={Xu, Rui and Niu, Yuzhen and Li, Yuezhou and Xu, Huangbiao and Liu, Wenxi and Chen, Yuzhong}, 97 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 98 | pages={21267--21276}, 99 | year={2025} 100 | } 101 | ``` 102 | 103 | ## License 104 | 105 | This repository is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/FZU-N/URWKV/blob/main/LICENSE) file. 106 | 107 | ## Acknowledgement 108 | 109 | URWKV is built with reference to the code of the following projects: [RWKV](https://github.com/BlinkDL/RWKV-LM), [Vision-RWKV](https://github.com/OpenGVLab/Vision-RWKV), and [BiFormer](https://github.com/FZU-N/BiFormer). Thanks for their awesome work! 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /custom_utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | from pytorch_msssim import ssim 5 | import math 6 | 7 | def calculate_psnr(img1, img2, border=0): 8 | # img1 and img2 have range [0, 255] 9 | #img1 = img1.squeeze() 10 | #img2 = img2.squeeze() 11 | if not img1.shape == img2.shape: 12 | raise ValueError('Input images must have the same dimensions.') 13 | h, w = img1.shape[:2] 14 | img1 = img1[border:h-border, border:w-border] 15 | img2 = img2[border:h-border, border:w-border] 16 | 17 | img1 = img1.astype(np.float64) 18 | img2 = img2.astype(np.float64) 19 | mse = np.mean((img1 - img2)**2) 20 | if mse == 0: 21 | return float('inf') 22 | return 20 * math.log10(255.0 / math.sqrt(mse)) 23 | 24 | 25 | # -------------------------------------------- 26 | # SSIM 27 | # -------------------------------------------- 28 | def calculate_ssim(img1, img2, border=0): 29 | '''calculate SSIM 30 | the same outputs as MATLAB's 31 | img1, img2: [0, 255] 32 | ''' 33 | #img1 = img1.squeeze() 34 | #img2 = img2.squeeze() 35 | if not img1.shape == img2.shape: 36 | raise ValueError('Input images must have the same dimensions.') 37 | h, w = img1.shape[:2] 38 | img1 = img1[border:h-border, border:w-border] 39 | img2 = img2[border:h-border, border:w-border] 40 | 41 | if img1.ndim == 2: 42 | return ssim(img1, img2) 43 | elif img1.ndim == 3: 44 | if img1.shape[2] == 3: 45 | ssims = [] 46 | for i in range(3): 47 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 48 | return np.array(ssims).mean() 49 | elif img1.shape[2] == 1: 50 | return ssim(np.squeeze(img1), np.squeeze(img2)) 51 | else: 52 | raise ValueError('Wrong input image dimensions.') 53 | 54 | def load_img(filepath): 55 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 56 | 57 | def torchPSNR(tar_img, prd_img): 58 | imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1) 59 | rmse = (imdff**2).mean().sqrt() 60 | ps = 20*torch.log10(1/rmse) 61 | return ps 62 | 63 | def batch_PSNR(img1, img2, data_range=None): 64 | PSNR = [] 65 | for im1, im2 in zip(img1, img2): 66 | psnr = torchPSNR(im1, im2) 67 | PSNR.append(psnr) 68 | return sum(PSNR)/len(PSNR) 69 | 70 | def torchSSIM(tar_img, prd_img): 71 | return ssim(tar_img, prd_img, data_range=1.0, size_average=True) 72 | 73 | def save_img(filepath, img): 74 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 75 | 76 | def numpyPSNR(tar_img, prd_img): 77 | imdff = np.float32(prd_img) - np.float32(tar_img) 78 | rmse = np.sqrt(np.mean(imdff**2)) 79 | ps = 20*np.log10(255/rmse) 80 | return ps 81 | import torch 82 | 83 | """ 84 | The following reference from: 85 | https://github.com/oblime/RGB_HSV_HSL 86 | """ 87 | def rgb2hsl_torch(rgb: torch.Tensor) -> torch.Tensor: 88 | cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True) 89 | cmin = torch.min(rgb, dim=1, keepdim=True)[0] 90 | delta = cmax - cmin 91 | hsl_h = torch.empty_like(rgb[:, 0:1, :, :]) 92 | cmax_idx[delta == 0] = 3 93 | hsl_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0] 94 | hsl_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1] 95 | hsl_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2] 96 | hsl_h[cmax_idx == 3] = 0. 97 | hsl_h /= 6. 98 | 99 | hsl_l = (cmax + cmin) / 2. 100 | hsl_s = torch.empty_like(hsl_h) 101 | hsl_s[hsl_l == 0] = 0 102 | hsl_s[hsl_l == 1] = 0 103 | hsl_l_ma = torch.bitwise_and(hsl_l > 0, hsl_l < 1) 104 | hsl_l_s0_5 = torch.bitwise_and(hsl_l_ma, hsl_l <= 0.5) 105 | hsl_l_l0_5 = torch.bitwise_and(hsl_l_ma, hsl_l > 0.5) 106 | hsl_s[hsl_l_s0_5] = ((cmax - cmin) / (hsl_l * 2.))[hsl_l_s0_5] 107 | hsl_s[hsl_l_l0_5] = ((cmax - cmin) / (- hsl_l * 2. + 2.))[hsl_l_l0_5] 108 | return torch.cat([hsl_h, hsl_s, hsl_l], dim=1) 109 | 110 | 111 | def rgb2hsv_torch(rgb: torch.Tensor) -> torch.Tensor: 112 | cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True) 113 | cmin = torch.min(rgb, dim=1, keepdim=True)[0] 114 | delta = cmax - cmin 115 | hsv_h = torch.empty_like(rgb[:, 0:1, :, :]) 116 | cmax_idx[delta == 0] = 3 117 | hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0] 118 | hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1] 119 | hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2] 120 | hsv_h[cmax_idx == 3] = 0. 121 | hsv_h /= 6. 122 | hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax) 123 | hsv_v = cmax 124 | return torch.cat([hsv_h, hsv_s, hsv_v], dim=1) 125 | 126 | 127 | def hsv2rgb_torch(hsv: torch.Tensor) -> torch.Tensor: 128 | hsv_h, hsv_s, hsv_l = hsv[:, 0:1], hsv[:, 1:2], hsv[:, 2:3] 129 | _c = hsv_l * hsv_s 130 | _x = _c * (- torch.abs(hsv_h * 6. % 2. - 1) + 1.) 131 | _m = hsv_l - _c 132 | _o = torch.zeros_like(_c) 133 | idx = (hsv_h * 6.).type(torch.uint8) 134 | idx = (idx % 6).expand(-1, 3, -1, -1) 135 | rgb = torch.empty_like(hsv) 136 | rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0] 137 | rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1] 138 | rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2] 139 | rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3] 140 | rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4] 141 | rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5] 142 | rgb += _m 143 | return rgb 144 | 145 | 146 | def hsl2rgb_torch(hsl: torch.Tensor) -> torch.Tensor: 147 | hsl_h, hsl_s, hsl_l = hsl[:, 0:1], hsl[:, 1:2], hsl[:, 2:3] 148 | _c = (-torch.abs(hsl_l * 2. - 1.) + 1) * hsl_s 149 | _x = _c * (-torch.abs(hsl_h * 6. % 2. - 1) + 1.) 150 | _m = hsl_l - _c / 2. 151 | idx = (hsl_h * 6.).type(torch.uint8) 152 | idx = (idx % 6).expand(-1, 3, -1, -1) 153 | rgb = torch.empty_like(hsl) 154 | _o = torch.zeros_like(_c) 155 | rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0] 156 | rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1] 157 | rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2] 158 | rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3] 159 | rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4] 160 | rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5] 161 | rgb += _m 162 | return rgb -------------------------------------------------------------------------------- /custom_utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from typing import Tuple 6 | 7 | def gaussian(window_size, sigma): 8 | def gauss_fcn(x): 9 | return -(x - window_size // 2)**2 / float(2 * sigma**2) 10 | gauss = torch.stack( 11 | [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)]) 12 | return gauss / gauss.sum() 13 | 14 | def get_gaussian_kernel(ksize: int, sigma: float) -> torch.Tensor: 15 | if not isinstance(ksize, int) or ksize % 2 == 0 or ksize <= 0: 16 | raise TypeError("ksize must be an odd positive integer. Got {}" 17 | .format(ksize)) 18 | window_1d: torch.Tensor = gaussian(ksize, sigma) 19 | return window_1d 20 | 21 | def get_gaussian_kernel2d(ksize: Tuple[int, int], 22 | sigma: Tuple[float, float]) -> torch.Tensor: 23 | if not isinstance(ksize, tuple) or len(ksize) != 2: 24 | raise TypeError("ksize must be a tuple of length two. Got {}" 25 | .format(ksize)) 26 | if not isinstance(sigma, tuple) or len(sigma) != 2: 27 | raise TypeError("sigma must be a tuple of length two. Got {}" 28 | .format(sigma)) 29 | ksize_x, ksize_y = ksize 30 | sigma_x, sigma_y = sigma 31 | kernel_x: torch.Tensor = get_gaussian_kernel(ksize_x, sigma_x) 32 | kernel_y: torch.Tensor = get_gaussian_kernel(ksize_y, sigma_y) 33 | kernel_2d: torch.Tensor = torch.matmul( 34 | kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()) 35 | return kernel_2d 36 | 37 | 38 | class PSNRLoss(nn.Module): 39 | 40 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 41 | super(PSNRLoss, self).__init__() 42 | assert reduction == 'mean' 43 | self.loss_weight = loss_weight 44 | self.scale = 10 / np.log(10) 45 | self.toY = toY 46 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 47 | self.first = True 48 | 49 | def forward(self, pred, target): 50 | assert len(pred.size()) == 4 51 | if self.toY: 52 | if self.first: 53 | self.coef = self.coef.to(pred.device) 54 | self.first = False 55 | 56 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 57 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 58 | 59 | pred, target = pred / 255., target / 255. 60 | pass 61 | assert len(pred.size()) == 4 62 | loss = -(self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()) 63 | return loss 64 | 65 | 66 | class SSIMLoss(nn.Module): 67 | def __init__(self, window_size: int = 11, reduction: str = 'mean', max_val: float = 1.0) -> None: 68 | super(SSIMLoss, self).__init__() 69 | self.window_size: int = window_size 70 | self.max_val: float = max_val 71 | self.reduction: str = reduction 72 | 73 | self.window: torch.Tensor = get_gaussian_kernel2d( 74 | (window_size, window_size), (1.5, 1.5)) 75 | self.padding: int = self.compute_zero_padding(window_size) 76 | 77 | self.C1: float = (0.01 * self.max_val) ** 2 78 | self.C2: float = (0.03 * self.max_val) ** 2 79 | 80 | @staticmethod 81 | def compute_zero_padding(kernel_size: int) -> int: 82 | """Computes zero padding.""" 83 | return (kernel_size - 1) // 2 84 | 85 | def filter2D( 86 | self, 87 | input: torch.Tensor, 88 | kernel: torch.Tensor, 89 | channel: int) -> torch.Tensor: 90 | return F.conv2d(input, kernel, padding=self.padding, groups=channel) 91 | 92 | def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: 93 | # prepare kernel 94 | b, c, h, w = img1.shape 95 | tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype) 96 | kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1) 97 | 98 | # compute local mean per channel 99 | mu1: torch.Tensor = self.filter2D(img1, kernel, c) 100 | mu2: torch.Tensor = self.filter2D(img2, kernel, c) 101 | 102 | mu1_sq = mu1.pow(2) 103 | mu2_sq = mu2.pow(2) 104 | mu1_mu2 = mu1 * mu2 105 | 106 | # compute local sigma per channel 107 | sigma1_sq = self.filter2D(img1 * img1, kernel, c) - mu1_sq 108 | sigma2_sq = self.filter2D(img2 * img2, kernel, c) - mu2_sq 109 | sigma12 = self.filter2D(img1 * img2, kernel, c) - mu1_mu2 110 | 111 | ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \ 112 | ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2)) 113 | 114 | loss = torch.clamp(1. - ssim_map, min=0, max=1) / 2. 115 | 116 | if self.reduction == 'mean': 117 | loss = torch.mean(loss) 118 | elif self.reduction == 'sum': 119 | loss = torch.sum(loss) 120 | elif self.reduction == 'none': 121 | pass 122 | return loss 123 | # ------------------------------------------------------------------------------ 124 | 125 | class CharbonnierLoss(nn.Module): 126 | """Charbonnier Loss (L1)""" 127 | 128 | def __init__(self, eps=1e-3): 129 | super(CharbonnierLoss, self).__init__() 130 | self.eps = eps 131 | 132 | def forward(self, x, y): 133 | diff = x - y 134 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 135 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 136 | return loss 137 | 138 | class EdgeLoss(nn.Module): 139 | def __init__(self): 140 | super(EdgeLoss, self).__init__() 141 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 142 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 143 | if torch.cuda.is_available(): 144 | self.kernel = self.kernel.cuda() 145 | self.loss = CharbonnierLoss() 146 | 147 | def conv_gauss(self, img): 148 | n_channels, _, kw, kh = self.kernel.shape 149 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 150 | return F.conv2d(img, self.kernel, groups=n_channels) 151 | 152 | def laplacian_kernel(self, current): 153 | filtered = self.conv_gauss(current) # filter 154 | down = filtered[:,:,::2,::2] # downsample 155 | new_filter = torch.zeros_like(filtered) 156 | new_filter[:, :, ::2, ::2] = down*4 # upsample 157 | filtered = self.conv_gauss(new_filter) # filter 158 | diff = current - filtered 159 | return diff 160 | 161 | def forward(self, x, y): 162 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) 163 | return loss 164 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Print usage instructions 4 | print_usage() { 5 | echo "Usage: $0 [] [] [] [] [] [] []" 6 | echo " dataset options:" 7 | echo " LOL_v1 -> Use LOL_v1 dataset for training" 8 | echo " LOL_v2_real -> Use LOL_v2_real dataset for training" 9 | echo " LOL_v2_sync -> Use LOL_v2_sync dataset for training" 10 | echo " MIT_5K -> Use MIT_5K dataset for training" 11 | echo " SID -> Use SID dataset for training" 12 | echo " SMID -> Use SMID dataset for training" 13 | echo " SDSD_indoor -> Use SDSD_indoor dataset for training" 14 | echo " SDSD_outdoor -> Use SDSD_outdoor dataset for training" 15 | echo " LOL_blur -> Use LOL_blur dataset for training" 16 | echo "Optional arguments (default values are used if not provided):" 17 | echo " model_name: Model name (default: URWKV)" 18 | echo " block_num: Number of blocks (default: 2)" 19 | echo " recursive_num: Number of recursive blocks (default: 3)" 20 | echo " batch_size: Batch size (default: 8)" 21 | echo " epochs: Number of epochs (default: 1000)" 22 | echo " lr_init: Initial learning rate (default: 0.0002)" 23 | echo " lr_min: Minimum learning rate (default: 1e-6)" 24 | } 25 | 26 | # Define function for training LOL-v1 27 | train_LOL_v1() { 28 | python ./tools/train.py \ 29 | --gpu_id $gpu_id \ 30 | --model_name $model_name \ 31 | --yml_path './configs/LOL_v1.yaml' \ 32 | --pretrain_weights '' \ 33 | --batch_size $batch_size \ 34 | --epochs $epochs \ 35 | --lr_init $lr_init \ 36 | --lr_min $lr_min \ 37 | --channel $channel 38 | 39 | # train_MIT_5K 40 | } 41 | 42 | # Define function for training LOL-v2-real 43 | train_LOL_v2_real() { 44 | python ./tools/train.py \ 45 | --gpu_id $gpu_id \ 46 | --model_name $model_name \ 47 | --yml_path './configs/LOL_v2_real.yaml' \ 48 | --pretrain_weights '' \ 49 | --batch_size $batch_size \ 50 | --epochs $epochs \ 51 | --lr_init $lr_init \ 52 | --lr_min $lr_min \ 53 | --channel $channel 54 | # train_LOL_v2_sync 55 | } 56 | 57 | # Define function for training LOL-v2-sync 58 | train_LOL_v2_sync() { 59 | python ./tools/train.py \ 60 | --gpu_id $gpu_id \ 61 | --model_name $model_name \ 62 | --yml_path './configs/LOL_v2_synthetic.yaml' \ 63 | --pretrain_weights '' \ 64 | --batch_size $batch_size \ 65 | --epochs $epochs \ 66 | --lr_init $lr_init \ 67 | --lr_min $lr_min \ 68 | --channel $channel 69 | } 70 | 71 | # Define function for training MIT_5K 72 | train_MIT_5K() { 73 | python ./tools/train_MIT5K.py \ 74 | --gpu_id $gpu_id \ 75 | --model_name $model_name \ 76 | --yml_path './configs/FiveK.yaml' \ 77 | --pretrain_weights '' \ 78 | --batch_size $batch_size \ 79 | --epochs $epochs \ 80 | --lr_init $lr_init \ 81 | --lr_min $lr_min \ 82 | --channel $channel 83 | # train_SMID 84 | } 85 | 86 | # Define function for training SID 87 | train_SID() { 88 | python ./tools/train.py \ 89 | --gpu_id $gpu_id \ 90 | --model_name $model_name \ 91 | --yml_path './configs/SID.yaml' \ 92 | --pretrain_weights '' \ 93 | --batch_size $batch_size \ 94 | --epochs $epochs \ 95 | --lr_init $lr_init \ 96 | --lr_min $lr_min \ 97 | --channel $channel 98 | # train_SDSD_outdoor 99 | } 100 | 101 | # Define function for training SMID 102 | train_SMID() { 103 | python ./tools/train.py \ 104 | --gpu_id $gpu_id \ 105 | --model_name $model_name \ 106 | --yml_path './configs/SMID.yaml' \ 107 | --pretrain_weights '' \ 108 | --batch_size $batch_size \ 109 | --epochs $epochs \ 110 | --lr_init $lr_init \ 111 | --lr_min $lr_min \ 112 | --channel $channel 113 | } 114 | 115 | # Define function for training SDSD_indoor 116 | train_SDSD_indoor() { 117 | python ./tools/train.py \ 118 | --gpu_id $gpu_id \ 119 | --model_name $model_name \ 120 | --yml_path './configs/SDSD_indoor.yaml' \ 121 | --pretrain_weights '' \ 122 | --batch_size $batch_size \ 123 | --epochs $epochs \ 124 | --lr_init $lr_init \ 125 | --lr_min $lr_min \ 126 | --channel $channel 127 | } 128 | 129 | # Define function for training SDSD_outdoor 130 | train_SDSD_outdoor() { 131 | python ./tools/train.py \ 132 | --gpu_id $gpu_id \ 133 | --model_name $model_name \ 134 | --yml_path './configs/SDSD_outdoor.yaml' \ 135 | --pretrain_weights '' \ 136 | --batch_size $batch_size \ 137 | --epochs $epochs \ 138 | --lr_init $lr_init \ 139 | --lr_min $lr_min \ 140 | --channel $channel 141 | } 142 | 143 | # Define function for training LOL-blur 144 | train_LOL_blur() { 145 | python ./tools/train.py \ 146 | --gpu_id $gpu_id \ 147 | --model_name $model_name \ 148 | --yml_path './configs/LOL_blur.yaml' \ 149 | --pretrain_weights '' \ 150 | --batch_size $batch_size \ 151 | --epochs $epochs \ 152 | --lr_init $lr_init \ 153 | --lr_min $lr_min \ 154 | --channel $channel 155 | } 156 | 157 | 158 | 159 | # Check if at least two arguments are provided 160 | if [ $# -lt 2 ]; then 161 | print_usage 162 | exit 1 163 | fi 164 | 165 | # Parse command line arguments 166 | gpu_id=$1 167 | dataset=$2 168 | model_name=${3:-"URWKV"} # default value: "URWKV" 169 | batch_size=${4:-8} # default value: 8 170 | epochs=${5:-1000} # default value: 1000 171 | lr_init=${6:-0.0002} # default value: 0.0002 172 | lr_min=${7:-1e-6} # default value: 1e-6 173 | channel=${8:-32} # default value: 32 174 | 175 | # Execute corresponding function based on dataset selection 176 | case $dataset in 177 | "LOL_v1") 178 | train_LOL_v1 179 | ;; 180 | "LOL_v2_real") 181 | train_LOL_v2_real 182 | ;; 183 | "LOL_v2_sync") 184 | train_LOL_v2_sync 185 | ;; 186 | "MIT_5K") 187 | train_MIT_5K 188 | ;; 189 | "SID") 190 | train_SID 191 | ;; 192 | "SMID") 193 | train_SMID 194 | ;; 195 | "SDSD_indoor") 196 | train_SDSD_indoor 197 | ;; 198 | "SDSD_outdoor") 199 | train_SDSD_outdoor 200 | ;; 201 | "LOL_blur") 202 | train_LOL_blur 203 | ;; 204 | *) 205 | echo "Invalid dataset selection. Please choose one of the following: LOL_v1, LOL_v2_real, LOL_v2_sync, MIT_5K." 206 | print_usage 207 | exit 1 208 | ;; 209 | esac -------------------------------------------------------------------------------- /custom_utils/metrics/niqe.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os.path import dirname, join 3 | 4 | import cv2 5 | import numpy as np 6 | import scipy 7 | import scipy.io 8 | import scipy.misc 9 | import scipy.ndimage 10 | import scipy.special 11 | from PIL import Image 12 | 13 | gamma_range = np.arange(0.2, 10, 0.001) 14 | a = scipy.special.gamma(2.0/gamma_range) 15 | a *= a 16 | b = scipy.special.gamma(1.0/gamma_range) 17 | c = scipy.special.gamma(3.0/gamma_range) 18 | prec_gammas = a/(b*c) 19 | 20 | 21 | def aggd_features(imdata): 22 | # flatten imdata 23 | imdata.shape = (len(imdata.flat),) 24 | imdata2 = imdata*imdata 25 | left_data = imdata2[imdata < 0] 26 | right_data = imdata2[imdata >= 0] 27 | left_mean_sqrt = 0 28 | right_mean_sqrt = 0 29 | if len(left_data) > 0: 30 | left_mean_sqrt = np.sqrt(np.average(left_data)) 31 | if len(right_data) > 0: 32 | right_mean_sqrt = np.sqrt(np.average(right_data)) 33 | 34 | if right_mean_sqrt != 0: 35 | gamma_hat = left_mean_sqrt/right_mean_sqrt 36 | else: 37 | gamma_hat = np.inf 38 | # solve r-hat norm 39 | 40 | imdata2_mean = np.mean(imdata2) 41 | if imdata2_mean != 0: 42 | r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2)) 43 | else: 44 | r_hat = np.inf 45 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * 46 | (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) 47 | 48 | # solve alpha by guessing values that minimize ro 49 | pos = np.argmin((prec_gammas - rhat_norm)**2) 50 | alpha = gamma_range[pos] 51 | 52 | gam1 = scipy.special.gamma(1.0/alpha) 53 | gam2 = scipy.special.gamma(2.0/alpha) 54 | gam3 = scipy.special.gamma(3.0/alpha) 55 | 56 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3) 57 | bl = aggdratio * left_mean_sqrt 58 | br = aggdratio * right_mean_sqrt 59 | 60 | # mean parameter 61 | N = (br - bl)*(gam2 / gam1) # *aggdratio 62 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) 63 | 64 | 65 | def ggd_features(imdata): 66 | nr_gam = 1/prec_gammas 67 | sigma_sq = np.var(imdata) 68 | E = np.mean(np.abs(imdata)) 69 | rho = sigma_sq/E**2 70 | pos = np.argmin(np.abs(nr_gam - rho)) 71 | return gamma_range[pos], sigma_sq 72 | 73 | 74 | def paired_product(new_im): 75 | shift1 = np.roll(new_im.copy(), 1, axis=1) 76 | shift2 = np.roll(new_im.copy(), 1, axis=0) 77 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) 78 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) 79 | 80 | H_img = shift1 * new_im 81 | V_img = shift2 * new_im 82 | D1_img = shift3 * new_im 83 | D2_img = shift4 * new_im 84 | 85 | return (H_img, V_img, D1_img, D2_img) 86 | 87 | 88 | def gen_gauss_window(lw, sigma): 89 | sd = np.float32(sigma) 90 | lw = int(lw) 91 | weights = [0.0] * (2 * lw + 1) 92 | weights[lw] = 1.0 93 | sum = 1.0 94 | sd *= sd 95 | for ii in range(1, lw + 1): 96 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) 97 | weights[lw + ii] = tmp 98 | weights[lw - ii] = tmp 99 | sum += 2.0 * tmp 100 | for ii in range(2 * lw + 1): 101 | weights[ii] /= sum 102 | return weights 103 | 104 | 105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): 106 | if avg_window is None: 107 | avg_window = gen_gauss_window(3, 7.0/6.0) 108 | assert len(np.shape(image)) == 2 109 | h, w = np.shape(image) 110 | mu_image = np.zeros((h, w), dtype=np.float32) 111 | var_image = np.zeros((h, w), dtype=np.float32) 112 | image = np.array(image).astype('float32') 113 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) 114 | scipy.ndimage.correlate1d(mu_image, avg_window, 1, 115 | mu_image, mode=extend_mode) 116 | scipy.ndimage.correlate1d(image**2, avg_window, 0, 117 | var_image, mode=extend_mode) 118 | scipy.ndimage.correlate1d(var_image, avg_window, 119 | 1, var_image, mode=extend_mode) 120 | var_image = np.sqrt(np.abs(var_image - mu_image**2)) 121 | return (image - mu_image)/(var_image + C), var_image, mu_image 122 | 123 | 124 | def _niqe_extract_subband_feats(mscncoefs): 125 | # alpha_m, = extract_ggd_features(mscncoefs) 126 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) 127 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs) 128 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) 129 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) 130 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) 131 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) 132 | return np.array([alpha_m, (bl+br)/2.0, 133 | alpha1, N1, bl1, br1, # (V) 134 | alpha2, N2, bl2, br2, # (H) 135 | alpha3, N3, bl3, bl3, # (D1) 136 | alpha4, N4, bl4, bl4, # (D2) 137 | ]) 138 | 139 | 140 | def get_patches_train_features(img, patch_size, stride=8): 141 | return _get_patches_generic(img, patch_size, 1, stride) 142 | 143 | 144 | def get_patches_test_features(img, patch_size, stride=8): 145 | return _get_patches_generic(img, patch_size, 0, stride) 146 | 147 | 148 | def extract_on_patches(img, patch_size): 149 | h, w = img.shape 150 | patch_size = np.int(patch_size) 151 | patches = [] 152 | for j in range(0, h-patch_size+1, patch_size): 153 | for i in range(0, w-patch_size+1, patch_size): 154 | patch = img[j:j+patch_size, i:i+patch_size] 155 | patches.append(patch) 156 | 157 | patches = np.array(patches) 158 | 159 | patch_features = [] 160 | for p in patches: 161 | patch_features.append(_niqe_extract_subband_feats(p)) 162 | patch_features = np.array(patch_features) 163 | 164 | return patch_features 165 | 166 | 167 | def _get_patches_generic(img, patch_size, is_train, stride): 168 | h, w = np.shape(img) 169 | if h < patch_size or w < patch_size: 170 | print("Input image is too small") 171 | exit(0) 172 | 173 | # ensure that the patch divides evenly into img 174 | hoffset = (h % patch_size) 175 | woffset = (w % patch_size) 176 | 177 | if hoffset > 0: 178 | img = img[:-hoffset, :] 179 | if woffset > 0: 180 | img = img[:, :-woffset] 181 | 182 | img = img.astype(np.float32) 183 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') 184 | img2 = cv2.resize(img, (0, 0), fx=0.5, fy=0.5) 185 | 186 | mscn1, var, mu = compute_image_mscn_transform(img) 187 | mscn1 = mscn1.astype(np.float32) 188 | 189 | mscn2, _, _ = compute_image_mscn_transform(img2) 190 | mscn2 = mscn2.astype(np.float32) 191 | 192 | feats_lvl1 = extract_on_patches(mscn1, patch_size) 193 | feats_lvl2 = extract_on_patches(mscn2, patch_size/2) 194 | 195 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3)) 196 | 197 | return feats 198 | 199 | 200 | def niqe(inputImgData): 201 | 202 | patch_size = 96 203 | module_path = dirname(__file__) 204 | 205 | # TODO: memoize 206 | params = scipy.io.loadmat( 207 | join(module_path, 'niqe_image_params.mat')) 208 | pop_mu = np.ravel(params["pop_mu"]) 209 | pop_cov = params["pop_cov"] 210 | 211 | if inputImgData.ndim == 3: 212 | inputImgData = cv2.cvtColor(inputImgData, cv2.COLOR_BGR2GRAY) 213 | M, N = inputImgData.shape 214 | 215 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) 216 | assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 217 | assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 218 | 219 | feats = get_patches_test_features(inputImgData, patch_size) 220 | sample_mu = np.mean(feats, axis=0) 221 | sample_cov = np.cov(feats.T) 222 | 223 | X = sample_mu - pop_mu 224 | covmat = ((pop_cov+sample_cov)/2.0) 225 | pinvmat = scipy.linalg.pinv(covmat) 226 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) 227 | 228 | return niqe_score 229 | -------------------------------------------------------------------------------- /custom_utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import math 9 | 10 | import numpy as np 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | # from pylle.core.config import cfg 14 | 15 | def adjust_learning_rate(optimizer, epoch, lr_decay=0.5): 16 | 17 | # --- Decay learning rate --- # 18 | step = 20 19 | 20 | if not epoch % step and epoch > 0: 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] *= lr_decay 23 | print('Learning rate sets to {}.'.format(param_group['lr'])) 24 | else: 25 | for param_group in optimizer.param_groups: 26 | print('Learning rate sets to {}.'.format(param_group['lr'])) 27 | 28 | # 动态调整学习率 29 | class LRScheduler(_LRScheduler): 30 | def __init__(self, optimizer, last_epoch=-1): 31 | if 'lr_spaces' not in self.__dict__: 32 | raise Exception('lr_spaces must be set in "LRSchduler"') 33 | super(LRScheduler, self).__init__(optimizer, last_epoch) 34 | 35 | def get_cur_lr(self): 36 | return self.lr_spaces[self.last_epoch] 37 | 38 | def get_lr(self): 39 | epoch = self.last_epoch 40 | return [self.lr_spaces[epoch] * pg['initial_lr'] / self.start_lr 41 | for pg in self.optimizer.param_groups] 42 | 43 | def __repr__(self): 44 | return "({}) lr spaces: \n{}".format(self.__class__.__name__, 45 | self.lr_spaces) 46 | 47 | 48 | class LogScheduler(LRScheduler): 49 | def __init__(self, optimizer, start_lr=0.03, end_lr=5e-4, 50 | epochs=50, last_epoch=-1, **kwargs): 51 | self.start_lr = start_lr 52 | self.end_lr = end_lr 53 | self.epochs = epochs 54 | self.lr_spaces = np.logspace(math.log10(start_lr), 55 | math.log10(end_lr), 56 | epochs) 57 | 58 | super(LogScheduler, self).__init__(optimizer, last_epoch) 59 | 60 | 61 | class StepScheduler(LRScheduler): 62 | def __init__(self, optimizer, start_lr=0.01, end_lr=None, 63 | step=10, mult=0.1, epochs=50, last_epoch=-1, **kwargs): 64 | if end_lr is not None: 65 | if start_lr is None: 66 | start_lr = end_lr / (mult ** (epochs // step)) 67 | else: # for warm up policy 68 | mult = math.pow(end_lr/start_lr, 1. / (epochs // step)) 69 | self.start_lr = start_lr 70 | self.lr_spaces = self.start_lr * (mult**(np.arange(epochs) // step)) 71 | self.mult = mult 72 | self._step = step 73 | 74 | super(StepScheduler, self).__init__(optimizer, last_epoch) 75 | 76 | 77 | class MultiStepScheduler(LRScheduler): 78 | def __init__(self, optimizer, start_lr=0.01, end_lr=None, 79 | steps=[10, 20, 30, 40], mult=0.5, epochs=50, 80 | last_epoch=-1, **kwargs): 81 | if end_lr is not None: 82 | if start_lr is None: 83 | start_lr = end_lr / (mult ** (len(steps))) 84 | else: 85 | mult = math.pow(end_lr/start_lr, 1. / len(steps)) 86 | self.start_lr = start_lr 87 | self.lr_spaces = self._build_lr(start_lr, steps, mult, epochs) 88 | self.mult = mult 89 | self.steps = steps 90 | 91 | super(MultiStepScheduler, self).__init__(optimizer, last_epoch) 92 | 93 | def _build_lr(self, start_lr, steps, mult, epochs): 94 | lr = [0] * epochs 95 | lr[0] = start_lr 96 | for i in range(1, epochs): 97 | lr[i] = lr[i-1] 98 | if i in steps: 99 | lr[i] *= mult 100 | return np.array(lr, dtype=np.float32) 101 | 102 | 103 | class LinearStepScheduler(LRScheduler): 104 | def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, 105 | epochs=50, last_epoch=-1, **kwargs): 106 | self.start_lr = start_lr 107 | self.end_lr = end_lr 108 | self.lr_spaces = np.linspace(start_lr, end_lr, epochs) 109 | super(LinearStepScheduler, self).__init__(optimizer, last_epoch) 110 | 111 | 112 | class CosStepScheduler(LRScheduler): 113 | def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, 114 | epochs=50, last_epoch=-1, **kwargs): 115 | self.start_lr = start_lr 116 | self.end_lr = end_lr 117 | self.lr_spaces = self._build_lr(start_lr, end_lr, epochs) 118 | 119 | super(CosStepScheduler, self).__init__(optimizer, last_epoch) 120 | 121 | def _build_lr(self, start_lr, end_lr, epochs): 122 | index = np.arange(epochs).astype(np.float32) 123 | lr = end_lr + (start_lr - end_lr) * \ 124 | (1. + np.cos(index * np.pi / epochs)) * 0.5 125 | return lr.astype(np.float32) 126 | 127 | 128 | class WarmUPScheduler(LRScheduler): 129 | def __init__(self, optimizer, warmup, normal, epochs=50, last_epoch=-1): 130 | warmup = warmup.lr_spaces # [::-1] 131 | normal = normal.lr_spaces 132 | self.lr_spaces = np.concatenate([warmup, normal]) 133 | self.start_lr = normal[0] 134 | 135 | super(WarmUPScheduler, self).__init__(optimizer, last_epoch) 136 | 137 | 138 | LRs = { 139 | 'log': LogScheduler, 140 | 'step': StepScheduler, 141 | 'multi-step': MultiStepScheduler, 142 | 'linear': LinearStepScheduler, 143 | 'cos': CosStepScheduler} 144 | 145 | 146 | def _build_lr_scheduler(optimizer, config, epochs=50, last_epoch=-1): 147 | return LRs[config.TYPE](optimizer, last_epoch=last_epoch, 148 | epochs=epochs, **config.KWARGS) 149 | 150 | 151 | # def _build_warm_up_scheduler(optimizer, epochs=50, last_epoch=-1): 152 | # warmup_epoch = cfg.TRAIN.LR_WARMUP.EPOCH 153 | # sc1 = _build_lr_scheduler(optimizer, cfg.TRAIN.LR_WARMUP, 154 | # warmup_epoch, last_epoch) 155 | # sc2 = _build_lr_scheduler(optimizer, cfg.TRAIN.LR, 156 | # epochs - warmup_epoch, last_epoch) 157 | # return WarmUPScheduler(optimizer, sc1, sc2, epochs, last_epoch) 158 | 159 | 160 | # def build_lr_scheduler(optimizer, epochs=50, last_epoch=-1): 161 | # if cfg.TRAIN.LR_WARMUP.WARMUP: 162 | # return _build_warm_up_scheduler(optimizer, epochs, last_epoch) 163 | # else: 164 | # return _build_lr_scheduler(optimizer, cfg.TRAIN.LR, 165 | # epochs, last_epoch) 166 | 167 | 168 | if __name__ == '__main__': 169 | import torch.nn as nn 170 | from torch.optim import SGD 171 | 172 | class Net(nn.Module): 173 | def __init__(self): 174 | super(Net, self).__init__() 175 | self.conv = nn.Conv2d(10, 10, kernel_size=3) 176 | net = Net().parameters() 177 | optimizer = SGD(net, lr=0.01) 178 | 179 | # test1 180 | step = { 181 | 'type': 'step', 182 | 'start_lr': 0.01, 183 | 'step': 10, 184 | 'mult': 0.1 185 | } 186 | lr = build_lr_scheduler(optimizer, step) 187 | print(lr) 188 | 189 | log = { 190 | 'type': 'log', 191 | 'start_lr': 0.03, 192 | 'end_lr': 5e-4, 193 | } 194 | lr = build_lr_scheduler(optimizer, log) 195 | 196 | print(lr) 197 | 198 | log = { 199 | 'type': 'multi-step', 200 | "start_lr": 0.01, 201 | "mult": 0.1, 202 | "steps": [10, 15, 20] 203 | } 204 | lr = build_lr_scheduler(optimizer, log) 205 | print(lr) 206 | 207 | cos = { 208 | "type": 'cos', 209 | 'start_lr': 0.01, 210 | 'end_lr': 0.0005, 211 | } 212 | lr = build_lr_scheduler(optimizer, cos) 213 | print(lr) 214 | 215 | step = { 216 | 'type': 'step', 217 | 'start_lr': 0.001, 218 | 'end_lr': 0.03, 219 | 'step': 1, 220 | } 221 | 222 | warmup = log.copy() 223 | warmup['warmup'] = step 224 | warmup['warmup']['epoch'] = 5 225 | lr = build_lr_scheduler(optimizer, warmup, epochs=55) 226 | print(lr) 227 | 228 | lr.step() 229 | print(lr.last_epoch) 230 | 231 | lr.step(5) 232 | print(lr.last_epoch) 233 | -------------------------------------------------------------------------------- /custom_utils/metrics/piqe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy.special import gamma 4 | 5 | 6 | def calculate_mscn(dis_image): 7 | dis_image = dis_image.astype(np.float32) # 类型转换十分重要 8 | ux = cv2.GaussianBlur(dis_image, (7, 7), 7/6) 9 | ux_sq = ux*ux 10 | sigma = np.sqrt(np.abs(cv2.GaussianBlur(dis_image**2, (7, 7), 7/6)-ux_sq)) 11 | 12 | mscn = (dis_image-ux)/(1+sigma) 13 | 14 | return mscn 15 | 16 | # Function to segment block edges 17 | 18 | 19 | def segmentEdge(blockEdge, nSegments, blockSize, windowSize): 20 | # Segment is defined as a collection of 6 contiguous pixels in a block edge 21 | segments = np.zeros((nSegments, windowSize)) 22 | for i in range(nSegments): 23 | segments[i, :] = blockEdge[i:windowSize] 24 | if(windowSize <= (blockSize+1)): 25 | windowSize = windowSize+1 26 | 27 | return segments 28 | 29 | 30 | def noticeDistCriterion(Block, nSegments, blockSize, windowSize, blockImpairedThreshold, N): 31 | # Top edge of block 32 | topEdge = Block[0, :] 33 | segTopEdge = segmentEdge(topEdge, nSegments, blockSize, windowSize) 34 | 35 | # Right side edge of block 36 | rightSideEdge = Block[:, N-1] 37 | rightSideEdge = np.transpose(rightSideEdge) 38 | segRightSideEdge = segmentEdge( 39 | rightSideEdge, nSegments, blockSize, windowSize) 40 | 41 | # Down side edge of block 42 | downSideEdge = Block[N-1, :] 43 | segDownSideEdge = segmentEdge( 44 | downSideEdge, nSegments, blockSize, windowSize) 45 | 46 | # Left side edge of block 47 | leftSideEdge = Block[:, 0] 48 | leftSideEdge = np.transpose(leftSideEdge) 49 | segLeftSideEdge = segmentEdge( 50 | leftSideEdge, nSegments, blockSize, windowSize) 51 | 52 | # Compute standard deviation of segments in left, right, top and down side edges of a block 53 | segTopEdge_stdDev = np.std(segTopEdge, axis=1) 54 | segRightSideEdge_stdDev = np.std(segRightSideEdge, axis=1) 55 | segDownSideEdge_stdDev = np.std(segDownSideEdge, axis=1) 56 | segLeftSideEdge_stdDev = np.std(segLeftSideEdge, axis=1) 57 | 58 | # Check for segment in block exhibits impairedness, if the standard deviation of the segment is less than blockImpairedThreshold. 59 | blockImpaired = 0 60 | for segIndex in range(segTopEdge.shape[0]): 61 | if((segTopEdge_stdDev[segIndex] < blockImpairedThreshold) or 62 | (segRightSideEdge_stdDev[segIndex] < blockImpairedThreshold) or 63 | (segDownSideEdge_stdDev[segIndex] < blockImpairedThreshold) or 64 | (segLeftSideEdge_stdDev[segIndex] < blockImpairedThreshold)): 65 | blockImpaired = 1 66 | break 67 | 68 | return blockImpaired 69 | 70 | 71 | def noiseCriterion(Block, blockSize, blockVar): 72 | # Compute block standard deviation[h,w,c]=size(I) 73 | blockSigma = np.sqrt(blockVar) 74 | # Compute ratio of center and surround standard deviation 75 | cenSurDev = centerSurDev(Block, blockSize) 76 | # Relation between center-surround deviation and the block standard deviation 77 | blockBeta = (abs(blockSigma-cenSurDev))/(max(blockSigma, cenSurDev)) 78 | 79 | return blockSigma, blockBeta 80 | 81 | # Function to compute center surround Deviation of a block 82 | 83 | 84 | def centerSurDev(Block, blockSize): 85 | # block center 86 | center1 = int((blockSize+1)/2)-1 87 | center2 = center1+1 88 | center = np.vstack((Block[:, center1], Block[:, center2])) 89 | # block surround 90 | Block = np.delete(Block, center1, axis=1) 91 | Block = np.delete(Block, center1, axis=1) 92 | 93 | # Compute standard deviation of block center and block surround 94 | center_std = np.std(center) 95 | surround_std = np.std(Block) 96 | 97 | # Ratio of center and surround standard deviation 98 | cenSurDev = (center_std/surround_std) 99 | 100 | # Check for nan's 101 | # if(isnan(cenSurDev)): 102 | # cenSurDev = 0 103 | 104 | return cenSurDev 105 | 106 | 107 | def piqe(im): 108 | blockSize = 16 # Considered 16x16 block size for overall analysis 109 | activityThreshold = 0.1 # Threshold used to identify high spatially prominent blocks 110 | blockImpairedThreshold = 0.1 # Threshold identify blocks having noticeable artifacts 111 | windowSize = 6 # Considered segment size in a block edge. 112 | nSegments = blockSize-windowSize+1 # Number of segments for each block edge 113 | distBlockScores = 0 # Accumulation of distorted block scores 114 | NHSA = 0 # Number of high spatial active blocks. 115 | 116 | # pad if size is not divisible by blockSize 117 | if len(im.shape) == 3: 118 | im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) 119 | originalSize = im.shape 120 | rows, columns = originalSize 121 | rowsPad = rows % blockSize 122 | columnsPad = columns % blockSize 123 | isPadded = False 124 | if(rowsPad > 0 or columnsPad > 0): 125 | if rowsPad > 0: 126 | rowsPad = blockSize-rowsPad 127 | if columnsPad > 0: 128 | columnsPad = blockSize-columnsPad 129 | isPadded = True 130 | padSize = [rowsPad, columnsPad] 131 | im = np.pad(im, ((0, rowsPad), (0, columnsPad)), 'edge') 132 | 133 | # Normalize image to zero mean and ~unit std 134 | # used circularly-symmetric Gaussian weighting function sampled out 135 | # to 3 standard deviations. 136 | imnorm = calculate_mscn(im) 137 | 138 | # Preallocation for masks 139 | NoticeableArtifactsMask = np.zeros(imnorm.shape) 140 | NoiseMask = np.zeros(imnorm.shape) 141 | ActivityMask = np.zeros(imnorm.shape) 142 | 143 | # Start of block by block processing 144 | total_var = [] 145 | total_bscore = [] 146 | total_ndc = [] 147 | total_nc = [] 148 | 149 | BlockScores = [] 150 | for i in np.arange(0, imnorm.shape[0]-1, blockSize): 151 | for j in np.arange(0, imnorm.shape[1]-1, blockSize): 152 | # Weights Initialization 153 | WNDC = 0 154 | WNC = 0 155 | 156 | # Compute block variance 157 | Block = imnorm[i:i+blockSize, j:j+blockSize] 158 | blockVar = np.var(Block) 159 | 160 | if(blockVar > activityThreshold): 161 | ActivityMask[i:i+blockSize, j:j+blockSize] = 1 162 | NHSA = NHSA+1 163 | 164 | # Analyze Block for noticeable artifacts 165 | blockImpaired = noticeDistCriterion( 166 | Block, nSegments, blockSize-1, windowSize, blockImpairedThreshold, blockSize) 167 | 168 | if(blockImpaired): 169 | WNDC = 1 170 | NoticeableArtifactsMask[i:i + 171 | blockSize, j:j+blockSize] = blockVar 172 | 173 | # Analyze Block for guassian noise distortions 174 | [blockSigma, blockBeta] = noiseCriterion( 175 | Block, blockSize-1, blockVar) 176 | 177 | if((blockSigma > 2*blockBeta)): 178 | WNC = 1 179 | NoiseMask[i:i+blockSize, j:j+blockSize] = blockVar 180 | 181 | # Pooling/ distortion assigment 182 | # distBlockScores = distBlockScores + \ 183 | # WNDC*pow(1-blockVar, 2) + WNC*pow(blockVar, 2) 184 | 185 | if WNDC*pow(1-blockVar, 2) + WNC*pow(blockVar, 2) > 0: 186 | BlockScores.append( 187 | WNDC*pow(1-blockVar, 2) + WNC*pow(blockVar, 2)) 188 | 189 | total_var = [total_var, blockVar] 190 | total_bscore = [total_bscore, WNDC * 191 | (1-blockVar) + WNC*(blockVar)] 192 | total_ndc = [total_ndc, WNDC] 193 | total_nc = [total_nc, WNC] 194 | 195 | BlockScores = sorted(BlockScores) 196 | lowSum = sum(BlockScores[:int(0.1*len(BlockScores))]) 197 | Sum = sum(BlockScores) 198 | Scores = [(s*10*lowSum)/Sum for s in BlockScores] 199 | C = 1 200 | Score = ((sum(Scores) + C)/(C + NHSA))*100 201 | 202 | # if input image is padded then remove those portions from ActivityMask, 203 | # NoticeableArtifactsMask and NoiseMask and ensure that size of these masks 204 | # are always M-by-N. 205 | if(isPadded): 206 | NoticeableArtifactsMask = NoticeableArtifactsMask[0:originalSize[0], 207 | 0:originalSize[1]] 208 | NoiseMask = NoiseMask[0:originalSize[0], 0:originalSize[1]] 209 | ActivityMask = ActivityMask[0:originalSize[0], 1:originalSize[1]] 210 | 211 | return Score, NoticeableArtifactsMask, NoiseMask, ActivityMask 212 | 213 | 214 | -------------------------------------------------------------------------------- /custom_utils/data_loaders/lol.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | import torch 4 | from PIL import Image 5 | import torchvision.transforms.functional as TF 6 | import random 7 | import numpy as np 8 | from glob import glob 9 | 10 | random.seed(1143) 11 | 12 | #import torch.nn.functional as F 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 15 | 16 | def load_img(filepath): 17 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 18 | 19 | class PatchDataLoaderTrain(Dataset): 20 | def __init__(self, rgb_dir, img_options=None): 21 | super(PatchDataLoaderTrain, self).__init__() 22 | 23 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'low'))) 24 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'high'))) 25 | 26 | self.inp_filenames = [os.path.join(rgb_dir, 'low', x) for x in inp_files if is_image_file(x)] 27 | self.tar_filenames = [os.path.join(rgb_dir, 'high', x) for x in tar_files if is_image_file(x)] 28 | 29 | self.img_options = img_options 30 | self.sizex = len(self.tar_filenames) # get the size of target 31 | 32 | self.ps = self.img_options['patch_size'] 33 | 34 | def __len__(self): 35 | return self.sizex 36 | 37 | def __getitem__(self, index): 38 | index_ = index % self.sizex 39 | ps = self.ps 40 | 41 | inp_path = self.inp_filenames[index_] 42 | tar_path = self.tar_filenames[index_] 43 | 44 | inp_img = Image.open(inp_path).convert('RGB') 45 | tar_img = Image.open(tar_path).convert('RGB') 46 | 47 | w, h = tar_img.size 48 | padw = ps - w if w < ps else 0 49 | padh = ps - h if h < ps else 0 50 | 51 | # Reflect Pad in case image is smaller than patch_size 52 | if padw != 0 or padh != 0: 53 | inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect') 54 | tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect') 55 | 56 | inp_img = TF.to_tensor(inp_img) 57 | tar_img = TF.to_tensor(tar_img) 58 | 59 | hh, ww = tar_img.shape[1], tar_img.shape[2] 60 | 61 | rr = random.randint(0, hh - ps) 62 | cc = random.randint(0, ww - ps) 63 | aug = random.randint(0, 8) 64 | 65 | # Crop patch 66 | inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] 67 | tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] 68 | 69 | # Data Augmentations 70 | if aug == 1: 71 | inp_img = inp_img.flip(1) 72 | tar_img = tar_img.flip(1) 73 | elif aug == 2: 74 | inp_img = inp_img.flip(2) 75 | tar_img = tar_img.flip(2) 76 | elif aug == 3: 77 | inp_img = torch.rot90(inp_img, dims=(1, 2)) 78 | tar_img = torch.rot90(tar_img, dims=(1, 2)) 79 | elif aug == 4: 80 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=2) 81 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=2) 82 | elif aug == 5: 83 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=3) 84 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=3) 85 | elif aug == 6: 86 | inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2)) 87 | tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2)) 88 | elif aug == 7: 89 | inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2)) 90 | tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2)) 91 | 92 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 93 | 94 | return inp_img, tar_img, filename 95 | 96 | 97 | 98 | 99 | class PatchDataLoaderVal(Dataset): 100 | def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): 101 | super(PatchDataLoaderVal, self).__init__() 102 | 103 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'low'))) 104 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'high'))) 105 | 106 | self.inp_filenames = [os.path.join(rgb_dir, 'low', x) for x in inp_files if is_image_file(x)] 107 | self.tar_filenames = [os.path.join(rgb_dir, 'high', x) for x in tar_files if is_image_file(x)] 108 | 109 | self.img_options = img_options 110 | self.sizex = len(self.tar_filenames) # get the size of target 111 | self.mul = 16 112 | 113 | def __len__(self): 114 | return self.sizex 115 | 116 | def __getitem__(self, index): 117 | index_ = index % self.sizex 118 | 119 | inp_path = self.inp_filenames[index_] 120 | tar_path = self.tar_filenames[index_] 121 | 122 | inp_img = Image.open(inp_path).convert('RGB') 123 | tar_img = Image.open(tar_path).convert('RGB') 124 | #inp_img = TF.to_tensor(inp_img) 125 | #tar_img = TF.to_tensor(tar_img) 126 | w, h = inp_img.size 127 | #h, w = inp_img.shape[2], inp_img.shape[3] 128 | H, W = ((h + self.mul) // self.mul) * self.mul, ((w + self.mul) // self.mul) * self.mul 129 | padh = H - h if h % self.mul != 0 else 0 130 | padw = W - w if w % self.mul != 0 else 0 131 | inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect') 132 | inp_img = TF.to_tensor(inp_img) 133 | tar_img = TF.to_tensor(tar_img) 134 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 135 | 136 | return inp_img, tar_img, filename 137 | 138 | 139 | def populate_train_list(images_path, mode='train'): 140 | # print(images_path) 141 | image_list_lowlight = glob(images_path + '*.png') 142 | train_list = image_list_lowlight 143 | if mode == 'train': 144 | random.shuffle(train_list) 145 | 146 | return train_list 147 | 148 | class wholeDataLoader(Dataset): 149 | 150 | def __init__(self, images_path, mode='train'): 151 | images_path = images_path + '/low/' 152 | self.train_list = populate_train_list(images_path, mode) 153 | self.mode = mode 154 | self.data_list = self.train_list 155 | print("Total examples:", len(self.train_list)) 156 | 157 | 158 | def __getitem__(self, index): 159 | data_lowlight_path = self.data_list[index] 160 | #ps = 256 # Training Patch Size 161 | if self.mode == 'train': 162 | data_lowlight = Image.open(data_lowlight_path).convert('RGB') 163 | data_highlight = Image.open(data_lowlight_path.replace('low', 'high')).convert('RGB') 164 | w, h = data_lowlight.size 165 | data_lowlight = TF.to_tensor(data_lowlight) 166 | data_highlight = TF.to_tensor(data_highlight) 167 | hh, ww = data_highlight.shape[1], data_highlight.shape[2] 168 | 169 | # rr = random.randint(0, hh - ps) 170 | # cc = random.randint(0, ww - ps) 171 | aug = random.randint(0, 3) 172 | 173 | # Crop patch 174 | # data_lowlight = data_lowlight[:, rr:rr + ps, cc:cc + ps] 175 | # data_highlight = data_highlight[:, rr:rr + ps, cc:cc + ps] 176 | 177 | # Data Augmentations 178 | if aug == 1: 179 | data_lowlight = data_lowlight.flip(1) 180 | data_highlight = data_highlight.flip(1) 181 | elif aug == 2: 182 | data_lowlight = data_lowlight.flip(2) 183 | data_highlight = data_highlight.flip(2) 184 | # elif aug == 3: 185 | # data_lowlight = torch.rot90(data_lowlight, dims=(1, 2)) 186 | # data_highlight = torch.rot90(data_highlight, dims=(1, 2)) 187 | # elif aug == 4: 188 | # data_lowlight = torch.rot90(data_lowlight, dims=(1, 2), k=2) 189 | # data_highlight = torch.rot90(data_highlight, dims=(1, 2), k=2) 190 | # elif aug == 5: 191 | # data_lowlight = torch.rot90(data_lowlight, dims=(1, 2), k=3) 192 | # data_highlight = torch.rot90(data_highlight, dims=(1, 2), k=3) 193 | # elif aug == 6: 194 | # data_lowlight = torch.rot90(data_lowlight.flip(1), dims=(1, 2)) 195 | # data_highlight = torch.rot90(data_highlight.flip(1), dims=(1, 2)) 196 | # elif aug == 7: 197 | # data_lowlight = torch.rot90(data_lowlight.flip(2), dims=(1, 2)) 198 | # data_highlight = torch.rot90(data_highlight.flip(2), dims=(1, 2)) 199 | 200 | filename = os.path.splitext(os.path.split(data_lowlight_path)[-1])[0] 201 | 202 | return data_lowlight, data_highlight, filename 203 | 204 | elif self.mode == 'val': 205 | data_lowlight = Image.open(data_lowlight_path).convert('RGB') 206 | data_highlight = Image.open(data_lowlight_path.replace('low', 'high')).convert('RGB') 207 | # Validate on center crop 208 | 209 | data_lowlight = TF.to_tensor(data_lowlight) 210 | data_highlight = TF.to_tensor(data_highlight) 211 | 212 | filename = os.path.splitext(os.path.split(data_lowlight_path)[-1])[0] 213 | 214 | return data_lowlight, data_highlight, filename 215 | 216 | elif self.mode == 'test': 217 | data_lowlight = Image.open(data_lowlight_path).convert('RGB') 218 | data_highlight = Image.open(data_lowlight_path.replace('low', 'high')).convert('RGB') 219 | 220 | data_lowlight = TF.to_tensor(data_lowlight) 221 | data_highlight = TF.to_tensor(data_highlight) 222 | 223 | filename = os.path.splitext(os.path.split(data_lowlight_path)[-1])[0] 224 | #print(filename) 225 | return data_lowlight, data_highlight, filename 226 | 227 | def __len__(self): 228 | return len(self.data_list) -------------------------------------------------------------------------------- /tools/measure.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import time 4 | from collections import OrderedDict 5 | import sys 6 | sys.path.append('.') 7 | import numpy as np 8 | import torch 9 | import cv2 10 | import argparse 11 | import custom_utils 12 | 13 | from natsort import natsort 14 | from skimage.metrics import structural_similarity as ssim 15 | from skimage.metrics import peak_signal_noise_ratio as psnr 16 | import torchvision.transforms.functional as TF 17 | import torch.nn.functional as F 18 | from math import log10, sqrt 19 | import pyiqa 20 | from pytorch_msssim import ssim as ssim2 21 | from IQA_pytorch import SSIM 22 | from skimage.metrics import structural_similarity as compare_ssim 23 | import scipy.misc 24 | from scipy.ndimage import gaussian_filter 25 | import lpips 26 | 27 | import numpy as np 28 | import skimage 29 | from skimage import data, img_as_float 30 | from skimage.io import imread 31 | from skimage.transform import resize 32 | from skimage.color import rgb2gray 33 | from scipy.ndimage.filters import convolve 34 | 35 | 36 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 37 | 38 | class Measure(): 39 | def __init__(self, net='alex', use_gpu=False): 40 | self.device = 'cuda' if use_gpu else 'cpu' 41 | self.niqe_val = pyiqa.create_metric("niqe", device=torch.device('cuda')) 42 | self.model = lpips.LPIPS(net=net) 43 | self.model.to(self.device) 44 | self.ssim_val = SSIM() 45 | 46 | def measure(self, imgA, imgB): 47 | return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips, self.mae, self.niqe]] 48 | 49 | 50 | def psnr(self, imgA, imgB): 51 | psnr_val = psnr(imgA, imgB) 52 | 53 | return psnr_val 54 | 55 | def ssim(self, imgA, imgB, gray_scale=False): 56 | if gray_scale: 57 | score, diff = ssim(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True) 58 | # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged. 59 | else: 60 | score, diff = ssim(imgA, imgB, full=True, multichannel=True,channel_axis=2) 61 | return score 62 | 63 | def lpips(self, imgA, imgB, model=None): 64 | tA = t(imgA).to(self.device) 65 | tB = t(imgB).to(self.device) 66 | dist01 = self.model.forward(tA, tB).item() 67 | return dist01 68 | 69 | # 此计算方式同2023 AAAI_LLFormer_Ultra-High-Definition Low-Light Image Enhancement:A Benchmark and Transformer-Based Method 70 | def mae(self, imgA, imgB): 71 | 72 | imgA = TF.to_tensor(imgA).float() 73 | imgB = TF.to_tensor(imgB).float() 74 | 75 | # 计算差异 76 | diff = torch.abs(imgA - imgB) 77 | 78 | # 计算MAE 79 | mae = torch.mean(diff) 80 | 81 | return mae.item() 82 | 83 | 84 | def niqe(self, gt, enhance): 85 | gt = TF.to_tensor(gt).float() 86 | enhance = TF.to_tensor(enhance).float() 87 | gt, enhance = gt.unsqueeze(0), enhance.unsqueeze(0) 88 | score = self.niqe_val(enhance).item() 89 | return score 90 | 91 | 92 | # def niqe(self, gt, enhance): 93 | # img = enhance 94 | # img = img_as_float(img) 95 | 96 | # if img.ndim == 3: 97 | # img = rgb2gray(img) 98 | 99 | # img = resize(img, (384, 384), anti_aliasing=True) 100 | 101 | # mu = convolve(img, np.ones((7, 7)) / 49, mode='reflect') 102 | # mu_sq = mu * mu 103 | # sigma = np.sqrt(abs(convolve(img * img, np.ones((7, 7)) / 49, mode='reflect') - mu_sq)) 104 | # structdis = convolve(abs(img - mu), np.ones((7, 7)) / 49, mode='reflect') 105 | # niqe_score = np.mean(sigma / (structdis + 1e-12)) 106 | 107 | # return niqe_score 108 | 109 | 110 | def t(img): 111 | def to_4d(img): 112 | assert len(img.shape) == 3 113 | assert img.dtype == np.uint8 114 | img_new = np.expand_dims(img, axis=0) 115 | assert len(img_new.shape) == 4 116 | return img_new 117 | 118 | def to_CHW(img): 119 | return np.transpose(img, [2, 0, 1]) 120 | 121 | def to_tensor(img): 122 | return torch.Tensor(img) 123 | 124 | return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1 125 | 126 | 127 | def fiFindByWildcard(wildcard): 128 | return natsort.natsorted(glob.glob(wildcard, recursive=True)) 129 | 130 | 131 | def imread(path): 132 | return cv2.imread(path)[:, :, [2, 1, 0]] 133 | 134 | 135 | def format_result(psnr, ssim, lpips, mae, niqe): 136 | return f'{psnr:0.4f}, {ssim:0.4f}, {lpips:0.4f}, {mae:0.4f}, {niqe:0.4f}' 137 | 138 | def measure_dirs(dirA, dirB, dataset_name, txt_path, use_gpu, verbose=False): 139 | if verbose: 140 | vprint = lambda x: print(x) 141 | else: 142 | vprint = lambda x: None 143 | 144 | 145 | t_init = time.time() 146 | 147 | paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}')) 148 | paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}')) 149 | 150 | vprint("Comparing: ") 151 | vprint(dirA) 152 | vprint(dirB) 153 | 154 | measure = Measure(use_gpu=use_gpu) 155 | 156 | 157 | results = [] 158 | for pathA, pathB in zip(paths_A, paths_B): 159 | result = OrderedDict() 160 | 161 | t = time.time() 162 | high = imread(pathA) 163 | low = imread(pathB) 164 | img_name = os.path.basename(pathA) 165 | 166 | result['psnr'], result['ssim'], result['lpips'], result['mae'], result['niqe'] = measure.measure(high, low) 167 | # d = time.time() - t 168 | # vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}") 169 | 170 | d = time.time() - t 171 | # output_str = f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]} >>\t " 172 | output_str = f"{pathA.split('/')[-1]} >>\t " 173 | output_str += f"PSNR: {result['psnr']:.2f} \t SSIM: {result['ssim']:.3f} \t LPIPS: {result['lpips']:.3f} \t MAE: {result['mae']:.3f} \t NIQE: {result['niqe']:.3f} \t " 174 | output_str += f"Time taken: {d:0.1f}s" 175 | vprint(output_str) 176 | 177 | results.append(result) 178 | 179 | with open(txt_path, 'a+') as f: 180 | f.write("Image: {} \t >> PSNR: {:.2f} \t SSIM: {:.4f} \t LPIPS: {:.4f} \t NIQE: {:.2f} \n".format(img_name, result['psnr'], result['ssim'], result['lpips'], result['mae'], result['niqe'])) 181 | 182 | 183 | psnr = np.mean([result['psnr'] for result in results]) 184 | ssim = np.mean([result['ssim'] for result in results]) 185 | lpips = np.mean([result['lpips'] for result in results]) 186 | mae = np.mean([result['mae'] for result in results]) 187 | niqe = np.mean([result['niqe'] for result in results]) 188 | 189 | 190 | with open(txt_path, 'a+') as f: 191 | f.write("\n------------------------------ \n\n{} \t >> PSNR: {:.2f} \t SSIM: {:.3f} \t LPIPS: {:.3f} \t MAE: {:.3f}\t NIQE: {:.3f}\n\n------------------------------".format(dataset_name, psnr, ssim, lpips, mae, niqe)) 192 | 193 | # vprint(f"Final Result: {format_result(psnr, ssim, lpips, mae, niqe)}, {time.time() - t_init:0.1f}s") 194 | result = format_result(psnr, ssim, lpips, mae, niqe) 195 | result_names = {'PSNR': psnr, 'SSIM': ssim, 'LPIPS': lpips, 'MAE': mae, 'NIQE': niqe} 196 | 197 | output_str = "\nFinal Result>> \t" 198 | for name, value in result_names.items(): 199 | if name == 'PSNR': 200 | output_str += f"{name}: {value:.2f} \t" # 保留两位小数 201 | else: 202 | output_str += f"{name}: {value:.3f} \t" # 保留三位小数 203 | 204 | output_str += f"Time taken: {time.time() - t_init:0.1f}s" 205 | vprint(output_str) 206 | 207 | 208 | if __name__ == "__main__": 209 | parser = argparse.ArgumentParser() 210 | 211 | parser.add_argument('--dataset_name', default='MIT_5K', type=str) 212 | parser.add_argument('--model_name', default='ReTrust', type=str) 213 | parser.add_argument('--result_dir', default='./results', type=str) 214 | 215 | parser.add_argument('-type', default='png') 216 | parser.add_argument('--gpu_id', default=False) 217 | args = parser.parse_args() 218 | 219 | if args.dataset_name == 'LOL_v1': 220 | dirA = '/data/xr/Dataset/light_dataset/LOL_v1/eval15/high/' 221 | dirB = os.path.join('results/LOL_v1/', args.model_name) 222 | elif args.dataset_name == 'LOL_v2_real': 223 | dirA = '/data/xr/Dataset/light_dataset/LOL_v2/Real_captured/Test/high/' 224 | dirB = os.path.join('results/LOL_v2_real/', args.model_name) 225 | elif args.dataset_name == 'LOL_v2_sync': 226 | dirA = '/data/xr/Dataset/light_dataset/LOL_v2/Synthetic/Test/high/' 227 | dirB = os.path.join('results/LOL_v2_sync/', args.model_name) 228 | elif args.dataset_name == 'MIT_5K': 229 | dirA = '/data/xr/Dataset/light_dataset/MIT-Adobe-5K-512/test/high/' 230 | dirB = os.path.join('results/MIT_5K/', args.model_name) 231 | elif args.dataset_name == 'LOL_blur': 232 | dirA = '/data/xr/Dataset/light_dataset/LOL_blur/eval/high/' 233 | dirB = os.path.join('results/LOL_blur/', args.model_name) 234 | elif args.dataset_name == 'SID': 235 | dirA = '/data/xr/Dataset/light_dataset/SID_png/eval/high/' 236 | dirB = os.path.join('results/SID/', args.model_name) 237 | elif args.dataset_name == 'SMID': 238 | dirA = '/data/xr/Dataset/light_dataset/SMID_png/eval/high/' 239 | dirB = os.path.join('results/SMID/', args.model_name) 240 | elif args.dataset_name == 'SDSD_indoor': 241 | dirA = '/data/xr/Dataset/light_dataset/SDSD_indoor_png/eval/high/' 242 | dirB = os.path.join('results/SDSD_indoor/', args.model_name) 243 | elif args.dataset_name == 'SDSD_outdoor': 244 | dirA = '/data/xr/Dataset/light_dataset/SDSD_outdoor_png/eval/high/' 245 | dirB = os.path.join('results/SDSD_outdoor/', args.model_name) 246 | 247 | 248 | 249 | result_path = os.path.join(args.result_dir, args.dataset_name) 250 | # if not os.path.exists(result_path): 251 | # custom_utils.mkdir(result_path) 252 | txt_path = result_path + '/result.txt' 253 | 254 | if os.path.exists(txt_path): 255 | os.remove(txt_path) 256 | 257 | type = args.type 258 | use_gpu = args.gpu_id 259 | 260 | if len(dirA) > 0 and len(dirB) > 0: 261 | measure_dirs(dirA, dirB, dataset_name=args.dataset_name, txt_path=txt_path, use_gpu=use_gpu, verbose=True) 262 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import random 4 | import numpy as np 5 | import yaml 6 | import argparse 7 | from thop import profile 8 | from tqdm import tqdm 9 | from thop import clever_format 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Hyper-parameters for URWKV') 14 | parser.add_argument('--gpu_id', type=str, default=0) 15 | parser.add_argument('--model_name', type=str, default='testpath') 16 | parser.add_argument('--yml_path', default="./configs/LOL_v1.yaml", type=str) 17 | parser.add_argument('--pretrain_weights', default='', type=str, help='Path to weights') 18 | parser.add_argument('--channel', type=int, default=32) 19 | parser.add_argument('--batch_size', type=int, default=8) 20 | parser.add_argument('--epochs', type=int, default=1000) 21 | parser.add_argument('--lr_init', type=float, default=0.0002) 22 | parser.add_argument('--lr_min', type=float, default=1e-6) 23 | args = parser.parse_args() 24 | 25 | # other imports 26 | import os 27 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 28 | import sys 29 | sys.path.append('.') 30 | 31 | # torch imports 32 | import torch 33 | import torch.nn as nn 34 | import torch.optim as optim 35 | from collections import OrderedDict 36 | from torch.utils.data import DataLoader 37 | 38 | 39 | # custom imports 40 | from model.loss import VGGLoss, SSIM_loss 41 | from model.model_builder import LLENet 42 | import custom_utils 43 | from custom_utils.dataset_utils import DataLoaderX 44 | from custom_utils.data_loaders.lol import PatchDataLoaderTrain, PatchDataLoaderVal, wholeDataLoader 45 | from custom_utils.warmup_scheduler.scheduler import GradualWarmupScheduler 46 | from custom_utils import network_parameters 47 | 48 | 49 | ## Set Seeds 50 | torch.backends.cudnn.benchmark = True 51 | random.seed(42) 52 | np.random.seed(42) 53 | torch.manual_seed(42) 54 | torch.cuda.manual_seed_all(42) 55 | 56 | ## Load yaml configuration file 57 | yaml_file = args.yml_path 58 | 59 | with open(yaml_file, 'r') as config: 60 | opt = yaml.safe_load(config) 61 | print("load training yaml file: %s"%(yaml_file)) 62 | 63 | Train = opt['TRAINING'] 64 | print(Train) 65 | OPT = opt['OPTIM'] 66 | 67 | ## Build Model 68 | print('==> Build the model') 69 | model_restored = LLENet(dim=args.channel) 70 | 71 | 72 | test_tensor = torch.randn(1, 3, 256, 256) 73 | flops, params = profile(model_restored.cuda(), ((test_tensor.cuda()),)) 74 | params_number = params / 1000000.0 75 | flops_number = flops / 1000000000.0 76 | model_restored.cuda() 77 | 78 | ## Training model path direction 79 | mode = args.model_name 80 | model_dir = os.path.join(Train['SAVE_DIR'], mode, 'models') 81 | custom_utils.mkdir(model_dir) 82 | train_dir = Train['TRAIN_DIR'] 83 | val_dir = Train['VAL_DIR'] 84 | 85 | # ## GPU 86 | # gpus = ','.join([str(i) for i in opt['GPU']]) 87 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 88 | # os.environ["CUDA_VISIBLE_DEVICES"] = gpus 89 | # device_ids = [i for i in range(torch.cuda.device_count())] 90 | # if torch.cuda.device_count() > 1: 91 | # print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 92 | # if len(device_ids) > 1: 93 | # model_restored = nn.DataParallel(model_restored, device_ids=device_ids) 94 | 95 | ## Optimizer 96 | start_epoch = 1 97 | new_lr = float(args.lr_init) 98 | optimizer = optim.Adam(model_restored.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8) 99 | 100 | ## Scheduler (Strategy) 101 | warmup_epochs = 3 102 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs - warmup_epochs, 103 | eta_min=float(args.lr_min)) 104 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 105 | scheduler.step() 106 | 107 | ## Resume (Continue training by a pretrained model) 108 | if Train['RESUME']: 109 | path_chk_rest = custom_utils.get_last_path(model_dir, '_latest.pth') 110 | custom_utils.load_checkpoint(model_restored, path_chk_rest) 111 | start_epoch = custom_utils.load_start_epoch(path_chk_rest) + 1 112 | custom_utils.load_optim(optimizer, path_chk_rest) 113 | 114 | for i in range(1, start_epoch): 115 | scheduler.step() 116 | new_lr = scheduler.get_lr()[0] 117 | print('------------------------------------------------------------------') 118 | print("==> Resuming Training with learning rate:", new_lr) 119 | print('------------------------------------------------------------------') 120 | 121 | # pretrain 122 | if args.pretrain_weights: 123 | checkpoint = torch.load(args.pretrain_weights) 124 | model_restored.load_state_dict(checkpoint["state_dict"]) 125 | # state_dict = checkpoint["state_dict"] 126 | # new_state_dict = OrderedDict() 127 | # for k, v in state_dict.items(): 128 | # name = k[7:] # remove `module.` 129 | # new_state_dict[name] = v 130 | # model_restored.load_state_dict(new_state_dict) 131 | 132 | 133 | ## Loss 134 | L1_loss, ssim_loss, vgg_loss = nn.L1Loss(), SSIM_loss(), VGGLoss() 135 | 136 | def load_data(train_dir, train_patchsize, val_dir, val_patchsize, train_batch_size, val_batch_size, shuffle=True, num_workers=16, drop_last=False): 137 | 138 | ## DataLoaders 139 | print('==> Loading datasets') 140 | if train_patchsize%64 == 0: 141 | train_dataset = PatchDataLoaderTrain(train_dir, {'patch_size': train_patchsize}) 142 | train_loader = DataLoaderX(dataset=train_dataset, batch_size=train_batch_size, 143 | shuffle=shuffle, num_workers=num_workers, drop_last=drop_last) 144 | val_dataset = PatchDataLoaderVal(val_dir, {'patch_size': val_patchsize}) 145 | val_loader = DataLoaderX(dataset=val_dataset, batch_size=val_batch_size, shuffle=shuffle, num_workers=num_workers, 146 | drop_last=drop_last) 147 | else: 148 | train_dataset = wholeDataLoader(images_path=train_dir) 149 | train_loader = DataLoaderX(train_dataset, batch_size=train_batch_size, shuffle=shuffle, num_workers=num_workers, 150 | pin_memory=True) 151 | val_dataset = wholeDataLoader(images_path=val_dir, mode='val') 152 | val_loader = DataLoaderX(val_dataset, batch_size=val_batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 153 | 154 | 155 | # Show the training configuration 156 | print(f'''==> Training details: 157 | ------------------------------------------------------------------ 158 | Restoration mode: {mode} 159 | Train patches size: {str(train_patchsize) + 'x' + str(train_patchsize)} 160 | Val patches size: {str(val_patchsize) + 'x' + str(val_patchsize)} 161 | Model parameters: {str(round(params_number,2)) + 'M'} 162 | Model FLOPs: {str(round(flops_number,2)) + 'G'} 163 | Start/End epochs: {str(start_epoch) + '~' + str(args.epochs)} 164 | Batch sizes: {train_batch_size} 165 | Learning rate: {args.lr_init} 166 | GPU: {'GPU' + str(args.gpu_id)}''') 167 | print('------------------------------------------------------------------') 168 | 169 | return train_loader, val_loader 170 | 171 | 172 | train_loader, val_loader = load_data(train_dir=train_dir, train_patchsize=Train['PATCH_SIZES'][0], 173 | val_dir=val_dir, val_patchsize=Train['PATCH_SIZES'][0], train_batch_size=Train['BATCH_SIZES'][0], val_batch_size=Train['BATCH_SIZES'][0]) 174 | 175 | 176 | best_psnr = 0 177 | best_ssim = 0 178 | best_epoch_psnr = 0 179 | best_epoch_ssim = 0 180 | total_start_time = time.time() 181 | 182 | ## Log 183 | log_dir = os.path.join(Train['SAVE_DIR'], mode, 'log') 184 | custom_utils.mkdir(log_dir) 185 | writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_{mode}') 186 | 187 | 188 | patch_sizes = Train['PATCH_SIZES'] 189 | batch_sizes = Train['BATCH_SIZES'] 190 | epochs_per_size = Train['EPOCHS_PER_SIZE'] 191 | 192 | num_sizes = len(patch_sizes) 193 | # 初始化当前尺度的索引和轮次计数 194 | current_size_index = 0 195 | current_size_epochs = 0 196 | 197 | # Start training! 198 | print('==> Multi-scale Training start with patch_sizes: ', patch_sizes, 'Batchsizes: ', batch_sizes) 199 | 200 | for epoch in range(start_epoch, args.epochs + 1): 201 | epoch_start_time = time.time() 202 | epoch_loss = 0 203 | train_id = 1 204 | 205 | if current_size_epochs >= epochs_per_size[current_size_index]: 206 | current_size_index += 1 207 | current_size_epochs = 0 208 | 209 | if current_size_index >= num_sizes: 210 | print('==> All scales have been trained, finishing early.') 211 | break 212 | 213 | train_patchsize = patch_sizes[current_size_index] 214 | train_batch_size = batch_sizes[current_size_index] 215 | val_patchsize = train_patchsize 216 | train_loader, val_loader = load_data(train_dir=train_dir, train_patchsize=train_patchsize, 217 | val_dir=val_dir, val_patchsize=val_patchsize, train_batch_size=train_batch_size, val_batch_size=train_batch_size) 218 | else: 219 | train_patchsize = patch_sizes[current_size_index] 220 | val_patchsize = train_patchsize 221 | 222 | current_size_epochs += 1 223 | 224 | 225 | model_restored.train() 226 | for i, data in enumerate(tqdm(train_loader), 0): 227 | # Forward propagation 228 | for param in model_restored.parameters(): 229 | param.grad = None 230 | input_ = data[0].cuda() 231 | target = data[1].cuda() 232 | restored = model_restored(input_) 233 | 234 | # Compute loss 235 | loss = L1_loss(restored, target)+ (1 - ssim_loss(restored, target)) + 0.01*vgg_loss(restored, target) 236 | 237 | # Back propagation 238 | loss.backward() 239 | torch.nn.utils.clip_grad_norm_(model_restored.parameters(), max_norm=1.0) 240 | optimizer.step() 241 | epoch_loss += loss.item() 242 | 243 | ## Evaluation (Validation) 244 | if epoch % Train['VAL_AFTER_EVERY'] == 0: 245 | model_restored.eval() 246 | psnr_val_rgb = [] 247 | ssim_val_rgb = [] 248 | for ii, data_val in enumerate(val_loader, 0): 249 | input_ = data_val[0].cuda() 250 | target = data_val[1].cuda() 251 | h, w = target.shape[2], target.shape[3] 252 | with torch.no_grad(): 253 | restored = model_restored(input_) 254 | restored = restored[:, :, :h, :w] 255 | for res, tar in zip(restored, target): 256 | psnr_val_rgb.append(custom_utils.torchPSNR(res, tar)) 257 | ssim_val_rgb.append(custom_utils.torchSSIM(restored, target)) 258 | 259 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 260 | ssim_val_rgb = torch.stack(ssim_val_rgb).mean().item() 261 | 262 | # Save the best PSNR model of validation 263 | if psnr_val_rgb > best_psnr: 264 | best_psnr = psnr_val_rgb 265 | best_epoch_psnr = epoch 266 | torch.save({'epoch': epoch, 267 | 'state_dict': model_restored.state_dict(), 268 | 'optimizer': optimizer.state_dict() 269 | }, os.path.join(model_dir, "model_bestPSNR.pth")) 270 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % ( 271 | epoch, psnr_val_rgb, best_epoch_psnr, best_psnr)) 272 | 273 | # Save the best SSIM model of validation 274 | if ssim_val_rgb > best_ssim: 275 | best_ssim = ssim_val_rgb 276 | best_epoch_ssim = epoch 277 | torch.save({'epoch': epoch, 278 | 'state_dict': model_restored.state_dict(), 279 | 'optimizer': optimizer.state_dict() 280 | }, os.path.join(model_dir, "model_bestSSIM.pth")) 281 | print("[epoch %d SSIM: %.4f --- best_epoch %d Best_SSIM %.4f]" % ( 282 | epoch, ssim_val_rgb, best_epoch_ssim, best_ssim)) 283 | 284 | writer.add_scalar('val/PSNR', psnr_val_rgb, epoch) 285 | writer.add_scalar('val/SSIM', ssim_val_rgb, epoch) 286 | scheduler.step() 287 | 288 | print("------------------------------------------------------------------") 289 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time, 290 | epoch_loss, scheduler.get_lr()[0])) 291 | print("------------------------------------------------------------------") 292 | 293 | # Save the last model 294 | torch.save({'epoch': epoch, 295 | 'state_dict': model_restored.state_dict(), 296 | 'optimizer': optimizer.state_dict() 297 | }, os.path.join(model_dir, "model_latest.pth")) 298 | 299 | writer.add_scalar('train/loss', epoch_loss, epoch) 300 | writer.add_scalar('train/lr', scheduler.get_lr()[0], epoch) 301 | writer.close() 302 | 303 | total_finish_time = (time.time() - total_start_time) # seconds 304 | print('Total training time: {:.1f} hours'.format((total_finish_time / 60 / 60))) 305 | --------------------------------------------------------------------------------