├── .gitignore ├── .python-version ├── convert.py ├── create_lmdb.py ├── install_linux.sh ├── install_windows.ps1 ├── license.txt ├── licenses └── readme.md ├── neosr ├── __init__.py ├── archs │ ├── __init__.py │ ├── arch_util.py │ ├── asid_arch.py │ ├── atd_arch.py │ ├── catanet_arch.py │ ├── cfsr_arch.py │ ├── compact_arch.py │ ├── craft_arch.py │ ├── cugan_arch.py │ ├── dat_arch.py │ ├── dct_arch.py │ ├── dctlsa_arch.py │ ├── ditn_arch.py │ ├── drct_arch.py │ ├── dunet_arch.py │ ├── ea2fpn_arch.py │ ├── eimn_arch.py │ ├── esc_arch.py │ ├── esrgan_arch.py │ ├── flexnet_arch.py │ ├── grformer_arch.py │ ├── hasn_arch.py │ ├── hat_arch.py │ ├── hitsrf_arch.py │ ├── hma_arch.py │ ├── krgn_arch.py │ ├── lmlt_arch.py │ ├── man_arch.py │ ├── metagan_arch.py │ ├── moesr_arch.py │ ├── mosrv2_arch.py │ ├── msdan_arch.py │ ├── ninasr_arch.py │ ├── omnisr_arch.py │ ├── patchgan_arch.py │ ├── plainusr_arch.py │ ├── plksr_arch.py │ ├── rcan_arch.py │ ├── realplksr_arch.py │ ├── rgt_arch.py │ ├── safmn_arch.py │ ├── sebica_arch.py │ ├── span_arch.py │ ├── spanplus_arch.py │ ├── srformer_arch.py │ ├── swinir_arch.py │ ├── unet_arch.py │ └── vgg_arch.py ├── data │ ├── __init__.py │ ├── augmentations.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── file_client.py │ ├── otf_dataset.py │ ├── paired_dataset.py │ ├── prefetch_dataloader.py │ ├── single_dataset.py │ └── transforms.py ├── losses │ ├── .gitattributes │ ├── __init__.py │ ├── basic_loss.py │ ├── consistency_loss.py │ ├── dists_loss.py │ ├── fdl_loss.py │ ├── ff_loss.py │ ├── gan_loss.py │ ├── kl_loss.py │ ├── ldl_loss.py │ ├── msswd_loss.py │ ├── ncc_loss.py │ ├── ssim_loss.py │ ├── vgg_perceptual_loss.py │ └── wavelet_guided.py ├── metrics │ ├── .gitattributes │ ├── __init__.py │ ├── calculate.py │ ├── metric_util.py │ └── topiq.py ├── models │ ├── __init__.py │ ├── base.py │ ├── image.py │ └── otf.py ├── optimizers │ ├── __init__.py │ ├── adamw_sf.py │ ├── adamw_win.py │ ├── adan.py │ ├── adan_sf.py │ ├── fsam.py │ └── soap_sf.py └── utils │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── misc.py │ ├── options.py │ ├── registry.py │ └── rng.py ├── options ├── test_asid.toml ├── test_atd.toml ├── test_catanet.toml ├── test_cfsr.toml ├── test_compact.toml ├── test_craft.toml ├── test_cugan.toml ├── test_dat.toml ├── test_dct.toml ├── test_dctlsa.toml ├── test_ditn.toml ├── test_drct.toml ├── test_eimn.toml ├── test_esc.toml ├── test_esrgan.toml ├── test_flexnet.toml ├── test_grformer.toml ├── test_hasn.toml ├── test_hat.toml ├── test_hitsrf.toml ├── test_hma.toml ├── test_krgn.toml ├── test_lmlt.toml ├── test_man.toml ├── test_moesr.toml ├── test_mosrv2.toml ├── test_msdan.toml ├── test_ninasr.toml ├── test_omnisr.toml ├── test_plainusr.toml ├── test_plksr.toml ├── test_rcan.toml ├── test_realplksr.toml ├── test_rgt.toml ├── test_safmn.toml ├── test_scnet.toml ├── test_sebica.toml ├── test_span.toml ├── test_spanplus.toml ├── test_srformer.toml ├── test_swinir.toml ├── train_asid.toml ├── train_asid_otf.toml ├── train_atd.toml ├── train_atd_otf.toml ├── train_catanet.toml ├── train_catanet_otf.toml ├── train_cfsr.toml ├── train_cfsr_otf.toml ├── train_compact.toml ├── train_compact_otf.toml ├── train_craft.toml ├── train_craft_otf.toml ├── train_cugan.toml ├── train_cugan_otf.toml ├── train_dat.toml ├── train_dat_otf.toml ├── train_dct.toml ├── train_dct_otf.toml ├── train_dctlsa.toml ├── train_dctlsa_otf.toml ├── train_ditn.toml ├── train_ditn_otf.toml ├── train_drct.toml ├── train_drct_otf.toml ├── train_eimn.toml ├── train_eimn_otf.toml ├── train_esc.toml ├── train_esc_otf.toml ├── train_esrgan.toml ├── train_esrgan_otf.toml ├── train_flexnet.toml ├── train_flexnet_otf.toml ├── train_grformer.toml ├── train_grformer_otf.toml ├── train_hasn.toml ├── train_hasn_otf.toml ├── train_hat.toml ├── train_hat_otf.toml ├── train_hitsrf.toml ├── train_hitsrf_otf.toml ├── train_hma.toml ├── train_hma_otf.toml ├── train_krgn.toml ├── train_krgn_otf.toml ├── train_lmlt.toml ├── train_lmlt_otf.toml ├── train_man.toml ├── train_man_otf.toml ├── train_moesr.toml ├── train_moesr_otf.toml ├── train_mosrv2.toml ├── train_mosrv2_otf.toml ├── train_msdan.toml ├── train_msdan_otf.toml ├── train_ninasr.toml ├── train_ninasr_otf.toml ├── train_omnisr.toml ├── train_omnisr_otf.toml ├── train_plainusr.toml ├── train_plainusr_otf.toml ├── train_plksr.toml ├── train_plksr_otf.toml ├── train_rcan.toml ├── train_rcan_otf.toml ├── train_realplksr.toml ├── train_realplksr_otf.toml ├── train_rgt.toml ├── train_rgt_otf.toml ├── train_safmn.toml ├── train_safmn_otf.toml ├── train_sebica.toml ├── train_sebica_otf.toml ├── train_span.toml ├── train_span_otf.toml ├── train_spanplus.toml ├── train_spanplus_otf.toml ├── train_srformer.toml ├── train_srformer_otf.toml ├── train_swinir.toml └── train_swinir_otf.toml ├── pyproject.toml ├── readme.md ├── test.py ├── train.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | __pycache__/ 3 | .ruff_cache/ 4 | .mypy_cache/ 5 | experiments/* 6 | options/tmp/ 7 | neosr/metrics/topiq_fr_weights.pth 8 | neosr/losses/dists_weights.pth 9 | check.sh 10 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from neosr.utils import scandir 4 | from neosr.utils.lmdb_util import make_lmdb_from_imgs 5 | 6 | 7 | def prepare_keys(folder_path): 8 | """Prepare image path list and keys. 9 | 10 | Args: 11 | ---- 12 | folder_path (str): Folder path. 13 | 14 | Returns: 15 | ------- 16 | list[str]: Image path list. 17 | list[str]: Key list. 18 | 19 | """ 20 | print("Reading image path list ...") 21 | img_path_list = sorted(scandir(folder_path, suffix="png", recursive=False)) 22 | keys = [img_path.split(".png")[0] for img_path in sorted(img_path_list)] 23 | 24 | return img_path_list, keys 25 | 26 | 27 | def create_lmdb(): 28 | """Create lmdb files. 29 | Before run this script, please run `extract_subimages.py`. 30 | """ 31 | folder_path = args.input 32 | lmdb_path = args.output 33 | img_path_list, keys = prepare_keys(folder_path) 34 | make_lmdb_from_imgs( 35 | folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | 42 | parser.add_argument("--input", type=str, help=("Input Path")) 43 | parser.add_argument("--output", type=str, help=("Output Path")) 44 | args = parser.parse_args() 45 | create_lmdb() 46 | -------------------------------------------------------------------------------- /install_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # exit on error 4 | set -e 5 | 6 | echo "--- starting neosr installation..." 7 | 8 | # function to prompt for package installation 9 | prompt_install() { 10 | package=$1 11 | echo "--- the package '$package' is required but not installed." 12 | printf "--- would you like to install it? [y/N] " 13 | read -r answer 14 | case "$answer" in 15 | [Yy]*) 16 | return 0 17 | ;; 18 | *) 19 | return 1 20 | ;; 21 | esac 22 | } 23 | 24 | # check if git is installed 25 | if ! command -v git >/dev/null 2>&1; then 26 | if prompt_install "git"; then 27 | if command -v apt-get >/dev/null 2>&1; then 28 | sudo apt-get update && sudo apt-get install -y git 29 | elif command -v dnf >/dev/null 2>&1; then 30 | sudo dnf install -y git 31 | elif command -v pacman >/dev/null 2>&1; then 32 | sudo pacman -Sy git --noconfirm 33 | else 34 | printf "-\033[1;31m-- Error: Could not install git. Please install git manually.\033[0m" 35 | exit 1 36 | fi 37 | else 38 | printf "\033[1;31m--- git is required for installation, exiting.\033[0m" 39 | exit 1 40 | fi 41 | fi 42 | 43 | # create and move to installation directory 44 | INSTALL_DIR="$PWD/neosr" 45 | 46 | # Handle existing installation 47 | if [ -d "$INSTALL_DIR" ]; then 48 | if [ -d "$INSTALL_DIR/.git" ]; then 49 | cd "$INSTALL_DIR" 50 | git pull --autostash 51 | else 52 | printf "\033[1;31m--- directory $INSTALL_DIR exists but is not a git repository.\033[0m" 53 | printf "\033[1;31m--- please remove or rename it and run the script again.\033[0m" 54 | exit 1 55 | fi 56 | else 57 | git clone https://github.com/neosr-project/neosr >/dev/null 2>&1 58 | cd neosr 59 | fi 60 | 61 | # install uv 62 | if command -v curl >/dev/null 2>&1; then 63 | curl -LsSf https://astral.sh/uv/install.sh | sh >/dev/null 2>&1 64 | elif command -v wget >/dev/null 2>&1; then 65 | wget -qO- https://astral.sh/uv/install.sh | sh >/dev/null 2>&1 66 | else 67 | if prompt_install "curl"; then 68 | if command -v apt-get >/dev/null 2>&1; then 69 | sudo apt-get update && sudo apt-get install -y curl 70 | curl -LsSf https://astral.sh/uv/install.sh | sh >/dev/null 2>&1 71 | elif command -v dnf >/dev/null 2>&1; then 72 | sudo dnf install -y curl 73 | curl -LsSf https://astral.sh/uv/install.sh | sh >/dev/null 2>&1 74 | elif command -v pacman >/dev/null 2>&1; then 75 | sudo pacman -Sy curl --noconfirm 76 | curl -LsSf https://astral.sh/uv/install.sh | sh >/dev/null 2>&1 77 | else 78 | printf "\033[1;31m-- error: could not install curl, please install curl or wget manually.\033[0m" 79 | exit 1 80 | fi 81 | else 82 | printf "\033[1;31m-- either curl or wget is required for installation, exiting.\033[0m" 83 | exit 1 84 | fi 85 | fi 86 | 87 | uv self update >/dev/null 2>&1 88 | uv cache clean >/dev/null 2>&1 89 | printf "\033[1m--- syncing dependencies (this might take several minutes)...\033[0m\n" 90 | uv sync 91 | 92 | # create aliases 93 | echo "--- adding aliases..." 94 | ALIAS_FILE="$HOME/.neosr_aliases" 95 | cat > "$ALIAS_FILE" << 'EOF' 96 | alias neosr-train='uv run --isolated train.py -opt' 97 | alias neosr-test='uv run --isolated test.py -opt' 98 | alias neosr-convert='uv run --isolated convert.py' 99 | alias neosr-update='git pull --autostash && uv self update && uv sync && uv cache prune' 100 | EOF 101 | # add source to shell config files 102 | for SHELL_RC in "$HOME/.bashrc" "$HOME/.zshrc" "$HOME/.profile"; do 103 | if [ -f "$SHELL_RC" ]; then 104 | if ! grep -q "source $ALIAS_FILE" "$SHELL_RC"; then 105 | echo "source $ALIAS_FILE" >> "$SHELL_RC" 106 | fi 107 | fi 108 | done 109 | 110 | printf "\033[1;32m--- neosr installation complete!\033[0m\n\n" 111 | -------------------------------------------------------------------------------- /neosr/__init__.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | from neosr.archs import * 3 | from neosr.data import * 4 | from neosr.losses import * 5 | from neosr.metrics import * 6 | from neosr.models import * 7 | from neosr.utils import * 8 | -------------------------------------------------------------------------------- /neosr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | from torch import nn 7 | 8 | from neosr.utils import get_root_logger, scandir 9 | from neosr.utils.registry import ARCH_REGISTRY 10 | 11 | __all__ = ["build_network"] 12 | 13 | 14 | def build_network(opt: dict[str, Any]) -> nn.Module | object: 15 | # automatically scan and import arch modules for registry 16 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 17 | arch_folder = Path(Path(__file__).resolve()).parent 18 | arch_filenames = [ 19 | Path(Path(v).name).stem 20 | for v in scandir(str(arch_folder)) 21 | if v.endswith("_arch.py") 22 | ] 23 | # import all the arch modules 24 | _arch_modules = [ 25 | importlib.import_module(f"neosr.archs.{file_name}") 26 | for file_name in arch_filenames 27 | ] 28 | if opt is not None: 29 | opt = deepcopy(opt) 30 | network_type = opt.pop("type") 31 | net = ARCH_REGISTRY.get(network_type)(**opt) # type: ignore[operator] 32 | logger = get_root_logger() 33 | logger.info(f"Using network [{net.__class__.__name__}].") 34 | return net 35 | -------------------------------------------------------------------------------- /neosr/archs/compact_arch.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from neosr.archs.arch_util import net_opt 6 | from neosr.utils.registry import ARCH_REGISTRY 7 | 8 | upscale, __ = net_opt() 9 | 10 | 11 | @ARCH_REGISTRY.register() 12 | class compact(nn.Module): 13 | """A compact VGG-style network structure for super-resolution. 14 | 15 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 16 | conducted on the HR feature space. 17 | 18 | Args: 19 | ---- 20 | num_in_ch (int): Channel number of inputs. Default: 3. 21 | num_out_ch (int): Channel number of outputs. Default: 3. 22 | num_feat (int): Channel number of intermediate features. Default: 64. 23 | num_conv (int): Number of convolution layers in the body network. Default: 16. 24 | upscale (int): Upsampling factor. Default: 4. 25 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | num_in_ch=3, 32 | num_out_ch=3, 33 | num_feat=64, 34 | num_conv=16, 35 | upscale=upscale, 36 | act_type="prelu", 37 | **kwargs, 38 | ): 39 | super(compact, self).__init__() 40 | self.num_in_ch = num_in_ch 41 | self.num_out_ch = num_out_ch 42 | self.num_feat = num_feat 43 | self.num_conv = num_conv 44 | self.upscale = upscale 45 | self.act_type = act_type 46 | 47 | self.body = nn.ModuleList() 48 | # the first conv 49 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 50 | # the first activation 51 | if act_type == "relu": 52 | activation = nn.ReLU(inplace=True) 53 | elif act_type == "prelu": 54 | activation = nn.PReLU(num_parameters=num_feat) 55 | elif act_type == "leakyrelu": 56 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 57 | self.body.append(activation) 58 | 59 | # the body structure 60 | for _ in range(num_conv): 61 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 62 | # activation 63 | if act_type == "relu": 64 | activation = nn.ReLU(inplace=True) 65 | elif act_type == "prelu": 66 | activation = nn.PReLU(num_parameters=num_feat) 67 | elif act_type == "leakyrelu": 68 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 69 | self.body.append(activation) 70 | 71 | # the last conv 72 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 73 | # upsample 74 | self.upsampler = nn.PixelShuffle(upscale) 75 | 76 | def forward(self, x): 77 | out = x 78 | for i in range(len(self.body)): 79 | out = self.body[i](out) 80 | 81 | out = self.upsampler(out) 82 | # add the nearest upsampled image, so that the network learns the residual 83 | base = F.interpolate(x, scale_factor=self.upscale, mode="nearest") 84 | out += base 85 | return out 86 | -------------------------------------------------------------------------------- /neosr/archs/dunet_arch.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | from torch import Tensor, nn 3 | from torch.nn.utils import spectral_norm 4 | 5 | from neosr.archs.arch_util import DySample 6 | from neosr.utils.registry import ARCH_REGISTRY 7 | 8 | 9 | class Down(nn.Sequential): 10 | def __init__(self, dim): 11 | super().__init__(spectral_norm(nn.Conv2d(dim, dim * 2, 3, 2, 1)), nn.Mish(True)) 12 | 13 | 14 | class Up(nn.Sequential): 15 | def __init__(self, dim): 16 | super().__init__( 17 | DySample(dim, dim, 2, 4, False), 18 | spectral_norm(nn.Conv2d(dim, dim // 2, 3, 1, 1)), 19 | ) 20 | 21 | 22 | @ARCH_REGISTRY.register() 23 | class dunet(nn.Module): 24 | """Code from: 25 | https://github.com/umzi2/DUnet 26 | """ 27 | 28 | def __init__(self, in_ch: int = 3, dim: int = 64): 29 | super().__init__() 30 | self.in_to_dim = nn.Conv2d(in_ch, dim, 3, 1, 1) 31 | # encode x 32 | self.e_x1 = Down(dim) 33 | self.e_x2 = Down(dim * 2) 34 | self.e_x3 = Down(dim * 4) 35 | # up 36 | self.up1 = Up(dim * 8) 37 | self.up2 = Up(dim * 4) 38 | self.up3 = Up(dim * 2) 39 | # end conv 40 | self.end_conv = nn.Sequential( 41 | spectral_norm(nn.Conv2d(dim, dim, 3, 1, 1, bias=False)), 42 | nn.Mish(True), 43 | spectral_norm(nn.Conv2d(dim, dim, 3, 1, 1, bias=False)), 44 | nn.Mish(True), 45 | nn.Conv2d(dim, 1, 3, 1, 1), 46 | ) 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | x0 = self.in_to_dim(x) 50 | x1 = self.e_x1(x0) 51 | x2 = self.e_x2(x1) 52 | x3 = self.e_x3(x2) 53 | x = self.up1(x3) + x2 54 | x = self.up2(x) + x1 55 | x = self.up3(x) + x0 56 | return self.end_conv(x) 57 | -------------------------------------------------------------------------------- /neosr/archs/unet_arch.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils import spectral_norm 5 | 6 | from neosr.utils.registry import ARCH_REGISTRY 7 | 8 | 9 | @ARCH_REGISTRY.register() 10 | class unet(nn.Module): 11 | """Defines a U-Net discriminator with spectral normalization (SN). 12 | 13 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 14 | 15 | Arg: 16 | num_in_ch (int): Channel number of inputs. Default: 3. 17 | num_feat (int): Channel number of base intermediate features. Default: 64. 18 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 19 | """ 20 | 21 | def __init__(self, num_in_ch=3, num_feat=64, skip_connection=True) -> None: 22 | super().__init__() 23 | self.skip_connection = skip_connection 24 | norm = spectral_norm 25 | # the first convolution 26 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 27 | # downsample 28 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 29 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 30 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 31 | # upsample 32 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 33 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 34 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 35 | # extra convolutions 36 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 37 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 38 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 39 | 40 | def forward(self, x): 41 | # downsample 42 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 43 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 44 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 45 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 46 | 47 | # upsample 48 | x3 = F.interpolate(x3, scale_factor=2, mode="bilinear", align_corners=False) 49 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 50 | 51 | if self.skip_connection: 52 | x4 = x4 + x2 53 | x4 = F.interpolate(x4, scale_factor=2, mode="bilinear", align_corners=False) 54 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 55 | 56 | if self.skip_connection: 57 | x5 = x5 + x1 58 | x5 = F.interpolate(x5, scale_factor=2, mode="bilinear", align_corners=False) 59 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 60 | 61 | if self.skip_connection: 62 | x6 = x6 + x0 63 | 64 | # extra convolutions 65 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 66 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 67 | return self.conv9(out) 68 | -------------------------------------------------------------------------------- /neosr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Iterator 3 | 4 | import torch 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class EnlargedSampler(Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset. 10 | 11 | Modified from torch.utils.data.distributed.DistributedSampler 12 | Support enlarging the dataset for iteration-based training, for saving 13 | time when restart the dataloader after each epoch 14 | 15 | Args: 16 | ---- 17 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 18 | num_replicas (int | None): Number of processes participating in 19 | the training. It is usually the world_size. 20 | rank (int | None): Rank of the current process within num_replicas. 21 | ratio (int): Enlarging ratio. Default: 1. 22 | 23 | """ 24 | 25 | def __init__( 26 | self, dataset, num_replicas: int = 1, rank: int = 1, ratio: int = 1 27 | ) -> None: 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 33 | self.total_size = self.num_samples * self.num_replicas 34 | 35 | def set_epoch(self, epoch: int) -> None: 36 | self.epoch = epoch 37 | 38 | def __iter__(self) -> Iterator[int]: 39 | # deterministically shuffle based on epoch 40 | g = torch.Generator(device="cuda") 41 | g.manual_seed(self.epoch) 42 | indices = torch.randperm(self.total_size, generator=g, device="cuda").tolist() 43 | 44 | dataset_size = len(self.dataset) 45 | indices = [v % dataset_size for v in indices] 46 | 47 | # subsample 48 | indices = indices[self.rank : self.total_size : self.num_replicas] 49 | assert len(indices) == self.num_samples 50 | 51 | return iter(indices) 52 | 53 | def __len__(self) -> int: 54 | return self.num_samples 55 | -------------------------------------------------------------------------------- /neosr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | from collections.abc import Iterator 3 | from threading import Thread 4 | from typing import Any 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | class PrefetchGenerator(Thread): 11 | """A general prefetch generator. 12 | 13 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 14 | 15 | Args: 16 | ---- 17 | generator: Python generator. 18 | num_prefetch_queue (int): Number of prefetch queue. 19 | 20 | """ 21 | 22 | def __init__(self, generator, num_prefetch_queue: int) -> None: 23 | Thread.__init__(self) 24 | self.queue: Queue.Queue[Any] = Queue.Queue(num_prefetch_queue) 25 | self.generator = generator 26 | self.daemon = True 27 | self.start() 28 | 29 | def run(self) -> None: 30 | for item in self.generator: 31 | self.queue.put(item) 32 | self.queue.put(None) 33 | 34 | def __next__(self) -> Any: 35 | next_item = self.queue.get() # type: Any 36 | if next_item is None: 37 | raise StopIteration 38 | return next_item 39 | 40 | def __iter__(self) -> Iterator: 41 | return self 42 | 43 | 44 | class PrefetchDataLoader(DataLoader): 45 | """Prefetch version of dataloader. 46 | 47 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 48 | 49 | Todo: 50 | ---- 51 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 52 | ddp. 53 | 54 | Args: 55 | ---- 56 | num_prefetch_queue (int): Number of prefetch queue. 57 | kwargs (dict): Other arguments for dataloader. 58 | 59 | """ 60 | 61 | def __init__(self, num_prefetch_queue: int, **kwargs) -> None: 62 | self.num_prefetch_queue = num_prefetch_queue 63 | super().__init__(**kwargs) 64 | 65 | def __iter__(self): # type: ignore[reportIncompatibleMethodOverride] 66 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 67 | 68 | 69 | class CUDAPrefetcher: 70 | """CUDA prefetcher. 71 | 72 | Reference: https://github.com/NVIDIA/apex/issues/304# 73 | 74 | It may consume more GPU memory. 75 | 76 | Args: 77 | ---- 78 | loader: Dataloader. 79 | opt (dict): Options. 80 | 81 | """ 82 | 83 | def __init__(self, loader: DataLoader, opt: dict[str, Any]) -> None: 84 | self.ori_loader = loader 85 | self.loader = iter(loader) 86 | self.opt = opt 87 | self.stream = torch.cuda.Stream() 88 | self.device = torch.device("cuda") 89 | self.preload() 90 | 91 | def preload(self) -> None: 92 | try: 93 | self.batch = next(self.loader) # self.batch is a dict 94 | except StopIteration: 95 | self.batch = None 96 | return 97 | # put tensors to gpu 98 | with torch.cuda.stream(self.stream): # type: ignore[reportArgumentType] 99 | for k, v in self.batch.items(): 100 | if torch.is_tensor(v): 101 | self.batch[k] = self.batch[k].to( 102 | device=self.device, non_blocking=True 103 | ) 104 | 105 | def next(self): 106 | torch.cuda.current_stream().wait_stream(self.stream) 107 | batch = self.batch 108 | self.preload() 109 | return batch 110 | 111 | def reset(self) -> None: 112 | self.loader = iter(self.ori_loader) 113 | self.preload() 114 | -------------------------------------------------------------------------------- /neosr/data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from torch.utils import data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from neosr.data.data_util import paths_from_lmdb 8 | from neosr.data.file_client import FileClient 9 | from neosr.utils import imfrombytes, img2tensor, scandir 10 | from neosr.utils.registry import DATASET_REGISTRY 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class single(data.Dataset): 15 | """Read only lq images in the test phase. 16 | 17 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 18 | 19 | There are two modes: 20 | 1. 'meta_info_file': Use meta information file to generate paths. 21 | 2. 'folder': Scan folders to generate paths. 22 | 23 | Args: 24 | ---- 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | 30 | """ 31 | 32 | def __init__(self, opt: dict[str, Any]) -> None: 33 | super().__init__() 34 | self.opt = opt 35 | # file client (io backend) 36 | self.file_client: FileClient | None = None 37 | self.mean = opt.get("mean") 38 | self.std = opt.get("std") 39 | self.lq_folder = opt["dataroot_lq"] 40 | self.color = self.opt.get("color", None) != "y" 41 | 42 | # sets flag for file_client.py 43 | if self.lq_folder.endswith("lmdb"): 44 | self.io_backend_opt: dict[str, str] = {"type": "lmdb"} 45 | lmdb = True 46 | else: 47 | self.io_backend_opt: dict[str, str] = {"type": "disk"} 48 | lmdb = False 49 | 50 | if lmdb: 51 | self.io_backend_opt["db_paths"] = [self.lq_folder] # type: ignore[assignment] 52 | self.io_backend_opt["client_keys"] = ["lq"] # type: ignore[assignment] 53 | self.paths = paths_from_lmdb(self.lq_folder) 54 | elif "meta_info_file" in self.opt: 55 | with Path(str(self.opt["meta_info_file"])).open(encoding="utf-8") as fin: 56 | self.paths = [ 57 | str(Path(self.lq_folder) / line.rstrip()).split(" ")[0] 58 | for line in fin 59 | ] 60 | else: 61 | self.paths = sorted(scandir(self.lq_folder, full_path=True)) 62 | 63 | def __getitem__(self, index): 64 | if self.file_client is None: 65 | self.file_client = FileClient( 66 | self.io_backend_opt.pop("type"), # type: ignore[union-attr] 67 | **self.io_backend_opt, 68 | ) 69 | 70 | # load lq image 71 | lq_path = self.paths[index] 72 | img_bytes = self.file_client.get(lq_path, "lq") # type: ignore[attr-defined] 73 | 74 | try: 75 | img_lq = imfrombytes(img_bytes, float32=True) 76 | except AttributeError: 77 | raise AttributeError(lq_path) 78 | 79 | # BGR to RGB, HWC to CHW, numpy to tensor 80 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True, color=self.color) 81 | # normalize 82 | if self.mean is not None or self.std is not None: 83 | normalize(img_lq, self.mean, self.std, inplace=True) # type: ignore[reportAssignmentType] 84 | return {"lq": img_lq, "lq_path": lq_path} 85 | 86 | def __len__(self) -> int: 87 | return len(self.paths) 88 | -------------------------------------------------------------------------------- /neosr/losses/.gitattributes: -------------------------------------------------------------------------------- 1 | dists_weights.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /neosr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | from torch import nn 7 | 8 | from neosr.utils import get_root_logger, scandir 9 | from neosr.utils.registry import LOSS_REGISTRY 10 | 11 | __all__ = ["build_loss"] 12 | 13 | # automatically scan and import loss modules for registry 14 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 15 | loss_folder = Path(Path(__file__).resolve()).parent 16 | loss_filenames = [ 17 | Path(Path(v).name).stem for v in scandir(str(loss_folder)) if v.endswith("_loss.py") 18 | ] 19 | # import all the loss modules 20 | _model_modules = [ 21 | importlib.import_module(f"neosr.losses.{file_name}") for file_name in loss_filenames 22 | ] 23 | 24 | 25 | def build_loss(opt: dict[str, Any]) -> nn.Module | object: 26 | """Build loss from options. 27 | 28 | Args: 29 | ---- 30 | opt (dict): Configuration. It must contain: 31 | type (str): Model type. 32 | 33 | """ 34 | opt = deepcopy(opt) 35 | loss_type = opt.pop("type") 36 | loss = LOSS_REGISTRY.get(loss_type)(**opt) # type: ignore[operator] 37 | logger = get_root_logger() 38 | logger.info(f"Loss [{loss.__class__.__name__}] enabled.") 39 | return loss 40 | -------------------------------------------------------------------------------- /neosr/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | 3 | from neosr.utils.registry import LOSS_REGISTRY 4 | 5 | 6 | @LOSS_REGISTRY.register() 7 | class gan_loss(nn.Module): 8 | """GAN loss. 9 | 10 | Args: 11 | ---- 12 | gan_type (str): Support 'bce', 'mse' (l2), 'huber'. 13 | real_label_val (float): The value for real label. Default: 1.0. 14 | fake_label_val (float): The value for fake label. Default: 0.0. 15 | loss_weight (float): Loss weight. Default: 0.1. 16 | Note that loss_weight is only for generators; and it is always 1.0 17 | for discriminators. 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | gan_type: str = "bce", 24 | real_label_val: float = 1.0, 25 | fake_label_val: float = 0.0, 26 | loss_weight: float = 0.1, 27 | ) -> None: 28 | super().__init__() 29 | self.gan_type = gan_type 30 | self.loss_weight = loss_weight 31 | self.real_label_val = real_label_val 32 | self.fake_label_val = fake_label_val 33 | self.loss: nn.BCEWithLogitsLoss | nn.MSELoss | nn.HuberLoss 34 | 35 | if self.gan_type == "bce": 36 | self.loss = nn.BCEWithLogitsLoss() 37 | elif self.gan_type == "mse": 38 | self.loss = nn.MSELoss() 39 | elif self.gan_type == "huber": 40 | self.loss = nn.HuberLoss() 41 | else: 42 | msg = f"GAN type {self.gan_type} is not implemented." 43 | raise NotImplementedError(msg) 44 | 45 | def get_target_label(self, net_output: Tensor, target_is_real: bool) -> Tensor: 46 | """Get target label. 47 | 48 | Args: 49 | ---- 50 | net_output (Tensor): Input tensor. 51 | target_is_real (bool): Whether the target is real or fake. 52 | 53 | Returns: 54 | ------- 55 | (bool | Tensor): Target tensor. Returns Tensor. 56 | 57 | """ 58 | target_val = self.real_label_val if target_is_real else self.fake_label_val 59 | return net_output.new_ones(net_output.size()) * target_val 60 | 61 | def forward( 62 | self, net_output: Tensor, target_is_real: bool, is_disc: bool = False 63 | ) -> Tensor: 64 | """Args: 65 | ---- 66 | net_output (Tensor): The input for the loss module, i.e., the network 67 | prediction. 68 | target_is_real (bool): Whether the targe is real or fake. 69 | is_disc (bool): Whether the loss for discriminators or not. 70 | Default: False. 71 | 72 | Returns 73 | ------- 74 | Tensor: GAN loss value. 75 | 76 | """ 77 | target_label = self.get_target_label(net_output, target_is_real) 78 | loss = self.loss(net_output, target_label) 79 | 80 | # loss_weight is always 1.0 for discriminators 81 | return loss if is_disc else loss * self.loss_weight 82 | -------------------------------------------------------------------------------- /neosr/losses/kl_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from torch.nn import functional as F 4 | 5 | from neosr.utils.registry import LOSS_REGISTRY 6 | 7 | 8 | @LOSS_REGISTRY.register() 9 | class kl_loss(nn.Module): 10 | """KL-Divergence loss. 11 | 12 | Args: 13 | ---- 14 | loss_weight (float): weight for the loss. Default: 1.0 15 | """ 16 | 17 | def __init__(self, loss_weight: float = 1.0) -> None: 18 | super().__init__() 19 | self.loss_weight = loss_weight 20 | 21 | def forward(self, net_output: Tensor, gt: Tensor): 22 | # Convert net_output and gt to probability distributions 23 | net_output_prob = F.softmax(net_output, dim=1) 24 | gt_prob = F.softmax(gt, dim=1) 25 | # Compute log probabilities 26 | net_output_log_prob = torch.log(net_output_prob + 1e-8) 27 | # Compute KL divergence 28 | loss = F.kl_div(net_output_log_prob, gt_prob, reduction="batchmean") 29 | # balance loss to avoid issues 30 | loss = loss * 0.03 31 | return loss * self.loss_weight 32 | -------------------------------------------------------------------------------- /neosr/losses/ldl_loss.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch.nn import functional as F 6 | 7 | from neosr.losses.basic_loss import chc_loss 8 | from neosr.utils.registry import LOSS_REGISTRY 9 | 10 | if TYPE_CHECKING: 11 | from collections.abc import Callable 12 | 13 | 14 | @LOSS_REGISTRY.register() 15 | class ldl_loss(nn.Module): 16 | """LDL loss. Adapted from 'Details or Artifacts: A Locally Discriminative 17 | Learning Approach to Realistic Image Super-Resolution': 18 | https://arxiv.org/abs/2203.09195. 19 | 20 | Args: 21 | ---- 22 | criterion (str): loss type. Default: 'huber' 23 | loss_weight (float): weight for colorloss. Default: 1.0 24 | ksize (int): size of the local window. Default: 7 25 | 26 | """ 27 | 28 | def __init__( 29 | self, criterion: str = "l1", loss_weight: float = 1.0, ksize: int = 7 30 | ) -> None: 31 | super().__init__() 32 | self.loss_weight = loss_weight 33 | self.ksize = ksize 34 | self.criterion_type = criterion 35 | self.criterion: nn.L1Loss | nn.MSELoss | nn.HuberLoss | Callable 36 | 37 | if self.criterion_type == "l1": 38 | self.criterion = nn.L1Loss() 39 | elif self.criterion_type == "l2": 40 | self.criterion = nn.MSELoss() 41 | elif self.criterion_type == "huber": 42 | self.criterion = nn.HuberLoss() 43 | elif self.criterion_type == "chc": 44 | self.criterion = chc_loss(loss_lambda=0, clip_min=0, clip_max=1) # type: ignore[reportCallIssue] 45 | else: 46 | msg = f"{criterion} criterion has not been supported." 47 | raise NotImplementedError(msg) 48 | 49 | def get_local_weights(self, residual: Tensor) -> Tensor: 50 | """Get local weights for generating the artifact map of LDL. 51 | 52 | It is only called by the `get_refined_artifact_map` function. 53 | 54 | Args: 55 | ---- 56 | residual (Tensor): Residual between predicted and ground truth images. 57 | 58 | Returns: 59 | ------- 60 | Tensor: weight for each pixel to be discriminated as an artifact pixel 61 | 62 | """ 63 | pad = (self.ksize - 1) // 2 64 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode="reflect") 65 | 66 | unfolded_residual = residual_pad.unfold(2, self.ksize, 1).unfold( 67 | 3, self.ksize, 1 68 | ) 69 | return ( 70 | torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True) 71 | .squeeze(-1) 72 | .squeeze(-1) 73 | ) 74 | 75 | def get_refined_artifact_map(self, img_gt: Tensor, img_output: Tensor) -> Tensor: 76 | """Calculate the artifact map of LDL 77 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022). 78 | 79 | Args: 80 | ---- 81 | img_gt (Tensor): ground truth images. 82 | img_output (Tensor): output images given by the optimizing model. 83 | 84 | Returns: 85 | ------- 86 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 87 | (calculated based on both local and global observations). 88 | 89 | """ 90 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 91 | 92 | patch_level_weight = torch.var( 93 | residual_sr.clone(), dim=(-1, -2, -3), keepdim=True 94 | ) ** (1 / 5) 95 | pixel_level_weight = self.get_local_weights(residual_sr.clone()) 96 | return patch_level_weight * pixel_level_weight 97 | 98 | def forward(self, net_output: Tensor, gt: Tensor) -> Tensor: 99 | overall_weight = self.get_refined_artifact_map(gt, net_output) 100 | self.output = torch.mul(overall_weight, net_output) 101 | self.gt = torch.mul(overall_weight, gt) 102 | 103 | return self.criterion(self.output, self.gt) * self.loss_weight 104 | -------------------------------------------------------------------------------- /neosr/losses/ncc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from neosr.utils.registry import LOSS_REGISTRY 5 | 6 | 7 | @LOSS_REGISTRY.register() 8 | class ncc_loss(nn.Module): 9 | """Normalized Cross-Correlation loss. 10 | 11 | Args: 12 | ---- 13 | loss_weight (float): weight for the loss. Default: 1.0 14 | """ 15 | 16 | def __init__(self, loss_weight: float = 1.0) -> None: 17 | super().__init__() 18 | self.loss_weight = loss_weight 19 | 20 | def _cc(self, net_output: Tensor, gt: Tensor): 21 | # reshape 22 | net_output_reshaped = net_output.view(net_output.shape[1], -1) 23 | gt_reshaped = gt.view(gt.shape[1], -1) 24 | # calculate mean 25 | mean_net_output = torch.mean(net_output_reshaped, 1).unsqueeze(1) 26 | mean_gt = torch.mean(gt_reshaped, 1).unsqueeze(1) 27 | # cross-correlation 28 | cc = torch.sum( 29 | (net_output_reshaped - mean_net_output) * (gt_reshaped - mean_gt), 1 30 | ) / torch.sqrt( 31 | torch.sum((net_output_reshaped - mean_net_output) ** 2, 1) 32 | * torch.sum((gt_reshaped - mean_gt) ** 2, 1) 33 | ) 34 | return torch.mean(cc) 35 | 36 | def forward(self, net_output: Tensor, gt: Tensor): 37 | cc_value = self._cc(net_output, gt) 38 | return (1 - ((cc_value + 1) * 0.5)) * self.loss_weight 39 | -------------------------------------------------------------------------------- /neosr/metrics/.gitattributes: -------------------------------------------------------------------------------- 1 | topiq_fr_weights.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /neosr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any 3 | 4 | from neosr.metrics.calculate import ( 5 | calculate_dists, 6 | calculate_psnr, 7 | calculate_ssim, 8 | calculate_topiq, 9 | ) 10 | from neosr.utils.registry import METRIC_REGISTRY 11 | 12 | __all__ = ["calculate_dists", "calculate_psnr", "calculate_ssim", "calculate_topiq"] 13 | 14 | 15 | def calculate_metric(data, opt: dict[str, Any]) -> float: 16 | """Calculate metric from data and options. 17 | 18 | Args: 19 | ---- 20 | opt (dict): Configuration. It must contain: 21 | type (str): Model type. 22 | 23 | """ 24 | opt = deepcopy(opt) 25 | metric_type = opt.pop("type") 26 | return METRIC_REGISTRY.get(metric_type)(**data, **opt) # type: ignore[operator,return-value] 27 | -------------------------------------------------------------------------------- /neosr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from neosr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img: np.ndarray, input_order: str = "HWC") -> np.ndarray: 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | ---- 15 | img (ndarray): Input image. 16 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 17 | If the input image shape is (h, w), input_order will not have 18 | effects. Default: 'HWC'. 19 | 20 | Returns: 21 | ------- 22 | ndarray: reordered image. 23 | 24 | """ 25 | if input_order not in {"HWC", "CHW"}: 26 | msg = f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" 27 | raise ValueError(msg) 28 | if len(img.shape) == 2: 29 | img = img[..., None] 30 | if input_order == "CHW": 31 | img = img.transpose(1, 2, 0) 32 | return img 33 | 34 | 35 | def to_y_channel(img: np.ndarray) -> np.ndarray: 36 | """Change to Y channel of YCbCr. 37 | 38 | Args: 39 | ---- 40 | img (ndarray): Images with range [0, 255]. 41 | 42 | Returns: 43 | ------- 44 | (ndarray): Images with range [0, 255] (float type) without round. 45 | 46 | """ 47 | img = img.astype(np.float32) / 255.0 48 | if img.ndim == 3 and img.shape[2] == 3: 49 | img = bgr2ycbcr(img, y_only=True) 50 | img = img[..., None] 51 | return img * 255.0 52 | -------------------------------------------------------------------------------- /neosr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from collections.abc import Callable 3 | from copy import deepcopy 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | from neosr.utils import get_root_logger, scandir 8 | from neosr.utils.registry import MODEL_REGISTRY 9 | 10 | __all__ = ["build_model"] 11 | 12 | # automatically scan and import model modules for registry 13 | # scan all the files under the 'models' folder and collect files 14 | model_folder = Path(Path(__file__).resolve()).parent 15 | model_filenames = [Path(Path(v).name).stem for v in scandir(str(model_folder))] 16 | 17 | # import all the model modules 18 | _model_modules = [ 19 | importlib.import_module(f"neosr.models.{file_name}") 20 | for file_name in model_filenames 21 | ] 22 | 23 | 24 | def build_model(opt: dict[str, Any]) -> Callable | object: 25 | """Build model from options. 26 | 27 | Args: 28 | ---- 29 | opt (dict): Configuration. It must contain: 30 | model_type (str): Model type. 31 | 32 | """ 33 | opt = deepcopy(opt) 34 | model = MODEL_REGISTRY.get(opt["model_type"])(opt) # type: ignore[operator] 35 | logger = get_root_logger() 36 | logger.info(f"Using model [{model.__class__.__name__}].") 37 | return model 38 | -------------------------------------------------------------------------------- /neosr/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from neosr.optimizers.adamw_sf import adamw_sf 2 | from neosr.optimizers.adamw_win import adamw_win 3 | from neosr.optimizers.adan import adan 4 | from neosr.optimizers.adan_sf import adan_sf 5 | from neosr.optimizers.fsam import fsam 6 | from neosr.optimizers.soap_sf import soap_sf 7 | 8 | __all__ = ["adamw_sf", "adamw_win", "adan", "adan_sf", "fsam", "soap_sf"] 9 | -------------------------------------------------------------------------------- /neosr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from neosr.utils.color_util import ( 2 | bgr2ycbcr, 3 | rgb2ycbcr, 4 | rgb2ycbcr_pt, 5 | ycbcr2bgr, 6 | ycbcr2rgb, 7 | ) 8 | from neosr.utils.diffjpeg import DiffJPEG # type: ignore[attr-defined] 9 | from neosr.utils.img_util import ( 10 | crop_border, 11 | imfrombytes, 12 | img2tensor, 13 | imwrite, 14 | tensor2img, 15 | ) 16 | from neosr.utils.logger import ( 17 | AvgTimer, 18 | MessageLogger, 19 | get_root_logger, 20 | init_tb_logger, 21 | init_wandb_logger, 22 | ) 23 | from neosr.utils.misc import ( 24 | check_disk_space, 25 | check_resume, 26 | get_time_str, 27 | make_exp_dirs, 28 | mkdir_and_rename, 29 | scandir, 30 | set_random_seed, 31 | sizeof_fmt, 32 | tc, 33 | ) 34 | from neosr.utils.options import toml_load 35 | from neosr.utils.registry import Registry 36 | 37 | __all__ = [ 38 | "AvgTimer", 39 | # diffjpeg 40 | "DiffJPEG", 41 | # logger.py 42 | "MessageLogger", 43 | # registry 44 | "Registry", 45 | # color_util.py 46 | "bgr2ycbcr", 47 | # misc.py 48 | "check_disk_space", 49 | "check_resume", 50 | "crop_border", 51 | "get_root_logger", 52 | "get_time_str", 53 | "imfrombytes", 54 | # img_util.py 55 | "img2tensor", 56 | "imwrite", 57 | "init_tb_logger", 58 | "init_wandb_logger", 59 | "make_exp_dirs", 60 | "mkdir_and_rename", 61 | "rgb2ycbcr", 62 | "rgb2ycbcr_pt", 63 | "scandir", 64 | "set_random_seed", 65 | "sizeof_fmt", 66 | "tc", 67 | "tensor2img", 68 | # options 69 | "toml_load", 70 | "ycbcr2bgr", 71 | "ycbcr2rgb", 72 | ] 73 | -------------------------------------------------------------------------------- /neosr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py 2 | import functools 3 | import os 4 | import subprocess # noqa: S404 5 | from collections.abc import Callable 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | 11 | 12 | def _init_dist_pytorch(backend: str, **kwargs) -> None: 13 | rank = int(os.environ["RANK"]) 14 | num_gpus = torch.cuda.device_count() 15 | torch.cuda.set_device(rank % num_gpus) 16 | dist.init_process_group(backend=backend, **kwargs) 17 | 18 | 19 | def _init_dist_slurm(backend: str, port: int) -> None: 20 | """Initialize slurm distributed training environment. 21 | 22 | If argument ``port`` is not specified, then the master port will be system 23 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 24 | environment variable, then a default port ``29500`` will be used. 25 | 26 | Args: 27 | ---- 28 | backend (str): Backend of torch.distributed. 29 | port (int, optional): Master port. Defaults to None. 30 | 31 | """ 32 | proc_id = int(os.environ["SLURM_PROCID"]) 33 | ntasks = int(os.environ["SLURM_NTASKS"]) 34 | node_list = os.environ["SLURM_NODELIST"] 35 | num_gpus = torch.cuda.device_count() 36 | torch.cuda.set_device(proc_id % num_gpus) 37 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") # noqa: S605 38 | # specify master port 39 | if port is not None: 40 | os.environ["MASTER_PORT"] = str(port) 41 | elif "MASTER_PORT" in os.environ: 42 | pass # use MASTER_PORT in the environment variable 43 | else: 44 | # 29500 is torch.distributed default port 45 | os.environ["MASTER_PORT"] = "29500" 46 | os.environ["MASTER_ADDR"] = addr 47 | os.environ["WORLD_SIZE"] = str(ntasks) 48 | os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) 49 | os.environ["RANK"] = str(proc_id) 50 | dist.init_process_group(backend=backend) 51 | 52 | 53 | def init_dist(launcher, backend: str = "nccl", **kwargs) -> None: 54 | if mp.get_start_method(allow_none=True) is None: 55 | mp.set_start_method("spawn") 56 | if launcher == "pytorch": 57 | _init_dist_pytorch(backend, **kwargs) 58 | elif launcher == "slurm": 59 | _init_dist_slurm(backend, **kwargs) 60 | else: 61 | msg = f"Invalid launcher type: {launcher}" 62 | raise ValueError(msg) 63 | 64 | 65 | def get_dist_info() -> tuple[int, int]: 66 | initialized = dist.is_initialized() if dist.is_available() else False 67 | if initialized: 68 | rank = dist.get_rank() 69 | world_size = dist.get_world_size() 70 | else: 71 | rank = 0 72 | world_size = 1 73 | return rank, world_size 74 | 75 | 76 | def master_only(func: Callable) -> Callable: 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | return None 83 | 84 | return wrapper 85 | -------------------------------------------------------------------------------- /neosr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py 2 | 3 | # pyright: strict 4 | from collections.abc import Callable, Iterable, Iterator 5 | from typing import Any 6 | 7 | 8 | class Registry: 9 | """The registry that provides name -> object mapping, to support third-party 10 | users' custom modules. 11 | 12 | To create a registry (e.g. a backbone registry): 13 | 14 | .. code-block:: python 15 | 16 | BACKBONE_REGISTRY = Registry('BACKBONE') 17 | 18 | To register an object: 19 | 20 | .. code-block:: python 21 | 22 | @BACKBONE_REGISTRY.register() 23 | class MyBackbone(): 24 | ... 25 | 26 | Or: 27 | 28 | .. code-block:: python 29 | 30 | BACKBONE_REGISTRY.register(MyBackbone) 31 | """ 32 | 33 | def __init__(self, name: str) -> None: 34 | """Args: 35 | ---- 36 | name (str): the name of this registry 37 | 38 | """ 39 | self._name = name 40 | self._obj_map: dict[str, Callable[..., object] | type[str] | str] = {} 41 | 42 | def _do_register( 43 | self, 44 | name: str, 45 | obj: Callable[..., object] | type[str] | str, 46 | suffix: str | None = None, 47 | ) -> None: 48 | if isinstance(suffix, str): 49 | name = name + "_" + suffix 50 | 51 | assert name not in self._obj_map, ( 52 | f"An object named '{name}' was already registered " 53 | f"in '{self._name}' registry!" 54 | ) 55 | self._obj_map[name] = obj 56 | 57 | def register( 58 | self, 59 | obj: Callable[..., object] | type[str] | None = None, 60 | suffix: str | None = None, 61 | ) -> Callable[..., object]: 62 | """Register the given object under the the name `obj.__name__`. 63 | Can be used as either a decorator or not. 64 | See docstring of this class for usage. 65 | """ 66 | if obj is None: 67 | # used as a decorator 68 | def deco( 69 | func_or_class: Callable[..., object] | type[str], 70 | ) -> Callable[..., object] | type[str]: 71 | name = func_or_class.__name__ 72 | self._do_register(name, func_or_class, suffix) 73 | return func_or_class 74 | 75 | return deco 76 | 77 | # used as a function call 78 | name = obj if isinstance(obj, str) else obj.__name__ 79 | self._do_register(name, obj, suffix) 80 | return None # type: ignore[return-value] 81 | 82 | def get( 83 | self, name: str, suffix: str = "neosr" 84 | ) -> Callable[..., object] | type[str] | str: 85 | ret = self._obj_map.get(name) 86 | if ret is None: 87 | ret = self._obj_map.get(name + "_" + suffix) 88 | if ret is None: 89 | msg = f"No object named '{name}' found in '{self._name}' registry!" 90 | raise KeyError(msg) 91 | return ret 92 | 93 | def keys(self) -> Iterable[str]: 94 | return self._obj_map.keys() 95 | 96 | def __iter__(self) -> Iterator[tuple[str, Callable[..., Any] | type[str] | str]]: 97 | return iter(self._obj_map.items()) 98 | 99 | def __contains__(self, name: str) -> bool: 100 | return name in self._obj_map 101 | 102 | 103 | DATASET_REGISTRY = Registry("dataset") 104 | ARCH_REGISTRY = Registry("arch") 105 | MODEL_REGISTRY = Registry("model") 106 | LOSS_REGISTRY = Registry("loss") 107 | METRIC_REGISTRY = Registry("metric") 108 | -------------------------------------------------------------------------------- /neosr/utils/rng.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from pathlib import Path 3 | 4 | from numpy.random import default_rng 5 | 6 | from neosr.utils.options import parse_options 7 | 8 | 9 | def rng() -> Callable: 10 | root_path = Path(__file__).parents[2] 11 | opt, __ = parse_options(str(root_path), is_train=True) 12 | if opt is not None: 13 | seed: int | None = opt.get("manual_seed", None) 14 | return default_rng(seed=seed) if seed is not None else default_rng() # type: ignore[return-value] 15 | -------------------------------------------------------------------------------- /options/test_asid.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_asid" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "asid" 18 | #type = "asid_d8" 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_atd.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_atd" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "atd" 18 | #type = "atd_light" 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_catanet.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_catanet" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "catanet" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_cfsr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_cfsr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "cfsr" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_compact.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_compact" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "compact" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_craft.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_craft" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "craft" 18 | #flash_attn = false 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_cugan.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_cugan" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "cugan" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_dat.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_dat" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "dat_s" 18 | #type = "dat_m" 19 | #type = "dat_2" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_dct.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_dct" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "dct" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_dctlsa.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_dctlsa" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "dctlsa" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_ditn.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_ditn" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "ditn" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_drct.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_drct" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "drct" 18 | #type = "drct_l" 19 | #type = "drct_xl" 20 | #type = "drct_s" 21 | 22 | [path] 23 | pretrain_network_g = 'C:\model.pth' 24 | -------------------------------------------------------------------------------- /options/test_eimn.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_eimn" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "eimn" 18 | #type = "eimn_a" 19 | #type = "eimn_l" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_esc.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_esc" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "esc" 18 | #type = "esc_light" 19 | #type = "esc_large" 20 | #deployment = true 21 | #do_compile = true 22 | 23 | [path] 24 | pretrain_network_g = 'C:\model.pth' 25 | -------------------------------------------------------------------------------- /options/test_esrgan.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_esrgan" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "esrgan" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_flexnet.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_flexnet" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "flexnet" 18 | #type = "metaflexnet" 19 | #flash_attn = false 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_grformer.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_grformer" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "grformer" 18 | #type = "grformer_medium" 19 | #type = "grformer_large" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_hasn.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_hasn" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "hasn" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_hat.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_hat" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "hat_s" 18 | #type = "hat_m" 19 | #type = "hat_l" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_hitsrf.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_hitsrf" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "hit_srf" 18 | #type = "hit_srf_medium" 19 | #type = "hit_srf_large" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_hma.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_hma" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "hma" 18 | #type = "hma_medium" 19 | #type = "hma_large" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_krgn.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_krgn" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "krgn" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_lmlt.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_lmlt" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "lmlt" 18 | #type = "lmlt_tiny" 19 | #type = "lmlt_large" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_man.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_man" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "man" 18 | #type = "man_tiny" 19 | #type = "man_light" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_moesr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_moesr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "moesr" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_mosrv2.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_mosrv2" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "mosrv2" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_msdan.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_msdan" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "msdan" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_ninasr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_ninasr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "ninasr" 18 | #type = "ninasr_b0" 19 | #type = "ninasr_b2" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_omnisr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_omnisr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "omnisr" 18 | upsampling = 4 19 | window_size = 8 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_plainusr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_plainusr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "plainusr" 18 | #type = "plainusr_ultra" 19 | #type = "plainusr_large" 20 | 21 | [path] 22 | pretrain_network_g = 'C:\model.pth' 23 | -------------------------------------------------------------------------------- /options/test_plksr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_plksr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "plksr" 18 | #type = "plksr_tiny" 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_rcan.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_rcan" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "rcan" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_realplksr.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_realplksr" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "realplksr" 18 | #type = "realplksr_s" 19 | #type = "realplksr_l" 20 | #dysample = true 21 | 22 | [path] 23 | pretrain_network_g = 'C:\model.pth' 24 | -------------------------------------------------------------------------------- /options/test_rgt.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_rgt" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "rgt" 18 | #type = "rgt_s" 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_safmn.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_safmn" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "safmn" 18 | #type = "safmn_l" 19 | #type = "light_safmnpp" 20 | #bcie = true 21 | 22 | [path] 23 | pretrain_network_g = 'C:\model.pth' 24 | -------------------------------------------------------------------------------- /options/test_scnet.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_scnet" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "scnet" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_sebica.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_sebica" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "sebica" 18 | #type = "sebica_mini" 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_span.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_span" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "span" 18 | 19 | [path] 20 | pretrain_network_g = 'C:\model.pth' 21 | -------------------------------------------------------------------------------- /options/test_spanplus.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_spanplus" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "spanplus" 18 | #type = "spanplus_sts" 19 | #type = "spanplus_s" 20 | #type = "spanplus_st" 21 | 22 | [path] 23 | pretrain_network_g = 'C:\model.pth' 24 | -------------------------------------------------------------------------------- /options/test_srformer.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_srformer" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "srformer_light" 18 | #type = "srformer_medium" 19 | 20 | [path] 21 | pretrain_network_g = 'C:\model.pth' 22 | -------------------------------------------------------------------------------- /options/test_swinir.toml: -------------------------------------------------------------------------------- 1 | # Results will be saved to neosr/experiments/results/ 2 | 3 | name = "test_swinir" 4 | model_type = "image" 5 | scale = 4 6 | #use_amp = true 7 | #compile = true 8 | 9 | [datasets.test_1] 10 | name = "val_1" 11 | type = "single" 12 | dataroot_lq = 'C:\datasets\val\' 13 | [val] 14 | #tile = 200 15 | 16 | [network_g] 17 | type = "swinir_small" 18 | #type = "swinir_medium" 19 | #type = "swinir_large" 20 | #flash_attn = true 21 | 22 | [path] 23 | pretrain_network_g = 'C:\model.pth' 24 | -------------------------------------------------------------------------------- /options/train_asid.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_asid" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "asid" 45 | #type = "asid_d8" 46 | #drop = 0.5 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_atd.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_atd" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "atd" 45 | #type = "atd_light" 46 | 47 | [network_d] 48 | type = "metagan" 49 | 50 | [train] 51 | ema = 0.999 52 | wavelet_init = 80000 53 | match_lq_colors = true 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #wavelet_guided = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 1e-3 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_catanet.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_catanet" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "catanet" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.5 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_cfsr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_cfsr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "cfsr" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 8e-4 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_cfsr_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_span_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 64 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "cfsr" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 8e-4 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_compact.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_compact" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "compact" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | grad_clip = false 51 | ema = 0.999 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 1e-3 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_craft.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_craft" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "craft" 45 | #flash_attn = false 46 | 47 | [network_d] 48 | type = "metagan" 49 | 50 | [train] 51 | ema = 0.999 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 1e-3 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_cugan.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_cugan" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "cugan" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | clamp = false 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 1e-3 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_dat.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_dat" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "dat_m" 45 | #type = "dat_s" 46 | #type = "dat_2" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_dct.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_dct" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "dct" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_dct_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_dct_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 32 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "dct" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 1e-3 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_dctlsa.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_dctlsa" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "dctlsa" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | #sam = "fsam" 52 | #sam_init = 1000 53 | #eco = true 54 | #eco_init = 15000 55 | #wavelet_guided = true 56 | #wavelet_init = 80000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | type = "fdl_loss" 89 | model = "dinov2" # "vgg", "resnet", "effnet" 90 | loss_weight = 0.75 91 | 92 | [train.gan_opt] 93 | type = "gan_loss" 94 | gan_type = "bce" 95 | loss_weight = 0.3 96 | 97 | #[train.msswd_opt] 98 | #type = "msswd_loss" 99 | #loss_weight = 1.0 100 | 101 | #[train.perceptual_opt] 102 | #type = "vgg_perceptual_loss" 103 | #loss_weight = 0.5 104 | #criterion = "huber" 105 | ##patchloss = true 106 | ##ipk = true 107 | ##patch_weight = 1.0 108 | #[train.fdl_opt] 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.gw_opt] 119 | #type = "gw_loss" 120 | #loss_weight = 0.1 121 | 122 | #[train.ncc_opt] 123 | #type = "ncc_loss" 124 | #loss_weight = 1.0 125 | 126 | #[train.kl_opt] 127 | #type = "kl_loss" 128 | #loss_weight = 1.0 129 | 130 | [logger] 131 | total_iter = 1000000 132 | save_checkpoint_freq = 1000 133 | use_tb_logger = true 134 | #save_tb_img = true 135 | #print_freq = 100 136 | -------------------------------------------------------------------------------- /options/train_ditn.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_ditn" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "ditn" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_ditn_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_ditn_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 64 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "ditn" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 1e-3 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_drct.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_drct" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "drct" 45 | #type = "drct_l" 46 | #type = "drct_xl" 47 | #type = "drct_s" 48 | 49 | 50 | [network_d] 51 | type = "metagan" 52 | 53 | [train] 54 | ema = 0.999 55 | wavelet_guided = true 56 | wavelet_init = 80000 57 | #sam = "fsam" 58 | #sam_init = 1000 59 | #eco = true 60 | #eco_init = 15000 61 | #match_lq_colors = true 62 | 63 | [train.optim_g] 64 | type = "adan_sf" 65 | lr = 1e-3 66 | betas = [ 0.98, 0.92, 0.99 ] 67 | weight_decay = 0.01 68 | schedule_free = true 69 | warmup_steps = 1600 70 | 71 | [train.optim_d] 72 | type = "adan_sf" 73 | lr = 1e-4 74 | betas = [ 0.98, 0.92, 0.99 ] 75 | weight_decay = 0.01 76 | schedule_free = true 77 | warmup_steps = 600 78 | 79 | # losses 80 | [train.mssim_opt] 81 | type = "mssim_loss" 82 | loss_weight = 1.0 83 | 84 | [train.consistency_opt] 85 | type = "consistency_loss" 86 | loss_weight = 1.0 87 | 88 | [train.ldl_opt] 89 | type = "ldl_loss" 90 | loss_weight = 1.0 91 | 92 | [train.fdl_opt] 93 | type = "fdl_loss" 94 | model = "dinov2" # "vgg", "resnet", "effnet" 95 | loss_weight = 0.75 96 | 97 | [train.gan_opt] 98 | type = "gan_loss" 99 | gan_type = "bce" 100 | loss_weight = 0.3 101 | 102 | #[train.msswd_opt] 103 | #type = "msswd_loss" 104 | #loss_weight = 1.0 105 | 106 | #[train.perceptual_opt] 107 | #type = "vgg_perceptual_loss" 108 | #loss_weight = 0.5 109 | #criterion = "huber" 110 | ##patchloss = true 111 | ##ipk = true 112 | ##patch_weight = 1.0 113 | 114 | #[train.dists_opt] 115 | #type = "dists_loss" 116 | #loss_weight = 0.5 117 | 118 | #[train.ff_opt] 119 | #type = "ff_loss" 120 | #loss_weight = 0.35 121 | 122 | #[train.ncc_opt] 123 | #type = "ncc_loss" 124 | #loss_weight = 1.0 125 | 126 | #[train.kl_opt] 127 | #type = "kl_loss" 128 | #loss_weight = 1.0 129 | 130 | [logger] 131 | total_iter = 1000000 132 | save_checkpoint_freq = 1000 133 | use_tb_logger = true 134 | #save_tb_img = true 135 | #print_freq = 100 136 | -------------------------------------------------------------------------------- /options/train_eimn.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_eimn" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "eimn" 45 | #type = "eimn_a" 46 | #type = "eimn_l" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 5e-4 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | 117 | #[train.ff_opt] 118 | #type = "ff_loss" 119 | #loss_weight = 0.35 120 | 121 | #[train.ncc_opt] 122 | #type = "ncc_loss" 123 | #loss_weight = 1.0 124 | 125 | #[train.kl_opt] 126 | #type = "kl_loss" 127 | #loss_weight = 1.0 128 | 129 | [logger] 130 | total_iter = 1000000 131 | save_checkpoint_freq = 1000 132 | use_tb_logger = true 133 | #save_tb_img = true 134 | #print_freq = 100 135 | -------------------------------------------------------------------------------- /options/train_esc.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_esc" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "esc" 45 | #type = "esc_light" 46 | #type = "esc_large" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 5e-4 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_esrgan.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_esrgan" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "esrgan" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | clamp = false 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 8e-4 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_flexnet.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_flexnet" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "flexnet" 45 | #type = "metaflexnet" 46 | #flash_attn = false 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 5e-4 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_grformer.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_grformer" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "grformer" 45 | #type = "grformer_medium" 46 | #type = "grformer_large" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_hasn.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_hasn" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "hasn" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_hasn_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_hasn_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 64 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "hasn" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 1e-3 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_hat.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_hat" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "hat_m" 45 | #type = "hat_s" 46 | #type = "hat_l" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_hitsrf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_hitsrf" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "hit_srf" 45 | #type = "hit_srf_medium" 46 | #type = "hit_srf_large" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_hma.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_hma" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "hma" 45 | #type = "hma_medium" 46 | #type = "hma_large" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_krgn.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_krgn" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "krgn" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 5e-4 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_krgn_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_krgn_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 64 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "krgn" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 5e-4 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_lmlt.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_lmlt" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "lmlt" 45 | #type = "lmlt_tiny" 46 | #type = "lmlt_large" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_man.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_man" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "man" 45 | #type = "man_tiny" 46 | #type = "man_light" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_moesr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_moesr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "moesr" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_mosrv2.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_mosrv2" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "mosrv2" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_msdan.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_msdan" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "msdan" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_ninasr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_ninasr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "ninasr" 45 | #type = "ninasr_b0" 46 | #type = "ninasr_b2" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | clamp = false 54 | wavelet_guided = true 55 | wavelet_init = 80000 56 | #sam = "fsam" 57 | #sam_init = 1000 58 | #eco = true 59 | #eco_init = 15000 60 | #match_lq_colors = true 61 | 62 | [train.optim_g] 63 | type = "adan_sf" 64 | lr = 1e-3 65 | betas = [ 0.98, 0.92, 0.99 ] 66 | weight_decay = 0.01 67 | schedule_free = true 68 | warmup_steps = 1600 69 | 70 | [train.optim_d] 71 | type = "adan_sf" 72 | lr = 1e-4 73 | betas = [ 0.98, 0.92, 0.99 ] 74 | weight_decay = 0.01 75 | schedule_free = true 76 | warmup_steps = 600 77 | 78 | # losses 79 | [train.mssim_opt] 80 | type = "mssim_loss" 81 | loss_weight = 1.0 82 | 83 | [train.consistency_opt] 84 | type = "consistency_loss" 85 | loss_weight = 1.0 86 | 87 | [train.ldl_opt] 88 | type = "ldl_loss" 89 | loss_weight = 1.0 90 | 91 | [train.fdl_opt] 92 | type = "fdl_loss" 93 | model = "dinov2" # "vgg", "resnet", "effnet" 94 | loss_weight = 0.75 95 | 96 | [train.gan_opt] 97 | type = "gan_loss" 98 | gan_type = "bce" 99 | loss_weight = 0.3 100 | 101 | #[train.msswd_opt] 102 | #type = "msswd_loss" 103 | #loss_weight = 1.0 104 | 105 | #[train.perceptual_opt] 106 | #type = "vgg_perceptual_loss" 107 | #loss_weight = 0.5 108 | #criterion = "huber" 109 | ##patchloss = true 110 | ##ipk = true 111 | ##patch_weight = 1.0 112 | 113 | #[train.dists_opt] 114 | #type = "dists_loss" 115 | #loss_weight = 0.5 116 | 117 | #[train.ff_opt] 118 | #type = "ff_loss" 119 | #loss_weight = 0.35 120 | 121 | #[train.ncc_opt] 122 | #type = "ncc_loss" 123 | #loss_weight = 1.0 124 | 125 | #[train.kl_opt] 126 | #type = "kl_loss" 127 | #loss_weight = 1.0 128 | 129 | [logger] 130 | total_iter = 1000000 131 | save_checkpoint_freq = 1000 132 | use_tb_logger = true 133 | #save_tb_img = true 134 | #print_freq = 100 135 | -------------------------------------------------------------------------------- /options/train_omnisr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_omnisr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "omnisr" 45 | upsampling = 4 46 | window_size = 8 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_plainusr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_plainusr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "plainusr" 45 | #type = "plainusr_ultra" 46 | #type = "plainusr_large" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 5e-4 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_plksr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_plksr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "plksr" 45 | #type = "plksr_tiny" 46 | 47 | [network_d] 48 | type = "metagan" 49 | 50 | [train] 51 | ema = 0.999 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 5e-4 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg" # "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_rcan.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_rcan" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "rcan" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_rcan_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_rcan_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 64 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "rcan" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 1e-3 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_realplksr.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_realplksr" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "realplksr" 45 | #type = "realplksr_s" 46 | #type = "realplksr_l" 47 | #dysample = true 48 | 49 | [network_d] 50 | type = "metagan" 51 | 52 | [train] 53 | ema = 0.999 54 | wavelet_guided = true 55 | wavelet_init = 80000 56 | #sam = "fsam" 57 | #sam_init = 1000 58 | #eco = true 59 | #eco_init = 15000 60 | #match_lq_colors = true 61 | 62 | [train.optim_g] 63 | type = "adan_sf" 64 | lr = 5e-4 65 | betas = [ 0.98, 0.92, 0.99 ] 66 | weight_decay = 0.01 67 | schedule_free = true 68 | warmup_steps = 1600 69 | 70 | [train.optim_d] 71 | type = "adan_sf" 72 | lr = 1e-4 73 | betas = [ 0.98, 0.92, 0.99 ] 74 | weight_decay = 0.01 75 | schedule_free = true 76 | warmup_steps = 600 77 | 78 | # losses 79 | [train.mssim_opt] 80 | type = "mssim_loss" 81 | loss_weight = 1.0 82 | 83 | [train.consistency_opt] 84 | type = "consistency_loss" 85 | loss_weight = 1.0 86 | 87 | [train.ldl_opt] 88 | type = "ldl_loss" 89 | loss_weight = 1.0 90 | 91 | [train.fdl_opt] 92 | type = "fdl_loss" 93 | model = "dinov2" # "vgg", "resnet", "effnet" 94 | loss_weight = 0.75 95 | 96 | [train.gan_opt] 97 | type = "gan_loss" 98 | gan_type = "bce" 99 | loss_weight = 0.3 100 | 101 | #[train.msswd_opt] 102 | #type = "msswd_loss" 103 | #loss_weight = 1.0 104 | 105 | #[train.perceptual_opt] 106 | #type = "vgg_perceptual_loss" 107 | #loss_weight = 0.5 108 | #criterion = "huber" 109 | ##patchloss = true 110 | ##ipk = true 111 | ##patch_weight = 1.0 112 | 113 | #[train.dists_opt] 114 | #type = "dists_loss" 115 | #loss_weight = 0.5 116 | 117 | #[train.ff_opt] 118 | #type = "ff_loss" 119 | #loss_weight = 0.35 120 | 121 | #[train.ncc_opt] 122 | #type = "ncc_loss" 123 | #loss_weight = 1.0 124 | 125 | #[train.kl_opt] 126 | #type = "kl_loss" 127 | #loss_weight = 1.0 128 | 129 | [logger] 130 | total_iter = 1000000 131 | save_checkpoint_freq = 1000 132 | use_tb_logger = true 133 | #save_tb_img = true 134 | #print_freq = 100 135 | -------------------------------------------------------------------------------- /options/train_rgt.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_rgt" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "rgt" 45 | #type = "rgt_s" 46 | 47 | [network_d] 48 | type = "metagan" 49 | 50 | [train] 51 | ema = 0.999 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 1e-3 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_safmn.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_safmn" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "safmn" 45 | #type = "safmn_l" 46 | #type = "light_safmnpp" 47 | 48 | [network_d] 49 | type = "metagan" 50 | 51 | [train] 52 | ema = 0.999 53 | wavelet_guided = true 54 | wavelet_init = 80000 55 | #sam = "fsam" 56 | #sam_init = 1000 57 | #eco = true 58 | #eco_init = 15000 59 | #match_lq_colors = true 60 | 61 | [train.optim_g] 62 | type = "adan_sf" 63 | lr = 1e-3 64 | betas = [ 0.98, 0.92, 0.99 ] 65 | weight_decay = 0.01 66 | schedule_free = true 67 | warmup_steps = 1600 68 | 69 | [train.optim_d] 70 | type = "adan_sf" 71 | lr = 1e-4 72 | betas = [ 0.98, 0.92, 0.99 ] 73 | weight_decay = 0.01 74 | schedule_free = true 75 | warmup_steps = 600 76 | 77 | # losses 78 | [train.mssim_opt] 79 | type = "mssim_loss" 80 | loss_weight = 1.0 81 | 82 | [train.consistency_opt] 83 | type = "consistency_loss" 84 | loss_weight = 1.0 85 | 86 | [train.ldl_opt] 87 | type = "ldl_loss" 88 | loss_weight = 1.0 89 | 90 | [train.fdl_opt] 91 | type = "fdl_loss" 92 | model = "dinov2" # "vgg", "resnet", "effnet" 93 | loss_weight = 0.75 94 | 95 | [train.gan_opt] 96 | type = "gan_loss" 97 | gan_type = "bce" 98 | loss_weight = 0.3 99 | 100 | #[train.msswd_opt] 101 | #type = "msswd_loss" 102 | #loss_weight = 1.0 103 | 104 | #[train.perceptual_opt] 105 | #type = "vgg_perceptual_loss" 106 | #loss_weight = 0.5 107 | #criterion = "huber" 108 | ##patchloss = true 109 | ##ipk = true 110 | ##patch_weight = 1.0 111 | 112 | #[train.dists_opt] 113 | #type = "dists_loss" 114 | #loss_weight = 0.5 115 | 116 | #[train.ff_opt] 117 | #type = "ff_loss" 118 | #loss_weight = 0.35 119 | 120 | #[train.ncc_opt] 121 | #type = "ncc_loss" 122 | #loss_weight = 1.0 123 | 124 | #[train.kl_opt] 125 | #type = "kl_loss" 126 | #loss_weight = 1.0 127 | 128 | [logger] 129 | total_iter = 1000000 130 | save_checkpoint_freq = 1000 131 | use_tb_logger = true 132 | #save_tb_img = true 133 | #print_freq = 100 134 | -------------------------------------------------------------------------------- /options/train_sebica.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_sebica" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "sebica" 45 | #type = "sebica_mini" 46 | 47 | [network_d] 48 | type = "metagan" 49 | 50 | [train] 51 | ema = 0.999 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 5e-4 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_span.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_span" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "span" 45 | 46 | [network_d] 47 | type = "metagan" 48 | 49 | [train] 50 | ema = 0.999 51 | wavelet_guided = true 52 | wavelet_init = 80000 53 | #sam = "fsam" 54 | #sam_init = 1000 55 | #eco = true 56 | #eco_init = 15000 57 | #match_lq_colors = true 58 | 59 | [train.optim_g] 60 | type = "adan_sf" 61 | lr = 1e-3 62 | betas = [ 0.98, 0.92, 0.99 ] 63 | weight_decay = 0.01 64 | schedule_free = true 65 | warmup_steps = 1600 66 | 67 | [train.optim_d] 68 | type = "adan_sf" 69 | lr = 1e-4 70 | betas = [ 0.98, 0.92, 0.99 ] 71 | weight_decay = 0.01 72 | schedule_free = true 73 | warmup_steps = 600 74 | 75 | # losses 76 | [train.mssim_opt] 77 | type = "mssim_loss" 78 | loss_weight = 1.0 79 | 80 | [train.consistency_opt] 81 | type = "consistency_loss" 82 | loss_weight = 1.0 83 | 84 | [train.ldl_opt] 85 | type = "ldl_loss" 86 | loss_weight = 1.0 87 | 88 | [train.fdl_opt] 89 | type = "fdl_loss" 90 | model = "dinov2" # "vgg", "resnet", "effnet" 91 | loss_weight = 0.75 92 | 93 | [train.gan_opt] 94 | type = "gan_loss" 95 | gan_type = "bce" 96 | loss_weight = 0.3 97 | 98 | #[train.msswd_opt] 99 | #type = "msswd_loss" 100 | #loss_weight = 1.0 101 | 102 | #[train.perceptual_opt] 103 | #type = "vgg_perceptual_loss" 104 | #loss_weight = 0.5 105 | #criterion = "huber" 106 | ##patchloss = true 107 | ##ipk = true 108 | ##patch_weight = 1.0 109 | 110 | #[train.dists_opt] 111 | #type = "dists_loss" 112 | #loss_weight = 0.5 113 | 114 | #[train.ff_opt] 115 | #type = "ff_loss" 116 | #loss_weight = 0.35 117 | 118 | #[train.ncc_opt] 119 | #type = "ncc_loss" 120 | #loss_weight = 1.0 121 | 122 | #[train.kl_opt] 123 | #type = "kl_loss" 124 | #loss_weight = 1.0 125 | 126 | [logger] 127 | total_iter = 1000000 128 | save_checkpoint_freq = 1000 129 | use_tb_logger = true 130 | #save_tb_img = true 131 | #print_freq = 100 132 | -------------------------------------------------------------------------------- /options/train_span_otf.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_span_otf" 3 | model_type = "otf" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "otf" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | patch_size = 64 15 | batch_size = 8 16 | #accumulate = 1 17 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 18 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 19 | 20 | [degradations] 21 | resize_prob = [ 0.3, 0.4, 0.3 ] 22 | resize_range = [ 0.5, 1.5 ] 23 | gaussian_noise_prob = 0.2 24 | noise_range = [ 0, 2 ] 25 | poisson_scale_range = [ 0.05, 0.25 ] 26 | gray_noise_prob = 0.1 27 | jpeg_range = [ 40, 95 ] 28 | second_blur_prob = 0.4 29 | resize_prob2 = [ 0.3, 0.4, 0.3 ] 30 | resize_range2 = [ 0.3, 1.5 ] 31 | gaussian_noise_prob2 = 0.2 32 | noise_range2 = [ 0, 2 ] 33 | poisson_scale_range2 = [ 0.05, 0.1 ] 34 | gray_noise_prob2 = 0.1 35 | jpeg_range2 = [ 35, 95 ] 36 | 37 | blur_kernel_size = 7 38 | kernel_list = [ 39 | "iso", 40 | "aniso", 41 | "generalized_iso", 42 | "generalized_aniso", 43 | "plateau_iso", 44 | "plateau_aniso" 45 | ] 46 | kernel_prob = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 47 | sinc_prob = 0.1 48 | blur_sigma = [ 0.2, 3 ] 49 | betag_range = [ 0.5, 4 ] 50 | betap_range = [ 1, 2 ] 51 | blur_kernel_size2 = 9 52 | kernel_list2 = [ 53 | "iso", 54 | "aniso", 55 | "generalized_iso", 56 | "generalized_aniso", 57 | "plateau_iso", 58 | "plateau_aniso" 59 | ] 60 | kernel_prob2 = [ 0.45, 0.25, 0.12, 0.03, 0.12, 0.03 ] 61 | sinc_prob2 = 0.1 62 | blur_sigma2 = [ 0.2, 1.5 ] 63 | betag_range2 = [ 0.5, 4 ] 64 | betap_range2 = [ 1, 2 ] 65 | final_sinc_prob = 0.8 66 | 67 | [datasets.val] 68 | name = "val" 69 | type = "paired" 70 | dataroot_gt = 'C:\datasets\val\gt\' 71 | dataroot_lq = 'C:\datasets\val\lq\' 72 | [val] 73 | val_freq = 1000 74 | #tile = 200 75 | #[val.metrics.psnr] 76 | #type = "calculate_psnr" 77 | #[val.metrics.ssim] 78 | #type = "calculate_ssim" 79 | #[val.metrics.dists] 80 | #type = "calculate_dists" 81 | #better = "lower" 82 | #[val.metrics.topiq] 83 | #type = "calculate_topiq" 84 | 85 | [path] 86 | #pretrain_network_g = 'experiments\pretrain_g.pth' 87 | #pretrain_network_d = 'experiments\pretrain_d.pth' 88 | 89 | [network_g] 90 | type = "span" 91 | 92 | [network_d] 93 | type = "metagan" 94 | 95 | [train] 96 | ema = 0.999 97 | wavelet_guided = true 98 | wavelet_init = 80000 99 | #sam = "fsam" 100 | #sam_init = 1000 101 | #eco = true 102 | #eco_init = 15000 103 | #match_lq = true 104 | 105 | [train.optim_g] 106 | type = "adan_sf" 107 | lr = 1e-3 108 | betas = [ 0.98, 0.92, 0.99 ] 109 | weight_decay = 0.01 110 | schedule_free = true 111 | warmup_steps = 1600 112 | 113 | [train.optim_d] 114 | type = "adan_sf" 115 | lr = 1e-4 116 | betas = [ 0.98, 0.92, 0.99 ] 117 | weight_decay = 0.01 118 | schedule_free = true 119 | warmup_steps = 600 120 | 121 | # losses 122 | [train.mssim_opt] 123 | type = "mssim_loss" 124 | loss_weight = 1.0 125 | 126 | [train.consistency_opt] 127 | type = "consistency_loss" 128 | loss_weight = 1.0 129 | 130 | [train.ldl_opt] 131 | type = "ldl_loss" 132 | loss_weight = 1.0 133 | 134 | [train.fdl_opt] 135 | type = "fdl_loss" 136 | model = "dinov2" # "vgg", "resnet", "effnet" 137 | loss_weight = 0.75 138 | 139 | [train.gan_opt] 140 | type = "gan_loss" 141 | gan_type = "bce" 142 | loss_weight = 0.3 143 | 144 | #[train.msswd_opt] 145 | #type = "msswd_loss" 146 | #loss_weight = 1.0 147 | 148 | #[train.perceptual_opt] 149 | #type = "vgg_perceptual_loss" 150 | #loss_weight = 0.5 151 | #criterion = "huber" 152 | ##patchloss = true 153 | ##ipk = true 154 | ##patch_weight = 1.0 155 | 156 | #[train.dists_opt] 157 | #type = "dists_loss" 158 | #loss_weight = 0.5 159 | 160 | #[train.ff_opt] 161 | #type = "ff_loss" 162 | #loss_weight = 0.35 163 | 164 | #[train.ncc_opt] 165 | #type = "ncc_loss" 166 | #loss_weight = 1.0 167 | 168 | #[train.kl_opt] 169 | #type = "kl_loss" 170 | #loss_weight = 1.0 171 | 172 | [logger] 173 | total_iter = 1000000 174 | save_checkpoint_freq = 1000 175 | use_tb_logger = true 176 | #save_tb_img = true 177 | #print_freq = 100 178 | -------------------------------------------------------------------------------- /options/train_spanplus.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_spanplus" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 64 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "spanplus" 45 | #type = "spanplus_sts" 46 | #type = "spanplus_s" 47 | #type = "spanplus_st" 48 | 49 | [network_d] 50 | type = "metagan" 51 | 52 | [train] 53 | ema = 0.999 54 | wavelet_guided = true 55 | wavelet_init = 80000 56 | #sam = "fsam" 57 | #sam_init = 1000 58 | #eco = true 59 | #eco_init = 15000 60 | #match_lq_colors = true 61 | 62 | [train.optim_g] 63 | type = "adan_sf" 64 | lr = 1e-3 65 | betas = [ 0.98, 0.92, 0.99 ] 66 | weight_decay = 0.01 67 | schedule_free = true 68 | warmup_steps = 1600 69 | 70 | [train.optim_d] 71 | type = "adan_sf" 72 | lr = 1e-4 73 | betas = [ 0.98, 0.92, 0.99 ] 74 | weight_decay = 0.01 75 | schedule_free = true 76 | warmup_steps = 600 77 | 78 | # losses 79 | [train.mssim_opt] 80 | type = "mssim_loss" 81 | loss_weight = 1.0 82 | 83 | [train.consistency_opt] 84 | type = "consistency_loss" 85 | loss_weight = 1.0 86 | 87 | [train.ldl_opt] 88 | type = "ldl_loss" 89 | loss_weight = 1.0 90 | 91 | [train.fdl_opt] 92 | type = "fdl_loss" 93 | model = "dinov2" # "vgg", "resnet", "effnet" 94 | loss_weight = 0.75 95 | 96 | [train.gan_opt] 97 | type = "gan_loss" 98 | gan_type = "bce" 99 | loss_weight = 0.3 100 | 101 | #[train.msswd_opt] 102 | #type = "msswd_loss" 103 | #loss_weight = 1.0 104 | 105 | #[train.perceptual_opt] 106 | #type = "vgg_perceptual_loss" 107 | #loss_weight = 0.5 108 | #criterion = "huber" 109 | ##patchloss = true 110 | ##ipk = true 111 | ##patch_weight = 1.0 112 | 113 | #[train.dists_opt] 114 | #type = "dists_loss" 115 | #loss_weight = 0.5 116 | 117 | #[train.ff_opt] 118 | #type = "ff_loss" 119 | #loss_weight = 0.35 120 | 121 | #[train.ncc_opt] 122 | #type = "ncc_loss" 123 | #loss_weight = 1.0 124 | 125 | #[train.kl_opt] 126 | #type = "kl_loss" 127 | #loss_weight = 1.0 128 | 129 | [logger] 130 | total_iter = 1000000 131 | save_checkpoint_freq = 1000 132 | use_tb_logger = true 133 | #save_tb_img = true 134 | #print_freq = 100 135 | -------------------------------------------------------------------------------- /options/train_srformer.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_srformer" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "srformer_light" 45 | #type = "srformer_medium" 46 | 47 | [network_d] 48 | type = "metagan" 49 | 50 | [train] 51 | ema = 0.999 52 | wavelet_guided = true 53 | wavelet_init = 80000 54 | #sam = "fsam" 55 | #sam_init = 1000 56 | #eco = true 57 | #eco_init = 15000 58 | #match_lq_colors = true 59 | 60 | [train.optim_g] 61 | type = "adan_sf" 62 | lr = 1e-3 63 | betas = [ 0.98, 0.92, 0.99 ] 64 | weight_decay = 0.01 65 | schedule_free = true 66 | warmup_steps = 1600 67 | 68 | [train.optim_d] 69 | type = "adan_sf" 70 | lr = 1e-4 71 | betas = [ 0.98, 0.92, 0.99 ] 72 | weight_decay = 0.01 73 | schedule_free = true 74 | warmup_steps = 600 75 | 76 | # losses 77 | [train.mssim_opt] 78 | type = "mssim_loss" 79 | loss_weight = 1.0 80 | 81 | [train.consistency_opt] 82 | type = "consistency_loss" 83 | loss_weight = 1.0 84 | 85 | [train.ldl_opt] 86 | type = "ldl_loss" 87 | loss_weight = 1.0 88 | 89 | [train.fdl_opt] 90 | type = "fdl_loss" 91 | model = "dinov2" # "vgg", "resnet", "effnet" 92 | loss_weight = 0.75 93 | 94 | [train.gan_opt] 95 | type = "gan_loss" 96 | gan_type = "bce" 97 | loss_weight = 0.3 98 | 99 | #[train.msswd_opt] 100 | #type = "msswd_loss" 101 | #loss_weight = 1.0 102 | 103 | #[train.perceptual_opt] 104 | #type = "vgg_perceptual_loss" 105 | #loss_weight = 0.5 106 | #criterion = "huber" 107 | ##patchloss = true 108 | ##ipk = true 109 | ##patch_weight = 1.0 110 | 111 | #[train.dists_opt] 112 | #type = "dists_loss" 113 | #loss_weight = 0.5 114 | 115 | #[train.ff_opt] 116 | #type = "ff_loss" 117 | #loss_weight = 0.35 118 | 119 | #[train.ncc_opt] 120 | #type = "ncc_loss" 121 | #loss_weight = 1.0 122 | 123 | #[train.kl_opt] 124 | #type = "kl_loss" 125 | #loss_weight = 1.0 126 | 127 | [logger] 128 | total_iter = 1000000 129 | save_checkpoint_freq = 1000 130 | use_tb_logger = true 131 | #save_tb_img = true 132 | #print_freq = 100 133 | -------------------------------------------------------------------------------- /options/train_swinir.toml: -------------------------------------------------------------------------------- 1 | 2 | name = "train_swinir" 3 | model_type = "image" 4 | scale = 4 5 | use_amp = true 6 | bfloat16 = true 7 | fast_matmul = true 8 | #compile = true 9 | #manual_seed = 1024 10 | 11 | [datasets.train] 12 | type = "paired" 13 | dataroot_gt = 'C:\datasets\gt\' 14 | dataroot_lq = 'C:\datasets\lq\' 15 | patch_size = 32 16 | batch_size = 8 17 | #accumulate = 1 18 | augmentation = [ "none", "mixup", "cutmix", "resizemix", "cutblur" ] 19 | aug_prob = [ 0.5, 0.1, 0.1, 0.1, 0.5 ] 20 | 21 | [datasets.val] 22 | name = "val" 23 | type = "paired" 24 | dataroot_gt = 'C:\datasets\val\gt\' 25 | dataroot_lq = 'C:\datasets\val\lq\' 26 | [val] 27 | val_freq = 1000 28 | #tile = 200 29 | #[val.metrics.psnr] 30 | #type = "calculate_psnr" 31 | #[val.metrics.ssim] 32 | #type = "calculate_ssim" 33 | #[val.metrics.dists] 34 | #type = "calculate_dists" 35 | #better = "lower" 36 | #[val.metrics.topiq] 37 | #type = "calculate_topiq" 38 | 39 | [path] 40 | #pretrain_network_g = 'experiments\pretrain_g.pth' 41 | #pretrain_network_d = 'experiments\pretrain_d.pth' 42 | 43 | [network_g] 44 | type = "swinir_small" 45 | #type = "swinir_medium" 46 | #type = "swinir_large" 47 | #flash_attn = true 48 | 49 | [network_d] 50 | type = "metagan" 51 | 52 | [train] 53 | ema = 0.999 54 | wavelet_guided = true 55 | wavelet_init = 80000 56 | #sam = "fsam" 57 | #sam_init = 1000 58 | #eco = true 59 | #eco_init = 15000 60 | #match_lq_colors = true 61 | 62 | [train.optim_g] 63 | type = "adan_sf" 64 | lr = 1e-3 65 | betas = [ 0.98, 0.92, 0.99 ] 66 | weight_decay = 0.01 67 | schedule_free = true 68 | warmup_steps = 1600 69 | 70 | [train.optim_d] 71 | type = "adan_sf" 72 | lr = 1e-4 73 | betas = [ 0.98, 0.92, 0.995 ] 74 | weight_decay = 0.01 75 | schedule_free = true 76 | warmup_steps = 600 77 | 78 | # losses 79 | [train.mssim_opt] 80 | type = "mssim_loss" 81 | loss_weight = 1.0 82 | 83 | [train.consistency_opt] 84 | type = "consistency_loss" 85 | loss_weight = 1.0 86 | 87 | [train.ldl_opt] 88 | type = "ldl_loss" 89 | loss_weight = 1.0 90 | 91 | [train.fdl_opt] 92 | type = "fdl_loss" 93 | model = "dinov2" # "vgg", "resnet", "effnet" 94 | loss_weight = 0.75 95 | 96 | [train.gan_opt] 97 | type = "gan_loss" 98 | gan_type = "bce" 99 | loss_weight = 0.3 100 | 101 | #[train.msswd_opt] 102 | #type = "msswd_loss" 103 | #loss_weight = 1.0 104 | 105 | #[train.perceptual_opt] 106 | #type = "vgg_perceptual_loss" 107 | #loss_weight = 0.5 108 | #criterion = "huber" 109 | ##patchloss = true 110 | ##ipk = true 111 | ##patch_weight = 1.0 112 | 113 | #[train.dists_opt] 114 | #type = "dists_loss" 115 | #loss_weight = 0.5 116 | 117 | #[train.ff_opt] 118 | #type = "ff_loss" 119 | #loss_weight = 0.35 120 | 121 | #[train.ncc_opt] 122 | #type = "ncc_loss" 123 | #loss_weight = 1.0 124 | 125 | #[train.kl_opt] 126 | #type = "kl_loss" 127 | #loss_weight = 1.0 128 | 129 | [logger] 130 | total_iter = 1000000 131 | save_checkpoint_freq = 1000 132 | use_tb_logger = true 133 | #save_tb_img = true 134 | #print_freq = 100 135 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "neosr" 3 | version = "1.0.0" 4 | description = "neosr is an open-source framework for training super-resolution models. It provides a comprehensive and reproducible environment for achieving state-of-the-art image restoration results, making it suitable for both the enthusiastic community, professionals and machine learning academic researchers. It serves as a versatile platform and aims to bridge the gap between practical application and academic research in the field." 5 | readme = "readme.md" 6 | requires-python = ">=3.12,<3.13" 7 | keywords = [ 8 | "neosr", "super-resolution", "machine-learning", "image-restoration" 9 | ] 10 | authors = [{ name = "neosr-project", email = "132400428+neosr-project@users.noreply.github.com" }] 11 | classifiers = [ 12 | "License :: OSI Approved :: Apache Software License", 13 | "Programming Language :: Python :: 3.12", 14 | ] 15 | 16 | dependencies = [ 17 | "einops>=0.8.1", 18 | "lmdb>=1.6.2", 19 | "numpy>=2.2.4", 20 | "onnx>=1.17.0", 21 | "onnxconverter-common>=1.14.0", 22 | "onnxruntime-gpu>=1.21.0", 23 | "opencv-python-headless>=4.11.0.86", 24 | "pywavelets>=1.8.0", 25 | "scipy>=1.15.2", 26 | "tb-nightly>=2.20.0a", 27 | "torch>=2.6", 28 | "torchvision>=0.21", 29 | "tqdm>=4.67.1", 30 | "triton>=3.2.0; sys_platform == 'linux'", 31 | ] 32 | 33 | [project.urls] 34 | Repository = "https://github.com/neosr-project/neosr" 35 | Documentation = "https://github.com/neosr-project/neosr/wiki" 36 | 37 | [tool.uv] 38 | package = false 39 | preview = true 40 | environments = [ 41 | "sys_platform == 'win32'", 42 | "sys_platform == 'linux'", 43 | ] 44 | 45 | [tool.uv.sources] 46 | torch = { index = "pytorch" } 47 | torchvision = { index = "pytorch" } 48 | 49 | [[tool.uv.index]] 50 | name = "pytorch" 51 | url = "https://download.pytorch.org/whl/cu126" 52 | explicit = true 53 | 54 | [tool.ruff] 55 | lint.select = ["ALL"] 56 | lint.fixable = ["ALL"] 57 | lint.ignore = [ 58 | "ANN", 59 | "B904", 60 | "C90", 61 | "COM812", 62 | "CPY", 63 | "D", 64 | "DOC", 65 | "ERA001", 66 | "E501", 67 | "E722", 68 | "E741", 69 | "FIX", 70 | "FBT001", 71 | "FBT002", 72 | "G004", 73 | "ISC001", 74 | "N8", 75 | "PLR", 76 | "PLC0206", 77 | "PGH003", 78 | "S101", 79 | "S110", 80 | "S311", 81 | "S403", 82 | "SLF001", 83 | "T201", 84 | "TD", 85 | ] 86 | exclude = ["*_arch.py", ".venv/*"] 87 | preview = true 88 | 89 | [tool.ruff.format] 90 | skip-magic-trailing-comma = true 91 | line-ending = "lf" 92 | quote-style = "double" 93 | 94 | [tool.ruff.lint.isort] 95 | split-on-trailing-comma = false 96 | 97 | [tool.ruff.lint.per-file-ignores] 98 | "neosr/__init__.py" = ["F403"] 99 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from os import path as osp 4 | from pathlib import Path 5 | from time import time 6 | 7 | import torch 8 | 9 | from neosr.data import build_dataloader, build_dataset 10 | from neosr.models import build_model 11 | from neosr.utils import get_root_logger, get_time_str, make_exp_dirs, tc 12 | from neosr.utils.options import parse_options 13 | 14 | 15 | def test_pipeline(root_path: str) -> None: 16 | # parse options, set distributed setting, set ramdom seed 17 | opt, _ = parse_options(root_path, is_train=False) 18 | 19 | torch.set_default_device("cuda") 20 | torch.backends.cudnn.benchmark = True 21 | 22 | # mkdir and initialize loggers 23 | make_exp_dirs(opt) 24 | log_file = Path(opt["path"]["log"]) / f"test_{opt['name']}_{get_time_str()}.log" 25 | logger = get_root_logger( 26 | logger_name="neosr", log_level=logging.INFO, log_file=str(log_file) 27 | ) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for _, dataset_opt in sorted(opt["datasets"].items()): 32 | test_set = build_dataset(dataset_opt) 33 | num_gpu = opt.get("num_gpu", "auto") 34 | test_loader = build_dataloader( 35 | test_set, # type: ignore[reportArgumentType] 36 | dataset_opt, 37 | num_gpu=num_gpu, 38 | dist=opt["dist"], 39 | sampler=None, 40 | seed=opt["manual_seed"], 41 | ) 42 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") # type: ignore[reportArgumentType] 43 | test_loaders.append(test_loader) 44 | 45 | # create model 46 | model = build_model(opt) 47 | 48 | try: 49 | for test_loader in test_loaders: 50 | test_set_name = test_loader.dataset.opt["name"] # type: ignore[attr-defined] 51 | logger.info(f"Testing {test_set_name}...") 52 | start_time = time() 53 | model.validation( # type: ignore[reportAttributeAccessIssue,attr-defined] 54 | test_loader, 55 | current_iter=opt["name"], 56 | tb_logger=None, 57 | save_img=opt["val"].get("save_img", True), 58 | ) 59 | end_time = time() 60 | total_time = end_time - start_time 61 | n_img = len(test_loader.dataset) # type: ignore[arg-type] 62 | fps = n_img / total_time 63 | logger.info( 64 | f"{tc.light_green}Inference took {total_time:.2f} seconds, at {fps:.2f} fps.{tc.end}" 65 | ) 66 | except KeyboardInterrupt: 67 | logger.info(f"{tc.red}Interrupted.{tc.end}") 68 | sys.exit(0) 69 | 70 | 71 | if __name__ == "__main__": 72 | root_path = Path.resolve(Path(__file__) / osp.pardir) 73 | test_pipeline(str(root_path)) 74 | --------------------------------------------------------------------------------