├── LICENSE ├── README.md ├── __pycache__ ├── hsi_setup.cpython-37.pyc └── indexes.cpython-37.pyc ├── checkpoints ├── houston │ ├── iid.pth │ ├── mixture.pth │ └── niid.pth └── icvl │ ├── iid.pth │ ├── mixture.pth │ └── niid.pth ├── datasets ├── test │ └── .test └── training │ └── .training ├── eval_iid.sh ├── eval_mix.sh ├── eval_niid.sh ├── hsi_denoising_complex.py ├── hsi_denoising_gauss_iid.py ├── hsi_denoising_gauss_niid.py ├── hsi_eval.py ├── hsi_setup.py ├── hsi_test.py ├── indexes.py ├── matlab ├── Data │ ├── _meta_complex.mat │ ├── _meta_complex_2.mat │ ├── _meta_gauss.mat │ └── _meta_gauss_2.mat ├── HSIData.m ├── HSIEval.m ├── HSI_eval.m ├── HSI_test.m ├── HSI_visualize.m ├── Main_Complex.m ├── Main_Gauss.m ├── Main_Real.m ├── README.md ├── Result_Complex.m ├── Result_Gauss.m ├── benchmarks.bib ├── demo_fun.m ├── eval_dataset.m ├── generate_dataset.m ├── generate_dataset_blind.m ├── generate_dataset_complex.m ├── generate_dataset_complex_backup.m ├── generate_dataset_deadline.m ├── generate_dataset_impulse.m ├── generate_dataset_mixture.m ├── generate_dataset_mixture_backup.m ├── generate_dataset_noniid.m └── generate_dataset_stripe.m ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── combinations.cpython-37.pyc │ └── im2col.cpython-37.pyc ├── combinations.py ├── im2col.py ├── networks_other.py ├── nssnn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── cc.cpython-37.pyc │ │ └── nssnn.cpython-37.pyc │ ├── cc.py │ └── nssnn.py ├── qrnn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── combinations.cpython-37.pyc │ │ ├── combinations.cpython-38.pyc │ │ ├── qrnn3d.cpython-37.pyc │ │ ├── qrnn3d.cpython-38.pyc │ │ ├── redc3d.cpython-37.pyc │ │ ├── redc3d.cpython-38.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ ├── combinations.py │ ├── qrnn3d.py │ ├── redc3d.py │ ├── resnet.py │ └── utils.py ├── sru │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── combinations.cpython-37.pyc │ │ └── sru3d.cpython-37.pyc │ ├── combinations.py │ └── sru3d.py ├── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── batchnorm.cpython-38.pyc │ │ ├── comm.cpython-37.pyc │ │ ├── comm.cpython-38.pyc │ │ ├── replicate.cpython-37.pyc │ │ └── replicate.cpython-38.pyc │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── unet │ ├── __pycache__ │ │ ├── buildingblocks.cpython-37.pyc │ │ ├── buildingblocks.cpython-38.pyc │ │ ├── unet.cpython-37.pyc │ │ ├── unet.cpython-38.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ ├── buildingblocks.py │ ├── unet.py │ └── utils.py └── utils.py ├── torch_37.yaml ├── train_iid.sh ├── train_mixtrue.sh ├── train_niid.sh └── utility ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── data_parallel.cpython-37.pyc ├── dataloaders_hsi_test.cpython-37.pyc ├── dataloaders_hsi_test.cpython-38.pyc ├── dataset.cpython-37.pyc ├── dataset.cpython-38.pyc ├── gauss.cpython-37.pyc ├── gauss.cpython-38.pyc ├── helper.cpython-37.pyc ├── helper.cpython-38.pyc ├── indexes.cpython-37.pyc ├── indexes.cpython-38.pyc ├── lmdb_dataset.cpython-37.pyc ├── lmdb_dataset.cpython-38.pyc ├── read_HSI.cpython-37.pyc ├── read_HSI.cpython-38.pyc ├── refold.cpython-37.pyc ├── refold.cpython-38.pyc ├── ssim.cpython-37.pyc ├── ssim.cpython-38.pyc ├── util.cpython-36.pyc ├── util.cpython-37.pyc └── util.cpython-38.pyc ├── data_parallel.py ├── dataloaders_hsi.py ├── dataloaders_hsi_test.py ├── dataset copy.py ├── dataset.py ├── gauss.py ├── helper.py ├── indexes.py ├── indexes_back.py ├── lmdb_data.py ├── lmdb_data_ori.py ├── lmdb_dataset.py ├── mat_data.py ├── read_HSI.py ├── readme.py ├── refold.py ├── ssim.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # NSSNN 2 | 3 | The implementation of TGRS 2022 paper ["Nonlocal Spatial-Spectral Neural Network for Hyperspectral Image Denoising"](https://ieeexplore.ieee.org/abstract/document/9930129/) 4 | 5 | ## Requisites 6 | * See ```torch_37.yaml``` 7 | 8 | ## Quick Start 9 | 10 | ### 1. Preparing your training/testing datasets 11 | 12 | * Download HSIs from [here](https://njusteducn-my.sharepoint.com/:f:/g/personal/119106032867_njust_edu_cn/EhlvptVmZohEpjkNnu9P_xQBCJfpcSzXTg_omD2YCvXuIA?e=Xzy29C). 13 | 14 | #### Training dataset 15 | 16 | * Create training datasets by ```python utility/lmdb_data.py``` 17 | 18 | #### Testing dataset 19 | 20 | *Note matlab is required to execute the following instructions.* 21 | 22 | * You can use the testing set we prepared for you in ```datasets/test/``` 23 | 24 | * Read the matlab code of ```matlab/generate_dataset*``` to understand how we generate noisy HSIs. 25 | 26 | * Read and modify the matlab code of ```matlab/HSIData.m``` to generate your own testing dataset 27 | 28 | ### 2. Testing with pretrained models 29 | 30 | * Our pretrained models are in ```checkpoints/```, you can use the scripts ```eval*.sh``` to test the pretrained models. 31 | 32 | ### 3. Training from scratch 33 | 34 | * Use training scipts ```train*.sh``` to train your own models. 35 | 36 | ## Citation 37 | If you find this work useful for your research, please cite: 38 | ``` 39 | @ARTICLE{fu2022nssnn, 40 | author={Fu, Guanyiman and Xiong, Fengchao and Lu, Jianfeng and Zhou, Jun and Qian, Yuntao}, 41 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 42 | title={Nonlocal Spatial–Spectral Neural Network for Hyperspectral Image Denoising}, 43 | year={2022}, 44 | volume={60}, 45 | number={}, 46 | pages={1-16}, 47 | doi={10.1109/TGRS.2022.3217097}} 48 | 49 | ## Contact 50 | Please contact me if there is any question (gym.fu@njust.edu.cn) 51 | -------------------------------------------------------------------------------- /__pycache__/hsi_setup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/__pycache__/hsi_setup.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/indexes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/__pycache__/indexes.cpython-37.pyc -------------------------------------------------------------------------------- /checkpoints/houston/iid.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/checkpoints/houston/iid.pth -------------------------------------------------------------------------------- /checkpoints/houston/mixture.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/checkpoints/houston/mixture.pth -------------------------------------------------------------------------------- /checkpoints/houston/niid.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/checkpoints/houston/niid.pth -------------------------------------------------------------------------------- /checkpoints/icvl/iid.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/checkpoints/icvl/iid.pth -------------------------------------------------------------------------------- /checkpoints/icvl/mixture.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/checkpoints/icvl/mixture.pth -------------------------------------------------------------------------------- /checkpoints/icvl/niid.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/checkpoints/icvl/niid.pth -------------------------------------------------------------------------------- /datasets/test/.test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/datasets/test/.test -------------------------------------------------------------------------------- /datasets/training/.training: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/datasets/training/.training -------------------------------------------------------------------------------- /eval_iid.sh: -------------------------------------------------------------------------------- 1 | conda activate torch_37 2 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 3 | --gpu-ids 2 -ofd results/nssnn/icvl/iid/50/ \ 4 | -r -rp checkpoints/icvl/iid.pth \ 5 | -tr datasets/test/ICVL/iid/50/ -gr datasets/test/ICVL/gt/ 6 | 7 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 8 | --gpu-ids 2 -ofd results/nssnn/icvl/iid/70/ \ 9 | -r -rp checkpoints/icvl/iid.pth \ 10 | -tr datasets/test/ICVL/iid/70/ -gr datasets/test/ICVL/gt/ 11 | 12 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 13 | --gpu-ids 2 -ofd results/nssnn/icvl/iid/90/ \ 14 | -r -rp checkpoints/icvl/iid.pth \ 15 | -tr datasets/test/ICVL/iid/90/ -gr datasets/test/ICVL/gt/ 16 | -------------------------------------------------------------------------------- /eval_mix.sh: -------------------------------------------------------------------------------- 1 | conda activate torch_37 2 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 3 | --gpu-ids 2 -ofd results/nssnn/icvl/mixture/95_mixture/ \ 4 | -r -rp checkpoints/icvl/mixture.pth \ 5 | -tr testsets/ICVL/mixture/95_mixture/ -gr testsets/ICVL/gt/ -------------------------------------------------------------------------------- /eval_niid.sh: -------------------------------------------------------------------------------- 1 | conda activate torch_37 2 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 3 | --gpu-ids 2 -ofd results/nssnn/icvl/niid/15/ \ 4 | -r -rp checkpoints/icvl/niid.pth \ 5 | -tr datasets/test/ICVL/niid/15/ -gr datasets/test/ICVL/gt/ 6 | 7 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 8 | --gpu-ids 2 -ofd results/nssnn/icvl/niid/55/ \ 9 | -r -rp checkpoints/icvl/niid.pth \ 10 | -tr datasets/test/ICVL/niid/55/ -gr datasets/test/ICVL/gt/ 11 | 12 | python hsi_eval.py -a nssnn -ofn nssnn -s output \ 13 | --gpu-ids 2 -ofd results/nssnn/icvl/niid/95/ \ 14 | -r -rp checkpoints/icvl/niid.pth \ 15 | -tr datasets/test/ICVL/niid/95/ -gr datasets/test/ICVL/gt/ 16 | -------------------------------------------------------------------------------- /hsi_denoising_complex.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import argparse 9 | 10 | from utility import * 11 | from hsi_setup import Engine, train_options, make_dataset 12 | from utility import dataloaders_hsi_test ###modified### 13 | 14 | if __name__ == '__main__': 15 | """Training settings""" 16 | parser = argparse.ArgumentParser( 17 | description='Hyperspectral Image Denoising (Complex noise)') 18 | opt = train_options(parser) 19 | print(opt) 20 | 21 | print(torch.cuda.device_count()) 22 | """Setup Engine""" 23 | engine = Engine(opt) 24 | 25 | """Dataset Setting""" 26 | HSI2Tensor = partial(HSI2Tensor, use_2dconv=engine.get_net().use_2dconv) 27 | 28 | add_noniid_noise = Compose([ 29 | AddNoiseDynamic(95), 30 | SequentialSelect( 31 | transforms=[ 32 | lambda x: x, 33 | AddNoiseImpulse(), 34 | AddNoiseStripe(), 35 | AddNoiseDeadline() 36 | ] 37 | ) 38 | ]) 39 | common_transform_1 = lambda x: x 40 | common_transform = Compose([ 41 | partial(rand_crop, cropx=32, cropy=32), 42 | ]) 43 | 44 | target_transform = HSI2Tensor() 45 | 46 | train_transform = Compose([ 47 | add_noniid_noise, 48 | HSI2Tensor() 49 | ]) 50 | 51 | print('==> Preparing data..') 52 | 53 | icvl_64_31_TL = make_dataset( 54 | opt, train_transform, 55 | target_transform, common_transform_1, 16) 56 | 57 | """Test-Dev""" 58 | basefolder = opt.testroot 59 | mat_loaders = [] 60 | test_path = os.path.join(basefolder) 61 | #print('noise: ',noise,end='') 62 | mat_loaders.append(dataloaders_hsi_test.get_dataloaders([test_path],verbose=True,grey=False)['test']) 63 | 64 | base_lr = opt.lr 65 | base_lr = 1e-4 66 | epoch_per_save = 10 67 | if opt.resetepoch != -1: 68 | engine.epoch = opt.resetepoch 69 | adjust_learning_rate(engine.optimizer, opt.lr) 70 | # from epoch 50 to 100 71 | while engine.epoch < 100: 72 | display_learning_rate(engine.optimizer) 73 | np.random.seed() 74 | if engine.epoch == 85: 75 | adjust_learning_rate(engine.optimizer, base_lr*0.3) 76 | 77 | if engine.epoch == 95: 78 | adjust_learning_rate(engine.optimizer, base_lr*0.1) 79 | 80 | print("Training with complex") 81 | engine.train(icvl_64_31_TL) 82 | if engine.epoch == 100: 83 | MSIQAs=engine.validate_MSIQA(mat_loaders[0],folder='results/nssrnn/icvl/mix/',name='nssnn_mix',size=50) 84 | print("%.4f %.4f %.4f"%( MSIQAs[0],MSIQAs[1],MSIQAs[2])) 85 | 86 | else: 87 | MSIQAs=engine.validate_MSIQA(mat_loaders[0],name='nssnn_mix',size=2) 88 | print("%.4f %.4f %.4f"%( MSIQAs[0],MSIQAs[1],MSIQAs[2])) 89 | 90 | 91 | print('\nLatest Result Saving...') 92 | model_latest_path = os.path.join(engine.basedir, engine.prefix, 'model_latest.pth') 93 | engine.save_checkpoint( 94 | model_out_path=model_latest_path 95 | ) 96 | 97 | 98 | if engine.epoch % epoch_per_save == 0:###modified### 99 | engine.save_checkpoint() 100 | 101 | MSIQAs = [] 102 | for mat_loader in mat_loaders: 103 | MSIQAs.append(engine.validate_MSIQA(mat_loader,folder=opt.output_fold,name=opt.output_file_name)) 104 | for MSIQA in MSIQAs: 105 | for index in MSIQA: 106 | print("%.4f"%(index)) -------------------------------------------------------------------------------- /hsi_denoising_gauss_iid.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import argparse 9 | 10 | from utility import * 11 | from hsi_setup import Engine, train_options, make_dataset 12 | from utility import dataloaders_hsi_test ###modified### 13 | 14 | if __name__ == '__main__': 15 | """Training settings""" 16 | parser = argparse.ArgumentParser( 17 | description='Hyperspectral Image Denoising (Gaussian Noise)') 18 | opt = train_options(parser) 19 | print(opt) 20 | 21 | print(torch.cuda.device_count()) 22 | """Setup Engine""" 23 | engine = Engine(opt) 24 | 25 | """Dataset Setting""" 26 | HSI2Tensor = partial(HSI2Tensor, use_2dconv=engine.get_net().use_2dconv) 27 | 28 | common_transform_1 = lambda x: x 29 | 30 | common_transform_2 = Compose([ 31 | partial(rand_crop, cropx=32, cropy=32), 32 | ]) 33 | 34 | target_transform = HSI2Tensor() 35 | train_transform_0 = Compose([ 36 | AddNoise(50), 37 | HSI2Tensor() 38 | ]) 39 | train_transform_1 = Compose([ 40 | AddNoise(10), 41 | HSI2Tensor() 42 | ]) 43 | train_transform_2 = Compose([ 44 | AddNoise(30), 45 | HSI2Tensor() 46 | ]) 47 | train_transform_3 = Compose([ 48 | AddNoise(50), 49 | HSI2Tensor() 50 | ]) 51 | train_transform_4 = Compose([ 52 | AddNoise(70), 53 | HSI2Tensor() 54 | ]) 55 | train_transform_5 = Compose([ 56 | AddNoise(90), 57 | HSI2Tensor() 58 | ]) 59 | train_transform_6 = Compose([ 60 | AddNoiseList((10,30,50,70,90)), 61 | HSI2Tensor() 62 | ]) 63 | ''' 64 | train_transform_1 = Compose([ 65 | AddNoise(50), 66 | HSI2Tensor() 67 | ]) 68 | 69 | train_transform_2 = Compose([ 70 | AddNoiseBlind([10, 30, 50, 70]), 71 | HSI2Tensor() 72 | ]) 73 | ''' 74 | print('==> Preparing data..') 75 | 76 | icvl_64_31_TL_0 = make_dataset( 77 | opt, train_transform_0, 78 | target_transform, common_transform_1,opt.batchSize ) 79 | icvl_64_31_TL_1 = make_dataset( 80 | opt, train_transform_1, 81 | target_transform, common_transform_1, opt.batchSize) 82 | icvl_64_31_TL_2 = make_dataset( 83 | opt, train_transform_2, 84 | target_transform, common_transform_1,opt.batchSize) 85 | icvl_64_31_TL_3 = make_dataset( 86 | opt, train_transform_3, 87 | target_transform, common_transform_1, opt.batchSize) 88 | # icvl_64_31_TL_4 = make_dataset( 89 | # opt, train_transform_4, 90 | # target_transform, common_transform_2, opt.batchSize*4) 91 | icvl_64_31_TL_4 = make_dataset( 92 | opt, train_transform_4, 93 | target_transform, common_transform_1, opt.batchSize) 94 | 95 | icvl_64_31_TL_5 = make_dataset( 96 | opt, train_transform_5, 97 | target_transform, common_transform_1, opt.batchSize) 98 | 99 | icvl_64_31_TL_6 = make_dataset( 100 | opt, train_transform_6, 101 | target_transform, common_transform_1, opt.batchSize) 102 | ''' 103 | icvl_64_31_TL_2 = make_dataset( 104 | opt, train_transform_2, 105 | target_transform, common_transform_2, 64) 106 | ''' 107 | """Test-Dev""" 108 | 109 | ###modified### 110 | basefolder = opt.testroot 111 | mat_names = ['icvl_dynamic_512_50','icvl_dynamic_512_70','icvl_dynamic_512_90'] 112 | #mat_names = ['icvl_512_30', 'icvl_512_50'] 113 | mat_loaders = [] 114 | for noise in (50,70,90): 115 | test_path = os.path.join(basefolder, str(noise)+'/') 116 | #print('noise: ',noise,end='') 117 | mat_loaders.append(dataloaders_hsi_test.get_dataloaders([test_path],verbose=True,grey=False)['test']) 118 | ###modified### 119 | 120 | #print(icvl_64_31_TL_0.__len__()) 121 | max=30*7 122 | if icvl_64_31_TL_0.__len__()*opt.batchSize > 2000: 123 | max_epoch = max//2 124 | if_100 = 1 125 | if_eval = 1 126 | epoch_per_save = 1 127 | testsize = 10 128 | else: 129 | max_epoch = max 130 | if_100 = 0 131 | if_eval = 1 132 | epoch_per_save = 10 133 | testsize = 5 134 | print('max_epoch: ',max_epoch) 135 | """Main loop""" 136 | base_lr = opt.lr 137 | if_val_any = 1 138 | if opt.resetepoch != -1: 139 | engine.epoch = opt.resetepoch 140 | while engine.epoch < max_epoch: 141 | if if_100: 142 | epoch = engine.epoch * 2 143 | else: 144 | epoch = engine.epoch 145 | display_learning_rate(engine.optimizer) 146 | np.random.seed() # reset seed per epoch, otherwise the noise will be added with a specific pattern 147 | if epoch == 0: 148 | adjust_learning_rate(engine.optimizer, opt.lr) 149 | elif epoch == 10: 150 | adjust_learning_rate(engine.optimizer, base_lr*0.1) 151 | elif epoch == 20: 152 | adjust_learning_rate(engine.optimizer, base_lr*0.01) 153 | elif epoch % 30 == 0 and epoch >29 and epoch < max: 154 | adjust_learning_rate(engine.optimizer, base_lr*0.1) 155 | elif epoch % 30 == 14 and epoch >29 and epoch < max: 156 | adjust_learning_rate(engine.optimizer, base_lr*0.01) 157 | elif epoch == max: 158 | adjust_learning_rate(engine.optimizer, base_lr*0.001) 159 | ''' 160 | elif engine.epoch % 30 == 1 and engine.epoch != 1: 161 | adjust_learning_rate(engine.optimizer, base_lr) 162 | 163 | elif engine.epoch % 30 == 0 and engine.epoch != 0: 164 | adjust_learning_rate(engine.optimizer, base_lr*0.01) 165 | ''' 166 | #print(if_100) 167 | if epoch < 30: 168 | #engine.validate(mat_loaders[0], 'icvl-validate-50') 169 | print("Training with unbindwise noise 50dB") 170 | engine.train(icvl_64_31_TL_0) 171 | if if_val_any: 172 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 173 | if if_eval: 174 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 175 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 176 | #engine.validate(mat_loaders[1], 'icvl-validate-50') 177 | elif epoch < 60: 178 | print("Training with 10dB") 179 | engine.train(icvl_64_31_TL_1) 180 | if if_val_any: 181 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 182 | if if_eval: 183 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 184 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 185 | #engine.validate(mat_loaders[0], 'icvl-validate-50') 186 | #engine.validate(mat_loaders[0], 'icvl-validate-30') 187 | #engine.validate(mat_loaders[1], 'icvl-validate-50') 188 | elif epoch < 90: 189 | print("Training with 30dB") 190 | engine.train(icvl_64_31_TL_2) 191 | if if_val_any: 192 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 193 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 194 | if if_eval: 195 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 196 | elif epoch < 120: 197 | print("Training with 50dB") 198 | engine.train(icvl_64_31_TL_3) 199 | if if_val_any: 200 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 201 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 202 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 203 | elif epoch < 150: 204 | print("Training with 70dB") 205 | engine.train(icvl_64_31_TL_4) 206 | if if_val_any: 207 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 208 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 209 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 210 | elif epoch < 180: 211 | print("Training with 90dB") 212 | engine.train(icvl_64_31_TL_5) 213 | if if_val_any: 214 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 215 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 216 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 217 | else: 218 | print("Training with random noise") 219 | engine.train(icvl_64_31_TL_6) 220 | if engine.epoch == max_epoch and engine.epoch == 15*7: 221 | testsize = 50 222 | MSIQAs = [] 223 | MSIQAs.append(engine.validate_MSIQA(mat_loaders[0], 'icvl-validate-50',folder='nssnn_iid')) 224 | MSIQAs.append(engine.validate_MSIQA(mat_loaders[0], 'icvl-validate-70',folder='nssnn_iid')) 225 | MSIQAs.append(engine.validate_MSIQA(mat_loaders[0], 'icvl-validate-90',folder='nssnn_iid')) 226 | print(" PSNR SSIM SAM") 227 | print("50dB: %.4f %.4f %.4f"%( MSIQAs[0][0],MSIQAs[0][1],MSIQAs[0][2])) 228 | print("70dB: %.4f %.4f %.4f"%( MSIQAs[1][0],MSIQAs[1][1],MSIQAs[1][2])) 229 | print("90dB: %.4f %.4f %.4f"%( MSIQAs[2][0],MSIQAs[2][1],MSIQAs[2][2])) 230 | else: 231 | if if_val_any: 232 | engine.validate(mat_loaders[0], 'icvl-validate-50',testsize)###modified### 233 | engine.validate(mat_loaders[1], 'icvl-validate-70',testsize)###modified### 234 | engine.validate(mat_loaders[2], 'icvl-validate-90',testsize)###modified### 235 | 236 | print('\nLatest Result Saving...') 237 | model_latest_path = os.path.join(engine.basedir, engine.prefix, 'model_latest.pth') 238 | engine.save_checkpoint( 239 | model_out_path=model_latest_path 240 | ) 241 | 242 | 243 | if engine.epoch % epoch_per_save == 0:###modified### 244 | engine.save_checkpoint() 245 | -------------------------------------------------------------------------------- /hsi_denoising_gauss_niid.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import argparse 9 | 10 | from utility import * 11 | from hsi_setup import Engine, train_options, make_dataset 12 | from utility import dataloaders_hsi_test ###modified### 13 | 14 | if __name__ == '__main__': 15 | """Training settings""" 16 | parser = argparse.ArgumentParser( 17 | description='Hyperspectral Image Denoising (Gaussian Noise)') 18 | opt = train_options(parser) 19 | print(opt) 20 | 21 | print(torch.cuda.device_count()) 22 | """Setup Engine""" 23 | engine = Engine(opt) 24 | 25 | """Dataset Setting""" 26 | HSI2Tensor = partial(HSI2Tensor, use_2dconv=engine.get_net().use_2dconv) 27 | 28 | common_transform_1 = lambda x: x 29 | 30 | common_transform_2 = Compose([ 31 | partial(rand_crop, cropx=32, cropy=32), 32 | ]) 33 | 34 | target_transform = HSI2Tensor() 35 | train_transform_0 = Compose([ 36 | AddNoise(50), 37 | HSI2Tensor() 38 | ]) 39 | train_transform_1 = Compose([ 40 | AddNoiseDynamic(15), 41 | HSI2Tensor() 42 | ]) 43 | train_transform_2 = Compose([ 44 | AddNoiseDynamic(55), 45 | HSI2Tensor() 46 | ]) 47 | train_transform_3 = Compose([ 48 | AddNoiseDynamic(95), 49 | HSI2Tensor() 50 | ]) 51 | train_transform_4 = Compose([ 52 | AddNoiseDynamicList((15,55,95)), 53 | HSI2Tensor() 54 | ]) 55 | ''' 56 | train_transform_1 = Compose([ 57 | AddNoise(50), 58 | HSI2Tensor() 59 | ]) 60 | 61 | train_transform_2 = Compose([ 62 | AddNoiseBlind([10, 30, 50, 70]), 63 | HSI2Tensor() 64 | ]) 65 | ''' 66 | print('==> Preparing data..') 67 | 68 | icvl_64_31_TL_0 = make_dataset( 69 | opt, train_transform_0, 70 | target_transform, common_transform_1,opt.batchSize ) 71 | icvl_64_31_TL_1 = make_dataset( 72 | opt, train_transform_1, 73 | target_transform, common_transform_1, opt.batchSize) 74 | icvl_64_31_TL_2 = make_dataset( 75 | opt, train_transform_2, 76 | target_transform, common_transform_1,opt.batchSize) 77 | icvl_64_31_TL_3 = make_dataset( 78 | opt, train_transform_3, 79 | target_transform, common_transform_1, opt.batchSize) 80 | icvl_64_31_TL_4 = make_dataset( 81 | opt, train_transform_4, 82 | target_transform, common_transform_2, opt.batchSize*4) 83 | icvl_64_31_TL_5 = make_dataset( 84 | opt, train_transform_4, 85 | target_transform, common_transform_1, opt.batchSize) 86 | ''' 87 | icvl_64_31_TL_2 = make_dataset( 88 | opt, train_transform_2, 89 | target_transform, common_transform_2, 64) 90 | ''' 91 | """Test-Dev""" 92 | 93 | ###modified### 94 | basefolder = opt.testroot 95 | mat_names = ['icvl_dynamic_512_15','icvl_dynamic_512_55','icvl_dynamic_512_95'] 96 | #mat_names = ['icvl_512_30', 'icvl_512_50'] 97 | mat_loaders = [] 98 | for noise in (15,55,95): 99 | # test_path = os.path.join(basefolder, str(noise)+'dB/') 100 | test_path = os.path.join(basefolder, str(noise)+'/') 101 | #print('noise: ',noise,end='') 102 | mat_loaders.append(dataloaders_hsi_test.get_dataloaders([test_path],verbose=True,grey=False)['test']) 103 | ###modified### 104 | 105 | #print(icvl_64_31_TL_0.__len__()) 106 | if icvl_64_31_TL_0.__len__()*opt.batchSize > 2000: 107 | max_epoch = 75 108 | if_100 = 1 109 | if_eval =1 110 | epoch_per_save = 1 111 | testsize = 10 112 | else: 113 | max_epoch = 150 114 | if_100 = 0 115 | if_eval = 1 116 | epoch_per_save = 10 117 | testsize = 5 118 | print(max_epoch) 119 | """Main loop""" 120 | base_lr = opt.lr 121 | if_val_any = 1 122 | if opt.resetepoch != -1: 123 | engine.epoch = opt.resetepoch 124 | while engine.epoch < max_epoch: 125 | if if_100: 126 | epoch = engine.epoch * 2 127 | else: 128 | epoch = engine.epoch 129 | display_learning_rate(engine.optimizer) 130 | np.random.seed() # reset seed per epoch, otherwise the noise will be added with a specific pattern 131 | if epoch == 0: 132 | adjust_learning_rate(engine.optimizer, opt.lr) 133 | elif epoch == 10: 134 | adjust_learning_rate(engine.optimizer, base_lr*0.1) 135 | elif epoch == 20: 136 | adjust_learning_rate(engine.optimizer, base_lr*0.01) 137 | elif epoch % 30 == 0 and epoch >29 and epoch < 150: 138 | adjust_learning_rate(engine.optimizer, base_lr*0.1) 139 | elif epoch % 30 == 14 and epoch >29 and epoch < 150: 140 | adjust_learning_rate(engine.optimizer, base_lr*0.01) 141 | elif epoch == 150: 142 | adjust_learning_rate(engine.optimizer, base_lr*0.001) 143 | ''' 144 | elif engine.epoch % 30 == 1 and engine.epoch != 1: 145 | adjust_learning_rate(engine.optimizer, base_lr) 146 | 147 | elif engine.epoch % 30 == 0 and engine.epoch != 0: 148 | adjust_learning_rate(engine.optimizer, base_lr*0.01) 149 | ''' 150 | #print(if_100) 151 | if epoch < 30: 152 | #engine.validate(mat_loaders[0], 'icvl-validate-15') 153 | print("Training with unbindwise noise 50dB") 154 | engine.train(icvl_64_31_TL_0) 155 | if if_val_any: 156 | engine.validate(mat_loaders[0], 'icvl-validate-15',testsize)###modified### 157 | engine.validate(mat_loaders[1], 'icvl-validate-55',testsize)###modified### 158 | engine.validate(mat_loaders[2], 'icvl-validate-95',testsize)###modified### 159 | 160 | #engine.validate(mat_loaders[1], 'icvl-validate-50') 161 | elif epoch < 60: 162 | print("Training with 15dB") 163 | engine.train(icvl_64_31_TL_1) 164 | if if_val_any: 165 | engine.validate(mat_loaders[0], 'icvl-validate-15',testsize)###modified### 166 | engine.validate(mat_loaders[1], 'icvl-validate-55',testsize)###modified### 167 | engine.validate(mat_loaders[2], 'icvl-validate-95',testsize)###modified### 168 | #engine.validate(mat_loaders[0], 'icvl-validate-15') 169 | #engine.validate(mat_loaders[0], 'icvl-validate-30') 170 | #engine.validate(mat_loaders[1], 'icvl-validate-50') 171 | elif epoch < 90: 172 | print("Training with 55dB") 173 | engine.train(icvl_64_31_TL_2) 174 | if if_val_any: 175 | engine.validate(mat_loaders[0], 'icvl-validate-15',testsize)###modified### 176 | engine.validate(mat_loaders[1], 'icvl-validate-55',testsize)###modified### 177 | engine.validate(mat_loaders[2], 'icvl-validate-95',testsize)###modified### 178 | elif epoch < 120: 179 | print("Training with 95dB") 180 | engine.train(icvl_64_31_TL_3) 181 | if if_val_any: 182 | engine.validate(mat_loaders[0], 'icvl-validate-15',testsize)###modified### 183 | engine.validate(mat_loaders[1], 'icvl-validate-55',testsize)###modified### 184 | engine.validate(mat_loaders[2], 'icvl-validate-95',testsize)###modified### 185 | else: 186 | print("Training with random noise") 187 | engine.train(icvl_64_31_TL_5) 188 | if engine.epoch == max_epoch and engine.epoch == 75: 189 | testsize = 50 190 | MSIQAs = [] 191 | MSIQAs.append(engine.validate_MSIQA(mat_loaders[0], 'icvl-validate-15',folder='nssnn_niid')) 192 | MSIQAs.append(engine.validate_MSIQA(mat_loaders[0], 'icvl-validate-55',folder='nssnn_niid')) 193 | MSIQAs.append(engine.validate_MSIQA(mat_loaders[0], 'icvl-validate-95',folder='nssnn_niid')) 194 | print(" PSNR SSIM SAM") 195 | print("15dB: %.4f %.4f %.4f"%( MSIQAs[0][0],MSIQAs[0][1],MSIQAs[0][2])) 196 | print("55dB: %.4f %.4f %.4f"%( MSIQAs[1][0],MSIQAs[1][1],MSIQAs[1][2])) 197 | print("95dB: %.4f %.4f %.4f"%( MSIQAs[2][0],MSIQAs[2][1],MSIQAs[2][2])) 198 | else: 199 | if if_val_any: 200 | engine.validate(mat_loaders[0], 'icvl-validate-15',testsize)###modified### 201 | engine.validate(mat_loaders[1], 'icvl-validate-55',testsize)###modified### 202 | engine.validate(mat_loaders[2], 'icvl-validate-95',testsize)###modified### 203 | 204 | print('\nLatest Result Saving...') 205 | model_latest_path = os.path.join(engine.basedir, engine.prefix, 'model_latest.pth') 206 | engine.save_checkpoint( 207 | model_out_path=model_latest_path 208 | ) 209 | 210 | 211 | if engine.epoch % epoch_per_save == 0:###modified### 212 | engine.save_checkpoint() 213 | -------------------------------------------------------------------------------- /hsi_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import os 7 | import argparse 8 | from utility import dataloaders_hsi_test 9 | from utility import * 10 | from hsi_setup import Engine, train_options 11 | import models 12 | from indexes import MSIQA 13 | 14 | model_names = sorted(name for name in models.__dict__ 15 | if name.islower() and not name.startswith("__") 16 | and callable(models.__dict__[name])) 17 | 18 | 19 | prefix = 'test' 20 | 21 | if __name__ == '__main__': 22 | """Training settings""" 23 | parser = argparse.ArgumentParser( 24 | description='Hyperspectral Image Denoising') 25 | opt = train_options(parser) 26 | print(opt) 27 | 28 | cuda = not opt.no_cuda 29 | opt.no_log = True 30 | 31 | """Setup Engine""" 32 | engine = Engine(opt) 33 | ###modified### 34 | MSIQAs = [] 35 | basefolder = opt.testroot 36 | psnrs = [] 37 | test_path = os.path.join(basefolder) 38 | test = dataloaders_hsi_test.get_dataloaders([test_path],verbose=True,grey=False) 39 | MSIQAs.append(engine.validate_MSIQA(test['test'],folder=opt.output_fold,name=opt.output_file_name)) 40 | # res_arr, input_arr = engine.test_develop(mat_loader, savedir=resdir, verbose=True) 41 | # print(res_arr.mean(axis=0)) 42 | # print(opt.output_file_name,opt.output_fold) 43 | for MSIQA in MSIQAs: 44 | for index in MSIQA: 45 | print("%.4f"%(index)) -------------------------------------------------------------------------------- /hsi_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import os 7 | import argparse 8 | 9 | from utility import * 10 | from hsi_setup import Engine, train_options 11 | import models 12 | 13 | 14 | model_names = sorted(name for name in models.__dict__ 15 | if name.islower() and not name.startswith("__") 16 | and callable(models.__dict__[name])) 17 | 18 | 19 | prefix = 'test' 20 | 21 | if __name__ == '__main__': 22 | """Training settings""" 23 | parser = argparse.ArgumentParser( 24 | description='Hyperspectral Image Denoising') 25 | opt = train_options(parser) 26 | print(opt) 27 | 28 | cuda = not opt.no_cuda 29 | opt.no_log = True 30 | 31 | """Setup Engine""" 32 | engine = Engine(opt) 33 | 34 | mat_dataset = MatDataFromFolder('/home/fugym/HDD/fugym/QRNN3D/matlab/Data/icvl_dynamic_512_15') 35 | 36 | mat_transform = Compose([ 37 | LoadMatKey(key='img'), # for testing 38 | lambda x: x[:,:220,:256][None], 39 | minmax_normalize, 40 | ]) 41 | 42 | mat_dataset = TransformDataset(mat_dataset, mat_transform) 43 | 44 | mat_loader = DataLoader( 45 | mat_dataset, 46 | batch_size=1, shuffle=False, 47 | num_workers=1, pin_memory=cuda 48 | ) 49 | 50 | # print(engine.net) 51 | 52 | engine.test_real(mat_loader, savedir=None) 53 | -------------------------------------------------------------------------------- /indexes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage.measure import compare_ssim, compare_psnr 4 | from functools import partial 5 | from utility.gauss import fspecial_gauss 6 | from scipy import signal 7 | 8 | class Bandwise(object): 9 | def __init__(self, index_fn): 10 | self.index_fn = index_fn 11 | 12 | def __call__(self, X, Y): 13 | C = X.shape[-3] 14 | bwindex = [] 15 | for ch in range(C): 16 | x = torch.squeeze(X[...,ch,:,:].data).cpu().numpy() 17 | y = torch.squeeze(Y[...,ch,:,:].data).cpu().numpy() 18 | index = self.index_fn(x, y) 19 | bwindex.append(index) 20 | return bwindex 21 | 22 | def ssim(img1, img2, cs_map=False): 23 | """Return the Structural Similarity Map corresponding to input images img1 24 | and img2 (images are assumed to be uint8) 25 | 26 | This function attempts to mimic precisely the functionality of ssim.m a 27 | MATLAB provided by the author's of SSIM 28 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m 29 | """ 30 | img1 = img1.astype(np.float64) 31 | img2 = img2.astype(np.float64) 32 | size = 11 33 | sigma = 1.5 34 | window = fspecial_gauss(size, sigma) 35 | K1 = 0.01 36 | K2 = 0.03 37 | L = 255 # bitdepth of image 38 | C1 = (K1 * L) ** 2 39 | C2 = (K2 * L) ** 2 40 | mu1 = signal.fftconvolve(window, img1, mode='valid') 41 | mu2 = signal.fftconvolve(window, img2, mode='valid') 42 | mu1_sq = mu1 * mu1 43 | mu2_sq = mu2 * mu2 44 | mu1_mu2 = mu1 * mu2 45 | sigma1_sq = signal.fftconvolve(window, img1 * img1, mode='valid') - mu1_sq 46 | sigma2_sq = signal.fftconvolve(window, img2 * img2, mode='valid') - mu2_sq 47 | sigma12 = signal.fftconvolve(window, img1 * img2, mode='valid') - mu1_mu2 48 | if cs_map: 49 | return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 50 | (sigma1_sq + sigma2_sq + C2)), 51 | (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) 52 | else: 53 | return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 54 | (sigma1_sq + sigma2_sq + C2)) 55 | def mse (GT,P): 56 | """calculates mean squared error (mse). 57 | 58 | :param GT: first (original) input image. 59 | :param P: second (deformed) input image. 60 | 61 | :returns: float -- mse value. 62 | """ 63 | # GT,P = _initial_check(GT,P) 64 | return np.mean((GT.astype(np.float32)-P.astype(np.float32))**2) 65 | cal_bwssim = Bandwise(compare_ssim) 66 | cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=1)) 67 | 68 | 69 | def cal_sam(X, Y, eps=1e-8): 70 | X = torch.squeeze(X.data).cpu().numpy() 71 | Y = torch.squeeze(Y.data).cpu().numpy() 72 | tmp = (np.sum(X*Y, axis=0) + eps) /( (np.sqrt(np.sum(X**2, axis=0)))* (np.sqrt(np.sum(Y**2, axis=0))) + eps) 73 | return np.mean(np.real(np.arccos(tmp))) 74 | 75 | def cal_ssim(im_true,im_test,eps=13-8): 76 | # print(im_true.shape) 77 | im_true=im_true.squeeze(0).squeeze(0).cpu().numpy() 78 | im_test = im_test.squeeze(0).squeeze(0).cpu().numpy() 79 | c,_,_=im_true.shape 80 | bwindex = [] 81 | for i in range(c): 82 | bwindex.append(ssim(im_true[i,:,:]*255, im_test[i,:,:,]*255)) 83 | return np.mean(bwindex) 84 | def MSIQA(X, Y): 85 | 86 | psnr = np.mean(cal_bwpsnr(X, Y)) 87 | ssim = cal_ssim(Y,X) 88 | sam = cal_sam(X, Y) 89 | return psnr, ssim, sam 90 | -------------------------------------------------------------------------------- /matlab/Data/_meta_complex.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/matlab/Data/_meta_complex.mat -------------------------------------------------------------------------------- /matlab/Data/_meta_complex_2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/matlab/Data/_meta_complex_2.mat -------------------------------------------------------------------------------- /matlab/Data/_meta_gauss.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/matlab/Data/_meta_gauss.mat -------------------------------------------------------------------------------- /matlab/Data/_meta_gauss_2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/matlab/Data/_meta_gauss_2.mat -------------------------------------------------------------------------------- /matlab/HSIData.m: -------------------------------------------------------------------------------- 1 | %% generate dataset 2 | rng(0) 3 | addpath(genpath('lib')); 4 | 5 | basedir = 'Data'; 6 | sz = 512; 7 | preprocess = @(x)(center_crop(rot90(x), sz, sz)); 8 | % preprocess = @(x)(center_crop(normalized(x), 340, 340)); % Pavia 9 | 10 | %%% for Pavia 11 | % datadir = '/media/kaixuan/DATA/Papers/Code/Matlab/ImageRestoration/ITSReg/code of ITSReg MSI denoising/data/real/new'; 12 | % % datadir = '/media/kaixuan/DATA/Papers/Code/QRNN3D/CNN_HSIC_MRF/Data'; 13 | % newdir = fullfile(basedir, ['Pavia_mixture_full']); 14 | % fns = {'PaviaU.mat'}; 15 | % sigmas = [10, 30, 50, 70]; 16 | % generate_dataset_mixture(datadir, fns, newdir, sigmas, 'hsi', preprocess); 17 | 18 | 19 | % %%% for iid gaussian 20 | datadir = '/media/kaixuan/DATA/Papers/Code/Data/ICVL/Testing'; 21 | g1 = load(fullfile(basedir, '_meta_gauss.mat')); % load fns 22 | g2 = load(fullfile(basedir, '_meta_gauss_2.mat')); 23 | fns = [g1.fns;g2.fns]; 24 | fns = fns(1:1); 25 | 26 | % for sigma = [30,50,70] 27 | % newdir = fullfile(basedir, ['icvl_', num2str(sz), '_', num2str(sigma)]); 28 | % generate_dataset(datadir, fns, newdir, sigma, 'rad', preprocess); 29 | % end 30 | 31 | newdir = fullfile(basedir, ['icvl_', num2str(sz), '_', 'blind']); 32 | generate_dataset_blind(datadir, fns, newdir, 'rad', preprocess); 33 | 34 | % %%% for non-iid gaussian 35 | % g1 = load(fullfile(basedir, '_meta_complex.mat')); % load fns 36 | % g2 = load(fullfile(basedir, '_meta_complex_2.mat')); 37 | % fns = [g1.fns;g2.fns]; 38 | 39 | % datadir = '/media/kaixuan/DATA/Papers/Code/Data/ICVL/Testing'; 40 | % sigmas = [95]; 41 | % newdir = fullfile(basedir, ['icvl_', num2str(sz), '_', 'noniid']); 42 | % generate_dataset_noniid(datadir, fns, newdir, sigmas, 'rad', preprocess); 43 | % %%% for non-iid gaussian + stripe 44 | % newdir = fullfile(basedir, ['icvl_', num2str(sz), '_', 'stripe']); 45 | % generate_dataset_stripe(datadir, fns, newdir, sigmas, 'rad', preprocess); 46 | % %%% for non-iid gaussian + deadline 47 | % newdir = fullfile(basedir, ['icvl_', num2str(sz), '_', 'deadline']); 48 | % generate_dataset_deadline(datadir, fns, newdir, sigmas, 'rad', preprocess); 49 | % %%% for non-iid gaussian + impluse 50 | % newdir = fullfile(basedir, ['icvl_', num2str(sz), '_', 'impulse']); 51 | % generate_dataset_impulse(datadir, fns, newdir, sigmas, 'rad', preprocess); 52 | % %%% for mixture noise 53 | % newdir = fullfile(basedir, ['icvl_', num2str(sz), '_','mixture']); 54 | % generate_dataset_mixture(datadir, fns, newdir, sigmas, 'rad', preprocess); 55 | -------------------------------------------------------------------------------- /matlab/HSIEval.m: -------------------------------------------------------------------------------- 1 | %% Evaluate the result based on Result w.r.t. 5 PQIs 2 | basedir = 'Data'; 3 | method = 'QRNN3D'; 4 | % ds_names = {'icvl_512_10', 'icvl_512_30', 'icvl_512_50', 'icvl_512_70', 'icvl_512_blind'}; 5 | ds_names = {'icvl_512_noniid', 'icvl_512_stripe', 'icvl_512_deadline', 'icvl_512_impulse', 'icvl_512_mixture'}; 6 | 7 | % g1 = load(fullfile(basedir, '_meta_gauss.mat')); % load fns 8 | % g2 = load(fullfile(basedir, '_meta_gauss_2.mat')); 9 | g1 = load(fullfile(basedir, '_meta_complex.mat')); % load fns 10 | g2 = load(fullfile(basedir, '_meta_complex_2.mat')); % load fns 11 | fns = [g1.fns;g2.fns]; 12 | 13 | % methodname = {'None','LRMR','LRTV','NMoG','LRTDTV', 'HSID-CNN', 'MemNet', 'QRNN3D'}; 14 | extra_methods = {'HSID-CNN', 'MemNet', 'QRNN3D'}; 15 | 16 | %% 17 | % fns = {'PaviaU.mat'}; 18 | % dataset_name = 'Pavia_mixture_full'; 19 | % datadir = fullfile(basedir, dataset_name); 20 | % resdir = fullfile('Result', dataset_name); 21 | % extra_res = ECCV_eval(datadir, resdir, fns, method); 22 | % load(fullfile(resdir, 'res_arr_final')); % load res_arr 23 | % res_arr(end+1, :,1:5) = extra_res; 24 | % save(fullfile(resdir, 'res_arr_final2'), 'res_arr'); 25 | % clear res_arr 26 | 27 | % methodname = {'None','LRMR','LRTV','NMoG','LRTDTV', 'HSID-CNN', 'QRNN3D-P', 'QRNN3D-F'}; 28 | %% 29 | 30 | % res_arr = zeros(length(methodname), 1, 5); 31 | % for i = 1:length(methodname) 32 | % method = methodname{i}; 33 | % res = ECCV_eval(datadir, resdir, fns, method); 34 | % res_arr(i,:,:) = res; 35 | % end 36 | % 37 | % save(fullfile(resdir, 'res_arr'),'res_arr'); 38 | 39 | for i = 1:5 40 | dataset_name = ds_names{i}; 41 | datadir = fullfile(basedir, dataset_name); 42 | resdir = fullfile('Result', dataset_name); 43 | load(fullfile(resdir, 'res_arr')); % load res_arr 44 | for k=1:length(extra_methods) 45 | method = extra_methods{k}; 46 | extra_res = HSI_eval(datadir, resdir, fns, method); 47 | 48 | res_arr(end+1, :,:) = extra_res; 49 | save(fullfile(resdir, 'res_arr_new'), 'res_arr'); 50 | end 51 | end -------------------------------------------------------------------------------- /matlab/HSI_eval.m: -------------------------------------------------------------------------------- 1 | function [ res_arr ] = HSI_eval( datadir, resdir, fns, method ) 2 | num_data = length(fns); 3 | res_arr = zeros(num_data, 5); 4 | for k = 1:num_data 5 | fn = fns{k}; 6 | disp(['evaluate ' method ' in pos (' num2str(k) ')' ]); 7 | [~, imgname] = fileparts(fn); 8 | filepath = fullfile(datadir, fn); 9 | mat = load(filepath); % contain (input, gt, sigma) 10 | hsi = mat.gt; 11 | imgdir = fullfile(resdir, imgname); 12 | load(fullfile(imgdir, method)); % load R_hsi 13 | [psnr, ssim, fsim, ergas, sam] = MSIQA(hsi*255, R_hsi*255); 14 | fprintf('psnr: %.3f\n', psnr); 15 | fprintf('ssim: %.3f\n', ssim); 16 | res_arr(k, :) = [psnr, ssim, fsim, ergas, sam]; 17 | end 18 | 19 | end 20 | 21 | -------------------------------------------------------------------------------- /matlab/HSI_test.m: -------------------------------------------------------------------------------- 1 | function [ res_arr, err ] = HSI_test( datadir, resdir, fns, methodname ) 2 | num_method = length(methodname); 3 | num_data = length(fns); 4 | res_fp = fullfile(resdir, 'res_arr.mat'); 5 | if ~exist(resdir, 'dir') 6 | mkdir(resdir); 7 | end 8 | if ~exist(res_fp, 'file') 9 | disp('init res_arr ...'); 10 | res_arr = zeros(num_method, num_data, 6); % result hold a table with column names 'psnr', 'ssim', 'fsim', 'ergas' and 'sam' 11 | else 12 | disp('load res_arr ...'); 13 | load(res_fp) 14 | if size(res_arr, 1) ~= num_method 15 | res_arr(num_method,1,1) = 0; 16 | end 17 | if size(res_arr, 2) ~= num_data 18 | res_arr(1,num_data,1) = 0; 19 | end 20 | end 21 | 22 | err = {}; 23 | 24 | for k = 1:num_data 25 | for m = [1,2,3,4,6,7,8,9] 26 | % for m = 1:num_method 27 | % for m = [2,3,4,5,6,10,13,15] 28 | % for m = [7,9] 29 | % for m = [1,2,3,4,5,6,7,9,10,11,12,13,14,15] 30 | eval_method(m,k); 31 | end 32 | end 33 | 34 | function [ psnr, ssim, fsim, ergas, sam ] = eval_method( m, k ) 35 | fn = fns{k}; 36 | method = methodname{m}; 37 | if abs(res_arr(m,k,1)) > 1e-5 % result reuse 38 | disp(['reuse precomputed result in pos (' num2str([m k]) ')']); 39 | fprintf('psnr: %.3f\n', res_arr(m,k,1)); 40 | fprintf('ssim: %.3f\n', res_arr(m,k,2)); 41 | else 42 | disp(['perform ' method ' in pos (' num2str([m k]) ')' ]); 43 | filepath = fullfile(datadir, fn); 44 | [~, imgname] = fileparts(fn); 45 | 46 | mat = load(filepath); % contain (input, gt, sigma) 47 | hsi = mat.gt; 48 | noisy_hsi = mat.input; 49 | 50 | if ~isfield(mat, 'sigma') 51 | sigma_ratio = NoiseLevel(noisy_hsi); 52 | else 53 | sigma_ratio = mat.sigma / 255; 54 | end 55 | 56 | imgdir = fullfile(resdir, imgname); 57 | if ~exist(imgdir, 'dir') 58 | mkdir(imgdir); 59 | end 60 | 61 | try 62 | [R_hsi, time] = demo_fun(noisy_hsi, sigma_ratio, method); 63 | save(fullfile(imgdir, method), 'R_hsi'); 64 | [psnr, ssim, fsim, ergas, sam] = MSIQA(hsi*255, R_hsi*255); 65 | fprintf('psnr: %.3f\n', psnr); 66 | fprintf('ssim: %.3f\n', ssim); 67 | res_arr(m, k, :) = [psnr, ssim, fsim, ergas, sam, time]; 68 | save(res_fp, 'res_arr'); 69 | catch Error 70 | disp(['error occured in ' [m k]]); 71 | disp(Error); 72 | err{end+1} = [m k]; 73 | end 74 | end 75 | end 76 | end 77 | 78 | -------------------------------------------------------------------------------- /matlab/HSI_visualize.m: -------------------------------------------------------------------------------- 1 | function [vis_img, R_hsi] = HSI_visualize( datadir, resdir, visdir, fn, method, band, leftup, hw, pos, amp_f, translate ) 2 | %ECCV_VISUALIZE visualize hsi (groundtruth when strcmp(method,'gt')) 3 | if ~exist('amp_f', 'var') 4 | amp_f = ones(size(pos))*6; 5 | end 6 | if ~exist('translate', 'var') 7 | translate = method; 8 | end 9 | if ~exist(visdir, 'dir') 10 | mkdir(visdir); 11 | end 12 | disp(['visualize ' method]); 13 | filepath = fullfile(datadir, fn); 14 | mat = load(filepath); % contain (input, gt, sigma) 15 | if isfield(mat, 'gt') 16 | hsi = mat.gt; 17 | else 18 | hsi = mat.hsi; 19 | end 20 | % hsi = hsi * 0.6; 21 | 22 | [~, imgname] = fileparts(fn); 23 | imgdir = fullfile(resdir, imgname); 24 | savedir = fullfile(visdir, imgname); 25 | savepath = fullfile(savedir, [translate '.png']); 26 | if ~exist(savedir, 'dir') 27 | mkdir(savedir); 28 | end 29 | 30 | if ~strcmp(method, 'gt') 31 | load(fullfile(imgdir, method)); % load R_hsi 32 | else 33 | R_hsi = hsi; 34 | end 35 | 36 | % spectrum = R_hsi(60,60,:); 37 | % plot(spectrum(:), 'DisplayName',method); 38 | % hold on 39 | % legend('show') 40 | % return 41 | 42 | img = hsi(:,:,band); 43 | % img = R_hsi(:,:,band); % for real 44 | maxI = max(img(:)); 45 | minI = min(img(:)); 46 | % disp([maxI, minI]); 47 | 48 | % rightbottom = leftup + hw; 49 | vis_img = (R_hsi(:,:,band)-minI)/(maxI-minI); 50 | 51 | % vis_img = imadjust(vis_img); 52 | [y, x] = size(vis_img); 53 | yn = y + 48; xn = x + 48; 54 | new_img = ones(yn, xn); 55 | 56 | startx = max(floor(xn/2)-floor(x/2), 1); 57 | starty = floor(yn/2)-floor(y/2); 58 | new_img(starty:starty+y-1,startx:startx+x-1) = vis_img; 59 | 60 | if length(pos) == 1 61 | new_img = vis_img; 62 | end 63 | 64 | new_img = ShowEnlargedRectangle(new_img, leftup{1}, leftup{1}+hw{1}, amp_f(1), pos(1), 2, [255,0,0]); 65 | % leftup = [166, 266]; % k = 32 66 | % leftup = [150, 320]; % k = 6 67 | % leftup = [55 45]; % for urban 68 | if length(pos) > 1 69 | new_img = ShowEnlargedRectangle(new_img, leftup{2}, leftup{2}+hw{2}, amp_f(2), pos(2), 2, [0,255,0]); 70 | end 71 | figure 72 | imshow(new_img); 73 | % h = imagesc(R_hsi(:,:,band)-img, [-0.03 0.03]); 74 | % colormap jet 75 | if ~exist(savepath, 'file') 76 | imwrite(new_img, savepath); 77 | else 78 | disp([method ' has already existed']); 79 | end 80 | % colorbar('Ticks',[-0.05, -0.025, 0, 0.025, 0.05]) 81 | axis off 82 | % title(['\fontsize{25} ' translate]); 83 | % saveas(h, savepath); 84 | % export_fig(savepath) 85 | end 86 | 87 | -------------------------------------------------------------------------------- /matlab/Main_Complex.m: -------------------------------------------------------------------------------- 1 | %========================================================================== 2 | % clc; 3 | clear; 4 | close all; 5 | addpath(genpath('lib')); 6 | 7 | %% Data init 8 | basedir = 'Data'; 9 | g1 = load(fullfile(basedir, '_meta_complex.mat')); % load fns 10 | g2 = load(fullfile(basedir, '_meta_complex_2.mat')); % load fns 11 | fns = [g1.fns;g2.fns]; 12 | methodname = {'None','BWBM3D','BM4D','LRMR','LRTV','NMoG','LRTDTV'}; 13 | 14 | 15 | %% 16 | 17 | ds_names = {'icvl_512_noniid', 'icvl_512_stripe', 'icvl_512_deadline', 'icvl_512_impulse', 'icvl_512_mixture'}; 18 | 19 | for i = 1:5 20 | dataset_name = ds_names{i}; 21 | datadir = fullfile(basedir, dataset_name); 22 | resdir = fullfile('Result', dataset_name); 23 | HSI_test(datadir, resdir, fns, methodname); 24 | end 25 | 26 | -------------------------------------------------------------------------------- /matlab/Main_Gauss.m: -------------------------------------------------------------------------------- 1 | %========================================================================== 2 | % clc; 3 | clear; 4 | close all; 5 | addpath(genpath('lib')); 6 | 7 | %% Data init 8 | basedir = 'Data'; 9 | g1 = load(fullfile(basedir, '_meta_gauss.mat')); % load fns 10 | g2 = load(fullfile(basedir, '_meta_gauss_2.mat')); 11 | fns = [g1.fns;g2.fns]; 12 | 13 | methodname = {'None','BM4D','TDL', 'ITSReg','LLRT'}; 14 | 15 | if isempty(gcp) 16 | parpool(4,'IdleTimeout', inf); % If your computer's memory is less than 8G, do not use more than 4 workers. 17 | end 18 | 19 | ds_names = {'icvl_512_30', 'icvl_512_50', 'icvl_512_70', 'icvl_512_blind'}; 20 | 21 | for i = 4 22 | dataset_name = ds_names{i}; 23 | datadir = fullfile(basedir, dataset_name); 24 | resdir = fullfile('Result', dataset_name); 25 | HSI_test(datadir, resdir, fns, methodname); 26 | end 27 | -------------------------------------------------------------------------------- /matlab/Main_Real.m: -------------------------------------------------------------------------------- 1 | addpath(genpath('lib')); 2 | % datadir = fullfile('Data','Indian'); 3 | datadir = fullfile('Data','Urban'); 4 | % datadir = fullfile('Data','Harvard'); 5 | 6 | % resdir = fullfile('Result', 'Indian'); 7 | resdir = fullfile('Result', 'Urban'); 8 | % resdir = fullfile('Result', 'Harvard'); 9 | 10 | % fns = {'Indian_pines.mat'}; 11 | fns = {'Urban183.mat'}; 12 | % fns = {'img1.mat'}; 13 | methodname = {'None', 'BM4D','TDL', 'ITSReg', 'LLRT','LRMR','LRTV','NMoG','LRTDTV'}; 14 | 15 | for m = 1:length(methodname) 16 | method = methodname{m}; 17 | for k = 1:length(fns) 18 | fn = fns{k}; 19 | [~, imgname] = fileparts(fn); 20 | imgdir = fullfile(resdir, imgname); 21 | savepath = fullfile(imgdir, [methodname{m}, '.mat']); 22 | if ~exist(imgdir, 'dir') 23 | mkdir(imgdir); 24 | end 25 | if exist(savepath, 'file') 26 | disp(['reuse precomputed result in pos (' num2str([m k]) ')']); 27 | break 28 | end 29 | load(fullfile(datadir, fn)) 30 | sigma_ratio = double(real(NoiseLevel(hsi))); 31 | disp(['perform ' method ' in pos (' num2str([m k]) ')' ]); 32 | R_hsi = demo_fun(hsi, sigma_ratio, methodname{m}); 33 | save(savepath, 'R_hsi'); 34 | end 35 | end 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /matlab/README.md: -------------------------------------------------------------------------------- 1 | # Hyperspectral Image Denoising - A Comprehensive Benchmark 2 | 3 | An **unified interface** for benchmarking HSI denoising algorithms on various datasets under different noise settings. 4 | 5 | ## Quick Start 6 | 7 | * Download the library of HSI denoising algorithms from [OneDrive](https://1drv.ms/u/s!AqddfvhavTRii3kvriy6C14ub-SH?e=wdNnRf) and put the ```lib``` directory in the ```matlab``` folder. 8 | 9 | * Type ```addpath(genpath('lib'));``` in matlab command line; now you can use any algorithms with an unified interface defined in ```demo_fun.m```. For instance, you can use ```BM4D``` to denoise a hyperspectral image via ```demo_fun(noisy_hsi, sigma_ratio, 'BM4D')``` 10 | 11 | * Benchmark your selected algorithm on the whole dataset by ```Main_Gauss/Complex/Real.m```; Get the quantitative results via ```Result_Gauss/Complex.m``` 12 | 13 | ## Citation 14 | 15 | The bibtex of the collected algorithms can be found in ```benchmarks.bib``` 16 | 17 | ## Acknowledgments 18 | * The code of different HSI denoising algorithms is collected online. 19 | * Special thanks to these authors for making their source code publicly available. 20 | 21 | 22 | ## Contributions 23 | 24 | Adding your own algorithms is very easy: Simply put your source code in ```lib``` directory, then write your algorithm interface in ```demo_fun.m```. 25 | -------------------------------------------------------------------------------- /matlab/Result_Complex.m: -------------------------------------------------------------------------------- 1 | %% Result Analysis 2 | addpath(genpath('lib')); 3 | 4 | basedir = 'Data'; 5 | methodname = {'None','BWBM3D','BM4D','LRMR','LRTV','NMoG','LRTDTV','HSID-CNN', 'MemNet','QRNN3D'}; 6 | num_method = length(methodname); 7 | ds_names = {'icvl_512_noniid', 'icvl_512_stripe', 'icvl_512_deadline', 'icvl_512_impulse', 'icvl_512_mixture'}; 8 | titles = {'non-i.i.d.', 'stripe', 'deadline', 'impulse', 'mixture'}; 9 | ms = 1:num_method; % displayed method 10 | 11 | % columnLabels= methodname(ms); 12 | rowLabels = {'PSNR', 'SSIM', 'SAM'}; 13 | 14 | g1 = load(fullfile(basedir, '_meta_complex.mat')); % load fns 15 | g2 = load(fullfile(basedir, '_meta_complex_2.mat')); 16 | fns = [g1.fns;g2.fns]; 17 | 18 | %% 19 | 20 | for d = 1%[1,2,3,4,5] 21 | dataset_name = ds_names{d}; 22 | resdir = fullfile('Result', dataset_name); 23 | load(fullfile(resdir, 'res_arr_new')); % load res_arr 24 | res_table = zeros(length(ms), 5, 2); 25 | 26 | disp(['============= ' dataset_name ' =============']) 27 | for i = 1:length(ms) 28 | m = ms(i); 29 | disp(['============= ' methodname{m} ' =============']) 30 | psnr = nonzeros(res_arr(m,:,1)); 31 | ssim = nonzeros(res_arr(m,:,2)); 32 | % fsim = nonzeros(res_arr(m,:,3)); 33 | % ergas = nonzeros(res_arr(m,:,4)); 34 | sam = nonzeros(res_arr(m,:,5)); 35 | 36 | res_table(i, 1, 1) = mean(psnr); 37 | res_table(i, 2, 1) = mean(ssim); 38 | % res_table(i, 3, 1) = mean(fsim); 39 | % res_table(i, 4, 1) = mean(ergas); 40 | res_table(i, 5, 1) = mean(sam); 41 | res_table(i, 1, 2) = std(psnr); 42 | res_table(i, 2, 2) = std(ssim); 43 | % res_table(i, 3, 2) = std(fsim); 44 | % res_table(i, 4, 2) = std(ergas); 45 | res_table(i, 5, 2) = std(sam); 46 | 47 | disp(['PSNR: ' num2str(mean(psnr)) '( ' num2str(std(psnr)) ' )']); 48 | disp(['SSIM: ' num2str(mean(ssim)) '( ' num2str(std(ssim)) ' )']); 49 | % disp(['FSIM: ' num2str(mean(fsim)) '( ' num2str(std(fsim)) ' )']); 50 | % disp(['ERGAS: ' num2str(mean(ergas)) '( ' num2str(std(ergas)) ' )']); 51 | disp(['SAM: ' num2str(mean(sam)) '( ' num2str(std(sam)) ' )']); 52 | end 53 | end 54 | -------------------------------------------------------------------------------- /matlab/Result_Gauss.m: -------------------------------------------------------------------------------- 1 | %% Result Analysis 2 | addpath(genpath('lib')); 3 | 4 | basedir = 'Data'; 5 | methodname = {'None','BM4D','TDL', 'ITSReg','LLRT', 'HSID-CNN','MemNet', 'QRNN3D'}; 6 | 7 | num_method = length(methodname); 8 | ds_names = {'icvl_512_30', 'icvl_512_50', 'icvl_512_70', 'icvl_512_blind'}; 9 | 10 | titles = {'30', '50', '70', 'blind'}; 11 | ms = 1:num_method; % displayed method 12 | 13 | 14 | columnLabels= methodname(ms); 15 | % rowLabels = {'PSNR', 'SSIM', 'FSIM', 'ERGAS', 'SAM'}; 16 | rowLabels = {'PSNR', 'SSIM', 'SAM'}; 17 | delete('table.tex'); 18 | 19 | g1 = load(fullfile(basedir, '_meta_gauss.mat')); % load fns 20 | g2 = load(fullfile(basedir, '_meta_gauss_2.mat')); 21 | fns = [g1.fns;g2.fns]; 22 | %% 23 | 24 | for d = 4 25 | dataset_name = ds_names{d}; 26 | resdir = fullfile('Result', dataset_name); 27 | load(fullfile(resdir, 'res_arr_final')); % load res_arr 28 | res_table = zeros(length(ms), 5, 2); 29 | 30 | disp(['============= ' dataset_name ' =============']) 31 | for i = 1:length(ms) 32 | m = ms(i); 33 | disp(['============= ' methodname{m} ' =============']) 34 | psnr = nonzeros(res_arr(m,:,1)); 35 | ssim = nonzeros(res_arr(m,:,2)); 36 | fsim = nonzeros(res_arr(m,:,3)); 37 | ergas = nonzeros(res_arr(m,:,4)); 38 | sam = nonzeros(res_arr(m,:,5)); 39 | % time = nonzeros(res_arr(m,:,6)); 40 | 41 | res_table(i, 1, 1) = mean(psnr); 42 | res_table(i, 2, 1) = mean(ssim); 43 | res_table(i, 3, 1) = mean(fsim); 44 | res_table(i, 4, 1) = mean(ergas); 45 | res_table(i, 5, 1) = mean(sam); 46 | % res_table(i, 6, 1) = mean(time); 47 | res_table(i, 1, 2) = std(psnr); 48 | res_table(i, 2, 2) = std(ssim); 49 | res_table(i, 3, 2) = std(fsim); 50 | res_table(i, 4, 2) = std(ergas); 51 | res_table(i, 5, 2) = std(sam); 52 | % res_table(i, 6, 1) = std(time); 53 | 54 | disp(['PSNR: ' num2str(mean(psnr)) '( ' num2str(std(psnr)) ' )']); 55 | disp(['SSIM: ' num2str(mean(ssim)) '( ' num2str(std(ssim)) ' )']); 56 | disp(['FSIM: ' num2str(mean(fsim)) '( ' num2str(std(fsim)) ' )']); 57 | disp(['ERGAS: ' num2str(mean(ergas)) '( ' num2str(std(ergas)) ' )']); 58 | disp(['SAM: ' num2str(mean(sam)) '( ' num2str(std(sam)) ' )']); 59 | % disp(['TIME: ' num2str(mean(time)) '( ' num2str(std(time)) ' )']); 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /matlab/benchmarks.bib: -------------------------------------------------------------------------------- 1 | 2 | # BWBM3D 3 | @article{dabov2007BM3D, 4 | title={Image denoising by sparse 3-D transform-domain collaborative filtering}, 5 | author={Dabov, Kostadin and Foi, Alessandro and Katkovnik, Vladimir and Egiazarian, Karen}, 6 | journal={IEEE Transactions on Image Processing}, 7 | volume={16}, 8 | number={8}, 9 | pages={2080--2095}, 10 | year={2007}, 11 | publisher={IEEE} 12 | } 13 | 14 | # TDL 15 | @inproceedings{peng2014TDL, 16 | title={Decomposable nonlocal tensor dictionary learning for multispectral image denoising}, 17 | author={Peng, Yi and Meng, Deyu and Xu, Zongben and Gao, Chenqiang and Yang, Yi and Zhang, Biao}, 18 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 19 | pages={2949--2956}, 20 | year={2014} 21 | } 22 | 23 | # BM4D 24 | @article{maggioni2013nonlocal, 25 | title={Nonlocal transform-domain filter for volumetric data denoising and reconstruction}, 26 | author={Maggioni, Matteo and Katkovnik, Vladimir and Egiazarian, Karen and Foi, Alessandro}, 27 | journal={IEEE Transactions on Image Processing}, 28 | volume={22}, 29 | number={1}, 30 | pages={119--133}, 31 | year={2013}, 32 | publisher={IEEE} 33 | } 34 | 35 | # ITSReg 36 | @inproceedings{xie2016multispectral, 37 | title={Multispectral images denoising by intrinsic tensor sparsity regularization}, 38 | author={Xie, Qi and Zhao, Qian and Meng, Deyu and Xu, Zongben and Gu, Shuhang and Zuo, Wangmeng and Zhang, Lei}, 39 | booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 40 | pages={1692--1700}, 41 | year={2016} 42 | } 43 | 44 | # BCTF-HSI 45 | @article{WEI2019412, 46 | title = "Low-rank Bayesian tensor factorization for hyperspectral image denoising", 47 | journal = "Neurocomputing", 48 | volume = "331", 49 | pages = "412 - 423", 50 | year = "2019", 51 | issn = "0925-2312", 52 | doi = "https://doi.org/10.1016/j.neucom.2018.10.023", 53 | url = "http://www.sciencedirect.com/science/article/pii/S0925231218312116", 54 | author = "Kaixuan Wei and Ying Fu", 55 | } 56 | 57 | # LLRT 58 | @inproceedings{chang2017hyper, 59 | title={Hyper-laplacian regularized unidirectional low-rank tensor recovery for multispectral image denoising}, 60 | author={Chang, Yi and Yan, Luxin and Zhong, Sheng}, 61 | booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 62 | pages={4260--4268}, 63 | year={2017} 64 | } 65 | 66 | # NMoG 67 | @article{chen2017denoising, 68 | title={Denoising Hyperspectral Image with Non-iid Noise Structure}, 69 | author={Chen, Yang and Cao, Xiangyong and Zhao, Qian and Meng, Deyu and Xu, Zongben}, 70 | journal={arXiv preprint arXiv:1702.00098}, 71 | year={2017} 72 | } 73 | 74 | # LRTV 75 | @article{he2016total, 76 | title={Total-variation-regularized low-rank matrix factorization for hyperspectral image restoration}, 77 | author={He, Wei and Zhang, Hongyan and Zhang, Liangpei and Shen, Huanfeng}, 78 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 79 | volume={54}, 80 | number={1}, 81 | pages={178--188}, 82 | year={2016}, 83 | publisher={IEEE} 84 | } 85 | 86 | # LRMR 87 | @article{zhang2014hyperspectral, 88 | title={Hyperspectral image restoration using low-rank matrix recovery}, 89 | author={Zhang, Hongyan and He, Wei and Zhang, Liangpei and Shen, Huanfeng and Yuan, Qiangqiang}, 90 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 91 | volume={52}, 92 | number={8}, 93 | pages={4729--4743}, 94 | year={2014}, 95 | publisher={IEEE} 96 | } 97 | 98 | # LRTDTV 99 | @article{wang2017hyperspectral, 100 | title={Hyperspectral image restoration via total variation regularized low-rank tensor decomposition}, 101 | author={Wang, Yao and Peng, Jiangjun and Zhao, Qian and Leung, Yee and Zhao, Xi-Le and Meng, Deyu}, 102 | journal={IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing}, 103 | year={2017}, 104 | publisher={IEEE} 105 | } 106 | 107 | 108 | # PARAFAC 109 | @article{liu2012denoising, 110 | title={Denoising of hyperspectral images using the PARAFAC model and statistical performance analysis}, 111 | author={Liu, Xuefeng and Bourennane, Salah and Fossati, Caroline}, 112 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 113 | volume={50}, 114 | number={10}, 115 | pages={3717--3724}, 116 | year={2012}, 117 | publisher={IEEE} 118 | } 119 | 120 | 121 | # tSVD 122 | @InProceedings{Zhang_2014_CVPR, 123 | author = {Zhang, Zemin and Ely, Gregory and Aeron, Shuchin and Hao, Ning and Kilmer, Misha}, 124 | title = {Novel Methods for Multilinear Data Completion and De-noising Based on Tensor-SVD}, 125 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 126 | month = {June}, 127 | year = {2014} 128 | } 129 | 130 | # LLRGTV 131 | @article{he2018hyperspectral, 132 | title={Hyperspectral Image Denoising Using Local Low-Rank Matrix Recovery and Global Spatial--Spectral Total Variation}, 133 | author={He, Wei and Zhang, Hongyan and Shen, Huanfeng and Zhang, Liangpei}, 134 | journal={IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing}, 135 | volume={11}, 136 | number={3}, 137 | pages={713--729}, 138 | year={2018}, 139 | publisher={IEEE} 140 | } 141 | 142 | 143 | # GLF 144 | @inproceedings{zhuang2017hyperspectral, 145 | title={Hyperspectral image denoising based on global and non-local low-rank factorizations}, 146 | author={Zhuang, Lina and Bioucas-Dias, Jos{\'e} M}, 147 | booktitle={Image Processing (ICIP), 2017 IEEE International Conference on}, 148 | pages={1900--1904}, 149 | year={2017}, 150 | organization={IEEE} 151 | } 152 | -------------------------------------------------------------------------------- /matlab/demo_fun.m: -------------------------------------------------------------------------------- 1 | function [ Re_hsi, time ] = demo_fun( noisy_hsi, sigma_ratio, methodname ) 2 | %DEMO_FUN 3 | hsi_sz = size(noisy_hsi); 4 | Re_hsi = noisy_hsi; 5 | tic; 6 | if strcmp(methodname, 'None') 7 | % Re_hsi(Re_hsi>1) = 1; 8 | % Re_hsi(Re_hsi<0) = 0; 9 | time = toc; 10 | return 11 | 12 | elseif strcmp(methodname, 'BWBM3D') 13 | if length(sigma_ratio) == 1 14 | sigma_ratio = repmat(sigma_ratio, hsi_sz(3), 1); 15 | end 16 | for ch = 1:hsi_sz(3) 17 | [~, Re_hsi(:, :, ch)] = BM3D(1, noisy_hsi(:, :, ch), sigma_ratio(ch)*255); 18 | end 19 | 20 | elseif strcmp(methodname, 'TDL') 21 | vstbmtf_params.peak_value = 1; 22 | vstbmtf_params.nsigma = mean(sigma_ratio); 23 | Re_hsi = TensorDL(noisy_hsi, vstbmtf_params); 24 | 25 | elseif strcmp(methodname, 'BM4D') 26 | if length(sigma_ratio) == 1 27 | [~, Re_hsi] = bm4d(1, noisy_hsi, sigma_ratio); 28 | else 29 | % [~, Re_hsi] = bm4d(1, noisy_hsi, 0); % enable automatical sigma estimation 30 | [~, Re_hsi] = bm4d(1, noisy_hsi, mean(sigma_ratio)); 31 | end 32 | 33 | elseif strcmp(methodname, 'ITSReg') 34 | Re_hsi = ITS_DeNoising(noisy_hsi,mean(sigma_ratio), 1); 35 | 36 | elseif strcmp(methodname, 'BCTF-HSI') 37 | Re_hsi = BCTF_DeNoising(noisy_hsi,mean(sigma_ratio)*255); 38 | 39 | elseif strcmp(methodname, 'BCTF') 40 | model = BCPF(noisy_hsi, 'init', 'rand', 'maxRank', 30, 'dimRed', 1, 'tol', 1e-3, 'maxiters', 25, 'verbose', 0); 41 | Re_hsi = double(model.X); 42 | 43 | elseif strcmp(methodname, 'LLRT') 44 | Par = LLRT_ParSet(mean(sigma_ratio)*255); 45 | Re_hsi = LLRT_DeNoising(noisy_hsi*255, Par)/255 ; 46 | 47 | elseif strcmp(methodname, 'NMoG') 48 | if size(noisy_hsi, 3) > 100 49 | r = 5; % objective rank of low rank component 50 | param.initial_rank = 30; 51 | param.rankDeRate = 7; 52 | param.mog_k = 5; 53 | param.lr_init = 'SVD'; 54 | param.maxiter = 30; 55 | param.tol = 1e-4; 56 | param.display = 0; 57 | else 58 | r = 3; 59 | param.initial_rank = 30; % initial rank of low rank component 60 | param.rankDeRate = 7; % the number of rank reduced in each iteration 61 | param.mog_k = 3; % the number of component reduced in each band 62 | param.lr_init = 'SVD'; 63 | param.maxiter = 30; 64 | param.tol = 1e-3; 65 | param.display = 0; 66 | end 67 | Re_hsi = NMoG(noisy_hsi, r, param); 68 | 69 | elseif strcmp(methodname, 'LRTV') 70 | if size(noisy_hsi,3) > 100 71 | tau = 0.015; 72 | lambda = 20/sqrt(hsi_sz(1)*hsi_sz(2)); 73 | rank = 10; 74 | else 75 | % ICVL 76 | tau = 0.01; 77 | lambda = 10/sqrt(hsi_sz(1)*hsi_sz(2)); 78 | rank = 5; 79 | end 80 | Re_hsi = LRTV(noisy_hsi, tau, lambda, rank); 81 | 82 | elseif strcmp(methodname, 'LRMR') 83 | if size(noisy_hsi, 3) > 100 84 | r = 7; 85 | slide =20; 86 | s = 0.1; 87 | stepsize = 8; 88 | else 89 | % ICVL 90 | % r = 2; 91 | % slide = 30; 92 | % s = 0.05; 93 | % stepsize = 4; 94 | % CAVE 95 | r = 3; 96 | slide = 20; 97 | s = 0.00; 98 | stepsize = 4; 99 | end 100 | Re_hsi = LRMR_HSI_denoise( noisy_hsi,r,slide,s,stepsize ); 101 | 102 | elseif strcmp(methodname, 'LRTDTV') 103 | if size(noisy_hsi,3) > 100 104 | Re_hsi = LRTDTV(noisy_hsi, 1, 10, [ceil(0.8*hsi_sz(1)), ceil(0.8*hsi_sz(2)), 10]); 105 | else 106 | Re_hsi = LRTDTV(noisy_hsi, 1, 10, [ceil(0.1*hsi_sz(1)), ceil(0.1*hsi_sz(2)), 3]); 107 | end 108 | 109 | elseif strcmp(methodname, 'LRTA') 110 | Re_hsi = double(LRTA(tensor(noisy_hsi))); 111 | elseif strcmp(methodname, 'PARAFAC') 112 | if size(noisy_hsi,3) > 100 113 | Re_hsi = PARAFAC(tensor(double(noisy_hsi)), 2e-6, 2e-5); 114 | else 115 | Re_hsi = PARAFAC(tensor(double(noisy_hsi))); 116 | end 117 | elseif strcmp(methodname, 'tSVD') 118 | Re_hsi = tSVD_DeNoising(noisy_hsi,mean(sigma_ratio), 1); 119 | elseif strcmp(methodname, 'LLRGTV') 120 | if size(noisy_hsi,3) > 100 121 | par.lambda = 0.20; 122 | par.tau = 0.005; 123 | par.r = 2; 124 | par.blocksize = 20; 125 | par.stepsize = 10; 126 | par.maxIter = 50; 127 | par.tol = 1e-6; 128 | else 129 | par.lambda = 0.13; 130 | par.tau = 0.013; 131 | par.r = 2; 132 | par.blocksize = 20; 133 | par.stepsize = 17; 134 | par.maxIter = 50; 135 | par.tol = 1e-6; 136 | end 137 | Re_hsi = LLRGTV(noisy_hsi, par); 138 | elseif strcmp(methodname, 'FastHyDe') 139 | noise_type = 'additive'; 140 | p_subspace = 3; %Dimension of the subspace 141 | iid = 1; 142 | Re_hsi = FastHyDe(noisy_hsi, noise_type, iid, p_subspace); 143 | elseif strcmp(methodname, 'GLF') 144 | noise_type = 'additive'; 145 | p_subspace = 10; %Dimension of the subspace 146 | Re_hsi = GLF_denoiser(noisy_hsi, p_subspace, noise_type) ; 147 | else 148 | error('Error: no matched method'); 149 | end 150 | 151 | Re_hsi(Re_hsi>1) = 1; 152 | Re_hsi(Re_hsi<0) = 0; 153 | time = toc; 154 | end 155 | -------------------------------------------------------------------------------- /matlab/eval_dataset.m: -------------------------------------------------------------------------------- 1 | function [ PSNR, SSIM, FSIM, ERGAS ] = eval_dataset( datadir, fns, method, preprocess ) 2 | %EVAL_DATASET Summary of this function goes here 3 | 4 | for k = 1:length(fns) 5 | fn = fns{k}; 6 | disp(['eval ' fn]); 7 | filepath = fullfile(datadir, fn); 8 | mat = load(filepath); % contain (input, gt, sigma) 9 | hsi = mat.gt; 10 | noisy_hsi = mat.input; 11 | if exist('preprocess', 'var') 12 | hsi = preprocess(hsi); 13 | noisy_hsi = preprocess(noisy_hsi); 14 | end 15 | if ~isfield(mat, 'sigma') 16 | sigma_ratio = NoiseLevel(noisy_hsi); 17 | else 18 | sigma_ratio = mat.sigma / 255; 19 | end 20 | 21 | R_hsi = demo_fun(noisy_hsi, sigma_ratio, method); 22 | [psnr, ssim, fsim, ergas] = MSIQA(hsi*255, R_hsi*255); 23 | fprintf('psnr: %.3f\n', psnr); 24 | fprintf('ssim: %.3f\n', ssim); 25 | PSNR(k) = psnr; 26 | SSIM(k) = ssim; 27 | FSIM(k) = fsim; 28 | ERGAS(k) = ergas; 29 | end 30 | disp(mean(PSNR)); 31 | end 32 | 33 | -------------------------------------------------------------------------------- /matlab/generate_dataset.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset( datadir, fns, newdir, sigma, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | k = 1; 4 | if ~exist(newdir, 'dir') 5 | mkdir(newdir) 6 | end 7 | 8 | for k = 1:length(fns) 9 | fn = fns{k}; 10 | fprintf('generate data(%d/%d)\n', k, length(fns)); 11 | filepath = fullfile(datadir, fn); 12 | mat = load(filepath); % contain gt_key 13 | gt = getfield(mat, gt_key); 14 | 15 | if exist('preprocess', 'var') 16 | gt = preprocess(gt); 17 | end 18 | 19 | gt = normalized(gt); 20 | 21 | s = reshape(sigma, 1, 1, length(sigma)); 22 | input = gt + s/255 .* randn(size(gt)); 23 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /matlab/generate_dataset_blind.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_blind( datadir, fns, newdir, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | 4 | if ~exist(newdir, 'dir') 5 | mkdir(newdir) 6 | end 7 | 8 | for k = 1:length(fns) 9 | fn = fns{k}; 10 | fprintf('generate data(%d/%d)\n', k, length(fns)); 11 | filepath = fullfile(datadir, fn); 12 | mat = load(filepath); % contain gt_key 13 | gt = getfield(mat, gt_key); 14 | 15 | if exist('preprocess', 'var') 16 | gt = preprocess(gt); 17 | end 18 | 19 | gt = normalized(gt); 20 | sigma = rand(1,1) * 60 + 10; 21 | % sigma = rand(1,1) * 85 + 15; 22 | fprintf('sigma: %.2f\n', sigma); 23 | s = reshape(sigma, 1, 1, length(sigma)); 24 | input = gt + s/255 .* randn(size(gt)); 25 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /matlab/generate_dataset_complex.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_complex( noise_type, datadir, newdir, sigmas, gt_key) 2 | fn=dir([datadir '*.mat']); 3 | fns={fn.name}; 4 | preprocess = @(x)(x); 5 | %GENERATE_DATASET Summary of this function goes here 6 | if strcmp(noise_type, 'noniid') 7 | generate_dataset_noniid( datadir, fns, newdir, sigmas, gt_key, preprocess ) 8 | elseif strcmp(noise_type, 'stripe') 9 | generate_dataset_stripe( datadir, fns, newdir, sigmas, gt_key, preprocess ) 10 | elseif strcmp(noise_type, 'deadline') 11 | generate_dataset_deadline( datadir, fns, newdir, sigmas, gt_key, preprocess ) 12 | elseif strcmp(noise_type, 'impulse') 13 | generate_dataset_impulse( datadir, fns, newdir, sigmas, gt_key, preprocess ) 14 | elseif strcmp(noise_type, 'mixture') 15 | generate_dataset_mixture( datadir, fns, newdir, sigmas, gt_key, preprocess ) 16 | end 17 | end 18 | -------------------------------------------------------------------------------- /matlab/generate_dataset_complex_backup.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_complex_backup( noise_type, datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | if strcmp(noise_type, 'noniid') 4 | generate_dataset_noniid( datadir, fns, newdir, sigmas, gt_key, preprocess ) 5 | elseif strcmp(noise_type, 'stripe') 6 | generate_dataset_stripe( datadir, fns, newdir, sigmas, gt_key, preprocess ) 7 | elseif strcmp(noise_type, 'deadline') 8 | generate_dataset_deadline( datadir, fns, newdir, sigmas, gt_key, preprocess ) 9 | elseif strcmp(noise_type, 'impulse') 10 | generate_dataset_impulse( datadir, fns, newdir, sigmas, gt_key, preprocess ) 11 | elseif strcmp(noise_type, 'mixture') 12 | generate_dataset_mixture( datadir, fns, newdir, sigmas, gt_key, preprocess ) 13 | end 14 | end 15 | -------------------------------------------------------------------------------- /matlab/generate_dataset_deadline.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_deadline( datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | min_amount = 0.05; 4 | max_amount = 0.15; 5 | if ~exist(newdir, 'dir') 6 | mkdir(newdir) 7 | end 8 | 9 | for k = 1:length(fns) 10 | fn = fns{k}; 11 | fprintf('generate data(%d/%d)\n', k, length(fns)); 12 | filepath = fullfile(datadir, fn); 13 | mat = load(filepath); % contain gt_key 14 | gt = getfield(mat, gt_key); 15 | 16 | if exist('preprocess', 'var') 17 | gt = preprocess(gt); 18 | end 19 | 20 | gt = normalized(gt); 21 | % sample sigma uniformly from sigmas 22 | idx = randi(length(sigmas), size(gt,3), 1); 23 | sigma = sigmas(idx); 24 | disp(sigma) 25 | 26 | s = reshape(sigma, 1, 1, length(sigma)); 27 | input = gt + s/255 .* randn(size(gt)); 28 | 29 | [~, N, B] = size(gt); 30 | band = randperm(B); 31 | band = band(1:10); 32 | 33 | deadlinenum = randi([ceil(min_amount * N), ceil(max_amount * N)], length(band), 1); 34 | disp(deadlinenum); 35 | for i=1:length(band) 36 | loc = randperm(N); 37 | loc = loc(1:deadlinenum(i)); 38 | input(:,loc,band(i)) = 0; 39 | end 40 | 41 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 42 | end 43 | end 44 | -------------------------------------------------------------------------------- /matlab/generate_dataset_impulse.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_impulse( datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | ratios = [0.1, 0.3, 0.5, 0.7]; 4 | if ~exist(newdir, 'dir') 5 | mkdir(newdir) 6 | end 7 | 8 | for k = 1:length(fns) 9 | fn = fns{k}; 10 | fprintf('generate data(%d/%d)\n', k, length(fns)); 11 | filepath = fullfile(datadir, fn); 12 | mat = load(filepath); % contain gt_key 13 | gt = getfield(mat, gt_key); 14 | 15 | if exist('preprocess', 'var') 16 | gt = preprocess(gt); 17 | end 18 | 19 | gt = normalized(gt); 20 | % sample sigma uniformly from sigmas 21 | idx = randi(length(sigmas), size(gt,3), 1); 22 | sigma = sigmas(idx); 23 | disp(sigma) 24 | 25 | s = reshape(sigma, 1, 1, length(sigma)); 26 | input = gt + s/255 .* randn(size(gt)); 27 | 28 | [~, N, B] = size(gt); 29 | 30 | band = randperm(B); 31 | band = band(1:10); 32 | idx = randi(length(ratios), length(band), 1); 33 | ratio = ratios(idx); 34 | disp(ratio); 35 | 36 | for i=1:length(band) 37 | input(:,:,band(i)) = imnoise(input(:,:,band(i)),'salt & pepper',ratio(i)); 38 | end 39 | 40 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 41 | end 42 | end 43 | -------------------------------------------------------------------------------- /matlab/generate_dataset_mixture.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_mixture( datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | ratios = [0.1, 0.3, 0.5, 0.7]; 4 | min_amount = 0.05; 5 | max_amount = 0.15; 6 | if ~exist(newdir, 'dir') 7 | mkdir(newdir) 8 | end 9 | 10 | for k = 1:length(fns) 11 | fn = fns{k}; 12 | fprintf('generate data(%d/%d)\n', k, length(fns)); 13 | filepath = fullfile(datadir, fn); 14 | mat = load(filepath); % contain gt_key 15 | gt = getfield(mat, gt_key); 16 | 17 | if exist('preprocess', 'var') 18 | gt = preprocess(gt); 19 | end 20 | 21 | gt = normalized(gt); 22 | % sample sigma uniformly from sigmas 23 | idx = randi(length(sigmas), size(gt,3), 1); 24 | sigma = sigmas(idx); 25 | sigma = sigma(1); 26 | s = sigma*rand(1,size(gt,3)); 27 | s = reshape(s, 1, 1, length(s)); 28 | input = gt + s/255 .* randn(size(gt)); 29 | 30 | [~, N, B] = size(gt); 31 | 32 | % add stripe 33 | all_band = randperm(B); 34 | b = floor(B/3); 35 | band_stripe = all_band(1:b); 36 | 37 | stripnum = randi([ceil(min_amount * N), ceil(max_amount * N)], length(band_stripe), 1); 38 | fprintf('Stripes:\n'); 39 | disp(stripnum); 40 | for i=1:length(band_stripe) 41 | loc = randperm(N); 42 | loc = loc(1:stripnum(i)); 43 | stripe = rand(1,length(loc))*0.5-0.25; 44 | input(:,loc,band_stripe(i)) = input(:,loc,band_stripe(i)) - stripe; 45 | end 46 | 47 | % add deadline 48 | band_deadline = all_band(b+1:2*b); 49 | deadlinenum = randi([ceil(min_amount * N), ceil(max_amount * N)], length(band_deadline), 1); 50 | fprintf('Deadline:\n'); 51 | disp(deadlinenum); 52 | for i=1:length(band_deadline) 53 | loc = randperm(N); 54 | loc = loc(1:deadlinenum(i)); 55 | input(:,loc,band_deadline(i)) = 0; 56 | end 57 | 58 | % add impulse 59 | fprintf('impulse:\n'); 60 | band_impulse = all_band(2*b+1:3*b); 61 | idx = randi(length(ratios), length(band_impulse), 1); 62 | ratio = ratios(idx); 63 | disp(ratio); 64 | for i=1:length(band_impulse) 65 | input(:,:,band_impulse(i)) = imnoise(input(:,:,band_impulse(i)),'salt & pepper',ratio(i)); 66 | end 67 | DataCube = input; 68 | save(fullfile(newdir, fn), 'DataCube','s','band_stripe','stripnum','band_deadline','deadlinenum','band_impulse','ratio'); 69 | end 70 | end 71 | 72 | 73 | function gt=normalized(gt) 74 | gt=gt./max(gt(:)); 75 | end 76 | -------------------------------------------------------------------------------- /matlab/generate_dataset_mixture_backup.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_mixture( datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | ratios = [0.1, 0.3, 0.5, 0.7]; 4 | min_amount = 0.05; 5 | max_amount = 0.15; 6 | if ~exist(newdir, 'dir') 7 | mkdir(newdir) 8 | end 9 | 10 | for k = 1:length(fns) 11 | fn = fns{k}; 12 | fprintf('generate data(%d/%d)\n', k, length(fns)); 13 | filepath = fullfile(datadir, fn); 14 | mat = load(filepath); % contain gt_key 15 | gt = getfield(mat, gt_key); 16 | 17 | if exist('preprocess', 'var') 18 | gt = preprocess(gt); 19 | end 20 | 21 | gt = normalized(gt); 22 | % sample sigma uniformly from sigmas 23 | idx = randi(length(sigmas), size(gt,3), 1); 24 | sigma = sigmas(idx); 25 | disp(sigma) 26 | 27 | s = reshape(sigma, 1, 1, length(sigma)); 28 | input = gt + s/255 .* randn(size(gt)); 29 | 30 | [~, N, B] = size(gt); 31 | 32 | % add stripe 33 | all_band = randperm(B); 34 | b = floor(B/3); 35 | band = all_band(1:b); 36 | 37 | stripnum = randi([ceil(min_amount * N), ceil(max_amount * N)], length(band), 1); 38 | fprintf('Stripes:\n'); 39 | disp(stripnum); 40 | for i=1:length(band) 41 | loc = randperm(N); 42 | loc = loc(1:stripnum(i)); 43 | stripe = rand(1,length(loc))*0.5-0.25; 44 | input(:,loc,band(i)) = input(:,loc,band(i)) - stripe; 45 | end 46 | 47 | % add deadline 48 | band = all_band(b+1:2*b); 49 | deadlinenum = randi([ceil(min_amount * N), ceil(max_amount * N)], length(band), 1); 50 | fprintf('Deadline:\n'); 51 | disp(deadlinenum); 52 | for i=1:length(band) 53 | loc = randperm(N); 54 | loc = loc(1:deadlinenum(i)); 55 | input(:,loc,band(i)) = 0; 56 | end 57 | 58 | % add impulse 59 | fprintf('impulse:\n'); 60 | band = all_band(2*b+1:3*b); 61 | idx = randi(length(ratios), length(band), 1); 62 | ratio = ratios(idx); 63 | disp(ratio); 64 | for i=1:length(band) 65 | input(:,:,band(i)) = imnoise(input(:,:,band(i)),'salt & pepper',ratio(i)); 66 | end 67 | 68 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 69 | end 70 | end 71 | 72 | 73 | function gt=normalized(gt) 74 | gt=gt./max(gt(:)); 75 | end 76 | -------------------------------------------------------------------------------- /matlab/generate_dataset_noniid.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_noniid( datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | if ~exist(newdir, 'dir') 4 | mkdir(newdir) 5 | end 6 | 7 | for k = 1:length(fns) 8 | fn = fns{k}; 9 | fprintf('generate data(%d/%d)\n', k, length(fns)); 10 | filepath = fullfile(datadir, fn); 11 | mat = load(filepath); % contain gt_key 12 | gt = getfield(mat, gt_key); 13 | 14 | if exist('preprocess', 'var') 15 | gt = preprocess(gt); 16 | end 17 | 18 | gt = normalized(gt); 19 | % sample sigma uniformly from sigmas 20 | idx = randi(length(sigmas), size(gt,3), 1); 21 | sigma = sigmas(idx); 22 | disp(sigma) 23 | length(sigma) 24 | s = reshape(sigma, 1, 1, length(sigma)); 25 | input = gt + s/255 .* randn(size(gt)); 26 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /matlab/generate_dataset_stripe.m: -------------------------------------------------------------------------------- 1 | function [ ] = generate_dataset_stripe( datadir, fns, newdir, sigmas, gt_key, preprocess ) 2 | %GENERATE_DATASET Summary of this function goes here 3 | min_amount = 0.05; 4 | max_amount = 0.15; 5 | if ~exist(newdir, 'dir') 6 | mkdir(newdir) 7 | end 8 | 9 | for k = 1:length(fns) 10 | fn = fns{k}; 11 | fprintf('generate data(%d/%d)\n', k, length(fns)); 12 | filepath = fullfile(datadir, fn); 13 | mat = load(filepath); % contain gt_key 14 | gt = getfield(mat, gt_key); 15 | 16 | if exist('preprocess', 'var') 17 | gt = preprocess(gt); 18 | end 19 | 20 | gt = normalized(gt); 21 | % sample sigma uniformly from sigmas 22 | idx = randi(length(sigmas), size(gt,3), 1); 23 | sigma = sigmas(idx); 24 | disp(sigma) 25 | 26 | s = reshape(sigma, 1, 1, length(sigma)); 27 | input = gt + s/255 .* randn(size(gt)); 28 | 29 | [~, N, B] = size(gt); 30 | band = randperm(B); 31 | band = band(1:10); 32 | 33 | stripnum = randi([ceil(min_amount * N), ceil(max_amount * N)], length(band), 1); 34 | disp(stripnum); 35 | for i=1:length(band) 36 | loc = randperm(N); 37 | loc = loc(1:stripnum(i)); 38 | stripe = rand(1,length(loc))*0.5-0.25; 39 | input(:,loc,band(i)) = input(:,loc,band(i)) - stripe; 40 | end 41 | save(fullfile(newdir, fn), 'gt', 'input', 'sigma'); 42 | end 43 | end 44 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .qrnn import REDC3D 2 | from .qrnn import QRNNREDC3D 3 | #attention 4 | from .sru.sru3d import SRUREDC3D 5 | from models.nssnn.nssnn import NSSNN 6 | 7 | """Define commonly used architecture""" 8 | 9 | def nssnn(): 10 | net = NSSNN(1,16,5,[1,3],has_ad=True,bn=False) 11 | net.use_2dconv = False 12 | net.bandwise = False 13 | return net 14 | 15 | def residualunet3d(): 16 | net = REDC3D(1, 16, 5, 3) 17 | net.use_2dconv = False 18 | net.bandwise = False 19 | return net 20 | 21 | def nssnn_7layers(): 22 | net = NSSNN(1, 16, 7, [1,3,5], has_ad=True,bn=False) 23 | net.use_2dconv = False 24 | net.bandwise = False 25 | return net 26 | 27 | def nssnn_3layers(): 28 | net = NSSNN(1, 16, 3, [1], has_ad=True,bn=False) 29 | net.use_2dconv = False 30 | net.bandwise = False 31 | return net 32 | 33 | def sru3d_nobn(): 34 | net = SRUREDC3D(1, 16, 5, [1,3], has_ad=True,bn=False) 35 | net.use_2dconv = False 36 | net.bandwise = False 37 | return net 38 | 39 | def qrnn3d(): 40 | net = QRNNREDC3D(1, 16, 5, [1,3], has_ad=True,bn=True) 41 | net.use_2dconv = False 42 | net.bandwise = False 43 | return net 44 | 45 | def qrnn3d_nobn(): 46 | net = QRNNREDC3D(1, 16, 5, [1,3], has_ad=True,bn=False) 47 | net.use_2dconv = False 48 | net.bandwise = False 49 | return net 50 | 51 | def qrnn2d(): 52 | net = QRNNREDC3D(1, 16, 5, [1,3], has_ad=True, is_2d=True) 53 | net.use_2dconv = False 54 | net.bandwise = False 55 | return net 56 | 57 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/combinations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/__pycache__/combinations.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/im2col.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/__pycache__/im2col.cpython-37.pyc -------------------------------------------------------------------------------- /models/im2col.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | import torch 3 | from torch.nn.modules.utils import _pair 4 | import math 5 | 6 | 7 | def Im2Col(input_tensor, kernel_size, stride, padding,dilation=1,tensorized=False,): 8 | batch = input_tensor.shape[0] 9 | out = F.unfold(input_tensor, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 10 | 11 | if tensorized: 12 | lh,lw = im2col_shape(input_tensor.shape[1:],kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation)[-2:] 13 | out = out.view(batch,-1,lh,lw) 14 | return out 15 | def Cube2Col(input_tensor, kernel_size, stride, padding,dilation=1,tensorized=False,device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): 16 | input_sz=input_tensor.shape 17 | _t=(input_sz[1]-kernel_size[0])//stride[0]+1 18 | out=torch.zeros(input_sz[0],kernel_size[0]*kernel_size[1]*kernel_size[2],(input_sz[1]-kernel_size[0])//stride[0]+1,(input_sz[2]-kernel_size[1])//stride[1]+1,(input_sz[3]-kernel_size[2])//stride[2]+1).to(device) 19 | for i in range(_t): 20 | ind1=i*stride[0] 21 | ind2=i*stride[0]+kernel_size[0] 22 | temp=Im2Col(input_tensor[:,ind1:ind2,:,:], (kernel_size[1],kernel_size[2]), (stride[1],stride[2]), padding, dilation, tensorized) 23 | out[:,:,i,:,:]=temp 24 | #out[:,:,i,:,:]=Im2Col(input_tensor[:,ind1:ind2,:,:], kernel_size, stride, padding, dilation, tensorized) 25 | return out 26 | 27 | def Col2Cube(input_tensor,output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False,device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): 28 | batch = input_tensor.shape[0] 29 | _t = (output_size[0] - kernel_size[0] )//stride[0]+ 1 30 | out = torch.zeros([batch,output_size[0],output_size[1],output_size[2]]).to(input_tensor.device) 31 | me=torch.zeros_like(out).to(input_tensor.device) 32 | for i in range(_t): 33 | j = i*stride[0] 34 | ind1 = j 35 | ind2 = j + kernel_size[0] 36 | if input_tensorized: 37 | #temp_ = input_tensor[:,:,i,:,:].flatten(2,3) 38 | #print(input_tensor[:,:,i,:,:].max()) 39 | temp_tensor = input_tensor[:,:,i,:,:].flatten(2,3) 40 | #print(temp_tensor.max()) 41 | #temp__ =F.fold(temp_tensor, output_size=output_size[1:], kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 42 | test =F.fold(temp_tensor, output_size=output_size[1:], kernel_size=(kernel_size[1],kernel_size[2]), padding=padding, stride=(stride[1],stride[2]),dilation=dilation) 43 | #nptest = test.cpu().numpy() 44 | #print(test.max()) 45 | out[:,ind1:ind2,:,:]+=test 46 | #ttt = torch.sum(test) 47 | #print(ttt) 48 | #out[:,ind1:ind2,:,:] += F.fold(temp_tensor, output_size=output_size[1:], kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 49 | #temp___ = F.fold(torch.ones_like(temp_tensor), output_size=output_size[1:], kernel_size=kernel_size, 50 | # padding=padding, stride=stride, dilation=dilation).numpy() 51 | temp = F.fold(torch.ones_like(temp_tensor), output_size=output_size[1:], kernel_size=(kernel_size[1],kernel_size[2]), padding=padding, stride=(stride[1],stride[2]),dilation=dilation) 52 | #nptemp = temp.cpu().numpy() 53 | #temp = torch.ones_like(test) 54 | #print(temp.max()) 55 | me[:,ind1:ind2,:,:] += temp 56 | #tt2 = torch.sum(temp) 57 | #print(tt2) 58 | pass 59 | 60 | 61 | 62 | if avg: 63 | me[me==0]=1 # !!!!!!! 64 | 65 | #print(me.max()) 66 | ''' 67 | test= me[0,15:65,15:65,15] 68 | test = test.cpu().numpy() 69 | test_out = out[0,15:65,15:65,15] 70 | test_out = test_out.cpu().numpy() 71 | ''' 72 | #print(test) 73 | out = out / me 74 | 75 | # me_ = F.conv_transpose2d(torch.ones_like(input_tensor),torch.ones(1,1,kernel_size,kernel_size)) 76 | 77 | return out 78 | 79 | 80 | def Col2Im(input_tensor,output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 81 | batch = input_tensor.shape[0] 82 | 83 | if input_tensorized: 84 | input_tensor = input_tensor.flatten(2,3) 85 | out = F.fold(input_tensor, output_size=output_size, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 86 | 87 | if avg: 88 | me = F.fold(torch.ones_like(input_tensor), output_size=output_size, kernel_size=kernel_size, padding=padding, stride=stride,dilation=dilation) 89 | # me[me==0]=1 # !!!!!!! 90 | out = out / me 91 | 92 | # me_ = F.conv_transpose2d(torch.ones_like(input_tensor),torch.ones(1,1,kernel_size,kernel_size)) 93 | 94 | return out 95 | 96 | 97 | class Col2Im_(torch.nn.Module): 98 | 99 | def __init__(self,input_shape, output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 100 | super(Col2Im_,self).__init__() 101 | 102 | xshape = tuple(input_shape) 103 | 104 | if input_tensorized: 105 | xshape = xshape[0:2]+(xshape[2]*xshape[3],) 106 | 107 | if avg: 108 | me = F.fold(torch.ones(xshape), output_size=output_size, kernel_size=kernel_size, 109 | padding=padding, stride=stride, dilation=dilation) 110 | me[me == 0] = 1 111 | self.me = me 112 | 113 | def forward(self, input_tensor,output_size, kernel_size, stride, padding, dilation=1, avg=False,input_tensorized=False): 114 | if input_tensorized: 115 | input_tensor = input_tensor.flatten(2, 3) 116 | out = F.fold(input_tensor, output_size=output_size, kernel_size=kernel_size, padding=padding, stride=stride, 117 | dilation=dilation) 118 | if avg: 119 | out /= self.me 120 | return out 121 | 122 | # def im2col_shape(size, kernel_size, stride, padding): 123 | # ksize_h, ksize_w = _pair(kernel_size) 124 | # stride_h, stride_w = _pair(stride) 125 | # pad_h, pad_w = _pair(padding) 126 | # n_input_plane, height, width = size 127 | # height_col = (height + 2 * pad_h - ksize_h) // stride_h + 1 128 | # width_col = (width + 2 * pad_w - ksize_w) // stride_w + 1 129 | # return n_input_plane, ksize_h, ksize_w, height_col, width_col 130 | 131 | def im2col_shape(size, kernel_size, stride, padding, dilation): 132 | ksize_h, ksize_w = _pair(kernel_size) 133 | stride_h, stride_w = _pair(stride) 134 | dil_h, dil_w = _pair(dilation) 135 | pad_h, pad_w = _pair(padding) 136 | n_input_plane, height, width = size 137 | height_col = (height + 2 * pad_h - dil_h * (ksize_h-1)-1) / stride_h + 1 138 | width_col = (width + 2 * pad_w - dil_w * (ksize_w-1)-1) / stride_w + 1 139 | return n_input_plane, ksize_h, ksize_w, math.floor(height_col), math.floor(width_col) 140 | 141 | 142 | def col2im_shape(size, kernel_size, stride, padding, input_size=None): 143 | ksize_h, ksize_w = _pair(kernel_size) 144 | stride_h, stride_w = _pair(stride) 145 | pad_h, pad_w = _pair(padding) 146 | n_input_plane, ksize_h, ksize_w, height_col, width_col = size 147 | if input_size is not None: 148 | height, width = input_size 149 | else: 150 | height = (height_col - 1) * stride_h - 2 * pad_h + ksize_h 151 | width = (width_col - 1) * stride_w - 2 * pad_w + ksize_w 152 | return n_input_plane, height, width -------------------------------------------------------------------------------- /models/nssnn/__init__.py: -------------------------------------------------------------------------------- 1 | import models.nssnn.cc -------------------------------------------------------------------------------- /models/nssnn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/nssnn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/nssnn/__pycache__/cc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/nssnn/__pycache__/cc.cpython-37.pyc -------------------------------------------------------------------------------- /models/nssnn/__pycache__/nssnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/nssnn/__pycache__/nssnn.cpython-37.pyc -------------------------------------------------------------------------------- /models/nssnn/cc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Softmax 5 | # print("CC Moudle") 6 | def INF(B,H,W,device): 7 | return -torch.diag(torch.tensor(float("inf")).to(device).repeat(H),0).unsqueeze(0).repeat(B*W,1,1) 8 | class CC_module(nn.Module): 9 | def __init__(self,in_dim): 10 | super(CC_module, self).__init__() 11 | if in_dim >=8: 12 | out_channels_query = in_dim//8 13 | else: 14 | out_channels_query = 1 15 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=out_channels_query, kernel_size=1) 16 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=out_channels_query, kernel_size=1) 17 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 18 | self.softmax = Softmax(dim=3) 19 | self.INF = INF 20 | self.gamma = nn.Parameter(torch.zeros(1)) 21 | def forward(self, x): 22 | m_batchsize, _, height, width = x.size() 23 | proj_query = self.query_conv(x) 24 | proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) 25 | proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) 26 | proj_key = self.key_conv(x) 27 | proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) 28 | proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) 29 | proj_value = self.value_conv(x) 30 | proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) 31 | proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) 32 | energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width,x.device)).view(m_batchsize,width,height,height).permute(0,2,1,3) 33 | energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width) 34 | concate = self.softmax(torch.cat([energy_H, energy_W], 3)) 35 | 36 | att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) 37 | #print(concate) 38 | 39 | #print(att_H) 40 | att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) 41 | out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) 42 | out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) 43 | #print(out_H.size(),out_W.size()) 44 | return self.gamma*(out_H + out_W) + x 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | model = CC_module(5).to('cuda:0') 50 | x = torch.randn(1, 5, 1, 64, 64).to('cuda:0') 51 | for i in range(31): 52 | x[:,:,i,:,:]= model(x[:,:,i,:,:]) 53 | print(x.shape) 54 | -------------------------------------------------------------------------------- /models/qrnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .qrnn3d import QRNNREDC3D 2 | from .redc3d import REDC3D 3 | from .resnet import ResQRNN3D -------------------------------------------------------------------------------- /models/qrnn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/combinations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/combinations.cpython-37.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/combinations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/combinations.cpython-38.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/qrnn3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/qrnn3d.cpython-37.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/qrnn3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/qrnn3d.cpython-38.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/redc3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/redc3d.cpython-37.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/redc3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/redc3d.cpython-38.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/qrnn/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/qrnn/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/qrnn/combinations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | from models.sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 5 | 6 | BatchNorm3d = SynchronizedBatchNorm3d 7 | 8 | 9 | class BNReLUConv3d(nn.Sequential): 10 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 11 | super(BNReLUConv3d, self).__init__() 12 | self.add_module('bn', BatchNorm3d(in_channels)) 13 | self.add_module('relu', nn.ReLU(inplace=inplace)) 14 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 15 | 16 | 17 | class BNReLUDeConv3d(nn.Sequential): 18 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 19 | super(BNReLUDeConv3d, self).__init__() 20 | self.add_module('bn', BatchNorm3d(in_channels)) 21 | self.add_module('relu', nn.ReLU(inplace=inplace)) 22 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 23 | 24 | 25 | class BNReLUUpsampleConv3d(nn.Sequential): 26 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 27 | super(BNReLUUpsampleConv3d, self).__init__() 28 | self.add_module('bn', BatchNorm3d(in_channels)) 29 | self.add_module('relu', nn.ReLU(inplace=inplace)) 30 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 31 | 32 | 33 | class UpsampleConv3d(torch.nn.Module): 34 | """UpsampleConvLayer 35 | Upsamples the input and then does a convolution. This method gives better results 36 | compared to ConvTranspose2d. 37 | ref: http://distill.pub/2016/deconv-checkerboard/ 38 | """ 39 | 40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 41 | super(UpsampleConv3d, self).__init__() 42 | self.upsample = upsample 43 | if upsample: 44 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 45 | 46 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 47 | 48 | def forward(self, x): 49 | x_in = x 50 | if self.upsample: 51 | x_in = self.upsample_layer(x_in) 52 | out = self.conv3d(x_in) 53 | return out 54 | 55 | 56 | class BasicConv3d(nn.Sequential): 57 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 58 | super(BasicConv3d, self).__init__() 59 | if bn: 60 | self.add_module('bn', BatchNorm3d(in_channels)) 61 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias)) 62 | 63 | 64 | class BasicDeConv3d(nn.Sequential): 65 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 66 | super(BasicDeConv3d, self).__init__() 67 | if bn: 68 | self.add_module('bn', BatchNorm3d(in_channels)) 69 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias)) 70 | 71 | 72 | class BasicUpsampleConv3d(nn.Sequential): 73 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True): 74 | super(BasicUpsampleConv3d, self).__init__() 75 | if bn: 76 | self.add_module('bn', BatchNorm3d(in_channels)) 77 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 78 | -------------------------------------------------------------------------------- /models/qrnn/qrnn3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as FF 4 | import numpy as np 5 | 6 | from functools import partial 7 | 8 | if __name__ == '__main__': 9 | from combinations import * 10 | from utils import * 11 | else: 12 | from .combinations import * 13 | from .utils import * 14 | 15 | 16 | """F pooling""" 17 | class QRNN3DLayer(nn.Module): 18 | def __init__(self, in_channels, hidden_channels, conv_layer, act='tanh'): 19 | super(QRNN3DLayer, self).__init__() 20 | self.in_channels = in_channels 21 | self.hidden_channels = hidden_channels 22 | # quasi_conv_layer 23 | self.conv = conv_layer 24 | self.act = act 25 | 26 | def _conv_step(self, inputs): 27 | gates = self.conv(inputs) 28 | Z, F = gates.split(split_size=self.hidden_channels, dim=1) 29 | if self.act == 'tanh': 30 | return Z.tanh(), F.sigmoid() 31 | elif self.act == 'relu': 32 | return Z.relu(), F.sigmoid() 33 | elif self.act == 'none': 34 | return Z, F.sigmoid 35 | else: 36 | raise NotImplementedError 37 | 38 | def _rnn_step(self, z, f, h): 39 | # uses 'f pooling' at each time step 40 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z 41 | return h_ 42 | 43 | def forward(self, inputs, reverse=False): 44 | h = None 45 | Z, F = self._conv_step(inputs) 46 | h_time = [] 47 | 48 | if not reverse: 49 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): # split along timestep 50 | h = self._rnn_step(z, f, h) 51 | h_time.append(h) 52 | else: 53 | for time, (z, f) in enumerate((zip( 54 | reversed(Z.split(1, 2)), reversed(F.split(1, 2)) 55 | ))): # split along timestep 56 | h = self._rnn_step(z, f, h) 57 | h_time.insert(0, h) 58 | 59 | # return concatenated hidden states 60 | return torch.cat(h_time, dim=2) 61 | 62 | def extra_repr(self): 63 | return 'act={}'.format(self.act) 64 | 65 | 66 | class BiQRNN3DLayer(QRNN3DLayer): 67 | def _conv_step(self, inputs): 68 | gates = self.conv(inputs) 69 | Z, F1, F2 = gates.split(split_size=self.hidden_channels, dim=1) 70 | if self.act == 'tanh': 71 | return Z.tanh(), F1.sigmoid(), F2.sigmoid() 72 | elif self.act == 'relu': 73 | return Z.relu(), F1.sigmoid(), F2.sigmoid() 74 | elif self.act == 'none': 75 | return Z, F1.sigmoid(), F2.sigmoid() 76 | else: 77 | raise NotImplementedError 78 | 79 | def forward(self, inputs, fname=None): 80 | h = None 81 | Z, F1, F2 = self._conv_step(inputs) 82 | hsl = [] ; hsr = [] 83 | zs = Z.split(1, 2) 84 | 85 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))): # split along timestep 86 | h = self._rnn_step(z, f, h) 87 | hsl.append(h) 88 | 89 | h = None 90 | for time, (z, f) in enumerate((zip( 91 | reversed(zs), reversed(F2.split(1, 2)) 92 | ))): # split along timestep 93 | h = self._rnn_step(z, f, h) 94 | hsr.insert(0, h) 95 | 96 | # return concatenated hidden states 97 | hsl = torch.cat(hsl, dim=2) 98 | hsr = torch.cat(hsr, dim=2) 99 | 100 | if fname is not None: 101 | stats_dict = {'z':Z, 'fl':F1, 'fr':F2, 'hsl':hsl, 'hsr':hsr} 102 | torch.save(stats_dict, fname) 103 | return hsl + hsr 104 | 105 | 106 | class BiQRNNConv3D(BiQRNN3DLayer): 107 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 108 | super(BiQRNNConv3D, self).__init__( 109 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels*3, k, s, p, bn=bn), act=act) 110 | 111 | 112 | class BiQRNNDeConv3D(BiQRNN3DLayer): 113 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bias=False, bn=True, act='tanh'): 114 | super(BiQRNNDeConv3D, self).__init__( 115 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels*3, k, s, p, bias=bias, bn=bn), act=act) 116 | 117 | 118 | class QRNNConv3D(QRNN3DLayer): 119 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 120 | super(QRNNConv3D, self).__init__( 121 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels*2, k, s, p, bn=bn), act=act) 122 | 123 | 124 | class QRNNDeConv3D(QRNN3DLayer): 125 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 126 | super(QRNNDeConv3D, self).__init__( 127 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels*2, k, s, p, bn=bn), act=act) 128 | 129 | 130 | class QRNNUpsampleConv3d(QRNN3DLayer): 131 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True, act='tanh'): 132 | super(QRNNUpsampleConv3d, self).__init__( 133 | in_channels, hidden_channels, BasicUpsampleConv3d(in_channels, hidden_channels*2, k, s, p, upsample, bn=bn), act=act) 134 | 135 | 136 | QRNN3DEncoder = partial( 137 | QRNN3DEncoder, 138 | QRNNConv3D=QRNNConv3D) 139 | 140 | QRNN3DDecoder = partial( 141 | QRNN3DDecoder, 142 | QRNNDeConv3D=QRNNDeConv3D, 143 | QRNNUpsampleConv3d=QRNNUpsampleConv3d) 144 | 145 | QRNNREDC3D = partial( 146 | QRNNREDC3D, 147 | BiQRNNConv3D=BiQRNNConv3D, 148 | BiQRNNDeConv3D=BiQRNNDeConv3D, 149 | QRNN3DEncoder=QRNN3DEncoder, 150 | QRNN3DDecoder=QRNN3DDecoder 151 | ) 152 | 153 | -------------------------------------------------------------------------------- /models/qrnn/redc3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | if __name__ == '__main__': 7 | from combinations import * 8 | else: 9 | from .combinations import * 10 | 11 | 12 | class REDC3D(torch.nn.Module): 13 | """Residual Encoder-Decoder Convolution 3D 14 | Args: 15 | downsample: downsample times, None denotes no downsample""" 16 | def __init__(self, in_channels, channels, num_half_layer, downsample=None): 17 | super(REDC3D, self).__init__() 18 | # Encoder 19 | assert downsample is None or 0 < downsample <= num_half_layer 20 | interval = num_half_layer // downsample if downsample else num_half_layer+1 21 | 22 | self.feature_extractor = BNReLUConv3d(in_channels, channels) 23 | self.encoder = nn.ModuleList() 24 | for i in range(1, num_half_layer+1): 25 | if i % interval: 26 | encoder_layer = BNReLUConv3d(channels, channels) 27 | else: 28 | encoder_layer = BNReLUConv3d(channels, 2*channels, k=3, s=(1,2,2), p=1) 29 | channels *= 2 30 | self.encoder.append(encoder_layer) 31 | # Decoder 32 | self.decoder = nn.ModuleList() 33 | for i in range(1,num_half_layer+1): 34 | if i % interval: 35 | decoder_layer = BNReLUDeConv3d(channels, channels) 36 | else: 37 | decoder_layer = BNReLUUpsampleConv3d(channels, channels//2) 38 | channels //= 2 39 | self.decoder.append(decoder_layer) 40 | self.reconstructor = BNReLUDeConv3d(channels, in_channels) 41 | 42 | def forward(self, x): 43 | num_half_layer = len(self.encoder) 44 | xs = [x] 45 | out = self.feature_extractor(xs[0]) 46 | xs.append(out) 47 | for i in range(num_half_layer-1): 48 | out = self.encoder[i](out) 49 | xs.append(out) 50 | out = self.encoder[-1](out) 51 | out = self.decoder[0](out) 52 | for i in range(1, num_half_layer): 53 | out = out + xs.pop() 54 | out = self.decoder[i](out) 55 | out = out + xs.pop() 56 | out = self.reconstructor(out) 57 | out = out + xs.pop() 58 | return out 59 | -------------------------------------------------------------------------------- /models/qrnn/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | 6 | if __name__ == '__main__': 7 | from qrnn3d import * 8 | else: 9 | from .qrnn3d import * 10 | 11 | 12 | class ResQRNN3D(nn.Module): 13 | def __init__(self, in_channels, channels, n_resblocks): 14 | super(ResQRNN3D, self).__init__() 15 | 16 | bn = True 17 | act = 'tanh' 18 | 19 | # define head module 20 | m_head = [BiQRNNConv3D(in_channels, channels, bn=bn, act=act)] 21 | 22 | # define body module 23 | m_body = [ 24 | ResBlock( 25 | QRNNConv3D, channels, bn=bn, act=act 26 | ) for i in range(n_resblocks) 27 | ] 28 | 29 | # define tail module 30 | m_tail = [ 31 | BiQRNNConv3D(channels, in_channels, bn=bn, act='none') 32 | ] 33 | 34 | self.head = nn.Sequential(*m_head) 35 | self.body = nn.Sequential(*m_body) 36 | self.tail = nn.Sequential(*m_tail) 37 | 38 | def forward(self, x): 39 | x = self.head(x) 40 | res = self.body(x) 41 | res += x 42 | x = self.tail(res) 43 | return x 44 | 45 | 46 | class ResBlock(nn.Module): 47 | def __init__( 48 | self, block, channels, **kwargs): 49 | super(ResBlock, self).__init__() 50 | self.layer1 = block(channels, channels, **kwargs) 51 | self.layer2 = block(channels, channels, **kwargs) 52 | 53 | def forward(self, x, reverse=False): 54 | res = self.layer1(x, reverse) 55 | res = self.layer2(x, not reverse) 56 | res += x 57 | 58 | return res 59 | -------------------------------------------------------------------------------- /models/qrnn/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import scipy.io as io 5 | 6 | class QRNNREDC3D(nn.Module): 7 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, 8 | BiQRNNConv3D=None, BiQRNNDeConv3D=None, 9 | QRNN3DEncoder=None, QRNN3DDecoder=None, is_2d=False, has_ad=True, bn=True, act='tanh', plain=False): 10 | super(QRNNREDC3D, self).__init__() 11 | assert sample_idx is None or isinstance(sample_idx, list) 12 | 13 | self.enable_ad = has_ad 14 | if sample_idx is None: sample_idx = [] 15 | if is_2d: 16 | self.feature_extractor = BiQRNNConv3D(in_channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 17 | else: 18 | self.feature_extractor = BiQRNNConv3D(in_channels, channels, bn=bn, act=act) 19 | 20 | self.encoder = QRNN3DEncoder(channels, num_half_layer, sample_idx, is_2d=is_2d, has_ad=has_ad, bn=bn, act=act, plain=plain) 21 | self.decoder = QRNN3DDecoder(channels*(2**len(sample_idx)), num_half_layer, sample_idx, is_2d=is_2d, has_ad=has_ad, bn=bn, act=act, plain=plain) 22 | 23 | if act == 'relu': 24 | act = 'none' 25 | 26 | if is_2d: 27 | self.reconstructor = BiQRNNDeConv3D(channels, in_channels, bias=True, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 28 | else: 29 | self.reconstructor = BiQRNNDeConv3D(channels, in_channels, bias=True, bn=bn, act=act) 30 | 31 | def forward(self, x): 32 | xs = [x] 33 | out = self.feature_extractor(xs[0]) 34 | xs.append(out) 35 | if self.enable_ad: 36 | out, reverse = self.encoder(out, xs, reverse=False) 37 | out = self.decoder(out, xs, reverse=(reverse)) 38 | else: 39 | out = self.encoder(out, xs) 40 | out = self.decoder(out, xs) 41 | out = out + xs.pop() 42 | out = self.reconstructor(out) 43 | out = out + xs.pop() 44 | return out 45 | 46 | 47 | class QRNN3DEncoder(nn.Module): 48 | def __init__(self, channels, num_half_layer, sample_idx, QRNNConv3D=None, 49 | is_2d=False, has_ad=True, bn=True, act='tanh', plain=False): 50 | super(QRNN3DEncoder, self).__init__() 51 | # Encoder 52 | self.layers = nn.ModuleList() 53 | self.enable_ad = has_ad 54 | for i in range(num_half_layer): 55 | if i not in sample_idx: 56 | if is_2d: 57 | encoder_layer = QRNNConv3D(channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 58 | else: 59 | encoder_layer = QRNNConv3D(channels, channels, bn=bn, act=act) 60 | else: 61 | if is_2d: 62 | encoder_layer = QRNNConv3D(channels, 2*channels, k=(1,3,3), s=(1,2,2), p=(0,1,1), bn=bn, act=act) 63 | else: 64 | if not plain: 65 | encoder_layer = QRNNConv3D(channels, 2*channels, k=3, s=(1,2,2), p=1, bn=bn, act=act) 66 | else: 67 | encoder_layer = QRNNConv3D(channels, 2*channels, k=3, s=(1,1,1), p=1, bn=bn, act=act) 68 | 69 | channels *= 2 70 | self.layers.append(encoder_layer) 71 | 72 | def forward(self, x, xs, reverse=False): 73 | if not self.enable_ad: 74 | num_half_layer = len(self.layers) 75 | for i in range(num_half_layer-1): 76 | x = self.layers[i](x) 77 | xs.append(x) 78 | x = self.layers[-1](x) 79 | 80 | return x 81 | else: 82 | num_half_layer = len(self.layers) 83 | for i in range(num_half_layer-1): 84 | x = self.layers[i](x, reverse=reverse) 85 | reverse = not reverse 86 | xs.append(x) 87 | x = self.layers[-1](x, reverse=reverse) 88 | reverse = not reverse 89 | 90 | return x, reverse 91 | 92 | 93 | class QRNN3DDecoder(nn.Module): 94 | def __init__(self, channels, num_half_layer, sample_idx, QRNNDeConv3D=None, QRNNUpsampleConv3d=None, 95 | is_2d=False, has_ad=True, bn=True, act='tanh', plain=False): 96 | super(QRNN3DDecoder, self).__init__() 97 | # Decoder 98 | self.layers = nn.ModuleList() 99 | self.enable_ad = has_ad 100 | for i in reversed(range(num_half_layer)): 101 | if i not in sample_idx: 102 | if is_2d: 103 | decoder_layer = QRNNDeConv3D(channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 104 | else: 105 | decoder_layer = QRNNDeConv3D(channels, channels, bn=bn, act=act) 106 | else: 107 | if is_2d: 108 | decoder_layer = QRNNUpsampleConv3d(channels, channels//2, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 109 | else: 110 | if not plain: 111 | decoder_layer = QRNNUpsampleConv3d(channels, channels//2, bn=bn, act=act) 112 | else: 113 | decoder_layer = QRNNDeConv3D(channels, channels//2, bn=bn, act=act) 114 | 115 | channels //= 2 116 | self.layers.append(decoder_layer) 117 | 118 | 119 | def forward(self, x, xs, reverse=False): 120 | if not self.enable_ad: 121 | num_half_layer = len(self.layers) 122 | x = self.layers[0](x) 123 | for i in range(1, num_half_layer): 124 | x = x + xs.pop() 125 | x = self.layers[i](x) 126 | return x 127 | else: 128 | num_half_layer = len(self.layers) 129 | x = self.layers[0](x, reverse=reverse) 130 | reverse = not reverse 131 | for i in range(1, num_half_layer): 132 | x = x + xs.pop() 133 | x = self.layers[i](x, reverse=reverse) 134 | 135 | # tmp = x 136 | # out_np = tmp[0,11,:,:,:].permute(1,2,0).cpu().numpy() 137 | # io.savemat('/nas_data/fugym/projects_python/SRU3D/results/qrnn3d/qrnn3dfeature'+str(i)+'.mat',{'qrnn3dfeature':out_np}) 138 | # print(x.shape) 139 | reverse = not reverse 140 | return x 141 | -------------------------------------------------------------------------------- /models/sru/__init__.py: -------------------------------------------------------------------------------- 1 | from .sru3d import SRUREDC3D -------------------------------------------------------------------------------- /models/sru/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sru/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/sru/__pycache__/combinations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sru/__pycache__/combinations.cpython-37.pyc -------------------------------------------------------------------------------- /models/sru/__pycache__/sru3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sru/__pycache__/sru3d.cpython-37.pyc -------------------------------------------------------------------------------- /models/sru/combinations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | from models.sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 5 | 6 | #BatchNorm3d = nn.BatchNorm3d 7 | BatchNorm3d = SynchronizedBatchNorm3d 8 | 9 | 10 | class BNReLUConv3d(nn.Sequential): 11 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 12 | super(BNReLUConv3d, self).__init__() 13 | self.add_module('bn', BatchNorm3d(in_channels)) 14 | self.add_module('relu', nn.ReLU(inplace=inplace)) 15 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 16 | 17 | 18 | class BNReLUDeConv3d(nn.Sequential): 19 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 20 | super(BNReLUDeConv3d, self).__init__() 21 | self.add_module('bn', BatchNorm3d(in_channels)) 22 | self.add_module('relu', nn.ReLU(inplace=inplace)) 23 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 24 | 25 | 26 | class BNReLUUpsampleConv3d(nn.Sequential): 27 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 28 | super(BNReLUUpsampleConv3d, self).__init__() 29 | self.add_module('bn', BatchNorm3d(in_channels)) 30 | self.add_module('relu', nn.ReLU(inplace=inplace)) 31 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 32 | 33 | 34 | class UpsampleConv3d(torch.nn.Module): 35 | """UpsampleConvLayer 36 | Upsamples the input and then does a convolution. This method gives better results 37 | compared to ConvTranspose2d. 38 | ref: http://distill.pub/2016/deconv-checkerboard/ 39 | """ 40 | 41 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 42 | super(UpsampleConv3d, self).__init__() 43 | self.upsample = upsample 44 | if upsample: 45 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 46 | 47 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 48 | 49 | def forward(self, x): 50 | x_in = x 51 | if self.upsample: 52 | x_in = self.upsample_layer(x_in) 53 | out = self.conv3d(x_in) 54 | return out 55 | 56 | 57 | class BasicConv3d(nn.Sequential): 58 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 59 | super(BasicConv3d, self).__init__() 60 | if bn: 61 | self.add_module('bn', BatchNorm3d(in_channels)) 62 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias)) 63 | 64 | 65 | class BasicDeConv3d(nn.Sequential): 66 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 67 | super(BasicDeConv3d, self).__init__() 68 | if bn: 69 | self.add_module('bn', BatchNorm3d(in_channels)) 70 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias)) 71 | 72 | 73 | class BasicUpsampleConv3d(nn.Sequential): 74 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True): 75 | super(BasicUpsampleConv3d, self).__init__() 76 | if bn: 77 | self.add_module('bn', BatchNorm3d(in_channels)) 78 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 79 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/replicate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/sync_batchnorm/__pycache__/replicate.cpython-38.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /models/unet/__pycache__/buildingblocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/unet/__pycache__/buildingblocks.cpython-37.pyc -------------------------------------------------------------------------------- /models/unet/__pycache__/buildingblocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/unet/__pycache__/buildingblocks.cpython-38.pyc -------------------------------------------------------------------------------- /models/unet/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/unet/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/unet/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/unet/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/unet/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/unet/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/unet/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/models/unet/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/unet/unet.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.unet.buildingblocks import DoubleConv, ExtResNetBlock, create_encoders, \ 6 | create_decoders 7 | from models.unet.utils import number_of_features_per_level 8 | 9 | 10 | class Abstract3DUNet(nn.Module): 11 | """ 12 | Base class for standard and residual UNet. 13 | 14 | Args: 15 | in_channels (int): number of input channels 16 | out_channels (int): number of output segmentation masks; 17 | Note that that the of out_channels might correspond to either 18 | different semantic classes or to different binary segmentation mask. 19 | It's up to the user of the class to interpret the out_channels and 20 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 21 | or BCEWithLogitsLoss (two-class) respectively) 22 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 23 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 24 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 25 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 26 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 27 | basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....) 28 | layer_order (string): determines the order of layers 29 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 30 | See `SingleConv` for more info 31 | num_groups (int): number of groups for the GroupNorm 32 | num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) 33 | is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied 34 | after the final convolution; if False (regression problem) the normalization layer is skipped at the end 35 | testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`) 36 | will be applied as the last operation during the forward pass; if False the model is in training mode 37 | and the `final_activation` (even if present) won't be applied; default: False 38 | conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module 39 | pool_kernel_size (int or tuple): the size of the window 40 | conv_padding (int or tuple): add zero-padding added to all three sides of the input 41 | """ 42 | 43 | def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', 44 | num_groups=8, num_levels=4, is_segmentation=True, testing=False, 45 | conv_kernel_size=3, pool_kernel_size=2, conv_padding=1, **kwargs): 46 | super(Abstract3DUNet, self).__init__() 47 | 48 | self.testing = testing 49 | 50 | if isinstance(f_maps, int): 51 | f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) 52 | 53 | assert isinstance(f_maps, list) or isinstance(f_maps, tuple) 54 | assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" 55 | 56 | # create encoder path 57 | self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, 58 | num_groups, pool_kernel_size) 59 | 60 | # create decoder path 61 | self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, 62 | upsample=True,scale_factor=(1,2,2)) 63 | 64 | # in the last layer a 1×1 convolution reduces the number of output 65 | # channels to the number of labels 66 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 67 | 68 | if is_segmentation: 69 | # semantic segmentation problem 70 | if final_sigmoid: 71 | self.final_activation = nn.Sigmoid() 72 | else: 73 | self.final_activation = nn.Softmax(dim=1) 74 | else: 75 | # regression problem 76 | self.final_activation = None 77 | 78 | def forward(self, x): 79 | # encoder part 80 | encoders_features = [] 81 | for encoder in self.encoders: 82 | x = encoder(x) 83 | # reverse the encoder outputs to be aligned with the decoder 84 | encoders_features.insert(0, x) 85 | 86 | # remove the last encoder's output from the list 87 | # !!remember: it's the 1st in the list 88 | encoders_features = encoders_features[1:] 89 | 90 | # decoder part 91 | for decoder, encoder_features in zip(self.decoders, encoders_features): 92 | # pass the output from the corresponding encoder and the output 93 | # of the previous decoder 94 | x = decoder(encoder_features, x) 95 | 96 | x = self.final_conv(x) 97 | 98 | # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs 99 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 100 | if self.testing and self.final_activation is not None: 101 | x = self.final_activation(x) 102 | 103 | return x 104 | 105 | 106 | class UNet3D(Abstract3DUNet): 107 | """ 108 | 3DUnet model from 109 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 110 | `. 111 | 112 | Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder 113 | """ 114 | 115 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 116 | num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): 117 | super(UNet3D, self).__init__(in_channels=in_channels, 118 | out_channels=out_channels, 119 | final_sigmoid=final_sigmoid, 120 | basic_module=DoubleConv, 121 | f_maps=f_maps, 122 | layer_order=layer_order, 123 | num_groups=num_groups, 124 | num_levels=num_levels, 125 | is_segmentation=is_segmentation, 126 | conv_padding=conv_padding, 127 | **kwargs) 128 | 129 | 130 | class ResidualUNet3D(Abstract3DUNet): 131 | """ 132 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 133 | Uses ExtResNetBlock as a basic building block, summation joining instead 134 | of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). 135 | Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 136 | """ 137 | 138 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 139 | num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, **kwargs): 140 | super(ResidualUNet3D, self).__init__(in_channels=in_channels, 141 | out_channels=out_channels, 142 | final_sigmoid=final_sigmoid, 143 | basic_module=ExtResNetBlock, 144 | f_maps=f_maps, 145 | layer_order=layer_order, 146 | num_groups=num_groups, 147 | num_levels=num_levels, 148 | is_segmentation=is_segmentation, 149 | pool_kernel_size=(1,2,2), 150 | conv_padding=1, 151 | **kwargs) 152 | 153 | 154 | class UNet2D(Abstract3DUNet): 155 | """ 156 | Just a standard 2D Unet. Arises naturally by specifying conv_kernel_size=(1, 3, 3), pool_kernel_size=(1, 2, 2). 157 | """ 158 | 159 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 160 | num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): 161 | if conv_padding == 1: 162 | conv_padding = (0, 1, 1) 163 | super(UNet2D, self).__init__(in_channels=in_channels, 164 | out_channels=out_channels, 165 | final_sigmoid=final_sigmoid, 166 | basic_module=DoubleConv, 167 | f_maps=f_maps, 168 | layer_order=layer_order, 169 | num_groups=num_groups, 170 | num_levels=num_levels, 171 | is_segmentation=is_segmentation, 172 | conv_kernel_size=(1, 3, 3), 173 | pool_kernel_size=(1, 2, 2), 174 | conv_padding=conv_padding, 175 | **kwargs) 176 | 177 | 178 | def get_model(model_config): 179 | def _model_class(class_name): 180 | modules = ['pytorch3dunet.unet3d.model'] 181 | for module in modules: 182 | m = importlib.import_module(module) 183 | clazz = getattr(m, class_name, None) 184 | if clazz is not None: 185 | return clazz 186 | 187 | model_class = _model_class(model_config['name']) 188 | return model_class(**model_config) 189 | 190 | if __name__ == '__main__': 191 | unet = ResidualUNet3D(1,1,True,16,num_levels=5) 192 | rand = torch.randn((1,1,31,512,512)) 193 | out = unet(rand) 194 | pass -------------------------------------------------------------------------------- /torch_37.yaml: -------------------------------------------------------------------------------- 1 | name: torch_37 2 | channels: 3 | - defaults 4 | - pytorch 5 | - anaconda 6 | - conda-forge 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _tflow_select=2.3.0=mkl 10 | - absl-py=0.13.0=py37h06a4308_0 11 | - astor=0.8.1=py37h06a4308_0 12 | - backcall=0.2.0=pyh9f0ad1d_0 13 | - backports=1.0=py_2 14 | - backports.functools_lru_cache=1.6.1=py_0 15 | - blas=1.0=mkl 16 | - boost=1.69.0=py37h8619c78_1001 17 | - boost-cpp=1.69.0=h11c811c_1000 18 | - bzip2=1.0.8=h516909a_3 19 | - c-ares=1.17.1=h27cfd23_0 20 | - ca-certificates=2021.9.30=h06a4308_1 21 | - caffe=1.0=py37hbab4207_5 22 | - cairo=1.16.0=h18b612c_1001 23 | - certifi=2021.10.8=py37h06a4308_0 24 | - cffi=1.14.3=py37he30daa8_0 25 | - cloudpickle=1.6.0=py_0 26 | - coverage=5.5=py37h27cfd23_2 27 | - cpuonly=1.0=0 28 | - cudatoolkit=10.1.243=h6bb024c_0 29 | - cycler=0.10.0=py_2 30 | - cython=0.29.24=py37h295c915_0 31 | - cytoolz=0.11.0=py37h8f50634_1 32 | - dask-core=2.30.0=py_0 33 | - dbus=1.13.6=he372182_0 34 | - decorator=4.4.2=py_0 35 | - einops=0.3.2=pyhd8ed1ab_0 36 | - expat=2.2.9=he1b5a44_2 37 | - ffmpeg=4.0=hcdf2ecd_0 38 | - fontconfig=2.13.1=he4413a7_1000 39 | - freeglut=3.0.0=hf484d3e_1005 40 | - freetype=2.10.4=h5ab3b9f_0 41 | - gast=0.2.2=py37_0 42 | - gflags=2.2.2=he1b5a44_1004 43 | - glib=2.66.1=h92f7085_0 44 | - glog=0.3.5=hf484d3e_1001 45 | - google-pasta=0.2.0=pyhd3eb1b0_0 46 | - graphite2=1.3.13=h58526e2_1001 47 | - grpcio=1.36.1=py37h2157cd5_1 48 | - gst-plugins-base=1.14.0=hbbd80ab_1 49 | - gstreamer=1.14.0=hb31296c_0 50 | - h5py=2.8.0=py37h989c5e5_3 51 | - harfbuzz=1.8.8=hffaf4a1_0 52 | - hdf5=1.10.2=hc401514_3 53 | - icu=58.2=hf484d3e_1000 54 | - imageio=2.9.0=py_0 55 | - importlib-metadata=4.8.1=py37h06a4308_0 56 | - intel-openmp=2020.2=254 57 | - ipython=7.18.1=py37h5ca1d4c_0 58 | - ipython_genutils=0.2.0=py_1 59 | - jasper=2.0.14=h07fcdf6_1 60 | - jedi=0.17.2=py37h89c1867_1 61 | - joblib=1.0.1=pyhd3eb1b0_0 62 | - jpeg=9b=h024ee3a_2 63 | - keras=2.3.1=0 64 | - keras-applications=1.0.8=py_1 65 | - keras-base=2.3.1=py37_0 66 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 67 | - kiwisolver=1.2.0=py37h99015e2_1 68 | - lcms2=2.11=h396b838_0 69 | - ld_impl_linux-64=2.33.1=h53a641e_7 70 | - leveldb=1.20=h6416369_1001 71 | - libblas=3.8.0=20_mkl 72 | - libboost=1.67.0=h46d08c1_4 73 | - libcblas=3.8.0=20_mkl 74 | - libedit=3.1.20191231=h14c3975_1 75 | - libffi=3.3=he6710b0_2 76 | - libgcc-ng=9.1.0=hdf63c60_0 77 | - libgfortran=3.0.0=1 78 | - libgfortran-ng=7.3.0=hdf63c60_0 79 | - libglu=9.0.0=he1b5a44_1001 80 | - liblapack=3.8.0=20_mkl 81 | - libopencv=3.4.2=hb342d67_1 82 | - libopus=1.3.1=h7b6447c_0 83 | - libpng=1.6.37=hbc83047_0 84 | - libprotobuf=3.13.0.1=h8b12597_0 85 | - libstdcxx-ng=9.1.0=hdf63c60_0 86 | - libtiff=4.1.0=h2733197_1 87 | - libuuid=2.32.1=h14c3975_1000 88 | - libuv=1.40.0=h7b6447c_0 89 | - libvpx=1.7.0=h439df22_0 90 | - libxcb=1.13=h14c3975_1002 91 | - libxml2=2.9.10=hb55368b_3 92 | - lmdb=0.9.29=h2531618_0 93 | - lz4-c=1.9.2=heb0550a_3 94 | - markdown=3.3.4=py37h06a4308_0 95 | - matplotlib=3.3.2=py37hc8dfbb8_1 96 | - matplotlib-base=3.3.2=py37hc9afd2a_1 97 | - mkl=2020.2=256 98 | - mkl-service=2.3.0=py37he904b0f_0 99 | - mkl_fft=1.2.0=py37h23d657b_0 100 | - mkl_random=1.1.1=py37h0573a6f_0 101 | - ncurses=6.2=he6710b0_1 102 | - networkx=2.5=py_0 103 | - ninja=1.10.1=py37hfd86e86_0 104 | - numpy-base=1.19.2=py37hfa32c7d_0 105 | - olefile=0.46=py37_0 106 | - opencv=3.4.2=py37h6fd60c2_1 107 | - openssl=1.1.1l=h7f8727e_0 108 | - opt_einsum=3.3.0=pyhd3eb1b0_1 109 | - pandas=1.1.4=py37h10a2094_0 110 | - parso=0.7.1=pyh9f0ad1d_0 111 | - pcre=8.44=he1b5a44_0 112 | - pexpect=4.8.0=pyh9f0ad1d_2 113 | - pickleshare=0.7.5=py_1003 114 | - pillow=8.0.0=py37h9a89aac_0 115 | - pip=20.2.4=py37_0 116 | - pixman=0.38.0=h516909a_1003 117 | - prompt-toolkit=3.0.8=pyha770c72_0 118 | - pthread-stubs=0.4=h14c3975_1001 119 | - ptyprocess=0.6.0=py_1001 120 | - py-boost=1.67.0=py37h04863e7_4 121 | - py-opencv=3.4.2=py37hb342d67_1 122 | - pycparser=2.20=py_2 123 | - pygments=2.7.2=py_0 124 | - pyparsing=2.4.7=pyh9f0ad1d_0 125 | - pyqt=5.9.2=py37hcca6a23_4 126 | - python=3.7.9=h7579374_0 127 | - python-dateutil=2.8.1=py_0 128 | - python-gflags=3.1.2=py_0 129 | - python-leveldb=0.201=py37he6710b0_0 130 | - python-lmdb=1.2.1=py37h2531618_1 131 | - python-spams=2.6.1=py37h55324e4_1204 132 | - python_abi=3.7=1_cp37m 133 | - pytz=2020.4=pyhd8ed1ab_0 134 | - pywavelets=1.1.1=py37h161383b_3 135 | - pyyaml=5.3.1=py37hb5d75c8_1 136 | - qt=5.9.7=h5867ecd_1 137 | - readline=8.0=h7b6447c_0 138 | - scikit-image=0.17.2=py37hdf5156a_0 139 | - scikit-learn=0.24.2=py37ha9443f7_0 140 | - scipy=1.5.2=py37h0b6359f_0 141 | - setuptools=50.3.0=py37hb0f4dca_1 142 | - sip=4.19.8=py37hf484d3e_0 143 | - six=1.15.0=py_0 144 | - snappy=1.1.8=he1b5a44_3 145 | - sqlite=3.33.0=h62c20be_0 146 | - tensorboard=1.15.0=pyhb230dea_0 147 | - tensorflow=1.15.0=mkl_py37h28c19af_0 148 | - tensorflow-estimator=1.15.1=pyh2649769_0 149 | - termcolor=1.1.0=py37h06a4308_1 150 | - threadpoolctl=2.2.0=pyh0d69192_0 151 | - tifffile=2020.10.1=py37hdd07704_2 152 | - tk=8.6.10=hbc83047_0 153 | - toolz=0.11.1=py_0 154 | - tornado=6.0.4=py37h8f50634_2 155 | - tqdm=4.51.0=pyh9f0ad1d_0 156 | - traitlets=5.0.5=py_0 157 | - typing_extensions=3.10.0.2=pyh06a4308_0 158 | - wcwidth=0.2.5=pyh9f0ad1d_2 159 | - webencodings=0.5.1=py37_1 160 | - werkzeug=0.16.1=py_0 161 | - wheel=0.35.1=py_0 162 | - wrapt=1.12.1=py37h7b6447c_1 163 | - xorg-fixesproto=5.0=h14c3975_1002 164 | - xorg-inputproto=2.3.2=h14c3975_1002 165 | - xorg-kbproto=1.0.7=h14c3975_1002 166 | - xorg-libice=1.0.10=h516909a_0 167 | - xorg-libsm=1.2.3=h84519dc_1000 168 | - xorg-libx11=1.6.12=h516909a_0 169 | - xorg-libxau=1.0.9=h14c3975_0 170 | - xorg-libxdmcp=1.1.3=h516909a_0 171 | - xorg-libxext=1.3.4=h516909a_0 172 | - xorg-libxfixes=5.0.3=h516909a_1004 173 | - xorg-libxi=1.7.10=h516909a_0 174 | - xorg-libxrender=0.9.10=h516909a_1002 175 | - xorg-renderproto=0.11.1=h14c3975_1002 176 | - xorg-xextproto=7.3.0=h14c3975_1002 177 | - xorg-xproto=7.0.31=h14c3975_1007 178 | - xz=5.2.5=h7b6447c_0 179 | - yaml=0.2.5=h516909a_0 180 | - zipp=3.5.0=pyhd3eb1b0_0 181 | - zlib=1.2.11=h7b6447c_3 182 | - zstd=1.4.5=h9ceee32_0 183 | - pip: 184 | - anykeystore==0.2 185 | - apex==0.1 186 | - bm3d==3.0.7 187 | - chardet==3.0.4 188 | - cryptacular==1.6.2 189 | - defusedxml==0.7.1 190 | - future==0.18.2 191 | - greenlet==1.1.2 192 | - hupper==1.10.3 193 | - idna==2.10 194 | - inplace-abn==1.1.0 195 | - jsonpatch==1.26 196 | - jsonpointer==2.0 197 | - llvmlite==0.35.0 198 | - markupsafe==2.0.1 199 | - nibabel==3.2.1 200 | - numba==0.52.0 201 | - numpy==1.16.0 202 | - oauthlib==3.1.1 203 | - packaging==21.0 204 | - pastedeploy==2.1.1 205 | - pbkdf2==1.3 206 | - plaster==1.0 207 | - plaster-pastedeploy==0.7 208 | - protobuf==3.14.0 209 | - pyramid==2.0 210 | - pyramid-mailer==0.15.1 211 | - pysptools==0.15.0 212 | - python3-openid==3.2.0 213 | - pyzmq==20.0.0 214 | - repoze-sendmail==4.4.1 215 | - requests==2.25.0 216 | - requests-oauthlib==1.3.0 217 | - sqlalchemy==1.4.25 218 | - tensorboardx==2.1 219 | - tensorflow-gpu==1.15.0 220 | - torch==1.6.0+cu101 221 | - torchaudio==0.8.2 222 | - torchfile==0.1.0 223 | - torchnet==0.0.4 224 | - torchvision==0.7.0+cu101 225 | - transaction==3.0.1 226 | - translationstring==1.4 227 | - urllib3==1.26.2 228 | - velruse==1.1.1 229 | - venusian==3.0.0 230 | - visdom==0.1.8.9 231 | - webob==1.8.7 232 | - websocket-client==0.57.0 233 | - wtforms==2.3.3 234 | - wtforms-recaptcha==0.3.2 235 | - zope-deprecation==4.4.0 236 | - zope-interface==5.4.0 237 | - zope-sqlalchemy==1.6 238 | prefix: /home/ironkitty/data/miniconda3/envs/torch_37 239 | -------------------------------------------------------------------------------- /train_iid.sh: -------------------------------------------------------------------------------- 1 | conda activate torch_37 2 | python hsi_denoising_gauss_iid.py --batchSize 16 -a nssnn -p 2mats_iid \ 3 | --dataroot ./datasets/ICVL64_31_2mats.db --gpu-ids 3 \ 4 | -tr datasets/test/ICVL/iid/ -gr datasets/test/ICVL/gt/ \ 5 | --lr 1e-3 6 | 7 | python hsi_denoising_gauss_iid.py --batchSize 16 -a nssnn -p iid \ 8 | --dataroot ./datasets/ICVL64_31.db --gpu-ids 3 \ 9 | -tr datasets/test/ICVL/iid/ -gr datasets/test/ICVL/gt/ \ 10 | --lr 1e-3 \ 11 | -r -rp checkpoints/nssnn/2mats_iid/model_latest.pth \ 12 | --resetepoch 15 -------------------------------------------------------------------------------- /train_mixtrue.sh: -------------------------------------------------------------------------------- 1 | conda activate torch_37 2 | python hsi_denoising_complex.py --batchSize 16 -a nssnn -p mxiture \ 3 | --dataroot ./datasets/ICVL64_31.db --gpu-ids 2 \ 4 | -tr datasets/test/ICVL/mixture/95_mixture/ -gr datasets/test/ICVL/gt/ \ 5 | --lr 1e-4 \ 6 | -r -rp checkpoints/nssnn/niid/model_latest.pth \ 7 | --resetepoch 50 8 | 9 | -------------------------------------------------------------------------------- /train_niid.sh: -------------------------------------------------------------------------------- 1 | conda activate torch_37 2 | python hsi_denoising_gauss_niid.py --batchSize 16 -a nssnn -p 2mats_niid \ 3 | --dataroot ./datasets/ICVL64_31_2mats.db --gpu-ids 0 \ 4 | -tr datasets/test/ICVL/niid/ -gr datasets/test/ICVL/gt/ \ 5 | --lr 1e-3 6 | 7 | python hsi_denoising_gauss_niid.py --batchSize 16 -a nssnn -p niid \ 8 | --dataroot ./datasets/ICVL64_31.db --gpu-ids 4 \ 9 | -tr datasets/test/ICVL/niid/ -gr datasets/test/ICVL/gt/ \ 10 | --lr 1e-3 \ 11 | -r -rp checkpoints/nssnn/2mats_niid/model_latest.pth \ 12 | --resetepoch 15 -------------------------------------------------------------------------------- /utility/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .util import * 3 | from .helper import * 4 | from .lmdb_dataset import LMDBDataset 5 | from .indexes import * 6 | from .gauss import * 7 | from .read_HSI import * 8 | from .refold import * -------------------------------------------------------------------------------- /utility/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/data_parallel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/data_parallel.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/dataloaders_hsi_test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/dataloaders_hsi_test.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/dataloaders_hsi_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/dataloaders_hsi_test.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/gauss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/gauss.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/gauss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/gauss.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/helper.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/helper.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/indexes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/indexes.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/indexes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/indexes.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/lmdb_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/lmdb_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/lmdb_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/lmdb_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/read_HSI.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/read_HSI.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/read_HSI.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/read_HSI.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/refold.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/refold.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/refold.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/refold.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/ssim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/ssim.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/ssim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/ssim.cpython-38.pyc -------------------------------------------------------------------------------- /utility/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lronkitty/NSSNN/16b1dded446e151458aee992f24ef8f85c2a293f/utility/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /utility/data_parallel.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DataParallel 2 | import torch 3 | from torch.nn.parallel._functions import Scatter 4 | from torch.nn.parallel.parallel_apply import parallel_apply 5 | 6 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 7 | r""" 8 | Slices tensors into approximately equal chunks and 9 | distributes them across given GPUs. Duplicates 10 | references to objects that are not tensors. 11 | """ 12 | def scatter_map(obj): 13 | if isinstance(obj, torch.Tensor): 14 | try: 15 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 16 | except: 17 | print('obj', obj.size()) 18 | print('dim', dim) 19 | print('chunk_sizes', chunk_sizes) 20 | quit() 21 | if isinstance(obj, tuple) and len(obj) > 0: 22 | return list(zip(*map(scatter_map, obj))) 23 | if isinstance(obj, list) and len(obj) > 0: 24 | return list(map(list, zip(*map(scatter_map, obj)))) 25 | if isinstance(obj, dict) and len(obj) > 0: 26 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 27 | return [obj for targets in target_gpus] 28 | 29 | # After scatter_map is called, a scatter_map cell will exist. This cell 30 | # has a reference to the actual function scatter_map, which has references 31 | # to a closure that has a reference to the scatter_map cell (because the 32 | # fn is recursive). To avoid this reference cycle, we set the function to 33 | # None, clearing the cell 34 | try: 35 | return scatter_map(inputs) 36 | finally: 37 | scatter_map = None 38 | 39 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 40 | r"""Scatter with support for kwargs dictionary""" 41 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 42 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 43 | if len(inputs) < len(kwargs): 44 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 45 | elif len(kwargs) < len(inputs): 46 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 47 | inputs = tuple(inputs) 48 | kwargs = tuple(kwargs) 49 | return inputs, kwargs 50 | 51 | class BalancedDataParallel(DataParallel): 52 | def __init__(self, gpu0_bsz, *args, **kwargs): 53 | self.gpu0_bsz = gpu0_bsz 54 | super().__init__(*args, **kwargs) 55 | 56 | def forward(self, *inputs, **kwargs): 57 | if not self.device_ids: 58 | return self.module(*inputs, **kwargs) 59 | if self.gpu0_bsz == 0: 60 | device_ids = self.device_ids[1:] 61 | else: 62 | device_ids = self.device_ids 63 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 64 | if len(self.device_ids) == 1: 65 | return self.module(*inputs[0], **kwargs[0]) 66 | replicas = self.replicate(self.module, self.device_ids) 67 | if self.gpu0_bsz == 0: 68 | replicas = replicas[1:] 69 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 70 | return self.gather(outputs, self.output_device) 71 | 72 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 73 | return parallel_apply(replicas, inputs, kwargs, device_ids) 74 | 75 | def scatter(self, inputs, kwargs, device_ids): 76 | bsz = inputs[0].size(self.dim) 77 | num_dev = len(self.device_ids) 78 | gpu0_bsz = self.gpu0_bsz 79 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 80 | if gpu0_bsz < bsz_unit: 81 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 82 | delta = bsz - sum(chunk_sizes) 83 | for i in range(delta): 84 | chunk_sizes[i + 1] += 1 85 | if gpu0_bsz == 0: 86 | chunk_sizes = chunk_sizes[1:] 87 | else: 88 | return super().scatter(inputs, kwargs, device_ids) 89 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) -------------------------------------------------------------------------------- /utility/dataloaders_hsi_test.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import Dataset 3 | from os import listdir, path 4 | from PIL import Image 5 | import torch 6 | import math 7 | import torchvision.transforms.functional as TF 8 | import random 9 | from typing import Sequence 10 | from itertools import repeat 11 | import scipy.io as scio 12 | import numpy as np 13 | import torch 14 | import re 15 | from torch._six import container_abcs, string_classes, int_classes 16 | 17 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 18 | def repeater(data_loader): 19 | for loader in repeat(data_loader): 20 | for data in loader: 21 | yield data 22 | class MyResize: 23 | def __init__(self, scale,crop): 24 | self.scale = scale 25 | self.crop = crop 26 | 27 | 28 | def __call__(self, x): 29 | bands = x.shape[2] 30 | if bands > 31: 31 | bs = int(np.random.rand(1) * bands) 32 | if bs + 31 > bands: 33 | bs = bands - 31 34 | x = x[:, :, bs:bs + 31] 35 | im_sz=x.shape 36 | rs=[int(im_sz[0]*self.scale),int(im_sz[1]*self.scale)] 37 | if rs[0] _w: 126 | x2 = _w 127 | x = _w - self.size 128 | if y2 > _h: 129 | y2 = _h 130 | y = _h - self.size 131 | cropImg = img[(x):(x2), (y):(y2), :] 132 | return cropImg 133 | 134 | # return self.cropit(img,self.size) 135 | # return img 136 | def cropit(image, crop_size): 137 | _w, _h, _b = image.shape 138 | x = random.randint(1, _w) 139 | y = random.randint(1, _h) 140 | x2 = x + crop_size 141 | y2 = y + crop_size 142 | if x2 > _w: 143 | x2 = _w 144 | x = _w - crop_size 145 | if y2 > _h: 146 | y2 = _h 147 | y = _h - crop_size 148 | cropImg = image[(x):(x2), (y):(y2), :] 149 | return cropImg 150 | class MyToTensor(object): 151 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 152 | 153 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 154 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 155 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 156 | or if the numpy.ndarray has dtype = np.uint8 157 | 158 | In the other cases, tensors are returned without scaling. 159 | """ 160 | 161 | def __call__(self, pic): 162 | """ 163 | Args: 164 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 165 | 166 | Returns: 167 | Tensor: Converted image. 168 | """ 169 | return TF.to_tensor(pic.copy()) 170 | 171 | def __repr__(self): 172 | return self.__class__.__name__ + '()' 173 | 174 | 175 | 176 | 177 | class Dataset(Dataset): 178 | def __init__(self, root_dirs, transform=None, verbose=False, grey=False): 179 | self.root_dirs = root_dirs 180 | self.transform = transform 181 | self.images_path = [] 182 | for cur_path in root_dirs: 183 | self.images_path += [path.join(cur_path, file) for file in listdir(cur_path) if file.endswith(('tif','png','jpg','jpeg','bmp','mat'))] 184 | self.verbose = verbose 185 | self.grey = grey 186 | 187 | def __len__(self): 188 | return len(self.images_path) 189 | 190 | def __getitem__(self, idx): 191 | img_name = self.images_path[idx] 192 | 193 | if self.grey: 194 | image = Image.open(img_name).convert('L') 195 | else: 196 | # image = Image.open(img_name).convert('RGB') 197 | image = scio.loadmat(img_name)['DataCube'].astype(np.float32) 198 | # image=image/image.max() 199 | # image = flipit(flipit(cropit(image,crop_size=128),[0,1]),[1,0]) 200 | 201 | # image=transforms.ToPILImage(image) 202 | if self.transform: 203 | image = self.transform(image) 204 | 205 | 206 | if self.verbose: 207 | return image, img_name.split('/')[-1] 208 | 209 | return image 210 | def get_gt(gt_path, img_name,verbose=False, grey=False): 211 | tfs = [] 212 | tfs += [ 213 | # MyRotation90(), 214 | # MyCenterCrop(), 215 | MyToTensor() 216 | ] 217 | gt_transforms = transforms.Compose(tfs) 218 | image = scio.loadmat(gt_path+img_name)['DataCube'].astype(np.float32) 219 | image = gt_transforms(image) 220 | # image=image/image.max() 221 | return image 222 | 223 | def get_dataloaders(test_path_list, crop_size=96, batch_size=1, downscale=0, 224 | drop_last=True, concat=True, n_worker=0, scale_min=0.001, scale_max=0.1, verbose=False, grey=False): 225 | 226 | batch_sizes = {'test':1, 'gt': 1} 227 | test_transforms = transforms.Compose([MyToTensor()]) 228 | data_transforms = {'test': test_transforms} 229 | image_datasets = {'test': Dataset(test_path_list, data_transforms['test'], verbose=verbose, grey=grey)} 230 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_sizes[x], 231 | num_workers=n_worker,drop_last=drop_last, shuffle=False) for x in ['test']} 232 | return dataloaders 233 | 234 | def flipit(image, axes): 235 | 236 | if axes[0]: 237 | image = np.fliplr(image) 238 | if axes[1]: 239 | image = np.flipud(image) 240 | 241 | return image 242 | default_collate_err_msg_format = ( 243 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 244 | "dicts or lists; found {}") 245 | 246 | 247 | # def cropit(image, seg=None, margin=5): 248 | # 249 | # fixedaxes = np.argmin(image.shape[:2]) 250 | # trimaxes = 0 if fixedaxes == 1 else 1 251 | # trim = image.shape[fixedaxes] 252 | # center = image.shape[trimaxes] // 2 253 | # if seg is not None: 254 | # 255 | # hits = np.where(seg != 0) 256 | # mins = np.argmin(hits, axis=1) 257 | # maxs = np.argmax(hits, axis=1) 258 | # 259 | # if center - (trim // 2) > mins[0]: 260 | # while center - (trim // 2) > mins[0]: 261 | # center = center - 1 262 | # center = center + margin 263 | # 264 | # if center + (trim // 2) < maxs[0]: 265 | # while center + (trim // 2) < maxs[0]: 266 | # center = center + 1 267 | # center = center + margin 268 | # 269 | # top = max(0, center - (trim // 2)) 270 | # bottom = trim if top == 0 else center + (trim // 2) 271 | # 272 | # if bottom > image.shape[trimaxes]: 273 | # bottom = image.shape[trimaxes] 274 | # top = image.shape[trimaxes] - trim 275 | # 276 | # if trimaxes == 0: 277 | # image = image[top: bottom, :, :] 278 | # else: 279 | # image = image[:, top: bottom, :] 280 | # 281 | # if seg is not None: 282 | # if trimaxes == 0: 283 | # seg = seg[top: bottom, :, :] 284 | # else: 285 | # seg = seg[:, top: bottom, :] 286 | # 287 | # return image, seg 288 | # else: 289 | # return image 290 | 291 | -------------------------------------------------------------------------------- /utility/gauss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module providing functionality surrounding gaussian function. 3 | """ 4 | SVN_REVISION = '$LastChangedRevision: 16541 $' 5 | 6 | import sys 7 | import numpy 8 | 9 | 10 | def gaussian2(size, sigma): 11 | """Returns a normalized circularly symmetric 2D gauss kernel array 12 | 13 | f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where 14 | 15 | A = 1/(2*pi*sigma^2) 16 | 17 | as define by Wolfram Mathworld 18 | http://mathworld.wolfram.com/GaussianFunction.html 19 | """ 20 | A = 1 / (2.0 * numpy.pi * sigma ** 2) 21 | x, y = numpy.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 22 | g = A * numpy.exp(-((x ** 2 / (2.0 * sigma ** 2)) + (y ** 2 / (2.0 * sigma ** 2)))) 23 | return g 24 | 25 | 26 | def fspecial_gauss(size, sigma): 27 | """Function to mimic the 'fspecial' gaussian MATLAB function 28 | """ 29 | x, y = numpy.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 30 | g = numpy.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) 31 | return g / g.sum() 32 | 33 | 34 | def main(): 35 | """Show simple use cases for functionality provided by this module.""" 36 | from mpl_toolkits.mplot3d.axes3d import Axes3D 37 | import pylab 38 | argv = sys.argv 39 | if len(argv) != 3: 40 | print >> sys.stderr, 'usage: python -m pim.sp.gauss size sigma' 41 | sys.exit(2) 42 | size = int(argv[1]) 43 | sigma = float(argv[2]) 44 | x, y = numpy.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 45 | 46 | fig = pylab.figure() 47 | fig.suptitle('Some 2-D Gauss Functions') 48 | ax = fig.add_subplot(2, 1, 1, projection='3d') 49 | ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, 50 | linewidth=0, antialiased=False, cmap=pylab.jet()) 51 | ax = fig.add_subplot(2, 1, 2, projection='3d') 52 | ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, 53 | linewidth=0, antialiased=False, cmap=pylab.jet()) 54 | pylab.show() 55 | return 0 56 | 57 | 58 | if __name__ == '__main__': 59 | sys.exit(main()) 60 | # {"mode": "full", "isActive": false}% -------------------------------------------------------------------------------- /utility/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | 10 | from tensorboardX import SummaryWriter 11 | import socket 12 | from datetime import datetime 13 | 14 | from models.sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 15 | 16 | 17 | def adjust_learning_rate(optimizer, lr): 18 | print('Adjust Learning Rate => %.4e' %lr) 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | # param_group['initial_lr'] = lr 22 | 23 | 24 | def display_learning_rate(optimizer): 25 | lrs = [] 26 | for i, param_group in enumerate(optimizer.param_groups): 27 | lr = param_group['lr'] 28 | print('learning rate of group %d: %.4e' % (i, lr)) 29 | lrs.append(lr) 30 | return lrs 31 | 32 | 33 | def adjust_opt_params(optimizer, param_dict): 34 | print('Adjust Optimizer Parameters => %s' %param_dict) 35 | for param_group in optimizer.param_groups: 36 | for k, v in param_dict.items(): 37 | param_group[k] = v 38 | 39 | 40 | def display_opt_params(optimizer, keys): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | for k in keys: 43 | v = param_group[k] 44 | print('%s of group %d: %.4e' % (k,i,v)) 45 | 46 | 47 | def set_bn_eval(m): 48 | classname = m.__class__.__name__ 49 | if classname.find('BatchNorm') != -1: 50 | m.weight.requires_grad = False 51 | m.bias.requires_grad = False 52 | m.eval() 53 | 54 | 55 | def get_summary_writer(log_dir, prefix=None): 56 | # log_dir = './checkpoints/%s/logs'%(arch) 57 | if not os.path.exists(log_dir): 58 | os.mkdir(log_dir) 59 | if prefix is None: 60 | log_dir = os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S')+'_'+socket.gethostname()) 61 | else: 62 | log_dir = os.path.join(log_dir, prefix+'_'+datetime.now().strftime('%b%d_%H-%M-%S')+'_'+socket.gethostname()) 63 | if not os.path.exists(log_dir): 64 | os.mkdir(log_dir) 65 | writer = SummaryWriter(log_dir) 66 | return writer 67 | 68 | 69 | def init_params(net, init_type='kn'): 70 | print('use init scheme: %s' %init_type) 71 | if init_type != 'edsr': 72 | for m in net.modules(): 73 | if isinstance(m, (nn.Conv2d, nn.Conv3d)): 74 | if init_type == 'kn': 75 | init.kaiming_normal_(m.weight, mode='fan_out') 76 | if init_type == 'ku': 77 | init.kaiming_uniform_(m.weight, mode='fan_out') 78 | if init_type == 'xn': 79 | init.xavier_normal_(m.weight) 80 | if init_type == 'xu': 81 | init.xavier_uniform_(m.weight) 82 | if m.bias is not None: 83 | init.constant_(m.bias, 0) 84 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d)): 85 | init.constant_(m.weight, 1) 86 | if m.bias is not None: 87 | init.constant_(m.bias, 0) 88 | elif isinstance(m, nn.Linear): 89 | init.normal_(m.weight, std=1e-3) 90 | if m.bias is not None: 91 | init.constant_(m.bias, 0) 92 | 93 | 94 | _, term_width = os.popen('stty size', 'r').read().split() 95 | term_width = int(term_width) 96 | 97 | TOTAL_BAR_LENGTH = 30. 98 | last_time = time.time() 99 | begin_time = last_time 100 | def progress_bar(current, total, msg=None): 101 | global last_time, begin_time 102 | if current == 0: 103 | begin_time = time.time() # Reset for new bar. 104 | 105 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 106 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 107 | 108 | sys.stdout.write(' [') 109 | for i in range(cur_len): 110 | sys.stdout.write('=') 111 | sys.stdout.write('>') 112 | for i in range(rest_len): 113 | sys.stdout.write('.') 114 | sys.stdout.write(']') 115 | 116 | cur_time = time.time() 117 | step_time = cur_time - last_time 118 | last_time = cur_time 119 | tot_time = cur_time - begin_time 120 | 121 | L = [] 122 | L.append(' Step: %s' % format_time(step_time)) 123 | L.append(' | Tot: %s' % format_time(tot_time)) 124 | if msg: 125 | L.append(' | ' + msg) 126 | 127 | msg = ''.join(L) 128 | sys.stdout.write(msg) 129 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 130 | sys.stdout.write(' ') 131 | 132 | # Go back to the center of the bar. 133 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 134 | sys.stdout.write('\b') 135 | sys.stdout.write(' %d/%d ' % (current+1, total)) 136 | 137 | if current < total-1: 138 | sys.stdout.write('\r') 139 | else: 140 | sys.stdout.write('\n') 141 | sys.stdout.flush() 142 | 143 | def format_time(seconds): 144 | days = int(seconds / 3600/24) 145 | seconds = seconds - days*3600*24 146 | hours = int(seconds / 3600) 147 | seconds = seconds - hours*3600 148 | minutes = int(seconds / 60) 149 | seconds = seconds - minutes*60 150 | secondsf = int(seconds) 151 | seconds = seconds - secondsf 152 | millis = int(seconds*1000) 153 | 154 | f = '' 155 | i = 1 156 | if days > 0: 157 | f += str(days) + 'D' 158 | i += 1 159 | if hours > 0 and i <= 2: 160 | f += str(hours) + 'h' 161 | i += 1 162 | if minutes > 0 and i <= 2: 163 | f += str(minutes) + 'm' 164 | i += 1 165 | if secondsf > 0 and i <= 2: 166 | f += str(secondsf) + 's' 167 | i += 1 168 | if millis > 0 and i <= 2: 169 | f += str(millis) + 'ms' 170 | i += 1 171 | if f == '': 172 | f = '0ms' 173 | return f 174 | -------------------------------------------------------------------------------- /utility/indexes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage.measure import compare_ssim, compare_psnr 4 | from functools import partial 5 | from utility.gauss import fspecial_gauss 6 | from scipy import signal 7 | 8 | class Bandwise(object): 9 | def __init__(self, index_fn): 10 | self.index_fn = index_fn 11 | 12 | def __call__(self, X, Y): 13 | C = X.shape[-3] 14 | bwindex = [] 15 | for ch in range(C): 16 | x = torch.squeeze(X[...,ch,:,:].data).cpu().numpy() 17 | y = torch.squeeze(Y[...,ch,:,:].data).cpu().numpy() 18 | index = self.index_fn(x, y) 19 | bwindex.append(index) 20 | return bwindex 21 | 22 | def ssim(img1, img2, cs_map=False): 23 | """Return the Structural Similarity Map corresponding to input images img1 24 | and img2 (images are assumed to be uint8) 25 | 26 | This function attempts to mimic precisely the functionality of ssim.m a 27 | MATLAB provided by the author's of SSIM 28 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m 29 | """ 30 | img1 = img1.astype(np.float64) 31 | img2 = img2.astype(np.float64) 32 | size = 11 33 | sigma = 1.5 34 | window = fspecial_gauss(size, sigma) 35 | K1 = 0.01 36 | K2 = 0.03 37 | L = 255 # bitdepth of image 38 | C1 = (K1 * L) ** 2 39 | C2 = (K2 * L) ** 2 40 | mu1 = signal.fftconvolve(window, img1, mode='valid') 41 | mu2 = signal.fftconvolve(window, img2, mode='valid') 42 | mu1_sq = mu1 * mu1 43 | mu2_sq = mu2 * mu2 44 | mu1_mu2 = mu1 * mu2 45 | sigma1_sq = signal.fftconvolve(window, img1 * img1, mode='valid') - mu1_sq 46 | sigma2_sq = signal.fftconvolve(window, img2 * img2, mode='valid') - mu2_sq 47 | sigma12 = signal.fftconvolve(window, img1 * img2, mode='valid') - mu1_mu2 48 | if cs_map: 49 | return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 50 | (sigma1_sq + sigma2_sq + C2)), 51 | (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) 52 | else: 53 | return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 54 | (sigma1_sq + sigma2_sq + C2)) 55 | def mse (GT,P): 56 | """calculates mean squared error (mse). 57 | 58 | :param GT: first (original) input image. 59 | :param P: second (deformed) input image. 60 | 61 | :returns: float -- mse value. 62 | """ 63 | # GT,P = _initial_check(GT,P) 64 | return np.mean((GT.astype(np.float32)-P.astype(np.float32))**2) 65 | cal_bwssim = Bandwise(compare_ssim) 66 | cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=1)) 67 | 68 | 69 | def cal_sam(X, Y, eps=1e-8): 70 | X = torch.squeeze(X.data).cpu().numpy() 71 | Y = torch.squeeze(Y.data).cpu().numpy() 72 | tmp = (np.sum(X*Y, axis=0) + eps) /( (np.sqrt(np.sum(X**2, axis=0)))* (np.sqrt(np.sum(Y**2, axis=0))) + eps) 73 | return np.mean(np.real(np.arccos(tmp))) 74 | 75 | def cal_ssim(im_true,im_test,eps=13-8): 76 | # print(im_true.shape) 77 | im_true=im_true.squeeze(0).squeeze(0).cpu().numpy() 78 | im_test = im_test.squeeze(0).squeeze(0).cpu().numpy() 79 | c,_,_=im_true.shape 80 | bwindex = [] 81 | for i in range(c): 82 | bwindex.append(ssim(im_true[i,:,:]*255, im_test[i,:,:,]*255)) 83 | return np.mean(bwindex) 84 | def MSIQA(X, Y): 85 | 86 | psnr = np.mean(cal_bwpsnr(X, Y)) 87 | ssim = cal_ssim(Y,X) 88 | sam = cal_sam(X, Y) 89 | return psnr, ssim, sam 90 | -------------------------------------------------------------------------------- /utility/indexes_back.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage.measure import compare_ssim, compare_psnr 4 | from functools import partial 5 | 6 | 7 | class Bandwise(object): 8 | def __init__(self, index_fn): 9 | self.index_fn = index_fn 10 | 11 | def __call__(self, X, Y): 12 | C = X.shape[-3] 13 | bwindex = [] 14 | for ch in range(C): 15 | x = torch.squeeze(X[...,ch,:,:].data).cpu().numpy() 16 | y = torch.squeeze(Y[...,ch,:,:].data).cpu().numpy() 17 | index = self.index_fn(x, y) 18 | bwindex.append(index) 19 | return bwindex 20 | 21 | 22 | cal_bwssim = Bandwise(compare_ssim) 23 | cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=1)) 24 | 25 | 26 | def cal_sam(X, Y, eps=1e-8): 27 | X = torch.squeeze(X.data).cpu().numpy() 28 | Y = torch.squeeze(Y.data).cpu().numpy() 29 | tmp = (np.sum(X*Y, axis=0) + eps) / (np.sqrt(np.sum(X**2, axis=0)) + eps) / (np.sqrt(np.sum(Y**2, axis=0)) + eps) 30 | return np.mean(np.real(np.arccos(tmp))) 31 | 32 | 33 | def MSIQA(X, Y): 34 | psnr = np.mean(cal_bwpsnr(X, Y)) 35 | ssim = np.mean(cal_bwssim(X, Y)) 36 | sam = cal_sam(X, Y) 37 | return psnr, ssim, sam 38 | -------------------------------------------------------------------------------- /utility/lmdb_data.py: -------------------------------------------------------------------------------- 1 | """Create lmdb dataset""" 2 | from termios import XCASE 3 | from util import * 4 | import lmdb 5 | import caffe 6 | import scipy.io 7 | 8 | def create_lmdb_train( 9 | datadir, fns, name, matkey, 10 | crop_sizes, scales, ksizes, strides, 11 | load=h5py.File, augment=True, 12 | seed=2017,trans=1,norm=0,map_size = None): 13 | """ 14 | Create Augmented Dataset 15 | """ 16 | def preprocess(data): 17 | new_data = [] 18 | if trans == 1: 19 | data[data < 0] = 0 20 | data = minmax_normalize(data) 21 | data = np.rot90(data, k=2, axes=(1,2)) # ICVL 22 | # data = minmax_normalize(data.transpose((2,0,1))) # for Remote Sensing 23 | # Visualize3D(data) 24 | if crop_sizes is not None: 25 | data = crop_center(data, crop_sizes[0], crop_sizes[1]) 26 | 27 | for i in range(len(scales)): 28 | if scales[i] != 1: 29 | temp = zoom(data, zoom=(1, scales[i], scales[i])) 30 | else: 31 | temp = data 32 | # print(temp.shape) 33 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i])) 34 | new_data.append(temp) 35 | new_data = np.concatenate(new_data, axis=0) 36 | if augment: 37 | for i in range(new_data.shape[0]): 38 | new_data[i,...] = data_augmentation(new_data[i, ...]) 39 | 40 | return new_data.astype(np.float32) 41 | 42 | np.random.seed(seed) 43 | scales = list(scales) 44 | ksizes = list(ksizes) 45 | assert len(scales) == len(strides) 46 | # calculate the shape of dataset 47 | data = load(datadir + fns[0]) 48 | data= data[matkey] 49 | if trans: 50 | data = np.transpose(data,(2,0,1)) 51 | # print(data.shape) 52 | if map_size is None: 53 | data = preprocess(data) 54 | N = data.shape[0] 55 | 56 | # print(data.shape) 57 | map_size = data.nbytes * len(fns) * 1.2 58 | print('map size (GB):', map_size / 1024 / 1024 / 1024) 59 | 60 | #import ipdb#; ipdb.set_trace() 61 | if os.path.exists(name+'.db'): 62 | raise Exception('database already exist!') 63 | env = lmdb.open(name+'.db', map_size=map_size, writemap=True) 64 | with env.begin(write=True) as txn: 65 | # txn is a Transaction object 66 | k = 0 67 | for i, fn in enumerate(fns): 68 | try: 69 | X = load(datadir + fn)[matkey] 70 | if trans: 71 | # print(X.shape) 72 | X = np.transpose(X,(2,0,1)) 73 | except: 74 | print('loading', datadir+fn, 'fail') 75 | continue 76 | X = preprocess(X) 77 | N = X.shape[0] 78 | for j in range(N): 79 | # print(X[j].max(), X[j].min()) 80 | if X[j].min() < -100: 81 | continue 82 | elif X[j].max() == 0: 83 | continue 84 | datum = caffe.proto.caffe_pb2.Datum() 85 | datum.channels = X.shape[1] 86 | datum.height = X.shape[2] 87 | datum.width = X.shape[3] 88 | # print(X[j].max(), X[j].min()) 89 | if norm == 0: 90 | datum.data = X[j].tobytes() 91 | else: 92 | datum.data = minmax_normalize(X[j]).tobytes() 93 | str_id = '{:08}'.format(k) 94 | k += 1 95 | txn.put(str_id.encode('ascii'), datum.SerializeToString()) 96 | print('load mat (%d/%d): %s' %(i,len(fns),fn)) 97 | print(k) 98 | 99 | print('done') 100 | 101 | 102 | # Create 2mats ICVL training dataset 103 | def create_icvl64_31_2mats(): 104 | print('create icvl64_31...') 105 | datadir = 'datasets/training/ICVL/train2_2mats/' # your own data address 106 | fns = os.listdir(datadir) 107 | fns = [fn.split('.')[0]+'.mat' for fn in fns] 108 | 109 | create_lmdb_train( 110 | datadir, fns, 'datasets/ICVL64_31_2mats', 'rad', # your own dataset address 111 | crop_sizes=(1024, 1024), 112 | scales=(1, 0.5, 0.25), 113 | ksizes=(31, 64, 64), 114 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)], 115 | load=h5py.File, augment=True,trans=0 116 | ) 117 | 118 | # Create ICVL training dataset 119 | def create_icvl64_31(): 120 | print('create icvl64_31...') 121 | datadir = 'datasets/training/ICVL/train/' # your own data address 122 | fns = os.listdir(datadir) 123 | fns = [fn.split('.')[0]+'.mat' for fn in fns] 124 | 125 | create_lmdb_train( 126 | datadir, fns, 'datasets/ICVL64_31', 'rad', # your own dataset address 127 | crop_sizes=(1024, 1024), 128 | scales=(1, 0.5, 0.25), 129 | ksizes=(31, 64, 64), 130 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)], 131 | load=h5py.File, augment=True,trans=0 132 | ) 133 | 134 | def create_houston64_stride_24_46_norm(): 135 | print('create houston64_46_norm...') 136 | datadir = 'datasets/training/Houston/' # your own data address 137 | fns = os.listdir(datadir) 138 | fns = [fn.split('.')[0]+'.mat' for fn in fns] 139 | 140 | create_lmdb_train( 141 | datadir, fns, 'datasets/houston64_stride_24_46_norm', 'houston', # your own dataset address 142 | crop_sizes=None, 143 | scales=(1, 0.5, 0.25), 144 | ksizes=(46, 64, 64), 145 | strides=[(10, 24, 24), (5, 12, 12), (5, 12, 12)], 146 | load=scipy.io.loadmat, augment=True,norm=1,map_size =15827848396 147 | ) 148 | 149 | if __name__ == '__main__': 150 | # create_icvl64_31_2mats() 151 | # create_icvl64_31() 152 | create_houston64_stride_24_46_norm() 153 | pass 154 | -------------------------------------------------------------------------------- /utility/lmdb_data_ori.py: -------------------------------------------------------------------------------- 1 | """Create lmdb dataset""" 2 | from util import * 3 | import lmdb 4 | import caffe 5 | 6 | 7 | def create_lmdb_train( 8 | datadir, fns, name, matkey, 9 | crop_sizes, scales, ksizes, strides, 10 | load=h5py.File, augment=True, 11 | seed=2017): 12 | """ 13 | Create Augmented Dataset 14 | """ 15 | def preprocess(data): 16 | new_data = [] 17 | data = minmax_normalize(data) 18 | data = np.rot90(data, k=2, axes=(1,2)) # ICVL 19 | # data = minmax_normalize(data.transpose((2,0,1))) # for Remote Sensing 20 | # Visualize3D(data) 21 | if crop_sizes is not None: 22 | data = crop_center(data, crop_sizes[0], crop_sizes[1]) 23 | 24 | for i in range(len(scales)): 25 | if scales[i] != 1: 26 | temp = zoom(data, zoom=(1, scales[i], scales[i])) 27 | else: 28 | temp = data 29 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i])) 30 | new_data.append(temp) 31 | new_data = np.concatenate(new_data, axis=0) 32 | if augment: 33 | for i in range(new_data.shape[0]): 34 | new_data[i,...] = data_augmentation(new_data[i, ...]) 35 | 36 | return new_data.astype(np.float32) 37 | 38 | np.random.seed(seed) 39 | scales = list(scales) 40 | ksizes = list(ksizes) 41 | assert len(scales) == len(strides) 42 | # calculate the shape of dataset 43 | data = load(datadir + fns[0])[matkey] 44 | data = preprocess(data) 45 | N = data.shape[0] 46 | 47 | print(data.shape) 48 | map_size = data.nbytes * len(fns) * 1.2 49 | print('map size (GB):', map_size / 1024 / 1024 / 1024) 50 | 51 | if os.path.exists(name+'.db'): 52 | raise Exception('database already exist!') 53 | env = lmdb.open(name+'.db', map_size=map_size, writemap=True) 54 | with env.begin(write=True) as txn: 55 | # txn is a Transaction object 56 | k = 0 57 | for i, fn in enumerate(fns): 58 | try: 59 | X = load(datadir + fn)[matkey] 60 | except: 61 | print('loading', datadir+fn, 'fail') 62 | continue 63 | X = preprocess(X) 64 | N = X.shape[0] 65 | for j in range(N): 66 | datum = caffe.proto.caffe_pb2.Datum() 67 | datum.channels = X.shape[1] 68 | datum.height = X.shape[2] 69 | datum.width = X.shape[3] 70 | datum.data = X[j].tobytes() 71 | str_id = '{:08}'.format(k) 72 | k += 1 73 | txn.put(str_id.encode('ascii'), datum.SerializeToString()) 74 | print('load mat (%d/%d): %s' %(i,len(fns),fn)) 75 | 76 | print('done') 77 | 78 | 79 | # Create Pavia Centre dataset 80 | def create_PaviaCentre(): 81 | print('create Pavia Centre...') 82 | datadir = './data/PaviaCentre/' 83 | fns = os.listdir(datadir) 84 | fns = [fn.split('.')[0]+'.mat' for fn in fns] 85 | 86 | create_lmdb_train( 87 | datadir, fns, '/home/kaixuan/Dataset/PaviaCentre', 'hsi', # your own dataset address 88 | crop_sizes=None, 89 | scales=(1,), 90 | ksizes=(101, 64, 64), 91 | strides=[(101, 32, 32)], 92 | load=loadmat, augment=True, 93 | ) 94 | 95 | # Create ICVL training dataset 96 | def create_icvl64_31(): 97 | print('create icvl64_31...') 98 | datadir = 'datasets/training/ICVL/train/' # your own data address 99 | fns = os.listdir(datadir) 100 | fns = [fn.split('.')[0]+'.mat' for fn in fns] 101 | 102 | create_lmdb_train( 103 | datadir, fns, 'datasets/ICVL64_31', 'rad', # your own dataset address 104 | crop_sizes=(1024, 1024), 105 | scales=(1, 0.5, 0.25), 106 | ksizes=(31, 64, 64), 107 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)], 108 | load=h5py.File, augment=True, 109 | ) 110 | 111 | 112 | if __name__ == '__main__': 113 | create_icvl64_31() 114 | # create_PaviaCentre() 115 | pass 116 | -------------------------------------------------------------------------------- /utility/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import six 7 | import string 8 | import sys 9 | import caffe 10 | if sys.version_info[0] == 2: 11 | import cPickle as pickle 12 | else: 13 | import pickle 14 | 15 | 16 | class LMDBDataset(data.Dataset): 17 | def __init__(self, db_path, repeat=1): 18 | import lmdb 19 | self.db_path = db_path 20 | self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, 21 | readahead=False, meminit=False) 22 | with self.env.begin(write=False) as txn: 23 | self.length = txn.stat()['entries'] 24 | self.repeat = repeat 25 | # cache_file = '_cache_' + db_path.replace('/', '_') 26 | # if os.path.isfile(cache_file): 27 | # self.keys = pickle.load(open(cache_file, "rb")) 28 | # else: 29 | # with self.env.begin(write=False) as txn: 30 | # self.keys = [key for key, _ in txn.cursor()] 31 | # pickle.dump(self.keys, open(cache_file, "wb")) 32 | 33 | def __getitem__(self, index): 34 | index = index % self.length 35 | env = self.env 36 | with env.begin(write=False) as txn: 37 | raw_datum = txn.get('{:08}'.format(index).encode('ascii')) 38 | 39 | datum = caffe.proto.caffe_pb2.Datum() 40 | datum.ParseFromString(raw_datum) 41 | 42 | flat_x = np.fromstring(datum.data, dtype=np.float32) 43 | # flat_x = np.fromstring(datum.data, dtype=np.float64) 44 | x = flat_x.reshape(datum.channels, datum.height, datum.width) 45 | 46 | return x 47 | 48 | def __len__(self): 49 | return self.length * self.repeat 50 | 51 | def __repr__(self): 52 | return self.__class__.__name__ + ' (' + self.db_path + ')' 53 | 54 | def savedataset(): 55 | dataset = LMDBDataset('/home/ironkitty/nas_data/datasets/houston/houston512_46.db') 56 | import scipy.io as io 57 | for i in range(len(dataset)): 58 | x = dataset[i] 59 | x = x.transpose(1, 2, 0) 60 | io.savemat("/home/ironkitty/nas_data/datasets/houston/houston512_46_mac/houston_ori_"+str(i)+".mat", {'DataCube':x}) 61 | 62 | if __name__ == '__main__': 63 | # dataset = LMDBDataset('Data/ICVL/ICVL32.db') 64 | # dataset = LMDBDataset('/home/kaixuan/Dataset/ICVL32_28.db') 65 | # dataset = LMDBDataset('/home/kaixuan/Dataset/CAVE512_31.db') 66 | savedataset() 67 | # dataset = LMDBDataset('/home/ironkitty/nas_data/datasets/houston/houston64_stride_24_46.db') 68 | 69 | # print(len(dataset)) 70 | # data = dataset[3] 71 | # print(data.shape) 72 | # print(np.max(data[...]), np.min(data[...])) 73 | # from util import Visualize3D 74 | # Visualize3D(data) 75 | 76 | # train_loader = data.DataLoader(dataset, batch_size=128, num_workers=4) 77 | # print(iter(train_loader).next().shape) 78 | -------------------------------------------------------------------------------- /utility/mat_data.py: -------------------------------------------------------------------------------- 1 | """generate testing mat dataset""" 2 | import os 3 | import numpy as np 4 | import h5py 5 | from os.path import join, exists 6 | from scipy.io import loadmat, savemat 7 | 8 | from util import crop_center, Visualize3D, minmax_normalize 9 | 10 | 11 | def create_mat_dataset(datadir, fnames, newdir, matkey, func=None, load=h5py.File): 12 | if not exists(newdir): 13 | os.mkdir(newdir) 14 | 15 | for i, fn in enumerate(fnames): 16 | print('generate data(%d/%d)' %(i+1, len(fnames))) 17 | filepath = join(datadir, fn) 18 | mat = load(filepath) 19 | 20 | data = func(mat[matkey][...]) 21 | # Visualize3D(data) 22 | # import ipdb; ipdb.set_trace() 23 | 24 | # if not exists(join(newdir, fn)): 25 | savemat(join(newdir, fn), {'data':data.transpose((2,1,0))}) 26 | 27 | 28 | def create_icvl_sr(): 29 | basedir = '/media/kaixuan/DATA/Papers/Code/Matlab/ITSReg/code of ITSReg MSI denoising/data' 30 | datadir = join(basedir, 'icvl_test') 31 | newdir = join(basedir, 'icvl_256_sr') 32 | fnames = os.listdir(datadir) 33 | 34 | def func(data): 35 | data = np.rot90(data, k=-1, axes=(1,2)) 36 | 37 | data = crop_center(data, 256, 256) 38 | 39 | data = minmax_normalize(data) 40 | return data 41 | 42 | create_mat_dataset(datadir, fnames, newdir, 'rad', func=func) 43 | 44 | 45 | if __name__ == '__main__': 46 | # create_icvl_sr() 47 | pass 48 | -------------------------------------------------------------------------------- /utility/read_HSI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def read_HSI(data,kernel_size=(56,56,31), stride=(15,15,15),device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): 5 | import torch 6 | from models import im2col 7 | data_shape = [] 8 | data_shape.append(data.shape) 9 | #device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") 10 | # print(data.shape[0],kernel_size[0],stride[0]) 11 | # if data.shape[0] < 31: 12 | if False: 13 | pad_x = 31 -data.shape[0] 14 | else: 15 | pad_x = (-data.shape[0]+kernel_size[0])%stride[0] 16 | # pad_x = 15 17 | pad_y = (-data.shape[1]+kernel_size[1])%stride[1] 18 | #pad_z =0 19 | pad_z = (-data.shape[2]+kernel_size[2])%stride[2] 20 | # print(pad_x,pad_y,pad_z) 21 | data =np.pad(data,((0, pad_x),(0, pad_y),(0,pad_z)),'edge') 22 | # print(data.shape) 23 | data = torch.from_numpy(data).to(device) 24 | data_shape.append(data.shape) 25 | #pad = torch.nn.ReplicationPad3d([0, pad_right, 0, pad_down,0,pad_z]) 26 | #data = pad(data) 27 | col_data = im2col.Cube2Col(data.reshape((1,data.shape[0],data.shape[1],data.shape[2])),kernel_size, stride,padding=0,tensorized=True,device=device) 28 | # print(col_data.shape) 29 | data_shape.append(col_data.shape) 30 | #print(col_data[0][:][-1][-1][-1]) 31 | 32 | col_data = col_data.view(1,kernel_size[0],kernel_size[1],kernel_size[2],col_data.shape[2],col_data.shape[3],col_data.shape[4]) #[1,56,56,31,18,18,1] 33 | data_shape.append(col_data.shape) 34 | col_data = col_data.permute(0, 4, 5, 6, 1, 2, 3) # 35 | data_shape.append(col_data.shape) 36 | col_data = col_data.view(col_data.shape[0]*col_data.shape[1]*col_data.shape[2]*col_data.shape[3],1,col_data.shape[4],col_data.shape[5],col_data.shape[6]) 37 | data_shape.append(col_data.shape) 38 | #print(data_shape) 39 | #test = col_data[:,:,:,:,0,0,0] 40 | return col_data,data_shape 41 | 42 | if __name__ == "__main__": 43 | import os 44 | import sys 45 | import scipy.io as scio 46 | os.chdir(sys.path[0]) 47 | dataFiles = [] 48 | dataFiles.append('train/4cam_0411-1640-1.mat') 49 | for dataFile in dataFiles: 50 | data = scio.loadmat(dataFile)['DataCube'] 51 | test = np.sum(data) 52 | col_data = read_HSI(data)[0] 53 | test2 = torch.sum(col_data) 54 | col_data = col_data.cpu().numpy() 55 | #data.reshape((1,data.shape[0],data.shape[1],data.shape[2])) 56 | scio.savemat('col_Data.mat', {'colData':col_data}) 57 | pass -------------------------------------------------------------------------------- /utility/readme.py: -------------------------------------------------------------------------------- 1 | from read_HSI import read_HSI 2 | from refold import refold 3 | kernel_size = (x,y,z) 4 | stride =(x_,y_,z_) 5 | col_data,data_shape = read_HSI('''cube''',kernel_size=kernel_size,stride=stride) 6 | #col_data.shape = [n,1,x,y,z] 对col_data进行各种运算,保证shape不变 7 | '''cube''' = refold(col_data,data_shape=data_shape, kernel_size=kernel_size,stride=stride) 8 | -------------------------------------------------------------------------------- /utility/refold.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def refold(col_data,data_shape, kernel_size, stride,device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")): 4 | #import torch 5 | import os 6 | import sys 7 | os.chdir(sys.path[0]) 8 | from models import im2col 9 | # print(col_data.shape, data_shape[-2]) 10 | col_data = col_data.view(data_shape[-2]) 11 | #print(col_data.shape) 12 | col_data = col_data.permute(0, 4, 5, 6, 1, 2, 3) 13 | #print(col_data.shape) 14 | col_data = col_data.reshape(data_shape[-4]) 15 | #print(col_data.shape) 16 | cube_data = im2col.Col2Cube(col_data.to(device),data_shape[1], kernel_size, stride, padding=0, dilation=1, avg=True,input_tensorized=True,device=device) 17 | #print(cube_data.shape) 18 | cube_data = cube_data[0,:data_shape[0][0],:data_shape[0][1],:data_shape[0][2]] 19 | #print(cube_data.shape) 20 | 21 | return cube_data 22 | 23 | if __name__ == "__main__": 24 | import scipy.io as scio 25 | import os 26 | import sys 27 | import numpy as np 28 | import torch 29 | os.chdir(sys.path[0]) 30 | col_data_LISTA = scio.loadmat('col_Data.mat')['colData'] 31 | #col_data_LISTA = np.transpose(col_data_LISTA) 32 | #col_data_LISTA = col_data_LISTA.reshape((1,125,66,66,118)) 33 | col_data_LISTA = torch.from_numpy(col_data_LISTA) 34 | #scio.savemat('col_data_LISTA.mat', {'colData':col_data_LISTA}) 35 | refold_data = refold(col_data_LISTA,output_size=(206, 206, 31), kernel_size=(56,56,31), stride=(15,15,15), padding=0, dilation=1, avg=True,input_tensorized=True) 36 | refold_data = refold_data[:,:200,:200,:31] 37 | refold_data=refold_data.cpu().numpy() 38 | test = np.sum(refold_data) 39 | scio.savemat('refold_data.mat', {'refoldData':refold_data}) 40 | #torch.Size([1, 125, 66, 66, 118])186890.55419061845 41 | -------------------------------------------------------------------------------- /utility/util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import cv2 6 | 7 | import h5py 8 | import os 9 | import random 10 | 11 | import threading 12 | from itertools import product 13 | from scipy.io import loadmat 14 | from functools import partial 15 | from scipy.ndimage import zoom 16 | from matplotlib.widgets import Slider 17 | from PIL import Image 18 | 19 | 20 | def Data2Volume(data, ksizes, strides): 21 | """ 22 | Construct Volumes from Original High Dimensional (D) Data 23 | """ 24 | dshape = data.shape 25 | PatNum = lambda l, k, s: (np.floor( (l - k) / s ) + 1) 26 | 27 | TotalPatNum = 1 28 | for i in range(len(ksizes)): 29 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i]) 30 | 31 | V = np.zeros([int(TotalPatNum)]+ksizes); # create D+1 dimension volume 32 | 33 | args = [range(kz) for kz in ksizes] 34 | for s in product(*args): 35 | s1 = (slice(None),) + s 36 | s2 = tuple([slice(key, -ksizes[i]+key+1 or None, strides[i]) for i, key in enumerate(s)]) 37 | V[s1] = np.reshape(data[s2], (-1,)) 38 | 39 | return V 40 | 41 | def crop_center(img,cropx,cropy): 42 | _,y,x = img.shape 43 | startx = x//2-(cropx//2) 44 | starty = y//2-(cropy//2) 45 | return img[:, starty:starty+cropy,startx:startx+cropx] 46 | 47 | 48 | def rand_crop(img, cropx, cropy): 49 | _,y,x = img.shape 50 | x1 = random.randint(0, x - cropx) 51 | y1 = random.randint(0, y - cropy) 52 | return img[:, y1:y1+cropy, x1:x1+cropx] 53 | 54 | 55 | def sequetial_process(*fns): 56 | """ 57 | Integerate all process functions 58 | """ 59 | def processor(data): 60 | for f in fns: 61 | data = f(data) 62 | return data 63 | return processor 64 | 65 | 66 | def minmax_normalize(array): 67 | amin = np.min(array) 68 | amax = np.max(array) 69 | return (array - amin) / (amax - amin) 70 | 71 | 72 | def frame_diff(frames): 73 | diff_frames = frames[1:, ...] - frames[:-1, ...] 74 | return diff_frames 75 | 76 | 77 | def visualize(filename, matkey, load=loadmat, preprocess=None): 78 | """ 79 | Visualize a preprecessed hyperspectral image 80 | """ 81 | if not preprocess: 82 | preprocess = lambda identity: identity 83 | mat = load(filename) 84 | data = preprocess(mat[matkey]) 85 | # print(data.shape) 86 | # print(np.max(data), np.min(data)) 87 | 88 | data = np.squeeze(data[:,:,:]) 89 | Visualize3D(data) 90 | # Visualize3D(np.squeeze(data[:,0,:,:])) 91 | 92 | def Visualize3D(data, meta=None): 93 | data = np.squeeze(data) 94 | 95 | for ch in range(data.shape[0]): 96 | data[ch, ...] = minmax_normalize(data[ch, ...]) 97 | 98 | print(np.max(data), np.min(data)) 99 | 100 | ax = plt.subplot(111) 101 | plt.subplots_adjust(left=0.25, bottom=0.25) 102 | 103 | frame = 0 104 | # l = plt.imshow(data[frame,:,:]) 105 | 106 | l = plt.imshow(data[frame,:,:], cmap='gray') #shows 256x256 image, i.e. 0th frame 107 | # plt.colorbar() 108 | axcolor = 'lightgoldenrodyellow' 109 | axframe = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor) 110 | sframe = Slider(axframe, 'Frame', 0, data.shape[0]-1, valinit=0) 111 | 112 | def update(val): 113 | frame = int(np.around(sframe.val)) 114 | l.set_data(data[frame,:,:]) 115 | if meta is not None: 116 | axframe.set_title(meta[frame]) 117 | 118 | sframe.on_changed(update) 119 | 120 | plt.show() 121 | 122 | 123 | def data_augmentation(image, mode=None): 124 | """ 125 | Args: 126 | image: np.ndarray, shape: C X H X W 127 | """ 128 | axes = (-2, -1) 129 | flipud = lambda x: x[:, ::-1, :] 130 | 131 | if mode is None: 132 | mode = random.randint(0, 7) 133 | if mode == 0: 134 | # original 135 | image = image 136 | elif mode == 1: 137 | # flip up and down 138 | image = flipud(image) 139 | elif mode == 2: 140 | # rotate counterwise 90 degree 141 | image = np.rot90(image, axes=axes) 142 | elif mode == 3: 143 | # rotate 90 degree and flip up and down 144 | image = np.rot90(image, axes=axes) 145 | image = flipud(image) 146 | elif mode == 4: 147 | # rotate 180 degree 148 | image = np.rot90(image, k=2, axes=axes) 149 | elif mode == 5: 150 | # rotate 180 degree and flip 151 | image = np.rot90(image, k=2, axes=axes) 152 | image = flipud(image) 153 | elif mode == 6: 154 | # rotate 270 degree 155 | image = np.rot90(image, k=3, axes=axes) 156 | elif mode == 7: 157 | # rotate 270 degree and flip 158 | image = np.rot90(image, k=3, axes=axes) 159 | image = flipud(image) 160 | 161 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D. 162 | # disable it when training 2D CNN, e.g. MemNet 163 | if random.random() < 0.5: 164 | image = image[::-1, :, :] 165 | 166 | return np.ascontiguousarray(image) 167 | 168 | 169 | class LockedIterator(object): 170 | def __init__(self, it): 171 | self.lock = threading.Lock() 172 | self.it = it.__iter__() 173 | 174 | def __iter__(self): return self 175 | 176 | def __next__(self): 177 | self.lock.acquire() 178 | try: 179 | return next(self.it) 180 | finally: 181 | self.lock.release() 182 | 183 | 184 | if __name__ == '__main__': 185 | """Code Usage Example""" 186 | """ICVL""" 187 | # hsi_rot = partial(np.rot90, k=-1, axes=(1,2)) 188 | # crop = lambda img: img[:,-1024:, -1024:] 189 | # zoom_512 = partial(zoom, zoom=[1, 0.5, 0.5]) 190 | # d2v = partial(Data2Volume, ksizes=[31,64,64], strides=[1,28,28]) 191 | # preprocess = sequetial_process(hsi_rot, crop, minmax_normalize, d2v) 192 | 193 | # preprocess = sequetial_process(hsi_rot, crop, minmax_normalize) 194 | # datadir = 'Data/ICVL/Training/' 195 | # fns = os.listdir(datadir) 196 | # mat = h5py.File(os.path.join(datadir, fns[1])) 197 | # data = preprocess(mat['rad']) 198 | # data = np.linalg.norm(data, ord=2, axis=(1,2)) 199 | 200 | """Common""" 201 | # print(data) 202 | # fig = plt.figure() 203 | # ax = fig.add_subplot(111) 204 | # ax.plot(data) 205 | # plt.show() 206 | 207 | # preprocess = sequetial_process(hsi_rot, crop, minmax_normalize, frame_diff) 208 | # visualize(os.path.join(datadir, fns[0]), 'rad', load=h5py.File, preprocess=preprocess) 209 | # visualize('Data/BSD/TrainingPatches/imdb_40_128.mat', 'inputs', load=h5py.File, preprocess=None) 210 | 211 | # preprocess = lambda x: np.transpose(x[4][0],(2,0,1)) 212 | # preprocess = lambda x: minmax_normalize(np.transpose(np.array(x,dtype=np.float),(2,0,1))) 213 | 214 | # visualize('/media/kaixuan/DATA/Papers/Code/Data/PIRM18/sample/true_hr', 'hsi', load=loadmat, preprocess=preprocess) 215 | # visualize('/media/kaixuan/DATA/Papers/Code/Data/PIRM18/sample/img_1', 'true_hr', load=loadmat, preprocess=preprocess) 216 | 217 | # visualize('/media/kaixuan/DATA/Papers/Code/Matlab/ITSReg/code of ITSReg MSI denoising/data/real/new/Indian/Indian_pines.mat', 'hsi', load=loadmat, preprocess=preprocess) 218 | # visualize('/media/kaixuan/DATA/Papers/Code/Matlab/ECCV2018/Result/Indian/Indian_pines/QRNN3D-f.mat', 'R_hsi', load=loadmat, preprocess=preprocess) 219 | # visualize('/media/kaixuan/DATA/Papers/Code/Matlab/ECCV2018/Data/Pavia/PaviaU', 'input', load=loadmat, preprocess=preprocess) 220 | 221 | pass --------------------------------------------------------------------------------