├── README.md
├── assets
├── AAAI_DiffCLIP.pdf
└── DiffCLIP.png
├── augment.py
├── bpe_simple_vocab_16e6.txt.gz
├── build_EMP.py
├── classnames_houston.txt
├── clip_model.py
├── config.py
├── data_read.py
├── diffusion
├── __init__.py
├── diffusion_utils.py
├── gaussian_diffusion.py
├── respace.py
└── timestep_sampler.py
├── generate_pic.py
├── hyper_dataset.py
├── model.py
├── pos_embed.py
├── record.py
├── simple_tokenizer.py
├── train.py
└── util_CNN_clip.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
DiffCLIP
3 | DiffCLIP: Few-shot Language-driven Multimodal Classifier
4 | AAAI 2025
5 |
6 |
7 |
8 |
9 | ## **Overview**
10 |
11 |
12 |
13 |
14 |
15 | ## **Getting Started**
16 |
17 | **Step 1: Clone the DiffCLIP repository:**
18 |
19 | To get started, first clone the DiffCLIP repository and navigate to the project directory:
20 |
21 | ```bash
22 | git clone https://github.com/icey-zhang/DiffCLIP
23 | cd DiffCLIP
24 | ```
25 |
26 | **Step 2: Environment Setup:**
27 |
28 | DiffCLIP recommends setting up a conda environment and installing dependencies via pip. Use the following commands to set up your environment:
29 |
30 | ***Create and activate a new conda environment***
31 |
32 | ```bash
33 | conda create -n DiffCLIP python=3.9.17
34 | conda activate DiffCLIP
35 | ```
36 |
37 | ***install some necessary package***
38 | ```bash
39 | pip install pytorch
40 | ......
41 | ```
42 |
43 | ### Prepare the dataset
44 |
45 | ```python
46 | root
47 | ├── Trento
48 | │ ├── HSI.mat
49 | │ ├── LiDAR.mat
50 | │ ├── TRLabel.mat
51 | │ ├── TSLabel.mat
52 | ├── ......
53 |
54 | ```
55 |
56 | ### Begin to train
57 |
58 | ```python
59 | python train.py
60 | ```
61 |
62 | ## Citation
63 | If our code is helpful to you, please cite:
64 |
65 | ```
66 | @article{zhang2024multimodal,
67 | title={Multimodal Informative ViT: Information Aggregation and Distribution for Hyperspectral and LiDAR Classification},
68 | author={Zhang, Jiaqing and Lei, Jie and Xie, Weiying and Yang, Geng and Li, Daixun and Li, Yunsong},
69 | journal={IEEE Transactions on Circuits and Systems for Video Technology},
70 | year={2024},
71 | publisher={IEEE}
72 | }
73 |
74 | @inproceedings{zhange2025DiffCLIP,
75 | title={DiffCLIP: Few-shot Language-driven Multimodal Classifier },
76 | author={Zhang, Jiaqing and Cao, Mingxiang and Jiang, Kai and Yang, Xue},
77 | booktitle={AAAI2025}
78 | }
79 |
80 | ```
81 |
82 |
88 |
94 |
98 | height="500"
99 | />
100 |
101 |
102 |
103 |
104 |
105 |
--------------------------------------------------------------------------------
/assets/AAAI_DiffCLIP.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/icey-zhang/DiffCLIP/413caad6246caa63799fbd5053f7740ceb9a18c0/assets/AAAI_DiffCLIP.pdf
--------------------------------------------------------------------------------
/assets/DiffCLIP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/icey-zhang/DiffCLIP/413caad6246caa63799fbd5053f7740ceb9a18c0/assets/DiffCLIP.png
--------------------------------------------------------------------------------
/augment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 |
5 |
6 | class CenterResizeCrop(object):
7 | '''
8 | Class that performs CenterResizeCrop.
9 | -------------------------------------------------------------------------------------
10 | probability: The probability that the operation will be performed.
11 | scale_begin: min scale size
12 | windowsize: max scale size
13 |
14 | -------------------------------------------------------------------------------------
15 | '''
16 | def __init__(self, scale_begin = 23, windowsize = 27):
17 | self.scale_begin = scale_begin
18 | self.windowsize = windowsize
19 |
20 | def __call__(self, image):
21 |
22 | length = np.array(range(self.scale_begin, self.windowsize+1, 2))
23 |
24 | row_center = int((self.windowsize-1)/2)
25 | col_center = int((self.windowsize-1)/2)
26 | row = image.shape[1]
27 | col = image.shape[2]
28 | # band = image.shape[0]
29 | s = np.random.choice(length, size = 1)
30 | halfsize_row = int((s-1)/2)
31 | halfsize_col = int((s-1)/2)
32 | r_image = image[:, row_center-halfsize_row : row_center+halfsize_row+1, col_center-halfsize_col : col_center+halfsize_col+1]
33 | r_image_transpose = cv2.resize(np.transpose(r_image, [1, 2, 0]), (row, col))
34 | if r_image_transpose.ndim != 3:
35 | r_image_transpose = np.expand_dims(r_image_transpose,2)
36 | r_image = np.transpose(r_image_transpose, [2,0,1])
37 | return r_image
38 |
39 | class RandomResizeCrop(object):
40 | def __init__(self, scale = [0.5, 1], probability = 0.5):
41 | self.scale = scale
42 | self.probability = probability
43 |
44 | def __call__(self, image):
45 |
46 | if random.uniform(0, 1) > self.probability:
47 | return image
48 | else:
49 | row = image.shape[1]
50 | col = image.shape[2]
51 | s = np.random.uniform(self.scale[0], self.scale[1])
52 | r_row = round(row * s)
53 | r_col = round(col * s)
54 | halfsize_row = int((r_row-1)/2)
55 | halfsize_col = int((r_col-1)/2)
56 | row_center =random.randint(halfsize_row, r_row - halfsize_row-1)
57 | col_center =random.randint(halfsize_col, r_col - halfsize_col-1)
58 | r_image = image[:, row_center-halfsize_row : row_center+halfsize_row+1, col_center-halfsize_col : col_center+halfsize_col+1]
59 | r_image = np.transpose(cv2.resize(np.transpose(r_image,[1,2,0]), (row, col)), [2,0,1])
60 | return r_image
61 |
62 | class CenterCrop(object):
63 | '''
64 | Class that performs CenterResizeCrop.
65 | -------------------------------------------------------------------------------------
66 | probability: The probability that the operation will be performed.
67 | scale_begin: min scale size
68 | windowsize: max scale size
69 |
70 | -------------------------------------------------------------------------------------
71 | '''
72 | def __init__(self, scale_begin = 23, windowsize = 27, probability = 0.5):
73 | self.scale_begin = scale_begin
74 | self.windowsize = windowsize
75 | self.probability = probability
76 |
77 | def __call__(self, image):
78 |
79 | if random.uniform(0, 1) > self.probability:
80 | return image
81 | else:
82 | length = np.array(range(self.scale_begin, self.windowsize, 2))
83 |
84 | row_center = int((self.windowsize-1)/2)
85 | col_center = int((self.windowsize-1)/2)
86 | s = np.random.choice(length, size = 1)
87 | halfsize_row = int((s-1)/2)
88 | halfsize_col = int((s-1)/2)
89 | r_image = image[:, row_center-halfsize_row : row_center+halfsize_row+1, col_center-halfsize_col : col_center+halfsize_col+1]
90 |
91 | # r_image = np.pad(r_image, ((0, 0), (row_center - halfsize_row, row_center - halfsize_row), (col_center - halfsize_col, col_center - halfsize_col)), 'edge')
92 | r_image = np.pad(r_image, ((0, 0), (row_center - halfsize_row, row_center - halfsize_row), (col_center - halfsize_col, col_center - halfsize_col)), 'constant', constant_values=0)
93 | return r_image
94 |
95 | class Cutout(object):
96 | """Randomly mask out one or more patches from an image.
97 | Args:
98 | n_holes (int): Number of patches to cut out of each image.
99 | length (int): The length (in pixels) of each square patch.
100 | """
101 | def __init__(self, n_holes, length):
102 | self.n_holes = n_holes
103 | self.length = length
104 |
105 | def __call__(self, img):
106 | """
107 | Args:
108 | img (Tensor): Tensor image of size (C, H, W).
109 | Returns:
110 | Tensor: Image with n_holes of dimension length x length cut out of it.
111 | """
112 | h = img.shape[1]
113 | w = img.shape[2]
114 | c = img.shape[0]
115 |
116 | mask = np.ones((h, w), np.float32)
117 |
118 | for n in range(self.n_holes):
119 | # (x,y)表示方形补丁的中心位置
120 | y = np.random.randint(h)
121 | x = np.random.randint(w)
122 |
123 | y1 = np.clip(y - self.length // 2, 0, h)
124 | y2 = np.clip(y + self.length // 2, 0, h)
125 | x1 = np.clip(x - self.length // 2, 0, w)
126 | x2 = np.clip(x + self.length // 2, 0, w)
127 |
128 | mask[y1: y2, x1: x2] = 0.
129 |
130 |
131 | mask = np.tile(mask[np.newaxis,:,:], (c,1,1))
132 | img = img * mask
133 |
134 | return img
135 |
136 | def resize(train_image, size = (27,27)):
137 | r_image = np.zeros([train_image.shape[0], train_image.shape[1], size[0], size[1]], dtype = np.float32)
138 | for i in range(train_image.shape[0]):
139 | r_image[i] = np.transpose(cv2.resize(np.transpose(train_image[i],[1,2,0]), size), [2,0,1])
140 | return r_image
141 |
142 | def take_elements(image, location, windowsize):
143 | if windowsize == 1:
144 | if len(image.shape) == 3:
145 | spectral = np.zeros([location.shape[0],image.shape[2]], dtype = image.dtype)
146 | for i in range(location.shape[0]):
147 | spectral[i] = image[location[i][0], location[i][1]]
148 | else:
149 | spectral = np.zeros(location.shape[0], dtype = np.int32)
150 | for i in range(location.shape[0]):
151 | spectral[i] = image[location[i][0], location[i][1]]
152 | else:
153 | if len(image.shape) == 3:
154 | halfsize = int((windowsize - 1)/2)
155 | spectral = np.zeros([location.shape[0], windowsize, windowsize, image.shape[2]], dtype = image.dtype)
156 | for i in range(location.shape[0]):
157 | spectral[i,:,:,:] = image[location[i][0]-halfsize : location[i][0]+halfsize+1 , location[i][1]-halfsize : location[i][1]+halfsize+1, :]
158 | else:
159 | halfsize = int((windowsize - 1)/2)
160 | spectral = np.zeros([location.shape[0], windowsize, windowsize], dtype = image.dtype)
161 | for i in range(location.shape[0]):
162 | spectral[i,:,:] = image[location[i][0]-halfsize : location[i][0]+halfsize+1 , location[i][1]-halfsize : location[i][1]+halfsize+1]
163 | return spectral
164 |
165 |
166 |
--------------------------------------------------------------------------------
/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/icey-zhang/DiffCLIP/413caad6246caa63799fbd5053f7740ceb9a18c0/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/build_EMP.py:
--------------------------------------------------------------------------------
1 | # Building the Extended Morphological Profiles (EMP)
2 | from skimage.morphology import reconstruction
3 | from skimage.morphology import erosion
4 | from skimage.morphology import disk
5 | from skimage import util
6 | import numpy as np
7 | import skimage.morphology as sm
8 | def opening_by_reconstruction(image, se):
9 | """
10 | Performs an Opening by Reconstruction.
11 |
12 | Parameters:
13 | image: 2D matrix.
14 | se: structuring element
15 | Returns:
16 | 2D matrix of the reconstructed image.
17 | """
18 | eroded = erosion(image, se)
19 | reconstructed = reconstruction(eroded, image)
20 | return reconstructed
21 |
22 |
23 | def closing_by_reconstruction(image, se):
24 | """
25 | Performs a Closing by Reconstruction.
26 |
27 | Parameters:
28 | image: 2D matrix.
29 | se: structuring element
30 | Returns:
31 | 2D matrix of the reconstructed image.
32 | """
33 | obr = opening_by_reconstruction(image, se)
34 |
35 | obr_inverted = util.invert(obr)
36 | obr_inverted_eroded = erosion(obr_inverted, se)
37 | obr_inverted_eroded_rec = reconstruction(
38 | obr_inverted_eroded, obr_inverted)
39 | obr_inverted_eroded_rec_inverted = util.invert(obr_inverted_eroded_rec)
40 | return obr_inverted_eroded_rec_inverted
41 |
42 |
43 | def build_morphological_profiles(image, se_size=4, se_size_increment=2, num_openings_closings=4):
44 | """
45 | Build the morphological profiles for a given image.
46 |
47 | Parameters:
48 | base_image: 2d matrix, it is the spectral information part of the MP.
49 | se_size: int, initial size of the structuring element (or kernel). Structuring Element used: disk
50 | se_size_increment: int, structuring element increment step
51 | num_openings_closings: int, number of openings and closings by reconstruction to perform.
52 | Returns:
53 | emp: 3d matrix with both spectral (from the base_image) and spatial information
54 | """
55 | x, y = image.shape
56 |
57 | cbr = np.zeros(shape=(x, y, num_openings_closings))
58 | obr = np.zeros(shape=(x, y, num_openings_closings))
59 |
60 | it = 0
61 | tam = se_size
62 | while it < num_openings_closings:
63 | se = disk(tam)
64 | temp = closing_by_reconstruction(image, se)
65 | cbr[:, :, it] = temp[:, :]
66 | temp = opening_by_reconstruction(image, se)
67 | obr[:, :, it] = temp[:, :]
68 | tam += se_size_increment
69 | it += 1
70 |
71 | mp = np.zeros(shape=(x, y, (num_openings_closings*2)+1))
72 | cont = num_openings_closings - 1
73 | for i in range(num_openings_closings):
74 | mp[:, :, i] = cbr[:, :, cont]
75 | cont = cont - 1
76 |
77 | mp[:, :, num_openings_closings] = image[:, :]
78 |
79 | cont = 0
80 | for i in range(num_openings_closings+1, num_openings_closings*2+1):
81 | mp[:, :, i] = obr[:, :, cont]
82 | cont += 1
83 |
84 | return mp
85 |
86 |
87 | def build_emp(base_image, se_size=4, se_size_increment=2, num_openings_closings=4):
88 | """
89 | Build the extended morphological profiles for a given set of images.
90 |
91 | Parameters:
92 | base_image: 3d matrix, each 'channel' is considered for applying the morphological profile. It is the spectral information part of the EMP.
93 | se_size: int, initial size of the structuring element (or kernel). Structuring Element used: disk
94 | se_size_increment: int, structuring element increment step
95 | num_openings_closings: int, number of openings and closings by reconstruction to perform.
96 | Returns:
97 | emp: 3d matrix with both spectral (from the base_image) and spatial information
98 | """
99 | base_image_rows, base_image_columns, base_image_channels = base_image.shape
100 | se_size = se_size
101 | se_size_increment = se_size_increment
102 | num_openings_closings = num_openings_closings
103 | morphological_profile_size = (num_openings_closings * 2) + 1
104 | emp_size = morphological_profile_size * base_image_channels
105 | emp = np.zeros(
106 | shape=(base_image_rows, base_image_columns, emp_size))
107 |
108 | cont = 0
109 | for i in range(base_image_channels):
110 | # build MPs
111 | mp_temp = build_morphological_profiles(
112 | base_image[:, :, i], se_size, se_size_increment, num_openings_closings)
113 |
114 | aux = morphological_profile_size * (i+1)
115 |
116 | # build the EMP
117 | cont_aux = 0
118 | for k in range(cont, aux):
119 | emp[:, :, k] = mp_temp[:, :, cont_aux]
120 | cont_aux += 1
121 |
122 | cont = morphological_profile_size * (i+1)
123 |
124 | return emp
125 | def build_emp1(base_image, nScale=3):
126 | row = base_image.shape[0]
127 | col = base_image.shape[1]
128 | nPC = base_image.shape[2]
129 | emp = np.zeros([row, col,2*nScale*nPC])
130 | i = 0
131 | for iScale in range(nScale):
132 | for iPC in range(nPC):
133 | se = sm.disk( 3 * (iScale + 1))
134 | x=sm.opening(base_image[:, : , iPC],se)
135 | y=sm.closing(base_image[:, :, iPC ],se)
136 | emp[:,:,i] = x
137 | emp[:,:,i+1] = y
138 | i = i + 2
139 | return emp
140 |
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/classnames_houston.txt:
--------------------------------------------------------------------------------
1 | 0 Healthy grass
2 | 1 Stressed grass
3 | 2 Synthetic grass
4 | 3 Tree
5 | 4 Soil
6 | 5 Water
7 | 6 Residential
8 | 7 Commercial
9 | 8 Road
10 | 9 Highway
11 | 10 Railway
12 | 11 Park lot 1
13 | 12 Park lot 2
14 | 13 Tennis court
15 | 14 Running track
--------------------------------------------------------------------------------
/clip_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | import math
9 | from pos_embed import interpolate_pos_embed
10 | from timm.models.layers import trunc_normal_
11 | from model import vit_HSI_LIDAR_patch3
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1):
16 | super().__init__()
17 |
18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu1 = nn.ReLU(inplace=True)
22 |
23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.relu2 = nn.ReLU(inplace=True)
26 |
27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28 |
29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31 | self.relu3 = nn.ReLU(inplace=True)
32 |
33 | self.downsample = None
34 | self.stride = stride
35 |
36 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38 | self.downsample = nn.Sequential(OrderedDict([
39 | ("-1", nn.AvgPool2d(stride)),
40 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41 | ("1", nn.BatchNorm2d(planes * self.expansion))
42 | ]))
43 |
44 | def forward(self, x: torch.Tensor):
45 | identity = x
46 |
47 | out = self.relu1(self.bn1(self.conv1(x)))
48 | out = self.relu2(self.bn2(self.conv2(out)))
49 | out = self.avgpool(out)
50 | out = self.bn3(self.conv3(out))
51 |
52 | if self.downsample is not None:
53 | identity = self.downsample(x)
54 |
55 | out += identity
56 | out = self.relu3(out)
57 | return out
58 |
59 |
60 | class AttentionPool2d(nn.Module):
61 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62 | super().__init__()
63 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64 | self.k_proj = nn.Linear(embed_dim, embed_dim)
65 | self.q_proj = nn.Linear(embed_dim, embed_dim)
66 | self.v_proj = nn.Linear(embed_dim, embed_dim)
67 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68 | self.num_heads = num_heads
69 |
70 | def forward(self, x):
71 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
72 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74 | x, _ = F.multi_head_attention_forward(
75 | query=x[:1], key=x, value=x,
76 | embed_dim_to_check=x.shape[-1],
77 | num_heads=self.num_heads,
78 | q_proj_weight=self.q_proj.weight,
79 | k_proj_weight=self.k_proj.weight,
80 | v_proj_weight=self.v_proj.weight,
81 | in_proj_weight=None,
82 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83 | bias_k=None,
84 | bias_v=None,
85 | add_zero_attn=False,
86 | dropout_p=0,
87 | out_proj_weight=self.c_proj.weight,
88 | out_proj_bias=self.c_proj.bias,
89 | use_separate_proj_weight=True,
90 | training=self.training,
91 | need_weights=False
92 | )
93 | return x.squeeze(0)
94 |
95 |
96 | class ModifiedResNet(nn.Module):
97 | """
98 | A ResNet class that is similar to torchvision's but contains the following changes:
99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101 | - The final pooling layer is a QKV attention instead of an average pool
102 | """
103 |
104 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105 | super().__init__()
106 | self.output_dim = output_dim
107 | self.input_resolution = input_resolution
108 |
109 | # the 3-layer stem
110 | self.conv1 = nn.Conv2d(width, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111 | self.bn1 = nn.BatchNorm2d(width // 2)
112 | self.relu1 = nn.ReLU(inplace=True)
113 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
114 | self.bn2 = nn.BatchNorm2d(width // 2)
115 | self.relu2 = nn.ReLU(inplace=True)
116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
117 | self.bn3 = nn.BatchNorm2d(width)
118 | self.relu3 = nn.ReLU(inplace=True)
119 | self.avgpool = nn.AvgPool2d(2)
120 |
121 | # residual layers
122 | self._inplanes = width # this is a *mutable* variable used during construction
123 | self.layer1 = self._make_layer(width, layers[0])
124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
127 |
128 | embed_dim = width * 32 # the ResNet feature dimension
129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
130 |
131 | def _make_layer(self, planes, blocks, stride=1):
132 | layers = [Bottleneck(self._inplanes, planes, stride)]
133 |
134 | self._inplanes = planes * Bottleneck.expansion
135 | for _ in range(1, blocks):
136 | layers.append(Bottleneck(self._inplanes, planes))
137 |
138 | return nn.Sequential(*layers)
139 |
140 | def forward(self, x):
141 | def stem(x):
142 | x = self.relu1(self.bn1(self.conv1(x)))
143 | x = self.relu2(self.bn2(self.conv2(x)))
144 | x = self.relu3(self.bn3(self.conv3(x)))
145 | x = self.avgpool(x)
146 | return x
147 |
148 | x = x.type(self.conv1.weight.dtype)
149 | x = stem(x)
150 | x = self.layer1(x)
151 | x = self.layer2(x)
152 | x = self.layer3(x)
153 | x = self.layer4(x)
154 | x = self.attnpool(x)
155 |
156 | return x
157 |
158 |
159 | class LayerNorm(nn.LayerNorm):
160 | """Subclass torch's LayerNorm to handle fp16."""
161 |
162 | def forward(self, x: torch.Tensor):
163 | orig_type = x.dtype
164 | ret = super().forward(x.type(torch.float32))
165 | return ret.type(orig_type)
166 |
167 |
168 | class QuickGELU(nn.Module):
169 | def forward(self, x: torch.Tensor):
170 | return x * torch.sigmoid(1.702 * x)
171 |
172 |
173 | class ResidualAttentionBlock(nn.Module):
174 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
175 | super().__init__()
176 |
177 | self.attn = nn.MultiheadAttention(d_model, n_head)
178 | self.ln_1 = LayerNorm(d_model)
179 | self.mlp = nn.Sequential(OrderedDict([
180 | ("c_fc", nn.Linear(d_model, d_model * 4)),
181 | ("gelu", QuickGELU()),
182 | ("c_proj", nn.Linear(d_model * 4, d_model)),
183 | ]))
184 | self.ln_2 = LayerNorm(d_model)
185 | self.attn_mask = attn_mask
186 |
187 | def attention(self, x: torch.Tensor):
188 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
189 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
190 |
191 | def forward(self, x: torch.Tensor):
192 | x = x + self.attention(self.ln_1(x))
193 | x = x + self.mlp(self.ln_2(x))
194 | return x
195 |
196 |
197 | class Transformer(nn.Module):
198 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
199 | super().__init__()
200 | self.width = width
201 | self.layers = layers
202 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)])
203 | def forward(self, x: torch.Tensor):
204 | return self.resblocks(x)
205 |
206 |
207 | class cnn_encoder_A_and_B(nn.Module):
208 | def __init__(self, l1, l2,vision_patch_size,out_channels):
209 | super(cnn_encoder_A_and_B, self).__init__()
210 |
211 | self.conv1 = conv_bn_relu(l1, 32, 3, 1, 1)
212 | self.conv2 = conv_bn_relu(l2, 32, 3, 1, 1)
213 | self.conv1_1 = conv_bn_relu_max(32, out_channels, 3, 1, 1, vision_patch_size)
214 | self.conv2_1 = conv_bn_relu_max(32, out_channels, 3, 1, 1, vision_patch_size)
215 |
216 |
217 | self.xishu1 = torch.nn.Parameter(torch.Tensor([0.5])) # lamda
218 | self.xishu2 = torch.nn.Parameter(torch.Tensor([0.5])) # 1 - lamda
219 |
220 |
221 | def forward(self, x11, x21):
222 | x11 = self.conv1(x11)
223 | x21 = self.conv2(x21)
224 | x1_1 = self.conv1_1(x11) #64,64,8,8
225 | x2_1 = self.conv2_1(x21)
226 | x_add = x1_1*self.xishu1 + x2_1*self.xishu2
227 |
228 | return x_add,x1_1,x2_1
229 |
230 | class VisionTransformer(nn.Module):
231 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
232 | super().__init__()
233 | self.input_resolution = input_resolution
234 | self.output_dim = output_dim
235 |
236 | scale = width ** -0.5
237 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
238 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
239 | self.ln_pre = LayerNorm(width)
240 |
241 | self.transformer = Transformer(width, layers, heads)
242 |
243 | self.ln_post = LayerNorm(width)
244 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
245 |
246 | def forward(self, x: torch.Tensor,csa=True):
247 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
248 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
249 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
250 | x = x + self.positional_embedding.to(x.dtype)
251 | x = self.ln_pre(x)
252 |
253 | x = x.permute(1, 0, 2) # NLD -> LND
254 | x = self.transformer(x)
255 |
256 | x = x.permute(1, 0, 2) # LND -> NLD
257 | x_vit = x
258 |
259 |
260 | x = self.ln_post(x[:, 0, :])
261 |
262 | if self.proj is not None:
263 | x = x @ self.proj
264 |
265 | return x,x_vit
266 |
267 |
268 |
269 | class CLIP(nn.Module):
270 | def __init__(self,img_size,
271 | in_chans, in_chans_LIDAR,hid_chans, hid_chans_LIDAR, embed_dim, depth, num_heads, mlp_ratio,num_classes, global_pool,
272 | # text
273 | context_length: int,
274 | vocab_size: int,
275 | transformer_width: int,
276 | transformer_heads: int,
277 | transformer_layers: int
278 | ):
279 | super().__init__()
280 |
281 | self.context_length = context_length
282 |
283 | self.visual = vit_HSI_LIDAR_patch3(img_size=img_size,
284 | in_chans=in_chans, in_chans_LIDAR=in_chans_LIDAR,hid_chans=hid_chans,
285 | hid_chans_LIDAR=hid_chans_LIDAR, embed_dim=embed_dim, depth=depth, num_heads=num_heads,
286 | mlp_ratio=mlp_ratio,num_classes=num_classes, global_pool=global_pool).cuda()
287 |
288 | checkpoint = torch.load('./net.pt')
289 | checkpoint_model = checkpoint['model']
290 |
291 | state_dict = self.visual.state_dict()
292 | for k in ['head.weight', 'head.bias','head_LIDAR.weight', 'head_LIDAR.bias']:
293 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
294 | print(f"Removing key {k} from pretrained checkpoint")
295 | del checkpoint_model[k]
296 |
297 | interpolate_pos_embed(self.visual, checkpoint_model)
298 | msg = self.visual.load_state_dict(checkpoint_model, strict=False)
299 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
300 |
301 | # manually initialize fc layer
302 | trunc_normal_(self.visual.head.weight, std=2e-5)
303 |
304 | self.transformer = Transformer(
305 | width=transformer_width,
306 | layers=transformer_layers,
307 | heads=transformer_heads,
308 | attn_mask=self.build_attention_mask()
309 | )
310 |
311 | self.vocab_size = vocab_size
312 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
313 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
314 | self.ln_final = LayerNorm(transformer_width)
315 |
316 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
317 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.7))
318 |
319 | self.initialize_parameters()
320 |
321 | def initialize_parameters(self):
322 | nn.init.normal_(self.token_embedding.weight, std=0.02)
323 | nn.init.normal_(self.positional_embedding, std=0.01)
324 |
325 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
326 | attn_std = self.transformer.width ** -0.5
327 | fc_std = (2 * self.transformer.width) ** -0.5
328 | for block in self.transformer.resblocks:
329 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
330 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
331 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
332 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
333 | for block in self.transformer.resblocks[-1:]:
334 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
335 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
336 |
337 | if self.text_projection is not None:
338 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
339 |
340 |
341 | def build_attention_mask(self):
342 | # lazily create causal attention mask, with full attention between the vision tokens
343 | # pytorch uses additive attention mask; fill with -inf
344 | mask = torch.empty(self.context_length, self.context_length)
345 | mask.fill_(float("-inf"))
346 | mask.triu_(1) # zero out the lower diagonal
347 | return mask
348 |
349 | @property
350 | def dtype(self):
351 | return self.visual.dimen_redu[0].weight.dtype
352 |
353 |
354 | def encode_text(self, text):
355 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
356 |
357 | x = x + self.positional_embedding.type(self.dtype)
358 | x = x.permute(1, 0, 2) # NLD -> LND
359 | x = self.transformer(x)
360 | x = x.permute(1, 0, 2) # LND -> NLD
361 | x = self.ln_final(x).type(self.dtype)
362 |
363 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
364 |
365 | return x
366 |
367 | def forward(self, image,image_li,t, text):
368 | text_features = self.encode_text(text)
369 | image_features_x = self.visual(image,image_li,t)
370 | image_features_x = image_features_x / image_features_x.norm(dim=1, keepdim=True)
371 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
372 |
373 |
374 | # cosine similarity as logits
375 | logit_scale = self.logit_scale.exp()
376 | logits_per_image_x= logit_scale * image_features_x @ text_features.t()
377 |
378 | logits_per_text = logits_per_image_x.t()
379 |
380 | return logits_per_image_x,logits_per_text
381 |
382 |
383 | def convert_weights(model: nn.Module):
384 | """Convert applicable model parameters to fp16"""
385 |
386 | def _convert_weights_to_fp16(l):
387 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
388 | l.weight.data = l.weight.data.half()
389 | if l.bias is not None:
390 | l.bias.data = l.bias.data.half()
391 |
392 | if isinstance(l, nn.MultiheadAttention):
393 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
394 | tensor = getattr(l, attr)
395 | if tensor is not None:
396 | tensor.data = tensor.data.half()
397 |
398 | for name in ["text_projection", "proj"]:
399 | if hasattr(l, name):
400 | attr = getattr(l, name)
401 | if attr is not None:
402 | attr.data = attr.data.half()
403 |
404 | model.apply(_convert_weights_to_fp16)
405 |
406 |
407 | def build_model(img_size, in_chans, in_chans_LIDAR, hid_chans, hid_chans_LIDAR, embed_dim, depth, num_heads, mlp_ratio,num_classes, global_pool=False):
408 |
409 | context_length = 77
410 | vocab_size = 49408
411 | transformer_width = 128
412 | transformer_heads = 8
413 | transformer_layers = 3
414 | model = CLIP(img_size,in_chans, in_chans_LIDAR,hid_chans, hid_chans_LIDAR, embed_dim, depth, num_heads, mlp_ratio,num_classes, global_pool,
415 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
416 | )
417 | return model
418 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.nn
3 |
4 | def load_args():
5 | parser = argparse.ArgumentParser()
6 |
7 | # Pre training
8 | parser.add_argument('--train_num_perclass', type=int, default=40)
9 | # Pre training
10 |
11 | parser.add_argument('--windowsize', type=int, default=11)
12 | parser.add_argument('--batch_size', type=int, default=128)
13 | parser.add_argument('--epochs', type=int, default=400)
14 | parser.add_argument('--fine_tuned_epochs', type=int, default=150)
15 | parser.add_argument('--lr', type=float, default=1e-4)
16 | parser.add_argument('--fine_tuned_lr', type=float, default=1e-4)
17 |
18 | parser.add_argument('--type', type=str, default='none')
19 | parser.add_argument('--dataset', type=str, default='2013houston') #Trento 2013houston Muufl Berlin
20 |
21 | # Network
22 | parser.add_argument('--mask_ratio', default=0.7, type=float,
23 | help='mask_ratio (default: 0.8)')
24 | parser.add_argument('--mlp_ratio', default=2.0, type=float,
25 | help='mlp ratio of encoder/decoder (default: 2.0)')
26 | parser.add_argument('--hid_chans', default=128, type=int,
27 | help='hidden channels for dimension reduction (default: 128)')
28 | parser.add_argument('--hid_chans_LIDAR', default=128, type=int,
29 | help='hidden channels for dimension reduction (default: 128)')
30 | # Augmentation
31 | parser.add_argument('--augment', default=True, type=bool,
32 | help='either use data augmentation or not (default: False)')
33 | parser.add_argument('--scale', default=9, type=int,
34 | help='the minimum scale for center crop (default: 19)')
35 |
36 | # MAE encoder specifics
37 | parser.add_argument('--encoder_dim', default=128, type=int,
38 | help='feature dimension for encoder (default: 64)')
39 | parser.add_argument('--encoder_depth', default = 4, type=int,
40 | help='encoder_depth; number of blocks ')
41 | parser.add_argument('--encoder_num_heads', default=8, type=int,
42 | help='number of heads of encoder (default: 8)')
43 |
44 | # MAE decoder specifics
45 | parser.add_argument('--decoder_dim', default= 128, type=int,
46 | help='feature dimension for decoder (default: 64)')
47 | parser.add_argument('--decoder_depth', default = 3, type=int,
48 | help='decoder_depth; number of blocks ')
49 | parser.add_argument('--decoder_num_heads', default=8, type=int,
50 | help='number of heads of decoder (default: 8)')
51 |
52 | # options for supervised MAE
53 | parser.add_argument('--temperature', default=1.0, type=float,
54 | help='temperature for classification logits')
55 | parser.add_argument('--cls_loss_ratio', default=0.005, type=float,
56 | help='ratio for classification loss')
57 | args = parser.parse_args()
58 | return args
59 |
--------------------------------------------------------------------------------
/data_read.py:
--------------------------------------------------------------------------------
1 | import scipy.io as sio
2 | import numpy as np
3 | from sklearn.decomposition import PCA
4 | from build_EMP import build_emp
5 |
6 | def pca_whitening(image, number_of_pc):
7 | shape = image.shape
8 | image = np.reshape(image, [shape[0]*shape[1], shape[2]])
9 | number_of_rows = shape[0]
10 | number_of_columns = shape[1]
11 | pca = PCA(n_components = number_of_pc)
12 | image = pca.fit_transform(image)
13 | pc_images = np.zeros(shape=(number_of_rows, number_of_columns, number_of_pc),dtype=np.float32)
14 | for i in range(number_of_pc):
15 | pc_images[:, :, i] = np.reshape(image[:, i], (number_of_rows, number_of_columns))
16 | return pc_images
17 |
18 | def load_data(dataset):
19 | if dataset == 'Trento':
20 | image_file_HSI = r'Trento/HSI.mat'
21 | image_file_LiDAR = r'Trento/LiDAR.mat'
22 | label_file_tr = r'Trento/TRLabel.mat'
23 | label_file_ts = r'Trento/TSLabel.mat'
24 | image_data_HSI = sio.loadmat(image_file_HSI)
25 | image_data_LiDAR = sio.loadmat(image_file_LiDAR)
26 | label_data_tr = sio.loadmat(label_file_tr)
27 | label_data_ts = sio.loadmat(label_file_ts)
28 | image_HSI = image_data_HSI['HSI']
29 | image_LiDAR = image_data_LiDAR['LiDAR']
30 | label = label_data_tr['TRLabel']+label_data_ts['TSLabel']
31 | elif dataset == '2013houston':
32 | image_file_HSI = r'./Houston2013/HSI.mat'
33 | image_file_LiDAR = r'./Houston2013/LiDAR.mat'
34 | label_file_tr = r'./Houston2013/TRLabel.mat'
35 | label_file_ts = r'./Houston2013/TSLabel.mat'
36 | image_data_HSI = sio.loadmat(image_file_HSI)
37 | image_data_LiDAR = sio.loadmat(image_file_LiDAR)
38 | label_data_tr = sio.loadmat(label_file_tr)
39 | label_data_ts = sio.loadmat(label_file_ts)
40 | image_HSI = image_data_HSI['HSI']
41 | image_LiDAR = image_data_LiDAR['LiDAR']
42 | label = label_data_tr['TRLabel']+label_data_ts['TSLabel']
43 | elif dataset == 'Muufl':
44 | image_file_HSI = r'Muufl/hsi.mat'
45 | image_file_LiDAR = r'Muufl/lidar.mat'
46 | label_file = r'Muufl/train_test_gt.mat'
47 | image_data_HSI = sio.loadmat(image_file_HSI)
48 | image_data_LiDAR = sio.loadmat(image_file_LiDAR)
49 | label_data = sio.loadmat(label_file)
50 | image_HSI = image_data_HSI['HSI']
51 | image_LiDAR = image_data_LiDAR['lidar']
52 | label = label_data['trainlabels']+label_data['testlabels']
53 | else:
54 | raise Exception('dataset does not find')
55 | image_HSI = image_HSI.astype(np.float32)
56 | image_LiDAR = image_LiDAR.astype(np.float32)
57 | label = label.astype(np.int64)
58 | return image_HSI, image_LiDAR, label
59 |
60 |
61 | def readdata(type, dataset, windowsize, train_num, val_num, num):
62 |
63 | or_image_HSI, or_image_LiDAR, or_label = load_data(dataset)
64 | # image = np.expand_dims(image, 2)
65 | halfsize = int((windowsize-1)/2)
66 | number_class = np.max(or_label).astype(np.int64)
67 | if dataset == 'Augsburg_SAR':
68 | pass
69 | elif dataset == 'Berlin' or dataset == 'Muufl':
70 | pass
71 | else:
72 | or_image_LiDAR = np.expand_dims(or_image_LiDAR, 2)
73 | # or_image_LiDAR = np.expand_dims(or_image_LiDAR, 2)
74 | image = np.pad(or_image_HSI, ((halfsize, halfsize), (halfsize, halfsize), (0, 0)), 'edge')
75 | image_LiDAR = np.pad(or_image_LiDAR, ((halfsize, halfsize), (halfsize, halfsize), (0, 0)), 'edge')
76 | label = np.pad(or_label, ((halfsize, halfsize), (halfsize, halfsize)), 'constant',constant_values=0)
77 |
78 | if type == 'PCA':
79 | image1 = pca_whitening(image, number_of_pc = 30)
80 | image_LiDAR1 = np.copy(image_LiDAR)
81 | elif type == 'EMP':
82 | image1 = pca_whitening(image, number_of_pc = 4)
83 | num_openings_closings = 3
84 | emp_image = build_emp(base_image=image1, num_openings_closings=num_openings_closings)
85 | image1 = emp_image
86 | elif type == 'none':
87 | image1 = np.copy(image)
88 | image_LiDAR1 = np.copy(image_LiDAR)
89 | else:
90 | raise Exception('type does not find')
91 | image = (image1 - np.min(image1)) / (np.max(image1) - np.min(image1))
92 | image_LiDAR = (image_LiDAR1 - np.min(image_LiDAR1)) / (np.max(image_LiDAR1) - np.min(image_LiDAR1))
93 | #set the manner of selecting training samples
94 |
95 |
96 | n = np.zeros(number_class,dtype=np.int64)
97 | for i in range(number_class):
98 | temprow, tempcol = np.where(label == i + 1)
99 | n[i] = len(temprow)
100 | total_num = np.sum(n)
101 |
102 | nTrain_perClass = np.ones(number_class,dtype=np.int64) * train_num
103 | for i in range(number_class):
104 | if n[i] <= nTrain_perClass[i]:
105 | nTrain_perClass[i] = 15
106 | ###验证机数目
107 | nValidation_perClass = (n/total_num)*val_num
108 | nvalid_perClass = nValidation_perClass.astype(np.int32)
109 |
110 | index = []
111 | flag = 0
112 | fl = 0
113 |
114 |
115 | bands = np.size(image,2)
116 | bands_LIDAR = np.size(image_LiDAR,2)
117 | validation_image = np.zeros([np.sum(nvalid_perClass), windowsize, windowsize, bands], dtype=np.float32)
118 | validation_image_LIDAR = np.zeros([np.sum(nvalid_perClass), windowsize, windowsize, bands_LIDAR], dtype=np.float32)
119 | validation_label = np.zeros(np.sum(nvalid_perClass), dtype=np.int64)
120 | train_image = np.zeros([np.sum(nTrain_perClass), windowsize, windowsize, bands], dtype=np.float32)
121 | train_image_LIDAR = np.zeros([np.sum(nTrain_perClass), windowsize, windowsize, bands_LIDAR], dtype=np.float32)
122 | train_label = np.zeros(np.sum(nTrain_perClass),dtype=np.int64)
123 | train_index = np.zeros([np.sum(nTrain_perClass), 2], dtype = np.int32)
124 | val_index = np.zeros([np.sum(nvalid_perClass), 2], dtype = np.int32)
125 |
126 | for i in range(number_class):
127 | temprow, tempcol = np.where(label == i + 1)
128 | matrix = np.zeros([len(temprow),2], dtype=np.int64)
129 | matrix[:,0] = temprow
130 | matrix[:,1] = tempcol
131 | np.random.seed(num)
132 | np.random.shuffle(matrix)
133 |
134 | temprow = matrix[:,0]
135 | tempcol = matrix[:,1]
136 | index.append(matrix)
137 |
138 | for j in range(nTrain_perClass[i]):
139 | train_image[flag + j, :, :, :] = image[(temprow[j] - halfsize):(temprow[j] + halfsize + 1),
140 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)]
141 | train_image_LIDAR[flag + j, :, :, :] = image_LiDAR[(temprow[j] - halfsize):(temprow[j] + halfsize + 1),
142 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)]
143 | train_label[flag + j] = i
144 | train_index[flag + j] = matrix[j,:]
145 | flag = flag + nTrain_perClass[i]
146 |
147 | for j in range(nTrain_perClass[i], nTrain_perClass[i] + nvalid_perClass[i]):
148 | validation_image[fl + j-nTrain_perClass[i], :, :,:] = image[(temprow[j] - halfsize):(temprow[j] + halfsize + 1),
149 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)]
150 | validation_image_LIDAR[fl + j-nTrain_perClass[i], :, :,:] = image_LiDAR[(temprow[j] - halfsize):(temprow[j] + halfsize + 1),
151 | (tempcol[j] - halfsize):(tempcol[j] + halfsize + 1)]
152 | validation_label[fl + j-nTrain_perClass[i] ] = i
153 | val_index[fl + j-nTrain_perClass[i]] = matrix[j,:]
154 | fl =fl + nvalid_perClass[i]
155 |
156 |
157 | return train_image, train_image_LIDAR, train_label, validation_image, validation_image_LIDAR, validation_label,\
158 | nTrain_perClass, nvalid_perClass,train_index, val_index, index, image, image_LiDAR, label,total_num
159 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=False,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.START_X if not predict_xstart else gd.ModelMeanType.EPSILON
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | )
46 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 |
7 | import math
8 |
9 | import numpy as np
10 | import torch
11 | import torch as th
12 | import enum
13 |
14 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
15 |
16 |
17 | def mean_flat(tensor):
18 | """
19 | Take the mean over all non-batch dimensions.
20 | """
21 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
22 |
23 |
24 | class ModelMeanType(enum.Enum):
25 | """
26 | Which type of output the model predicts.
27 | """
28 |
29 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
30 | START_X = enum.auto() # the model predicts x_0
31 | EPSILON = enum.auto() # the model predicts epsilon
32 |
33 |
34 | class ModelVarType(enum.Enum):
35 | """
36 | What is used as the model's output variance.
37 | The LEARNED_RANGE option has been added to allow the model to predict
38 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39 | """
40 |
41 | LEARNED = enum.auto()
42 | FIXED_SMALL = enum.auto()
43 | FIXED_LARGE = enum.auto()
44 | LEARNED_RANGE = enum.auto()
45 |
46 |
47 | class LossType(enum.Enum):
48 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49 | RESCALED_MSE = (
50 | enum.auto()
51 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
52 | KL = enum.auto() # use the variational lower-bound
53 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54 |
55 | def is_vb(self):
56 | return self == LossType.KL or self == LossType.RESCALED_KL
57 |
58 |
59 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
60 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
61 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
62 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
63 | return betas
64 |
65 |
66 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
67 | """
68 | This is the deprecated API for creating beta schedules.
69 | See get_named_beta_schedule() for the new library of schedules.
70 | """
71 | if beta_schedule == "quad":
72 | betas = (
73 | np.linspace(
74 | beta_start ** 0.5,
75 | beta_end ** 0.5,
76 | num_diffusion_timesteps,
77 | dtype=np.float64,
78 | )
79 | ** 2
80 | )
81 | elif beta_schedule == "linear":
82 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
83 | elif beta_schedule == "warmup10":
84 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
85 | elif beta_schedule == "warmup50":
86 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
87 | elif beta_schedule == "const":
88 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
89 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
90 | betas = 1.0 / np.linspace(
91 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
92 | )
93 | else:
94 | raise NotImplementedError(beta_schedule)
95 | assert betas.shape == (num_diffusion_timesteps,)
96 | return betas
97 |
98 |
99 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
100 | """
101 | Get a pre-defined beta schedule for the given name.
102 | The beta schedule library consists of beta schedules which remain similar
103 | in the limit of num_diffusion_timesteps.
104 | Beta schedules may be added, but should not be removed or changed once
105 | they are committed to maintain backwards compatibility.
106 | """
107 | if schedule_name == "linear":
108 | # Linear schedule from Ho et al, extended to work for any number of
109 | # diffusion steps.
110 | scale = 1000 / num_diffusion_timesteps
111 | return get_beta_schedule(
112 | "linear",
113 | beta_start=scale * 0.0001,
114 | beta_end=scale * 0.02,
115 | num_diffusion_timesteps=num_diffusion_timesteps,
116 | )
117 | elif schedule_name == "squaredcos_cap_v2":
118 | return betas_for_alpha_bar(
119 | num_diffusion_timesteps,
120 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
121 | )
122 | else:
123 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
124 |
125 |
126 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
127 | """
128 | Create a beta schedule that discretizes the given alpha_t_bar function,
129 | which defines the cumulative product of (1-beta) over time from t = [0,1].
130 | :param num_diffusion_timesteps: the number of betas to produce.
131 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
132 | produces the cumulative product of (1-beta) up to that
133 | part of the diffusion process.
134 | :param max_beta: the maximum beta to use; use values lower than 1 to
135 | prevent singularities.
136 | """
137 | betas = []
138 | for i in range(num_diffusion_timesteps):
139 | t1 = i / num_diffusion_timesteps
140 | t2 = (i + 1) / num_diffusion_timesteps
141 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
142 | return np.array(betas)
143 |
144 |
145 | class GaussianDiffusion:
146 | """
147 | Utilities for training and sampling diffusion models.
148 | Original ported from this codebase:
149 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
150 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
151 | starting at T and going to 1.
152 | """
153 |
154 | def __init__(
155 | self,
156 | *,
157 | betas,
158 | model_mean_type,
159 | model_var_type,
160 | loss_type
161 | ):
162 |
163 | self.model_mean_type = model_mean_type
164 | self.model_var_type = model_var_type
165 | self.loss_type = loss_type
166 |
167 | # Use float64 for accuracy.
168 | betas = np.array(betas, dtype=np.float64)
169 | self.betas = betas
170 | assert len(betas.shape) == 1, "betas must be 1-D"
171 | assert (betas > 0).all() and (betas <= 1).all()
172 |
173 | self.num_timesteps = int(betas.shape[0])
174 |
175 | alphas = 1.0 - betas
176 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
177 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
178 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
179 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
180 |
181 | # calculations for diffusion q(x_t | x_{t-1}) and others
182 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
183 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
184 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
185 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
186 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
187 |
188 | # calculations for posterior q(x_{t-1} | x_t, x_0)
189 | self.posterior_variance = (
190 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
191 | )
192 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
193 | self.posterior_log_variance_clipped = np.log(
194 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
195 | ) if len(self.posterior_variance) > 1 else np.array([])
196 |
197 | self.posterior_mean_coef1 = (
198 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
199 | )
200 | self.posterior_mean_coef2 = (
201 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
202 | )
203 |
204 | def q_mean_variance(self, x_start, t):
205 | """
206 | Get the distribution q(x_t | x_0).
207 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
208 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210 | """
211 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
212 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
214 | return mean, variance, log_variance
215 |
216 | def q_sample(self, x_start, t, noise=None):
217 | """
218 | Diffuse the data for a given number of diffusion steps.
219 | In other words, sample from q(x_t | x_0).
220 | :param x_start: the initial data batch.
221 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
222 | :param noise: if specified, the split-out normal noise.
223 | :return: A noisy version of x_start.
224 | """
225 | if noise is None:
226 | noise = th.randn_like(x_start)
227 | assert noise.shape == x_start.shape
228 | return (
229 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
230 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
231 | )
232 |
233 | def q_posterior_mean_variance(self, x_start, x_t, t):
234 | """
235 | Compute the mean and variance of the diffusion posterior:
236 | q(x_{t-1} | x_t, x_0)
237 | """
238 | assert x_start.shape == x_t.shape
239 | posterior_mean = (
240 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
241 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
242 | )
243 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
244 | posterior_log_variance_clipped = _extract_into_tensor(
245 | self.posterior_log_variance_clipped, t, x_t.shape
246 | )
247 | assert (
248 | posterior_mean.shape[0]
249 | == posterior_variance.shape[0]
250 | == posterior_log_variance_clipped.shape[0]
251 | == x_start.shape[0]
252 | )
253 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
254 |
255 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
256 | """
257 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
258 | the initial x, x_0.
259 | :param model: the model, which takes a signal and a batch of timesteps
260 | as input.
261 | :param x: the [N x C x ...] tensor at time t.
262 | :param t: a 1-D Tensor of timesteps.
263 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
264 | :param denoised_fn: if not None, a function which applies to the
265 | x_start prediction before it is used to sample. Applies before
266 | clip_denoised.
267 | :param model_kwargs: if not None, a dict of extra keyword arguments to
268 | pass to the model. This can be used for conditioning.
269 | :return: a dict with the following keys:
270 | - 'mean': the model mean output.
271 | - 'variance': the model variance output.
272 | - 'log_variance': the log of 'variance'.
273 | - 'pred_xstart': the prediction for x_0.
274 | """
275 | if model_kwargs is None:
276 | model_kwargs = {}
277 |
278 | B, C = x.shape[:2]
279 | assert t.shape == (B,)
280 | loss, model_output, model_output_LIDAR, logits, ids_restore, ids_restore_LIDAR, Cross_rec_loss = model(x, x_LIDAR, t, **model_kwargs)
281 | if isinstance(model_output, tuple):
282 | model_output, extra = model_output
283 | else:
284 | extra = None
285 |
286 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
287 | assert model_output.shape == (B, C *2, *x.shape[2:])
288 | model_output, model_var_values = th.split(model_output, C, dim=1)
289 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
290 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
291 | # The model_var_values is [-1, 1] for [min_var, max_var].
292 | frac = (model_var_values + 1) / 2
293 | model_log_variance = frac * max_log + (1 - frac) * min_log
294 | model_variance = th.exp(model_log_variance)
295 |
296 | model_output_LIDAR, model_var_values_LIDAR = th.split(model_output_LIDAR, 1, dim=1)
297 | min_log_LIDAR = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_LIDAR.shape)
298 | max_log_LIDAR = _extract_into_tensor(np.log(self.betas), t, x_LIDAR.shape)
299 | # The model_var_values is [-1, 1] for [min_var, max_var].
300 | frac_LIDAR = (model_var_values_LIDAR + 1) / 2
301 | model_log_variance_LIDAR = frac_LIDAR * max_log_LIDAR + (1 - frac_LIDAR) * min_log_LIDAR
302 | model_variance_LIDAR = th.exp(model_log_variance_LIDAR)
303 | else:
304 | model_variance, model_log_variance = {
305 | # for fixedlarge, we set the initial (log-)variance like so
306 | # to get a better decoder log likelihood.
307 | ModelVarType.FIXED_LARGE: (
308 | np.append(self.posterior_variance[1], self.betas[1:]),
309 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
310 | ),
311 | ModelVarType.FIXED_SMALL: (
312 | self.posterior_variance,
313 | self.posterior_log_variance_clipped,
314 | ),
315 | }[self.model_var_type]
316 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
317 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
318 |
319 | def process_xstart(x):
320 | if denoised_fn is not None:
321 | x = denoised_fn(x)
322 | if clip_denoised:
323 | return x.clamp(-1, 1)
324 | return x
325 |
326 | if self.model_mean_type == ModelMeanType.START_X:
327 | pred_xstart = process_xstart(model_output)
328 | pred_xstart_LIDAR = process_xstart(model_output_LIDAR)
329 | else:
330 | pred_xstart = process_xstart(
331 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
332 | )
333 | pred_xstart_LIDAR = process_xstart(
334 | self._predict_xstart_from_eps(x_t=x_LIDAR, t=t, eps=model_output_LIDAR)
335 | )
336 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
337 | model_mean_LIDAR, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart_LIDAR, x_t=x_LIDAR, t=t)
338 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
339 | return {
340 | "mean": model_mean,
341 | "variance": model_variance,
342 | "log_variance": model_log_variance,
343 | "pred_xstart": pred_xstart,
344 | "extra": extra,
345 | }
346 |
347 | def _predict_xstart_from_eps(self, x_t, t, eps):
348 | assert x_t.shape == eps.shape
349 | return (
350 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
351 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
352 | )
353 |
354 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
355 | return (
356 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
357 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
358 |
359 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
360 | """
361 | Compute the mean for the previous step, given a function cond_fn that
362 | computes the gradient of a conditional log probability with respect to
363 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
364 | condition on y.
365 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
366 | """
367 | gradient = cond_fn(x, t, **model_kwargs)
368 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
369 | return new_mean
370 |
371 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
372 | """
373 | Compute what the p_mean_variance output would have been, should the
374 | model's score function be conditioned by cond_fn.
375 | See condition_mean() for details on cond_fn.
376 | Unlike condition_mean(), this instead uses the conditioning strategy
377 | from Song et al (2020).
378 | """
379 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
380 |
381 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
382 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
383 |
384 | out = p_mean_var.copy()
385 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
386 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
387 | return out
388 |
389 | def p_sample(
390 | self,
391 | model,
392 | x,
393 | t,
394 | clip_denoised=True,
395 | denoised_fn=None,
396 | cond_fn=None,
397 | model_kwargs=None,
398 | ):
399 | """
400 | Sample x_{t-1} from the model at the given timestep.
401 | :param model: the model to sample from.
402 | :param x: the current tensor at x_{t-1}.
403 | :param t: the value of t, starting at 0 for the first diffusion step.
404 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
405 | :param denoised_fn: if not None, a function which applies to the
406 | x_start prediction before it is used to sample.
407 | :param cond_fn: if not None, this is a gradient function that acts
408 | similarly to the model.
409 | :param model_kwargs: if not None, a dict of extra keyword arguments to
410 | pass to the model. This can be used for conditioning.
411 | :return: a dict containing the following keys:
412 | - 'sample': a random sample from the model.
413 | - 'pred_xstart': a prediction of x_0.
414 | """
415 | out = self.p_mean_variance(
416 | model,
417 | x,
418 | t,
419 | clip_denoised=clip_denoised,
420 | denoised_fn=denoised_fn,
421 | model_kwargs=model_kwargs,
422 | )
423 | noise = th.randn_like(x)
424 | nonzero_mask = (
425 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
426 | ) # no noise when t == 0
427 | if cond_fn is not None:
428 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
429 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
430 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
431 |
432 | def p_sample_loop(
433 | self,
434 | model,
435 | shape,
436 | noise=None,
437 | clip_denoised=True,
438 | denoised_fn=None,
439 | cond_fn=None,
440 | model_kwargs=None,
441 | device=None,
442 | progress=False,
443 | ):
444 | """
445 | Generate samples from the model.
446 | :param model: the model module.
447 | :param shape: the shape of the samples, (N, C, H, W).
448 | :param noise: if specified, the noise from the encoder to sample.
449 | Should be of the same shape as `shape`.
450 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
451 | :param denoised_fn: if not None, a function which applies to the
452 | x_start prediction before it is used to sample.
453 | :param cond_fn: if not None, this is a gradient function that acts
454 | similarly to the model.
455 | :param model_kwargs: if not None, a dict of extra keyword arguments to
456 | pass to the model. This can be used for conditioning.
457 | :param device: if specified, the device to create the samples on.
458 | If not specified, use a model parameter's device.
459 | :param progress: if True, show a tqdm progress bar.
460 | :return: a non-differentiable batch of samples.
461 | """
462 | final = None
463 | for sample in self.p_sample_loop_progressive(
464 | model,
465 | shape,
466 | noise=noise,
467 | clip_denoised=clip_denoised,
468 | denoised_fn=denoised_fn,
469 | cond_fn=cond_fn,
470 | model_kwargs=model_kwargs,
471 | device=device,
472 | progress=progress,
473 | ):
474 | final = sample
475 | return final["sample"]
476 |
477 | def p_sample_loop_progressive(
478 | self,
479 | model,
480 | shape,
481 | noise=None,
482 | clip_denoised=True,
483 | denoised_fn=None,
484 | cond_fn=None,
485 | model_kwargs=None,
486 | device=None,
487 | progress=False,
488 | ):
489 | """
490 | Generate samples from the model and yield intermediate samples from
491 | each timestep of diffusion.
492 | Arguments are the same as p_sample_loop().
493 | Returns a generator over dicts, where each dict is the return value of
494 | p_sample().
495 | """
496 | if device is None:
497 | device = next(model.parameters()).device
498 | assert isinstance(shape, (tuple, list))
499 | if noise is not None:
500 | img = noise
501 | else:
502 | img = th.randn(*shape, device=device)
503 | indices = list(range(self.num_timesteps))[::-1]
504 |
505 | if progress:
506 | # Lazy import so that we don't depend on tqdm.
507 | from tqdm.auto import tqdm
508 |
509 | indices = tqdm(indices)
510 |
511 | for i in indices:
512 | t = th.tensor([i] * shape[0], device=device)
513 | with th.no_grad():
514 | out = self.p_sample(
515 | model,
516 | img,
517 | t,
518 | clip_denoised=clip_denoised,
519 | denoised_fn=denoised_fn,
520 | cond_fn=cond_fn,
521 | model_kwargs=model_kwargs,
522 | )
523 | yield out
524 | img = out["sample"]
525 |
526 | def ddim_sample(
527 | self,
528 | model,
529 | x,
530 | t,
531 | clip_denoised=True,
532 | denoised_fn=None,
533 | cond_fn=None,
534 | model_kwargs=None,
535 | eta=0.0,
536 | ):
537 | """
538 | Sample x_{t-1} from the model using DDIM.
539 | Same usage as p_sample().
540 | """
541 | out = self.p_mean_variance(
542 | model,
543 | x,
544 | t,
545 | clip_denoised=clip_denoised,
546 | denoised_fn=denoised_fn,
547 | model_kwargs=model_kwargs,
548 | )
549 | if cond_fn is not None:
550 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
551 |
552 | # Usually our model outputs epsilon, but we re-derive it
553 | # in case we used x_start or x_prev prediction.
554 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
555 |
556 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
557 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
558 | sigma = (
559 | eta
560 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
561 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
562 | )
563 | # Equation 12.
564 | noise = th.randn_like(x)
565 | mean_pred = (
566 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
567 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
568 | )
569 | nonzero_mask = (
570 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
571 | ) # no noise when t == 0
572 | sample = mean_pred + nonzero_mask * sigma * noise
573 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
574 |
575 | def ddim_reverse_sample(
576 | self,
577 | model,
578 | x,
579 | t,
580 | clip_denoised=True,
581 | denoised_fn=None,
582 | cond_fn=None,
583 | model_kwargs=None,
584 | eta=0.0,
585 | ):
586 | """
587 | Sample x_{t+1} from the model using DDIM reverse ODE.
588 | """
589 | assert eta == 0.0, "Reverse ODE only for deterministic path"
590 | out = self.p_mean_variance(
591 | model,
592 | x,
593 | t,
594 | clip_denoised=clip_denoised,
595 | denoised_fn=denoised_fn,
596 | model_kwargs=model_kwargs,
597 | )
598 | if cond_fn is not None:
599 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
600 | # Usually our model outputs epsilon, but we re-derive it
601 | # in case we used x_start or x_prev prediction.
602 | eps = (
603 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
604 | - out["pred_xstart"]
605 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
606 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
607 |
608 | # Equation 12. reversed
609 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
610 |
611 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
612 |
613 | def ddim_sample_loop(
614 | self,
615 | model,
616 | shape,
617 | noise=None,
618 | clip_denoised=True,
619 | denoised_fn=None,
620 | cond_fn=None,
621 | model_kwargs=None,
622 | device=None,
623 | progress=False,
624 | eta=0.0,
625 | ):
626 | """
627 | Generate samples from the model using DDIM.
628 | Same usage as p_sample_loop().
629 | """
630 | final = None
631 | for sample in self.ddim_sample_loop_progressive(
632 | model,
633 | shape,
634 | noise=noise,
635 | clip_denoised=clip_denoised,
636 | denoised_fn=denoised_fn,
637 | cond_fn=cond_fn,
638 | model_kwargs=model_kwargs,
639 | device=device,
640 | progress=progress,
641 | eta=eta,
642 | ):
643 | final = sample
644 | return final["sample"]
645 |
646 | def ddim_sample_loop_progressive(
647 | self,
648 | model,
649 | shape,
650 | noise=None,
651 | clip_denoised=True,
652 | denoised_fn=None,
653 | cond_fn=None,
654 | model_kwargs=None,
655 | device=None,
656 | progress=False,
657 | eta=0.0,
658 | ):
659 | """
660 | Use DDIM to sample from the model and yield intermediate samples from
661 | each timestep of DDIM.
662 | Same usage as p_sample_loop_progressive().
663 | """
664 | if device is None:
665 | device = next(model.parameters()).device
666 | assert isinstance(shape, (tuple, list))
667 | if noise is not None:
668 | img = noise
669 | else:
670 | img = th.randn(*shape, device=device)
671 | indices = list(range(self.num_timesteps))[::-1]
672 |
673 | if progress:
674 | # Lazy import so that we don't depend on tqdm.
675 | from tqdm.auto import tqdm
676 |
677 | indices = tqdm(indices)
678 |
679 | for i in indices:
680 | t = th.tensor([i] * shape[0], device=device)
681 | with th.no_grad():
682 | out = self.ddim_sample(
683 | model,
684 | img,
685 | t,
686 | clip_denoised=clip_denoised,
687 | denoised_fn=denoised_fn,
688 | cond_fn=cond_fn,
689 | model_kwargs=model_kwargs,
690 | eta=eta,
691 | )
692 | yield out
693 | img = out["sample"]
694 |
695 | def _vb_terms_bpd(
696 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
697 | ):
698 | """
699 | Get a term for the variational lower-bound.
700 | The resulting units are bits (rather than nats, as one might expect).
701 | This allows for comparison to other papers.
702 | :return: a dict with the following keys:
703 | - 'output': a shape [N] tensor of NLLs or KLs.
704 | - 'pred_xstart': the x_0 predictions.
705 | """
706 | true_mean, _, true_log_variance_clipped = \
707 | self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
708 | out = self.p_mean_variance(
709 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
710 | )
711 | kl = normal_kl(
712 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
713 | )
714 | kl = mean_flat(kl) / np.log(2.0)
715 |
716 | decoder_nll = -discretized_gaussian_log_likelihood(
717 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
718 | )
719 | assert decoder_nll.shape == x_start.shape
720 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
721 | # At the first timestep return the decoder NLL,
722 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
723 | output = th.where((t == 0), decoder_nll, kl)
724 | return {"output": output, "pred_xstart": out["pred_xstart"]}
725 |
726 | def patchify(self, imgs, imgs_LIDAR):
727 | """
728 | imgs: (N, 3, H, W)
729 | x: (N, L, patch_size**2 *3)
730 | """
731 | p = 3
732 | p_LIDAR = 3
733 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
734 |
735 | h = imgs.shape[2] // p
736 | w = imgs.shape[3] // p
737 |
738 | x = imgs.reshape(shape=(imgs.shape[0], imgs.shape[1], h, p, w, p))
739 | x = torch.einsum('nchpwq->nhwpqc', x)
740 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * imgs.shape[1]))
741 |
742 | x_LIDAR = imgs_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], imgs_LIDAR.shape[1], h, p_LIDAR, w, p_LIDAR))
743 | x_LIDAR = torch.einsum('nchpwq->nhwpqc', x_LIDAR)
744 | x_LIDAR = x_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], h * w, p_LIDAR**2 * imgs_LIDAR.shape[1]))
745 | return x, x_LIDAR
746 |
747 | def training_losses(self, model, x_start, x_start_LIDAR, t, model_kwargs=None, noise=None):
748 | """
749 | Compute training losses for a single timestep.
750 | :param model: the model to evaluate loss on.
751 | :param x_start: the [N x C x ...] tensor of inputs.
752 | :param t: a batch of timestep indices.
753 | :param model_kwargs: if not None, a dict of extra keyword arguments to
754 | pass to the model. This can be used for conditioning.
755 | :param noise: if specified, the specific Gaussian noise to try to remove.
756 | :return: a dict with the key "loss" containing a tensor of shape [N].
757 | Some mean or variance settings may also have other keys.
758 | """
759 | if model_kwargs is None:
760 | model_kwargs = {}
761 | if noise is None:
762 | noise = th.randn_like(x_start)
763 | noise_LIDAR = th.randn_like(x_start_LIDAR)
764 | x_t = self.q_sample(x_start, t, noise=noise)
765 | x_t_LIDAR = self.q_sample(x_start_LIDAR, t, noise=noise_LIDAR)
766 |
767 | terms = {}
768 |
769 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
770 | terms["loss"] = self._vb_terms_bpd(
771 | model=model,
772 | x_start=x_start,
773 | x_t=x_t,
774 | t=t,
775 | clip_denoised=False,
776 | model_kwargs=model_kwargs,
777 | )["output"]
778 | if self.loss_type == LossType.RESCALED_KL:
779 | terms["loss"] *= self.num_timesteps
780 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
781 | rec_loss, model_output, model_output_LIDAR, logits, mask, mask_LIDAR, Cross_rec_loss,Cross_pred_imgs, Cross_pred_imgs_LIDAR = model(x_t, x_t_LIDAR, t, **model_kwargs)
782 |
783 | if self.model_var_type in [
784 | ModelVarType.LEARNED,
785 | ModelVarType.LEARNED_RANGE,
786 | ]:
787 | B, C = x_t.shape[:2]
788 | assert model_output.shape == (B, C * 2 , *x_t.shape[2:])
789 | model_output, model_var_values = th.split(model_output, C, dim=1)
790 | # Learn the variance using the variational bound, but don't let
791 | # it affect our mean prediction.
792 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
793 | terms["vb"] = self._vb_terms_bpd(
794 | model=lambda *args, r=frozen_out: r,
795 | x_start=x_start,
796 | x_t=x_t,
797 | t=t,
798 | clip_denoised=False,
799 | )["output"]
800 | if self.loss_type == LossType.RESCALED_MSE:
801 | # Divide by 1000 for equivalence with initial implementation.
802 | # Without a factor of 1/1000, the VB term hurts the MSE term.
803 | terms["vb"] *= self.num_timesteps / 1000.0
804 | target = {
805 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
806 | x_start=x_start,
807 | x_t=x_t,
808 | t=t
809 | )[0],
810 | ModelMeanType.START_X: x_start,
811 | ModelMeanType.EPSILON: noise,
812 | }[self.model_mean_type]
813 | target_LIDAR = {
814 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
815 | x_start=x_start_LIDAR,
816 | x_t=x_t_LIDAR,
817 | t=t
818 | )[0],
819 | ModelMeanType.START_X: x_start_LIDAR,
820 | ModelMeanType.EPSILON: noise_LIDAR,
821 | }[self.model_mean_type]
822 | assert model_output.shape == target.shape == x_start.shape
823 |
824 |
825 | model_output, model_output_LIDAR = self.patchify(model_output, model_output_LIDAR)
826 | target, target_LIDAR = self.patchify(target, target_LIDAR)
827 | visible = torch.zeros_like(mask)
828 | visible_LIDAR = torch.zeros_like(mask_LIDAR)
829 | zeros_mask = torch.eq(mask, 0)
830 | ones_mask = torch.logical_not(zeros_mask)
831 | visible[zeros_mask] = 1
832 | visible[ones_mask] = 0
833 | zeros_mask = torch.eq(mask_LIDAR, 0)
834 | ones_mask = torch.logical_not(zeros_mask)
835 | visible_LIDAR[zeros_mask] = 1
836 | visible_LIDAR[ones_mask] = 0
837 |
838 | loss_mse = ((target - model_output) ** 2).mean(dim=-1)
839 | loss_mse_LIDAR = ((target_LIDAR - model_output_LIDAR) ** 2).mean(dim=-1)
840 | terms["mse"] = (loss_mse * visible).sum() / visible.sum()
841 | terms["mse_LIDAR"] = (loss_mse_LIDAR * visible_LIDAR).sum() / visible_LIDAR.sum()
842 | if "vb" in terms:
843 | terms["loss"] = terms["mse"] + terms["vb"]
844 | else:
845 | terms["loss"] = terms["mse"]
846 | else:
847 | raise NotImplementedError(self.loss_type)
848 |
849 | return terms, rec_loss, logits, Cross_rec_loss
850 |
851 | def _prior_bpd(self, x_start):
852 | """
853 | Get the prior KL term for the variational lower-bound, measured in
854 | bits-per-dim.
855 | This term can't be optimized, as it only depends on the encoder.
856 | :param x_start: the [N x C x ...] tensor of inputs.
857 | :return: a batch of [N] KL values (in bits), one per batch element.
858 | """
859 | batch_size = x_start.shape[0]
860 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
861 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
862 | kl_prior = normal_kl(
863 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
864 | )
865 | return mean_flat(kl_prior) / np.log(2.0)
866 |
867 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
868 | """
869 | Compute the entire variational lower-bound, measured in bits-per-dim,
870 | as well as other related quantities.
871 | :param model: the model to evaluate loss on.
872 | :param x_start: the [N x C x ...] tensor of inputs.
873 | :param clip_denoised: if True, clip denoised samples.
874 | :param model_kwargs: if not None, a dict of extra keyword arguments to
875 | pass to the model. This can be used for conditioning.
876 | :return: a dict containing the following keys:
877 | - total_bpd: the total variational lower-bound, per batch element.
878 | - prior_bpd: the prior term in the lower-bound.
879 | - vb: an [N x T] tensor of terms in the lower-bound.
880 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
881 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
882 | """
883 | device = x_start.device
884 | batch_size = x_start.shape[0]
885 |
886 | vb = []
887 | xstart_mse = []
888 | mse = []
889 | for t in list(range(self.num_timesteps))[::-1]:
890 | t_batch = th.tensor([t] * batch_size, device=device)
891 | noise = th.randn_like(x_start)
892 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
893 | # Calculate VLB term at the current timestep
894 | with th.no_grad():
895 | out = self._vb_terms_bpd(
896 | model,
897 | x_start=x_start,
898 | x_t=x_t,
899 | t=t_batch,
900 | clip_denoised=clip_denoised,
901 | model_kwargs=model_kwargs,
902 | )
903 | vb.append(out["output"])
904 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
905 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
906 | mse.append(mean_flat((eps - noise) ** 2))
907 |
908 | vb = th.stack(vb, dim=1)
909 | xstart_mse = th.stack(xstart_mse, dim=1)
910 | mse = th.stack(mse, dim=1)
911 |
912 | prior_bpd = self._prior_bpd(x_start)
913 | total_bpd = vb.sum(dim=1) + prior_bpd
914 | return {
915 | "total_bpd": total_bpd,
916 | "prior_bpd": prior_bpd,
917 | "vb": vb,
918 | "xstart_mse": xstart_mse,
919 | "mse": mse,
920 | }
921 |
922 |
923 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
924 | """
925 | Extract values from a 1-D numpy array for a batch of indices.
926 | :param arr: the 1-D numpy array.
927 | :param timesteps: a tensor of indices into the array to extract.
928 | :param broadcast_shape: a larger shape of K dimensions with the batch
929 | dimension equal to the length of timesteps.
930 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
931 | """
932 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
933 | while len(res.shape) < len(broadcast_shape):
934 | res = res[..., None]
935 | return res + th.zeros(broadcast_shape, device=timesteps.device)
936 |
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, x_LIDAR, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, x_LIDAR, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/generate_pic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn
3 |
4 | import matplotlib.pyplot as plt
5 |
6 | indianpines_colors = np.array([[0, 0, 0],
7 | [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0],
8 | [0, 255, 255], [255, 0, 255], [192, 192, 192], [128, 128, 128],
9 | [128, 0, 0], [128, 128, 0], [0, 128, 0], [128, 0, 128],
10 | [0, 128, 128], [0, 0, 128], [255, 165, 0], [255, 215, 0]])
11 | indianpines_colors = sklearn.preprocessing.minmax_scale(indianpines_colors, feature_range=(0, 1))
12 |
13 | def classification_map(img, ground_truth, dpi, save_path):
14 | fig = plt.figure(frameon=False)
15 | fig.set_size_inches(ground_truth.shape[1] * 2.0 / dpi, ground_truth.shape[0] * 2.0 / dpi)
16 |
17 | ax = plt.Axes(fig, [0., 0., 1., 1.])
18 | ax.set_axis_off()
19 | ax.xaxis.set_visible(False)
20 | ax.yaxis.set_visible(False)
21 | fig.add_axes(ax)
22 |
23 | ax.imshow(img)
24 | fig.savefig(save_path, dpi=dpi)
25 |
26 | return 0
27 |
28 |
29 | def generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, OA, halfsize, dataset, day_str, num, model_name):
30 | number_of_rows = np.size(image,0)
31 | number_of_columns = np.size(image,1)
32 |
33 | gt_thematic_map = np.zeros(shape=(number_of_rows, number_of_columns, 3))
34 | predicted_thematic_map = np.zeros(shape=(number_of_rows, number_of_columns, 3))
35 | for i in range(number_of_rows):
36 | for j in range(number_of_columns):
37 | gt_thematic_map[i, j, :] = indianpines_colors[gt[i,j]]
38 | predicted_thematic_map[i, j, :] = indianpines_colors[gt[i,j]]
39 | nclass = np.max(gt)
40 |
41 | fl = 0
42 | for i in range(nclass):
43 | print('test lable of class:',i)
44 | matrix = index[i]
45 | temprow = matrix[:,0]
46 | tempcol = matrix[:,1]
47 | m = len(temprow)
48 | fl = fl - nTrain_perClass[i] - nvalid_perClass[i]
49 | for j in range(nTrain_perClass[i] + nvalid_perClass[i], m):
50 | predicted_thematic_map[temprow[j], tempcol[j], :] = indianpines_colors[test_pred[fl + j]+1]
51 | fl = fl + m
52 |
53 |
54 | predicted_thematic_map = predicted_thematic_map[halfsize:number_of_rows -halfsize,halfsize:number_of_columns-halfsize,: ]
55 | gt_thematic_map = gt_thematic_map[halfsize:number_of_rows -halfsize,halfsize:number_of_columns-halfsize,: ]
56 | path = '.'
57 | classification_map(predicted_thematic_map, gt, 600,
58 | path + '/classification_maps/' + dataset + '_' + day_str +'_' + str(num) +'_OA_'+ str(round(OA, 2)) + '_' + 'predicted' + '.png')
59 | classification_map(gt_thematic_map, gt, 600,
60 | path + '/classification_maps/' + dataset + '_' + day_str +'_' + str(num) +'_OA_'+ str(round(OA, 2)) + '_' + 'gt' + '.png')
61 |
62 | return predicted_thematic_map, gt_thematic_map
63 |
64 |
65 |
--------------------------------------------------------------------------------
/hyper_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data.dataset import Dataset
4 |
5 |
6 | class HyperData(Dataset):
7 | def __init__(self, dataset, transfor):
8 | self.data = dataset[0].astype(np.float32)
9 | self.data_LIDAR = dataset[1].astype(np.float32)
10 | self.transformer = transfor
11 | self.labels = []
12 | for n in dataset[2]:
13 | self.labels += [int(n)]
14 |
15 | def __getitem__(self, index):
16 | label = self.labels[index]
17 | if self.transformer == None:
18 | img = torch.from_numpy(np.asarray(self.data[index,:,:,:]))
19 | img_LIDAR = torch.from_numpy(np.asarray(self.data_LIDAR[index,:,:,:]))
20 | return img, img_LIDAR, label
21 | elif len(self.transformer) == 2:
22 | img = torch.from_numpy(np.asarray(self.transformer[1](self.transformer[0](self.data[index,:,:,:]))))
23 | img_LIDAR = torch.from_numpy(np.asarray(self.transformer[1](self.transformer[0](self.data_LIDAR[index,:,:,:]))))
24 | return img, img_LIDAR, label
25 | else:
26 | img = torch.from_numpy(np.asarray(self.transformer[0](self.data[index,:,:,:])))
27 | img_LIDAR = torch.from_numpy(np.asarray(self.transformer[0](self.data_LIDAR[index,:,:,:])))
28 | return img, img_LIDAR, label
29 |
30 | def __len__(self):
31 | return len(self.labels)
32 |
33 | def __labels__(self):
34 | return self.labels
35 |
36 |
37 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 | import scipy.io as sio
4 | import torch
5 | import torch.nn as nn
6 | from torch import _assert
7 | from timm.models.vision_transformer import Block
8 | from timm.models.layers import to_2tuple, DropPath
9 | # from timm.models.vision_transformer import PatchEmbed, Block
10 | from pos_embed import get_2d_sincos_pos_embed
11 |
12 | class PatchEmbed(nn.Module):
13 | """ 2D Image to Patch Embedding
14 | """
15 | def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
16 | super().__init__()
17 | img_size = to_2tuple(img_size)
18 | patch_size = to_2tuple(patch_size)
19 |
20 | self.img_size = img_size
21 | self.patch_size = patch_size
22 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
23 | self.num_patches = self.grid_size[0] * self.grid_size[1]
24 | self.flatten = flatten
25 |
26 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
27 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
28 |
29 | def forward(self, x):
30 | B, C, H, W = x.shape
31 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
32 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
33 | x = self.proj(x)
34 | if self.flatten:
35 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
36 | x = self.norm(x)
37 | return x
38 |
39 | class TimestepEmbedder(nn.Module):
40 | """
41 | Embeds scalar timesteps into vector representations.
42 | """
43 | def __init__(self, hidden_size, frequency_embedding_size=256):
44 | super().__init__()
45 | self.mlp = nn.Sequential(
46 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
47 | nn.SiLU(),
48 | nn.Linear(hidden_size, hidden_size, bias=True),
49 | )
50 | self.frequency_embedding_size = frequency_embedding_size
51 |
52 | @staticmethod
53 | def timestep_embedding(t, dim, max_period=10000):
54 | """
55 | Create sinusoidal timestep embeddings.
56 | :param t: a 1-D Tensor of N indices, one per batch element.
57 | These may be fractional.
58 | :param dim: the dimension of the output.
59 | :param max_period: controls the minimum frequency of the embeddings.
60 | :return: an (N, D) Tensor of positional embeddings.
61 | """
62 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
63 | half = dim // 2
64 | freqs = torch.exp(
65 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
66 | ).to(device=t.device)
67 | args = t[:, None].float() * freqs[None]
68 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
69 | if dim % 2:
70 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
71 | return embedding
72 |
73 | def forward(self, t):
74 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
75 | t_emb = self.mlp(t_freq)
76 | return t_emb
77 |
78 | class MaskedAutoencoderViT(nn.Module):
79 | """ Masked Autoencoder with VisionTransformer backbone
80 | """
81 | def __init__(self, img_size=224, patch_size=16, in_chans=3, in_chans_LIDAR = 1,hid_chans = 32,hid_chans_LIDAR = 32,
82 | embed_dim=1024, depth=24, num_heads=16,
83 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
84 | mlp_ratio=4., drop_rate=0.,attn_drop_rate=0., drop_path_rate=0.,
85 | norm_layer=nn.LayerNorm, norm_pix_loss=False,
86 | cls_hidden_mlp=256, nb_classes=1000, global_pool=False,
87 | mlp_depth=2):
88 | super().__init__()
89 | # --------------------------------------------------------------------------
90 | #HSI
91 | # MAE dimensionality reduction/expansion specifics
92 | self.dimen_redu = nn.Sequential(
93 | nn.Conv2d(in_chans, hid_chans, kernel_size=1, stride=1, padding=0, bias=True),
94 | nn.BatchNorm2d(hid_chans),
95 | nn.ReLU(),
96 | nn.Conv2d(hid_chans, hid_chans, 1, 1, 0, bias=True),
97 | nn.BatchNorm2d(hid_chans),
98 | nn.ReLU(),
99 | )
100 |
101 | self.t_embedder = TimestepEmbedder(embed_dim)
102 | self.dimen_expa = nn.Conv2d(hid_chans, in_chans, kernel_size=1, stride=1, padding=0, bias=True)
103 | self.h = img_size[0]// patch_size
104 | self.w = img_size[1]// patch_size
105 | # --------------------------------------------------------------------------
106 | # MAE encoder specifics
107 | self.patch_embed = PatchEmbed(img_size, patch_size, hid_chans , embed_dim)
108 | num_patches = self.patch_embed.num_patches
109 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
111 | self.blocks = nn.ModuleList([
112 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate,attn_drop=attn_drop_rate, drop_path=drop_path_rate, qkv_bias=True, norm_layer=norm_layer)
113 | for i in range(depth)])
114 | self.norm = norm_layer(embed_dim)
115 | self.projection_en_mask = nn.Linear(embed_dim, embed_dim)
116 | self.projection_en_visible = nn.Linear(embed_dim, embed_dim)
117 | # --------------------------------------------------------------------------
118 |
119 | # --------------------------------------------------------------------------
120 | # MAE decoder specifics
121 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
122 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
123 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
124 | self.decoder_blocks = nn.ModuleList([
125 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
126 | for i in range(decoder_depth)])
127 | self.decoder_norm = norm_layer(decoder_embed_dim)
128 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * hid_chans, bias=True) # decoder to patch
129 | self.projection_de = nn.Linear(decoder_embed_dim, decoder_embed_dim)
130 | if cls_hidden_mlp == 0:
131 | self.cls_head = nn.Linear(embed_dim, nb_classes)
132 | else:
133 | assert mlp_depth in [2], "mlp depth should be 2"
134 | if mlp_depth == 2:
135 | self.cls_head = nn.Sequential(
136 | nn.Linear(embed_dim*2, cls_hidden_mlp),
137 | nn.BatchNorm1d(cls_hidden_mlp),
138 | nn.ReLU(inplace=True),
139 | nn.Linear(cls_hidden_mlp, embed_dim),
140 | nn.BatchNorm1d(embed_dim),
141 | nn.ReLU(inplace=True),
142 | nn.Linear(embed_dim, nb_classes),
143 | )
144 | # --------------------------------------------------------------------------
145 | self.global_pool = global_pool
146 | self.norm_pix_loss = norm_pix_loss
147 |
148 | # --------------------------------------------------------------------------
149 | # LIDAR
150 | # MAE dimensionality reduction/expansion specifics
151 | self.dimen_redu_LIDAR = nn.Sequential(
152 | nn.Conv2d(in_chans_LIDAR, hid_chans_LIDAR, kernel_size=1, stride=1, padding=0, bias=True),
153 | nn.BatchNorm2d(hid_chans_LIDAR),
154 | nn.ReLU(),
155 |
156 | nn.Conv2d(hid_chans_LIDAR, hid_chans_LIDAR, 1, 1, 0, bias=True),
157 | nn.BatchNorm2d(hid_chans_LIDAR),
158 | nn.ReLU(),
159 | )
160 | self.t_embedder_LIDAR = TimestepEmbedder(embed_dim)
161 | self.dimen_expa_LIDAR = nn.Conv2d(hid_chans_LIDAR, in_chans_LIDAR, kernel_size=1, stride=1, padding=0, bias=True)
162 | self.h = img_size[0] // patch_size
163 | self.w = img_size[1] // patch_size
164 | # --------------------------------------------------------------------------
165 | # MAE encoder specifics
166 | self.patch_embed_LIDAR = PatchEmbed(img_size, patch_size, hid_chans_LIDAR, embed_dim)
167 | num_patches = self.patch_embed_LIDAR.num_patches
168 |
169 | self.cls_token_LIDAR = nn.Parameter(torch.zeros(1, 1, embed_dim))
170 | self.pos_embed_LIDAR = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
171 | requires_grad=False) # fixed sin-cos embedding
172 |
173 | self.blocks_LIDAR = nn.ModuleList([
174 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
175 | qkv_bias=True, norm_layer=norm_layer)
176 | for i in range(depth)])
177 | self.norm_LIDAR = norm_layer(embed_dim)
178 | self.projection_en_mask_LIDAR = nn.Linear(embed_dim, embed_dim)
179 | self.projection_en_visible_LIDAR = nn.Linear(embed_dim, embed_dim)
180 | # --------------------------------------------------------------------------
181 |
182 | # --------------------------------------------------------------------------
183 | # MAE decoder specifics
184 | self.decoder_embed_LIDAR = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
185 |
186 | self.mask_token_LIDAR = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
187 |
188 | self.decoder_pos_embed_LIDAR = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
189 | requires_grad=False) # fixed sin-cos embedding
190 |
191 | self.decoder_blocks_LIDAR = nn.ModuleList([
192 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
193 | for i in range(decoder_depth)])
194 |
195 | self.decoder_norm_LIDAR = norm_layer(decoder_embed_dim)
196 | self.decoder_pred_LIDAR = nn.Linear(decoder_embed_dim, patch_size ** 2 * hid_chans, bias=True) # decoder to patch
197 | self.projection_de_LIDAR = nn.Linear(decoder_embed_dim, decoder_embed_dim)
198 | # --------------------------------------------------------------------------
199 | self.global_pool_LIDAR = global_pool
200 | self.norm_pix_loss_LIDAR = norm_pix_loss
201 |
202 |
203 | self.initialize_weights()
204 |
205 | def initialize_weights(self):
206 | # initialization`
207 |
208 | # initialize (and freeze) pos_embed by sin-cos embedding
209 | #HSI
210 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
211 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
212 |
213 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
214 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
215 |
216 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
217 | w = self.patch_embed.proj.weight.data
218 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
219 |
220 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
221 | torch.nn.init.normal_(self.cls_token, std=.02)
222 | # Initialize timestep embedding MLP:
223 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
224 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
225 |
226 | #LIDAR
227 | pos_embed_LIDAR = get_2d_sincos_pos_embed(self.pos_embed_LIDAR.shape[-1], int(self.patch_embed_LIDAR.num_patches**.5), cls_token=True)
228 | self.pos_embed_LIDAR.data.copy_(torch.from_numpy(pos_embed_LIDAR).float().unsqueeze(0))
229 |
230 | decoder_pos_embed_LIDAR = get_2d_sincos_pos_embed(self.decoder_pos_embed_LIDAR.shape[-1], int(self.patch_embed_LIDAR.num_patches**.5), cls_token=True)
231 | self.decoder_pos_embed_LIDAR.data.copy_(torch.from_numpy(decoder_pos_embed_LIDAR).float().unsqueeze(0))
232 |
233 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
234 | w_LIDAR = self.patch_embed_LIDAR.proj.weight.data
235 | torch.nn.init.xavier_uniform_(w_LIDAR.view([w_LIDAR.shape[0], -1]))
236 |
237 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
238 | torch.nn.init.normal_(self.cls_token_LIDAR, std=.02)
239 |
240 | # Initialize timestep embedding MLP:
241 | nn.init.normal_(self.t_embedder_LIDAR.mlp[0].weight, std=0.02)
242 | nn.init.normal_(self.t_embedder_LIDAR.mlp[2].weight, std=0.02)
243 |
244 | # initialize nn.Linear and nn.LayerNorm
245 | self.apply(self._init_weights)
246 |
247 | def _init_weights(self, m):
248 | if isinstance(m, nn.Linear):
249 | # we use xavier_uniform following official JAX ViT:
250 | torch.nn.init.xavier_uniform_(m.weight)
251 | if isinstance(m, nn.Linear) and m.bias is not None:
252 | nn.init.constant_(m.bias, 0)
253 | elif isinstance(m, nn.LayerNorm):
254 | nn.init.constant_(m.bias, 0)
255 | nn.init.constant_(m.weight, 1.0)
256 |
257 | def patchify(self, imgs, imgs_LIDAR):
258 | """
259 | imgs: (N, 3, H, W)
260 | x: (N, L, patch_size**2 *3)
261 | """
262 | p = self.patch_embed.patch_size[0]
263 | p_LIDAR = self.patch_embed_LIDAR.patch_size[0]
264 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
265 |
266 | h = imgs.shape[2] // p
267 | w = imgs.shape[3] // p
268 |
269 | x = imgs.reshape(shape=(imgs.shape[0], imgs.shape[1], h, p, w, p))
270 | x = torch.einsum('nchpwq->nhwpqc', x)
271 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * imgs.shape[1]))
272 |
273 | x_LIDAR = imgs_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], imgs_LIDAR.shape[1], h, p_LIDAR, w, p_LIDAR))
274 | x_LIDAR = torch.einsum('nchpwq->nhwpqc', x_LIDAR)
275 | x_LIDAR = x_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], h * w, p_LIDAR**2 * imgs_LIDAR.shape[1]))
276 | return x, x_LIDAR
277 |
278 | def unpatchify(self, x, x_LIDAR):
279 | """
280 | x: (N, L, patch_size**2 *3)
281 | imgs: (N, 3, H, W)
282 | """
283 | p = self.patch_embed.patch_size[0]
284 | p_LIDAR = self.patch_embed_LIDAR.patch_size[0]
285 | h = self.h
286 | w = self.w
287 | assert h * w == x.shape[1]
288 | assert h * w == x_LIDAR.shape[1]
289 |
290 | hid_chans = int(x.shape[2]/(p**2))
291 | hid_chans_LIDAR = int(x_LIDAR.shape[2]/(p_LIDAR**2))
292 |
293 | x = x.reshape(shape=(x.shape[0], h, w, p, p, hid_chans))
294 | x = torch.einsum('nhwpqc->nchpwq', x)
295 | imgs = x.reshape(shape=(x.shape[0], hid_chans, h * p, w * p))
296 |
297 | x_LIDAR = x_LIDAR.reshape(shape=(x_LIDAR.shape[0], h, w, p_LIDAR, p_LIDAR, hid_chans_LIDAR))
298 | x_LIDAR = torch.einsum('nhwpqc->nchpwq', x_LIDAR)
299 | imgs_LIDAR = x_LIDAR.reshape(shape=(x_LIDAR.shape[0], hid_chans_LIDAR, h * p_LIDAR, w * p_LIDAR))
300 | return imgs, imgs_LIDAR
301 |
302 | def random_masking(self, x, x_LIDAR, mask_ratio):
303 | """
304 | Perform per-sample random masking by per-sample shuffling.
305 | Per-sample shuffling is done by argsort random noise.
306 | x: [N, L, D], sequence
307 | """
308 | N, L, D = x.shape # batch, length, dim
309 | len_keep = int(L * (1 - mask_ratio))
310 |
311 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
312 |
313 | # sort noise for each sample
314 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
315 | ids_restore = torch.argsort(ids_shuffle, dim=1)
316 |
317 | # keep the first subset
318 | ids_keep = ids_shuffle[:, :len_keep]
319 | x_visible = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
320 |
321 |
322 | noise_LIDAR = torch.rand(N, L, device=x.device)
323 | ids_shuffle_LIDAR = torch.argsort(noise_LIDAR, dim=1) # ascend: small is keep, large is remove
324 | ids_restore_LIDAR = torch.argsort(ids_shuffle_LIDAR, dim=1)
325 |
326 | # keep the first subset
327 | ids_keep_LIDAR = ids_restore_LIDAR[:, :len_keep]
328 | x_LIDAR_visible = torch.gather(x_LIDAR, dim=1, index=ids_keep_LIDAR.unsqueeze(-1).repeat(1, 1, D))
329 |
330 | # generate the binary mask: 0 is keep, 1 is remove
331 | mask = torch.ones([N, L], device=x.device)
332 | mask[:, :len_keep] = 0
333 | mask_LIDAR = torch.ones([N, L], device=x.device)
334 | mask_LIDAR[:, :len_keep] = 0
335 | # unshuffle to get the binary mask
336 | mask = torch.gather(mask, dim=1, index=ids_restore)
337 | mask_LIDAR = torch.gather(mask_LIDAR, dim=1, index=ids_restore_LIDAR)
338 | return x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR
339 |
340 | def preprocessing(self, x, x_LIDAR, mask_ratio):
341 | # embed patches
342 | x = self.dimen_redu(x)
343 | x = self.patch_embed(x)
344 |
345 | x_LIDAR = self.dimen_redu_LIDAR(x_LIDAR)
346 | x_LIDAR = self.patch_embed_LIDAR(x_LIDAR)
347 |
348 | # add pos embed w/o cls token
349 | x = x + self.pos_embed[:, 1:, :]
350 | x_LIDAR = x_LIDAR + self.pos_embed_LIDAR[:, 1:, :]
351 |
352 | # masking: length -> length * mask_ratio
353 | x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR = self.random_masking(x, x_LIDAR, mask_ratio)
354 | return x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR
355 |
356 | def forward_encoder(self, x, x_LIDAR, t):
357 | # append cls token
358 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
359 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
360 |
361 | t_token = self.t_embedder(torch.squeeze(t))
362 | t_token = torch.unsqueeze(t_token, dim=1)
363 | cls_tokens = cls_tokens + t_token
364 |
365 | x = torch.cat((cls_tokens, x), dim=1)
366 |
367 | cls_token_LIDAR = self.cls_token_LIDAR + self.pos_embed_LIDAR[:, :1, :]
368 | cls_tokens_LIDAR = cls_token_LIDAR.expand(x_LIDAR.shape[0], -1, -1)
369 |
370 | t_token_LIDAR = self.t_embedder_LIDAR(torch.squeeze(t))
371 | t_token_LIDAR = torch.unsqueeze(t_token_LIDAR, dim=1)
372 | cls_tokens_LIDAR = cls_tokens_LIDAR + t_token_LIDAR
373 |
374 | x_LIDAR = torch.cat((cls_tokens_LIDAR, x_LIDAR), dim=1)
375 |
376 | # apply Transformer blocks
377 | for blk in self.blocks:
378 | x = blk(x)
379 | x = self.norm(x)
380 |
381 | for blk_LIDAR in self.blocks:
382 | x_LIDAR = blk_LIDAR(x_LIDAR)
383 | x_LIDAR = self.norm(x_LIDAR)
384 |
385 | return x, x_LIDAR
386 |
387 | def forward_decoder_A(self, x, ids_restore):
388 | # embed token
389 | x = self.decoder_embed(x)
390 |
391 | # append mask tokens to sequence
392 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
393 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
394 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
395 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
396 |
397 | # add pos embed
398 | x = x + self.decoder_pos_embed[:, :x.shape[1], :]
399 |
400 | # apply Transformer blocks
401 | for blk in self.decoder_blocks:
402 | x = blk(x)
403 | x = self.decoder_norm(x)
404 |
405 | x = self.projection_de(x)
406 | return x
407 |
408 | def forward_decoder_B(self, x_LIDAR, ids_restore_LIDAR):
409 | # embed token
410 | x_LIDAR = self.decoder_embed_LIDAR(x_LIDAR)
411 |
412 | # append mask tokens to sequence
413 | mask_tokens_LIDAR = self.mask_token_LIDAR.repeat(x_LIDAR.shape[0], ids_restore_LIDAR.shape[1] + 1 - x_LIDAR.shape[1], 1)
414 | x_LIDAR_ = torch.cat([x_LIDAR[:, 1:, :], mask_tokens_LIDAR], dim=1) # no cls token
415 | x_LIDAR_ = torch.gather(x_LIDAR_, dim=1, index=ids_restore_LIDAR.unsqueeze(-1).repeat(1, 1, x_LIDAR.shape[2])) # unshuffle
416 | x_LIDAR = torch.cat([x_LIDAR[:, :1, :], x_LIDAR_], dim=1) # append cls token
417 |
418 | # add pos embed
419 | x_LIDAR = x_LIDAR + self.decoder_pos_embed_LIDAR[:, :x_LIDAR.shape[1], :]
420 |
421 | # apply Transformer blocks
422 | for blk_LIDAR in self.decoder_blocks_LIDAR:
423 | x_LIDAR = blk_LIDAR(x_LIDAR)
424 | x_LIDAR = self.decoder_norm_LIDAR(x_LIDAR)
425 |
426 | x_LIDAR = self.projection_de_LIDAR(x_LIDAR)
427 | return x_LIDAR
428 |
429 |
430 | def reconstruction(self, x, x_LIDAR):
431 | x = self.decoder_pred(x)
432 | x_LIDAR = self.decoder_pred_LIDAR(x_LIDAR)
433 | # # remove cls token
434 | x = x[:, 1:, :]
435 | x_LIDAR = x_LIDAR[:, 1:, :]
436 |
437 |
438 | x, x_LIDAR = self.unpatchify(x, x_LIDAR)
439 | x = self.dimen_expa(x)
440 | x_LIDAR = self.dimen_expa_LIDAR(x_LIDAR)
441 |
442 | pred_Reconstruction, pred_LIDAR_Reconstruction = self.patchify(x, x_LIDAR)
443 | return x, x_LIDAR, pred_Reconstruction, pred_LIDAR_Reconstruction
444 |
445 | def forward_classification(self, x, x_LIDAR):
446 | if self.global_pool:
447 | feat = x[:, 1:, :].mean(dim=1) # global pool without cls token
448 | else:
449 | feat = x[:, 0, :] # with cls token
450 |
451 | if self.global_pool_LIDAR:
452 | feat_LIDAR = x_LIDAR[:, 1:, :].mean(dim=1) # global pool without cls token
453 | else:
454 | feat_LIDAR = x_LIDAR[:, 0, :] # with cls token
455 | feat_all = torch.cat((feat, feat_LIDAR),dim=1)
456 | logits = self.cls_head(feat_all)
457 | return logits
458 |
459 | def Reconstruction_loss(self, imgs, pred, imgs_LIDAR, pred_LIDAR, mask, mask_LIDAR):
460 | """
461 | imgs: [N, 3, H, W]
462 | pred: [N, L, p*p*3]
463 | mask: [N, L], 0 is keep, 1 is remove,
464 | """
465 | target, target_LIDAR = self.patchify(imgs, imgs_LIDAR)
466 | if self.norm_pix_loss:
467 | mean = target.mean(dim=-1, keepdim=True)
468 | var = target.var(dim=-1, keepdim=True)
469 | target = (target - mean) / (var + 1.e-6)**.5
470 |
471 | loss = (pred - target) ** 2
472 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
473 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
474 |
475 | if self.norm_pix_loss_LIDAR:
476 | mean_LIDAR = target_LIDAR.mean(dim=-1, keepdim=True)
477 | var_LIDAR = target_LIDAR.var(dim=-1, keepdim=True)
478 | target_LIDAR = (target_LIDAR - mean_LIDAR) / (var_LIDAR + 1.e-6)**.5
479 |
480 | loss_LIDAR = (pred_LIDAR - target_LIDAR) ** 2
481 | loss_LIDAR = loss_LIDAR.mean(dim=-1) # [N, L], mean loss per patch
482 | loss_LIDAR = (loss_LIDAR * mask_LIDAR).sum() / mask_LIDAR.sum() # mean loss on removed patches
483 |
484 | loss_all = loss + loss_LIDAR
485 | return loss_all
486 |
487 | def forward(self, imgs, imgs_LIDAR, t, y, mask_ratio):
488 | #preprocessing
489 | x_visible, x_LIDAR_visible, mask, ids_restore, mask_LIDAR, ids_restore_LIDAR = self.preprocessing(imgs, imgs_LIDAR, mask_ratio)
490 |
491 | #visible_process
492 | feature_visible, feature_visible_LIDAR = self.forward_encoder(x_visible, x_LIDAR_visible, t)
493 | pred = self.forward_decoder_A(feature_visible, ids_restore)
494 | pred_LIDAR = self.forward_decoder_B(feature_visible_LIDAR, ids_restore_LIDAR)
495 |
496 | # Reconstruction branch
497 | pred_imgs, pred_imgs_LIDAR, pred_Reconstruction, pred_LIDAR_Reconstruction = self.reconstruction(pred, pred_LIDAR) # [N, L, p*p*3]
498 |
499 | # Cross Reconstruction branch
500 | # Cross_pred = self.forward_decoder_A(feature_visible_LIDAR, ids_restore_LIDAR)
501 | # Cross_pred_LIDAR = self.forward_decoder_B(feature_visible, ids_restore)
502 | # Cross_pred_imgs, Cross_pred_imgs_LIDAR, Cross_pred_Reconstruction, Cross_pred_LIDAR_Reconstruction = self.reconstruction(Cross_pred, Cross_pred_LIDAR)
503 |
504 | # Classification branch
505 | logits = self.forward_classification(feature_visible, feature_visible_LIDAR)
506 | # return pred_imgs, pred_imgs_LIDAR, logits, mask, mask_LIDAR, Cross_pred_imgs, Cross_pred_imgs_LIDAR
507 | return pred_imgs, pred_imgs_LIDAR, logits, mask, mask_LIDAR
508 |
509 |
510 | def patchify(imgs, imgs_LIDAR, size):
511 | """
512 | imgs: (N, 3, H, W)
513 | x: (N, L, patch_size**2 *3)
514 | """
515 | p = size
516 | p_LIDAR = size
517 | # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
518 |
519 | h = imgs.shape[2] // p
520 | w = imgs.shape[3] // p
521 |
522 | x = imgs.reshape(shape=(imgs.shape[0], imgs.shape[1], h, p, w, p))
523 | x = torch.einsum('nchpwq->nhwpqc', x)
524 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * imgs.shape[1]))
525 |
526 | x_LIDAR = imgs_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], imgs_LIDAR.shape[1], h, p_LIDAR, w, p_LIDAR))
527 | x_LIDAR = torch.einsum('nchpwq->nhwpqc', x_LIDAR)
528 | x_LIDAR = x_LIDAR.reshape(shape=(imgs_LIDAR.shape[0], h * w, p_LIDAR**2 * imgs_LIDAR.shape[1]))
529 | return x, x_LIDAR
530 |
531 | def DDPM_LOSS(model_output, model_output_LIDAR, target, target_LIDAR, mask , mask_LIDAR, size):
532 | model_output, model_output_LIDAR = patchify(model_output, model_output_LIDAR, size)
533 | target, target_LIDAR = patchify(target, target_LIDAR, size)
534 |
535 | loss_mse = ((target - model_output) ** 2).mean(dim=-1)
536 | loss_mse_LIDAR = ((target_LIDAR - model_output_LIDAR) ** 2).mean(dim=-1)
537 | loss_mse_m = (loss_mse * mask).sum() / mask.sum()
538 | loss_mse_LIDAR_m = (loss_mse_LIDAR * mask_LIDAR).sum() / mask_LIDAR.sum()
539 | visible = torch.zeros_like(mask)
540 | visible_LIDAR = torch.zeros_like(mask_LIDAR)
541 | zeros_mask = torch.eq(mask, 0)
542 | ones_mask = torch.logical_not(zeros_mask)
543 | visible[zeros_mask] = 1
544 | visible[ones_mask] = 0
545 | zeros_mask = torch.eq(mask_LIDAR, 0)
546 | ones_mask = torch.logical_not(zeros_mask)
547 | visible_LIDAR[zeros_mask] = 1
548 | visible_LIDAR[ones_mask] = 0
549 | loss_mse_v = (loss_mse * visible).sum() / visible.sum()
550 | loss_mse_LIDAR_v = (loss_mse_LIDAR * visible_LIDAR).sum() / visible_LIDAR.sum()
551 |
552 | # loss_mse = ((target - model_output) ** 2).mean(dim=-1).mean()
553 | # loss_mse_LIDAR = ((target_LIDAR - model_output_LIDAR) ** 2).mean(dim=-1).mean()
554 | return loss_mse_m, loss_mse_LIDAR_m, loss_mse_v, loss_mse_LIDAR_v
555 |
556 | class vit_HSI_LIDAR(nn.Module):
557 | """ Masked Autoencoder's'backbone
558 | """
559 |
560 | def __init__(self, img_size=(224, 224), patch_size=16, num_classes=1000, in_chans=3, in_chans_LIDAR = 1, hid_chans=32,
561 | hid_chans_LIDAR=128,embed_dim=1024, depth=24, num_heads=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
562 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, global_pool=False):
563 | super().__init__()
564 | self.patch_size = patch_size
565 |
566 | # --------------------------------------------------------------------------
567 | # HSI
568 | # MAE encoder specifics
569 | self.dimen_redu = nn.Sequential(
570 | nn.Conv2d(in_chans, hid_chans, kernel_size=1, stride=1, padding=0, bias=True),
571 | nn.BatchNorm2d(hid_chans),
572 | nn.ReLU(),
573 |
574 | nn.Conv2d(hid_chans, hid_chans, 1, 1, 0, bias=True),
575 | nn.BatchNorm2d(hid_chans),
576 | nn.ReLU(),
577 | )
578 |
579 | # --------------------------------------------------------------------------
580 | # MAE encoder specifics
581 | self.patch_embed = PatchEmbed(img_size, patch_size, hid_chans, embed_dim)
582 | num_patches = self.patch_embed.num_patches
583 |
584 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
585 | self.t_embedder = TimestepEmbedder(embed_dim)
586 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
587 | requires_grad=True) # fixed sin-cos embedding
588 |
589 | self.blocks = nn.ModuleList([
590 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
591 | qkv_bias=True, norm_layer=norm_layer)
592 | for i in range(depth)])
593 | self.norm = norm_layer(embed_dim)
594 | self.head = nn.Linear(embed_dim * 2, embed_dim, bias=True)
595 |
596 | self.global_pool = global_pool
597 | if self.global_pool:
598 | self.fc_norm = norm_layer(embed_dim)
599 | del self.norm
600 |
601 | # LIDAR
602 | # MAE encoder specifics
603 | self.dimen_redu_LIDAR = nn.Sequential(
604 | nn.Conv2d(in_chans_LIDAR, hid_chans_LIDAR, kernel_size=1, stride=1, padding=0, bias=True),
605 | nn.BatchNorm2d(hid_chans_LIDAR),
606 | nn.ReLU(),
607 |
608 | nn.Conv2d(hid_chans_LIDAR, hid_chans_LIDAR, 1, 1, 0, bias=True),
609 | nn.BatchNorm2d(hid_chans_LIDAR),
610 | nn.ReLU(),
611 | )
612 |
613 | # --------------------------------------------------------------------------
614 | # MAE encoder specifics
615 | self.patch_embed_LIDAR = PatchEmbed(img_size, patch_size, hid_chans_LIDAR, embed_dim)
616 | num_patches = self.patch_embed_LIDAR.num_patches
617 |
618 | self.cls_token_LIDAR = nn.Parameter(torch.zeros(1, 1, embed_dim))
619 | self.t_embedder_LIDAR = TimestepEmbedder(embed_dim)
620 | self.pos_embed_LIDAR = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
621 | requires_grad=True) # fixed sin-cos embedding
622 |
623 | self.blocks_LIDAR = nn.ModuleList([
624 | Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
625 | qkv_bias=True, norm_layer=norm_layer)
626 | for i in range(depth)])
627 | self.norm_LIDAR = norm_layer(embed_dim)
628 | self.global_pool_LIDAR = global_pool
629 | if self.global_pool_LIDAR:
630 | self.fc_norm_LIDAR = norm_layer(embed_dim)
631 | del self.norm_LIDAR
632 |
633 | def initialize_weights(self):
634 | # initialization
635 | # initialize (and freeze) pos_embed by sin-cos embedding
636 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
637 | cls_token=True)
638 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
639 |
640 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
641 | w = self.patch_embed.proj.weight.data
642 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
643 |
644 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
645 | torch.nn.init.normal_(self.cls_token, std=.02)
646 |
647 | # initialize (and freeze) pos_embed by sin-cos embedding
648 | pos_embed_LIDAR = get_2d_sincos_pos_embed(self.pos_embed_LIDAR.shape[-1], int(self.patch_embed_LIDAR.num_patches ** .5),
649 | cls_token=True)
650 | self.pos_embed_LIDAR.data.copy_(torch.from_numpy(pos_embed_LIDAR).float().unsqueeze(0))
651 |
652 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
653 | w_LIDAR = self.patch_embed_LIDAR.proj.weight.data
654 | torch.nn.init.xavier_uniform_(w_LIDAR.view([w_LIDAR.shape[0], -1]))
655 |
656 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
657 | torch.nn.init.normal_(self.cls_token_LIDAR, std=.02)
658 |
659 | # Initialize timestep embedding MLP:
660 | nn.init.normal_(self.t_embedder_LIDAR.mlp[0].weight, std=0.02)
661 | nn.init.normal_(self.t_embedder_LIDAR.mlp[2].weight, std=0.02)
662 |
663 | # initialize nn.Linear and nn.LayerNorm
664 | self.apply(self._init_weights)
665 |
666 | def _init_weights(self, m):
667 | if isinstance(m, nn.Linear):
668 | # we use xavier_uniform following official JAX ViT:
669 | torch.nn.init.xavier_uniform_(m.weight)
670 | if isinstance(m, nn.Linear) and m.bias is not None:
671 | nn.init.constant_(m.bias, 0)
672 | elif isinstance(m, nn.LayerNorm):
673 | nn.init.constant_(m.bias, 0)
674 | nn.init.constant_(m.weight, 1.0)
675 |
676 | def forward_features(self, x, x_LIDAR, t):
677 | x = self.dimen_redu(x)
678 | x_LIDAR = self.dimen_redu_LIDAR(x_LIDAR)
679 |
680 | # embed patches
681 | x = self.patch_embed(x)
682 | x_LIDAR = self.patch_embed_LIDAR(x_LIDAR)
683 |
684 | # add pos embed w/o cls token
685 | x = x + self.pos_embed[:, 1:, :]
686 | x_LIDAR = x_LIDAR + self.pos_embed_LIDAR[:, 1:, :]
687 |
688 | # append cls token
689 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
690 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
691 |
692 | t_token = self.t_embedder(t)
693 | t_token = torch.unsqueeze(t_token, dim=1)
694 | cls_tokens = cls_tokens + t_token
695 |
696 | x = torch.cat((cls_tokens, x), dim=1)
697 |
698 |
699 | cls_token_LIDAR = self.cls_token_LIDAR + self.pos_embed_LIDAR[:, :1, :]
700 | cls_tokens_LIDAR = cls_token_LIDAR.expand(x_LIDAR.shape[0], -1, -1)
701 |
702 | t_token_LIDAR = self.t_embedder_LIDAR(t)
703 | t_token_LIDAR = torch.unsqueeze(t_token_LIDAR, dim=1)
704 | cls_tokens_LIDAR = cls_tokens_LIDAR + t_token_LIDAR
705 |
706 | x_LIDAR = torch.cat((cls_tokens_LIDAR, x_LIDAR), dim=1)
707 |
708 | # apply Transformer blocks
709 | for blk in self.blocks:
710 | x = blk(x)
711 | if self.global_pool:
712 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
713 | outcome = self.fc_norm(x)
714 | else:
715 | x = self.norm(x)
716 | outcome = x[:, 0]
717 |
718 | for blk_LIDAR in self.blocks:
719 | x_LIDAR = blk_LIDAR(x_LIDAR)
720 | if self.global_pool_LIDAR:
721 | x_LIDAR = x_LIDAR[:, 1:, :].mean(dim=1) # global pool without cls token
722 | outcome_LIDAR = self.fc_norm(x_LIDAR)
723 | else:
724 | x_LIDAR = self.norm(x_LIDAR)
725 | outcome_LIDAR = x_LIDAR[:, 0]
726 |
727 | outcome_all = torch.cat((outcome, outcome_LIDAR),dim = 1)
728 | return outcome_all
729 |
730 | def forward(self, x, x_LIDAR, t):
731 | x = self.forward_features(x, x_LIDAR, t)
732 | x = self.head(x)
733 | return x
734 |
735 |
736 | def mae_vit_HSIandLIDAR_patch3(**kwargs):
737 | model = MaskedAutoencoderViT(
738 | patch_size=1,
739 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
740 | return model
741 |
742 | def vit_HSI_LIDAR_patch3(**kwargs):
743 | model = vit_HSI_LIDAR(
744 | patch_size=1,
745 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
746 | return model
--------------------------------------------------------------------------------
/pos_embed.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Position embedding utils
3 | # --------------------------------------------------------
4 |
5 | import numpy as np
6 |
7 | import torch
8 |
9 | # --------------------------------------------------------
10 | # 2D sine-cosine position embedding
11 | # References:
12 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
13 | # MoCo v3: https://github.com/facebookresearch/moco-v3
14 | # --------------------------------------------------------
15 | def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
16 | """
17 | grid_size: int of the grid height and width
18 | return:
19 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
20 | """
21 | grid_h = np.arange(grid_size, dtype=np.float32)
22 | grid_w = np.arange(grid_size, dtype=np.float32)
23 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
24 | grid = np.stack(grid, axis=0)
25 |
26 | grid = grid.reshape([2, 1, grid_size, grid_size])
27 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
28 | # pos_embed =
29 | if cls_token:
30 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
31 | return pos_embed
32 |
33 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
34 | """
35 | grid_size: int of the grid height and width
36 | return:
37 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
38 | """
39 | grid_h = np.arange(grid_size, dtype=np.float32)
40 | grid_w = np.arange(grid_size, dtype=np.float32)
41 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
42 | grid = np.stack(grid, axis=0)
43 |
44 | grid = grid.reshape([2, 1, grid_size, grid_size])
45 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
46 | if cls_token:
47 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
48 | return pos_embed
49 |
50 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
51 | assert embed_dim % 2 == 0
52 |
53 | # use half of dimensions to encode grid_h
54 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
55 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
56 |
57 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
58 | return emb
59 |
60 |
61 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
62 | """
63 | embed_dim: output dimension for each position
64 | pos: a list of positions to be encoded: size (M,)
65 | out: (M, D)
66 | """
67 | assert embed_dim % 2 == 0
68 | omega = np.arange(embed_dim // 2, dtype=np.float32)
69 | omega /= embed_dim / 2.
70 | omega = 1. / 10000**omega # (D/2,)
71 |
72 | pos = pos.reshape(-1) # (M,)
73 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
74 |
75 | emb_sin = np.sin(out) # (M, D/2)
76 | emb_cos = np.cos(out) # (M, D/2)
77 |
78 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
79 | return emb
80 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
81 | """
82 | grid_size: int of the grid height and width
83 | return:
84 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
85 | """
86 | grid_l = np.arange(grid_size, dtype=np.float32)
87 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l)
88 | if cls_token:
89 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
90 | return pos_embed
91 |
92 | # --------------------------------------------------------
93 | # Interpolate position embeddings for high-resolution
94 | # References:
95 | # DeiT: https://github.com/facebookresearch/deit
96 | # --------------------------------------------------------
97 | def interpolate_pos_embed(model, checkpoint_model):
98 | if 'pos_embed' in checkpoint_model:
99 | pos_embed_checkpoint = checkpoint_model['pos_embed']
100 | embedding_size = pos_embed_checkpoint.shape[-1]
101 | num_patches = model.patch_embed.num_patches
102 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
103 | # height (== width) for the checkpoint position embedding
104 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
105 | # height (== width) for the new position embedding
106 | new_size = int(num_patches ** 0.5)
107 | # class_token and dist_token are kept unchanged
108 | if orig_size != new_size:
109 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
110 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
111 | # only the position tokens are interpolated
112 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
113 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
114 | pos_tokens = torch.nn.functional.interpolate(
115 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
116 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
117 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
118 | checkpoint_model['pos_embed'] = new_pos_embed
119 |
120 | if 'pos_embed_LIDAR' in checkpoint_model:
121 | pos_embed_LIDAR_checkpoint = checkpoint_model['pos_embed_LIDAR']
122 | embedding_size_LIDAR = pos_embed_LIDAR_checkpoint.shape[-1]
123 | num_patches_LIDAR = model.patch_embed_LIDAR.num_patches
124 | num_extra_tokens_LIDAR = model.pos_embed_LIDAR.shape[-2] - num_patches_LIDAR
125 | # height (== width) for the checkpoint position embedding
126 | orig_size_LIDAR = int((pos_embed_LIDAR_checkpoint.shape[-2] - num_extra_tokens_LIDAR) ** 0.5)
127 | # height (== width) for the new position embedding
128 | new_size_LIDAR = int(num_patches_LIDAR ** 0.5)
129 | # class_token and dist_token are kept unchanged
130 | if orig_size_LIDAR != new_size_LIDAR:
131 | print("Position interpolate from %dx%d to %dx%d" % (orig_size_LIDAR, orig_size_LIDAR, new_size_LIDAR, new_size_LIDAR))
132 | extra_tokens_LIDAR = pos_embed_LIDAR_checkpoint[:, :num_extra_tokens_LIDAR]
133 | # only the position tokens are interpolated
134 | pos_tokens_LIDAR = pos_embed_LIDAR_checkpoint[:, num_extra_tokens_LIDAR:]
135 | pos_tokens_LIDAR = pos_tokens_LIDAR.reshape(-1, orig_size_LIDAR, orig_size_LIDAR, embedding_size_LIDAR).permute(0, 3, 1, 2)
136 | pos_tokens_LIDAR = torch.nn.functional.interpolate(
137 | pos_tokens_LIDAR, size=(new_size_LIDAR, new_size_LIDAR), mode='bicubic', align_corners=False)
138 | pos_tokens_LIDAR = pos_tokens_LIDAR.permute(0, 2, 3, 1).flatten(1, 2)
139 | new_pos_embed_LIDAR = torch.cat((extra_tokens_LIDAR, pos_tokens_LIDAR), dim=1)
140 | checkpoint_model['pos_embed_LIDAR'] = new_pos_embed_LIDAR
141 |
--------------------------------------------------------------------------------
/record.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 |
5 | def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, training_time_ae, testing_time_ae, path):
6 | f = open(path, 'a')
7 |
8 | sentence0 = 'OAs for each iteration are:' + str(oa_ae) + '\n'
9 | f.write(sentence0)
10 | sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n'
11 | f.write(sentence1)
12 | sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n'
13 | f.write(sentence2)
14 | sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)) + ' ± ' + str(np.std(oa_ae)) + '\n'
15 | f.write(sentence3)
16 | sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)) + ' ± ' + str(np.std(aa_ae)) + '\n'
17 | f.write(sentence4)
18 | sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n'
19 | f.write(sentence5)
20 | sentence6 = 'Total average Training time is: ' + str(np.mean(training_time_ae)) + '\n'
21 | f.write(sentence6)
22 | sentence7 = 'Total average Testing time is: ' + str(np.mean(testing_time_ae)) + '\n' + '\n'
23 | f.write(sentence7)
24 |
25 | element_mean = np.mean(element_acc_ae, axis=0)
26 | element_std = np.std(element_acc_ae, axis=0)
27 | sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean) + '\n'
28 | f.write(sentence8)
29 | sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std) + '\n'
30 | f.write(sentence9)
31 |
32 | f.close()
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from scipy.io import savemat
2 |
3 | import record
4 | import scipy.io as sio
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torch.utils.data import DataLoader, TensorDataset
9 |
10 | import time
11 | import datetime
12 | import numpy as np
13 | import os
14 |
15 | from config import load_args
16 | from data_read import readdata
17 | from diffusion import create_diffusion
18 | from generate_pic import generate
19 | from hyper_dataset import HyperData
20 | from augment import CenterResizeCrop
21 | from util_CNN_clip import test_batch, pre_train
22 | from pos_embed import interpolate_pos_embed
23 | from timm.models.layers import trunc_normal_
24 | from model import mae_vit_HSIandLIDAR_patch3, vit_HSI_LIDAR_patch3, DDPM_LOSS
25 | from clip_model import build_model
26 | from collections import OrderedDict
27 |
28 | # get class name
29 | def read_classnames(text_file):
30 | """Return a dictionary containing
31 | key-value pairs of : .
32 | """
33 | classnames = OrderedDict()
34 | with open(text_file, "r") as f:
35 | lines = f.readlines()
36 | for line in lines:
37 | line = line.strip().split(" ")
38 | folder = line[0]
39 | classname = " ".join(line[1:])
40 | classnames[folder] = classname
41 | return classnames
42 |
43 | args = load_args()
44 |
45 | mask_ratio = args.mask_ratio
46 | windowsize = args.windowsize
47 | dataset = args.dataset
48 | type = args.type
49 | num_epoch = args.epochs
50 | num_fine_tuned =args.fine_tuned_epochs
51 | lr = args.lr
52 | train_num_per = args.train_num_perclass
53 | num_of_ex = 10
54 | batch_size= args.batch_size
55 |
56 | net_name = 'LDS2AE'
57 | day = datetime.datetime.now()
58 | day_str = day.strftime('%m_%d_%H_%M')
59 | halfsize = int((windowsize-1)/2)
60 | val_num = 1000
61 | Seed = 0
62 | _, _, _, _, _,_,_, _, _, _, _,_, _,gt,s = readdata(type, dataset, windowsize,train_num_per, val_num, 0)
63 | num_of_samples = int(s * 0.2)
64 | nclass = np.max(gt).astype(np.int64)
65 | print(nclass)
66 |
67 | if args.dataset == 'Muufl':
68 | in_chans_LIDAR = 2
69 | else:
70 | in_chans_LIDAR = 1
71 |
72 | size = 1
73 |
74 | KAPPA = []
75 | OA = []
76 | AA = []
77 | TRAINING_TIME = []
78 | TESTING_TIME = []
79 | ELEMENT_ACC = np.zeros((num_of_ex, nclass))
80 | af_result = np.zeros([nclass+3, num_of_ex])
81 | criterion = nn.CrossEntropyLoss()
82 |
83 | for num in range(0,num_of_ex):
84 | print('num:', num)
85 | train_image, train_image_LIDAR, train_label, validation_image1, validation_image_LIDAR1, validation_label1, nTrain_perClass, nvalid_perClass, \
86 | train_index, val_index, index, image, image_LiDAR, gt,s = readdata(type, dataset, windowsize,train_num_per,num_of_samples,num)
87 | ind = np.random.choice(validation_image1.shape[0], 200, replace = False)
88 | validation_image = validation_image1[ind]
89 | validation_image_LIDAR = validation_image_LIDAR1[ind]
90 | validation_label= validation_label1[ind]
91 | nvalid_perClass = np.zeros_like(nvalid_perClass)
92 | nband = train_image.shape[3]
93 |
94 |
95 | train_num = train_image.shape[0]
96 | train_image = np.transpose(train_image,(0,3,1,2))
97 | train_image_LIDAR = np.transpose(train_image_LIDAR,(0,3,1,2))
98 |
99 | validation_image = np.transpose(validation_image,(0,3,1,2))
100 | validation_image1 = np.transpose(validation_image1,(0,3,1,2))
101 | validation_image_LIDAR = np.transpose(validation_image_LIDAR,(0,3,1,2))
102 | validation_image_LIDAR1 = np.transpose(validation_image_LIDAR1,(0,3,1,2))
103 |
104 | if args.augment:
105 | transform_train = [CenterResizeCrop(scale_begin = args.scale, windowsize = windowsize)]
106 | untrain_dataset = HyperData((train_image, train_image_LIDAR, train_label), transform_train)
107 | else:
108 | untrain_dataset = TensorDataset(torch.tensor(train_image), torch.tensor(train_image_LIDAR), torch.tensor(train_label))
109 | untrain_loader = DataLoader(dataset = untrain_dataset, batch_size = batch_size, shuffle = True)
110 |
111 | print("=> creating model '{}'".format(net_name))
112 |
113 |
114 | ######################## pre-train ########################
115 | diffusion = create_diffusion(timestep_respacing="1000")
116 | net = mae_vit_HSIandLIDAR_patch3(img_size=(windowsize,windowsize), in_chans=nband, in_chans_LIDAR=in_chans_LIDAR, hid_chans = args.hid_chans, hid_chans_LIDAR = args.hid_chans_LIDAR, embed_dim=args.encoder_dim, depth=args.encoder_depth, num_heads=args.encoder_num_heads, mlp_ratio=args.mlp_ratio,
117 | decoder_embed_dim=args.decoder_dim, decoder_depth=args.decoder_depth, decoder_num_heads=args.decoder_num_heads, nb_classes=nclass, global_pool=False)
118 |
119 | net.cuda()
120 | optimizer = optim.Adam(net.parameters(),lr = lr, weight_decay= 1e-4)
121 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5,T_mult=2)
122 |
123 | tic1 = time.time()
124 | for epoch in range(num_epoch):
125 | net.train()
126 | total_loss = 0
127 | for idx, (x, x_LIDAR, y) in enumerate(untrain_loader):
128 |
129 | x = x.cuda()
130 | x_LIDAR = x_LIDAR.cuda()
131 | y = y.cuda()
132 | t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],)).cuda()
133 | model_kwargs = dict(y=y)
134 | model_kwargs['mask_ratio'] = mask_ratio
135 |
136 | # Method 1 of forward process
137 | noise = torch.randn_like(x)
138 | x_t = diffusion.q_sample(x, t, noise)
139 | noise_LIDAR = torch.randn_like(x_LIDAR)
140 | x_LIDAR_t = diffusion.q_sample(x_LIDAR, t, noise_LIDAR)
141 |
142 | pred_imgs, pred_imgs_LIDAR, logits, mask, mask_LIDAR = net(x_t, x_LIDAR_t, t, **model_kwargs)
143 |
144 | cls_loss = criterion(logits / args.temperature, y)
145 | loss_mse_m, loss_mse_LIDAR_m, loss_mse_v, loss_mse_LIDAR_v = DDPM_LOSS(pred_imgs, pred_imgs_LIDAR, x, x_LIDAR, mask, mask_LIDAR, size)
146 |
147 | #removing classification loss
148 | loss = 0.1 * (loss_mse_m + loss_mse_LIDAR_m) + loss_mse_v + loss_mse_LIDAR_v
149 |
150 | optimizer.zero_grad()
151 | loss.backward()
152 | optimizer.step()
153 | total_loss = total_loss + loss
154 |
155 | scheduler.step()
156 | total_loss = total_loss/(idx+1)
157 | state = {'model':net.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
158 | print('epoch:',epoch,
159 | 'loss:',total_loss.data.cpu().numpy())
160 |
161 | toc1 = time.time()
162 | torch.save(state, './net.pt')
163 |
164 |
165 | # ######################## finetune # ########################
166 | classnames_dict = read_classnames("./classnames_houston.txt")
167 | classnames = list(classnames_dict.values())
168 | model = build_model(img_size=(windowsize,windowsize), in_chans=nband, in_chans_LIDAR=in_chans_LIDAR, hid_chans = args.hid_chans, hid_chans_LIDAR = args.hid_chans_LIDAR, embed_dim=args.encoder_dim, depth=args.encoder_depth, num_heads=args.encoder_num_heads, mlp_ratio=args.mlp_ratio,num_classes = nclass, global_pool=False).cuda()
169 |
170 | tic2 = time.time()
171 |
172 | optimizer = optim.Adam(model.parameters(), lr = args.fine_tuned_lr, weight_decay= 1e-5)
173 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [100, 131], gamma = 0.1, last_epoch=-1)
174 | model = pre_train(model, train_image, train_image_LIDAR, train_label,validation_image, validation_image_LIDAR, validation_label, num_fine_tuned, optimizer, scheduler, batch_size, diffusion,classnames, val = False)
175 | toc2 = time.time()
176 | model.load_state_dict(torch.load('Best_val_model/net_params.pkl'))
177 | state_finetune = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
178 | true_cla, overall_accuracy, average_accuracy, kappa, true_label, test_pred, test_index, cm, pred_array = test_batch(model.eval(), image, image_LiDAR, index, 128, nTrain_perClass,nvalid_perClass, halfsize, diffusion,classnames)
179 | toc3 = time.time()
180 |
181 | af_result[:nclass,num] = true_cla
182 | af_result[nclass,num] = overall_accuracy
183 | af_result[nclass+1,num] = average_accuracy
184 | af_result[nclass+2,num] = kappa
185 |
186 | OA.append(overall_accuracy)
187 | AA.append(average_accuracy)
188 | KAPPA.append(kappa)
189 | TRAINING_TIME.append(toc1 - tic1 + toc2 - tic2)
190 | TESTING_TIME.append(toc3 - toc2)
191 | ELEMENT_ACC[num, :] = true_cla
192 | classification_map, gt_map = generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, overall_accuracy, halfsize, dataset, day_str, num, net_name)
193 | file_name = 'result/'+ dataset + '/' + '40_11_0.7_p3' + str(overall_accuracy)+'.mat'
194 | savemat(file_name, {'map':classification_map})
195 | torch.save(state_finetune, 'model/'+ dataset + '/' + '40_11_0.7_p3' + str(overall_accuracy)+'net.pt')
196 | result = np.mean(af_result, axis = 1)
197 | print(result)
198 | print("--------" + net_name + " Training Finished-----------")
199 | record.record_output(OA, AA, KAPPA, ELEMENT_ACC, TRAINING_TIME, TESTING_TIME,
200 | 'records/' + dataset + '/' + net_name + '_' + day_str+ '_' + str(args.epochs)+ '_' + str(args.fine_tuned_epochs) + '_train_num:' + str(train_image.shape[0]) +'_windowsize:' + str(windowsize)+'_mask_ratio_' + str(mask_ratio) + '_temperature_' + str(args.temperature) +
201 | '_augment_' + str(args.augment) +'_aug_scale_' + str(args.scale) + '_loss_ratio_' + str(args.cls_loss_ratio) +'_decoder_dim_' + str(args.decoder_dim) + '_decoder_depth_' + str(args.decoder_depth)+ '.txt')
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/util_CNN_clip.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from sklearn import metrics
5 | from torch.utils.data import TensorDataset, DataLoader
6 | import torch.nn as nn
7 |
8 | ce_loss = nn.CrossEntropyLoss()
9 |
10 | def text_valid(classnames):
11 |
12 | ctx_init_1 = "a patch of a" #p1
13 | ctx_init_2 = "a fine grained patch of a" #p2
14 | ctx_init_3 = "a multimodal fusion patch of a" #p3
15 | ctx_init_4 = "a patch of a" #p4
16 | classnames = [name.replace("_", " ") for name in classnames]
17 | prompts = [ctx_init_3 + " " + name + "." for name in classnames]
18 | tokenized_prompts_fuse = torch.cat([tokenize(p) for p in prompts]) # (n_cls, n_tkn)
19 |
20 |
21 | return tokenized_prompts_fuse
22 | def tr_acc(model, image, image_LIDAR, label, diffusion,classnames):
23 | train_dataset = TensorDataset(torch.tensor(image), torch.tensor(image_LIDAR), torch.tensor(label))
24 | train_loader = DataLoader(dataset = train_dataset, batch_size = 64, shuffle = False)
25 | train_loss = 0
26 | corr_num = 0
27 | for idx, (image_batch, image_LIDAR_batch, label_batch) in enumerate(train_loader):
28 | trans_image_batch = image_batch.cuda()
29 | trans_image_LIDAR_batch = image_LIDAR_batch.cuda()
30 | label_batch = label_batch.cuda()
31 | t = torch.randint(0, diffusion.num_timesteps, (trans_image_batch.shape[0],)).cuda()
32 | text = text_valid(classnames).cuda()
33 | logits,_ = model(trans_image_batch, trans_image_LIDAR_batch, t,text)
34 |
35 | if isinstance(logits,tuple):
36 | logits = logits[-1]
37 | pred = torch.max(logits, dim=1)[1]
38 | loss = ce_loss(logits, label_batch)
39 | train_loss = train_loss + loss.cpu().data.numpy()
40 | corr_num = torch.eq(pred, label_batch).float().sum().cpu().numpy() + corr_num
41 | return corr_num/image.shape[0], train_loss/(idx+1)
42 |
43 | from simple_tokenizer import SimpleTokenizer as _Tokenizer
44 | _tokenizer = _Tokenizer()
45 | from pkg_resources import packaging
46 | def tokenize(texts, context_length: int = 77, truncate: bool = False):
47 | """
48 | Returns the tokenized representation of given input string(s)
49 |
50 | Parameters
51 | ----------
52 | texts : Union[str, List[str]]
53 | An input string or a list of input strings to tokenize
54 |
55 | context_length : int
56 | The context length to use; all CLIP models use 77 as the context length
57 |
58 | truncate: bool
59 | Whether to truncate the text in case its encoding is longer than the context length
60 |
61 | Returns
62 | -------
63 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
64 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
65 | """
66 | if isinstance(texts, str):
67 | texts = [texts]
68 |
69 | sot_token = _tokenizer.encoder["<|startoftext|>"]
70 | eot_token = _tokenizer.encoder["<|endoftext|>"]
71 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
72 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
73 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
74 | else:
75 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
76 |
77 | for i, tokens in enumerate(all_tokens):
78 | if len(tokens) > context_length:
79 | if truncate:
80 | tokens = tokens[:context_length]
81 | tokens[-1] = eot_token
82 | else:
83 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
84 | result[i, :len(tokens)] = torch.tensor(tokens)
85 |
86 | return result
87 |
88 | def text_target(classnames,batch_target):
89 |
90 | ctx_init_1 = "a patch of a" #p1
91 | ctx_init_2 = "a fine grained patch of a" #p2
92 | ctx_init_3 = "a multimodal fusion patch of a" #p3
93 | ctx_init_4 = "a patch of a" #p4
94 | classnames = [name.replace("_", " ") for name in classnames]
95 | prompts = [ctx_init_3 + " " + classnames[batch_target[i].item()] + "." for i in range(len(batch_target))]
96 | tokenized_prompts_fuse = torch.cat([tokenize(p) for p in prompts]) # (n_cls, n_tkn)
97 |
98 |
99 | return tokenized_prompts_fuse
100 |
101 |
102 |
103 | def pre_train(model, train_image, train_image_LIDAR, train_label,validation_image, validation_image_LIDAR, validation_label, epoch, optimizer, scheduler, bs, diffusion,classnames, val = False, ):
104 | train_dataset = TensorDataset(torch.tensor(train_image), torch.tensor(train_image_LIDAR), torch.tensor(train_label))
105 | train_loader = DataLoader(dataset = train_dataset, batch_size = bs, shuffle = True)
106 | Train_loss = []
107 | Train_acc = []
108 | Val_loss = []
109 | Val_acc = []
110 | BestAcc = 0
111 | for i in range(epoch):
112 | model.train()
113 | train_loss = 0
114 | for idx, (image_batch, image_LIDAR_batch, label_batch) in enumerate(train_loader):
115 |
116 | trans_image_batch = image_batch.cuda()
117 | trans_image_LIDAR_batch = image_LIDAR_batch.cuda()
118 | label_batch = label_batch.cuda()
119 | text = text_target(classnames,label_batch).cuda()
120 | t = torch.randint(0, diffusion.num_timesteps, (trans_image_batch.shape[0],)).cuda()
121 | # logits = model(trans_image_batch, trans_image_LIDAR_batch, t,text)
122 | # loss = ce_loss(logits, label_batch)
123 | logits_per_image_x,logits_per_text = model(trans_image_batch, trans_image_LIDAR_batch, t,text)
124 | labels = torch.arange(len(logits_per_image_x)).to(logits_per_image_x.device)
125 | loss = ce_loss(logits_per_image_x, labels)
126 | loss += ce_loss(logits_per_text, labels)
127 | optimizer.zero_grad()
128 | loss.backward()
129 | optimizer.step()
130 |
131 | train_loss = train_loss + loss
132 | scheduler.step()
133 | train_loss = train_loss / (idx + 1)
134 | train_acc, tr_loss = tr_acc(model.eval(), train_image, train_image_LIDAR, train_label, diffusion,classnames)
135 | val_acc, val_loss= tr_acc(model.eval(), validation_image, validation_image_LIDAR, validation_label, diffusion,classnames)
136 | if val_acc > BestAcc:
137 | torch.save(model.state_dict(), 'Best_val_model/' + 'net_params.pkl')
138 | BestAcc = val_acc
139 | print("epoch {}, training loss: {:.4f}, train acc:{:.4f}, valid acc:{:.4f}".format(i, train_loss.item(), train_acc*100, val_acc*100))
140 |
141 | if val:
142 | Train_loss.append(tr_loss)
143 | Val_loss.append(val_loss)
144 | Train_acc.append(train_acc)
145 | Val_acc.append(val_acc)
146 | if val:
147 | return model, [Train_loss, Train_acc, Val_loss, Val_acc]
148 | else:
149 | return model
150 |
151 | def test_batch(model, image, image_LIDAR, index, BATCH_SIZE, nTrain_perClass, nvalid_perClass, halfsize, diffusion,classnames):
152 | ind = index[0][nTrain_perClass[0]+ nvalid_perClass[0]:,:]
153 | nclass = len(index)
154 | true_label = np.zeros(ind.shape[0], dtype = np.int32)
155 | for i in range(1, nclass):
156 | ddd = index[i][nTrain_perClass[i] + nvalid_perClass[i]:,:]
157 | ind = np.concatenate((ind, ddd), axis = 0)
158 | tr_label = np.ones(ddd.shape[0], dtype = np.int32) * i
159 | true_label = np.concatenate((true_label, tr_label), axis = 0)
160 | test_index = np.copy(ind)
161 | length = ind.shape[0]
162 | if length % BATCH_SIZE != 0:
163 | add_num = BATCH_SIZE - length % BATCH_SIZE
164 | ff = range(length)
165 | add_ind = np.random.choice(ff, add_num, replace = False)
166 | add_ind = ind[add_ind]
167 | ind = np.concatenate((ind,add_ind), axis =0)
168 |
169 | pred_array = np.zeros([ind.shape[0],nclass], dtype = np.float32)
170 | n = ind.shape[0] // BATCH_SIZE
171 | windowsize = 2 * halfsize + 1
172 | image_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image.shape[2]], dtype=np.float32)
173 | image_LIDAR_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image_LIDAR.shape[2]], dtype=np.float32)
174 | for i in range(n):
175 | for j in range(BATCH_SIZE):
176 | m = ind[BATCH_SIZE*i+j, :]
177 | image_batch[j,:,:,:] = image[(m[0] - halfsize):(m[0] + halfsize + 1),
178 | (m[1] - halfsize):(m[1] + halfsize + 1),:]
179 | image_b = np.transpose(image_batch,(0,3,1,2))
180 | image_LIDAR_batch[j,:,:,:] = image_LIDAR[(m[0] - halfsize):(m[0] + halfsize + 1),
181 | (m[1] - halfsize):(m[1] + halfsize + 1),:]
182 | image_LIDAR_b = np.transpose(image_LIDAR_batch,(0,3,1,2))
183 |
184 | t = torch.randint(0, diffusion.num_timesteps, (image_b.shape[0],)).cuda()
185 | text = text_valid(classnames).cuda()
186 | logits,_ = model(torch.tensor(image_b).cuda(), torch.tensor(image_LIDAR_b).cuda(), t,text)
187 | if isinstance(logits,tuple):
188 | logits = logits[-1]
189 | pred_array[i*BATCH_SIZE:(i+1)*BATCH_SIZE] = torch.softmax(logits, dim = 1).cpu().data.numpy()
190 | pred_array = pred_array[range(length)]
191 | predict_label = np.argmax(pred_array, axis=1)
192 |
193 |
194 | confusion_matrix = metrics.confusion_matrix(true_label, predict_label)
195 | overall_accuracy = metrics.accuracy_score(true_label, predict_label)
196 |
197 | true_cla = np.zeros(nclass, dtype=np.int64)
198 | for i in range(nclass):
199 | true_cla[i] = confusion_matrix[i,i]
200 | test_num_class = np.sum(confusion_matrix,1)
201 | test_num = np.sum(test_num_class)
202 | num1 = np.sum(confusion_matrix,0)
203 | po = overall_accuracy
204 | pe = np.sum(test_num_class*num1)/(test_num*test_num)
205 | kappa = (po-pe)/(1-pe)*100
206 | true_cla = np.true_divide(true_cla,test_num_class)*100
207 | average_accuracy = np.average(true_cla)
208 | print('overall_accuracy: {0:f}'.format(overall_accuracy*100))
209 | print('average_accuracy: {0:f}'.format(average_accuracy))
210 | print('kappa:{0:f}'.format(kappa))
211 | return true_cla, overall_accuracy*100, average_accuracy, kappa, true_label, predict_label, test_index, confusion_matrix, pred_array
212 |
--------------------------------------------------------------------------------