├── .gitignore ├── README.md └── code ├── model ├── DefineNet.py ├── ISDTDNet.py └── modules │ ├── Diffusion.py │ ├── Esr.py │ ├── IDN.py │ ├── INR_align.py │ ├── INR_sr.py │ ├── MIFA.py │ ├── SFFI.py │ ├── Style.py │ └── half_IDN.py ├── train_DCHFR.py ├── train_ISDTD.py └── utils ├── BCEDiceloss.py ├── DCHFRDataset.py ├── ISDTDDataset.py ├── Imetrics.py ├── Pmetric.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official code for "Diffusion-Based Continuous Feature Representation for Infrared Small-Dim Target Detection" 2 | ![3-3_DCHFR-Net](https://github.com/flyannie/DCFR-Net/assets/162861421/00e27e9f-b37c-4fba-88b7-59b8e4ec9248) 3 | 4 | # 1. Data preparation 5 | 6 | This experiment includes multiple public datasets, which are introduced below one by one. The quadruple bicubic interpolation required in the DCHFR branch is included in the code. Additionally, since DCHFR incorporates a diffusion model, its testing process is relatively slower. To avoid affecting the training progress, you may consider extending the testing interval or reducing the number of test images. 7 | 8 | ## 1.1 "BigReal" and "BigSim" 9 | 10 | * dataset: https://xzbai.buaa.edu.cn/datasets.html 11 | 12 | * paper doi: 10.1109/TGRS.2023.3235150 13 | 14 | * description: IRDST dataset consists of 142,727 real and simulation frames (40,650 real frames in 85 scenes and 102,077 simulation frames in 317 scenes). 15 | 16 | * path format 17 | 18 | images: "your root path/images/xxx.png", where "xxx" is included in the corresponding TXT file. 19 | 20 | labels: "your root path/masks/xxx.png", where "xxx" is included in the corresponding TXT file. 21 | 22 | TXT files: "your root path/train.txt", "your root path/test_hr.txt", "your root path/test.txt" 23 | 24 | ## 1.2 "NUAA" 25 | 26 | * dataset: https://github.com/YimianDai/sirst 27 | 28 | * paper doi: 10.1109/WACV48630.2021.00099 29 | 30 | * description: SIRST is a dataset specially constructed for single-frame infrared small target detection, in which the images are selected from hundreds of infrared sequences for different scenarios. 31 | 32 | * path format 33 | 34 | images: "your root path/train_imgs/xxx.png", "your root path/test_imgs/xxx.png", where "xxx" will be automatically retrieved in the code. 35 | 36 | labels: "your root path/train_labels/xxx.png", "your root path/test_labels/xxx.png", where "xxx" will be automatically retrieved in the code. 37 | 38 | ## 1.3 "NUDT" 39 | 40 | * dataset: https://github.com/YeRen123455/Infrared-Small-Target-Detection 41 | 42 | * paper doi: 10.1109/TIP.2022.3199107 43 | 44 | * description: NUDT-SIRST dataset is a synthesized dataset, which contains 1327 images with resolution of 256x256. 45 | 46 | * path format 47 | 48 | images: "your root path/train_imgs/xxx.png", "your root path/test_imgs/xxx.png", where "xxx" will be automatically retrieved in the code. 49 | 50 | labels: "your root path/train_labels/xxx.png", "your root path/test_labels/xxx.png", where "xxx" will be automatically retrieved in the code. 51 | 52 | ## 1.4 "IRSTD" 53 | 54 | * dataset: https://github.com/RuiZhang97/ISNet 55 | 56 | * paper doi: 10.1109/CVPR52688.2022.00095 57 | 58 | * description: IRSTD-1k dataset is the realistic infrared small target detection dataset, which consists of 1,001 manually labeled realistic images with various target shapes, different target sizes, and rich clutter back-grounds from diverse scenes. 59 | 60 | * path format 61 | 62 | images: "your root path/train_imgs/xxx.png", "your root path/test_imgs/xxx.png", where "xxx" will be automatically retrieved in the code. 63 | 64 | labels: "your root path/train_labels/xxx.png", "your root path/test_labels/xxx.png", where "xxx" will be automatically retrieved in the code. 65 | 66 | # 2.Training DCHFR branch 67 | 68 | If resources are insufficient to train this branch, you can directly initialize the encoder weights in "train_ISDTD.py" randomly. This approach can still achieve results close to the state-of-the-art (SOTA). Of course, with sufficient resources, more targeted results can be obtained. 69 | 70 | * change “--root” to your root path 71 | 72 | * choose "--dataset_type", "--phase" 73 | 74 | * set appropriate parameters, including "--base_size", "--crop_size", "--batch_size", "--val_batch_size", "--num_worker", and "--n_epoch" 75 | 76 | * you can also modify the result save location ("--results_hr", "--results_sr", "--results_lr") and checkpoint file save location ("--checkpoint") if needed 77 | 78 | * If you need to use WANDB for visualization, please set your key first (os.environ["WANDB_API_KEY"] = "xxxx"). 79 | 80 | `python train_DCHFR.py` 81 | 82 | # 3.Training ISDTD branch 83 | 84 | * change “--root” to your root path 85 | 86 | * choose "--dataset_type", "--phase" 87 | 88 | * enter the file path of the trained DCHFR/ISDTD checkpoint into "--checkpoint_path" / "--ISDTD_checkpoint_path" 89 | 90 | * set appropriate parameters, including "--base_size", "--batch_size", "--val_batch_size", "--num_worker", "--n_epoch" and "--val_freq" 91 | 92 | * you can also modify the result save location ("--results_mask") and weight file save location ("--save_path") if needed 93 | 94 | * If you need to use WANDB for visualization, please set your key first (os.environ["WANDB_API_KEY"] = "xxxx"). 95 | 96 | `python train_ISDTD.py` 97 | -------------------------------------------------------------------------------- /code/model/DefineNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | from .modules.IDN import IDN 4 | from .modules.Esr import Esr 5 | from .modules.Diffusion import Diffusion 6 | from .ISDTDNet import ISDTDNet 7 | 8 | 9 | def define_DCHFRnet(args): 10 | IDN_model = IDN(in_channel=6, out_channel=3, norm_groups=32, inner_channel=64, channel_mults=[1,2,4,8,8], attn_res=[32], res_blocks=2, dropout=0.2, image_size=args.base_size) 11 | encoder = Esr(n_resblocks=16, n_feats=64, res_scale=1, no_upsampling=False, rgb_range=1) 12 | DCHFR_net = Diffusion(encoder, denoise_fn=IDN_model, image_size=args.base_size, channels=3, conditional=True) 13 | init_weights(DCHFR_net, init_type='orthogonal') 14 | assert torch.cuda.is_available() 15 | return DCHFR_net 16 | 17 | def define_trained_DCHFRnet(args): 18 | IDN_model = IDN(in_channel=6, out_channel=3, norm_groups=32, inner_channel=64, channel_mults=[1,2,4,8,8], attn_res=[32], res_blocks=2, dropout=0.2, image_size=args.base_size) 19 | encoder = Esr(n_resblocks=16, n_feats=64, res_scale=1, scale=4, no_upsampling=False, rgb_range=1) 20 | net = Diffusion(encoder, denoise_fn=IDN_model, image_size=args.base_size, channels=3, conditional=True) 21 | assert torch.cuda.is_available() 22 | return net 23 | 24 | def define_ISDTDnet(args): 25 | ISDTD_net = ISDTDNet(args) 26 | assert torch.cuda.is_available() 27 | return ISDTD_net 28 | 29 | def init_weights(net, init_type='orthogonal'): 30 | if init_type == 'orthogonal': 31 | net.apply(weights_init_orthogonal) 32 | else: 33 | raise NotImplementedError( 34 | 'initialization method [{:s}] not implemented'.format(init_type)) 35 | 36 | def weights_init_orthogonal(m): 37 | classname = m.__class__.__name__ 38 | if classname.find('Conv') != -1: 39 | init.orthogonal_(m.weight.data, gain=1) 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif classname.find('Linear') != -1: 43 | init.orthogonal_(m.weight.data, gain=1) 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | elif classname.find('BatchNorm2d') != -1: 47 | init.constant_(m.weight.data, 1.0) 48 | init.constant_(m.bias.data, 0.0) -------------------------------------------------------------------------------- /code/model/ISDTDNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .modules.SFFI import SFFI 3 | from .modules.MIFA import MIFA 4 | from torch.nn import functional as F 5 | from .modules.half_IDN import half_IDN 6 | 7 | 8 | class ISDTDNet(nn.Module): 9 | 10 | def __init__(self, args): 11 | super(ISDTDNet, self).__init__() 12 | self.encoder = half_IDN(in_channel=3, norm_groups=32, inner_channel=64, channel_mults=[1,2,4,8,16], attn_res=[32], res_blocks=2, dropout=0.2, image_size=args.base_size) 13 | self.MIFA = MIFA() 14 | self.SFFI = SFFI(input_dim=3, hidden_dim=64) 15 | 16 | def forward(self, feat, img): 17 | _, _, h, w = feat.size() 18 | feat = self.encoder(feat,img) 19 | out = self.MIFA(feat) 20 | resize_out = F.upsample(input=out, size=(h, w), mode='bilinear', align_corners=True) 21 | mask = self.SFFI(resize_out) 22 | return mask 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /code/model/modules/Diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from functools import partial 6 | from tqdm.contrib import tzip 7 | from utils.utils import default, compute_alpha 8 | 9 | 10 | class Diffusion(nn.Module): 11 | def __init__(self, encoder, denoise_fn, image_size, channels=3, conditional=True, feat_unfold=False, local_ensemble=False, cell_decode=False, schedule_opt=None): 12 | super().__init__() 13 | self.channels = channels 14 | self.image_size = image_size 15 | self.encoder = encoder 16 | self.denoise_fn = denoise_fn 17 | self.conditional = conditional 18 | self.feat_unfold = feat_unfold 19 | self.local_ensemble = local_ensemble 20 | self.cell_decode = cell_decode 21 | 22 | def set_loss(self, device): 23 | self.loss_func = nn.L1Loss(reduction='sum').to(device) 24 | 25 | def set_new_noise_schedule(self, device): 26 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 27 | betas = np.linspace(1e-6, 1e-2, 2000, dtype=np.float64) 28 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas 29 | alphas = 1. - betas 30 | alphas_cumprod = np.cumprod(alphas, axis=0) 31 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 32 | self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod)) 33 | timesteps, = betas.shape 34 | self.num_timesteps = int(timesteps) 35 | self.register_buffer('betas', to_torch(betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 44 | self.register_buffer('posterior_variance', to_torch(posterior_variance)) 45 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 46 | self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 47 | self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 48 | self.register_buffer('ddim_c1', torch.sqrt(to_torch((1. - alphas_cumprod / alphas_cumprod_prev) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)))) 49 | 50 | def forward(self, x, *args, **kwargs): 51 | return self.p_losses(x, *args, **kwargs) 52 | 53 | def p_losses(self, x_in, noise=None): 54 | inp, scaler = x_in['lr'], x_in['scaler'] 55 | x_feat = self.gen_feat(inp, x_in['hr'].shape[2:]) 56 | x_con = x_feat 57 | x_start = x_in['hr'] 58 | [b, c, h, w] = x_start.shape 59 | t = np.random.randint(1, self.num_timesteps + 1) 60 | continuous_sqrt_alpha_cumprod = torch.FloatTensor(np.random.uniform(self.sqrt_alphas_cumprod_prev[t-1], self.sqrt_alphas_cumprod_prev[t], size=b)).to(x_start.device) 61 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(b, -1) 62 | noise = default(noise, lambda: torch.randn_like(x_start)) 63 | x_noisy = self.q_sample(x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise) 64 | if not self.conditional: 65 | x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod) 66 | else: 67 | x_recon = self.denoise_fn(torch.cat([x_con, x_noisy], dim=1), x_con, scaler, continuous_sqrt_alpha_cumprod) 68 | loss = self.loss_func(noise, x_recon) 69 | return loss 70 | 71 | def gen_feat(self, inp, shape): 72 | feat = self.encoder(inp, shape) 73 | return feat 74 | 75 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None): 76 | noise = default(noise, lambda: torch.randn_like(x_start)) 77 | return (continuous_sqrt_alpha_cumprod * x_start + (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise) 78 | 79 | @torch.no_grad() 80 | def super_resolution(self, x_in, continous=False, use_ddim=False): 81 | if not use_ddim: 82 | return self.p_sample_loop(x_in, continous) 83 | else: 84 | return self.generalized_steps(x_in, conditional_input=None, continous=continous) 85 | 86 | @torch.no_grad() 87 | def p_sample_loop(self, x_in, continous=False): 88 | device = self.betas.device 89 | sample_inter = (1 | (self.num_timesteps // 10)) 90 | if not self.conditional: 91 | shape = x_in 92 | img = torch.randn(shape, device=device) 93 | ret_img = img 94 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 95 | img = self.p_sample(img, i) 96 | if i % sample_inter == 0: 97 | ret_img = torch.cat([ret_img, img], dim=0) 98 | else: 99 | x, scaler = x_in['lr'], x_in['scaler'] 100 | shape = x.shape 101 | gt_shape = list(x_in['hr'].shape) 102 | img = torch.randn(gt_shape, device=device) 103 | x_feat = self.gen_feat(x, gt_shape[2:]) 104 | x_con = x_feat 105 | ret_img = x_con 106 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 107 | img = self.p_sample(img, i, scaler, condition_x=x_con) 108 | if i % sample_inter == 0: 109 | ret_img = torch.cat([ret_img, img], dim=0) 110 | if continous: 111 | return ret_img 112 | else: 113 | return ret_img[-1] 114 | 115 | @torch.no_grad() 116 | def p_sample(self, x, t, scaler, clip_denoised=True, condition_x=None): 117 | model_mean, model_log_variance = self.p_mean_variance(x=x, t=t, scaler=scaler, clip_denoised=clip_denoised, condition_x=condition_x) 118 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 119 | return model_mean + noise * (0.5 * model_log_variance).exp() 120 | 121 | def p_mean_variance(self, x, t, scaler, clip_denoised: bool, condition_x=None): 122 | batch_size = x.shape[0] 123 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device) 124 | if condition_x is not None: 125 | x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), condition_x, scaler, noise_level)) 126 | else: 127 | x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x, noise_level)) 128 | if clip_denoised: 129 | x_recon.clamp_(-1., 1.) 130 | model_mean, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 131 | return model_mean, posterior_log_variance 132 | 133 | def predict_start_from_noise(self, x_t, t, noise): 134 | return self.sqrt_recip_alphas_cumprod[t] * x_t - self.sqrt_recipm1_alphas_cumprod[t] * noise 135 | 136 | def q_posterior(self, x_start, x_t, t): 137 | posterior_mean = self.posterior_mean_coef1[t] * x_start + self.posterior_mean_coef2[t] * x_t 138 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 139 | return posterior_mean, posterior_log_variance_clipped 140 | 141 | @torch.no_grad() 142 | def generalized_steps(self, x_in, conditional_input=None, continous=False): 143 | device = self.betas.device 144 | skip = self.num_timesteps // 200 145 | seq = range(0, self.num_timesteps, skip) 146 | seq_next = [-1] + list(seq[:-1]) 147 | x, scaler = x_in['lr'], x_in['scaler'] 148 | b = x.size(0) 149 | gt_shape = list(x_in['hr'].shape) 150 | img = torch.randn(gt_shape, device=device) 151 | x_feat = self.gen_feat(x, gt_shape[2:]) 152 | conditional_input = x_feat 153 | ret_img = img 154 | for i, j in tzip(reversed(seq), reversed(seq_next)): 155 | if i == 0: 156 | break 157 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[i + 1]]).repeat(b, 1).to(x.device) 158 | t = (torch.ones(b) * i).to(x.device) 159 | next_t = (torch.ones(b) * j).to(x.device) 160 | at = compute_alpha(self.betas, t.long()) 161 | at_next = compute_alpha(self.betas, next_t.long()) 162 | xt = ret_img[-1] # [c,h,w] 163 | et = self.denoise_fn(torch.cat([conditional_input, xt.unsqueeze(0)], dim=1), conditional_input, scaler, noise_level) 164 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 165 | c1 = (0 * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()) 166 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 167 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(img) + c2 * et 168 | ret_img = torch.cat([ret_img, xt_next], dim=0) 169 | if continous: 170 | return ret_img 171 | else: 172 | return ret_img[-1] 173 | 174 | 175 | -------------------------------------------------------------------------------- /code/model/modules/Esr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from utils.utils import MeanShift, ResBlock 4 | 5 | 6 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 7 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 8 | 9 | 10 | class Esr(nn.Module): 11 | def __init__(self, n_colors=3, n_resblocks=16, n_feats=64, res_scale=1, scale=2, no_upsampling=False, rgb_range=1, conv=default_conv): 12 | super(Esr, self).__init__() 13 | n_resblocks = n_resblocks 14 | n_feats = n_feats 15 | kernel_size = 3 16 | act = nn.ReLU(True) 17 | self.no_upsampling = no_upsampling 18 | self.sub_mean = MeanShift(rgb_range) 19 | self.add_mean = MeanShift(rgb_range, sign=1) 20 | m_head = [conv(n_colors, n_feats, kernel_size)] 21 | m_body = [ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale) for _ in range(n_resblocks)] 22 | m_body.append(conv(n_feats, n_feats, kernel_size)) 23 | self.head = nn.Sequential(*m_head) 24 | self.body = nn.Sequential(*m_body) 25 | if self.no_upsampling: 26 | self.out_dim = n_feats 27 | else: 28 | self.out_dim = n_colors 29 | m_tail = [conv(n_feats, n_colors, kernel_size)] 30 | self.tail = nn.Sequential(*m_tail) 31 | 32 | def forward(self, x, shape): 33 | x = self.head(x) 34 | res = self.body(x) 35 | res += x 36 | if self.no_upsampling: 37 | x = res 38 | print("EDSR_no_up", x.shape) 39 | else: 40 | res = F.interpolate(res, shape) 41 | x = self.tail(res) 42 | return x 43 | 44 | def load_state_dict(self, state_dict, strict=True): 45 | own_state = self.state_dict() 46 | for name, param in state_dict.items(): 47 | if name in own_state: 48 | if isinstance(param, nn.Parameter): 49 | param = param.data 50 | try: 51 | own_state[name].copy_(param) 52 | except Exception: 53 | if name.find('tail') == -1: 54 | raise RuntimeError('While copying the parameter named {}, ' 55 | 'whose dimensions in the model are {} and ' 56 | 'whose dimensions in the checkpoint are {}.' 57 | .format(name, own_state[name].size(), param.size())) 58 | elif strict: 59 | if name.find('tail') == -1: 60 | raise KeyError('unexpected key "{}" in state_dict'.format(name)) 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /code/model/modules/IDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.utils import PositionalEncoding, Swish, ResnetBlocWithAttn, Downsample, Block, default, exists 4 | from .Style import StyleLayer, EqualLinear 5 | from .INR_sr import INR_sr 6 | from einops import rearrange 7 | 8 | 9 | class IDN(nn.Module): 10 | def __init__(self,in_channel=6,out_channel=3,inner_channel=64,norm_groups=32,channel_mults=(1, 2, 4, 8, 8),attn_res=(32),res_blocks=2,dropout=0.2,image_size=256): 11 | super().__init__() 12 | noise_level_channel = inner_channel 13 | self.noise_level_mlp = nn.Sequential(PositionalEncoding(inner_channel), nn.Linear(inner_channel, inner_channel * 4), Swish(), nn.Linear(inner_channel * 4, inner_channel)) 14 | num_mults = len(channel_mults) 15 | pre_channel = inner_channel 16 | feat_channels = [pre_channel] 17 | now_res = image_size 18 | downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)] 19 | self.conv_body_first = StyleLayer(3, pre_channel, 3, bias=True, activate=True) 20 | self.conv_body_down = nn.ModuleList() 21 | self.condition_scale1 = nn.ModuleList() 22 | self.condition_scale2 = nn.ModuleList() 23 | self.condition_shift = nn.ModuleList() 24 | for ind in range(num_mults): 25 | is_last = (ind == num_mults - 1) 26 | use_attn = (now_res in attn_res) 27 | channel_mult = inner_channel * channel_mults[ind] 28 | self.conv_body_down.append(StyleLayer(pre_channel, channel_mult, 3, downsample=True)) 29 | self.condition_scale1.append(EqualLinear(1, channel_mult, bias=True, bias_init_val=1, activation=None)) 30 | self.condition_scale2.append(EqualLinear(1, channel_mult, bias=True, bias_init_val=1, activation=None)) 31 | self.condition_shift.append(StyleLayer(pre_channel, channel_mult, 3, bias=True, activate=False)) 32 | for _ in range(0, res_blocks): 33 | downs.append(ResnetBlocWithAttn(pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 34 | feat_channels.append(channel_mult) 35 | pre_channel = channel_mult 36 | if not is_last: 37 | downs.append(Downsample(pre_channel)) 38 | feat_channels.append(pre_channel) 39 | now_res = now_res // 2 40 | self.downs = nn.ModuleList(downs) 41 | self.final_down1 = StyleLayer(512, 512, 3, downsample=False) 42 | self.final_down2 = StyleLayer(512, 256, 3, downsample=True) 43 | self.num_latent, self.num_style_feat = 4, 512 44 | self.final_linear = EqualLinear(2 *2 * 256, self.num_style_feat * self.num_latent, bias=True, activation='fused_lrelu') 45 | self.final_styleconv = StyleLayer(512, 512, 3) 46 | self.mid = nn.ModuleList([ 47 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=True), 48 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=False) 49 | ]) 50 | ups = [] 51 | for ind in reversed(range(num_mults)): 52 | is_last = (ind < 1) 53 | use_attn = (now_res in attn_res) 54 | channel_mult = inner_channel * channel_mults[ind] 55 | for _ in range(0, res_blocks + 1): 56 | ups.append(ResnetBlocWithAttn(pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 57 | pre_channel = channel_mult 58 | if not is_last: 59 | ups.append(INR_sr(pre_channel)) 60 | now_res = now_res * 2 61 | self.ups = nn.ModuleList(ups) 62 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 63 | 64 | def forward(self, x, lr, scaler, time): 65 | t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None 66 | feat = self.conv_body_first(lr) 67 | scales1, scales2, shifts = [], [], [] 68 | scale1 = self.condition_scale1[0](scaler) 69 | scales1.append(scale1.clone()) 70 | scale2 = self.condition_scale2[0](scaler) 71 | scales2.append(scale2.clone()) 72 | shift = self.condition_shift[0](feat) 73 | shifts.append(shift.clone()) 74 | j = 1 75 | for i in range(len(self.conv_body_down)): 76 | feat = self.conv_body_down[i](feat) 77 | if j < len(self.condition_scale1) : 78 | scale1 = self.condition_scale1[j](scaler) 79 | scales1.append(scale1.clone()) 80 | scale2 = self.condition_scale2[j](scaler) 81 | scales2.append(scale2.clone()) 82 | shift = self.condition_shift[j](feat) 83 | shifts.append(shift.clone()) 84 | j += 1 85 | feats = [] 86 | for i,layer in enumerate(self.downs): 87 | if isinstance(layer, ResnetBlocWithAttn): 88 | x = layer(x, t) 89 | else: 90 | x = layer(x) 91 | feats.append(x) 92 | for layer in self.mid: 93 | if isinstance(layer, ResnetBlocWithAttn): 94 | x = layer(x, t) 95 | else: 96 | x = layer(x) 97 | 98 | for i, layer in enumerate(self.ups): 99 | if isinstance(layer, ResnetBlocWithAttn): 100 | x = layer(torch.cat((x, feats.pop()), dim=1), t) 101 | else: 102 | x = layer(x, feats[-1].shape[2:], scales1.pop(), scales2.pop(), shifts.pop()) 103 | x = rearrange(x, 'b (h w) c -> b c h w', h=feats[-1].shape[-1]) 104 | return self.final_conv(x) 105 | 106 | -------------------------------------------------------------------------------- /code/model/modules/INR_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.utils import SpatialEncoding, gen_feat 4 | 5 | 6 | class INR_align(nn.Module): 7 | def __init__(self, pos_dim=24, stride=1, require_grad=True): 8 | super(INR_align, self).__init__() 9 | self.pos_dim = pos_dim 10 | self.stride = stride 11 | norm_layer = nn.BatchNorm1d 12 | self.pos1 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad) 13 | self.pos2 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad) 14 | self.pos3 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad) 15 | self.pos4 = SpatialEncoding(2, self.pos_dim, require_grad=require_grad) 16 | self.pos_dim += 2 17 | in_dim = 4 * (256 + self.pos_dim) 18 | self.inr = nn.Sequential( 19 | nn.Conv1d(in_dim, 512, 1), norm_layer(512), nn.ReLU(), 20 | nn.Conv1d(512, 256, 1), norm_layer(256), nn.ReLU(), 21 | nn.Conv1d(256, 256, 1), norm_layer(256), nn.ReLU(), 22 | nn.Conv1d(256, 3, 1) 23 | ) 24 | 25 | def forward(self, x, size, level=0, after_cat=False): 26 | h, w = size 27 | if not after_cat: 28 | rel_coord, q_feat = gen_feat(x, [h, w]) 29 | rel_coord = eval('self.pos' + str(level))(rel_coord) 30 | x = torch.cat([rel_coord, q_feat], dim=-1) 31 | else: 32 | x = self.inr(x) 33 | x = x.view(x.shape[0], -1, h, w) 34 | return x 35 | 36 | 37 | -------------------------------------------------------------------------------- /code/model/modules/INR_sr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from utils.utils import make_coord 5 | from .Style import StyleLayer_norm_scale_shift 6 | 7 | class INR_sr(nn.Module): 8 | def __init__(self, dim, feat_unfold=False, local_ensemble=False, cell_decode=False): 9 | super().__init__() 10 | self.feat_unfold = feat_unfold 11 | self.local_ensemble = local_ensemble 12 | self.cell_decode = cell_decode 13 | self.style = StyleLayer_norm_scale_shift(dim, dim, kernel_size=3, num_style_feat=512, demodulate=True, sample_mode=None, resample_kernel=(1, 3, 3, 1)) 14 | if self.cell_decode: 15 | self.inr = nn.Sequential(nn.Linear(dim + 2 + 2 , 256),nn.Linear(256, dim)) 16 | else: 17 | self.inr = nn.Sequential(nn.Linear(dim + 2, 256),nn.Linear(256, dim)) 18 | def forward(self, x, shape, scale1, scale2, shift): 19 | coord = make_coord(shape).repeat(x.shape[0], 1, 1).to('cuda') 20 | cell = torch.ones_like(coord) 21 | cell[:, 0] *= 2 / shape[-2] 22 | cell[:, 1] *= 2 / shape[-1] 23 | return self.query_rgb(x, scale1, scale2, shift, coord, cell) 24 | 25 | def query_rgb(self, x_feat, scale1, scale2, shift, coord, cell=None): 26 | 27 | feat = self.style(x_feat, noise=None, scale1=scale1, scale2=scale2, shift=shift) 28 | if self.feat_unfold: 29 | feat = F.unfold(feat, 3, padding=1).view( 30 | feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) 31 | 32 | if self.local_ensemble: 33 | vx_lst = [-1, 1] 34 | vy_lst = [-1, 1] 35 | eps_shift = 1e-6 36 | else: 37 | vx_lst, vy_lst, eps_shift = [0], [0], 0 38 | 39 | rx = 2 / feat.shape[-2] / 2 40 | ry = 2 / feat.shape[-1] / 2 41 | 42 | feat_coord = make_coord(feat.shape[-2:], flatten=False).to('cuda').permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) 43 | 44 | preds = [] 45 | areas = [] 46 | for vx in vx_lst: 47 | for vy in vy_lst: 48 | coord_ = coord.clone() 49 | coord_[:, :, 0] += vx * rx + eps_shift 50 | coord_[:, :, 1] += vy * ry + eps_shift 51 | coord_.clamp_(-1 + 1e-6, 1 - 1e-6) 52 | q_feat = F.grid_sample( 53 | feat, coord_.flip(-1).unsqueeze(1), 54 | mode='nearest', align_corners=False)[:, :, 0, :] \ 55 | .permute(0, 2, 1) 56 | q_coord = F.grid_sample( 57 | feat_coord, coord_.flip(-1).unsqueeze(1), 58 | mode='nearest', align_corners=False)[:, :, 0, :] \ 59 | .permute(0, 2, 1) 60 | rel_coord = coord - q_coord 61 | # print(rel_coord) 62 | rel_coord[:, :, 0] *= feat.shape[-2] 63 | rel_coord[:, :, 1] *= feat.shape[-1] 64 | inp = torch.cat([q_feat, rel_coord], dim=-1) 65 | 66 | if self.cell_decode: 67 | rel_cell = cell.clone() 68 | rel_cell[:, :, 0] *= feat.shape[-2] 69 | rel_cell[:, :, 1] *= feat.shape[-1] 70 | inp = torch.cat([inp, rel_cell], dim=-1) 71 | 72 | bs, q = coord.shape[:2] 73 | pred = self.inr(inp.view(bs * q, -1)).view(bs, q, -1) 74 | preds.append(pred) 75 | 76 | area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1]) 77 | areas.append(area + 1e-9) 78 | 79 | tot_area = torch.stack(areas).sum(dim=0) 80 | if self.local_ensemble: 81 | t = areas[0]; areas[0] = areas[3]; areas[3] = t 82 | t = areas[1]; areas[1] = areas[2]; areas[2] = t 83 | ret = 0 84 | for pred, area in zip(preds, areas): 85 | ret = ret + pred * (area / tot_area).unsqueeze(-1) 86 | return ret -------------------------------------------------------------------------------- /code/model/modules/MIFA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .INR_align import INR_align 4 | 5 | 6 | class MIFA(nn.Module): 7 | 8 | def __init__(self): 9 | super(MIFA, self).__init__() 10 | self.INR_align = INR_align() 11 | norm_layer = nn.BatchNorm2d 12 | self.head = nn.Sequential(nn.Conv2d(2048, 256, kernel_size=1), norm_layer(256), nn.ReLU(inplace=True)) 13 | self.enc1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1), norm_layer(256), nn.ReLU(inplace=True)) 14 | self.enc2 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=1), norm_layer(256), nn.ReLU(inplace=True)) 15 | self.enc3 = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=1), norm_layer(256), nn.ReLU(inplace=True)) 16 | 17 | def forward(self, x): 18 | x1, x2, x3, x4 = x 19 | aspp_out = self.head(x4) 20 | x1 = self.enc1(x1) 21 | x2 = self.enc2(x2) 22 | x3 = self.enc3(x3) 23 | context = [] 24 | h, w = x1.shape[-2], x1.shape[-1] 25 | target_feat = [x1, x2, x3, aspp_out] 26 | for i, feat in enumerate(target_feat): 27 | context.append(self.INR_align(feat, size=[h, w], level=i+1)) 28 | context = torch.cat(context, dim=-1).permute(0,2,1) 29 | res = self.INR_align(context, size=[h, w], after_cat=True) 30 | return res 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /code/model/modules/SFFI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.fft as fft 4 | from einops import rearrange 5 | 6 | class NonLocalAttention(nn.Module): 7 | def __init__(self, in_channels) -> None: 8 | super().__init__() 9 | self.conv = nn.Conv2d(in_channels, 1, 1) 10 | self.softmax = nn.Softmax(dim=1) 11 | 12 | def forward(self, feat) -> torch.Tensor: 13 | b, c, h, w = feat.shape 14 | out_feat = self.conv(feat) 15 | out_feat = rearrange(out_feat, 'b c h w -> b (h w) c') 16 | out_feat = torch.unsqueeze(out_feat, -1) 17 | out_feat = self.softmax(out_feat) 18 | out_feat = torch.squeeze(out_feat, -1) 19 | identity = rearrange(feat, 'b c h w -> b c (h w)') 20 | out_feat = torch.matmul(identity, out_feat) 21 | out_feat = torch.unsqueeze(out_feat, -1) 22 | return out_feat 23 | 24 | 25 | class NonLocalAttentionBlock(nn.Module): 26 | 27 | def __init__(self, in_channels) -> None: 28 | super().__init__() 29 | self.nonlocal_attention = NonLocalAttention(in_channels) 30 | self.global_transform = nn.Sequential( 31 | nn.Conv2d(in_channels, in_channels, 1), 32 | nn.LeakyReLU(inplace=True), 33 | nn.Conv2d(in_channels, in_channels, 1), 34 | nn.LeakyReLU(inplace=True), 35 | ) 36 | 37 | def forward(self, feat): 38 | out_feat = self.nonlocal_attention(feat) 39 | out_feat = self.global_transform(out_feat) 40 | return feat + out_feat 41 | 42 | 43 | class SpectralTransformer(nn.Module): 44 | def __init__(self, in_channels): 45 | super().__init__() 46 | self.conv = nn.Conv2d(in_channels * 2, in_channels * 2, 3, padding=1) 47 | self.lrelu = nn.LeakyReLU(inplace=True) 48 | 49 | def forward(self, feat: torch.Tensor) -> torch.Tensor: 50 | b, c, h, w = feat.shape 51 | out_feat = fft.rfft2(feat) 52 | out_feat = torch.cat([out_feat.real, out_feat.imag], dim=1) 53 | out_feat = self.conv(out_feat) 54 | out_feat = self.lrelu(out_feat) 55 | c = out_feat.shape[1] 56 | out_feat = torch.complex(out_feat[:, : c // 2], out_feat[:, c // 2 :]) 57 | out_feat = fft.irfft2(out_feat) 58 | return out_feat 59 | 60 | 61 | class FourierConvolutionBlock(nn.Module): 62 | def __init__(self, in_channels): 63 | super().__init__() 64 | self.half_channels = in_channels // 2 65 | self.func_g_to_g = SpectralTransformer(self.half_channels) 66 | self.func_g_to_l = nn.Sequential( 67 | nn.Conv2d(self.half_channels, self.half_channels, kernel_size=3, padding=1), 68 | nn.LeakyReLU(inplace=True), 69 | ) 70 | self.func_l_to_g = nn.Sequential( 71 | nn.Conv2d(self.half_channels, self.half_channels, kernel_size=1), 72 | NonLocalAttentionBlock(self.half_channels), 73 | ) 74 | self.func_l_to_l = nn.Sequential( 75 | nn.Conv2d(self.half_channels, self.half_channels, kernel_size=3, padding=1), 76 | nn.LeakyReLU(inplace=True), 77 | ) 78 | 79 | def forward(self, feat: torch.Tensor) -> torch.Tensor: 80 | global_feat = feat[:, self.half_channels :] 81 | local_feat = feat[:, : self.half_channels] 82 | out_global_feat = self.func_l_to_g(local_feat) + self.func_g_to_g(global_feat) 83 | out_local_feat = self.func_g_to_l(global_feat) + self.func_l_to_l(local_feat) 84 | return torch.cat([out_global_feat, out_local_feat], 1) 85 | 86 | 87 | class SFFI(nn.Module): 88 | def __init__(self, input_dim=3,hidden_dim = 64) -> None: 89 | super().__init__() 90 | self.first_block = nn.Sequential( 91 | nn.Conv2d(input_dim, hidden_dim, 1), 92 | FourierConvolutionBlock(hidden_dim), 93 | nn.Conv2d(hidden_dim, hidden_dim, 1), 94 | FourierConvolutionBlock(hidden_dim), 95 | nn.Conv2d(hidden_dim, hidden_dim, 1), 96 | ) 97 | self.second_block = nn.Sequential( 98 | FourierConvolutionBlock(hidden_dim), 99 | nn.Conv2d(hidden_dim, hidden_dim, 1), 100 | FourierConvolutionBlock(hidden_dim), 101 | nn.Conv2d(hidden_dim, hidden_dim, 1), 102 | ) 103 | self.third_block = nn.Sequential( 104 | FourierConvolutionBlock(hidden_dim), 105 | nn.Conv2d(hidden_dim, hidden_dim, 1), 106 | FourierConvolutionBlock(hidden_dim), 107 | nn.Conv2d(hidden_dim, hidden_dim, 1), 108 | FourierConvolutionBlock(hidden_dim), 109 | nn.Conv2d(hidden_dim, hidden_dim, 1), 110 | FourierConvolutionBlock(hidden_dim), 111 | nn.Conv2d(hidden_dim, hidden_dim, 1), 112 | ) 113 | self.fourth_block = nn.Sequential( 114 | FourierConvolutionBlock(hidden_dim), 115 | nn.Conv2d(hidden_dim, hidden_dim//4, kernel_size=3, padding=1, bias=False), 116 | nn.BatchNorm2d(num_features=hidden_dim//4), 117 | nn.ReLU(),# nn.Sigmoid, 118 | nn.Conv2d( hidden_dim//4, 1, kernel_size=1, padding=0, bias=False), 119 | ) 120 | def forward(self, feat): 121 | first_feat = self.first_block(feat) 122 | second_feat = first_feat + self.second_block(first_feat) 123 | third_feat = second_feat + self.third_block(second_feat) 124 | result=self.fourth_block(third_feat) 125 | return result 126 | 127 | 128 | -------------------------------------------------------------------------------- /code/model/modules/Style.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 7 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 8 | return out 9 | 10 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 11 | input = input.permute(0, 2, 3, 1) 12 | _, in_h, in_w, minor = input.shape 13 | kernel_h, kernel_w = kernel.shape 14 | out = input.view(-1, in_h, 1, in_w, 1, minor) 15 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 16 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 17 | 18 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 19 | out = out[:,max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),:,] 20 | out = out.permute(0, 3, 1, 2) 21 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 22 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 23 | out = F.conv2d(out, w) 24 | out = out.reshape(-1,minor,in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,) 25 | return out[:, :, ::down_y, ::down_x] 26 | 27 | class FusedLeakyReLU(nn.Module): 28 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 29 | super().__init__() 30 | 31 | self.bias = nn.Parameter(torch.zeros(channel)) 32 | self.negative_slope = negative_slope 33 | self.scale = scale 34 | 35 | def forward(self, input): 36 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 37 | 38 | 39 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 40 | return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope) 41 | 42 | 43 | 44 | class StyleLayer(nn.Sequential): 45 | def __init__(self, in_channels, out_channels, kernel_size, downsample=False, resample_kernel=(1, 3, 1), bias=True, activate=True): 46 | layers = [] 47 | if downsample: 48 | layers.append( 49 | UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)) 50 | stride = 2 51 | self.padding = 1 52 | else: 53 | stride = 1 54 | self.padding = kernel_size // 2 55 | 56 | layers.append(EqualConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias and not activate)) 57 | if activate: 58 | if bias: 59 | layers.append(FusedLeakyReLU(out_channels)) 60 | else: 61 | layers.append(ScaledLeakyReLU(0.2)) 62 | 63 | super(StyleLayer, self).__init__(*layers) 64 | 65 | 66 | 67 | class UpFirDnSmooth(nn.Module): 68 | def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1): 69 | super(UpFirDnSmooth, self).__init__() 70 | self.upsample_factor = upsample_factor 71 | self.downsample_factor = downsample_factor 72 | self.kernel = make_resample_kernel(resample_kernel) 73 | if upsample_factor > 1: 74 | self.kernel = self.kernel * (upsample_factor**2) 75 | 76 | if upsample_factor > 1: 77 | pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1) 78 | self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1) 79 | elif downsample_factor > 1: 80 | pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1) 81 | self.pad = ((pad - 1) // 2, pad // 2) 82 | else: 83 | raise NotImplementedError 84 | 85 | def forward(self, x): 86 | out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad) 87 | return out 88 | 89 | def __repr__(self): 90 | return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}' 91 | f', downsample_factor={self.downsample_factor})') 92 | 93 | def make_resample_kernel(k): 94 | k = torch.tensor(k, dtype=torch.float32) 95 | if k.ndim == 1: 96 | k = k[None, :] * k[:, None] 97 | k /= k.sum() 98 | return k 99 | 100 | class EqualConv2d(nn.Module): 101 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): 102 | super(EqualConv2d, self).__init__() 103 | self.in_channels = in_channels 104 | self.out_channels = out_channels 105 | self.kernel_size = kernel_size 106 | self.stride = stride 107 | self.padding = padding 108 | self.scale = 1 / math.sqrt(in_channels * kernel_size**2) 109 | self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) 110 | if bias: 111 | self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) 112 | else: 113 | self.register_parameter('bias', None) # None 114 | 115 | def forward(self, x): 116 | out = F.conv2d(x, self.weight * self.scale,bias=self.bias, stride=self.stride, padding=self.padding,) 117 | return out 118 | 119 | def __repr__(self): 120 | return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' 121 | f'out_channels={self.out_channels}, ' 122 | f'kernel_size={self.kernel_size},' 123 | f' stride={self.stride}, padding={self.padding}, ' 124 | f'bias={self.bias is not None})') 125 | 126 | class ScaledLeakyReLU(nn.Module): 127 | def __init__(self, negative_slope=0.2): 128 | super(ScaledLeakyReLU, self).__init__() 129 | self.negative_slope = negative_slope 130 | 131 | def forward(self, x): 132 | out = F.leaky_relu(x, negative_slope=self.negative_slope) 133 | return out * math.sqrt(2) 134 | 135 | class EqualLinear(nn.Module): 136 | def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): 137 | super(EqualLinear, self).__init__() 138 | self.in_channels = in_channels 139 | self.out_channels = out_channels 140 | self.lr_mul = lr_mul 141 | self.activation = activation 142 | if self.activation not in ['fused_lrelu', None]: 143 | raise ValueError(f'Wrong activation value in EqualLinear: {activation}' 144 | "Supported ones are: ['fused_lrelu', None].") 145 | self.scale = (1 / math.sqrt(in_channels)) * lr_mul 146 | self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) 147 | if bias: 148 | self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) 149 | else: 150 | self.register_parameter('bias', None) 151 | 152 | def forward(self, x): 153 | if self.bias is None: 154 | bias = None 155 | else: 156 | bias = self.bias * self.lr_mul 157 | if self.activation == 'fused_lrelu': 158 | out = F.linear(x, self.weight * self.scale) 159 | out = fused_leaky_relu(out, bias) 160 | else: 161 | out = F.linear(x, self.weight * self.scale, bias=bias) 162 | return out 163 | 164 | def __repr__(self): 165 | return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' 166 | f'out_channels={self.out_channels}, bias={self.bias is not None})') 167 | 168 | 169 | class StyleLayer_norm_scale_shift(nn.Module): 170 | def __init__(self, 171 | in_channels, 172 | out_channels, 173 | kernel_size, 174 | num_style_feat, 175 | demodulate=True, 176 | sample_mode=None, 177 | resample_kernel=(1, 3, 3, 1)): 178 | super(StyleLayer_norm_scale_shift, self).__init__() 179 | self.modulated_conv = ModulatedLayer( 180 | in_channels, 181 | out_channels, 182 | kernel_size, 183 | num_style_feat, 184 | demodulate=demodulate, 185 | sample_mode=sample_mode, 186 | resample_kernel=resample_kernel) 187 | self.weight = nn.Parameter(torch.zeros(1)) 188 | self.activate = FusedLeakyReLU(out_channels) 189 | self.norm = Norm2Scale() 190 | 191 | def forward(self, x, noise=None, scale1=None, scale2=None, shift=None): 192 | scale1, scale2 = self.norm(scale1, scale2) 193 | out = x * scale1.view(-1, x.size(1), 1, 1) + shift * scale2.view(-1, x.size(1), 1, 1) 194 | out = self.activate(out) 195 | return out 196 | 197 | class ModulatedLayer(nn.Module): 198 | def __init__(self, 199 | in_channels, 200 | out_channels, 201 | kernel_size, 202 | num_style_feat, 203 | demodulate=True, 204 | sample_mode=None, 205 | resample_kernel=(1, 3, 3, 1), 206 | eps=1e-8): 207 | super(ModulatedLayer, self).__init__() 208 | self.in_channels = in_channels 209 | self.out_channels = out_channels 210 | self.kernel_size = kernel_size 211 | self.demodulate = demodulate 212 | self.sample_mode = sample_mode 213 | self.eps = eps 214 | 215 | if self.sample_mode == 'upsample': 216 | self.smooth = UpFirDnSmooth( 217 | resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size) 218 | elif self.sample_mode == 'downsample': 219 | self.smooth = UpFirDnSmooth( 220 | resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size) 221 | elif self.sample_mode is None: 222 | pass 223 | else: 224 | raise ValueError(f'Wrong sample mode {self.sample_mode}, ' 225 | "supported ones are ['upsample', 'downsample', None].") 226 | self.scale = 1 / math.sqrt(in_channels * kernel_size**2) 227 | self.modulation = EqualLinear( 228 | num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) 229 | self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) 230 | self.padding = kernel_size // 2 231 | 232 | def forward(self, x, style): 233 | b, c, h, w = x.shape 234 | style = self.modulation(style).view(b, 1, c, 1, 1) 235 | weight = self.scale * self.weight * style 236 | if self.demodulate: 237 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) 238 | weight = weight * demod.view(b, self.out_channels, 1, 1, 1) 239 | weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) 240 | if self.sample_mode == 'upsample': 241 | x = x.view(1, b * c, h, w) 242 | weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size) 243 | weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size) 244 | out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b) 245 | out = out.view(b, self.out_channels, *out.shape[2:4]) 246 | out = self.smooth(out) 247 | elif self.sample_mode == 'downsample': 248 | x = self.smooth(x) 249 | x = x.view(1, b * c, *x.shape[2:4]) 250 | out = F.conv2d(x, weight, padding=0, stride=2, groups=b) 251 | out = out.view(b, self.out_channels, *out.shape[2:4]) 252 | else: 253 | x = x.view(1, b * c, h, w) 254 | out = F.conv2d(x, weight, padding=self.padding, groups=b) 255 | out = out.view(b, self.out_channels, *out.shape[2:4]) 256 | return out 257 | 258 | def __repr__(self): 259 | return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' 260 | f'out_channels={self.out_channels}, ' 261 | f'kernel_size={self.kernel_size}, ' 262 | f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') 263 | 264 | class Norm2Scale(nn.Module): 265 | def forward(self, scale1, scale2): 266 | scales_norm = scale1 ** 2 + scale2 ** 2 + 1e-8 267 | return scale1 * torch.rsqrt(scales_norm), scale2 * torch.rsqrt(scales_norm) 268 | 269 | -------------------------------------------------------------------------------- /code/model/modules/half_IDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .Style import StyleLayer 4 | from utils.utils import Uhalf_ResnetBlocWithAttn, Downsample 5 | 6 | class half_IDN(nn.Module): 7 | def __init__(self, in_channel=6, inner_channel=64, norm_groups=32, channel_mults=(1, 2, 4, 8, 16), attn_res=(32), res_blocks=2, dropout=0.2, image_size=256): 8 | super().__init__() 9 | num_mults = len(channel_mults) 10 | pre_channel = inner_channel 11 | feat_channels = [pre_channel] 12 | now_res = image_size 13 | downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)] 14 | self.conv_body_first = StyleLayer(3, pre_channel, 3, bias=True, activate=True) 15 | self.conv_body_down = nn.ModuleList() 16 | self.condition_shift = nn.ModuleList() 17 | for ind in range(num_mults): 18 | is_last = (ind == num_mults - 1) 19 | use_attn = (now_res in attn_res) 20 | channel_mult = inner_channel * channel_mults[ind] 21 | self.conv_body_down.append(StyleLayer(pre_channel, channel_mult, 3, downsample=True)) 22 | self.condition_shift.append(StyleLayer(pre_channel, channel_mult, 3, bias=True, activate=False)) 23 | for _ in range(0, res_blocks): 24 | downs.append(Uhalf_ResnetBlocWithAttn(pre_channel, channel_mult, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 25 | feat_channels.append(channel_mult) 26 | pre_channel = channel_mult 27 | if not is_last: 28 | downs.append(Downsample(pre_channel)) 29 | feat_channels.append(pre_channel) 30 | now_res = now_res // 2 31 | self.downs = nn.ModuleList(downs) 32 | 33 | def forward(self, feat, img): 34 | feat = self.conv_body_first(feat) 35 | # g0 = self.condition_shift[0](feat) 36 | feat = self.conv_body_down[0](feat) 37 | g1 = self.condition_shift[1](feat) 38 | feat = self.conv_body_down[1](feat) 39 | g2 = self.condition_shift[2](feat) 40 | feat = self.conv_body_down[2](feat) 41 | g3 = self.condition_shift[3](feat) 42 | feat = self.conv_body_down[3](feat) 43 | g4 = self.condition_shift[4](feat) 44 | feat = self.downs[0](img) 45 | f0_1 = self.downs[1](feat) 46 | f0 = self.downs[2](f0_1) 47 | f1_1 = self.downs[3](f0) 48 | f1_2 = self.downs[4](f1_1) 49 | f1 = self.downs[5](f1_2) 50 | f2_1 = self.downs[6](f1) 51 | f2_2 = self.downs[7](f2_1) 52 | f2 = self.downs[8](f2_2) 53 | f3_1 = self.downs[9](f2) 54 | f3_2 = self.downs[10](f3_1) 55 | f3 = self.downs[11](f3_2) 56 | f4_1 = self.downs[12](f3) 57 | f4_2 = self.downs[13](f4_1) 58 | f4 = self.downs[14](f4_2) 59 | m1 = torch.cat([g1, f1], dim=1) 60 | m2 = torch.cat([g2, f2], dim=1) 61 | m3 = torch.cat([g3, f3], dim=1) 62 | m4 = torch.cat([g4, f4], dim=1) 63 | return [m1, m2, m3, m4] 64 | 65 | -------------------------------------------------------------------------------- /code/train_DCHFR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import wandb 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | from datetime import datetime 9 | from torch.utils.data import DataLoader 10 | from torch.optim import lr_scheduler 11 | from utils.DCHFRDataset import DCHFRDataset 12 | from utils.utils import set_device, data_process, get_current_visuals, save_network, tensor2img, calculate_psnr 13 | from model.DefineNet import define_DCHFRnet 14 | 15 | def main(args): 16 | device = torch.device('cuda') 17 | 18 | # data set 19 | train_set = DCHFRDataset(args.root, 'train', if_transform=False, dataset_type=args.dataset_type, base_size=args.base_size) 20 | val_set = DCHFRDataset(args.root, 'val', if_transform=False, dataset_type=args.dataset_type, base_size=args.base_size) 21 | print("---------------finished loading dataset---------------") 22 | print('train data number:',len(train_set)) 23 | print('val data number:', len(val_set)) 24 | 25 | # data loader 26 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker) 27 | val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_worker) 28 | print("---------------finished loading dataloader---------------") 29 | 30 | # model 31 | DCHFRNet = set_device(define_DCHFRnet(args),device) 32 | DCHFRNet.set_loss(device) 33 | DCHFRNet.set_new_noise_schedule(device) 34 | 35 | DCHFRNet.train() 36 | optim_params = list(DCHFRNet.parameters()) 37 | DCHFRopt = torch.optim.Adam(optim_params, lr=1e-4) 38 | scheduler = lr_scheduler.MultiStepLR(DCHFRopt, milestones=[1000], gamma=0.2) 39 | print("---------------finished loading model---------------") 40 | 41 | # train 42 | current_step = 0 43 | print('begin step',current_step) 44 | current_epoch = 0 45 | print('begin epoch', current_epoch) 46 | n_epoch = args.n_epoch 47 | print("---------------finished setting noise schedule---------------") 48 | if args.phase == 'train': 49 | max_psnr = -1e18 50 | while current_epoch < n_epoch: 51 | train_loss = 0.0 52 | scheduler.step() 53 | for _, train_data in tqdm(enumerate(train_loader), total=len(train_loader)): 54 | current_step += 1 55 | feed_data = data_process(train_data, mode='train') 56 | feed_data = set_device(feed_data,device) 57 | loss = DCHFRNet(feed_data) 58 | b, c, h, w = feed_data['hr'].shape 59 | loss = loss.sum() / int(b * c * h * w) 60 | DCHFRopt.zero_grad() 61 | loss.backward() 62 | DCHFRopt.step() 63 | train_loss += loss 64 | wandb.log({'train_epoch': current_epoch, 'train_loss': train_loss/len(train_loader)}) 65 | current_epoch += 1 66 | if current_epoch < 301: 67 | val_freq = 1 68 | else: 69 | val_freq = 30 70 | if current_epoch % val_freq == 0: 71 | print("---------------validation---------------") 72 | avg_psnr = 0.0 73 | idx = 0 74 | if args.dataset_type == "BigReal" or args.dataset_type == "BigSim": 75 | mean = [-0.1246] 76 | std = [1.0923] 77 | elif args.dataset_type == "NUAA" or args.dataset_type == "NUDT" or args.dataset_type == "IRSTD": 78 | mean = [.485, .456, .406] 79 | std = [.229, .224, .225] 80 | result_hr_path = args.results_hr.rsplit('/', 1)[0] + '/{}/'.format(current_epoch) + args.results_hr.rsplit('/', 1)[1] 81 | result_sr_path = args.results_sr.rsplit('/', 1)[0] + '/{}/'.format(current_epoch) + args.results_sr.rsplit('/', 1)[1] 82 | result_lr_path = args.results_lr.rsplit('/', 1)[0] + '/{}/'.format(current_epoch) + args.results_lr.rsplit('/', 1)[1] 83 | os.makedirs('{}'.format(result_hr_path), exist_ok=True) 84 | os.makedirs('{}'.format(result_sr_path), exist_ok=True) 85 | os.makedirs('{}'.format(result_lr_path), exist_ok=True) 86 | DCHFRNet.set_new_noise_schedule(device) 87 | for _, val_data in tqdm(enumerate(val_loader), total=len(val_loader)): 88 | idx += 1 89 | feed_data = data_process(val_data, mode='val') 90 | feed_data = set_device(feed_data, device) 91 | DCHFRNet.eval() 92 | with torch.no_grad(): 93 | SR_imgs = DCHFRNet.super_resolution(feed_data, continous=False, use_ddim=True) 94 | visuals = get_current_visuals(SR_imgs,feed_data) 95 | sr_img = tensor2img(visuals['SR']) 96 | hr_img = tensor2img(visuals['HR']) 97 | avg_psnr += calculate_psnr(sr_img, hr_img) 98 | 99 | sr_img = visuals['SR'].squeeze().cpu().numpy() 100 | for i in range(3): 101 | sr_img[i] = sr_img[i] * std[i] + mean[i] 102 | sr_img = sr_img * 255 103 | sr_img = np.transpose(sr_img, (1, 2, 0)) 104 | 105 | hr_img = visuals['HR'].squeeze().cpu().numpy() 106 | for i in range(3): 107 | hr_img[i] = hr_img[i] * std[i] + mean[i] 108 | hr_img = hr_img * 255 109 | hr_img = np.transpose(hr_img, (1, 2, 0)) 110 | 111 | lr_img = visuals['LR'].squeeze().cpu().numpy() 112 | for i in range(3): 113 | lr_img[i] = lr_img[i] * std[i] + mean[i] 114 | lr_img = lr_img * 255 115 | lr_img = np.transpose(lr_img, (1, 2, 0)) 116 | cv2.imwrite('{}/{}_hr.png'.format(result_hr_path, idx), cv2.cvtColor(hr_img, cv2.COLOR_RGB2GRAY)) 117 | cv2.imwrite('{}/{}_sr.png'.format(result_sr_path, idx), cv2.cvtColor(sr_img, cv2.COLOR_RGB2GRAY)) 118 | cv2.imwrite('{}/{}_lr.png'.format(result_lr_path, idx), cv2.cvtColor(lr_img, cv2.COLOR_RGB2GRAY)) 119 | print("---------------result---------------") 120 | avg_psnr = avg_psnr / idx 121 | wandb.log({'train_epoch': current_epoch, 'psnr': avg_psnr}) 122 | print("psnr:", avg_psnr) 123 | if avg_psnr >= max_psnr: 124 | max_psnr = avg_psnr 125 | save_network(args.checkpoint, current_epoch, current_step, DCHFRNet, SRopt, best='psnr_{}'.format(max_psnr)) 126 | wandb.log({'train_epoch': current_epoch, 'max_psnr': max_psnr}) 127 | print("---------------finished saving best weights---------------") 128 | DCHFRNet.train() 129 | DCHFRNet.set_new_noise_schedule(device) 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--root', type=str, default='/Datasets/', help='choose datasets') 135 | parser.add_argument('--dataset_type', type=str, default='BigReal', help='BigReal or BigSim or NUAA or NUDT or IRSTD') 136 | parser.add_argument('--base_size', type=int, default=512, help='base_size') 137 | parser.add_argument('--crop_size', type=int, default=480, help='crop_size') 138 | parser.add_argument('--batch_size', type=int, default=8, help='train_batch_size') 139 | parser.add_argument('--val_batch_size', type=int, default=1, help='val_batch_size') 140 | parser.add_argument('--num_worker', type=int, default=8, help='num_workers') 141 | parser.add_argument('--n_epoch', type=int, default=2000, help='iter number') 142 | parser.add_argument('--phase', type=str, default='train', help='train or val') 143 | parser.add_argument('--results_hr', type=str, default='results/hr', help='hr result fold') 144 | parser.add_argument('--results_sr', type=str, default='results/sr', help='sr result fold') 145 | parser.add_argument('--results_lr', type=str, default='results/lr', help='lr result fold') 146 | parser.add_argument('--checkpoint', type=str, default='checkpoints', help='checkpoint fold') 147 | args = parser.parse_args() 148 | os.environ["WANDB_API_KEY"] = "xxxx" 149 | wandb.login() 150 | nowtime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 151 | wandb.init(project='train_DCHFR', config=args.__dict__, name='BigReal' + nowtime, save_code=False) 152 | num_gpus = torch.cuda.device_count() 153 | print("Number of available GPUs:", num_gpus) 154 | main(args) 155 | -------------------------------------------------------------------------------- /code/train_ISDTD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import wandb 4 | import argparse 5 | from tqdm import tqdm 6 | import os.path as ops 7 | from datetime import datetime 8 | from utils.ISDTDDataset import ISDTDDataset 9 | from torch.utils.data import DataLoader 10 | from model.DefineNet import define_trained_DCHFRnet, define_ISDTDnet 11 | from utils.utils import set_device 12 | from utils.BCEDiceloss import BCEDiceloss 13 | from torch.optim import lr_scheduler 14 | from utils.Imetrics import SigmoidMetric, SamplewiseSigmoidMetric 15 | from utils.Pmetric import ROCMetric, PD_FA 16 | 17 | def main(args): 18 | device = torch.device('cuda') 19 | 20 | train_set = ISDTDDataset(args.root, 'train', if_transform=True, base_size=args.base_size, dataset_type=args.dataset_type) 21 | val_set = ISDTDDataset(args.root, 'val', if_transform=True, base_size=args.base_size, dataset_type=args.dataset_type) 22 | print("---------------finished loading dataset---------------") 23 | print('train data number:',len(train_set)) 24 | print('val data number:', len(val_set)) 25 | 26 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker) 27 | val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_worker) 28 | print("---------------finished loading dataloader---------------") 29 | 30 | DCHFRnet = set_device(define_trained_DCHFRnet(args), device) 31 | DCHFR_checkpoint_path = args.checkpoint_path 32 | DCHFR_checkpoint = torch.load(DCHFR_checkpoint_path) 33 | DCHFRnet.load_state_dict(DCHFR_checkpoint, strict=False) 34 | for param in DCHFRnet.parameters(): 35 | param.requires_grad = False 36 | print("---------------finished loading DCHFR checkpoint---------------") 37 | 38 | net = set_device(define_ISDTDnet(args), device) 39 | loss_func = BCEDiceloss().to(device) 40 | eval_iou = SigmoidMetric() 41 | eval_niou = SamplewiseSigmoidMetric(1, score_thresh=0.5) 42 | eval_ROC = ROCMetric(1, 10) 43 | eval_PD_FA = PD_FA(1, 10) 44 | net.train() 45 | optim_params = list(net.parameters()) 46 | opt = torch.optim.Adam(optim_params, lr=1e-4) 47 | scheduler = lr_scheduler.MultiStepLR(opt, milestones=[50,100,150], gamma=0.5) 48 | print("---------------finished loading ISDTD model---------------") 49 | 50 | current_epoch = 0 51 | n_epoch = args.n_epoch 52 | best_iou = 0 53 | best_niou = 0 54 | current_step = 0 55 | if args.phase == 'train': 56 | while current_epoch < n_epoch: 57 | train_loss = 0.0 58 | tbar = tqdm(train_loader) 59 | for _, train_data in enumerate(tbar): 60 | train_data = set_device(train_data, device) 61 | img = train_data['img'] 62 | feat = DCHFRnet.encoder(img, img.shape[2:]) 63 | pred = net(feat,img) 64 | label = train_data['label'] 65 | label = label[:, 0:1, :, :]/255 66 | loss_ISDTD, bce, dice = loss_func(pred,label) 67 | current_step += 1 68 | opt.zero_grad() 69 | loss_ISDTD.backward() 70 | opt.step() 71 | scheduler.step() 72 | train_loss += loss_ISDTD 73 | tbar.set_description('Epoch %d, train loss %.4f, bce %.4f, dice %.4f' % (current_epoch, loss_ISDTD, bce, dice)) 74 | wandb.log({'train_epoch': current_epoch, 'train_loss': train_loss/len(train_loader)}) 75 | current_epoch += 1 76 | if current_epoch % args.val_freq == 0: 77 | print("---------------validation---------------") 78 | val_tbar = tqdm(val_loader) 79 | eval_iou.reset() 80 | eval_niou.reset() 81 | eval_PD_FA.reset() 82 | eval_ROC.reset() 83 | eval_loss = 0.0 84 | net.eval() 85 | for _, val_data in enumerate(val_tbar): 86 | val_img = val_data['img'] 87 | val_img = set_device(val_img, device) 88 | with torch.no_grad(): 89 | val_feat = DCHFRnet.encoder(val_img, val_img.shape[2:]) 90 | val_pred = net(val_feat,val_img) 91 | val_label = val_data['label'] 92 | val_label = set_device(val_label, device) 93 | val_label = val_label[:, 0:1, :, :]/255 94 | val_pred = val_pred 95 | val_loss_ISDTD, val_bce, val_dice = loss_func(val_pred, val_label) 96 | eval_loss += val_loss_ISDTD 97 | val_label = val_label.cpu() 98 | val_pred = val_pred.cpu() 99 | eval_iou.update(val_pred, val_label) 100 | eval_niou.update(val_pred, val_label) 101 | eval_ROC.update(val_pred, val_label) 102 | eval_PD_FA.update(val_pred, val_label) 103 | val_tbar.set_description('Epoch %d, val loss %.4f, bce %.4f, dice %.4f' % (current_epoch, val_loss_ISDTD, val_bce, val_dice)) 104 | FA, PD = eval_PD_FA.get(len(val_loader)) 105 | _, IoU = eval_iou.get() 106 | _, nIoU = eval_niou.get() 107 | _, _, _, _, F1_score = eval_ROC.get() 108 | FA = FA[5] * 1000000 109 | PD = PD[5] * 100 110 | F1_score = F1_score[5] 111 | wandb.log({'epoch': current_epoch, "ioU": IoU, "nioU": nIoU, 'test_loss': eval_loss/len(val_loader), "PD": PD, "FA": FA, "F1_score": F1_score}) 112 | if IoU > best_iou: 113 | best_iou = IoU 114 | if IoU > 7.0: 115 | pkl_name = 'best-Epoch-%3d_IoU-%.4f_nIoU-%.4f.pkl' % (current_epoch, best_iou, nIoU) 116 | torch.save(net.state_dict(), ops.join(args.save_path, pkl_name)) 117 | if nIoU > best_niou: 118 | best_niou = nIoU 119 | wandb.log({'epoch': current_epoch, "best_ioU": best_iou, "best_nioU": best_niou}) 120 | net.train() 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument('--root', type=str, default='/Datasets/', help='choose datasets') 126 | parser.add_argument('--dataset_type', type=str, default='BigReal', help='BigReal or BigSim or NUAA or NUDT or IRSTD') 127 | parser.add_argument('--checkpoint_path', type=str, default='/xxxx_gen.pth', help='DCHFR checkpoint path') 128 | parser.add_argument('--ISDTD_checkpoint_path', type=str, default='.pth', help='trained ISDTD checkpoint path') 129 | parser.add_argument('--base_size', type=int, default=512, help='img_size') 130 | parser.add_argument('--batch_size', type=int, default=4, help='train_batch_size') 131 | parser.add_argument('--val_batch_size', type=int, default=1, help='val_batch_size') 132 | parser.add_argument('--num_worker', type=int, default=8, help='num_workers') 133 | parser.add_argument('--n_epoch', type=int, default=200, help='iter number') 134 | parser.add_argument('--phase', type=str, default='train', help='train or test') 135 | parser.add_argument('--val_freq', type=int, default=1, help='validation frequent') 136 | parser.add_argument('--results_mask', type=str, default='results', help='mask result fold') 137 | parser.add_argument('--save_path', type=str, default='checkpoints', help='checkpoint fold') 138 | args = parser.parse_args() 139 | os.environ["WANDB_API_KEY"] = "xxxx" 140 | wandb.login() 141 | nowtime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 142 | wandb.init(project='ISDTD', config=args.__dict__, name='BigReal' + nowtime, save_code=False) 143 | 144 | main(args) -------------------------------------------------------------------------------- /code/utils/BCEDiceloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def dice_loss(input, target): 6 | smooth = 1e-5 7 | num = target.size(0) 8 | input = input.view(num, -1) 9 | target = target.view(num, -1) 10 | intersection = (input * target) 11 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) 12 | dice = 1 - dice.sum() / num 13 | return dice 14 | 15 | class BCEDiceloss(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, input, target): 20 | input = torch.sigmoid(input) 21 | dice = dice_loss(input,target) 22 | bce = F.binary_cross_entropy(input, target) 23 | return 0.8 * bce + 0.2 * dice, bce, dice 24 | -------------------------------------------------------------------------------- /code/utils/DCHFRDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import numpy as np 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | from PIL import Image, ImageOps, ImageFilter 8 | 9 | class DCHFRDataset(Dataset): 10 | def __init__(self, root, mode, if_transform, base_size, dataset_type, crop_size=288): 11 | self.mode = mode 12 | self.if_transform = if_transform 13 | self.base_size = base_size 14 | self.crop_size = crop_size 15 | self.dataset_type = dataset_type 16 | 17 | if self.dataset_type == "BigReal" or self.dataset_type == "BigSim": 18 | self.imgs_dir = os.path.join(root, 'images') 19 | self.names = [] 20 | if self.mode == 'train': 21 | self.list_dir = os.path.join(root, 'train.txt') 22 | elif self.mode == 'val': 23 | self.list_dir = os.path.join(root, 'test_hr.txt') 24 | with open(self.list_dir, 'r') as f: 25 | self.names += [line.strip() for line in f.readlines()] 26 | self.dataset_len = len(self.names) 27 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([-0.1246], [1.0923])]) 28 | elif self.dataset_type == "NUAA" or self.dataset_type == "NUDT" or self.dataset_type == "IRSTD": 29 | if self.mode == 'train': 30 | self.imgs = sorted(glob.glob(os.path.join(root, 'train_imgs', "*.png"))) 31 | elif self.mode == 'val': 32 | self.imgs = sorted(glob.glob(os.path.join(root, 'test_imgs', "*.png"))) 33 | self.dataset_len = len(self.imgs) 34 | self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) 35 | else: 36 | raise NotImplementedError( 'dataset_type is wrong!') 37 | 38 | 39 | def __len__(self): 40 | return self.dataset_len 41 | 42 | def _sync_transform(self, img): 43 | if random.random() < 0.5: 44 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 45 | crop_size = self.crop_size 46 | base_size = self.base_size 47 | long_size = random.randint(int(base_size * 0.5), int(base_size * 2.0)) 48 | w, h = img.size 49 | if h > w: 50 | oh = long_size 51 | ow = int(1.0 * w * long_size / h + 0.5) 52 | short_size = ow 53 | else: 54 | ow = long_size 55 | oh = int(1.0 * h * long_size / w + 0.5) 56 | short_size = oh 57 | img = img.resize((ow, oh), Image.BILINEAR) 58 | if short_size < crop_size: 59 | padh = crop_size - oh if oh < crop_size else 0 60 | padw = crop_size - ow if ow < crop_size else 0 61 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 62 | w, h = img.size 63 | x1 = random.randint(0, w - crop_size) 64 | y1 = random.randint(0, h - crop_size) 65 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 66 | if random.random() < 0.5: 67 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 68 | img_lr = img.resize((base_size/4, base_size/4), Image.BICUBIC) 69 | img_hr = np.array(img) 70 | img_lr = np.array(img_lr) 71 | return img_lr, img_hr 72 | 73 | 74 | def _testval_sync_transform(self, img): 75 | base_size = self.base_size 76 | img_hr = img.resize ((base_size, base_size), Image.BILINEAR) 77 | img_lr = img.resize((64, 64), Image.BICUBIC) 78 | img_hr = np.array(img_hr) 79 | img_lr = np.array(img_lr) 80 | return img_lr, img_hr 81 | 82 | def __getitem__(self, index): 83 | if self.dataset_type == "BigReal" or self.dataset_type == "BigSim": 84 | name = self.names[index] 85 | img_path = self.imgs_dir+'/'+name+'.png' 86 | elif self.dataset_type == "NUAA" or self.dataset_type == "NUDT" or self.dataset_type == "IRSTD": 87 | img_path = self.imgs[index] 88 | img = Image.open(img_path).convert("RGB") 89 | 90 | if self.if_transform == True: 91 | if self.mode == 'train': 92 | img_lr, img_hr = self._sync_transform(img) 93 | img_lr = self.transform(img_lr) 94 | img_hr = self.transform(img_hr) 95 | if random.random() < 0.5: 96 | img_lr = img_lr.flip(-1) 97 | img_hr = img_hr.flip(-1) 98 | return {'lr': img_lr, 'hr': img_hr} 99 | elif self.mode == 'val': 100 | img_lr, img_hr = self._testval_sync_transform(img) 101 | img_lr = self.transform(img_lr) 102 | img_hr = self.transform(img_hr) 103 | if random.random() < 0.5: 104 | img_lr = img_lr.flip(-1) 105 | img_hr = img_hr.flip(-1) 106 | return {'lr': img_lr, 'hr': img_hr} 107 | else: 108 | base_size = self.base_size 109 | lr_base_size = int(base_size/4) 110 | img_hr = img.resize ((base_size, base_size), Image.BILINEAR) 111 | img_lr = img.resize((lr_base_size, lr_base_size), Image.BICUBIC) 112 | img_hr = self.transform(img_hr) 113 | img_lr = self.transform(img_lr) 114 | return {'lr': img_lr, 'hr': img_hr} 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /code/utils/ISDTDDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import numpy as np 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | from PIL import Image, ImageOps, ImageFilter 8 | 9 | # NUAA 10 | def NUAA_extract_number(filename): 11 | parts = filename.split('_') 12 | number_str = parts[-1].split('.')[0] 13 | return int(number_str) 14 | 15 | def NUAAlabel_extract_number(filename): 16 | parts = filename.split('_') 17 | number_str = parts[-2] 18 | return int(number_str) 19 | # NUDT 20 | def NUDT_extract_number(filename): 21 | parts = filename.split('/') 22 | number_str = parts[-1].split('.')[0] 23 | return int(number_str) 24 | 25 | def NUDTlabel_extract_number(filename): 26 | parts = filename.split('/') 27 | number_str = parts[-1].split('.')[0] 28 | return int(number_str) 29 | 30 | 31 | class ISDTDDataset(Dataset): 32 | def __init__(self, root, mode, if_transform, base_size, dataset_type, crop_size=480): 33 | self.mode = mode 34 | self.if_transform = if_transform 35 | self.base_size = base_size 36 | self.crop_size = crop_size 37 | self.dataset_type = dataset_type 38 | 39 | if self.dataset_type == "BigReal" or self.dataset_type == "BigSim": 40 | self.imgs_dir = os.path.join(root, 'images') 41 | self.mask_dir = os.path.join(root, 'masks') 42 | self.names = [] 43 | if self.mode == 'train': 44 | self.list_dir = os.path.join(root, 'train.txt') 45 | elif self.mode == 'val': 46 | self.list_dir = os.path.join(root, 'test.txt') 47 | with open(self.list_dir, 'r') as f: 48 | self.names += [line.strip() for line in f.readlines()] 49 | self.dataset_len = len(self.names) 50 | elif self.dataset_type == "NUAA": 51 | if self.mode == 'train': 52 | self.imgs = sorted(glob.glob(os.path.join(root, 'train_imgs', "*.png")),key=NUAA_extract_number) 53 | self.labels = sorted(glob.glob(os.path.join(root, 'train_labels', "*.png")),key=NUAAlabel_extract_number) 54 | if self.mode == 'val': 55 | self.imgs = sorted(glob.glob(os.path.join(root, 'test_imgs', "*.png")),key=NUAA_extract_number) 56 | self.labels = sorted(glob.glob(os.path.join(root, 'test_labels', "*.png")),key=NUAAlabel_extract_number) 57 | self.dataset_len = len(self.imgs) 58 | elif self.dataset_type == "NUDT": 59 | if self.mode == 'train': 60 | self.imgs = sorted(glob.glob(os.path.join(root, 'train_imgs', "*.png")),key=NUDT_extract_number) 61 | self.labels = sorted(glob.glob(os.path.join(root, 'train_labels', "*.png")),key=NUDTlabel_extract_number) 62 | if self.mode == 'val': 63 | self.imgs = sorted(glob.glob(os.path.join(root, 'test_imgs', "*.png")),key=NUDT_extract_number) 64 | self.labels = sorted(glob.glob(os.path.join(root, 'test_labels', "*.png")),key=NUDTlabel_extract_number) 65 | self.dataset_len = len(self.imgs) 66 | elif self.dataset_type == "IRSTD": 67 | if self.mode == 'train': 68 | self.imgs = sorted(glob.glob(os.path.join(root, 'train_imgs', "*.png"))) 69 | self.labels = sorted(glob.glob(os.path.join(root, 'train_labels', "*.png"))) 70 | if self.mode == 'val': 71 | self.imgs = sorted(glob.glob(os.path.join(root, 'test_imgs', "*.png"))) 72 | self.labels = sorted(glob.glob(os.path.join(root, 'test_labels', "*.png"))) 73 | self.dataset_len = len(self.imgs) 74 | else: 75 | raise NotImplementedError( 'dataset_type is wrong!') 76 | self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([-0.1246], [1.0923])]) # transforms.ToTensor(), 77 | 78 | def __len__(self): 79 | return self.dataset_len 80 | 81 | def _sync_transform(self, img, mask): 82 | if random.random() < 0.5: 83 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 84 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 85 | crop_size = self.crop_size 86 | long_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 87 | w, h = img.size 88 | if h > w: 89 | oh = long_size 90 | ow = int(1.0 * w * long_size / h + 0.5) 91 | short_size = ow 92 | else: 93 | ow = long_size 94 | oh = int(1.0 * h * long_size / w + 0.5) 95 | short_size = oh 96 | img = img.resize((ow, oh), Image.BILINEAR) 97 | mask = mask.resize((ow, oh), Image.NEAREST) 98 | if short_size < crop_size: 99 | padh = crop_size - oh if oh < crop_size else 0 100 | padw = crop_size - ow if ow < crop_size else 0 101 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 102 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 103 | w, h = img.size 104 | x1 = random.randint(0, w - crop_size) 105 | y1 = random.randint(0, h - crop_size) 106 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 107 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 108 | if random.random() < 0.5: 109 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 110 | img, mask = np.array(img), np.array(mask, dtype=np.float32) 111 | return img, mask 112 | 113 | def _testval_sync_transform(self, img, mask): 114 | base_size = self.base_size 115 | img = img.resize ((base_size, base_size), Image.BILINEAR) 116 | mask = mask.resize((base_size, base_size), Image.NEAREST) 117 | img, mask = np.array(img), np.array(mask, dtype=np.float32) 118 | return img, mask 119 | 120 | def __getitem__(self, index): 121 | if self.dataset_type == "BigReal" or self.dataset_type == "BigSim": 122 | name = self.names[index] 123 | img_path = self.imgs_dir+'/'+name+'.png' 124 | label_path = self.mask_dir+'/'+name+'.png' 125 | img = Image.open(img_path).convert("RGB") 126 | label = Image.open(label_path).convert("RGB") 127 | out_name = name+'.png' 128 | elif self.dataset_type == "NUAA" or self.dataset_type == "NUDT" or self.dataset_type == "IRSTD": 129 | img = Image.open(self.imgs[index]).convert("RGB") 130 | label = Image.open(self.labels[index]).convert("RGB") 131 | out_name = self.imgs[index].split("/")[-1] 132 | 133 | 134 | if self.if_transform == True: 135 | if self.mode == 'train': 136 | img, label = self._sync_transform(img,label) 137 | img, label = self.transform(img), transforms.ToTensor()(label) 138 | if random.random() < 0.5: 139 | img = img.flip(-1) 140 | label = label.flip(-1) 141 | return {'img': img, 'label': label} 142 | 143 | elif self.mode == 'val': 144 | width = img.width 145 | height = img.height 146 | img, label = self._testval_sync_transform(img,label) 147 | img, label = self.transform(img), transforms.ToTensor()(label) 148 | return {'img': img, 'label': label, 'name': out_name, 'width': width, 'height': height} 149 | 150 | else: 151 | if self.mode == 'train': 152 | img = img.resize((self.base_size, self.base_size)) 153 | label = label.resize((self.base_size, self.base_size)) 154 | img, label = self.transform(img), transforms.ToTensor()(label) 155 | if random.random() < 0.5: 156 | img = img.flip(-1) 157 | label = label.flip(-1) 158 | return {'img': img, 'label': label} 159 | 160 | elif self.mode == 'val': 161 | width = img.width 162 | height = img.height 163 | img, label = self.transform(img), transforms.ToTensor()(label) 164 | return {'img': img, 'label': label, 'name': out_name, 'width': width, 'height': height} 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /code/utils/Imetrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class SigmoidMetric(): 4 | def __init__(self): 5 | self.reset() 6 | 7 | def update(self, pred, labels): 8 | correct, labeled = self.batch_pix_accuracy1(pred, labels) 9 | inter, union = self.batch_intersection_union1(pred, labels) 10 | 11 | self.total_correct += correct 12 | self.total_label += labeled 13 | self.total_inter += inter 14 | self.total_union += union 15 | 16 | def get(self): 17 | """Gets the current evaluation result.""" 18 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 19 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 20 | mIoU = IoU.mean() 21 | return pixAcc, mIoU 22 | 23 | def reset(self): 24 | """Resets the internal evaluation result to initial state.""" 25 | self.total_inter = 0 26 | self.total_union = 0 27 | self.total_correct = 0 28 | self.total_label = 0 29 | 30 | def batch_pix_accuracy1(self, output, target): 31 | assert output.shape == target.shape 32 | output = output.detach().numpy() 33 | target = target.detach().numpy() 34 | 35 | predict = (output > 0.0).astype('int64') 36 | pixel_labeled = np.sum(target > 0) 37 | pixel_correct = np.sum((predict == target)*(target > 0)) 38 | assert pixel_correct <= pixel_labeled 39 | return pixel_correct, pixel_labeled 40 | 41 | def batch_intersection_union1(self, output, target): 42 | mini = 1 43 | maxi = 1 44 | nbins = 1 45 | predict = (output.detach().numpy() > 0).astype('int64') 46 | target = target.detach().numpy().astype('int64') 47 | intersection = predict * (predict == target) 48 | 49 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 50 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 51 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 52 | area_union = area_pred + area_lab - area_inter 53 | assert (area_inter <= area_union).all() 54 | return area_inter, area_union 55 | 56 | class SamplewiseSigmoidMetric(): 57 | def __init__(self, nclass, score_thresh=0.5): 58 | self.nclass = nclass 59 | self.score_thresh = score_thresh 60 | self.reset() 61 | 62 | def update(self, preds, labels): 63 | """Updates the internal evaluation result.""" 64 | inter_arr, union_arr = self.batch_intersection_union2(preds, labels, self.nclass, self.score_thresh) 65 | self.total_inter = np.append(self.total_inter, inter_arr) 66 | self.total_union = np.append(self.total_union, union_arr) 67 | 68 | def get(self): 69 | """Gets the current evaluation result.""" 70 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 71 | mIoU = IoU.mean() 72 | return IoU, mIoU 73 | 74 | def reset(self): 75 | """Resets the internal evaluation result to initial state.""" 76 | self.total_inter = np.array([]) 77 | self.total_union = np.array([]) 78 | self.total_correct = np.array([]) 79 | self.total_label = np.array([]) 80 | 81 | def batch_intersection_union2(self, output, target, nclass, score_thresh): 82 | """mIoU""" 83 | mini = 1 84 | maxi = 1 85 | nbins = 1 86 | predict = (output.detach().numpy() > score_thresh).astype('int64') 87 | 88 | target = target.detach().numpy().astype('int64') # T 89 | intersection = predict * (predict == target) # TP 90 | 91 | num_sample = intersection.shape[0] 92 | area_inter_arr = np.zeros(num_sample) 93 | area_pred_arr = np.zeros(num_sample) 94 | area_lab_arr = np.zeros(num_sample) 95 | area_union_arr = np.zeros(num_sample) 96 | 97 | for b in range(num_sample): 98 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) 99 | area_inter_arr[b] = area_inter 100 | 101 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) 102 | area_pred_arr[b] = area_pred 103 | 104 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) 105 | area_lab_arr[b] = area_lab 106 | 107 | area_union = area_pred + area_lab - area_inter 108 | area_union_arr[b] = area_union 109 | 110 | assert (area_inter <= area_union).all() 111 | 112 | return area_inter_arr, area_union_arr 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /code/utils/Pmetric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from skimage import measure 4 | class ROCMetric(): 5 | """Computes pixAcc and mIoU metric scores 6 | """ 7 | def __init__(self, nclass, bins): 8 | super(ROCMetric, self).__init__() 9 | self.nclass = nclass 10 | self.bins = bins 11 | self.tp_arr = np.zeros(self.bins+1) 12 | self.pos_arr = np.zeros(self.bins+1) 13 | self.fp_arr = np.zeros(self.bins+1) 14 | self.neg_arr = np.zeros(self.bins+1) 15 | self.class_pos=np.zeros(self.bins+1) 16 | # self.reset() 17 | 18 | def update(self, preds, labels): 19 | 20 | for iBin in range(self.bins+1): 21 | score_thresh = (iBin + 0.0) / self.bins 22 | i_tp, i_pos, i_fp, i_neg,i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass,score_thresh) 23 | self.tp_arr[iBin] += i_tp 24 | self.pos_arr[iBin] += i_pos 25 | self.fp_arr[iBin] += i_fp 26 | self.neg_arr[iBin] += i_neg 27 | self.class_pos[iBin]+=i_class_pos 28 | 29 | def get(self): 30 | 31 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 32 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 33 | 34 | recall = self.tp_arr / (self.pos_arr + 0.001) 35 | precision = self.tp_arr / (self.class_pos + 0.001) 36 | 37 | f1_score = (2*precision*recall)/(precision+recall) 38 | 39 | 40 | return tp_rates, fp_rates, recall, precision,f1_score 41 | 42 | def reset(self): 43 | 44 | self.tp_arr = np.zeros([11]) 45 | self.pos_arr = np.zeros([11]) 46 | self.fp_arr = np.zeros([11]) 47 | self.neg_arr = np.zeros([11]) 48 | self.class_pos= np.zeros([11]) 49 | 50 | 51 | 52 | class PD_FA(): 53 | def __init__(self, nclass, bins): 54 | super(PD_FA, self).__init__() 55 | self.nclass = nclass 56 | self.bins = bins 57 | self.image_area_total = [] 58 | self.image_area_match = [] 59 | self.FA = np.zeros(self.bins+1) 60 | self.PD = np.zeros(self.bins + 1) 61 | self.target= np.zeros(self.bins + 1) 62 | def update(self, preds, labels): 63 | self.W = preds.shape[3] 64 | self.W = int(self.W) 65 | 66 | for iBin in range(self.bins+1): 67 | score_thresh = iBin / self.bins 68 | predits = np.array((preds > score_thresh).cpu()).astype('int64') 69 | predits = np.reshape (predits, (self.W,self.W)) 70 | labelss = np.array((labels).cpu()).astype('int64') 71 | labelss = np.reshape (labelss , (self.W,self.W)) 72 | 73 | image = measure.label(predits, connectivity=2) 74 | coord_image = measure.regionprops(image) 75 | label = measure.label(labelss , connectivity=2) 76 | coord_label = measure.regionprops(label) 77 | 78 | self.target[iBin] += len(coord_label) 79 | self.image_area_total = [] 80 | self.image_area_match = [] 81 | self.distance_match = [] 82 | self.dismatch = [] 83 | 84 | for K in range(len(coord_image)): 85 | area_image = np.array(coord_image[K].area) 86 | self.image_area_total.append(area_image) 87 | 88 | for i in range(len(coord_label)): 89 | centroid_label = np.array(list(coord_label[i].centroid)) 90 | for m in range(len(coord_image)): 91 | centroid_image = np.array(list(coord_image[m].centroid)) 92 | distance = np.linalg.norm(centroid_image - centroid_label) 93 | area_image = np.array(coord_image[m].area) 94 | if distance < 3: 95 | self.distance_match.append(distance) 96 | self.image_area_match.append(area_image) 97 | 98 | del coord_image[m] 99 | break 100 | 101 | self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] 102 | self.FA[iBin]+=np.sum(self.dismatch) 103 | self.PD[iBin]+=len(self.distance_match) 104 | 105 | def get(self,img_num): 106 | 107 | Final_FA = self.FA / ((self.W * self.W) * img_num) 108 | Final_PD = self.PD /self.target 109 | 110 | return Final_FA,Final_PD 111 | 112 | 113 | def reset(self): 114 | self.FA = np.zeros([self.bins+1]) 115 | self.PD = np.zeros([self.bins+1]) 116 | self.target= np.zeros([self.bins+1]) 117 | 118 | 119 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 120 | 121 | predict = (torch.sigmoid(output) > score_thresh).float() 122 | if len(target.shape) == 3: 123 | target = np.expand_dims(target.float(), axis=1) 124 | elif len(target.shape) == 4: 125 | target = target.float() 126 | else: 127 | raise ValueError("Unknown target dimension") 128 | 129 | intersection = predict * ((predict == target).float()) 130 | 131 | tp = intersection.sum() 132 | fp = (predict * ((predict != target).float())).sum() 133 | tn = ((1 - predict) * ((predict == target).float())).sum() 134 | fn = (((predict != target).float()) * (1 - predict)).sum() 135 | pos = tp + fn 136 | neg = fp + tn 137 | class_pos= tp+fp 138 | 139 | return tp, pos, fp, neg, class_pos 140 | 141 | -------------------------------------------------------------------------------- /code/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch import nn 7 | from inspect import isfunction 8 | import torch.nn.functional as F 9 | from collections import OrderedDict 10 | from torchvision.utils import make_grid 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__(self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 14 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 15 | std = torch.Tensor(rgb_std) 16 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 17 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 18 | for p in self.parameters(): 19 | p.requires_grad = False 20 | 21 | class ResBlock(nn.Module): 22 | def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 23 | super(ResBlock, self).__init__() 24 | m = [] 25 | for i in range(2): 26 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 27 | if bn: 28 | m.append(nn.BatchNorm2d(n_feats)) 29 | if i == 0: 30 | m.append(act) 31 | self.body = nn.Sequential(*m) 32 | self.res_scale = res_scale 33 | def forward(self, x): 34 | res = self.body(x).mul(self.res_scale) 35 | res += x 36 | return res 37 | 38 | 39 | class Uhalf_ResnetBlocWithAttn(nn.Module): 40 | def __init__(self, dim, dim_out, *, norm_groups=32, dropout=0, with_attn=False): 41 | super().__init__() 42 | self.with_attn = with_attn 43 | self.res_block = Uhalf_ResnetBlock( 44 | dim, dim_out, norm_groups=norm_groups, dropout=dropout) 45 | if with_attn: 46 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 47 | 48 | def forward(self, x): 49 | x = self.res_block(x) 50 | if(self.with_attn): 51 | x = self.attn(x) 52 | return x 53 | 54 | class SelfAttention(nn.Module): 55 | def __init__(self, in_channel, n_head=1, norm_groups=32): 56 | super().__init__() 57 | 58 | self.n_head = n_head 59 | 60 | self.norm = nn.GroupNorm(norm_groups, in_channel) 61 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 62 | self.out = nn.Conv2d(in_channel, in_channel, 1) 63 | 64 | def forward(self, input): 65 | batch, channel, height, width = input.shape 66 | n_head = self.n_head 67 | head_dim = channel // n_head 68 | 69 | norm = self.norm(input) 70 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 71 | query, key, value = qkv.chunk(3, dim=2) 72 | 73 | attn = torch.einsum( 74 | "bnchw, bncyx -> bnhwyx", query, key 75 | ).contiguous() / math.sqrt(channel) 76 | attn = attn.view(batch, n_head, height, width, -1) 77 | attn = torch.softmax(attn, -1) 78 | attn = attn.view(batch, n_head, height, width, height, width) 79 | 80 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 81 | out = self.out(out.view(batch, channel, height, width)) 82 | 83 | return out + input 84 | 85 | class Uhalf_ResnetBlock(nn.Module): 86 | def __init__(self, dim, dim_out, dropout=0, norm_groups=32): 87 | super().__init__() 88 | self.block1 = Block(dim, dim_out, groups=norm_groups) 89 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 90 | self.res_conv = nn.Conv2d( 91 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 92 | 93 | def forward(self, x): 94 | h = self.block1(x) 95 | h = self.block2(h) 96 | return h + self.res_conv(x) 97 | 98 | class Downsample(nn.Module): 99 | def __init__(self, dim): 100 | super().__init__() 101 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 102 | 103 | def forward(self, x): 104 | return self.conv(x) 105 | 106 | 107 | class PositionalEncoding(nn.Module): 108 | def __init__(self, dim): 109 | super().__init__() 110 | self.dim = dim 111 | 112 | def forward(self, noise_level): 113 | count = self.dim // 2 114 | step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count 115 | encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 116 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 117 | return encoding 118 | 119 | class Swish(nn.Module): 120 | def forward(self, x): 121 | return x * torch.sigmoid(x) 122 | 123 | class ResnetBlocWithAttn(nn.Module): 124 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 125 | super().__init__() 126 | self.with_attn = with_attn 127 | self.res_block = ResnetBlock( 128 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 129 | if with_attn: 130 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 131 | 132 | def forward(self, x, time_emb): 133 | x = self.res_block(x, time_emb) 134 | if(self.with_attn): 135 | x = self.attn(x) 136 | return x 137 | 138 | class ResnetBlock(nn.Module): 139 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): 140 | super().__init__() 141 | self.noise_func = FeatureWiseAffine( 142 | noise_level_emb_dim, dim_out, use_affine_level) 143 | self.block1 = Block(dim, dim_out, groups=norm_groups) 144 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 145 | self.res_conv = nn.Conv2d( 146 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 147 | 148 | def forward(self, x, time_emb): 149 | b, c, h, w = x.shape 150 | h = self.block1(x) 151 | h = self.noise_func(h, time_emb) 152 | h = self.block2(h) 153 | return h + self.res_conv(x) 154 | 155 | class FeatureWiseAffine(nn.Module): 156 | def __init__(self, in_channels, out_channels, use_affine_level=False): 157 | super(FeatureWiseAffine, self).__init__() 158 | self.use_affine_level = use_affine_level 159 | self.noise_func = nn.Sequential( 160 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level)) 161 | ) 162 | 163 | def forward(self, x, noise_embed): 164 | batch = x.shape[0] 165 | if self.use_affine_level: 166 | gamma, beta = self.noise_func(noise_embed).view( 167 | batch, -1, 1, 1).chunk(2, dim=1) 168 | x = (1 + gamma) * x + beta 169 | else: 170 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) 171 | return x 172 | 173 | class Block(nn.Module): 174 | def __init__(self, dim, dim_out, groups=32, dropout=0): 175 | super().__init__() 176 | self.block = nn.Sequential( 177 | nn.GroupNorm(groups, dim), 178 | Swish(), 179 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 180 | nn.Conv2d(dim, dim_out, 3, padding=1) 181 | ) 182 | def forward(self, x): 183 | return self.block(x) 184 | 185 | def default(val, d): 186 | if exists(val): 187 | return val 188 | return d() if isfunction(d) else d 189 | 190 | def exists(x): 191 | return x is not None 192 | 193 | def make_coord(shape, ranges=None, flatten=True): 194 | coord_seqs = [] 195 | for i, n in enumerate(shape): 196 | if ranges is None: 197 | v0, v1 = -1, 1 198 | else: 199 | v0, v1 = ranges[i] 200 | r = (v1 - v0) / (2 * n) 201 | seq = v0 + r + (2 * r) * torch.arange(n).float() 202 | coord_seqs.append(seq) 203 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 204 | if flatten: 205 | ret = ret.view(-1, ret.shape[-1]) 206 | return ret 207 | 208 | class SpatialEncoding(nn.Module): 209 | def __init__(self, 210 | in_dim, 211 | out_dim, 212 | sigma=6, 213 | cat_input=True, 214 | require_grad=False, ): 215 | 216 | super().__init__() 217 | assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" 218 | 219 | n = out_dim // 2 // in_dim 220 | m = 2 ** np.linspace(0, sigma, n) 221 | m = np.stack([m] + [np.zeros_like(m)] * (in_dim - 1), axis=-1) 222 | m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0) 223 | self.emb = torch.FloatTensor(m) 224 | if require_grad: 225 | self.emb = nn.Parameter(self.emb, requires_grad=True) 226 | self.in_dim = in_dim 227 | self.out_dim = out_dim 228 | self.sigma = sigma 229 | self.cat_input = cat_input 230 | self.require_grad = require_grad 231 | 232 | def forward(self, x): 233 | 234 | if not self.require_grad: 235 | self.emb = self.emb.to(x.device) 236 | y = torch.matmul(x, self.emb.T) 237 | if self.cat_input: 238 | return torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1) 239 | else: 240 | return torch.cat([torch.sin(y), torch.cos(y)], dim=-1) 241 | 242 | 243 | def gen_feat(res, size, stride=1, local=False): 244 | bs, hh, ww = res.shape[0], res.shape[-2], res.shape[-1] 245 | h, w = size 246 | coords = (make_coord((h, w)).cuda().flip(-1) + 1) / 2 247 | coords = coords.unsqueeze(0).expand(bs, *coords.shape) 248 | coords = (coords * 2 - 1).flip(-1) 249 | feat_coords = make_coord((hh, ww), flatten=False).cuda().permute(2, 0, 1).unsqueeze(0).expand(res.shape[0], 2, *(hh, ww)) 250 | if local: 251 | vx_list = [-1, 1] 252 | vy_list = [-1, 1] 253 | eps_shift = 1e-6 254 | rel_coord_list = [] 255 | q_feat_list = [] 256 | area_list = [] 257 | else: 258 | vx_list, vy_list, eps_shift = [0], [0], 0 259 | rx = stride / h 260 | ry = stride / w 261 | 262 | for vx in vx_list: 263 | for vy in vy_list: 264 | coords_ = coords.clone() 265 | coords_[:, :, 0] += vx * rx + eps_shift 266 | coords_[:, :, 1] += vy * ry + eps_shift 267 | coords_.clamp_(-1 + 1e-6, 1 - 1e-6) 268 | q_feat = F.grid_sample(res, coords_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, :, 0, 269 | :].permute(0, 2, 1) 270 | q_coord = F.grid_sample(feat_coords, coords_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)[:, 271 | :, 0, :].permute(0, 2, 1) 272 | rel_coord = coords - q_coord 273 | rel_coord[:, :, 0] *= hh 274 | rel_coord[:, :, 1] *= ww 275 | if local: 276 | rel_coord_list.append(rel_coord) 277 | q_feat_list.append(q_feat) 278 | area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1]) 279 | area_list.append(area + 1e-9) 280 | if not local: 281 | return rel_coord, q_feat 282 | else: 283 | return rel_coord_list, q_feat_list, area_list 284 | 285 | 286 | def set_device(x,device): 287 | if isinstance(x, dict): 288 | for key, item in x.items(): 289 | if item is not None: 290 | x[key] = item.to(device) 291 | elif isinstance(x, list): 292 | for item in x: 293 | if item is not None: 294 | item = item.to(device) 295 | else: 296 | x = x.to(device) 297 | return x 298 | 299 | 300 | def data_process(data,mode): 301 | if mode == 'train': 302 | p = random.random() 303 | elif mode == 'val': 304 | p = 1 305 | img_lr, img_hr = data['lr'], data['hr'] 306 | w_hr = round(img_lr.shape[-1] + (img_hr.shape[-1] - img_lr.shape[-1]) * p) 307 | img_hr = resize_fn(img_hr, w_hr) 308 | hr_coord, _ = to_pixel_samples(img_hr) 309 | cell = torch.ones_like(hr_coord) 310 | cell[:, 0] *= 2 / img_hr.shape[-2] 311 | cell[:, 1] *= 2 / img_hr.shape[-1] 312 | hr_coord = hr_coord.repeat(img_hr.shape[0], 1, 1) 313 | cell = cell.repeat(img_hr.shape[0], 1, 1) 314 | data = {'lr': img_lr, 'hrcoord': hr_coord, 'cell': cell, 'hr': img_hr, 'scaler': torch.from_numpy(np.array([p], dtype=np.float32))} 315 | return data 316 | 317 | def resize_fn(img, size): 318 | return F.interpolate(img, size=size, mode='bicubic', align_corners=False) 319 | 320 | def to_pixel_samples(img): 321 | coord = make_coord(img.shape[-2:]) 322 | rgb = img.view(3, -1).permute(1, 0) 323 | return coord, rgb 324 | 325 | 326 | 327 | def get_current_visuals(SR_img, data): 328 | out_dict = OrderedDict() 329 | out_dict['SR'] = SR_img.detach().float().cpu() 330 | out_dict['INF'] = data['lr'].detach().float().cpu() 331 | out_dict['HR'] = data['hr'].detach().float().cpu() 332 | out_dict['LR'] = data['lr'].detach().float().cpu() 333 | return out_dict 334 | 335 | 336 | def save_network(checkpoints, epoch, iter_step, net, opt, best=None): 337 | if best is not None: 338 | gen_path = os.path.join(checkpoints, 'best_{}_gen.pth'.format(best)) 339 | opt_path = os.path.join(checkpoints, 'best_{}_opt.pth'.format(best)) 340 | else: 341 | gen_path = os.path.join(checkpoints, 'latest_gen.pth'.format(iter_step, epoch)) 342 | opt_path = os.path.join(checkpoints, 'latest_opt.pth'.format(iter_step, epoch)) 343 | network = net 344 | state_dict = network.state_dict() 345 | for key, param in state_dict.items(): 346 | state_dict[key] = param.cpu() 347 | torch.save(state_dict, gen_path) 348 | # opt 349 | opt_state = {'epoch': epoch, 'iter': iter_step, 'scheduler': None, 'optimizer': None} 350 | opt_state['optimizer'] = opt.state_dict() 351 | torch.save(opt_state, opt_path) 352 | 353 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 354 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) 355 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) 356 | n_dim = tensor.dim() 357 | if n_dim == 4: 358 | n_img = len(tensor) 359 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 360 | img_np = np.transpose(img_np, (1, 2, 0)) 361 | elif n_dim == 3: 362 | img_np = tensor.numpy() 363 | img_np = np.transpose(img_np, (1, 2, 0)) 364 | elif n_dim == 2: 365 | img_np = tensor.numpy() 366 | else: 367 | raise TypeError( 368 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 369 | if out_type == np.uint8: 370 | img_np = (img_np * 255.0).round() 371 | return img_np.astype(out_type) 372 | 373 | def calculate_psnr(img1, img2): 374 | img1 = img1.astype(np.float64) 375 | img2 = img2.astype(np.float64) 376 | mse = np.mean((img1 - img2)**2) 377 | if mse == 0: 378 | return 0 379 | return 20 * math.log10(255.0 / math.sqrt(mse)) 380 | 381 | def compute_alpha(beta, t): 382 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 383 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 384 | return a 385 | 386 | def test_data_process(data): 387 | sub, div = torch.FloatTensor([0.5]).view(1, -1, 1, 1), torch.FloatTensor([0.5]).view(1, -1, 1, 1) 388 | data['img'] = (data['img'] -sub) / div 389 | img, label = data['img'], data['label'] 390 | name, width, height = data['name'], data['width'], data['height'] 391 | data = {'img': img, 'label': label, 'name': name, 'width': width, 'height': height} 392 | return data 393 | --------------------------------------------------------------------------------