├── fig ├── .DS_Store ├── result.png ├── ex_pair.png ├── overall.png └── ex_unpair.png ├── code ├── base │ ├── .DS_Store │ ├── rain100H │ │ ├── .DS_Store │ │ ├── config │ │ │ ├── .DS_Store │ │ │ ├── function │ │ │ │ ├── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __pycache__ │ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ │ ├── subtraction.py │ │ │ │ │ ├── aggregation.py │ │ │ │ │ └── subtraction2.py │ │ │ │ ├── __pycache__ │ │ │ │ │ └── functional.cpython-37.pyc │ │ │ │ ├── functions │ │ │ │ │ ├── __pycache__ │ │ │ │ │ │ ├── utils.cpython-37.pyc │ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ │ └── subtraction2_zeropad.cpython-37.pyc │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── utils.py │ │ │ │ └── functional.py │ │ │ ├── clean.sh │ │ │ ├── compile.py │ │ │ ├── tensorboard.sh │ │ │ ├── settings.py │ │ │ ├── cal_ssim.py │ │ │ ├── eval.py │ │ │ ├── show.py │ │ │ ├── dataset.py │ │ │ └── train.py │ │ └── models │ │ │ ├── .DS_Store │ │ │ └── README.md │ └── rain100L │ │ ├── .DS_Store │ │ ├── config │ │ ├── .DS_Store │ │ ├── function │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ ├── subtraction.py │ │ │ │ ├── aggregation.py │ │ │ │ └── subtraction2.py │ │ │ ├── __pycache__ │ │ │ │ └── functional.cpython-37.pyc │ │ │ ├── functions │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── utils.cpython-37.pyc │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ └── subtraction2_zeropad.cpython-37.pyc │ │ │ │ ├── __init__.py │ │ │ │ └── utils.py │ │ │ └── functional.py │ │ ├── clean.sh │ │ ├── compile.py │ │ ├── tensorboard.sh │ │ ├── settings.py │ │ ├── cal_ssim.py │ │ ├── eval.py │ │ ├── show.py │ │ └── dataset.py │ │ └── models │ │ ├── .DS_Store │ │ └── README.md ├── diff_loss │ ├── .DS_Store │ ├── mae │ │ ├── .DS_Store │ │ ├── config │ │ │ ├── .DS_Store │ │ │ ├── function │ │ │ │ ├── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __pycache__ │ │ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ │ ├── aggregation.cpython-36.pyc │ │ │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ │ │ ├── subtraction.cpython-36.pyc │ │ │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ │ │ ├── subtraction2.cpython-36.pyc │ │ │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ │ ├── subtraction.py │ │ │ │ │ ├── aggregation.py │ │ │ │ │ └── subtraction2.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── functional.cpython-36.pyc │ │ │ │ │ └── functional.cpython-37.pyc │ │ │ │ ├── functions │ │ │ │ │ ├── __pycache__ │ │ │ │ │ │ ├── utils.cpython-36.pyc │ │ │ │ │ │ ├── utils.cpython-37.pyc │ │ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ │ ├── aggregation_refpad.cpython-36.pyc │ │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ │ ├── aggregation_zeropad.cpython-36.pyc │ │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction2_refpad.cpython-36.pyc │ │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction_refpad.cpython-36.pyc │ │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction_zeropad.cpython-36.pyc │ │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ │ ├── subtraction2_zeropad.cpython-36.pyc │ │ │ │ │ │ └── subtraction2_zeropad.cpython-37.pyc │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── utils.py │ │ │ │ └── functional.py │ │ │ ├── clean.sh │ │ │ ├── compile.py │ │ │ ├── tensorboard.sh │ │ │ ├── settings.py │ │ │ ├── cal_ssim.py │ │ │ ├── eval.py │ │ │ ├── show.py │ │ │ └── dataset.py │ │ └── models │ │ │ ├── .DS_Store │ │ │ └── README.md │ └── mse │ │ ├── .DS_Store │ │ ├── config │ │ ├── .DS_Store │ │ ├── function │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation.cpython-37.pyc │ │ │ │ │ ├── subtraction.cpython-37.pyc │ │ │ │ │ └── subtraction2.cpython-37.pyc │ │ │ │ ├── subtraction.py │ │ │ │ ├── aggregation.py │ │ │ │ └── subtraction2.py │ │ │ ├── __pycache__ │ │ │ │ └── functional.cpython-37.pyc │ │ │ ├── functions │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── utils.cpython-37.pyc │ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ │ ├── aggregation_refpad.cpython-37.pyc │ │ │ │ │ ├── aggregation_zeropad.cpython-37.pyc │ │ │ │ │ ├── subtraction2_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction_refpad.cpython-37.pyc │ │ │ │ │ ├── subtraction_zeropad.cpython-37.pyc │ │ │ │ │ └── subtraction2_zeropad.cpython-37.pyc │ │ │ │ ├── __init__.py │ │ │ │ └── utils.py │ │ │ └── functional.py │ │ ├── clean.sh │ │ ├── compile.py │ │ ├── tensorboard.sh │ │ ├── settings.py │ │ ├── cal_ssim.py │ │ ├── eval.py │ │ ├── show.py │ │ └── dataset.py │ │ └── models │ │ ├── .DS_Store │ │ └── README.md └── ablation │ ├── r2 │ ├── config │ │ ├── .DS_Store │ │ ├── clean.sh │ │ ├── compile.py │ │ ├── tensorboard.sh │ │ ├── settings.py │ │ ├── cal_ssim.py │ │ ├── eval.py │ │ ├── show.py │ │ ├── dataset.py │ │ ├── model.py │ │ └── train.py │ └── models │ │ ├── .DS_Store │ │ └── README.md │ └── r1 │ ├── config │ ├── clean.sh │ ├── compile.py │ ├── tensorboard.sh │ ├── settings.py │ ├── cal_ssim.py │ ├── model.py │ ├── eval.py │ ├── show.py │ ├── dataset.py │ └── train.py │ └── models │ └── README.md └── README.md /fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/fig/.DS_Store -------------------------------------------------------------------------------- /fig/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/fig/result.png -------------------------------------------------------------------------------- /fig/ex_pair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/fig/ex_pair.png -------------------------------------------------------------------------------- /fig/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/fig/overall.png -------------------------------------------------------------------------------- /code/base/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/.DS_Store -------------------------------------------------------------------------------- /fig/ex_unpair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/fig/ex_unpair.png -------------------------------------------------------------------------------- /code/diff_loss/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100L/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mse/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/.DS_Store -------------------------------------------------------------------------------- /code/ablation/r2/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/ablation/r2/config/.DS_Store -------------------------------------------------------------------------------- /code/ablation/r2/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/ablation/r2/models/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/models/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100L/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100L/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/models/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mae/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/models/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mse/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/.DS_Store -------------------------------------------------------------------------------- /code/diff_loss/mse/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/models/.DS_Store -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /code/ablation/r1/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/ablation/r2/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/base/rain100H/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/base/rain100L/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/* 7 | -------------------------------------------------------------------------------- /code/ablation/r1/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/ablation/r2/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/base/rain100H/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/base/rain100L/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/diff_loss/mae/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/diff_loss/mse/config/compile.py: -------------------------------------------------------------------------------- 1 | import re 2 | rec_data="+MIPLOBSERVE:0,68220,1,3303,0,-1" 3 | msgidRegex = re.compile(r',(\d)+,') 4 | mo = msgidRegex.search(rec_data) 5 | print(mo.group()) -------------------------------------------------------------------------------- /code/ablation/r1/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/ablation/r2/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/__pycache__/functional.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/__pycache__/functional.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/base/rain100L/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/diff_loss/mae/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/diff_loss/mse/models/README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | The training model or pre-training model will be placed here. 4 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/modules/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/subtraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/modules/__pycache__/subtraction.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/__pycache__/subtraction2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/modules/__pycache__/subtraction2.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/aggregation_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_refpad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/subtraction_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100H/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/base/rain100L/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-36.pyc -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mae/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ohraincu/JDNet/HEAD/code/diff_loss/mse/config/function/functions/__pycache__/subtraction2_zeropad.cpython-37.pyc -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /code/ablation/r1/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/ablation/r2/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/base/rain100H/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/base/rain100L/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port $rnd --reload_interval 3 11 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.util.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /code/base/rain100H/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/base/rain100L/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/function/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /code/ablation/r1/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置 ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/wangcong/dataset/rain12' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '2,1' 39 | 40 | epoch = 1000 41 | batch_size = 28 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/ablation/r2/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置 ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/wangcong/dataset/rain12' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '2,1' 39 | 40 | epoch = 1000 41 | batch_size = 24 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/base/rain100H/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置 ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/wangcong/dataset/rain12' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '1,2' 39 | 40 | epoch = 1000 41 | batch_size = 12 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置 ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/dataset/haze_blur_rain/rain100H' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '0,1' 39 | 40 | epoch = 1000 41 | batch_size = 12 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/base/rain100L/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置 ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/dataset/haze_blur_rain/rain100L' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 4 37 | 38 | device_id = '0,1,2,3' 39 | 40 | epoch = 1000 41 | batch_size = 32 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | ####################### 网络参数设置 ############################################################ 6 | channel = 32 7 | feature_map_num = 32 8 | res_conv_num = 4 9 | unit_num = 32 10 | #scale_num = 4 11 | num_scale_attention = 4 12 | scale_attention = False 13 | ssim_loss = True 14 | ######################################################################################## 15 | aug_data = False # Set as False for fair comparison 16 | 17 | patch_size = 64 18 | pic_is_pair = True #input picture is pair or single 19 | 20 | lr = 0.0005 21 | 22 | data_dir = '/data1/datasets/haze_blur_rain/rain100H' 23 | if pic_is_pair is False: 24 | data_dir = '/data1/wangcong/dataset/real-world-images' 25 | log_dir = '../logdir' 26 | show_dir = '../showdir' 27 | model_dir = '../models' 28 | show_dir_feature = '../showdir_feature' 29 | 30 | log_level = 'info' 31 | model_path = os.path.join(model_dir, 'latest_net') 32 | save_steps = 400 33 | 34 | num_workers = 8 35 | 36 | num_GPU = 2 37 | 38 | device_id = '1,2' 39 | 40 | epoch = 1000 41 | batch_size = 12 42 | 43 | if pic_is_pair: 44 | root_dir = os.path.join(data_dir, 'train') 45 | mat_files = os.listdir(root_dir) 46 | num_datasets = len(mat_files) 47 | l1 = int(3/5 * epoch * num_datasets / batch_size) 48 | l2 = int(4/5 * epoch * num_datasets / batch_size) 49 | one_epoch = int(num_datasets/batch_size) 50 | total_step = int((epoch * num_datasets)/batch_size) 51 | 52 | logger = logging.getLogger('train') 53 | logger.setLevel(logging.INFO) 54 | 55 | ch = logging.StreamHandler() 56 | ch.setLevel(logging.INFO) 57 | 58 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 59 | ch.setFormatter(formatter) 60 | logger.addHandler(ch) 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/ablation/r1/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/ablation/r2/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/base/rain100H/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/base/rain100L/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/cal_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 3 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | #print(img1.size()) 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /code/ablation/r1/config/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import settings 5 | from itertools import combinations,product 6 | import math 7 | 8 | class Residual_Block(nn.Module): 9 | def __init__(self, in_ch, out_ch, stride): 10 | super(Residual_Block, self).__init__() 11 | self.channel_num = settings.channel 12 | self.convs = nn.ModuleList() 13 | self.relus = nn.ModuleList() 14 | self.convert = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, stride, 1), 16 | nn.LeakyReLU(0.2) 17 | ) 18 | self.res = nn.Sequential( 19 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 20 | nn.LeakyReLU(0.2), 21 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 22 | ) 23 | 24 | def forward(self, x): 25 | convert = self.convert(x) 26 | out = convert + self.res(convert) 27 | return out 28 | 29 | class Scale_attention(nn.Module): 30 | def __init__(self): 31 | super(Scale_attention, self).__init__() 32 | self.scale_attention = nn.ModuleList() 33 | self.res_list = nn.ModuleList() 34 | self.channel = settings.channel 35 | if settings.scale_attention is True: 36 | for i in range(settings.num_scale_attention): 37 | self.scale_attention.append( 38 | nn.Sequential( 39 | nn.MaxPool2d(2 ** (i + 1), 2 ** (i + 1)), 40 | nn.Conv2d(self.channel, self.channel, 1, 1), 41 | nn.Sigmoid() 42 | ) 43 | ) 44 | for i in range(settings.num_scale_attention): 45 | self.res_list.append( 46 | Residual_Block(self.channel, self.channel, 2) 47 | ) 48 | 49 | self.conv11 = nn.Sequential( 50 | nn.Conv2d((settings.num_scale_attention + 1) * self.channel, self.channel, 1, 1), 51 | nn.LeakyReLU(0.2) 52 | ) 53 | 54 | def forward(self, x): 55 | b, c, h, w = x.size() 56 | temp = x 57 | out = [] 58 | out.append(temp) 59 | if settings.scale_attention is True: 60 | for i in range(settings.num_scale_attention): 61 | temp = self.res_list[i](temp) 62 | b0,c0,h0,w0 = temp.size() 63 | temp = temp * F.upsample(self.scale_attention[i](x), [h0, w0]) 64 | up = temp 65 | out.append(F.upsample(up, [h, w])) 66 | fusion = self.conv11(torch.cat(out, dim=1)) 67 | 68 | else: 69 | for i in range(settings.num_scale_attention): 70 | temp = self.res_list[i](temp) 71 | up = temp 72 | out.append(F.upsample(up, [h, w])) 73 | fusion = self.conv11(torch.cat(out, dim=1)) 74 | return fusion + x 75 | 76 | class DenseConnection(nn.Module): 77 | def __init__(self, unit, unit_num): 78 | super(DenseConnection, self).__init__() 79 | self.unit_num = unit_num 80 | self.channel = settings.channel 81 | self.units = nn.ModuleList() 82 | self.conv1x1 = nn.ModuleList() 83 | for i in range(self.unit_num): 84 | self.units.append(unit()) 85 | self.conv1x1.append(nn.Sequential(nn.Conv2d((i+2)*self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))) 86 | 87 | def forward(self, x): 88 | cat = [] 89 | cat.append(x) 90 | out = x 91 | for i in range(self.unit_num): 92 | tmp = self.units[i](out) 93 | cat.append(tmp) 94 | out = self.conv1x1[i](torch.cat(cat,dim=1)) 95 | return out 96 | 97 | 98 | class ODE_DerainNet(nn.Module): 99 | def __init__(self): 100 | super(ODE_DerainNet, self).__init__() 101 | self.channel = settings.channel 102 | self.unit_num = settings.unit_num 103 | self.enterBlock = nn.Sequential(nn.Conv2d(3, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)) 104 | self.derain_net = DenseConnection(Scale_attention, self.unit_num) 105 | self.exitBlock = nn.Sequential(nn.Conv2d(self.channel, 3, 3, 1, 1), nn.LeakyReLU(0.2)) 106 | 107 | 108 | def forward(self, x): 109 | image_feature = self.enterBlock(x) 110 | rain_feature = self.derain_net(image_feature) 111 | rain = self.exitBlock(rain_feature) 112 | derain = x - rain 113 | return derain 114 | -------------------------------------------------------------------------------- /code/ablation/r1/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/ablation/r2/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/base/rain100H/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | 79 | def inf_batch(self, name, batch): 80 | O, B = batch['O'].cuda(), batch['B'].cuda() 81 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 82 | 83 | with torch.no_grad(): 84 | derain = self.net(O) 85 | 86 | l1_loss = self.l1(derain, B) 87 | ssim = self.ssim(derain, B) 88 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 89 | losses = { 'L1 loss' : l1_loss } 90 | ssimes = { 'ssim' : ssim } 91 | losses.update(ssimes) 92 | 93 | return losses, psnr 94 | 95 | 96 | def run_test(ckp_name): 97 | sess = Session() 98 | sess.net.eval() 99 | sess.load_checkpoints(ckp_name) 100 | dt = sess.get_dataloader('test') 101 | psnr_all = 0 102 | all_num = 0 103 | all_losses = {} 104 | for i, batch in enumerate(dt): 105 | losses,psnr= sess.inf_batch('test', batch) 106 | psnr_all=psnr_all+psnr 107 | batch_size = batch['O'].size(0) 108 | all_num += batch_size 109 | for key, val in losses.items(): 110 | if i == 0: 111 | all_losses[key] = 0. 112 | all_losses[key] += val * batch_size 113 | logger.info('batch %d mse %s: %f' % (i, key, val)) 114 | 115 | for key, val in all_losses.items(): 116 | logger.info('total mse %s: %f' % (key, val / all_num)) 117 | #psnr=sum(psnr_all) 118 | #print(psnr) 119 | print('psnr_ll:%8f'%(psnr_all/all_num)) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-m', '--model', default='latest_net') 124 | 125 | args = parser.parse_args(sys.argv[1:]) 126 | run_test(args.model) 127 | 128 | -------------------------------------------------------------------------------- /code/base/rain100L/config/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import MSELoss 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | import settings 17 | from dataset import TestDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | def PSNR(img1, img2): 32 | b,_,_,_=img1.shape 33 | #mse=0 34 | #for i in range(b): 35 | img1=np.clip(img1,0,255) 36 | img2=np.clip(img2,0,255) 37 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 38 | if mse == 0: 39 | return 100 40 | #mse=mse/b 41 | PIXEL_MAX = 1 42 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 43 | class Session: 44 | def __init__(self): 45 | self.log_dir = settings.log_dir 46 | self.model_dir = settings.model_dir 47 | ensure_dir(settings.log_dir) 48 | ensure_dir(settings.model_dir) 49 | logger.info('set log dir as %s' % settings.log_dir) 50 | logger.info('set model dir as %s' % settings.model_dir) 51 | if len(settings.device_id) >1: 52 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 53 | else: 54 | torch.cuda.set_device(settings.device_id[0]) 55 | self.net = ODE_DerainNet().cuda() 56 | self.l2 = MSELoss().cuda() 57 | self.l1 = nn.L1Loss().cuda() 58 | self.ssim = SSIM().cuda() 59 | self.dataloaders = {} 60 | 61 | def get_dataloader(self, dataset_name): 62 | dataset = TestDataset(dataset_name) 63 | if not dataset_name in self.dataloaders: 64 | self.dataloaders[dataset_name] = \ 65 | DataLoader(dataset, batch_size=1, 66 | shuffle=False, num_workers=1, drop_last=False) 67 | return self.dataloaders[dataset_name] 68 | 69 | def load_checkpoints(self, name): 70 | ckp_path = os.path.join(self.model_dir, name) 71 | try: 72 | obj = torch.load(ckp_path) 73 | logger.info('Load checkpoint %s' % ckp_path) 74 | except FileNotFoundError: 75 | logger.info('No checkpoint %s!!' % ckp_path) 76 | return 77 | self.net.load_state_dict(obj['net']) 78 | def loss_vgg(self,input,groundtruth): 79 | vgg_gt = self.vgg.forward(groundtruth) 80 | eval = self.vgg.forward(input) 81 | loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))] 82 | loss = sum(loss_vgg) 83 | return loss 84 | 85 | def inf_batch(self, name, batch): 86 | O, B = batch['O'].cuda(), batch['B'].cuda() 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | 89 | with torch.no_grad(): 90 | derain = self.net(O) 91 | 92 | l1_loss = self.l1(derain, B) 93 | ssim = self.ssim(derain, B) 94 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 95 | losses = { 'L1 loss' : l1_loss } 96 | ssimes = { 'ssim' : ssim } 97 | losses.update(ssimes) 98 | 99 | return losses, psnr 100 | 101 | 102 | def run_test(ckp_name): 103 | sess = Session() 104 | sess.net.eval() 105 | sess.load_checkpoints(ckp_name) 106 | dt = sess.get_dataloader('test') 107 | psnr_all = 0 108 | all_num = 0 109 | all_losses = {} 110 | for i, batch in enumerate(dt): 111 | losses,psnr= sess.inf_batch('test', batch) 112 | psnr_all=psnr_all+psnr 113 | batch_size = batch['O'].size(0) 114 | all_num += batch_size 115 | for key, val in losses.items(): 116 | if i == 0: 117 | all_losses[key] = 0. 118 | all_losses[key] += val * batch_size 119 | logger.info('batch %d mse %s: %f' % (i, key, val)) 120 | 121 | for key, val in all_losses.items(): 122 | logger.info('total mse %s: %f' % (key, val / all_num)) 123 | #psnr=sum(psnr_all) 124 | #print(psnr) 125 | print('psnr_ll:%8f'%(psnr_all/all_num)) 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('-m', '--model', default='latest_net') 130 | 131 | args = parser.parse_args(sys.argv[1:]) 132 | run_test(args.model) 133 | 134 | -------------------------------------------------------------------------------- /code/base/rain100L/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-w' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/ablation/r1/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>423: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/ablation/r2/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/base/rain100H/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/show.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import math 6 | import numpy as np 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import DataParallel 12 | from torch.optim import Adam 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | 16 | import settings 17 | from dataset import ShowDataset 18 | from model import ODE_DerainNet 19 | from cal_ssim import SSIM 20 | 21 | os.environ['CUDA_VISIBLE_DEVICES'] = settings.device_id 22 | logger = settings.logger 23 | torch.cuda.manual_seed_all(66) 24 | torch.manual_seed(66) 25 | #torch.cuda.set_device(settings.device_id) 26 | 27 | 28 | def ensure_dir(dir_path): 29 | if not os.path.isdir(dir_path): 30 | os.makedirs(dir_path) 31 | 32 | def PSNR(img1, img2): 33 | b,_,_,_=img1.shape 34 | #mse=0 35 | #for i in range(b): 36 | img1=np.clip(img1,0,255) 37 | img2=np.clip(img2,0,255) 38 | mse = np.mean((img1/ 255. - img2/ 255.) ** 2)#+mse 39 | if mse == 0: 40 | return 100 41 | #mse=mse/b 42 | PIXEL_MAX = 1 43 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 44 | 45 | class Session: 46 | def __init__(self): 47 | self.show_dir = settings.show_dir 48 | self.model_dir = settings.model_dir 49 | ensure_dir(settings.show_dir) 50 | ensure_dir(settings.model_dir) 51 | logger.info('set show dir as %s' % settings.show_dir) 52 | logger.info('set model dir as %s' % settings.model_dir) 53 | 54 | if len(settings.device_id) >1: 55 | self.net = nn.DataParallel(ODE_DerainNet()).cuda() 56 | #self.l2 = nn.DataParallel(MSELoss(),settings.device_id) 57 | #self.l1 = nn.DataParallel(nn.L1Loss(),settings.device_id) 58 | #self.ssim = nn.DataParallel(SSIM(),settings.device_id) 59 | #self.vgg = nn.DataParallel(VGG(),settings.device_id) 60 | else: 61 | torch.cuda.set_device(settings.device_id[0]) 62 | self.net = ODE_DerainNet().cuda() 63 | self.ssim = SSIM().cuda() 64 | self.dataloaders = {} 65 | self.ssim=SSIM().cuda() 66 | self.a=0 67 | self.t=0 68 | def get_dataloader(self, dataset_name): 69 | dataset = ShowDataset(dataset_name) 70 | self.dataloaders[dataset_name] = \ 71 | DataLoader(dataset, batch_size=1, 72 | shuffle=False, num_workers=1) 73 | return self.dataloaders[dataset_name] 74 | 75 | def load_checkpoints(self, name): 76 | ckp_path = os.path.join(self.model_dir, name) 77 | try: 78 | obj = torch.load(ckp_path) 79 | logger.info('Load checkpoint %s' % ckp_path) 80 | except FileNotFoundError: 81 | logger.info('No checkpoint %s!!' % ckp_path) 82 | return 83 | self.net.load_state_dict(obj['net']) 84 | def inf_batch(self, name, batch,i): 85 | O, B, file_name= batch['O'].cuda(), batch['B'].cuda(), batch['file_name'] 86 | file_name = str(file_name[0]) 87 | O, B = Variable(O, requires_grad=False), Variable(B, requires_grad=False) 88 | with torch.no_grad(): 89 | import time 90 | t0=time.time() 91 | derain = self.net(O) 92 | t1 = time.time() 93 | comput_time=t1-t0 94 | print(comput_time) 95 | ssim = self.ssim(derain, B).data.cpu().numpy() 96 | psnr = PSNR(derain.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) 97 | print('psnr:%4f-------------ssim:%4f'%(psnr, ssim)) 98 | return derain, psnr, ssim, file_name 99 | 100 | def save_image(self, No, imgs, name, psnr, ssim, file_name): 101 | for i, img in enumerate(imgs): 102 | img = (img.cpu().data * 255).numpy() 103 | img = np.clip(img, 0, 255) 104 | img = np.transpose(img, (1, 2, 0)) 105 | h, w, c = img.shape 106 | 107 | img_file = os.path.join(self.show_dir, '%s.png' % (file_name)) 108 | print(img_file) 109 | cv2.imwrite(img_file, img) 110 | 111 | 112 | def run_show(ckp_name): 113 | sess = Session() 114 | sess.load_checkpoints(ckp_name) 115 | sess.net.eval() 116 | dataset = 'test' 117 | if settings.pic_is_pair is False: 118 | dataset = 'train-small' 119 | dt = sess.get_dataloader(dataset) 120 | 121 | for i, batch in enumerate(dt): 122 | logger.info(i) 123 | if i>-1: 124 | imgs,psnr,ssim, file_name= sess.inf_batch('test', batch,i) 125 | sess.save_image(i, imgs, dataset, psnr, ssim, file_name) 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-m', '--model', default='latest_net') 132 | 133 | args = parser.parse_args(sys.argv[1:]) 134 | 135 | run_show(args.model) 136 | 137 | -------------------------------------------------------------------------------- /code/ablation/r1/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/ablation/r2/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/base/rain100H/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/base/rain100L/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/diff_loss/mae/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/diff_loss/mse/config/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | 7 | import settings 8 | 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.rand_state = RandomState(66) 14 | self.root_dir = os.path.join(settings.data_dir, name) 15 | self.mat_files = os.listdir(self.root_dir) 16 | self.patch_size = settings.patch_size 17 | self.file_num = len(self.mat_files) 18 | 19 | def __len__(self): 20 | return self.file_num * 100 21 | 22 | def __getitem__(self, idx): 23 | file_name = self.mat_files[idx % self.file_num] 24 | img_file = os.path.join(self.root_dir, file_name) 25 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 26 | 27 | if settings.aug_data: 28 | O, B = self.crop(img_pair, aug=True) 29 | O, B = self.flip(O, B) 30 | O, B = self.rotate(O, B) 31 | else: 32 | O, B = self.crop(img_pair, aug=False) 33 | 34 | O = np.transpose(O, (2, 0, 1)) 35 | B = np.transpose(B, (2, 0, 1)) 36 | sample = {'O': O, 'B': B} 37 | 38 | return sample 39 | 40 | def crop(self, img_pair, aug): 41 | patch_size = self.patch_size 42 | h, ww, c = img_pair.shape 43 | w = int(ww / 2) 44 | 45 | if aug: 46 | mini = - 1 / 4 * self.patch_size 47 | maxi = 1 / 4 * self.patch_size + 1 48 | p_h = patch_size + self.rand_state.randint(mini, maxi) 49 | p_w = patch_size + self.rand_state.randint(mini, maxi) 50 | else: 51 | p_h, p_w = patch_size, patch_size 52 | 53 | r = self.rand_state.randint(0, h - p_h) 54 | c = self.rand_state.randint(0, w - p_w) 55 | O = img_pair[r: r+p_h, c+w: c+p_w+w] 56 | B = img_pair[r: r+p_h, c: c+p_w] 57 | 58 | if aug: 59 | O = cv2.resize(O, (patch_size, patch_size)) 60 | B = cv2.resize(B, (patch_size, patch_size)) 61 | 62 | return O, B 63 | 64 | def flip(self, O, B): 65 | if self.rand_state.rand() > 0.5: 66 | O = np.flip(O, axis=1) 67 | B = np.flip(B, axis=1) 68 | return O, B 69 | 70 | def rotate(self, O, B): 71 | angle = self.rand_state.randint(-30, 30) 72 | patch_size = self.patch_size 73 | center = (int(patch_size / 2), int(patch_size / 2)) 74 | M = cv2.getRotationMatrix2D(center, angle, 1) 75 | O = cv2.warpAffine(O, M, (patch_size, patch_size)) 76 | B = cv2.warpAffine(B, M, (patch_size, patch_size)) 77 | return O, B 78 | 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, name): 82 | super().__init__() 83 | self.rand_state = RandomState(66) 84 | self.root_dir = os.path.join(settings.data_dir, name) 85 | self.mat_files = os.listdir(self.root_dir) 86 | self.patch_size = settings.patch_size 87 | self.file_num = len(self.mat_files) 88 | 89 | def __len__(self): 90 | return self.file_num 91 | 92 | def __getitem__(self, idx): 93 | file_name = self.mat_files[idx % self.file_num] 94 | img_file = os.path.join(self.root_dir, file_name) 95 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 96 | h, ww, c = img_pair.shape 97 | w = int(ww / 2) 98 | #h_8=h%8 99 | #w_8=w%8 100 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 101 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 102 | sample = {'O': O, 'B': B} 103 | 104 | return sample 105 | 106 | 107 | class ShowDataset(Dataset): 108 | def __init__(self, name): 109 | super().__init__() 110 | self.rand_state = RandomState(66) 111 | self.root_dir = os.path.join(settings.data_dir, name) 112 | self.img_files = sorted(os.listdir(self.root_dir)) 113 | self.file_num = len(self.img_files) 114 | 115 | def __len__(self): 116 | return self.file_num 117 | 118 | def __getitem__(self, idx): 119 | file_name = self.img_files[idx % self.file_num] 120 | img_file = os.path.join(self.root_dir, file_name) 121 | print(img_file) 122 | img_pair = cv2.imread(img_file).astype(np.float32) / 255 123 | 124 | h, ww, c = img_pair.shape 125 | w = int(ww / 2) 126 | 127 | #h_8 = h % 8 128 | #w_8 = w % 8 129 | if settings.pic_is_pair: 130 | O = np.transpose(img_pair[:, w:], (2, 0, 1)) 131 | B = np.transpose(img_pair[:, :w], (2, 0, 1)) 132 | else: 133 | O = np.transpose(img_pair[:, :], (2, 0, 1)) 134 | B = np.transpose(img_pair[:, :], (2, 0, 1)) 135 | sample = {'O': O, 'B': B,'file_name':file_name[:-4]} 136 | 137 | return sample 138 | 139 | 140 | if __name__ == '__main__': 141 | dt = TrainValDataset('val') 142 | print('TrainValDataset') 143 | for i in range(10): 144 | smp = dt[i] 145 | for k, v in smp.items(): 146 | print(k, v.shape, v.dtype, v.mean()) 147 | 148 | print() 149 | dt = TestDataset('test') 150 | print('TestDataset') 151 | for i in range(10): 152 | smp = dt[i] 153 | for k, v in smp.items(): 154 | print(k, v.shape, v.dtype, v.mean()) 155 | 156 | print() 157 | print('ShowDataset') 158 | dt = ShowDataset('test') 159 | for i in range(10): 160 | smp = dt[i] 161 | for k, v in smp.items(): 162 | print(k, v.shape, v.dtype, v.mean()) 163 | -------------------------------------------------------------------------------- /code/ablation/r2/config/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import settings 5 | from itertools import combinations,product 6 | import math 7 | 8 | class Residual_Block(nn.Module): 9 | def __init__(self, in_ch, out_ch, stride): 10 | super(Residual_Block, self).__init__() 11 | self.channel_num = settings.channel 12 | self.convs = nn.ModuleList() 13 | self.relus = nn.ModuleList() 14 | self.convert = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, stride, 1), 16 | nn.LeakyReLU(0.2) 17 | ) 18 | self.res = nn.Sequential( 19 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 20 | nn.LeakyReLU(0.2), 21 | nn.Conv2d(out_ch, out_ch, 3, 1, 1), 22 | ) 23 | 24 | def forward(self, x): 25 | convert = self.convert(x) 26 | out = convert + self.res(convert) 27 | return out 28 | 29 | class SCConv(nn.Module): 30 | def __init__(self, planes, pooling_r): 31 | super(SCConv, self).__init__() 32 | self.k2 = nn.Sequential( 33 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 34 | nn.Conv2d(planes, planes, 3, 1, 1), 35 | ) 36 | self.k3 = nn.Sequential( 37 | nn.Conv2d(planes, planes, 3, 1, 1), 38 | ) 39 | self.k4 = nn.Sequential( 40 | nn.Conv2d(planes, planes, 3, 1, 1), 41 | nn.LeakyReLU(0.2), 42 | ) 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2) 48 | out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2) 49 | out = self.k4(out) # k4 50 | 51 | return out 52 | 53 | class SCBottleneck(nn.Module): 54 | #expansion = 4 55 | pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv. 56 | 57 | def __init__(self, in_planes, planes): 58 | super(SCBottleneck, self).__init__() 59 | planes = int(planes / 2) 60 | 61 | self.conv1_a = nn.Conv2d(in_planes, planes, 1, 1) 62 | self.k1 = nn.Sequential( 63 | nn.Conv2d(planes, planes, 3, 1, 1), 64 | nn.LeakyReLU(0.2), 65 | ) 66 | 67 | self.conv1_b = nn.Conv2d(in_planes, planes, 1, 1) 68 | 69 | self.scconv = SCConv(planes, self.pooling_r) 70 | 71 | self.conv3 = nn.Conv2d(planes * 2, planes * 2, 1, 1) 72 | self.relu = nn.LeakyReLU(0.2) 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out_a= self.conv1_a(x) 78 | out_a = self.relu(out_a) 79 | 80 | out_a = self.k1(out_a) 81 | 82 | out_b = self.conv1_b(x) 83 | out_b = self.relu(out_b) 84 | 85 | out_b = self.scconv(out_b) 86 | 87 | out = self.conv3(torch.cat([out_a, out_b], dim=1)) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Scale_attention(nn.Module): 95 | def __init__(self): 96 | super(Scale_attention, self).__init__() 97 | self.scale_attention = nn.ModuleList() 98 | self.res_list = nn.ModuleList() 99 | self.channel = settings.channel 100 | if settings.scale_attention is True: 101 | for i in range(settings.num_scale_attention): 102 | self.scale_attention.append( 103 | nn.Sequential( 104 | nn.MaxPool2d(2 ** (i + 1), 2 ** (i + 1)), 105 | nn.Conv2d(self.channel, self.channel, 1, 1), 106 | nn.Sigmoid() 107 | ) 108 | ) 109 | for i in range(settings.num_scale_attention): 110 | self.res_list.append( 111 | Residual_Block(self.channel, self.channel, 2) 112 | ) 113 | 114 | self.conv11 = nn.Sequential( 115 | nn.Conv2d((settings.num_scale_attention + 1) * self.channel, self.channel, 1, 1), 116 | nn.LeakyReLU(0.2) 117 | ) 118 | self.scn = SCBottleneck(self.channel, self.channel) 119 | 120 | def forward(self, x): 121 | b, c, h, w = x.size() 122 | temp = x 123 | out = [] 124 | out.append(temp) 125 | if settings.scale_attention is True: 126 | for i in range(settings.num_scale_attention): 127 | temp = self.res_list[i](temp) 128 | b0,c0,h0,w0 = temp.size() 129 | temp = temp * F.upsample(self.scale_attention[i](x), [h0, w0]) 130 | up = temp 131 | out.append(F.upsample(up, [h, w])) 132 | fusion = self.conv11(torch.cat(out, dim=1)) 133 | 134 | else: 135 | for i in range(settings.num_scale_attention): 136 | temp = self.res_list[i](temp) 137 | up = temp 138 | out.append(F.upsample(up, [h, w])) 139 | fusion = self.conv11(torch.cat(out, dim=1)) 140 | out = self.scn(fusion + x) 141 | return out 142 | 143 | class DenseConnection(nn.Module): 144 | def __init__(self, unit, unit_num): 145 | super(DenseConnection, self).__init__() 146 | self.unit_num = unit_num 147 | self.channel = settings.channel 148 | self.units = nn.ModuleList() 149 | self.conv1x1 = nn.ModuleList() 150 | for i in range(self.unit_num): 151 | self.units.append(unit()) 152 | self.conv1x1.append(nn.Sequential(nn.Conv2d((i+2)*self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))) 153 | 154 | def forward(self, x): 155 | cat = [] 156 | cat.append(x) 157 | out = x 158 | for i in range(self.unit_num): 159 | tmp = self.units[i](out) 160 | cat.append(tmp) 161 | out = self.conv1x1[i](torch.cat(cat,dim=1)) 162 | return out 163 | 164 | 165 | class ODE_DerainNet(nn.Module): 166 | def __init__(self): 167 | super(ODE_DerainNet, self).__init__() 168 | self.channel = settings.channel 169 | self.unit_num = settings.unit_num 170 | self.enterBlock = nn.Sequential(nn.Conv2d(3, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)) 171 | self.derain_net = DenseConnection(Scale_attention, self.unit_num) 172 | self.exitBlock = nn.Sequential(nn.Conv2d(self.channel, 3, 3, 1, 1), nn.LeakyReLU(0.2)) 173 | 174 | 175 | def forward(self, x): 176 | image_feature = self.enterBlock(x) 177 | rain_feature = self.derain_net(image_feature) 178 | rain = self.exitBlock(rain_feature) 179 | derain = x - rain 180 | return derain 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JDNet:Joint Self-Attention and Scale-Aggregation for Self-Calibrated Deraining Network 2 | 3 | [Cong Wang](https://supercong94.wixsite.com/supercong94)\*, [Yutong Wu](https://ohraincu.github.io/)\*, [Zhixun Su](http://faculty.dlut.edu.cn/ZhixunSu/zh_CN/index/759047/list/index.htm) †, Junyang Chen 4 | 5 | <\* Both authors contributed equally to this research. † Corresponding author.> 6 | 7 | This work has been accepted by ACM'MM 2020. [\[Arxiv\]](https://arxiv.org/abs/2008.02763) 8 | 9 |
11 |
12 | Fig1:An example from real-world datasets.
13 |
20 |
21 | Fig2:The architecture of Joint Network for deraining (JDNet).
22 | ![]() |
104 | ![]() |
105 |
| Fig3: Paired image. | 108 |Fig4: Unpaired image. | 109 |