├── .DS_Store ├── Enhancement ├── cal_metrics_with_imgs.py ├── eval.py ├── eval_uciqe_uiqm.py ├── evalv2_speedup.py ├── speed_test.py └── utils.py ├── LICENSE ├── Options ├── CG_UNet_LOLv1.yml ├── CG_UNet_LOLv2Real.yml ├── CG_UNet_LOLv2Syn.yml ├── CG_UNet_UIEB.yml ├── IE_UNet_LOLv1.yml ├── IE_UNet_LOLv2Real.yml ├── IE_UNet_LOLv2Syn.yml └── IE_UNet_UIEB.yml ├── README.md ├── VERSION ├── analysis ├── .DS_Store ├── erf.py ├── flops_param.py ├── model_zoo │ ├── HWMNet.py │ ├── LLFormer.py │ ├── RetinexFormer.py │ ├── UVMNet.py │ ├── edsr.py │ ├── hat.py │ ├── mambaIR.py │ ├── rcan.py │ └── swinIR.py ├── plot.py ├── show │ └── erf │ │ ├── erf.png │ │ ├── train1000_erf.png │ │ └── trained_erf.png └── util.py ├── assets ├── RetinexFormer_LOL_v2_real.png ├── RetinexFormer_LOL_v2_synthetic.png ├── RetinexFormer_NTIRE.png ├── RetinexFormer_SDSD_indoor.png ├── RetinexFormer_SDSD_outdoor.png ├── RetinexFormer_SID.png ├── RetinexFormer_SMID.png ├── clip_bright.gif ├── clip_default.gif ├── clip_noise.gif ├── clip_quality.gif ├── dnnvsbnn.png ├── input_demo.png ├── pipeline.png ├── pipelinev2.png ├── pred_process.png ├── result_lol.png ├── result_uie.png ├── result_upaired5sets.png ├── twostagev2.png ├── vis_5sets.png ├── vis_hd.png └── vis_uie.png ├── basicsr ├── __init__.py ├── archs │ ├── UMamba_arch.py │ ├── UMambav2_arch.py │ ├── UTransformerv2_arch.py │ ├── __init__.py │ ├── arch_util.py │ └── vgg_arch.py ├── bayesian │ ├── __init__.py │ ├── base_layer.py │ ├── conv.py │ ├── linear.py │ ├── norm.py │ └── tools.py ├── data │ ├── SID_image_dataset.py │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── realesrgan_dataset.py │ ├── realesrgan_paired_dataset.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── losses │ ├── __init__.py │ ├── basic_loss.py │ ├── gan_loss.py │ ├── loss_util.py │ └── my_loss.py ├── metrics │ ├── README.md │ ├── README_CN.md │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ ├── psnr_ssim.py │ ├── test_metrics │ │ └── test_psnr_ssim.py │ └── uciqe_uiqm.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── condition_generator_model.py │ ├── condition_generatorv2_model.py │ ├── image_enhancer_model.py │ ├── image_enhancerv2_model.py │ └── lr_scheduler.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── gaussian_downsample.py │ ├── histogram.py │ ├── hog.py │ ├── img_process_util.py │ ├── img_util.py │ ├── labelnoise.py │ ├── lmdb_util.py │ ├── logger.py │ ├── mask.py │ ├── matlab_functions.py │ ├── misc.py │ ├── mixing_augment.py │ ├── noise_cal.py │ ├── options.py │ ├── plot_util.py │ ├── poisson_gaussian.py │ └── registry.py ├── version.py └── vmamba │ ├── __init__.py │ ├── config.py │ ├── configs │ ├── vssm │ │ ├── vmambav0_base_224.yaml │ │ ├── vmambav0_small_224.yaml │ │ ├── vmambav0_tiny_224.yaml │ │ ├── vmambav2_base_224.yaml │ │ ├── vmambav2_small_224.yaml │ │ ├── vmambav2_tiny_224.yaml │ │ ├── vmambav2v_base_224.yaml │ │ ├── vmambav2v_small_224.yaml │ │ └── vmambav2v_tiny_224.yaml │ ├── vssmab │ │ ├── vmambav0_tiny_224_a0.yaml │ │ ├── vmambav0_tiny_224_a01.yaml │ │ ├── vmambav0_tiny_224_a0seq.yaml │ │ ├── vmambav0_tiny_224_a1.yaml │ │ ├── vmambav0_tiny_224_a2.yaml │ │ ├── vmambav0_tiny_224_a3.yaml │ │ ├── vmambav0_tiny_224_a7.yaml │ │ ├── vmambav0_tiny_224_a8.yaml │ │ ├── vmambav2_tiny_224_a9d.yaml │ │ ├── vmambav2_tiny_224_bidi.yaml │ │ ├── vmambav2_tiny_224_bidi_ndw.yaml │ │ ├── vmambav2_tiny_224_cas2d.yaml │ │ ├── vmambav2_tiny_224_cas2d_ndw.yaml │ │ ├── vmambav2_tiny_224_ds16.yaml │ │ ├── vmambav2_tiny_224_ds2.yaml │ │ ├── vmambav2_tiny_224_ds4.yaml │ │ ├── vmambav2_tiny_224_ds8.yaml │ │ ├── vmambav2_tiny_224_gelu.yaml │ │ ├── vmambav2_tiny_224_init1.yaml │ │ ├── vmambav2_tiny_224_init2.yaml │ │ ├── vmambav2_tiny_224_m2s2h.yaml │ │ ├── vmambav2_tiny_224_m3s1h.yaml │ │ ├── vmambav2_tiny_224_ndw.yaml │ │ ├── vmambav2_tiny_224_ondw.yaml │ │ ├── vmambav2_tiny_224_onone.yaml │ │ ├── vmambav2_tiny_224_onsoftmax.yaml │ │ ├── vmambav2_tiny_224_posndw.yaml │ │ ├── vmambav2_tiny_224_relu.yaml │ │ ├── vmambav2_tiny_224_sr1hl5.yaml │ │ ├── vmambav2_tiny_224_sr1l5.yaml │ │ ├── vmambav2_tiny_224_unidi.yaml │ │ └── vmambav2_tiny_224_unidi_ndw.yaml │ └── wasted │ │ ├── vssm01 │ │ ├── vmambav2_tiny_224.yaml │ │ ├── vssm_base_224_a0.yaml │ │ ├── vssm_base_224_a6.yaml │ │ ├── vssm_base_224_aav1.yaml │ │ ├── vssm_base_224_ahv1_0423.yaml │ │ ├── vssm_base_224_ahv3.yaml │ │ ├── vssm_small_224_a0.yaml │ │ ├── vssm_small_224_a6.yaml │ │ ├── vssm_small_224_aav1.yaml │ │ ├── vssm_small_224_ahv3.yaml │ │ ├── vssm_tiny_224_a9v1.yaml │ │ ├── vssm_tiny_224_a9v2.yaml │ │ ├── vssm_tiny_224_a9v3.yaml │ │ ├── vssm_tiny_224_aaa.yaml │ │ ├── vssm_tiny_224_aav1.yaml │ │ ├── vssm_tiny_224_aav2.yaml │ │ ├── vssm_tiny_224_abv2.yaml │ │ ├── vssm_tiny_224_abv3.yaml │ │ ├── vssm_tiny_224_abv4.yaml │ │ ├── vssm_tiny_224_aca.yaml │ │ ├── vssm_tiny_224_acv1.yaml │ │ ├── vssm_tiny_224_acv1_61.yaml │ │ ├── vssm_tiny_224_acv1_66.yaml │ │ ├── vssm_tiny_224_acv1_67.yaml │ │ ├── vssm_tiny_224_acv1_68.yaml │ │ ├── vssm_tiny_224_acv2.yaml │ │ ├── vssm_tiny_224_acv3.yaml │ │ ├── vssm_tiny_224_acv4.yaml │ │ ├── vssm_tiny_224_adv1_mini.yaml │ │ ├── vssm_tiny_224_adv1_mini2.yaml │ │ ├── vssm_tiny_224_ahv3_0420.yaml │ │ └── vssm_tiny_224_aiv1.yaml │ │ ├── vssm1 │ │ ├── vssm_base_224.yaml │ │ ├── vssm_mini_224.yaml │ │ ├── vssm_small_224.yaml │ │ ├── vssm_tiny_224.yaml │ │ └── vssm_tiny_224_0220.yaml │ │ ├── vssm_base_224_ahv1.yaml │ │ ├── vssm_base_224_ahv1_0421.yaml │ │ ├── vssm_base_224_ahv1_0422.yaml │ │ ├── vssm_base_224_aiv1.yaml │ │ ├── vssm_base_224_aiv1_dp06.yaml │ │ ├── vssm_small_224_ahv1.yaml │ │ ├── vssm_small_224_ahv1_0421.yaml │ │ ├── vssm_small_224_ahv1_0422.yaml │ │ ├── vssm_small_224_aiv1.yaml │ │ ├── vssm_small_224_aiv1_dp04.yaml │ │ ├── vssm_tiny_224_0211.yaml │ │ ├── vssm_tiny_224_0211v1.yaml │ │ ├── vssm_tiny_224_0212.yaml │ │ ├── vssm_tiny_224_0213.yaml │ │ ├── vssm_tiny_224_0215.yaml │ │ ├── vssm_tiny_224_0216.yaml │ │ ├── vssm_tiny_224_0217.yaml │ │ ├── vssm_tiny_224_0218.yaml │ │ ├── vssm_tiny_224_0219.yaml │ │ ├── vssm_tiny_224_0221.yaml │ │ ├── vssm_tiny_224_0222.yaml │ │ ├── vssm_tiny_224_0223.yaml │ │ ├── vssm_tiny_224_0224.yaml │ │ ├── vssm_tiny_224_0225.yaml │ │ ├── vssm_tiny_224_0229.yaml │ │ ├── vssm_tiny_224_0229flex.yaml │ │ ├── vssm_tiny_224_0230.yaml │ │ ├── vssm_tiny_224_0230ab1d.yaml │ │ ├── vssm_tiny_224_0230ab2d.yaml │ │ ├── vssm_tiny_224_0309.yaml │ │ ├── vssm_tiny_224_0310.yaml │ │ ├── vssm_tiny_224_0311.yaml │ │ ├── vssm_tiny_224_0312.yaml │ │ ├── vssm_tiny_224_0313.yaml │ │ ├── vssm_tiny_224_0314.yaml │ │ ├── vssm_tiny_224_0315.yaml │ │ ├── vssm_tiny_224_0316.yaml │ │ ├── vssm_tiny_224_0317.yaml │ │ ├── vssm_tiny_224_0318.2.yaml │ │ ├── vssm_tiny_224_0318.yaml │ │ ├── vssm_tiny_224_0319.yaml │ │ ├── vssm_tiny_224_0320.yaml │ │ ├── vssm_tiny_224_0321.yaml │ │ ├── vssm_tiny_224_0322.yaml │ │ ├── vssm_tiny_224_0323.yaml │ │ ├── vssm_tiny_224_0324.yaml │ │ ├── vssm_tiny_224_0325.yaml │ │ ├── vssm_tiny_224_0326.yaml │ │ ├── vssm_tiny_224_0327.yaml │ │ ├── vssm_tiny_224_1.yaml │ │ ├── vssm_tiny_224_1v1.yaml │ │ ├── vssm_tiny_224_a8d.yaml │ │ ├── vssm_tiny_224_a9.yaml │ │ ├── vssm_tiny_224_a9a.yaml │ │ ├── vssm_tiny_224_aa.yaml │ │ ├── vssm_tiny_224_abv1.yaml │ │ ├── vssm_tiny_224_acb.yaml │ │ ├── vssm_tiny_224_acv1_0401.yaml │ │ ├── vssm_tiny_224_acv1_0403.yaml │ │ ├── vssm_tiny_224_acv1_0405.yaml │ │ ├── vssm_tiny_224_acv1_0406.yaml │ │ ├── vssm_tiny_224_acv1_0407.yaml │ │ ├── vssm_tiny_224_acv1_0408.yaml │ │ ├── vssm_tiny_224_acv1_0409.yaml │ │ ├── vssm_tiny_224_acv1_0410.yaml │ │ ├── vssm_tiny_224_acv1_6.yaml │ │ ├── vssm_tiny_224_acv1_62.yaml │ │ ├── vssm_tiny_224_acv1_62_0415.yaml │ │ ├── vssm_tiny_224_acv1_63.yaml │ │ ├── vssm_tiny_224_acv1_64.yaml │ │ ├── vssm_tiny_224_acv1_65.yaml │ │ ├── vssm_tiny_224_adv1.yaml │ │ ├── vssm_tiny_224_adv1c.yaml │ │ ├── vssm_tiny_224_aev1.yaml │ │ ├── vssm_tiny_224_aev1c.yaml │ │ ├── vssm_tiny_224_afv1.yaml │ │ ├── vssm_tiny_224_agv1.yaml │ │ ├── vssm_tiny_224_ahv1.yaml │ │ ├── vssm_tiny_224_ahv3.yaml │ │ └── vssm_tiny_224_ahv3_0418.yaml │ ├── data │ ├── __init__.py │ ├── build.py │ ├── cached_image_folder.py │ ├── data_simmim_ft.py │ ├── data_simmim_pt.py │ ├── imagenet22k_dataset.py │ ├── map22kto1k.txt │ ├── samplers.py │ └── zipreader.py │ ├── main.py │ ├── models │ ├── __init__.py │ ├── csm_triton.py │ ├── csms6s.py │ ├── mamba2 │ │ ├── __init__.py │ │ ├── k_activations.py │ │ ├── layer_norm.py │ │ ├── layernorm_gated.py │ │ ├── selective_state_update.py │ │ ├── ssd_bmm.py │ │ ├── ssd_chunk_scan.py │ │ ├── ssd_chunk_state.py │ │ ├── ssd_combined.py │ │ ├── ssd_minimal.py │ │ └── ssd_state_passing.py │ ├── vmamba.py │ ├── vmamba_checks.py │ └── vmamba_v02.py │ ├── readme.md │ ├── requirements.txt │ └── utils │ ├── cosine_lr.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── optimizer.py │ └── utils.py ├── kernels └── selective_scan │ ├── README.md │ ├── csrc │ └── selective_scan │ │ ├── cub_extra.cuh │ │ ├── cus │ │ ├── selective_scan.cpp │ │ ├── selective_scan_bwd_kernel.cuh │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_fwd.cu │ │ └── selective_scan_fwd_kernel.cuh │ │ ├── cusndstate │ │ ├── selective_scan_bwd_kernel_ndstate.cuh │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_fwd.cu │ │ ├── selective_scan_fwd_kernel_ndstate.cuh │ │ ├── selective_scan_ndstate.cpp │ │ └── selective_scan_ndstate.h │ │ ├── cusnrow │ │ ├── selective_scan_bwd_kernel_nrow.cuh │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_bwd2.cu │ │ ├── selective_scan_core_bwd3.cu │ │ ├── selective_scan_core_bwd4.cu │ │ ├── selective_scan_core_fwd.cu │ │ ├── selective_scan_core_fwd2.cu │ │ ├── selective_scan_core_fwd3.cu │ │ ├── selective_scan_core_fwd4.cu │ │ ├── selective_scan_fwd_kernel_nrow.cuh │ │ └── selective_scan_nrow.cpp │ │ ├── cusoflex │ │ ├── selective_scan_bwd_kernel_oflex.cuh │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_fwd.cu │ │ ├── selective_scan_fwd_kernel_oflex.cuh │ │ └── selective_scan_oflex.cpp │ │ ├── reverse_scan.cuh │ │ ├── selective_scan.h │ │ ├── selective_scan_common.h │ │ ├── static_switch.h │ │ └── uninitialized_copy.cuh │ ├── setup.py │ ├── test_selective_scan.py │ ├── test_selective_scan_easy.py │ └── test_selective_scan_speed.py ├── requirements.txt ├── setup.cfg └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/.DS_Store -------------------------------------------------------------------------------- /Enhancement/cal_metrics_with_imgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | class DummyFile(object): 4 | def write(self, x): pass 5 | sys.stdout = DummyFile() 6 | import argparse 7 | from tqdm import tqdm 8 | import numpy as np 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | import utils 12 | from natsort import natsorted 13 | from glob import glob 14 | from skimage import img_as_ubyte 15 | import lpips 16 | from basicsr import calculate_niqe 17 | sys.stdout = sys.__stdout__ 18 | 19 | parser = argparse.ArgumentParser(description='Image Enhancement') 20 | 21 | parser.add_argument('--pred_dir', type=str, help='Dir of predited images') 22 | parser.add_argument('--target_dir', type=str, default='', help='Dir of targets') 23 | parser.add_argument('--psnr', action='store_true', help='True to compute PSNR') 24 | parser.add_argument('--ssim', action='store_true', help='True to compute SSIM') 25 | parser.add_argument('--lpips', action='store_true', help='True to compute LPIPS') 26 | parser.add_argument('--niqe', action='store_true', help='True to compute NIQE') 27 | 28 | args = parser.parse_args() 29 | 30 | pred_paths = natsorted( glob(os.path.join(args.pred_dir, '*.png')) + glob(os.path.join(args.pred_dir, '*.jpg')) + glob(os.path.join(args.pred_dir, '*.bmp')) ) 31 | if args.target_dir != '': 32 | target_paths = natsorted( glob(os.path.join(args.target_dir, '*.png')) + glob(os.path.join(args.target_dir, '*.jpg')) + glob(os.path.join(args.target_dir, '*.bmp')) ) 33 | 34 | psnr = [] 35 | ssim = [] 36 | lpips_ = [] 37 | niqe = [] 38 | 39 | if args.lpips: 40 | loss_fn = lpips.LPIPS(net='alex', verbose=False) 41 | loss_fn.cuda() 42 | 43 | for p_idx, pred_path in tqdm(enumerate(pred_paths), total=len(pred_paths)): 44 | pred = np.float32(utils.load_img(pred_path)) / 255. 45 | if args.target_dir != '': 46 | target = np.float32(utils.load_img(target_paths[p_idx])) / 255. 47 | if args.psnr: 48 | psnr.append(utils.calculate_psnr(target, pred)) 49 | if args.ssim: 50 | ssim.append(utils.calculate_ssim(img_as_ubyte(target), img_as_ubyte(pred))) 51 | if args.lpips: 52 | ex_p0 = lpips.im2tensor(img_as_ubyte(pred)).cuda() 53 | ex_ref = lpips.im2tensor(img_as_ubyte(target)).cuda() 54 | score_lpips = loss_fn.forward(ex_ref, ex_p0).item() 55 | lpips_.append(score_lpips) 56 | if args.niqe: 57 | niqe.append(calculate_niqe(img_as_ubyte(pred), crop_border=0)) 58 | if args.psnr: 59 | psnr = np.mean(np.array(psnr)) 60 | print("Best_PSNR: {:.4f} dB".format(psnr)) 61 | if args.ssim: 62 | ssim = np.mean(np.array(ssim)) 63 | print("Best_SSIM: {:.4f}".format(ssim)) 64 | if args.lpips: 65 | lpips_ = np.mean(np.array(lpips_)) 66 | print("Best_lpips: {:.4f}".format(lpips_)) 67 | if args.niqe: 68 | niqe = np.mean(np.array(niqe)) 69 | print("Best_NIQE: {:.4f}".format(niqe)) 70 | 71 | 72 | -------------------------------------------------------------------------------- /Enhancement/eval_uciqe_uiqm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from glob import glob 3 | import os 4 | import sys 5 | from natsort import natsorted 6 | from PIL import Image 7 | import numpy as np 8 | class DummyFile(object): 9 | def write(self, x): pass 10 | sys.stdout = DummyFile() 11 | from basicsr.metrics import getUCIQE, getUIQM 12 | from basicsr.metrics.uciqe_uiqm import UIQM 13 | 14 | sys.stdout = sys.__stdout__ 15 | 16 | def get_average_UCIQE_and_UICM(img_dir): 17 | img_paths = natsorted( glob(os.path.join(img_dir, '*.png')) + glob(os.path.join(img_dir, '*.jpg')) + glob(os.path.join(img_dir, '*.bmp')) ) 18 | total_uciqe = 0 19 | total_uiqm = 0 20 | uiqm_inst = UIQM() 21 | for img_path in img_paths: 22 | 23 | image_RGB = Image.open(img_path) 24 | total_uciqe += getUCIQE(np.array(image_RGB)) 25 | 26 | image = Image.open(img_path) 27 | original_width, original_height = image.size 28 | new_width = 256 29 | new_height = int((new_width / original_width) * original_height) 30 | resized_image = image.resize((new_width, new_height)) 31 | image_RGB = np.array(resized_image) 32 | 33 | total_uiqm += getUIQM(image_RGB) 34 | average_uciqe = total_uciqe / len(img_paths) 35 | average_uiqm = total_uiqm / len(img_paths) 36 | 37 | return average_uciqe, average_uiqm 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser(description="Calculate average UCIQE and UIQM for a directory of images.") 41 | parser.add_argument('img_dir', type=str, help="Path to the directory containing the images.") 42 | 43 | args = parser.parse_args() 44 | 45 | average_uciqe, average_uiqm = get_average_UCIQE_and_UICM(args.img_dir) 46 | 47 | print(f"Average UCIQE: {average_uciqe:.4f}") 48 | print(f"Average UIQM: {average_uiqm:.4f}") 49 | -------------------------------------------------------------------------------- /Enhancement/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | def calculate_psnr(img1, img2): 6 | mse_ = np.mean((img1 - img2) ** 2) 7 | if mse_ == 0: 8 | return 100 9 | return 10 * math.log10(1 / mse_) 10 | 11 | 12 | def calculate_ssim(img1, img2, border=0): 13 | '''calculate SSIM 14 | the same outputs as MATLAB's 15 | img1, img2: [0, 255] 16 | ''' 17 | if not img1.shape == img2.shape: 18 | raise ValueError('Input images must have the same dimensions.') 19 | h, w = img1.shape[:2] 20 | img1 = img1[border:h - border, border:w - border] 21 | img2 = img2[border:h - border, border:w - border] 22 | 23 | if img1.ndim == 2: 24 | return ssim(img1, img2) 25 | elif img1.ndim == 3: 26 | if img1.shape[2] == 3: 27 | ssims = [] 28 | for i in range(3): 29 | ssims.append(ssim(img1[:, :, i], img2[:, :, i])) 30 | return np.array(ssims).mean() 31 | elif img1.shape[2] == 1: 32 | return ssim(np.squeeze(img1), np.squeeze(img2)) 33 | else: 34 | raise ValueError('Wrong input image dimensions.') 35 | 36 | 37 | def ssim(img1, img2): 38 | C1 = (0.01 * 255)**2 39 | C2 = (0.03 * 255)**2 40 | 41 | img1 = img1.astype(np.float64) 42 | img2 = img2.astype(np.float64) 43 | kernel = cv2.getGaussianKernel(11, 1.5) 44 | window = np.outer(kernel, kernel.transpose()) 45 | 46 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 47 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 48 | mu1_sq = mu1**2 49 | mu2_sq = mu2**2 50 | mu1_mu2 = mu1 * mu2 51 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 52 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 53 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 54 | 55 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 56 | (sigma1_sq + sigma2_sq + C2)) 57 | return ssim_map.mean() 58 | 59 | 60 | def load_img(filepath): 61 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 62 | 63 | 64 | def save_img(filepath, img): 65 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 66 | 67 | 68 | def load_gray_img(filepath): 69 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 70 | 71 | 72 | def save_gray_img(filepath, img): 73 | cv2.imwrite(filepath, img) 74 | 75 | 76 | def visualization(feature, save_path, type='max', colormap=cv2.COLORMAP_JET): 77 | ''' 78 | :param feature: [C,H,W] 79 | :param save_path: saving path 80 | :param type: 'mean' or 'max' 81 | :param colormap: the type of the pseudocolor map 82 | ''' 83 | feature = feature.cpu().numpy() 84 | if type == 'mean': 85 | feature = np.mean(feature, axis=0) 86 | else: 87 | feature = np.max(feature, axis=0) 88 | normed_feat = (feature - feature.min()) / (feature.max() - feature.min()) 89 | normed_feat = (normed_feat * 255).astype('uint8') 90 | color_feat = cv2.applyColorMap(normed_feat, colormap) 91 | cv2.imwrite(save_path, color_feat) 92 | 93 | 94 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Anonymous1563 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Options/CG_UNet_LOLv1.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CG_UNet_LOLv1 3 | model_type: ConditionGenerator 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | condition: &condition 8 | type: mean # mean using downsampled image, histogram using histogram difference 9 | scale_down: 16 10 | weights_from_phaseI: ~ 11 | sigma_init: 0.05 12 | selective: false 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | train: 17 | name: TrainSet 18 | type: Dataset_PairedImage_Mask 19 | dataroot_gt: "/mnt/e/datasets/LOLv1/Train/target" 20 | dataroot_lq: "/mnt/e/datasets/LOLv1/Train/input" 21 | geometric_augs: true 22 | condition: *condition 23 | labelnoise: 24 | tem_mean: 1.0 25 | tem_var: 0.01 26 | bright_mean: 1.0 27 | bright_var: 0.01 28 | contrast_mean: 1. 29 | contrast_var: 0.01 30 | 31 | filename_tmpl: '{}' 32 | io_backend: 33 | type: disk 34 | 35 | # data loader 36 | use_shuffle: true 37 | num_worker_per_gpu: 8 38 | batch_size_per_gpu: 8 39 | 40 | ### ------- Training on single fixed-patch size--------- 41 | mini_batch_sizes: [8] 42 | iters: [300000] 43 | gt_size: 384 44 | gt_sizes: [384] 45 | ### ------------------------------------------------------------ 46 | 47 | dataset_enlarge_ratio: 1 48 | prefetch_mode: ~ 49 | 50 | val: 51 | name: ValSet 52 | type: Dataset_PairedImage_Mask 53 | dataroot_gt: "/mnt/e/datasets/LOLv1/Test/target" 54 | dataroot_lq: "/mnt/e/datasets/LOLv1/Test/input" 55 | condition: *condition 56 | io_backend: 57 | type: disk 58 | 59 | # network structures 60 | network_g: 61 | type: Network 62 | in_channels: 3 63 | out_channels: 3 64 | n_feat: 40 65 | d_state: [1,1,1] 66 | ssm_ratio: 1 67 | mlp_ratio: 4 68 | mlp_type: gdmlp 69 | use_pixelshuffle: true 70 | drop_path: 0. 71 | stage: 1 72 | num_blocks: [2,2,2] 73 | 74 | # path 75 | path: 76 | pretrain_network_g: ~ 77 | strict_load_g: true 78 | resume_state: ~ 79 | 80 | # training settings 81 | train: 82 | total_iter: 300000 83 | warmup_iter: -1 # no warm up 84 | max_grad_norm: 1 85 | 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [200000, 50000, 50000] 89 | restart_weights: [1, 1, 1] 90 | eta_mins: [0.0002, 0.0002, 0.000001] 91 | 92 | optim_g: 93 | type: Adam 94 | lr: 0.0002 95 | # weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | mixing_augs: 99 | mixup: false 100 | # mixup_beta: 1.2 101 | # use_identity: true 102 | 103 | 104 | 105 | pixel_opt: 106 | type: MSELoss 107 | loss_weight: 1 108 | reduction: mean 109 | 110 | 111 | # validation settings 112 | val: 113 | window_size: 4 114 | val_freq: !!float 1e3 115 | save_img: true 116 | rgb2bgr: true 117 | use_image: true 118 | max_minibatch: 8 119 | 120 | metrics: 121 | psnr: # metric name, can be arbitrary 122 | type: calculate_psnr 123 | crop_border: 0 124 | test_y_channel: false 125 | # ssim: 126 | # type: calculate_ssim 127 | # crop_border: 0 128 | # test_y_channel: false 129 | 130 | # logging settings 131 | logger: 132 | print_freq: 100 133 | save_checkpoint_freq: !!float 1e3 134 | use_tb_logger: true 135 | record_grad: false 136 | wandb: 137 | project: underwater 138 | resume_id: ~ 139 | 140 | # dist training settings 141 | dist_params: 142 | backend: nccl 143 | port: 29500 144 | -------------------------------------------------------------------------------- /Options/CG_UNet_LOLv2Syn.yml: -------------------------------------------------------------------------------- 1 | name: CG_UNet_LOLv2Syn 2 | model_type: ConditionGenerator 3 | scale: 1 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 100 6 | condition: &condition 7 | type: mean # mean using downsampled image, histogram using histogram difference 8 | scale_down: 16 9 | weights_from_phaseI: ~ 10 | sigma_init: 0.05 11 | selective: false 12 | 13 | # dataset and data loader settings 14 | datasets: 15 | train: 16 | name: TrainSet 17 | type: Dataset_PairedImage_Mask 18 | dataroot_gt: /mnt/e/datasets/LOLv2/Synthetic/Train/Normal 19 | dataroot_lq: /mnt/e/datasets/LOLv2/Synthetic/Train/Low 20 | geometric_augs: true 21 | condition: *condition 22 | labelnoise: 23 | tem_mean: 1.0 24 | tem_var: 0.01 25 | bright_mean: 1.0 26 | bright_var: 0.01 27 | contrast_mean: 1. 28 | contrast_var: 0.01 29 | 30 | filename_tmpl: '{}' 31 | io_backend: 32 | type: disk 33 | 34 | # data loader 35 | use_shuffle: true 36 | num_worker_per_gpu: 8 37 | batch_size_per_gpu: 8 38 | 39 | ### ------- Training on single fixed-patch size--------- 40 | mini_batch_sizes: [8] 41 | iters: [300000] 42 | gt_size: 384 43 | gt_sizes: [384] 44 | ### ------------------------------------------------------------ 45 | 46 | dataset_enlarge_ratio: 1 47 | prefetch_mode: ~ 48 | 49 | val: 50 | name: ValSet 51 | type: Dataset_PairedImage_Mask 52 | dataroot_gt: /mnt/e/datasets/LOLv2/Synthetic/Test/Normal 53 | dataroot_lq: /mnt/e/datasets/LOLv2/Synthetic/Test/Low 54 | condition: *condition 55 | io_backend: 56 | type: disk 57 | 58 | # network structures 59 | network_g: 60 | type: Network 61 | in_channels: 3 62 | out_channels: 3 63 | n_feat: 40 64 | d_state: [1,1,1] 65 | ssm_ratio: 1 66 | mlp_ratio: 4 67 | mlp_type: gdmlp 68 | use_pixelshuffle: true 69 | drop_path: 0. 70 | stage: 1 71 | num_blocks: [2,2,2] 72 | 73 | # path 74 | path: 75 | pretrain_network_g: ~ 76 | strict_load_g: true 77 | resume_state: ~ 78 | 79 | # training settings 80 | train: 81 | total_iter: 300000 82 | warmup_iter: -1 # no warm up 83 | max_grad_norm: 1 84 | 85 | scheduler: 86 | type: CosineAnnealingRestartCyclicLR 87 | periods: [150000, 46000, 104000] 88 | restart_weights: [1, 1, 1] 89 | eta_mins: [0.0002, 0.0002, 0.000001] 90 | 91 | optim_g: 92 | type: Adam 93 | lr: 0.0002 94 | # weight_decay: !!float 1e-4 95 | betas: [0.9, 0.999] 96 | 97 | mixing_augs: 98 | mixup: false 99 | # mixup_beta: 1.2 100 | # use_identity: true 101 | 102 | pixel_opt: 103 | type: MSELoss 104 | loss_weight: 1 105 | reduction: mean 106 | 107 | 108 | # validation settings 109 | val: 110 | window_size: 4 111 | val_freq: !!float 1e3 112 | save_img: true 113 | rgb2bgr: true 114 | use_image: true 115 | max_minibatch: 8 116 | 117 | metrics: 118 | psnr: # metric name, can be arbitrary 119 | type: calculate_psnr 120 | crop_border: 0 121 | test_y_channel: false 122 | # ssim: 123 | # type: calculate_ssim 124 | # crop_border: 0 125 | # test_y_channel: false 126 | 127 | # logging settings 128 | logger: 129 | print_freq: 100 130 | save_checkpoint_freq: !!float 1e3 131 | use_tb_logger: true 132 | record_grad: false 133 | wandb: 134 | project: underwater 135 | resume_id: ~ 136 | 137 | # dist training settings 138 | dist_params: 139 | backend: nccl 140 | port: 29500 141 | -------------------------------------------------------------------------------- /Options/CG_UNet_UIEB.yml: -------------------------------------------------------------------------------- 1 | 2 | # general settings 3 | name: CG_UNet_UIEB 4 | model_type: ConditionGenerator 5 | scale: 1 6 | num_gpu: 1 # set num_gpu: 0 for cpu mode 7 | manual_seed: 100 8 | condition: &condition 9 | type: mean # mean using downsampled image, histogram using histogram difference 10 | scale_down: 16 11 | weights_from_phaseI: ~ 12 | sigma_init: 0.05 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | train: 17 | name: TrainSet 18 | type: Dataset_PairedImage_Mask 19 | dataroot_gt: /mnt/e/datasets/UIEB/Train/gt 20 | dataroot_lq: /mnt/e/datasets/UIEB/Train/input 21 | geometric_augs: true 22 | condition: *condition 23 | labelnoise: 24 | tem_mean: 1.0 25 | tem_var: 0.01 26 | bright_mean: 1.0 27 | bright_var: 0.01 28 | contrast_mean: 1. 29 | contrast_var: 0.01 30 | 31 | filename_tmpl: '{}' 32 | io_backend: 33 | type: disk 34 | 35 | # data loader 36 | use_shuffle: true 37 | num_worker_per_gpu: 8 38 | batch_size_per_gpu: 8 39 | 40 | ### ------- Training on single fixed-patch size--------- 41 | mini_batch_sizes: [8] 42 | iters: [300000] 43 | gt_size: 384 44 | gt_sizes: [384] 45 | ### ------------------------------------------------------------ 46 | 47 | dataset_enlarge_ratio: 1 48 | prefetch_mode: ~ 49 | 50 | val: 51 | name: ValSet 52 | type: Dataset_PairedImage_Mask 53 | dataroot_gt: /mnt/e/datasets/UIEB/Test/gt 54 | dataroot_lq: /mnt/e/datasets/UIEB/Test/input 55 | condition: *condition 56 | io_backend: 57 | type: disk 58 | 59 | # network structures 60 | network_g: 61 | type: Network 62 | in_channels: 3 63 | out_channels: 3 64 | n_feat: 40 65 | d_state: [1,1,1] 66 | ssm_ratio: 1 67 | mlp_ratio: 4 68 | mlp_type: gdmlp 69 | use_pixelshuffle: true 70 | drop_path: 0. 71 | stage: 1 72 | num_blocks: [2,2,2] 73 | 74 | # path 75 | path: 76 | pretrain_network_g: ~ 77 | strict_load_g: true 78 | resume_state: ~ 79 | 80 | # training settings 81 | train: 82 | total_iter: 300000 83 | warmup_iter: -1 # no warm up 84 | max_grad_norm: 1 85 | 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [150000, 46000, 104000] 89 | restart_weights: [1, 1, 1] 90 | eta_mins: [0.0002, 0.0002, 0.000001] 91 | 92 | optim_g: 93 | type: Adam 94 | lr: 0.0002 95 | # weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | mixing_augs: 99 | mixup: false 100 | # mixup_beta: 1.2 101 | # use_identity: true 102 | 103 | 104 | 105 | pixel_opt: 106 | type: MSELoss 107 | loss_weight: 1 108 | reduction: mean 109 | 110 | 111 | 112 | # validation settings 113 | val: 114 | window_size: 4 115 | val_freq: !!float 3e3 116 | save_img: true 117 | rgb2bgr: true 118 | use_image: true 119 | max_minibatch: 8 120 | 121 | metrics: 122 | psnr: # metric name, can be arbitrary 123 | type: calculate_psnr 124 | crop_border: 0 125 | test_y_channel: false 126 | # ssim: 127 | # type: calculate_ssim 128 | # crop_border: 0 129 | # test_y_channel: false 130 | 131 | # logging settings 132 | logger: 133 | print_freq: 100 134 | save_checkpoint_freq: !!float 3e3 135 | use_tb_logger: true 136 | record_grad: false 137 | wandb: 138 | project: underwater 139 | resume_id: ~ 140 | 141 | # dist training settings 142 | dist_params: 143 | backend: nccl 144 | port: 29500 145 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /analysis/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/analysis/.DS_Store -------------------------------------------------------------------------------- /analysis/flops_param.py: -------------------------------------------------------------------------------- 1 | import basicsr.archs as archs 2 | # from model_zoo.HWMNet import buildHWMNet 3 | # from basicsr.archs.RetinexMamba_arch import RetinexMamba 4 | # from model_zoo.RetinexFormer import buildRetinexFormer 5 | # from model_zoo.UVMNet import buildUVMNet 6 | # from model_zoo.LLFormer import buildLLFormer 7 | from analysis.util import FLOPs 8 | from analysis.util import Throughput 9 | from basicsr.bayesian import convert2bnn, convert2bnn_selective 10 | import logging 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | class RandomTensorDataset(Dataset): 15 | def __init__(self, size=(3, 256, 256), length=30): 16 | self.size = size 17 | self.length = length 18 | 19 | def __len__(self): 20 | return self.length 21 | 22 | def __getitem__(self, idx): 23 | return torch.randn(self.size).cuda(), 0 24 | 25 | if __name__ == '__main__': 26 | 27 | H, W = 16, 16 28 | model = archs.BUNet_arch.build_model() 29 | bnn_config = { 30 | "sigma_init": 0.05, 31 | "decay": 0.998, 32 | "pretrain": False, 33 | } 34 | convert2bnn_selective(model, bnn_config) 35 | # model = archs.UNet_arch.build_model() 36 | # model = buildHWMNet() 37 | # model = RetinexMamba(in_channels=3, out_channels=3, n_feat=40, stage=1, num_blocks=[1,2,2]) 38 | # model = buildRetinexFormer() 39 | # model = buildUVMNet() 40 | # model = buildLLFormer() 41 | 42 | model.cuda() 43 | 44 | n_param, flops = FLOPs.fvcore_flop_count(model, input_shape=(1, 3, H, W), verbose=False) 45 | print(f'FLOPs:{flops:.3f}G') 46 | print(f'Params:{n_param/(1000*1000):.3f}M') 47 | 48 | dataset = RandomTensorDataset(size=(3, H, W), length=30) 49 | dataloader = DataLoader(dataset, batch_size=1) 50 | Throughput.throughput(data_loader=dataloader, model=model, logger=logging) 51 | 52 | 53 | -------------------------------------------------------------------------------- /analysis/model_zoo/edsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | import sys 4 | sys.path.append('..') 5 | from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | 9 | 10 | class EDSR(nn.Module): 11 | """EDSR network structure. 12 | 13 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. 14 | Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch 15 | 16 | Args: 17 | num_in_ch (int): Channel number of inputs. 18 | num_out_ch (int): Channel number of outputs. 19 | num_feat (int): Channel number of intermediate features. 20 | Default: 64. 21 | num_block (int): Block number in the trunk network. Default: 16. 22 | upscale (int): Upsampling factor. Support 2^n and 3. 23 | Default: 4. 24 | res_scale (float): Used to scale the residual in residual block. 25 | Default: 1. 26 | img_range (float): Image range. Default: 255. 27 | rgb_mean (tuple[float]): Image mean in RGB orders. 28 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 29 | """ 30 | 31 | def __init__(self, 32 | num_in_ch=3, 33 | num_out_ch=3, 34 | num_feat=64, 35 | num_block=16, 36 | upscale=2, 37 | res_scale=1, 38 | img_range=255., 39 | rgb_mean=(0.4488, 0.4371, 0.4040)): 40 | super(EDSR, self).__init__() 41 | 42 | self.img_range = img_range 43 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 44 | 45 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 46 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) 47 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 48 | self.upsample = Upsample(upscale, num_feat) 49 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 50 | 51 | def forward(self, x): 52 | self.mean = self.mean.type_as(x) 53 | 54 | x = (x - self.mean) * self.img_range 55 | x = self.conv_first(x) 56 | res = self.conv_after_body(self.body(x)) 57 | res += x 58 | 59 | x = self.conv_last(self.upsample(res)) 60 | x = x / self.img_range + self.mean 61 | 62 | return x 63 | 64 | 65 | def buildEDSR(): 66 | return EDSR() 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /analysis/show/erf/erf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/analysis/show/erf/erf.png -------------------------------------------------------------------------------- /analysis/show/erf/train1000_erf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/analysis/show/erf/train1000_erf.png -------------------------------------------------------------------------------- /analysis/show/erf/trained_erf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/analysis/show/erf/trained_erf.png -------------------------------------------------------------------------------- /assets/RetinexFormer_LOL_v2_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_LOL_v2_real.png -------------------------------------------------------------------------------- /assets/RetinexFormer_LOL_v2_synthetic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_LOL_v2_synthetic.png -------------------------------------------------------------------------------- /assets/RetinexFormer_NTIRE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_NTIRE.png -------------------------------------------------------------------------------- /assets/RetinexFormer_SDSD_indoor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_SDSD_indoor.png -------------------------------------------------------------------------------- /assets/RetinexFormer_SDSD_outdoor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_SDSD_outdoor.png -------------------------------------------------------------------------------- /assets/RetinexFormer_SID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_SID.png -------------------------------------------------------------------------------- /assets/RetinexFormer_SMID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/RetinexFormer_SMID.png -------------------------------------------------------------------------------- /assets/clip_bright.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/clip_bright.gif -------------------------------------------------------------------------------- /assets/clip_default.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/clip_default.gif -------------------------------------------------------------------------------- /assets/clip_noise.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/clip_noise.gif -------------------------------------------------------------------------------- /assets/clip_quality.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/clip_quality.gif -------------------------------------------------------------------------------- /assets/dnnvsbnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/dnnvsbnn.png -------------------------------------------------------------------------------- /assets/input_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/input_demo.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/pipeline.png -------------------------------------------------------------------------------- /assets/pipelinev2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/pipelinev2.png -------------------------------------------------------------------------------- /assets/pred_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/pred_process.png -------------------------------------------------------------------------------- /assets/result_lol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/result_lol.png -------------------------------------------------------------------------------- /assets/result_uie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/result_uie.png -------------------------------------------------------------------------------- /assets/result_upaired5sets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/result_upaired5sets.png -------------------------------------------------------------------------------- /assets/twostagev2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/twostagev2.png -------------------------------------------------------------------------------- /assets/vis_5sets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/vis_5sets.png -------------------------------------------------------------------------------- /assets/vis_hd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/vis_hd.png -------------------------------------------------------------------------------- /assets/vis_uie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/assets/vis_uie.png -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | # from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 12 | arch_folder = osp.dirname(osp.abspath(__file__)) 13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 14 | # import all the arch modules 15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 16 | 17 | 18 | def build_network(opt): 19 | opt = deepcopy(opt) 20 | network_type = opt.pop('type') 21 | net = ARCH_REGISTRY.get(network_type)(**opt) 22 | logger = get_root_logger() 23 | logger.info(f'Network [{net.__class__.__name__}] is created.') 24 | return net 25 | -------------------------------------------------------------------------------- /basicsr/bayesian/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_layer import * 2 | from .conv import * 3 | from .norm import * 4 | from .linear import * 5 | from .tools import * 6 | -------------------------------------------------------------------------------- /basicsr/bayesian/base_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as distributions 4 | from itertools import repeat 5 | import collections 6 | from abc import abstractmethod 7 | 8 | class BaseLayer_(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, input): 13 | if not self.deterministic: 14 | return self._forward_uncertain(input) 15 | else: 16 | return self._forward_det(input) 17 | 18 | @abstractmethod 19 | def _forward_uncertain(self, input): 20 | pass 21 | 22 | @abstractmethod 23 | def _forward_det(self, input): 24 | pass 25 | 26 | def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): 27 | """ 28 | Calculates kl divergence between two gaussians (Q || P) 29 | 30 | Parameters: 31 | * mu_q: torch.Tensor -> mu parameter of distribution Q 32 | * sigma_q: torch.Tensor -> sigma parameter of distribution Q 33 | * mu_p: float -> mu parameter of distribution P 34 | * sigma_p: float -> sigma parameter of distribution P 35 | 36 | returns torch.Tensor of shape 0 37 | """ 38 | kl = torch.log(sigma_p) - torch.log(sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 39 | return kl.mean() 40 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SingleImageDataset(data.Dataset): 12 | """Read only lq images in the test phase. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 15 | 16 | There are two modes: 17 | 1. 'meta_info_file': Use meta information file to generate paths. 18 | 2. 'folder': Scan folders to generate paths. 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_lq (str): Data root path for lq. 23 | meta_info_file (str): Path for meta information file. 24 | io_backend (dict): IO backend type and other kwarg. 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(SingleImageDataset, self).__init__() 29 | self.opt = opt 30 | # file client (io backend) 31 | self.file_client = None 32 | self.io_backend_opt = opt['io_backend'] 33 | self.mean = opt['mean'] if 'mean' in opt else None 34 | self.std = opt['std'] if 'std' in opt else None 35 | self.lq_folder = opt['dataroot_lq'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = [self.lq_folder] 39 | self.io_backend_opt['client_keys'] = ['lq'] 40 | self.paths = paths_from_lmdb(self.lq_folder) 41 | elif 'meta_info_file' in self.opt: 42 | with open(self.opt['meta_info_file'], 'r') as fin: 43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 44 | else: 45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load lq image 52 | lq_path = self.paths[index] 53 | img_bytes = self.file_client.get(lq_path, 'lq') 54 | img_lq = imfrombytes(img_bytes, float32=True) 55 | 56 | # color space transform 57 | if 'color' in self.opt and self.opt['color'] == 'y': 58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 59 | 60 | # BGR to RGB, HWC to CHW, numpy to tensor 61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 62 | # normalize 63 | if self.mean is not None or self.std is not None: 64 | normalize(img_lq, self.mean, self.std, inplace=True) 65 | return {'lq': img_lq, 'lq_path': lq_path} 66 | 67 | def __len__(self): 68 | return len(self.paths) 69 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /basicsr/losses/my_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | from pytorch_msssim import ms_ssim 6 | from pytorch_msssim import ssim 7 | import torchvision.transforms as T 8 | 9 | class VGGPerceptualLoss(nn.Module): 10 | def __init__(self, device): 11 | super(VGGPerceptualLoss, self).__init__() 12 | vgg = models.vgg19(weights=True).features[:16] # Until block3_conv3 13 | self.loss_model = vgg.to(device).eval() 14 | for param in self.loss_model.parameters(): 15 | param.requires_grad = False 16 | 17 | def forward(self, y_true, y_pred): 18 | y_true, y_pred = y_true.to(next(self.loss_model.parameters()).device), y_pred.to(next(self.loss_model.parameters()).device) 19 | return F.mse_loss(self.loss_model(y_true), self.loss_model(y_pred)) 20 | 21 | 22 | def color_loss(y_true, y_pred): 23 | return torch.mean(torch.abs(torch.mean(y_true, dim=[1, 2, 3]) - torch.mean(y_pred, dim=[1, 2, 3]))) 24 | 25 | def psnr_loss(y_true, y_pred): 26 | mse = F.mse_loss(y_true, y_pred) 27 | psnr = 20 * torch.log10(1.0 / torch.sqrt(mse)) 28 | return 40.0 - torch.mean(psnr) 29 | 30 | def smooth_l1_loss(y_true, y_pred): 31 | return F.smooth_l1_loss(y_true, y_pred) 32 | 33 | def multiscale_ssim_loss(y_true, y_pred, max_val=1.0, power_factors=[0.5, 0.5]): 34 | return 1.0 - ms_ssim(y_true, y_pred, data_range=max_val, size_average=True) 35 | 36 | 37 | def ssim_loss(y_true, y_pred, max_val=1.0, power_factors=[0.5, 0.5]): 38 | return 1.0 - ssim(y_true, y_pred, data_range=max_val, size_average=True) 39 | 40 | def histogram_loss(y_true, y_pred, bins=256): 41 | y_true_hist = torch.histc(y_true, bins=bins, min=0.0, max=1.0) 42 | y_pred_hist = torch.histc(y_pred, bins=bins, min=0.0, max=1.0) 43 | 44 | y_true_hist = y_true_hist / y_true_hist.sum() 45 | y_pred_hist = y_pred_hist / y_pred_hist.sum() 46 | 47 | hist_distance = torch.mean(torch.abs(y_true_hist - y_pred_hist)) 48 | 49 | return hist_distance 50 | 51 | class CombinedLoss(nn.Module): 52 | def __init__(self, device): 53 | super(CombinedLoss, self).__init__() 54 | self.perceptual_loss_model = VGGPerceptualLoss(device) 55 | self.alpha1 = 1.00 56 | self.alpha2 = 0.06 57 | self.alpha3 = 0.05 58 | self.alpha4 = 0.5 59 | self.alpha5 = 0.0083 60 | self.alpha6 = 0.25 61 | 62 | def forward(self, y_true, y_pred): 63 | smooth_l1_l = smooth_l1_loss(y_true, y_pred) 64 | ms_ssim_l = ssim_loss(y_true, y_pred) 65 | perc_l = self.perceptual_loss_model(y_true, y_pred) 66 | hist_l = histogram_loss(y_true, y_pred) 67 | psnr_l = psnr_loss(y_true, y_pred) 68 | color_l = color_loss(y_true, y_pred) 69 | 70 | total_loss = (self.alpha1 * smooth_l1_l + self.alpha2 * perc_l + 71 | self.alpha3 * hist_l + self.alpha5 * psnr_l + 72 | self.alpha6 * color_l + self.alpha4 * ms_ssim_l) 73 | 74 | return torch.mean(total_loss) -------------------------------------------------------------------------------- /basicsr/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/README_CN.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .uciqe_uiqm import getUCIQE, getUIQM 6 | from .psnr_ssim import calculate_psnr, calculate_psnr_pt, calculate_ssim 7 | 8 | __all__ = ['calculate_psnr', 'calculate_psnr_pt', 'calculate_ssim', 'calculate_niqe'] 9 | 10 | 11 | def calculate_metric(data, opt): 12 | """Calculate metric from data and options. 13 | 14 | Args: 15 | opt (dict): Configuration. It must contain: 16 | type (str): Model type. 17 | """ 18 | opt = deepcopy(opt) 19 | metric_type = opt.pop('type') 20 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 21 | return metric 22 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: 54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 55 | Stable version by Dougal J. Sutherland. 56 | 57 | Args: 58 | mu1 (np.array): The sample mean over activations. 59 | sigma1 (np.array): The covariance matrix over activations for generated samples. 60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set. 61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. 62 | 63 | Returns: 64 | float: The Frechet Distance. 65 | """ 66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 68 | 69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 70 | 71 | # Product might be almost singular 72 | if not np.isfinite(cov_sqrt).all(): 73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(cov_sqrt): 79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(cov_sqrt.imag)) 81 | raise ValueError(f'Imaginary component {m}') 82 | cov_sqrt = cov_sqrt.real 83 | 84 | mean_diff = mu1 - mu2 85 | mean_norm = mean_diff @ mean_diff 86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 87 | fid = mean_norm + trace 88 | 89 | return fid 90 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/metrics/test_metrics/test_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from basicsr.metrics import calculate_psnr, calculate_ssim 5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt 6 | from basicsr.utils import img2tensor 7 | 8 | 9 | def test(img_path, img_path2, crop_border, test_y_channel=False): 10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) 12 | 13 | # --------------------- Numpy --------------------- 14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 17 | 18 | # --------------------- PyTorch (CPU) --------------------- 19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 21 | 22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 25 | 26 | # --------------------- PyTorch (GPU) --------------------- 27 | img = img.cuda() 28 | img2 = img2.cuda() 29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 32 | 33 | psnr_pth = calculate_psnr_pt( 34 | torch.repeat_interleave(img, 2, dim=0), 35 | torch.repeat_interleave(img2, 2, dim=0), 36 | crop_border=crop_border, 37 | test_y_channel=test_y_channel) 38 | ssim_pth = calculate_ssim_pt( 39 | torch.repeat_interleave(img, 2, dim=0), 40 | torch.repeat_interleave(img2, 2, dim=0), 41 | crop_border=crop_border, 42 | test_y_channel=test_y_channel) 43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' 44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') 45 | 46 | 47 | if __name__ == '__main__': 48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) 49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) 50 | 51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) 52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) 53 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | # logger = get_root_logger() 28 | # logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | 'padding', 41 | 'padding_DP', 42 | 'imfrombytesDP', 43 | # diffjpeg 44 | 'DiffJPEG', 45 | # img_process_util 46 | 'USMSharp', 47 | 'usm_sharp', 48 | # options 49 | 'yaml_load' 50 | ] 51 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/labelnoise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import random 4 | 5 | def clahe_enhancement(image, clip_limit=2.0, tile_grid_size=(8, 8)): 6 | image_255 = np.clip(image * 255.0, 0, 255).astype(np.uint8) 7 | clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) 8 | 9 | if len(image_255.shape) == 2: 10 | clahe_image = clahe.apply(image_255) 11 | else: 12 | yuv_image = cv2.cvtColor(image_255, cv2.COLOR_BGR2YUV) 13 | yuv_image[:, :, 0] = clahe.apply(yuv_image[:, :, 0]) 14 | clahe_image = cv2.cvtColor(yuv_image, cv2.COLOR_YUV2BGR) 15 | 16 | clahe_image = (clahe_image - clahe_image.min()) / (clahe_image.max() - clahe_image.min()) 17 | 18 | return clahe_image 19 | 20 | def adjust_color_temperature(image, temperature_factor): 21 | image_float = image.astype(np.float32) 22 | adjustment = np.array([1.0, 1.0, 1.0]) 23 | adjustment = np.array([temperature_factor, 1.0, 1.0 / temperature_factor]) 24 | img_adjusted = image_float * adjustment 25 | img_adjusted = np.clip(img_adjusted, 0, 1) 26 | return img_adjusted 27 | 28 | def adjust_contrast(image, contrast_factor): 29 | img_float = image.astype(np.float32) 30 | img_contrast = contrast_factor * (img_float - 0.5) + 0.5 31 | img_contrast = np.clip(img_contrast, 0, 1) 32 | return img_contrast 33 | 34 | 35 | def gamma_correction(image, gamma=1.0): 36 | inv_gamma = 1.0 / gamma 37 | table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("uint8") 38 | image_255 = np.clip(image * 255.0, 0, 255).astype(np.uint8) 39 | corrected = cv2.LUT(image_255, table) 40 | return np.clip(corrected / 255.0, 0, 1) 41 | 42 | def adjust_brightness(image, factor=1): 43 | image_float = image.astype(np.float32) 44 | image_float = image_float * factor 45 | image_float = np.clip(image_float, 0, 1) 46 | return image_float 47 | 48 | def adjust_brightness_nonlinear(image, gamma): 49 | image_float = image.astype(np.float32) 50 | img_nonlinear = np.power(image_float, gamma) 51 | img_nonlinear = np.clip(img_nonlinear, 0, 1) 52 | return img_nonlinear 53 | 54 | 55 | def add_label_noise(image_np, 56 | tem_mean=1, tem_var=0.03, 57 | bright_mean=1.15, bright_var=0.15, 58 | contrast_mean=1.15, contrast_var=0.15): 59 | if tem_mean != 1 or tem_var != 0: 60 | temperature_factor = np.random.normal(tem_mean, tem_var) 61 | image_np = adjust_color_temperature(image_np, temperature_factor) 62 | if bright_mean != 1 or bright_var != 0: 63 | bright_factor = np.random.normal(bright_mean, bright_var) 64 | image_np = adjust_brightness(image_np, factor=bright_factor) 65 | if contrast_mean != 1 or contrast_var != 0: 66 | contrast_factor = np.random.normal(contrast_mean, contrast_var) 67 | image_np = adjust_contrast(image_np, contrast_factor) 68 | 69 | return image_np -------------------------------------------------------------------------------- /basicsr/utils/mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class MaskGenerator: 4 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): 5 | self.input_size = input_size 6 | self.mask_patch_size = mask_patch_size 7 | self.model_patch_size = model_patch_size 8 | self.mask_ratio = mask_ratio 9 | 10 | assert self.input_size % self.mask_patch_size == 0 11 | assert self.mask_patch_size % self.model_patch_size == 0 12 | 13 | self.rand_size = self.input_size // self.mask_patch_size 14 | self.scale = self.mask_patch_size // self.model_patch_size 15 | 16 | self.token_count = self.rand_size ** 2 17 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) 18 | 19 | def __call__(self): 20 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count] 21 | mask = np.zeros(self.token_count, dtype=int) 22 | mask[mask_idx] = 1 23 | 24 | mask = mask.reshape((self.rand_size, self.rand_size)) 25 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) 26 | 27 | return mask -------------------------------------------------------------------------------- /basicsr/utils/mixing_augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Mixing_Augment: 4 | def __init__(self, mixup_beta, use_identity, device): 5 | self.dist = torch.distributions.beta.Beta( 6 | torch.tensor([mixup_beta]), torch.tensor([mixup_beta])) 7 | self.device = device 8 | 9 | self.use_identity = use_identity 10 | 11 | self.augments = [self.mixup] 12 | 13 | def mixup(self, target, input_): 14 | lam = self.dist.rsample((1, 1)).item() 15 | 16 | r_index = torch.randperm(target.size(0)).to(self.device) 17 | 18 | target = lam * target + (1 - lam) * target[r_index, :] 19 | input_ = lam * input_ + (1 - lam) * input_[r_index, :] 20 | 21 | return target, input_ 22 | 23 | def __call__(self, target, input_): 24 | if self.use_identity: 25 | augment = random.randint(0, len(self.augments)) 26 | if augment < len(self.augments): 27 | target, input_ = self.augments[augment](target, input_) 28 | else: 29 | augment = random.randint(0, len(self.augments) - 1) 30 | target, input_ = self.augments[augment](target, input_) 31 | return target, input_ -------------------------------------------------------------------------------- /basicsr/utils/noise_cal.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | def calculate_noise_map(x): 5 | """ 6 | args: 7 | x (Tensor) : B C H W 8 | return: 9 | noise_map (Tensor): return Noise Map of shape B C H W 10 | 11 | """ 12 | 13 | def gradient(x): 14 | def sub_gradient(x): 15 | left_shift_x, right_shift_x, grad = torch.zeros_like( 16 | x), torch.zeros_like(x), torch.zeros_like(x) 17 | left_shift_x[:, :, 0:-1] = x[:, :, 1:] 18 | right_shift_x[:, :, 1:] = x[:, :, 0:-1] 19 | grad = 0.5 * (left_shift_x - right_shift_x) 20 | return grad 21 | 22 | return sub_gradient(x), sub_gradient(torch.transpose(x, 2, 3)).transpose(2, 3) 23 | low_after_awb = x.exp() 24 | color_map = low_after_awb / (low_after_awb.sum(dim=1, keepdims=True) + 1e-4) 25 | dx, dy = gradient(color_map) 26 | noise_map = torch.max(torch.stack([dx.abs(), dy.abs()], dim=0), dim=0)[0] 27 | 28 | return noise_map -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /basicsr/utils/poisson_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Define the Poisson-Gaussian noise addition function 4 | def add_poisson_gaussian_noise(input_image, poisson_scale=500.0, gaussian_std=0.01): 5 | """ 6 | Add Poisson-Gaussian noise to the input image. 7 | 8 | Args: 9 | input_image: Input image tensor with shape (B, C, H, W). 10 | poisson_scale: Scaling factor for Poisson noise; higher values result in stronger Poisson noise. 11 | gaussian_std: Standard deviation of the Gaussian noise. 12 | 13 | Returns: 14 | Noisy image tensor. 15 | """ 16 | # Normalize the input image to a non-negative range (Poisson distribution requires non-negative values) 17 | normalized_image = input_image - input_image.min() 18 | normalized_image = normalized_image / (normalized_image.max() + 1e-8) # Prevent division by zero 19 | 20 | # Add Poisson noise (generate a Poisson-distributed random value for each pixel) 21 | poisson_noise = torch.poisson(normalized_image * poisson_scale) / poisson_scale 22 | 23 | # Add Gaussian noise 24 | gaussian_noise = torch.randn_like(input_image) * gaussian_std 25 | 26 | # Combine the noise 27 | noisy_image = poisson_noise + gaussian_noise 28 | 29 | return noisy_image 30 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Sun Sep 22 01:53:49 2024 3 | __version__ = '1.2.0+d419381' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /basicsr/vmamba/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BinCVER/Bayesian-Enhancement-Model/1abee7d6d05478094c857b9dc9902dcab43f5e9c/basicsr/vmamba/__init__.py -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav0_base_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base 4 | DROP_PATH_RATE: 0.5 5 | # DROP_PATH_RATE: 0.6 6 | VSSM: 7 | EMBED_DIM: 128 8 | DEPTHS: [ 2, 2, 27, 2 ] 9 | SSM_D_STATE: 16 10 | SSM_DT_RANK: "auto" 11 | SSM_RATIO: 2.0 12 | SSM_FORWARDTYPE: "v0" 13 | MLP_RATIO: 0.0 14 | DOWNSAMPLE: "v1" 15 | PATCHEMBED: "v1" 16 | 17 | 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav0_small_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 27, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v0" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav0_tiny_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v0" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav2_base_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_base_0229 4 | DROP_PATH_RATE: 0.6 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav2_small_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_small_0229 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav2_tiny_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav2v_base_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_base_0229s 4 | DROP_PATH_RATE: 0.5 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav2v_small_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_small_0229s 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssm/vmambav2v_tiny_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a0.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_v0 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v0" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a01.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a01 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v01" # csm_torch 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a0seq.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_v0seq 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v0seq" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v02" # csm_triton 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v04" # csm_triton + i16o32 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v05" # csm_triton + i16o32 + noeinsum + layout 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | NORM_LAYER: "ln2d" 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a7.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a7 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 2, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav0_tiny_224_a8.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a8ln 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" # "ln" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_a9d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a9d 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05_noz" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_bidi.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab2d 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] # [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 # 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v052d_noz" # "v32d_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_bidi_ndw.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab2d_ndw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] # [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 # 2.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v052d_noz" # "v32d_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_cas2d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab2dc 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] # [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 # 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v052dc_noz" # "v32dc_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_cas2d_ndw.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab2dc_ndw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] # [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 # 2.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v052dc_noz" # "v32dc_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_ds16.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230_ds16 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_ds2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230_ds2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 2 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_ds4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230_ds4 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 4 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_ds8.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230_ds8 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 8 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.5 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_gelu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_gelu 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | # SSM_INIT: "v2" 19 | SSM_ACT_LAYER: "gelu" 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_init1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_init1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | SSM_INIT: "v1" 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_init2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_init2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | SSM_INIT: "v2" -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_m2s2h.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_m2s2h 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.5 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 2.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | # SSM_INIT: "v2" 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_m3s1h.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_m3s1h 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.5 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 3.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | # SSM_INIT: "v2" 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_ndw.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_ndw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_ondw.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_ondw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_ondwconv3_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_onone.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_onone 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_onnone_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_onsoftmax.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_ondw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_onsoftmax_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_posndw.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_posndw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | POSEMBED: true 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_relu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s_relu 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | # SSM_INIT: "v2" 19 | SSM_ACT_LAYER: "relu" 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_sr1hl5.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230_sr1hl5 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.5 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_sr1l5.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230_sr1l5 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_unidi.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab1d 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] # [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 # 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v051d_noz" # "v31d_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/vssmab/vmambav2_tiny_224_unidi_ndw.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab1d_ndw 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] # [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 # 2.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v051d_noz" # "v31d_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vmambav2_tiny_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_base_224_a0.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base 4 | DROP_PATH_RATE: 0.5 5 | # DROP_PATH_RATE: 0.6 6 | VSSM: 7 | EMBED_DIM: 128 8 | DEPTHS: [ 2, 2, 27, 2 ] 9 | SSM_D_STATE: 16 10 | SSM_DT_RANK: "auto" 11 | SSM_RATIO: 2.0 12 | SSM_FORWARDTYPE: "v0" 13 | MLP_RATIO: 0.0 14 | DOWNSAMPLE: "v1" 15 | PATCHEMBED: "v1" 16 | 17 | # SSM_FORWARDTYPE: "v0" # if you want exactly the same 18 | 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_base_224_a6.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_a6 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 27, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v05" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | NORM_LAYER: "ln2d" -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_base_224_aav1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_aav1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_base_224_ahv1_0423.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_ahv1_0423 4 | DROP_PATH_RATE: 0.5 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.001 20 | 21 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_base_224_ahv3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_ahv3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_small_224_a0.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 27, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v0" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | # SSM_FORWARDTYPE: "v0" # if you want exactly the same 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_small_224_a6.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_a6 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 27, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_FORWARDTYPE: "v05" 12 | MLP_RATIO: 0.0 13 | DOWNSAMPLE: "v1" 14 | PATCHEMBED: "v1" 15 | NORM_LAYER: "ln2d" -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_small_224_aav1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_aav1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_small_224_ahv3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_ahv3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_a9v1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a9v1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 2 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_a9v2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a9v2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 4 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_a9v3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a9v3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 8 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_aaa.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aaa 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05_noz_oact" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_aav1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aav1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_aav2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aav2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: true 13 | SSM_FORWARDTYPE: "v05_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_abv2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_abv2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_abv3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_abv3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_abv4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_abv4 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_aca.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aca 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv1_61.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_61 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv1_66.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_66 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ca1" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv1_67.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_67 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv1_68.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_68 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv2 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_acv4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv4 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_adv1_mini.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_mini 4 | DROP_PATH_RATE: 0.0 5 | VSSM: 6 | EMBED_DIM: 192 7 | DEPTHS: [12] 8 | PATCH_SIZE: 16 9 | SSM_D_STATE: 1 10 | SSM_DT_RANK: "auto" 11 | SSM_RATIO: 2.0 12 | SSM_CONV: 3 13 | SSM_CONV_BIAS: false 14 | SSM_FORWARDTYPE: "xv1a_act" 15 | MLP_RATIO: 4.0 16 | DOWNSAMPLE: "v3" 17 | PATCHEMBED: "v2" 18 | NORM_LAYER: "ln2d" 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_adv1_mini2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_mini 4 | DROP_PATH_RATE: 0.0 5 | VSSM: 6 | EMBED_DIM: 192 7 | DEPTHS: [36] 8 | PATCH_SIZE: 16 9 | SSM_D_STATE: 1 10 | SSM_DT_RANK: "auto" 11 | SSM_RATIO: 2.0 12 | SSM_CONV: 3 13 | SSM_CONV_BIAS: false 14 | SSM_FORWARDTYPE: "xv1a_act" 15 | MLP_RATIO: 0.0 16 | DOWNSAMPLE: "v3" 17 | PATCHEMBED: "v1" 18 | NORM_LAYER: "ln2d" 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_ahv3_0420.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_ahv3_0420 4 | DROP_PATH_RATE: 0.15 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm01/vssm_tiny_224_aiv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aiv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_oncnorm" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | 19 | 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm1/vssm_base_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_base_0229 4 | DROP_PATH_RATE: 0.6 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v3_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | 18 | # 89.0 + 15.2 + 118min/e + 48G 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm1/vssm_mini_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0222 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v2" # "v2softmaxnozact", "v2sigmoidnozact",... 14 | MLP_RATIO: -1.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | # 17.56 + 2.73 -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm1/vssm_small_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_small_0229 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v3_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | 18 | # 50.4 + 8.6 + 90min/e + 36G -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm1/vssm_tiny_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.2 11 | MLP_RATIO: 4.0 12 | 13 | 14 | # PRINT_FREQ: 1 # for debug 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm1/vssm_tiny_224_0220.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0220 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v2" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_base_224_ahv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_ahv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_base_224_ahv1_0421.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_ahv1_0421 4 | DROP_PATH_RATE: 0.5 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_base_224_ahv1_0422.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_ahv1_0422 4 | DROP_PATH_RATE: 0.6 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_base_224_aiv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_aiv1 4 | DROP_PATH_RATE: 0.5 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_oncnorm" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | 19 | 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_base_224_aiv1_dp06.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_base_aiv1_dp06 4 | DROP_PATH_RATE: 0.6 5 | VSSM: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_oncnorm" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | 19 | 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_small_224_ahv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_ahv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_small_224_ahv1_0421.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_ahv1_0421 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_small_224_ahv1_0422.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_ahv1_0422 4 | DROP_PATH_RATE: 0.4 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_small_224_aiv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_aiv1 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_oncnorm" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_small_224_aiv1_dp04.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_small_aiv1_dp04 4 | DROP_PATH_RATE: 0.4 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 20, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_oncnorm" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0211.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0211 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v2" 15 | PATCHEMBED: "v1" 16 | 17 | # PRINT_FREQ: 1 # for debug 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0211v1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0211v1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 2 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v2" 15 | PATCHEMBED: "v1" 16 | 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0212.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0212 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 64 7 | # DEPTHS: [ 4, 4, 18, 4 ] # 36 + 6.12 8 | # DEPTHS: [3, 4, 12, 4] # 30 + 4.7 9 | DEPTHS: [3, 3, 12, 3] # 26 + 4.3 10 | SSM_D_STATE: 1 11 | SSM_DT_RANK: "auto" 12 | SSM_RATIO: 2.0 13 | SSM_CONV: -1 14 | SSM_FORWARDTYPE: "v2" 15 | MLP_RATIO: 4.0 16 | DOWNSAMPLE: "v2" 17 | PATCHEMBED: "v1" -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0213.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0213 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 64 7 | DEPTHS: [3, 3, 12, 3] # 26 + 4.3 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v1" 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0215.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0215 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v1" 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0216.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0216 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0217.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0217 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | # SSM_RATIO: 2.0 12 | # SSM_RANK_RATIO: 1.6 # add SSM_RANK_RATIO will introduce extra nn.linear, which costs 1.5+ GFlops 13 | SSM_CONV: -1 14 | SSM_FORWARDTYPE: "v2" 15 | MLP_RATIO: 4.0 16 | DOWNSAMPLE: "v3" 17 | PATCHEMBED: "v2" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0218.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0218 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.2 11 | # SSM_RATIO: 2.0 12 | # SSM_RANK_RATIO: 1.2 # add SSM_RANK_RATIO will introduce extra nn.linear, which costs 1.5+ GFlops 13 | SSM_CONV: -1 14 | SSM_FORWARDTYPE: "v2" 15 | MLP_RATIO: 4.0 16 | DOWNSAMPLE: "v3" 17 | PATCHEMBED: "v2" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0219.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0219 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0221.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0221 # to compare with 0218 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.2 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0222.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0222 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_INIT: "v1" # original: SSM_SIMPLE_INIT: true 14 | SSM_FORWARDTYPE: "v2" 15 | MLP_RATIO: 4.0 16 | DOWNSAMPLE: "v3" 17 | PATCHEMBED: "v2" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0223.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0223 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v2_nozact" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0224 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 9, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v2" 14 | MLP_RATIO: -1.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0225.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0225 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v2_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0229.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0229 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v2_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0229flex.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0229flex 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v3_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0230.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v3_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0230ab1d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab1d 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v31d_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0230ab2d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230ab2d 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v32d_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0309.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0309 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0310.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0310 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0311.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0311 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | # EMBED_DIM: 128 9 | # DEPTHS: [ 2, 2, 12, 2 ] 10 | SSM_D_STATE: 1 11 | SSM_DT_RANK: "auto" 12 | SSM_RATIO: 2.0 13 | # SSM_RATIO: 2.4 14 | SSM_CONV: 3 15 | SSM_CONV_BIAS: false 16 | SSM_FORWARDTYPE: "xv2_noz" 17 | MLP_RATIO: 4.0 18 | DOWNSAMPLE: "v3" 19 | PATCHEMBED: "v2" 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0312.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0312 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3_noz" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0313.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0313 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv4" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0314.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0314 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] # [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv5" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0315.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0315 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] # [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v2" 16 | PATCHEMBED: "v1" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0316.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0316 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv4" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v2" 16 | PATCHEMBED: "v1" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0317.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0317 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] # [ 3,3,27,3 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv6" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v2" # "v3" 16 | PATCHEMBED: "v1" # "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0318.2.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_03182 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv61" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0318.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0318 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv61" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v2" # "v3" 16 | PATCHEMBED: "v1" # "v2" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0319.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0319 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv7" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0320.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0320 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv7" 14 | GMLP: true 15 | MLP_RATIO: 2.5 16 | DOWNSAMPLE: "v3" # "v3" 17 | PATCHEMBED: "v2" # "v2" 18 | NORM_LAYER: "ln2d" 19 | 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0321.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0321 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0322.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0322 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0323.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0323 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0324.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0324 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0325.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0325 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_act_mul" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0326.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0326 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_0327.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0327 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 # 128 7 | DEPTHS: [2, 2, 5, 2] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv2a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" # "v3" 16 | PATCHEMBED: "v2" # "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.2 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v2" 15 | PATCHEMBED: "v1" 16 | 17 | 18 | # PRINT_FREQ: 1 # for debug 19 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_1v1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_1v1 4 | DROP_PATH_RATE: 0.1 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 4, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.2 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v2" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v2" 15 | PATCHEMBED: "v1" 16 | 17 | # PRINT_FREQ: 1 # for debug 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_a8d.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a8d 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 0.9 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_a9.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a9 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.6 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_a9a.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_a9a 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 16 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05_noz_oact" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_aa.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aa 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 12 | SSM_FORWARDTYPE: "v05_noz" 13 | MLP_RATIO: 4.0 14 | DOWNSAMPLE: "v3" 15 | PATCHEMBED: "v2" 16 | NORM_LAYER: "ln2d" 17 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_abv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_abv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acb.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acb 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 8 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0401.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0401 4 | DROP_PATH_RATE: 0.25 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0403.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0403 4 | DROP_PATH_RATE: 0.15 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0405.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0405 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.0004 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0406.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0406 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.001 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0407.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0407 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.002 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0408.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0408 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.003 -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0409.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0409 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.0003 -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_0410.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_0410 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: -1 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | TRAIN: 19 | BASE_LR: 0.0002 -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_6.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_6 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.8 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_62.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_62 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_62_0415.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_62_0415 4 | DROP_PATH_RATE: 0.15 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_63.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_63 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_64.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_64 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 7 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_acv1_65.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_acv1_65 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ca1" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_adv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_adv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ca_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_adv1c.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_adv1c 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_aev1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aev1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ocov_ca_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_aev1c.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_aev1c 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ocov_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_afv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_afv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ocov2_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_agv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_agv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 3 # 3 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_cpos_act" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_ahv1.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_ahv1 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv1a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | 19 | 20 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_ahv3.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_ahv3 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/configs/wasted/vssm_tiny_224_ahv3_0418.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm_tiny_ahv3_0418 4 | DROP_PATH_RATE: 0.25 5 | VSSM: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 8, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 1.0 11 | SSM_CONV: -1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "xv3a_ondwconv3" 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | -------------------------------------------------------------------------------- /basicsr/vmamba/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader as _build_loader 2 | from .data_simmim_pt import build_loader_simmim 3 | from .data_simmim_ft import build_loader_finetune 4 | 5 | 6 | def build_loader(config, simmim=False, is_pretrain=False): 7 | if not simmim: 8 | return _build_loader(config) 9 | if is_pretrain: 10 | return build_loader_simmim(config) 11 | else: 12 | return build_loader_finetune(config) 13 | -------------------------------------------------------------------------------- /basicsr/vmamba/data/imagenet22k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch.utils.data as data 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 10 | 11 | 12 | class IN22KDATASET(data.Dataset): 13 | def __init__(self, root, ann_file='', transform=None, target_transform=None): 14 | super(IN22KDATASET, self).__init__() 15 | 16 | self.data_path = root 17 | self.ann_path = os.path.join(self.data_path, ann_file) 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | # id & label: https://github.com/google-research/big_transfer/issues/7 21 | # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027 22 | self.database = json.load(open(self.ann_path)) 23 | 24 | def _load_image(self, path): 25 | try: 26 | im = Image.open(path) 27 | except: 28 | print("ERROR IMG LOADED: ", path) 29 | random_img = np.random.rand(224, 224, 3) * 255 30 | im = Image.fromarray(np.uint8(random_img)) 31 | return im 32 | 33 | def __getitem__(self, index): 34 | """ 35 | Args: 36 | index (int): Index 37 | Returns: 38 | tuple: (image, target) where target is class_index of the target class. 39 | """ 40 | idb = self.database[index] 41 | 42 | # images 43 | images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') 44 | if self.transform is not None: 45 | images = self.transform(images) 46 | 47 | # target 48 | target = int(idb[1]) 49 | if self.target_transform is not None: 50 | target = self.target_transform(target) 51 | 52 | return images, target 53 | 54 | def __len__(self): 55 | return len(self.database) 56 | -------------------------------------------------------------------------------- /basicsr/vmamba/data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /basicsr/vmamba/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | import torch 4 | 5 | from .vmamba import VSSM 6 | 7 | 8 | def build_vssm_model(config, **kwargs): 9 | model_type = config.MODEL.TYPE 10 | if model_type in ["vssm"]: 11 | model = VSSM( 12 | patch_size=config.MODEL.VSSM.PATCH_SIZE, 13 | in_chans=config.MODEL.VSSM.IN_CHANS, 14 | num_classes=config.MODEL.NUM_CLASSES, 15 | depths=config.MODEL.VSSM.DEPTHS, 16 | dims=config.MODEL.VSSM.EMBED_DIM, 17 | # =================== 18 | ssm_d_state=config.MODEL.VSSM.SSM_D_STATE, 19 | ssm_ratio=config.MODEL.VSSM.SSM_RATIO, 20 | ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO, 21 | ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)), 22 | ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER, 23 | ssm_conv=config.MODEL.VSSM.SSM_CONV, 24 | ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS, 25 | ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE, 26 | ssm_init=config.MODEL.VSSM.SSM_INIT, 27 | forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE, 28 | # =================== 29 | mlp_ratio=config.MODEL.VSSM.MLP_RATIO, 30 | mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER, 31 | mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE, 32 | # =================== 33 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 34 | patch_norm=config.MODEL.VSSM.PATCH_NORM, 35 | norm_layer=config.MODEL.VSSM.NORM_LAYER, 36 | downsample_version=config.MODEL.VSSM.DOWNSAMPLE, 37 | patchembed_version=config.MODEL.VSSM.PATCHEMBED, 38 | gmlp=config.MODEL.VSSM.GMLP, 39 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 40 | # =================== 41 | posembed=config.MODEL.VSSM.POSEMBED, 42 | imgsize=config.DATA.IMG_SIZE, 43 | ) 44 | return model 45 | 46 | return None 47 | 48 | 49 | def build_model(config, is_pretrain=False): 50 | model = None 51 | if model is None: 52 | model = build_vssm_model(config, is_pretrain) 53 | return model 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /basicsr/vmamba/models/mamba2/__init__.py: -------------------------------------------------------------------------------- 1 | # all the code in this folder is copied from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ 2 | 3 | -------------------------------------------------------------------------------- /basicsr/vmamba/readme.md: -------------------------------------------------------------------------------- 1 | ## origins 2 | 3 | based on https://github.com/microsoft/Swin-Transformer#20240103 4 | 5 | `main.py` and `utils/utils_ema.py` is modified from https://github.com/microsoft/Swin-Transformer#20240103, based on https://github.com/facebookresearch/ConvNeXt#20240103 6 | 7 | -------------------------------------------------------------------------------- /basicsr/vmamba/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | packaging 5 | triton 6 | timm==0.4.12 7 | pytest 8 | chardet 9 | yacs 10 | termcolor 11 | submitit 12 | tensorboardX 13 | fvcore 14 | seaborn 15 | -------------------------------------------------------------------------------- /basicsr/vmamba/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cub_extra.cuh: -------------------------------------------------------------------------------- 1 | // WarpMask is copied from /usr/local/cuda-12.1/include/cub/util_ptx.cuh 2 | // PowerOfTwo is copied from /usr/local/cuda-12.1/include/cub/util_type.cuh 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | /** 12 | * \brief Statically determine if N is a power-of-two 13 | */ 14 | template 15 | struct PowerOfTwo 16 | { 17 | enum { VALUE = ((N & (N - 1)) == 0) }; 18 | }; 19 | 20 | 21 | /** 22 | * @brief Returns the warp mask for a warp of @p LOGICAL_WARP_THREADS threads 23 | * 24 | * @par 25 | * If the number of threads assigned to the virtual warp is not a power of two, 26 | * it's assumed that only one virtual warp exists. 27 | * 28 | * @tparam LOGICAL_WARP_THREADS [optional] The number of threads per 29 | * "logical" warp (may be less than the number of 30 | * hardware warp threads). 31 | * @param warp_id Id of virtual warp within architectural warp 32 | */ 33 | template 34 | __host__ __device__ __forceinline__ 35 | unsigned int WarpMask(unsigned int warp_id) 36 | { 37 | constexpr bool is_pow_of_two = PowerOfTwo::VALUE; 38 | constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0); 39 | 40 | unsigned int member_mask = 0xFFFFFFFFu >> 41 | (CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS); 42 | 43 | if (is_pow_of_two && !is_arch_warp) 44 | { 45 | member_mask <<= warp_id * LOGICAL_WARP_THREADS; 46 | } 47 | 48 | return member_mask; 49 | } 50 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_ndstate.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_ndstate.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | 32 | bool delta_softplus; 33 | 34 | index_t A_d_stride; 35 | index_t B_batch_stride; 36 | index_t B_d_stride; 37 | index_t B_group_stride; 38 | index_t C_batch_stride; 39 | index_t C_d_stride; 40 | index_t C_group_stride; 41 | index_t u_batch_stride; 42 | index_t u_d_stride; 43 | index_t delta_batch_stride; 44 | index_t delta_d_stride; 45 | index_t out_batch_stride; 46 | index_t out_d_stride; 47 | 48 | // Common data pointers. 49 | void *__restrict__ A_ptr; 50 | void *__restrict__ B_ptr; 51 | void *__restrict__ C_ptr; 52 | void *__restrict__ D_ptr; 53 | void *__restrict__ u_ptr; 54 | void *__restrict__ delta_ptr; 55 | void *__restrict__ delta_bias_ptr; 56 | void *__restrict__ out_ptr; 57 | void *__restrict__ x_ptr; 58 | }; 59 | 60 | struct SSMParamsBwd: public SSMParamsBase { 61 | index_t dout_batch_stride; 62 | index_t dout_d_stride; 63 | index_t dA_d_stride; 64 | index_t dB_batch_stride; 65 | index_t dB_group_stride; 66 | index_t dB_d_stride; 67 | index_t dC_batch_stride; 68 | index_t dC_group_stride; 69 | index_t dC_d_stride; 70 | index_t du_batch_stride; 71 | index_t du_d_stride; 72 | index_t ddelta_batch_stride; 73 | index_t ddelta_d_stride; 74 | 75 | // Common data pointers. 76 | void *__restrict__ dout_ptr; 77 | void *__restrict__ dA_ptr; 78 | void *__restrict__ dB_ptr; 79 | void *__restrict__ dC_ptr; 80 | void *__restrict__ dD_ptr; 81 | void *__restrict__ du_ptr; 82 | void *__restrict__ ddelta_ptr; 83 | void *__restrict__ ddelta_bias_ptr; 84 | }; 85 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_oflex.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | template void selective_scan_bwd_cuda<1, at::Half, float, at::Half>(SSMParamsBwd ¶ms, cudaStream_t stream); 10 | template void selective_scan_bwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBwd ¶ms, cudaStream_t stream); 11 | 12 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_oflex.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | template void selective_scan_fwd_cuda<1, at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream); 11 | 12 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | 32 | bool delta_softplus; 33 | 34 | index_t A_d_stride; 35 | index_t A_dstate_stride; 36 | index_t B_batch_stride; 37 | index_t B_d_stride; 38 | index_t B_dstate_stride; 39 | index_t B_group_stride; 40 | index_t C_batch_stride; 41 | index_t C_d_stride; 42 | index_t C_dstate_stride; 43 | index_t C_group_stride; 44 | index_t u_batch_stride; 45 | index_t u_d_stride; 46 | index_t delta_batch_stride; 47 | index_t delta_d_stride; 48 | index_t out_batch_stride; 49 | index_t out_d_stride; 50 | 51 | // Common data pointers. 52 | void *__restrict__ A_ptr; 53 | void *__restrict__ B_ptr; 54 | void *__restrict__ C_ptr; 55 | void *__restrict__ D_ptr; 56 | void *__restrict__ u_ptr; 57 | void *__restrict__ delta_ptr; 58 | void *__restrict__ delta_bias_ptr; 59 | void *__restrict__ out_ptr; 60 | void *__restrict__ x_ptr; 61 | }; 62 | 63 | struct SSMParamsBwd: public SSMParamsBase { 64 | index_t dout_batch_stride; 65 | index_t dout_d_stride; 66 | index_t dA_d_stride; 67 | index_t dA_dstate_stride; 68 | index_t dB_batch_stride; 69 | index_t dB_group_stride; 70 | index_t dB_d_stride; 71 | index_t dB_dstate_stride; 72 | index_t dC_batch_stride; 73 | index_t dC_group_stride; 74 | index_t dC_d_stride; 75 | index_t dC_dstate_stride; 76 | index_t du_batch_stride; 77 | index_t du_d_stride; 78 | index_t ddelta_batch_stride; 79 | index_t ddelta_d_stride; 80 | 81 | // Common data pointers. 82 | void *__restrict__ dout_ptr; 83 | void *__restrict__ dA_ptr; 84 | void *__restrict__ dB_ptr; 85 | void *__restrict__ dC_ptr; 86 | void *__restrict__ dD_ptr; 87 | void *__restrict__ du_ptr; 88 | void *__restrict__ ddelta_ptr; 89 | void *__restrict__ ddelta_bias_ptr; 90 | }; 91 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | flax==0.8.5 3 | fvcore==0.1.5.post20221221 4 | h5py==3.11.0 5 | httplib2==0.22.0 6 | huggingface-hub==0.23.5 7 | libsvm-official==3.35.0 8 | lmdb==1.4.1 9 | lpips==0.1.4 10 | matplotlib==3.9.0 11 | networkx==3.3 12 | ninja==1.11.1.1 13 | numpy==1.24.1 14 | numpyro==0.15.1 15 | opencv-python==4.10.0.84 16 | opt-einsum==3.3.0 17 | optax==0.2.3 18 | orbax-checkpoint==0.5.22 19 | orderedmultidict==1.0.1 20 | pandas==2.2.2 21 | pillow==10.3.0 22 | pytorch-msssim==1.0.0 23 | safetensors==0.4.3 24 | scikit-image==0.23.2 25 | scikit-learn==1.5.0 26 | scipy==1.13.1 27 | seaborn==0.13.2 28 | tensorboard==2.17.0 29 | tensorboardX==2.6.2.2 30 | timm==0.4.12 31 | tokenizers==0.19.1 32 | torchmetrics==1.4.1 33 | tqdm==4.66.4 34 | causal_conv1d==1.0.0 35 | mamba_ssm==1.0.1 36 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | --------------------------------------------------------------------------------