├── codes ├── run.sh ├── plot.py ├── utils.py ├── unet3d.py ├── diffusion.py ├── nn.py ├── diffusion_plot.py ├── train_unet.py ├── fp16_util.py ├── feature_extract_unet.py ├── data.py └── logger.py ├── framework.png ├── .gitignore └── README.md /codes/run.sh: -------------------------------------------------------------------------------- 1 | nohup python train_unet.py &> log.txt & 2 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenning0115/spectraldiff_diffusion/HEAD/framework.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | jupyter 2 | codes/__pycache__ 3 | codes/.ipynb_checkpoints/ 4 | codes/save_model 5 | codes/save_feature 6 | -------------------------------------------------------------------------------- /codes/plot.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | from torchvision import transforms 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def show_spectral_curve(image, X, Y, total=3): 8 | if X is None or Y is None: 9 | return 10 | if type(image) != np.ndarray: 11 | image = image.numpy() 12 | if len(image.shape) == 4: 13 | image = image[0, :, :, :] 14 | if len(image.shape) == 5: 15 | image = image[0, 0, :, :, :] 16 | 17 | # image shape (spectral, w, h) 18 | X = X[0] 19 | Y = Y[0] 20 | num = 0 21 | w,h = Y.shape 22 | for i in range(w): 23 | for j in range(h): 24 | if num > total: 25 | break 26 | if Y[i,j] > 0: 27 | ss = list(image[:,i,j]) 28 | real_ss = list(X[:,i,j]) 29 | ii = list(range(len(ss))) 30 | plt.plot(ii, ss, label='pred') 31 | plt.plot(ii, real_ss, label='real') 32 | num += 1 33 | 34 | 35 | 36 | 37 | 38 | 39 | def show_tensor_image(image, rgb=(0,1,2)): 40 | if type(image) != np.ndarray: 41 | image = image.numpy() 42 | r,g,b = rgb 43 | if len(image.shape) == 4: 44 | image = image[0, :, :, :] 45 | if len(image.shape) == 5: 46 | image = image[0, 0, :, :, :] 47 | if image.shape[-1] > 3: 48 | rimg = image[r,:,:] 49 | gimg = image[g,:,:] 50 | bimg = image[b,:,:] 51 | image = np.stack([rimg,gimg,bimg]) 52 | 53 | def trans(x): 54 | if type(x) == np.ndarray: 55 | return np.transpose(x, (1,2,0)) 56 | else: 57 | return x.permute(1,2,0) 58 | 59 | def totype(x): 60 | if type(x) == np.ndarray: 61 | return x.astype(np.uint8) 62 | else: 63 | return x.numpy().astype(np.uint8) 64 | 65 | reverse_transforms = transforms.Compose([ 66 | # transforms.Lambda(lambda t: (t + 1) / 2), 67 | transforms.Lambda(lambda t: trans(t)), # CHW to HWC 68 | transforms.Lambda(lambda t: t * 255.), 69 | transforms.Lambda(lambda t: totype(t)), 70 | transforms.ToPILImage(), 71 | ]) 72 | # Take first image of batch 73 | plt.imshow(reverse_transforms(image)) 74 | 75 | 76 | -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json, time 3 | import numpy as np 4 | from torchvision import transforms 5 | import matplotlib.pyplot as plt 6 | import torch 7 | device = "cuda:6" if torch.cuda.is_available() else "cpu" 8 | 9 | def show_img(x): 10 | def trans(x): 11 | if type(x) == np.ndarray: 12 | return np.transpose(x, (1,2,0)) 13 | else: 14 | return x.permute(1,2,0) 15 | 16 | def totype(x): 17 | if type(x) == np.ndarray: 18 | return x.astype(np.uint8) 19 | else: 20 | return x.numpy().astype(np.uint8) 21 | 22 | reverse_transforms = transforms.Compose([ 23 | transforms.Lambda(lambda t: (t + 1) / 2), 24 | transforms.Lambda(lambda t: trans(t)), # CHW to HWC 25 | transforms.Lambda(lambda t: t * 255.), 26 | transforms.Lambda(lambda t: totype(t)), 27 | transforms.ToPILImage(), 28 | ]) 29 | plt.imshow(reverse_transforms(x)) 30 | 31 | class AvgrageMeter(object): 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.avg = 0 37 | self.sum = 0 38 | self.cnt = 0 39 | 40 | def update(self, val, n=1): 41 | self.sum += val * n 42 | self.cnt += n 43 | self.avg = self.sum / self.cnt 44 | 45 | def get_avg(self): 46 | return self.avg 47 | 48 | class Recoder(object): 49 | def __init__(self) -> None: 50 | self.record_data = {} 51 | 52 | def append_index_value(self, name, index, value): 53 | """ 54 | index : int, 55 | value: Any 56 | save to dict 57 | {index: list, value: list} 58 | """ 59 | if name not in self.record_data: 60 | self.record_data[name] = { 61 | "type": "index_value", 62 | "index":[], 63 | "value":[] 64 | } 65 | self.record_data[name]['index'].append(index) 66 | self.record_data[name]['value'].append(value) 67 | 68 | def record_param(self, param): 69 | self.record_data['param'] = param 70 | 71 | def record_eval(self, eval_obj): 72 | self.record_data['eval'] = eval_obj 73 | 74 | def to_file(self, path): 75 | time_stamp = int(time.time()) 76 | save_path = "%s_%s.json" % (path, str(time_stamp)) 77 | ss = json.dumps(self.record_data, indent=4) 78 | with open(save_path, 'w') as fout: 79 | fout.write(ss) 80 | fout.flush() 81 | 82 | def reset(self): 83 | self.record_data = {} 84 | 85 | 86 | # global recorder 87 | recorder = Recoder() 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models 2 | 3 | [Ning Chen](), [Jun Yue](), [Leyuan Fang](), [Shaobo Xia]() 4 | ___________ 5 | 6 | The code in this toolbox implements the ["SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models"](). 7 | 8 | **The codes for this research includes two parts, [spectral-spatial diffusion module](https://github.com/chenning0115/spectraldiff_diffusion/) and [attention-based classification module](https://github.com/chenning0115/SpectralDiff#spectraldiff). This repository is for the spectral-spatial diffusion module.** 9 | 10 | More specifically, it is detailed as follow. 11 | 12 | ![alt text](./framework.png) 13 | 14 | Citation 15 | --------------------- 16 | 17 | **Please kindly cite the papers if this code is useful and helpful for your research.** 18 | 19 | ``` 20 | N. Chen, J. Yue, L. Fang and S. Xia, "SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2023.3310023. 21 | 22 | ``` 23 | 24 | ``` 25 | @ARTICLE{10234379, 26 | author={Chen, Ning and Yue, Jun and Fang, Leyuan and Xia, Shaobo}, 27 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 28 | title={SpectralDiff: A Generative Framework for Hyperspectral Image Classification with Diffusion Models}, 29 | year={2023}, 30 | volume={}, 31 | number={}, 32 | pages={1-1}, 33 | doi={10.1109/TGRS.2023.3310023}} 34 | 35 | ``` 36 | 37 | How to use it? 38 | --------------------- 39 | 1. Prepare Data, you can get data from [here](https://www.ehu.eus/ccwintco/index.php/Hyperspectral_Remote_Sensing_Scenes). 40 | 2. Modify the configuration for the corresponding dataset in train_unet.py file. 41 | ``` 42 | # for PU 43 | sign = 'PU' 44 | batch_size = 20 45 | patch_size = 64 46 | select_spectral = [] 47 | spe = 104 48 | channel = 1 #3d channel 49 | 50 | # for IP 51 | # sign = 'IP' 52 | # batch_size = 20 53 | # patch_size = 64 54 | # select_spectral = [] 55 | # spe = 200 56 | # channel = 1 #3d channel 57 | 58 | # for SA 59 | # sign = 'SA' 60 | # batch_size = 20 61 | # patch_size = 64 62 | # select_spectral = [] 63 | # spe = 104 64 | # channel = 1 #3d channel 65 | ``` 66 | 3. Run the code to train diffusion model, note that the epoch should be more than 30000. 67 | ``` 68 | python train_unet.py 69 | ``` 70 | 71 | 4. Modify the Confituration in feature_extract_unet.py file and run the code to extract diffusion features by diffusion model. 72 | 73 | ``` 74 | python feature_extract_unet.py 75 | ``` 76 | 77 | Others 78 | ---------------------- 79 | If you want to run the code in your own data, you can accordingly change the input (e.g., data, labels) and tune the parameters. 80 | 81 | If you encounter the bugs while using this code, please do not hesitate to contact us. 82 | 83 | Licensing 84 | --------- 85 | 86 | Copyright (C) 2023 Ning Chen 87 | 88 | This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3 of the License. 89 | 90 | This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. 91 | 92 | You should have received a copy of the GNU General Public License along with this program. 93 | -------------------------------------------------------------------------------- /codes/unet3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import matplotlib.pyplot as plt 4 | from torch import nn 5 | import math 6 | import numpy as np 7 | 8 | class Block3d(nn.Module): 9 | def __init__(self, in_ch, out_ch, time_emb_dim, kernal=(4,3,3), stride=(2,1,1), padding=(1,1,1), up=False): 10 | super().__init__() 11 | self.time_mlp = nn.Linear(time_emb_dim, out_ch) 12 | if up: 13 | self.conv1 = nn.Conv3d(2*in_ch, out_ch, 3, padding=1) 14 | self.transform = nn.ConvTranspose3d(out_ch, out_ch, kernal, stride, padding) 15 | else: 16 | self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1) 17 | self.transform = nn.Conv3d(out_ch, out_ch, kernal, stride, padding) 18 | self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1) 19 | self.bnorm1 = nn.BatchNorm3d(out_ch) 20 | self.bnorm2 = nn.BatchNorm3d(out_ch) 21 | self.relu = nn.ReLU() 22 | 23 | 24 | def forward(self, x, t, ): 25 | # First Conv 26 | h = self.bnorm1(self.relu(self.conv1(x))) 27 | # Time embedding 28 | time_emb = self.relu(self.time_mlp(t)) 29 | # Extend last 2 dimensions 30 | time_emb = time_emb[(..., ) + (None, ) * 3] 31 | # Add time channel 32 | h = h + time_emb 33 | # Second Conv 34 | h = self.bnorm2(self.relu(self.conv2(h))) 35 | # Down or Upsample 36 | return self.transform(h) 37 | 38 | 39 | class SinusoidalPositionEmbeddings(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.dim = dim 43 | 44 | def forward(self, time): 45 | device = time.device 46 | half_dim = self.dim // 2 47 | embeddings = math.log(10000) / (half_dim - 1) 48 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 49 | embeddings = time[:, None] * embeddings[None, :] 50 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 51 | # TODO: Double check the ordering here 52 | return embeddings 53 | 54 | 55 | class SimpleUnet(nn.Module): 56 | """ 57 | A simplified variant of the Unet architecture. 58 | """ 59 | def __init__(self, _image_channels): 60 | super().__init__() 61 | image_channels = _image_channels 62 | down_channels = (16,32,64,128) 63 | down_params = [ 64 | [(4,5,5),(2,1,1),(1,2,2)], 65 | [(4,5,5),(2,1,1),(1,2,2)], 66 | [(4,5,5),(2,1,1),(1,2,2)], 67 | ] 68 | up_channels = (128,64,32,16) 69 | up_params = [ 70 | [(4,5,5),(2,1,1),(1,2,2)], 71 | [(4,5,5),(2,1,1),(1,2,2)], 72 | [(4,5,5),(2,1,1),(1,2,2)], 73 | ] 74 | out_dim = 1 75 | time_emb_dim = 32 76 | self.features = [] 77 | 78 | # Time embedding 79 | self.time_mlp = nn.Sequential( 80 | SinusoidalPositionEmbeddings(time_emb_dim), 81 | nn.Linear(time_emb_dim, time_emb_dim), 82 | nn.ReLU() 83 | ) 84 | 85 | # Initial projection 86 | self.conv0 = nn.Conv3d(image_channels, down_channels[0], 3, padding=1) 87 | 88 | # Downsample 89 | self.downs = nn.ModuleList([Block3d(down_channels[i], down_channels[i+1], time_emb_dim, \ 90 | down_params[i][0], down_params[i][1], down_params[i][2]) \ 91 | for i in range(len(down_channels)-1)]) 92 | # Upsample 93 | self.ups = nn.ModuleList([Block3d(up_channels[i], up_channels[i+1], time_emb_dim, \ 94 | up_params[i][0], up_params[i][1], up_params[i][2], up=True) \ 95 | for i in range(len(up_channels)-1)]) 96 | 97 | self.output = nn.Conv3d(up_channels[-1], image_channels, out_dim) 98 | 99 | def forward(self, x, timestep, feature=False): 100 | # Embedd time 101 | t = self.time_mlp(timestep) 102 | # Initial conv 103 | x = self.conv0(x) 104 | # Unet 105 | residual_inputs = [] 106 | for down in self.downs: 107 | x = down(x, t) 108 | residual_inputs.append(x) 109 | for up in self.ups: 110 | residual_x = residual_inputs.pop() 111 | # print("down=",residual_x.shape, "up=", x.shape) 112 | # Add residual x as additional channels 113 | x = torch.cat((x, residual_x), dim=1) 114 | if feature: 115 | self.features.append(x.detach().cpu().numpy()) 116 | x = up(x, t) 117 | return self.output(x) 118 | 119 | def return_features(self): 120 | temp_features = [] 121 | temp_features = self.features[:] 122 | self.features = [] 123 | return temp_features 124 | 125 | 126 | 127 | 128 | if __name__ == "__main__": 129 | model = SimpleUnet(_image_channels=1) 130 | t = torch.full((1,), 100, dtype=torch.long) 131 | a = torch.randn((100,1,104,16,16)) 132 | print(model(a, t)) 133 | -------------------------------------------------------------------------------- /codes/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import matplotlib.pyplot as plt 4 | from torch.optim import Adam 5 | import torch.nn.functional as F 6 | from data import HSIDataLoader 7 | import numpy as np 8 | from plot import show_tensor_image 9 | from utils import device 10 | 11 | 12 | class Diffusion(object): 13 | def __init__(self, T=1000) -> None: 14 | self.T = T 15 | self.betas = self._linear_beta_schedule(timesteps=self.T) 16 | # Pre-calculate different terms for closed form 17 | self.alphas = 1. - self.betas 18 | self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) 19 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) 20 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) 21 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 22 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 23 | self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) 24 | 25 | 26 | def _linear_beta_schedule(self, timesteps, start=0.0001, end=0.02): 27 | return torch.linspace(start, end, timesteps) 28 | 29 | def _get_index_from_list(self, vals, t, x_shape): 30 | """ 31 | Returns a specific index t of a passed list of values vals 32 | while considering the batch dimension. 33 | """ 34 | batch_size = t.shape[0] 35 | out = vals.gather(-1, t.cpu()) 36 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) 37 | 38 | def forward_diffusion_sample(self, x_0, t, device="cpu"): 39 | """ 40 | Takes an image and a timestep as input and 41 | returns the noisy version of it 42 | """ 43 | noise = torch.randn_like(x_0) 44 | sqrt_alphas_cumprod_t = self._get_index_from_list(self.sqrt_alphas_cumprod, t, x_0.shape) 45 | sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list( 46 | self.sqrt_one_minus_alphas_cumprod, t, x_0.shape 47 | ) 48 | # mean + variance 49 | return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \ 50 | + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device) 51 | 52 | 53 | def get_loss(self, model, x_0, t): 54 | x_noisy, noise = self.forward_diffusion_sample(x_0, t, device) 55 | noise_pred = model(x_noisy, t) 56 | return F.l1_loss(noise, noise_pred), x_noisy, noise, noise_pred 57 | 58 | 59 | @torch.no_grad() 60 | def sample_timestep(self, x, t, model): 61 | """ 62 | Calls the model to predict the noise in the image and returns 63 | the denoised image. 64 | Applies noise to this image, if we are not in the last step yet. 65 | 66 | x is xt, t is timestamp 67 | return x_{t-1} 68 | """ 69 | betas_t = self._get_index_from_list(self.betas, t, x.shape) 70 | sqrt_one_minus_alphas_cumprod_t = self._get_index_from_list( 71 | self.sqrt_one_minus_alphas_cumprod, t, x.shape 72 | ) 73 | sqrt_recip_alphas_t = self._get_index_from_list(self.sqrt_recip_alphas, t, x.shape) 74 | 75 | # Call model (current image - noise prediction) 76 | model_mean = sqrt_recip_alphas_t * ( 77 | x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t 78 | ) 79 | posterior_variance_t = self._get_index_from_list(self.posterior_variance, t, x.shape) 80 | 81 | if t == 0: 82 | return model_mean 83 | else: 84 | noise = torch.randn_like(x) 85 | return model_mean + torch.sqrt(posterior_variance_t) * noise 86 | 87 | @torch.no_grad() 88 | def reconstruct(self, model, xt=None, tempT=None, num = 5, from_noise=False, shape=None): 89 | ''' 90 | 分别从纯noise和xt,逐步恢复信息 91 | 如果不给定xt 则自动使用随机造成 92 | 给定xt同时需要给定tempT,表明该xt是来自多少步造成生成 93 | ''' 94 | stepsize = int(tempT.cpu().numpy()[0] / num) 95 | index = [] 96 | res = [] 97 | # Sample noise 98 | if from_noise: 99 | img = torch.randn(shape, device=device) 100 | else: 101 | img = xt 102 | 103 | if tempT is None: 104 | tempT = self.T 105 | 106 | for i in range(0, tempT)[::-1]: 107 | t = torch.full((1,), i, device=device, dtype=torch.long) 108 | img = self.sample_timestep(img, t, model) 109 | if i % stepsize == 0: 110 | index.append(i) 111 | res.append(img.detach().cpu()) 112 | index.append(i) 113 | res.append(img.detach().cpu()) 114 | return index, res 115 | 116 | @torch.no_grad() 117 | def reconstruct_v2(self, model, xt=None, tempT=None, use_index=[], from_noise=False, shape=None): 118 | ''' 119 | 分别从纯noise和xt,逐步恢复信息 120 | 如果不给定xt 则自动使用随机造成 121 | 给定xt同时需要给定tempT,表明该xt是来自多少步造成生成 122 | ''' 123 | index = [] 124 | res = [] 125 | # Sample noise 126 | if from_noise: 127 | img = torch.randn(shape, device=device) 128 | else: 129 | img = xt 130 | 131 | if tempT is None: 132 | tempT = self.T 133 | 134 | for i in range(0, tempT)[::-1]: 135 | t = torch.full((1,), i, device=device, dtype=torch.long) 136 | img = self.sample_timestep(img, t, model) 137 | if i in use_index: 138 | index.append(i) 139 | res.append(img.detach().cpu()) 140 | index.append(i) 141 | res.append(img.detach().cpu()) 142 | return index, res 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /codes/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels, num_groups=32): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(num_groups, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /codes/diffusion_plot.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | #os.environ["CUDA_VISIBLE_DEVICES"]="1" 3 | import torch 4 | import torchvision 5 | from torchvision import transforms 6 | from torch import nn 7 | from torch.optim import Adam 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import math 11 | import matplotlib.pyplot as plt 12 | 13 | from new_data import HSIDataLoader, TestDS, TrainDS 14 | from unet3d import SimpleUnet 15 | from spectral_transformer import SpectralTransNet 16 | from unet import UNetModel 17 | from diffusion import Diffusion 18 | from utils import AvgrageMeter, recorder, show_img 19 | 20 | batch_size = 128 21 | patch_size = 16 22 | select_spectral = [] 23 | spe = 144 24 | channel = 1 #3d channel 25 | 26 | epochs = 100000 # Try more! 27 | lr = 1e-4 28 | T=500 29 | 30 | rgb = [0,99,199] 31 | model_load_path = "../../data/save_model/unet3d_patch16_without_downsample_kernal5_fix" 32 | model_name = "unet3d_27000.pkl" 33 | 34 | TList = [5, 10, 100, 200, 400] 35 | 36 | 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | 39 | def plot_by_imgs(imgs, rgb=[1,100,199]): 40 | assert len(imgs) > 0 41 | batch, c, s, h, w = imgs[0].shape 42 | for i in range(batch): 43 | plt.figure(figsize=(12,8)) 44 | for j in range(len(imgs)): 45 | plt.subplot(1,len(imgs),j+1) 46 | img = imgs[j][i,0,rgb,:,:] 47 | show_img(img) 48 | plt.show() 49 | 50 | def plot_by_images_v2(imgs, rgb=[1,100,199]): 51 | ''' 52 | input image shape is (spectral, height, width) 53 | ''' 54 | assert len(imgs) > 0 55 | s,h,w = imgs[0].shape 56 | plt.figure(figsize=(12,8)) 57 | for j in range(len(imgs)): 58 | plt.subplot(1,len(imgs),j+1) 59 | img = imgs[j][rgb,:,:] 60 | show_img(img) 61 | plt.show() 62 | 63 | def plot_spectral(x0, recon_x0, num=3): 64 | ''' 65 | x0, recon_x0 shape is (batch, channel, spectral, h, w) 66 | ''' 67 | batch, c, s, h ,w = x0.shape 68 | step = h // num 69 | plt.figure(figsize=(20,5)) 70 | for ii in range(num): 71 | i = ii * step 72 | x0_spectral = x0[0,0,:,i,i] 73 | recon_x0_spectral = recon_x0[0,0,:,i,i] 74 | plt.subplot(1,num,ii+1) 75 | plt.plot(x0_spectral, label="x0") 76 | plt.plot(recon_x0_spectral, label="recon") 77 | plt.legend() 78 | plt.show() 79 | 80 | 81 | def recon_all_fig(diffusion, model, splitX, dataloader, big_img_size=[145, 145]): 82 | ''' 83 | X shape is (spectral, h, w) => (batch, channel=1, 200, 145, 145) 84 | ''' 85 | # 1. reconstruct 86 | t = torch.full((1,), diffusion.T-1, device=device, dtype=torch.long) 87 | xt, tmp_noise = diffusion.forward_diffusion_sample(torch.from_numpy(splitX.astype('float32')), t, device) 88 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 5) 89 | 90 | # ---just for test--- 91 | # recon_from_xt.append(torch.from_numpy(splitX.astype('float32'))) 92 | # plot_by_imgs(recon_from_xt, rgb=rgb) 93 | 94 | # --------- 95 | 96 | res_xt_list = [] 97 | for tempxt in recon_from_xt: 98 | big_xt = dataloader.split_to_big_image(tempxt.numpy()) 99 | res_xt_list.append(big_xt) 100 | ori_data, _ = dataloader.get_ori_data() 101 | res_xt_list.append(ori_data) 102 | plot_by_images_v2(res_xt_list, rgb=rgb) 103 | 104 | def sample_by_t(diffusion, model, X): 105 | num = 10 106 | choose_index = [3] 107 | x0 = torch.from_numpy(X[choose_index,:,:,:,:]).float() 108 | 109 | step = diffusion.T // num 110 | for ti in range(10, diffusion.T, step): 111 | t = torch.full((1,), ti, device=device, dtype=torch.long) 112 | xt, tmp_noise = diffusion.forward_diffusion_sample(x0, t, device) 113 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 5) 114 | recon_x0 = recon_from_xt[-1] 115 | recon_from_xt.append(x0) 116 | print('---',ti,'---') 117 | plot_by_imgs(recon_from_xt, rgb=rgb) 118 | print("x0", x0.shape, "recon_x0", recon_x0.shape) 119 | plot_spectral(x0, recon_x0) 120 | 121 | 122 | def inference_by_t(dataloader, diffusion, model, X, ti): 123 | ''' 124 | X shape is (batch, channel, spe, h, w) 125 | ''' 126 | 127 | X = torch.from_numpy(X).float() 128 | t = torch.full((1,), ti, device=device, dtype=torch.long) 129 | xt, tmp_noise = diffusion.forward_diffusion_sample(X, t, device) 130 | 131 | # 2. 对模型在该t下进行完全恢复尝试验证 132 | choose_index = [3] 133 | show_x0 = X[choose_index,:,:,:,:] 134 | show_xt = xt[choose_index, :,:,:,:] 135 | _, recon_from_xt = diffusion.reconstruct(model, xt=show_xt, tempT=t, num = 5) # recon_from_xt[0] shape (batch, channel, spe, h, w) 136 | recon_x0 = recon_from_xt[-1] 137 | recon_from_xt.append(show_x0) 138 | print('---',ti,'---') 139 | plot_by_imgs(recon_from_xt, rgb=rgb) 140 | plot_spectral(show_x0, recon_x0) 141 | 142 | 143 | 144 | def sample_eval(diffusion, model, X): 145 | all_size, channel, spe, h, w = X.shape 146 | num = 16 147 | step = all_size // num 148 | r,g,b = 1, 100, 199 149 | choose_index = list(range(0, all_size, step)) 150 | x0 = torch.from_numpy(X[choose_index,:,:,:,:]).float() 151 | 152 | use_t = 499 153 | # from xt 154 | t = torch.full((1,), use_t, device=device, dtype=torch.long) 155 | xt, tmp_noise = diffusion.forward_diffusion_sample(x0, t, device) 156 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 10) 157 | recon_from_xt.append(x0) 158 | plot_by_imgs(recon_from_xt, rgb=rgb) 159 | 160 | # from noise 161 | t = torch.full((1,), use_t, device=device, dtype=torch.long) 162 | 163 | _, recon_from_noise = diffusion.reconstruct(model, xt=x0, tempT=t, num = 10, from_noise=True, shape=x0.shape) 164 | plot_by_imgs(recon_from_noise, rgb=rgb) 165 | 166 | 167 | def save_model(model, path): 168 | torch.save(model.state_dict(), path) 169 | print("save model done. path=%s" % path) 170 | 171 | 172 | def plot(): 173 | dataloader = HSIDataLoader({"data":{"data_sign":"Indian", "padding":False, "batch_size":batch_size, "patch_size":patch_size, "select_spectral":select_spectral}}) 174 | train_loader,X,Y = dataloader.generate_torch_dataset(light_split=True) 175 | diffusion = Diffusion(T=T) 176 | 177 | model = SimpleUnet(_image_channels=channel) 178 | 179 | model_path = "%s/%s" % (model_load_path, model_name) 180 | model.load_state_dict(torch.load(model_path)) 181 | 182 | model.to(device) 183 | 184 | # for ti in TList: 185 | # inference_by_t(dataloader, diffusion, model, X, ti) 186 | # print("feature extract t=%s done." % ti) 187 | 188 | sample_eval(diffusion, model, X) 189 | print('done.') 190 | 191 | if __name__ == "__main__": 192 | plot() 193 | -------------------------------------------------------------------------------- /codes/train_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import transforms 4 | from torch import nn 5 | from torch.optim import Adam 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import math 9 | import matplotlib.pyplot as plt 10 | import os,sys 11 | 12 | from data import HSIDataLoader, TestDS, TrainDS 13 | from unet3d import SimpleUnet 14 | from diffusion import Diffusion 15 | from utils import AvgrageMeter, recorder, show_img 16 | from utils import device 17 | 18 | # for PU 19 | sign = 'PU' 20 | batch_size = 20 21 | patch_size = 64 22 | select_spectral = [] 23 | spe = 104 24 | channel = 1 #3d channel 25 | 26 | # for IP 27 | # sign = 'IP' 28 | # batch_size = 20 29 | # patch_size = 64 30 | # select_spectral = [] 31 | # spe = 200 32 | # channel = 1 #3d channel 33 | 34 | # for SA 35 | # sign = 'SA' 36 | # batch_size = 20 37 | # patch_size = 64 38 | # select_spectral = [] 39 | # spe = 104 40 | # channel = 1 #3d channel 41 | 42 | 43 | 44 | epochs = 100000 # more than 30000 45 | lr = 1e-4 46 | T=500 47 | 48 | rgb = [30,50,90] 49 | path_prefix = "./save_model/%s_diffusion" % sign 50 | 51 | 52 | #device = "cuda" if torch.cuda.is_available() else "cpu" 53 | 54 | def plot_by_imgs(imgs, rgb=[1,100,199]): 55 | assert len(imgs) > 0 56 | batch, c, s, h, w = imgs[0].shape 57 | for i in range(batch): 58 | plt.figure(figsize=(12,8)) 59 | for j in range(len(imgs)): 60 | plt.subplot(1,len(imgs),j+1) 61 | img = imgs[j][i,0,rgb,:,:] 62 | show_img(img) 63 | plt.show() 64 | 65 | def plot_by_images_v2(imgs, rgb=[1,100,199]): 66 | ''' 67 | input image shape is (spectral, height, width) 68 | ''' 69 | assert len(imgs) > 0 70 | s,h,w = imgs[0].shape 71 | plt.figure(figsize=(12,8)) 72 | for j in range(len(imgs)): 73 | plt.subplot(1,len(imgs),j+1) 74 | img = imgs[j][rgb,:,:] 75 | show_img(img) 76 | plt.show() 77 | 78 | def plot_spectral(x0, recon_x0, num=3): 79 | ''' 80 | x0, recon_x0 shape is (batch, channel, spectral, h, w) 81 | ''' 82 | batch, c, s, h ,w = x0.shape 83 | step = h // num 84 | plt.figure(figsize=(20,5)) 85 | for ii in range(num): 86 | i = ii * step 87 | x0_spectral = x0[0,0,:,i,i] 88 | recon_x0_spectral = recon_x0[0,0,:,i,i] 89 | plt.subplot(1,num,ii+1) 90 | plt.plot(x0_spectral, label="x0") 91 | plt.plot(recon_x0_spectral, label="recon") 92 | plt.legend() 93 | plt.show() 94 | 95 | def recon_all_fig(diffusion, model, splitX, dataloader, big_img_size=[145, 145]): 96 | ''' 97 | X shape is (spectral, h, w) => (batch, channel=1, 200, 145, 145) 98 | ''' 99 | # 1. reconstruct 100 | t = torch.full((1,), diffusion.T-1, device=device, dtype=torch.long) 101 | xt, tmp_noise = diffusion.forward_diffusion_sample(torch.from_numpy(splitX.astype('float32')), t, device) 102 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 5) 103 | 104 | # ---just for test--- 105 | # recon_from_xt.append(torch.from_numpy(splitX.astype('float32'))) 106 | # plot_by_imgs(recon_from_xt, rgb=rgb) 107 | 108 | # --------- 109 | 110 | res_xt_list = [] 111 | for tempxt in recon_from_xt: 112 | big_xt = dataloader.split_to_big_image(tempxt.numpy()) 113 | res_xt_list.append(big_xt) 114 | ori_data, _ = dataloader.get_ori_data() 115 | res_xt_list.append(ori_data) 116 | plot_by_images_v2(res_xt_list, rgb=rgb) 117 | 118 | def sample_by_t(diffusion, model, X): 119 | num = 10 120 | choose_index = [3] 121 | x0 = torch.from_numpy(X[choose_index,:,:,:,:]).float() 122 | 123 | step = diffusion.T // num 124 | for ti in range(10, diffusion.T, step): 125 | t = torch.full((1,), ti, device=device, dtype=torch.long) 126 | xt, tmp_noise = diffusion.forward_diffusion_sample(x0, t, device) 127 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 5) 128 | recon_x0 = recon_from_xt[-1] 129 | recon_from_xt.append(x0) 130 | print('---',ti,'---') 131 | plot_by_imgs(recon_from_xt, rgb=rgb) 132 | plot_spectral(x0, recon_x0) 133 | 134 | def sample_eval(diffusion, model, X): 135 | all_size, channel, spe, h, w = X.shape 136 | num = 5 137 | step = all_size // num 138 | r,g,b = 1, 100, 199 139 | choose_index = list(range(0, all_size, step)) 140 | x0 = torch.from_numpy(X[choose_index,:,:,:,:]).float() 141 | 142 | use_t = 499 143 | # from xt 144 | t = torch.full((1,), use_t, device=device, dtype=torch.long) 145 | xt, tmp_noise = diffusion.forward_diffusion_sample(x0, t, device) 146 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 10) 147 | recon_from_xt.append(x0) 148 | plot_by_imgs(recon_from_xt, rgb=rgb) 149 | 150 | # from noise 151 | t = torch.full((1,), use_t, device=device, dtype=torch.long) 152 | _, recon_from_noise = diffusion.reconstruct(model, xt=x0, tempT=t, num = 10, from_noise=True, shape=x0.shape) 153 | plot_by_imgs(recon_from_noise, rgb=rgb) 154 | 155 | 156 | def save_model(model, path): 157 | torch.save(model.state_dict(), path) 158 | print("save model done. path=%s" % path) 159 | 160 | 161 | def train(): 162 | dataloader = HSIDataLoader( 163 | {"data":{"data_sign":"Pavia", "padding":False, "batch_size":batch_size, "patch_size":patch_size, "select_spectral":select_spectral}}) 164 | train_loader,X,Y = dataloader.generate_torch_dataset(light_split=True) 165 | diffusion = Diffusion(T=T) 166 | model = SimpleUnet(_image_channels=channel) 167 | model.to(device) 168 | optimizer = Adam(model.parameters(), lr=lr) 169 | 170 | loss_metric = AvgrageMeter() 171 | 172 | assert not os.path.exists(path_prefix) 173 | os.makedirs(path_prefix) 174 | 175 | for epoch in range(epochs): 176 | loss_metric.reset() 177 | for step, (batch, _) in enumerate(train_loader): 178 | batch = batch.to(device) 179 | optimizer.zero_grad() 180 | cur_batch_size = batch.shape[0] 181 | t = torch.randint(0, diffusion.T , (cur_batch_size,), device=device).long() 182 | loss, temp_xt, temp_noise, temp_noise_pred = diffusion.get_loss(model, batch, t) 183 | loss.backward() 184 | optimizer.step() 185 | loss_metric.update(loss.item(), batch.shape[0]) 186 | 187 | if step % 10 == 0: 188 | print(f"[Epoch-step] {epoch} | step {step:03d} Loss: {loss.item()} ") 189 | print("[TRAIN EPOCH %s] loss=%s" % (epoch, loss_metric.get_avg())) 190 | 191 | if epoch % 100 == 0: 192 | #sample_by_t(diffusion, model, X) 193 | #sample_eval(diffusion, model, X) 194 | _, splitX, splitY = dataloader.generate_torch_dataset(split=True) 195 | # recon_all_fig(diffusion, model, splitX, dataloader, big_img_size=[145, 145]) 196 | path = "%s/unet3d_%s.pkl" % (path_prefix, epoch) 197 | save_model(model, path) 198 | 199 | 200 | if __name__ == "__main__": 201 | train() 202 | -------------------------------------------------------------------------------- /codes/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /codes/feature_extract_unet.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | #os.environ["CUDA_VISIBLE_DEVICES"]="1" 3 | import torch 4 | import torchvision 5 | from torchvision import transforms 6 | from torch import nn 7 | from torch.optim import Adam 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import math 11 | import matplotlib.pyplot as plt 12 | 13 | from data import HSIDataLoader, TestDS, TrainDS 14 | from unet3d import SimpleUnet 15 | from unet import UNetModel 16 | from diffusion import Diffusion 17 | from utils import AvgrageMeter, recorder, show_img 18 | from utils import device 19 | 20 | batch_size = 8 21 | patch_size = 64 22 | select_spectral = [] 23 | spe = 104 24 | channel = 1 #3d channel 25 | 26 | epochs = 100000 # Try more! 27 | lr = 1e-4 28 | T=500 29 | 30 | rgb = [50,60,100] 31 | model_load_path = "./save_model/pavia_unet3d_patch64_without_downsample_kernal5_fix" 32 | model_name = "unet3d_31900.pkl" #loss=0.0258 33 | save_feature_path_prefix = "./save_feature/pavia_unet3d_patch64_without_downsample_kernal5_fix_31900/save_feature" 34 | 35 | TList = [5, 10, 50, 100, 200, 400] 36 | 37 | 38 | 39 | def plot_by_imgs(imgs, rgb=[1,100,199]): 40 | assert len(imgs) > 0 41 | batch, c, s, h, w = imgs[0].shape 42 | for i in range(batch): 43 | plt.figure(figsize=(12,8)) 44 | for j in range(len(imgs)): 45 | plt.subplot(1,len(imgs),j+1) 46 | img = imgs[j][i,0,rgb,:,:] 47 | show_img(img) 48 | plt.show() 49 | 50 | def plot_by_images_v2(imgs, rgb=[1,100,199]): 51 | ''' 52 | input image shape is (spectral, height, width) 53 | ''' 54 | assert len(imgs) > 0 55 | s,h,w = imgs[0].shape 56 | plt.figure(figsize=(12,8)) 57 | for j in range(len(imgs)): 58 | plt.subplot(1,len(imgs),j+1) 59 | img = imgs[j][rgb,:,:] 60 | show_img(img) 61 | plt.show() 62 | 63 | def plot_spectral(x0, recon_x0, num=3): 64 | ''' 65 | x0, recon_x0 shape is (batch, channel, spectral, h, w) 66 | ''' 67 | batch, c, s, h ,w = x0.shape 68 | step = h // num 69 | plt.figure(figsize=(20,5)) 70 | for ii in range(num): 71 | i = ii * step 72 | x0_spectral = x0[0,0,:,i,i] 73 | recon_x0_spectral = recon_x0[0,0,:,i,i] 74 | plt.subplot(1,num,ii+1) 75 | plt.plot(x0_spectral, label="x0") 76 | plt.plot(recon_x0_spectral, label="recon") 77 | plt.legend() 78 | plt.show() 79 | 80 | 81 | def recon_all_fig(diffusion, model, splitX, dataloader, big_img_size=[145, 145]): 82 | ''' 83 | X shape is (spectral, h, w) => (batch, channel=1, 200, 145, 145) 84 | ''' 85 | # 1. reconstruct 86 | t = torch.full((1,), diffusion.T-1, device=device, dtype=torch.long) 87 | xt, tmp_noise = diffusion.forward_diffusion_sample(torch.from_numpy(splitX.astype('float32')), t, device) 88 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 5) 89 | 90 | # ---just for test--- 91 | # recon_from_xt.append(torch.from_numpy(splitX.astype('float32'))) 92 | # plot_by_imgs(recon_from_xt, rgb=rgb) 93 | 94 | # --------- 95 | 96 | res_xt_list = [] 97 | for tempxt in recon_from_xt: 98 | big_xt = dataloader.split_to_big_image(tempxt.numpy()) 99 | res_xt_list.append(big_xt) 100 | ori_data, _ = dataloader.get_ori_data() 101 | res_xt_list.append(ori_data) 102 | plot_by_images_v2(res_xt_list, rgb=rgb) 103 | 104 | def sample_by_t(diffusion, model, X): 105 | num = 10 106 | choose_index = [3] 107 | x0 = torch.from_numpy(X[choose_index,:,:,:,:]).float() 108 | 109 | step = diffusion.T // num 110 | for ti in range(10, diffusion.T, step): 111 | t = torch.full((1,), ti, device=device, dtype=torch.long) 112 | xt, tmp_noise = diffusion.forward_diffusion_sample(x0, t, device) 113 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 5) 114 | recon_x0 = recon_from_xt[-1] 115 | recon_from_xt.append(x0) 116 | print('---',ti,'---') 117 | plot_by_imgs(recon_from_xt, rgb=rgb) 118 | print("x0", x0.shape, "recon_x0", recon_x0.shape) 119 | plot_spectral(x0, recon_x0) 120 | 121 | def inference_mini_batch(model, xt, t): 122 | mini_batch_size = 4 123 | batch, channel, c, h, w = xt.shape 124 | step = batch // mini_batch_size + 1 125 | 126 | res_feature_t_list = [] 127 | for i in range(step): 128 | start = i * mini_batch_size 129 | end = (i+1) * mini_batch_size 130 | temp_xt = xt[start:end, :, :, :, :] 131 | if temp_xt.shape[0] <= 0: 132 | break 133 | noise_pred = model(temp_xt, t, feature=True) 134 | temp_feature_t_list = model.return_features() 135 | if len(res_feature_t_list) == 0: 136 | res_feature_t_list = temp_feature_t_list[:] 137 | else: 138 | assert len(res_feature_t_list) == len(temp_feature_t_list) 139 | temp_res = [] 140 | for j in range(len(temp_feature_t_list)): 141 | temp_res.append(np.concatenate([res_feature_t_list[j], temp_feature_t_list[j]])) 142 | res_feature_t_list = temp_res[:] 143 | for fea in res_feature_t_list: 144 | print(fea.shape) 145 | return res_feature_t_list 146 | 147 | def inference_by_t(dataloader, diffusion, model, X, ti): 148 | ''' 149 | X shape is (batch, channel, spe, h, w) 150 | ''' 151 | 152 | X = torch.from_numpy(X).float() 153 | t = torch.full((1,), ti, device=device, dtype=torch.long) 154 | xt, tmp_noise = diffusion.forward_diffusion_sample(X, t, device) 155 | 156 | # 1. 显示调用模型直接获取隐层特征 157 | # noise_pred = model(xt, t, feature=True) 158 | # feature_t_list = model.return_features() 159 | feature_t_list = inference_mini_batch(model, xt, t) 160 | for index, feature_matrix in enumerate(feature_t_list): 161 | path = "%s/t%s_%s.pkl" % (save_feature_path_prefix, ti, index) 162 | np.save(path, feature_matrix) 163 | print("save matrix t=%s, index=%s done." % (ti, index)) 164 | # feature_matrix shape is (batch, channel, spe, h, w) 165 | fb, fc, fs, fh, fw = feature_matrix.shape 166 | temp = feature_matrix.reshape((fb,fc*fs, fh, fw)).transpose((0,2,3,1)) 167 | full_feature_img = dataloader.reconstruct_image_by_light_split(temp, pathch_size=patch_size) 168 | path = "%s/t%s_%s_full.pkl" % (save_feature_path_prefix, ti, index) 169 | np.save(path, full_feature_img) 170 | print("save full matrix done. t=%s, index=%s, shape=%s" % (ti, index, str(full_feature_img.shape))) 171 | 172 | # 2. 对模型在该t下进行完全恢复尝试验证 173 | choose_index = [3] 174 | show_x0 = X[choose_index,:,:,:,:] 175 | show_xt = xt[choose_index, :,:,:,:] 176 | _, recon_from_xt = diffusion.reconstruct(model, xt=show_xt, tempT=t, num = 5) # recon_from_xt[0] shape (batch, channel, spe, h, w) 177 | recon_x0 = recon_from_xt[-1] 178 | recon_from_xt.append(show_x0) 179 | print('---',ti,'---') 180 | plot_by_imgs(recon_from_xt, rgb=rgb) 181 | plot_spectral(show_x0, recon_x0) 182 | 183 | 184 | 185 | def sample_eval(diffusion, model, X): 186 | all_size, channel, spe, h, w = X.shape 187 | num = 5 188 | step = all_size // num 189 | r,g,b = 1, 100, 199 190 | choose_index = list(range(0, all_size, step)) 191 | x0 = torch.from_numpy(X[choose_index,:,:,:,:]).float() 192 | 193 | use_t = 499 194 | # from xt 195 | t = torch.full((1,), use_t, device=device, dtype=torch.long) 196 | xt, tmp_noise = diffusion.forward_diffusion_sample(x0, t, device) 197 | _, recon_from_xt = diffusion.reconstruct(model, xt=xt, tempT=t, num = 10) 198 | recon_from_xt.append(x0) 199 | plot_by_imgs(recon_from_xt, rgb=rgb) 200 | 201 | # from noise 202 | t = torch.full((1,), use_t, device=device, dtype=torch.long) 203 | 204 | _, recon_from_noise = diffusion.reconstruct(model, xt=x0, tempT=t, num = 10, from_noise=True, shape=x0.shape) 205 | plot_by_imgs(recon_from_noise, rgb=rgb) 206 | 207 | 208 | def save_model(model, path): 209 | torch.save(model.state_dict(), path) 210 | print("save model done. path=%s" % path) 211 | 212 | 213 | def extract(): 214 | dataloader = HSIDataLoader({"data":{"data_sign":"Pavia", "padding":False, "batch_size":batch_size, "patch_size":patch_size, "select_spectral":select_spectral}}) 215 | train_loader,X,Y = dataloader.generate_torch_dataset(light_split=True) 216 | diffusion = Diffusion(T=T) 217 | 218 | model = SimpleUnet(_image_channels=channel) 219 | assert os.path.exists(model_load_path) 220 | if not os.path.exists(save_feature_path_prefix): 221 | os.makedirs(save_feature_path_prefix) 222 | 223 | model_path = "%s/%s" % (model_load_path, model_name) 224 | print('model path is ', model_path) 225 | model.load_state_dict(torch.load(model_path, map_location=device)) 226 | 227 | model.to(device) 228 | print("load model done. model_path=%s" % (save_feature_path_prefix)) 229 | 230 | for ti in TList: 231 | inference_by_t(dataloader, diffusion, model, X, ti) 232 | print("feature extract t=%s done." % ti) 233 | 234 | print('done.') 235 | 236 | if __name__ == "__main__": 237 | extract() 238 | -------------------------------------------------------------------------------- /codes/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | from sklearn.decomposition import PCA 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from operator import truediv 10 | import time 11 | 12 | """ Training dataset""" 13 | 14 | class TrainDS(torch.utils.data.Dataset): 15 | 16 | def __init__(self, Xtrain, ytrain): 17 | 18 | self.len = Xtrain.shape[0] 19 | self.x_data = torch.FloatTensor(Xtrain) 20 | self.y_data = torch.LongTensor(ytrain) 21 | 22 | def __getitem__(self, index): 23 | 24 | # 根据索引返回数据和对应的标签 25 | return self.x_data[index], self.y_data[index] 26 | def __len__(self): 27 | 28 | # 返回文件数据的数目 29 | return self.len 30 | 31 | """ Testing dataset""" 32 | 33 | class TestDS(torch.utils.data.Dataset): 34 | 35 | def __init__(self, Xtest, ytest): 36 | 37 | self.len = Xtest.shape[0] 38 | self.x_data = torch.FloatTensor(Xtest) 39 | self.y_data = torch.LongTensor(ytest) 40 | 41 | def __getitem__(self, index): 42 | 43 | # 根据索引返回数据和对应的标签 44 | return self.x_data[index], self.y_data[index] 45 | 46 | def __len__(self): 47 | 48 | # 返回文件数据的数目 49 | return self.len 50 | 51 | 52 | 53 | class HSIDataLoader(object): 54 | def __init__(self, param={}) -> None: 55 | self.data_param = param.get('data', {}) 56 | self.data = None #原始读入X数据 shape=(h,w,c) 57 | self.labels = None #原始读入Y数据 shape=(h,w,1) 58 | 59 | # 参数设置 60 | self.data_sign = self.data_param.get('data_sign', 'Indian') 61 | self.patch_size = self.data_param.get('patch_size', 32) # n * n 62 | self.padding = self.data_param.get('padding', True) # n * n 63 | self.remove_zeros = self.data_param.get('remove_zeros', False) 64 | self.batch_size = self.data_param.get('batch_size', 256) 65 | self.select_spectral = self.data_param.get('select_spectral', []) # [] all spectral selected 66 | 67 | self.squzze = True 68 | 69 | self.split_row = 0 70 | self.split_col = 0 71 | 72 | self.light_split_ori_shape = None 73 | self.light_split_map = [] 74 | 75 | 76 | 77 | def load_data(self): 78 | data, labels = None, None 79 | if self.data_sign == "Indian": 80 | data = sio.loadmat('../data/Indian_pines_corrected.mat')['indian_pines_corrected'] 81 | labels = sio.loadmat('../data/Indian_pines_gt.mat')['indian_pines_gt'] 82 | elif self.data_sign == "Pavia": 83 | data = sio.loadmat('../data/PaviaU.mat')['paviaU'] 84 | labels = sio.loadmat('../data/PaviaU_gt.mat')['paviaU_gt'] 85 | elif self.data_sign == "Houston": 86 | data = sio.loadmat('../data/Houston.mat')['img'] 87 | labels = sio.loadmat('../data/Houston_gt.mat')['Houston_gt'] 88 | elif self.data_sign == "Salinas": 89 | data = sio.loadmat('../data/Salinas_corrected.mat')['salinas_corrected'] 90 | labels = sio.loadmat('../data/Salinas_gt.mat')['salinas_gt'] 91 | else: 92 | pass 93 | print("ori data load shape is", data.shape, labels.shape) 94 | if len(self.select_spectral) > 0: #user choose spectral himself 95 | data = data[:,:,self.select_spectral] 96 | return data, labels 97 | 98 | def get_ori_data(self): 99 | return np.transpose(self.data, (2,0,1)), self.labels 100 | 101 | def _padding(self, X, margin=2): 102 | # pading with zeros 103 | w,h,c = X.shape 104 | new_x, new_h, new_c = w+margin*2, h+margin*2, c 105 | returnX = np.zeros((new_x, new_h, new_c)) 106 | start_x, start_y = margin, margin 107 | returnX[start_x:start_x+w, start_y:start_y+h,:] = X 108 | return returnX 109 | 110 | def get_patches_by_light_split(self, X, Y, patch_size=1): 111 | h, w, c = X.shape 112 | row = h // patch_size 113 | if h % patch_size != 0: 114 | row += 1 115 | col = w // patch_size 116 | if w % patch_size != 0: 117 | col += 1 118 | res = np.zeros((row*col, patch_size, patch_size, c)) 119 | self.light_split_ori_shape = X.shape 120 | resY = np.zeros((row*col)) 121 | index = 0 122 | for i in range(row): 123 | for j in range(col): 124 | start_row = i*patch_size 125 | if start_row + patch_size > h: 126 | start_row = h - patch_size 127 | start_col = j*patch_size 128 | if start_col + patch_size > w: 129 | start_col = w - patch_size 130 | 131 | res[index, :,:,:] = X[start_row:start_row+patch_size, start_col:start_col+patch_size, :] 132 | self.light_split_map.append([index, start_row, start_row+patch_size, start_col, start_col+patch_size]) 133 | index += 1 134 | return res, resY 135 | 136 | def reconstruct_image_by_light_split(self, inputX, pathch_size=1): 137 | ''' 138 | input shape is (batch, h, w, c) 139 | ''' 140 | assert self.light_split_ori_shape is not None 141 | ori_h, ori_w, ori_c = self.light_split_ori_shape 142 | batch, h, w, c = inputX.shape 143 | assert batch == len(self.light_split_map) # light_split_map必须与batch值相同 144 | X = np.zeros((ori_h, ori_w, c)) 145 | for tup in self.light_split_map: 146 | index, start_row, end_row, start_col, end_col = tup 147 | X[start_row:end_row, start_col:end_col, :] = inputX[index, :, :, :] 148 | return X 149 | 150 | 151 | def get_patches_by_split(self, X, Y, patch_size=1): 152 | h, w, c = X.shape 153 | row = h // patch_size 154 | col = w // patch_size 155 | newX = X 156 | res = np.zeros((row*col, patch_size, patch_size, c)) 157 | resY = np.zeros((row*col)) 158 | index = 0 159 | for i in range(row): 160 | for j in range(col): 161 | res[index,:,:,:] = newX[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size,:] 162 | index += 1 163 | self.split_row = row 164 | self.split_col = col 165 | return res, resY 166 | def split_to_big_image(self, splitX): 167 | ''' 168 | input splitX shape (batch, 1, spe, h, w) 169 | return newX shape (spe, bigh, bigw) 170 | ''' 171 | patch_size = self.patch_size 172 | batch, channel, spe, h, w = splitX.shape 173 | assert self.split_row * self.split_col == batch 174 | newX = np.zeros((spe, self.split_row * patch_size, self.split_col * patch_size)) 175 | index = 0 176 | for i in range(self.split_row): 177 | for j in range(self.split_col): 178 | index = i * self.split_col + j 179 | newX[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = splitX[index, 0, :, :, :] 180 | return newX 181 | 182 | 183 | 184 | def re_build_split(self, X_patches, patch_size): 185 | ''' 186 | X_pathes shape is (batch, channel=1, spectral, height, with) 187 | return shape is (height, width, spectral) 188 | ''' 189 | h,w,c = self.data.shape 190 | row = h // patch_size 191 | if h % patch_size > 0: 192 | row += 1 193 | col = w // patch_size 194 | if w % patch_size > 0: 195 | col += 1 196 | newX = np.zeros((c, row*patch_size, col*patch_size)) 197 | for i in range(row): 198 | for j in range(col): 199 | newX[:,i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = X_patches[i*col+j,0,:,:,:] 200 | return np.transpose(newX, (1,2,0)) 201 | 202 | def get_patches(self, X, Y, patch_size=1, remove_zero=False): 203 | w,h,c = X.shape 204 | #1. padding 205 | margin = (patch_size - 1) // 2 206 | if self.padding: 207 | X_padding = self._padding(X, margin=margin) 208 | else: 209 | X_padding = X 210 | 211 | #2. zero patchs 212 | temp_w, temp_h, temp_c = X_padding.shape 213 | row = temp_w - patch_size + 1 214 | col = temp_h - patch_size + 1 215 | X_patchs = np.zeros((row * col, patch_size, patch_size, c)) #one pixel one patch with padding 216 | Y_patchs = np.zeros((row * col)) 217 | patch_index = 0 218 | for r in range(0, row): 219 | for c in range(0, col): 220 | temp_patch = X_padding[r:r+patch_size, c:c+patch_size, :] 221 | X_patchs[patch_index, :, :, :] = temp_patch 222 | patch_index += 1 223 | 224 | if remove_zero: 225 | X_patchs = X_patchs[Y_patchs>0,:,:,:] 226 | Y_patchs = Y_patchs[Y_patchs>0] 227 | Y_patchs -= 1 228 | return X_patchs, Y_patchs #(batch, w, h, c), (batch) 229 | 230 | 231 | def custom_process(self, data): 232 | ''' 233 | pavia数据集 增加一个光谱维度 从103->104 其中第104维为103的镜像维度 234 | data shape is [h, w, spe] 235 | ''' 236 | 237 | if self.data_sign == "Pavia": 238 | h, w, spe = data.shape 239 | new_data = np.zeros((h,w,spe+1)) 240 | new_data[:,:,:spe] = data 241 | new_data[:,:,spe] = data[:,:,spe-1] 242 | return new_data 243 | if self.data_sign == "Salinas": 244 | h, w, spe = data.shape 245 | new_data = np.zeros((h,w,spe+4)) 246 | new_data[:,:,2:spe+2] = data 247 | return new_data 248 | return data 249 | 250 | 251 | def generate_torch_dataset(self, split=False, light_split=False): 252 | #1. 根据data_sign load data 253 | self.data, self.labels = self.load_data() 254 | 255 | #1.1 norm化 256 | norm_data = np.zeros(self.data.shape) 257 | for i in range(self.data.shape[2]): 258 | input_max = np.max(self.data[:,:,i]) 259 | input_min = np.min(self.data[:,:,i]) 260 | norm_data[:,:,i] = (self.data[:,:,i]-input_min)/(input_max-input_min) * 2 - 1 # [-1,1] 261 | 262 | print('[data] load data shape data=%s, label=%s' % (str(norm_data.shape), str(self.labels.shape))) 263 | self.data = norm_data 264 | 265 | #1.2 专门针对特殊的数据集补充或删减一些光谱维度 266 | norm_data = self.custom_process(norm_data) 267 | 268 | #2. 获取patchs 269 | if not split and not light_split: 270 | X_patchs, Y_patchs = self.get_patches(norm_data, self.labels, patch_size=self.patch_size, remove_zero=False) 271 | print('[data not split] data patches shape data=%s, label=%s' % (str(X_patchs.shape), str(Y_patchs.shape))) 272 | elif split: 273 | X_patchs, Y_patchs = self.get_patches_by_split(norm_data, self.labels, patch_size=self.patch_size) 274 | print('[data split] data patches shape data=%s, label=%s' % (str(X_patchs.shape), str(Y_patchs.shape))) 275 | elif light_split: 276 | X_patchs, Y_patchs = self.get_patches_by_light_split(norm_data, self.labels, patch_size=self.patch_size) 277 | print('[data light split] data patches shape data=%s, label=%s' % (str(X_patchs.shape), str(Y_patchs.shape))) 278 | 279 | 280 | #4. 调整shape来满足torch使用 281 | X_all = X_patchs.transpose((0, 3, 1, 2)) 282 | X_all = np.expand_dims(X_all, axis=1) 283 | print('------[data] after transpose train, test------') 284 | print("X.shape=", X_all.shape) 285 | print("Y.shape=", Y_patchs.shape) 286 | 287 | trainset = TrainDS(X_all, Y_patchs) 288 | train_loader = torch.utils.data.DataLoader(dataset=trainset, 289 | batch_size=self.batch_size, 290 | shuffle=True, 291 | num_workers=0, 292 | ) 293 | return train_loader, X_all, Y_patchs 294 | 295 | 296 | 297 | 298 | if __name__ == "__main__": 299 | # dataloader = HSIDataLoader({"data":{"padding":False, "select_spectral":[1,99,199]}}) 300 | # train_loader = dataloader.generate_torch_dataset() 301 | # train_loader,X,Y = dataloader.generate_torch_dataset(split=True) 302 | # newX = dataloader.re_build_split(X, dataloader.patch_size) 303 | # print(newX.shape) 304 | 305 | #dataloader = HSIDataLoader( 306 | # {"data":{"data_sign":"Pavia", "padding":False, "batch_size":256, "patch_size":16, "select_spectral":[]}}) 307 | #train_loader,X,Y = dataloader.generate_torch_dataset(light_split=True) 308 | #print(X.shape) 309 | 310 | dataloader = HSIDataLoader( 311 | {"data":{"data_sign":"Houston", "padding":False, "batch_size":256, "patch_size":16, "select_spectral":[]}}) 312 | train_loader,X,Y = dataloader.generate_torch_dataset(light_split=True) 313 | print(X.shape) 314 | -------------------------------------------------------------------------------- /codes/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | --------------------------------------------------------------------------------