├── .idea
├── .gitignore
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── SuperRestoration.iml
├── datasets
├── T91
│ ├── t1.png
│ ├── t10.png
│ ├── t11.png
│ ├── t12.png
│ ├── t13.png
│ ├── t14.png
│ ├── t15.png
│ ├── t16.png
│ ├── t17.png
│ ├── t18.png
│ ├── t19.png
│ ├── t2.png
│ ├── t20.png
│ ├── t21.png
│ ├── t22.png
│ ├── t23.png
│ ├── t24.png
│ ├── t25.png
│ ├── t26.png
│ ├── t27.png
│ ├── t28.png
│ ├── t29.png
│ ├── t3.png
│ ├── t30.png
│ ├── t31.png
│ ├── t32.png
│ ├── t33.png
│ ├── t34.png
│ ├── t35.png
│ ├── t36.png
│ ├── t37.png
│ ├── t38.png
│ ├── t39.png
│ ├── t4.png
│ ├── t40.png
│ ├── t42.png
│ ├── t43.png
│ ├── t44.png
│ ├── t45.png
│ ├── t46.png
│ ├── t47.png
│ ├── t48.png
│ ├── t49.png
│ ├── t5.png
│ ├── t50.png
│ ├── t51.png
│ ├── t52.png
│ ├── t53.png
│ ├── t54.png
│ ├── t55.png
│ ├── t56.png
│ ├── t57.png
│ ├── t58.png
│ ├── t59.png
│ ├── t6.png
│ ├── t60.png
│ ├── t61.png
│ ├── t62.png
│ ├── t63.png
│ ├── t64.png
│ ├── t65.png
│ ├── t66.png
│ ├── t7.png
│ ├── t8.png
│ ├── t9.png
│ ├── tt1.png
│ ├── tt2.png
│ ├── tt3.png
│ ├── tt4.png
│ ├── tt5.png
│ ├── tt6.png
│ ├── tt7.png
│ ├── tt8.png
│ ├── tt9.png
│ ├── tt10.png
│ ├── tt12.png
│ ├── tt13.png
│ ├── tt14.png
│ ├── tt15.png
│ ├── tt16.png
│ ├── tt17.png
│ ├── tt18.png
│ ├── tt19.png
│ ├── tt20.png
│ ├── tt21.png
│ ├── tt22.png
│ ├── tt23.png
│ ├── tt24.png
│ ├── tt25.png
│ ├── tt26.png
│ └── tt27.png
├── Set14
│ ├── face.bmp
│ ├── man.bmp
│ ├── ppt3.bmp
│ ├── baboon.bmp
│ ├── bridge.bmp
│ ├── comic.bmp
│ ├── lenna.bmp
│ ├── pepper.bmp
│ ├── zebra.bmp
│ ├── barbara.bmp
│ ├── flowers.bmp
│ ├── foreman.bmp
│ ├── monarch.bmp
│ └── coastguard.bmp
└── Set5
│ ├── baby_GT.bmp
│ ├── bird_GT.bmp
│ ├── head_GT.bmp
│ ├── woman_GT.bmp
│ └── butterfly_GT.bmp
├── niqe_pris_params.npz
├── SRResNet
├── result
│ ├── comic.bmp
│ ├── baboon.bmp
│ ├── baby_GT.bmp
│ ├── flowers.bmp
│ ├── head_GT.bmp
│ ├── woman_GT.bmp
│ ├── coastguard.bmp
│ ├── baboon_SRGAN_x4.bmp
│ ├── comic_SRGAN_x4.bmp
│ ├── baboon_SRResNet_x4.bmp
│ ├── baboon_bicubic_x4.bmp
│ ├── baby_GT_SRGAN_x4.bmp
│ ├── baby_GT_bicubic_x4.bmp
│ ├── comic_SRResNet_x4.bmp
│ ├── comic_bicubic_x4.bmp
│ ├── flowers_SRGAN_x4.bmp
│ ├── flowers_bicubic_x4.bmp
│ ├── head_GT_SRGAN_x4.bmp
│ ├── head_GT_bicubic_x4.bmp
│ ├── woman_GT_SRGAN_x4.bmp
│ ├── baby_GT_SRResNet_x4.bmp
│ ├── coastguard_SRGAN_x4.bmp
│ ├── flowers_SRResNet_x4.bmp
│ ├── head_GT_SRResNet_x4.bmp
│ ├── woman_GT_SRResNet_x4.bmp
│ ├── woman_GT_bicubic_x4.bmp
│ ├── coastguard_SRResNet_x4.bmp
│ └── coastguard_bicubic_x4.bmp
├── script.py
├── SRResNet_x4_DIV2Ksub_iter=400000.py
├── VGGLoss.py
├── SRGAN_x4_MSRA_DIV2Kaug_lr=e-4_batch=16_out=96.py
├── eval.py
├── train2.py
├── SRResNetdatasets.py
├── gen_datasets.py
├── train_SRGAN_WGAN.py
├── train_SRGAN_iter.py
└── train.py
├── pth2csv.py
├── ViewLoss.py
├── SRCNN
├── SRCNN_x2.py
├── models.py
├── MatVsPy.py
├── SRCNNdatasets.py
├── gen_datasets.py
├── train.py
└── test.py
├── FSRCNN
├── FSRCNN_x3_MSRA_T91_lr=e-1_batch=128_out=19_res.py
├── N2-10-4_x3_MSRA_T91res_lr=e-1_batch=128_Huber.py
├── FSRCNNdatasets.py
├── train.py
├── train_deconv.py
├── gen_datasets.py
└── model.py
├── ESRGAN
├── ESRGAN3_x4_DFO_lr=e-4_batch=16_out=128.py
├── VGGLoss2.py
├── gen_datasets.py
├── demo_realesrgan.py
├── test_realesrgan.py
├── test_realesrgan_metrics.py
├── train_ESRGAN_iter2.py
├── train_ESRGAN_iter.py
├── ESRGANdatasets.py
└── train.py
├── FFT.py
├── SRFlow
├── glow_arch.py
├── thops.py
├── Permutations.py
├── SRFlow_model.py
├── timer.py
├── SRFlow_DF2K_4X.yml
├── Split.py
├── module_util.py
└── test_srflow.py
├── create_lmdb.py
├── HCFlow
├── thops.py
├── test_SR_DF2K_4X_HCFlow.yml
├── HCFlow_SR_model.py
├── FlowStep.py
├── networks.py
├── HCFlowNet_SR_arch.py
├── test_hcflow.py
├── loss.py
├── module_util.py
└── ActNorms.py
├── test_metrics.py
├── SwinIR
├── test_classical_swinir.py
└── test_swinir_metrics.py
├── BSRGAN
├── test_bsrgan.py
├── bsrgan_model.py
└── test_bsrgan_metrics.py
├── metrics.py
└── README.md
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/datasets/T91/t1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t1.png
--------------------------------------------------------------------------------
/datasets/T91/t10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t10.png
--------------------------------------------------------------------------------
/datasets/T91/t11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t11.png
--------------------------------------------------------------------------------
/datasets/T91/t12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t12.png
--------------------------------------------------------------------------------
/datasets/T91/t13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t13.png
--------------------------------------------------------------------------------
/datasets/T91/t14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t14.png
--------------------------------------------------------------------------------
/datasets/T91/t15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t15.png
--------------------------------------------------------------------------------
/datasets/T91/t16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t16.png
--------------------------------------------------------------------------------
/datasets/T91/t17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t17.png
--------------------------------------------------------------------------------
/datasets/T91/t18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t18.png
--------------------------------------------------------------------------------
/datasets/T91/t19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t19.png
--------------------------------------------------------------------------------
/datasets/T91/t2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t2.png
--------------------------------------------------------------------------------
/datasets/T91/t20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t20.png
--------------------------------------------------------------------------------
/datasets/T91/t21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t21.png
--------------------------------------------------------------------------------
/datasets/T91/t22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t22.png
--------------------------------------------------------------------------------
/datasets/T91/t23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t23.png
--------------------------------------------------------------------------------
/datasets/T91/t24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t24.png
--------------------------------------------------------------------------------
/datasets/T91/t25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t25.png
--------------------------------------------------------------------------------
/datasets/T91/t26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t26.png
--------------------------------------------------------------------------------
/datasets/T91/t27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t27.png
--------------------------------------------------------------------------------
/datasets/T91/t28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t28.png
--------------------------------------------------------------------------------
/datasets/T91/t29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t29.png
--------------------------------------------------------------------------------
/datasets/T91/t3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t3.png
--------------------------------------------------------------------------------
/datasets/T91/t30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t30.png
--------------------------------------------------------------------------------
/datasets/T91/t31.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t31.png
--------------------------------------------------------------------------------
/datasets/T91/t32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t32.png
--------------------------------------------------------------------------------
/datasets/T91/t33.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t33.png
--------------------------------------------------------------------------------
/datasets/T91/t34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t34.png
--------------------------------------------------------------------------------
/datasets/T91/t35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t35.png
--------------------------------------------------------------------------------
/datasets/T91/t36.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t36.png
--------------------------------------------------------------------------------
/datasets/T91/t37.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t37.png
--------------------------------------------------------------------------------
/datasets/T91/t38.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t38.png
--------------------------------------------------------------------------------
/datasets/T91/t39.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t39.png
--------------------------------------------------------------------------------
/datasets/T91/t4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t4.png
--------------------------------------------------------------------------------
/datasets/T91/t40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t40.png
--------------------------------------------------------------------------------
/datasets/T91/t42.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t42.png
--------------------------------------------------------------------------------
/datasets/T91/t43.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t43.png
--------------------------------------------------------------------------------
/datasets/T91/t44.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t44.png
--------------------------------------------------------------------------------
/datasets/T91/t45.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t45.png
--------------------------------------------------------------------------------
/datasets/T91/t46.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t46.png
--------------------------------------------------------------------------------
/datasets/T91/t47.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t47.png
--------------------------------------------------------------------------------
/datasets/T91/t48.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t48.png
--------------------------------------------------------------------------------
/datasets/T91/t49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t49.png
--------------------------------------------------------------------------------
/datasets/T91/t5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t5.png
--------------------------------------------------------------------------------
/datasets/T91/t50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t50.png
--------------------------------------------------------------------------------
/datasets/T91/t51.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t51.png
--------------------------------------------------------------------------------
/datasets/T91/t52.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t52.png
--------------------------------------------------------------------------------
/datasets/T91/t53.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t53.png
--------------------------------------------------------------------------------
/datasets/T91/t54.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t54.png
--------------------------------------------------------------------------------
/datasets/T91/t55.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t55.png
--------------------------------------------------------------------------------
/datasets/T91/t56.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t56.png
--------------------------------------------------------------------------------
/datasets/T91/t57.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t57.png
--------------------------------------------------------------------------------
/datasets/T91/t58.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t58.png
--------------------------------------------------------------------------------
/datasets/T91/t59.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t59.png
--------------------------------------------------------------------------------
/datasets/T91/t6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t6.png
--------------------------------------------------------------------------------
/datasets/T91/t60.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t60.png
--------------------------------------------------------------------------------
/datasets/T91/t61.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t61.png
--------------------------------------------------------------------------------
/datasets/T91/t62.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t62.png
--------------------------------------------------------------------------------
/datasets/T91/t63.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t63.png
--------------------------------------------------------------------------------
/datasets/T91/t64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t64.png
--------------------------------------------------------------------------------
/datasets/T91/t65.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t65.png
--------------------------------------------------------------------------------
/datasets/T91/t66.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t66.png
--------------------------------------------------------------------------------
/datasets/T91/t7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t7.png
--------------------------------------------------------------------------------
/datasets/T91/t8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t8.png
--------------------------------------------------------------------------------
/datasets/T91/t9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t9.png
--------------------------------------------------------------------------------
/datasets/T91/tt1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt1.png
--------------------------------------------------------------------------------
/datasets/T91/tt2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt2.png
--------------------------------------------------------------------------------
/datasets/T91/tt3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt3.png
--------------------------------------------------------------------------------
/datasets/T91/tt4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt4.png
--------------------------------------------------------------------------------
/datasets/T91/tt5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt5.png
--------------------------------------------------------------------------------
/datasets/T91/tt6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt6.png
--------------------------------------------------------------------------------
/datasets/T91/tt7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt7.png
--------------------------------------------------------------------------------
/datasets/T91/tt8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt8.png
--------------------------------------------------------------------------------
/datasets/T91/tt9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt9.png
--------------------------------------------------------------------------------
/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/niqe_pris_params.npz
--------------------------------------------------------------------------------
/datasets/Set14/face.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/face.bmp
--------------------------------------------------------------------------------
/datasets/Set14/man.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/man.bmp
--------------------------------------------------------------------------------
/datasets/Set14/ppt3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/ppt3.bmp
--------------------------------------------------------------------------------
/datasets/T91/tt10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt10.png
--------------------------------------------------------------------------------
/datasets/T91/tt12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt12.png
--------------------------------------------------------------------------------
/datasets/T91/tt13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt13.png
--------------------------------------------------------------------------------
/datasets/T91/tt14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt14.png
--------------------------------------------------------------------------------
/datasets/T91/tt15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt15.png
--------------------------------------------------------------------------------
/datasets/T91/tt16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt16.png
--------------------------------------------------------------------------------
/datasets/T91/tt17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt17.png
--------------------------------------------------------------------------------
/datasets/T91/tt18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt18.png
--------------------------------------------------------------------------------
/datasets/T91/tt19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt19.png
--------------------------------------------------------------------------------
/datasets/T91/tt20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt20.png
--------------------------------------------------------------------------------
/datasets/T91/tt21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt21.png
--------------------------------------------------------------------------------
/datasets/T91/tt22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt22.png
--------------------------------------------------------------------------------
/datasets/T91/tt23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt23.png
--------------------------------------------------------------------------------
/datasets/T91/tt24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt24.png
--------------------------------------------------------------------------------
/datasets/T91/tt25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt25.png
--------------------------------------------------------------------------------
/datasets/T91/tt26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt26.png
--------------------------------------------------------------------------------
/datasets/T91/tt27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt27.png
--------------------------------------------------------------------------------
/SRResNet/result/comic.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic.bmp
--------------------------------------------------------------------------------
/datasets/Set14/baboon.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/baboon.bmp
--------------------------------------------------------------------------------
/datasets/Set14/bridge.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/bridge.bmp
--------------------------------------------------------------------------------
/datasets/Set14/comic.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/comic.bmp
--------------------------------------------------------------------------------
/datasets/Set14/lenna.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/lenna.bmp
--------------------------------------------------------------------------------
/datasets/Set14/pepper.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/pepper.bmp
--------------------------------------------------------------------------------
/datasets/Set14/zebra.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/zebra.bmp
--------------------------------------------------------------------------------
/datasets/Set5/baby_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/baby_GT.bmp
--------------------------------------------------------------------------------
/datasets/Set5/bird_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/bird_GT.bmp
--------------------------------------------------------------------------------
/datasets/Set5/head_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/head_GT.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baboon.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baby_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT.bmp
--------------------------------------------------------------------------------
/SRResNet/result/flowers.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers.bmp
--------------------------------------------------------------------------------
/SRResNet/result/head_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT.bmp
--------------------------------------------------------------------------------
/SRResNet/result/woman_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT.bmp
--------------------------------------------------------------------------------
/datasets/Set14/barbara.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/barbara.bmp
--------------------------------------------------------------------------------
/datasets/Set14/flowers.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/flowers.bmp
--------------------------------------------------------------------------------
/datasets/Set14/foreman.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/foreman.bmp
--------------------------------------------------------------------------------
/datasets/Set14/monarch.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/monarch.bmp
--------------------------------------------------------------------------------
/datasets/Set5/woman_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/woman_GT.bmp
--------------------------------------------------------------------------------
/SRResNet/result/coastguard.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard.bmp
--------------------------------------------------------------------------------
/datasets/Set14/coastguard.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/coastguard.bmp
--------------------------------------------------------------------------------
/datasets/Set5/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/butterfly_GT.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baboon_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/comic_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baboon_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baboon_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon_bicubic_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baby_GT_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baby_GT_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT_bicubic_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/comic_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/comic_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic_bicubic_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/flowers_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/flowers_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers_bicubic_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/head_GT_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/head_GT_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT_bicubic_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/woman_GT_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/baby_GT_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/coastguard_SRGAN_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard_SRGAN_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/flowers_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/head_GT_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/woman_GT_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/woman_GT_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT_bicubic_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/coastguard_SRResNet_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard_SRResNet_x4.bmp
--------------------------------------------------------------------------------
/SRResNet/result/coastguard_bicubic_x4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard_bicubic_x4.bmp
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/SuperRestoration.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/pth2csv.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import torch
4 |
5 | csvFile = open("./SRCNN_x3_lr=1e-02_batch=128.csv", 'w', newline='')
6 |
7 | try:
8 | writer = csv.writer(csvFile)
9 | writer.writerow(('epoch', 'loss', 'psnr'))
10 |
11 | pthDir = './SRCNN/weight_file/9-1-5_python2/x3/'
12 | pthList = os.listdir(pthDir)
13 | if 'best.pth' in pthList:
14 | pthList.remove('best.pth')
15 | for pthName in pthList:
16 | pth = torch.load(pthDir + pthName)
17 | writer.writerow((pth['epoch'], pth['loss'], pth['psnr']))
18 | finally:
19 | csvFile.close()
20 |
--------------------------------------------------------------------------------
/ViewLoss.py:
--------------------------------------------------------------------------------
1 | from mpl_toolkits.mplot3d import Axes3D
2 | import matplotlib.pyplot as plt
3 | from matplotlib import cm
4 | from matplotlib.ticker import LinearLocator, FormatStrFormatter
5 | import numpy as np
6 |
7 | fig = plt.figure()
8 | ax = fig.gca(projection='3d')
9 |
10 | # Make data.
11 | X = np.arange(0.01, 1, 0.01)
12 | Y = np.arange(0.01, 1, 0.01)
13 | X, Y = np.meshgrid(X, Y)
14 | Z = -0.5*(np.log(1-X)+np.log(Y))
15 |
16 | # Plot the surface.
17 | surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
18 | linewidth=0, antialiased=False)
19 | # Add a color bar which maps values to colors.
20 | fig.colorbar(surf, shrink=0.5, aspect=5)
21 | plt.show()
22 |
--------------------------------------------------------------------------------
/SRCNN/SRCNN_x2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from train import train_model
3 | import os
4 | if __name__ == '__main__':
5 | program = os.path.basename(sys.argv[0]).split('.')[0]
6 | scale = 2
7 | config = {'train_file': f'../datasets/T91_aug_train_SRCNNx{scale}.h5',
8 | # config = {'train_file': 'G:/Document/datasets/T91_aug_x3.h5',
9 | 'val_file': f'../datasets/Set5_val_SRCNNx{scale}.h5',
10 | 'outputs_dir': f'./weight_file/{program}/x{scale}/',
11 | 'csv_name': f'{program}.csv',
12 | 'scale': scale,
13 | 'lr': 1e-2,
14 | 'batch_size': 128,
15 | 'num_epochs': 1000,
16 | 'num_workers': 2,
17 | 'seed': 123,
18 | 'weight_file': f'./weight_file/{program}/x{scale}/latest.pth',
19 | 'Gpu': '0',
20 | 'residual': False,
21 | 'auto_lr': False
22 | }
23 | train_model(config, from_pth=True, useVisdom=True)
--------------------------------------------------------------------------------
/SRResNet/script.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.backends import cudnn
3 | from model import G, D
4 | import os
5 |
6 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
7 | if __name__ == '__main__':
8 | config = {'weight_file': './weight_file/SRGAN7_x4_DIVsub_WGAN/',
9 | 'scale': 4
10 | }
11 | scale = config['scale']
12 | padding = scale
13 | # weight_file = config['weight_file'] + f'best.pth'
14 | weight_file = config['weight_file'] + f'x{scale}/latest.pth'
15 | if not os.path.exists(weight_file):
16 | print(f'Weight file not exist!\n{weight_file}\n')
17 | raise "Error"
18 |
19 | cudnn.benchmark = True
20 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21 |
22 | checkpoint = torch.load(weight_file)
23 |
24 | disc = D().to(device)
25 | disc = torch.nn.DataParallel(disc)
26 | disc.load_state_dict(checkpoint['disc'])
27 | disc = disc.module
28 | checkpoint['disc'] = disc.state_dict()
29 | torch.save(checkpoint, weight_file)
30 |
--------------------------------------------------------------------------------
/SRResNet/SRResNet_x4_DIV2Ksub_iter=400000.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from train2 import train_model
4 | if __name__ == '__main__':
5 | program = os.path.basename(sys.argv[0]).split('.')[0]
6 | scale = 4
7 | label_size = 96
8 | config = {'train_file': f'../datasets/DIV2K_train_HR/',
9 | 'val_file': f'../datasets/Set5_label={label_size}_val_SRResNetx{scale}.h5',
10 | 'outputs_dir': f'./weight_file/{program}/x{scale}/',
11 | 'logs_dir': f'./logs/{program}/',
12 | 'csv_name': f'{program}.csv',
13 | 'weight_file': f'./weight_file/{program}/x{scale}/latest.pth',
14 | 'scale': scale,
15 | 'out_size': label_size,
16 | 'lr': 1e-4,
17 | 'batch_size': 16,
18 | 'num_steps': 400000,
19 | 'num_workers': 16,
20 | 'seed': 123,
21 | 'init': 'MSRA',
22 | 'Gpu': '0',
23 | 'auto_lr': False,
24 | 'milestone': [1000]
25 | }
26 |
27 | train_model(config, from_pth=False, useVisdom=False)
--------------------------------------------------------------------------------
/FSRCNN/FSRCNN_x3_MSRA_T91_lr=e-1_batch=128_out=19_res.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from train import train_model
4 | if __name__ == '__main__':
5 | program = os.path.basename(sys.argv[0]).split('.')[0]
6 | scale = 3
7 | label_size = 19
8 | config = {'train_file': f'../datasets/T91_aug_label={label_size}_train_FSRCNNx{scale}_res.h5',
9 | 'val_file': f'../datasets/Set5_label={label_size}_val_FSRCNNx{scale}_res.h5',
10 | 'outputs_dir': f'../weight_file/{program}/',
11 | 'logs_dir': f'../logs/{program}/',
12 | 'csv_name': f'{program}.csv',
13 | 'weight_file': f'./weight_file/{program}/latest.pth',
14 | 'scale': scale,
15 | 'in_size': 11,
16 | 'out_size': label_size,
17 | 'd': 56,
18 | 's': 12,
19 | 'm': 4,
20 | 'lr': 1e-1,
21 | 'batch_size': 128,
22 | 'num_epochs': 100000,
23 | 'num_workers': 4,
24 | 'seed': 123,
25 | 'init': 'MSRA',
26 | 'Gpu': '0',
27 | 'auto_lr': False,
28 | 'residual': True
29 | }
30 |
31 | train_model(config, from_pth=True)
--------------------------------------------------------------------------------
/SRResNet/VGGLoss.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch.nn as nn
4 | from torchvision.models import vgg19
5 | import torch
6 | from torchvision.models import vgg as vgg
7 | # phi_5,4 5th conv layer before maxpooling but after activation
8 |
9 | class VGGLoss(nn.Module):
10 | def __init__(self, device, feature_loss=nn.MSELoss()):
11 | super().__init__()
12 | print('===> Loading VGG model')
13 | netVGG = vgg19()
14 | netVGG.load_state_dict(torch.load('../VGG19/vgg19-dcbb9e9d.pth'))
15 | self.vgg = netVGG.features[:36].eval().to(device) # VGG54, before activation
16 | self.loss = feature_loss
17 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
18 | self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
19 | for param in self.vgg.parameters():
20 | param.requires_grad = False
21 |
22 | def forward(self, pred, target):
23 | pred = (pred - self.mean) / self.std
24 | target = (target - self.mean) / self.std
25 | vgg_input_features = self.vgg(pred)
26 | vgg_target_features = self.vgg(target)
27 | return self.loss(vgg_input_features, vgg_target_features)
28 |
--------------------------------------------------------------------------------
/SRResNet/SRGAN_x4_MSRA_DIV2Kaug_lr=e-4_batch=16_out=96.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from train_SRGAN import train_model
4 | if __name__ == '__main__':
5 | program = os.path.basename(sys.argv[0]).split('.')[0]
6 | scale = 4
7 | label_size = 96
8 | config = {'train_file': f'../datasets/test/',
9 | 'val_file': f'../datasets/Set5_label={label_size}_val_SRResNetx{scale}.h5',
10 | 'outputs_dir': f'./weight_file/{program}/x{scale}/',
11 | 'csv_name': f'{program}.csv',
12 | 'weight_file': f'./weight_file/{program}/x{scale}/latest.pth',
13 | 'logs_dir': f'./logs/{program}/',
14 | 'scale': scale,
15 | 'in_size': 24,
16 | 'out_size': label_size,
17 | 'gen_lr': 1e-4,
18 | 'disc_lr': 1e-4,
19 | 'batch_size': 5,
20 | 'num_epochs': 10000,
21 | 'num_workers': 0,
22 | 'seed': 123,
23 | 'init': 'MSRA',
24 | 'Gpu': '0',
25 | 'gen_k': 1,
26 | 'disc_k': 1,
27 | }
28 |
29 | train_model(config, pre_train=None, from_pth=False)
30 | # train_model(config, pre_train='./best.pth', from_pth=False, use_visdom=False)
--------------------------------------------------------------------------------
/ESRGAN/ESRGAN3_x4_DFO_lr=e-4_batch=16_out=128.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from train_ESRGAN import train_model
4 | if __name__ == '__main__':
5 | program = os.path.basename(sys.argv[0]).split('.')[0]
6 | scale = 4
7 | label_size = 128
8 | config = {'val_file': f'../datasets/Set5_label={label_size}_val_ESRGANx{scale}.h5',
9 | 'outputs_dir': f'../weight_file/{program}/x{scale}/',
10 | 'csv_name': f'{program}.csv',
11 | 'weight_file': f'../weight_file/{program}/x{scale}/latest.pth',
12 | 'logs_dir': f'../logs/{program}/',
13 | 'scale': scale,
14 | 'in_size': 32,
15 | 'out_size': label_size,
16 | 'gen_lr': 1e-4,
17 | 'disc_lr': 1e-4,
18 | 'batch_size': 16,
19 | 'num_epochs': 10000,
20 | 'num_workers': 32,
21 | 'seed': 0,
22 | 'Gpu': '0',
23 | 'auto_lr': True,
24 | 'gen_k': 1,
25 | 'disc_k': 1,
26 | 'adversarial_weight': 5e-3,
27 | 'pixel_weight': 1e-2
28 | }
29 |
30 | train_model(config, pre_train=None, from_pth=False)
31 | # train_model(config, pre_train='./best.pth', from_pth=False)
--------------------------------------------------------------------------------
/SRCNN/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class SRCNN(nn.Module):
5 | def __init__(self, padding=False, num_channels=1):
6 | super(SRCNN, self).__init__()
7 | self.conv1 = nn.Sequential(nn.Conv2d(num_channels, 64, kernel_size=9, padding=4*int(padding), padding_mode='replicate'),
8 | nn.ReLU(inplace=True))
9 | self.conv2 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=1, padding=0), # n1 * 1 * 1 * n2
10 | nn.ReLU(inplace=True))
11 | self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2*int(padding), padding_mode='replicate')
12 |
13 | def forward(self, x):
14 | x = self.conv1(x)
15 | x = self.conv2(x)
16 | x = self.conv3(x)
17 | return x
18 |
19 | def init_weights(self):
20 | for L in self.conv1:
21 | if isinstance(L, nn.Conv2d):
22 | L.weight.data.normal_(mean=0.0, std=0.001)
23 | L.bias.data.zero_()
24 | for L in self.conv2:
25 | if isinstance(L, nn.Conv2d):
26 | L.weight.data.normal_(mean=0.0, std=0.001)
27 | L.bias.data.zero_()
28 | self.conv3.weight.data.normal_(mean=0.0, std=0.001)
29 | self.conv3.bias.data.zero_()
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/FFT.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import cv2
6 | import sys
7 | from PIL import Image
8 | from imresize import imresize
9 | import utils
10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
11 |
12 |
13 | img = Image.open('./datasets/test/t1.png').convert('RGB')
14 | img = np.array(img).astype(np.float32)
15 | img = utils.rgb2ycbcr(img)[..., 0]
16 | # img = cv2.imread('./datasets/test/t1.png', 0)
17 | bic = imresize(img, 0.5, 'bicubic')
18 |
19 | dft = cv2.dft(np.float32(img), flags=cv2.DFT_COMPLEX_OUTPUT)
20 | dft_shift = np.fft.fftshift(dft)
21 | magnitude_spectrum = 20 * np.log(cv2.magnitude(dft_shift[:, :, 0], dft_shift[:, :, 1]))
22 |
23 | dft2 = cv2.dft(np.float32(bic), flags=cv2.DFT_COMPLEX_OUTPUT)
24 | dft2 = np.fft.fftshift(dft2)
25 | dft2 = 20 * np.log(cv2.magnitude(dft2[:, :, 0], dft2[:, :, 1]))
26 |
27 | plt.subplot(221), plt.imshow(img, cmap='gray')
28 | plt.title('Input Image'), plt.xticks([]), plt.yticks([])
29 | plt.subplot(222), plt.imshow(magnitude_spectrum, cmap='gray')
30 | plt.title('Magnitude Spectrum'), plt.xticks([]), plt.yticks([])
31 | plt.subplot(223), plt.imshow(bic, cmap='gray')
32 | plt.title('bicubic'), plt.xticks([]), plt.yticks([])
33 | plt.subplot(224), plt.imshow(dft2, cmap='gray')
34 | plt.title('bicubic_fft'), plt.xticks([]), plt.yticks([])
35 | plt.show()
36 |
--------------------------------------------------------------------------------
/FSRCNN/N2-10-4_x3_MSRA_T91res_lr=e-1_batch=128_Huber.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from train_N2_10_4 import train_model
4 | if __name__ == '__main__':
5 | program = os.path.basename(sys.argv[0]).split('.')[0]
6 | scale = 3
7 | label_size = 19
8 | config = {'train_file': f'../datasets/T91_aug_label={label_size}_train_FSRCNNx{scale}_res.h5',
9 | 'val_file': f'../datasets/Set5_label={label_size}_val_FSRCNNx{scale}_res.h5',
10 | 'outputs_dir': f'./weight_file/{program}/',
11 | 'logs_dir': f'./weight_file/{program}/',
12 | 'csv_name': f'{program}.csv',
13 | 'weight_file': f'./weight_file/{program}/latest.pth',
14 | 'scale': scale,
15 | 'in_size': 11,
16 | 'out_size': label_size,
17 | 'd': 10,
18 | 'm': 4,
19 | 'lr': 1e-1,
20 | 'step_size': 1500,
21 | 'gamma': 0.1,
22 | 'weight_decay': 0,
23 | 'batch_size': 128,
24 | 'num_epochs': 50000,
25 | 'num_workers': 4,
26 | 'seed': 123,
27 | 'init': 'MSRA',
28 | 'Gpu': '0',
29 | 'auto_lr': False,
30 | 'residual': True,
31 | 'Loss': 'Huber',
32 | 'delta': 0.8
33 | }
34 |
35 | train_model(config, from_pth=False)
--------------------------------------------------------------------------------
/SRFlow/glow_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch.nn as nn
18 |
19 |
20 | def f_conv2d_bias(in_channels, out_channels):
21 | def padding_same(kernel, stride):
22 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
23 |
24 | padding = padding_same([3, 3], [1, 1])
25 | assert padding == [1, 1], padding
26 | return nn.Sequential(
27 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
28 | bias=True))
29 |
--------------------------------------------------------------------------------
/ESRGAN/VGGLoss2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models import vgg19
3 | import torch
4 | import torch.nn.functional as F
5 | # phi_5,4 5th conv layer before maxpooling but after activation
6 |
7 | class VGGLoss(nn.Module):
8 | def __init__(self, device):
9 | super().__init__()
10 | print('===> Loading VGG model')
11 | netVGG = vgg19()
12 | netVGG.load_state_dict(torch.load('./VGG19/vgg19-dcbb9e9d.pth'))
13 | self.vgg = netVGG.features[:35].eval().to(device) # VGG54, before activation
14 | self.loss = nn.L1Loss()
15 |
16 | for param in self.vgg.parameters():
17 | param.requires_grad = False
18 |
19 | # # The preprocessing method of the input data. This is the preprocessing method of the VGG model on the ImageNet dataset.
20 | # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
21 | # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
22 |
23 | def forward(self, input, target):
24 | # input = (input - self.mean) / self.std
25 | # target = (target - self.mean) / self.std
26 | vgg_input_features = self.vgg(input)
27 | vgg_target_features = self.vgg(target)
28 | return self.loss(vgg_input_features, vgg_target_features)
29 |
30 |
31 | if __name__ == '__main__':
32 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
33 | v = VGGLoss(device)
34 |
--------------------------------------------------------------------------------
/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import lmdb
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | root_dirs = ['./datasets/Set5/']
8 | lmdb_save_path = './datasets/Set5.lmdb'
9 | # 创建数据库文件
10 | img_list = []
11 | for root_dir in root_dirs:
12 | img_names = os.listdir(root_dir)
13 | for name in img_names:
14 | img_list.append(root_dir + name)
15 | dataset = []
16 | data_size = 0
17 |
18 | print('Read images...')
19 | with tqdm(total=len(img_list)) as t:
20 | for i, v in enumerate(img_list):
21 | img = np.array(Image.open(v).convert('RGB'))
22 | dataset.append(img)
23 | data_size += img.nbytes
24 | t.update(1)
25 | env = lmdb.open(lmdb_save_path, max_dbs=2, map_size=data_size * 2)
26 | print('Finish reading {} images.\nWrite lmdb...'.format(len(img_list)))
27 | # 创建对应的数据库
28 | train_data = env.open_db("train_data".encode('ascii'))
29 | train_shape = env.open_db("train_shape".encode('ascii'))
30 | # 把图像数据写入到LMDB中
31 | with env.begin(write=True) as txn:
32 | with tqdm(total=len(dataset)) as t:
33 | for idx, img in enumerate(dataset):
34 | H, W, C = img.shape
35 | txn.put(str(idx).encode('ascii'), img, db=train_data)
36 | meta_key = (str(idx) + '.meta').encode('ascii')
37 | meta = '{:d}, {:d}, {:d}'.format(H, W, C)
38 | txn.put(meta_key, meta.encode('ascii'), db=train_shape)
39 | t.update(1)
40 | env.close()
41 | print('Finish writing lmdb.')
--------------------------------------------------------------------------------
/HCFlow/thops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def sum(tensor, dim=None, keepdim=False):
5 | if dim is None:
6 | # sum up all dim
7 | return torch.sum(tensor)
8 | else:
9 | if isinstance(dim, int):
10 | dim = [dim]
11 | dim = sorted(dim)
12 | for d in dim:
13 | tensor = tensor.sum(dim=d, keepdim=True)
14 | if not keepdim:
15 | for i, d in enumerate(dim):
16 | tensor.squeeze_(d-i)
17 | return tensor
18 |
19 |
20 | def mean(tensor, dim=None, keepdim=False):
21 | if dim is None:
22 | # mean all dim
23 | return torch.mean(tensor)
24 | else:
25 | if isinstance(dim, int):
26 | dim = [dim]
27 | dim = sorted(dim)
28 | for d in dim:
29 | tensor = tensor.mean(dim=d, keepdim=True)
30 | if not keepdim:
31 | for i, d in enumerate(dim):
32 | tensor.squeeze_(d-i)
33 | return tensor
34 |
35 |
36 |
37 | def split_feature(tensor, type="split"):
38 | """
39 | type = ["split", "cross"]
40 | """
41 | C = tensor.size(1)
42 | if type == "split":
43 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
44 | elif type == "cross":
45 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
46 |
47 |
48 | def cat_feature(tensor_a, tensor_b):
49 | return torch.cat((tensor_a, tensor_b), dim=1)
50 |
51 |
52 | def pixels(tensor):
53 | return int(tensor.size(2) * tensor.size(3))
--------------------------------------------------------------------------------
/FSRCNN/FSRCNNdatasets.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class TrainDataset(Dataset):
7 | def __init__(self, h5_file):
8 | super(TrainDataset, self).__init__()
9 | self.h5_file = h5_file
10 |
11 | def __getitem__(self, idx):
12 | with h5py.File(self.h5_file, 'r') as f:
13 | return np.expand_dims(f['data'][idx] / 255., 0), np.expand_dims(f['label'][idx] / 255., 0)
14 |
15 | def __len__(self):
16 | with h5py.File(self.h5_file, 'r') as f:
17 | return len(f['data'])
18 |
19 |
20 | class ValDataset(Dataset):
21 | def __init__(self, h5_file):
22 | super(ValDataset, self).__init__()
23 | self.h5_file = h5_file
24 |
25 | def __getitem__(self, idx):
26 | with h5py.File(self.h5_file, 'r') as f:
27 | return np.expand_dims(f['data'][str(idx)][:, :] / 255., 0), np.expand_dims(f['label'][str(idx)][:, :] / 255., 0)
28 |
29 | def __len__(self):
30 | with h5py.File(self.h5_file, 'r') as f:
31 | return len(f['data'])
32 |
33 |
34 | class ResValDataset(Dataset):
35 | def __init__(self, h5_file):
36 | super(ResValDataset, self).__init__()
37 | self.h5_file = h5_file
38 |
39 | def __getitem__(self, idx):
40 | with h5py.File(self.h5_file, 'r') as f:
41 | return np.expand_dims(f['data'][str(idx)][:, :] / 255., 0), \
42 | np.expand_dims(f['label'][str(idx)][:, :] / 255., 0),\
43 | np.expand_dims(f['bicubic'][str(idx)][:, :] / 255., 0)
44 |
45 | def __len__(self):
46 | with h5py.File(self.h5_file, 'r') as f:
47 | return len(f['data'])
48 |
--------------------------------------------------------------------------------
/SRCNN/MatVsPy.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 | from scipy.io import loadmat
4 | import utils
5 | import numpy as np
6 | import transplant
7 | import os
8 | from PIL import Image
9 | import imresize
10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
11 |
12 |
13 |
14 |
15 | if __name__ == '__main__':
16 | # h5_path = '../datasets/Matlab_Set5_val_SRCNNx3.h5'
17 | # with h5py.File(h5_path, 'r') as f:
18 | # len = len(f['data'])
19 | # for idx in range(len):
20 | # input = f['data'][idx]
21 | # label = f['label'][idx]
22 |
23 | # m = loadmat("G:/Document/文献资料/超分辨率/SRCNN_train/SRCNN/model/9-1-5(91 images)/x2.mat")
24 | # x = m['weights_conv1']
25 | # x = np.reshape(x, (9, 9, 64))
26 | # ax = utils.viz_layer2(x, 64)
27 |
28 | hrfile = 'G:/Document/pythonProj/SuperRestoration/datasets/T91_aug/t10_0_x0.6.png'
29 | hrfile2 = 'G:/Document/pythonProj/SuperRestoration/datasets/T91_augt10_rot0_s6.bmp'
30 | matlab = transplant.Matlab(executable='G:/Software/Matlab2020b/bin/matlab.exe')
31 | hr_py = Image.open(hrfile).convert('RGB') # RGBA->RGB
32 | hr_py = np.array(hr_py)
33 | hr_mat = matlab.imread(hrfile)
34 | hr_mat = hr_mat[0][:, :, :]
35 | diff = hr_py - hr_mat # 无差别
36 | hr_mat = matlab.rgb2ycbcr(hr_py)
37 | hr_mat_y = hr_mat[:, :, 0]
38 | hr_py_ycbcr = utils.rgb2ycbcr(hr_py)
39 | hr_py_ycbcr = hr_py_ycbcr.astype(np.uint8)
40 | hr_py_y = hr_py_ycbcr[:, :, 0]
41 | diff = hr_mat - hr_py_ycbcr # 有很大差别,所以不能使用python的ycbcr
42 |
43 | hr_py = Image.open(hrfile).convert('RGB')
44 | hr_mat = matlab.imread(hrfile)
45 | hr_mat = hr_mat[0][:, :, :]
46 | scale = 3
47 | hr_py = np.array(hr_py).astype(np.uint8)
48 | lr_mat = matlab.imresize(hr_py, 1 / scale, 'bicubic')[0]
49 | lr_mat2 = imresize.imresize(hr_py, 1 / scale)
50 | diff = lr_mat2 - lr_mat
51 |
52 |
53 |
--------------------------------------------------------------------------------
/test_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from tqdm import tqdm
3 | import utils
4 | import metrics
5 | import os
6 |
7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
8 | if __name__ == '__main__':
9 | model_name = 'RealESRGAN'
10 | scale = 4
11 | itplt = 'bilinear'
12 | # itplt = 'bicubic'
13 | dataset = f'degraded4_{itplt}_heavy'
14 | # dataset = f'degraded5_{itplt}_medium'
15 | # dataset = f'degraded6_{itplt}_slight'
16 |
17 | sr_root_dir_degradation = f'./test_res/{model_name}_{dataset}/sort_with_degradation/'
18 | gt_dir = './datasets/PIPAL/'
19 | deg_types = os.listdir(sr_root_dir_degradation)
20 |
21 | csv_file = f'./test_res/{model_name}_{dataset}.csv'
22 | csv_file = open(csv_file, 'w', newline='')
23 | writer = csv.writer(csv_file)
24 | writer.writerow(('name', 'psnr', 'niqe', 'ssim', 'lpips'))
25 | for deg_type in deg_types:
26 | sr_dir = sr_root_dir_degradation + deg_type + '/'
27 | sr_lists = os.listdir(sr_dir)
28 | Avg_psnr = utils.AverageMeter()
29 | Avg_niqe = utils.AverageMeter()
30 | Avg_ssim = utils.AverageMeter()
31 | Avg_lpips = utils.AverageMeter()
32 | with tqdm(total=len(sr_lists)) as t:
33 | t.set_description(f"Processing: {deg_type}")
34 | for imgName in sr_lists:
35 | SR = utils.loadIMG_crop(sr_dir + imgName, scale)
36 | GT = utils.loadIMG_crop(gt_dir + imgName, scale)
37 | my_metrics = metrics.calc_metric(SR, GT, ['lpips'])
38 | Avg_lpips.update(my_metrics['lpips'], 1)
39 | # Avg_psnr.update(my_metrics['psnr'], 1)
40 | # Avg_niqe.update(my_metrics['niqe'], 1)
41 | # Avg_ssim.update(my_metrics['ssim'], 1)
42 | t.update(1)
43 | # writer.writerow((deg_type, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg, Avg_lpips.avg))
44 |
--------------------------------------------------------------------------------
/SRCNN/SRCNNdatasets.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 |
5 |
6 | class TrainDataset(Dataset):
7 | def __init__(self, h5_file):
8 | super(TrainDataset, self).__init__()
9 | self.h5_file = h5_file
10 |
11 | def __getitem__(self, idx):
12 | with h5py.File(self.h5_file, 'r') as f:
13 | return np.expand_dims(f['data'][idx] / 255., 0), np.expand_dims(f['label'][idx] / 255., 0)
14 |
15 | def __len__(self):
16 | with h5py.File(self.h5_file, 'r') as f:
17 | return len(f['data'])
18 |
19 |
20 | class ValDataset(Dataset):
21 | def __init__(self, h5_file):
22 | super(ValDataset, self).__init__()
23 | self.h5_file = h5_file
24 |
25 | def __getitem__(self, idx):
26 | with h5py.File(self.h5_file, 'r') as f:
27 | return np.expand_dims(f['data'][str(idx)][:, :] / 255., 0), np.expand_dims(f['label'][str(idx)][:, :] / 255., 0)
28 |
29 | def __len__(self):
30 | with h5py.File(self.h5_file, 'r') as f:
31 | return len(f['data'])
32 |
33 |
34 | class MatlabTrainDataset(Dataset):
35 | def __init__(self, h5_file):
36 | super(MatlabTrainDataset, self).__init__()
37 | self.h5_file = h5_file
38 |
39 | def __getitem__(self, idx):
40 | with h5py.File(self.h5_file, 'r') as f:
41 | return f['data'][idx], f['label'][idx]
42 |
43 | def __len__(self):
44 | with h5py.File(self.h5_file, 'r') as f:
45 | return len(f['data'])
46 |
47 |
48 | class MatlabValidDataset(Dataset):
49 | def __init__(self, h5_file):
50 | super(MatlabValidDataset, self).__init__()
51 | self.h5_file = h5_file
52 |
53 | def __getitem__(self, idx):
54 | with h5py.File(self.h5_file, 'r') as f:
55 | return f['data'][idx], f['label'][idx]
56 |
57 | def __len__(self):
58 | with h5py.File(self.h5_file, 'r') as f:
59 | return len(f['data'])
--------------------------------------------------------------------------------
/ESRGAN/gen_datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 | import utils
5 | from imresize import imresize
6 | from PIL import Image
7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
8 |
9 |
10 | def gen_valdata(config):
11 | scale = config["scale"]
12 | size_output = config['size_output']
13 | method = config['method']
14 | h5savepath = config["hrDir"] + f'_label={size_output}_val_ESRGANx{scale}.h5'
15 | hrDir = config["hrDir"] + '/'
16 | h5_file = h5py.File(h5savepath, 'w')
17 | lr_group = h5_file.create_group('data')
18 | hr_group = h5_file.create_group('label')
19 | imgList = os.listdir(hrDir)
20 | for i, imgName in enumerate(imgList):
21 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale).convert('RGB')
22 | hr = utils.img2ycbcr(hrIMG, gray2rgb=True).astype(np.float32)
23 | lr = imresize(np.array(hrIMG).astype(np.float32), 1 / scale, method)
24 |
25 | data = lr.astype(np.float32).transpose([2, 0, 1])
26 | label = hr.transpose([2, 0, 1])
27 |
28 | lr_group.create_dataset(str(i), data=data)
29 | hr_group.create_dataset(str(i), data=label)
30 | h5_file.close()
31 |
32 |
33 | def scale_OST():
34 | # some OST image has a size lower than 128*128, resize it with scale 2
35 | root_dir = '../datasets/OST/'
36 |
37 | img_names = os.listdir(root_dir)
38 | for name in img_names:
39 | img_path = root_dir + name
40 | image = Image.open(img_path)
41 |
42 | if image.width < 128 or image.height < 128:
43 | print(img_path)
44 | image.save('../datasets/OST_backup/' + name)
45 | image = np.array(image)
46 | image = imresize(image, 2, 'bicubic')
47 | image = Image.fromarray(image.astype(np.uint8))
48 | image.save(img_path)
49 |
50 |
51 | if __name__ == '__main__':
52 | config = {'hrDir': '../datasets/Set14', 'scale': 4, 'size_output': 128, 'method': 'bicubic'}
53 | gen_valdata(config)
54 |
--------------------------------------------------------------------------------
/SRResNet/eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.backends import cudnn
4 | import utils
5 | from PIL import Image
6 | from imresize import imresize
7 | import os
8 | import niqe
9 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
10 | if __name__ == '__main__':
11 | model_name = 'ESRGAN'
12 | rootdir = '../datasets/'
13 | gnd_data = 'BSDS100/'
14 | test_data = f'SRGAN_official/{gnd_data}'
15 | gnd_dir = rootdir + gnd_data
16 | test_dir = rootdir + test_data
17 | scale = 4
18 | padding = scale
19 |
20 |
21 | cudnn.benchmark = True
22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23 |
24 |
25 | imglist = os.listdir(gnd_dir)
26 | testlist = os.listdir(test_dir)
27 | Avg_psnr = utils.AverageMeter()
28 | Avg_niqe = utils.AverageMeter()
29 | for idx, imgName in enumerate(imglist):
30 | image = utils.loadIMG_crop(gnd_dir + imgName, scale)
31 | SR = utils.loadIMG_crop(test_dir + testlist[idx], scale)
32 | img_mode = image.mode
33 | if img_mode == 'L':
34 | gray_img = np.array(image)
35 | hr_image = np.array(image.convert('RGB')).astype(np.float32)
36 | hr_image = hr_image[padding: -padding, padding: -padding, ...]
37 |
38 | SR = np.array(SR).astype(np.float32)
39 | SR = SR[padding: -padding, padding: -padding, ...]
40 | if img_mode != 'L':
41 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255.
42 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0]/255.
43 | else:
44 | gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...]
45 | hr_y = gray_img / 255.
46 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L')
47 | SR_y = np.array(SR).astype(np.float32) / 255.
48 | psnr = utils.calc_psnr(hr_y, SR_y)
49 | NIQE = niqe.calculate_niqe(SR_y)
50 | Avg_psnr.update(psnr, 1)
51 | Avg_niqe.update(NIQE, 1)
52 | print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}'.format(psnr.item(), NIQE))
53 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg))
--------------------------------------------------------------------------------
/ESRGAN/demo_realesrgan.py:
--------------------------------------------------------------------------------
1 | import csv
2 |
3 | import numpy as np
4 | import torch
5 | from torch.backends import cudnn
6 | import niqe
7 | import utils
8 | from model import G, G2
9 | from PIL import Image
10 | import os
11 |
12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13 | if __name__ == '__main__':
14 | model_name = 'RealESRGAN'
15 | weight_file = '../weight_file/RealESRGAN_x4plus.pth'
16 | img_dir = '../datasets/Real-ESRGAN_input/'
17 | outputs_dir = './test_res/realESRGAN_x4/'
18 | utils.mkdirs(outputs_dir)
19 | scale = 4
20 | padding = scale
21 |
22 | if not os.path.exists(weight_file):
23 | print(f'Weight file not exist!\n{weight_file}\n')
24 | raise "Error"
25 |
26 | cudnn.benchmark = True
27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28 |
29 | model = G2().to(device)
30 | checkpoint = torch.load(weight_file)
31 | if model_name == 'ESRGAN':
32 | model.load_state_dict(checkpoint['gen'])
33 | else:
34 | model.load_state_dict(checkpoint['params_ema'])
35 | model.eval()
36 | imglist = os.listdir(img_dir)
37 |
38 | for imgName in imglist:
39 | image = utils.loadIMG_crop(img_dir + imgName, scale)
40 | img_mode = image.mode
41 | if img_mode == 'L':
42 | gray_img = np.array(image)
43 | image = image.convert('RGB')
44 | lr_image = np.array(image)
45 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
46 | lr /= 255.
47 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
48 |
49 | with torch.no_grad():
50 | SR = model(lr)
51 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
52 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
53 | if img_mode != 'L':
54 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255.
55 | else:
56 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L')
57 | SR_y = np.array(SR).astype(np.float32) / 255.
58 | # GPU tensor -> CPU tensor -> numpy
59 | output = np.array(SR).astype(np.uint8)
60 | output = Image.fromarray(output) # hw -> wh
61 | output.save(outputs_dir + imgName)
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/HCFlow/test_SR_DF2K_4X_HCFlow.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: 001_HCFlow_DF2K_x4_bicSR_test
3 | suffix: ~
4 | use_tb_logger: false
5 | model: HCFlow_SR
6 | distortion: sr
7 | scale: 4
8 | quant: 64
9 | gpu_ids: [0]
10 |
11 |
12 |
13 | datasets:
14 | test0:
15 | name: example
16 | mode: GTLQ
17 | dataroot_GT: ../datasets/example_general_4X/HR
18 | dataroot_LQ: ../datasets/example_general_4X/LR
19 |
20 | # test_1:
21 | # name: Set5
22 | # mode: GTLQx
23 | # dataroot_GT: ../datasets/Set5/HR
24 | # dataroot_LQ: ../datasets/Set5/LR_bicubic/X4
25 |
26 | # test_2:
27 | # name: Set14
28 | # mode: GTLQx
29 | # dataroot_GT: ../datasets/Set14/HR
30 | # dataroot_LQ: ../datasets/Set14/LR_bicubic/X4
31 | #
32 | # test_3:
33 | # name: BSD100
34 | # mode: GTLQx
35 | # dataroot_GT: ../datasets/BSD100/HR
36 | # dataroot_LQ: ../datasets/BSD100/LR_bicubic/X4
37 | #
38 | # test_4:
39 | # name: Urban100
40 | # mode: GTLQx
41 | # dataroot_GT: ../datasets/Urban100/HR
42 | # dataroot_LQ: ../datasets/Urban100/LR_bicubic/X4
43 | #
44 | # test_5:
45 | # name: DIV2K-va-4X
46 | # mode: GTLQ
47 | # dataroot_GT: ../datasets/srflow_datasets/div2k-validation-modcrop8-gt
48 | # dataroot_LQ: ../datasets/srflow_datasets/div2k-validation-modcrop8-x4
49 |
50 |
51 | #### network structures
52 | network_G:
53 | which_model_G: HCFlowNet_SR
54 | in_nc: 3
55 | out_nc: 3
56 | act_norm_start_step: 100
57 |
58 | flowDownsampler:
59 | K: 26
60 | L: 2
61 | flow_permutation: invconv
62 | flow_coupling: Affine
63 | nn_module: FCN
64 | hidden_channels: 64
65 | cond_channels: ~
66 | splitOff:
67 | enable: true
68 | after_flowstep: [13, 13]
69 | flow_permutation: invconv
70 | flow_coupling: Affine
71 | nn_module: FCN
72 | nn_module_last: Conv2dZeros
73 | hidden_channels: 64
74 | RRDB_nb: [7, 7]
75 | RRDB_nf: 64
76 | RRDB_gc: 32
77 |
78 |
79 | #### validation settings
80 | val:
81 | heats: [0.0]
82 | n_sample: 3
83 |
84 |
85 | path:
86 | strict_load: true
87 | load_submodule: ~
88 | # pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth
89 | # pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow+.pth
90 | pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow++.pth
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/SRFlow/thops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 |
19 |
20 | def sum(tensor, dim=None, keepdim=False):
21 | if dim is None:
22 | # sum up all dim
23 | return torch.sum(tensor)
24 | else:
25 | if isinstance(dim, int):
26 | dim = [dim]
27 | dim = sorted(dim)
28 | for d in dim:
29 | tensor = tensor.sum(dim=d, keepdim=True)
30 | if not keepdim:
31 | for i, d in enumerate(dim):
32 | tensor.squeeze_(d-i)
33 | return tensor
34 |
35 |
36 | def mean(tensor, dim=None, keepdim=False):
37 | if dim is None:
38 | # mean all dim
39 | return torch.mean(tensor)
40 | else:
41 | if isinstance(dim, int):
42 | dim = [dim]
43 | dim = sorted(dim)
44 | for d in dim:
45 | tensor = tensor.mean(dim=d, keepdim=True)
46 | if not keepdim:
47 | for i, d in enumerate(dim):
48 | tensor.squeeze_(d-i)
49 | return tensor
50 |
51 |
52 | def split_feature(tensor, type="split"):
53 | """
54 | type = ["split", "cross"]
55 | """
56 | C = tensor.size(1)
57 | if type == "split":
58 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
59 | elif type == "cross":
60 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
61 |
62 |
63 | def cat_feature(tensor_a, tensor_b):
64 | return torch.cat((tensor_a, tensor_b), dim=1)
65 |
66 |
67 | def pixels(tensor):
68 | return int(tensor.size(2) * tensor.size(3))
--------------------------------------------------------------------------------
/SRFlow/Permutations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import numpy as np
18 | import torch
19 | from torch import nn as nn
20 | from torch.nn import functional as F
21 | import thops
22 |
23 |
24 | class InvertibleConv1x1(nn.Module):
25 | def __init__(self, num_channels, LU_decomposed=False):
26 | super().__init__()
27 | w_shape = [num_channels, num_channels]
28 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
29 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
30 | self.w_shape = w_shape
31 | self.LU = LU_decomposed
32 |
33 | def get_weight(self, input, reverse):
34 | w_shape = self.w_shape
35 | pixels = thops.pixels(input)
36 | dlogdet = torch.slogdet(self.weight)[1] * pixels
37 | if not reverse:
38 | weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
39 | else:
40 | weight = torch.inverse(self.weight.double()).float() \
41 | .view(w_shape[0], w_shape[1], 1, 1)
42 | return weight, dlogdet
43 | def forward(self, input, logdet=None, reverse=False):
44 | """
45 | log-det = log|abs(|W|)| * pixels
46 | """
47 | weight, dlogdet = self.get_weight(input, reverse)
48 | if not reverse:
49 | z = F.conv2d(input, weight)
50 | if logdet is not None:
51 | logdet = logdet + dlogdet
52 | return z, logdet
53 | else:
54 | z = F.conv2d(input, weight)
55 | if logdet is not None:
56 | logdet = logdet - dlogdet
57 | return z, logdet
58 |
--------------------------------------------------------------------------------
/HCFlow/HCFlow_SR_model.py:
--------------------------------------------------------------------------------
1 | # base model for HCFlow
2 | import logging
3 | from collections import OrderedDict
4 | import torch
5 | from HCFlowNet_SR_arch import HCFlowNet_SR
6 | from base_model import BaseModel
7 | logger = logging.getLogger('base')
8 |
9 |
10 | class HCFlowSRModel(BaseModel):
11 | def __init__(self, opt, heats=[0.0], step=0):
12 | super(HCFlowSRModel, self).__init__(opt)
13 | self.opt = opt
14 |
15 | self.rank = -1 # non dist training
16 |
17 | # define network and load pretrained models
18 | self.netG = HCFlowNet_SR(opt, step).to(self.device)
19 | self.heats = heats
20 | # val
21 | if 'val' in opt:
22 | # self.heats = opt['val']['heats']
23 | self.n_sample = opt['val']['n_sample']
24 | self.sr_mode = opt['val']['sr_mode']
25 |
26 | def feed_data(self, data, need_GT=True):
27 | self.var_L = data['LQ'].to(self.device) # LQ
28 | if need_GT:
29 | self.real_H = data['GT'].to(self.device) # GT
30 | else:
31 | self.real_H = None
32 |
33 | def test(self):
34 | self.netG.eval()
35 | self.fake_H = {}
36 |
37 | with torch.no_grad():
38 | if self.real_H is None:
39 | nll = torch.zeros(1)
40 | else:
41 | # hr->lr+z, calculate nll
42 | self.fake_L_from_H, nll = self.netG(hr=self.real_H, lr=self.var_L, u=None, reverse=False, training=False)
43 |
44 | # lr+z->hr
45 | for heat in self.heats:
46 | for sample in range(self.n_sample):
47 | # z = self.get_z(heat, seed=1, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
48 | self.fake_H[(heat, sample)] = self.netG(lr=self.var_L,
49 | z=None, u=None, eps_std=heat, reverse=True, training=False)
50 |
51 | self.netG.train()
52 |
53 | return nll.mean().item()
54 |
55 | def get_current_visuals(self, need_GT=True):
56 | out_dict = OrderedDict()
57 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
58 | for heat in self.heats:
59 | for i in range(self.n_sample):
60 | out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu()
61 |
62 | if need_GT:
63 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
64 | out_dict['LQ_fromH'] = self.fake_L_from_H.detach()[0].float().cpu()
65 |
66 | return out_dict
67 |
68 |
--------------------------------------------------------------------------------
/HCFlow/FlowStep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | import ActNorms, Permutations, AffineCouplings
5 |
6 |
7 | class FlowStep(nn.Module):
8 | def __init__(self, in_channels, cond_channels=None, flow_permutation='invconv', flow_coupling='Affine', LRvsothers=True,
9 | actnorm_scale=1.0, LU_decomposed=False, opt=None):
10 | super().__init__()
11 | self.flow_permutation = flow_permutation
12 | self.flow_coupling = flow_coupling
13 |
14 | # 1. actnorm
15 | self.actnorm = ActNorms.ActNorm2d(in_channels, actnorm_scale)
16 |
17 | # 2. permute # todo: maybe hurtful for downsampling; presever the structure of downsampling
18 | if self.flow_permutation == "invconv":
19 | self.permute = Permutations.InvertibleConv1x1(in_channels, LU_decomposed=LU_decomposed)
20 | elif self.flow_permutation == "none":
21 | self.permute = None
22 |
23 | # 3. coupling
24 | if self.flow_coupling == "AffineInjector":
25 | self.affine = AffineCouplings.AffineCouplingInjector(in_channels=in_channels, cond_channels=cond_channels, opt=opt)
26 | elif self.flow_coupling == "noCoupling":
27 | pass
28 | elif self.flow_coupling == "Affine":
29 | self.affine = AffineCouplings.AffineCoupling(in_channels=in_channels, cond_channels=cond_channels, opt=opt)
30 | elif self.flow_coupling == "Affine3shift":
31 | self.affine = AffineCouplings.AffineCoupling3shift(in_channels=in_channels, cond_channels=cond_channels, LRvsothers=LRvsothers, opt=opt)
32 |
33 | def forward(self, z, u=None, logdet=None, reverse=False):
34 | if not reverse:
35 | return self.normal_flow(z, u, logdet)
36 | else:
37 | return self.reverse_flow(z, u)
38 |
39 | def normal_flow(self, z, u=None, logdet=None):
40 | # 1. actnorm
41 | z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
42 |
43 | # 2. permute
44 | if self.permute is not None:
45 | z, logdet = self.permute( z, logdet=logdet, reverse=False)
46 |
47 | # 3. coupling
48 | z, logdet = self.affine(z, u=u, logdet=logdet, reverse=False)
49 |
50 | return z, logdet
51 |
52 | def reverse_flow(self, z, u=None, logdet=None):
53 | # 1.coupling
54 | z, _ = self.affine(z, u=u, reverse=True)
55 |
56 | # 2. permute
57 | if self.permute is not None:
58 | z, _ = self.permute(z, reverse=True)
59 |
60 | # 3. actnorm
61 | z, _ = self.actnorm(z, reverse=True)
62 |
63 | return z, logdet
64 |
65 |
--------------------------------------------------------------------------------
/SRFlow/SRFlow_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import torch
18 | from SRFlowNet_arch import SRFlowNet
19 | from base_model import BaseModel
20 |
21 | class SRFlowModel(BaseModel):
22 | def __init__(self, opt, step=0):
23 | super(SRFlowModel, self).__init__(opt)
24 | self.opt = opt
25 |
26 | self.heats = opt['val']['heats']
27 | self.n_sample = opt['val']['n_sample']
28 | # define network and load pretrained models
29 | opt_net = opt['network_G']
30 | self.netG = SRFlowNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'],
31 | scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step).to(self.device)
32 |
33 | def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
34 | return self.get_sr_with_z(lq, heat, seed, z, epses)[0]
35 |
36 | def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
37 | self.netG.eval()
38 |
39 | z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z
40 |
41 | with torch.no_grad():
42 | sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses)
43 | self.netG.train()
44 | return sr, z
45 |
46 | def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
47 | if seed: torch.manual_seed(seed)
48 | C = self.netG.flowUpsamplerNet.C
49 | H = int(self.opt['scale'] * lr_shape[2] // self.netG.flowUpsamplerNet.scaleH)
50 | W = int(self.opt['scale'] * lr_shape[3] // self.netG.flowUpsamplerNet.scaleW)
51 | z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros(
52 | (batch_size, C, H, W))
53 | return z
54 |
--------------------------------------------------------------------------------
/HCFlow/networks.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | import torch
4 | import discriminator_vgg_arch as SRGAN_arch
5 |
6 | logger = logging.getLogger('base')
7 |
8 |
9 | def find_model_using_name(model_name):
10 | model_filename = "models.modules." + model_name + "_arch"
11 | modellib = importlib.import_module(model_filename)
12 |
13 | model = None
14 | target_model_name = model_name.replace('_Net', '')
15 | for name, cls in modellib.__dict__.items():
16 | if name.lower() == target_model_name.lower():
17 | model = cls
18 |
19 | if model is None:
20 | print(
21 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
22 | model_filename, target_model_name))
23 | exit(0)
24 |
25 | return model
26 |
27 | def define_Flow(opt, step):
28 | opt_net = opt['network_G']
29 | which_model = opt_net['which_model_G']
30 |
31 | Arch = find_model_using_name(which_model)
32 | netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
33 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step)
34 | return netG
35 |
36 | def define_G(opt, step):
37 | which_model = opt['network_G']['which_model_G']
38 |
39 | Arch = find_model_using_name(which_model)
40 | netG = Arch(opt=opt, step=step)
41 | return netG
42 |
43 | #### Discriminator
44 | def define_D(opt):
45 | opt_net = opt['network_D']
46 | which_model = opt_net['which_model_D']
47 |
48 | if which_model == 'discriminator_vgg_128':
49 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
50 | elif which_model == 'discriminator_vgg_160':
51 | netD = SRGAN_arch.Discriminator_VGG_160(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
52 | elif which_model == 'PatchGANDiscriminator':
53 | netD = SRGAN_arch.PatchGANDiscriminator(in_nc=opt_net['in_nc'], ndf=opt_net['ndf'], n_layers=opt_net['n_layers'],)
54 | else:
55 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
56 | return netD
57 |
58 |
59 | #### Define Network used for Perceptual Loss
60 | def define_F(opt, use_bn=False):
61 | gpu_ids = opt['gpu_ids']
62 | device = torch.device('cuda' if gpu_ids else 'cpu')
63 | # PyTorch pretrained VGG19-54, before ReLU.
64 | if use_bn:
65 | feature_layer = 49
66 | else:
67 | feature_layer = 34
68 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
69 | use_input_norm=True, device=device)
70 | netF.eval() # No need to train
71 | return netF
72 |
--------------------------------------------------------------------------------
/SRFlow/timer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import time
18 |
19 |
20 | class ScopeTimer:
21 | def __init__(self, name):
22 | self.name = name
23 |
24 | def __enter__(self):
25 | self.start = time.time()
26 | return self
27 |
28 | def __exit__(self, *args):
29 | self.end = time.time()
30 | self.interval = self.end - self.start
31 | print("{} {:.3E}".format(self.name, self.interval))
32 |
33 |
34 | class Timer:
35 | def __init__(self):
36 | self.times = []
37 |
38 | def tick(self):
39 | self.times.append(time.time())
40 |
41 | def get_average_and_reset(self):
42 | if len(self.times) < 2:
43 | return -1
44 | avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1)
45 | self.times = [self.times[-1]]
46 | return avg
47 |
48 | def get_last_iteration(self):
49 | if len(self.times) < 2:
50 | return 0
51 | return self.times[-1] - self.times[-2]
52 |
53 |
54 | class TickTock:
55 | def __init__(self):
56 | self.time_pairs = []
57 | self.current_time = None
58 |
59 | def tick(self):
60 | self.current_time = time.time()
61 |
62 | def tock(self):
63 | assert self.current_time is not None, self.current_time
64 | self.time_pairs.append([self.current_time, time.time()])
65 | self.current_time = None
66 |
67 | def get_average_and_reset(self):
68 | if len(self.time_pairs) == 0:
69 | return -1
70 | deltas = [t2 - t1 for t1, t2 in self.time_pairs]
71 | avg = sum(deltas) / len(deltas)
72 | self.time_pairs = []
73 | return avg
74 |
75 | def get_last_iteration(self):
76 | if len(self.time_pairs) == 0:
77 | return -1
78 | return self.time_pairs[-1][1] - self.time_pairs[-1][0]
79 |
--------------------------------------------------------------------------------
/HCFlow/HCFlowNet_SR_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from util import opt_get
7 | import Basic, thops
8 |
9 |
10 |
11 | class HCFlowNet_SR(nn.Module):
12 | def __init__(self, opt, step=None):
13 | super(HCFlowNet_SR, self).__init__()
14 | self.opt = opt
15 | self.quant = opt_get(opt, ['quant'], 256)
16 |
17 | hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'], 160)
18 | hr_channel = opt_get(opt, ['network_G', 'in_nc'], 3)
19 | scale = opt_get(opt, ['scale'])
20 |
21 | if scale == 4:
22 | from FlowNet_SR_x4 import FlowNet
23 | else:
24 | raise NotImplementedError('Scale {} is not implemented'.format(scale))
25 |
26 | # hr->lr+z
27 | self.flow = FlowNet((hr_size, hr_size, hr_channel), opt=opt)
28 |
29 | self.quantization = Basic.Quantization()
30 |
31 | # hr: HR image, lr: LR image, z: latent variable, u: conditional variable
32 | def forward(self, hr=None, lr=None, z=None, u=None, eps_std=None,
33 | add_gt_noise=False, step=None, reverse=False, training=True):
34 |
35 | # hr->z
36 | if not reverse:
37 | return self.normal_flow_diracLR(hr, lr, u, step=step, training=training)
38 | # z->hr
39 | else:
40 | return self.reverse_flow_diracLR(lr, z, u, eps_std=eps_std, training=training)
41 |
42 |
43 | #########################################diracLR
44 | # hr->lr+z, diracLR
45 | def normal_flow_diracLR(self, hr, lr, u=None, step=None, training=True):
46 | # 1. quantitize HR
47 | pixels = thops.pixels(hr)
48 |
49 | # according to Glow and ours, it should be u~U(0,a) (0.06 better in practice), not u~U(-0.5,0.5) (though better in theory)
50 | hr = hr + (torch.rand(hr.shape, device=hr.device)) / self.quant
51 | logdet = torch.zeros_like(hr[:, 0, 0, 0]) + float(-np.log(self.quant) * pixels)
52 |
53 | # 2. hr->lr+z
54 | fake_lr_from_hr, logdet = self.flow(hr=hr, u=u, logdet=logdet, reverse=False, training=training)
55 |
56 | # note in rescaling, we use LR for LR loss before quantization
57 | fake_lr_from_hr = self.quantization(fake_lr_from_hr)
58 |
59 | # 3. loss, Gaussian with small variance to approximate Dirac delta function of LR.
60 | # for the second term, using small log-variance may lead to svd problem, for both exp and tanh version
61 | objective = logdet + Basic.GaussianDiag.logp(lr, -torch.ones_like(lr)*6, fake_lr_from_hr)
62 |
63 | nll = ((-objective) / float(np.log(2.) * pixels)).mean()
64 |
65 | return torch.clamp(fake_lr_from_hr, 0, 1), nll
66 |
67 | # lr+z->hr
68 | def reverse_flow_diracLR(self, lr, z, u, eps_std, training=True):
69 |
70 | # lr+z->hr
71 | fake_hr = self.flow(z=lr, u=u, eps_std=eps_std, reverse=True, training=training)
72 |
73 | return torch.clamp(fake_hr, 0, 1)
74 |
--------------------------------------------------------------------------------
/SwinIR/test_classical_swinir.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.backends import cudnn
4 | from tqdm import tqdm
5 |
6 | import utils
7 | from swinir_model import SwinIR
8 | from PIL import Image
9 | from imresize import imresize
10 | import os
11 |
12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13 | cudnn.benchmark = True
14 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15 |
16 | if __name__ == '__main__':
17 | scale = 4
18 | model_name = 'SwinIR-Classical'
19 | weight_file = '../weight_file/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth'
20 | dataset = 'degraded_srflow'
21 | root_dir = f'../datasets/{dataset}/'
22 | out_root_dir = f'./test_res/{model_name}_{dataset}/'
23 | hr_dir = '/data0/jli/datasets/PIPAL/'
24 |
25 | lr_dirs = os.listdir(root_dir)
26 | for dir in lr_dirs:
27 | utils.mkdirs(out_root_dir + dir)
28 |
29 | model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8,
30 | img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
31 | mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv').to(device)
32 | checkpoint = torch.load(weight_file)
33 | model.load_state_dict(checkpoint['params'])
34 | model.eval()
35 |
36 |
37 | for dir in lr_dirs:
38 | outputs_dir = out_root_dir + dir + '/'
39 | lr_dir = root_dir + dir + '/'
40 | lr_lists = os.listdir(lr_dir)
41 | with tqdm(total=len(lr_lists)) as t:
42 | t.set_description(f"Processing: {dir}")
43 | for imgName in lr_lists:
44 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
45 | if image.mode != 'L':
46 | image = image.convert('RGB')
47 | lr_image = np.array(image)
48 |
49 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
50 | lr /= 255.
51 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
52 |
53 | with torch.no_grad():
54 | SR = model(lr)
55 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
56 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
57 | # GPU tensor -> CPU tensor -> numpy
58 | SR = np.array(SR).astype(np.uint8)
59 | SR = Image.fromarray(SR) # hw -> wh
60 | SR.save(outputs_dir + imgName)
61 | t.update(1)
62 |
63 | sr_dirs = os.listdir(out_root_dir)
64 | for dir in sr_dirs:
65 | sr_dir = out_root_dir + dir + '/'
66 | sr_lists = os.listdir(sr_dir)
67 | with tqdm(total=len(sr_lists)) as t:
68 | t.set_description(f"Processing: {dir}")
69 | for imgName in sr_lists:
70 | image = utils.loadIMG_crop(sr_dir + imgName, scale)
71 | if image.mode != 'L':
72 | image = image.convert('RGB')
73 | lr_image = np.array(image)
74 |
75 |
76 |
--------------------------------------------------------------------------------
/SRCNN/gen_datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 | from tqdm import tqdm
5 | import utils
6 | from imresize import imresize
7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
8 |
9 |
10 | def gen_traindata(config):
11 | scale = config["scale"]
12 | stride = config["stride"]
13 | h5savepath = config["hrDir"] + f'_train_SRCNNx{scale}.h5'
14 | hrDir = config["hrDir"] + '/'
15 | size_input = config["size_input"]
16 | size_label = config["size_label"]
17 | padding = int(abs(size_input - size_label) / 2)
18 |
19 | h5_file = h5py.File(h5savepath, 'w')
20 | imgList = os.listdir(hrDir)
21 | lr_subimgs = []
22 | hr_subimgs = []
23 | with tqdm(total=len(imgList)) as t:
24 | for imgName in imgList:
25 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale)
26 | hr_y = utils.img2ycbcr(hrIMG)[..., 0]
27 | lr = imresize(hr_y, 1 / scale, 'bicubic')
28 | lr = imresize(lr, scale, 'bicubic').astype(np.float32)
29 |
30 | for r in range(0, hr_y.shape[0] - size_input + 1, stride): # hr.height
31 | for c in range(0, hr_y.shape[1] - size_input + 1, stride): # hr.width
32 | lr_subimgs.append(lr[r: r + size_input, c: c + size_input])
33 | hr_subimgs.append(hr_y[r + padding: r + padding + size_label, c + padding: c + padding + size_label])
34 | t.update(1)
35 |
36 | lr_subimgs = np.array(lr_subimgs)
37 | hr_subimgs = np.array(hr_subimgs)
38 |
39 | h5_file.create_dataset('data', data=lr_subimgs)
40 | h5_file.create_dataset('label', data=hr_subimgs)
41 |
42 | h5_file.close()
43 |
44 |
45 | def gen_valdata(config):
46 | scale = config["scale"]
47 | h5savepath = config["hrDir"] + f'_val_SRCNNx{scale}.h5'
48 | hrDir = config["hrDir"] + '/'
49 | size_input = config["size_input"]
50 | size_label = config["size_label"]
51 | padding = int(abs(size_input - size_label) / 2)
52 |
53 | h5_file = h5py.File(h5savepath, 'w')
54 | lr_group = h5_file.create_group('data')
55 | hr_group = h5_file.create_group('label')
56 | imgList = os.listdir(hrDir)
57 | for i, imgName in enumerate(imgList):
58 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale)
59 | hr_y = utils.img2ycbcr(hrIMG)[..., 0]
60 |
61 | data = imresize(hr_y, 1 / scale, 'bicubic')
62 | data = imresize(data, scale, 'bicubic')
63 | data = data.astype(np.float32)
64 | label = hr_y[padding: -padding, padding: -padding]
65 | lr_group.create_dataset(str(i), data=data)
66 | hr_group.create_dataset(str(i), data=label)
67 |
68 | h5_file.close()
69 |
70 |
71 | if __name__ == '__main__':
72 | # config = {'hrDir': '../datasets/T91_aug', 'scale': 3, "stride": 14, "size_input": 33, "size_label": 21}
73 | config = {'hrDir': '../datasets/T91_aug', 'scale': 2, "stride": 14, "size_input": 22, "size_label": 10}
74 | gen_traindata(config)
75 | config['hrDir'] = '../datasets/Set5'
76 | gen_valdata(config)
77 |
--------------------------------------------------------------------------------
/BSRGAN/test_bsrgan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.backends import cudnn
4 | from tqdm import tqdm
5 | import utils
6 | from bsrgan_model import RRDBNet
7 | from PIL import Image
8 | import os
9 | import torch.nn.functional as F
10 |
11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
12 | cudnn.benchmark = True
13 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14 | if __name__ == '__main__':
15 | model_name = 'BSRGAN'
16 | scale = 4
17 | itplt = 'bilinear'
18 | # itplt = 'bicubic'
19 | # dataset = f'degraded5_{itplt}_medium'
20 | # dataset = f'degraded6_{itplt}_slight'
21 | # dataset = f'degraded4_{itplt}_heavy'
22 | # dataset = f'degraded1_offset2'
23 | dataset = f'degraded1_bilinear_heavy'
24 |
25 | weight_file = '../weight_file/BSRGAN.pth'
26 | root_dir = f'../datasets/{dataset}/'
27 | out_root_dir_ImgName = f'./test_res/{model_name}_{dataset}/sort_with_ImgName/'
28 | out_root_dir_degradation = f'./test_res/{model_name}_{dataset}/sort_with_degradation/'
29 | hr_dir = '../datasets/PIPAL/'
30 |
31 | lr_dirs = os.listdir(root_dir)
32 | # Sort with ImgName
33 | out1_dirs = os.listdir(hr_dir)
34 | for dir in out1_dirs:
35 | dir = dir.split('.')[0] + '/'
36 | utils.mkdirs(out_root_dir_ImgName + dir)
37 | # Sort with degradation
38 | for dir in lr_dirs:
39 | utils.mkdirs(out_root_dir_degradation + dir)
40 |
41 | model = RRDBNet().to(device)
42 | checkpoint = torch.load(weight_file)
43 | model.load_state_dict(checkpoint)
44 | model.eval()
45 |
46 | for dir in lr_dirs:
47 | out_by_deg = out_root_dir_degradation + dir + '/'
48 | lr_dir = root_dir + dir + '/'
49 | lr_lists = os.listdir(lr_dir)
50 | with tqdm(total=len(lr_lists)) as t:
51 | t.set_description(f"Processing: {dir}")
52 | for imgName in lr_lists:
53 | out_by_name = out_root_dir_ImgName + imgName.split('.')[0] + '/'
54 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
55 | # image = utils.ImgOffSet(image, offset, offset)
56 | if image.mode == 'L':
57 | image = image.convert('RGB')
58 | lr_image = np.array(image)
59 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
60 | lr /= 255.
61 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
62 | lr = F.pad(lr, pad=[1, 1, 1, 1], mode='constant')
63 | with torch.no_grad():
64 | SR = model(lr)
65 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
66 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
67 | # GPU tensor -> CPU tensor -> numpy
68 | SR = np.array(SR).astype(np.uint8)
69 | SR = Image.fromarray(SR) # hw -> wh
70 | SR.save(out_by_name + dir + '.bmp')
71 | SR.save(out_by_deg + imgName)
72 | t.update(1)
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/SRResNet/train2.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | from tqdm import tqdm
3 | import model_utils
4 | import utils
5 | import os
6 | import torch
7 | from torch.backends import cudnn
8 | from model import G
9 | from torch import nn, optim
10 | from SRResNetdatasets import SRResNetValDataset, DIV2KDataset, DIV2KSubDataset
11 | from torch.utils.data.dataloader import DataLoader
12 |
13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
14 |
15 |
16 | def train_model(config, from_pth=False):
17 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
18 |
19 | outputs_dir = config['outputs_dir']
20 | batch_size = config['batch_size']
21 | utils.mkdirs(outputs_dir)
22 | csv_file = outputs_dir + config['csv_name']
23 | logs_dir = config['logs_dir']
24 | cudnn.benchmark = True
25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26 | torch.manual_seed(config['seed'])
27 |
28 | # ----需要修改部分------
29 | print("===> Loading datasets")
30 | train_dataset = DIV2KSubDataset()
31 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
32 | batch_size=batch_size, shuffle=True, pin_memory=True)
33 | val_dataset = SRResNetValDataset(config['val_file'])
34 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
35 |
36 | print("===> Building model")
37 | model = G()
38 | if not from_pth:
39 | model.init_weight()
40 | criterion = nn.MSELoss().cuda()
41 | optimizer = optim.Adam(model.parameters(), lr=config['lr'])
42 | # ----END------
43 | start_step, best_step, best_psnr, writer, csv_file = \
44 | model_utils.load_checkpoint_iter(config['weight_file'], model, optimizer, csv_file,
45 | from_pth, auto_lr=config['auto_lr'])
46 |
47 | if torch.cuda.device_count() > 1:
48 | print("Using GPUs.\n")
49 | model = torch.nn.DataParallel(model)
50 | model = model.to(device)
51 |
52 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"),
53 | 'test': SummaryWriter(f"{logs_dir}/test")}
54 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
55 | num_steps = config['num_steps']
56 | iter_of_epoch = 1000
57 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 'milestone': [2e5],
58 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps,
59 | 'best_psnr': best_psnr, 'best_step': best_step, 'iter_of_epoch': iter_of_epoch}
60 |
61 | while global_info['step'] < config['num_steps']:
62 | with tqdm(total=len(train_dataset)) as t:
63 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1))
64 | global_info['t'] = t
65 | model_utils.train_iter(model, dataloaders, optimizer, criterion, global_info)
66 | print('best step: {}, psnr: {:.2f}'.format(global_info['best_step'], global_info['best_psnr']))
67 | csv_file.close()
68 |
--------------------------------------------------------------------------------
/SRFlow/SRFlow_DF2K_4X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SRFlow
21 | distortion: sr
22 | scale: 4
23 | gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/DF2K-tr.pklv4
31 | dataroot_LQ: ../datasets/DF2K-tr_X4.pklv4
32 | quant: 32
33 |
34 | use_shuffle: true
35 | n_workers: 3 # per GPU
36 | batch_size: 12
37 | GT_size: 160
38 | use_flip: true
39 | color: RGB
40 | val:
41 | name: CelebA_160_va
42 | mode: LRHR_PKL
43 | dataroot_GT: ../datasets/DIV2K-va.pklv4
44 | dataroot_LQ: ../datasets/DIV2K-va_X4.pklv4
45 | quant: 32
46 | n_max: 20
47 |
48 | #### Test Settings
49 | dataroot_GT: ../datasets/div2k-validation-modcrop8-gt
50 | dataroot_LR: ../datasets/div2k-validation-modcrop8-x4
51 | model_path: ../pretrained_models/SRFlow_DF2K_4X.pth
52 | heat: 0.9 # This is the standard deviation of the latent vectors
53 |
54 | #### network structures
55 | network_G:
56 | which_model_G: SRFlowNet
57 | in_nc: 3
58 | out_nc: 3
59 | nf: 64
60 | nb: 23
61 | upscale: 4
62 | train_RRDB: false
63 | train_RRDB_delay: 0.5
64 |
65 | flow:
66 | K: 16
67 | L: 3
68 | noInitialInj: true
69 | coupling: CondAffineSeparatedAndCond
70 | additionalFlowNoAffine: 2
71 | split:
72 | enable: true
73 | fea_up0: true
74 | stackRRDB:
75 | blocks: [ 1, 8, 15, 22 ]
76 | concat: true
77 |
78 | #### path
79 | path:
80 | pretrain_model_G: ../pretrained_models/RRDB_DF2K_4X.pth
81 | strict_load: true
82 | resume_state: auto
83 |
84 | #### training settings: learning rate scheme, loss
85 | train:
86 | manual_seed: 10
87 | lr_G: !!float 2.5e-4
88 | weight_decay_G: 0
89 | beta1: 0.9
90 | beta2: 0.99
91 | lr_scheme: MultiStepLR
92 | warmup_iter: -1 # no warm up
93 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
94 | lr_gamma: 0.5
95 |
96 | niter: 200000
97 | val_freq: 40000
98 |
99 | #### validation settings
100 | val:
101 | heats: [ 1.0 ]
102 | n_sample: 1
103 |
104 | #### logger
105 | logger:
106 | print_freq: 100
107 | save_checkpoint_freq: !!float 1e3
108 |
--------------------------------------------------------------------------------
/HCFlow/test_hcflow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 |
4 | import options
5 | import torch
6 | from torch.backends import cudnn
7 | import utils
8 | from HCFlow_SR_model import HCFlowSRModel
9 | from PIL import Image
10 | import os
11 |
12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13 | if __name__ == '__main__':
14 | model_name = 'HCFlowSR'
15 | weight_file = '../weight_file/SR_DF2K_X4_HCFlow++.pth'
16 | root_dir = '../datasets/degraded_srflow/'
17 | out_root_dir = f'../test_res/{model_name}/'
18 | hr_dir = '../datasets/PIPAL/'
19 |
20 | lr_dirs = os.listdir(root_dir)
21 | for dir in lr_dirs:
22 | utils.mkdirs(out_root_dir + dir)
23 |
24 | scale = 4
25 | padding = scale
26 |
27 | if not os.path.exists(weight_file):
28 | print(f'Weight file not exist!\n{weight_file}\n')
29 | raise "Error"
30 |
31 | cudnn.benchmark = True
32 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
33 | opt = options.parse('./test_SR_DF2K_4X_HCFlow.yml', is_train=False)
34 | opt['gpu_ids'] = '0'
35 | opt = options.dict_to_nonedict(opt)
36 | heat = 0.8
37 | model = HCFlowSRModel(opt, [heat])
38 | checkpoint = torch.load(weight_file)
39 | model.netG.load_state_dict(checkpoint)
40 | offset = 0
41 |
42 | for dir in lr_dirs:
43 | outputs_dir = out_root_dir + dir + '/'
44 | lr_dir = root_dir + dir + '/'
45 | lr_lists = os.listdir(lr_dir)
46 | with tqdm(total=len(lr_lists)) as t:
47 | t.set_description(f"Processing: {dir}")
48 | for imgName in lr_lists:
49 | # outputs_dir = out_root_dir + imgName.split('.')[0] + '/'
50 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
51 | image = utils.ImgOffSet(image, offset, offset)
52 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale)
53 | hr_image = utils.ImgOffSet(hr_image, offset*scale, offset*scale)
54 |
55 | img_mode = image.mode
56 | if img_mode == 'L':
57 | gray_img = np.array(image)
58 | image = image.convert('RGB')
59 | lr_image = np.array(image)
60 | hr_image = np.array(hr_image)
61 |
62 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
63 | lr /= 255.
64 | lr = torch.from_numpy(lr).unsqueeze(0)
65 |
66 | data = {'LQ': lr}
67 | model.feed_data(data, need_GT=False)
68 | model.test()
69 | visuals = model.get_current_visuals(need_GT=False)
70 | for sample in range(len(visuals)-1):
71 | SR = visuals['SR', heat, sample]
72 | SR = SR.mul(255.0).cpu().numpy()
73 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
74 |
75 | # GPU tensor -> CPU tensor -> numpy
76 | output = np.array(SR).astype(np.uint8)
77 | output = Image.fromarray(output) # hw -> wh
78 | tmp = outputs_dir + f'sample={sample}/'
79 | utils.mkdirs(tmp)
80 | output.save(tmp + imgName)
81 | t.update(1)
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/SRCNN/train.py:
--------------------------------------------------------------------------------
1 | import csv
2 |
3 | from torch.utils.tensorboard import SummaryWriter
4 | from tqdm import tqdm
5 | import utils
6 | import os
7 | import torch
8 | from torch.backends import cudnn
9 | from models import SRCNN
10 | from torch import nn, optim
11 | from SRCNNdatasets import TrainDataset, ValDataset
12 | from torch.utils.data.dataloader import DataLoader
13 | import model_utils
14 |
15 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
16 |
17 |
18 | def train_model(config, from_pth=False):
19 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
20 |
21 | outputs_dir = config['outputs_dir']
22 | lr = config['lr']
23 | batch_size = config['batch_size']
24 | num_epochs = config['num_epochs']
25 | csv_file = outputs_dir + config['csv_name']
26 | logs_dir = config['logs_dir']
27 | utils.mkdirs(outputs_dir)
28 | utils.mkdirs(logs_dir)
29 |
30 | cudnn.benchmark = True
31 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
32 | torch.manual_seed(config['seed'])
33 |
34 | model = SRCNN()
35 | if not from_pth:
36 | model.init_weights()
37 | criterion = nn.MSELoss()
38 | optimizer = optim.SGD([
39 | {'params': model.conv1.parameters()},
40 | {'params': model.conv2.parameters()},
41 | {'params': model.conv3.parameters(), 'lr': lr * 0.1}
42 | ], lr=lr, momentum=0.9) # 前两层学习率lr, 最后一层学习率lr*0.1
43 |
44 | train_dataset = TrainDataset(config['train_file'])
45 | train_dataloader = DataLoader(dataset=train_dataset,
46 | batch_size=batch_size,
47 | shuffle=True,
48 | num_workers=config['num_workers'],
49 | pin_memory=True,
50 | drop_last=True)
51 | val_dataset = ValDataset(config['val_file'])
52 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
53 |
54 | start_epoch, best_epoch, best_psnr, writer, csv_file = \
55 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, from_pth)
56 |
57 | if torch.cuda.device_count() > 1:
58 | print("Using GPUs.\n")
59 | model = torch.nn.DataParallel(model)
60 | model = model.to(device)
61 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar")
62 |
63 | for epoch in range(start_epoch, num_epochs):
64 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
65 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}')
66 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t)
67 |
68 | if isinstance(model, torch.nn.DataParallel):
69 | model = model.module
70 |
71 | epoch_psnr = model_utils.validate(model, val_dataloader, device)
72 |
73 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch)
74 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch)
75 |
76 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses,
77 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer)
78 |
79 | csv_file.close()
80 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/ESRGAN/test_realesrgan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.backends import cudnn
4 | from tqdm import tqdm
5 | import utils
6 | from realesrgan_model import G
7 | from PIL import Image
8 | import os
9 | import torch.nn.functional as F
10 | from imresize import imresize
11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
12 | if __name__ == '__main__':
13 | model_name = 'RealESRGAN'
14 | scale = 2
15 | itplt = 'bilinear'
16 | # itplt = 'bicubic'
17 | # dataset = f'degraded5_{itplt}_medium'
18 | # dataset = f'degraded6_{itplt}_slight'
19 | # dataset = f'degraded4_{itplt}_heavy'
20 | # dataset = f'degraded1_bilinear_heavy'
21 | # dataset = f'degraded1_offset_bilinear_heavy'
22 | dataset = f'PIPAL_offset_test'
23 |
24 | weight_file = f'../weight_file/RealESRGAN_x{scale}plus.pth'
25 | root_dir = f'../datasets/{dataset}/'
26 | out_root_dir_ImgName = f'./test_res/{model_name}_{dataset}/sort_with_ImgName/'
27 | out_root_dir_degradation = f'./test_res/{model_name}_{dataset}/sort_with_degradation/'
28 | hr_dir = '../datasets/PIPAL/'
29 |
30 | lr_dirs = os.listdir(root_dir)
31 | # Sort with ImgName
32 | out1_dirs = os.listdir(hr_dir)
33 | for dir in out1_dirs:
34 | dir = dir.split('.')[0] + '/'
35 | utils.mkdirs(out_root_dir_ImgName + dir)
36 | # Sort with degradation
37 | for dir in lr_dirs:
38 | utils.mkdirs(out_root_dir_degradation + dir)
39 |
40 | offset = 1
41 |
42 | if not os.path.exists(weight_file):
43 | print(f'Weight file not exist!\n{weight_file}\n')
44 | raise "Error"
45 |
46 | cudnn.benchmark = True
47 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
48 |
49 | model = G(scale=scale).to(device)
50 | checkpoint = torch.load(weight_file)
51 | model.load_state_dict(checkpoint['params_ema'])
52 | model.eval()
53 | total_dir = len(lr_dirs)
54 | for idx, dir in enumerate(lr_dirs):
55 | out_by_deg = out_root_dir_degradation + dir + '/'
56 | lr_dir = root_dir + dir + '/'
57 | lr_lists = os.listdir(lr_dir)
58 | with tqdm(total=len(lr_lists)) as t:
59 | t.set_description(f"Processing {idx}/{total_dir}: {dir}")
60 | for imgName in lr_lists:
61 | out_by_name = out_root_dir_ImgName + imgName.split('.')[0] + '/'
62 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
63 | # image = utils.ImgOffSet(image, offset, offset)
64 | if image.mode == 'L':
65 | image = image.convert('RGB')
66 | lr_image = np.array(image)
67 | lr_image = imresize(lr_image, 1. / scale, 'bilinear')
68 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
69 | lr /= 255.
70 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
71 | lr = F.pad(lr, pad=[1, 1, 1, 1], mode='constant')
72 | with torch.no_grad():
73 | SR = model(lr)
74 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
75 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
76 | # GPU tensor -> CPU tensor -> numpy
77 | SR = np.array(SR).astype(np.uint8)
78 | SR = Image.fromarray(SR) # hw -> wh
79 | # SR.save(out_by_name + dir + '.bmp')
80 | SR.save(out_by_deg + imgName)
81 | t.update(1)
--------------------------------------------------------------------------------
/SRFlow/Split.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | from torch import nn as nn
19 |
20 | import thops
21 | from flow import Conv2dZeros, GaussianDiag
22 |
23 |
24 | class Split2d(nn.Module):
25 | def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
26 | super().__init__()
27 |
28 | self.num_channels_consume = int(round(num_channels * consume_ratio))
29 | self.num_channels_pass = num_channels - self.num_channels_consume
30 |
31 | self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
32 | out_channels=self.num_channels_consume * 2)
33 | self.logs_eps = logs_eps
34 | self.position = position
35 | self.opt = opt
36 |
37 | def split2d_prior(self, z, ft):
38 | if ft is not None:
39 | z = torch.cat([z, ft], dim=1)
40 | h = self.conv(z)
41 | return thops.split_feature(h, "cross")
42 |
43 | def exp_eps(self, logs):
44 | return torch.exp(logs) + self.logs_eps
45 |
46 | def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
47 | if not reverse:
48 | # self.input = input
49 | z1, z2 = self.split_ratio(input)
50 | mean, logs = self.split2d_prior(z1, ft)
51 |
52 | eps = (z2 - mean) / self.exp_eps(logs)
53 |
54 | logdet = logdet + self.get_logdet(logs, mean, z2)
55 |
56 | # print(logs.shape, mean.shape, z2.shape)
57 | # self.eps = eps
58 | # print('split, enc eps:', eps)
59 | return z1, logdet, eps
60 | else:
61 | z1 = input
62 | mean, logs = self.split2d_prior(z1, ft)
63 |
64 | if eps is None:
65 | #print("WARNING: eps is None, generating eps untested functionality!")
66 | eps = GaussianDiag.sample_eps(mean.shape, eps_std)
67 |
68 | eps = eps.to(mean.device)
69 | z2 = mean + self.exp_eps(logs) * eps
70 |
71 | z = thops.cat_feature(z1, z2)
72 | logdet = logdet - self.get_logdet(logs, mean, z2)
73 |
74 | return z, logdet
75 | # return z, logdet, eps
76 |
77 | def get_logdet(self, logs, mean, z2):
78 | logdet_diff = GaussianDiag.logp(mean, logs, z2)
79 | # print("Split2D: logdet diff", logdet_diff.item())
80 | return logdet_diff
81 |
82 | def split_ratio(self, input):
83 | z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
84 | return z1, z2
--------------------------------------------------------------------------------
/ESRGAN/test_realesrgan_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import metrics
3 | import numpy as np
4 | import torch
5 | from torch.backends import cudnn
6 | from tqdm import tqdm
7 | import utils
8 | from realesrgan_model import G
9 | from PIL import Image
10 | import os
11 |
12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13 | if __name__ == '__main__':
14 | model_name = 'RealESRGAN'
15 | itplt = 'bilinear'
16 | # itplt = 'bicubic'
17 | dataset = f'degraded4_{itplt}_heavy'
18 | # dataset = f'degraded5_{itplt}_medium'
19 | # dataset = f'degraded6_{itplt}_slight'
20 | csv_file = f'./test_res/{model_name}_{dataset}.csv'
21 | csv_file = open(csv_file, 'w', newline='')
22 | writer = csv.writer(csv_file)
23 | writer.writerow(('name', 'psnr', 'niqe', 'ssim', 'lpips'))
24 |
25 | weight_file = '/data0/jli/project/BasicSR-iqa/experiments/pretrained_models/RealESRGAN_x4plus.pth'
26 | root_dir = f'/data0/jli/datasets/{dataset}/'
27 | out_root_dir = f'./test_res/{model_name}_{dataset}_metrics/'
28 | hr_dir = '/data0/jli/datasets/PIPAL/'
29 |
30 | lr_dirs = os.listdir(root_dir)
31 | for dir in lr_dirs:
32 | utils.mkdirs(out_root_dir + dir)
33 |
34 | scale = 4
35 | padding = scale
36 |
37 | if not os.path.exists(weight_file):
38 | print(f'Weight file not exist!\n{weight_file}\n')
39 | raise "Error"
40 |
41 | cudnn.benchmark = True
42 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
43 |
44 | model = G().to(device)
45 | checkpoint = torch.load(weight_file)
46 | model.load_state_dict(checkpoint['params_ema'])
47 |
48 | model.eval()
49 |
50 | for dir in lr_dirs:
51 | outputs_dir = out_root_dir + dir + '/'
52 | lr_dir = root_dir + dir + '/'
53 | lr_lists = os.listdir(lr_dir)
54 | Avg_psnr = utils.AverageMeter()
55 | Avg_niqe = utils.AverageMeter()
56 | Avg_ssim = utils.AverageMeter()
57 | Avg_lpips = utils.AverageMeter()
58 | with tqdm(total=len(lr_lists)) as t:
59 | t.set_description(f"Processing: {dir}")
60 | for imgName in lr_lists:
61 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
62 | GT = utils.loadIMG_crop(hr_dir + imgName, scale)
63 | img_mode = image.mode
64 | if image.mode == 'L':
65 | image = image.convert('RGB')
66 | GT = GT.convert('RGB')
67 |
68 | lr_image = np.array(image)
69 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
70 | lr /= 255.
71 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
72 |
73 | with torch.no_grad():
74 | SR = model(lr)
75 |
76 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
77 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
78 |
79 | SR = np.array(SR).astype(np.uint8)
80 | SR = Image.fromarray(SR) # hw -> wh
81 | SR.save(outputs_dir + imgName)
82 |
83 | metric = metrics.calc_metric(SR, GT, ['psnr', 'ssim', 'niqe', 'lpips'])
84 | Avg_lpips.update(metric['lpips'], 1)
85 | Avg_psnr.update(metric['psnr'], 1)
86 | Avg_niqe.update(metric['niqe'], 1)
87 | Avg_ssim.update(metric['ssim'], 1)
88 | t.update(1)
89 |
90 | writer.writerow((dir, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg, Avg_lpips.avg))
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/HCFlow/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """Charbonnier Loss (L1)"""
7 |
8 | def __init__(self, eps=1e-6):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = eps
11 |
12 | def forward(self, x, y):
13 | diff = x - y
14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps))
15 | return loss
16 |
17 |
18 | # Define GAN loss: [gan(vanilla) | lsgan | wgan-gp | ragan]
19 | class GANLoss(nn.Module):
20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
21 | super(GANLoss, self).__init__()
22 | self.gan_type = gan_type.lower()
23 | self.real_label_val = real_label_val
24 | self.fake_label_val = fake_label_val
25 |
26 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
27 | self.loss = nn.BCEWithLogitsLoss()
28 | elif self.gan_type == 'lsgan':
29 | self.loss = nn.MSELoss()
30 | elif self.gan_type == 'wgan-gp':
31 |
32 | def wgan_loss(input, target):
33 | # target is boolean
34 | return -1 * input.mean() if target else input.mean()
35 |
36 | self.loss = wgan_loss
37 | else:
38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
39 |
40 | def get_target_label(self, input, target_is_real):
41 | if self.gan_type == 'wgan-gp':
42 | return target_is_real
43 | if target_is_real:
44 | return torch.empty_like(input).fill_(self.real_label_val)
45 | else:
46 | return torch.empty_like(input).fill_(self.fake_label_val)
47 |
48 | def forward(self, input, target_is_real):
49 | target_label = self.get_target_label(input, target_is_real)
50 | loss = self.loss(input, target_label)
51 | return loss
52 |
53 |
54 | class GradientPenaltyLoss(nn.Module):
55 | def __init__(self, device=torch.device('cpu')):
56 | super(GradientPenaltyLoss, self).__init__()
57 | self.register_buffer('grad_outputs', torch.Tensor())
58 | self.grad_outputs = self.grad_outputs.to(device)
59 |
60 | def get_grad_outputs(self, input):
61 | if self.grad_outputs.size() != input.size():
62 | self.grad_outputs.resize_(input.size()).fill_(1.0)
63 | return self.grad_outputs
64 |
65 | def forward(self, interp, interp_crit):
66 | grad_outputs = self.get_grad_outputs(interp_crit)
67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
68 | grad_outputs=grad_outputs, create_graph=True,
69 | retain_graph=True, only_inputs=True)[0]
70 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
71 | grad_interp_norm = grad_interp.norm(2, dim=1)
72 |
73 | loss = ((grad_interp_norm - 1)**2).mean()
74 | return loss
75 |
76 | class ReconstructionLoss(nn.Module):
77 | def __init__(self, losstype='l2', eps=1e-6):
78 | super(ReconstructionLoss, self).__init__()
79 | self.losstype = losstype
80 | self.eps = eps
81 |
82 | def forward(self, x, target):
83 | if self.losstype == 'l2':
84 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3)))
85 | elif self.losstype == 'l1':
86 | diff = x - target
87 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3)))
88 | else:
89 | print("reconstruction loss type error!")
90 | return 0
91 |
92 |
--------------------------------------------------------------------------------
/FSRCNN/train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | from tqdm import tqdm
3 | import model_utils
4 | import utils
5 | import os
6 | import torch
7 | from torch.backends import cudnn
8 | from model import FSRCNN
9 | from torch import nn, optim
10 | from FSRCNNdatasets import TrainDataset, ValDataset, ResValDataset
11 | from torch.utils.data.dataloader import DataLoader
12 |
13 |
14 | def train_model(config, from_pth=False):
15 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
16 | outputs_dir = config['outputs_dir']
17 | lr = config['lr']
18 | batch_size = config['batch_size']
19 | num_epochs = config['num_epochs']
20 | logs_dir = config['logs_dir']
21 | csv_file = outputs_dir + config['csv_name']
22 |
23 | utils.mkdirs(outputs_dir)
24 | utils.mkdirs(logs_dir)
25 |
26 | cudnn.benchmark = True
27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28 | torch.manual_seed(config['seed'])
29 | # ----需要修改部分------
30 | model = FSRCNN(config['scale'], config['in_size'], config['out_size'],
31 | num_channels=1, d=config['d'], s=config['s'], m=config['m'])
32 | if not from_pth:
33 | model.init_weights(method=config['init'])
34 | criterion = nn.MSELoss().cuda()
35 |
36 | optimizer = optim.SGD([
37 | {'params': model.extract_layer.parameters()},
38 | {'params': model.mid_part.parameters()},
39 | {'params': model.deconv_layer.parameters(), 'lr': lr * 0.1}
40 | ], lr=lr, momentum=0.9) # 前两层学习率lr, 最后一层学习率lr*0.1
41 | train_dataset = TrainDataset(config['train_file'])
42 | train_dataloader = DataLoader(dataset=train_dataset,
43 | batch_size=batch_size,
44 | shuffle=True,
45 | num_workers=config['num_workers'],
46 | pin_memory=True)
47 | if config['residual']:
48 | val_dataset = ResValDataset(config['val_file'])
49 | else:
50 | val_dataset = ValDataset(config['val_file'])
51 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
52 | # ----END------
53 |
54 | start_epoch, best_epoch, best_psnr, writer, csv_file = \
55 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file,
56 | from_pth, config['auto_lr'])
57 |
58 | if torch.cuda.device_count() > 1:
59 | print("Using GPUs.\n")
60 | model = torch.nn.DataParallel(model)
61 | model = model.to(device)
62 |
63 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar")
64 | for epoch in range(start_epoch, num_epochs):
65 | lr = optimizer.state_dict()['param_groups'][0]['lr']
66 | print(f'learning rate: {lr}\n')
67 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
68 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}')
69 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t)
70 |
71 | if isinstance(model, torch.nn.DataParallel):
72 | model = model.module
73 |
74 | epoch_psnr = model_utils.validate(model, val_dataloader, device, config['residual'])
75 |
76 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch)
77 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch)
78 |
79 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses,
80 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer)
81 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
--------------------------------------------------------------------------------
/HCFlow/module_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 | def initialize_weights_xavier(net_l, scale=1):
27 | if not isinstance(net_l, list):
28 | net_l = [net_l]
29 | for net in net_l:
30 | for m in net.modules():
31 | if isinstance(m, nn.Conv2d):
32 | init.xavier_normal_(m.weight)
33 | m.weight.data *= scale # for residual block
34 | if m.bias is not None:
35 | m.bias.data.zero_()
36 | elif isinstance(m, nn.Linear):
37 | init.xavier_normal_(m.weight)
38 | m.weight.data *= scale
39 | if m.bias is not None:
40 | m.bias.data.zero_()
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant_(m.weight, 1)
43 | init.constant_(m.bias.data, 0.0)
44 |
45 |
46 | def make_layer(block, n_layers):
47 | layers = []
48 | for _ in range(n_layers):
49 | layers.append(block())
50 | return nn.Sequential(*layers)
51 |
52 |
53 | class ResidualBlock_noBN(nn.Module):
54 | '''Residual block w/o BN
55 | ---Conv-ReLU-Conv-+-
56 | |________________|
57 | '''
58 |
59 | def __init__(self, nf=64):
60 | super(ResidualBlock_noBN, self).__init__()
61 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
62 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63 |
64 | # initialization
65 | initialize_weights([self.conv1, self.conv2], 0.1)
66 |
67 | def forward(self, x):
68 | identity = x
69 | out = F.relu(self.conv1(x), inplace=True)
70 | out = self.conv2(out)
71 | return identity + out
72 |
73 |
74 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
75 | """Warp an image or feature map with optical flow
76 | Args:
77 | x (Tensor): size (N, C, H, W)
78 | flow (Tensor): size (N, H, W, 2), normal value
79 | interp_mode (str): 'nearest' or 'bilinear'
80 | padding_mode (str): 'zeros' or 'border' or 'reflection'
81 |
82 | Returns:
83 | Tensor: warped image or feature map
84 | """
85 | assert x.size()[-2:] == flow.size()[1:3]
86 | B, C, H, W = x.size()
87 | # mesh grid
88 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
89 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
90 | grid.requires_grad = False
91 | grid = grid.type_as(x)
92 | vgrid = grid + flow
93 | # scale grid to [-1,1]
94 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
95 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
96 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
97 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
98 | return output
99 |
--------------------------------------------------------------------------------
/SRCNN/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.backends import cudnn
4 | import utils
5 | from models import SRCNN
6 | from PIL import Image
7 | import os
8 | from imresize import imresize
9 |
10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
11 | if __name__ == '__main__':
12 | scale = 3
13 | config = {'weight_file': f'./weight_file/SRCNN_x{scale}_lr=1e-02_batch=128/',
14 | 'img_dir': '../datasets/Set5/',
15 | 'outputs_dir': f'./test_res/test_x{scale}_Set5/',
16 | 'visual_filter': False
17 | }
18 |
19 | outputs_dir = config['outputs_dir']
20 |
21 | padding = scale
22 | # weight_file = config['weight_file'] + f'best.pth'
23 | # weight_file = config['weight_file'] + f'SRCNNx3_data=276864_lr=1e-2.pth'
24 | weight_file = config['weight_file'] + f'x{scale}/best.pth'
25 | img_dir = config['img_dir']
26 | outputs_dir = outputs_dir + f'x{scale}/'
27 | utils.mkdirs(outputs_dir)
28 | if not os.path.exists(weight_file):
29 | print(f'Weight file not exist!\n{weight_file}\n')
30 | raise "Error"
31 | if not os.path.exists(img_dir):
32 | raise "Image file not exist!\n"
33 |
34 | cudnn.benchmark = True
35 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
36 |
37 | model = SRCNN(padding=True).to(device)
38 | checkpoint = torch.load(weight_file)
39 | model.load_state_dict(checkpoint['model'])
40 |
41 | if config['visual_filter']:
42 | ax = utils.viz_layer(model.conv1[0].weight.cpu(), 64)
43 | model.eval()
44 | imglist = os.listdir(img_dir)
45 | Avg_psnr = utils.AverageMeter()
46 | for imgName in imglist:
47 | image = utils.loadIMG_crop(img_dir + imgName, scale)
48 | if image.mode != 'L': # gray image don't need to convert
49 | image = image.convert('RGB')
50 | hr_image = np.array(image)
51 |
52 | lr_image = imresize(hr_image, 1. / scale, 'bicubic')
53 | bic_image = imresize(lr_image, scale, 'bicubic')
54 | bic_pil = Image.fromarray(bic_image.astype(np.uint8)[padding: -padding, padding: -padding, ...])
55 | bic_pil.save(outputs_dir + imgName.replace('.bmp', f'_bicubic_x{scale}.png'))
56 |
57 | bic_y, bic_ycbcr = utils.preprocess(bic_image, device, image.mode)
58 | hr_y, _ = utils.preprocess(hr_image, device, image.mode)
59 | with torch.no_grad():
60 | SR = model(bic_y).clamp(0.0, 1.0)
61 | hr_y = hr_y[..., padding: -padding, padding: -padding]
62 | SR = SR[..., padding: -padding, padding: -padding]
63 | bic_ycbcr = bic_ycbcr[..., padding: -padding, padding: -padding]
64 | bic_y = bic_y[..., padding: -padding, padding: -padding]
65 |
66 | psnr = utils.calc_psnr(hr_y, SR)
67 | psnr2 = utils.calc_psnr(hr_y, bic_y)
68 | Avg_psnr.update(psnr, 1)
69 | # Avg_psnr.update(psnr2, 1)
70 | print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item()))
71 | print(f'{imgName}, ' + 'PSNR_bicubic: {:.2f}'.format(psnr2.item()))
72 | # GPU tensor -> CPU tensor -> numpy
73 | SR = SR.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
74 | if image.mode == 'L':
75 | output = np.clip(SR, 0.0, 255.0).astype(np.uint8) # chw -> hwc
76 | else:
77 | bic_ycbcr = bic_ycbcr.mul(255.0).cpu().numpy().squeeze(0).transpose([1, 2, 0])
78 | output = np.array([SR, bic_ycbcr[..., 1], bic_ycbcr[..., 2]]).transpose([1, 2, 0]) # chw -> hwc
79 | output = np.clip(utils.ycbcr2rgb(output), 0.0, 255.0).astype(np.uint8)
80 | output = Image.fromarray(output) # hw -> wh
81 | output.save(outputs_dir + imgName.replace('.bmp', f'_SRCNNx{scale}.png'))
82 | print('Average_PSNR: {:.2f}'.format(Avg_psnr.avg))
--------------------------------------------------------------------------------
/SRResNet/SRResNetdatasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | import albumentations as A
5 | from imresize import imresize
6 | import h5py
7 | from torch.utils.data import Dataset, DataLoader
8 | import torch
9 |
10 |
11 | class SRResNetTrainDataset(Dataset):
12 | def __init__(self, h5_file):
13 | super(SRResNetTrainDataset, self).__init__()
14 | self.h5_file = h5_file
15 |
16 | def __getitem__(self, idx):
17 | with h5py.File(self.h5_file, 'r') as f:
18 | return torch.from_numpy(f['data'][idx] / 255.), torch.from_numpy(f['label'][idx] / 255. * 2 - 1)
19 |
20 | def __len__(self):
21 | with h5py.File(self.h5_file, 'r') as f:
22 | return len(f['data'])
23 |
24 |
25 | class SRResNetValDataset(Dataset):
26 | def __init__(self, h5_file):
27 | super(SRResNetValDataset, self).__init__()
28 | self.h5_file = h5_file
29 |
30 | def __getitem__(self, idx):
31 | with h5py.File(self.h5_file, 'r') as f:
32 | return torch.from_numpy(f['data'][str(idx)][:, :, :] / 255.), \
33 | torch.from_numpy(f['label'][str(idx)][:, :, :] / 255.)
34 |
35 | def __len__(self):
36 | with h5py.File(self.h5_file, 'r') as f:
37 | return len(f['data'])
38 |
39 |
40 | class DIV2KDataset(Dataset):
41 | def __init__(self, root_dir='../datasets/DIV2K_train_HR/'):
42 | super(DIV2KDataset, self).__init__()
43 | self.data = []
44 | self.img_names = os.listdir(root_dir)
45 |
46 | for name in self.img_names:
47 | self.data.append(root_dir + name)
48 |
49 | def __len__(self):
50 | return len(self.data)
51 |
52 | def __getitem__(self, index):
53 | img_path = self.data[index]
54 |
55 | image = np.array(Image.open(img_path))
56 | transpose = A.RandomCrop(width=96, height=96)
57 | image = transpose(image=image)["image"]
58 | # transpose2 = A.Normalize(0.5, 0.5)
59 | # label = transpose2(image=image)["image"]
60 | # label = torch.from_numpy(label.astype(np.float32).transpose([2, 0, 1]))
61 | label = torch.from_numpy(image.astype(np.float32).transpose([2, 0, 1])/127.5-1)
62 | data = imresize(image, 1 / 4, 'bicubic')
63 | data = torch.from_numpy(data.astype(np.float32).transpose([2, 0, 1]) / 255)
64 | return data, label
65 |
66 |
67 | class DIV2KSubDataset(Dataset):
68 | def __init__(self, hr_dir='../datasets/DIV2K_sub/HR/', lr_dir='../datasets/DIV2K_sub/LRx4/'):
69 | super(DIV2KSubDataset, self).__init__()
70 | self.hr_data = []
71 | self.lr_data = []
72 | self.img_names = os.listdir(hr_dir)
73 |
74 | for name in self.img_names:
75 | self.hr_data.append(hr_dir + name)
76 | self.lr_data.append(lr_dir + name)
77 |
78 | def __len__(self):
79 | return len(self.hr_data)
80 |
81 | def __getitem__(self, index):
82 | hr_img_path = self.hr_data[index]
83 | lr_img_path = self.lr_data[index]
84 |
85 | hr_image = np.array(Image.open(hr_img_path))
86 | lr_image = np.array(Image.open(lr_img_path))
87 | label = torch.from_numpy(hr_image.astype(np.float32).transpose([2, 0, 1])/127.5-1)
88 | data = torch.from_numpy(lr_image.astype(np.float32).transpose([2, 0, 1]) / 255)
89 | return data, label
90 |
91 |
92 | def test():
93 | dataset = DIV2KDataset(root_dir='../datasets/T91/')
94 | loader = DataLoader(dataset, batch_size=1, num_workers=0)
95 |
96 | for low_res, high_res in loader:
97 | print(low_res.shape)
98 | print(high_res.shape)
99 |
100 |
101 | if __name__ == "__main__":
102 | test()
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # PyTorch
2 | import cv2
3 | import torch
4 | import utils
5 | import numpy as np
6 | import lpips
7 | import niqe
8 | calc_lpips = lpips.LPIPS(net='alex', model_path='./weight_file/alexnet-owt-7be5be79.pth')
9 |
10 |
11 | def calc_psnr(img1, img2):
12 | if isinstance(img1, torch.Tensor):
13 | return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
14 | else:
15 | return 10. * np.log10(1. / np.mean((img1 - img2) ** 2))
16 |
17 |
18 | def calculate_ssim(img, img2):
19 | """Calculate SSIM (structural similarity) for one channel images.
20 |
21 | It is called by func:`calculate_ssim`.
22 |
23 | Args:
24 | img (ndarray): Images with range [0, 255] with order 'HWC'.
25 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
26 |
27 | Returns:
28 | float: ssim result.
29 | """
30 |
31 | c1 = (0.01 * 255) ** 2
32 | c2 = (0.03 * 255) ** 2
33 |
34 | img = img.astype(np.float64)
35 | img2 = img2.astype(np.float64)
36 | kernel = cv2.getGaussianKernel(11, 1.5)
37 | window = np.outer(kernel, kernel.transpose())
38 |
39 | mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
40 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
41 | mu1_sq = mu1 ** 2
42 | mu2_sq = mu2 ** 2
43 | mu1_mu2 = mu1 * mu2
44 | sigma1_sq = cv2.filter2D(img ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
45 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
46 | sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
47 |
48 | ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
49 | return ssim_map.mean()
50 |
51 |
52 | def calc_metric(sr_image, gt_image, metric, crop=0):
53 | """
54 | Args:
55 | sr_image: low quality img, [r,g,b] or gray 【0,255】
56 | gt_image: high quality img, [r,g,b] or gray (same size with SR)
57 | metric: string array, 'psnr','ssim','niqe'
58 | crop: crop pixel of border
59 | Returns:
60 | dict of matric with correspond float value
61 | """
62 |
63 | res = dict()
64 | # convert RGBA to RGB
65 | if sr_image.mode != 'L':
66 | sr_image = sr_image.convert('RGB')
67 | gt_image = gt_image.convert('RGB')
68 | # crop border
69 | sr = np.array(sr_image).astype(np.float32)[crop: -max(crop, 1), crop: -max(crop, 1), ...]
70 | gt = np.array(gt_image).astype(np.float32)[crop: -max(crop, 1), crop: -max(crop, 1), ...]
71 | # RGB to YCbCr
72 | if sr_image.mode == 'L':
73 | sr_y = sr / 255.
74 | gt_y = gt / 255.
75 | else:
76 | sr_y = utils.rgb2ycbcr(sr).astype(np.float32)[..., 0] / 255.
77 | gt_y = utils.rgb2ycbcr(gt).astype(np.float32)[..., 0] / 255.
78 | # calculate matrics
79 | if 'psnr' in metric:
80 | res.update({'psnr': calc_psnr(sr_y, gt_y)})
81 | if 'niqe' in metric:
82 | res.update({'niqe': niqe.calculate_niqe(sr_y)})
83 | if 'ssim' in metric:
84 | res.update({'ssim': calculate_ssim(sr_y * 255, gt_y * 255)})
85 | if 'lpips' in metric:
86 | res.update({'lpips': calc_lpips(
87 | torch.from_numpy((sr / 255.).transpose([2, 0, 1])).unsqueeze(0),
88 | torch.from_numpy((gt / 255.).transpose([2, 0, 1])).unsqueeze(0),
89 | normalize=True
90 | ).item()})
91 | return res
92 |
93 | def test_calc_metric():
94 | SR = utils.loadIMG_crop('./test_res/HCFlowSR/bicubicx4/sample=0/A0012.bmp', 4).convert('RGB')
95 | GT = utils.loadIMG_crop('./datasets/PIPAL/A0012.bmp', 4).convert('RGB')
96 | metric = calc_metric(SR, GT, ['psnr', 'ssim', 'niqe', 'lpips'])
97 |
98 | if __name__ == '__main__':
99 | test_calc_metric()
100 |
101 |
102 |
--------------------------------------------------------------------------------
/ESRGAN/train_ESRGAN_iter2.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import model_utils
3 | import utils
4 | import os
5 | import torch
6 | from torch.backends import cudnn
7 | from torch.utils.tensorboard import SummaryWriter
8 | from VGGLoss import VGGLoss
9 | from model import G, D
10 | from torch import nn, optim
11 | from ESRGANdatasets import ESRGANValDataset, ESRGANTrainDataset
12 | from torch.utils.data.dataloader import DataLoader
13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
14 |
15 |
16 | def train_model(config, pre_train, from_pth=False):
17 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
18 | outputs_dir = config['outputs_dir']
19 | batch_size = config['batch_size']
20 | utils.mkdirs(outputs_dir)
21 | csv_file = outputs_dir + config['csv_name']
22 | logs_dir = config['logs_dir']
23 |
24 | cudnn.benchmark = True
25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26 | torch.manual_seed(config['seed'])
27 |
28 | # ----需要修改部分------
29 | print("===> Loading datasets")
30 | train_dataset = ESRGANTrainDataset()
31 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
32 | batch_size=batch_size, shuffle=True, pin_memory=True)
33 | val_dataset = ESRGANValDataset(config['val_file'])
34 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
35 |
36 | print("===> Building model")
37 | gen = G()
38 | disc = D()
39 | if not from_pth:
40 | disc.init_weight()
41 | criterion = {'bce': nn.BCEWithLogitsLoss().to(device), 'pixel_loss': nn.L1Loss().to(device),
42 | 'vgg_loss': VGGLoss(device, nn.L1Loss())}
43 | gen_opt = optim.Adam(gen.parameters(), lr=config['gen_lr'])
44 | disc_opt = optim.Adam(disc.parameters(), lr=config['disc_lr'])
45 | # ----END------
46 | start_step, best_step, best_niqe, writer, csv_file = \
47 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt,
48 | disc, disc_opt, csv_file, from_pth, config['auto_lr'])
49 |
50 | if torch.cuda.device_count() > 1:
51 | print("Using GPUs.\n")
52 | gen = torch.nn.DataParallel(gen)
53 | disc = torch.nn.DataParallel(disc)
54 | gen = gen.to(device)
55 | disc = disc.to(device)
56 |
57 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")}
58 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
59 | num_steps = config['num_steps']
60 | iter_of_epoch = 100
61 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 'milestone': config['milestone'],
62 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps,
63 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe,
64 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'],
65 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']}
66 | log_dict = {'D_losses': 0., 'G_losses': 0.,
67 | 'D_losses_real': 0., 'D_losses_fake': 0.,
68 | 'pixel_loss': 0., 'gan_loss': 0.,
69 | 'percep_loss': 0.,
70 | 'F_prob': 0., 'R_prob': 0.}
71 | while global_info['step'] < config['num_steps']:
72 | with tqdm(total=len(train_dataset)) as t:
73 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1))
74 | global_info['t'] = t
75 | model_utils.train_ESRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict)
76 |
77 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_niqe']))
78 |
--------------------------------------------------------------------------------
/SRFlow/module_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.init as init
20 | import torch.nn.functional as F
21 |
22 |
23 | def initialize_weights(net_l, scale=1):
24 | if not isinstance(net_l, list):
25 | net_l = [net_l]
26 | for net in net_l:
27 | for m in net.modules():
28 | if isinstance(m, nn.Conv2d):
29 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
30 | m.weight.data *= scale # for residual block
31 | if m.bias is not None:
32 | m.bias.data.zero_()
33 | elif isinstance(m, nn.Linear):
34 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
35 | m.weight.data *= scale
36 | if m.bias is not None:
37 | m.bias.data.zero_()
38 | elif isinstance(m, nn.BatchNorm2d):
39 | init.constant_(m.weight, 1)
40 | init.constant_(m.bias.data, 0.0)
41 |
42 |
43 | def make_layer(block, n_layers):
44 | layers = []
45 | for _ in range(n_layers):
46 | layers.append(block())
47 | return nn.Sequential(*layers)
48 |
49 |
50 | class ResidualBlock_noBN(nn.Module):
51 | '''Residual block w/o BN
52 | ---Conv-ReLU-Conv-+-
53 | |________________|
54 | '''
55 |
56 | def __init__(self, nf=64):
57 | super(ResidualBlock_noBN, self).__init__()
58 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
59 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
60 |
61 | # initialization
62 | initialize_weights([self.conv1, self.conv2], 0.1)
63 |
64 | def forward(self, x):
65 | identity = x
66 | out = F.relu(self.conv1(x), inplace=True)
67 | out = self.conv2(out)
68 | return identity + out
69 |
70 |
71 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
72 | """Warp an image or feature map with optical flow
73 | Args:
74 | x (Tensor): size (N, C, H, W)
75 | flow (Tensor): size (N, H, W, 2), normal value
76 | interp_mode (str): 'nearest' or 'bilinear'
77 | padding_mode (str): 'zeros' or 'border' or 'reflection'
78 |
79 | Returns:
80 | Tensor: warped image or feature map
81 | """
82 | assert x.size()[-2:] == flow.size()[1:3]
83 | B, C, H, W = x.size()
84 | # mesh grid
85 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
86 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
87 | grid.requires_grad = False
88 | grid = grid.type_as(x)
89 | vgrid = grid + flow
90 | # scale grid to [-1,1]
91 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
92 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
93 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
94 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
95 | return output
96 |
--------------------------------------------------------------------------------
/SRResNet/gen_datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 | import utils
5 | from imresize import imresize
6 | from tqdm import tqdm
7 | from PIL import Image
8 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
9 |
10 |
11 | def gen_traindata(config):
12 | scale = config["scale"]
13 | stride = config["stride"]
14 | size_input = config["size_input"]
15 | size_label = size_input * scale
16 | size_output = config['size_output']
17 | method = config['method']
18 | h5savepath = config["hrDir"] + f'_label={size_output}_train_SRResNetx{scale}.h5'
19 | hrDir = config["hrDir"] + '/'
20 |
21 | h5_file = h5py.File(h5savepath, 'w')
22 | imgList = os.listdir(hrDir)
23 | lr_subimgs = []
24 | hr_subimgs = []
25 | total = len(imgList)
26 | with tqdm(total=total) as t:
27 | for imgName in imgList:
28 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale)
29 | hr = utils.img2ycbcr(hrIMG, gray2rgb=True)
30 | lr = imresize(hr, 1 / scale, method)
31 | input = lr.astype(np.float32)
32 |
33 | for r in range(0, lr.shape[0] - size_input + 1, stride): # height
34 | for c in range(0, lr.shape[1] - size_input + 1, stride): # width
35 | lr_subimgs.append(input[r: r + size_input, c: c + size_input].transpose([2, 0, 1])) # hwc -> chw
36 | label = hr[r * scale: r * scale + size_label, c * scale: c * scale + size_label].transpose([2, 0, 1])
37 | hr_subimgs.append(label)
38 | t.update(1)
39 |
40 | lr_subimgs = np.array(lr_subimgs).astype(np.float32)
41 | h5_file.create_dataset('data', data=lr_subimgs)
42 | lr_subimgs = []
43 | # HR dataset is too large to convert to ndarray, so convert it by several parts.
44 | num = len(hr_subimgs)
45 | seg = num // scale
46 | hr_imgs = np.array(hr_subimgs[0: seg]).astype(np.float32)
47 | dl = h5_file.create_dataset('label', data=hr_imgs, maxshape=[num, 3, 96, 96])
48 | for i in range(1, scale):
49 | hr_imgs = []
50 | hr_imgs = np.array(hr_subimgs[i*seg: (i+1)*seg]).astype(np.float32)
51 | dl.resize(((i+1)*seg, 3, 96, 96))
52 | dl[i*seg:] = hr_imgs
53 | hr_imgs = []
54 | hr_imgs = np.array(hr_subimgs[scale*seg: num]).astype(np.float32)
55 | dl.resize((num, 3, 96, 96))
56 | dl[scale*seg:] = hr_imgs
57 | hr_imgs = []
58 | h5_file.close()
59 |
60 |
61 | def gen_valdata(config):
62 | scale = config["scale"]
63 | size_input = config["size_input"]
64 | size_label = size_input * scale
65 | size_output = config['size_output']
66 | method = config['method']
67 | h5savepath = config["hrDir"] + f'_label={size_output}_val_SRResNetx{scale}.h5'
68 | hrDir = config["hrDir"] + '/'
69 | h5_file = h5py.File(h5savepath, 'w')
70 | lr_group = h5_file.create_group('data')
71 | hr_group = h5_file.create_group('label')
72 | imgList = os.listdir(hrDir)
73 | for i, imgName in enumerate(imgList):
74 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale)
75 | hr = utils.img2ycbcr(hrIMG, gray2rgb=True).astype(np.float32)
76 | lr = imresize(np.array(hrIMG).astype(np.float32), 1 / scale, method)
77 |
78 | data = lr.astype(np.float32).transpose([2, 0, 1])
79 | label = hr.transpose([2, 0, 1])
80 |
81 | lr_group.create_dataset(str(i), data=data)
82 | hr_group.create_dataset(str(i), data=label)
83 | h5_file.close()
84 |
85 |
86 | if __name__ == '__main__':
87 | # config = {'hrDir': './test/flower', 'scale': 3, "stride": 14, "size_input": 33, "size_label": 21}
88 | config = {'hrDir': '../datasets/291_aug', 'scale': 4, 'stride': 12, "size_input": 24, "size_output": 96, 'method': 'bicubic'}
89 | #gen_traindata(config) # if use DIV2K, don't need to run this line.
90 | config['hrDir'] = '../datasets/Set5'
91 | gen_valdata(config)
92 |
--------------------------------------------------------------------------------
/SRResNet/train_SRGAN_WGAN.py:
--------------------------------------------------------------------------------
1 |
2 | from tqdm import tqdm
3 | import model_utils
4 | import utils
5 | import os
6 | import torch
7 | from torch.backends import cudnn
8 | from torch.utils.tensorboard import SummaryWriter
9 | from VGGLoss import VGGLoss
10 | from model import G, D
11 | from torch import nn, optim
12 | from SRResNetdatasets import SRResNetValDataset, SRResNetTrainDataset, DIV2KDataset, DIV2KSubDataset
13 | from torch.utils.data.dataloader import DataLoader
14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
15 |
16 |
17 | def train_model(config, pre_train, from_pth=False):
18 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
19 | outputs_dir = config['outputs_dir']
20 | batch_size = config['batch_size']
21 | utils.mkdirs(outputs_dir)
22 | csv_file = outputs_dir + config['csv_name']
23 | logs_dir = config['logs_dir']
24 | cudnn.benchmark = True
25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26 | torch.manual_seed(config['seed'])
27 |
28 | # ----需要修改部分------
29 | print("===> Loading datasets")
30 | train_dataset = DIV2KSubDataset()
31 | # train_dataset = DIV2KDataset(config['train_file'])
32 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
33 | batch_size=batch_size, shuffle=True, pin_memory=True)
34 | val_dataset = SRResNetValDataset(config['val_file'])
35 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
36 |
37 | print("===> Building model")
38 | gen = G()
39 | disc = D()
40 | if not from_pth:
41 | disc.init_weight()
42 | criterion = {'pixel_loss': nn.MSELoss().to(device), 'vgg_loss': VGGLoss(device)}
43 | gen_opt = optim.RMSprop(gen.parameters(), lr=config['gen_lr'])
44 | disc_opt = optim.RMSprop(disc.parameters(), lr=config['disc_lr'])
45 | # ----END------
46 | start_step, best_step, best_niqe, writer, csv_file = \
47 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt, disc, disc_opt,
48 | csv_file, from_pth, config['auto_lr'])
49 |
50 | if torch.cuda.device_count() > 1:
51 | print("Using GPUs.\n")
52 | gen = torch.nn.DataParallel(gen)
53 | disc = torch.nn.DataParallel(disc)
54 | gen = gen.to(device)
55 | disc = disc.to(device)
56 |
57 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")}
58 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
59 | num_steps = config['num_steps']
60 | iter_of_epoch = 1000
61 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'],
62 | 'milestone': config['milestone'],
63 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps,
64 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe,
65 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'],
66 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']}
67 | log_dict = {'D_losses': 0., 'G_losses': 0.,
68 | 'D_losses_real': 0., 'D_losses_fake': 0.,
69 | 'pixel_loss': 0., 'gan_loss': 0.,
70 | 'percep_loss': 0.,
71 | 'F_prob': 0., 'R_prob': 0.}
72 | while global_info['step'] < config['num_steps']:
73 | with tqdm(total=len(train_dataset)) as t:
74 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1))
75 | global_info['t'] = t
76 | model_utils.train_SRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict, use_WGAN=True)
77 |
78 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_niqe']))
79 |
--------------------------------------------------------------------------------
/ESRGAN/train_ESRGAN_iter.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import model_utils
3 | import utils
4 | import os
5 | import torch
6 | from torch.backends import cudnn
7 | from torch.utils.tensorboard import SummaryWriter
8 | from VGGLoss import VGGLoss, PerceptualLoss
9 | from model import G, D
10 | from torch import nn, optim
11 | from ESRGANdatasets import ESRGANValDataset, ESRGANTrainDataset
12 | from torch.utils.data.dataloader import DataLoader
13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
14 |
15 |
16 | def train_model(config, pre_train, from_pth=False):
17 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
18 | outputs_dir = config['outputs_dir']
19 | batch_size = config['batch_size']
20 | utils.mkdirs(outputs_dir)
21 | csv_file = outputs_dir + config['csv_name']
22 | logs_dir = config['logs_dir']
23 |
24 | cudnn.benchmark = True
25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26 | torch.manual_seed(config['seed'])
27 |
28 | # ----需要修改部分------
29 | print("===> Loading datasets")
30 | train_dataset = ESRGANTrainDataset()
31 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
32 | batch_size=batch_size, shuffle=True, pin_memory=True)
33 | val_dataset = ESRGANValDataset(config['val_file'])
34 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
35 |
36 | print("===> Building model")
37 | gen = G()
38 | disc = D()
39 | if not from_pth:
40 | disc.init_weight()
41 | criterion = {'bce': nn.BCEWithLogitsLoss().to(device), 'pixel_loss': nn.L1Loss().to(device),
42 | 'vgg_loss': PerceptualLoss({'conv5_4': 1})}
43 | # 'vgg_loss': VGGLoss(device, nn.L1Loss())}
44 | gen_opt = optim.Adam(gen.parameters(), lr=config['gen_lr'])
45 | disc_opt = optim.Adam(disc.parameters(), lr=config['disc_lr'])
46 | # ----END------
47 | start_step, best_step, best_niqe, writer, csv_file = \
48 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt,
49 | disc, disc_opt, csv_file, from_pth, config['auto_lr'])
50 |
51 | if torch.cuda.device_count() > 1:
52 | print("Using GPUs.\n")
53 | gen = torch.nn.DataParallel(gen)
54 | disc = torch.nn.DataParallel(disc)
55 | gen = gen.to(device)
56 | disc = disc.to(device)
57 |
58 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")}
59 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
60 | num_steps = config['num_steps']
61 | iter_of_epoch = 100
62 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 'milestone': config['milestone'],
63 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps,
64 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe,
65 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'],
66 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']}
67 | log_dict = {'D_losses': 0., 'G_losses': 0.,
68 | 'D_losses_real': 0., 'D_losses_fake': 0.,
69 | 'pixel_loss': 0., 'gan_loss': 0.,
70 | 'percep_loss': 0.,
71 | 'F_prob': 0., 'R_prob': 0.}
72 | while global_info['step'] < config['num_steps']:
73 | with tqdm(total=len(train_dataset)) as t:
74 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1))
75 | global_info['t'] = t
76 | model_utils.train_ESRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict)
77 |
78 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_niqe']))
79 |
--------------------------------------------------------------------------------
/SRResNet/train_SRGAN_iter.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torchvision
4 | from tqdm import tqdm
5 | import model_utils
6 | import utils
7 | import os
8 | import torch
9 | from torch.backends import cudnn
10 | from torch.utils.tensorboard import SummaryWriter
11 | from VGGLoss import VGGLoss
12 | from model import G, D
13 | from torch import nn, optim
14 | from SRResNetdatasets import SRResNetValDataset, SRResNetTrainDataset, DIV2KDataset, DIV2KSubDataset
15 | from torch.utils.data.dataloader import DataLoader
16 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
17 |
18 |
19 | def train_model(config, pre_train, from_pth=False):
20 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
21 | outputs_dir = config['outputs_dir']
22 | batch_size = config['batch_size']
23 | utils.mkdirs(outputs_dir)
24 | csv_file = outputs_dir + config['csv_name']
25 | logs_dir = config['logs_dir']
26 |
27 | cudnn.benchmark = True
28 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
29 | torch.manual_seed(config['seed'])
30 |
31 | # ----需要修改部分------
32 | print("===> Loading datasets")
33 | train_dataset = DIV2KSubDataset()
34 | # train_dataset = DIV2KDataset(config['train_file'])
35 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
36 | batch_size=batch_size, shuffle=True, pin_memory=True)
37 | val_dataset = SRResNetValDataset(config['val_file'])
38 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
39 |
40 | print("===> Building model")
41 | gen = G()
42 | disc = D()
43 | if not from_pth:
44 | disc.init_weight()
45 | criterion = {'bce': nn.BCEWithLogitsLoss().to(device), 'pixel_loss': nn.MSELoss().to(device), 'vgg_loss': VGGLoss(device)}
46 | gen_opt = optim.Adam(gen.parameters(), lr=config['gen_lr'])
47 | disc_opt = optim.Adam(disc.parameters(), lr=config['disc_lr'])
48 | # ----END------
49 | start_step, best_step, best_niqe, writer, csv_file = \
50 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt,
51 | disc, disc_opt, csv_file, from_pth, config['auto_lr'])
52 |
53 | if torch.cuda.device_count() > 1:
54 | print("Using GPUs.\n")
55 | gen = torch.nn.DataParallel(gen)
56 | disc = torch.nn.DataParallel(disc)
57 | gen = gen.to(device)
58 | disc = disc.to(device)
59 |
60 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")}
61 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
62 | num_steps = config['num_steps']
63 | iter_of_epoch = 1000
64 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'],
65 | 'milestone': config['milestone'],
66 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps,
67 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe,
68 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'],
69 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']}
70 | log_dict = {'D_losses': 0., 'G_losses': 0.,
71 | 'D_losses_real': 0., 'D_losses_fake': 0.,
72 | 'pixel_loss': 0., 'gan_loss': 0.,
73 | 'percep_loss': 0.,
74 | 'F_prob': 0., 'R_prob': 0.}
75 | while global_info['step'] < config['num_steps']:
76 | with tqdm(total=len(train_dataset)) as t:
77 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1))
78 | global_info['t'] = t
79 | model_utils.train_SRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict)
80 |
81 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_psnr']))
82 |
--------------------------------------------------------------------------------
/ESRGAN/ESRGANdatasets.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import lmdb
4 | from PIL import Image
5 | import numpy as np
6 | import albumentations as A
7 | from imresize import imresize
8 | import h5py
9 | from torch.utils.data import Dataset, DataLoader
10 | import torch
11 | import matplotlib.pyplot as plt
12 |
13 | class ESRGANTrainDataset(Dataset):
14 | def __init__(self, root_dirs=['../datasets/DIV2K_train_HR/',
15 | '../datasets/Flickr2K/Flickr2K_HR/',
16 | '../datasets/OST/'
17 | ]):
18 | super(ESRGANTrainDataset, self).__init__()
19 | self.data = []
20 |
21 | for root_dir in root_dirs:
22 | self.img_names = os.listdir(root_dir)
23 | for name in self.img_names:
24 | self.data.append(root_dir + name)
25 |
26 | def __len__(self):
27 | return len(self.data)
28 |
29 | def __getitem__(self, index):
30 | img_path = self.data[index]
31 |
32 | image = np.array(Image.open(img_path))
33 | transpose = A.Compose([
34 | A.RandomCrop(width=128, height=128),
35 | A.HorizontalFlip(p=0.5),
36 | A.RandomRotate90(p=0.5),
37 | ])
38 | image = transpose(image=image)["image"]
39 | label = torch.from_numpy(image.astype(np.float32).transpose([2, 0, 1]) / 255.)
40 | data = imresize(image, 1 / 4, 'bicubic')
41 | data = torch.from_numpy(data.astype(np.float32).transpose([2, 0, 1]) / 255.)
42 | return data, label
43 |
44 |
45 | class ESRGANValDataset(Dataset):
46 | def __init__(self, h5_file):
47 | super(ESRGANValDataset, self).__init__()
48 | self.h5_file = h5_file
49 |
50 | def __getitem__(self, idx):
51 | with h5py.File(self.h5_file, 'r') as f:
52 | return torch.from_numpy(f['data'][str(idx)][:, :, :] / 255.), \
53 | torch.from_numpy(f['label'][str(idx)][:, :, :] / 255.)
54 |
55 | def __len__(self):
56 | with h5py.File(self.h5_file, 'r') as f:
57 | return len(f['data'])
58 |
59 |
60 | class lmdbDataset(Dataset):
61 | def __init__(self):
62 | super(lmdbDataset, self).__init__()
63 | env = lmdb.open('../datasets/Set5.lmdb', max_dbs=2, readonly=True)
64 | self.data = env.open_db("train_data".encode('ascii'))
65 | self.shape = env.open_db("train_shape".encode('ascii'))
66 | self.txn = env.begin()
67 | self._length = int(self.txn.stat(db=self.data)["entries"] / 2)
68 |
69 | def __getitem__(self, idx):
70 | idx = str(idx)
71 | image = self.txn.get(idx.encode('ascii'), db=self.data)
72 | image = np.frombuffer(image, 'uint8')
73 | buf_meta = self.txn.get((idx+'.meta').encode('ascii'), db=self.shape)
74 | buf_meta = buf_meta.decode('ascii')
75 | H, W, C = [int(s) for s in buf_meta.split(',')]
76 | image = image.reshape(H, W, C)
77 |
78 | transpose = A.Compose([
79 | A.RandomCrop(width=128, height=128),
80 | A.HorizontalFlip(p=0.5),
81 | A.RandomRotate90(p=0.5),
82 | ])
83 | image = transpose(image=image)["image"]
84 | label = torch.from_numpy(image.astype(np.float32).transpose([2, 0, 1]) / 255.)
85 | data = imresize(image, 1 / 4, 'bicubic')
86 | data = torch.from_numpy(data.astype(np.float32).transpose([2, 0, 1]) / 255.)
87 | return data, label
88 |
89 | def __len__(self):
90 | return self._length
91 |
92 |
93 | def test():
94 | dataset = lmdbDataset()
95 | loader = DataLoader(dataset, batch_size=1, num_workers=0)
96 |
97 | for low_res , high_res in loader:
98 | # plt.imshow(low_res.cpu().numpy().squeeze(0).transpose([1,2,0]))
99 | # plt.show()
100 | print(low_res.shape)
101 | print(high_res.shape)
102 |
103 |
104 | if __name__ == "__main__":
105 | test()
106 |
--------------------------------------------------------------------------------
/SRFlow/test_srflow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import options
3 | import torch
4 | from torch.backends import cudnn
5 | import utils
6 | from SRFlow_model import SRFlowModel
7 | from PIL import Image
8 | import os
9 |
10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
11 | if __name__ == '__main__':
12 | model_name = 'SRFlow'
13 | weight_file = '../weight_file/SRFlow_DF2K_4X.pth'
14 | root_dir = '../datasets/degraded9/'
15 | out_root_dir = f'../test_res/{model_name}_degraded4/'
16 | hr_dir = '../datasets/PIPAL/'
17 |
18 | lr_dirs = os.listdir(root_dir)
19 | # 根据图片名称分类输出
20 | out_dirs = os.listdir(hr_dir)
21 | for dir in out_dirs:
22 | dir = dir.split('.')[0] + '/'
23 | utils.mkdirs(out_root_dir + dir)
24 |
25 | scale = 4
26 | padding = scale
27 |
28 | if not os.path.exists(weight_file):
29 | print(f'Weight file not exist!\n{weight_file}\n')
30 | raise "Error"
31 |
32 | cudnn.benchmark = True
33 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
34 | opt = options.parse('./SRFlow_DF2K_4X.yml', is_train=False)
35 | opt['gpu_ids'] = None
36 | opt = options.dict_to_nonedict(opt)
37 | heat = opt['heat']
38 | model = SRFlowModel(opt)
39 | checkpoint = torch.load(weight_file)
40 | model.netG.load_state_dict(checkpoint)
41 | offset = 0
42 | for dir in lr_dirs:
43 | # outputs_dir = out_root_dir + dir + '/'
44 | lr_dir = root_dir + dir + '/'
45 | lr_lists = os.listdir(lr_dir)
46 | Avg_psnr = utils.AverageMeter()
47 | Avg_niqe = utils.AverageMeter()
48 | for imgName in lr_lists:
49 | outputs_dir = out_root_dir + imgName.split('.')[0] + '/'
50 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
51 | image = utils.ImgOffSet(image, offset, offset)
52 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale)
53 | hr_image = utils.ImgOffSet(hr_image, offset*scale, offset*scale)
54 |
55 | img_mode = image.mode
56 | if img_mode == 'L':
57 | gray_img = np.array(image)
58 | image = image.convert('RGB')
59 | lr_image = np.array(image)
60 | hr_image = np.array(hr_image)
61 |
62 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
63 | lr /= 255.
64 | lr = torch.from_numpy(lr).unsqueeze(0)
65 |
66 | # hr_image = hr_image[padding: -padding, padding: -padding, ...]
67 |
68 | SR = model.get_sr(lr, heat)
69 | # SR = SR[..., padding: -padding, padding: -padding]
70 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
71 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
72 | if img_mode != 'L':
73 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255.
74 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255.
75 | else:
76 | # gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...]
77 | hr_y = gray_img / 255.
78 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L')
79 | SR_y = np.array(SR).astype(np.float32) / 255.
80 | psnr = utils.calc_psnr(hr_y, SR_y)
81 | # NIQE = niqe.calculate_niqe(SR_y)
82 | Avg_psnr.update(psnr, 1)
83 | # Avg_niqe.update(NIQE, 1)
84 | print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item()))
85 | # print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}'.format(psnr.item(), NIQE))
86 | # GPU tensor -> CPU tensor -> numpy
87 | output = np.array(SR).astype(np.uint8)
88 | output = Image.fromarray(output) # hw -> wh
89 | # output.save(outputs_dir + imgName)
90 | output.save(outputs_dir + dir + '.bmp')
91 | # output.save(outputs_dir + imgName.replace('.', f'_{model_name}_x{scale}.'))
92 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg))
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/ESRGAN/train.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch.utils.tensorboard import SummaryWriter
3 | from tqdm import tqdm
4 | import model_utils
5 | import utils
6 | import os
7 | import torch
8 | from torch.backends import cudnn
9 | from model import G, D
10 | from torch import nn, optim
11 | from ESRGANdatasets import ESRGANValDataset, ESRGANTrainDataset
12 | from torch.utils.data.dataloader import DataLoader
13 | import numpy as np
14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
15 |
16 |
17 | def train_model(config, from_pth=False):
18 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
19 |
20 | outputs_dir = config['outputs_dir']
21 | batch_size = config['batch_size']
22 | num_epochs = config['num_epochs']
23 | utils.mkdirs(outputs_dir)
24 | csv_file = outputs_dir + config['csv_name']
25 | logs_dir = config['logs_dir']
26 | cudnn.benchmark = True
27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28 | torch.manual_seed(config['seed'])
29 |
30 | # ----需要修改部分------
31 | print("===> Loading datasets")
32 | train_dataset = ESRGANTrainDataset()
33 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
34 | batch_size=batch_size, shuffle=True, pin_memory=True)
35 | val_dataset = ESRGANValDataset(config['val_file'])
36 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
37 |
38 | print("===> Building model")
39 | model = G()
40 | if not from_pth:
41 | model.init_weight()
42 | criterion = nn.L1Loss().cuda()
43 | optimizer = optim.Adam(model.parameters(), lr=config['lr'])
44 | # ----END------
45 | start_epoch, best_epoch, best_psnr, writer, csv_file = \
46 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file,
47 | from_pth, auto_lr=config['auto_lr'])
48 |
49 | if torch.cuda.device_count() > 1:
50 | print("Using GPUs.\n")
51 | model = torch.nn.DataParallel(model)
52 | model = model.to(device)
53 |
54 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar")
55 | writer_test = SummaryWriter(f"{logs_dir}/test")
56 |
57 | for epoch in range(start_epoch, num_epochs):
58 | lr = optimizer.state_dict()['param_groups'][0]['lr']
59 | print(f'learning rate: {lr}\n')
60 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
61 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}')
62 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t)
63 |
64 | if isinstance(model, torch.nn.DataParallel):
65 | model = model.module
66 |
67 | model.eval()
68 | epoch_psnr = utils.AverageMeter()
69 | for idx, data in enumerate(val_dataloader):
70 | inputs, labels = data
71 | inputs = inputs.to(device)
72 | with torch.no_grad():
73 | preds = model(inputs)
74 | img_grid_fake = torchvision.utils.make_grid(preds, normalize=True)
75 | writer_test.add_image(f"Test Fake{idx}", img_grid_fake, global_step=epoch)
76 | preds = preds.mul(255.0).cpu().numpy().squeeze(0)
77 | preds = preds.transpose([1, 2, 0]) #chw->hwc
78 | preds = np.clip(preds, 0.0, 255.0)
79 | preds = utils.rgb2ycbcr(preds).astype(np.float32)[..., 0]/255.
80 | epoch_psnr.update(utils.calc_psnr(preds, labels.numpy()[0, 0, ...]).item(), len(inputs))
81 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
82 | if config['auto_lr'] and (epoch+1) % 232 == 0:
83 | model_utils.update_lr(optimizer, 0.5)
84 |
85 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch)
86 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch)
87 |
88 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses,
89 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer)
90 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
--------------------------------------------------------------------------------
/FSRCNN/train_deconv.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | from tqdm import tqdm
3 | import model_utils
4 | import utils
5 | import os
6 | import torch
7 | from torch.backends import cudnn
8 | from torch import nn, optim
9 | from model import N2_10_4, CLoss, HuberLoss, FSRCNN
10 | from FSRCNNdatasets import TrainDataset, ValDataset, ResValDataset
11 | from torch.utils.data.dataloader import DataLoader
12 |
13 |
14 | def train_model(config, from_pth=False, pre_train=None):
15 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
16 | outputs_dir = config['outputs_dir']
17 | batch_size = config['batch_size']
18 | lr = config['lr']
19 | num_epochs = config['num_epochs']
20 | logs_dir = config['logs_dir']
21 | utils.mkdirs(outputs_dir)
22 | utils.mkdirs(logs_dir)
23 |
24 | csv_file = outputs_dir + config['csv_name']
25 |
26 | cudnn.benchmark = True
27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28 | torch.manual_seed(config['seed'])
29 | # ----需要修改部分------
30 | model = N2_10_4(config['scale'], config['in_size'], config['out_size'],
31 | num_channels=1, d=config['d'], m=config['m'])
32 | # model = FSRCNN(scale, in_size, out_size, num_channels=1, d=config['d'], m=config['m'])
33 | if config['Loss'] == 'CLoss':
34 | criterion = CLoss(delta=config['delta']).cuda()
35 | elif config['Loss'] == 'Huber':
36 | criterion = HuberLoss(delta=config['delta']).cuda()
37 | else:
38 | criterion = nn.MSELoss().cuda()
39 |
40 | optimizer = optim.SGD([
41 | {'params': model.deconv_layer.parameters(), 'lr': lr},
42 | ], lr=lr, momentum=0.9) # 前两层学习率lr, 最后一层学习率lr*0.1
43 | train_dataset = TrainDataset(config['train_file'])
44 | train_dataloader = DataLoader(dataset=train_dataset,
45 | batch_size=batch_size,
46 | shuffle=True,
47 | num_workers=config['num_workers'],
48 | pin_memory=True)
49 | if config['residual']:
50 | val_dataset = ResValDataset(config['val_file'])
51 | else:
52 | val_dataset = ValDataset(config['val_file'])
53 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
54 | # ----END------
55 | start_epoch, best_epoch, best_psnr, writer, csv_file = \
56 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, from_pth, config['auto_lr'])
57 |
58 | if not from_pth:
59 | if not os.path.exists(pre_train):
60 | print(f'Weight file not exist!\n{pre_train}\n')
61 | raise "Error"
62 | checkpoint = torch.load(pre_train)
63 | model.load_state_dict(checkpoint['model'])
64 | model.deconv_layer.weight.data.normal_(mean=0.0, std=0.001)
65 | model.deconv_layer.bias.data.zero_()
66 |
67 | if torch.cuda.device_count() > 1:
68 | print("Using GPUs.\n")
69 | model = torch.nn.DataParallel(model)
70 | model = model.to(device)
71 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar")
72 |
73 | for epoch in range(start_epoch, num_epochs):
74 | lr = optimizer.state_dict()['param_groups'][0]['lr']
75 | print(f'learning rate: {lr}\n')
76 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
77 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}')
78 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t)
79 |
80 | if isinstance(model, torch.nn.DataParallel):
81 | model = model.module
82 |
83 | epoch_psnr = model_utils.validate(model, val_dataloader, device, config['residual'])
84 |
85 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch)
86 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch)
87 |
88 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses,
89 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer)
90 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
--------------------------------------------------------------------------------
/SRResNet/train.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch.utils.tensorboard import SummaryWriter
3 | from tqdm import tqdm
4 | import model_utils
5 | import utils
6 | import os
7 | import torch
8 | from torch.backends import cudnn
9 | from model import G, D
10 | from torch import nn, optim
11 | from SRResNetdatasets import SRResNetValDataset, SRResNetTrainDataset, DIV2KDataset
12 | from torch.utils.data.dataloader import DataLoader
13 | import numpy as np
14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
15 |
16 |
17 | def train_model(config, from_pth=False):
18 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu']
19 |
20 | outputs_dir = config['outputs_dir']
21 | batch_size = config['batch_size']
22 | num_epochs = config['num_epochs']
23 | utils.mkdirs(outputs_dir)
24 | csv_file = outputs_dir + config['csv_name']
25 | logs_dir = config['logs_dir']
26 | cudnn.benchmark = True
27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
28 | torch.manual_seed(config['seed'])
29 |
30 | # ----需要修改部分------
31 | print("===> Loading datasets")
32 | train_dataset = DIV2KDataset(config['train_file'])
33 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'],
34 | batch_size=batch_size, shuffle=True, pin_memory=True)
35 | val_dataset = SRResNetValDataset(config['val_file'])
36 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1)
37 |
38 | print("===> Building model")
39 | model = G()
40 | if not from_pth:
41 | model.init_weight()
42 | criterion = nn.MSELoss().cuda()
43 | optimizer = optim.Adam(model.parameters(), lr=config['lr'])
44 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=50)
45 | # ----END------
46 | start_epoch, best_epoch, best_psnr, writer, csv_file = \
47 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file,
48 | from_pth, auto_lr=config['auto_lr'])
49 |
50 | if torch.cuda.device_count() > 1:
51 | print("Using GPUs.\n")
52 | model = torch.nn.DataParallel(model)
53 | model = model.to(device)
54 |
55 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar")
56 | writer_test = SummaryWriter(f"{logs_dir}/test")
57 |
58 | for epoch in range(start_epoch, num_epochs):
59 | lr = optimizer.state_dict()['param_groups'][0]['lr']
60 | print(f'learning rate: {lr}\n')
61 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
62 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}')
63 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t)
64 |
65 | if isinstance(model, torch.nn.DataParallel):
66 | model = model.module
67 |
68 | model.eval()
69 | epoch_psnr = utils.AverageMeter()
70 | for idx, data in enumerate(val_dataloader):
71 | inputs, labels = data
72 | inputs = inputs.to(device)
73 | with torch.no_grad():
74 | preds = model(inputs) * 0.5 + 0.5
75 | img_grid_fake = torchvision.utils.make_grid(preds, normalize=True)
76 | writer_test.add_image(f"Test Fake{idx}", img_grid_fake, global_step=epoch)
77 | preds = preds.mul(255.0).cpu().numpy().squeeze(0)
78 | preds = preds.transpose([1, 2, 0]) #chw->hwc
79 | preds = np.clip(preds, 0.0, 255.0)
80 | preds = utils.rgb2ycbcr(preds).astype(np.float32)[..., 0]/255.
81 | epoch_psnr.update(utils.calc_psnr(preds, labels.numpy()[0, 0, ...]).item(), len(inputs))
82 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
83 | if config['auto_lr']:
84 | scheduler.step(epoch_psnr.avg)
85 |
86 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch)
87 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch)
88 |
89 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses,
90 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer)
91 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
--------------------------------------------------------------------------------
/BSRGAN/bsrgan_model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 |
7 |
8 | def initialize_weights(net_l, scale=1):
9 | if not isinstance(net_l, list):
10 | net_l = [net_l]
11 | for net in net_l:
12 | for m in net.modules():
13 | if isinstance(m, nn.Conv2d):
14 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
15 | m.weight.data *= scale # for residual block
16 | if m.bias is not None:
17 | m.bias.data.zero_()
18 | elif isinstance(m, nn.Linear):
19 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20 | m.weight.data *= scale
21 | if m.bias is not None:
22 | m.bias.data.zero_()
23 | elif isinstance(m, nn.BatchNorm2d):
24 | init.constant_(m.weight, 1)
25 | init.constant_(m.bias.data, 0.0)
26 |
27 |
28 | def make_layer(block, n_layers):
29 | layers = []
30 | for _ in range(n_layers):
31 | layers.append(block())
32 | return nn.Sequential(*layers)
33 |
34 |
35 | class ResidualDenseBlock_5C(nn.Module):
36 | def __init__(self, nf=64, gc=32, bias=True):
37 | super(ResidualDenseBlock_5C, self).__init__()
38 | # gc: growth channel, i.e. intermediate channels
39 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
40 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
41 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
42 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
43 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
44 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
45 |
46 | # initialization
47 | initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
48 |
49 | def forward(self, x):
50 | x1 = self.lrelu(self.conv1(x))
51 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
52 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
53 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
54 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
55 | return x5 * 0.2 + x
56 |
57 |
58 | class RRDB(nn.Module):
59 | '''Residual in Residual Dense Block'''
60 |
61 | def __init__(self, nf, gc=32):
62 | super(RRDB, self).__init__()
63 | self.RDB1 = ResidualDenseBlock_5C(nf, gc)
64 | self.RDB2 = ResidualDenseBlock_5C(nf, gc)
65 | self.RDB3 = ResidualDenseBlock_5C(nf, gc)
66 |
67 | def forward(self, x):
68 | out = self.RDB1(x)
69 | out = self.RDB2(out)
70 | out = self.RDB3(out)
71 | return out * 0.2 + x
72 |
73 |
74 | class RRDBNet(nn.Module):
75 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
76 | super(RRDBNet, self).__init__()
77 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
78 | self.sf = sf
79 | print([in_nc, out_nc, nf, nb, gc, sf])
80 |
81 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
82 | self.RRDB_trunk = make_layer(RRDB_block_f, nb)
83 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
84 | #### upsampling
85 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
86 | if self.sf==4:
87 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
88 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
89 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
90 |
91 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
92 |
93 | def forward(self, x):
94 | fea = self.conv_first(x)
95 | trunk = self.trunk_conv(self.RRDB_trunk(fea))
96 | fea = fea + trunk
97 |
98 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
99 | if self.sf==4:
100 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
101 | out = self.conv_last(self.lrelu(self.HRconv(fea)))
102 |
103 | return out
--------------------------------------------------------------------------------
/BSRGAN/test_bsrgan_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 |
3 | import numpy as np
4 | import torch
5 | from torch.backends import cudnn
6 |
7 | import niqe
8 | import utils
9 | from bsrgan_model import RRDBNet
10 | from PIL import Image
11 | from imresize import imresize
12 | import os
13 |
14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
15 | if __name__ == '__main__':
16 | model_name = 'BSRGAN'
17 | weight_file = './weight_file/BSRGAN.pth'
18 | root_dir = '/data0/jli/datasets/degraded_17/'
19 | out_root_dir = f'./test_res/{model_name}_degraded17/'
20 | hr_dir = '/data0/jli/datasets/PIPAL/'
21 |
22 | csv_file = f'./test_res/{model_name}_x4_degraded17.csv'
23 | csv_file = open(csv_file, 'w', newline='')
24 | writer = csv.writer(csv_file)
25 | writer.writerow(('name', 'psnr', 'niqe', 'ssim'))
26 |
27 | lr_dirs = os.listdir(root_dir)
28 | for dir in lr_dirs:
29 | utils.mkdirs(out_root_dir + dir)
30 |
31 | scale = 4
32 | padding = scale
33 |
34 |
35 | if not os.path.exists(weight_file):
36 | print(f'Weight file not exist!\n{weight_file}\n')
37 | raise "Error"
38 |
39 | cudnn.benchmark = True
40 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
41 |
42 | model = RRDBNet().to(device)
43 | checkpoint = torch.load(weight_file)
44 | model.load_state_dict(checkpoint)
45 | model.eval()
46 |
47 | for dir in lr_dirs:
48 | outputs_dir = out_root_dir + dir + '/'
49 | lr_dir = root_dir + dir + '/'
50 | lr_lists = os.listdir(lr_dir)
51 | Avg_psnr = utils.AverageMeter()
52 | Avg_niqe = utils.AverageMeter()
53 | Avg_ssim = utils.AverageMeter()
54 | for imgName in lr_lists:
55 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
56 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale)
57 | img_mode = image.mode
58 | if img_mode == 'L':
59 | gray_img = np.array(image)
60 | image = image.convert('RGB')
61 | lr_image = np.array(image)
62 | hr_image = np.array(hr_image)
63 |
64 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
65 | lr /= 255.
66 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
67 |
68 | # hr_image = hr_image[padding: -padding, padding: -padding, ...]
69 |
70 | with torch.no_grad():
71 | SR = model(lr)
72 | # SR = SR[..., padding: -padding, padding: -padding]
73 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
74 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
75 | if img_mode != 'L':
76 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255.
77 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255.
78 | else:
79 | # gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...]
80 | hr_y = gray_img / 255.
81 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L')
82 | SR_y = np.array(SR).astype(np.float32) / 255.
83 | psnr = utils.calc_psnr(hr_y, SR_y)
84 | NIQE = niqe.calculate_niqe(SR_y)
85 | ssim = utils.calculate_ssim(hr_y * 255, SR_y * 255)
86 | Avg_psnr.update(psnr, 1)
87 | Avg_niqe.update(NIQE, 1)
88 | Avg_ssim.update(ssim, 1)
89 | # print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item()))
90 | print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}, ssim: {:.4f}'.format(psnr.item(), NIQE, ssim))
91 | # GPU tensor -> CPU tensor -> numpy
92 | output = np.array(SR).astype(np.uint8)
93 | output = Image.fromarray(output) # hw -> wh
94 | output.save(outputs_dir + imgName.replace('.', '_{:.2f}.'.format(psnr.item())))
95 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}, Average_ssim: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg,
96 | Avg_ssim.avg))
97 | writer.writerow((dir, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg))
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/FSRCNN/gen_datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 | from tqdm import tqdm
5 | import utils
6 | from imresize import imresize
7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
8 |
9 |
10 | def gen_traindata(config):
11 | scale = config["scale"]
12 | stride = config["stride"]
13 | size_input = config["size_input"]
14 | size_label = size_input * scale
15 | size_output = config['size_output']
16 | padding = abs(size_label - size_output) // 2
17 | if scale == 3:
18 | padding2 = padding
19 | else:
20 | padding2 = padding + 1
21 | method = config['method']
22 | if config['residual']:
23 | h5savepath = config["hrDir"] + f'_label={size_output}_train_FSRCNNx{scale}_res.h5'
24 | else:
25 | h5savepath = config["hrDir"] + f'_label={size_output}_train_FSRCNNx{scale}.h5'
26 | hrDir = config["hrDir"] + '/'
27 | h5_file = h5py.File(h5savepath, 'w')
28 | imgList = os.listdir(hrDir)
29 | lr_subimgs = []
30 | hr_subimgs = []
31 | with tqdm(total=len(imgList)) as t:
32 | for imgName in imgList:
33 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale)
34 | hr_y = utils.img2ycbcr(hrIMG)[..., 0]
35 | lr_y = imresize(hr_y, 1 / scale, method).astype(np.float32)
36 | # residual
37 | if config['residual']:
38 | lr_y_upscale = imresize(lr_y, scale, method).astype(np.float32)
39 | hr_y = hr_y - lr_y_upscale
40 |
41 | for r in range(0, lr_y.shape[0] - size_input + 1, stride):
42 | for c in range(0, lr_y.shape[1] - size_input + 1, stride):
43 | lr_subimgs.append(lr_y[r: r + size_input, c: c + size_input])
44 | label = hr_y[r * scale: r * scale + size_label, c * scale: c * scale + size_label]
45 | label = label[padding: -padding2, padding: -padding2]
46 | hr_subimgs.append(label)
47 | t.update(1)
48 |
49 | lr_subimgs = np.array(lr_subimgs).astype(np.float32)
50 | hr_subimgs = np.array(hr_subimgs).astype(np.float32)
51 |
52 | h5_file.create_dataset('data', data=lr_subimgs)
53 | h5_file.create_dataset('label', data=hr_subimgs)
54 |
55 | h5_file.close()
56 |
57 |
58 | def gen_valdata(config):
59 | scale = config["scale"]
60 | size_input = config["size_input"]
61 | size_label = size_input * scale
62 | size_output = config['size_output']
63 | padding = (size_label - size_output) // 2
64 | method = config['method']
65 | if scale == 3:
66 | padding2 = padding
67 | else:
68 | padding2 = padding + 1
69 | if config['residual']:
70 | h5savepath = config["hrDir"] + f'_label={size_output}_val_FSRCNNx{scale}_res.h5'
71 | else:
72 | h5savepath = config["hrDir"] + f'_label={size_output}_val_FSRCNNx{scale}.h5'
73 | hrDir = config["hrDir"] + '/'
74 | h5_file = h5py.File(h5savepath, 'w')
75 | lr_group = h5_file.create_group('data')
76 | hr_group = h5_file.create_group('label')
77 | if config['residual']:
78 | bic_group = h5_file.create_group('bicubic')
79 | imgList = os.listdir(hrDir)
80 | for i, imgName in enumerate(imgList):
81 | hrIMG = utils.loadIMG_crop(hrDir+imgName, scale)
82 | hr_y = utils.img2ycbcr(hrIMG)[..., 0]
83 |
84 | lr_y = imresize(hr_y, 1 / scale, method).astype(np.float32)
85 | # residual
86 | if config['residual']:
87 | bic_y = imresize(lr_y, scale, method).astype(np.float32)
88 | bic_y = bic_y[padding: -padding2, padding: -padding2]
89 | label = hr_y.astype(np.float32)[padding: -padding2, padding: -padding2]
90 |
91 | lr_group.create_dataset(str(i), data=lr_y)
92 | hr_group.create_dataset(str(i), data=label)
93 | if config['residual']:
94 | bic_group.create_dataset(str(i), data=bic_y)
95 | h5_file.close()
96 |
97 | if __name__ == '__main__':
98 | # config = {'hrDir': './test/flower', 'scale': 3, "stride": 14, "size_input": 33, "size_label": 21}
99 | config = {'hrDir': '../datasets/T91_aug', 'scale': 4, 'stride': 10, "size_input": 10, "size_output": 21, "residual": True, 'method': 'bicubic'}
100 | gen_traindata(config)
101 | config['hrDir'] = '../datasets/Set5'
102 | gen_valdata(config)
103 |
--------------------------------------------------------------------------------
/HCFlow/ActNorms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | import thops
5 |
6 |
7 | class _ActNorm(nn.Module):
8 | """
9 | Activation Normalization
10 | Initialize the bias and scale with a given minibatch,
11 | so that the output per-channel have zero mean and unit variance for that.
12 |
13 | After initialization, `bias` and `logs` will be trained as parameters.
14 | """
15 |
16 | def __init__(self, num_features, scale=1.):
17 | super().__init__()
18 | # register mean and scale
19 | size = [1, num_features, 1, 1]
20 | self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
21 | self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
22 | self.num_features = num_features
23 | self.scale = float(scale)
24 | self.inited = False
25 |
26 | def _check_input_dim(self, input):
27 | return NotImplemented
28 |
29 | def initialize_parameters(self, input):
30 | self._check_input_dim(input)
31 | if not self.training:
32 | return
33 | if (self.bias != 0).any():
34 | self.inited = True
35 | return
36 | assert input.device == self.bias.device, (input.device, self.bias.device)
37 | with torch.no_grad():
38 | bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
39 | vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
40 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
41 | self.bias.data.copy_(bias.data)
42 | self.logs.data.copy_(logs.data)
43 | self.inited = True
44 |
45 | def _center(self, input, reverse=False, offset=None):
46 | bias = self.bias
47 |
48 | if offset is not None:
49 | bias = bias + offset
50 |
51 | if not reverse:
52 | return input + bias
53 | else:
54 | return input - bias
55 |
56 | def _scale(self, input, logdet=None, reverse=False, offset=None):
57 | logs = self.logs
58 |
59 | if offset is not None:
60 | logs = logs + offset
61 |
62 | if not reverse:
63 | input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
64 | # input = input * torch.exp(logs+logs_offset)
65 | else:
66 | input = input * torch.exp(-logs)
67 | if logdet is not None:
68 | """
69 | logs is log_std of `mean of channels`
70 | so we need to multiply pixels
71 | """
72 | dlogdet = thops.sum(logs) * thops.pixels(input)
73 | if reverse:
74 | dlogdet *= -1
75 | logdet = logdet + dlogdet
76 | return input, logdet
77 |
78 | def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
79 | if not self.inited:
80 | self.initialize_parameters(input)
81 |
82 | if offset_mask is not None:
83 | logs_offset *= offset_mask
84 | bias_offset *= offset_mask
85 | # no need to permute dims as old version
86 | if not reverse:
87 | # center and scale
88 | input = self._center(input, reverse, bias_offset)
89 | input, logdet = self._scale(input, logdet, reverse, logs_offset)
90 | else:
91 | # scale and center
92 | input, logdet = self._scale(input, logdet, reverse, logs_offset)
93 | input = self._center(input, reverse, bias_offset)
94 | return input, logdet
95 |
96 |
97 | class ActNorm2d(_ActNorm):
98 | def __init__(self, num_features, scale=1.):
99 | super().__init__(num_features, scale)
100 |
101 | def _check_input_dim(self, input):
102 | assert len(input.size()) == 4
103 | assert input.size(1) == self.num_features, (
104 | "[ActNorm]: input should be in shape as `BCHW`,"
105 | " channels should be {} rather than {}".format(
106 | self.num_features, input.size()))
107 |
108 |
109 | class MaskedActNorm2d(ActNorm2d):
110 | def __init__(self, num_features, scale=1.):
111 | super().__init__(num_features, scale)
112 |
113 | def forward(self, input, mask, logdet=None, reverse=False):
114 |
115 | assert mask.dtype == torch.bool
116 | output, logdet_out = super().forward(input, logdet, reverse)
117 |
118 | input[mask] = output[mask]
119 | logdet[mask] = logdet_out[mask]
120 |
121 | return input, logdet
122 |
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SuperRestoration
2 | Include:SRCNN、FSRCNN、SRResNet、SRGAN
3 |
4 | # Introduce
5 | ## Why wrote this?
6 | There are many ***Pytorch*** implementations of these networks on the web, but they do not appear exactly as described in the paper, so the results are quite different from the paper.
7 | So I've provided a version that's as close to the paper as possible.
8 | Hopefully it will help those interested in **Super-Resolution** networks get started.
9 | For more details:https://zhuanlan.zhihu.com/p/431724297
10 | # Install
11 | You can clone this repository directly, and run it without installing.
12 |
13 | ## Running Enviroment
14 | * Pytorch 3.7 64bit
15 | * Windows 10
16 |
17 | ## Reference
18 | Because `bicubic` interpolation in python is different with matlab,
19 | but in paper use matlab to generate datasets and evaluate PSNR, so I found a ***Python*** implementations of ***Matlab*** function:`imresize()`.
20 | [Here is the author's repository.](https://github.com/fatheral/matlab_imresize.git)
21 | Similiarly, I give a python version of the `rgb2ycbcr()` and `ycbcr2rgb()` in matlab.
22 |
23 | # Usage
24 | ## Prepare Datasets
25 | * Run `data_aug.py` to augment datasets.
26 | * Run `gen_datasets.py` to generate trainning and validation data. (You may need to modify parameters in `config`.)
27 | ## Train
28 | Take **SRCNN** as an example, run `SRCNN_x2.py` to train SRCNN. You can modify the training parameters according to your needs follow this template.
29 | ## Test
30 | Run `test.py` to generate test result and calculate PSNR. (You can modify parameters to specify testsets.)
31 | ## Visualize
32 | Run `csv2visdom.py` can visualize converge curve with visdom. (You need to install `visdom` and run it in advance.)
33 | Then visit `localhost:8097`.
34 |
35 | # Result: PSNR
36 | ## SRCNN x3
37 | | |Paper| Ours|
38 | ----|---|---|
39 | baby|35.01|34.96|
40 | bird|34.91|34.95|
41 | butterfly|27.58|27.77|
42 | head|33.55|33.51|
43 | woman|30.92|30.99|
44 | | |32.39|32.43|
45 |
46 | | |Paper| Ours |
47 | ----|-----|------|
48 | baboon|23.60|23.60|
49 | barbara|26.66|26.71|
50 | bridge|25.07|25.08|
51 | coastguard|27.20|27.17|
52 | comic|24.39|24.42|
53 | face|33.58|33.54|
54 | flowers|28.97|29.01|
55 | foreman|33.35|33.32|
56 | lenna|33.39|33.40|
57 | man|28.18|28.18|
58 | monarch|32.39|32.54|
59 | pepper|34.35|34.24|
60 | ppt3|26.02|26.14|
61 | zebra|28.87|28.80|
62 | | |29.00|29.01|
63 |
64 | ## FSRCNN x3
65 | Train on 91-images.
66 |
67 | | |Paper| Ours|
68 | ----|---|---|
69 | Set5|33.06|33.06|
70 | Set14|29.37|29.35|
71 | BSDS200|28.55|28.95|
72 |
73 | ## SRResNet x4
74 | Train on DIV2K.
75 |
76 | | |Paper| Ours|
77 | ----|---|---|
78 | Set5|32.05|32.12|
79 | Set14|28.49|28.50|
80 | BSDS100|27.58|27.54|
81 |
82 | ## SRGAN x4
83 | Train on DIV2K.
84 |
85 | | |Paper| Ours|
86 | ----|---|---|
87 | Set5|29.40|30.19|
88 | Set14|26.02|26.94|
89 | BSDS100|25.16|25.82|
90 |
91 | SRGAN cannot be evaluated by PSNR alone, so I list some test result.
92 | Obviously, SRGAN generates a sharper results than SRResNet and looks more convincing.
93 |
94 | |bicubic|SRResNet|SRGAN|original|
95 | ---|---|---|---|
96 | |||
97 | |||
98 | |||
99 | |||
100 | |||
101 | |||
102 | |||
103 |
104 |
--------------------------------------------------------------------------------
/FSRCNN/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch import nn
3 | import torch
4 |
5 |
6 | class FSRCNN(nn.Module):
7 | def __init__(self, scale_factor, in_size, out_size, num_channels=1, d=56, s=12, m=4):
8 | super(FSRCNN, self).__init__()
9 | self.extract_layer = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=5, padding=2, padding_mode='replicate'),
10 | nn.PReLU())
11 |
12 | self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU()]
13 | for i in range(m):
14 | self.mid_part.extend([nn.ReplicationPad2d(1),
15 | nn.Conv2d(s, s, kernel_size=3),
16 | nn.PReLU()])
17 | self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU()])
18 | self.mid_part = nn.Sequential(*self.mid_part)
19 |
20 | # 11->out
21 | self.deconv_layer = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor,
22 | padding=(9 + (in_size-1)*scale_factor - out_size)//2)
23 |
24 | def init_weights(self, method='MSRA'):
25 | if method == 'MSRA':
26 | init_weights_MSRA(self)
27 | else:
28 | init_weights_Xavier(self)
29 |
30 | def forward(self, x):
31 | x = self.extract_layer(x)
32 | x = self.mid_part(x)
33 | x = self.deconv_layer(x)
34 | return x
35 |
36 |
37 | class N2_10_4(nn.Module):
38 | def __init__(self, scale_factor, in_size, out_size, num_channels=1, d=10, m=4):
39 | super(N2_10_4, self).__init__()
40 | self.extract_layer = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=3, padding=1, padding_mode='replicate'),
41 | nn.PReLU())
42 | self.mid_part = []
43 | for i in range(m):
44 | self.mid_part.extend([nn.Conv2d(d, d, kernel_size=3, padding=1, padding_mode='replicate'),
45 | nn.PReLU()])
46 | self.mid_part = nn.Sequential(*self.mid_part)
47 |
48 | # 11->out
49 | self.deconv_layer = nn.ConvTranspose2d(d, num_channels, kernel_size=7, stride=scale_factor,
50 | padding=(7 + (in_size-1)*scale_factor - out_size)//2)
51 |
52 | def forward(self, x):
53 | x = self.extract_layer(x)
54 | x = self.mid_part(x)
55 | x = self.deconv_layer(x)
56 | return x
57 |
58 | def init_weights(self, method='MSRA'):
59 | if method == 'MSRA':
60 | init_weights_MSRA(self)
61 | else:
62 | init_weights_Xavier(self)
63 |
64 |
65 | class CLoss(nn.Module):
66 | """L1 Charbonnierloss."""
67 | def __init__(self, delta=1e-3):
68 | super(CLoss, self).__init__()
69 | self.delta2 = delta * delta
70 |
71 | def forward(self, X, Y):
72 | diff = torch.add(X, -Y)
73 | error = torch.sqrt(diff * diff + self.delta2)
74 | loss = torch.mean(error)
75 | return loss
76 |
77 |
78 | class HuberLoss(nn.Module):
79 | """Huber loss."""
80 | def __init__(self, delta=2e-4):
81 | super(HuberLoss, self).__init__()
82 | self.delta = delta
83 | self.delta2 = delta * delta
84 |
85 | def forward(self, X, Y):
86 | abs_diff = abs(torch.add(X, -Y))
87 | cond = torch.less_equal(abs_diff, self.delta)
88 | large_loss = 0.5 * abs_diff * abs_diff
89 | small_loss = self.delta * abs_diff - 0.5 * self.delta2
90 | error = torch.where(cond, large_loss, small_loss)
91 | loss = torch.mean(error)
92 | return loss
93 |
94 |
95 | def init_weights_MSRA(Model):
96 | for L in Model.extract_layer:
97 | if isinstance(L, nn.Conv2d):
98 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels * L.weight.data[0][0].numel())))
99 | L.bias.data.zero_()
100 | for L in Model.mid_part:
101 | if isinstance(L, nn.Conv2d):
102 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels * L.weight.data[0][0].numel())))
103 | L.bias.data.zero_()
104 | Model.deconv_layer.weight.data.normal_(mean=0.0, std=0.001)
105 | Model.deconv_layer.bias.data.zero_()
106 |
107 | def init_weights_Xavier(Model):
108 | for L in Model.extract_layer:
109 | if isinstance(L, nn.Conv2d):
110 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels + L.in_channels)))
111 | L.bias.data.zero_()
112 | for L in Model.mid_part:
113 | if isinstance(L, nn.Conv2d):
114 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels + L.in_channels)))
115 | L.bias.data.zero_()
116 | Model.deconv_layer.weight.data.normal_(mean=0.0, std=0.001)
117 | Model.deconv_layer.bias.data.zero_()
--------------------------------------------------------------------------------
/SwinIR/test_swinir_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 |
3 | import numpy as np
4 | import torch
5 | from torch.backends import cudnn
6 | import niqe
7 | import utils
8 | from swinir_model import SwinIR
9 | from PIL import Image
10 | import os
11 |
12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13 | if __name__ == '__main__':
14 | model_name = 'SwinIR-M'
15 | csv_file = f'./test_res/{model_name}_x4_degraded17.csv'
16 | csv_file = open(csv_file, 'w', newline='')
17 | writer = csv.writer(csv_file)
18 | writer.writerow(('name', 'psnr', 'niqe', 'ssim'))
19 |
20 | weight_file = './weight_file/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth'
21 | is_large = False
22 | root_dir = '/data0/jli/datasets/degraded_17/'
23 | out_root_dir = f'./test_res/{model_name}_degraded17/'
24 | hr_dir = '/data0/jli/datasets/PIPAL/'
25 |
26 | lr_dirs = os.listdir(root_dir)
27 | for dir in lr_dirs:
28 | utils.mkdirs(out_root_dir + dir)
29 |
30 | scale = 4
31 | padding = scale
32 |
33 | if not os.path.exists(weight_file):
34 | print(f'Weight file not exist!\n{weight_file}\n')
35 | raise "Error"
36 |
37 | cudnn.benchmark = True
38 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
39 | if not is_large:
40 | model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8,
41 | img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
42 | mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv').to(device)
43 | else:
44 | model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8,
45 | img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
46 | num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
47 | mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv').to(device)
48 | checkpoint = torch.load(weight_file)
49 | model.load_state_dict(checkpoint['params_ema'])
50 | model.eval()
51 |
52 | for dir in lr_dirs:
53 | outputs_dir = out_root_dir + dir + '/'
54 | lr_dir = root_dir + dir + '/'
55 | lr_lists = os.listdir(lr_dir)
56 | Avg_psnr = utils.AverageMeter()
57 | Avg_niqe = utils.AverageMeter()
58 | Avg_ssim = utils.AverageMeter()
59 | for imgName in lr_lists:
60 | image = utils.loadIMG_crop(lr_dir + imgName, scale)
61 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale)
62 | img_mode = image.mode
63 | if img_mode == 'L':
64 | gray_img = np.array(image)
65 | image = image.convert('RGB')
66 | lr_image = np.array(image)
67 | hr_image = np.array(hr_image)
68 |
69 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw
70 | lr /= 255.
71 | lr = torch.from_numpy(lr).to(device).unsqueeze(0)
72 |
73 | # hr_image = hr_image[padding: -padding, padding: -padding, ...]
74 |
75 | with torch.no_grad():
76 | SR = model(lr)
77 | # SR = SR[..., padding: -padding, padding: -padding]
78 | SR = SR.mul(255.0).cpu().numpy().squeeze(0)
79 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
80 | if img_mode != 'L':
81 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255.
82 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255.
83 | else:
84 | # gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...]
85 | hr_y = gray_img / 255.
86 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L')
87 | SR_y = np.array(SR).astype(np.float32) / 255.
88 | psnr = utils.calc_psnr(hr_y, SR_y)
89 | NIQE = niqe.calculate_niqe(SR_y)
90 | ssim = utils.calculate_ssim(hr_y * 255, SR_y * 255)
91 | Avg_psnr.update(psnr, 1)
92 | Avg_niqe.update(NIQE, 1)
93 | Avg_ssim.update(ssim, 1)
94 | # print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item()))
95 | print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}, ssim: {:.4f}'.format(psnr.item(), NIQE, ssim))
96 | # GPU tensor -> CPU tensor -> numpy
97 | output = np.array(SR).astype(np.uint8)
98 | output = Image.fromarray(output) # hw -> wh
99 | output.save(outputs_dir + imgName.replace('.', '_{:.2f}.'.format(psnr.item())))
100 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}, Average_ssim: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg,
101 | Avg_ssim.avg))
102 | writer.writerow((dir, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg))
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------