├── .idea ├── .gitignore ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── SuperRestoration.iml ├── datasets ├── T91 │ ├── t1.png │ ├── t10.png │ ├── t11.png │ ├── t12.png │ ├── t13.png │ ├── t14.png │ ├── t15.png │ ├── t16.png │ ├── t17.png │ ├── t18.png │ ├── t19.png │ ├── t2.png │ ├── t20.png │ ├── t21.png │ ├── t22.png │ ├── t23.png │ ├── t24.png │ ├── t25.png │ ├── t26.png │ ├── t27.png │ ├── t28.png │ ├── t29.png │ ├── t3.png │ ├── t30.png │ ├── t31.png │ ├── t32.png │ ├── t33.png │ ├── t34.png │ ├── t35.png │ ├── t36.png │ ├── t37.png │ ├── t38.png │ ├── t39.png │ ├── t4.png │ ├── t40.png │ ├── t42.png │ ├── t43.png │ ├── t44.png │ ├── t45.png │ ├── t46.png │ ├── t47.png │ ├── t48.png │ ├── t49.png │ ├── t5.png │ ├── t50.png │ ├── t51.png │ ├── t52.png │ ├── t53.png │ ├── t54.png │ ├── t55.png │ ├── t56.png │ ├── t57.png │ ├── t58.png │ ├── t59.png │ ├── t6.png │ ├── t60.png │ ├── t61.png │ ├── t62.png │ ├── t63.png │ ├── t64.png │ ├── t65.png │ ├── t66.png │ ├── t7.png │ ├── t8.png │ ├── t9.png │ ├── tt1.png │ ├── tt2.png │ ├── tt3.png │ ├── tt4.png │ ├── tt5.png │ ├── tt6.png │ ├── tt7.png │ ├── tt8.png │ ├── tt9.png │ ├── tt10.png │ ├── tt12.png │ ├── tt13.png │ ├── tt14.png │ ├── tt15.png │ ├── tt16.png │ ├── tt17.png │ ├── tt18.png │ ├── tt19.png │ ├── tt20.png │ ├── tt21.png │ ├── tt22.png │ ├── tt23.png │ ├── tt24.png │ ├── tt25.png │ ├── tt26.png │ └── tt27.png ├── Set14 │ ├── face.bmp │ ├── man.bmp │ ├── ppt3.bmp │ ├── baboon.bmp │ ├── bridge.bmp │ ├── comic.bmp │ ├── lenna.bmp │ ├── pepper.bmp │ ├── zebra.bmp │ ├── barbara.bmp │ ├── flowers.bmp │ ├── foreman.bmp │ ├── monarch.bmp │ └── coastguard.bmp └── Set5 │ ├── baby_GT.bmp │ ├── bird_GT.bmp │ ├── head_GT.bmp │ ├── woman_GT.bmp │ └── butterfly_GT.bmp ├── niqe_pris_params.npz ├── SRResNet ├── result │ ├── comic.bmp │ ├── baboon.bmp │ ├── baby_GT.bmp │ ├── flowers.bmp │ ├── head_GT.bmp │ ├── woman_GT.bmp │ ├── coastguard.bmp │ ├── baboon_SRGAN_x4.bmp │ ├── comic_SRGAN_x4.bmp │ ├── baboon_SRResNet_x4.bmp │ ├── baboon_bicubic_x4.bmp │ ├── baby_GT_SRGAN_x4.bmp │ ├── baby_GT_bicubic_x4.bmp │ ├── comic_SRResNet_x4.bmp │ ├── comic_bicubic_x4.bmp │ ├── flowers_SRGAN_x4.bmp │ ├── flowers_bicubic_x4.bmp │ ├── head_GT_SRGAN_x4.bmp │ ├── head_GT_bicubic_x4.bmp │ ├── woman_GT_SRGAN_x4.bmp │ ├── baby_GT_SRResNet_x4.bmp │ ├── coastguard_SRGAN_x4.bmp │ ├── flowers_SRResNet_x4.bmp │ ├── head_GT_SRResNet_x4.bmp │ ├── woman_GT_SRResNet_x4.bmp │ ├── woman_GT_bicubic_x4.bmp │ ├── coastguard_SRResNet_x4.bmp │ └── coastguard_bicubic_x4.bmp ├── script.py ├── SRResNet_x4_DIV2Ksub_iter=400000.py ├── VGGLoss.py ├── SRGAN_x4_MSRA_DIV2Kaug_lr=e-4_batch=16_out=96.py ├── eval.py ├── train2.py ├── SRResNetdatasets.py ├── gen_datasets.py ├── train_SRGAN_WGAN.py ├── train_SRGAN_iter.py └── train.py ├── pth2csv.py ├── ViewLoss.py ├── SRCNN ├── SRCNN_x2.py ├── models.py ├── MatVsPy.py ├── SRCNNdatasets.py ├── gen_datasets.py ├── train.py └── test.py ├── FSRCNN ├── FSRCNN_x3_MSRA_T91_lr=e-1_batch=128_out=19_res.py ├── N2-10-4_x3_MSRA_T91res_lr=e-1_batch=128_Huber.py ├── FSRCNNdatasets.py ├── train.py ├── train_deconv.py ├── gen_datasets.py └── model.py ├── ESRGAN ├── ESRGAN3_x4_DFO_lr=e-4_batch=16_out=128.py ├── VGGLoss2.py ├── gen_datasets.py ├── demo_realesrgan.py ├── test_realesrgan.py ├── test_realesrgan_metrics.py ├── train_ESRGAN_iter2.py ├── train_ESRGAN_iter.py ├── ESRGANdatasets.py └── train.py ├── FFT.py ├── SRFlow ├── glow_arch.py ├── thops.py ├── Permutations.py ├── SRFlow_model.py ├── timer.py ├── SRFlow_DF2K_4X.yml ├── Split.py ├── module_util.py └── test_srflow.py ├── create_lmdb.py ├── HCFlow ├── thops.py ├── test_SR_DF2K_4X_HCFlow.yml ├── HCFlow_SR_model.py ├── FlowStep.py ├── networks.py ├── HCFlowNet_SR_arch.py ├── test_hcflow.py ├── loss.py ├── module_util.py └── ActNorms.py ├── test_metrics.py ├── SwinIR ├── test_classical_swinir.py └── test_swinir_metrics.py ├── BSRGAN ├── test_bsrgan.py ├── bsrgan_model.py └── test_bsrgan_metrics.py ├── metrics.py └── README.md /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /datasets/T91/t1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t1.png -------------------------------------------------------------------------------- /datasets/T91/t10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t10.png -------------------------------------------------------------------------------- /datasets/T91/t11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t11.png -------------------------------------------------------------------------------- /datasets/T91/t12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t12.png -------------------------------------------------------------------------------- /datasets/T91/t13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t13.png -------------------------------------------------------------------------------- /datasets/T91/t14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t14.png -------------------------------------------------------------------------------- /datasets/T91/t15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t15.png -------------------------------------------------------------------------------- /datasets/T91/t16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t16.png -------------------------------------------------------------------------------- /datasets/T91/t17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t17.png -------------------------------------------------------------------------------- /datasets/T91/t18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t18.png -------------------------------------------------------------------------------- /datasets/T91/t19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t19.png -------------------------------------------------------------------------------- /datasets/T91/t2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t2.png -------------------------------------------------------------------------------- /datasets/T91/t20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t20.png -------------------------------------------------------------------------------- /datasets/T91/t21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t21.png -------------------------------------------------------------------------------- /datasets/T91/t22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t22.png -------------------------------------------------------------------------------- /datasets/T91/t23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t23.png -------------------------------------------------------------------------------- /datasets/T91/t24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t24.png -------------------------------------------------------------------------------- /datasets/T91/t25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t25.png -------------------------------------------------------------------------------- /datasets/T91/t26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t26.png -------------------------------------------------------------------------------- /datasets/T91/t27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t27.png -------------------------------------------------------------------------------- /datasets/T91/t28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t28.png -------------------------------------------------------------------------------- /datasets/T91/t29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t29.png -------------------------------------------------------------------------------- /datasets/T91/t3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t3.png -------------------------------------------------------------------------------- /datasets/T91/t30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t30.png -------------------------------------------------------------------------------- /datasets/T91/t31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t31.png -------------------------------------------------------------------------------- /datasets/T91/t32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t32.png -------------------------------------------------------------------------------- /datasets/T91/t33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t33.png -------------------------------------------------------------------------------- /datasets/T91/t34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t34.png -------------------------------------------------------------------------------- /datasets/T91/t35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t35.png -------------------------------------------------------------------------------- /datasets/T91/t36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t36.png -------------------------------------------------------------------------------- /datasets/T91/t37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t37.png -------------------------------------------------------------------------------- /datasets/T91/t38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t38.png -------------------------------------------------------------------------------- /datasets/T91/t39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t39.png -------------------------------------------------------------------------------- /datasets/T91/t4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t4.png -------------------------------------------------------------------------------- /datasets/T91/t40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t40.png -------------------------------------------------------------------------------- /datasets/T91/t42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t42.png -------------------------------------------------------------------------------- /datasets/T91/t43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t43.png -------------------------------------------------------------------------------- /datasets/T91/t44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t44.png -------------------------------------------------------------------------------- /datasets/T91/t45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t45.png -------------------------------------------------------------------------------- /datasets/T91/t46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t46.png -------------------------------------------------------------------------------- /datasets/T91/t47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t47.png -------------------------------------------------------------------------------- /datasets/T91/t48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t48.png -------------------------------------------------------------------------------- /datasets/T91/t49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t49.png -------------------------------------------------------------------------------- /datasets/T91/t5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t5.png -------------------------------------------------------------------------------- /datasets/T91/t50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t50.png -------------------------------------------------------------------------------- /datasets/T91/t51.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t51.png -------------------------------------------------------------------------------- /datasets/T91/t52.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t52.png -------------------------------------------------------------------------------- /datasets/T91/t53.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t53.png -------------------------------------------------------------------------------- /datasets/T91/t54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t54.png -------------------------------------------------------------------------------- /datasets/T91/t55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t55.png -------------------------------------------------------------------------------- /datasets/T91/t56.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t56.png -------------------------------------------------------------------------------- /datasets/T91/t57.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t57.png -------------------------------------------------------------------------------- /datasets/T91/t58.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t58.png -------------------------------------------------------------------------------- /datasets/T91/t59.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t59.png -------------------------------------------------------------------------------- /datasets/T91/t6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t6.png -------------------------------------------------------------------------------- /datasets/T91/t60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t60.png -------------------------------------------------------------------------------- /datasets/T91/t61.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t61.png -------------------------------------------------------------------------------- /datasets/T91/t62.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t62.png -------------------------------------------------------------------------------- /datasets/T91/t63.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t63.png -------------------------------------------------------------------------------- /datasets/T91/t64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t64.png -------------------------------------------------------------------------------- /datasets/T91/t65.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t65.png -------------------------------------------------------------------------------- /datasets/T91/t66.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t66.png -------------------------------------------------------------------------------- /datasets/T91/t7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t7.png -------------------------------------------------------------------------------- /datasets/T91/t8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t8.png -------------------------------------------------------------------------------- /datasets/T91/t9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/t9.png -------------------------------------------------------------------------------- /datasets/T91/tt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt1.png -------------------------------------------------------------------------------- /datasets/T91/tt2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt2.png -------------------------------------------------------------------------------- /datasets/T91/tt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt3.png -------------------------------------------------------------------------------- /datasets/T91/tt4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt4.png -------------------------------------------------------------------------------- /datasets/T91/tt5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt5.png -------------------------------------------------------------------------------- /datasets/T91/tt6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt6.png -------------------------------------------------------------------------------- /datasets/T91/tt7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt7.png -------------------------------------------------------------------------------- /datasets/T91/tt8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt8.png -------------------------------------------------------------------------------- /datasets/T91/tt9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt9.png -------------------------------------------------------------------------------- /niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/niqe_pris_params.npz -------------------------------------------------------------------------------- /datasets/Set14/face.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/face.bmp -------------------------------------------------------------------------------- /datasets/Set14/man.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/man.bmp -------------------------------------------------------------------------------- /datasets/Set14/ppt3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/ppt3.bmp -------------------------------------------------------------------------------- /datasets/T91/tt10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt10.png -------------------------------------------------------------------------------- /datasets/T91/tt12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt12.png -------------------------------------------------------------------------------- /datasets/T91/tt13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt13.png -------------------------------------------------------------------------------- /datasets/T91/tt14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt14.png -------------------------------------------------------------------------------- /datasets/T91/tt15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt15.png -------------------------------------------------------------------------------- /datasets/T91/tt16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt16.png -------------------------------------------------------------------------------- /datasets/T91/tt17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt17.png -------------------------------------------------------------------------------- /datasets/T91/tt18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt18.png -------------------------------------------------------------------------------- /datasets/T91/tt19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt19.png -------------------------------------------------------------------------------- /datasets/T91/tt20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt20.png -------------------------------------------------------------------------------- /datasets/T91/tt21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt21.png -------------------------------------------------------------------------------- /datasets/T91/tt22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt22.png -------------------------------------------------------------------------------- /datasets/T91/tt23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt23.png -------------------------------------------------------------------------------- /datasets/T91/tt24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt24.png -------------------------------------------------------------------------------- /datasets/T91/tt25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt25.png -------------------------------------------------------------------------------- /datasets/T91/tt26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt26.png -------------------------------------------------------------------------------- /datasets/T91/tt27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/T91/tt27.png -------------------------------------------------------------------------------- /SRResNet/result/comic.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic.bmp -------------------------------------------------------------------------------- /datasets/Set14/baboon.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/baboon.bmp -------------------------------------------------------------------------------- /datasets/Set14/bridge.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/bridge.bmp -------------------------------------------------------------------------------- /datasets/Set14/comic.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/comic.bmp -------------------------------------------------------------------------------- /datasets/Set14/lenna.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/lenna.bmp -------------------------------------------------------------------------------- /datasets/Set14/pepper.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/pepper.bmp -------------------------------------------------------------------------------- /datasets/Set14/zebra.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/zebra.bmp -------------------------------------------------------------------------------- /datasets/Set5/baby_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/baby_GT.bmp -------------------------------------------------------------------------------- /datasets/Set5/bird_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/bird_GT.bmp -------------------------------------------------------------------------------- /datasets/Set5/head_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/head_GT.bmp -------------------------------------------------------------------------------- /SRResNet/result/baboon.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon.bmp -------------------------------------------------------------------------------- /SRResNet/result/baby_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT.bmp -------------------------------------------------------------------------------- /SRResNet/result/flowers.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers.bmp -------------------------------------------------------------------------------- /SRResNet/result/head_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT.bmp -------------------------------------------------------------------------------- /SRResNet/result/woman_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT.bmp -------------------------------------------------------------------------------- /datasets/Set14/barbara.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/barbara.bmp -------------------------------------------------------------------------------- /datasets/Set14/flowers.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/flowers.bmp -------------------------------------------------------------------------------- /datasets/Set14/foreman.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/foreman.bmp -------------------------------------------------------------------------------- /datasets/Set14/monarch.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/monarch.bmp -------------------------------------------------------------------------------- /datasets/Set5/woman_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/woman_GT.bmp -------------------------------------------------------------------------------- /SRResNet/result/coastguard.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard.bmp -------------------------------------------------------------------------------- /datasets/Set14/coastguard.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set14/coastguard.bmp -------------------------------------------------------------------------------- /datasets/Set5/butterfly_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/datasets/Set5/butterfly_GT.bmp -------------------------------------------------------------------------------- /SRResNet/result/baboon_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/comic_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/baboon_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/baboon_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baboon_bicubic_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/baby_GT_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/baby_GT_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT_bicubic_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/comic_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/comic_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/comic_bicubic_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/flowers_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/flowers_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers_bicubic_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/head_GT_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/head_GT_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT_bicubic_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/woman_GT_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/baby_GT_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/baby_GT_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/coastguard_SRGAN_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard_SRGAN_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/flowers_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/flowers_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/head_GT_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/head_GT_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/woman_GT_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/woman_GT_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/woman_GT_bicubic_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/coastguard_SRResNet_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard_SRResNet_x4.bmp -------------------------------------------------------------------------------- /SRResNet/result/coastguard_bicubic_x4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sadisticheaven/SuperRestoration/HEAD/SRResNet/result/coastguard_bicubic_x4.bmp -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/SuperRestoration.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /pth2csv.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import torch 4 | 5 | csvFile = open("./SRCNN_x3_lr=1e-02_batch=128.csv", 'w', newline='') 6 | 7 | try: 8 | writer = csv.writer(csvFile) 9 | writer.writerow(('epoch', 'loss', 'psnr')) 10 | 11 | pthDir = './SRCNN/weight_file/9-1-5_python2/x3/' 12 | pthList = os.listdir(pthDir) 13 | if 'best.pth' in pthList: 14 | pthList.remove('best.pth') 15 | for pthName in pthList: 16 | pth = torch.load(pthDir + pthName) 17 | writer.writerow((pth['epoch'], pth['loss'], pth['psnr'])) 18 | finally: 19 | csvFile.close() 20 | -------------------------------------------------------------------------------- /ViewLoss.py: -------------------------------------------------------------------------------- 1 | from mpl_toolkits.mplot3d import Axes3D 2 | import matplotlib.pyplot as plt 3 | from matplotlib import cm 4 | from matplotlib.ticker import LinearLocator, FormatStrFormatter 5 | import numpy as np 6 | 7 | fig = plt.figure() 8 | ax = fig.gca(projection='3d') 9 | 10 | # Make data. 11 | X = np.arange(0.01, 1, 0.01) 12 | Y = np.arange(0.01, 1, 0.01) 13 | X, Y = np.meshgrid(X, Y) 14 | Z = -0.5*(np.log(1-X)+np.log(Y)) 15 | 16 | # Plot the surface. 17 | surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, 18 | linewidth=0, antialiased=False) 19 | # Add a color bar which maps values to colors. 20 | fig.colorbar(surf, shrink=0.5, aspect=5) 21 | plt.show() 22 | -------------------------------------------------------------------------------- /SRCNN/SRCNN_x2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from train import train_model 3 | import os 4 | if __name__ == '__main__': 5 | program = os.path.basename(sys.argv[0]).split('.')[0] 6 | scale = 2 7 | config = {'train_file': f'../datasets/T91_aug_train_SRCNNx{scale}.h5', 8 | # config = {'train_file': 'G:/Document/datasets/T91_aug_x3.h5', 9 | 'val_file': f'../datasets/Set5_val_SRCNNx{scale}.h5', 10 | 'outputs_dir': f'./weight_file/{program}/x{scale}/', 11 | 'csv_name': f'{program}.csv', 12 | 'scale': scale, 13 | 'lr': 1e-2, 14 | 'batch_size': 128, 15 | 'num_epochs': 1000, 16 | 'num_workers': 2, 17 | 'seed': 123, 18 | 'weight_file': f'./weight_file/{program}/x{scale}/latest.pth', 19 | 'Gpu': '0', 20 | 'residual': False, 21 | 'auto_lr': False 22 | } 23 | train_model(config, from_pth=True, useVisdom=True) -------------------------------------------------------------------------------- /SRResNet/script.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.backends import cudnn 3 | from model import G, D 4 | import os 5 | 6 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 7 | if __name__ == '__main__': 8 | config = {'weight_file': './weight_file/SRGAN7_x4_DIVsub_WGAN/', 9 | 'scale': 4 10 | } 11 | scale = config['scale'] 12 | padding = scale 13 | # weight_file = config['weight_file'] + f'best.pth' 14 | weight_file = config['weight_file'] + f'x{scale}/latest.pth' 15 | if not os.path.exists(weight_file): 16 | print(f'Weight file not exist!\n{weight_file}\n') 17 | raise "Error" 18 | 19 | cudnn.benchmark = True 20 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 21 | 22 | checkpoint = torch.load(weight_file) 23 | 24 | disc = D().to(device) 25 | disc = torch.nn.DataParallel(disc) 26 | disc.load_state_dict(checkpoint['disc']) 27 | disc = disc.module 28 | checkpoint['disc'] = disc.state_dict() 29 | torch.save(checkpoint, weight_file) 30 | -------------------------------------------------------------------------------- /SRResNet/SRResNet_x4_DIV2Ksub_iter=400000.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from train2 import train_model 4 | if __name__ == '__main__': 5 | program = os.path.basename(sys.argv[0]).split('.')[0] 6 | scale = 4 7 | label_size = 96 8 | config = {'train_file': f'../datasets/DIV2K_train_HR/', 9 | 'val_file': f'../datasets/Set5_label={label_size}_val_SRResNetx{scale}.h5', 10 | 'outputs_dir': f'./weight_file/{program}/x{scale}/', 11 | 'logs_dir': f'./logs/{program}/', 12 | 'csv_name': f'{program}.csv', 13 | 'weight_file': f'./weight_file/{program}/x{scale}/latest.pth', 14 | 'scale': scale, 15 | 'out_size': label_size, 16 | 'lr': 1e-4, 17 | 'batch_size': 16, 18 | 'num_steps': 400000, 19 | 'num_workers': 16, 20 | 'seed': 123, 21 | 'init': 'MSRA', 22 | 'Gpu': '0', 23 | 'auto_lr': False, 24 | 'milestone': [1000] 25 | } 26 | 27 | train_model(config, from_pth=False, useVisdom=False) -------------------------------------------------------------------------------- /FSRCNN/FSRCNN_x3_MSRA_T91_lr=e-1_batch=128_out=19_res.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from train import train_model 4 | if __name__ == '__main__': 5 | program = os.path.basename(sys.argv[0]).split('.')[0] 6 | scale = 3 7 | label_size = 19 8 | config = {'train_file': f'../datasets/T91_aug_label={label_size}_train_FSRCNNx{scale}_res.h5', 9 | 'val_file': f'../datasets/Set5_label={label_size}_val_FSRCNNx{scale}_res.h5', 10 | 'outputs_dir': f'../weight_file/{program}/', 11 | 'logs_dir': f'../logs/{program}/', 12 | 'csv_name': f'{program}.csv', 13 | 'weight_file': f'./weight_file/{program}/latest.pth', 14 | 'scale': scale, 15 | 'in_size': 11, 16 | 'out_size': label_size, 17 | 'd': 56, 18 | 's': 12, 19 | 'm': 4, 20 | 'lr': 1e-1, 21 | 'batch_size': 128, 22 | 'num_epochs': 100000, 23 | 'num_workers': 4, 24 | 'seed': 123, 25 | 'init': 'MSRA', 26 | 'Gpu': '0', 27 | 'auto_lr': False, 28 | 'residual': True 29 | } 30 | 31 | train_model(config, from_pth=True) -------------------------------------------------------------------------------- /SRResNet/VGGLoss.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | from torchvision.models import vgg19 5 | import torch 6 | from torchvision.models import vgg as vgg 7 | # phi_5,4 5th conv layer before maxpooling but after activation 8 | 9 | class VGGLoss(nn.Module): 10 | def __init__(self, device, feature_loss=nn.MSELoss()): 11 | super().__init__() 12 | print('===> Loading VGG model') 13 | netVGG = vgg19() 14 | netVGG.load_state_dict(torch.load('../VGG19/vgg19-dcbb9e9d.pth')) 15 | self.vgg = netVGG.features[:36].eval().to(device) # VGG54, before activation 16 | self.loss = feature_loss 17 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 18 | self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 19 | for param in self.vgg.parameters(): 20 | param.requires_grad = False 21 | 22 | def forward(self, pred, target): 23 | pred = (pred - self.mean) / self.std 24 | target = (target - self.mean) / self.std 25 | vgg_input_features = self.vgg(pred) 26 | vgg_target_features = self.vgg(target) 27 | return self.loss(vgg_input_features, vgg_target_features) 28 | -------------------------------------------------------------------------------- /SRResNet/SRGAN_x4_MSRA_DIV2Kaug_lr=e-4_batch=16_out=96.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from train_SRGAN import train_model 4 | if __name__ == '__main__': 5 | program = os.path.basename(sys.argv[0]).split('.')[0] 6 | scale = 4 7 | label_size = 96 8 | config = {'train_file': f'../datasets/test/', 9 | 'val_file': f'../datasets/Set5_label={label_size}_val_SRResNetx{scale}.h5', 10 | 'outputs_dir': f'./weight_file/{program}/x{scale}/', 11 | 'csv_name': f'{program}.csv', 12 | 'weight_file': f'./weight_file/{program}/x{scale}/latest.pth', 13 | 'logs_dir': f'./logs/{program}/', 14 | 'scale': scale, 15 | 'in_size': 24, 16 | 'out_size': label_size, 17 | 'gen_lr': 1e-4, 18 | 'disc_lr': 1e-4, 19 | 'batch_size': 5, 20 | 'num_epochs': 10000, 21 | 'num_workers': 0, 22 | 'seed': 123, 23 | 'init': 'MSRA', 24 | 'Gpu': '0', 25 | 'gen_k': 1, 26 | 'disc_k': 1, 27 | } 28 | 29 | train_model(config, pre_train=None, from_pth=False) 30 | # train_model(config, pre_train='./best.pth', from_pth=False, use_visdom=False) -------------------------------------------------------------------------------- /ESRGAN/ESRGAN3_x4_DFO_lr=e-4_batch=16_out=128.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from train_ESRGAN import train_model 4 | if __name__ == '__main__': 5 | program = os.path.basename(sys.argv[0]).split('.')[0] 6 | scale = 4 7 | label_size = 128 8 | config = {'val_file': f'../datasets/Set5_label={label_size}_val_ESRGANx{scale}.h5', 9 | 'outputs_dir': f'../weight_file/{program}/x{scale}/', 10 | 'csv_name': f'{program}.csv', 11 | 'weight_file': f'../weight_file/{program}/x{scale}/latest.pth', 12 | 'logs_dir': f'../logs/{program}/', 13 | 'scale': scale, 14 | 'in_size': 32, 15 | 'out_size': label_size, 16 | 'gen_lr': 1e-4, 17 | 'disc_lr': 1e-4, 18 | 'batch_size': 16, 19 | 'num_epochs': 10000, 20 | 'num_workers': 32, 21 | 'seed': 0, 22 | 'Gpu': '0', 23 | 'auto_lr': True, 24 | 'gen_k': 1, 25 | 'disc_k': 1, 26 | 'adversarial_weight': 5e-3, 27 | 'pixel_weight': 1e-2 28 | } 29 | 30 | train_model(config, pre_train=None, from_pth=False) 31 | # train_model(config, pre_train='./best.pth', from_pth=False) -------------------------------------------------------------------------------- /SRCNN/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SRCNN(nn.Module): 5 | def __init__(self, padding=False, num_channels=1): 6 | super(SRCNN, self).__init__() 7 | self.conv1 = nn.Sequential(nn.Conv2d(num_channels, 64, kernel_size=9, padding=4*int(padding), padding_mode='replicate'), 8 | nn.ReLU(inplace=True)) 9 | self.conv2 = nn.Sequential(nn.Conv2d(64, 32, kernel_size=1, padding=0), # n1 * 1 * 1 * n2 10 | nn.ReLU(inplace=True)) 11 | self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=2*int(padding), padding_mode='replicate') 12 | 13 | def forward(self, x): 14 | x = self.conv1(x) 15 | x = self.conv2(x) 16 | x = self.conv3(x) 17 | return x 18 | 19 | def init_weights(self): 20 | for L in self.conv1: 21 | if isinstance(L, nn.Conv2d): 22 | L.weight.data.normal_(mean=0.0, std=0.001) 23 | L.bias.data.zero_() 24 | for L in self.conv2: 25 | if isinstance(L, nn.Conv2d): 26 | L.weight.data.normal_(mean=0.0, std=0.001) 27 | L.bias.data.zero_() 28 | self.conv3.weight.data.normal_(mean=0.0, std=0.001) 29 | self.conv3.bias.data.zero_() 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /FFT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import cv2 6 | import sys 7 | from PIL import Image 8 | from imresize import imresize 9 | import utils 10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 11 | 12 | 13 | img = Image.open('./datasets/test/t1.png').convert('RGB') 14 | img = np.array(img).astype(np.float32) 15 | img = utils.rgb2ycbcr(img)[..., 0] 16 | # img = cv2.imread('./datasets/test/t1.png', 0) 17 | bic = imresize(img, 0.5, 'bicubic') 18 | 19 | dft = cv2.dft(np.float32(img), flags=cv2.DFT_COMPLEX_OUTPUT) 20 | dft_shift = np.fft.fftshift(dft) 21 | magnitude_spectrum = 20 * np.log(cv2.magnitude(dft_shift[:, :, 0], dft_shift[:, :, 1])) 22 | 23 | dft2 = cv2.dft(np.float32(bic), flags=cv2.DFT_COMPLEX_OUTPUT) 24 | dft2 = np.fft.fftshift(dft2) 25 | dft2 = 20 * np.log(cv2.magnitude(dft2[:, :, 0], dft2[:, :, 1])) 26 | 27 | plt.subplot(221), plt.imshow(img, cmap='gray') 28 | plt.title('Input Image'), plt.xticks([]), plt.yticks([]) 29 | plt.subplot(222), plt.imshow(magnitude_spectrum, cmap='gray') 30 | plt.title('Magnitude Spectrum'), plt.xticks([]), plt.yticks([]) 31 | plt.subplot(223), plt.imshow(bic, cmap='gray') 32 | plt.title('bicubic'), plt.xticks([]), plt.yticks([]) 33 | plt.subplot(224), plt.imshow(dft2, cmap='gray') 34 | plt.title('bicubic_fft'), plt.xticks([]), plt.yticks([]) 35 | plt.show() 36 | -------------------------------------------------------------------------------- /FSRCNN/N2-10-4_x3_MSRA_T91res_lr=e-1_batch=128_Huber.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from train_N2_10_4 import train_model 4 | if __name__ == '__main__': 5 | program = os.path.basename(sys.argv[0]).split('.')[0] 6 | scale = 3 7 | label_size = 19 8 | config = {'train_file': f'../datasets/T91_aug_label={label_size}_train_FSRCNNx{scale}_res.h5', 9 | 'val_file': f'../datasets/Set5_label={label_size}_val_FSRCNNx{scale}_res.h5', 10 | 'outputs_dir': f'./weight_file/{program}/', 11 | 'logs_dir': f'./weight_file/{program}/', 12 | 'csv_name': f'{program}.csv', 13 | 'weight_file': f'./weight_file/{program}/latest.pth', 14 | 'scale': scale, 15 | 'in_size': 11, 16 | 'out_size': label_size, 17 | 'd': 10, 18 | 'm': 4, 19 | 'lr': 1e-1, 20 | 'step_size': 1500, 21 | 'gamma': 0.1, 22 | 'weight_decay': 0, 23 | 'batch_size': 128, 24 | 'num_epochs': 50000, 25 | 'num_workers': 4, 26 | 'seed': 123, 27 | 'init': 'MSRA', 28 | 'Gpu': '0', 29 | 'auto_lr': False, 30 | 'residual': True, 31 | 'Loss': 'Huber', 32 | 'delta': 0.8 33 | } 34 | 35 | train_model(config, from_pth=False) -------------------------------------------------------------------------------- /SRFlow/glow_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch.nn as nn 18 | 19 | 20 | def f_conv2d_bias(in_channels, out_channels): 21 | def padding_same(kernel, stride): 22 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)] 23 | 24 | padding = padding_same([3, 3], [1, 1]) 25 | assert padding == [1, 1], padding 26 | return nn.Sequential( 27 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1, 28 | bias=True)) 29 | -------------------------------------------------------------------------------- /ESRGAN/VGGLoss2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import vgg19 3 | import torch 4 | import torch.nn.functional as F 5 | # phi_5,4 5th conv layer before maxpooling but after activation 6 | 7 | class VGGLoss(nn.Module): 8 | def __init__(self, device): 9 | super().__init__() 10 | print('===> Loading VGG model') 11 | netVGG = vgg19() 12 | netVGG.load_state_dict(torch.load('./VGG19/vgg19-dcbb9e9d.pth')) 13 | self.vgg = netVGG.features[:35].eval().to(device) # VGG54, before activation 14 | self.loss = nn.L1Loss() 15 | 16 | for param in self.vgg.parameters(): 17 | param.requires_grad = False 18 | 19 | # # The preprocessing method of the input data. This is the preprocessing method of the VGG model on the ImageNet dataset. 20 | # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 21 | # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 22 | 23 | def forward(self, input, target): 24 | # input = (input - self.mean) / self.std 25 | # target = (target - self.mean) / self.std 26 | vgg_input_features = self.vgg(input) 27 | vgg_target_features = self.vgg(target) 28 | return self.loss(vgg_input_features, vgg_target_features) 29 | 30 | 31 | if __name__ == '__main__': 32 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 33 | v = VGGLoss(device) 34 | -------------------------------------------------------------------------------- /create_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import lmdb 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | root_dirs = ['./datasets/Set5/'] 8 | lmdb_save_path = './datasets/Set5.lmdb' 9 | # 创建数据库文件 10 | img_list = [] 11 | for root_dir in root_dirs: 12 | img_names = os.listdir(root_dir) 13 | for name in img_names: 14 | img_list.append(root_dir + name) 15 | dataset = [] 16 | data_size = 0 17 | 18 | print('Read images...') 19 | with tqdm(total=len(img_list)) as t: 20 | for i, v in enumerate(img_list): 21 | img = np.array(Image.open(v).convert('RGB')) 22 | dataset.append(img) 23 | data_size += img.nbytes 24 | t.update(1) 25 | env = lmdb.open(lmdb_save_path, max_dbs=2, map_size=data_size * 2) 26 | print('Finish reading {} images.\nWrite lmdb...'.format(len(img_list))) 27 | # 创建对应的数据库 28 | train_data = env.open_db("train_data".encode('ascii')) 29 | train_shape = env.open_db("train_shape".encode('ascii')) 30 | # 把图像数据写入到LMDB中 31 | with env.begin(write=True) as txn: 32 | with tqdm(total=len(dataset)) as t: 33 | for idx, img in enumerate(dataset): 34 | H, W, C = img.shape 35 | txn.put(str(idx).encode('ascii'), img, db=train_data) 36 | meta_key = (str(idx) + '.meta').encode('ascii') 37 | meta = '{:d}, {:d}, {:d}'.format(H, W, C) 38 | txn.put(meta_key, meta.encode('ascii'), db=train_shape) 39 | t.update(1) 40 | env.close() 41 | print('Finish writing lmdb.') -------------------------------------------------------------------------------- /HCFlow/thops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sum(tensor, dim=None, keepdim=False): 5 | if dim is None: 6 | # sum up all dim 7 | return torch.sum(tensor) 8 | else: 9 | if isinstance(dim, int): 10 | dim = [dim] 11 | dim = sorted(dim) 12 | for d in dim: 13 | tensor = tensor.sum(dim=d, keepdim=True) 14 | if not keepdim: 15 | for i, d in enumerate(dim): 16 | tensor.squeeze_(d-i) 17 | return tensor 18 | 19 | 20 | def mean(tensor, dim=None, keepdim=False): 21 | if dim is None: 22 | # mean all dim 23 | return torch.mean(tensor) 24 | else: 25 | if isinstance(dim, int): 26 | dim = [dim] 27 | dim = sorted(dim) 28 | for d in dim: 29 | tensor = tensor.mean(dim=d, keepdim=True) 30 | if not keepdim: 31 | for i, d in enumerate(dim): 32 | tensor.squeeze_(d-i) 33 | return tensor 34 | 35 | 36 | 37 | def split_feature(tensor, type="split"): 38 | """ 39 | type = ["split", "cross"] 40 | """ 41 | C = tensor.size(1) 42 | if type == "split": 43 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] 44 | elif type == "cross": 45 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 46 | 47 | 48 | def cat_feature(tensor_a, tensor_b): 49 | return torch.cat((tensor_a, tensor_b), dim=1) 50 | 51 | 52 | def pixels(tensor): 53 | return int(tensor.size(2) * tensor.size(3)) -------------------------------------------------------------------------------- /FSRCNN/FSRCNNdatasets.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TrainDataset(Dataset): 7 | def __init__(self, h5_file): 8 | super(TrainDataset, self).__init__() 9 | self.h5_file = h5_file 10 | 11 | def __getitem__(self, idx): 12 | with h5py.File(self.h5_file, 'r') as f: 13 | return np.expand_dims(f['data'][idx] / 255., 0), np.expand_dims(f['label'][idx] / 255., 0) 14 | 15 | def __len__(self): 16 | with h5py.File(self.h5_file, 'r') as f: 17 | return len(f['data']) 18 | 19 | 20 | class ValDataset(Dataset): 21 | def __init__(self, h5_file): 22 | super(ValDataset, self).__init__() 23 | self.h5_file = h5_file 24 | 25 | def __getitem__(self, idx): 26 | with h5py.File(self.h5_file, 'r') as f: 27 | return np.expand_dims(f['data'][str(idx)][:, :] / 255., 0), np.expand_dims(f['label'][str(idx)][:, :] / 255., 0) 28 | 29 | def __len__(self): 30 | with h5py.File(self.h5_file, 'r') as f: 31 | return len(f['data']) 32 | 33 | 34 | class ResValDataset(Dataset): 35 | def __init__(self, h5_file): 36 | super(ResValDataset, self).__init__() 37 | self.h5_file = h5_file 38 | 39 | def __getitem__(self, idx): 40 | with h5py.File(self.h5_file, 'r') as f: 41 | return np.expand_dims(f['data'][str(idx)][:, :] / 255., 0), \ 42 | np.expand_dims(f['label'][str(idx)][:, :] / 255., 0),\ 43 | np.expand_dims(f['bicubic'][str(idx)][:, :] / 255., 0) 44 | 45 | def __len__(self): 46 | with h5py.File(self.h5_file, 'r') as f: 47 | return len(f['data']) 48 | -------------------------------------------------------------------------------- /SRCNN/MatVsPy.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | from scipy.io import loadmat 4 | import utils 5 | import numpy as np 6 | import transplant 7 | import os 8 | from PIL import Image 9 | import imresize 10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 11 | 12 | 13 | 14 | 15 | if __name__ == '__main__': 16 | # h5_path = '../datasets/Matlab_Set5_val_SRCNNx3.h5' 17 | # with h5py.File(h5_path, 'r') as f: 18 | # len = len(f['data']) 19 | # for idx in range(len): 20 | # input = f['data'][idx] 21 | # label = f['label'][idx] 22 | 23 | # m = loadmat("G:/Document/文献资料/超分辨率/SRCNN_train/SRCNN/model/9-1-5(91 images)/x2.mat") 24 | # x = m['weights_conv1'] 25 | # x = np.reshape(x, (9, 9, 64)) 26 | # ax = utils.viz_layer2(x, 64) 27 | 28 | hrfile = 'G:/Document/pythonProj/SuperRestoration/datasets/T91_aug/t10_0_x0.6.png' 29 | hrfile2 = 'G:/Document/pythonProj/SuperRestoration/datasets/T91_augt10_rot0_s6.bmp' 30 | matlab = transplant.Matlab(executable='G:/Software/Matlab2020b/bin/matlab.exe') 31 | hr_py = Image.open(hrfile).convert('RGB') # RGBA->RGB 32 | hr_py = np.array(hr_py) 33 | hr_mat = matlab.imread(hrfile) 34 | hr_mat = hr_mat[0][:, :, :] 35 | diff = hr_py - hr_mat # 无差别 36 | hr_mat = matlab.rgb2ycbcr(hr_py) 37 | hr_mat_y = hr_mat[:, :, 0] 38 | hr_py_ycbcr = utils.rgb2ycbcr(hr_py) 39 | hr_py_ycbcr = hr_py_ycbcr.astype(np.uint8) 40 | hr_py_y = hr_py_ycbcr[:, :, 0] 41 | diff = hr_mat - hr_py_ycbcr # 有很大差别,所以不能使用python的ycbcr 42 | 43 | hr_py = Image.open(hrfile).convert('RGB') 44 | hr_mat = matlab.imread(hrfile) 45 | hr_mat = hr_mat[0][:, :, :] 46 | scale = 3 47 | hr_py = np.array(hr_py).astype(np.uint8) 48 | lr_mat = matlab.imresize(hr_py, 1 / scale, 'bicubic')[0] 49 | lr_mat2 = imresize.imresize(hr_py, 1 / scale) 50 | diff = lr_mat2 - lr_mat 51 | 52 | 53 | -------------------------------------------------------------------------------- /test_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from tqdm import tqdm 3 | import utils 4 | import metrics 5 | import os 6 | 7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 8 | if __name__ == '__main__': 9 | model_name = 'RealESRGAN' 10 | scale = 4 11 | itplt = 'bilinear' 12 | # itplt = 'bicubic' 13 | dataset = f'degraded4_{itplt}_heavy' 14 | # dataset = f'degraded5_{itplt}_medium' 15 | # dataset = f'degraded6_{itplt}_slight' 16 | 17 | sr_root_dir_degradation = f'./test_res/{model_name}_{dataset}/sort_with_degradation/' 18 | gt_dir = './datasets/PIPAL/' 19 | deg_types = os.listdir(sr_root_dir_degradation) 20 | 21 | csv_file = f'./test_res/{model_name}_{dataset}.csv' 22 | csv_file = open(csv_file, 'w', newline='') 23 | writer = csv.writer(csv_file) 24 | writer.writerow(('name', 'psnr', 'niqe', 'ssim', 'lpips')) 25 | for deg_type in deg_types: 26 | sr_dir = sr_root_dir_degradation + deg_type + '/' 27 | sr_lists = os.listdir(sr_dir) 28 | Avg_psnr = utils.AverageMeter() 29 | Avg_niqe = utils.AverageMeter() 30 | Avg_ssim = utils.AverageMeter() 31 | Avg_lpips = utils.AverageMeter() 32 | with tqdm(total=len(sr_lists)) as t: 33 | t.set_description(f"Processing: {deg_type}") 34 | for imgName in sr_lists: 35 | SR = utils.loadIMG_crop(sr_dir + imgName, scale) 36 | GT = utils.loadIMG_crop(gt_dir + imgName, scale) 37 | my_metrics = metrics.calc_metric(SR, GT, ['lpips']) 38 | Avg_lpips.update(my_metrics['lpips'], 1) 39 | # Avg_psnr.update(my_metrics['psnr'], 1) 40 | # Avg_niqe.update(my_metrics['niqe'], 1) 41 | # Avg_ssim.update(my_metrics['ssim'], 1) 42 | t.update(1) 43 | # writer.writerow((deg_type, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg, Avg_lpips.avg)) 44 | -------------------------------------------------------------------------------- /SRCNN/SRCNNdatasets.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | 6 | class TrainDataset(Dataset): 7 | def __init__(self, h5_file): 8 | super(TrainDataset, self).__init__() 9 | self.h5_file = h5_file 10 | 11 | def __getitem__(self, idx): 12 | with h5py.File(self.h5_file, 'r') as f: 13 | return np.expand_dims(f['data'][idx] / 255., 0), np.expand_dims(f['label'][idx] / 255., 0) 14 | 15 | def __len__(self): 16 | with h5py.File(self.h5_file, 'r') as f: 17 | return len(f['data']) 18 | 19 | 20 | class ValDataset(Dataset): 21 | def __init__(self, h5_file): 22 | super(ValDataset, self).__init__() 23 | self.h5_file = h5_file 24 | 25 | def __getitem__(self, idx): 26 | with h5py.File(self.h5_file, 'r') as f: 27 | return np.expand_dims(f['data'][str(idx)][:, :] / 255., 0), np.expand_dims(f['label'][str(idx)][:, :] / 255., 0) 28 | 29 | def __len__(self): 30 | with h5py.File(self.h5_file, 'r') as f: 31 | return len(f['data']) 32 | 33 | 34 | class MatlabTrainDataset(Dataset): 35 | def __init__(self, h5_file): 36 | super(MatlabTrainDataset, self).__init__() 37 | self.h5_file = h5_file 38 | 39 | def __getitem__(self, idx): 40 | with h5py.File(self.h5_file, 'r') as f: 41 | return f['data'][idx], f['label'][idx] 42 | 43 | def __len__(self): 44 | with h5py.File(self.h5_file, 'r') as f: 45 | return len(f['data']) 46 | 47 | 48 | class MatlabValidDataset(Dataset): 49 | def __init__(self, h5_file): 50 | super(MatlabValidDataset, self).__init__() 51 | self.h5_file = h5_file 52 | 53 | def __getitem__(self, idx): 54 | with h5py.File(self.h5_file, 'r') as f: 55 | return f['data'][idx], f['label'][idx] 56 | 57 | def __len__(self): 58 | with h5py.File(self.h5_file, 'r') as f: 59 | return len(f['data']) -------------------------------------------------------------------------------- /ESRGAN/gen_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | import utils 5 | from imresize import imresize 6 | from PIL import Image 7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 8 | 9 | 10 | def gen_valdata(config): 11 | scale = config["scale"] 12 | size_output = config['size_output'] 13 | method = config['method'] 14 | h5savepath = config["hrDir"] + f'_label={size_output}_val_ESRGANx{scale}.h5' 15 | hrDir = config["hrDir"] + '/' 16 | h5_file = h5py.File(h5savepath, 'w') 17 | lr_group = h5_file.create_group('data') 18 | hr_group = h5_file.create_group('label') 19 | imgList = os.listdir(hrDir) 20 | for i, imgName in enumerate(imgList): 21 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale).convert('RGB') 22 | hr = utils.img2ycbcr(hrIMG, gray2rgb=True).astype(np.float32) 23 | lr = imresize(np.array(hrIMG).astype(np.float32), 1 / scale, method) 24 | 25 | data = lr.astype(np.float32).transpose([2, 0, 1]) 26 | label = hr.transpose([2, 0, 1]) 27 | 28 | lr_group.create_dataset(str(i), data=data) 29 | hr_group.create_dataset(str(i), data=label) 30 | h5_file.close() 31 | 32 | 33 | def scale_OST(): 34 | # some OST image has a size lower than 128*128, resize it with scale 2 35 | root_dir = '../datasets/OST/' 36 | 37 | img_names = os.listdir(root_dir) 38 | for name in img_names: 39 | img_path = root_dir + name 40 | image = Image.open(img_path) 41 | 42 | if image.width < 128 or image.height < 128: 43 | print(img_path) 44 | image.save('../datasets/OST_backup/' + name) 45 | image = np.array(image) 46 | image = imresize(image, 2, 'bicubic') 47 | image = Image.fromarray(image.astype(np.uint8)) 48 | image.save(img_path) 49 | 50 | 51 | if __name__ == '__main__': 52 | config = {'hrDir': '../datasets/Set14', 'scale': 4, 'size_output': 128, 'method': 'bicubic'} 53 | gen_valdata(config) 54 | -------------------------------------------------------------------------------- /SRResNet/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.backends import cudnn 4 | import utils 5 | from PIL import Image 6 | from imresize import imresize 7 | import os 8 | import niqe 9 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 10 | if __name__ == '__main__': 11 | model_name = 'ESRGAN' 12 | rootdir = '../datasets/' 13 | gnd_data = 'BSDS100/' 14 | test_data = f'SRGAN_official/{gnd_data}' 15 | gnd_dir = rootdir + gnd_data 16 | test_dir = rootdir + test_data 17 | scale = 4 18 | padding = scale 19 | 20 | 21 | cudnn.benchmark = True 22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | imglist = os.listdir(gnd_dir) 26 | testlist = os.listdir(test_dir) 27 | Avg_psnr = utils.AverageMeter() 28 | Avg_niqe = utils.AverageMeter() 29 | for idx, imgName in enumerate(imglist): 30 | image = utils.loadIMG_crop(gnd_dir + imgName, scale) 31 | SR = utils.loadIMG_crop(test_dir + testlist[idx], scale) 32 | img_mode = image.mode 33 | if img_mode == 'L': 34 | gray_img = np.array(image) 35 | hr_image = np.array(image.convert('RGB')).astype(np.float32) 36 | hr_image = hr_image[padding: -padding, padding: -padding, ...] 37 | 38 | SR = np.array(SR).astype(np.float32) 39 | SR = SR[padding: -padding, padding: -padding, ...] 40 | if img_mode != 'L': 41 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255. 42 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0]/255. 43 | else: 44 | gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...] 45 | hr_y = gray_img / 255. 46 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L') 47 | SR_y = np.array(SR).astype(np.float32) / 255. 48 | psnr = utils.calc_psnr(hr_y, SR_y) 49 | NIQE = niqe.calculate_niqe(SR_y) 50 | Avg_psnr.update(psnr, 1) 51 | Avg_niqe.update(NIQE, 1) 52 | print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}'.format(psnr.item(), NIQE)) 53 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg)) -------------------------------------------------------------------------------- /ESRGAN/demo_realesrgan.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import numpy as np 4 | import torch 5 | from torch.backends import cudnn 6 | import niqe 7 | import utils 8 | from model import G, G2 9 | from PIL import Image 10 | import os 11 | 12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 13 | if __name__ == '__main__': 14 | model_name = 'RealESRGAN' 15 | weight_file = '../weight_file/RealESRGAN_x4plus.pth' 16 | img_dir = '../datasets/Real-ESRGAN_input/' 17 | outputs_dir = './test_res/realESRGAN_x4/' 18 | utils.mkdirs(outputs_dir) 19 | scale = 4 20 | padding = scale 21 | 22 | if not os.path.exists(weight_file): 23 | print(f'Weight file not exist!\n{weight_file}\n') 24 | raise "Error" 25 | 26 | cudnn.benchmark = True 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | 29 | model = G2().to(device) 30 | checkpoint = torch.load(weight_file) 31 | if model_name == 'ESRGAN': 32 | model.load_state_dict(checkpoint['gen']) 33 | else: 34 | model.load_state_dict(checkpoint['params_ema']) 35 | model.eval() 36 | imglist = os.listdir(img_dir) 37 | 38 | for imgName in imglist: 39 | image = utils.loadIMG_crop(img_dir + imgName, scale) 40 | img_mode = image.mode 41 | if img_mode == 'L': 42 | gray_img = np.array(image) 43 | image = image.convert('RGB') 44 | lr_image = np.array(image) 45 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 46 | lr /= 255. 47 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 48 | 49 | with torch.no_grad(): 50 | SR = model(lr) 51 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 52 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 53 | if img_mode != 'L': 54 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255. 55 | else: 56 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L') 57 | SR_y = np.array(SR).astype(np.float32) / 255. 58 | # GPU tensor -> CPU tensor -> numpy 59 | output = np.array(SR).astype(np.uint8) 60 | output = Image.fromarray(output) # hw -> wh 61 | output.save(outputs_dir + imgName) 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /HCFlow/test_SR_DF2K_4X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_HCFlow_DF2K_x4_bicSR_test 3 | suffix: ~ 4 | use_tb_logger: false 5 | model: HCFlow_SR 6 | distortion: sr 7 | scale: 4 8 | quant: 64 9 | gpu_ids: [0] 10 | 11 | 12 | 13 | datasets: 14 | test0: 15 | name: example 16 | mode: GTLQ 17 | dataroot_GT: ../datasets/example_general_4X/HR 18 | dataroot_LQ: ../datasets/example_general_4X/LR 19 | 20 | # test_1: 21 | # name: Set5 22 | # mode: GTLQx 23 | # dataroot_GT: ../datasets/Set5/HR 24 | # dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 25 | 26 | # test_2: 27 | # name: Set14 28 | # mode: GTLQx 29 | # dataroot_GT: ../datasets/Set14/HR 30 | # dataroot_LQ: ../datasets/Set14/LR_bicubic/X4 31 | # 32 | # test_3: 33 | # name: BSD100 34 | # mode: GTLQx 35 | # dataroot_GT: ../datasets/BSD100/HR 36 | # dataroot_LQ: ../datasets/BSD100/LR_bicubic/X4 37 | # 38 | # test_4: 39 | # name: Urban100 40 | # mode: GTLQx 41 | # dataroot_GT: ../datasets/Urban100/HR 42 | # dataroot_LQ: ../datasets/Urban100/LR_bicubic/X4 43 | # 44 | # test_5: 45 | # name: DIV2K-va-4X 46 | # mode: GTLQ 47 | # dataroot_GT: ../datasets/srflow_datasets/div2k-validation-modcrop8-gt 48 | # dataroot_LQ: ../datasets/srflow_datasets/div2k-validation-modcrop8-x4 49 | 50 | 51 | #### network structures 52 | network_G: 53 | which_model_G: HCFlowNet_SR 54 | in_nc: 3 55 | out_nc: 3 56 | act_norm_start_step: 100 57 | 58 | flowDownsampler: 59 | K: 26 60 | L: 2 61 | flow_permutation: invconv 62 | flow_coupling: Affine 63 | nn_module: FCN 64 | hidden_channels: 64 65 | cond_channels: ~ 66 | splitOff: 67 | enable: true 68 | after_flowstep: [13, 13] 69 | flow_permutation: invconv 70 | flow_coupling: Affine 71 | nn_module: FCN 72 | nn_module_last: Conv2dZeros 73 | hidden_channels: 64 74 | RRDB_nb: [7, 7] 75 | RRDB_nf: 64 76 | RRDB_gc: 32 77 | 78 | 79 | #### validation settings 80 | val: 81 | heats: [0.0] 82 | n_sample: 3 83 | 84 | 85 | path: 86 | strict_load: true 87 | load_submodule: ~ 88 | # pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth 89 | # pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow+.pth 90 | pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow++.pth 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /SRFlow/thops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | 19 | 20 | def sum(tensor, dim=None, keepdim=False): 21 | if dim is None: 22 | # sum up all dim 23 | return torch.sum(tensor) 24 | else: 25 | if isinstance(dim, int): 26 | dim = [dim] 27 | dim = sorted(dim) 28 | for d in dim: 29 | tensor = tensor.sum(dim=d, keepdim=True) 30 | if not keepdim: 31 | for i, d in enumerate(dim): 32 | tensor.squeeze_(d-i) 33 | return tensor 34 | 35 | 36 | def mean(tensor, dim=None, keepdim=False): 37 | if dim is None: 38 | # mean all dim 39 | return torch.mean(tensor) 40 | else: 41 | if isinstance(dim, int): 42 | dim = [dim] 43 | dim = sorted(dim) 44 | for d in dim: 45 | tensor = tensor.mean(dim=d, keepdim=True) 46 | if not keepdim: 47 | for i, d in enumerate(dim): 48 | tensor.squeeze_(d-i) 49 | return tensor 50 | 51 | 52 | def split_feature(tensor, type="split"): 53 | """ 54 | type = ["split", "cross"] 55 | """ 56 | C = tensor.size(1) 57 | if type == "split": 58 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] 59 | elif type == "cross": 60 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 61 | 62 | 63 | def cat_feature(tensor_a, tensor_b): 64 | return torch.cat((tensor_a, tensor_b), dim=1) 65 | 66 | 67 | def pixels(tensor): 68 | return int(tensor.size(2) * tensor.size(3)) -------------------------------------------------------------------------------- /SRFlow/Permutations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn as nn 20 | from torch.nn import functional as F 21 | import thops 22 | 23 | 24 | class InvertibleConv1x1(nn.Module): 25 | def __init__(self, num_channels, LU_decomposed=False): 26 | super().__init__() 27 | w_shape = [num_channels, num_channels] 28 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) 29 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 30 | self.w_shape = w_shape 31 | self.LU = LU_decomposed 32 | 33 | def get_weight(self, input, reverse): 34 | w_shape = self.w_shape 35 | pixels = thops.pixels(input) 36 | dlogdet = torch.slogdet(self.weight)[1] * pixels 37 | if not reverse: 38 | weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) 39 | else: 40 | weight = torch.inverse(self.weight.double()).float() \ 41 | .view(w_shape[0], w_shape[1], 1, 1) 42 | return weight, dlogdet 43 | def forward(self, input, logdet=None, reverse=False): 44 | """ 45 | log-det = log|abs(|W|)| * pixels 46 | """ 47 | weight, dlogdet = self.get_weight(input, reverse) 48 | if not reverse: 49 | z = F.conv2d(input, weight) 50 | if logdet is not None: 51 | logdet = logdet + dlogdet 52 | return z, logdet 53 | else: 54 | z = F.conv2d(input, weight) 55 | if logdet is not None: 56 | logdet = logdet - dlogdet 57 | return z, logdet 58 | -------------------------------------------------------------------------------- /HCFlow/HCFlow_SR_model.py: -------------------------------------------------------------------------------- 1 | # base model for HCFlow 2 | import logging 3 | from collections import OrderedDict 4 | import torch 5 | from HCFlowNet_SR_arch import HCFlowNet_SR 6 | from base_model import BaseModel 7 | logger = logging.getLogger('base') 8 | 9 | 10 | class HCFlowSRModel(BaseModel): 11 | def __init__(self, opt, heats=[0.0], step=0): 12 | super(HCFlowSRModel, self).__init__(opt) 13 | self.opt = opt 14 | 15 | self.rank = -1 # non dist training 16 | 17 | # define network and load pretrained models 18 | self.netG = HCFlowNet_SR(opt, step).to(self.device) 19 | self.heats = heats 20 | # val 21 | if 'val' in opt: 22 | # self.heats = opt['val']['heats'] 23 | self.n_sample = opt['val']['n_sample'] 24 | self.sr_mode = opt['val']['sr_mode'] 25 | 26 | def feed_data(self, data, need_GT=True): 27 | self.var_L = data['LQ'].to(self.device) # LQ 28 | if need_GT: 29 | self.real_H = data['GT'].to(self.device) # GT 30 | else: 31 | self.real_H = None 32 | 33 | def test(self): 34 | self.netG.eval() 35 | self.fake_H = {} 36 | 37 | with torch.no_grad(): 38 | if self.real_H is None: 39 | nll = torch.zeros(1) 40 | else: 41 | # hr->lr+z, calculate nll 42 | self.fake_L_from_H, nll = self.netG(hr=self.real_H, lr=self.var_L, u=None, reverse=False, training=False) 43 | 44 | # lr+z->hr 45 | for heat in self.heats: 46 | for sample in range(self.n_sample): 47 | # z = self.get_z(heat, seed=1, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) 48 | self.fake_H[(heat, sample)] = self.netG(lr=self.var_L, 49 | z=None, u=None, eps_std=heat, reverse=True, training=False) 50 | 51 | self.netG.train() 52 | 53 | return nll.mean().item() 54 | 55 | def get_current_visuals(self, need_GT=True): 56 | out_dict = OrderedDict() 57 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 58 | for heat in self.heats: 59 | for i in range(self.n_sample): 60 | out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() 61 | 62 | if need_GT: 63 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 64 | out_dict['LQ_fromH'] = self.fake_L_from_H.detach()[0].float().cpu() 65 | 66 | return out_dict 67 | 68 | -------------------------------------------------------------------------------- /HCFlow/FlowStep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | import ActNorms, Permutations, AffineCouplings 5 | 6 | 7 | class FlowStep(nn.Module): 8 | def __init__(self, in_channels, cond_channels=None, flow_permutation='invconv', flow_coupling='Affine', LRvsothers=True, 9 | actnorm_scale=1.0, LU_decomposed=False, opt=None): 10 | super().__init__() 11 | self.flow_permutation = flow_permutation 12 | self.flow_coupling = flow_coupling 13 | 14 | # 1. actnorm 15 | self.actnorm = ActNorms.ActNorm2d(in_channels, actnorm_scale) 16 | 17 | # 2. permute # todo: maybe hurtful for downsampling; presever the structure of downsampling 18 | if self.flow_permutation == "invconv": 19 | self.permute = Permutations.InvertibleConv1x1(in_channels, LU_decomposed=LU_decomposed) 20 | elif self.flow_permutation == "none": 21 | self.permute = None 22 | 23 | # 3. coupling 24 | if self.flow_coupling == "AffineInjector": 25 | self.affine = AffineCouplings.AffineCouplingInjector(in_channels=in_channels, cond_channels=cond_channels, opt=opt) 26 | elif self.flow_coupling == "noCoupling": 27 | pass 28 | elif self.flow_coupling == "Affine": 29 | self.affine = AffineCouplings.AffineCoupling(in_channels=in_channels, cond_channels=cond_channels, opt=opt) 30 | elif self.flow_coupling == "Affine3shift": 31 | self.affine = AffineCouplings.AffineCoupling3shift(in_channels=in_channels, cond_channels=cond_channels, LRvsothers=LRvsothers, opt=opt) 32 | 33 | def forward(self, z, u=None, logdet=None, reverse=False): 34 | if not reverse: 35 | return self.normal_flow(z, u, logdet) 36 | else: 37 | return self.reverse_flow(z, u) 38 | 39 | def normal_flow(self, z, u=None, logdet=None): 40 | # 1. actnorm 41 | z, logdet = self.actnorm(z, logdet=logdet, reverse=False) 42 | 43 | # 2. permute 44 | if self.permute is not None: 45 | z, logdet = self.permute( z, logdet=logdet, reverse=False) 46 | 47 | # 3. coupling 48 | z, logdet = self.affine(z, u=u, logdet=logdet, reverse=False) 49 | 50 | return z, logdet 51 | 52 | def reverse_flow(self, z, u=None, logdet=None): 53 | # 1.coupling 54 | z, _ = self.affine(z, u=u, reverse=True) 55 | 56 | # 2. permute 57 | if self.permute is not None: 58 | z, _ = self.permute(z, reverse=True) 59 | 60 | # 3. actnorm 61 | z, _ = self.actnorm(z, reverse=True) 62 | 63 | return z, logdet 64 | 65 | -------------------------------------------------------------------------------- /SRFlow/SRFlow_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import torch 18 | from SRFlowNet_arch import SRFlowNet 19 | from base_model import BaseModel 20 | 21 | class SRFlowModel(BaseModel): 22 | def __init__(self, opt, step=0): 23 | super(SRFlowModel, self).__init__(opt) 24 | self.opt = opt 25 | 26 | self.heats = opt['val']['heats'] 27 | self.n_sample = opt['val']['n_sample'] 28 | # define network and load pretrained models 29 | opt_net = opt['network_G'] 30 | self.netG = SRFlowNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], 31 | scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step).to(self.device) 32 | 33 | def get_sr(self, lq, heat=None, seed=None, z=None, epses=None): 34 | return self.get_sr_with_z(lq, heat, seed, z, epses)[0] 35 | 36 | def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None): 37 | self.netG.eval() 38 | 39 | z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z 40 | 41 | with torch.no_grad(): 42 | sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses) 43 | self.netG.train() 44 | return sr, z 45 | 46 | def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): 47 | if seed: torch.manual_seed(seed) 48 | C = self.netG.flowUpsamplerNet.C 49 | H = int(self.opt['scale'] * lr_shape[2] // self.netG.flowUpsamplerNet.scaleH) 50 | W = int(self.opt['scale'] * lr_shape[3] // self.netG.flowUpsamplerNet.scaleW) 51 | z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros( 52 | (batch_size, C, H, W)) 53 | return z 54 | -------------------------------------------------------------------------------- /HCFlow/networks.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import torch 4 | import discriminator_vgg_arch as SRGAN_arch 5 | 6 | logger = logging.getLogger('base') 7 | 8 | 9 | def find_model_using_name(model_name): 10 | model_filename = "models.modules." + model_name + "_arch" 11 | modellib = importlib.import_module(model_filename) 12 | 13 | model = None 14 | target_model_name = model_name.replace('_Net', '') 15 | for name, cls in modellib.__dict__.items(): 16 | if name.lower() == target_model_name.lower(): 17 | model = cls 18 | 19 | if model is None: 20 | print( 21 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( 22 | model_filename, target_model_name)) 23 | exit(0) 24 | 25 | return model 26 | 27 | def define_Flow(opt, step): 28 | opt_net = opt['network_G'] 29 | which_model = opt_net['which_model_G'] 30 | 31 | Arch = find_model_using_name(which_model) 32 | netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 33 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step) 34 | return netG 35 | 36 | def define_G(opt, step): 37 | which_model = opt['network_G']['which_model_G'] 38 | 39 | Arch = find_model_using_name(which_model) 40 | netG = Arch(opt=opt, step=step) 41 | return netG 42 | 43 | #### Discriminator 44 | def define_D(opt): 45 | opt_net = opt['network_D'] 46 | which_model = opt_net['which_model_D'] 47 | 48 | if which_model == 'discriminator_vgg_128': 49 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 50 | elif which_model == 'discriminator_vgg_160': 51 | netD = SRGAN_arch.Discriminator_VGG_160(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 52 | elif which_model == 'PatchGANDiscriminator': 53 | netD = SRGAN_arch.PatchGANDiscriminator(in_nc=opt_net['in_nc'], ndf=opt_net['ndf'], n_layers=opt_net['n_layers'],) 54 | else: 55 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 56 | return netD 57 | 58 | 59 | #### Define Network used for Perceptual Loss 60 | def define_F(opt, use_bn=False): 61 | gpu_ids = opt['gpu_ids'] 62 | device = torch.device('cuda' if gpu_ids else 'cpu') 63 | # PyTorch pretrained VGG19-54, before ReLU. 64 | if use_bn: 65 | feature_layer = 49 66 | else: 67 | feature_layer = 34 68 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 69 | use_input_norm=True, device=device) 70 | netF.eval() # No need to train 71 | return netF 72 | -------------------------------------------------------------------------------- /SRFlow/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import time 18 | 19 | 20 | class ScopeTimer: 21 | def __init__(self, name): 22 | self.name = name 23 | 24 | def __enter__(self): 25 | self.start = time.time() 26 | return self 27 | 28 | def __exit__(self, *args): 29 | self.end = time.time() 30 | self.interval = self.end - self.start 31 | print("{} {:.3E}".format(self.name, self.interval)) 32 | 33 | 34 | class Timer: 35 | def __init__(self): 36 | self.times = [] 37 | 38 | def tick(self): 39 | self.times.append(time.time()) 40 | 41 | def get_average_and_reset(self): 42 | if len(self.times) < 2: 43 | return -1 44 | avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1) 45 | self.times = [self.times[-1]] 46 | return avg 47 | 48 | def get_last_iteration(self): 49 | if len(self.times) < 2: 50 | return 0 51 | return self.times[-1] - self.times[-2] 52 | 53 | 54 | class TickTock: 55 | def __init__(self): 56 | self.time_pairs = [] 57 | self.current_time = None 58 | 59 | def tick(self): 60 | self.current_time = time.time() 61 | 62 | def tock(self): 63 | assert self.current_time is not None, self.current_time 64 | self.time_pairs.append([self.current_time, time.time()]) 65 | self.current_time = None 66 | 67 | def get_average_and_reset(self): 68 | if len(self.time_pairs) == 0: 69 | return -1 70 | deltas = [t2 - t1 for t1, t2 in self.time_pairs] 71 | avg = sum(deltas) / len(deltas) 72 | self.time_pairs = [] 73 | return avg 74 | 75 | def get_last_iteration(self): 76 | if len(self.time_pairs) == 0: 77 | return -1 78 | return self.time_pairs[-1][1] - self.time_pairs[-1][0] 79 | -------------------------------------------------------------------------------- /HCFlow/HCFlowNet_SR_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from util import opt_get 7 | import Basic, thops 8 | 9 | 10 | 11 | class HCFlowNet_SR(nn.Module): 12 | def __init__(self, opt, step=None): 13 | super(HCFlowNet_SR, self).__init__() 14 | self.opt = opt 15 | self.quant = opt_get(opt, ['quant'], 256) 16 | 17 | hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'], 160) 18 | hr_channel = opt_get(opt, ['network_G', 'in_nc'], 3) 19 | scale = opt_get(opt, ['scale']) 20 | 21 | if scale == 4: 22 | from FlowNet_SR_x4 import FlowNet 23 | else: 24 | raise NotImplementedError('Scale {} is not implemented'.format(scale)) 25 | 26 | # hr->lr+z 27 | self.flow = FlowNet((hr_size, hr_size, hr_channel), opt=opt) 28 | 29 | self.quantization = Basic.Quantization() 30 | 31 | # hr: HR image, lr: LR image, z: latent variable, u: conditional variable 32 | def forward(self, hr=None, lr=None, z=None, u=None, eps_std=None, 33 | add_gt_noise=False, step=None, reverse=False, training=True): 34 | 35 | # hr->z 36 | if not reverse: 37 | return self.normal_flow_diracLR(hr, lr, u, step=step, training=training) 38 | # z->hr 39 | else: 40 | return self.reverse_flow_diracLR(lr, z, u, eps_std=eps_std, training=training) 41 | 42 | 43 | #########################################diracLR 44 | # hr->lr+z, diracLR 45 | def normal_flow_diracLR(self, hr, lr, u=None, step=None, training=True): 46 | # 1. quantitize HR 47 | pixels = thops.pixels(hr) 48 | 49 | # according to Glow and ours, it should be u~U(0,a) (0.06 better in practice), not u~U(-0.5,0.5) (though better in theory) 50 | hr = hr + (torch.rand(hr.shape, device=hr.device)) / self.quant 51 | logdet = torch.zeros_like(hr[:, 0, 0, 0]) + float(-np.log(self.quant) * pixels) 52 | 53 | # 2. hr->lr+z 54 | fake_lr_from_hr, logdet = self.flow(hr=hr, u=u, logdet=logdet, reverse=False, training=training) 55 | 56 | # note in rescaling, we use LR for LR loss before quantization 57 | fake_lr_from_hr = self.quantization(fake_lr_from_hr) 58 | 59 | # 3. loss, Gaussian with small variance to approximate Dirac delta function of LR. 60 | # for the second term, using small log-variance may lead to svd problem, for both exp and tanh version 61 | objective = logdet + Basic.GaussianDiag.logp(lr, -torch.ones_like(lr)*6, fake_lr_from_hr) 62 | 63 | nll = ((-objective) / float(np.log(2.) * pixels)).mean() 64 | 65 | return torch.clamp(fake_lr_from_hr, 0, 1), nll 66 | 67 | # lr+z->hr 68 | def reverse_flow_diracLR(self, lr, z, u, eps_std, training=True): 69 | 70 | # lr+z->hr 71 | fake_hr = self.flow(z=lr, u=u, eps_std=eps_std, reverse=True, training=training) 72 | 73 | return torch.clamp(fake_hr, 0, 1) 74 | -------------------------------------------------------------------------------- /SwinIR/test_classical_swinir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.backends import cudnn 4 | from tqdm import tqdm 5 | 6 | import utils 7 | from swinir_model import SwinIR 8 | from PIL import Image 9 | from imresize import imresize 10 | import os 11 | 12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 13 | cudnn.benchmark = True 14 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 15 | 16 | if __name__ == '__main__': 17 | scale = 4 18 | model_name = 'SwinIR-Classical' 19 | weight_file = '../weight_file/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth' 20 | dataset = 'degraded_srflow' 21 | root_dir = f'../datasets/{dataset}/' 22 | out_root_dir = f'./test_res/{model_name}_{dataset}/' 23 | hr_dir = '/data0/jli/datasets/PIPAL/' 24 | 25 | lr_dirs = os.listdir(root_dir) 26 | for dir in lr_dirs: 27 | utils.mkdirs(out_root_dir + dir) 28 | 29 | model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8, 30 | img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], 31 | mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv').to(device) 32 | checkpoint = torch.load(weight_file) 33 | model.load_state_dict(checkpoint['params']) 34 | model.eval() 35 | 36 | 37 | for dir in lr_dirs: 38 | outputs_dir = out_root_dir + dir + '/' 39 | lr_dir = root_dir + dir + '/' 40 | lr_lists = os.listdir(lr_dir) 41 | with tqdm(total=len(lr_lists)) as t: 42 | t.set_description(f"Processing: {dir}") 43 | for imgName in lr_lists: 44 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 45 | if image.mode != 'L': 46 | image = image.convert('RGB') 47 | lr_image = np.array(image) 48 | 49 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 50 | lr /= 255. 51 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 52 | 53 | with torch.no_grad(): 54 | SR = model(lr) 55 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 56 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 57 | # GPU tensor -> CPU tensor -> numpy 58 | SR = np.array(SR).astype(np.uint8) 59 | SR = Image.fromarray(SR) # hw -> wh 60 | SR.save(outputs_dir + imgName) 61 | t.update(1) 62 | 63 | sr_dirs = os.listdir(out_root_dir) 64 | for dir in sr_dirs: 65 | sr_dir = out_root_dir + dir + '/' 66 | sr_lists = os.listdir(sr_dir) 67 | with tqdm(total=len(sr_lists)) as t: 68 | t.set_description(f"Processing: {dir}") 69 | for imgName in sr_lists: 70 | image = utils.loadIMG_crop(sr_dir + imgName, scale) 71 | if image.mode != 'L': 72 | image = image.convert('RGB') 73 | lr_image = np.array(image) 74 | 75 | 76 | -------------------------------------------------------------------------------- /SRCNN/gen_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | from tqdm import tqdm 5 | import utils 6 | from imresize import imresize 7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 8 | 9 | 10 | def gen_traindata(config): 11 | scale = config["scale"] 12 | stride = config["stride"] 13 | h5savepath = config["hrDir"] + f'_train_SRCNNx{scale}.h5' 14 | hrDir = config["hrDir"] + '/' 15 | size_input = config["size_input"] 16 | size_label = config["size_label"] 17 | padding = int(abs(size_input - size_label) / 2) 18 | 19 | h5_file = h5py.File(h5savepath, 'w') 20 | imgList = os.listdir(hrDir) 21 | lr_subimgs = [] 22 | hr_subimgs = [] 23 | with tqdm(total=len(imgList)) as t: 24 | for imgName in imgList: 25 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale) 26 | hr_y = utils.img2ycbcr(hrIMG)[..., 0] 27 | lr = imresize(hr_y, 1 / scale, 'bicubic') 28 | lr = imresize(lr, scale, 'bicubic').astype(np.float32) 29 | 30 | for r in range(0, hr_y.shape[0] - size_input + 1, stride): # hr.height 31 | for c in range(0, hr_y.shape[1] - size_input + 1, stride): # hr.width 32 | lr_subimgs.append(lr[r: r + size_input, c: c + size_input]) 33 | hr_subimgs.append(hr_y[r + padding: r + padding + size_label, c + padding: c + padding + size_label]) 34 | t.update(1) 35 | 36 | lr_subimgs = np.array(lr_subimgs) 37 | hr_subimgs = np.array(hr_subimgs) 38 | 39 | h5_file.create_dataset('data', data=lr_subimgs) 40 | h5_file.create_dataset('label', data=hr_subimgs) 41 | 42 | h5_file.close() 43 | 44 | 45 | def gen_valdata(config): 46 | scale = config["scale"] 47 | h5savepath = config["hrDir"] + f'_val_SRCNNx{scale}.h5' 48 | hrDir = config["hrDir"] + '/' 49 | size_input = config["size_input"] 50 | size_label = config["size_label"] 51 | padding = int(abs(size_input - size_label) / 2) 52 | 53 | h5_file = h5py.File(h5savepath, 'w') 54 | lr_group = h5_file.create_group('data') 55 | hr_group = h5_file.create_group('label') 56 | imgList = os.listdir(hrDir) 57 | for i, imgName in enumerate(imgList): 58 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale) 59 | hr_y = utils.img2ycbcr(hrIMG)[..., 0] 60 | 61 | data = imresize(hr_y, 1 / scale, 'bicubic') 62 | data = imresize(data, scale, 'bicubic') 63 | data = data.astype(np.float32) 64 | label = hr_y[padding: -padding, padding: -padding] 65 | lr_group.create_dataset(str(i), data=data) 66 | hr_group.create_dataset(str(i), data=label) 67 | 68 | h5_file.close() 69 | 70 | 71 | if __name__ == '__main__': 72 | # config = {'hrDir': '../datasets/T91_aug', 'scale': 3, "stride": 14, "size_input": 33, "size_label": 21} 73 | config = {'hrDir': '../datasets/T91_aug', 'scale': 2, "stride": 14, "size_input": 22, "size_label": 10} 74 | gen_traindata(config) 75 | config['hrDir'] = '../datasets/Set5' 76 | gen_valdata(config) 77 | -------------------------------------------------------------------------------- /BSRGAN/test_bsrgan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.backends import cudnn 4 | from tqdm import tqdm 5 | import utils 6 | from bsrgan_model import RRDBNet 7 | from PIL import Image 8 | import os 9 | import torch.nn.functional as F 10 | 11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 12 | cudnn.benchmark = True 13 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 14 | if __name__ == '__main__': 15 | model_name = 'BSRGAN' 16 | scale = 4 17 | itplt = 'bilinear' 18 | # itplt = 'bicubic' 19 | # dataset = f'degraded5_{itplt}_medium' 20 | # dataset = f'degraded6_{itplt}_slight' 21 | # dataset = f'degraded4_{itplt}_heavy' 22 | # dataset = f'degraded1_offset2' 23 | dataset = f'degraded1_bilinear_heavy' 24 | 25 | weight_file = '../weight_file/BSRGAN.pth' 26 | root_dir = f'../datasets/{dataset}/' 27 | out_root_dir_ImgName = f'./test_res/{model_name}_{dataset}/sort_with_ImgName/' 28 | out_root_dir_degradation = f'./test_res/{model_name}_{dataset}/sort_with_degradation/' 29 | hr_dir = '../datasets/PIPAL/' 30 | 31 | lr_dirs = os.listdir(root_dir) 32 | # Sort with ImgName 33 | out1_dirs = os.listdir(hr_dir) 34 | for dir in out1_dirs: 35 | dir = dir.split('.')[0] + '/' 36 | utils.mkdirs(out_root_dir_ImgName + dir) 37 | # Sort with degradation 38 | for dir in lr_dirs: 39 | utils.mkdirs(out_root_dir_degradation + dir) 40 | 41 | model = RRDBNet().to(device) 42 | checkpoint = torch.load(weight_file) 43 | model.load_state_dict(checkpoint) 44 | model.eval() 45 | 46 | for dir in lr_dirs: 47 | out_by_deg = out_root_dir_degradation + dir + '/' 48 | lr_dir = root_dir + dir + '/' 49 | lr_lists = os.listdir(lr_dir) 50 | with tqdm(total=len(lr_lists)) as t: 51 | t.set_description(f"Processing: {dir}") 52 | for imgName in lr_lists: 53 | out_by_name = out_root_dir_ImgName + imgName.split('.')[0] + '/' 54 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 55 | # image = utils.ImgOffSet(image, offset, offset) 56 | if image.mode == 'L': 57 | image = image.convert('RGB') 58 | lr_image = np.array(image) 59 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 60 | lr /= 255. 61 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 62 | lr = F.pad(lr, pad=[1, 1, 1, 1], mode='constant') 63 | with torch.no_grad(): 64 | SR = model(lr) 65 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 66 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 67 | # GPU tensor -> CPU tensor -> numpy 68 | SR = np.array(SR).astype(np.uint8) 69 | SR = Image.fromarray(SR) # hw -> wh 70 | SR.save(out_by_name + dir + '.bmp') 71 | SR.save(out_by_deg + imgName) 72 | t.update(1) 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /SRResNet/train2.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from tqdm import tqdm 3 | import model_utils 4 | import utils 5 | import os 6 | import torch 7 | from torch.backends import cudnn 8 | from model import G 9 | from torch import nn, optim 10 | from SRResNetdatasets import SRResNetValDataset, DIV2KDataset, DIV2KSubDataset 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 14 | 15 | 16 | def train_model(config, from_pth=False): 17 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 18 | 19 | outputs_dir = config['outputs_dir'] 20 | batch_size = config['batch_size'] 21 | utils.mkdirs(outputs_dir) 22 | csv_file = outputs_dir + config['csv_name'] 23 | logs_dir = config['logs_dir'] 24 | cudnn.benchmark = True 25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | torch.manual_seed(config['seed']) 27 | 28 | # ----需要修改部分------ 29 | print("===> Loading datasets") 30 | train_dataset = DIV2KSubDataset() 31 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 32 | batch_size=batch_size, shuffle=True, pin_memory=True) 33 | val_dataset = SRResNetValDataset(config['val_file']) 34 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 35 | 36 | print("===> Building model") 37 | model = G() 38 | if not from_pth: 39 | model.init_weight() 40 | criterion = nn.MSELoss().cuda() 41 | optimizer = optim.Adam(model.parameters(), lr=config['lr']) 42 | # ----END------ 43 | start_step, best_step, best_psnr, writer, csv_file = \ 44 | model_utils.load_checkpoint_iter(config['weight_file'], model, optimizer, csv_file, 45 | from_pth, auto_lr=config['auto_lr']) 46 | 47 | if torch.cuda.device_count() > 1: 48 | print("Using GPUs.\n") 49 | model = torch.nn.DataParallel(model) 50 | model = model.to(device) 51 | 52 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 53 | 'test': SummaryWriter(f"{logs_dir}/test")} 54 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 55 | num_steps = config['num_steps'] 56 | iter_of_epoch = 1000 57 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 'milestone': [2e5], 58 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps, 59 | 'best_psnr': best_psnr, 'best_step': best_step, 'iter_of_epoch': iter_of_epoch} 60 | 61 | while global_info['step'] < config['num_steps']: 62 | with tqdm(total=len(train_dataset)) as t: 63 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1)) 64 | global_info['t'] = t 65 | model_utils.train_iter(model, dataloaders, optimizer, criterion, global_info) 66 | print('best step: {}, psnr: {:.2f}'.format(global_info['best_step'], global_info['best_psnr'])) 67 | csv_file.close() 68 | -------------------------------------------------------------------------------- /SRFlow/SRFlow_DF2K_4X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SRFlow 21 | distortion: sr 22 | scale: 4 23 | gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/DF2K-tr.pklv4 31 | dataroot_LQ: ../datasets/DF2K-tr_X4.pklv4 32 | quant: 32 33 | 34 | use_shuffle: true 35 | n_workers: 3 # per GPU 36 | batch_size: 12 37 | GT_size: 160 38 | use_flip: true 39 | color: RGB 40 | val: 41 | name: CelebA_160_va 42 | mode: LRHR_PKL 43 | dataroot_GT: ../datasets/DIV2K-va.pklv4 44 | dataroot_LQ: ../datasets/DIV2K-va_X4.pklv4 45 | quant: 32 46 | n_max: 20 47 | 48 | #### Test Settings 49 | dataroot_GT: ../datasets/div2k-validation-modcrop8-gt 50 | dataroot_LR: ../datasets/div2k-validation-modcrop8-x4 51 | model_path: ../pretrained_models/SRFlow_DF2K_4X.pth 52 | heat: 0.9 # This is the standard deviation of the latent vectors 53 | 54 | #### network structures 55 | network_G: 56 | which_model_G: SRFlowNet 57 | in_nc: 3 58 | out_nc: 3 59 | nf: 64 60 | nb: 23 61 | upscale: 4 62 | train_RRDB: false 63 | train_RRDB_delay: 0.5 64 | 65 | flow: 66 | K: 16 67 | L: 3 68 | noInitialInj: true 69 | coupling: CondAffineSeparatedAndCond 70 | additionalFlowNoAffine: 2 71 | split: 72 | enable: true 73 | fea_up0: true 74 | stackRRDB: 75 | blocks: [ 1, 8, 15, 22 ] 76 | concat: true 77 | 78 | #### path 79 | path: 80 | pretrain_model_G: ../pretrained_models/RRDB_DF2K_4X.pth 81 | strict_load: true 82 | resume_state: auto 83 | 84 | #### training settings: learning rate scheme, loss 85 | train: 86 | manual_seed: 10 87 | lr_G: !!float 2.5e-4 88 | weight_decay_G: 0 89 | beta1: 0.9 90 | beta2: 0.99 91 | lr_scheme: MultiStepLR 92 | warmup_iter: -1 # no warm up 93 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] 94 | lr_gamma: 0.5 95 | 96 | niter: 200000 97 | val_freq: 40000 98 | 99 | #### validation settings 100 | val: 101 | heats: [ 1.0 ] 102 | n_sample: 1 103 | 104 | #### logger 105 | logger: 106 | print_freq: 100 107 | save_checkpoint_freq: !!float 1e3 108 | -------------------------------------------------------------------------------- /HCFlow/test_hcflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | import options 5 | import torch 6 | from torch.backends import cudnn 7 | import utils 8 | from HCFlow_SR_model import HCFlowSRModel 9 | from PIL import Image 10 | import os 11 | 12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 13 | if __name__ == '__main__': 14 | model_name = 'HCFlowSR' 15 | weight_file = '../weight_file/SR_DF2K_X4_HCFlow++.pth' 16 | root_dir = '../datasets/degraded_srflow/' 17 | out_root_dir = f'../test_res/{model_name}/' 18 | hr_dir = '../datasets/PIPAL/' 19 | 20 | lr_dirs = os.listdir(root_dir) 21 | for dir in lr_dirs: 22 | utils.mkdirs(out_root_dir + dir) 23 | 24 | scale = 4 25 | padding = scale 26 | 27 | if not os.path.exists(weight_file): 28 | print(f'Weight file not exist!\n{weight_file}\n') 29 | raise "Error" 30 | 31 | cudnn.benchmark = True 32 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 33 | opt = options.parse('./test_SR_DF2K_4X_HCFlow.yml', is_train=False) 34 | opt['gpu_ids'] = '0' 35 | opt = options.dict_to_nonedict(opt) 36 | heat = 0.8 37 | model = HCFlowSRModel(opt, [heat]) 38 | checkpoint = torch.load(weight_file) 39 | model.netG.load_state_dict(checkpoint) 40 | offset = 0 41 | 42 | for dir in lr_dirs: 43 | outputs_dir = out_root_dir + dir + '/' 44 | lr_dir = root_dir + dir + '/' 45 | lr_lists = os.listdir(lr_dir) 46 | with tqdm(total=len(lr_lists)) as t: 47 | t.set_description(f"Processing: {dir}") 48 | for imgName in lr_lists: 49 | # outputs_dir = out_root_dir + imgName.split('.')[0] + '/' 50 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 51 | image = utils.ImgOffSet(image, offset, offset) 52 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale) 53 | hr_image = utils.ImgOffSet(hr_image, offset*scale, offset*scale) 54 | 55 | img_mode = image.mode 56 | if img_mode == 'L': 57 | gray_img = np.array(image) 58 | image = image.convert('RGB') 59 | lr_image = np.array(image) 60 | hr_image = np.array(hr_image) 61 | 62 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 63 | lr /= 255. 64 | lr = torch.from_numpy(lr).unsqueeze(0) 65 | 66 | data = {'LQ': lr} 67 | model.feed_data(data, need_GT=False) 68 | model.test() 69 | visuals = model.get_current_visuals(need_GT=False) 70 | for sample in range(len(visuals)-1): 71 | SR = visuals['SR', heat, sample] 72 | SR = SR.mul(255.0).cpu().numpy() 73 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 74 | 75 | # GPU tensor -> CPU tensor -> numpy 76 | output = np.array(SR).astype(np.uint8) 77 | output = Image.fromarray(output) # hw -> wh 78 | tmp = outputs_dir + f'sample={sample}/' 79 | utils.mkdirs(tmp) 80 | output.save(tmp + imgName) 81 | t.update(1) 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /SRCNN/train.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | from torch.utils.tensorboard import SummaryWriter 4 | from tqdm import tqdm 5 | import utils 6 | import os 7 | import torch 8 | from torch.backends import cudnn 9 | from models import SRCNN 10 | from torch import nn, optim 11 | from SRCNNdatasets import TrainDataset, ValDataset 12 | from torch.utils.data.dataloader import DataLoader 13 | import model_utils 14 | 15 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 16 | 17 | 18 | def train_model(config, from_pth=False): 19 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 20 | 21 | outputs_dir = config['outputs_dir'] 22 | lr = config['lr'] 23 | batch_size = config['batch_size'] 24 | num_epochs = config['num_epochs'] 25 | csv_file = outputs_dir + config['csv_name'] 26 | logs_dir = config['logs_dir'] 27 | utils.mkdirs(outputs_dir) 28 | utils.mkdirs(logs_dir) 29 | 30 | cudnn.benchmark = True 31 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 32 | torch.manual_seed(config['seed']) 33 | 34 | model = SRCNN() 35 | if not from_pth: 36 | model.init_weights() 37 | criterion = nn.MSELoss() 38 | optimizer = optim.SGD([ 39 | {'params': model.conv1.parameters()}, 40 | {'params': model.conv2.parameters()}, 41 | {'params': model.conv3.parameters(), 'lr': lr * 0.1} 42 | ], lr=lr, momentum=0.9) # 前两层学习率lr, 最后一层学习率lr*0.1 43 | 44 | train_dataset = TrainDataset(config['train_file']) 45 | train_dataloader = DataLoader(dataset=train_dataset, 46 | batch_size=batch_size, 47 | shuffle=True, 48 | num_workers=config['num_workers'], 49 | pin_memory=True, 50 | drop_last=True) 51 | val_dataset = ValDataset(config['val_file']) 52 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 53 | 54 | start_epoch, best_epoch, best_psnr, writer, csv_file = \ 55 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, from_pth) 56 | 57 | if torch.cuda.device_count() > 1: 58 | print("Using GPUs.\n") 59 | model = torch.nn.DataParallel(model) 60 | model = model.to(device) 61 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar") 62 | 63 | for epoch in range(start_epoch, num_epochs): 64 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t: 65 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}') 66 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t) 67 | 68 | if isinstance(model, torch.nn.DataParallel): 69 | model = model.module 70 | 71 | epoch_psnr = model_utils.validate(model, val_dataloader, device) 72 | 73 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch) 74 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch) 75 | 76 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses, 77 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer) 78 | 79 | csv_file.close() 80 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /ESRGAN/test_realesrgan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.backends import cudnn 4 | from tqdm import tqdm 5 | import utils 6 | from realesrgan_model import G 7 | from PIL import Image 8 | import os 9 | import torch.nn.functional as F 10 | from imresize import imresize 11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 12 | if __name__ == '__main__': 13 | model_name = 'RealESRGAN' 14 | scale = 2 15 | itplt = 'bilinear' 16 | # itplt = 'bicubic' 17 | # dataset = f'degraded5_{itplt}_medium' 18 | # dataset = f'degraded6_{itplt}_slight' 19 | # dataset = f'degraded4_{itplt}_heavy' 20 | # dataset = f'degraded1_bilinear_heavy' 21 | # dataset = f'degraded1_offset_bilinear_heavy' 22 | dataset = f'PIPAL_offset_test' 23 | 24 | weight_file = f'../weight_file/RealESRGAN_x{scale}plus.pth' 25 | root_dir = f'../datasets/{dataset}/' 26 | out_root_dir_ImgName = f'./test_res/{model_name}_{dataset}/sort_with_ImgName/' 27 | out_root_dir_degradation = f'./test_res/{model_name}_{dataset}/sort_with_degradation/' 28 | hr_dir = '../datasets/PIPAL/' 29 | 30 | lr_dirs = os.listdir(root_dir) 31 | # Sort with ImgName 32 | out1_dirs = os.listdir(hr_dir) 33 | for dir in out1_dirs: 34 | dir = dir.split('.')[0] + '/' 35 | utils.mkdirs(out_root_dir_ImgName + dir) 36 | # Sort with degradation 37 | for dir in lr_dirs: 38 | utils.mkdirs(out_root_dir_degradation + dir) 39 | 40 | offset = 1 41 | 42 | if not os.path.exists(weight_file): 43 | print(f'Weight file not exist!\n{weight_file}\n') 44 | raise "Error" 45 | 46 | cudnn.benchmark = True 47 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 48 | 49 | model = G(scale=scale).to(device) 50 | checkpoint = torch.load(weight_file) 51 | model.load_state_dict(checkpoint['params_ema']) 52 | model.eval() 53 | total_dir = len(lr_dirs) 54 | for idx, dir in enumerate(lr_dirs): 55 | out_by_deg = out_root_dir_degradation + dir + '/' 56 | lr_dir = root_dir + dir + '/' 57 | lr_lists = os.listdir(lr_dir) 58 | with tqdm(total=len(lr_lists)) as t: 59 | t.set_description(f"Processing {idx}/{total_dir}: {dir}") 60 | for imgName in lr_lists: 61 | out_by_name = out_root_dir_ImgName + imgName.split('.')[0] + '/' 62 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 63 | # image = utils.ImgOffSet(image, offset, offset) 64 | if image.mode == 'L': 65 | image = image.convert('RGB') 66 | lr_image = np.array(image) 67 | lr_image = imresize(lr_image, 1. / scale, 'bilinear') 68 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 69 | lr /= 255. 70 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 71 | lr = F.pad(lr, pad=[1, 1, 1, 1], mode='constant') 72 | with torch.no_grad(): 73 | SR = model(lr) 74 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 75 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 76 | # GPU tensor -> CPU tensor -> numpy 77 | SR = np.array(SR).astype(np.uint8) 78 | SR = Image.fromarray(SR) # hw -> wh 79 | # SR.save(out_by_name + dir + '.bmp') 80 | SR.save(out_by_deg + imgName) 81 | t.update(1) -------------------------------------------------------------------------------- /SRFlow/Split.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | from torch import nn as nn 19 | 20 | import thops 21 | from flow import Conv2dZeros, GaussianDiag 22 | 23 | 24 | class Split2d(nn.Module): 25 | def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): 26 | super().__init__() 27 | 28 | self.num_channels_consume = int(round(num_channels * consume_ratio)) 29 | self.num_channels_pass = num_channels - self.num_channels_consume 30 | 31 | self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels, 32 | out_channels=self.num_channels_consume * 2) 33 | self.logs_eps = logs_eps 34 | self.position = position 35 | self.opt = opt 36 | 37 | def split2d_prior(self, z, ft): 38 | if ft is not None: 39 | z = torch.cat([z, ft], dim=1) 40 | h = self.conv(z) 41 | return thops.split_feature(h, "cross") 42 | 43 | def exp_eps(self, logs): 44 | return torch.exp(logs) + self.logs_eps 45 | 46 | def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None): 47 | if not reverse: 48 | # self.input = input 49 | z1, z2 = self.split_ratio(input) 50 | mean, logs = self.split2d_prior(z1, ft) 51 | 52 | eps = (z2 - mean) / self.exp_eps(logs) 53 | 54 | logdet = logdet + self.get_logdet(logs, mean, z2) 55 | 56 | # print(logs.shape, mean.shape, z2.shape) 57 | # self.eps = eps 58 | # print('split, enc eps:', eps) 59 | return z1, logdet, eps 60 | else: 61 | z1 = input 62 | mean, logs = self.split2d_prior(z1, ft) 63 | 64 | if eps is None: 65 | #print("WARNING: eps is None, generating eps untested functionality!") 66 | eps = GaussianDiag.sample_eps(mean.shape, eps_std) 67 | 68 | eps = eps.to(mean.device) 69 | z2 = mean + self.exp_eps(logs) * eps 70 | 71 | z = thops.cat_feature(z1, z2) 72 | logdet = logdet - self.get_logdet(logs, mean, z2) 73 | 74 | return z, logdet 75 | # return z, logdet, eps 76 | 77 | def get_logdet(self, logs, mean, z2): 78 | logdet_diff = GaussianDiag.logp(mean, logs, z2) 79 | # print("Split2D: logdet diff", logdet_diff.item()) 80 | return logdet_diff 81 | 82 | def split_ratio(self, input): 83 | z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] 84 | return z1, z2 -------------------------------------------------------------------------------- /ESRGAN/test_realesrgan_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import metrics 3 | import numpy as np 4 | import torch 5 | from torch.backends import cudnn 6 | from tqdm import tqdm 7 | import utils 8 | from realesrgan_model import G 9 | from PIL import Image 10 | import os 11 | 12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 13 | if __name__ == '__main__': 14 | model_name = 'RealESRGAN' 15 | itplt = 'bilinear' 16 | # itplt = 'bicubic' 17 | dataset = f'degraded4_{itplt}_heavy' 18 | # dataset = f'degraded5_{itplt}_medium' 19 | # dataset = f'degraded6_{itplt}_slight' 20 | csv_file = f'./test_res/{model_name}_{dataset}.csv' 21 | csv_file = open(csv_file, 'w', newline='') 22 | writer = csv.writer(csv_file) 23 | writer.writerow(('name', 'psnr', 'niqe', 'ssim', 'lpips')) 24 | 25 | weight_file = '/data0/jli/project/BasicSR-iqa/experiments/pretrained_models/RealESRGAN_x4plus.pth' 26 | root_dir = f'/data0/jli/datasets/{dataset}/' 27 | out_root_dir = f'./test_res/{model_name}_{dataset}_metrics/' 28 | hr_dir = '/data0/jli/datasets/PIPAL/' 29 | 30 | lr_dirs = os.listdir(root_dir) 31 | for dir in lr_dirs: 32 | utils.mkdirs(out_root_dir + dir) 33 | 34 | scale = 4 35 | padding = scale 36 | 37 | if not os.path.exists(weight_file): 38 | print(f'Weight file not exist!\n{weight_file}\n') 39 | raise "Error" 40 | 41 | cudnn.benchmark = True 42 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 43 | 44 | model = G().to(device) 45 | checkpoint = torch.load(weight_file) 46 | model.load_state_dict(checkpoint['params_ema']) 47 | 48 | model.eval() 49 | 50 | for dir in lr_dirs: 51 | outputs_dir = out_root_dir + dir + '/' 52 | lr_dir = root_dir + dir + '/' 53 | lr_lists = os.listdir(lr_dir) 54 | Avg_psnr = utils.AverageMeter() 55 | Avg_niqe = utils.AverageMeter() 56 | Avg_ssim = utils.AverageMeter() 57 | Avg_lpips = utils.AverageMeter() 58 | with tqdm(total=len(lr_lists)) as t: 59 | t.set_description(f"Processing: {dir}") 60 | for imgName in lr_lists: 61 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 62 | GT = utils.loadIMG_crop(hr_dir + imgName, scale) 63 | img_mode = image.mode 64 | if image.mode == 'L': 65 | image = image.convert('RGB') 66 | GT = GT.convert('RGB') 67 | 68 | lr_image = np.array(image) 69 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 70 | lr /= 255. 71 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 72 | 73 | with torch.no_grad(): 74 | SR = model(lr) 75 | 76 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 77 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 78 | 79 | SR = np.array(SR).astype(np.uint8) 80 | SR = Image.fromarray(SR) # hw -> wh 81 | SR.save(outputs_dir + imgName) 82 | 83 | metric = metrics.calc_metric(SR, GT, ['psnr', 'ssim', 'niqe', 'lpips']) 84 | Avg_lpips.update(metric['lpips'], 1) 85 | Avg_psnr.update(metric['psnr'], 1) 86 | Avg_niqe.update(metric['niqe'], 1) 87 | Avg_ssim.update(metric['ssim'], 1) 88 | t.update(1) 89 | 90 | writer.writerow((dir, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg, Avg_lpips.avg)) 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /HCFlow/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | 17 | 18 | # Define GAN loss: [gan(vanilla) | lsgan | wgan-gp | ragan] 19 | class GANLoss(nn.Module): 20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 21 | super(GANLoss, self).__init__() 22 | self.gan_type = gan_type.lower() 23 | self.real_label_val = real_label_val 24 | self.fake_label_val = fake_label_val 25 | 26 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 27 | self.loss = nn.BCEWithLogitsLoss() 28 | elif self.gan_type == 'lsgan': 29 | self.loss = nn.MSELoss() 30 | elif self.gan_type == 'wgan-gp': 31 | 32 | def wgan_loss(input, target): 33 | # target is boolean 34 | return -1 * input.mean() if target else input.mean() 35 | 36 | self.loss = wgan_loss 37 | else: 38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 39 | 40 | def get_target_label(self, input, target_is_real): 41 | if self.gan_type == 'wgan-gp': 42 | return target_is_real 43 | if target_is_real: 44 | return torch.empty_like(input).fill_(self.real_label_val) 45 | else: 46 | return torch.empty_like(input).fill_(self.fake_label_val) 47 | 48 | def forward(self, input, target_is_real): 49 | target_label = self.get_target_label(input, target_is_real) 50 | loss = self.loss(input, target_label) 51 | return loss 52 | 53 | 54 | class GradientPenaltyLoss(nn.Module): 55 | def __init__(self, device=torch.device('cpu')): 56 | super(GradientPenaltyLoss, self).__init__() 57 | self.register_buffer('grad_outputs', torch.Tensor()) 58 | self.grad_outputs = self.grad_outputs.to(device) 59 | 60 | def get_grad_outputs(self, input): 61 | if self.grad_outputs.size() != input.size(): 62 | self.grad_outputs.resize_(input.size()).fill_(1.0) 63 | return self.grad_outputs 64 | 65 | def forward(self, interp, interp_crit): 66 | grad_outputs = self.get_grad_outputs(interp_crit) 67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 68 | grad_outputs=grad_outputs, create_graph=True, 69 | retain_graph=True, only_inputs=True)[0] 70 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 71 | grad_interp_norm = grad_interp.norm(2, dim=1) 72 | 73 | loss = ((grad_interp_norm - 1)**2).mean() 74 | return loss 75 | 76 | class ReconstructionLoss(nn.Module): 77 | def __init__(self, losstype='l2', eps=1e-6): 78 | super(ReconstructionLoss, self).__init__() 79 | self.losstype = losstype 80 | self.eps = eps 81 | 82 | def forward(self, x, target): 83 | if self.losstype == 'l2': 84 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3))) 85 | elif self.losstype == 'l1': 86 | diff = x - target 87 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) 88 | else: 89 | print("reconstruction loss type error!") 90 | return 0 91 | 92 | -------------------------------------------------------------------------------- /FSRCNN/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from tqdm import tqdm 3 | import model_utils 4 | import utils 5 | import os 6 | import torch 7 | from torch.backends import cudnn 8 | from model import FSRCNN 9 | from torch import nn, optim 10 | from FSRCNNdatasets import TrainDataset, ValDataset, ResValDataset 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | 14 | def train_model(config, from_pth=False): 15 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 16 | outputs_dir = config['outputs_dir'] 17 | lr = config['lr'] 18 | batch_size = config['batch_size'] 19 | num_epochs = config['num_epochs'] 20 | logs_dir = config['logs_dir'] 21 | csv_file = outputs_dir + config['csv_name'] 22 | 23 | utils.mkdirs(outputs_dir) 24 | utils.mkdirs(logs_dir) 25 | 26 | cudnn.benchmark = True 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | torch.manual_seed(config['seed']) 29 | # ----需要修改部分------ 30 | model = FSRCNN(config['scale'], config['in_size'], config['out_size'], 31 | num_channels=1, d=config['d'], s=config['s'], m=config['m']) 32 | if not from_pth: 33 | model.init_weights(method=config['init']) 34 | criterion = nn.MSELoss().cuda() 35 | 36 | optimizer = optim.SGD([ 37 | {'params': model.extract_layer.parameters()}, 38 | {'params': model.mid_part.parameters()}, 39 | {'params': model.deconv_layer.parameters(), 'lr': lr * 0.1} 40 | ], lr=lr, momentum=0.9) # 前两层学习率lr, 最后一层学习率lr*0.1 41 | train_dataset = TrainDataset(config['train_file']) 42 | train_dataloader = DataLoader(dataset=train_dataset, 43 | batch_size=batch_size, 44 | shuffle=True, 45 | num_workers=config['num_workers'], 46 | pin_memory=True) 47 | if config['residual']: 48 | val_dataset = ResValDataset(config['val_file']) 49 | else: 50 | val_dataset = ValDataset(config['val_file']) 51 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 52 | # ----END------ 53 | 54 | start_epoch, best_epoch, best_psnr, writer, csv_file = \ 55 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, 56 | from_pth, config['auto_lr']) 57 | 58 | if torch.cuda.device_count() > 1: 59 | print("Using GPUs.\n") 60 | model = torch.nn.DataParallel(model) 61 | model = model.to(device) 62 | 63 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar") 64 | for epoch in range(start_epoch, num_epochs): 65 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 66 | print(f'learning rate: {lr}\n') 67 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t: 68 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}') 69 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t) 70 | 71 | if isinstance(model, torch.nn.DataParallel): 72 | model = model.module 73 | 74 | epoch_psnr = model_utils.validate(model, val_dataloader, device, config['residual']) 75 | 76 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch) 77 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch) 78 | 79 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses, 80 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer) 81 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) -------------------------------------------------------------------------------- /HCFlow/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | def initialize_weights_xavier(net_l, scale=1): 27 | if not isinstance(net_l, list): 28 | net_l = [net_l] 29 | for net in net_l: 30 | for m in net.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | init.xavier_normal_(m.weight) 33 | m.weight.data *= scale # for residual block 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | elif isinstance(m, nn.Linear): 37 | init.xavier_normal_(m.weight) 38 | m.weight.data *= scale 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant_(m.weight, 1) 43 | init.constant_(m.bias.data, 0.0) 44 | 45 | 46 | def make_layer(block, n_layers): 47 | layers = [] 48 | for _ in range(n_layers): 49 | layers.append(block()) 50 | return nn.Sequential(*layers) 51 | 52 | 53 | class ResidualBlock_noBN(nn.Module): 54 | '''Residual block w/o BN 55 | ---Conv-ReLU-Conv-+- 56 | |________________| 57 | ''' 58 | 59 | def __init__(self, nf=64): 60 | super(ResidualBlock_noBN, self).__init__() 61 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 62 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | 64 | # initialization 65 | initialize_weights([self.conv1, self.conv2], 0.1) 66 | 67 | def forward(self, x): 68 | identity = x 69 | out = F.relu(self.conv1(x), inplace=True) 70 | out = self.conv2(out) 71 | return identity + out 72 | 73 | 74 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 75 | """Warp an image or feature map with optical flow 76 | Args: 77 | x (Tensor): size (N, C, H, W) 78 | flow (Tensor): size (N, H, W, 2), normal value 79 | interp_mode (str): 'nearest' or 'bilinear' 80 | padding_mode (str): 'zeros' or 'border' or 'reflection' 81 | 82 | Returns: 83 | Tensor: warped image or feature map 84 | """ 85 | assert x.size()[-2:] == flow.size()[1:3] 86 | B, C, H, W = x.size() 87 | # mesh grid 88 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 89 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 90 | grid.requires_grad = False 91 | grid = grid.type_as(x) 92 | vgrid = grid + flow 93 | # scale grid to [-1,1] 94 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 95 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 96 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 97 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 98 | return output 99 | -------------------------------------------------------------------------------- /SRCNN/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.backends import cudnn 4 | import utils 5 | from models import SRCNN 6 | from PIL import Image 7 | import os 8 | from imresize import imresize 9 | 10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 11 | if __name__ == '__main__': 12 | scale = 3 13 | config = {'weight_file': f'./weight_file/SRCNN_x{scale}_lr=1e-02_batch=128/', 14 | 'img_dir': '../datasets/Set5/', 15 | 'outputs_dir': f'./test_res/test_x{scale}_Set5/', 16 | 'visual_filter': False 17 | } 18 | 19 | outputs_dir = config['outputs_dir'] 20 | 21 | padding = scale 22 | # weight_file = config['weight_file'] + f'best.pth' 23 | # weight_file = config['weight_file'] + f'SRCNNx3_data=276864_lr=1e-2.pth' 24 | weight_file = config['weight_file'] + f'x{scale}/best.pth' 25 | img_dir = config['img_dir'] 26 | outputs_dir = outputs_dir + f'x{scale}/' 27 | utils.mkdirs(outputs_dir) 28 | if not os.path.exists(weight_file): 29 | print(f'Weight file not exist!\n{weight_file}\n') 30 | raise "Error" 31 | if not os.path.exists(img_dir): 32 | raise "Image file not exist!\n" 33 | 34 | cudnn.benchmark = True 35 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 36 | 37 | model = SRCNN(padding=True).to(device) 38 | checkpoint = torch.load(weight_file) 39 | model.load_state_dict(checkpoint['model']) 40 | 41 | if config['visual_filter']: 42 | ax = utils.viz_layer(model.conv1[0].weight.cpu(), 64) 43 | model.eval() 44 | imglist = os.listdir(img_dir) 45 | Avg_psnr = utils.AverageMeter() 46 | for imgName in imglist: 47 | image = utils.loadIMG_crop(img_dir + imgName, scale) 48 | if image.mode != 'L': # gray image don't need to convert 49 | image = image.convert('RGB') 50 | hr_image = np.array(image) 51 | 52 | lr_image = imresize(hr_image, 1. / scale, 'bicubic') 53 | bic_image = imresize(lr_image, scale, 'bicubic') 54 | bic_pil = Image.fromarray(bic_image.astype(np.uint8)[padding: -padding, padding: -padding, ...]) 55 | bic_pil.save(outputs_dir + imgName.replace('.bmp', f'_bicubic_x{scale}.png')) 56 | 57 | bic_y, bic_ycbcr = utils.preprocess(bic_image, device, image.mode) 58 | hr_y, _ = utils.preprocess(hr_image, device, image.mode) 59 | with torch.no_grad(): 60 | SR = model(bic_y).clamp(0.0, 1.0) 61 | hr_y = hr_y[..., padding: -padding, padding: -padding] 62 | SR = SR[..., padding: -padding, padding: -padding] 63 | bic_ycbcr = bic_ycbcr[..., padding: -padding, padding: -padding] 64 | bic_y = bic_y[..., padding: -padding, padding: -padding] 65 | 66 | psnr = utils.calc_psnr(hr_y, SR) 67 | psnr2 = utils.calc_psnr(hr_y, bic_y) 68 | Avg_psnr.update(psnr, 1) 69 | # Avg_psnr.update(psnr2, 1) 70 | print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item())) 71 | print(f'{imgName}, ' + 'PSNR_bicubic: {:.2f}'.format(psnr2.item())) 72 | # GPU tensor -> CPU tensor -> numpy 73 | SR = SR.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) 74 | if image.mode == 'L': 75 | output = np.clip(SR, 0.0, 255.0).astype(np.uint8) # chw -> hwc 76 | else: 77 | bic_ycbcr = bic_ycbcr.mul(255.0).cpu().numpy().squeeze(0).transpose([1, 2, 0]) 78 | output = np.array([SR, bic_ycbcr[..., 1], bic_ycbcr[..., 2]]).transpose([1, 2, 0]) # chw -> hwc 79 | output = np.clip(utils.ycbcr2rgb(output), 0.0, 255.0).astype(np.uint8) 80 | output = Image.fromarray(output) # hw -> wh 81 | output.save(outputs_dir + imgName.replace('.bmp', f'_SRCNNx{scale}.png')) 82 | print('Average_PSNR: {:.2f}'.format(Avg_psnr.avg)) -------------------------------------------------------------------------------- /SRResNet/SRResNetdatasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import albumentations as A 5 | from imresize import imresize 6 | import h5py 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | 10 | 11 | class SRResNetTrainDataset(Dataset): 12 | def __init__(self, h5_file): 13 | super(SRResNetTrainDataset, self).__init__() 14 | self.h5_file = h5_file 15 | 16 | def __getitem__(self, idx): 17 | with h5py.File(self.h5_file, 'r') as f: 18 | return torch.from_numpy(f['data'][idx] / 255.), torch.from_numpy(f['label'][idx] / 255. * 2 - 1) 19 | 20 | def __len__(self): 21 | with h5py.File(self.h5_file, 'r') as f: 22 | return len(f['data']) 23 | 24 | 25 | class SRResNetValDataset(Dataset): 26 | def __init__(self, h5_file): 27 | super(SRResNetValDataset, self).__init__() 28 | self.h5_file = h5_file 29 | 30 | def __getitem__(self, idx): 31 | with h5py.File(self.h5_file, 'r') as f: 32 | return torch.from_numpy(f['data'][str(idx)][:, :, :] / 255.), \ 33 | torch.from_numpy(f['label'][str(idx)][:, :, :] / 255.) 34 | 35 | def __len__(self): 36 | with h5py.File(self.h5_file, 'r') as f: 37 | return len(f['data']) 38 | 39 | 40 | class DIV2KDataset(Dataset): 41 | def __init__(self, root_dir='../datasets/DIV2K_train_HR/'): 42 | super(DIV2KDataset, self).__init__() 43 | self.data = [] 44 | self.img_names = os.listdir(root_dir) 45 | 46 | for name in self.img_names: 47 | self.data.append(root_dir + name) 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | def __getitem__(self, index): 53 | img_path = self.data[index] 54 | 55 | image = np.array(Image.open(img_path)) 56 | transpose = A.RandomCrop(width=96, height=96) 57 | image = transpose(image=image)["image"] 58 | # transpose2 = A.Normalize(0.5, 0.5) 59 | # label = transpose2(image=image)["image"] 60 | # label = torch.from_numpy(label.astype(np.float32).transpose([2, 0, 1])) 61 | label = torch.from_numpy(image.astype(np.float32).transpose([2, 0, 1])/127.5-1) 62 | data = imresize(image, 1 / 4, 'bicubic') 63 | data = torch.from_numpy(data.astype(np.float32).transpose([2, 0, 1]) / 255) 64 | return data, label 65 | 66 | 67 | class DIV2KSubDataset(Dataset): 68 | def __init__(self, hr_dir='../datasets/DIV2K_sub/HR/', lr_dir='../datasets/DIV2K_sub/LRx4/'): 69 | super(DIV2KSubDataset, self).__init__() 70 | self.hr_data = [] 71 | self.lr_data = [] 72 | self.img_names = os.listdir(hr_dir) 73 | 74 | for name in self.img_names: 75 | self.hr_data.append(hr_dir + name) 76 | self.lr_data.append(lr_dir + name) 77 | 78 | def __len__(self): 79 | return len(self.hr_data) 80 | 81 | def __getitem__(self, index): 82 | hr_img_path = self.hr_data[index] 83 | lr_img_path = self.lr_data[index] 84 | 85 | hr_image = np.array(Image.open(hr_img_path)) 86 | lr_image = np.array(Image.open(lr_img_path)) 87 | label = torch.from_numpy(hr_image.astype(np.float32).transpose([2, 0, 1])/127.5-1) 88 | data = torch.from_numpy(lr_image.astype(np.float32).transpose([2, 0, 1]) / 255) 89 | return data, label 90 | 91 | 92 | def test(): 93 | dataset = DIV2KDataset(root_dir='../datasets/T91/') 94 | loader = DataLoader(dataset, batch_size=1, num_workers=0) 95 | 96 | for low_res, high_res in loader: 97 | print(low_res.shape) 98 | print(high_res.shape) 99 | 100 | 101 | if __name__ == "__main__": 102 | test() -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | import cv2 3 | import torch 4 | import utils 5 | import numpy as np 6 | import lpips 7 | import niqe 8 | calc_lpips = lpips.LPIPS(net='alex', model_path='./weight_file/alexnet-owt-7be5be79.pth') 9 | 10 | 11 | def calc_psnr(img1, img2): 12 | if isinstance(img1, torch.Tensor): 13 | return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2)) 14 | else: 15 | return 10. * np.log10(1. / np.mean((img1 - img2) ** 2)) 16 | 17 | 18 | def calculate_ssim(img, img2): 19 | """Calculate SSIM (structural similarity) for one channel images. 20 | 21 | It is called by func:`calculate_ssim`. 22 | 23 | Args: 24 | img (ndarray): Images with range [0, 255] with order 'HWC'. 25 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 26 | 27 | Returns: 28 | float: ssim result. 29 | """ 30 | 31 | c1 = (0.01 * 255) ** 2 32 | c2 = (0.03 * 255) ** 2 33 | 34 | img = img.astype(np.float64) 35 | img2 = img2.astype(np.float64) 36 | kernel = cv2.getGaussianKernel(11, 1.5) 37 | window = np.outer(kernel, kernel.transpose()) 38 | 39 | mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] 40 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 41 | mu1_sq = mu1 ** 2 42 | mu2_sq = mu2 ** 2 43 | mu1_mu2 = mu1 * mu2 44 | sigma1_sq = cv2.filter2D(img ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 45 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 46 | sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 47 | 48 | ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) 49 | return ssim_map.mean() 50 | 51 | 52 | def calc_metric(sr_image, gt_image, metric, crop=0): 53 | """ 54 | Args: 55 | sr_image: low quality img, [r,g,b] or gray 【0,255】 56 | gt_image: high quality img, [r,g,b] or gray (same size with SR) 57 | metric: string array, 'psnr','ssim','niqe' 58 | crop: crop pixel of border 59 | Returns: 60 | dict of matric with correspond float value 61 | """ 62 | 63 | res = dict() 64 | # convert RGBA to RGB 65 | if sr_image.mode != 'L': 66 | sr_image = sr_image.convert('RGB') 67 | gt_image = gt_image.convert('RGB') 68 | # crop border 69 | sr = np.array(sr_image).astype(np.float32)[crop: -max(crop, 1), crop: -max(crop, 1), ...] 70 | gt = np.array(gt_image).astype(np.float32)[crop: -max(crop, 1), crop: -max(crop, 1), ...] 71 | # RGB to YCbCr 72 | if sr_image.mode == 'L': 73 | sr_y = sr / 255. 74 | gt_y = gt / 255. 75 | else: 76 | sr_y = utils.rgb2ycbcr(sr).astype(np.float32)[..., 0] / 255. 77 | gt_y = utils.rgb2ycbcr(gt).astype(np.float32)[..., 0] / 255. 78 | # calculate matrics 79 | if 'psnr' in metric: 80 | res.update({'psnr': calc_psnr(sr_y, gt_y)}) 81 | if 'niqe' in metric: 82 | res.update({'niqe': niqe.calculate_niqe(sr_y)}) 83 | if 'ssim' in metric: 84 | res.update({'ssim': calculate_ssim(sr_y * 255, gt_y * 255)}) 85 | if 'lpips' in metric: 86 | res.update({'lpips': calc_lpips( 87 | torch.from_numpy((sr / 255.).transpose([2, 0, 1])).unsqueeze(0), 88 | torch.from_numpy((gt / 255.).transpose([2, 0, 1])).unsqueeze(0), 89 | normalize=True 90 | ).item()}) 91 | return res 92 | 93 | def test_calc_metric(): 94 | SR = utils.loadIMG_crop('./test_res/HCFlowSR/bicubicx4/sample=0/A0012.bmp', 4).convert('RGB') 95 | GT = utils.loadIMG_crop('./datasets/PIPAL/A0012.bmp', 4).convert('RGB') 96 | metric = calc_metric(SR, GT, ['psnr', 'ssim', 'niqe', 'lpips']) 97 | 98 | if __name__ == '__main__': 99 | test_calc_metric() 100 | 101 | 102 | -------------------------------------------------------------------------------- /ESRGAN/train_ESRGAN_iter2.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import model_utils 3 | import utils 4 | import os 5 | import torch 6 | from torch.backends import cudnn 7 | from torch.utils.tensorboard import SummaryWriter 8 | from VGGLoss import VGGLoss 9 | from model import G, D 10 | from torch import nn, optim 11 | from ESRGANdatasets import ESRGANValDataset, ESRGANTrainDataset 12 | from torch.utils.data.dataloader import DataLoader 13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 14 | 15 | 16 | def train_model(config, pre_train, from_pth=False): 17 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 18 | outputs_dir = config['outputs_dir'] 19 | batch_size = config['batch_size'] 20 | utils.mkdirs(outputs_dir) 21 | csv_file = outputs_dir + config['csv_name'] 22 | logs_dir = config['logs_dir'] 23 | 24 | cudnn.benchmark = True 25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | torch.manual_seed(config['seed']) 27 | 28 | # ----需要修改部分------ 29 | print("===> Loading datasets") 30 | train_dataset = ESRGANTrainDataset() 31 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 32 | batch_size=batch_size, shuffle=True, pin_memory=True) 33 | val_dataset = ESRGANValDataset(config['val_file']) 34 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 35 | 36 | print("===> Building model") 37 | gen = G() 38 | disc = D() 39 | if not from_pth: 40 | disc.init_weight() 41 | criterion = {'bce': nn.BCEWithLogitsLoss().to(device), 'pixel_loss': nn.L1Loss().to(device), 42 | 'vgg_loss': VGGLoss(device, nn.L1Loss())} 43 | gen_opt = optim.Adam(gen.parameters(), lr=config['gen_lr']) 44 | disc_opt = optim.Adam(disc.parameters(), lr=config['disc_lr']) 45 | # ----END------ 46 | start_step, best_step, best_niqe, writer, csv_file = \ 47 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt, 48 | disc, disc_opt, csv_file, from_pth, config['auto_lr']) 49 | 50 | if torch.cuda.device_count() > 1: 51 | print("Using GPUs.\n") 52 | gen = torch.nn.DataParallel(gen) 53 | disc = torch.nn.DataParallel(disc) 54 | gen = gen.to(device) 55 | disc = disc.to(device) 56 | 57 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")} 58 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 59 | num_steps = config['num_steps'] 60 | iter_of_epoch = 100 61 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 'milestone': config['milestone'], 62 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps, 63 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe, 64 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'], 65 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']} 66 | log_dict = {'D_losses': 0., 'G_losses': 0., 67 | 'D_losses_real': 0., 'D_losses_fake': 0., 68 | 'pixel_loss': 0., 'gan_loss': 0., 69 | 'percep_loss': 0., 70 | 'F_prob': 0., 'R_prob': 0.} 71 | while global_info['step'] < config['num_steps']: 72 | with tqdm(total=len(train_dataset)) as t: 73 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1)) 74 | global_info['t'] = t 75 | model_utils.train_ESRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict) 76 | 77 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_niqe'])) 78 | -------------------------------------------------------------------------------- /SRFlow/module_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.init as init 20 | import torch.nn.functional as F 21 | 22 | 23 | def initialize_weights(net_l, scale=1): 24 | if not isinstance(net_l, list): 25 | net_l = [net_l] 26 | for net in net_l: 27 | for m in net.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 30 | m.weight.data *= scale # for residual block 31 | if m.bias is not None: 32 | m.bias.data.zero_() 33 | elif isinstance(m, nn.Linear): 34 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 35 | m.weight.data *= scale 36 | if m.bias is not None: 37 | m.bias.data.zero_() 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant_(m.weight, 1) 40 | init.constant_(m.bias.data, 0.0) 41 | 42 | 43 | def make_layer(block, n_layers): 44 | layers = [] 45 | for _ in range(n_layers): 46 | layers.append(block()) 47 | return nn.Sequential(*layers) 48 | 49 | 50 | class ResidualBlock_noBN(nn.Module): 51 | '''Residual block w/o BN 52 | ---Conv-ReLU-Conv-+- 53 | |________________| 54 | ''' 55 | 56 | def __init__(self, nf=64): 57 | super(ResidualBlock_noBN, self).__init__() 58 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 59 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 60 | 61 | # initialization 62 | initialize_weights([self.conv1, self.conv2], 0.1) 63 | 64 | def forward(self, x): 65 | identity = x 66 | out = F.relu(self.conv1(x), inplace=True) 67 | out = self.conv2(out) 68 | return identity + out 69 | 70 | 71 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 72 | """Warp an image or feature map with optical flow 73 | Args: 74 | x (Tensor): size (N, C, H, W) 75 | flow (Tensor): size (N, H, W, 2), normal value 76 | interp_mode (str): 'nearest' or 'bilinear' 77 | padding_mode (str): 'zeros' or 'border' or 'reflection' 78 | 79 | Returns: 80 | Tensor: warped image or feature map 81 | """ 82 | assert x.size()[-2:] == flow.size()[1:3] 83 | B, C, H, W = x.size() 84 | # mesh grid 85 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 86 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 87 | grid.requires_grad = False 88 | grid = grid.type_as(x) 89 | vgrid = grid + flow 90 | # scale grid to [-1,1] 91 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 92 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 93 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 94 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 95 | return output 96 | -------------------------------------------------------------------------------- /SRResNet/gen_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | import utils 5 | from imresize import imresize 6 | from tqdm import tqdm 7 | from PIL import Image 8 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 9 | 10 | 11 | def gen_traindata(config): 12 | scale = config["scale"] 13 | stride = config["stride"] 14 | size_input = config["size_input"] 15 | size_label = size_input * scale 16 | size_output = config['size_output'] 17 | method = config['method'] 18 | h5savepath = config["hrDir"] + f'_label={size_output}_train_SRResNetx{scale}.h5' 19 | hrDir = config["hrDir"] + '/' 20 | 21 | h5_file = h5py.File(h5savepath, 'w') 22 | imgList = os.listdir(hrDir) 23 | lr_subimgs = [] 24 | hr_subimgs = [] 25 | total = len(imgList) 26 | with tqdm(total=total) as t: 27 | for imgName in imgList: 28 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale) 29 | hr = utils.img2ycbcr(hrIMG, gray2rgb=True) 30 | lr = imresize(hr, 1 / scale, method) 31 | input = lr.astype(np.float32) 32 | 33 | for r in range(0, lr.shape[0] - size_input + 1, stride): # height 34 | for c in range(0, lr.shape[1] - size_input + 1, stride): # width 35 | lr_subimgs.append(input[r: r + size_input, c: c + size_input].transpose([2, 0, 1])) # hwc -> chw 36 | label = hr[r * scale: r * scale + size_label, c * scale: c * scale + size_label].transpose([2, 0, 1]) 37 | hr_subimgs.append(label) 38 | t.update(1) 39 | 40 | lr_subimgs = np.array(lr_subimgs).astype(np.float32) 41 | h5_file.create_dataset('data', data=lr_subimgs) 42 | lr_subimgs = [] 43 | # HR dataset is too large to convert to ndarray, so convert it by several parts. 44 | num = len(hr_subimgs) 45 | seg = num // scale 46 | hr_imgs = np.array(hr_subimgs[0: seg]).astype(np.float32) 47 | dl = h5_file.create_dataset('label', data=hr_imgs, maxshape=[num, 3, 96, 96]) 48 | for i in range(1, scale): 49 | hr_imgs = [] 50 | hr_imgs = np.array(hr_subimgs[i*seg: (i+1)*seg]).astype(np.float32) 51 | dl.resize(((i+1)*seg, 3, 96, 96)) 52 | dl[i*seg:] = hr_imgs 53 | hr_imgs = [] 54 | hr_imgs = np.array(hr_subimgs[scale*seg: num]).astype(np.float32) 55 | dl.resize((num, 3, 96, 96)) 56 | dl[scale*seg:] = hr_imgs 57 | hr_imgs = [] 58 | h5_file.close() 59 | 60 | 61 | def gen_valdata(config): 62 | scale = config["scale"] 63 | size_input = config["size_input"] 64 | size_label = size_input * scale 65 | size_output = config['size_output'] 66 | method = config['method'] 67 | h5savepath = config["hrDir"] + f'_label={size_output}_val_SRResNetx{scale}.h5' 68 | hrDir = config["hrDir"] + '/' 69 | h5_file = h5py.File(h5savepath, 'w') 70 | lr_group = h5_file.create_group('data') 71 | hr_group = h5_file.create_group('label') 72 | imgList = os.listdir(hrDir) 73 | for i, imgName in enumerate(imgList): 74 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale) 75 | hr = utils.img2ycbcr(hrIMG, gray2rgb=True).astype(np.float32) 76 | lr = imresize(np.array(hrIMG).astype(np.float32), 1 / scale, method) 77 | 78 | data = lr.astype(np.float32).transpose([2, 0, 1]) 79 | label = hr.transpose([2, 0, 1]) 80 | 81 | lr_group.create_dataset(str(i), data=data) 82 | hr_group.create_dataset(str(i), data=label) 83 | h5_file.close() 84 | 85 | 86 | if __name__ == '__main__': 87 | # config = {'hrDir': './test/flower', 'scale': 3, "stride": 14, "size_input": 33, "size_label": 21} 88 | config = {'hrDir': '../datasets/291_aug', 'scale': 4, 'stride': 12, "size_input": 24, "size_output": 96, 'method': 'bicubic'} 89 | #gen_traindata(config) # if use DIV2K, don't need to run this line. 90 | config['hrDir'] = '../datasets/Set5' 91 | gen_valdata(config) 92 | -------------------------------------------------------------------------------- /SRResNet/train_SRGAN_WGAN.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | import model_utils 4 | import utils 5 | import os 6 | import torch 7 | from torch.backends import cudnn 8 | from torch.utils.tensorboard import SummaryWriter 9 | from VGGLoss import VGGLoss 10 | from model import G, D 11 | from torch import nn, optim 12 | from SRResNetdatasets import SRResNetValDataset, SRResNetTrainDataset, DIV2KDataset, DIV2KSubDataset 13 | from torch.utils.data.dataloader import DataLoader 14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 15 | 16 | 17 | def train_model(config, pre_train, from_pth=False): 18 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 19 | outputs_dir = config['outputs_dir'] 20 | batch_size = config['batch_size'] 21 | utils.mkdirs(outputs_dir) 22 | csv_file = outputs_dir + config['csv_name'] 23 | logs_dir = config['logs_dir'] 24 | cudnn.benchmark = True 25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | torch.manual_seed(config['seed']) 27 | 28 | # ----需要修改部分------ 29 | print("===> Loading datasets") 30 | train_dataset = DIV2KSubDataset() 31 | # train_dataset = DIV2KDataset(config['train_file']) 32 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 33 | batch_size=batch_size, shuffle=True, pin_memory=True) 34 | val_dataset = SRResNetValDataset(config['val_file']) 35 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 36 | 37 | print("===> Building model") 38 | gen = G() 39 | disc = D() 40 | if not from_pth: 41 | disc.init_weight() 42 | criterion = {'pixel_loss': nn.MSELoss().to(device), 'vgg_loss': VGGLoss(device)} 43 | gen_opt = optim.RMSprop(gen.parameters(), lr=config['gen_lr']) 44 | disc_opt = optim.RMSprop(disc.parameters(), lr=config['disc_lr']) 45 | # ----END------ 46 | start_step, best_step, best_niqe, writer, csv_file = \ 47 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt, disc, disc_opt, 48 | csv_file, from_pth, config['auto_lr']) 49 | 50 | if torch.cuda.device_count() > 1: 51 | print("Using GPUs.\n") 52 | gen = torch.nn.DataParallel(gen) 53 | disc = torch.nn.DataParallel(disc) 54 | gen = gen.to(device) 55 | disc = disc.to(device) 56 | 57 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")} 58 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 59 | num_steps = config['num_steps'] 60 | iter_of_epoch = 1000 61 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 62 | 'milestone': config['milestone'], 63 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps, 64 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe, 65 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'], 66 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']} 67 | log_dict = {'D_losses': 0., 'G_losses': 0., 68 | 'D_losses_real': 0., 'D_losses_fake': 0., 69 | 'pixel_loss': 0., 'gan_loss': 0., 70 | 'percep_loss': 0., 71 | 'F_prob': 0., 'R_prob': 0.} 72 | while global_info['step'] < config['num_steps']: 73 | with tqdm(total=len(train_dataset)) as t: 74 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1)) 75 | global_info['t'] = t 76 | model_utils.train_SRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict, use_WGAN=True) 77 | 78 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_niqe'])) 79 | -------------------------------------------------------------------------------- /ESRGAN/train_ESRGAN_iter.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import model_utils 3 | import utils 4 | import os 5 | import torch 6 | from torch.backends import cudnn 7 | from torch.utils.tensorboard import SummaryWriter 8 | from VGGLoss import VGGLoss, PerceptualLoss 9 | from model import G, D 10 | from torch import nn, optim 11 | from ESRGANdatasets import ESRGANValDataset, ESRGANTrainDataset 12 | from torch.utils.data.dataloader import DataLoader 13 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 14 | 15 | 16 | def train_model(config, pre_train, from_pth=False): 17 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 18 | outputs_dir = config['outputs_dir'] 19 | batch_size = config['batch_size'] 20 | utils.mkdirs(outputs_dir) 21 | csv_file = outputs_dir + config['csv_name'] 22 | logs_dir = config['logs_dir'] 23 | 24 | cudnn.benchmark = True 25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | torch.manual_seed(config['seed']) 27 | 28 | # ----需要修改部分------ 29 | print("===> Loading datasets") 30 | train_dataset = ESRGANTrainDataset() 31 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 32 | batch_size=batch_size, shuffle=True, pin_memory=True) 33 | val_dataset = ESRGANValDataset(config['val_file']) 34 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 35 | 36 | print("===> Building model") 37 | gen = G() 38 | disc = D() 39 | if not from_pth: 40 | disc.init_weight() 41 | criterion = {'bce': nn.BCEWithLogitsLoss().to(device), 'pixel_loss': nn.L1Loss().to(device), 42 | 'vgg_loss': PerceptualLoss({'conv5_4': 1})} 43 | # 'vgg_loss': VGGLoss(device, nn.L1Loss())} 44 | gen_opt = optim.Adam(gen.parameters(), lr=config['gen_lr']) 45 | disc_opt = optim.Adam(disc.parameters(), lr=config['disc_lr']) 46 | # ----END------ 47 | start_step, best_step, best_niqe, writer, csv_file = \ 48 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt, 49 | disc, disc_opt, csv_file, from_pth, config['auto_lr']) 50 | 51 | if torch.cuda.device_count() > 1: 52 | print("Using GPUs.\n") 53 | gen = torch.nn.DataParallel(gen) 54 | disc = torch.nn.DataParallel(disc) 55 | gen = gen.to(device) 56 | disc = disc.to(device) 57 | 58 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")} 59 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 60 | num_steps = config['num_steps'] 61 | iter_of_epoch = 100 62 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 'milestone': config['milestone'], 63 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps, 64 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe, 65 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'], 66 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']} 67 | log_dict = {'D_losses': 0., 'G_losses': 0., 68 | 'D_losses_real': 0., 'D_losses_fake': 0., 69 | 'pixel_loss': 0., 'gan_loss': 0., 70 | 'percep_loss': 0., 71 | 'F_prob': 0., 'R_prob': 0.} 72 | while global_info['step'] < config['num_steps']: 73 | with tqdm(total=len(train_dataset)) as t: 74 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1)) 75 | global_info['t'] = t 76 | model_utils.train_ESRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict) 77 | 78 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_niqe'])) 79 | -------------------------------------------------------------------------------- /SRResNet/train_SRGAN_iter.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torchvision 4 | from tqdm import tqdm 5 | import model_utils 6 | import utils 7 | import os 8 | import torch 9 | from torch.backends import cudnn 10 | from torch.utils.tensorboard import SummaryWriter 11 | from VGGLoss import VGGLoss 12 | from model import G, D 13 | from torch import nn, optim 14 | from SRResNetdatasets import SRResNetValDataset, SRResNetTrainDataset, DIV2KDataset, DIV2KSubDataset 15 | from torch.utils.data.dataloader import DataLoader 16 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 17 | 18 | 19 | def train_model(config, pre_train, from_pth=False): 20 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 21 | outputs_dir = config['outputs_dir'] 22 | batch_size = config['batch_size'] 23 | utils.mkdirs(outputs_dir) 24 | csv_file = outputs_dir + config['csv_name'] 25 | logs_dir = config['logs_dir'] 26 | 27 | cudnn.benchmark = True 28 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 29 | torch.manual_seed(config['seed']) 30 | 31 | # ----需要修改部分------ 32 | print("===> Loading datasets") 33 | train_dataset = DIV2KSubDataset() 34 | # train_dataset = DIV2KDataset(config['train_file']) 35 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 36 | batch_size=batch_size, shuffle=True, pin_memory=True) 37 | val_dataset = SRResNetValDataset(config['val_file']) 38 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 39 | 40 | print("===> Building model") 41 | gen = G() 42 | disc = D() 43 | if not from_pth: 44 | disc.init_weight() 45 | criterion = {'bce': nn.BCEWithLogitsLoss().to(device), 'pixel_loss': nn.MSELoss().to(device), 'vgg_loss': VGGLoss(device)} 46 | gen_opt = optim.Adam(gen.parameters(), lr=config['gen_lr']) 47 | disc_opt = optim.Adam(disc.parameters(), lr=config['disc_lr']) 48 | # ----END------ 49 | start_step, best_step, best_niqe, writer, csv_file = \ 50 | model_utils.load_GAN_checkpoint_iter(pre_train, config['weight_file'], gen, gen_opt, 51 | disc, disc_opt, csv_file, from_pth, config['auto_lr']) 52 | 53 | if torch.cuda.device_count() > 1: 54 | print("Using GPUs.\n") 55 | gen = torch.nn.DataParallel(gen) 56 | disc = torch.nn.DataParallel(disc) 57 | gen = gen.to(device) 58 | disc = disc.to(device) 59 | 60 | tb_writer = {'scalar': SummaryWriter(f"{logs_dir}/scalar"), 'test': SummaryWriter(f"{logs_dir}/test")} 61 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 62 | num_steps = config['num_steps'] 63 | iter_of_epoch = 1000 64 | global_info = {'device': device, 'step': start_step, 't': None, 'auto_lr': config['auto_lr'], 65 | 'milestone': config['milestone'], 66 | 'tb_writer': tb_writer, 'outputs_dir': outputs_dir, 'csv_writer': writer, 'num_steps': num_steps, 67 | 'best_step': best_step, 'batch_no': -1, 'iter_of_epoch': iter_of_epoch, 'best_niqe': best_niqe, 68 | 'pixel_weight': config['pixel_weight'], 'adversarial_weight': config['adversarial_weight'], 69 | 'disc_k': config['gen_k'], 'gen_k': config['disc_k']} 70 | log_dict = {'D_losses': 0., 'G_losses': 0., 71 | 'D_losses_real': 0., 'D_losses_fake': 0., 72 | 'pixel_loss': 0., 'gan_loss': 0., 73 | 'percep_loss': 0., 74 | 'F_prob': 0., 'R_prob': 0.} 75 | while global_info['step'] < config['num_steps']: 76 | with tqdm(total=len(train_dataset)) as t: 77 | t.set_description('step:{}/{}'.format(global_info['step'], num_steps - 1)) 78 | global_info['t'] = t 79 | model_utils.train_SRGAN_iter(gen, disc, dataloaders, gen_opt, disc_opt, criterion, global_info, log_dict) 80 | 81 | print('best step: {}, niqe: {:.2f}'.format(global_info['best_step'], global_info['best_psnr'])) 82 | -------------------------------------------------------------------------------- /ESRGAN/ESRGANdatasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lmdb 4 | from PIL import Image 5 | import numpy as np 6 | import albumentations as A 7 | from imresize import imresize 8 | import h5py 9 | from torch.utils.data import Dataset, DataLoader 10 | import torch 11 | import matplotlib.pyplot as plt 12 | 13 | class ESRGANTrainDataset(Dataset): 14 | def __init__(self, root_dirs=['../datasets/DIV2K_train_HR/', 15 | '../datasets/Flickr2K/Flickr2K_HR/', 16 | '../datasets/OST/' 17 | ]): 18 | super(ESRGANTrainDataset, self).__init__() 19 | self.data = [] 20 | 21 | for root_dir in root_dirs: 22 | self.img_names = os.listdir(root_dir) 23 | for name in self.img_names: 24 | self.data.append(root_dir + name) 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | def __getitem__(self, index): 30 | img_path = self.data[index] 31 | 32 | image = np.array(Image.open(img_path)) 33 | transpose = A.Compose([ 34 | A.RandomCrop(width=128, height=128), 35 | A.HorizontalFlip(p=0.5), 36 | A.RandomRotate90(p=0.5), 37 | ]) 38 | image = transpose(image=image)["image"] 39 | label = torch.from_numpy(image.astype(np.float32).transpose([2, 0, 1]) / 255.) 40 | data = imresize(image, 1 / 4, 'bicubic') 41 | data = torch.from_numpy(data.astype(np.float32).transpose([2, 0, 1]) / 255.) 42 | return data, label 43 | 44 | 45 | class ESRGANValDataset(Dataset): 46 | def __init__(self, h5_file): 47 | super(ESRGANValDataset, self).__init__() 48 | self.h5_file = h5_file 49 | 50 | def __getitem__(self, idx): 51 | with h5py.File(self.h5_file, 'r') as f: 52 | return torch.from_numpy(f['data'][str(idx)][:, :, :] / 255.), \ 53 | torch.from_numpy(f['label'][str(idx)][:, :, :] / 255.) 54 | 55 | def __len__(self): 56 | with h5py.File(self.h5_file, 'r') as f: 57 | return len(f['data']) 58 | 59 | 60 | class lmdbDataset(Dataset): 61 | def __init__(self): 62 | super(lmdbDataset, self).__init__() 63 | env = lmdb.open('../datasets/Set5.lmdb', max_dbs=2, readonly=True) 64 | self.data = env.open_db("train_data".encode('ascii')) 65 | self.shape = env.open_db("train_shape".encode('ascii')) 66 | self.txn = env.begin() 67 | self._length = int(self.txn.stat(db=self.data)["entries"] / 2) 68 | 69 | def __getitem__(self, idx): 70 | idx = str(idx) 71 | image = self.txn.get(idx.encode('ascii'), db=self.data) 72 | image = np.frombuffer(image, 'uint8') 73 | buf_meta = self.txn.get((idx+'.meta').encode('ascii'), db=self.shape) 74 | buf_meta = buf_meta.decode('ascii') 75 | H, W, C = [int(s) for s in buf_meta.split(',')] 76 | image = image.reshape(H, W, C) 77 | 78 | transpose = A.Compose([ 79 | A.RandomCrop(width=128, height=128), 80 | A.HorizontalFlip(p=0.5), 81 | A.RandomRotate90(p=0.5), 82 | ]) 83 | image = transpose(image=image)["image"] 84 | label = torch.from_numpy(image.astype(np.float32).transpose([2, 0, 1]) / 255.) 85 | data = imresize(image, 1 / 4, 'bicubic') 86 | data = torch.from_numpy(data.astype(np.float32).transpose([2, 0, 1]) / 255.) 87 | return data, label 88 | 89 | def __len__(self): 90 | return self._length 91 | 92 | 93 | def test(): 94 | dataset = lmdbDataset() 95 | loader = DataLoader(dataset, batch_size=1, num_workers=0) 96 | 97 | for low_res , high_res in loader: 98 | # plt.imshow(low_res.cpu().numpy().squeeze(0).transpose([1,2,0])) 99 | # plt.show() 100 | print(low_res.shape) 101 | print(high_res.shape) 102 | 103 | 104 | if __name__ == "__main__": 105 | test() 106 | -------------------------------------------------------------------------------- /SRFlow/test_srflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import options 3 | import torch 4 | from torch.backends import cudnn 5 | import utils 6 | from SRFlow_model import SRFlowModel 7 | from PIL import Image 8 | import os 9 | 10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 11 | if __name__ == '__main__': 12 | model_name = 'SRFlow' 13 | weight_file = '../weight_file/SRFlow_DF2K_4X.pth' 14 | root_dir = '../datasets/degraded9/' 15 | out_root_dir = f'../test_res/{model_name}_degraded4/' 16 | hr_dir = '../datasets/PIPAL/' 17 | 18 | lr_dirs = os.listdir(root_dir) 19 | # 根据图片名称分类输出 20 | out_dirs = os.listdir(hr_dir) 21 | for dir in out_dirs: 22 | dir = dir.split('.')[0] + '/' 23 | utils.mkdirs(out_root_dir + dir) 24 | 25 | scale = 4 26 | padding = scale 27 | 28 | if not os.path.exists(weight_file): 29 | print(f'Weight file not exist!\n{weight_file}\n') 30 | raise "Error" 31 | 32 | cudnn.benchmark = True 33 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 34 | opt = options.parse('./SRFlow_DF2K_4X.yml', is_train=False) 35 | opt['gpu_ids'] = None 36 | opt = options.dict_to_nonedict(opt) 37 | heat = opt['heat'] 38 | model = SRFlowModel(opt) 39 | checkpoint = torch.load(weight_file) 40 | model.netG.load_state_dict(checkpoint) 41 | offset = 0 42 | for dir in lr_dirs: 43 | # outputs_dir = out_root_dir + dir + '/' 44 | lr_dir = root_dir + dir + '/' 45 | lr_lists = os.listdir(lr_dir) 46 | Avg_psnr = utils.AverageMeter() 47 | Avg_niqe = utils.AverageMeter() 48 | for imgName in lr_lists: 49 | outputs_dir = out_root_dir + imgName.split('.')[0] + '/' 50 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 51 | image = utils.ImgOffSet(image, offset, offset) 52 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale) 53 | hr_image = utils.ImgOffSet(hr_image, offset*scale, offset*scale) 54 | 55 | img_mode = image.mode 56 | if img_mode == 'L': 57 | gray_img = np.array(image) 58 | image = image.convert('RGB') 59 | lr_image = np.array(image) 60 | hr_image = np.array(hr_image) 61 | 62 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 63 | lr /= 255. 64 | lr = torch.from_numpy(lr).unsqueeze(0) 65 | 66 | # hr_image = hr_image[padding: -padding, padding: -padding, ...] 67 | 68 | SR = model.get_sr(lr, heat) 69 | # SR = SR[..., padding: -padding, padding: -padding] 70 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 71 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 72 | if img_mode != 'L': 73 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255. 74 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255. 75 | else: 76 | # gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...] 77 | hr_y = gray_img / 255. 78 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L') 79 | SR_y = np.array(SR).astype(np.float32) / 255. 80 | psnr = utils.calc_psnr(hr_y, SR_y) 81 | # NIQE = niqe.calculate_niqe(SR_y) 82 | Avg_psnr.update(psnr, 1) 83 | # Avg_niqe.update(NIQE, 1) 84 | print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item())) 85 | # print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}'.format(psnr.item(), NIQE)) 86 | # GPU tensor -> CPU tensor -> numpy 87 | output = np.array(SR).astype(np.uint8) 88 | output = Image.fromarray(output) # hw -> wh 89 | # output.save(outputs_dir + imgName) 90 | output.save(outputs_dir + dir + '.bmp') 91 | # output.save(outputs_dir + imgName.replace('.', f'_{model_name}_x{scale}.')) 92 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg)) 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /ESRGAN/train.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.tensorboard import SummaryWriter 3 | from tqdm import tqdm 4 | import model_utils 5 | import utils 6 | import os 7 | import torch 8 | from torch.backends import cudnn 9 | from model import G, D 10 | from torch import nn, optim 11 | from ESRGANdatasets import ESRGANValDataset, ESRGANTrainDataset 12 | from torch.utils.data.dataloader import DataLoader 13 | import numpy as np 14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 15 | 16 | 17 | def train_model(config, from_pth=False): 18 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 19 | 20 | outputs_dir = config['outputs_dir'] 21 | batch_size = config['batch_size'] 22 | num_epochs = config['num_epochs'] 23 | utils.mkdirs(outputs_dir) 24 | csv_file = outputs_dir + config['csv_name'] 25 | logs_dir = config['logs_dir'] 26 | cudnn.benchmark = True 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | torch.manual_seed(config['seed']) 29 | 30 | # ----需要修改部分------ 31 | print("===> Loading datasets") 32 | train_dataset = ESRGANTrainDataset() 33 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 34 | batch_size=batch_size, shuffle=True, pin_memory=True) 35 | val_dataset = ESRGANValDataset(config['val_file']) 36 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 37 | 38 | print("===> Building model") 39 | model = G() 40 | if not from_pth: 41 | model.init_weight() 42 | criterion = nn.L1Loss().cuda() 43 | optimizer = optim.Adam(model.parameters(), lr=config['lr']) 44 | # ----END------ 45 | start_epoch, best_epoch, best_psnr, writer, csv_file = \ 46 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, 47 | from_pth, auto_lr=config['auto_lr']) 48 | 49 | if torch.cuda.device_count() > 1: 50 | print("Using GPUs.\n") 51 | model = torch.nn.DataParallel(model) 52 | model = model.to(device) 53 | 54 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar") 55 | writer_test = SummaryWriter(f"{logs_dir}/test") 56 | 57 | for epoch in range(start_epoch, num_epochs): 58 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 59 | print(f'learning rate: {lr}\n') 60 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t: 61 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}') 62 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t) 63 | 64 | if isinstance(model, torch.nn.DataParallel): 65 | model = model.module 66 | 67 | model.eval() 68 | epoch_psnr = utils.AverageMeter() 69 | for idx, data in enumerate(val_dataloader): 70 | inputs, labels = data 71 | inputs = inputs.to(device) 72 | with torch.no_grad(): 73 | preds = model(inputs) 74 | img_grid_fake = torchvision.utils.make_grid(preds, normalize=True) 75 | writer_test.add_image(f"Test Fake{idx}", img_grid_fake, global_step=epoch) 76 | preds = preds.mul(255.0).cpu().numpy().squeeze(0) 77 | preds = preds.transpose([1, 2, 0]) #chw->hwc 78 | preds = np.clip(preds, 0.0, 255.0) 79 | preds = utils.rgb2ycbcr(preds).astype(np.float32)[..., 0]/255. 80 | epoch_psnr.update(utils.calc_psnr(preds, labels.numpy()[0, 0, ...]).item(), len(inputs)) 81 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) 82 | if config['auto_lr'] and (epoch+1) % 232 == 0: 83 | model_utils.update_lr(optimizer, 0.5) 84 | 85 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch) 86 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch) 87 | 88 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses, 89 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer) 90 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) -------------------------------------------------------------------------------- /FSRCNN/train_deconv.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from tqdm import tqdm 3 | import model_utils 4 | import utils 5 | import os 6 | import torch 7 | from torch.backends import cudnn 8 | from torch import nn, optim 9 | from model import N2_10_4, CLoss, HuberLoss, FSRCNN 10 | from FSRCNNdatasets import TrainDataset, ValDataset, ResValDataset 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | 14 | def train_model(config, from_pth=False, pre_train=None): 15 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 16 | outputs_dir = config['outputs_dir'] 17 | batch_size = config['batch_size'] 18 | lr = config['lr'] 19 | num_epochs = config['num_epochs'] 20 | logs_dir = config['logs_dir'] 21 | utils.mkdirs(outputs_dir) 22 | utils.mkdirs(logs_dir) 23 | 24 | csv_file = outputs_dir + config['csv_name'] 25 | 26 | cudnn.benchmark = True 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | torch.manual_seed(config['seed']) 29 | # ----需要修改部分------ 30 | model = N2_10_4(config['scale'], config['in_size'], config['out_size'], 31 | num_channels=1, d=config['d'], m=config['m']) 32 | # model = FSRCNN(scale, in_size, out_size, num_channels=1, d=config['d'], m=config['m']) 33 | if config['Loss'] == 'CLoss': 34 | criterion = CLoss(delta=config['delta']).cuda() 35 | elif config['Loss'] == 'Huber': 36 | criterion = HuberLoss(delta=config['delta']).cuda() 37 | else: 38 | criterion = nn.MSELoss().cuda() 39 | 40 | optimizer = optim.SGD([ 41 | {'params': model.deconv_layer.parameters(), 'lr': lr}, 42 | ], lr=lr, momentum=0.9) # 前两层学习率lr, 最后一层学习率lr*0.1 43 | train_dataset = TrainDataset(config['train_file']) 44 | train_dataloader = DataLoader(dataset=train_dataset, 45 | batch_size=batch_size, 46 | shuffle=True, 47 | num_workers=config['num_workers'], 48 | pin_memory=True) 49 | if config['residual']: 50 | val_dataset = ResValDataset(config['val_file']) 51 | else: 52 | val_dataset = ValDataset(config['val_file']) 53 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 54 | # ----END------ 55 | start_epoch, best_epoch, best_psnr, writer, csv_file = \ 56 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, from_pth, config['auto_lr']) 57 | 58 | if not from_pth: 59 | if not os.path.exists(pre_train): 60 | print(f'Weight file not exist!\n{pre_train}\n') 61 | raise "Error" 62 | checkpoint = torch.load(pre_train) 63 | model.load_state_dict(checkpoint['model']) 64 | model.deconv_layer.weight.data.normal_(mean=0.0, std=0.001) 65 | model.deconv_layer.bias.data.zero_() 66 | 67 | if torch.cuda.device_count() > 1: 68 | print("Using GPUs.\n") 69 | model = torch.nn.DataParallel(model) 70 | model = model.to(device) 71 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar") 72 | 73 | for epoch in range(start_epoch, num_epochs): 74 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 75 | print(f'learning rate: {lr}\n') 76 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t: 77 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}') 78 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t) 79 | 80 | if isinstance(model, torch.nn.DataParallel): 81 | model = model.module 82 | 83 | epoch_psnr = model_utils.validate(model, val_dataloader, device, config['residual']) 84 | 85 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch) 86 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch) 87 | 88 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses, 89 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer) 90 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) -------------------------------------------------------------------------------- /SRResNet/train.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.tensorboard import SummaryWriter 3 | from tqdm import tqdm 4 | import model_utils 5 | import utils 6 | import os 7 | import torch 8 | from torch.backends import cudnn 9 | from model import G, D 10 | from torch import nn, optim 11 | from SRResNetdatasets import SRResNetValDataset, SRResNetTrainDataset, DIV2KDataset 12 | from torch.utils.data.dataloader import DataLoader 13 | import numpy as np 14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 15 | 16 | 17 | def train_model(config, from_pth=False): 18 | os.environ['CUDA_VISIBLE_DEVICES'] = config['Gpu'] 19 | 20 | outputs_dir = config['outputs_dir'] 21 | batch_size = config['batch_size'] 22 | num_epochs = config['num_epochs'] 23 | utils.mkdirs(outputs_dir) 24 | csv_file = outputs_dir + config['csv_name'] 25 | logs_dir = config['logs_dir'] 26 | cudnn.benchmark = True 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | torch.manual_seed(config['seed']) 29 | 30 | # ----需要修改部分------ 31 | print("===> Loading datasets") 32 | train_dataset = DIV2KDataset(config['train_file']) 33 | train_dataloader = DataLoader(dataset=train_dataset, num_workers=config['num_workers'], 34 | batch_size=batch_size, shuffle=True, pin_memory=True) 35 | val_dataset = SRResNetValDataset(config['val_file']) 36 | val_dataloader = DataLoader(dataset=val_dataset, batch_size=1) 37 | 38 | print("===> Building model") 39 | model = G() 40 | if not from_pth: 41 | model.init_weight() 42 | criterion = nn.MSELoss().cuda() 43 | optimizer = optim.Adam(model.parameters(), lr=config['lr']) 44 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=50) 45 | # ----END------ 46 | start_epoch, best_epoch, best_psnr, writer, csv_file = \ 47 | model_utils.load_checkpoint(config['weight_file'], model, optimizer, csv_file, 48 | from_pth, auto_lr=config['auto_lr']) 49 | 50 | if torch.cuda.device_count() > 1: 51 | print("Using GPUs.\n") 52 | model = torch.nn.DataParallel(model) 53 | model = model.to(device) 54 | 55 | writer_scalar = SummaryWriter(f"{logs_dir}/scalar") 56 | writer_test = SummaryWriter(f"{logs_dir}/test") 57 | 58 | for epoch in range(start_epoch, num_epochs): 59 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 60 | print(f'learning rate: {lr}\n') 61 | with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t: 62 | t.set_description(f'epoch:{epoch}/{num_epochs - 1}') 63 | epoch_losses = model_utils.train(model, train_dataloader, optimizer, criterion, device, t) 64 | 65 | if isinstance(model, torch.nn.DataParallel): 66 | model = model.module 67 | 68 | model.eval() 69 | epoch_psnr = utils.AverageMeter() 70 | for idx, data in enumerate(val_dataloader): 71 | inputs, labels = data 72 | inputs = inputs.to(device) 73 | with torch.no_grad(): 74 | preds = model(inputs) * 0.5 + 0.5 75 | img_grid_fake = torchvision.utils.make_grid(preds, normalize=True) 76 | writer_test.add_image(f"Test Fake{idx}", img_grid_fake, global_step=epoch) 77 | preds = preds.mul(255.0).cpu().numpy().squeeze(0) 78 | preds = preds.transpose([1, 2, 0]) #chw->hwc 79 | preds = np.clip(preds, 0.0, 255.0) 80 | preds = utils.rgb2ycbcr(preds).astype(np.float32)[..., 0]/255. 81 | epoch_psnr.update(utils.calc_psnr(preds, labels.numpy()[0, 0, ...]).item(), len(inputs)) 82 | print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) 83 | if config['auto_lr']: 84 | scheduler.step(epoch_psnr.avg) 85 | 86 | writer_scalar.add_scalar('PSNR', epoch_psnr.avg, epoch) 87 | writer_scalar.add_scalar('Loss', epoch_losses.avg, epoch) 88 | 89 | best_epoch, best_psnr = model_utils.save_checkpoint(model, optimizer, epoch, epoch_losses, 90 | epoch_psnr, best_psnr, best_epoch, outputs_dir, writer) 91 | print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) -------------------------------------------------------------------------------- /BSRGAN/bsrgan_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | 8 | def initialize_weights(net_l, scale=1): 9 | if not isinstance(net_l, list): 10 | net_l = [net_l] 11 | for net in net_l: 12 | for m in net.modules(): 13 | if isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 15 | m.weight.data *= scale # for residual block 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | m.weight.data *= scale 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif isinstance(m, nn.BatchNorm2d): 24 | init.constant_(m.weight, 1) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def make_layer(block, n_layers): 29 | layers = [] 30 | for _ in range(n_layers): 31 | layers.append(block()) 32 | return nn.Sequential(*layers) 33 | 34 | 35 | class ResidualDenseBlock_5C(nn.Module): 36 | def __init__(self, nf=64, gc=32, bias=True): 37 | super(ResidualDenseBlock_5C, self).__init__() 38 | # gc: growth channel, i.e. intermediate channels 39 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 40 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 41 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 42 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 43 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 44 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 45 | 46 | # initialization 47 | initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 48 | 49 | def forward(self, x): 50 | x1 = self.lrelu(self.conv1(x)) 51 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 52 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 53 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 54 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 55 | return x5 * 0.2 + x 56 | 57 | 58 | class RRDB(nn.Module): 59 | '''Residual in Residual Dense Block''' 60 | 61 | def __init__(self, nf, gc=32): 62 | super(RRDB, self).__init__() 63 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 64 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 65 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 66 | 67 | def forward(self, x): 68 | out = self.RDB1(x) 69 | out = self.RDB2(out) 70 | out = self.RDB3(out) 71 | return out * 0.2 + x 72 | 73 | 74 | class RRDBNet(nn.Module): 75 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): 76 | super(RRDBNet, self).__init__() 77 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 78 | self.sf = sf 79 | print([in_nc, out_nc, nf, nb, gc, sf]) 80 | 81 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 82 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 83 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 84 | #### upsampling 85 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 86 | if self.sf==4: 87 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 88 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 89 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 90 | 91 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 92 | 93 | def forward(self, x): 94 | fea = self.conv_first(x) 95 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 96 | fea = fea + trunk 97 | 98 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 99 | if self.sf==4: 100 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 101 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 102 | 103 | return out -------------------------------------------------------------------------------- /BSRGAN/test_bsrgan_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import numpy as np 4 | import torch 5 | from torch.backends import cudnn 6 | 7 | import niqe 8 | import utils 9 | from bsrgan_model import RRDBNet 10 | from PIL import Image 11 | from imresize import imresize 12 | import os 13 | 14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 15 | if __name__ == '__main__': 16 | model_name = 'BSRGAN' 17 | weight_file = './weight_file/BSRGAN.pth' 18 | root_dir = '/data0/jli/datasets/degraded_17/' 19 | out_root_dir = f'./test_res/{model_name}_degraded17/' 20 | hr_dir = '/data0/jli/datasets/PIPAL/' 21 | 22 | csv_file = f'./test_res/{model_name}_x4_degraded17.csv' 23 | csv_file = open(csv_file, 'w', newline='') 24 | writer = csv.writer(csv_file) 25 | writer.writerow(('name', 'psnr', 'niqe', 'ssim')) 26 | 27 | lr_dirs = os.listdir(root_dir) 28 | for dir in lr_dirs: 29 | utils.mkdirs(out_root_dir + dir) 30 | 31 | scale = 4 32 | padding = scale 33 | 34 | 35 | if not os.path.exists(weight_file): 36 | print(f'Weight file not exist!\n{weight_file}\n') 37 | raise "Error" 38 | 39 | cudnn.benchmark = True 40 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 41 | 42 | model = RRDBNet().to(device) 43 | checkpoint = torch.load(weight_file) 44 | model.load_state_dict(checkpoint) 45 | model.eval() 46 | 47 | for dir in lr_dirs: 48 | outputs_dir = out_root_dir + dir + '/' 49 | lr_dir = root_dir + dir + '/' 50 | lr_lists = os.listdir(lr_dir) 51 | Avg_psnr = utils.AverageMeter() 52 | Avg_niqe = utils.AverageMeter() 53 | Avg_ssim = utils.AverageMeter() 54 | for imgName in lr_lists: 55 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 56 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale) 57 | img_mode = image.mode 58 | if img_mode == 'L': 59 | gray_img = np.array(image) 60 | image = image.convert('RGB') 61 | lr_image = np.array(image) 62 | hr_image = np.array(hr_image) 63 | 64 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 65 | lr /= 255. 66 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 67 | 68 | # hr_image = hr_image[padding: -padding, padding: -padding, ...] 69 | 70 | with torch.no_grad(): 71 | SR = model(lr) 72 | # SR = SR[..., padding: -padding, padding: -padding] 73 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 74 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 75 | if img_mode != 'L': 76 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255. 77 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255. 78 | else: 79 | # gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...] 80 | hr_y = gray_img / 255. 81 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L') 82 | SR_y = np.array(SR).astype(np.float32) / 255. 83 | psnr = utils.calc_psnr(hr_y, SR_y) 84 | NIQE = niqe.calculate_niqe(SR_y) 85 | ssim = utils.calculate_ssim(hr_y * 255, SR_y * 255) 86 | Avg_psnr.update(psnr, 1) 87 | Avg_niqe.update(NIQE, 1) 88 | Avg_ssim.update(ssim, 1) 89 | # print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item())) 90 | print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}, ssim: {:.4f}'.format(psnr.item(), NIQE, ssim)) 91 | # GPU tensor -> CPU tensor -> numpy 92 | output = np.array(SR).astype(np.uint8) 93 | output = Image.fromarray(output) # hw -> wh 94 | output.save(outputs_dir + imgName.replace('.', '_{:.2f}.'.format(psnr.item()))) 95 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}, Average_ssim: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg, 96 | Avg_ssim.avg)) 97 | writer.writerow((dir, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg)) 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /FSRCNN/gen_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | from tqdm import tqdm 5 | import utils 6 | from imresize import imresize 7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 8 | 9 | 10 | def gen_traindata(config): 11 | scale = config["scale"] 12 | stride = config["stride"] 13 | size_input = config["size_input"] 14 | size_label = size_input * scale 15 | size_output = config['size_output'] 16 | padding = abs(size_label - size_output) // 2 17 | if scale == 3: 18 | padding2 = padding 19 | else: 20 | padding2 = padding + 1 21 | method = config['method'] 22 | if config['residual']: 23 | h5savepath = config["hrDir"] + f'_label={size_output}_train_FSRCNNx{scale}_res.h5' 24 | else: 25 | h5savepath = config["hrDir"] + f'_label={size_output}_train_FSRCNNx{scale}.h5' 26 | hrDir = config["hrDir"] + '/' 27 | h5_file = h5py.File(h5savepath, 'w') 28 | imgList = os.listdir(hrDir) 29 | lr_subimgs = [] 30 | hr_subimgs = [] 31 | with tqdm(total=len(imgList)) as t: 32 | for imgName in imgList: 33 | hrIMG = utils.loadIMG_crop(hrDir + imgName, scale) 34 | hr_y = utils.img2ycbcr(hrIMG)[..., 0] 35 | lr_y = imresize(hr_y, 1 / scale, method).astype(np.float32) 36 | # residual 37 | if config['residual']: 38 | lr_y_upscale = imresize(lr_y, scale, method).astype(np.float32) 39 | hr_y = hr_y - lr_y_upscale 40 | 41 | for r in range(0, lr_y.shape[0] - size_input + 1, stride): 42 | for c in range(0, lr_y.shape[1] - size_input + 1, stride): 43 | lr_subimgs.append(lr_y[r: r + size_input, c: c + size_input]) 44 | label = hr_y[r * scale: r * scale + size_label, c * scale: c * scale + size_label] 45 | label = label[padding: -padding2, padding: -padding2] 46 | hr_subimgs.append(label) 47 | t.update(1) 48 | 49 | lr_subimgs = np.array(lr_subimgs).astype(np.float32) 50 | hr_subimgs = np.array(hr_subimgs).astype(np.float32) 51 | 52 | h5_file.create_dataset('data', data=lr_subimgs) 53 | h5_file.create_dataset('label', data=hr_subimgs) 54 | 55 | h5_file.close() 56 | 57 | 58 | def gen_valdata(config): 59 | scale = config["scale"] 60 | size_input = config["size_input"] 61 | size_label = size_input * scale 62 | size_output = config['size_output'] 63 | padding = (size_label - size_output) // 2 64 | method = config['method'] 65 | if scale == 3: 66 | padding2 = padding 67 | else: 68 | padding2 = padding + 1 69 | if config['residual']: 70 | h5savepath = config["hrDir"] + f'_label={size_output}_val_FSRCNNx{scale}_res.h5' 71 | else: 72 | h5savepath = config["hrDir"] + f'_label={size_output}_val_FSRCNNx{scale}.h5' 73 | hrDir = config["hrDir"] + '/' 74 | h5_file = h5py.File(h5savepath, 'w') 75 | lr_group = h5_file.create_group('data') 76 | hr_group = h5_file.create_group('label') 77 | if config['residual']: 78 | bic_group = h5_file.create_group('bicubic') 79 | imgList = os.listdir(hrDir) 80 | for i, imgName in enumerate(imgList): 81 | hrIMG = utils.loadIMG_crop(hrDir+imgName, scale) 82 | hr_y = utils.img2ycbcr(hrIMG)[..., 0] 83 | 84 | lr_y = imresize(hr_y, 1 / scale, method).astype(np.float32) 85 | # residual 86 | if config['residual']: 87 | bic_y = imresize(lr_y, scale, method).astype(np.float32) 88 | bic_y = bic_y[padding: -padding2, padding: -padding2] 89 | label = hr_y.astype(np.float32)[padding: -padding2, padding: -padding2] 90 | 91 | lr_group.create_dataset(str(i), data=lr_y) 92 | hr_group.create_dataset(str(i), data=label) 93 | if config['residual']: 94 | bic_group.create_dataset(str(i), data=bic_y) 95 | h5_file.close() 96 | 97 | if __name__ == '__main__': 98 | # config = {'hrDir': './test/flower', 'scale': 3, "stride": 14, "size_input": 33, "size_label": 21} 99 | config = {'hrDir': '../datasets/T91_aug', 'scale': 4, 'stride': 10, "size_input": 10, "size_output": 21, "residual": True, 'method': 'bicubic'} 100 | gen_traindata(config) 101 | config['hrDir'] = '../datasets/Set5' 102 | gen_valdata(config) 103 | -------------------------------------------------------------------------------- /HCFlow/ActNorms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | import thops 5 | 6 | 7 | class _ActNorm(nn.Module): 8 | """ 9 | Activation Normalization 10 | Initialize the bias and scale with a given minibatch, 11 | so that the output per-channel have zero mean and unit variance for that. 12 | 13 | After initialization, `bias` and `logs` will be trained as parameters. 14 | """ 15 | 16 | def __init__(self, num_features, scale=1.): 17 | super().__init__() 18 | # register mean and scale 19 | size = [1, num_features, 1, 1] 20 | self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) 21 | self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) 22 | self.num_features = num_features 23 | self.scale = float(scale) 24 | self.inited = False 25 | 26 | def _check_input_dim(self, input): 27 | return NotImplemented 28 | 29 | def initialize_parameters(self, input): 30 | self._check_input_dim(input) 31 | if not self.training: 32 | return 33 | if (self.bias != 0).any(): 34 | self.inited = True 35 | return 36 | assert input.device == self.bias.device, (input.device, self.bias.device) 37 | with torch.no_grad(): 38 | bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 39 | vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) 40 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) 41 | self.bias.data.copy_(bias.data) 42 | self.logs.data.copy_(logs.data) 43 | self.inited = True 44 | 45 | def _center(self, input, reverse=False, offset=None): 46 | bias = self.bias 47 | 48 | if offset is not None: 49 | bias = bias + offset 50 | 51 | if not reverse: 52 | return input + bias 53 | else: 54 | return input - bias 55 | 56 | def _scale(self, input, logdet=None, reverse=False, offset=None): 57 | logs = self.logs 58 | 59 | if offset is not None: 60 | logs = logs + offset 61 | 62 | if not reverse: 63 | input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 64 | # input = input * torch.exp(logs+logs_offset) 65 | else: 66 | input = input * torch.exp(-logs) 67 | if logdet is not None: 68 | """ 69 | logs is log_std of `mean of channels` 70 | so we need to multiply pixels 71 | """ 72 | dlogdet = thops.sum(logs) * thops.pixels(input) 73 | if reverse: 74 | dlogdet *= -1 75 | logdet = logdet + dlogdet 76 | return input, logdet 77 | 78 | def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): 79 | if not self.inited: 80 | self.initialize_parameters(input) 81 | 82 | if offset_mask is not None: 83 | logs_offset *= offset_mask 84 | bias_offset *= offset_mask 85 | # no need to permute dims as old version 86 | if not reverse: 87 | # center and scale 88 | input = self._center(input, reverse, bias_offset) 89 | input, logdet = self._scale(input, logdet, reverse, logs_offset) 90 | else: 91 | # scale and center 92 | input, logdet = self._scale(input, logdet, reverse, logs_offset) 93 | input = self._center(input, reverse, bias_offset) 94 | return input, logdet 95 | 96 | 97 | class ActNorm2d(_ActNorm): 98 | def __init__(self, num_features, scale=1.): 99 | super().__init__(num_features, scale) 100 | 101 | def _check_input_dim(self, input): 102 | assert len(input.size()) == 4 103 | assert input.size(1) == self.num_features, ( 104 | "[ActNorm]: input should be in shape as `BCHW`," 105 | " channels should be {} rather than {}".format( 106 | self.num_features, input.size())) 107 | 108 | 109 | class MaskedActNorm2d(ActNorm2d): 110 | def __init__(self, num_features, scale=1.): 111 | super().__init__(num_features, scale) 112 | 113 | def forward(self, input, mask, logdet=None, reverse=False): 114 | 115 | assert mask.dtype == torch.bool 116 | output, logdet_out = super().forward(input, logdet, reverse) 117 | 118 | input[mask] = output[mask] 119 | logdet[mask] = logdet_out[mask] 120 | 121 | return input, logdet 122 | 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SuperRestoration 2 | Include:SRCNN、FSRCNN、SRResNet、SRGAN 3 | 4 | # Introduce 5 | ## Why wrote this? 6 | There are many ***Pytorch*** implementations of these networks on the web, but they do not appear exactly as described in the paper, so the results are quite different from the paper. 7 | So I've provided a version that's as close to the paper as possible. 8 | Hopefully it will help those interested in **Super-Resolution** networks get started. 9 | For more details:https://zhuanlan.zhihu.com/p/431724297 10 | # Install 11 | You can clone this repository directly, and run it without installing. 12 | 13 | ## Running Enviroment 14 | * Pytorch 3.7 64bit 15 | * Windows 10 16 | 17 | ## Reference 18 | Because `bicubic` interpolation in python is different with matlab, 19 | but in paper use matlab to generate datasets and evaluate PSNR, so I found a ***Python*** implementations of ***Matlab*** function:`imresize()`. 20 | [Here is the author's repository.](https://github.com/fatheral/matlab_imresize.git) 21 | Similiarly, I give a python version of the `rgb2ycbcr()` and `ycbcr2rgb()` in matlab. 22 | 23 | # Usage 24 | ## Prepare Datasets 25 | * Run `data_aug.py` to augment datasets. 26 | * Run `gen_datasets.py` to generate trainning and validation data. (You may need to modify parameters in `config`.) 27 | ## Train 28 | Take **SRCNN** as an example, run `SRCNN_x2.py` to train SRCNN. You can modify the training parameters according to your needs follow this template. 29 | ## Test 30 | Run `test.py` to generate test result and calculate PSNR. (You can modify parameters to specify testsets.) 31 | ## Visualize 32 | Run `csv2visdom.py` can visualize converge curve with visdom. (You need to install `visdom` and run it in advance.) 33 | Then visit `localhost:8097`. 34 | 35 | # Result: PSNR 36 | ## SRCNN x3 37 | | |Paper| Ours| 38 | ----|---|---| 39 | baby|35.01|34.96| 40 | bird|34.91|34.95| 41 | butterfly|27.58|27.77| 42 | head|33.55|33.51| 43 | woman|30.92|30.99| 44 | | |32.39|32.43| 45 | 46 | | |Paper| Ours | 47 | ----|-----|------| 48 | baboon|23.60|23.60| 49 | barbara|26.66|26.71| 50 | bridge|25.07|25.08| 51 | coastguard|27.20|27.17| 52 | comic|24.39|24.42| 53 | face|33.58|33.54| 54 | flowers|28.97|29.01| 55 | foreman|33.35|33.32| 56 | lenna|33.39|33.40| 57 | man|28.18|28.18| 58 | monarch|32.39|32.54| 59 | pepper|34.35|34.24| 60 | ppt3|26.02|26.14| 61 | zebra|28.87|28.80| 62 | | |29.00|29.01| 63 | 64 | ## FSRCNN x3 65 | Train on 91-images. 66 | 67 | | |Paper| Ours| 68 | ----|---|---| 69 | Set5|33.06|33.06| 70 | Set14|29.37|29.35| 71 | BSDS200|28.55|28.95| 72 | 73 | ## SRResNet x4 74 | Train on DIV2K. 75 | 76 | | |Paper| Ours| 77 | ----|---|---| 78 | Set5|32.05|32.12| 79 | Set14|28.49|28.50| 80 | BSDS100|27.58|27.54| 81 | 82 | ## SRGAN x4 83 | Train on DIV2K. 84 | 85 | | |Paper| Ours| 86 | ----|---|---| 87 | Set5|29.40|30.19| 88 | Set14|26.02|26.94| 89 | BSDS100|25.16|25.82| 90 | 91 | SRGAN cannot be evaluated by PSNR alone, so I list some test result. 92 | Obviously, SRGAN generates a sharper results than SRResNet and looks more convincing. 93 | 94 | |bicubic|SRResNet|SRGAN|original| 95 | ---|---|---|---| 96 | ![Image text](./SRResNet/result/baby_GT_bicubic_x4.bmp)|![Image text](./SRResNet/result/baby_GT_SRResNet_x4.bmp)|![Image text](./SRResNet/result/baby_GT_SRGAN_x4.bmp)|![Image text](./SRResNet/result/baby_GT.bmp) 97 | ![Image text](./SRResNet/result/woman_GT_bicubic_x4.bmp)|![Image text](./SRResNet/result/woman_GT_SRResNet_x4.bmp)|![Image text](./SRResNet/result/woman_GT_SRGAN_x4.bmp)|![Image text](./SRResNet/result/woman_GT.bmp) 98 | ![Image text](./SRResNet/result/head_GT_bicubic_x4.bmp)|![Image text](./SRResNet/result/head_GT_SRResNet_x4.bmp)|![Image text](./SRResNet/result/head_GT_SRGAN_x4.bmp)|![Image text](./SRResNet/result/head_GT.bmp) 99 | ![Image text](./SRResNet/result/baboon_bicubic_x4.bmp)|![Image text](./SRResNet/result/baboon_SRResNet_x4.bmp)|![Image text](./SRResNet/result/baboon_SRGAN_x4.bmp)|![Image text](./SRResNet/result/baboon.bmp) 100 | ![Image text](./SRResNet/result/coastguard_bicubic_x4.bmp)|![Image text](./SRResNet/result/coastguard_SRResNet_x4.bmp)|![Image text](./SRResNet/result/coastguard_SRGAN_x4.bmp)|![Image text](./SRResNet/result/coastguard.bmp) 101 | ![Image text](./SRResNet/result/comic_bicubic_x4.bmp)|![Image text](./SRResNet/result/comic_SRResNet_x4.bmp)|![Image text](./SRResNet/result/comic_SRGAN_x4.bmp)|![Image text](./SRResNet/result/comic.bmp) 102 | ![Image text](./SRResNet/result/flowers_bicubic_x4.bmp)|![Image text](./SRResNet/result/flowers_SRResNet_x4.bmp)|![Image text](./SRResNet/result/flowers_SRGAN_x4.bmp)|![Image text](./SRResNet/result/flowers.bmp) 103 | 104 | -------------------------------------------------------------------------------- /FSRCNN/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch 4 | 5 | 6 | class FSRCNN(nn.Module): 7 | def __init__(self, scale_factor, in_size, out_size, num_channels=1, d=56, s=12, m=4): 8 | super(FSRCNN, self).__init__() 9 | self.extract_layer = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=5, padding=2, padding_mode='replicate'), 10 | nn.PReLU()) 11 | 12 | self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU()] 13 | for i in range(m): 14 | self.mid_part.extend([nn.ReplicationPad2d(1), 15 | nn.Conv2d(s, s, kernel_size=3), 16 | nn.PReLU()]) 17 | self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU()]) 18 | self.mid_part = nn.Sequential(*self.mid_part) 19 | 20 | # 11->out 21 | self.deconv_layer = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, 22 | padding=(9 + (in_size-1)*scale_factor - out_size)//2) 23 | 24 | def init_weights(self, method='MSRA'): 25 | if method == 'MSRA': 26 | init_weights_MSRA(self) 27 | else: 28 | init_weights_Xavier(self) 29 | 30 | def forward(self, x): 31 | x = self.extract_layer(x) 32 | x = self.mid_part(x) 33 | x = self.deconv_layer(x) 34 | return x 35 | 36 | 37 | class N2_10_4(nn.Module): 38 | def __init__(self, scale_factor, in_size, out_size, num_channels=1, d=10, m=4): 39 | super(N2_10_4, self).__init__() 40 | self.extract_layer = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=3, padding=1, padding_mode='replicate'), 41 | nn.PReLU()) 42 | self.mid_part = [] 43 | for i in range(m): 44 | self.mid_part.extend([nn.Conv2d(d, d, kernel_size=3, padding=1, padding_mode='replicate'), 45 | nn.PReLU()]) 46 | self.mid_part = nn.Sequential(*self.mid_part) 47 | 48 | # 11->out 49 | self.deconv_layer = nn.ConvTranspose2d(d, num_channels, kernel_size=7, stride=scale_factor, 50 | padding=(7 + (in_size-1)*scale_factor - out_size)//2) 51 | 52 | def forward(self, x): 53 | x = self.extract_layer(x) 54 | x = self.mid_part(x) 55 | x = self.deconv_layer(x) 56 | return x 57 | 58 | def init_weights(self, method='MSRA'): 59 | if method == 'MSRA': 60 | init_weights_MSRA(self) 61 | else: 62 | init_weights_Xavier(self) 63 | 64 | 65 | class CLoss(nn.Module): 66 | """L1 Charbonnierloss.""" 67 | def __init__(self, delta=1e-3): 68 | super(CLoss, self).__init__() 69 | self.delta2 = delta * delta 70 | 71 | def forward(self, X, Y): 72 | diff = torch.add(X, -Y) 73 | error = torch.sqrt(diff * diff + self.delta2) 74 | loss = torch.mean(error) 75 | return loss 76 | 77 | 78 | class HuberLoss(nn.Module): 79 | """Huber loss.""" 80 | def __init__(self, delta=2e-4): 81 | super(HuberLoss, self).__init__() 82 | self.delta = delta 83 | self.delta2 = delta * delta 84 | 85 | def forward(self, X, Y): 86 | abs_diff = abs(torch.add(X, -Y)) 87 | cond = torch.less_equal(abs_diff, self.delta) 88 | large_loss = 0.5 * abs_diff * abs_diff 89 | small_loss = self.delta * abs_diff - 0.5 * self.delta2 90 | error = torch.where(cond, large_loss, small_loss) 91 | loss = torch.mean(error) 92 | return loss 93 | 94 | 95 | def init_weights_MSRA(Model): 96 | for L in Model.extract_layer: 97 | if isinstance(L, nn.Conv2d): 98 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels * L.weight.data[0][0].numel()))) 99 | L.bias.data.zero_() 100 | for L in Model.mid_part: 101 | if isinstance(L, nn.Conv2d): 102 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels * L.weight.data[0][0].numel()))) 103 | L.bias.data.zero_() 104 | Model.deconv_layer.weight.data.normal_(mean=0.0, std=0.001) 105 | Model.deconv_layer.bias.data.zero_() 106 | 107 | def init_weights_Xavier(Model): 108 | for L in Model.extract_layer: 109 | if isinstance(L, nn.Conv2d): 110 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels + L.in_channels))) 111 | L.bias.data.zero_() 112 | for L in Model.mid_part: 113 | if isinstance(L, nn.Conv2d): 114 | L.weight.data.normal_(mean=0.0, std=math.sqrt(2 / (L.out_channels + L.in_channels))) 115 | L.bias.data.zero_() 116 | Model.deconv_layer.weight.data.normal_(mean=0.0, std=0.001) 117 | Model.deconv_layer.bias.data.zero_() -------------------------------------------------------------------------------- /SwinIR/test_swinir_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import numpy as np 4 | import torch 5 | from torch.backends import cudnn 6 | import niqe 7 | import utils 8 | from swinir_model import SwinIR 9 | from PIL import Image 10 | import os 11 | 12 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 13 | if __name__ == '__main__': 14 | model_name = 'SwinIR-M' 15 | csv_file = f'./test_res/{model_name}_x4_degraded17.csv' 16 | csv_file = open(csv_file, 'w', newline='') 17 | writer = csv.writer(csv_file) 18 | writer.writerow(('name', 'psnr', 'niqe', 'ssim')) 19 | 20 | weight_file = './weight_file/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth' 21 | is_large = False 22 | root_dir = '/data0/jli/datasets/degraded_17/' 23 | out_root_dir = f'./test_res/{model_name}_degraded17/' 24 | hr_dir = '/data0/jli/datasets/PIPAL/' 25 | 26 | lr_dirs = os.listdir(root_dir) 27 | for dir in lr_dirs: 28 | utils.mkdirs(out_root_dir + dir) 29 | 30 | scale = 4 31 | padding = scale 32 | 33 | if not os.path.exists(weight_file): 34 | print(f'Weight file not exist!\n{weight_file}\n') 35 | raise "Error" 36 | 37 | cudnn.benchmark = True 38 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 39 | if not is_large: 40 | model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8, 41 | img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], 42 | mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv').to(device) 43 | else: 44 | model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8, 45 | img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240, 46 | num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], 47 | mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv').to(device) 48 | checkpoint = torch.load(weight_file) 49 | model.load_state_dict(checkpoint['params_ema']) 50 | model.eval() 51 | 52 | for dir in lr_dirs: 53 | outputs_dir = out_root_dir + dir + '/' 54 | lr_dir = root_dir + dir + '/' 55 | lr_lists = os.listdir(lr_dir) 56 | Avg_psnr = utils.AverageMeter() 57 | Avg_niqe = utils.AverageMeter() 58 | Avg_ssim = utils.AverageMeter() 59 | for imgName in lr_lists: 60 | image = utils.loadIMG_crop(lr_dir + imgName, scale) 61 | hr_image = utils.loadIMG_crop(hr_dir + imgName, scale) 62 | img_mode = image.mode 63 | if img_mode == 'L': 64 | gray_img = np.array(image) 65 | image = image.convert('RGB') 66 | lr_image = np.array(image) 67 | hr_image = np.array(hr_image) 68 | 69 | lr = lr_image.astype(np.float32).transpose([2, 0, 1]) # hwc -> chw 70 | lr /= 255. 71 | lr = torch.from_numpy(lr).to(device).unsqueeze(0) 72 | 73 | # hr_image = hr_image[padding: -padding, padding: -padding, ...] 74 | 75 | with torch.no_grad(): 76 | SR = model(lr) 77 | # SR = SR[..., padding: -padding, padding: -padding] 78 | SR = SR.mul(255.0).cpu().numpy().squeeze(0) 79 | SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0]) 80 | if img_mode != 'L': 81 | SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255. 82 | hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255. 83 | else: 84 | # gray_img = gray_img.astype(np.float32)[padding: -padding, padding: -padding, ...] 85 | hr_y = gray_img / 255. 86 | SR = Image.fromarray(SR.astype(np.uint8)).convert('L') 87 | SR_y = np.array(SR).astype(np.float32) / 255. 88 | psnr = utils.calc_psnr(hr_y, SR_y) 89 | NIQE = niqe.calculate_niqe(SR_y) 90 | ssim = utils.calculate_ssim(hr_y * 255, SR_y * 255) 91 | Avg_psnr.update(psnr, 1) 92 | Avg_niqe.update(NIQE, 1) 93 | Avg_ssim.update(ssim, 1) 94 | # print(f'{imgName}, ' + 'PSNR: {:.2f}'.format(psnr.item())) 95 | print(f'{imgName}, ' + 'PSNR: {:.2f} , NIQE: {:.4f}, ssim: {:.4f}'.format(psnr.item(), NIQE, ssim)) 96 | # GPU tensor -> CPU tensor -> numpy 97 | output = np.array(SR).astype(np.uint8) 98 | output = Image.fromarray(output) # hw -> wh 99 | output.save(outputs_dir + imgName.replace('.', '_{:.2f}.'.format(psnr.item()))) 100 | print('Average_PSNR: {:.2f}, Average_NIQE: {:.4f}, Average_ssim: {:.4f}'.format(Avg_psnr.avg, Avg_niqe.avg, 101 | Avg_ssim.avg)) 102 | writer.writerow((dir, Avg_psnr.avg, Avg_niqe.avg, Avg_ssim.avg)) 103 | 104 | 105 | 106 | --------------------------------------------------------------------------------