├── README.md ├── assets ├── AAAI_DiffCLIP.pdf └── DiffCLIP.png ├── augment.py ├── bpe_simple_vocab_16e6.txt.gz ├── build_EMP.py ├── classnames_houston.txt ├── clip_model.py ├── config.py ├── data_read.py ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── respace.py └── timestep_sampler.py ├── generate_pic.py ├── hyper_dataset.py ├── model.py ├── pos_embed.py ├── record.py ├── simple_tokenizer.py ├── train.py └── util_CNN_clip.py /README.md: -------------------------------------------------------------------------------- 1 |
2 |

DiffCLIP

3 |

DiffCLIP: Few-shot Language-driven Multimodal Classifier

4 |

AAAI 2025

5 | 6 |
7 | 8 | 9 | ## **Overview** 10 | 11 |

12 | overview 13 |

14 | 15 | ## **Getting Started** 16 | 17 | **Step 1: Clone the DiffCLIP repository:** 18 | 19 | To get started, first clone the DiffCLIP repository and navigate to the project directory: 20 | 21 | ```bash 22 | git clone https://github.com/icey-zhang/DiffCLIP 23 | cd DiffCLIP 24 | ``` 25 | 26 | **Step 2: Environment Setup:** 27 | 28 | DiffCLIP recommends setting up a conda environment and installing dependencies via pip. Use the following commands to set up your environment: 29 | 30 | ***Create and activate a new conda environment*** 31 | 32 | ```bash 33 | conda create -n DiffCLIP python=3.9.17 34 | conda activate DiffCLIP 35 | ``` 36 | 37 | ***install some necessary package*** 38 | ```bash 39 | pip install pytorch 40 | ...... 41 | ``` 42 | 43 | ### Prepare the dataset 44 | 45 | ```python 46 | root 47 | ├── Trento 48 | │ ├── HSI.mat 49 | │ ├── LiDAR.mat 50 | │ ├── TRLabel.mat 51 | │ ├── TSLabel.mat 52 | ├── ...... 53 | 54 | ``` 55 | 56 | ### Begin to train 57 | 58 | ```python 59 | python train.py 60 | ``` 61 | 62 | ## Citation 63 | If our code is helpful to you, please cite: 64 | 65 | ``` 66 | @article{zhang2024multimodal, 67 | title={Multimodal Informative ViT: Information Aggregation and Distribution for Hyperspectral and LiDAR Classification}, 68 | author={Zhang, Jiaqing and Lei, Jie and Xie, Weiying and Yang, Geng and Li, Daixun and Li, Yunsong}, 69 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 70 | year={2024}, 71 | publisher={IEEE} 72 | } 73 | 74 | @inproceedings{zhange2025DiffCLIP, 75 | title={DiffCLIP: Few-shot Language-driven Multimodal Classifier }, 76 | author={Zhang, Jiaqing and Cao, Mingxiang and Jiang, Kai and Yang, Xue}, 77 | booktitle={AAAI2025} 78 | } 79 | 80 | ``` 81 | 82 | 88 | 94 | Star History Chart 98 | height="500" 99 | /> 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /assets/AAAI_DiffCLIP.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icey-zhang/DiffCLIP/413caad6246caa63799fbd5053f7740ceb9a18c0/assets/AAAI_DiffCLIP.pdf -------------------------------------------------------------------------------- /assets/DiffCLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icey-zhang/DiffCLIP/413caad6246caa63799fbd5053f7740ceb9a18c0/assets/DiffCLIP.png -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | 5 | 6 | class CenterResizeCrop(object): 7 | ''' 8 | Class that performs CenterResizeCrop. 9 | ------------------------------------------------------------------------------------- 10 | probability: The probability that the operation will be performed. 11 | scale_begin: min scale size 12 | windowsize: max scale size 13 | 14 | ------------------------------------------------------------------------------------- 15 | ''' 16 | def __init__(self, scale_begin = 23, windowsize = 27): 17 | self.scale_begin = scale_begin 18 | self.windowsize = windowsize 19 | 20 | def __call__(self, image): 21 | 22 | length = np.array(range(self.scale_begin, self.windowsize+1, 2)) 23 | 24 | row_center = int((self.windowsize-1)/2) 25 | col_center = int((self.windowsize-1)/2) 26 | row = image.shape[1] 27 | col = image.shape[2] 28 | # band = image.shape[0] 29 | s = np.random.choice(length, size = 1) 30 | halfsize_row = int((s-1)/2) 31 | halfsize_col = int((s-1)/2) 32 | r_image = image[:, row_center-halfsize_row : row_center+halfsize_row+1, col_center-halfsize_col : col_center+halfsize_col+1] 33 | r_image_transpose = cv2.resize(np.transpose(r_image, [1, 2, 0]), (row, col)) 34 | if r_image_transpose.ndim != 3: 35 | r_image_transpose = np.expand_dims(r_image_transpose,2) 36 | r_image = np.transpose(r_image_transpose, [2,0,1]) 37 | return r_image 38 | 39 | class RandomResizeCrop(object): 40 | def __init__(self, scale = [0.5, 1], probability = 0.5): 41 | self.scale = scale 42 | self.probability = probability 43 | 44 | def __call__(self, image): 45 | 46 | if random.uniform(0, 1) > self.probability: 47 | return image 48 | else: 49 | row = image.shape[1] 50 | col = image.shape[2] 51 | s = np.random.uniform(self.scale[0], self.scale[1]) 52 | r_row = round(row * s) 53 | r_col = round(col * s) 54 | halfsize_row = int((r_row-1)/2) 55 | halfsize_col = int((r_col-1)/2) 56 | row_center =random.randint(halfsize_row, r_row - halfsize_row-1) 57 | col_center =random.randint(halfsize_col, r_col - halfsize_col-1) 58 | r_image = image[:, row_center-halfsize_row : row_center+halfsize_row+1, col_center-halfsize_col : col_center+halfsize_col+1] 59 | r_image = np.transpose(cv2.resize(np.transpose(r_image,[1,2,0]), (row, col)), [2,0,1]) 60 | return r_image 61 | 62 | class CenterCrop(object): 63 | ''' 64 | Class that performs CenterResizeCrop. 65 | ------------------------------------------------------------------------------------- 66 | probability: The probability that the operation will be performed. 67 | scale_begin: min scale size 68 | windowsize: max scale size 69 | 70 | ------------------------------------------------------------------------------------- 71 | ''' 72 | def __init__(self, scale_begin = 23, windowsize = 27, probability = 0.5): 73 | self.scale_begin = scale_begin 74 | self.windowsize = windowsize 75 | self.probability = probability 76 | 77 | def __call__(self, image): 78 | 79 | if random.uniform(0, 1) > self.probability: 80 | return image 81 | else: 82 | length = np.array(range(self.scale_begin, self.windowsize, 2)) 83 | 84 | row_center = int((self.windowsize-1)/2) 85 | col_center = int((self.windowsize-1)/2) 86 | s = np.random.choice(length, size = 1) 87 | halfsize_row = int((s-1)/2) 88 | halfsize_col = int((s-1)/2) 89 | r_image = image[:, row_center-halfsize_row : row_center+halfsize_row+1, col_center-halfsize_col : col_center+halfsize_col+1] 90 | 91 | # r_image = np.pad(r_image, ((0, 0), (row_center - halfsize_row, row_center - halfsize_row), (col_center - halfsize_col, col_center - halfsize_col)), 'edge') 92 | r_image = np.pad(r_image, ((0, 0), (row_center - halfsize_row, row_center - halfsize_row), (col_center - halfsize_col, col_center - halfsize_col)), 'constant', constant_values=0) 93 | return r_image 94 | 95 | class Cutout(object): 96 | """Randomly mask out one or more patches from an image. 97 | Args: 98 | n_holes (int): Number of patches to cut out of each image. 99 | length (int): The length (in pixels) of each square patch. 100 | """ 101 | def __init__(self, n_holes, length): 102 | self.n_holes = n_holes 103 | self.length = length 104 | 105 | def __call__(self, img): 106 | """ 107 | Args: 108 | img (Tensor): Tensor image of size (C, H, W). 109 | Returns: 110 | Tensor: Image with n_holes of dimension length x length cut out of it. 111 | """ 112 | h = img.shape[1] 113 | w = img.shape[2] 114 | c = img.shape[0] 115 | 116 | mask = np.ones((h, w), np.float32) 117 | 118 | for n in range(self.n_holes): 119 | # (x,y)表示方形补丁的中心位置 120 | y = np.random.randint(h) 121 | x = np.random.randint(w) 122 | 123 | y1 = np.clip(y - self.length // 2, 0, h) 124 | y2 = np.clip(y + self.length // 2, 0, h) 125 | x1 = np.clip(x - self.length // 2, 0, w) 126 | x2 = np.clip(x + self.length // 2, 0, w) 127 | 128 | mask[y1: y2, x1: x2] = 0. 129 | 130 | 131 | mask = np.tile(mask[np.newaxis,:,:], (c,1,1)) 132 | img = img * mask 133 | 134 | return img 135 | 136 | def resize(train_image, size = (27,27)): 137 | r_image = np.zeros([train_image.shape[0], train_image.shape[1], size[0], size[1]], dtype = np.float32) 138 | for i in range(train_image.shape[0]): 139 | r_image[i] = np.transpose(cv2.resize(np.transpose(train_image[i],[1,2,0]), size), [2,0,1]) 140 | return r_image 141 | 142 | def take_elements(image, location, windowsize): 143 | if windowsize == 1: 144 | if len(image.shape) == 3: 145 | spectral = np.zeros([location.shape[0],image.shape[2]], dtype = image.dtype) 146 | for i in range(location.shape[0]): 147 | spectral[i] = image[location[i][0], location[i][1]] 148 | else: 149 | spectral = np.zeros(location.shape[0], dtype = np.int32) 150 | for i in range(location.shape[0]): 151 | spectral[i] = image[location[i][0], location[i][1]] 152 | else: 153 | if len(image.shape) == 3: 154 | halfsize = int((windowsize - 1)/2) 155 | spectral = np.zeros([location.shape[0], windowsize, windowsize, image.shape[2]], dtype = image.dtype) 156 | for i in range(location.shape[0]): 157 | spectral[i,:,:,:] = image[location[i][0]-halfsize : location[i][0]+halfsize+1 , location[i][1]-halfsize : location[i][1]+halfsize+1, :] 158 | else: 159 | halfsize = int((windowsize - 1)/2) 160 | spectral = np.zeros([location.shape[0], windowsize, windowsize], dtype = image.dtype) 161 | for i in range(location.shape[0]): 162 | spectral[i,:,:] = image[location[i][0]-halfsize : location[i][0]+halfsize+1 , location[i][1]-halfsize : location[i][1]+halfsize+1] 163 | return spectral 164 | 165 | 166 | -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icey-zhang/DiffCLIP/413caad6246caa63799fbd5053f7740ceb9a18c0/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /build_EMP.py: -------------------------------------------------------------------------------- 1 | # Building the Extended Morphological Profiles (EMP) 2 | from skimage.morphology import reconstruction 3 | from skimage.morphology import erosion 4 | from skimage.morphology import disk 5 | from skimage import util 6 | import numpy as np 7 | import skimage.morphology as sm 8 | def opening_by_reconstruction(image, se): 9 | """ 10 | Performs an Opening by Reconstruction. 11 | 12 | Parameters: 13 | image: 2D matrix. 14 | se: structuring element 15 | Returns: 16 | 2D matrix of the reconstructed image. 17 | """ 18 | eroded = erosion(image, se) 19 | reconstructed = reconstruction(eroded, image) 20 | return reconstructed 21 | 22 | 23 | def closing_by_reconstruction(image, se): 24 | """ 25 | Performs a Closing by Reconstruction. 26 | 27 | Parameters: 28 | image: 2D matrix. 29 | se: structuring element 30 | Returns: 31 | 2D matrix of the reconstructed image. 32 | """ 33 | obr = opening_by_reconstruction(image, se) 34 | 35 | obr_inverted = util.invert(obr) 36 | obr_inverted_eroded = erosion(obr_inverted, se) 37 | obr_inverted_eroded_rec = reconstruction( 38 | obr_inverted_eroded, obr_inverted) 39 | obr_inverted_eroded_rec_inverted = util.invert(obr_inverted_eroded_rec) 40 | return obr_inverted_eroded_rec_inverted 41 | 42 | 43 | def build_morphological_profiles(image, se_size=4, se_size_increment=2, num_openings_closings=4): 44 | """ 45 | Build the morphological profiles for a given image. 46 | 47 | Parameters: 48 | base_image: 2d matrix, it is the spectral information part of the MP. 49 | se_size: int, initial size of the structuring element (or kernel). Structuring Element used: disk 50 | se_size_increment: int, structuring element increment step 51 | num_openings_closings: int, number of openings and closings by reconstruction to perform. 52 | Returns: 53 | emp: 3d matrix with both spectral (from the base_image) and spatial information 54 | """ 55 | x, y = image.shape 56 | 57 | cbr = np.zeros(shape=(x, y, num_openings_closings)) 58 | obr = np.zeros(shape=(x, y, num_openings_closings)) 59 | 60 | it = 0 61 | tam = se_size 62 | while it < num_openings_closings: 63 | se = disk(tam) 64 | temp = closing_by_reconstruction(image, se) 65 | cbr[:, :, it] = temp[:, :] 66 | temp = opening_by_reconstruction(image, se) 67 | obr[:, :, it] = temp[:, :] 68 | tam += se_size_increment 69 | it += 1 70 | 71 | mp = np.zeros(shape=(x, y, (num_openings_closings*2)+1)) 72 | cont = num_openings_closings - 1 73 | for i in range(num_openings_closings): 74 | mp[:, :, i] = cbr[:, :, cont] 75 | cont = cont - 1 76 | 77 | mp[:, :, num_openings_closings] = image[:, :] 78 | 79 | cont = 0 80 | for i in range(num_openings_closings+1, num_openings_closings*2+1): 81 | mp[:, :, i] = obr[:, :, cont] 82 | cont += 1 83 | 84 | return mp 85 | 86 | 87 | def build_emp(base_image, se_size=4, se_size_increment=2, num_openings_closings=4): 88 | """ 89 | Build the extended morphological profiles for a given set of images. 90 | 91 | Parameters: 92 | base_image: 3d matrix, each 'channel' is considered for applying the morphological profile. It is the spectral information part of the EMP. 93 | se_size: int, initial size of the structuring element (or kernel). Structuring Element used: disk 94 | se_size_increment: int, structuring element increment step 95 | num_openings_closings: int, number of openings and closings by reconstruction to perform. 96 | Returns: 97 | emp: 3d matrix with both spectral (from the base_image) and spatial information 98 | """ 99 | base_image_rows, base_image_columns, base_image_channels = base_image.shape 100 | se_size = se_size 101 | se_size_increment = se_size_increment 102 | num_openings_closings = num_openings_closings 103 | morphological_profile_size = (num_openings_closings * 2) + 1 104 | emp_size = morphological_profile_size * base_image_channels 105 | emp = np.zeros( 106 | shape=(base_image_rows, base_image_columns, emp_size)) 107 | 108 | cont = 0 109 | for i in range(base_image_channels): 110 | # build MPs 111 | mp_temp = build_morphological_profiles( 112 | base_image[:, :, i], se_size, se_size_increment, num_openings_closings) 113 | 114 | aux = morphological_profile_size * (i+1) 115 | 116 | # build the EMP 117 | cont_aux = 0 118 | for k in range(cont, aux): 119 | emp[:, :, k] = mp_temp[:, :, cont_aux] 120 | cont_aux += 1 121 | 122 | cont = morphological_profile_size * (i+1) 123 | 124 | return emp 125 | def build_emp1(base_image, nScale=3): 126 | row = base_image.shape[0] 127 | col = base_image.shape[1] 128 | nPC = base_image.shape[2] 129 | emp = np.zeros([row, col,2*nScale*nPC]) 130 | i = 0 131 | for iScale in range(nScale): 132 | for iPC in range(nPC): 133 | se = sm.disk( 3 * (iScale + 1)) 134 | x=sm.opening(base_image[:, : , iPC],se) 135 | y=sm.closing(base_image[:, :, iPC ],se) 136 | emp[:,:,i] = x 137 | emp[:,:,i+1] = y 138 | i = i + 2 139 | return emp 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /classnames_houston.txt: -------------------------------------------------------------------------------- 1 | 0 Healthy grass 2 | 1 Stressed grass 3 | 2 Synthetic grass 4 | 3 Tree 5 | 4 Soil 6 | 5 Water 7 | 6 Residential 8 | 7 Commercial 9 | 8 Road 10 | 9 Highway 11 | 10 Railway 12 | 11 Park lot 1 13 | 12 Park lot 2 14 | 13 Tennis court 15 | 14 Running track -------------------------------------------------------------------------------- /clip_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import math 9 | from pos_embed import interpolate_pos_embed 10 | from timm.models.layers import trunc_normal_ 11 | from model import vit_HSI_LIDAR_patch3 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super().__init__() 17 | 18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu1 = nn.ReLU(inplace=True) 22 | 23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.relu2 = nn.ReLU(inplace=True) 26 | 27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 28 | 29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 31 | self.relu3 = nn.ReLU(inplace=True) 32 | 33 | self.downsample = None 34 | self.stride = stride 35 | 36 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 38 | self.downsample = nn.Sequential(OrderedDict([ 39 | ("-1", nn.AvgPool2d(stride)), 40 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 41 | ("1", nn.BatchNorm2d(planes * self.expansion)) 42 | ])) 43 | 44 | def forward(self, x: torch.Tensor): 45 | identity = x 46 | 47 | out = self.relu1(self.bn1(self.conv1(x))) 48 | out = self.relu2(self.bn2(self.conv2(out))) 49 | out = self.avgpool(out) 50 | out = self.bn3(self.conv3(out)) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu3(out) 57 | return out 58 | 59 | 60 | class AttentionPool2d(nn.Module): 61 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 62 | super().__init__() 63 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 64 | self.k_proj = nn.Linear(embed_dim, embed_dim) 65 | self.q_proj = nn.Linear(embed_dim, embed_dim) 66 | self.v_proj = nn.Linear(embed_dim, embed_dim) 67 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 68 | self.num_heads = num_heads 69 | 70 | def forward(self, x): 71 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 72 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 73 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 74 | x, _ = F.multi_head_attention_forward( 75 | query=x[:1], key=x, value=x, 76 | embed_dim_to_check=x.shape[-1], 77 | num_heads=self.num_heads, 78 | q_proj_weight=self.q_proj.weight, 79 | k_proj_weight=self.k_proj.weight, 80 | v_proj_weight=self.v_proj.weight, 81 | in_proj_weight=None, 82 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 83 | bias_k=None, 84 | bias_v=None, 85 | add_zero_attn=False, 86 | dropout_p=0, 87 | out_proj_weight=self.c_proj.weight, 88 | out_proj_bias=self.c_proj.bias, 89 | use_separate_proj_weight=True, 90 | training=self.training, 91 | need_weights=False 92 | ) 93 | return x.squeeze(0) 94 | 95 | 96 | class ModifiedResNet(nn.Module): 97 | """ 98 | A ResNet class that is similar to torchvision's but contains the following changes: 99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 101 | - The final pooling layer is a QKV attention instead of an average pool 102 | """ 103 | 104 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 105 | super().__init__() 106 | self.output_dim = output_dim 107 | self.input_resolution = input_resolution 108 | 109 | # the 3-layer stem 110 | self.conv1 = nn.Conv2d(width, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(width // 2) 112 | self.relu1 = nn.ReLU(inplace=True) 113 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(width // 2) 115 | self.relu2 = nn.ReLU(inplace=True) 116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(width) 118 | self.relu3 = nn.ReLU(inplace=True) 119 | self.avgpool = nn.AvgPool2d(2) 120 | 121 | # residual layers 122 | self._inplanes = width # this is a *mutable* variable used during construction 123 | self.layer1 = self._make_layer(width, layers[0]) 124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 127 | 128 | embed_dim = width * 32 # the ResNet feature dimension 129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 130 | 131 | def _make_layer(self, planes, blocks, stride=1): 132 | layers = [Bottleneck(self._inplanes, planes, stride)] 133 | 134 | self._inplanes = planes * Bottleneck.expansion 135 | for _ in range(1, blocks): 136 | layers.append(Bottleneck(self._inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | def stem(x): 142 | x = self.relu1(self.bn1(self.conv1(x))) 143 | x = self.relu2(self.bn2(self.conv2(x))) 144 | x = self.relu3(self.bn3(self.conv3(x))) 145 | x = self.avgpool(x) 146 | return x 147 | 148 | x = x.type(self.conv1.weight.dtype) 149 | x = stem(x) 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | x = self.attnpool(x) 155 | 156 | return x 157 | 158 | 159 | class LayerNorm(nn.LayerNorm): 160 | """Subclass torch's LayerNorm to handle fp16.""" 161 | 162 | def forward(self, x: torch.Tensor): 163 | orig_type = x.dtype 164 | ret = super().forward(x.type(torch.float32)) 165 | return ret.type(orig_type) 166 | 167 | 168 | class QuickGELU(nn.Module): 169 | def forward(self, x: torch.Tensor): 170 | return x * torch.sigmoid(1.702 * x) 171 | 172 | 173 | class ResidualAttentionBlock(nn.Module): 174 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 175 | super().__init__() 176 | 177 | self.attn = nn.MultiheadAttention(d_model, n_head) 178 | self.ln_1 = LayerNorm(d_model) 179 | self.mlp = nn.Sequential(OrderedDict([ 180 | ("c_fc", nn.Linear(d_model, d_model * 4)), 181 | ("gelu", QuickGELU()), 182 | ("c_proj", nn.Linear(d_model * 4, d_model)), 183 | ])) 184 | self.ln_2 = LayerNorm(d_model) 185 | self.attn_mask = attn_mask 186 | 187 | def attention(self, x: torch.Tensor): 188 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 189 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 190 | 191 | def forward(self, x: torch.Tensor): 192 | x = x + self.attention(self.ln_1(x)) 193 | x = x + self.mlp(self.ln_2(x)) 194 | return x 195 | 196 | 197 | class Transformer(nn.Module): 198 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 199 | super().__init__() 200 | self.width = width 201 | self.layers = layers 202 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)]) 203 | def forward(self, x: torch.Tensor): 204 | return self.resblocks(x) 205 | 206 | 207 | class cnn_encoder_A_and_B(nn.Module): 208 | def __init__(self, l1, l2,vision_patch_size,out_channels): 209 | super(cnn_encoder_A_and_B, self).__init__() 210 | 211 | self.conv1 = conv_bn_relu(l1, 32, 3, 1, 1) 212 | self.conv2 = conv_bn_relu(l2, 32, 3, 1, 1) 213 | self.conv1_1 = conv_bn_relu_max(32, out_channels, 3, 1, 1, vision_patch_size) 214 | self.conv2_1 = conv_bn_relu_max(32, out_channels, 3, 1, 1, vision_patch_size) 215 | 216 | 217 | self.xishu1 = torch.nn.Parameter(torch.Tensor([0.5])) # lamda 218 | self.xishu2 = torch.nn.Parameter(torch.Tensor([0.5])) # 1 - lamda 219 | 220 | 221 | def forward(self, x11, x21): 222 | x11 = self.conv1(x11) 223 | x21 = self.conv2(x21) 224 | x1_1 = self.conv1_1(x11) #64,64,8,8 225 | x2_1 = self.conv2_1(x21) 226 | x_add = x1_1*self.xishu1 + x2_1*self.xishu2 227 | 228 | return x_add,x1_1,x2_1 229 | 230 | class VisionTransformer(nn.Module): 231 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 232 | super().__init__() 233 | self.input_resolution = input_resolution 234 | self.output_dim = output_dim 235 | 236 | scale = width ** -0.5 237 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 238 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 239 | self.ln_pre = LayerNorm(width) 240 | 241 | self.transformer = Transformer(width, layers, heads) 242 | 243 | self.ln_post = LayerNorm(width) 244 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 245 | 246 | def forward(self, x: torch.Tensor,csa=True): 247 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 248 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 249 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 250 | x = x + self.positional_embedding.to(x.dtype) 251 | x = self.ln_pre(x) 252 | 253 | x = x.permute(1, 0, 2) # NLD -> LND 254 | x = self.transformer(x) 255 | 256 | x = x.permute(1, 0, 2) # LND -> NLD 257 | x_vit = x 258 | 259 | 260 | x = self.ln_post(x[:, 0, :]) 261 | 262 | if self.proj is not None: 263 | x = x @ self.proj 264 | 265 | return x,x_vit 266 | 267 | 268 | 269 | class CLIP(nn.Module): 270 | def __init__(self,img_size, 271 | in_chans, in_chans_LIDAR,hid_chans, hid_chans_LIDAR, embed_dim, depth, num_heads, mlp_ratio,num_classes, global_pool, 272 | # text 273 | context_length: int, 274 | vocab_size: int, 275 | transformer_width: int, 276 | transformer_heads: int, 277 | transformer_layers: int 278 | ): 279 | super().__init__() 280 | 281 | self.context_length = context_length 282 | 283 | self.visual = vit_HSI_LIDAR_patch3(img_size=img_size, 284 | in_chans=in_chans, in_chans_LIDAR=in_chans_LIDAR,hid_chans=hid_chans, 285 | hid_chans_LIDAR=hid_chans_LIDAR, embed_dim=embed_dim, depth=depth, num_heads=num_heads, 286 | mlp_ratio=mlp_ratio,num_classes=num_classes, global_pool=global_pool).cuda() 287 | 288 | checkpoint = torch.load('./net.pt') 289 | checkpoint_model = checkpoint['model'] 290 | 291 | state_dict = self.visual.state_dict() 292 | for k in ['head.weight', 'head.bias','head_LIDAR.weight', 'head_LIDAR.bias']: 293 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 294 | print(f"Removing key {k} from pretrained checkpoint") 295 | del checkpoint_model[k] 296 | 297 | interpolate_pos_embed(self.visual, checkpoint_model) 298 | msg = self.visual.load_state_dict(checkpoint_model, strict=False) 299 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 300 | 301 | # manually initialize fc layer 302 | trunc_normal_(self.visual.head.weight, std=2e-5) 303 | 304 | self.transformer = Transformer( 305 | width=transformer_width, 306 | layers=transformer_layers, 307 | heads=transformer_heads, 308 | attn_mask=self.build_attention_mask() 309 | ) 310 | 311 | self.vocab_size = vocab_size 312 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 313 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 314 | self.ln_final = LayerNorm(transformer_width) 315 | 316 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 317 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.7)) 318 | 319 | self.initialize_parameters() 320 | 321 | def initialize_parameters(self): 322 | nn.init.normal_(self.token_embedding.weight, std=0.02) 323 | nn.init.normal_(self.positional_embedding, std=0.01) 324 | 325 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 326 | attn_std = self.transformer.width ** -0.5 327 | fc_std = (2 * self.transformer.width) ** -0.5 328 | for block in self.transformer.resblocks: 329 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 330 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 331 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 332 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 333 | for block in self.transformer.resblocks[-1:]: 334 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 335 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 336 | 337 | if self.text_projection is not None: 338 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 339 | 340 | 341 | def build_attention_mask(self): 342 | # lazily create causal attention mask, with full attention between the vision tokens 343 | # pytorch uses additive attention mask; fill with -inf 344 | mask = torch.empty(self.context_length, self.context_length) 345 | mask.fill_(float("-inf")) 346 | mask.triu_(1) # zero out the lower diagonal 347 | return mask 348 | 349 | @property 350 | def dtype(self): 351 | return self.visual.dimen_redu[0].weight.dtype 352 | 353 | 354 | def encode_text(self, text): 355 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 356 | 357 | x = x + self.positional_embedding.type(self.dtype) 358 | x = x.permute(1, 0, 2) # NLD -> LND 359 | x = self.transformer(x) 360 | x = x.permute(1, 0, 2) # LND -> NLD 361 | x = self.ln_final(x).type(self.dtype) 362 | 363 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 364 | 365 | return x 366 | 367 | def forward(self, image,image_li,t, text): 368 | text_features = self.encode_text(text) 369 | image_features_x = self.visual(image,image_li,t) 370 | image_features_x = image_features_x / image_features_x.norm(dim=1, keepdim=True) 371 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 372 | 373 | 374 | # cosine similarity as logits 375 | logit_scale = self.logit_scale.exp() 376 | logits_per_image_x= logit_scale * image_features_x @ text_features.t() 377 | 378 | logits_per_text = logits_per_image_x.t() 379 | 380 | return logits_per_image_x,logits_per_text 381 | 382 | 383 | def convert_weights(model: nn.Module): 384 | """Convert applicable model parameters to fp16""" 385 | 386 | def _convert_weights_to_fp16(l): 387 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 388 | l.weight.data = l.weight.data.half() 389 | if l.bias is not None: 390 | l.bias.data = l.bias.data.half() 391 | 392 | if isinstance(l, nn.MultiheadAttention): 393 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 394 | tensor = getattr(l, attr) 395 | if tensor is not None: 396 | tensor.data = tensor.data.half() 397 | 398 | for name in ["text_projection", "proj"]: 399 | if hasattr(l, name): 400 | attr = getattr(l, name) 401 | if attr is not None: 402 | attr.data = attr.data.half() 403 | 404 | model.apply(_convert_weights_to_fp16) 405 | 406 | 407 | def build_model(img_size, in_chans, in_chans_LIDAR, hid_chans, hid_chans_LIDAR, embed_dim, depth, num_heads, mlp_ratio,num_classes, global_pool=False): 408 | 409 | context_length = 77 410 | vocab_size = 49408 411 | transformer_width = 128 412 | transformer_heads = 8 413 | transformer_layers = 3 414 | model = CLIP(img_size,in_chans, in_chans_LIDAR,hid_chans, hid_chans_LIDAR, embed_dim, depth, num_heads, mlp_ratio,num_classes, global_pool, 415 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 416 | ) 417 | return model 418 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn 3 | 4 | def load_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Pre training 8 | parser.add_argument('--train_num_perclass', type=int, default=40) 9 | # Pre training 10 | 11 | parser.add_argument('--windowsize', type=int, default=11) 12 | parser.add_argument('--batch_size', type=int, default=128) 13 | parser.add_argument('--epochs', type=int, default=400) 14 | parser.add_argument('--fine_tuned_epochs', type=int, default=150) 15 | parser.add_argument('--lr', type=float, default=1e-4) 16 | parser.add_argument('--fine_tuned_lr', type=float, default=1e-4) 17 | 18 | parser.add_argument('--type', type=str, default='none') 19 | parser.add_argument('--dataset', type=str, default='2013houston') #Trento 2013houston Muufl Berlin 20 | 21 | # Network 22 | parser.add_argument('--mask_ratio', default=0.7, type=float, 23 | help='mask_ratio (default: 0.8)') 24 | parser.add_argument('--mlp_ratio', default=2.0, type=float, 25 | help='mlp ratio of encoder/decoder (default: 2.0)') 26 | parser.add_argument('--hid_chans', default=128, type=int, 27 | help='hidden channels for dimension reduction (default: 128)') 28 | parser.add_argument('--hid_chans_LIDAR', default=128, type=int, 29 | help='hidden channels for dimension reduction (default: 128)') 30 | # Augmentation 31 | parser.add_argument('--augment', default=True, type=bool, 32 | help='either use data augmentation or not (default: False)') 33 | parser.add_argument('--scale', default=9, type=int, 34 | help='the minimum scale for center crop (default: 19)') 35 | 36 | # MAE encoder specifics 37 | parser.add_argument('--encoder_dim', default=128, type=int, 38 | help='feature dimension for encoder (default: 64)') 39 | parser.add_argument('--encoder_depth', default = 4, type=int, 40 | help='encoder_depth; number of blocks ') 41 | parser.add_argument('--encoder_num_heads', default=8, type=int, 42 | help='number of heads of encoder (default: 8)') 43 | 44 | # MAE decoder specifics 45 | parser.add_argument('--decoder_dim', default= 128, type=int, 46 | help='feature dimension for decoder (default: 64)') 47 | parser.add_argument('--decoder_depth', default = 3, type=int, 48 | help='decoder_depth; number of blocks ') 49 | parser.add_argument('--decoder_num_heads', default=8, type=int, 50 | help='number of heads of decoder (default: 8)') 51 | 52 | # options for supervised MAE 53 | parser.add_argument('--temperature', default=1.0, type=float, 54 | help='temperature for classification logits') 55 | parser.add_argument('--cls_loss_ratio', default=0.005, type=float, 56 | help='ratio for classification loss') 57 | args = parser.parse_args() 58 | return args 59 | -------------------------------------------------------------------------------- /data_read.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | from sklearn.decomposition import PCA 4 | from build_EMP import build_emp 5 | 6 | def pca_whitening(image, number_of_pc): 7 | shape = image.shape 8 | image = np.reshape(image, [shape[0]*shape[1], shape[2]]) 9 | number_of_rows = shape[0] 10 | number_of_columns = shape[1] 11 | pca = PCA(n_components = number_of_pc) 12 | image = pca.fit_transform(image) 13 | pc_images = np.zeros(shape=(number_of_rows, number_of_columns, number_of_pc),dtype=np.float32) 14 | for i in range(number_of_pc): 15 | pc_images[:, :, i] = np.reshape(image[:, i], (number_of_rows, number_of_columns)) 16 | return pc_images 17 | 18 | def load_data(dataset): 19 | if dataset == 'Trento': 20 | image_file_HSI = r'Trento/HSI.mat' 21 | image_file_LiDAR = r'Trento/LiDAR.mat' 22 | label_file_tr = r'Trento/TRLabel.mat' 23 | label_file_ts = r'Trento/TSLabel.mat' 24 | image_data_HSI = sio.loadmat(image_file_HSI) 25 | image_data_LiDAR = sio.loadmat(image_file_LiDAR) 26 | label_data_tr = sio.loadmat(label_file_tr) 27 | label_data_ts = sio.loadmat(label_file_ts) 28 | image_HSI = image_data_HSI['HSI'] 29 | image_LiDAR = image_data_LiDAR['LiDAR'] 30 | label = label_data_tr['TRLabel']+label_data_ts['TSLabel'] 31 | elif dataset == '2013houston': 32 | image_file_HSI = r'./Houston2013/HSI.mat' 33 | image_file_LiDAR = r'./Houston2013/LiDAR.mat' 34 | label_file_tr = r'./Houston2013/TRLabel.mat' 35 | label_file_ts = r'./Houston2013/TSLabel.mat' 36 | image_data_HSI = sio.loadmat(image_file_HSI) 37 | image_data_LiDAR = sio.loadmat(image_file_LiDAR) 38 | label_data_tr = sio.loadmat(label_file_tr) 39 | label_data_ts = sio.loadmat(label_file_ts) 40 | image_HSI = image_data_HSI['HSI'] 41 | image_LiDAR = image_data_LiDAR['LiDAR'] 42 | label = label_data_tr['TRLabel']+label_data_ts['TSLabel'] 43 | elif dataset == 'Muufl': 44 | image_file_HSI = r'Muufl/hsi.mat' 45 | image_file_LiDAR = r'Muufl/lidar.mat' 46 | label_file = r'Muufl/train_test_gt.mat' 47 | image_data_HSI = sio.loadmat(image_file_HSI) 48 | image_data_LiDAR = sio.loadmat(image_file_LiDAR) 49 | label_data = sio.loadmat(label_file) 50 | image_HSI = image_data_HSI['HSI'] 51 | image_LiDAR = image_data_LiDAR['lidar'] 52 | label = label_data['trainlabels']+label_data['testlabels'] 53 | else: 54 | raise Exception('dataset does not find') 55 | image_HSI = image_HSI.astype(np.float32) 56 | image_LiDAR = image_LiDAR.astype(np.float32) 57 | label = label.astype(np.int64) 58 | return image_HSI, image_LiDAR, label 59 | 60 | 61 | def readdata(type, dataset, windowsize, train_num, val_num, num): 62 | 63 | or_image_HSI, or_image_LiDAR, or_label = load_data(dataset) 64 | # image = np.expand_dims(image, 2) 65 | halfsize = int((windowsize-1)/2) 66 | number_class = np.max(or_label).astype(np.int64) 67 | if dataset == 'Augsburg_SAR': 68 | pass 69 | elif dataset == 'Berlin' or dataset == 'Muufl': 70 | pass 71 | else: 72 | or_image_LiDAR = np.expand_dims(or_image_LiDAR, 2) 73 | # or_image_LiDAR = np.expand_dims(or_image_LiDAR, 2) 74 | image = np.pad(or_image_HSI, ((halfsize, halfsize), (halfsize, halfsize), (0, 0)), 'edge') 75 | image_LiDAR = np.pad(or_image_LiDAR, ((halfsize, halfsize), (halfsize, halfsize), (0, 0)), 'edge') 76 | label = np.pad(or_label, ((halfsize, halfsize), (halfsize, halfsize)), 'constant',constant_values=0) 77 | 78 | if type == 'PCA': 79 | image1 = pca_whitening(image, number_of_pc = 30) 80 | image_LiDAR1 = np.copy(image_LiDAR) 81 | elif type == 'EMP': 82 | image1 = pca_whitening(image, number_of_pc = 4) 83 | num_openings_closings = 3 84 | emp_image = build_emp(base_image=image1, num_openings_closings=num_openings_closings) 85 | image1 = emp_image 86 | elif type == 'none': 87 | image1 = np.copy(image) 88 | image_LiDAR1 = np.copy(image_LiDAR) 89 | else: 90 | raise Exception('type does not find') 91 | image = (image1 - np.min(image1)) / (np.max(image1) - np.min(image1)) 92 | image_LiDAR = (image_LiDAR1 - np.min(image_LiDAR1)) / (np.max(image_LiDAR1) - np.min(image_LiDAR1)) 93 | #set the manner of selecting training samples 94 | 95 | 96 | n = np.zeros(number_class,dtype=np.int64) 97 | for i in range(number_class): 98 | temprow, tempcol = np.where(label == i + 1) 99 | n[i] = len(temprow) 100 | total_num = np.sum(n) 101 | 102 | nTrain_perClass = np.ones(number_class,dtype=np.int64) * train_num 103 | for i in range(number_class): 104 | if n[i] <= nTrain_perClass[i]: 105 | nTrain_perClass[i] = 15 106 | ###验证机数目 107 | nValidation_perClass = (n/total_num)*val_num 108 | nvalid_perClass = nValidation_perClass.astype(np.int32) 109 | 110 | index = [] 111 | flag = 0 112 | fl = 0 113 | 114 | 115 | bands = np.size(image,2) 116 | bands_LIDAR = np.size(image_LiDAR,2) 117 | validation_image = np.zeros([np.sum(nvalid_perClass), windowsize, windowsize, bands], dtype=np.float32) 118 | validation_image_LIDAR = np.zeros([np.sum(nvalid_perClass), windowsize, windowsize, bands_LIDAR], dtype=np.float32) 119 | validation_label = np.zeros(np.sum(nvalid_perClass), dtype=np.int64) 120 | train_image = np.zeros([np.sum(nTrain_perClass), windowsize, windowsize, bands], dtype=np.float32) 121 | train_image_LIDAR = np.zeros([np.sum(nTrain_perClass), windowsize, windowsize, bands_LIDAR], dtype=np.float32) 122 | train_label = np.zeros(np.sum(nTrain_perClass),dtype=np.int64) 123 | train_index = np.zeros([np.sum(nTrain_perClass), 2], dtype = np.int32) 124 | val_index = np.zeros([np.sum(nvalid_perClass), 2], dtype = np.int32) 125 | 126 | for i in range(number_class): 127 | temprow, tempcol = np.where(label == i + 1) 128 | matrix = np.zeros([len(temprow),2], dtype=np.int64) 129 | matrix[:,0] = temprow 130 | matrix[:,1] = tempcol 131 | np.random.seed(num) 132 | np.random.shuffle(matrix) 133 | 134 | temprow = matrix[:,0] 135 | tempcol = matrix[:,1] 136 | index.append(matrix) 137 | 138 | for j in range(nTrain_perClass[i]): 139 | train_image[flag + j, :, :, :] = image[(temprow[j] - halfsize):(temprow[j] + halfsize + 1), 140 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)] 141 | train_image_LIDAR[flag + j, :, :, :] = image_LiDAR[(temprow[j] - halfsize):(temprow[j] + halfsize + 1), 142 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)] 143 | train_label[flag + j] = i 144 | train_index[flag + j] = matrix[j,:] 145 | flag = flag + nTrain_perClass[i] 146 | 147 | for j in range(nTrain_perClass[i], nTrain_perClass[i] + nvalid_perClass[i]): 148 | validation_image[fl + j-nTrain_perClass[i], :, :,:] = image[(temprow[j] - halfsize):(temprow[j] + halfsize + 1), 149 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)] 150 | validation_image_LIDAR[fl + j-nTrain_perClass[i], :, :,:] = image_LiDAR[(temprow[j] - halfsize):(temprow[j] + halfsize + 1), 151 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)] 152 | validation_label[fl + j-nTrain_perClass[i] ] = i 153 | val_index[fl + j-nTrain_perClass[i]] = matrix[j,:] 154 | fl =fl + nvalid_perClass[i] 155 | 156 | 157 | return train_image, train_image_LIDAR, train_label, validation_image, validation_image_LIDAR, validation_label,\ 158 | nTrain_perClass, nvalid_perClass,train_index, val_index, index, image, image_LiDAR, label,total_num 159 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=False, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.START_X if not predict_xstart else gd.ModelMeanType.EPSILON 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | ) 46 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch as th 12 | import enum 13 | 14 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 15 | 16 | 17 | def mean_flat(tensor): 18 | """ 19 | Take the mean over all non-batch dimensions. 20 | """ 21 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 22 | 23 | 24 | class ModelMeanType(enum.Enum): 25 | """ 26 | Which type of output the model predicts. 27 | """ 28 | 29 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 30 | START_X = enum.auto() # the model predicts x_0 31 | EPSILON = enum.auto() # the model predicts epsilon 32 | 33 | 34 | class ModelVarType(enum.Enum): 35 | """ 36 | What is used as the model's output variance. 37 | The LEARNED_RANGE option has been added to allow the model to predict 38 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 39 | """ 40 | 41 | LEARNED = enum.auto() 42 | FIXED_SMALL = enum.auto() 43 | FIXED_LARGE = enum.auto() 44 | LEARNED_RANGE = enum.auto() 45 | 46 | 47 | class LossType(enum.Enum): 48 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 49 | RESCALED_MSE = ( 50 | enum.auto() 51 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 52 | KL = enum.auto() # use the variational lower-bound 53 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 54 | 55 | def is_vb(self): 56 | return self == LossType.KL or self == LossType.RESCALED_KL 57 | 58 | 59 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 60 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 61 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 62 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 63 | return betas 64 | 65 | 66 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 67 | """ 68 | This is the deprecated API for creating beta schedules. 69 | See get_named_beta_schedule() for the new library of schedules. 70 | """ 71 | if beta_schedule == "quad": 72 | betas = ( 73 | np.linspace( 74 | beta_start ** 0.5, 75 | beta_end ** 0.5, 76 | num_diffusion_timesteps, 77 | dtype=np.float64, 78 | ) 79 | ** 2 80 | ) 81 | elif beta_schedule == "linear": 82 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 83 | elif beta_schedule == "warmup10": 84 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 85 | elif beta_schedule == "warmup50": 86 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 87 | elif beta_schedule == "const": 88 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 89 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 90 | betas = 1.0 / np.linspace( 91 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 92 | ) 93 | else: 94 | raise NotImplementedError(beta_schedule) 95 | assert betas.shape == (num_diffusion_timesteps,) 96 | return betas 97 | 98 | 99 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 100 | """ 101 | Get a pre-defined beta schedule for the given name. 102 | The beta schedule library consists of beta schedules which remain similar 103 | in the limit of num_diffusion_timesteps. 104 | Beta schedules may be added, but should not be removed or changed once 105 | they are committed to maintain backwards compatibility. 106 | """ 107 | if schedule_name == "linear": 108 | # Linear schedule from Ho et al, extended to work for any number of 109 | # diffusion steps. 110 | scale = 1000 / num_diffusion_timesteps 111 | return get_beta_schedule( 112 | "linear", 113 | beta_start=scale * 0.0001, 114 | beta_end=scale * 0.02, 115 | num_diffusion_timesteps=num_diffusion_timesteps, 116 | ) 117 | elif schedule_name == "squaredcos_cap_v2": 118 | return betas_for_alpha_bar( 119 | num_diffusion_timesteps, 120 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 121 | ) 122 | else: 123 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 124 | 125 | 126 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 127 | """ 128 | Create a beta schedule that discretizes the given alpha_t_bar function, 129 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 130 | :param num_diffusion_timesteps: the number of betas to produce. 131 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 132 | produces the cumulative product of (1-beta) up to that 133 | part of the diffusion process. 134 | :param max_beta: the maximum beta to use; use values lower than 1 to 135 | prevent singularities. 136 | """ 137 | betas = [] 138 | for i in range(num_diffusion_timesteps): 139 | t1 = i / num_diffusion_timesteps 140 | t2 = (i + 1) / num_diffusion_timesteps 141 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 142 | return np.array(betas) 143 | 144 | 145 | class GaussianDiffusion: 146 | """ 147 | Utilities for training and sampling diffusion models. 148 | Original ported from this codebase: 149 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 150 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 151 | starting at T and going to 1. 152 | """ 153 | 154 | def __init__( 155 | self, 156 | *, 157 | betas, 158 | model_mean_type, 159 | model_var_type, 160 | loss_type 161 | ): 162 | 163 | self.model_mean_type = model_mean_type 164 | self.model_var_type = model_var_type 165 | self.loss_type = loss_type 166 | 167 | # Use float64 for accuracy. 168 | betas = np.array(betas, dtype=np.float64) 169 | self.betas = betas 170 | assert len(betas.shape) == 1, "betas must be 1-D" 171 | assert (betas > 0).all() and (betas <= 1).all() 172 | 173 | self.num_timesteps = int(betas.shape[0]) 174 | 175 | alphas = 1.0 - betas 176 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 177 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 178 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 179 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 180 | 181 | # calculations for diffusion q(x_t | x_{t-1}) and others 182 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 183 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 184 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 185 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 186 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 187 | 188 | # calculations for posterior q(x_{t-1} | x_t, x_0) 189 | self.posterior_variance = ( 190 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 191 | ) 192 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 193 | self.posterior_log_variance_clipped = np.log( 194 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 195 | ) if len(self.posterior_variance) > 1 else np.array([]) 196 | 197 | self.posterior_mean_coef1 = ( 198 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 199 | ) 200 | self.posterior_mean_coef2 = ( 201 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 202 | ) 203 | 204 | def q_mean_variance(self, x_start, t): 205 | """ 206 | Get the distribution q(x_t | x_0). 207 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 208 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 209 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 210 | """ 211 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 212 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 213 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 214 | return mean, variance, log_variance 215 | 216 | def q_sample(self, x_start, t, noise=None): 217 | """ 218 | Diffuse the data for a given number of diffusion steps. 219 | In other words, sample from q(x_t | x_0). 220 | :param x_start: the initial data batch. 221 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 222 | :param noise: if specified, the split-out normal noise. 223 | :return: A noisy version of x_start. 224 | """ 225 | if noise is None: 226 | noise = th.randn_like(x_start) 227 | assert noise.shape == x_start.shape 228 | return ( 229 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 230 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 231 | ) 232 | 233 | def q_posterior_mean_variance(self, x_start, x_t, t): 234 | """ 235 | Compute the mean and variance of the diffusion posterior: 236 | q(x_{t-1} | x_t, x_0) 237 | """ 238 | assert x_start.shape == x_t.shape 239 | posterior_mean = ( 240 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 241 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 242 | ) 243 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 244 | posterior_log_variance_clipped = _extract_into_tensor( 245 | self.posterior_log_variance_clipped, t, x_t.shape 246 | ) 247 | assert ( 248 | posterior_mean.shape[0] 249 | == posterior_variance.shape[0] 250 | == posterior_log_variance_clipped.shape[0] 251 | == x_start.shape[0] 252 | ) 253 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 254 | 255 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): 256 | """ 257 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 258 | the initial x, x_0. 259 | :param model: the model, which takes a signal and a batch of timesteps 260 | as input. 261 | :param x: the [N x C x ...] tensor at time t. 262 | :param t: a 1-D Tensor of timesteps. 263 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 264 | :param denoised_fn: if not None, a function which applies to the 265 | x_start prediction before it is used to sample. Applies before 266 | clip_denoised. 267 | :param model_kwargs: if not None, a dict of extra keyword arguments to 268 | pass to the model. This can be used for conditioning. 269 | :return: a dict with the following keys: 270 | - 'mean': the model mean output. 271 | - 'variance': the model variance output. 272 | - 'log_variance': the log of 'variance'. 273 | - 'pred_xstart': the prediction for x_0. 274 | """ 275 | if model_kwargs is None: 276 | model_kwargs = {} 277 | 278 | B, C = x.shape[:2] 279 | assert t.shape == (B,) 280 | loss, model_output, model_output_LIDAR, logits, ids_restore, ids_restore_LIDAR, Cross_rec_loss = model(x, x_LIDAR, t, **model_kwargs) 281 | if isinstance(model_output, tuple): 282 | model_output, extra = model_output 283 | else: 284 | extra = None 285 | 286 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 287 | assert model_output.shape == (B, C *2, *x.shape[2:]) 288 | model_output, model_var_values = th.split(model_output, C, dim=1) 289 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 290 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 291 | # The model_var_values is [-1, 1] for [min_var, max_var]. 292 | frac = (model_var_values + 1) / 2 293 | model_log_variance = frac * max_log + (1 - frac) * min_log 294 | model_variance = th.exp(model_log_variance) 295 | 296 | model_output_LIDAR, model_var_values_LIDAR = th.split(model_output_LIDAR, 1, dim=1) 297 | min_log_LIDAR = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_LIDAR.shape) 298 | max_log_LIDAR = _extract_into_tensor(np.log(self.betas), t, x_LIDAR.shape) 299 | # The model_var_values is [-1, 1] for [min_var, max_var]. 300 | frac_LIDAR = (model_var_values_LIDAR + 1) / 2 301 | model_log_variance_LIDAR = frac_LIDAR * max_log_LIDAR + (1 - frac_LIDAR) * min_log_LIDAR 302 | model_variance_LIDAR = th.exp(model_log_variance_LIDAR) 303 | else: 304 | model_variance, model_log_variance = { 305 | # for fixedlarge, we set the initial (log-)variance like so 306 | # to get a better decoder log likelihood. 307 | ModelVarType.FIXED_LARGE: ( 308 | np.append(self.posterior_variance[1], self.betas[1:]), 309 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 310 | ), 311 | ModelVarType.FIXED_SMALL: ( 312 | self.posterior_variance, 313 | self.posterior_log_variance_clipped, 314 | ), 315 | }[self.model_var_type] 316 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 317 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 318 | 319 | def process_xstart(x): 320 | if denoised_fn is not None: 321 | x = denoised_fn(x) 322 | if clip_denoised: 323 | return x.clamp(-1, 1) 324 | return x 325 | 326 | if self.model_mean_type == ModelMeanType.START_X: 327 | pred_xstart = process_xstart(model_output) 328 | pred_xstart_LIDAR = process_xstart(model_output_LIDAR) 329 | else: 330 | pred_xstart = process_xstart( 331 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 332 | ) 333 | pred_xstart_LIDAR = process_xstart( 334 | self._predict_xstart_from_eps(x_t=x_LIDAR, t=t, eps=model_output_LIDAR) 335 | ) 336 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 337 | model_mean_LIDAR, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart_LIDAR, x_t=x_LIDAR, t=t) 338 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 339 | return { 340 | "mean": model_mean, 341 | "variance": model_variance, 342 | "log_variance": model_log_variance, 343 | "pred_xstart": pred_xstart, 344 | "extra": extra, 345 | } 346 | 347 | def _predict_xstart_from_eps(self, x_t, t, eps): 348 | assert x_t.shape == eps.shape 349 | return ( 350 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 351 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 352 | ) 353 | 354 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 355 | return ( 356 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 357 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 358 | 359 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 360 | """ 361 | Compute the mean for the previous step, given a function cond_fn that 362 | computes the gradient of a conditional log probability with respect to 363 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 364 | condition on y. 365 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 366 | """ 367 | gradient = cond_fn(x, t, **model_kwargs) 368 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 369 | return new_mean 370 | 371 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 372 | """ 373 | Compute what the p_mean_variance output would have been, should the 374 | model's score function be conditioned by cond_fn. 375 | See condition_mean() for details on cond_fn. 376 | Unlike condition_mean(), this instead uses the conditioning strategy 377 | from Song et al (2020). 378 | """ 379 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 380 | 381 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 382 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 383 | 384 | out = p_mean_var.copy() 385 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 386 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 387 | return out 388 | 389 | def p_sample( 390 | self, 391 | model, 392 | x, 393 | t, 394 | clip_denoised=True, 395 | denoised_fn=None, 396 | cond_fn=None, 397 | model_kwargs=None, 398 | ): 399 | """ 400 | Sample x_{t-1} from the model at the given timestep. 401 | :param model: the model to sample from. 402 | :param x: the current tensor at x_{t-1}. 403 | :param t: the value of t, starting at 0 for the first diffusion step. 404 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 405 | :param denoised_fn: if not None, a function which applies to the 406 | x_start prediction before it is used to sample. 407 | :param cond_fn: if not None, this is a gradient function that acts 408 | similarly to the model. 409 | :param model_kwargs: if not None, a dict of extra keyword arguments to 410 | pass to the model. This can be used for conditioning. 411 | :return: a dict containing the following keys: 412 | - 'sample': a random sample from the model. 413 | - 'pred_xstart': a prediction of x_0. 414 | """ 415 | out = self.p_mean_variance( 416 | model, 417 | x, 418 | t, 419 | clip_denoised=clip_denoised, 420 | denoised_fn=denoised_fn, 421 | model_kwargs=model_kwargs, 422 | ) 423 | noise = th.randn_like(x) 424 | nonzero_mask = ( 425 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 426 | ) # no noise when t == 0 427 | if cond_fn is not None: 428 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 429 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 430 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 431 | 432 | def p_sample_loop( 433 | self, 434 | model, 435 | shape, 436 | noise=None, 437 | clip_denoised=True, 438 | denoised_fn=None, 439 | cond_fn=None, 440 | model_kwargs=None, 441 | device=None, 442 | progress=False, 443 | ): 444 | """ 445 | Generate samples from the model. 446 | :param model: the model module. 447 | :param shape: the shape of the samples, (N, C, H, W). 448 | :param noise: if specified, the noise from the encoder to sample. 449 | Should be of the same shape as `shape`. 450 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 451 | :param denoised_fn: if not None, a function which applies to the 452 | x_start prediction before it is used to sample. 453 | :param cond_fn: if not None, this is a gradient function that acts 454 | similarly to the model. 455 | :param model_kwargs: if not None, a dict of extra keyword arguments to 456 | pass to the model. This can be used for conditioning. 457 | :param device: if specified, the device to create the samples on. 458 | If not specified, use a model parameter's device. 459 | :param progress: if True, show a tqdm progress bar. 460 | :return: a non-differentiable batch of samples. 461 | """ 462 | final = None 463 | for sample in self.p_sample_loop_progressive( 464 | model, 465 | shape, 466 | noise=noise, 467 | clip_denoised=clip_denoised, 468 | denoised_fn=denoised_fn, 469 | cond_fn=cond_fn, 470 | model_kwargs=model_kwargs, 471 | device=device, 472 | progress=progress, 473 | ): 474 | final = sample 475 | return final["sample"] 476 | 477 | def p_sample_loop_progressive( 478 | self, 479 | model, 480 | shape, 481 | noise=None, 482 | clip_denoised=True, 483 | denoised_fn=None, 484 | cond_fn=None, 485 | model_kwargs=None, 486 | device=None, 487 | progress=False, 488 | ): 489 | """ 490 | Generate samples from the model and yield intermediate samples from 491 | each timestep of diffusion. 492 | Arguments are the same as p_sample_loop(). 493 | Returns a generator over dicts, where each dict is the return value of 494 | p_sample(). 495 | """ 496 | if device is None: 497 | device = next(model.parameters()).device 498 | assert isinstance(shape, (tuple, list)) 499 | if noise is not None: 500 | img = noise 501 | else: 502 | img = th.randn(*shape, device=device) 503 | indices = list(range(self.num_timesteps))[::-1] 504 | 505 | if progress: 506 | # Lazy import so that we don't depend on tqdm. 507 | from tqdm.auto import tqdm 508 | 509 | indices = tqdm(indices) 510 | 511 | for i in indices: 512 | t = th.tensor([i] * shape[0], device=device) 513 | with th.no_grad(): 514 | out = self.p_sample( 515 | model, 516 | img, 517 | t, 518 | clip_denoised=clip_denoised, 519 | denoised_fn=denoised_fn, 520 | cond_fn=cond_fn, 521 | model_kwargs=model_kwargs, 522 | ) 523 | yield out 524 | img = out["sample"] 525 | 526 | def ddim_sample( 527 | self, 528 | model, 529 | x, 530 | t, 531 | clip_denoised=True, 532 | denoised_fn=None, 533 | cond_fn=None, 534 | model_kwargs=None, 535 | eta=0.0, 536 | ): 537 | """ 538 | Sample x_{t-1} from the model using DDIM. 539 | Same usage as p_sample(). 540 | """ 541 | out = self.p_mean_variance( 542 | model, 543 | x, 544 | t, 545 | clip_denoised=clip_denoised, 546 | denoised_fn=denoised_fn, 547 | model_kwargs=model_kwargs, 548 | ) 549 | if cond_fn is not None: 550 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 551 | 552 | # Usually our model outputs epsilon, but we re-derive it 553 | # in case we used x_start or x_prev prediction. 554 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 555 | 556 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 557 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 558 | sigma = ( 559 | eta 560 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 561 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 562 | ) 563 | # Equation 12. 564 | noise = th.randn_like(x) 565 | mean_pred = ( 566 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 567 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 568 | ) 569 | nonzero_mask = ( 570 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 571 | ) # no noise when t == 0 572 | sample = mean_pred + nonzero_mask * sigma * noise 573 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 574 | 575 | def ddim_reverse_sample( 576 | self, 577 | model, 578 | x, 579 | t, 580 | clip_denoised=True, 581 | denoised_fn=None, 582 | cond_fn=None, 583 | model_kwargs=None, 584 | eta=0.0, 585 | ): 586 | """ 587 | Sample x_{t+1} from the model using DDIM reverse ODE. 588 | """ 589 | assert eta == 0.0, "Reverse ODE only for deterministic path" 590 | out = self.p_mean_variance( 591 | model, 592 | x, 593 | t, 594 | clip_denoised=clip_denoised, 595 | denoised_fn=denoised_fn, 596 | model_kwargs=model_kwargs, 597 | ) 598 | if cond_fn is not None: 599 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 600 | # Usually our model outputs epsilon, but we re-derive it 601 | # in case we used x_start or x_prev prediction. 602 | eps = ( 603 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 604 | - out["pred_xstart"] 605 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 606 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 607 | 608 | # Equation 12. reversed 609 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 610 | 611 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 612 | 613 | def ddim_sample_loop( 614 | self, 615 | model, 616 | shape, 617 | noise=None, 618 | clip_denoised=True, 619 | denoised_fn=None, 620 | cond_fn=None, 621 | model_kwargs=None, 622 | device=None, 623 | progress=False, 624 | eta=0.0, 625 | ): 626 | """ 627 | Generate samples from the model using DDIM. 628 | Same usage as p_sample_loop(). 629 | """ 630 | final = None 631 | for sample in self.ddim_sample_loop_progressive( 632 | model, 633 | shape, 634 | noise=noise, 635 | clip_denoised=clip_denoised, 636 | denoised_fn=denoised_fn, 637 | cond_fn=cond_fn, 638 | model_kwargs=model_kwargs, 639 | device=device, 640 | progress=progress, 641 | eta=eta, 642 | ): 643 | final = sample 644 | return final["sample"] 645 | 646 | def ddim_sample_loop_progressive( 647 | self, 648 | model, 649 | shape, 650 | noise=None, 651 | clip_denoised=True, 652 | denoised_fn=None, 653 | cond_fn=None, 654 | model_kwargs=None, 655 | device=None, 656 | progress=False, 657 | eta=0.0, 658 | ): 659 | """ 660 | Use DDIM to sample from the model and yield intermediate samples from 661 | each timestep of DDIM. 662 | Same usage as p_sample_loop_progressive(). 663 | """ 664 | if device is None: 665 | device = next(model.parameters()).device 666 | assert isinstance(shape, (tuple, list)) 667 | if noise is not None: 668 | img = noise 669 | else: 670 | img = th.randn(*shape, device=device) 671 | indices = list(range(self.num_timesteps))[::-1] 672 | 673 | if progress: 674 | # Lazy import so that we don't depend on tqdm. 675 | from tqdm.auto import tqdm 676 | 677 | indices = tqdm(indices) 678 | 679 | for i in indices: 680 | t = th.tensor([i] * shape[0], device=device) 681 | with th.no_grad(): 682 | out = self.ddim_sample( 683 | model, 684 | img, 685 | t, 686 | clip_denoised=clip_denoised, 687 | denoised_fn=denoised_fn, 688 | cond_fn=cond_fn, 689 | model_kwargs=model_kwargs, 690 | eta=eta, 691 | ) 692 | yield out 693 | img = out["sample"] 694 | 695 | def _vb_terms_bpd( 696 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 697 | ): 698 | """ 699 | Get a term for the variational lower-bound. 700 | The resulting units are bits (rather than nats, as one might expect). 701 | This allows for comparison to other papers. 702 | :return: a dict with the following keys: 703 | - 'output': a shape [N] tensor of NLLs or KLs. 704 | - 'pred_xstart': the x_0 predictions. 705 | """ 706 | true_mean, _, true_log_variance_clipped = \ 707 | self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) 708 | out = self.p_mean_variance( 709 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 710 | ) 711 | kl = normal_kl( 712 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 713 | ) 714 | kl = mean_flat(kl) / np.log(2.0) 715 | 716 | decoder_nll = -discretized_gaussian_log_likelihood( 717 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 718 | ) 719 | assert decoder_nll.shape == x_start.shape 720 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 721 | # At the first timestep return the decoder NLL, 722 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 723 | output = th.where((t == 0), decoder_nll, kl) 724 | return {"output": output, "pred_xstart": out["pred_xstart"]} 725 | 726 | def patchify(self, imgs, imgs_LIDAR): 727 | """ 728 | imgs: (N, 3, H, W) 729 | x: (N, L, patch_size**2 *3) 730 | """ 731 | p = 3 732 | p_LIDAR = 3 733 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 734 | 735 | h = imgs.shape[2] // p 736 | w = imgs.shape[3] // p 737 | 738 | x = imgs.reshape(shape=(imgs.shape[0], imgs.shape[1], h, p, w, p)) 739 | x = torch.einsum('nchpwq->nhwpqc', x) 740 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * imgs.shape[1])) 741 | 742 | x_LIDAR = imgs_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], imgs_LIDAR.shape[1], h, p_LIDAR, w, p_LIDAR)) 743 | x_LIDAR = torch.einsum('nchpwq->nhwpqc', x_LIDAR) 744 | x_LIDAR = x_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], h * w, p_LIDAR**2 * imgs_LIDAR.shape[1])) 745 | return x, x_LIDAR 746 | 747 | def training_losses(self, model, x_start, x_start_LIDAR, t, model_kwargs=None, noise=None): 748 | """ 749 | Compute training losses for a single timestep. 750 | :param model: the model to evaluate loss on. 751 | :param x_start: the [N x C x ...] tensor of inputs. 752 | :param t: a batch of timestep indices. 753 | :param model_kwargs: if not None, a dict of extra keyword arguments to 754 | pass to the model. This can be used for conditioning. 755 | :param noise: if specified, the specific Gaussian noise to try to remove. 756 | :return: a dict with the key "loss" containing a tensor of shape [N]. 757 | Some mean or variance settings may also have other keys. 758 | """ 759 | if model_kwargs is None: 760 | model_kwargs = {} 761 | if noise is None: 762 | noise = th.randn_like(x_start) 763 | noise_LIDAR = th.randn_like(x_start_LIDAR) 764 | x_t = self.q_sample(x_start, t, noise=noise) 765 | x_t_LIDAR = self.q_sample(x_start_LIDAR, t, noise=noise_LIDAR) 766 | 767 | terms = {} 768 | 769 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 770 | terms["loss"] = self._vb_terms_bpd( 771 | model=model, 772 | x_start=x_start, 773 | x_t=x_t, 774 | t=t, 775 | clip_denoised=False, 776 | model_kwargs=model_kwargs, 777 | )["output"] 778 | if self.loss_type == LossType.RESCALED_KL: 779 | terms["loss"] *= self.num_timesteps 780 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 781 | rec_loss, model_output, model_output_LIDAR, logits, mask, mask_LIDAR, Cross_rec_loss,Cross_pred_imgs, Cross_pred_imgs_LIDAR = model(x_t, x_t_LIDAR, t, **model_kwargs) 782 | 783 | if self.model_var_type in [ 784 | ModelVarType.LEARNED, 785 | ModelVarType.LEARNED_RANGE, 786 | ]: 787 | B, C = x_t.shape[:2] 788 | assert model_output.shape == (B, C * 2 , *x_t.shape[2:]) 789 | model_output, model_var_values = th.split(model_output, C, dim=1) 790 | # Learn the variance using the variational bound, but don't let 791 | # it affect our mean prediction. 792 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 793 | terms["vb"] = self._vb_terms_bpd( 794 | model=lambda *args, r=frozen_out: r, 795 | x_start=x_start, 796 | x_t=x_t, 797 | t=t, 798 | clip_denoised=False, 799 | )["output"] 800 | if self.loss_type == LossType.RESCALED_MSE: 801 | # Divide by 1000 for equivalence with initial implementation. 802 | # Without a factor of 1/1000, the VB term hurts the MSE term. 803 | terms["vb"] *= self.num_timesteps / 1000.0 804 | target = { 805 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 806 | x_start=x_start, 807 | x_t=x_t, 808 | t=t 809 | )[0], 810 | ModelMeanType.START_X: x_start, 811 | ModelMeanType.EPSILON: noise, 812 | }[self.model_mean_type] 813 | target_LIDAR = { 814 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 815 | x_start=x_start_LIDAR, 816 | x_t=x_t_LIDAR, 817 | t=t 818 | )[0], 819 | ModelMeanType.START_X: x_start_LIDAR, 820 | ModelMeanType.EPSILON: noise_LIDAR, 821 | }[self.model_mean_type] 822 | assert model_output.shape == target.shape == x_start.shape 823 | 824 | 825 | model_output, model_output_LIDAR = self.patchify(model_output, model_output_LIDAR) 826 | target, target_LIDAR = self.patchify(target, target_LIDAR) 827 | visible = torch.zeros_like(mask) 828 | visible_LIDAR = torch.zeros_like(mask_LIDAR) 829 | zeros_mask = torch.eq(mask, 0) 830 | ones_mask = torch.logical_not(zeros_mask) 831 | visible[zeros_mask] = 1 832 | visible[ones_mask] = 0 833 | zeros_mask = torch.eq(mask_LIDAR, 0) 834 | ones_mask = torch.logical_not(zeros_mask) 835 | visible_LIDAR[zeros_mask] = 1 836 | visible_LIDAR[ones_mask] = 0 837 | 838 | loss_mse = ((target - model_output) ** 2).mean(dim=-1) 839 | loss_mse_LIDAR = ((target_LIDAR - model_output_LIDAR) ** 2).mean(dim=-1) 840 | terms["mse"] = (loss_mse * visible).sum() / visible.sum() 841 | terms["mse_LIDAR"] = (loss_mse_LIDAR * visible_LIDAR).sum() / visible_LIDAR.sum() 842 | if "vb" in terms: 843 | terms["loss"] = terms["mse"] + terms["vb"] 844 | else: 845 | terms["loss"] = terms["mse"] 846 | else: 847 | raise NotImplementedError(self.loss_type) 848 | 849 | return terms, rec_loss, logits, Cross_rec_loss 850 | 851 | def _prior_bpd(self, x_start): 852 | """ 853 | Get the prior KL term for the variational lower-bound, measured in 854 | bits-per-dim. 855 | This term can't be optimized, as it only depends on the encoder. 856 | :param x_start: the [N x C x ...] tensor of inputs. 857 | :return: a batch of [N] KL values (in bits), one per batch element. 858 | """ 859 | batch_size = x_start.shape[0] 860 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 861 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 862 | kl_prior = normal_kl( 863 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 864 | ) 865 | return mean_flat(kl_prior) / np.log(2.0) 866 | 867 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 868 | """ 869 | Compute the entire variational lower-bound, measured in bits-per-dim, 870 | as well as other related quantities. 871 | :param model: the model to evaluate loss on. 872 | :param x_start: the [N x C x ...] tensor of inputs. 873 | :param clip_denoised: if True, clip denoised samples. 874 | :param model_kwargs: if not None, a dict of extra keyword arguments to 875 | pass to the model. This can be used for conditioning. 876 | :return: a dict containing the following keys: 877 | - total_bpd: the total variational lower-bound, per batch element. 878 | - prior_bpd: the prior term in the lower-bound. 879 | - vb: an [N x T] tensor of terms in the lower-bound. 880 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 881 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 882 | """ 883 | device = x_start.device 884 | batch_size = x_start.shape[0] 885 | 886 | vb = [] 887 | xstart_mse = [] 888 | mse = [] 889 | for t in list(range(self.num_timesteps))[::-1]: 890 | t_batch = th.tensor([t] * batch_size, device=device) 891 | noise = th.randn_like(x_start) 892 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 893 | # Calculate VLB term at the current timestep 894 | with th.no_grad(): 895 | out = self._vb_terms_bpd( 896 | model, 897 | x_start=x_start, 898 | x_t=x_t, 899 | t=t_batch, 900 | clip_denoised=clip_denoised, 901 | model_kwargs=model_kwargs, 902 | ) 903 | vb.append(out["output"]) 904 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 905 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 906 | mse.append(mean_flat((eps - noise) ** 2)) 907 | 908 | vb = th.stack(vb, dim=1) 909 | xstart_mse = th.stack(xstart_mse, dim=1) 910 | mse = th.stack(mse, dim=1) 911 | 912 | prior_bpd = self._prior_bpd(x_start) 913 | total_bpd = vb.sum(dim=1) + prior_bpd 914 | return { 915 | "total_bpd": total_bpd, 916 | "prior_bpd": prior_bpd, 917 | "vb": vb, 918 | "xstart_mse": xstart_mse, 919 | "mse": mse, 920 | } 921 | 922 | 923 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 924 | """ 925 | Extract values from a 1-D numpy array for a batch of indices. 926 | :param arr: the 1-D numpy array. 927 | :param timesteps: a tensor of indices into the array to extract. 928 | :param broadcast_shape: a larger shape of K dimensions with the batch 929 | dimension equal to the length of timesteps. 930 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 931 | """ 932 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 933 | while len(res.shape) < len(broadcast_shape): 934 | res = res[..., None] 935 | return res + th.zeros(broadcast_shape, device=timesteps.device) 936 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, x_LIDAR, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, x_LIDAR, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /generate_pic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | indianpines_colors = np.array([[0, 0, 0], 7 | [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], 8 | [0, 255, 255], [255, 0, 255], [192, 192, 192], [128, 128, 128], 9 | [128, 0, 0], [128, 128, 0], [0, 128, 0], [128, 0, 128], 10 | [0, 128, 128], [0, 0, 128], [255, 165, 0], [255, 215, 0]]) 11 | indianpines_colors = sklearn.preprocessing.minmax_scale(indianpines_colors, feature_range=(0, 1)) 12 | 13 | def classification_map(img, ground_truth, dpi, save_path): 14 | fig = plt.figure(frameon=False) 15 | fig.set_size_inches(ground_truth.shape[1] * 2.0 / dpi, ground_truth.shape[0] * 2.0 / dpi) 16 | 17 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 18 | ax.set_axis_off() 19 | ax.xaxis.set_visible(False) 20 | ax.yaxis.set_visible(False) 21 | fig.add_axes(ax) 22 | 23 | ax.imshow(img) 24 | fig.savefig(save_path, dpi=dpi) 25 | 26 | return 0 27 | 28 | 29 | def generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, OA, halfsize, dataset, day_str, num, model_name): 30 | number_of_rows = np.size(image,0) 31 | number_of_columns = np.size(image,1) 32 | 33 | gt_thematic_map = np.zeros(shape=(number_of_rows, number_of_columns, 3)) 34 | predicted_thematic_map = np.zeros(shape=(number_of_rows, number_of_columns, 3)) 35 | for i in range(number_of_rows): 36 | for j in range(number_of_columns): 37 | gt_thematic_map[i, j, :] = indianpines_colors[gt[i,j]] 38 | predicted_thematic_map[i, j, :] = indianpines_colors[gt[i,j]] 39 | nclass = np.max(gt) 40 | 41 | fl = 0 42 | for i in range(nclass): 43 | print('test lable of class:',i) 44 | matrix = index[i] 45 | temprow = matrix[:,0] 46 | tempcol = matrix[:,1] 47 | m = len(temprow) 48 | fl = fl - nTrain_perClass[i] - nvalid_perClass[i] 49 | for j in range(nTrain_perClass[i] + nvalid_perClass[i], m): 50 | predicted_thematic_map[temprow[j], tempcol[j], :] = indianpines_colors[test_pred[fl + j]+1] 51 | fl = fl + m 52 | 53 | 54 | predicted_thematic_map = predicted_thematic_map[halfsize:number_of_rows -halfsize,halfsize:number_of_columns-halfsize,: ] 55 | gt_thematic_map = gt_thematic_map[halfsize:number_of_rows -halfsize,halfsize:number_of_columns-halfsize,: ] 56 | path = '.' 57 | classification_map(predicted_thematic_map, gt, 600, 58 | path + '/classification_maps/' + dataset + '_' + day_str +'_' + str(num) +'_OA_'+ str(round(OA, 2)) + '_' + 'predicted' + '.png') 59 | classification_map(gt_thematic_map, gt, 600, 60 | path + '/classification_maps/' + dataset + '_' + day_str +'_' + str(num) +'_OA_'+ str(round(OA, 2)) + '_' + 'gt' + '.png') 61 | 62 | return predicted_thematic_map, gt_thematic_map 63 | 64 | 65 | -------------------------------------------------------------------------------- /hyper_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.dataset import Dataset 4 | 5 | 6 | class HyperData(Dataset): 7 | def __init__(self, dataset, transfor): 8 | self.data = dataset[0].astype(np.float32) 9 | self.data_LIDAR = dataset[1].astype(np.float32) 10 | self.transformer = transfor 11 | self.labels = [] 12 | for n in dataset[2]: 13 | self.labels += [int(n)] 14 | 15 | def __getitem__(self, index): 16 | label = self.labels[index] 17 | if self.transformer == None: 18 | img = torch.from_numpy(np.asarray(self.data[index,:,:,:])) 19 | img_LIDAR = torch.from_numpy(np.asarray(self.data_LIDAR[index,:,:,:])) 20 | return img, img_LIDAR, label 21 | elif len(self.transformer) == 2: 22 | img = torch.from_numpy(np.asarray(self.transformer[1](self.transformer[0](self.data[index,:,:,:])))) 23 | img_LIDAR = torch.from_numpy(np.asarray(self.transformer[1](self.transformer[0](self.data_LIDAR[index,:,:,:])))) 24 | return img, img_LIDAR, label 25 | else: 26 | img = torch.from_numpy(np.asarray(self.transformer[0](self.data[index,:,:,:]))) 27 | img_LIDAR = torch.from_numpy(np.asarray(self.transformer[0](self.data_LIDAR[index,:,:,:]))) 28 | return img, img_LIDAR, label 29 | 30 | def __len__(self): 31 | return len(self.labels) 32 | 33 | def __labels__(self): 34 | return self.labels 35 | 36 | 37 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | import scipy.io as sio 4 | import torch 5 | import torch.nn as nn 6 | from torch import _assert 7 | from timm.models.vision_transformer import Block 8 | from timm.models.layers import to_2tuple, DropPath 9 | # from timm.models.vision_transformer import PatchEmbed, Block 10 | from pos_embed import get_2d_sincos_pos_embed 11 | 12 | class PatchEmbed(nn.Module): 13 | """ 2D Image to Patch Embedding 14 | """ 15 | def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 16 | super().__init__() 17 | img_size = to_2tuple(img_size) 18 | patch_size = to_2tuple(patch_size) 19 | 20 | self.img_size = img_size 21 | self.patch_size = patch_size 22 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 23 | self.num_patches = self.grid_size[0] * self.grid_size[1] 24 | self.flatten = flatten 25 | 26 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 27 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 28 | 29 | def forward(self, x): 30 | B, C, H, W = x.shape 31 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 32 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 33 | x = self.proj(x) 34 | if self.flatten: 35 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 36 | x = self.norm(x) 37 | return x 38 | 39 | class TimestepEmbedder(nn.Module): 40 | """ 41 | Embeds scalar timesteps into vector representations. 42 | """ 43 | def __init__(self, hidden_size, frequency_embedding_size=256): 44 | super().__init__() 45 | self.mlp = nn.Sequential( 46 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 47 | nn.SiLU(), 48 | nn.Linear(hidden_size, hidden_size, bias=True), 49 | ) 50 | self.frequency_embedding_size = frequency_embedding_size 51 | 52 | @staticmethod 53 | def timestep_embedding(t, dim, max_period=10000): 54 | """ 55 | Create sinusoidal timestep embeddings. 56 | :param t: a 1-D Tensor of N indices, one per batch element. 57 | These may be fractional. 58 | :param dim: the dimension of the output. 59 | :param max_period: controls the minimum frequency of the embeddings. 60 | :return: an (N, D) Tensor of positional embeddings. 61 | """ 62 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 63 | half = dim // 2 64 | freqs = torch.exp( 65 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 66 | ).to(device=t.device) 67 | args = t[:, None].float() * freqs[None] 68 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 69 | if dim % 2: 70 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 71 | return embedding 72 | 73 | def forward(self, t): 74 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 75 | t_emb = self.mlp(t_freq) 76 | return t_emb 77 | 78 | class MaskedAutoencoderViT(nn.Module): 79 | """ Masked Autoencoder with VisionTransformer backbone 80 | """ 81 | def __init__(self, img_size=224, patch_size=16, in_chans=3, in_chans_LIDAR = 1,hid_chans = 32,hid_chans_LIDAR = 32, 82 | embed_dim=1024, depth=24, num_heads=16, 83 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 84 | mlp_ratio=4., drop_rate=0.,attn_drop_rate=0., drop_path_rate=0., 85 | norm_layer=nn.LayerNorm, norm_pix_loss=False, 86 | cls_hidden_mlp=256, nb_classes=1000, global_pool=False, 87 | mlp_depth=2): 88 | super().__init__() 89 | # -------------------------------------------------------------------------- 90 | #HSI 91 | # MAE dimensionality reduction/expansion specifics 92 | self.dimen_redu = nn.Sequential( 93 | nn.Conv2d(in_chans, hid_chans, kernel_size=1, stride=1, padding=0, bias=True), 94 | nn.BatchNorm2d(hid_chans), 95 | nn.ReLU(), 96 | nn.Conv2d(hid_chans, hid_chans, 1, 1, 0, bias=True), 97 | nn.BatchNorm2d(hid_chans), 98 | nn.ReLU(), 99 | ) 100 | 101 | self.t_embedder = TimestepEmbedder(embed_dim) 102 | self.dimen_expa = nn.Conv2d(hid_chans, in_chans, kernel_size=1, stride=1, padding=0, bias=True) 103 | self.h = img_size[0]// patch_size 104 | self.w = img_size[1]// patch_size 105 | # -------------------------------------------------------------------------- 106 | # MAE encoder specifics 107 | self.patch_embed = PatchEmbed(img_size, patch_size, hid_chans , embed_dim) 108 | num_patches = self.patch_embed.num_patches 109 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 111 | self.blocks = nn.ModuleList([ 112 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate,attn_drop=attn_drop_rate, drop_path=drop_path_rate, qkv_bias=True, norm_layer=norm_layer) 113 | for i in range(depth)]) 114 | self.norm = norm_layer(embed_dim) 115 | self.projection_en_mask = nn.Linear(embed_dim, embed_dim) 116 | self.projection_en_visible = nn.Linear(embed_dim, embed_dim) 117 | # -------------------------------------------------------------------------- 118 | 119 | # -------------------------------------------------------------------------- 120 | # MAE decoder specifics 121 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 122 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 123 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 124 | self.decoder_blocks = nn.ModuleList([ 125 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 126 | for i in range(decoder_depth)]) 127 | self.decoder_norm = norm_layer(decoder_embed_dim) 128 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * hid_chans, bias=True) # decoder to patch 129 | self.projection_de = nn.Linear(decoder_embed_dim, decoder_embed_dim) 130 | if cls_hidden_mlp == 0: 131 | self.cls_head = nn.Linear(embed_dim, nb_classes) 132 | else: 133 | assert mlp_depth in [2], "mlp depth should be 2" 134 | if mlp_depth == 2: 135 | self.cls_head = nn.Sequential( 136 | nn.Linear(embed_dim*2, cls_hidden_mlp), 137 | nn.BatchNorm1d(cls_hidden_mlp), 138 | nn.ReLU(inplace=True), 139 | nn.Linear(cls_hidden_mlp, embed_dim), 140 | nn.BatchNorm1d(embed_dim), 141 | nn.ReLU(inplace=True), 142 | nn.Linear(embed_dim, nb_classes), 143 | ) 144 | # -------------------------------------------------------------------------- 145 | self.global_pool = global_pool 146 | self.norm_pix_loss = norm_pix_loss 147 | 148 | # -------------------------------------------------------------------------- 149 | # LIDAR 150 | # MAE dimensionality reduction/expansion specifics 151 | self.dimen_redu_LIDAR = nn.Sequential( 152 | nn.Conv2d(in_chans_LIDAR, hid_chans_LIDAR, kernel_size=1, stride=1, padding=0, bias=True), 153 | nn.BatchNorm2d(hid_chans_LIDAR), 154 | nn.ReLU(), 155 | 156 | nn.Conv2d(hid_chans_LIDAR, hid_chans_LIDAR, 1, 1, 0, bias=True), 157 | nn.BatchNorm2d(hid_chans_LIDAR), 158 | nn.ReLU(), 159 | ) 160 | self.t_embedder_LIDAR = TimestepEmbedder(embed_dim) 161 | self.dimen_expa_LIDAR = nn.Conv2d(hid_chans_LIDAR, in_chans_LIDAR, kernel_size=1, stride=1, padding=0, bias=True) 162 | self.h = img_size[0] // patch_size 163 | self.w = img_size[1] // patch_size 164 | # -------------------------------------------------------------------------- 165 | # MAE encoder specifics 166 | self.patch_embed_LIDAR = PatchEmbed(img_size, patch_size, hid_chans_LIDAR, embed_dim) 167 | num_patches = self.patch_embed_LIDAR.num_patches 168 | 169 | self.cls_token_LIDAR = nn.Parameter(torch.zeros(1, 1, embed_dim)) 170 | self.pos_embed_LIDAR = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 171 | requires_grad=False) # fixed sin-cos embedding 172 | 173 | self.blocks_LIDAR = nn.ModuleList([ 174 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, 175 | qkv_bias=True, norm_layer=norm_layer) 176 | for i in range(depth)]) 177 | self.norm_LIDAR = norm_layer(embed_dim) 178 | self.projection_en_mask_LIDAR = nn.Linear(embed_dim, embed_dim) 179 | self.projection_en_visible_LIDAR = nn.Linear(embed_dim, embed_dim) 180 | # -------------------------------------------------------------------------- 181 | 182 | # -------------------------------------------------------------------------- 183 | # MAE decoder specifics 184 | self.decoder_embed_LIDAR = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 185 | 186 | self.mask_token_LIDAR = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 187 | 188 | self.decoder_pos_embed_LIDAR = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), 189 | requires_grad=False) # fixed sin-cos embedding 190 | 191 | self.decoder_blocks_LIDAR = nn.ModuleList([ 192 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 193 | for i in range(decoder_depth)]) 194 | 195 | self.decoder_norm_LIDAR = norm_layer(decoder_embed_dim) 196 | self.decoder_pred_LIDAR = nn.Linear(decoder_embed_dim, patch_size ** 2 * hid_chans, bias=True) # decoder to patch 197 | self.projection_de_LIDAR = nn.Linear(decoder_embed_dim, decoder_embed_dim) 198 | # -------------------------------------------------------------------------- 199 | self.global_pool_LIDAR = global_pool 200 | self.norm_pix_loss_LIDAR = norm_pix_loss 201 | 202 | 203 | self.initialize_weights() 204 | 205 | def initialize_weights(self): 206 | # initialization` 207 | 208 | # initialize (and freeze) pos_embed by sin-cos embedding 209 | #HSI 210 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 211 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 212 | 213 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 214 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 215 | 216 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 217 | w = self.patch_embed.proj.weight.data 218 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 219 | 220 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 221 | torch.nn.init.normal_(self.cls_token, std=.02) 222 | # Initialize timestep embedding MLP: 223 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 224 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 225 | 226 | #LIDAR 227 | pos_embed_LIDAR = get_2d_sincos_pos_embed(self.pos_embed_LIDAR.shape[-1], int(self.patch_embed_LIDAR.num_patches**.5), cls_token=True) 228 | self.pos_embed_LIDAR.data.copy_(torch.from_numpy(pos_embed_LIDAR).float().unsqueeze(0)) 229 | 230 | decoder_pos_embed_LIDAR = get_2d_sincos_pos_embed(self.decoder_pos_embed_LIDAR.shape[-1], int(self.patch_embed_LIDAR.num_patches**.5), cls_token=True) 231 | self.decoder_pos_embed_LIDAR.data.copy_(torch.from_numpy(decoder_pos_embed_LIDAR).float().unsqueeze(0)) 232 | 233 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 234 | w_LIDAR = self.patch_embed_LIDAR.proj.weight.data 235 | torch.nn.init.xavier_uniform_(w_LIDAR.view([w_LIDAR.shape[0], -1])) 236 | 237 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 238 | torch.nn.init.normal_(self.cls_token_LIDAR, std=.02) 239 | 240 | # Initialize timestep embedding MLP: 241 | nn.init.normal_(self.t_embedder_LIDAR.mlp[0].weight, std=0.02) 242 | nn.init.normal_(self.t_embedder_LIDAR.mlp[2].weight, std=0.02) 243 | 244 | # initialize nn.Linear and nn.LayerNorm 245 | self.apply(self._init_weights) 246 | 247 | def _init_weights(self, m): 248 | if isinstance(m, nn.Linear): 249 | # we use xavier_uniform following official JAX ViT: 250 | torch.nn.init.xavier_uniform_(m.weight) 251 | if isinstance(m, nn.Linear) and m.bias is not None: 252 | nn.init.constant_(m.bias, 0) 253 | elif isinstance(m, nn.LayerNorm): 254 | nn.init.constant_(m.bias, 0) 255 | nn.init.constant_(m.weight, 1.0) 256 | 257 | def patchify(self, imgs, imgs_LIDAR): 258 | """ 259 | imgs: (N, 3, H, W) 260 | x: (N, L, patch_size**2 *3) 261 | """ 262 | p = self.patch_embed.patch_size[0] 263 | p_LIDAR = self.patch_embed_LIDAR.patch_size[0] 264 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 265 | 266 | h = imgs.shape[2] // p 267 | w = imgs.shape[3] // p 268 | 269 | x = imgs.reshape(shape=(imgs.shape[0], imgs.shape[1], h, p, w, p)) 270 | x = torch.einsum('nchpwq->nhwpqc', x) 271 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * imgs.shape[1])) 272 | 273 | x_LIDAR = imgs_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], imgs_LIDAR.shape[1], h, p_LIDAR, w, p_LIDAR)) 274 | x_LIDAR = torch.einsum('nchpwq->nhwpqc', x_LIDAR) 275 | x_LIDAR = x_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], h * w, p_LIDAR**2 * imgs_LIDAR.shape[1])) 276 | return x, x_LIDAR 277 | 278 | def unpatchify(self, x, x_LIDAR): 279 | """ 280 | x: (N, L, patch_size**2 *3) 281 | imgs: (N, 3, H, W) 282 | """ 283 | p = self.patch_embed.patch_size[0] 284 | p_LIDAR = self.patch_embed_LIDAR.patch_size[0] 285 | h = self.h 286 | w = self.w 287 | assert h * w == x.shape[1] 288 | assert h * w == x_LIDAR.shape[1] 289 | 290 | hid_chans = int(x.shape[2]/(p**2)) 291 | hid_chans_LIDAR = int(x_LIDAR.shape[2]/(p_LIDAR**2)) 292 | 293 | x = x.reshape(shape=(x.shape[0], h, w, p, p, hid_chans)) 294 | x = torch.einsum('nhwpqc->nchpwq', x) 295 | imgs = x.reshape(shape=(x.shape[0], hid_chans, h * p, w * p)) 296 | 297 | x_LIDAR = x_LIDAR.reshape(shape=(x_LIDAR.shape[0], h, w, p_LIDAR, p_LIDAR, hid_chans_LIDAR)) 298 | x_LIDAR = torch.einsum('nhwpqc->nchpwq', x_LIDAR) 299 | imgs_LIDAR = x_LIDAR.reshape(shape=(x_LIDAR.shape[0], hid_chans_LIDAR, h * p_LIDAR, w * p_LIDAR)) 300 | return imgs, imgs_LIDAR 301 | 302 | def random_masking(self, x, x_LIDAR, mask_ratio): 303 | """ 304 | Perform per-sample random masking by per-sample shuffling. 305 | Per-sample shuffling is done by argsort random noise. 306 | x: [N, L, D], sequence 307 | """ 308 | N, L, D = x.shape # batch, length, dim 309 | len_keep = int(L * (1 - mask_ratio)) 310 | 311 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 312 | 313 | # sort noise for each sample 314 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 315 | ids_restore = torch.argsort(ids_shuffle, dim=1) 316 | 317 | # keep the first subset 318 | ids_keep = ids_shuffle[:, :len_keep] 319 | x_visible = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 320 | 321 | 322 | noise_LIDAR = torch.rand(N, L, device=x.device) 323 | ids_shuffle_LIDAR = torch.argsort(noise_LIDAR, dim=1) # ascend: small is keep, large is remove 324 | ids_restore_LIDAR = torch.argsort(ids_shuffle_LIDAR, dim=1) 325 | 326 | # keep the first subset 327 | ids_keep_LIDAR = ids_restore_LIDAR[:, :len_keep] 328 | x_LIDAR_visible = torch.gather(x_LIDAR, dim=1, index=ids_keep_LIDAR.unsqueeze(-1).repeat(1, 1, D)) 329 | 330 | # generate the binary mask: 0 is keep, 1 is remove 331 | mask = torch.ones([N, L], device=x.device) 332 | mask[:, :len_keep] = 0 333 | mask_LIDAR = torch.ones([N, L], device=x.device) 334 | mask_LIDAR[:, :len_keep] = 0 335 | # unshuffle to get the binary mask 336 | mask = torch.gather(mask, dim=1, index=ids_restore) 337 | mask_LIDAR = torch.gather(mask_LIDAR, dim=1, index=ids_restore_LIDAR) 338 | return x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR 339 | 340 | def preprocessing(self, x, x_LIDAR, mask_ratio): 341 | # embed patches 342 | x = self.dimen_redu(x) 343 | x = self.patch_embed(x) 344 | 345 | x_LIDAR = self.dimen_redu_LIDAR(x_LIDAR) 346 | x_LIDAR = self.patch_embed_LIDAR(x_LIDAR) 347 | 348 | # add pos embed w/o cls token 349 | x = x + self.pos_embed[:, 1:, :] 350 | x_LIDAR = x_LIDAR + self.pos_embed_LIDAR[:, 1:, :] 351 | 352 | # masking: length -> length * mask_ratio 353 | x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR = self.random_masking(x, x_LIDAR, mask_ratio) 354 | return x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR 355 | 356 | def forward_encoder(self, x, x_LIDAR, t): 357 | # append cls token 358 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 359 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 360 | 361 | t_token = self.t_embedder(torch.squeeze(t)) 362 | t_token = torch.unsqueeze(t_token, dim=1) 363 | cls_tokens = cls_tokens + t_token 364 | 365 | x = torch.cat((cls_tokens, x), dim=1) 366 | 367 | cls_token_LIDAR = self.cls_token_LIDAR + self.pos_embed_LIDAR[:, :1, :] 368 | cls_tokens_LIDAR = cls_token_LIDAR.expand(x_LIDAR.shape[0], -1, -1) 369 | 370 | t_token_LIDAR = self.t_embedder_LIDAR(torch.squeeze(t)) 371 | t_token_LIDAR = torch.unsqueeze(t_token_LIDAR, dim=1) 372 | cls_tokens_LIDAR = cls_tokens_LIDAR + t_token_LIDAR 373 | 374 | x_LIDAR = torch.cat((cls_tokens_LIDAR, x_LIDAR), dim=1) 375 | 376 | # apply Transformer blocks 377 | for blk in self.blocks: 378 | x = blk(x) 379 | x = self.norm(x) 380 | 381 | for blk_LIDAR in self.blocks: 382 | x_LIDAR = blk_LIDAR(x_LIDAR) 383 | x_LIDAR = self.norm(x_LIDAR) 384 | 385 | return x, x_LIDAR 386 | 387 | def forward_decoder_A(self, x, ids_restore): 388 | # embed token 389 | x = self.decoder_embed(x) 390 | 391 | # append mask tokens to sequence 392 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 393 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 394 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 395 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 396 | 397 | # add pos embed 398 | x = x + self.decoder_pos_embed[:, :x.shape[1], :] 399 | 400 | # apply Transformer blocks 401 | for blk in self.decoder_blocks: 402 | x = blk(x) 403 | x = self.decoder_norm(x) 404 | 405 | x = self.projection_de(x) 406 | return x 407 | 408 | def forward_decoder_B(self, x_LIDAR, ids_restore_LIDAR): 409 | # embed token 410 | x_LIDAR = self.decoder_embed_LIDAR(x_LIDAR) 411 | 412 | # append mask tokens to sequence 413 | mask_tokens_LIDAR = self.mask_token_LIDAR.repeat(x_LIDAR.shape[0], ids_restore_LIDAR.shape[1] + 1 - x_LIDAR.shape[1], 1) 414 | x_LIDAR_ = torch.cat([x_LIDAR[:, 1:, :], mask_tokens_LIDAR], dim=1) # no cls token 415 | x_LIDAR_ = torch.gather(x_LIDAR_, dim=1, index=ids_restore_LIDAR.unsqueeze(-1).repeat(1, 1, x_LIDAR.shape[2])) # unshuffle 416 | x_LIDAR = torch.cat([x_LIDAR[:, :1, :], x_LIDAR_], dim=1) # append cls token 417 | 418 | # add pos embed 419 | x_LIDAR = x_LIDAR + self.decoder_pos_embed_LIDAR[:, :x_LIDAR.shape[1], :] 420 | 421 | # apply Transformer blocks 422 | for blk_LIDAR in self.decoder_blocks_LIDAR: 423 | x_LIDAR = blk_LIDAR(x_LIDAR) 424 | x_LIDAR = self.decoder_norm_LIDAR(x_LIDAR) 425 | 426 | x_LIDAR = self.projection_de_LIDAR(x_LIDAR) 427 | return x_LIDAR 428 | 429 | 430 | def reconstruction(self, x, x_LIDAR): 431 | x = self.decoder_pred(x) 432 | x_LIDAR = self.decoder_pred_LIDAR(x_LIDAR) 433 | # # remove cls token 434 | x = x[:, 1:, :] 435 | x_LIDAR = x_LIDAR[:, 1:, :] 436 | 437 | 438 | x, x_LIDAR = self.unpatchify(x, x_LIDAR) 439 | x = self.dimen_expa(x) 440 | x_LIDAR = self.dimen_expa_LIDAR(x_LIDAR) 441 | 442 | pred_Reconstruction, pred_LIDAR_Reconstruction = self.patchify(x, x_LIDAR) 443 | return x, x_LIDAR, pred_Reconstruction, pred_LIDAR_Reconstruction 444 | 445 | def forward_classification(self, x, x_LIDAR): 446 | if self.global_pool: 447 | feat = x[:, 1:, :].mean(dim=1) # global pool without cls token 448 | else: 449 | feat = x[:, 0, :] # with cls token 450 | 451 | if self.global_pool_LIDAR: 452 | feat_LIDAR = x_LIDAR[:, 1:, :].mean(dim=1) # global pool without cls token 453 | else: 454 | feat_LIDAR = x_LIDAR[:, 0, :] # with cls token 455 | feat_all = torch.cat((feat, feat_LIDAR),dim=1) 456 | logits = self.cls_head(feat_all) 457 | return logits 458 | 459 | def Reconstruction_loss(self, imgs, pred, imgs_LIDAR, pred_LIDAR, mask, mask_LIDAR): 460 | """ 461 | imgs: [N, 3, H, W] 462 | pred: [N, L, p*p*3] 463 | mask: [N, L], 0 is keep, 1 is remove, 464 | """ 465 | target, target_LIDAR = self.patchify(imgs, imgs_LIDAR) 466 | if self.norm_pix_loss: 467 | mean = target.mean(dim=-1, keepdim=True) 468 | var = target.var(dim=-1, keepdim=True) 469 | target = (target - mean) / (var + 1.e-6)**.5 470 | 471 | loss = (pred - target) ** 2 472 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 473 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 474 | 475 | if self.norm_pix_loss_LIDAR: 476 | mean_LIDAR = target_LIDAR.mean(dim=-1, keepdim=True) 477 | var_LIDAR = target_LIDAR.var(dim=-1, keepdim=True) 478 | target_LIDAR = (target_LIDAR - mean_LIDAR) / (var_LIDAR + 1.e-6)**.5 479 | 480 | loss_LIDAR = (pred_LIDAR - target_LIDAR) ** 2 481 | loss_LIDAR = loss_LIDAR.mean(dim=-1) # [N, L], mean loss per patch 482 | loss_LIDAR = (loss_LIDAR * mask_LIDAR).sum() / mask_LIDAR.sum() # mean loss on removed patches 483 | 484 | loss_all = loss + loss_LIDAR 485 | return loss_all 486 | 487 | def forward(self, imgs, imgs_LIDAR, t, y, mask_ratio): 488 | #preprocessing 489 | x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR = self.preprocessing(imgs, imgs_LIDAR, mask_ratio) 490 | 491 | #visible_process 492 | feature_visible, feature_visible_LIDAR = self.forward_encoder(x_visible, x_LIDAR_visible, t) 493 | pred = self.forward_decoder_A(feature_visible, ids_restore) 494 | pred_LIDAR = self.forward_decoder_B(feature_visible_LIDAR, ids_restore_LIDAR) 495 | 496 | # Reconstruction branch 497 | pred_imgs, pred_imgs_LIDAR, pred_Reconstruction, pred_LIDAR_Reconstruction = self.reconstruction(pred, pred_LIDAR) # [N, L, p*p*3] 498 | 499 | # Cross Reconstruction branch 500 | # Cross_pred = self.forward_decoder_A(feature_visible_LIDAR, ids_restore_LIDAR) 501 | # Cross_pred_LIDAR = self.forward_decoder_B(feature_visible, ids_restore) 502 | # Cross_pred_imgs, Cross_pred_imgs_LIDAR, Cross_pred_Reconstruction, Cross_pred_LIDAR_Reconstruction = self.reconstruction(Cross_pred, Cross_pred_LIDAR) 503 | 504 | # Classification branch 505 | logits = self.forward_classification(feature_visible, feature_visible_LIDAR) 506 | # return pred_imgs, pred_imgs_LIDAR, logits, mask, mask_LIDAR, Cross_pred_imgs, Cross_pred_imgs_LIDAR 507 | return pred_imgs, pred_imgs_LIDAR, logits, mask, mask_LIDAR 508 | 509 | 510 | def patchify(imgs, imgs_LIDAR, size): 511 | """ 512 | imgs: (N, 3, H, W) 513 | x: (N, L, patch_size**2 *3) 514 | """ 515 | p = size 516 | p_LIDAR = size 517 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 518 | 519 | h = imgs.shape[2] // p 520 | w = imgs.shape[3] // p 521 | 522 | x = imgs.reshape(shape=(imgs.shape[0], imgs.shape[1], h, p, w, p)) 523 | x = torch.einsum('nchpwq->nhwpqc', x) 524 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * imgs.shape[1])) 525 | 526 | x_LIDAR = imgs_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], imgs_LIDAR.shape[1], h, p_LIDAR, w, p_LIDAR)) 527 | x_LIDAR = torch.einsum('nchpwq->nhwpqc', x_LIDAR) 528 | x_LIDAR = x_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], h * w, p_LIDAR**2 * imgs_LIDAR.shape[1])) 529 | return x, x_LIDAR 530 | 531 | def DDPM_LOSS(model_output, model_output_LIDAR, target, target_LIDAR, mask , mask_LIDAR, size): 532 | model_output, model_output_LIDAR = patchify(model_output, model_output_LIDAR, size) 533 | target, target_LIDAR = patchify(target, target_LIDAR, size) 534 | 535 | loss_mse = ((target - model_output) ** 2).mean(dim=-1) 536 | loss_mse_LIDAR = ((target_LIDAR - model_output_LIDAR) ** 2).mean(dim=-1) 537 | loss_mse_m = (loss_mse * mask).sum() / mask.sum() 538 | loss_mse_LIDAR_m = (loss_mse_LIDAR * mask_LIDAR).sum() / mask_LIDAR.sum() 539 | visible = torch.zeros_like(mask) 540 | visible_LIDAR = torch.zeros_like(mask_LIDAR) 541 | zeros_mask = torch.eq(mask, 0) 542 | ones_mask = torch.logical_not(zeros_mask) 543 | visible[zeros_mask] = 1 544 | visible[ones_mask] = 0 545 | zeros_mask = torch.eq(mask_LIDAR, 0) 546 | ones_mask = torch.logical_not(zeros_mask) 547 | visible_LIDAR[zeros_mask] = 1 548 | visible_LIDAR[ones_mask] = 0 549 | loss_mse_v = (loss_mse * visible).sum() / visible.sum() 550 | loss_mse_LIDAR_v = (loss_mse_LIDAR * visible_LIDAR).sum() / visible_LIDAR.sum() 551 | 552 | # loss_mse = ((target - model_output) ** 2).mean(dim=-1).mean() 553 | # loss_mse_LIDAR = ((target_LIDAR - model_output_LIDAR) ** 2).mean(dim=-1).mean() 554 | return loss_mse_m, loss_mse_LIDAR_m, loss_mse_v, loss_mse_LIDAR_v 555 | 556 | class vit_HSI_LIDAR(nn.Module): 557 | """ Masked Autoencoder's'backbone 558 | """ 559 | 560 | def __init__(self, img_size=(224, 224), patch_size=16, num_classes=1000, in_chans=3, in_chans_LIDAR = 1, hid_chans=32, 561 | hid_chans_LIDAR=128,embed_dim=1024, depth=24, num_heads=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 562 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, global_pool=False): 563 | super().__init__() 564 | self.patch_size = patch_size 565 | 566 | # -------------------------------------------------------------------------- 567 | # HSI 568 | # MAE encoder specifics 569 | self.dimen_redu = nn.Sequential( 570 | nn.Conv2d(in_chans, hid_chans, kernel_size=1, stride=1, padding=0, bias=True), 571 | nn.BatchNorm2d(hid_chans), 572 | nn.ReLU(), 573 | 574 | nn.Conv2d(hid_chans, hid_chans, 1, 1, 0, bias=True), 575 | nn.BatchNorm2d(hid_chans), 576 | nn.ReLU(), 577 | ) 578 | 579 | # -------------------------------------------------------------------------- 580 | # MAE encoder specifics 581 | self.patch_embed = PatchEmbed(img_size, patch_size, hid_chans, embed_dim) 582 | num_patches = self.patch_embed.num_patches 583 | 584 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 585 | self.t_embedder = TimestepEmbedder(embed_dim) 586 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 587 | requires_grad=True) # fixed sin-cos embedding 588 | 589 | self.blocks = nn.ModuleList([ 590 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, 591 | qkv_bias=True, norm_layer=norm_layer) 592 | for i in range(depth)]) 593 | self.norm = norm_layer(embed_dim) 594 | self.head = nn.Linear(embed_dim * 2, embed_dim, bias=True) 595 | 596 | self.global_pool = global_pool 597 | if self.global_pool: 598 | self.fc_norm = norm_layer(embed_dim) 599 | del self.norm 600 | 601 | # LIDAR 602 | # MAE encoder specifics 603 | self.dimen_redu_LIDAR = nn.Sequential( 604 | nn.Conv2d(in_chans_LIDAR, hid_chans_LIDAR, kernel_size=1, stride=1, padding=0, bias=True), 605 | nn.BatchNorm2d(hid_chans_LIDAR), 606 | nn.ReLU(), 607 | 608 | nn.Conv2d(hid_chans_LIDAR, hid_chans_LIDAR, 1, 1, 0, bias=True), 609 | nn.BatchNorm2d(hid_chans_LIDAR), 610 | nn.ReLU(), 611 | ) 612 | 613 | # -------------------------------------------------------------------------- 614 | # MAE encoder specifics 615 | self.patch_embed_LIDAR = PatchEmbed(img_size, patch_size, hid_chans_LIDAR, embed_dim) 616 | num_patches = self.patch_embed_LIDAR.num_patches 617 | 618 | self.cls_token_LIDAR = nn.Parameter(torch.zeros(1, 1, embed_dim)) 619 | self.t_embedder_LIDAR = TimestepEmbedder(embed_dim) 620 | self.pos_embed_LIDAR = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 621 | requires_grad=True) # fixed sin-cos embedding 622 | 623 | self.blocks_LIDAR = nn.ModuleList([ 624 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, 625 | qkv_bias=True, norm_layer=norm_layer) 626 | for i in range(depth)]) 627 | self.norm_LIDAR = norm_layer(embed_dim) 628 | self.global_pool_LIDAR = global_pool 629 | if self.global_pool_LIDAR: 630 | self.fc_norm_LIDAR = norm_layer(embed_dim) 631 | del self.norm_LIDAR 632 | 633 | def initialize_weights(self): 634 | # initialization 635 | # initialize (and freeze) pos_embed by sin-cos embedding 636 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), 637 | cls_token=True) 638 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 639 | 640 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 641 | w = self.patch_embed.proj.weight.data 642 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 643 | 644 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 645 | torch.nn.init.normal_(self.cls_token, std=.02) 646 | 647 | # initialize (and freeze) pos_embed by sin-cos embedding 648 | pos_embed_LIDAR = get_2d_sincos_pos_embed(self.pos_embed_LIDAR.shape[-1], int(self.patch_embed_LIDAR.num_patches ** .5), 649 | cls_token=True) 650 | self.pos_embed_LIDAR.data.copy_(torch.from_numpy(pos_embed_LIDAR).float().unsqueeze(0)) 651 | 652 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 653 | w_LIDAR = self.patch_embed_LIDAR.proj.weight.data 654 | torch.nn.init.xavier_uniform_(w_LIDAR.view([w_LIDAR.shape[0], -1])) 655 | 656 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 657 | torch.nn.init.normal_(self.cls_token_LIDAR, std=.02) 658 | 659 | # Initialize timestep embedding MLP: 660 | nn.init.normal_(self.t_embedder_LIDAR.mlp[0].weight, std=0.02) 661 | nn.init.normal_(self.t_embedder_LIDAR.mlp[2].weight, std=0.02) 662 | 663 | # initialize nn.Linear and nn.LayerNorm 664 | self.apply(self._init_weights) 665 | 666 | def _init_weights(self, m): 667 | if isinstance(m, nn.Linear): 668 | # we use xavier_uniform following official JAX ViT: 669 | torch.nn.init.xavier_uniform_(m.weight) 670 | if isinstance(m, nn.Linear) and m.bias is not None: 671 | nn.init.constant_(m.bias, 0) 672 | elif isinstance(m, nn.LayerNorm): 673 | nn.init.constant_(m.bias, 0) 674 | nn.init.constant_(m.weight, 1.0) 675 | 676 | def forward_features(self, x, x_LIDAR, t): 677 | x = self.dimen_redu(x) 678 | x_LIDAR = self.dimen_redu_LIDAR(x_LIDAR) 679 | 680 | # embed patches 681 | x = self.patch_embed(x) 682 | x_LIDAR = self.patch_embed_LIDAR(x_LIDAR) 683 | 684 | # add pos embed w/o cls token 685 | x = x + self.pos_embed[:, 1:, :] 686 | x_LIDAR = x_LIDAR + self.pos_embed_LIDAR[:, 1:, :] 687 | 688 | # append cls token 689 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 690 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 691 | 692 | t_token = self.t_embedder(t) 693 | t_token = torch.unsqueeze(t_token, dim=1) 694 | cls_tokens = cls_tokens + t_token 695 | 696 | x = torch.cat((cls_tokens, x), dim=1) 697 | 698 | 699 | cls_token_LIDAR = self.cls_token_LIDAR + self.pos_embed_LIDAR[:, :1, :] 700 | cls_tokens_LIDAR = cls_token_LIDAR.expand(x_LIDAR.shape[0], -1, -1) 701 | 702 | t_token_LIDAR = self.t_embedder_LIDAR(t) 703 | t_token_LIDAR = torch.unsqueeze(t_token_LIDAR, dim=1) 704 | cls_tokens_LIDAR = cls_tokens_LIDAR + t_token_LIDAR 705 | 706 | x_LIDAR = torch.cat((cls_tokens_LIDAR, x_LIDAR), dim=1) 707 | 708 | # apply Transformer blocks 709 | for blk in self.blocks: 710 | x = blk(x) 711 | if self.global_pool: 712 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 713 | outcome = self.fc_norm(x) 714 | else: 715 | x = self.norm(x) 716 | outcome = x[:, 0] 717 | 718 | for blk_LIDAR in self.blocks: 719 | x_LIDAR = blk_LIDAR(x_LIDAR) 720 | if self.global_pool_LIDAR: 721 | x_LIDAR = x_LIDAR[:, 1:, :].mean(dim=1) # global pool without cls token 722 | outcome_LIDAR = self.fc_norm(x_LIDAR) 723 | else: 724 | x_LIDAR = self.norm(x_LIDAR) 725 | outcome_LIDAR = x_LIDAR[:, 0] 726 | 727 | outcome_all = torch.cat((outcome, outcome_LIDAR),dim = 1) 728 | return outcome_all 729 | 730 | def forward(self, x, x_LIDAR, t): 731 | x = self.forward_features(x, x_LIDAR, t) 732 | x = self.head(x) 733 | return x 734 | 735 | 736 | def mae_vit_HSIandLIDAR_patch3(**kwargs): 737 | model = MaskedAutoencoderViT( 738 | patch_size=1, 739 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 740 | return model 741 | 742 | def vit_HSI_LIDAR_patch3(**kwargs): 743 | model = vit_HSI_LIDAR( 744 | patch_size=1, 745 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 746 | return model -------------------------------------------------------------------------------- /pos_embed.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Position embedding utils 3 | # -------------------------------------------------------- 4 | 5 | import numpy as np 6 | 7 | import torch 8 | 9 | # -------------------------------------------------------- 10 | # 2D sine-cosine position embedding 11 | # References: 12 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 13 | # MoCo v3: https://github.com/facebookresearch/moco-v3 14 | # -------------------------------------------------------- 15 | def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 16 | """ 17 | grid_size: int of the grid height and width 18 | return: 19 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 20 | """ 21 | grid_h = np.arange(grid_size, dtype=np.float32) 22 | grid_w = np.arange(grid_size, dtype=np.float32) 23 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 24 | grid = np.stack(grid, axis=0) 25 | 26 | grid = grid.reshape([2, 1, grid_size, grid_size]) 27 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 28 | # pos_embed = 29 | if cls_token: 30 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 31 | return pos_embed 32 | 33 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 34 | """ 35 | grid_size: int of the grid height and width 36 | return: 37 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 38 | """ 39 | grid_h = np.arange(grid_size, dtype=np.float32) 40 | grid_w = np.arange(grid_size, dtype=np.float32) 41 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 42 | grid = np.stack(grid, axis=0) 43 | 44 | grid = grid.reshape([2, 1, grid_size, grid_size]) 45 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 46 | if cls_token: 47 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 48 | return pos_embed 49 | 50 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 51 | assert embed_dim % 2 == 0 52 | 53 | # use half of dimensions to encode grid_h 54 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 55 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 56 | 57 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 58 | return emb 59 | 60 | 61 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 62 | """ 63 | embed_dim: output dimension for each position 64 | pos: a list of positions to be encoded: size (M,) 65 | out: (M, D) 66 | """ 67 | assert embed_dim % 2 == 0 68 | omega = np.arange(embed_dim // 2, dtype=np.float32) 69 | omega /= embed_dim / 2. 70 | omega = 1. / 10000**omega # (D/2,) 71 | 72 | pos = pos.reshape(-1) # (M,) 73 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 74 | 75 | emb_sin = np.sin(out) # (M, D/2) 76 | emb_cos = np.cos(out) # (M, D/2) 77 | 78 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 79 | return emb 80 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 81 | """ 82 | grid_size: int of the grid height and width 83 | return: 84 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 85 | """ 86 | grid_l = np.arange(grid_size, dtype=np.float32) 87 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l) 88 | if cls_token: 89 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 90 | return pos_embed 91 | 92 | # -------------------------------------------------------- 93 | # Interpolate position embeddings for high-resolution 94 | # References: 95 | # DeiT: https://github.com/facebookresearch/deit 96 | # -------------------------------------------------------- 97 | def interpolate_pos_embed(model, checkpoint_model): 98 | if 'pos_embed' in checkpoint_model: 99 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 100 | embedding_size = pos_embed_checkpoint.shape[-1] 101 | num_patches = model.patch_embed.num_patches 102 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 103 | # height (== width) for the checkpoint position embedding 104 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 105 | # height (== width) for the new position embedding 106 | new_size = int(num_patches ** 0.5) 107 | # class_token and dist_token are kept unchanged 108 | if orig_size != new_size: 109 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 110 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 111 | # only the position tokens are interpolated 112 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 113 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 114 | pos_tokens = torch.nn.functional.interpolate( 115 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 116 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 117 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 118 | checkpoint_model['pos_embed'] = new_pos_embed 119 | 120 | if 'pos_embed_LIDAR' in checkpoint_model: 121 | pos_embed_LIDAR_checkpoint = checkpoint_model['pos_embed_LIDAR'] 122 | embedding_size_LIDAR = pos_embed_LIDAR_checkpoint.shape[-1] 123 | num_patches_LIDAR = model.patch_embed_LIDAR.num_patches 124 | num_extra_tokens_LIDAR = model.pos_embed_LIDAR.shape[-2] - num_patches_LIDAR 125 | # height (== width) for the checkpoint position embedding 126 | orig_size_LIDAR = int((pos_embed_LIDAR_checkpoint.shape[-2] - num_extra_tokens_LIDAR) ** 0.5) 127 | # height (== width) for the new position embedding 128 | new_size_LIDAR = int(num_patches_LIDAR ** 0.5) 129 | # class_token and dist_token are kept unchanged 130 | if orig_size_LIDAR != new_size_LIDAR: 131 | print("Position interpolate from %dx%d to %dx%d" % (orig_size_LIDAR, orig_size_LIDAR, new_size_LIDAR, new_size_LIDAR)) 132 | extra_tokens_LIDAR = pos_embed_LIDAR_checkpoint[:, :num_extra_tokens_LIDAR] 133 | # only the position tokens are interpolated 134 | pos_tokens_LIDAR = pos_embed_LIDAR_checkpoint[:, num_extra_tokens_LIDAR:] 135 | pos_tokens_LIDAR = pos_tokens_LIDAR.reshape(-1, orig_size_LIDAR, orig_size_LIDAR, embedding_size_LIDAR).permute(0, 3, 1, 2) 136 | pos_tokens_LIDAR = torch.nn.functional.interpolate( 137 | pos_tokens_LIDAR, size=(new_size_LIDAR, new_size_LIDAR), mode='bicubic', align_corners=False) 138 | pos_tokens_LIDAR = pos_tokens_LIDAR.permute(0, 2, 3, 1).flatten(1, 2) 139 | new_pos_embed_LIDAR = torch.cat((extra_tokens_LIDAR, pos_tokens_LIDAR), dim=1) 140 | checkpoint_model['pos_embed_LIDAR'] = new_pos_embed_LIDAR 141 | -------------------------------------------------------------------------------- /record.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, training_time_ae, testing_time_ae, path): 6 | f = open(path, 'a') 7 | 8 | sentence0 = 'OAs for each iteration are:' + str(oa_ae) + '\n' 9 | f.write(sentence0) 10 | sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n' 11 | f.write(sentence1) 12 | sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n' 13 | f.write(sentence2) 14 | sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)) + ' ± ' + str(np.std(oa_ae)) + '\n' 15 | f.write(sentence3) 16 | sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)) + ' ± ' + str(np.std(aa_ae)) + '\n' 17 | f.write(sentence4) 18 | sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n' 19 | f.write(sentence5) 20 | sentence6 = 'Total average Training time is: ' + str(np.mean(training_time_ae)) + '\n' 21 | f.write(sentence6) 22 | sentence7 = 'Total average Testing time is: ' + str(np.mean(testing_time_ae)) + '\n' + '\n' 23 | f.write(sentence7) 24 | 25 | element_mean = np.mean(element_acc_ae, axis=0) 26 | element_std = np.std(element_acc_ae, axis=0) 27 | sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean) + '\n' 28 | f.write(sentence8) 29 | sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std) + '\n' 30 | f.write(sentence9) 31 | 32 | f.close() 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from scipy.io import savemat 2 | 3 | import record 4 | import scipy.io as sio 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | import time 11 | import datetime 12 | import numpy as np 13 | import os 14 | 15 | from config import load_args 16 | from data_read import readdata 17 | from diffusion import create_diffusion 18 | from generate_pic import generate 19 | from hyper_dataset import HyperData 20 | from augment import CenterResizeCrop 21 | from util_CNN_clip import test_batch, pre_train 22 | from pos_embed import interpolate_pos_embed 23 | from timm.models.layers import trunc_normal_ 24 | from model import mae_vit_HSIandLIDAR_patch3, vit_HSI_LIDAR_patch3, DDPM_LOSS 25 | from clip_model import build_model 26 | from collections import OrderedDict 27 | 28 | # get class name 29 | def read_classnames(text_file): 30 | """Return a dictionary containing 31 | key-value pairs of : . 32 | """ 33 | classnames = OrderedDict() 34 | with open(text_file, "r") as f: 35 | lines = f.readlines() 36 | for line in lines: 37 | line = line.strip().split(" ") 38 | folder = line[0] 39 | classname = " ".join(line[1:]) 40 | classnames[folder] = classname 41 | return classnames 42 | 43 | args = load_args() 44 | 45 | mask_ratio = args.mask_ratio 46 | windowsize = args.windowsize 47 | dataset = args.dataset 48 | type = args.type 49 | num_epoch = args.epochs 50 | num_fine_tuned =args.fine_tuned_epochs 51 | lr = args.lr 52 | train_num_per = args.train_num_perclass 53 | num_of_ex = 10 54 | batch_size= args.batch_size 55 | 56 | net_name = 'LDS2AE' 57 | day = datetime.datetime.now() 58 | day_str = day.strftime('%m_%d_%H_%M') 59 | halfsize = int((windowsize-1)/2) 60 | val_num = 1000 61 | Seed = 0 62 | _, _, _, _, _,_,_, _, _, _, _,_, _,gt,s = readdata(type, dataset, windowsize,train_num_per, val_num, 0) 63 | num_of_samples = int(s * 0.2) 64 | nclass = np.max(gt).astype(np.int64) 65 | print(nclass) 66 | 67 | if args.dataset == 'Muufl': 68 | in_chans_LIDAR = 2 69 | else: 70 | in_chans_LIDAR = 1 71 | 72 | size = 1 73 | 74 | KAPPA = [] 75 | OA = [] 76 | AA = [] 77 | TRAINING_TIME = [] 78 | TESTING_TIME = [] 79 | ELEMENT_ACC = np.zeros((num_of_ex, nclass)) 80 | af_result = np.zeros([nclass+3, num_of_ex]) 81 | criterion = nn.CrossEntropyLoss() 82 | 83 | for num in range(0,num_of_ex): 84 | print('num:', num) 85 | train_image, train_image_LIDAR, train_label, validation_image1, validation_image_LIDAR1, validation_label1, nTrain_perClass, nvalid_perClass, \ 86 | train_index, val_index, index, image, image_LiDAR, gt,s = readdata(type, dataset, windowsize,train_num_per,num_of_samples,num) 87 | ind = np.random.choice(validation_image1.shape[0], 200, replace = False) 88 | validation_image = validation_image1[ind] 89 | validation_image_LIDAR = validation_image_LIDAR1[ind] 90 | validation_label= validation_label1[ind] 91 | nvalid_perClass = np.zeros_like(nvalid_perClass) 92 | nband = train_image.shape[3] 93 | 94 | 95 | train_num = train_image.shape[0] 96 | train_image = np.transpose(train_image,(0,3,1,2)) 97 | train_image_LIDAR = np.transpose(train_image_LIDAR,(0,3,1,2)) 98 | 99 | validation_image = np.transpose(validation_image,(0,3,1,2)) 100 | validation_image1 = np.transpose(validation_image1,(0,3,1,2)) 101 | validation_image_LIDAR = np.transpose(validation_image_LIDAR,(0,3,1,2)) 102 | validation_image_LIDAR1 = np.transpose(validation_image_LIDAR1,(0,3,1,2)) 103 | 104 | if args.augment: 105 | transform_train = [CenterResizeCrop(scale_begin = args.scale, windowsize = windowsize)] 106 | untrain_dataset = HyperData((train_image, train_image_LIDAR, train_label), transform_train) 107 | else: 108 | untrain_dataset = TensorDataset(torch.tensor(train_image), torch.tensor(train_image_LIDAR), torch.tensor(train_label)) 109 | untrain_loader = DataLoader(dataset = untrain_dataset, batch_size = batch_size, shuffle = True) 110 | 111 | print("=> creating model '{}'".format(net_name)) 112 | 113 | 114 | ######################## pre-train ######################## 115 | diffusion = create_diffusion(timestep_respacing="1000") 116 | net = mae_vit_HSIandLIDAR_patch3(img_size=(windowsize,windowsize), in_chans=nband, in_chans_LIDAR=in_chans_LIDAR, hid_chans = args.hid_chans, hid_chans_LIDAR = args.hid_chans_LIDAR, embed_dim=args.encoder_dim, depth=args.encoder_depth, num_heads=args.encoder_num_heads, mlp_ratio=args.mlp_ratio, 117 | decoder_embed_dim=args.decoder_dim, decoder_depth=args.decoder_depth, decoder_num_heads=args.decoder_num_heads, nb_classes=nclass, global_pool=False) 118 | 119 | net.cuda() 120 | optimizer = optim.Adam(net.parameters(),lr = lr, weight_decay= 1e-4) 121 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5,T_mult=2) 122 | 123 | tic1 = time.time() 124 | for epoch in range(num_epoch): 125 | net.train() 126 | total_loss = 0 127 | for idx, (x, x_LIDAR, y) in enumerate(untrain_loader): 128 | 129 | x = x.cuda() 130 | x_LIDAR = x_LIDAR.cuda() 131 | y = y.cuda() 132 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],)).cuda() 133 | model_kwargs = dict(y=y) 134 | model_kwargs['mask_ratio'] = mask_ratio 135 | 136 | # Method 1 of forward process 137 | noise = torch.randn_like(x) 138 | x_t = diffusion.q_sample(x, t, noise) 139 | noise_LIDAR = torch.randn_like(x_LIDAR) 140 | x_LIDAR_t = diffusion.q_sample(x_LIDAR, t, noise_LIDAR) 141 | 142 | pred_imgs, pred_imgs_LIDAR, logits, mask, mask_LIDAR = net(x_t, x_LIDAR_t, t, **model_kwargs) 143 | 144 | cls_loss = criterion(logits / args.temperature, y) 145 | loss_mse_m, loss_mse_LIDAR_m, loss_mse_v, loss_mse_LIDAR_v = DDPM_LOSS(pred_imgs, pred_imgs_LIDAR, x, x_LIDAR, mask, mask_LIDAR, size) 146 | 147 | #removing classification loss 148 | loss = 0.1 * (loss_mse_m + loss_mse_LIDAR_m) + loss_mse_v + loss_mse_LIDAR_v 149 | 150 | optimizer.zero_grad() 151 | loss.backward() 152 | optimizer.step() 153 | total_loss = total_loss + loss 154 | 155 | scheduler.step() 156 | total_loss = total_loss/(idx+1) 157 | state = {'model':net.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} 158 | print('epoch:',epoch, 159 | 'loss:',total_loss.data.cpu().numpy()) 160 | 161 | toc1 = time.time() 162 | torch.save(state, './net.pt') 163 | 164 | 165 | # ######################## finetune # ######################## 166 | classnames_dict = read_classnames("./classnames_houston.txt") 167 | classnames = list(classnames_dict.values()) 168 | model = build_model(img_size=(windowsize,windowsize), in_chans=nband, in_chans_LIDAR=in_chans_LIDAR, hid_chans = args.hid_chans, hid_chans_LIDAR = args.hid_chans_LIDAR, embed_dim=args.encoder_dim, depth=args.encoder_depth, num_heads=args.encoder_num_heads, mlp_ratio=args.mlp_ratio,num_classes = nclass, global_pool=False).cuda() 169 | 170 | tic2 = time.time() 171 | 172 | optimizer = optim.Adam(model.parameters(), lr = args.fine_tuned_lr, weight_decay= 1e-5) 173 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [100, 131], gamma = 0.1, last_epoch=-1) 174 | model = pre_train(model, train_image, train_image_LIDAR, train_label,validation_image, validation_image_LIDAR, validation_label, num_fine_tuned, optimizer, scheduler, batch_size, diffusion,classnames, val = False) 175 | toc2 = time.time() 176 | model.load_state_dict(torch.load('Best_val_model/net_params.pkl')) 177 | state_finetune = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()} 178 | true_cla, overall_accuracy, average_accuracy, kappa, true_label, test_pred, test_index, cm, pred_array = test_batch(model.eval(), image, image_LiDAR, index, 128, nTrain_perClass,nvalid_perClass, halfsize, diffusion,classnames) 179 | toc3 = time.time() 180 | 181 | af_result[:nclass,num] = true_cla 182 | af_result[nclass,num] = overall_accuracy 183 | af_result[nclass+1,num] = average_accuracy 184 | af_result[nclass+2,num] = kappa 185 | 186 | OA.append(overall_accuracy) 187 | AA.append(average_accuracy) 188 | KAPPA.append(kappa) 189 | TRAINING_TIME.append(toc1 - tic1 + toc2 - tic2) 190 | TESTING_TIME.append(toc3 - toc2) 191 | ELEMENT_ACC[num, :] = true_cla 192 | classification_map, gt_map = generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, overall_accuracy, halfsize, dataset, day_str, num, net_name) 193 | file_name = 'result/'+ dataset + '/' + '40_11_0.7_p3' + str(overall_accuracy)+'.mat' 194 | savemat(file_name, {'map':classification_map}) 195 | torch.save(state_finetune, 'model/'+ dataset + '/' + '40_11_0.7_p3' + str(overall_accuracy)+'net.pt') 196 | result = np.mean(af_result, axis = 1) 197 | print(result) 198 | print("--------" + net_name + " Training Finished-----------") 199 | record.record_output(OA, AA, KAPPA, ELEMENT_ACC, TRAINING_TIME, TESTING_TIME, 200 | 'records/' + dataset + '/' + net_name + '_' + day_str+ '_' + str(args.epochs)+ '_' + str(args.fine_tuned_epochs) + '_train_num:' + str(train_image.shape[0]) +'_windowsize:' + str(windowsize)+'_mask_ratio_' + str(mask_ratio) + '_temperature_' + str(args.temperature) + 201 | '_augment_' + str(args.augment) +'_aug_scale_' + str(args.scale) + '_loss_ratio_' + str(args.cls_loss_ratio) +'_decoder_dim_' + str(args.decoder_dim) + '_decoder_depth_' + str(args.decoder_depth)+ '.txt') 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /util_CNN_clip.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from sklearn import metrics 5 | from torch.utils.data import TensorDataset, DataLoader 6 | import torch.nn as nn 7 | 8 | ce_loss = nn.CrossEntropyLoss() 9 | 10 | def text_valid(classnames): 11 | 12 | ctx_init_1 = "a patch of a" #p1 13 | ctx_init_2 = "a fine grained patch of a" #p2 14 | ctx_init_3 = "a multimodal fusion patch of a" #p3 15 | ctx_init_4 = "a patch of a" #p4 16 | classnames = [name.replace("_", " ") for name in classnames] 17 | prompts = [ctx_init_3 + " " + name + "." for name in classnames] 18 | tokenized_prompts_fuse = torch.cat([tokenize(p) for p in prompts]) # (n_cls, n_tkn) 19 | 20 | 21 | return tokenized_prompts_fuse 22 | def tr_acc(model, image, image_LIDAR, label, diffusion,classnames): 23 | train_dataset = TensorDataset(torch.tensor(image), torch.tensor(image_LIDAR), torch.tensor(label)) 24 | train_loader = DataLoader(dataset = train_dataset, batch_size = 64, shuffle = False) 25 | train_loss = 0 26 | corr_num = 0 27 | for idx, (image_batch, image_LIDAR_batch, label_batch) in enumerate(train_loader): 28 | trans_image_batch = image_batch.cuda() 29 | trans_image_LIDAR_batch = image_LIDAR_batch.cuda() 30 | label_batch = label_batch.cuda() 31 | t = torch.randint(0, diffusion.num_timesteps, (trans_image_batch.shape[0],)).cuda() 32 | text = text_valid(classnames).cuda() 33 | logits,_ = model(trans_image_batch, trans_image_LIDAR_batch, t,text) 34 | 35 | if isinstance(logits,tuple): 36 | logits = logits[-1] 37 | pred = torch.max(logits, dim=1)[1] 38 | loss = ce_loss(logits, label_batch) 39 | train_loss = train_loss + loss.cpu().data.numpy() 40 | corr_num = torch.eq(pred, label_batch).float().sum().cpu().numpy() + corr_num 41 | return corr_num/image.shape[0], train_loss/(idx+1) 42 | 43 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 44 | _tokenizer = _Tokenizer() 45 | from pkg_resources import packaging 46 | def tokenize(texts, context_length: int = 77, truncate: bool = False): 47 | """ 48 | Returns the tokenized representation of given input string(s) 49 | 50 | Parameters 51 | ---------- 52 | texts : Union[str, List[str]] 53 | An input string or a list of input strings to tokenize 54 | 55 | context_length : int 56 | The context length to use; all CLIP models use 77 as the context length 57 | 58 | truncate: bool 59 | Whether to truncate the text in case its encoding is longer than the context length 60 | 61 | Returns 62 | ------- 63 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 64 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 65 | """ 66 | if isinstance(texts, str): 67 | texts = [texts] 68 | 69 | sot_token = _tokenizer.encoder["<|startoftext|>"] 70 | eot_token = _tokenizer.encoder["<|endoftext|>"] 71 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 72 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 73 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 74 | else: 75 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 76 | 77 | for i, tokens in enumerate(all_tokens): 78 | if len(tokens) > context_length: 79 | if truncate: 80 | tokens = tokens[:context_length] 81 | tokens[-1] = eot_token 82 | else: 83 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 84 | result[i, :len(tokens)] = torch.tensor(tokens) 85 | 86 | return result 87 | 88 | def text_target(classnames,batch_target): 89 | 90 | ctx_init_1 = "a patch of a" #p1 91 | ctx_init_2 = "a fine grained patch of a" #p2 92 | ctx_init_3 = "a multimodal fusion patch of a" #p3 93 | ctx_init_4 = "a patch of a" #p4 94 | classnames = [name.replace("_", " ") for name in classnames] 95 | prompts = [ctx_init_3 + " " + classnames[batch_target[i].item()] + "." for i in range(len(batch_target))] 96 | tokenized_prompts_fuse = torch.cat([tokenize(p) for p in prompts]) # (n_cls, n_tkn) 97 | 98 | 99 | return tokenized_prompts_fuse 100 | 101 | 102 | 103 | def pre_train(model, train_image, train_image_LIDAR, train_label,validation_image, validation_image_LIDAR, validation_label, epoch, optimizer, scheduler, bs, diffusion,classnames, val = False, ): 104 | train_dataset = TensorDataset(torch.tensor(train_image), torch.tensor(train_image_LIDAR), torch.tensor(train_label)) 105 | train_loader = DataLoader(dataset = train_dataset, batch_size = bs, shuffle = True) 106 | Train_loss = [] 107 | Train_acc = [] 108 | Val_loss = [] 109 | Val_acc = [] 110 | BestAcc = 0 111 | for i in range(epoch): 112 | model.train() 113 | train_loss = 0 114 | for idx, (image_batch, image_LIDAR_batch, label_batch) in enumerate(train_loader): 115 | 116 | trans_image_batch = image_batch.cuda() 117 | trans_image_LIDAR_batch = image_LIDAR_batch.cuda() 118 | label_batch = label_batch.cuda() 119 | text = text_target(classnames,label_batch).cuda() 120 | t = torch.randint(0, diffusion.num_timesteps, (trans_image_batch.shape[0],)).cuda() 121 | # logits = model(trans_image_batch, trans_image_LIDAR_batch, t,text) 122 | # loss = ce_loss(logits, label_batch) 123 | logits_per_image_x,logits_per_text = model(trans_image_batch, trans_image_LIDAR_batch, t,text) 124 | labels = torch.arange(len(logits_per_image_x)).to(logits_per_image_x.device) 125 | loss = ce_loss(logits_per_image_x, labels) 126 | loss += ce_loss(logits_per_text, labels) 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | 131 | train_loss = train_loss + loss 132 | scheduler.step() 133 | train_loss = train_loss / (idx + 1) 134 | train_acc, tr_loss = tr_acc(model.eval(), train_image, train_image_LIDAR, train_label, diffusion,classnames) 135 | val_acc, val_loss= tr_acc(model.eval(), validation_image, validation_image_LIDAR, validation_label, diffusion,classnames) 136 | if val_acc > BestAcc: 137 | torch.save(model.state_dict(), 'Best_val_model/' + 'net_params.pkl') 138 | BestAcc = val_acc 139 | print("epoch {}, training loss: {:.4f}, train acc:{:.4f}, valid acc:{:.4f}".format(i, train_loss.item(), train_acc*100, val_acc*100)) 140 | 141 | if val: 142 | Train_loss.append(tr_loss) 143 | Val_loss.append(val_loss) 144 | Train_acc.append(train_acc) 145 | Val_acc.append(val_acc) 146 | if val: 147 | return model, [Train_loss, Train_acc, Val_loss, Val_acc] 148 | else: 149 | return model 150 | 151 | def test_batch(model, image, image_LIDAR, index, BATCH_SIZE, nTrain_perClass, nvalid_perClass, halfsize, diffusion,classnames): 152 | ind = index[0][nTrain_perClass[0]+ nvalid_perClass[0]:,:] 153 | nclass = len(index) 154 | true_label = np.zeros(ind.shape[0], dtype = np.int32) 155 | for i in range(1, nclass): 156 | ddd = index[i][nTrain_perClass[i] + nvalid_perClass[i]:,:] 157 | ind = np.concatenate((ind, ddd), axis = 0) 158 | tr_label = np.ones(ddd.shape[0], dtype = np.int32) * i 159 | true_label = np.concatenate((true_label, tr_label), axis = 0) 160 | test_index = np.copy(ind) 161 | length = ind.shape[0] 162 | if length % BATCH_SIZE != 0: 163 | add_num = BATCH_SIZE - length % BATCH_SIZE 164 | ff = range(length) 165 | add_ind = np.random.choice(ff, add_num, replace = False) 166 | add_ind = ind[add_ind] 167 | ind = np.concatenate((ind,add_ind), axis =0) 168 | 169 | pred_array = np.zeros([ind.shape[0],nclass], dtype = np.float32) 170 | n = ind.shape[0] // BATCH_SIZE 171 | windowsize = 2 * halfsize + 1 172 | image_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image.shape[2]], dtype=np.float32) 173 | image_LIDAR_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image_LIDAR.shape[2]], dtype=np.float32) 174 | for i in range(n): 175 | for j in range(BATCH_SIZE): 176 | m = ind[BATCH_SIZE*i+j, :] 177 | image_batch[j,:,:,:] = image[(m[0] - halfsize):(m[0] + halfsize + 1), 178 | (m[1] - halfsize):(m[1] + halfsize + 1),:] 179 | image_b = np.transpose(image_batch,(0,3,1,2)) 180 | image_LIDAR_batch[j,:,:,:] = image_LIDAR[(m[0] - halfsize):(m[0] + halfsize + 1), 181 | (m[1] - halfsize):(m[1] + halfsize + 1),:] 182 | image_LIDAR_b = np.transpose(image_LIDAR_batch,(0,3,1,2)) 183 | 184 | t = torch.randint(0, diffusion.num_timesteps, (image_b.shape[0],)).cuda() 185 | text = text_valid(classnames).cuda() 186 | logits,_ = model(torch.tensor(image_b).cuda(), torch.tensor(image_LIDAR_b).cuda(), t,text) 187 | if isinstance(logits,tuple): 188 | logits = logits[-1] 189 | pred_array[i*BATCH_SIZE:(i+1)*BATCH_SIZE] = torch.softmax(logits, dim = 1).cpu().data.numpy() 190 | pred_array = pred_array[range(length)] 191 | predict_label = np.argmax(pred_array, axis=1) 192 | 193 | 194 | confusion_matrix = metrics.confusion_matrix(true_label, predict_label) 195 | overall_accuracy = metrics.accuracy_score(true_label, predict_label) 196 | 197 | true_cla = np.zeros(nclass, dtype=np.int64) 198 | for i in range(nclass): 199 | true_cla[i] = confusion_matrix[i,i] 200 | test_num_class = np.sum(confusion_matrix,1) 201 | test_num = np.sum(test_num_class) 202 | num1 = np.sum(confusion_matrix,0) 203 | po = overall_accuracy 204 | pe = np.sum(test_num_class*num1)/(test_num*test_num) 205 | kappa = (po-pe)/(1-pe)*100 206 | true_cla = np.true_divide(true_cla,test_num_class)*100 207 | average_accuracy = np.average(true_cla) 208 | print('overall_accuracy: {0:f}'.format(overall_accuracy*100)) 209 | print('average_accuracy: {0:f}'.format(average_accuracy)) 210 | print('kappa:{0:f}'.format(kappa)) 211 | return true_cla, overall_accuracy*100, average_accuracy, kappa, true_label, predict_label, test_index, confusion_matrix, pred_array 212 | --------------------------------------------------------------------------------