├── Eval_MACNet.py ├── ICVL_test_gauss.txt ├── ICVL_train.txt ├── README.md ├── dataloaders.py ├── dataloaders_hsi.py ├── dataloaders_hsi_test.py ├── model ├── MACNet.py ├── __pycache__ │ ├── CBAM.cpython-37.pyc │ ├── MACNet.cpython-37.pyc │ ├── SubCNN.cpython-37.pyc │ ├── SubCNN_NL.cpython-37.pyc │ ├── SubCNN_NLM.cpython-37.pyc │ ├── SubCNN_NL_TIP.cpython-37.pyc │ ├── SubCNN_QRNN.cpython-37.pyc │ ├── SubSCNN_BNRED.cpython-37.pyc │ ├── SubSCNN_BNREDCBAM.cpython-37.pyc │ ├── SubSCNN_BNREDRES.cpython-37.pyc │ ├── combinations.cpython-37.pyc │ ├── non_local.cpython-37.pyc │ └── pyramidpooling.cpython-37.pyc ├── combinations.py ├── non_local.py └── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── batchnorm.cpython-37.pyc │ ├── comm.cpython-37.pyc │ └── replicate.cpython-37.pyc │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── model_loader.py ├── ops ├── __pycache__ │ ├── gauss.cpython-37.pyc │ ├── im2col.cpython-37.pyc │ ├── im2col.py │ ├── utils.cpython-37.pyc │ └── utils_blocks.cpython-37.pyc ├── gauss.py ├── im2col.py ├── utils.py ├── utils_blocks.py └── utils_plot.py ├── test.py ├── train.py └── trained_model ├── ckpt_15 ├── ckpt_55 └── ckpt_95 /Eval_MACNet.py: -------------------------------------------------------------------------------- 1 | import dataloaders_hsi_test 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | import argparse 6 | import os 7 | from collections import OrderedDict 8 | import torch.nn.functional as F 9 | import time 10 | import scipy.io as scio 11 | from ops.utils_blocks import block_module 12 | from ops.utils import show_mem, generate_key, save_checkpoint, str2bool, step_lr, get_lr, MSIQA 13 | parser = argparse.ArgumentParser() 14 | #model 15 | parser.add_argument("--mode", type=str, default='sc',help='[group, sc]') 16 | parser.add_argument("--noise_level", type=int, dest="noise_level", help="Should be an int in the range [0,255]", default=0) 17 | parser.add_argument("--bandwise", type=str2bool, default=1, help='bandwise noise') 18 | parser.add_argument("--num_half_layer", type=int, dest="num_half_layer", help="Number of LISTA step unfolded", default=6) 19 | parser.add_argument("--downsample", type=int, dest="downsample", help="Down Sample", default=2) 20 | parser.add_argument("--channels", type=int, dest="channels", help="Should be an int in the range [0,255]", default=16) 21 | parser.add_argument("--nl", type=str2bool, dest="nl", help="If Nonlocal", default=1) 22 | parser.add_argument("--patch_size", type=int, dest="patch_size", help="Size of image blocks to process", default=56) 23 | parser.add_argument("--rescaling_init_val", type=float, default=1.0) 24 | parser.add_argument("--gpus", '--list',action='append', type=int, help='GPU') 25 | 26 | #training 27 | parser.add_argument("--lr", type=float, dest="lr", help="ADAM Learning rate", default=1e-4) 28 | parser.add_argument("--lr_step", type=int, dest="lr_step", help="ADAM Learning rate step for decay", default=80) 29 | parser.add_argument("--lr_decay", type=float, dest="lr_decay", help="ADAM Learning rate decay (on step)", default=0.35) 30 | parser.add_argument("--backtrack_decay", type=float, help='decay when backtracking',default=0.8) 31 | parser.add_argument("--eps", type=float, dest="eps", help="ADAM epsilon parameter", default=1e-3) 32 | parser.add_argument("--validation_every", type=int, default=300, help='validation frequency on training set (if using backtracking)') 33 | parser.add_argument("--backtrack", type=str2bool, default=1, help='use backtrack to prevent model divergence') 34 | parser.add_argument("--num_epochs", type=int, dest="num_epochs", help="Total number of epochs to train", default=300) 35 | parser.add_argument("--train_batch", type=int, default=2, help='batch size during training') 36 | parser.add_argument("--test_batch", type=int, default=3, help='batch size during eval') 37 | parser.add_argument("--aug_scale", type=int, default=0) 38 | 39 | #data 40 | parser.add_argument("--out_dir", type=str, dest="out_dir", help="Results' dir path", default='./trained_model') 41 | parser.add_argument("--model_name", type=str, dest="model_name", help="The name of the model to be saved.", default='trained_model_25_bandwise/MTMF_patch_56Layer_12lr_0.00100000/ckpt') 42 | parser.add_argument("--test_path", type=str, help="Path to the dir containing the testing datasets.", default="data/") 43 | parser.add_argument("--gt_path", type=str, help="Path to the dir containing the ground truth datasets.", default="gt/") 44 | parser.add_argument("--resume", type=str2bool, dest="resume", help='Resume training of the model',default=True) 45 | parser.add_argument("--dummy", type=str2bool, dest="dummy", default=False) 46 | parser.add_argument("--tqdm", type=str2bool, default=False) 47 | parser.add_argument('--log_dir', type=str, default='log', help='log directory') 48 | 49 | #inference 50 | parser.add_argument("--kernel_size", type=int, default=12, help='stride of overlapping image blocks [4,8,16,24,48] kernel_//stride') 51 | parser.add_argument("--stride_test", type=int, default=12, help='stride of overlapping image blocks [4,8,16,24,48] kernel_//stride') 52 | parser.add_argument("--stride_val", type=int, default=40, help='stride of overlapping image blocks for validation [4,8,16,24,48] kernel_//stride') 53 | parser.add_argument("--test_every", type=int, default=300, help='report performance on test set every X epochs') 54 | parser.add_argument("--block_inference", type=str2bool, default=False,help='if true process blocks of large image in paralel') 55 | parser.add_argument("--pad_image", type=str2bool, default=0,help='padding strategy for inference') 56 | parser.add_argument("--pad_block", type=str2bool, default=1,help='padding strategy for inference') 57 | parser.add_argument("--pad_patch", type=str2bool, default=0,help='padding strategy for inference') 58 | parser.add_argument("--no_pad", type=str2bool, default=False, help='padding strategy for inference') 59 | parser.add_argument("--custom_pad", type=int, default=None,help='padding strategy for inference') 60 | 61 | #variance reduction 62 | #var reg 63 | parser.add_argument("--nu_var", type=float, default=0.01) 64 | parser.add_argument("--freq_var", type=int, default=3) 65 | parser.add_argument("--var_reg", type=str2bool, default=False) 66 | 67 | parser.add_argument("--verbose", type=str2bool, default=1) 68 | 69 | args = parser.parse_args() 70 | # os.environ['CUDA_VISIBLE_DEVICES']= '6,7' 71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 | #device=torch.device("cpu") 73 | device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu' 74 | #device_name='cpu' 75 | capability = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else os.cpu_count() 76 | # device= torch.device("cpu") 77 | test_path = [args.test_path] 78 | gt_path = args.gt_path 79 | print(f'test data : {test_path}') 80 | print(f'gt data : {gt_path}') 81 | train_path = val_path = [] 82 | 83 | noise_std = args.noise_level / 255 84 | 85 | loaders = dataloaders_hsi_test.get_dataloaders(test_path,crop_size=args.patch_size, 86 | batch_size=args.train_batch, downscale=args.aug_scale, concat=1,verbose=True,grey=False) 87 | 88 | from model.MACNet import Params 89 | from model.MACNet import MACNet 90 | 91 | params = Params(in_channels=1, channels=args.channels, 92 | num_half_layer=args.num_half_layer,downsample=args.downsample,sigma=args.noise_level,nl=args.nl) 93 | model = MACNet(params).to(device=device) 94 | # model = SubCNN_NL(params).to(device=device) 95 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 96 | print(f'Arguments: {vars(args)}') 97 | print('Nb tensors: ',len(list(model.named_parameters())), "; Trainable Params: ", pytorch_total_params, "; device: ", device, 98 | "; name : ", device_name) 99 | 100 | psnr = {x: np.zeros(args.num_epochs) for x in ['train', 'test', 'val']} 101 | model_name = args.model_name 102 | out_dir = os.path.join(model_name) 103 | ckpt_path = os.path.join(out_dir) 104 | config_dict = vars(args) 105 | 106 | if args.resume: 107 | if os.path.isfile(ckpt_path): 108 | try: 109 | print('\n existing ckpt detected') 110 | checkpoint = torch.load(ckpt_path, map_location=device) 111 | start_epoch = checkpoint['epoch'] 112 | psnr_validation = checkpoint['psnr_validation'] 113 | state_dict = checkpoint['state_dict'] 114 | new_state_dict = OrderedDict() 115 | for k, v in state_dict.items(): 116 | # name = k 117 | name = k[7:] # remove 'module.' of dataparallel 118 | new_state_dict[name] = v 119 | model.load_state_dict(new_state_dict, strict=True) 120 | print(f"=> loaded checkpoint '{ckpt_path}' (epoch {start_epoch})") 121 | # print(f"=> loaded checkpoint '{ckpt_path}' (epoch {start_epoch})") 122 | except Exception as e: 123 | print(e) 124 | print(f'ckpt loading failed @{ckpt_path}, exit ...') 125 | exit() 126 | 127 | else: 128 | print(f'\nno ckpt found @{ckpt_path}') 129 | exit() 130 | gpus=args.gpus 131 | if torch.cuda.is_available(): 132 | torch.backends.cudnn.benchmark = True 133 | if device.type=='cuda': 134 | torch.cuda.set_device('cuda:{}'.format(gpus[0])) 135 | model.to(device=device) 136 | if device.type=='cuda': 137 | model = torch.nn.DataParallel(model.to(device=device), device_ids=gpus) 138 | # l = args.kernel_size // 2 139 | tic = time.time() 140 | phase = 'test' 141 | print(f'\nstarting eval on test set with stride {args.stride_test}...') 142 | model.eval() # Set model to evaluate mode 143 | 144 | num_iters = 0 145 | psnr_tot = [] 146 | ssim_tot = [] 147 | sam_tot=[] 148 | stride_test = args.stride_test 149 | 150 | loader = loaders['test'] 151 | for batch,fname in tqdm(loader,disable=not args.tqdm): 152 | batch = batch.to(device=device) 153 | fname=fname[0] 154 | print(fname) 155 | noisy_batch = batch 156 | # scio.savemat(fname + 'Noisy.mat', {'output': noisy_batch.detach().cpu().numpy()}) 157 | with torch.set_grad_enabled(False): 158 | if args.block_inference: 159 | params = { 160 | 'crop_out_blocks': 0, 161 | 'ponderate_out_blocks': 1, 162 | 'sum_blocks': 0, 163 | 'pad_even': 1, # otherwise pad with 0 for las 164 | 'centered_pad': 0, # corner pixel have only one estimate 165 | 'pad_block': args.pad_block, # pad so each pixel has S**2 estimate 166 | 'pad_patch': args.pad_patch, # pad so each pixel from the image has at least S**2 estimate from 1 block 167 | 'no_pad': args.no_pad, 168 | 'custom_pad': args.custom_pad, 169 | 'avg': 1} 170 | block = block_module(args.patch_size, stride_test, args.kernel_size, params) 171 | batch_noisy_blocks = block._make_blocks(noisy_batch) 172 | patch_loader = torch.utils.data.DataLoader(batch_noisy_blocks, batch_size=args.test_batch, drop_last=False) 173 | batch_out_blocks = torch.zeros_like(batch_noisy_blocks) 174 | for i, inp in enumerate(patch_loader): # if it doesnt fit in memory 175 | id_from, id_to = i * patch_loader.batch_size, (i + 1) * patch_loader.batch_size 176 | batch_out_blocks[id_from:id_to] = model(inp) 177 | 178 | output = block._agregate_blocks(batch_out_blocks) 179 | else: 180 | output = model(noisy_batch) 181 | gt=dataloaders_hsi_test.get_gt(gt_path,fname); 182 | gt=gt.to(device=device) 183 | if device_name=='cpu': 184 | psnr_batch, ssim_batch, sam_batch = MSIQA(gt.detach().numpy(), 185 | output.squeeze(0).detach().numpy()) 186 | # scio.savemat(fname + 'Res.mat', {'output': output.squeeze(0).detach().numpy()}) 187 | else: 188 | psnr_batch, ssim_batch, sam_batch = MSIQA(gt.detach().cpu().numpy(), 189 | output.squeeze(0).detach().cpu().numpy()) 190 | # scio.savemat(fname + 'Res.mat', {'output': output.squeeze(0).detach().cpu().numpy()}) 191 | 192 | # psnr_batch, ssim_batch, sam_batch=MSIQA(gt.detach().cpu().numpy(),output.squeeze(0).detach().cpu().numpy()) 193 | psnr_tot.append(psnr_batch) 194 | ssim_tot.append(ssim_batch) 195 | sam_tot.append(sam_batch) 196 | num_iters += 1 197 | tqdm.write(f'psnr avg {psnr_batch} ssim avg {ssim_batch} sam avg {sam_batch} ') 198 | if args.dummy: 199 | break 200 | tac = time.time() 201 | psnr_mean = np.mean(psnr_tot) 202 | ssim_mean = np.mean(ssim_tot) 203 | sam_mean = np.mean(sam_tot) 204 | scio.savemat(args.out_dir + str(args.noise_level)+'.mat', {'psnr': psnr_tot, 'ssim': ssim_tot, 'sam': sam_tot}) 205 | tqdm.write( 206 | f'psnr: {psnr_mean:0.4f} ssim: {ssim_mean:0.4f} sam: {sam_mean:0.4f}({(tac - tic) / num_iters:0.3f} s/iter)') 207 | -------------------------------------------------------------------------------- /ICVL_test_gauss.txt: -------------------------------------------------------------------------------- 1 | CC_40D_2_1103-0917.mat 2 | IDS_COLORCHECK_1020-1215-1.mat 3 | Labtest_0910-1504.mat 4 | Labtest_0910-1509.mat 5 | Labtest_0910-1513.mat 6 | Lehavim_0910-1636.mat 7 | Lehavim_0910-1716.mat 8 | Lehavim_0910-1717.mat 9 | Master5000K_2900K.mat 10 | bulb_0822-0909.mat 11 | gavyam_0823-0930.mat 12 | gavyam_0823-0933.mat 13 | gavyam_0823-0944.mat 14 | lehavim_0910-1600.mat 15 | lehavim_0910-1602.mat 16 | lehavim_0910-1610.mat 17 | nachal_0823-1047.mat 18 | nachal_0823-1118.mat 19 | nachal_0823-1121.mat 20 | nachal_0823-1127.mat 21 | nachal_0823-1144.mat 22 | nachal_0823-1149.mat 23 | nachal_0823-1214.mat 24 | nachal_0823-1217.mat 25 | nachal_0823-1222.mat 26 | nachal_0823-1223.mat 27 | negev_0823-1003.mat 28 | objects_0924-1650.mat 29 | rmt_0328-1241-1.mat 30 | rmt_0328-1249-1.mat 31 | selfie_0822-0906.mat 32 | ulm_0328-1118.mat 33 | Lehavim_0910-1626.mat 34 | Lehavim_0910-1635.mat 35 | Master2900k.mat 36 | Master5000K.mat 37 | objects_0924-1550.mat 38 | objects_0924-1623.mat 39 | objects_0924-1557.mat 40 | objects_0924-1602.mat 41 | objects_0924-1607.mat 42 | objects_0924-1611.mat 43 | objects_0924-1614.mat 44 | objects_0924-1617.mat 45 | objects_0924-1625.mat 46 | objects_0924-1628.mat 47 | objects_0924-1632.mat 48 | objects_0924-1637.mat 49 | objects_0924-1641.mat 50 | objects_0924-1648.mat 51 | -------------------------------------------------------------------------------- /ICVL_train.txt: -------------------------------------------------------------------------------- 1 | 4cam_0411-1640-1.mat 2 | 4cam_0411-1648.mat 3 | bguCAMP_0514-1659.mat 4 | bguCAMP_0514-1711.mat 5 | bguCAMP_0514-1712.mat 6 | bguCAMP_0514-1718.mat 7 | bguCAMP_0514-1723.mat 8 | bguCAMP_0514-1724.mat 9 | BGU_0403-1419-1.mat 10 | bgu_0403-1439.mat 11 | bgu_0403-1444.mat 12 | bgu_0403-1459.mat 13 | bgu_0403-1511.mat 14 | bgu_0403-1523.mat 15 | bgu_0403-1525.mat 16 | BGU_0522-1113-1.mat 17 | BGU_0522-1127.mat 18 | BGU_0522-1136.mat 19 | BGU_0522-1201.mat 20 | BGU_0522-1203.mat 21 | BGU_0522-1211.mat 22 | BGU_0522-1216.mat 23 | BGU_0522-1217.mat 24 | eve_0331-1551.mat 25 | eve_0331-1601.mat 26 | eve_0331-1602.mat 27 | eve_0331-1606.mat 28 | eve_0331-1618.mat 29 | eve_0331-1632.mat 30 | eve_0331-1633.mat 31 | eve_0331-1646.mat 32 | eve_0331-1647.mat 33 | eve_0331-1656.mat 34 | eve_0331-1657.mat 35 | eve_0331-1702.mat 36 | eve_0331-1705.mat 37 | Flower_0325-1336.mat 38 | hill_0325-1219.mat 39 | hill_0325-1228.mat 40 | hill_0325-1235.mat 41 | hill_0325-1242.mat 42 | Lehavim_0910-1630.mat 43 | eve_0331-1549.mat 44 | grf_0328-0949.mat 45 | lst_0408-0950.mat 46 | omer_0331-1055.mat 47 | peppers_0503-1311.mat 48 | prk_0328-1025.mat 49 | lst_0408-1004.mat 50 | lst_0408-1012.mat 51 | Master20150112_f2_colorchecker.mat 52 | Maz0326-1038.mat 53 | maz_0326-1048.mat 54 | mor_0328-1209-2.mat 55 | nachal_0823-1110.mat 56 | omer_0331-1102.mat 57 | omer_0331-1104.mat 58 | omer_0331-1118.mat 59 | omer_0331-1119.mat 60 | omer_0331-1130.mat 61 | omer_0331-1131.mat 62 | omer_0331-1135.mat 63 | omer_0331-1150.mat 64 | omer_0331-1159.mat 65 | peppers_0503-1308.mat 66 | peppers_0503-1315.mat 67 | peppers_0503-1330.mat 68 | peppers_0503-1332.mat 69 | pepper_0503-1228.mat 70 | pepper_0503-1229.mat 71 | pepper_0503-1236.mat 72 | plt_0411-1037.mat 73 | plt_0411-1046.mat 74 | plt_0411-1116.mat 75 | plt_0411-1155.mat 76 | plt_0411-1200-1.mat 77 | plt_0411-1207.mat 78 | plt_0411-1210.mat 79 | plt_0411-1211.mat 80 | plt_0411-1232-1.mat 81 | prk_0328-0945.mat 82 | prk_0328-1031.mat 83 | prk_0328-1034.mat 84 | prk_0328-1037.mat 85 | prk_0328-1045.mat 86 | Ramot0325-1364.mat 87 | ramot_0325-1322.mat 88 | rsh2_0406-1505.mat 89 | rsh_0406-1343.mat 90 | rsh_0406-1356.mat 91 | rsh_0406-1413.mat 92 | rsh_0406-1427.mat 93 | rsh_0406-1441-1.mat 94 | rsh_0406-1443.mat 95 | sami_0331-1019.mat 96 | sat_0406-1107.mat 97 | sat_0406-1129.mat 98 | sat_0406-1130.mat 99 | sat_0406-1157-1.mat 100 | strt_0331-1027.mat 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAC-Net: Model Aided Nonlocal Neural Network for Hyperspectral Image Denoising 2 | ## Fengchao Xiong; Jun Zhou; Qinling Zhao; Jianfeng Lu; Yuntao Qian 3 | 4 | [Link to paper](https://ieeexplore.ieee.org/abstract/document/9631264/) 5 | 6 | # Abstract 7 | 8 | Hyperspectral image (HSI) denoising is an ill-posed inverse problem. The underlying physical model is always important to tackle this problem, which is unfortunately ignored by most of the current deep learning (DL)-based methods, producing poor denoising performance. To address this issue, this paper introduces an end-to-end model aided nonlocal neural network (MAC-Net) which simultaneously takes the spectral low-rank model and spatial deep prior into account for HSI noise reduction. Specifically, motivated by the success of the spectral low-rank model in depicting the strong spectral correlations and the nonlocal similarity prior in capturing spatial long-range dependencies, we first build a spectral low-rank model and then integrate a nonlocal U-Net into the model. In this way, we obtain a hybrid model-based and DL-based HSI denoising method where the spatial local and nonlocal multi-scale and spectral low-rank structures are effectively exploited. After that, we cast the optimization and denoising procedure of the hybrid method as a forward process of a neural network and introduce a set of learnable modules to yield our MAC-Net. Compared with traditional model-based methods, our MAC-Net overcomes the difficulties of accurate modeling thanks to the strong learning and representation ability of DL. Unlike most “black-box” DL-based methods, the spectral low-rank model is beneficial to increase the generalization ability of the network and decrease the requirement of training samples. Experimental results on the natural and remote sensing HSIs show that MAC-Net achieves state-of-the-art performance over both model-based and DL-based methods. 9 | 10 | # Requirements 11 | 12 | We tested the implementation in Python 3.7. 13 | 14 | 15 | 16 | 17 | # Datasets 18 | 19 | * The ICVL dataset can be downloaded from [Link to Dataset](http://icvl.cs.bgu.ac.il/hyperspectral/). 20 | 21 | * The 100 HSIs used in our training can be found in [ICVL_train.txt](https://github.com/bearshng/mac-net/blob/master/ICVL_train.txt). 22 | 23 | 24 | * The 50 HSIs used for testing can be found in [ICVL_test.txt](https://github.com/bearshng/mac-net/blob/master/ICVL_test_gauss.txt). 25 | 26 | * The sample WDC testing HSI can be accessed via [Google Drive](https://drive.google.com/drive/folders/1XI2S-AVCsx1jNyO4-XvQnfY8n7sXj1mW?usp=sharing). 27 | 28 | 29 | # Test 30 | 31 | ## Test with known noise level 32 | 33 | `python test.py --test_path '---' --channels 16 --num_half_layer 5 --blind 0 --noise_level 15 --gt_path '---' --gpus 0 --verbose 1` 34 | 35 | ## Test with unknown noise level 36 | 37 | `python test.py --test_path '---' --channels 16 --num_half_layer 5 --blind 1 --noise_level 0 --gt_path '---' --gpus 0 --verbose 1` 38 | 39 | ## Test with real-world remote sensing HSI 40 | `python test.py --test_path '---' --channels 16 --num_half_layer 5 --blind 1 --noise_level 0 --gt_path '---' --gpus 0 --verbose 1 --save 1 --rs_real 1` 41 | 42 | 43 | ### Important arguments 44 | 45 | * `test_path `: path to testing dataset 46 | * `channels `: the dimensiona of feature extractor 47 | * `blind `: blind denoising 48 | * `noise_level `: the maximum noise level 49 | * `gt_path `: path to ground truth image (for calculating PSNR ,SSIM and SAM) 50 | * `gpus `: device id 51 | * `save` : save results 52 | * `rs_real `: real-world remote sensing HSI 53 | 54 | 55 | 56 | 57 | # Train 58 | 59 | ## train your own model 60 | 61 | `python train.py --noise_level 15 --lr 5e-3 --patch_size 64 --train_path 'your_path' --test_path 'your_path' --log_dir './log' --out_dir './trained_model' --verbose 1 --validation_every 300 --gpus 0 --num_epochs 300 --bandwise 1 --train_batch 16 --num_half_layer 5` 62 | 63 | 64 | ### Important arguments 65 | 66 | * `noise_level `: maximum noise level 67 | * `bandwise `: add noise with a range of sigma 68 | * `lr`: learning rate 69 | * `patch_size`: patch size 70 | 71 | * `train_path`: path to training set 72 | * `test_path `: path to validation set 73 | * `log_dir `: path to logs 74 | * `out_dir `: path to resulted models 75 | * `num_epochs `: maximum epochs 76 | * `train_batch `: batch size 77 | 78 | ## Bibtex 79 | 80 | `@ARTICLE{macnet2021, author={Xiong, Fengchao and Zhou, Jun and Zhao, Qinling and Lu, Jianfeng and Qian, Yuntao}, journal={IEEE Transactions on Geoscience and Remote Sensing}, title={MAC-Net: Model Aided Nonlocal Neural Network for Hyperspectral Image Denoising}, year={2021}, volume={}, number={}, pages={1-1}, doi={10.1109/TGRS.2021.3131878}}` 81 | 82 | ## Update 83 | 84 | Because of the difference between MATLAB and Python in SSIM index calculation, the value produced by Python is a little higher than that by MATLAB. In MATLAB, the SSIM of MACNet in the ICVL dataset are respectively: 85 | 86 | | Sigma | SSIM | 87 | |:----------|:----------| 88 | | [0-15] | 0.9945 | 89 | | [0-55] | 0.9802 | 90 | | [0-95] | 0.9560 | 91 | | [blind] | 0.9700 | 92 | 93 | We reimplemented the `ssim` function to keep consisent with MATLAB. 94 | ## Contact Information: 95 | 96 | Fengchao Xiong: fcxiong@njust.edu.cn 97 | 98 | School of Computer Science and Engineering 99 | 100 | Nanjing University of Science and Technology 101 | 102 | -------------------------------------------------------------------------------- /dataloaders.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | from os import listdir, path 4 | from PIL import Image 5 | import torch 6 | import torchvision.transforms.functional as TF 7 | import random 8 | from typing import Sequence 9 | from itertools import repeat 10 | 11 | def repeater(data_loader): 12 | for loader in repeat(data_loader): 13 | for data in loader: 14 | yield data 15 | 16 | class MyRotateTransform: 17 | def __init__(self, angles: Sequence[int]): 18 | self.angles = angles 19 | 20 | def __call__(self, x): 21 | angle = random.choice(self.angles) 22 | return TF.rotate(x, angle) 23 | 24 | 25 | class Dataset(Dataset): 26 | def __init__(self, root_dirs, transform=None, verbose=False, grey=False): 27 | self.root_dirs = root_dirs 28 | self.transform = transform 29 | self.images_path = [] 30 | for cur_path in root_dirs: 31 | self.images_path += [path.join(cur_path, file) for file in listdir(cur_path) if file.endswith(('tif','png','jpg','jpeg','bmp'))] 32 | self.verbose = verbose 33 | self.grey = grey 34 | 35 | def __len__(self): 36 | return len(self.images_path) 37 | 38 | def __getitem__(self, idx): 39 | img_name = self.images_path[idx] 40 | 41 | if self.grey: 42 | image = Image.open(img_name).convert('L') 43 | else: 44 | image = Image.open(img_name).convert('RGB') 45 | 46 | if self.transform: 47 | image = self.transform(image) 48 | 49 | if self.verbose: 50 | return image, img_name.split('/')[-1] 51 | 52 | return image 53 | 54 | 55 | def get_dataloaders(train_path_list, test_path_list, val_path_list, crop_size=128, batch_size=1, downscale=0, 56 | drop_last=True, concat=True, n_worker=0, scale_min=0.001, scale_max=0.1, verbose=False, grey=False): 57 | 58 | batch_sizes = {'train': batch_size, 'test':1, 'val': 1} 59 | tfs = [] 60 | if downscale==0: 61 | tfs = [transforms.RandomCrop(crop_size)] 62 | elif downscale==1: 63 | tfs += [transforms.RandomResizedCrop(crop_size, scale=(scale_min,scale_max), ratio=(1.0,1.0))] 64 | elif downscale==2: 65 | print('mode 2') 66 | tfs += [transforms.Resize(300)] 67 | tfs += [transforms.RandomCrop(crop_size)] 68 | 69 | tfs += [transforms.RandomCrop(crop_size), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.RandomVerticalFlip(), 72 | transforms.ToTensor()] 73 | 74 | train_transforms = transforms.Compose(tfs) 75 | test_transforms = transforms.Compose([transforms.ToTensor()]) 76 | 77 | data_transforms = {'train': train_transforms, 'test': test_transforms, 'val': test_transforms} 78 | 79 | if concat: 80 | train = torch.utils.data.ConcatDataset( 81 | [Dataset(train_path_list, data_transforms['train'], verbose=verbose, grey=grey) for _ in range(batch_sizes['train'])]) 82 | else: 83 | train = Dataset(train_path_list, data_transforms['train'], verbose=verbose, grey=grey) 84 | 85 | image_datasets = {'train': train, 86 | 'test': Dataset(test_path_list, data_transforms['test'], verbose=verbose, grey=grey), 87 | 'val': Dataset(val_path_list, data_transforms['test'], verbose=verbose, grey=grey)} 88 | 89 | if len(val_path_list) == 0 or len(train_path_list) == 0: 90 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 91 | num_workers=n_worker, drop_last=drop_last, shuffle=(x == 'train')) 92 | for x in ['test']} 93 | else: 94 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 95 | num_workers=n_worker,drop_last=drop_last, shuffle=(x == 'train')) for x in ['train', 'test', 'val']} 96 | return dataloaders 97 | -------------------------------------------------------------------------------- /dataloaders_hsi.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | from os import listdir, path 4 | from PIL import Image 5 | import torch 6 | import torchvision.transforms.functional as TF 7 | import random 8 | from typing import Sequence 9 | from itertools import repeat 10 | import scipy.io as scio 11 | import numpy as np 12 | import torch 13 | import re 14 | from torch._six import container_abcs, string_classes, int_classes 15 | 16 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 17 | def repeater(data_loader): 18 | for loader in repeat(data_loader): 19 | for data in loader: 20 | yield data 21 | class MyResize: 22 | def __init__(self, scale,crop): 23 | self.scale = scale 24 | self.crop = crop 25 | 26 | 27 | def __call__(self, x): 28 | bands = x.shape[2] 29 | # if bands > 31: 30 | # bs = int(np.random.rand(1) * bands) 31 | # if bs + 31 > bands: 32 | # bs = bands - 31 33 | # x = x[:, :, bs:bs + 31] 34 | im_sz=x.shape 35 | rs=[int(im_sz[0]*self.scale),int(im_sz[1]*self.scale)] 36 | if rs[0] _w: 121 | x2 = _w 122 | x = _w - self.size 123 | if y2 > _h: 124 | y2 = _h 125 | y = _h - self.size 126 | cropImg = img[(x):(x2), (y):(y2), :] 127 | return cropImg 128 | 129 | # return self.cropit(img,self.size) 130 | # return img 131 | def cropit(image, crop_size): 132 | _w, _h, _b = image.shape 133 | x = random.randint(1, _w) 134 | y = random.randint(1, _h) 135 | x2 = x + crop_size 136 | y2 = y + crop_size 137 | if x2 > _w: 138 | x2 = _w 139 | x = _w - crop_size 140 | if y2 > _h: 141 | y2 = _h 142 | y = _h - crop_size 143 | cropImg = image[(x):(x2), (y):(y2), :] 144 | return cropImg 145 | class MyToTensor(object): 146 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 147 | 148 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 149 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 150 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 151 | or if the numpy.ndarray has dtype = np.uint8 152 | 153 | In the other cases, tensors are returned without scaling. 154 | """ 155 | 156 | def __call__(self, pic): 157 | """ 158 | Args: 159 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 160 | 161 | Returns: 162 | Tensor: Converted image. 163 | """ 164 | return TF.to_tensor(pic.copy()) 165 | 166 | def __repr__(self): 167 | return self.__class__.__name__ + '()' 168 | 169 | 170 | 171 | 172 | class Dataset(Dataset): 173 | def __init__(self, root_dirs, transform=None, verbose=False, grey=False): 174 | self.root_dirs = root_dirs 175 | self.transform = transform 176 | self.images_path = [] 177 | for cur_path in root_dirs: 178 | self.images_path += [path.join(cur_path, file) for file in listdir(cur_path) if file.endswith(('tif','png','jpg','jpeg','bmp','mat'))] 179 | self.verbose = verbose 180 | self.grey = grey 181 | 182 | def __len__(self): 183 | return len(self.images_path) 184 | 185 | def __getitem__(self, idx): 186 | img_name = self.images_path[idx] 187 | 188 | if self.grey: 189 | image = Image.open(img_name).convert('L') 190 | else: 191 | # image = Image.open(img_name).convert('RGB') 192 | image = scio.loadmat(img_name)['DataCube'].astype(np.float32) 193 | image=image/image.max() 194 | # image = flipit(flipit(cropit(image,crop_size=128),[0,1]),[1,0]) 195 | 196 | # image=transforms.ToPILImage(image) 197 | if self.transform: 198 | image = self.transform(image) 199 | 200 | 201 | if self.verbose: 202 | return image, img_name.split('/')[-1] 203 | 204 | return image 205 | 206 | 207 | def get_dataloaders(train_path_list, test_path_list, val_path_list, crop_size=96, batch_size=1, downscale=0, 208 | drop_last=True, concat=True, n_worker=0, scale_min=0.001, scale_max=0.1, verbose=False, grey=False): 209 | 210 | batch_sizes = {'train': batch_size, 'test':1, 'val': 1} 211 | tfs = [] 212 | # if downscale==0: 213 | # tfs = [MyRandomCrop(crop_size)] 214 | # elif downscale==1: 215 | # tfs += [transforms.RandomResizedCrop(crop_size, scale=(scale_min,scale_max), ratio=(1.0,1.0))] 216 | # elif downscale==2: 217 | # print('mode 2') 218 | # tfs += [transforms.Resize(300)] 219 | # tfs += [transforms.RandomCrop(crop_size)] 220 | scale=np.random.rand(1) 221 | # rs=int(scale) 222 | # =np.floor([,crop_size*scale]) 223 | # 224 | tfs += [ 225 | MyResize(scale, crop_size), 226 | MyRandomCrop(crop_size), 227 | MyRandomHorizontalFlip(), 228 | MyRandomVerticalFlip(), 229 | MyToTensor() 230 | ] 231 | 232 | train_transforms = transforms.Compose(tfs) 233 | test_transforms = transforms.Compose([MyToTensor()]) 234 | 235 | data_transforms = {'train': train_transforms, 'test': test_transforms, 'val': test_transforms} 236 | 237 | if concat: 238 | train = torch.utils.data.ConcatDataset( 239 | [Dataset(train_path_list, data_transforms['train'], verbose=verbose, grey=grey) for _ in range(batch_sizes['train'])]) 240 | else: 241 | train = Dataset(train_path_list, data_transforms['train'], verbose=verbose, grey=grey) 242 | 243 | image_datasets = {'train': train, 244 | 'test': Dataset(test_path_list, data_transforms['test'], verbose=verbose, grey=grey), 245 | 'val': Dataset(val_path_list, data_transforms['test'], verbose=verbose, grey=grey)} 246 | 247 | if len(val_path_list) == 0 or len(train_path_list) == 0: 248 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 249 | # num_workers=n_worker,collate_fn=collate_wrapper, drop_last=drop_last, shuffle=(x == 'train')) 250 | # for x in ['test']} 251 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 252 | num_workers=n_worker, drop_last=drop_last, shuffle=(x == 'train')) 253 | for x in ['test']} 254 | else: 255 | # dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 256 | # num_workers=n_worker, drop_last=drop_last,collate_fn=collate_wrapper, shuffle=(x == 'train')) 257 | # for x in ['train', 'test', 'val']} 258 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 259 | num_workers=n_worker,drop_last=drop_last, shuffle=(x == 'train')) for x in ['train', 'test', 'val']} 260 | return dataloaders 261 | 262 | def flipit(image, axes): 263 | 264 | if axes[0]: 265 | image = np.fliplr(image) 266 | if axes[1]: 267 | image = np.flipud(image) 268 | 269 | return image 270 | default_collate_err_msg_format = ( 271 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 272 | "dicts or lists; found {}") 273 | 274 | def collate_wrapper(batch): 275 | elem = batch[0] 276 | elem_type = type(elem) 277 | if isinstance(elem, torch.Tensor): 278 | out = None 279 | if torch.utils.data.get_worker_info() is not None: 280 | # If we're in a background process, concatenate directly into a 281 | # shared memory tensor to avoid an extra copy 282 | numel = sum([x.numel() for x in batch]) 283 | storage = elem.storage()._new_shared(numel) 284 | out = elem.new(storage) 285 | return torch.cat(batch, 0, out=out) 286 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 287 | and elem_type.__name__ != 'string_': 288 | elem = batch[0] 289 | if elem_type.__name__ == 'ndarray': 290 | # array of string classes and object 291 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 292 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 293 | 294 | return collate_wrapper([torch.as_tensor(b) for b in batch]) 295 | elif elem.shape == (): # scalars 296 | return torch.as_tensor(batch) 297 | elif isinstance(elem, float): 298 | return torch.tensor(batch, dtype=torch.float64) 299 | elif isinstance(elem, int_classes): 300 | return torch.tensor(batch) 301 | elif isinstance(elem, string_classes): 302 | return batch 303 | elif isinstance(elem, container_abcs.Mapping): 304 | return {key: default_collate([d[key] for d in batch]) for key in elem} 305 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 306 | return elem_type(*(default_collate(samples) for samples in zip(*batch))) 307 | elif isinstance(elem, container_abcs.Sequence): 308 | transposed = zip(*batch) 309 | return [default_collate(samples) for samples in transposed] 310 | 311 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 312 | 313 | 314 | # def cropit(image, seg=None, margin=5): 315 | # 316 | # fixedaxes = np.argmin(image.shape[:2]) 317 | # trimaxes = 0 if fixedaxes == 1 else 1 318 | # trim = image.shape[fixedaxes] 319 | # center = image.shape[trimaxes] // 2 320 | # if seg is not None: 321 | # 322 | # hits = np.where(seg != 0) 323 | # mins = np.argmin(hits, axis=1) 324 | # maxs = np.argmax(hits, axis=1) 325 | # 326 | # if center - (trim // 2) > mins[0]: 327 | # while center - (trim // 2) > mins[0]: 328 | # center = center - 1 329 | # center = center + margin 330 | # 331 | # if center + (trim // 2) < maxs[0]: 332 | # while center + (trim // 2) < maxs[0]: 333 | # center = center + 1 334 | # center = center + margin 335 | # 336 | # top = max(0, center - (trim // 2)) 337 | # bottom = trim if top == 0 else center + (trim // 2) 338 | # 339 | # if bottom > image.shape[trimaxes]: 340 | # bottom = image.shape[trimaxes] 341 | # top = image.shape[trimaxes] - trim 342 | # 343 | # if trimaxes == 0: 344 | # image = image[top: bottom, :, :] 345 | # else: 346 | # image = image[:, top: bottom, :] 347 | # 348 | # if seg is not None: 349 | # if trimaxes == 0: 350 | # seg = seg[top: bottom, :, :] 351 | # else: 352 | # seg = seg[:, top: bottom, :] 353 | # 354 | # return image, seg 355 | # else: 356 | # return image 357 | 358 | -------------------------------------------------------------------------------- /dataloaders_hsi_test.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | from os import listdir, path 4 | from PIL import Image 5 | import torch 6 | import math 7 | import torchvision.transforms.functional as TF 8 | import random 9 | from typing import Sequence 10 | from itertools import repeat 11 | import scipy.io as scio 12 | import numpy as np 13 | import torch 14 | import re 15 | from torch._six import container_abcs, string_classes, int_classes 16 | 17 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 18 | def repeater(data_loader): 19 | for loader in repeat(data_loader): 20 | for data in loader: 21 | yield data 22 | class MyResize: 23 | def __init__(self, scale,crop): 24 | self.scale = scale 25 | self.crop = crop 26 | 27 | 28 | def __call__(self, x): 29 | bands = x.shape[2] 30 | if bands > 31: 31 | bs = int(np.random.rand(1) * bands) 32 | if bs + 31 > bands: 33 | bs = bands - 31 34 | x = x[:, :, bs:bs + 31] 35 | im_sz=x.shape 36 | rs=[int(im_sz[0]*self.scale),int(im_sz[1]*self.scale)] 37 | if rs[0] _w: 126 | x2 = _w 127 | x = _w - self.size 128 | if y2 > _h: 129 | y2 = _h 130 | y = _h - self.size 131 | cropImg = img[(x):(x2), (y):(y2), :] 132 | return cropImg 133 | 134 | # return self.cropit(img,self.size) 135 | # return img 136 | def cropit(image, crop_size): 137 | _w, _h, _b = image.shape 138 | x = random.randint(1, _w) 139 | y = random.randint(1, _h) 140 | x2 = x + crop_size 141 | y2 = y + crop_size 142 | if x2 > _w: 143 | x2 = _w 144 | x = _w - crop_size 145 | if y2 > _h: 146 | y2 = _h 147 | y = _h - crop_size 148 | cropImg = image[(x):(x2), (y):(y2), :] 149 | return cropImg 150 | class MyToTensor(object): 151 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 152 | 153 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 154 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 155 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 156 | or if the numpy.ndarray has dtype = np.uint8 157 | 158 | In the other cases, tensors are returned without scaling. 159 | """ 160 | 161 | def __call__(self, pic): 162 | """ 163 | Args: 164 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 165 | 166 | Returns: 167 | Tensor: Converted image. 168 | """ 169 | return TF.to_tensor(pic.copy()) 170 | 171 | def __repr__(self): 172 | return self.__class__.__name__ + '()' 173 | 174 | 175 | 176 | 177 | class Dataset(Dataset): 178 | def __init__(self, root_dirs, transform=None, verbose=False, grey=False): 179 | self.root_dirs = root_dirs 180 | self.transform = transform 181 | self.images_path = [] 182 | for cur_path in root_dirs: 183 | self.images_path += [path.join(cur_path, file) for file in listdir(cur_path) if file.endswith(('tif','png','jpg','jpeg','bmp','mat'))] 184 | self.verbose = verbose 185 | self.grey = grey 186 | 187 | def __len__(self): 188 | return len(self.images_path) 189 | 190 | def __getitem__(self, idx): 191 | img_name = self.images_path[idx] 192 | 193 | if self.grey: 194 | image = Image.open(img_name).convert('L') 195 | else: 196 | # image = Image.open(img_name).convert('RGB') 197 | image = scio.loadmat(img_name)['DataCube'].astype(np.float32) 198 | # image=image/image.max() 199 | # image = flipit(flipit(cropit(image,crop_size=128),[0,1]),[1,0]) 200 | 201 | # image=transforms.ToPILImage(image) 202 | if self.transform: 203 | image = self.transform(image) 204 | 205 | 206 | if self.verbose: 207 | return image, img_name.split('/')[-1] 208 | 209 | return image 210 | def get_gt(gt_path, img_name,verbose=False, grey=False): 211 | tfs = [] 212 | tfs += [ 213 | # MyRotation90(), 214 | # MyCenterCrop(), 215 | MyToTensor() 216 | ] 217 | gt_transforms = transforms.Compose(tfs) 218 | image = scio.loadmat(gt_path+img_name)['DataCube'].astype(np.float32) 219 | image = gt_transforms(image) 220 | image=image/image.max() 221 | return image 222 | 223 | def get_dataloaders(test_path_list, crop_size=96, batch_size=1, downscale=0, 224 | drop_last=True, concat=True, n_worker=0, scale_min=0.001, scale_max=0.1, verbose=False, grey=False): 225 | 226 | batch_sizes = {'test':1, 'gt': 1} 227 | test_transforms = transforms.Compose([MyToTensor()]) 228 | data_transforms = {'test': test_transforms} 229 | image_datasets = {'test': Dataset(test_path_list, data_transforms['test'], verbose=verbose, grey=grey)} 230 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 231 | num_workers=n_worker,drop_last=drop_last, shuffle=False) for x in ['test']} 232 | return dataloaders 233 | 234 | def flipit(image, axes): 235 | 236 | if axes[0]: 237 | image = np.fliplr(image) 238 | if axes[1]: 239 | image = np.flipud(image) 240 | 241 | return image 242 | default_collate_err_msg_format = ( 243 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 244 | "dicts or lists; found {}") 245 | 246 | 247 | # def cropit(image, seg=None, margin=5): 248 | # 249 | # fixedaxes = np.argmin(image.shape[:2]) 250 | # trimaxes = 0 if fixedaxes == 1 else 1 251 | # trim = image.shape[fixedaxes] 252 | # center = image.shape[trimaxes] // 2 253 | # if seg is not None: 254 | # 255 | # hits = np.where(seg != 0) 256 | # mins = np.argmin(hits, axis=1) 257 | # maxs = np.argmax(hits, axis=1) 258 | # 259 | # if center - (trim // 2) > mins[0]: 260 | # while center - (trim // 2) > mins[0]: 261 | # center = center - 1 262 | # center = center + margin 263 | # 264 | # if center + (trim // 2) < maxs[0]: 265 | # while center + (trim // 2) < maxs[0]: 266 | # center = center + 1 267 | # center = center + margin 268 | # 269 | # top = max(0, center - (trim // 2)) 270 | # bottom = trim if top == 0 else center + (trim // 2) 271 | # 272 | # if bottom > image.shape[trimaxes]: 273 | # bottom = image.shape[trimaxes] 274 | # top = image.shape[trimaxes] - trim 275 | # 276 | # if trimaxes == 0: 277 | # image = image[top: bottom, :, :] 278 | # else: 279 | # image = image[:, top: bottom, :] 280 | # 281 | # if seg is not None: 282 | # if trimaxes == 0: 283 | # seg = seg[top: bottom, :, :] 284 | # else: 285 | # seg = seg[:, top: bottom, :] 286 | # 287 | # return image, seg 288 | # else: 289 | # return image 290 | 291 | -------------------------------------------------------------------------------- /model/MACNet.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from ops.utils import est_noise,count 3 | # from model.qrnn.combinations import * 4 | from model.non_local import NLBlockND,EfficientNL 5 | from model.combinations import * 6 | Params = namedtuple('Params', ['in_channels', 'channels', 'num_half_layer','rs']) 7 | from skimage.restoration import denoise_nl_means,estimate_sigma 8 | class MACNet(nn.Module): 9 | ''' 10 | Tied lista with coupling 11 | ''' 12 | 13 | def __init__(self, params: Params): 14 | super(MACNet, self).__init__() 15 | self.params=params 16 | self.net=REDC3DBNRES_NL(in_channels=params.in_channels,channels=params.channels,num_half_layer=params.num_half_layer) 17 | 18 | def forward(self, I, writer=None, epoch=None, return_patches=False): 19 | 20 | return self.pro_sub(I) 21 | 22 | def pro_sub(self, I): 23 | R = list() 24 | Ek = list() 25 | Rw = list() 26 | I_iid = list() 27 | sigma_est = 0 28 | I_size = I.shape 29 | for _I in I: 30 | _I = _I.permute([1, 2, 0]) 31 | _, _, w, _Rw = count(_I) # count subspace 32 | _I = torch.matmul(_I, torch.inverse(_Rw).sqrt()) # spectral iid 33 | I_nlm = _I.cpu().numpy() 34 | sigma_est = estimate_sigma(I_nlm, multichannel=True, average_sigmas=True) 35 | I_nlm = denoise_nl_means(I_nlm, patch_size=7, patch_distance=9, h=0.08, multichannel=True, 36 | fast_mode=True, sigma=sigma_est) 37 | I_nlm = torch.FloatTensor(I_nlm).to(device=_I.device) 38 | _R, _Ek, _, _ = count(I_nlm) 39 | if self.params.rs: 40 | _R = _R // 3 41 | # _R = max(_R, torch.FloatTensor(3).to(I.device)) 42 | R.append(_R) 43 | Ek.append(_Ek) 44 | Rw.append(_Rw) 45 | I_iid.append(_I) 46 | dim = max(torch.stack(R).max(), 3) 47 | Ek = torch.stack(Ek, dim=0) 48 | I_iid = torch.stack(I_iid, dim=0) 49 | Ek = Ek[:, :, 0:dim] 50 | Rw = torch.stack(Rw, dim=0) 51 | I_sub = torch.bmm(I_iid.view(I_size[0], -1, I_size[1]), Ek) 52 | I_sub = I_sub.view(I_size[0], I_size[2], I_size[3], -1).permute([0, 3, 1, 2]) 53 | CNN_sub = self.net(I_sub.unsqueeze(1)).squeeze(1) 54 | CNN_sub = CNN_sub.view(I_size[0], dim, -1) 55 | output = torch.bmm(Rw.sqrt(), torch.bmm(Ek, CNN_sub)) 56 | output = output.view(I_size) 57 | return output 58 | class REDC3DBNRES_NL(torch.nn.Module): 59 | """Residual Encoder-Decoder Convolution 3D 60 | Args: 61 | downsample: downsample times, None denotes no downsample""" 62 | 63 | def __init__(self, in_channels, channels, num_half_layer, downsample=None): 64 | super(REDC3DBNRES_NL, self).__init__() 65 | # Encoder 66 | # assert downsample is None or 0 < downsample <= num_half_layer 67 | interval = 2 68 | 69 | self.feature_extractor = BNReLUConv3d(in_channels, channels) 70 | self.encoder = nn.ModuleList() 71 | for i in range(1, num_half_layer + 1): 72 | if i % interval: 73 | encoder_layer = BNReLUConv3d(channels, channels) 74 | else: 75 | encoder_layer = BNReLUConv3d(channels, 2 * channels, k=3, s=(1, 2, 2), p=1) 76 | channels *= 2 77 | self.encoder.append(encoder_layer) 78 | # Decoder 79 | self.decoder = nn.ModuleList() 80 | for i in range(1, num_half_layer + 1): 81 | if i % interval: 82 | decoder_layer = BNReLUDeConv3d(channels, channels) 83 | else: 84 | decoder_layer = BNReLUUpsampleConv3d(channels, channels // 2) 85 | channels //= 2 86 | self.decoder.append(decoder_layer) 87 | self.reconstructor = BNReLUDeConv3d(channels, in_channels) 88 | # self.enl_1 = EfficientNL(in_channels=channels) 89 | self.enl_2 = EfficientNL(in_channels=channels) 90 | self.enl_3 = EfficientNL(in_channels=1,key_channels=1,value_channels=1,head_count=1) 91 | 92 | # = None, head_count = None, = None 93 | def forward(self, x): 94 | num_half_layer = len(self.encoder) 95 | xs = [x] 96 | out = self.feature_extractor(xs[0]) 97 | xs.append(out) 98 | for i in range(num_half_layer - 1): 99 | out = self.encoder[i](out) 100 | xs.append(out) 101 | out = self.encoder[-1](out) 102 | # out = self.nl_1(out) 103 | out = self.decoder[0](out) 104 | for i in range(1, num_half_layer): 105 | out = out + xs.pop() 106 | out = self.decoder[i](out) 107 | out = self.enl_2(out) + xs.pop() 108 | out = self.reconstructor(out) 109 | out = self.enl_3(out) + xs.pop() 110 | return out 111 | -------------------------------------------------------------------------------- /model/__pycache__/CBAM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/CBAM.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/MACNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/MACNet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubCNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubCNN.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubCNN_NL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubCNN_NL.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubCNN_NLM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubCNN_NLM.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubCNN_NL_TIP.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubCNN_NL_TIP.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubCNN_QRNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubCNN_QRNN.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubSCNN_BNRED.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubSCNN_BNRED.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubSCNN_BNREDCBAM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubSCNN_BNREDCBAM.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SubSCNN_BNREDRES.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/SubSCNN_BNREDRES.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/combinations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/combinations.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/non_local.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/non_local.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/pyramidpooling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/__pycache__/pyramidpooling.cpython-37.pyc -------------------------------------------------------------------------------- /model/combinations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | from model.sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 5 | 6 | BatchNorm3d = SynchronizedBatchNorm3d 7 | BatchNorm2d=SynchronizedBatchNorm2d 8 | 9 | class BNReLUConv3d(nn.Sequential): 10 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 11 | super(BNReLUConv3d, self).__init__() 12 | self.add_module('bn', BatchNorm3d(in_channels)) 13 | self.add_module('relu', nn.ReLU(inplace=inplace)) 14 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 15 | class BNReLUConv2d(nn.Sequential): 16 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 17 | super(BNReLUConv2d, self).__init__() 18 | self.add_module('bn', BatchNorm2d(in_channels)) 19 | self.add_module('relu', nn.ReLU(inplace=inplace)) 20 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=False)) 21 | class Conv3dBNReLU(nn.Sequential): 22 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 23 | super(Conv3dBNReLU, self).__init__() 24 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 25 | self.add_module('bn', BatchNorm3d(channels)) 26 | self.add_module('relu', nn.ReLU(inplace=inplace)) 27 | class Conv2dBNReLU(nn.Sequential): 28 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 29 | super(Conv2dBNReLU, self).__init__() 30 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=False)) 31 | self.add_module('bn', BatchNorm2d(channels)) 32 | self.add_module('relu', nn.ReLU(inplace=inplace)) 33 | 34 | class BNReLUDeConv3d(nn.Sequential): 35 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 36 | super(BNReLUDeConv3d, self).__init__() 37 | self.add_module('bn', BatchNorm3d(in_channels)) 38 | self.add_module('relu', nn.ReLU(inplace=inplace)) 39 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 40 | class BNReLUDeConv2d(nn.Sequential): 41 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 42 | super(BNReLUDeConv2d, self).__init__() 43 | self.add_module('bn', BatchNorm2d(in_channels)) 44 | self.add_module('relu', nn.ReLU(inplace=inplace)) 45 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=False)) 46 | 47 | class DeConv3dBNReLU(nn.Sequential): 48 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 49 | super(DeConv3dBNReLU, self).__init__() 50 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 51 | self.add_module('bn', BatchNorm3d(channels)) 52 | self.add_module('relu', nn.ReLU(inplace=inplace)) 53 | class DeConv2dBNReLU(nn.Sequential): 54 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 55 | super(DeConv3dBNReLU, self).__init__() 56 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=False)) 57 | self.add_module('bn', BatchNorm2d(channels)) 58 | self.add_module('relu', nn.ReLU(inplace=inplace)) 59 | 60 | 61 | 62 | class ReLUDeConv3d(nn.Sequential): 63 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 64 | super(ReLUDeConv3d, self).__init__() 65 | self.add_module('relu', nn.ReLU(inplace=inplace)) 66 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 67 | class ReLUDeConv2d(nn.Sequential): 68 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 69 | super(ReLUDeConv2d, self).__init__() 70 | self.add_module('relu', nn.ReLU(inplace=inplace)) 71 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=False)) 72 | 73 | class BNReLUUpsampleConv3d(nn.Sequential): 74 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 75 | super(BNReLUUpsampleConv3d, self).__init__() 76 | self.add_module('bn', BatchNorm3d(in_channels)) 77 | self.add_module('relu', nn.ReLU(inplace=inplace)) 78 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 79 | class BNReLUUpsampleConv2d(nn.Sequential): 80 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(2,2), inplace=False): 81 | super(BNReLUUpsampleConv2d, self).__init__() 82 | self.add_module('bn', BatchNorm2d(in_channels)) 83 | self.add_module('relu', nn.ReLU(inplace=inplace)) 84 | self.add_module('upsample_conv', UpsampleConv2d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 85 | 86 | class UpsampleConv3dBNReLU(nn.Sequential): 87 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 88 | super(UpsampleConv3dBNReLU, self).__init__() 89 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 90 | self.add_module('bn', BatchNorm3d(channels)) 91 | self.add_module('relu', nn.ReLU(inplace=inplace)) 92 | class UpsampleConv2dBNReLU(nn.Sequential): 93 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 94 | super(UpsampleConv2dBNReLU, self).__init__() 95 | self.add_module('upsample_conv', UpsampleConv2d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 96 | self.add_module('bn', BatchNorm2d(channels)) 97 | self.add_module('relu', nn.ReLU(inplace=inplace)) 98 | 99 | 100 | 101 | class Conv3dReLU(nn.Sequential): 102 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False,bn=False): 103 | super(Conv3dReLU, self).__init__() 104 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 105 | if bn: 106 | self.add_module('bn', BatchNorm3d(channels)) 107 | 108 | class Conv2dReLU(nn.Sequential): 109 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False,bn=False): 110 | super(Conv2dReLU, self).__init__() 111 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=False)) 112 | if bn: 113 | self.add_module('bn', BatchNorm2d(channels)) 114 | self.add_module('relu', nn.ReLU(inplace=inplace)) 115 | 116 | 117 | class DeConv3dReLU(nn.Sequential): 118 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False,bn=False): 119 | super(DeConv3dReLU, self).__init__() 120 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 121 | if bn: 122 | self.add_module('bn', BatchNorm3d(channels)) 123 | self.add_module('relu', nn.ReLU(inplace=inplace)) 124 | 125 | 126 | class DeConv2dReLU(nn.Sequential): 127 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False,bn=False): 128 | super(DeConv2dReLU, self).__init__() 129 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=False)) 130 | if bn: 131 | self.add_module('bn', BatchNorm2d(channels)) 132 | self.add_module('relu', nn.ReLU(inplace=inplace)) 133 | 134 | class UpsampleConv3dReLU(nn.Sequential): 135 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False,bn=False): 136 | super(UpsampleConv3dReLU, self).__init__() 137 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 138 | if bn: 139 | self.add_module('bn', BatchNorm3d(channels)) 140 | self.add_module('relu', nn.ReLU(inplace=inplace)) 141 | class UpsampleConv2dReLU(nn.Sequential): 142 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 143 | super(UpsampleConv2dReLU, self).__init__() 144 | self.add_module('upsample_conv', UpsampleConv2d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 145 | self.add_module('relu', nn.ReLU(inplace=inplace)) 146 | 147 | class UpsampleConv3d(torch.nn.Module): 148 | """UpsampleConvLayer 149 | Upsamples the input and then does a convolution. This method gives better results 150 | compared to ConvTranspose2d. 151 | ref: http://distill.pub/2016/deconv-checkerboard/ 152 | """ 153 | 154 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 155 | super(UpsampleConv3d, self).__init__() 156 | self.upsample = upsample 157 | if upsample: 158 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 159 | 160 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 161 | 162 | def forward(self, x): 163 | x_in = x 164 | if self.upsample: 165 | x_in = self.upsample_layer(x_in) 166 | out = self.conv3d(x_in) 167 | return out 168 | 169 | 170 | class UpsampleConv2d(torch.nn.Module): 171 | """UpsampleConvLayer 172 | Upsamples the input and then does a convolution. This method gives better results 173 | compared to ConvTranspose2d. 174 | ref: http://distill.pub/2016/deconv-checkerboard/ 175 | """ 176 | 177 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 178 | super(UpsampleConv2d, self).__init__() 179 | self.upsample = upsample 180 | if upsample: 181 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='bilinear', align_corners=True) 182 | 183 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 184 | 185 | def forward(self, x): 186 | x_in = x 187 | if self.upsample: 188 | x_in = self.upsample_layer(x_in) 189 | out = self.conv2d(x_in) 190 | return out 191 | 192 | 193 | class BasicConv3d(nn.Sequential): 194 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 195 | super(BasicConv3d, self).__init__() 196 | if bn: 197 | self.add_module('bn', BatchNorm3d(in_channels)) 198 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias)) 199 | 200 | class BasicConv2d(nn.Sequential): 201 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 202 | super(BasicConv2d, self).__init__() 203 | if bn: 204 | self.add_module('bn', BatchNorm2d(in_channels)) 205 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=bias)) 206 | 207 | class BasicDeConv3d(nn.Sequential): 208 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 209 | super(BasicDeConv3d, self).__init__() 210 | if bn: 211 | self.add_module('bn', BatchNorm3d(in_channels)) 212 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias)) 213 | 214 | class BasicDeConv2d(nn.Sequential): 215 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 216 | super(BasicDeConv2d, self).__init__() 217 | if bn: 218 | self.add_module('bn', BatchNorm2d(in_channels)) 219 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=bias)) 220 | 221 | 222 | class BasicUpsampleConv3d(nn.Sequential): 223 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True): 224 | super(BasicUpsampleConv3d, self).__init__() 225 | if bn: 226 | self.add_module('bn', BatchNorm3d(in_channels)) 227 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 228 | class BasicUpsampleConv2d(nn.Sequential): 229 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True): 230 | super(BasicUpsampleConv2d, self).__init__() 231 | if bn: 232 | self.add_module('bn', BatchNorm3d(in_channels)) 233 | self.add_module('upsample_conv', UpsampleConv2d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 234 | -------------------------------------------------------------------------------- /model/non_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | class EfficientNL(nn.Module): 5 | def __init__(self, in_channels, key_channels=None, head_count=None, value_channels=None): 6 | super(EfficientNL, self).__init__() 7 | self.in_channels = in_channels 8 | self.key_channels = key_channels 9 | self.head_count = head_count 10 | self.value_channels = value_channels 11 | if self.key_channels==None: 12 | self.key_channels=self.in_channels//2 13 | if self.value_channels == None: 14 | self.value_channels = self.in_channels // 2 15 | if self.head_count == None: 16 | self.head_count = self.head_count=2 17 | self.keys = nn.Conv3d( self.in_channels, self.key_channels, 1) 18 | self.queries = nn.Conv3d( self.in_channels, self.key_channels, 1) 19 | self.values = nn.Conv3d( self.in_channels, self.value_channels, 1) 20 | self.reprojection = nn.Conv3d(self.value_channels, self.in_channels, 1) 21 | 22 | def forward(self, input_): 23 | n, _,c, h, w = input_.size() 24 | keys = self.keys(input_).reshape((n, self.key_channels,-1)) 25 | queries = self.queries(input_).reshape(n, self.key_channels, -1) 26 | values = self.values(input_).reshape((n, self.value_channels, -1)) 27 | head_key_channels = self.key_channels // self.head_count 28 | head_value_channels = self.value_channels // self.head_count 29 | 30 | attended_values = [] 31 | for i in range(self.head_count): 32 | key = F.softmax(keys[ 33 | :, 34 | i * head_key_channels: (i + 1) * head_key_channels, 35 | : 36 | ], dim=2) 37 | query = F.softmax(queries[ 38 | :, 39 | i * head_key_channels: (i + 1) * head_key_channels, 40 | : 41 | ], dim=1) 42 | value = values[ 43 | :, 44 | i * head_value_channels: (i + 1) * head_value_channels, 45 | : 46 | ] 47 | context = key @ value.transpose(1, 2) 48 | attended_value = ( 49 | context.transpose(1, 2) @ query 50 | ).reshape(n, head_value_channels,c, h, w) 51 | attended_values.append(attended_value) 52 | 53 | aggregated_values = torch.cat(attended_values, dim=1) 54 | reprojected_value = self.reprojection(aggregated_values) 55 | attention = reprojected_value + input_ 56 | return attention 57 | 58 | class NLBlockND(nn.Module): 59 | def __init__(self, in_channels, inter_channels=None, mode='embedded', 60 | dimension=3, bn_layer=True, levels=None): 61 | """Implementation of Non-Local Block with 4 different pairwise functions 62 | args: 63 | in_channels: original channel size (1024 in the paper) 64 | inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) 65 | mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 66 | dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) 67 | bn_layer: whether to add batch norm 68 | """ 69 | super(NLBlockND, self).__init__() 70 | 71 | assert dimension in [1, 2, 3] 72 | 73 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 74 | raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`') 75 | 76 | self.mode = mode 77 | self.dimension = dimension 78 | self.in_channels = in_channels 79 | self.inter_channels = inter_channels 80 | if levels is not None: 81 | self.ssp=True 82 | self.p = SpatialPyramidPooling(levels=[2*i+1 for i in range(0,levels)]) 83 | else: 84 | self.ssp = False 85 | # the channel size is reduced to half inside the block 86 | if self.inter_channels is None: 87 | self.inter_channels = in_channels // 4 88 | if self.inter_channels == 0: 89 | self.inter_channels = 1 90 | 91 | # assign appropriate convolutional, max pool, and batch norm layers for different dimensions 92 | if dimension == 3: 93 | conv_nd = nn.Conv3d 94 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 95 | bn = nn.BatchNorm3d 96 | elif dimension == 2: 97 | conv_nd = nn.Conv2d 98 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 99 | bn = nn.BatchNorm2d 100 | else: 101 | conv_nd = nn.Conv1d 102 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 103 | bn = nn.BatchNorm1d 104 | 105 | # function g in the paper which goes through conv. with kernel size 1 106 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 107 | # add BatchNorm layer after the last conv layer 108 | if bn_layer: 109 | self.W_z = nn.Sequential( 110 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), 111 | bn(self.in_channels) 112 | ) 113 | nn.init.constant_(self.W_z[1].weight, 0) 114 | nn.init.constant_(self.W_z[1].bias, 0) 115 | else: 116 | self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1) 117 | nn.init.constant_(self.W_z.weight, 0) 118 | nn.init.constant_(self.W_z.bias, 0) 119 | 120 | # define theta and phi for all operations except gaussian 121 | if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate": 122 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 123 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 124 | 125 | if self.mode == "concatenate": 126 | self.W_f = nn.Sequential( 127 | nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1), 128 | nn.ReLU() 129 | ) 130 | # print() 131 | def forward(self, x): 132 | """ 133 | args 134 | x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 135 | """ 136 | 137 | batch_size,c,t,h,w = x.size() 138 | 139 | # (N, C, THW) 140 | g_x = self.g(x).view(batch_size, -1, h,w) 141 | if self.ssp: 142 | g_x = self.p(g_x) 143 | g_x=g_x.view(batch_size, self.inter_channels, -1) 144 | g_x = g_x.permute(0, 2, 1) 145 | # print(self.mode) 146 | if self.mode == "gaussian": 147 | theta_x = x.view(batch_size, self.in_channels, -1) 148 | phi_x = x.view(batch_size, self.in_channels, -1) 149 | theta_x = theta_x.permute(0, 2, 1) 150 | f = torch.matmul(theta_x, phi_x) 151 | 152 | elif self.mode == "embedded" or self.mode == "dot": 153 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 154 | phi_x = self.phi(x).view(batch_size, -1, h,w) 155 | if self.ssp: 156 | phi_x=self.p(phi_x) 157 | phi_x=phi_x.view(batch_size, self.inter_channels, -1) 158 | theta_x = theta_x.permute(0, 2, 1) 159 | f = torch.matmul(theta_x, phi_x) 160 | 161 | elif self.mode == "concatenate": 162 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 163 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 164 | 165 | h = theta_x.size(2) 166 | w = phi_x.size(3) 167 | theta_x = theta_x.repeat(1, 1, 1, w) 168 | phi_x = phi_x.repeat(1, 1, h, 1) 169 | 170 | concat = torch.cat([theta_x, phi_x], dim=1) 171 | f = self.W_f(concat) 172 | f = f.view(f.size(0), f.size(2), f.size(3)) 173 | 174 | if self.mode == "gaussian" or self.mode == "embedded": 175 | f_div_C = F.softmax(f, dim=-1) 176 | elif self.mode == "dot" or self.mode == "concatenate": 177 | N = f.size(-1) # number of position in x 178 | f_div_C = f / N 179 | # print(f_div_C.shape) 180 | # print(g_x.shape) 181 | y = torch.matmul(f_div_C, g_x) 182 | 183 | # contiguous here just allocates contiguous chunk of memory 184 | y = y.permute(0, 2, 1).contiguous() 185 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 186 | 187 | W_y = self.W_z(y) 188 | # residual connection 189 | z = W_y + x 190 | 191 | return z 192 | 193 | 194 | if __name__ == '__main__': 195 | import torch 196 | 197 | # for bn_layer in [True, False]: 198 | # img = torch.zeros(2, 3, 20) 199 | # net = NLBlockND(in_channels=3, mode='concatenate', dimension=1, bn_layer=bn_layer) 200 | # out = net(img) 201 | # print(out.size()) 202 | # 203 | # img = torch.zeros(2, 3, 20, 20) 204 | # net = NLBlockND(in_channels=3, mode='concatenate', dimension=2, bn_layer=bn_layer) 205 | # out = net(img) 206 | # print(out.size()) 207 | 208 | img = torch.randn(1, 16, 31, 512, 512) 209 | net = EfficientNL(in_channels=16) 210 | out = net(img) 211 | print(out.size()) 212 | 213 | -------------------------------------------------------------------------------- /model/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /model/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /model/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /model/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/model/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /model/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | 132 | .. math:: 133 | 134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 135 | 136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 137 | standard-deviation are reduced across all devices during training. 138 | 139 | For example, when one uses `nn.DataParallel` to wrap the network during 140 | training, PyTorch's implementation normalize the tensor on each device using 141 | the statistics only on that device, which accelerated the computation and 142 | is also easy to implement, but the statistics might be inaccurate. 143 | Instead, in this synchronized version, the statistics will be computed 144 | over all training samples distributed on multiple devices. 145 | 146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 147 | as the built-in PyTorch implementation. 148 | 149 | The mean and standard-deviation are calculated per-dimension over 150 | the mini-batches and gamma and beta are learnable parameter vectors 151 | of size C (where C is the input size). 152 | 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | 156 | During evaluation, this running mean/variance is used for normalization. 157 | 158 | Because the BatchNorm is done over the `C` dimension, computing statistics 159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 160 | 161 | Args: 162 | num_features: num_features from an expected input of size 163 | `batch_size x num_features [x width]` 164 | eps: a value added to the denominator for numerical stability. 165 | Default: 1e-5 166 | momentum: the value used for the running_mean and running_var 167 | computation. Default: 0.1 168 | affine: a boolean value that when set to ``True``, gives the layer learnable 169 | affine parameters. Default: ``True`` 170 | 171 | Shape: 172 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 174 | 175 | Examples: 176 | >>> # With Learnable Parameters 177 | >>> m = SynchronizedBatchNorm1d(100) 178 | >>> # Without Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 180 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 181 | >>> output = m(input) 182 | """ 183 | 184 | def _check_input_dim(self, input): 185 | if input.dim() != 2 and input.dim() != 3: 186 | raise ValueError('expected 2D or 3D input (got {}D input)' 187 | .format(input.dim())) 188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 189 | 190 | 191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 193 | of 3d inputs 194 | 195 | .. math:: 196 | 197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 198 | 199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 200 | standard-deviation are reduced across all devices during training. 201 | 202 | For example, when one uses `nn.DataParallel` to wrap the network during 203 | training, PyTorch's implementation normalize the tensor on each device using 204 | the statistics only on that device, which accelerated the computation and 205 | is also easy to implement, but the statistics might be inaccurate. 206 | Instead, in this synchronized version, the statistics will be computed 207 | over all training samples distributed on multiple devices. 208 | 209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 210 | as the built-in PyTorch implementation. 211 | 212 | The mean and standard-deviation are calculated per-dimension over 213 | the mini-batches and gamma and beta are learnable parameter vectors 214 | of size C (where C is the input size). 215 | 216 | During training, this layer keeps a running estimate of its computed mean 217 | and variance. The running sum is kept with a default momentum of 0.1. 218 | 219 | During evaluation, this running mean/variance is used for normalization. 220 | 221 | Because the BatchNorm is done over the `C` dimension, computing statistics 222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 223 | 224 | Args: 225 | num_features: num_features from an expected input of 226 | size batch_size x num_features x height x width 227 | eps: a value added to the denominator for numerical stability. 228 | Default: 1e-5 229 | momentum: the value used for the running_mean and running_var 230 | computation. Default: 0.1 231 | affine: a boolean value that when set to ``True``, gives the layer learnable 232 | affine parameters. Default: ``True`` 233 | 234 | Shape: 235 | - Input: :math:`(N, C, H, W)` 236 | - Output: :math:`(N, C, H, W)` (same shape as input) 237 | 238 | Examples: 239 | >>> # With Learnable Parameters 240 | >>> m = SynchronizedBatchNorm2d(100) 241 | >>> # Without Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 244 | >>> output = m(input) 245 | """ 246 | 247 | def _check_input_dim(self, input): 248 | if input.dim() != 4: 249 | raise ValueError('expected 4D input (got {}D input)' 250 | .format(input.dim())) 251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 252 | 253 | 254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 256 | of 4d inputs 257 | 258 | .. math:: 259 | 260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 261 | 262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 263 | standard-deviation are reduced across all devices during training. 264 | 265 | For example, when one uses `nn.DataParallel` to wrap the network during 266 | training, PyTorch's implementation normalize the tensor on each device using 267 | the statistics only on that device, which accelerated the computation and 268 | is also easy to implement, but the statistics might be inaccurate. 269 | Instead, in this synchronized version, the statistics will be computed 270 | over all training samples distributed on multiple devices. 271 | 272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 273 | as the built-in PyTorch implementation. 274 | 275 | The mean and standard-deviation are calculated per-dimension over 276 | the mini-batches and gamma and beta are learnable parameter vectors 277 | of size C (where C is the input size). 278 | 279 | During training, this layer keeps a running estimate of its computed mean 280 | and variance. The running sum is kept with a default momentum of 0.1. 281 | 282 | During evaluation, this running mean/variance is used for normalization. 283 | 284 | Because the BatchNorm is done over the `C` dimension, computing statistics 285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 286 | or Spatio-temporal BatchNorm 287 | 288 | Args: 289 | num_features: num_features from an expected input of 290 | size batch_size x num_features x depth x height x width 291 | eps: a value added to the denominator for numerical stability. 292 | Default: 1e-5 293 | momentum: the value used for the running_mean and running_var 294 | computation. Default: 0.1 295 | affine: a boolean value that when set to ``True``, gives the layer learnable 296 | affine parameters. Default: ``True`` 297 | 298 | Shape: 299 | - Input: :math:`(N, C, D, H, W)` 300 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 301 | 302 | Examples: 303 | >>> # With Learnable Parameters 304 | >>> m = SynchronizedBatchNorm3d(100) 305 | >>> # Without Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 308 | >>> output = m(input) 309 | """ 310 | 311 | def _check_input_dim(self, input): 312 | if input.dim() != 5: 313 | raise ValueError('expected 5D input (got {}D input)' 314 | .format(input.dim())) 315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 316 | -------------------------------------------------------------------------------- /model/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /model/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /model/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /model_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | from model.MACNet import Params 5 | from model.MACNet import MACNet 6 | def init_model(in_channels,channels,num_half_layer,rs): 7 | params = Params(in_channels=in_channels, channels=channels, 8 | num_half_layer=num_half_layer,rs=rs) 9 | model = MACNet(params) 10 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 11 | print('Nb tensors: ',len(list(model.named_parameters())), "; Trainable Params: ", pytorch_total_params) 12 | return model 13 | def load_model(model_name,model,device_name): 14 | out_dir = os.path.join(model_name) 15 | ckpt_path = os.path.join(out_dir) 16 | if os.path.isfile(ckpt_path): 17 | try: 18 | print('\n existing ckpt detected') 19 | checkpoint = torch.load(ckpt_path) 20 | state_dict = checkpoint['state_dict'] 21 | new_state_dict = OrderedDict() 22 | for k, v in state_dict.items(): 23 | if device_name=="cpu": 24 | name = k[7:] # remove 'module.' of dataparallel 25 | else: 26 | name = k 27 | new_state_dict[name] = v 28 | model.load_state_dict(new_state_dict, strict=True) 29 | except Exception as e: 30 | print(e) 31 | print(f'ckpt loading failed @{ckpt_path}, exit ...') 32 | exit() 33 | 34 | else: 35 | print(f'\nno ckpt found @{ckpt_path}') 36 | exit() 37 | if torch.cuda.is_available(): 38 | torch.backends.cudnn.benchmark = True 39 | -------------------------------------------------------------------------------- /ops/__pycache__/gauss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/ops/__pycache__/gauss.cpython-37.pyc -------------------------------------------------------------------------------- /ops/__pycache__/im2col.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/ops/__pycache__/im2col.cpython-37.pyc -------------------------------------------------------------------------------- /ops/__pycache__/im2col.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | import torch 3 | from torch.nn.modules.utils import _pair 4 | import math 5 | 6 | 7 | def Im2Col(input_tensor, kernel_size, stride, padding,dilation=1,tensorized=False,): 8 | batch = input_tensor.shape[0] 9 | out = F.unfold(input_tensor, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 10 | 11 | if tensorized: 12 | lh,lw = im2col_shape(input_tensor.shape[1:],kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation)[-2:] 13 | out = out.view(batch,-1,lh,lw) 14 | return out 15 | def Cube2Col(input_tensor, kernel_size, stride, padding,dilation=1,tensorized=False,): 16 | input_sz=input_tensor.shape 17 | _t=input_sz[1]-kernel_size+1 18 | out=torch.zeros(input_sz[0],kernel_size**3,input_sz[1]-kernel_size+1,input_sz[2]-kernel_size+1,input_sz[3]-kernel_size+1) 19 | for i in range(_t): 20 | ind1=i 21 | ind2=i+kernel_size 22 | out[:,:,i,:,:]=Im2Col(input_tensor[:,ind1:ind2,:,:], kernel_size, stride, padding, dilation, tensorized) 23 | return out 24 | def Col2Cube(input_tensor,output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 25 | batch = input_tensor.shape[0] 26 | _t = output_size[0] - kernel_size + 1 27 | out = torch.zeros([batch,output_size[0],output_size[1],output_size[2]]).cuda() 28 | me=torch.zeros_like(out).cuda() 29 | # if input_tensor.is_cuda: 30 | # out.to(device="cuda") 31 | # me.to(device="cuda") 32 | for i in range(_t): 33 | ind1 = i 34 | ind2 = i + kernel_size 35 | if input_tensorized: 36 | temp_tensor = input_tensor[:,:,i,:,:].flatten(2,3) 37 | out[:,ind1:ind2,:,:] += F.fold(temp_tensor, output_size=output_size[1:], kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 38 | me[:,ind1:ind2,:,:] += F.fold(torch.ones_like(temp_tensor), output_size=output_size[1:], kernel_size=kernel_size, 39 | padding=padding, stride=stride, dilation=dilation) 40 | 41 | 42 | if avg: 43 | # me[me==0]=1 # !!!!!!! 44 | out = out / me 45 | 46 | # me_ = F.conv_transpose2d(torch.ones_like(input_tensor),torch.ones(1,1,kernel_size,kernel_size)) 47 | 48 | return out 49 | 50 | 51 | def Col2Im(input_tensor,output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 52 | batch = input_tensor.shape[0] 53 | 54 | if input_tensorized: 55 | input_tensor = input_tensor.flatten(2,3) 56 | out = F.fold(input_tensor, output_size=output_size, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 57 | 58 | if avg: 59 | me = F.fold(torch.ones_like(input_tensor), output_size=output_size, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 60 | # me[me==0]=1 # !!!!!!! 61 | out = out / me 62 | 63 | # me_ = F.conv_transpose2d(torch.ones_like(input_tensor),torch.ones(1,1,kernel_size,kernel_size)) 64 | 65 | return out 66 | 67 | 68 | class Col2Im_(torch.nn.Module): 69 | 70 | def __init__(self,input_shape, output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 71 | super(Col2Im_,self).__init__() 72 | 73 | xshape = tuple(input_shape) 74 | 75 | if input_tensorized: 76 | xshape = xshape[0:2]+(xshape[2]*xshape[3],) 77 | 78 | if avg: 79 | me = F.fold(torch.ones(xshape), output_size=output_size, kernel_size=kernel_size, 80 | padding=padding, stride=stride, dilation=dilation) 81 | me[me == 0] = 1 82 | self.me = me 83 | 84 | def forward(self, input_tensor,output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 85 | if input_tensorized: 86 | input_tensor = input_tensor.flatten(2, 3) 87 | out = F.fold(input_tensor, output_size=output_size, kernel_size=kernel_size, padding=padding, stride=stride, 88 | dilation=dilation) 89 | if avg: 90 | out /= self.me 91 | return out 92 | 93 | # def im2col_shape(size, kernel_size, stride, padding): 94 | # ksize_h, ksize_w = _pair(kernel_size) 95 | # stride_h, stride_w = _pair(stride) 96 | # pad_h, pad_w = _pair(padding) 97 | # n_input_plane, height, width = size 98 | # height_col = (height + 2 * pad_h - ksize_h) // stride_h + 1 99 | # width_col = (width + 2 * pad_w - ksize_w) // stride_w + 1 100 | # return n_input_plane, ksize_h, ksize_w, height_col, width_col 101 | 102 | def im2col_shape(size, kernel_size, stride, padding, dilation): 103 | ksize_h, ksize_w = _pair(kernel_size) 104 | stride_h, stride_w = _pair(stride) 105 | dil_h, dil_w = _pair(dilation) 106 | pad_h, pad_w = _pair(padding) 107 | n_input_plane, height, width = size 108 | height_col = (height + 2 * pad_h - dil_h * (ksize_h-1)-1) / stride_h + 1 109 | width_col = (width + 2 * pad_w - dil_w * (ksize_w-1)-1) / stride_w + 1 110 | return n_input_plane, ksize_h, ksize_w, math.floor(height_col), math.floor(width_col) 111 | 112 | 113 | def col2im_shape(size, kernel_size, stride, padding, input_size=None): 114 | ksize_h, ksize_w = _pair(kernel_size) 115 | stride_h, stride_w = _pair(stride) 116 | pad_h, pad_w = _pair(padding) 117 | n_input_plane, ksize_h, ksize_w, height_col, width_col = size 118 | if input_size is not None: 119 | height, width = input_size 120 | else: 121 | height = (height_col - 1) * stride_h - 2 * pad_h + ksize_h 122 | width = (width_col - 1) * stride_w - 2 * pad_w + ksize_w 123 | return n_input_plane, height, width -------------------------------------------------------------------------------- /ops/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/ops/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ops/__pycache__/utils_blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/ops/__pycache__/utils_blocks.cpython-37.pyc -------------------------------------------------------------------------------- /ops/gauss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module providing functionality surrounding gaussian function. 3 | """ 4 | SVN_REVISION = '$LastChangedRevision: 16541 $' 5 | 6 | import sys 7 | import numpy 8 | 9 | 10 | def gaussian2(size, sigma): 11 | """Returns a normalized circularly symmetric 2D gauss kernel array 12 | 13 | f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where 14 | 15 | A = 1/(2*pi*sigma^2) 16 | 17 | as define by Wolfram Mathworld 18 | http://mathworld.wolfram.com/GaussianFunction.html 19 | """ 20 | A = 1 / (2.0 * numpy.pi * sigma ** 2) 21 | x, y = numpy.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 22 | g = A * numpy.exp(-((x ** 2 / (2.0 * sigma ** 2)) + (y ** 2 / (2.0 * sigma ** 2)))) 23 | return g 24 | 25 | 26 | def fspecial_gauss(size, sigma): 27 | """Function to mimic the 'fspecial' gaussian MATLAB function 28 | """ 29 | x, y = numpy.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 30 | g = numpy.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) 31 | return g / g.sum() 32 | 33 | 34 | def main(): 35 | """Show simple use cases for functionality provided by this module.""" 36 | from mpl_toolkits.mplot3d.axes3d import Axes3D 37 | import pylab 38 | argv = sys.argv 39 | if len(argv) != 3: 40 | print >> sys.stderr, 'usage: python -m pim.sp.gauss size sigma' 41 | sys.exit(2) 42 | size = int(argv[1]) 43 | sigma = float(argv[2]) 44 | x, y = numpy.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 45 | 46 | fig = pylab.figure() 47 | fig.suptitle('Some 2-D Gauss Functions') 48 | ax = fig.add_subplot(2, 1, 1, projection='3d') 49 | ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, 50 | linewidth=0, antialiased=False, cmap=pylab.jet()) 51 | ax = fig.add_subplot(2, 1, 2, projection='3d') 52 | ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, 53 | linewidth=0, antialiased=False, cmap=pylab.jet()) 54 | pylab.show() 55 | return 0 56 | 57 | 58 | if __name__ == '__main__': 59 | sys.exit(main()) 60 | # {"mode": "full", "isActive": false} -------------------------------------------------------------------------------- /ops/im2col.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | import torch 3 | from torch.nn.modules.utils import _pair 4 | import math 5 | 6 | 7 | def Im2Col(input_tensor, kernel_size, stride, padding,dilation=1,tensorized=False,): 8 | batch = input_tensor.shape[0] 9 | out = F.unfold(input_tensor, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 10 | 11 | if tensorized: 12 | lh,lw = im2col_shape(input_tensor.shape[1:],kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation)[-2:] 13 | out = out.view(batch,-1,lh,lw) 14 | return out 15 | def Cube2Col(input_tensor, kernel_size, stride, padding,dilation=1,tensorized=False,): 16 | input_sz=input_tensor.shape 17 | if input_sz[1]acbd", A, B).view(A.size(0)*B.size(0), A.size(1)*B.size(1)) 14 | def gen_bayer_mask(h,w): 15 | x = torch.zeros(1, 3, h, w) 16 | 17 | x[:, 0, 1::2, 1::2] = 1 # r 18 | x[:, 1, ::2, 1::2] = 1 19 | x[:, 1, 1::2, ::2] = 1 # g 20 | x[:, 2, ::2, ::2] = 1 # b 21 | 22 | return x 23 | 24 | def togray(tensor): 25 | b, c, h, w = tensor.shape 26 | tensor = tensor.view(b, 3, -1, h, w) 27 | tensor = tensor.sum(1) 28 | return tensor 29 | 30 | def torch_to_np(img_var): 31 | return img_var.detach().cpu().numpy() 32 | 33 | def plot_tensor(img, **kwargs): 34 | inp_shape = tuple(img.shape) 35 | print(inp_shape) 36 | img_np = torch_to_np(img) 37 | if inp_shape[1]==3: 38 | img_np_ = img_np.transpose([1,2,0]) 39 | plt.imshow(img_np_) 40 | 41 | elif inp_shape[1]==1: 42 | img_np_ = np.squeeze(img_np) 43 | plt.imshow(img_np_, **kwargs) 44 | 45 | else: 46 | # raise NotImplementedError 47 | plt.imshow(img_np, **kwargs) 48 | plt.axis('off') 49 | 50 | 51 | def get_mask(A): 52 | mask = A.clone().detach() 53 | mask[A != 0] = 1 54 | return mask.byte() 55 | 56 | def sparsity(A): 57 | return get_mask(A).sum().item()/A.numel() 58 | 59 | def soft_threshold(x, lambd): 60 | return nn.functional.relu(x - lambd,inplace=True) - nn.functional.relu(-x - lambd,inplace=True) 61 | def nn_threshold(x, lambd): 62 | return nn.functional.relu(x - lambd) 63 | 64 | def fastSoftThrs(x, lmbda): 65 | return x + 0.5 * (torch.abs(x-torch.abs(lmbda))-torch.abs(x+torch.abs(lmbda))) 66 | 67 | def save_checkpoint(state,ckpt_path): 68 | torch.save(state, ckpt_path) 69 | 70 | def generate_key(): 71 | return '{}'.format(randint(0, 100000)) 72 | 73 | def show_mem(): 74 | mem = torch.cuda.memory_allocated() * 1e-6 75 | max_mem = torch.cuda.max_memory_allocated() * 1e-6 76 | return mem, max_mem 77 | 78 | def str2bool(v): 79 | if isinstance(v, bool): 80 | return v 81 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 82 | return True 83 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 84 | return False 85 | else: 86 | raise argparse.ArgumentTypeError('Boolean value expected.') 87 | 88 | def step_lr(optimizer, lr_decay): 89 | lr = optimizer.param_groups[0]['lr'] 90 | optimizer.param_groups[0]['lr'] = lr * lr_decay 91 | def set_lr(optimizer, lr): 92 | # lr = optimizer.param_groups[0]['lr'] 93 | optimizer.param_groups[0]['lr'] = lr 94 | 95 | def step_lr_als(optimizer, lr_decay): 96 | lr = optimizer.param_groups[0]['lr'] 97 | optimizer.param_groups[0]['lr'] = lr * lr_decay 98 | optimizer.param_groups[1]['lr'] *= lr_decay 99 | 100 | def get_lr(optimizer): 101 | return optimizer.param_groups[0]['lr'] 102 | 103 | 104 | def gen_mask_windows(h, w): 105 | ''' 106 | return mask for block window 107 | :param h: 108 | :param w: 109 | :return: (h,w,h,w) 110 | ''' 111 | mask = torch.zeros(2 * h, 2 * w, h, w) 112 | for i in range(h): 113 | for j in range(w): 114 | mask[i:i + h, j:j + w, i, j] = 1 115 | 116 | return mask[h // 2:-h // 2, w // 2:-w // 2, :, :] 117 | 118 | 119 | def gen_linear_mask_windows(h, w, h_,w_): 120 | ''' 121 | return mask for block window 122 | :param h: 123 | :param w: 124 | :return: (h,w,h,w) 125 | ''' 126 | 127 | x = torch.ones(1, 1, h - h_ + 1, w - w_ + 1) 128 | k = torch.ones(1, 1, h_, w_) 129 | kernel = F.conv_transpose2d(x, k) 130 | kernel /= kernel.max() 131 | mask = torch.zeros(2 * h, 2 * w, h, w) 132 | for i in range(h): 133 | for j in range(w): 134 | mask[i:i + h, j:j + w, i, j] = kernel 135 | 136 | return mask[h // 2:-h // 2, w // 2:-w // 2, :, :] 137 | 138 | def gen_quadra_mask_windows(h, w, h_,w_): 139 | ''' 140 | return mask for block window 141 | :param h: 142 | :param w: 143 | :return: (h,w,h,w) 144 | ''' 145 | 146 | x = torch.ones(1, 1, h - h_ + 1, w - w_ + 1) 147 | k = torch.ones(1, 1, h_, w_) 148 | kernel = F.conv_transpose2d(x, k) **2 149 | kernel /= kernel.max() 150 | mask = torch.zeros(2 * h, 2 * w, h, w) 151 | for i in range(h): 152 | for j in range(w): 153 | mask[i:i + h, j:j + w, i, j] = kernel 154 | 155 | return mask[h // 2:-h // 2, w // 2:-w // 2, :, :] 156 | 157 | def pil_to_np(img_PIL): 158 | '''Converts image in PIL format to np.array. 159 | 160 | From W x H x C [0...255] to C x W x H [0..1] 161 | ''' 162 | ar = np.array(img_PIL) 163 | 164 | if len(ar.shape) == 3: 165 | ar = ar.transpose(2, 0, 1) 166 | else: 167 | ar = ar[None, ...] 168 | 169 | return ar.astype(np.float32) / 255. 170 | 171 | 172 | def np_to_pil(img_np): 173 | '''Converts image in np.array format to PIL image. 174 | 175 | From C x W x H [0..1] to W x H x C [0...255] 176 | ''' 177 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 178 | 179 | if img_np.shape[0] == 1: 180 | ar = ar[0] 181 | else: 182 | ar = ar.transpose(1, 2, 0) 183 | 184 | return Image.fromarray(ar) 185 | def Init_DCT(n, m): 186 | """ Compute the Overcomplete Discrete Cosinus Transform. """ 187 | n=int(n) 188 | m=int(m) 189 | Dictionary = np.zeros((n,m)) 190 | for k in range(m): 191 | V = np.cos(np.arange(0, n) * k * np.pi / m) 192 | if k > 0: 193 | V = V - np.mean(V) 194 | Dictionary[:, k] = V / np.linalg.norm(V) 195 | # Dictionary = np.kron(Dictionary, Dictionary) 196 | # Dictionary = Dictionary.dot(np.diag(1 / np.sqrt(np.sum(Dictionary ** 2, axis=0)))) 197 | # idx = np.arange(0, n ** 2) 198 | # idx = idx.reshape(n, n, order="F") 199 | # idx = idx.reshape(n ** 2, order="C") 200 | # Dictionary = Dictionary[idx, :] 201 | Dictionary = torch.from_numpy(Dictionary).float() 202 | return Dictionary 203 | 204 | def est_noise(y, noise_type='additive'): 205 | """ 206 | This function infers the noise in a 207 | hyperspectral data set, by assuming that the 208 | reflectance at a given band is well modelled 209 | by a linear regression on the remaining bands. 210 | 211 | Parameters: 212 | y: `numpy array` 213 | a HSI cube ((m*n) x p) 214 | 215 | noise_type: `string [optional 'additive'|'poisson']` 216 | 217 | Returns: `tuple numpy array, numpy array` 218 | * the noise estimates for every pixel (N x p) 219 | * the noise correlation matrix estimates (p x p) 220 | 221 | Copyright: 222 | Jose Nascimento (zen@isel.pt) and Jose Bioucas-Dias (bioucas@lx.it.pt) 223 | For any comments contact the authors 224 | """ 225 | # def est_additive_noise(r): 226 | # small = 1e-6 227 | # L, N = r.shape 228 | # w=np.zeros((L,N), dtype=np.float) 229 | # RR=np.dot(r,r.T) 230 | # RRi = np.linalg.pinv(RR+small*np.eye(L)) 231 | # RRi = np.matrix(RRi) 232 | # for i in range(L): 233 | # XX = RRi - (RRi[:,i]*RRi[i,:]) / RRi[i,i] 234 | # RRa = RR[:,i] 235 | # RRa[i] = 0 236 | # beta = np.dot(XX, RRa) 237 | # beta[0,i]=0; 238 | # w[i,:] = r[i,:] - np.dot(beta,r) 239 | # Rw = np.diag(np.diag(np.dot(w,w.T) / N)) 240 | # return w, Rw 241 | def est_additive_noise(r): 242 | small = 1e-6 243 | L, N = r.shape 244 | w=torch.zeros((L,N), dtype=torch.float,device=r.device) 245 | RR=r@r.T 246 | # print((small*torch.eye(L,device=r.device)).device) 247 | temp=RR+small*torch.eye(L,device=r.device) 248 | # print(temp.device) 249 | RRi = torch.inverse(temp) 250 | 251 | # RRi = np.matrix(RRi) 252 | for i in range(L): 253 | XX = RRi - (RRi[:,i].unsqueeze(1)*RRi[i,:].unsqueeze(0)) / RRi[i,i] 254 | RRa = RR[:,i] 255 | RRa[i] = 0 256 | beta =XX@RRa 257 | beta[i]=0; 258 | w[i,:] = r[i,:] - beta@r 259 | Rw = torch.diag(torch.diag((w@w.T) / N)) 260 | return w, Rw 261 | 262 | h, w, numBands = y.shape 263 | y = torch.reshape(y, (w * h, numBands)) 264 | # y = np.reshape(y, (w * h, numBands)) 265 | y = y.T 266 | L, N = y.shape 267 | # verb = 'poisson' 268 | if noise_type == 'poisson': 269 | sqy = torch.sqrt(y * (y > 0)) 270 | u, Ru = est_additive_noise(sqy) 271 | x = (sqy - u) ** 2 272 | w = torch.sqrt(x) * u * 2 273 | Rw = (w@w.T) / N 274 | # additive 275 | else: 276 | w, Rw = est_additive_noise(y) 277 | return w.T, Rw.T 278 | 279 | # y = y.T 280 | # L, N = y.shape 281 | # #verb = 'poisson' 282 | # if noise_type == 'poisson': 283 | # sqy = np.sqrt(y * (y > 0)) 284 | # u, Ru = est_additive_noise(sqy) 285 | # x = (sqy - u)**2 286 | # w = np.sqrt(x)*u*2 287 | # Rw = np.dot(w,w.T) / N 288 | # # additive 289 | # else: 290 | # w, Rw = est_additive_noise(y) 291 | # return w.T, Rw.T 292 | 293 | 294 | def hysime(y, n, Rn): 295 | """ 296 | Hyperspectral signal subspace estimation 297 | 298 | Parameters: 299 | y: `numpy array` 300 | hyperspectral data set (each row is a pixel) 301 | with ((m*n) x p), where p is the number of bands 302 | and (m*n) the number of pixels. 303 | 304 | n: `numpy array` 305 | ((m*n) x p) matrix with the noise in each pixel. 306 | 307 | Rn: `numpy array` 308 | noise correlation matrix (p x p) 309 | 310 | Returns: `tuple integer, numpy array` 311 | * kf signal subspace dimension 312 | * Ek matrix which columns are the eigenvectors that span 313 | the signal subspace. 314 | 315 | Copyright: 316 | Jose Nascimento (zen@isel.pt) & Jose Bioucas-Dias (bioucas@lx.it.pt) 317 | For any comments contact the authors 318 | """ 319 | h, w, numBands = y.shape 320 | y = torch.reshape(y, (w * h, numBands)) 321 | y=y.T 322 | n=n.T 323 | Rn=Rn.T 324 | L, N = y.shape 325 | Ln, Nn = n.shape 326 | d1, d2 = Rn.shape 327 | 328 | x = y - n; 329 | 330 | Ry = y@y.T / N 331 | Rx = x@x.T/ N 332 | E, dx, V =torch.svd(Rx.cpu()) 333 | E=E.to(device=y.device) 334 | # print(V) 335 | Rn = Rn+torch.sum(torch.diag(Rx))/L/10**5 * torch.eye(L,device=y.device) 336 | Py = torch.diag(E.T@(Ry@E)) 337 | Pn = torch.diag(E.T@(Rn@E)) 338 | cost_F = -Py + 2 * Pn 339 | kf = torch.sum(cost_F < 0) 340 | ind_asc = torch.argsort(cost_F) 341 | Ek = E[:, ind_asc[0:kf]] 342 | # h, w, numBands = y.shape 343 | # y = np.reshape(y, (w * h, numBands)) 344 | # y = y.T 345 | # n = n.T 346 | # Rn = Rn.T 347 | # L, N = y.shape 348 | # Ln, Nn = n.shape 349 | # d1, d2 = Rn.shape 350 | # 351 | # x = y - n; 352 | # 353 | # Ry = np.dot(y, y.T) / N 354 | # Rx = np.dot(x, x.T) / N 355 | # E, dx, V = np.linalg.svd(Rx) 356 | # 357 | # Rn = Rn + np.sum(np.diag(Rx)) / L / 10 ** 5 * np.eye(L) 358 | # Py = np.diag(np.dot(E.T, np.dot(Ry, E))) 359 | # Pn = np.diag(np.dot(E.T, np.dot(Rn, E))) 360 | # cost_F = -Py + 2 * Pn 361 | # kf = np.sum(cost_F < 0) 362 | # ind_asc = np.argsort(cost_F) 363 | # Ek = E[:, ind_asc[0:kf]] 364 | return kf, E # Ek.T ? 365 | def count(M): 366 | w, Rw = est_noise(M) 367 | kf, Ek = hysime(M, w, Rw) 368 | return kf, Ek, w, Rw 369 | 370 | def cal_sam(X, Y, eps=1e-8): 371 | # X = torch.squeeze(X.data).cpu().numpy() 372 | # Y = torch.squeeze(Y.data).cpu().numpy() 373 | tmp = (np.sum(X*Y, axis=0) + eps) / ((np.sqrt(np.sum(X**2, axis=0)) + eps) * (np.sqrt(np.sum(Y**2, axis=0)) + eps)+eps) 374 | return np.mean(np.real(np.arccos(tmp))) 375 | def cal_psnr(im_true,im_test,eps=13-8): 376 | c,_,_=im_true.shape 377 | bwindex = [] 378 | for i in range(c): 379 | bwindex.append(compare_psnr(im_true[i,:,:], im_test[i,:,:])) 380 | return np.mean(bwindex) 381 | def ssim(img1, img2, cs_map=False): 382 | """Return the Structural Similarity Map corresponding to input images img1 383 | and img2 (images are assumed to be uint8) 384 | 385 | This function attempts to mimic precisely the functionality of ssim.m a 386 | MATLAB provided by the author's of SSIM 387 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m 388 | """ 389 | img1 = img1.astype(np.float64) 390 | img2 = img2.astype(np.float64) 391 | size = 11 392 | sigma = 1.5 393 | window = fspecial_gauss(size, sigma) 394 | K1 = 0.01 395 | K2 = 0.03 396 | L = 255 # bitdepth of image 397 | C1 = (K1 * L) ** 2 398 | C2 = (K2 * L) ** 2 399 | mu1 = signal.fftconvolve(window, img1, mode='valid') 400 | mu2 = signal.fftconvolve(window, img2, mode='valid') 401 | mu1_sq = mu1 * mu1 402 | mu2_sq = mu2 * mu2 403 | mu1_mu2 = mu1 * mu2 404 | sigma1_sq = signal.fftconvolve(window, img1 * img1, mode='valid') - mu1_sq 405 | sigma2_sq = signal.fftconvolve(window, img2 * img2, mode='valid') - mu2_sq 406 | sigma12 = signal.fftconvolve(window, img1 * img2, mode='valid') - mu1_mu2 407 | if cs_map: 408 | return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 409 | (sigma1_sq + sigma2_sq + C2)), 410 | (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) 411 | else: 412 | return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 413 | (sigma1_sq + sigma2_sq + C2)) 414 | 415 | def cal_ssim(im_true,im_test,eps=13-8): 416 | # print(im_true.shape) 417 | # print(im_true.shape) 418 | # print(im_test.shape) 419 | # im_true=im_true.cpu().numpy() 420 | # im_test = im_test.cpu().numpy() 421 | c,_,_=im_true.shape 422 | bwindex = [] 423 | for i in range(c): 424 | bwindex.append(ssim(im_true[i,:,:]*255, im_test[i,:,:,]*255)) 425 | return np.mean(bwindex) 426 | # def cal_ssim(im_true,im_test,eps=13-8): 427 | # c,_,_=im_true.shape 428 | # bwindex = [] 429 | # for i in range(c): 430 | # bwindex.append(compare_ssim(im_true[i,:,:], im_test[i,:,:,])) 431 | # return np.mean(bwindex) 432 | 433 | # class Bandwise(object): 434 | # def __init__(self, index_fn): 435 | # self.index_fn = index_fn 436 | # 437 | # def __call__(self, X, Y): 438 | # C = X.shape[-3] 439 | # bwindex = [] 440 | # for ch in range(C): 441 | # x = torch.squeeze(X[...,ch,:,:].data).cpu().numpy() 442 | # y = torch.squeeze(Y[...,ch,:,:].data).cpu().numpy() 443 | # index = self.index_fn(x, y) 444 | # bwindex.append(index) 445 | # return bwindex 446 | 447 | 448 | def MSIQA(X, Y): 449 | # print(X.shape) 450 | # print(Y.shape) 451 | psnr = cal_psnr(X, Y) 452 | ssim = cal_ssim(X, Y) 453 | sam = cal_sam(X, Y) 454 | return psnr, ssim, sam 455 | if __name__ == '__main__': 456 | hsi = torch.rand(200,200, 198) 457 | w, Rw=est_noise(hsi) 458 | kf, E= hysime(hsi, w, Rw) 459 | print(kf) 460 | 461 | 462 | 463 | 464 | -------------------------------------------------------------------------------- /ops/utils_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ops.im2col import Im2Col, Col2Im, Col2Cube,Cube2Col 4 | 5 | 6 | def shape_pad_even(tensor_shape, patch,stride): 7 | assert len(tensor_shape) == 4 8 | b,c,h,w = tensor_shape 9 | required_pad_h = stride - (h-patch) % stride 10 | required_pad_w = stride - (w-patch) % stride 11 | return required_pad_h,required_pad_w 12 | 13 | 14 | class block_module(): 15 | 16 | def __init__(self,block_size,block_stride, kernel_size, params): 17 | super(block_module).__init__() 18 | self.params = params 19 | self.kernel_size = kernel_size 20 | self.block_size = block_size 21 | self.block_stride = block_stride 22 | # self.channel_size = channel_size 23 | 24 | def _make_blocks(self, image, return_padded=False): 25 | ''' 26 | :param image: (1,C,H,W) 27 | :return: raw block (batch,C,block_size,block_size), tulple shape augmented image 28 | ''' 29 | params = self.params 30 | 31 | self.channel_size = image.shape[1] 32 | 33 | if params['pad_block']: 34 | pad = (self.block_size - 1,) * 4 35 | elif params['pad_patch']: 36 | pad = (self.kernel_size,)*4 37 | elif params['no_pad']: 38 | pad = (0,) * 4 39 | elif params['custom_pad'] is not None: 40 | pad = (params['custom_pad'],) * 4 41 | 42 | else: 43 | raise NotImplementedError 44 | 45 | image_mirror_padded = F.pad(image, pad, mode='reflect') 46 | pad_even = shape_pad_even(image_mirror_padded.shape, self.block_size, self.block_stride) 47 | pad_h, pad_w = pad_even 48 | if params['centered_pad']: 49 | pad_ = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) 50 | else: 51 | pad_ =(0, pad_w, 0, pad_h) 52 | pad = tuple([x+y for x,y in zip(pad,pad_)]) 53 | self.pad = pad 54 | 55 | image_mirror_padded_even = F.pad(image, pad, mode='reflect') # add half kernel cause block edges are dump 56 | 57 | self.augmented_shape = image_mirror_padded_even.shape 58 | 59 | if return_padded: 60 | return image_mirror_padded 61 | 62 | batch_blocks = Im2Col(image_mirror_padded_even, 63 | kernel_size=self.block_size, 64 | stride= self.block_stride, 65 | padding=0) 66 | 67 | batch_blocks = batch_blocks.permute(2, 0, 1) 68 | batch_blocks = batch_blocks.view(-1, self.channel_size, self.block_size, self.block_size) 69 | return batch_blocks 70 | def _make_cubes(self, image, return_padded=False): 71 | ''' 72 | :param image: (1,C,H,W) 73 | :return: raw block (batch_spa,batch_spec,block_size,block_size,block_size), tulple shape augmented image 74 | ''' 75 | params = self.params 76 | 77 | self.channel_size = image.shape[1] 78 | 79 | if params['pad_block']: 80 | pad = (self.block_size - 1,) * 4 81 | elif params['pad_patch']: 82 | pad = (self.kernel_size,)*4 83 | elif params['no_pad']: 84 | pad = (0,) * 4 85 | elif params['custom_pad'] is not None: 86 | pad = (params['custom_pad'],) * 4 87 | 88 | else: 89 | raise NotImplementedError 90 | 91 | image_mirror_padded = F.pad(image, pad, mode='reflect') 92 | pad_even = shape_pad_even(image_mirror_padded.shape, self.block_size, self.block_stride) 93 | pad_h, pad_w = pad_even 94 | if params['centered_pad']: 95 | pad_ = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) 96 | else: 97 | pad_ =(0, pad_w, 0, pad_h) 98 | pad = tuple([x+y for x,y in zip(pad,pad_)]) 99 | self.pad = pad 100 | 101 | image_mirror_padded_even = F.pad(image, pad, mode='reflect') # add half kernel cause block edges are dump 102 | 103 | self.augmented_shape = image_mirror_padded_even.shape 104 | 105 | if return_padded: 106 | return image_mirror_padded 107 | 108 | batch_blocks = Im2Col(image_mirror_padded_even, 109 | kernel_size=self.block_size, 110 | stride= self.block_stride, 111 | padding=0) 112 | 113 | batch_blocks = batch_blocks.permute(2, 0, 1) 114 | batch_blocks = batch_blocks.view(-1, self.channel_size, self.block_size, self.block_size) 115 | return batch_blocks 116 | 117 | def _agregate_blocks(self,batch_out_blocks): 118 | ''' 119 | :param blocks: processed blocks 120 | :return: image of averaged estimates 121 | ''' 122 | h_pad, w_pad = self.augmented_shape[2:] 123 | params = self.params 124 | l = self.kernel_size // 2 125 | device = batch_out_blocks.device 126 | 127 | # batch_out_blocks_flatten = batch_out_blocks.flatten(2, 3).permute(1, 2, 0) 128 | batch_out_blocks_flatten = batch_out_blocks.view(-1,self.channel_size * self.block_size**2).transpose(0,1).unsqueeze(0) 129 | print(self.block_size) 130 | # print(self.kernel_size) 131 | if params['ponderate_out_blocks']: 132 | if self.kernel_size%2==0: 133 | mask = F.conv_transpose2d(torch.ones((1,1)+(self.block_size - 2 * l,)*2), 134 | torch.ones((1,1)+(self.kernel_size+1,)*2)) 135 | else: 136 | mask = F.conv_transpose2d(torch.ones((1, 1) + (self.block_size - 2 * l,) * 2), 137 | torch.ones((1, 1) + (self.kernel_size,) * 2)) 138 | mask = mask.to(device=device) 139 | print(batch_out_blocks.shape) 140 | print(mask.shape) 141 | batch_out_blocks *= mask 142 | 143 | # batch_out_blocks_flatten = batch_out_blocks.flatten(2, 3).permute(1, 2, 0) 144 | 145 | output_padded = Col2Im(batch_out_blocks_flatten, 146 | output_size=(h_pad, w_pad), 147 | kernel_size=self.block_size, 148 | stride=self.block_stride, 149 | padding=0, 150 | avg=False) 151 | 152 | batch_out_blocks_ones = torch.ones_like(batch_out_blocks) * mask 153 | # batch_out_blocks_flatten_ones = batch_out_blocks_ones.flatten(2, 3).permute(1, 2, 0) 154 | batch_out_blocks_flatten_ones = batch_out_blocks_ones.view(-1, self.channel_size * self.block_size ** 2).transpose(0,1).unsqueeze(0) 155 | 156 | if params['avg']: 157 | mask_ = Col2Im(batch_out_blocks_flatten_ones, 158 | output_size=(h_pad, w_pad), 159 | kernel_size=self.block_size, 160 | stride=self.block_stride, 161 | padding=0, 162 | avg=False) 163 | output_padded /= mask_ 164 | 165 | elif params['crop_out_blocks']: 166 | kernel_ = self.block_size - 2 * l 167 | # batch_out_blocks_flatten = batch_out_blocks.flatten(2, 3).permute(1, 2, 0) 168 | output_padded = Col2Im(batch_out_blocks_flatten, 169 | output_size=(h_pad - 2 * l, w_pad - 2 * l), 170 | kernel_size=kernel_, 171 | stride=self.block_size, 172 | padding=0, 173 | avg=params['avg']) 174 | 175 | elif params['sum_blocks']: 176 | # batch_out_blocks_flatten = batch_out_blocks.flatten(2, 3).permute(1, 2, 0) 177 | output_padded = Col2Im(batch_out_blocks_flatten, 178 | output_size=(h_pad, w_pad), 179 | kernel_size=self.block_size, 180 | stride=self.block_stride, 181 | padding=0, 182 | avg=params['avg']) 183 | else: 184 | raise NotImplementedError 185 | 186 | pad = self.pad 187 | output = output_padded[:, :, pad[2]:-pad[3], pad[0]:-pad[1]] 188 | 189 | return output 190 | 191 | -------------------------------------------------------------------------------- /ops/utils_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from PIL import Image 4 | from torchvision.utils import make_grid 5 | from ops.im2col import * 6 | from ops.utils import get_mask 7 | 8 | def plot_tensor(img, **kwargs): 9 | inp_shape = tuple(img.shape) 10 | print(inp_shape) 11 | img_np = torch_to_np(img) 12 | if inp_shape[1]==3: 13 | img_np_ = img_np.transpose([1,2,0]) 14 | plt.imshow(img_np_) 15 | 16 | elif inp_shape[1]==1: 17 | img_np_ = np.squeeze(img_np) 18 | plt.imshow(img_np_, **kwargs) 19 | 20 | else: 21 | # raise NotImplementedError 22 | plt.imshow(img_np, **kwargs) 23 | plt.axis('off') 24 | 25 | 26 | def hist_tensor(img,**kwargs): 27 | inp_shape = tuple(img.shape) 28 | print(inp_shape) 29 | img_np = torch_to_np(img) 30 | return plt.hist(img_np.flatten(),**kwargs) 31 | 32 | def np_to_torch(img_np): 33 | '''Converts image in numpy.array to torch.Tensor. 34 | 35 | From C x W x H [0..1] to C x W x H [0..1] 36 | ''' 37 | return torch.from_numpy(img_np)[None, :] 38 | 39 | 40 | def torch_to_np(img_var): 41 | '''Converts an image in torch.Tensor format to np.array. 42 | 43 | From 1 x C x W x H [0..1] to C x W x H [0..1] 44 | ''' 45 | return img_var.detach().cpu().numpy()[0] 46 | 47 | def pil_to_np(img_PIL): 48 | '''Converts image in PIL format to np.array. 49 | 50 | From W x H x C [0...255] to C x W x H [0..1] 51 | ''' 52 | ar = np.array(img_PIL) 53 | 54 | if len(ar.shape) == 3: 55 | ar = ar.transpose(2, 0, 1) 56 | else: 57 | ar = ar[None, ...] 58 | 59 | return ar.astype(np.float32) / 255. 60 | 61 | 62 | def np_to_pil(img_np): 63 | '''Converts image in np.array format to PIL image. 64 | 65 | From C x W x H [0..1] to W x H x C [0...255] 66 | ''' 67 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 68 | 69 | if img_np.shape[0] == 1: 70 | ar = ar[0] 71 | else: 72 | ar = ar.transpose(1, 2, 0) 73 | 74 | return Image.fromarray(ar) 75 | 76 | 77 | 78 | def show_dict(m,a=None, norm_grid=False, sort_freq=True, norm=True): 79 | n_elem,_,s = m.shape 80 | s_ = int(math.sqrt(s)) 81 | m=m.view(n_elem,1,s_,s_) 82 | if norm: 83 | m = normalize_patches(m) 84 | if sort_freq: 85 | if a is None: 86 | raise ValueError("provide code array to sort dicts by usage frequency") 87 | idx = sort_patches(a) 88 | m = m[idx] 89 | 90 | grid = make_grid(m, normalize=norm_grid, padding=2,nrow=int(math.sqrt(n_elem))) 91 | return grid 92 | 93 | def whiten_col(tx,eps=1e-4): 94 | shape = tx.shape 95 | tx = tx.squeeze() 96 | D = torch.mm(tx, tx.t()) / len(tx) 97 | diag, v = torch.symeig(D, eigenvectors=True) 98 | diag[diag < eps] = 1 99 | diag = diag ** 0.5 100 | diag = 1 / diag 101 | S = torch.diag(diag) 102 | out = v @ S @ v.t() @ tx 103 | out = out.view(shape) 104 | return out 105 | 106 | def normalize_patches(D): 107 | p=3.5 108 | M=D.max() 109 | m=D.min() 110 | if m>=0: 111 | me = 0 112 | else: 113 | me = D.mean() 114 | sig = torch.sqrt(((D-me)**2).mean()) 115 | D=torch.min(torch.max(D, -p*sig),p*sig) 116 | M=D.max() 117 | m=D.min() 118 | D = (D-m)/(M-m) 119 | return D 120 | 121 | def sort_patches(a): 122 | code = get_mask(a).float() 123 | code_freq = code.mean([0, 2, 3]).flatten() 124 | _, idx = code_freq.sort(descending=True) 125 | return idx 126 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import scipy.io as scio 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import dataloaders_hsi_test 11 | from ops.utils import MSIQA 12 | from ops.utils import str2bool 13 | from ops.utils_blocks import block_module 14 | 15 | parser = argparse.ArgumentParser() 16 | from skimage.restoration import estimate_sigma 17 | from model_loader import init_model, load_model 18 | 19 | # model 20 | parser.add_argument("--noise_level", type=int, dest="noise_level", help="Should be an int in the range [0,255]", 21 | default=0) 22 | parser.add_argument("--bandwise", type=str2bool, default=1, help='bandwise noise') 23 | parser.add_argument("--num_half_layer", type=int, dest="num_half_layer", help="Number of LISTA step unfolded", 24 | default=6) 25 | parser.add_argument("--channels", type=int, dest="channels", help="Should be an int in the range [0,255]", default=16) 26 | parser.add_argument("--nl", type=str2bool, dest="nl", help="If Nonlocal", default=1) 27 | parser.add_argument("--patch_size", type=int, dest="patch_size", help="Size of image blocks to process", default=56) 28 | parser.add_argument("--rescaling_init_val", type=float, default=1.0) 29 | parser.add_argument("--nu_init", type=float, default=1, help='convex combination of correlation map init value') 30 | parser.add_argument("--corr_update", type=int, default=3, 31 | help='choose update method in [2,3] without or with patch averaging') 32 | parser.add_argument("--multi_theta", type=str2bool, default=1, 33 | help='wether to use a sequence of lambda [1] or a single vector during lista [0]') 34 | parser.add_argument("--diag_rescale_gamma", type=str2bool, default=0, help='diag rescaling code correlation map') 35 | parser.add_argument("--diag_rescale_patch", type=str2bool, default=1, help='diag rescaling patch correlation map') 36 | parser.add_argument("--freq_corr_update", type=int, default=6, help='freq update correlation_map') 37 | parser.add_argument("--mask_windows", type=int, default=1, help='binarym, quadratic mask [1,2]') 38 | parser.add_argument("--center_windows", type=str2bool, default=1, 39 | help='compute correlation with neighboors only within a block') 40 | parser.add_argument("--multi_std", type=str2bool, default=0) 41 | parser.add_argument("--gpus", '--list', action='append', type=int, help='GPU') 42 | parser.add_argument("--rs_real", type=str2bool, default=0) 43 | parser.add_argument("--blind", type=str2bool, default=0) 44 | 45 | # training 46 | parser.add_argument("--lr", type=float, dest="lr", help="ADAM Learning rate", default=1e-4) 47 | parser.add_argument("--lr_step", type=int, dest="lr_step", help="ADAM Learning rate step for decay", default=80) 48 | parser.add_argument("--lr_decay", type=float, dest="lr_decay", help="ADAM Learning rate decay (on step)", default=0.35) 49 | parser.add_argument("--backtrack_decay", type=float, help='decay when backtracking', default=0.8) 50 | parser.add_argument("--eps", type=float, dest="eps", help="ADAM epsilon parameter", default=1e-3) 51 | parser.add_argument("--validation_every", type=int, default=300, 52 | help='validation frequency on training set (if using backtracking)') 53 | parser.add_argument("--backtrack", type=str2bool, default=1, help='use backtrack to prevent model divergence') 54 | parser.add_argument("--num_epochs", type=int, dest="num_epochs", help="Total number of epochs to train", default=300) 55 | parser.add_argument("--train_batch", type=int, default=2, help='batch size during training') 56 | parser.add_argument("--test_batch", type=int, default=3, help='batch size during eval') 57 | parser.add_argument("--aug_scale", type=int, default=0) 58 | 59 | # data 60 | parser.add_argument("--out_dir", type=str, dest="out_dir", help="Results' dir path", default='./trained_model') 61 | parser.add_argument("--model_name", type=str, dest="model_name", help="The name of the model to be saved.", 62 | default='trained_model_25_bandwise/MTMF_patch_56Layer_12lr_0.00100000/ckpt') 63 | parser.add_argument("--test_path", type=str, help="Path to the dir containing the testing datasets.", default="data/") 64 | parser.add_argument("--gt_path", type=str, help="Path to the dir containing the ground truth datasets.", default="gt/") 65 | parser.add_argument("--resume", type=str2bool, dest="resume", help='Resume training of the model', default=True) 66 | parser.add_argument("--dummy", type=str2bool, dest="dummy", default=False) 67 | parser.add_argument("--tqdm", type=str2bool, default=False) 68 | parser.add_argument('--log_dir', type=str, default='log', help='log directory') 69 | 70 | # inference 71 | parser.add_argument("--kernel_size", type=int, default=12, 72 | help='stride of overlapping image blocks [4,8,16,24,48] kernel_//stride') 73 | 74 | # parser.add_argument("--stride_test", type=int, default=12, help='stride of overlapping image blocks [4,8,16,24,48] kernel_//stride') 75 | parser.add_argument("--stride_val", type=int, default=40, 76 | help='stride of overlapping image blocks for validation [4,8,16,24,48] kernel_//stride') 77 | parser.add_argument("--test_every", type=int, default=300, help='report performance on test set every X epochs') 78 | parser.add_argument("--block_inference", type=str2bool, default=False, 79 | help='if true process blocks of large image in paralel') 80 | parser.add_argument("--pad_image", type=str2bool, default=0, help='padding strategy for inference') 81 | parser.add_argument("--pad_block", type=str2bool, default=1, help='padding strategy for inference') 82 | parser.add_argument("--pad_patch", type=str2bool, default=0, help='padding strategy for inference') 83 | parser.add_argument("--no_pad", type=str2bool, default=False, help='padding strategy for inference') 84 | parser.add_argument("--custom_pad", type=int, default=None, help='padding strategy for inference') 85 | parser.add_argument("--save", type=str2bool, default=0, help='padding strategy for inference') 86 | 87 | # variance reduction 88 | # var reg 89 | parser.add_argument("--nu_var", type=float, default=0.01) 90 | parser.add_argument("--freq_var", type=int, default=3) 91 | parser.add_argument("--var_reg", type=str2bool, default=False) 92 | 93 | parser.add_argument("--verbose", type=str2bool, default=1) 94 | 95 | args = parser.parse_args() 96 | # os.environ['CUDA_VISIBLE_DEVICES']= '6,7' 97 | if args.gpus is not None and len(args.gpus): 98 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 99 | device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu' 100 | capability = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else os.cpu_count() 101 | gpus = args.gpus 102 | if torch.cuda.is_available(): 103 | torch.backends.cudnn.benchmark = True 104 | if device.type == 'cuda': 105 | torch.cuda.set_device('cuda:{}'.format(gpus[0])) 106 | else: 107 | device = torch.device("cpu") 108 | device_name = 'cpu' 109 | capability = os.cpu_count() 110 | 111 | test_path = [args.test_path] 112 | gt_path = args.gt_path 113 | print(f'test data : {test_path}') 114 | print(f'gt data : {gt_path}') 115 | train_path = val_path = [] 116 | 117 | noise_std = args.noise_level / 255 118 | 119 | loaders = dataloaders_hsi_test.get_dataloaders(test_path, crop_size=args.patch_size, 120 | batch_size=args.train_batch, downscale=args.aug_scale, concat=1, 121 | verbose=True, grey=False) 122 | model = init_model(in_channels=1, channels=args.channels, 123 | num_half_layer=args.num_half_layer, 124 | rs=args.rs_real) 125 | if device.type == 'cuda': 126 | model = torch.nn.DataParallel(model.to(device=device), device_ids=gpus) 127 | args.model_15 = 'trained_model/ckpt_15' 128 | args.model_55 = 'trained_model/ckpt_55' 129 | args.model_95 = 'trained_model/ckpt_95' 130 | 131 | tic = time.time() 132 | phase = 'test' 133 | 134 | num_iters = 0 135 | psnr_tot = [] 136 | ssim_tot = [] 137 | sam_tot = [] 138 | stride_test = args.patch_size // 2 139 | loader = loaders['test'] 140 | for batch, fname in tqdm(loader, disable=not args.tqdm): 141 | batch = batch.to(device=device) 142 | fname = fname[0] 143 | print(fname) 144 | noisy_batch = batch 145 | if args.blind: 146 | sigma_est = np.array(estimate_sigma(noisy_batch.squeeze(0).permute([1, 2, 0]).detach().cpu(), multichannel=True, 147 | average_sigmas=False)).max() * 255 148 | else: 149 | sigma_est = args.noise_level 150 | if sigma_est > 15 and sigma_est <= 55: 151 | load_model(model_name=args.model_55, model=model,device_name=device_name) 152 | if sigma_est <= 15: 153 | load_model(model_name=args.model_15, model=model,device_name=device_name) 154 | if sigma_est > 55: 155 | load_model(model_name=args.model_95, model=model,device_name=device_name) 156 | if args.rs_real: 157 | load_model(model_name=args.model_95, model=model,device_name=device_name) 158 | args.block_inference=1 159 | args.patch_size=128 160 | stride_test=args.patch_size//2 161 | model.eval() # Set model to evaluate mode 162 | if args.block_inference: 163 | if args.patch_size > noisy_batch.shape[-1] // 2 or args.patch_size > noisy_batch.shape[-2] // 2: 164 | stride_test = min(args.patch_size // 8, 8) 165 | with torch.set_grad_enabled(False): 166 | if args.block_inference: 167 | params = { 168 | 'crop_out_blocks': 0, 169 | 'ponderate_out_blocks': 1, 170 | 'sum_blocks': 0, 171 | 'pad_even': 1, # otherwise pad with 0 for las 172 | 'centered_pad': 0, # corner pixel have only one estimate 173 | 'pad_block': args.pad_block, # pad so each pixel has S**2 estimate 174 | 'pad_patch': args.pad_patch, # pad so each pixel from the image has at least S**2 estimate from 1 block 175 | 'no_pad': args.no_pad, 176 | 'custom_pad': args.custom_pad, 177 | 'avg': 1} 178 | block = block_module(args.patch_size, stride_test, args.kernel_size, params) 179 | batch_noisy_blocks = block._make_blocks(noisy_batch) 180 | patch_loader = torch.utils.data.DataLoader(batch_noisy_blocks, batch_size=args.test_batch, drop_last=False) 181 | batch_out_blocks = torch.zeros_like(batch_noisy_blocks) 182 | for i, inp in enumerate(patch_loader): # if it doesnt fit in memory 183 | id_from, id_to = i * patch_loader.batch_size, (i + 1) * patch_loader.batch_size 184 | batch_out_blocks[id_from:id_to] = model(inp) 185 | 186 | output = block._agregate_blocks(batch_out_blocks) 187 | # print(torch.isnan(output).sum()) 188 | else: 189 | output = model(noisy_batch) 190 | gt = dataloaders_hsi_test.get_gt(gt_path, fname); 191 | gt = gt.to(device=device) 192 | if device_name == 'cpu': 193 | psnr_batch, ssim_batch, sam_batch = MSIQA(gt.detach().numpy(), 194 | output.squeeze(0).detach().numpy()) 195 | if args.save: 196 | scio.savemat(fname + 'Res.mat', {'output': output.squeeze(0).detach().numpy()}) 197 | else: 198 | psnr_batch, ssim_batch, sam_batch = MSIQA(gt.detach().cpu().numpy(), 199 | output.squeeze(0).detach().cpu().numpy()) 200 | 201 | if args.save: 202 | scio.savemat(fname + 'Res.mat', {'output': output.squeeze(0).detach().cpu().numpy()}) 203 | psnr_tot.append(psnr_batch) 204 | ssim_tot.append(ssim_batch) 205 | sam_tot.append(sam_batch) 206 | num_iters += 1 207 | tqdm.write(f'psnr avg {psnr_batch} ssim avg {ssim_batch} sam avg {sam_batch} ') 208 | if args.dummy: 209 | break 210 | tac = time.time() 211 | psnr_mean = np.mean(psnr_tot) 212 | ssim_mean = np.mean(ssim_tot) 213 | sam_mean = np.mean(sam_tot) 214 | # scio.savemat(args.out_dir + 'GT.mat', {'psnr': psnr_tot, 'ssim': ssim_tot, 'sam': sam_tot}) 215 | # psnr_tot = psnr_tot.item() 216 | 217 | tqdm.write( 218 | f'psnr: {psnr_mean:0.4f} ssim: {ssim_mean:0.4f} sam: {sam_mean:0.4f}({(tac - tic) / num_iters:0.3f} s/iter)') 219 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dataloaders_hsi 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | import argparse 6 | import os 7 | import torch.nn.functional as F 8 | import time 9 | from ops.utils_blocks import block_module 10 | from ops.utils import show_mem, generate_key, save_checkpoint, str2bool, step_lr, get_lr 11 | from model.MACNet import Params 12 | from model.MACNet import MACNet 13 | parser = argparse.ArgumentParser() 14 | #model 15 | parser.add_argument("--noise_level", type=int, dest="noise_level", help="Should be an int in the range [0,255]", default=25) 16 | parser.add_argument("--nl_level", type=int, dest="nl_level", help="Should be an int in the range [0,255]", default=5) 17 | 18 | parser.add_argument("--channels", type=int, dest="channels", help="Should be an int in the range [0,255]", default=16) 19 | 20 | parser.add_argument("--bandwise", type=str2bool, default=1, help='bandwise noise') 21 | parser.add_argument("--num_half_layer", type=int, dest="num_half_layer", help="Number of LISTA step unfolded", default=3) 22 | 23 | parser.add_argument("--patch_size", type=int, dest="patch_size", help="Size of image blocks to process", default=64) 24 | parser.add_argument("--rescaling_init_val", type=float, default=1.0) 25 | parser.add_argument("--nu_init", type=float, default=1, help='convex combination of correlation map init value') 26 | parser.add_argument("--corr_update", type=int, default=3, help='choose update method in [2,3] without or with patch averaging') 27 | parser.add_argument("--multi_theta", type=str2bool, default=1, help='wether to use a sequence of lambda [1] or a single vector during lista [0]') 28 | parser.add_argument("--diag_rescale_gamma", type=str2bool, default=0,help='diag rescaling code correlation map') 29 | parser.add_argument("--diag_rescale_patch", type=str2bool, default=1,help='diag rescaling patch correlation map') 30 | parser.add_argument("--freq_corr_update", type=int, default=6, help='freq update correlation_map') 31 | parser.add_argument("--mask_windows", type=int, default=1,help='binarym, quadratic mask [1,2]') 32 | parser.add_argument("--center_windows", type=str2bool, default=1, help='compute correlation with neighboors only within a block') 33 | parser.add_argument("--multi_std", type=str2bool, default=0) 34 | parser.add_argument("--gpus", '--list',action='append', type=int, help='GPU') 35 | 36 | #training 37 | parser.add_argument("--lr", type=float, dest="lr", help="ADAM Learning rate", default=1e-3) 38 | parser.add_argument("--lr_step", type=int, dest="lr_step", help="ADAM Learning rate step for decay", default=80) 39 | parser.add_argument("--lr_decay", type=float, dest="lr_decay", help="ADAM Learning rate decay (on step)", default=0.35) 40 | parser.add_argument("--backtrack_decay", type=float, help='decay when backtracking',default=0.8) 41 | parser.add_argument("--eps", type=float, dest="eps", help="ADAM epsilon parameter", default=1e-3) 42 | parser.add_argument("--validation_every", type=int, default=300, help='validation frequency on training set (if using backtracking)') 43 | parser.add_argument("--backtrack", type=str2bool, default=1, help='use backtrack to prevent model divergence') 44 | parser.add_argument("--num_epochs", type=int, dest="num_epochs", help="Total number of epochs to train", default=300) 45 | parser.add_argument("--train_batch", type=int, default=2, help='batch size during training') 46 | parser.add_argument("--test_batch", type=int, default=3, help='batch size during eval') 47 | parser.add_argument("--aug_scale", type=int, default=0) 48 | parser.add_argument("--rs_real", type=str2bool, default=0) 49 | 50 | #data 51 | parser.add_argument("--out_dir", type=str, dest="out_dir", help="Results' dir path", default='./trained_model') 52 | parser.add_argument("--model_name", type=str, dest="model_name", help="The name of the model to be saved.", default=None) 53 | parser.add_argument("--test_path", type=str, help="Path to the dir containing the testing datasets.", default="data/") 54 | parser.add_argument("--train_path", type=str, help="Path to the dir containing the training datasets.", default="data/") 55 | parser.add_argument("--resume", type=str2bool, dest="resume", help='Resume training of the model',default=True) 56 | parser.add_argument("--dummy", type=str2bool, dest="dummy", default=False) 57 | parser.add_argument("--tqdm", type=str2bool, default=False) 58 | parser.add_argument('--log_dir', type=str, default='log', help='log directory') 59 | 60 | #inference 61 | parser.add_argument("--stride_test", type=int, default=12, help='stride of overlapping image blocks [4,8,16,24,48] kernel_//stride') 62 | parser.add_argument("--stride_val", type=int, default=40, help='stride of overlapping image blocks for validation [4,8,16,24,48] kernel_//stride') 63 | parser.add_argument("--test_every", type=int, default=300, help='report performance on test set every X epochs') 64 | parser.add_argument("--block_inference", type=str2bool, default=True,help='if true process blocks of large image in paralel') 65 | parser.add_argument("--pad_image", type=str2bool, default=0,help='padding strategy for inference') 66 | parser.add_argument("--pad_block", type=str2bool, default=1,help='padding strategy for inference') 67 | parser.add_argument("--pad_patch", type=str2bool, default=0,help='padding strategy for inference') 68 | parser.add_argument("--no_pad", type=str2bool, default=False, help='padding strategy for inference') 69 | parser.add_argument("--custom_pad", type=int, default=None,help='padding strategy for inference') 70 | 71 | #variance reduction 72 | #var reg 73 | parser.add_argument("--nu_var", type=float, default=0.01) 74 | parser.add_argument("--freq_var", type=int, default=3) 75 | parser.add_argument("--var_reg", type=str2bool, default=False) 76 | 77 | parser.add_argument("--verbose", type=str2bool, default=1) 78 | 79 | args = parser.parse_args() 80 | # os.environ['CUDA_VISIBLE_DEVICES']= '6,7' 81 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 82 | device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu' 83 | capability = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else os.cpu_count() 84 | gpus=args.gpus 85 | if torch.cuda.is_available(): 86 | torch.backends.cudnn.benchmark = True 87 | if device.type=='cuda': 88 | torch.cuda.set_device('cuda:{}'.format(gpus[0])) 89 | if args.stride_val>args.patch_size: 90 | args.stride_val=args.patch_size//2 91 | if args.stride_test>args.patch_size: 92 | args.stride_test = args.patch_size // 2 93 | test_path = [f'{args.test_path}'] 94 | train_path = [f'{args.train_path}'] 95 | val_path = train_path 96 | noise_std = args.noise_level / 255 97 | args.log_dir= args.log_dir+"_"+str(args.noise_level) 98 | args.out_dir= args.out_dir+"_"+str(args.noise_level) 99 | 100 | if args.bandwise: 101 | args.log_dir = args.log_dir + "_bandwise" 102 | args.out_dir += "_bandwise" 103 | if not os.path.exists(args.log_dir): 104 | os.makedirs(args.log_dir) 105 | log_file_name = "./%s/MACNet_patch_%dLayer_%dlr_%.8f.txt" % ( 106 | args.log_dir,args.patch_size,args.num_half_layer*2, args.lr) 107 | 108 | loaders = dataloaders_hsi.get_dataloaders(train_path, test_path, val_path, crop_size=args.patch_size, 109 | batch_size=args.train_batch, downscale=args.aug_scale, concat=1,grey=False) 110 | 111 | 112 | 113 | 114 | 115 | params = Params(in_channels=1, channels=args.channels, 116 | num_half_layer=args.num_half_layer,rs=args.rs_real) 117 | model = MACNet(params).to(device=device) 118 | if device.type=='cuda': 119 | model = torch.nn.DataParallel(model.to(device=device), device_ids=gpus) 120 | 121 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=args.eps) 122 | 123 | if args.backtrack: 124 | reload_counter = 0 125 | 126 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 127 | print(f'Arguments: {vars(args)}') 128 | print('Nb tensors: ',len(list(model.named_parameters())), "; Trainable Params: ", pytorch_total_params, "; device: ", device, 129 | "; name : ", device_name) 130 | 131 | psnr = {x: np.zeros(args.num_epochs) for x in ['train', 'test', 'val']} 132 | 133 | model_name = args.model_name if args.model_name is not None else generate_key() 134 | model_name = "MACNet_patch_%dLayer_%dlr_%.8f" % (args.patch_size, args.num_half_layer*2, args.lr) 135 | out_dir = os.path.join(args.out_dir, model_name) 136 | if not os.path.exists(out_dir): 137 | os.makedirs(out_dir) 138 | ckpt_path = os.path.join(out_dir+'/ckpt') 139 | config_dict = vars(args) 140 | if args.resume: 141 | if os.path.isfile(ckpt_path): 142 | try: 143 | print('\n existing ckpt detected') 144 | checkpoint = torch.load(ckpt_path) 145 | start_epoch = checkpoint['epoch'] 146 | psnr_validation = checkpoint['psnr_validation'] 147 | model.load_state_dict(checkpoint['state_dict']) 148 | optimizer.load_state_dict(checkpoint['optimizer']) 149 | print(f"=> loaded checkpoint '{ckpt_path}' (epoch {start_epoch})") 150 | except Exception as e: 151 | print(e) 152 | print(f'ckpt loading failed @{ckpt_path}, exit ...') 153 | exit() 154 | 155 | else: 156 | print(f'\nno ckpt found @{ckpt_path}') 157 | start_epoch = 0 158 | psnr_validation = 22.0 159 | if args.backtrack: 160 | state = {'psnr_validation': psnr_validation, 161 | 'epoch': 0, 162 | 'config': config_dict, 163 | 'state_dict': model.state_dict(), 164 | 'optimizer': optimizer.state_dict(), } 165 | torch.save(state, ckpt_path + '_lasteval') 166 | 167 | print(f'... starting training ...\n') 168 | 169 | 170 | epoch = start_epoch 171 | 172 | while epoch < args.num_epochs: 173 | 174 | tic = time.time() 175 | 176 | phases = ['train', 'val', 'test',] 177 | 178 | for phase in phases: 179 | if phase == 'train': 180 | if (epoch % args.lr_step) == 0 and (epoch != 0) : 181 | step_lr(optimizer, args.lr_decay) 182 | model.train() 183 | 184 | elif phase == 'val': 185 | if not (args.backtrack and ((epoch+1) % args.validation_every == 0)): 186 | continue 187 | model.eval() # Set model to evaluate mode 188 | print(f'\nstarting validation on train set with stride {args.stride_val}...') 189 | 190 | 191 | elif phase == 'test': 192 | if (epoch+1) % args.test_every != 0: 193 | continue # test every k epoch 194 | print(f'\nstarting eval on test set with stride {args.stride_test}...') 195 | model.eval() # Set model to evaluate mode 196 | 197 | 198 | # Iterate over data. 199 | num_iters = 0 200 | psnr_set = 0 201 | loss_set = 0 202 | 203 | loader = loaders[phase] 204 | 205 | for batch in tqdm(loader,disable=not args.tqdm): 206 | batch = batch.to(device=device) 207 | if args.bandwise: 208 | bands=batch.shape[1] 209 | noise=torch.randn_like(batch) 210 | for i in range(bands): 211 | noise[:,i,:,:] = torch.randn_like(batch[:,i,:,:])*torch.rand(1).to(device=device)* noise_std 212 | else: 213 | noise = torch.randn_like(batch)* noise_std 214 | noisy_batch = batch + noise 215 | optimizer.zero_grad() 216 | 217 | with torch.set_grad_enabled(phase == 'train'): 218 | 219 | # Block inference during test phase 220 | if (phase == 'test' or phase == 'val'): 221 | 222 | if phase == 'val': 223 | stride_test = args.stride_val 224 | else: 225 | stride_test = args.stride_test 226 | 227 | if args.block_inference: 228 | params = { 229 | 'crop_out_blocks': 0, 230 | 'ponderate_out_blocks': 1, 231 | 'sum_blocks': 0, 232 | 'pad_even': 1, # otherwise pad with 0 for las 233 | 'centered_pad': 0, # corner pixel have only one estimate 234 | 'pad_block': args.pad_block, # pad so each pixel has S**2 estimate 235 | 'pad_patch': args.pad_patch, # pad so each pixel from the image has at least S**2 estimate from 1 block 236 | 'no_pad': args.no_pad, 237 | 'custom_pad': args.custom_pad, 238 | 'avg': 1} 239 | block = block_module(args.patch_size, stride_test, args.kernel_size, params) 240 | batch_noisy_blocks = block._make_blocks(noisy_batch) 241 | patch_loader = torch.utils.data.DataLoader(batch_noisy_blocks, batch_size=args.test_batch, drop_last=False) 242 | batch_out_blocks = torch.zeros_like(batch_noisy_blocks) 243 | 244 | for i, inp in enumerate(patch_loader): # if it doesnt fit in memory 245 | id_from, id_to = i * patch_loader.batch_size, (i + 1) * patch_loader.batch_size 246 | batch_out_blocks[id_from:id_to] = model(inp) 247 | 248 | output = block._agregate_blocks(batch_out_blocks) 249 | #print(torch.isnan(output).sum()) 250 | else: 251 | output = model(noisy_batch) 252 | loss = ((output.clamp(0., 1.) - batch)).pow(2).sum() / batch.shape[0] 253 | loss_psnr = -10 * torch.log10((output.clamp(0., 1.) - batch).pow(2).mean([1, 2, 3])).mean() 254 | 255 | if phase == 'train': 256 | 257 | output = model(noisy_batch) 258 | loss = ((output - batch)).pow(2).sum() / batch.shape[0] 259 | loss_psnr = -10 * torch.log10((output - batch).pow(2).mean([1, 2, 3])).mean() 260 | loss.backward() 261 | optimizer.step() 262 | # print("loss: \n", loss.item()) 263 | 264 | psnr_set += loss_psnr.item() 265 | loss_set += loss.item() 266 | num_iters += 1 267 | 268 | if args.dummy: 269 | break 270 | 271 | tac = time.time() 272 | psnr_set /= num_iters 273 | loss_set /= num_iters 274 | 275 | psnr[phase][epoch] = psnr_set 276 | 277 | if phase == 'val': 278 | r_err = -(psnr_set - psnr_validation) 279 | print( 280 | f'validation psnr {psnr_set:0.4f}, {psnr_validation:0.4f}, absolute_delta {-r_err:0.2e}, reload counter {reload_counter}') 281 | path = ckpt_path + '_lasteval' 282 | 283 | if r_err > 0.2: # test divergence 284 | if os.path.isfile(path): 285 | try: 286 | print('backtracking: previous ckpt detected') 287 | checkpoint = torch.load(path) 288 | epoch = checkpoint['epoch'] 289 | model.load_state_dict(checkpoint['state_dict']) 290 | optimizer.load_state_dict(checkpoint['optimizer']) 291 | [step_lr(optimizer, args.backtrack_decay) for _ in range(reload_counter + 1)] 292 | print(f"loaded checkpoint '{path}' (epoch {epoch}), decreasing lr ==> {get_lr(optimizer):0.2e}") 293 | reload_counter += 1 294 | except Exception as e: 295 | print('catched exception :') 296 | print(e) 297 | print(f'ckpt loading failed @{path}') 298 | else: 299 | print('no ckpt found for backtrack') 300 | else: 301 | reload_counter = 0 302 | state = {'psnr_validation': psnr_validation, 303 | 'epoch': epoch, 304 | 'config': config_dict, 305 | 'state_dict': model.state_dict(), 306 | 'optimizer': optimizer.state_dict(), } 307 | torch.save(state, ckpt_path + '_lasteval') 308 | psnr_validation = psnr_set 309 | 310 | if torch.cuda.is_available(): 311 | mem_used, max_mem = show_mem() 312 | tqdm.write(f'epoch {epoch} - {phase} psnr: {psnr[phase][epoch]:0.4f} ({tac-tic:0.1f} s, {(tac - tic) / num_iters:0.3f} s/iter, max gpu mem allocated {max_mem:0.1f} Mb, lr {get_lr(optimizer):0.1e})') 313 | else: 314 | tqdm.write(f'epoch {epoch} - {phase} psnr: {psnr[phase][epoch]:0.4f} loss: {loss_set:0.4f} ({(tac-tic)/num_iters:0.3f} s/iter, lr {get_lr(optimizer):0.2e})') 315 | with open(f'{log_file_name}', 'a') as log_file: 316 | log_file = open(log_file_name, 'a') 317 | log_file.write( 318 | f'epoch {epoch} - {phase} psnr: {psnr[phase][epoch]:0.4f} loss: {loss_set:0.4f} ({(tac - tic) / num_iters:0.3f} s/iter, lr {get_lr(optimizer):0.2e})\n') 319 | # output_file.close() 320 | with open(f'{out_dir}/{phase}.psnr','a') as psnr_file: 321 | psnr_file.write(f'{psnr[phase][epoch]:0.4f}\n') 322 | 323 | 324 | epoch += 1 325 | ##################### saving ################# 326 | if epoch % 10 == 0: 327 | save_checkpoint({'epoch': epoch, 328 | 'config': config_dict, 329 | 'state_dict': model.state_dict(), 330 | 'optimizer': optimizer.state_dict(), 331 | 'psnr_validation': psnr_validation}, os.path.join(out_dir+'/ckpt_'+str(epoch))) 332 | save_checkpoint({'epoch': epoch, 333 | 'config': config_dict, 334 | 'state_dict': model.state_dict(), 335 | 'optimizer': optimizer.state_dict(), 336 | 'psnr_validation':psnr_validation}, ckpt_path) -------------------------------------------------------------------------------- /trained_model/ckpt_15: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/trained_model/ckpt_15 -------------------------------------------------------------------------------- /trained_model/ckpt_55: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/trained_model/ckpt_55 -------------------------------------------------------------------------------- /trained_model/ckpt_95: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearshng/mac-net/0c8a72a2d2ca5154e8ae6c697727ad5a24a8774e/trained_model/ckpt_95 --------------------------------------------------------------------------------