├── GCN.py ├── Graphormer.py ├── MAIN_Mamba_ip.py ├── MAIN_Mamba_sa.py ├── MAIN_Mamba_uh2018.py ├── Mamba.py ├── README.md ├── code_acc.py ├── functions.py └── rope.py /GCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 5 | from einops import rearrange, repeat 6 | 7 | 8 | 9 | class GCNLayer(nn.Module): 10 | def __init__(self, input_dim: int, output_dim: int): 11 | super(GCNLayer, self).__init__() 12 | self.BN = nn.BatchNorm1d(input_dim) 13 | self.Activition = nn.LeakyReLU() 14 | self.sigma1 = torch.nn.Parameter(torch.tensor([0.1], requires_grad=True)) 15 | self.GCN_liner_theta_1 = nn.Sequential(nn.Linear(input_dim, 256)) 16 | self.GCN_liner_out_1 = nn.Sequential(nn.Linear(input_dim, output_dim)) 17 | 18 | 19 | def A_to_D_inv(self, A: torch.Tensor): 20 | D = A.sum(2) 21 | batch,l=D.shape 22 | D1=torch.reshape(D, (batch * l,1)) 23 | D1=D1.squeeze(1) 24 | D2=torch.pow(D1, -0.5) 25 | D2=torch.reshape(D2,(batch,l)) 26 | D_hat=torch.zeros([batch,l,l],dtype=torch.float) 27 | for i in range(batch): 28 | D_hat[i] = torch.diag(D2[i]) 29 | return D_hat.cuda() 30 | 31 | def forward(self, H, A ): 32 | nodes_count = A.shape[1] 33 | I = torch.eye(nodes_count, nodes_count, requires_grad=False).to(device) 34 | A = A + I 35 | (batch, l, c) = H.shape 36 | H1 = torch.reshape(H,(batch*l, c)) 37 | H2 = self.BN(H1) 38 | H=torch.reshape(H2,(batch,l, c)) 39 | D_hat = self.A_to_D_inv(A) 40 | A_hat = torch.matmul(D_hat, torch.matmul(A,D_hat))#点乘 41 | # A_hat = I + A_hat 42 | output = torch.matmul(A_hat, self.GCN_liner_out_1(H))#矩阵相乘 43 | output = self.Activition(output) 44 | return output 45 | 46 | 47 | class GCN(nn.Module): 48 | def __init__(self, height: int, width: int, changel: int, layers_count: int): 49 | super(GCN, self).__init__() 50 | self.channel = changel 51 | self.height = height 52 | self.width = width 53 | self.GCN_Branch = nn.Sequential() 54 | for i in range(layers_count): 55 | self.GCN_Branch.add_module('GCN_Branch' + str(i), GCNLayer(self.channel, self.channel)) 56 | 57 | # self.Softmax_linear = nn.Sequential(nn.Linear(64, self.class_count)) 58 | 59 | self.BN = nn.BatchNorm1d(64) 60 | 61 | def forward(self, x: torch.Tensor,A: torch.Tensor): 62 | H = x 63 | for i in range(len(self.GCN_Branch)): 64 | H = self.GCN_Branch[i](H, A) 65 | return H 66 | 67 | -------------------------------------------------------------------------------- /Graphormer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | from einops import rearrange, repeat 10 | 11 | 12 | def gain_neighborhood_band(x_train, band, band_patch, patch_all): 13 | nn = band_patch // 2 14 | pp = (patch_all) // 2 15 | x_train_band = torch.zeros((x_train.shape[0], patch_all*band_patch, band),dtype=float)#64*27*200 16 | # 中心区域 17 | x_train_band[:,nn*patch_all:(nn+1)*patch_all,:] = x_train 18 | #左边镜像 19 | for i in range(nn): 20 | if pp > 0: 21 | x_train_band[:,i*patch_all:(i+1)*patch_all,:i+1] = x_train[:,:,band-i-1:] 22 | x_train_band[:,i*patch_all:(i+1)*patch_all,i+1:] = x_train[:,:,:band-i-1] 23 | else: 24 | x_train_band[:,i:(i+1),:(nn-i)] = x_train[:,0:1,(band-nn+i):] 25 | x_train_band[:,i:(i+1),(nn-i):] = x_train[:,0:1,:(band-nn+i)] 26 | #右边镜像 27 | for i in range(nn): 28 | if pp > 0: 29 | x_train_band[:,(nn+i+1)*patch_all:(nn+i+2)*patch_all,:band-i-1] = x_train[:,:,i+1:] 30 | x_train_band[:,(nn+i+1)*patch_all:(nn+i+2)*patch_all,band-i-1:] = x_train[:,:,:i+1] 31 | else: 32 | x_train_band[:,(nn+1+i):(nn+2+i),(band-i-1):] = x_train[:,0:1,:(i+1)] 33 | x_train_band[:,(nn+1+i):(nn+2+i),:(band-i-1)] = x_train[:,0:1,(i+1):] 34 | return x_train_band 35 | 36 | 37 | class Residual(nn.Module): 38 | def __init__(self, fn): 39 | super().__init__() 40 | self.fn = fn 41 | def forward(self, x, **kwargs): 42 | return self.fn(x, **kwargs) + x 43 | 44 | class PreNorm(nn.Module): 45 | def __init__(self, dim, fn): 46 | super().__init__() 47 | self.norm = nn.LayerNorm(dim) 48 | self.fn = fn 49 | def forward(self, x, **kwargs): 50 | return self.fn(self.norm(x), **kwargs) 51 | 52 | class FeedForward(nn.Module): 53 | def __init__(self, dim, hidden_dim, dropout = 0.): 54 | super().__init__() 55 | self.net = nn.Sequential( 56 | nn.Linear(dim, hidden_dim), 57 | nn.GELU(), 58 | nn.Dropout(dropout), 59 | nn.Linear(hidden_dim, dim), 60 | nn.Dropout(dropout) 61 | ) 62 | def forward(self, x): 63 | return self.net(x) 64 | 65 | class Attention(nn.Module): 66 | def __init__(self, dim, heads, dim_head, dropout,dis,D2,edge): 67 | super().__init__() 68 | 69 | # self.degree_encoder = nn.Embedding(10, dim, padding_idx=0)#根据度矩阵排序 0-80,然后映射 ,维度和token相同 70 | # self.spatial_pos_encoder = nn.Embedding(8, heads, padding_idx=0)#划定5*5的区域计算欧式距离然后映射,维度和head相同 71 | # self.edge_dis_encoder = nn.Embedding(4, heads, padding_idx=0)#将每个边都进行编码,生成对应的权重,然后利用生成的权重乘以距离 72 | # self.edge_weight = nn.Embedding(4, heads, padding_idx=0) # 将每个边都进行编码,生成对应的权重,然后利用生成的权重乘以距离 73 | 74 | inner_dim = dim_head * heads 75 | self.heads = heads 76 | self.scale = dim_head ** -0.5 77 | 78 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 79 | self.to_out = nn.Sequential( 80 | nn.Linear(inner_dim, dim), 81 | nn.Dropout(dropout) 82 | ) 83 | # self.degree_encoder = nn.Embedding(10, dim, padding_idx=0)#根据度矩阵排序 0-80,然后映射 ,维度和token相同 84 | # self.spatial_pos_encoder = nn.Embedding(8, heads, padding_idx=0)#划定5*5的区域计算欧式距离然后映射,维度和head相同 85 | # self.edge_dis_encoder = nn.Embedding(4, heads, padding_idx=0)#将每个边都进行编码,生成对应的权重,然后利用生成的权重乘以距离 86 | # self.edge_weight = nn.Embedding(4, heads, padding_idx=0) # 将每个边都进行编码,生成对应的权重,然后利用生成的权重乘以距离 87 | 88 | def forward(self, x, degree,mask = None): 89 | # x:[b,n,dim] 90 | b, n, _, h = *x.shape, self.heads 91 | #######################中心编码############################### 92 | # x = x + self.degree_encoder(self.D2) # 中心编码 93 | #########################中心编码################################ 94 | # get qkv tuple:([b,n,head_num*head_dim],[...],[...]) 95 | qkv = self.to_qkv(x).chunk(3, dim = -1) 96 | # split q,k,v from [b,n,head_num*head_dim] -> [b,head_num,n,head_dim] 97 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 98 | 99 | # transpose(k) * q / sqrt(head_dim) -> [b,head_num,n,n] 100 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 101 | #####################距离编码####边编码############################## 102 | # edge = torch.tensor([0, 1, 2, 3]).cuda() 103 | # spatial_pos_encoder = self.spatial_pos_encoder(self.dis).unsqueeze(0).permute(0, 3, 1, 2) 104 | # edge_dis_encoder = torch.mul(self.edge_dis_encoder(edge), self.edge_weight(edge)) # 8,8,4 105 | # edge_dis_encoder = torch.matmul(self.edge, edge_dis_encoder) 106 | # edge_dis_encoder = edge_dis_encoder.unsqueeze(0).permute(0, 3, 1, 2) 107 | # dots = dots + spatial_pos_encoder#距离编码 108 | # dots = dots + edge_dis_encoder # 边编码 109 | ###################################################################### 110 | mask_value = -torch.finfo(dots.dtype).max 111 | 112 | # mask value: -inf 113 | if mask is not None: 114 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 115 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 116 | mask = mask[:, None, :] * mask[:, :, None] 117 | dots.masked_fill_(~mask, mask_value) 118 | del mask 119 | 120 | # softmax normalization -> attention matrix 121 | attn = dots.softmax(dim=-1) 122 | # value * attention matrix -> output 123 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 124 | # cat all output -> [b, n, head_num*head_dim] 125 | out = rearrange(out, 'b h n d -> b n (h d)') 126 | out = self.to_out(out) 127 | return out 128 | 129 | # class Attention(nn.Module): 130 | # def __init__(self, dim, heads, dim_head, dropout,dis,D2,edge): 131 | # super().__init__() 132 | # inner_dim = dim_head * heads 133 | # self.heads = heads 134 | # self.scale = dim_head ** -0.5 135 | # self.dis = dis 136 | # self.D2 = D2 137 | # self.edge=edge 138 | # self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 139 | # self.to_out = nn.Sequential( 140 | # nn.Linear(inner_dim, dim), 141 | # nn.Dropout(dropout) 142 | # ) 143 | # self.degree_encoder = nn.Embedding(10, dim, padding_idx=0)#根据度矩阵排序 0-80,然后映射 ,维度和token相同 144 | # self.spatial_pos_encoder = nn.Embedding(8, heads, padding_idx=0)#划定5*5的区域计算欧式距离然后映射,维度和head相同 145 | # self.edge_dis_encoder = nn.Embedding(4, heads, padding_idx=0)#将每个边都进行编码,生成对应的权重,然后利用生成的权重乘以距离 146 | # self.edge_weight = nn.Embedding(4, heads, padding_idx=0) # 将每个边都进行编码,生成对应的权重,然后利用生成的权重乘以距离 147 | # def forward(self, x, degree,mask = None): 148 | # b, n, _, h = *x.shape, self.heads 149 | # # edge= torch.tensor([1,2,3,4,5,6,7,8]).cuda() 150 | # # edge = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]).cuda() 151 | # edge = torch.tensor([0, 1, 2, 3]).cuda() 152 | # # edge = torch.tensor([0, 1, 2, 3]).cuda() 153 | # # c = self.degree_encoder(self.D2) # 中心编码 154 | # # cc=self.degree_encoder(self.D2).unsqueeze(0)#中心编码 155 | # # x=x+self.degree_encoder( .D2)#中心编码 156 | # qkv = self.to_qkv(x).chunk(3, dim = -1) 157 | # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 158 | # dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale#[64,4,81,81] 159 | # spatial_pos_encoder=self.spatial_pos_encoder(self.dis).unsqueeze(0).permute(0, 3, 1, 2) 160 | # edge_dis_encoder=torch.mul(self.edge_dis_encoder(edge),self.edge_weight(edge))#8,8,4 161 | # edge_dis_encoder=torch.matmul(self.edge,edge_dis_encoder ) 162 | # edge_dis_encoder=edge_dis_encoder.unsqueeze(0).permute(0, 3, 1, 2) 163 | # # dots = dots + spatial_pos_encoder#距离编码 164 | # # dots = dots + edge_dis_encoder # 边编码 165 | # mask_value = -torch.finfo(dots.dtype).max 166 | # # mask value: -inf 167 | # if mask is not None: 168 | # mask = F.pad(mask.flatten(1), (1, 0), value = True) 169 | # assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 170 | # mask = mask[:, None, :] * mask[:, :, None] 171 | # dots.masked_fill_(~mask, mask_value) 172 | # del mask 173 | # 174 | # # softmax normalization -> attention matrix 175 | # attn = dots.softmax(dim=-1) 176 | # # value * attention matrix -> output 177 | # out = torch.einsum('bhij,bhjd->bhid', attn, v) 178 | # # cat all output -> [b, n, head_num*head_dim] 179 | # out = rearrange(out, 'b h n d -> b n (h d)') 180 | # out = self.to_out(out) 181 | # return out 182 | 183 | class Transformer(nn.Module): 184 | def __init__(self, dim, depth, heads, dim_head, mlp_head, dropout, num_channel, mode, dis,D2,edge): 185 | super().__init__() 186 | 187 | self.layers = nn.ModuleList([]) 188 | for _ in range(depth): 189 | self.layers.append(nn.ModuleList([ 190 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, dis=dis,D2=D2,edge=edge))), 191 | Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout))) 192 | ])) 193 | 194 | self.mode = mode 195 | self.skipcat = nn.ModuleList([]) 196 | for _ in range(depth-2): 197 | self.skipcat.append(nn.Conv2d(num_channel, num_channel, [1, 2], 1, 0)) 198 | 199 | def forward(self, x, degree, mask = None): 200 | if self.mode == 'ViT': 201 | for attn, ff in self.layers: 202 | x = attn(x,degree=degree, mask = mask) 203 | x = ff(x) 204 | elif self.mode == 'CAF': 205 | last_output = [] 206 | nl = 0 207 | for attn, ff in self.layers: 208 | last_output.append(x) 209 | if nl > 1: 210 | x = self.skipcat[nl-2](torch.cat([x.unsqueeze(3), last_output[nl-2].unsqueeze(3)], dim=3)).squeeze(3) 211 | x = attn(x,degree=degree, mask = mask) 212 | x = ff(x) 213 | nl += 1 214 | 215 | return x 216 | 217 | class ViT(nn.Module): 218 | def __init__(self, band, num_token, num_classes, dim, depth, heads, mlp_dim,dis,D2,edge, pool='cls', channels=1, dim_head = 16, dropout=0., emb_dropout=0., mode='ViT'): 219 | super().__init__() 220 | 221 | self.num_classes=num_classes 222 | self.pos_embedding = nn.Parameter(torch.randn(1, num_token, dim))#1,201,64 223 | self.patch_to_embedding = nn.Linear(band, dim) 224 | # self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 225 | 226 | self.dropout = nn.Dropout(emb_dropout) 227 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_token, mode, dis,D2,edge) 228 | 229 | self.pool = pool 230 | self.to_latent = nn.Identity() 231 | 232 | self.mlp_head = nn.Sequential( 233 | nn.LayerNorm(dim), 234 | nn.Linear(dim, num_classes) 235 | ) 236 | def forward(self, x, center_pos,degree,mask = None): 237 | 238 | x=x.to(torch.float32) 239 | x = self.patch_to_embedding(x) 240 | batch, n, _ = x.shape 241 | pos=self.pos_embedding[:, :n] 242 | x += pos 243 | x = self.dropout(x) 244 | x = self.transformer(x,degree, mask) 245 | x = self.mlp_head(x) #[64,81,16] 246 | x_out=torch.zeros((batch, self.num_classes),dtype=float).to(device) 247 | for i in range(batch): 248 | x_out[i]=x[i,center_pos[i],:] 249 | return x_out 250 | 251 | 252 | class GCNLayer(nn.Module): 253 | def __init__(self, input_dim: int, output_dim: int): 254 | super(GCNLayer, self).__init__() 255 | self.BN = nn.BatchNorm1d(input_dim) 256 | self.Activition = nn.LeakyReLU() 257 | self.sigma1 = torch.nn.Parameter(torch.tensor([0.1], requires_grad=True)) 258 | # 第一层GCN 259 | self.GCN_liner_theta_1 = nn.Sequential(nn.Linear(input_dim, 256)) 260 | self.GCN_liner_out_1 = nn.Sequential(nn.Linear(input_dim, output_dim)) 261 | # 这个函数主要是为了生成对角线全1,其余部分全0的二维数组 262 | 263 | def A_to_D_inv(self, A: torch.Tensor): 264 | D = A.sum(2) 265 | batch,l=D.shape 266 | D1=torch.reshape(D, (batch * l,1)) 267 | D1=D1.squeeze(1) 268 | D2=torch.pow(D1, -0.5) 269 | D2=torch.reshape(D2,(batch,l)) 270 | D_hat=torch.zeros([batch,l,l],dtype=torch.float) 271 | for i in range(batch): 272 | D_hat[i] = torch.diag(D2[i]) 273 | return D_hat.cuda() 274 | 275 | def forward(self, H, A ): 276 | nodes_count = A.shape[1] 277 | I = torch.eye(nodes_count, nodes_count, requires_grad=False).to(device) 278 | # 方案一:一阶切比雪夫 279 | (batch, l, c) = H.shape 280 | H1 = torch.reshape(H,(batch*l, c)) 281 | H2 = self.BN(H1) 282 | H=torch.reshape(H2,(batch,l, c)) 283 | D_hat = self.A_to_D_inv(A) 284 | A_hat = torch.matmul(D_hat, torch.matmul(A,D_hat))#点乘 285 | A_hat = I + A_hat 286 | output = torch.matmul(A_hat, self.GCN_liner_out_1(H))#矩阵相乘 287 | output = self.Activition(output) 288 | return output 289 | # 方案一:二阶切比雪夫 290 | # H = H.to(torch.float16) 291 | # H = self.BN(H).to(torch.float16) 292 | # A1 = self.A1.to(torch.float16) 293 | # A2 = self.A2.to(torch.float16) 294 | # D1_hat = self.A_to_D_inv(A1).to(torch.float16) 295 | # A1_hat = torch.matmul(D1_hat, torch.matmul(A1, D1_hat)) # 点乘 296 | # M = self.I + A1_hat + torch.matmul(A1_hat.to(torch.float16), A1_hat.to(torch.float16)) 297 | # W = math.exp(-1) / (math.exp(-1) + math.exp(-4)) * A1 + math.exp(-4) / (math.exp(-1) + math.exp(-4)) * ( 298 | # A2 - A1) + self.I 299 | # M = M.mul(W) # 逐点相乘 300 | # output = torch.mm(M.to(torch.float16), self.GCN_liner_out_1(H.to(torch.float32)).to(torch.float16)) # 矩阵相乘 301 | # output = self.Activition(output) 302 | # return output, A1 303 | 304 | class neigh_Conv(nn.Module): 305 | def __init__(self, channel, neigh_number): 306 | super(neigh_Conv, self).__init__() 307 | self.neigh_Branch = nn.Sequential() 308 | self.neigh_number=neigh_number 309 | for i in range(channel-neigh_number+1): 310 | self.neigh_Branch.add_module('neigh_Branch' + str(i), nn.Conv2d(neigh_number, 1, kernel_size = (1,1), stride=1)) 311 | 312 | def forward(self, x): 313 | batch,c,w,h = x.shape 314 | for i in range(c-self.neigh_number+1): 315 | if i==0: 316 | A=self.neigh_Branch[i](x[:,i:i+self.neigh_number,:,:])#[64 1 21 1] 317 | if i>0: 318 | B= self.neigh_Branch[i](x[:, i:i + self.neigh_number, :, :]) # [64 1 21 1] 319 | A = torch.cat((A,B),1) 320 | return A 321 | 322 | class neigh_Conv2(nn.Module): 323 | def __init__(self, channel, neigh_number): 324 | super(neigh_Conv2, self).__init__() 325 | self.neigh_Branch = nn.Sequential() 326 | self.neigh_number=neigh_number 327 | for i in range(channel): 328 | self.neigh_Branch.add_module('neigh_Branch' + str(i), nn.Conv2d(neigh_number, 1, kernel_size = (1,1), stride=1)) 329 | 330 | def forward(self, x): 331 | batch,c,w,h = x.shape 332 | start=int((self.neigh_number-1)/2)#3 1 333 | end = int(c-1-start)#c-1 334 | for i in range(c): 335 | self_c = x[:, i, :, :] 336 | self_c=self_c.unsqueeze(1) 337 | if i==0: 338 | A=self_c+self.neigh_Branch[i](x[:,i:i+self.neigh_number,:,:])#[64 1 21 1] 339 | if i>0: 340 | if i=start and i<=end: 343 | B= self_c + self.neigh_Branch[i](x[:, (i-start):(i-start+ self.neigh_number), :, :]) # [64 1 21 1] 344 | if i>end: 345 | B= self_c + self.neigh_Branch[i](x[:, c-self.neigh_number:c , :, :]) # [64 1 21 1] 346 | A = torch.cat((A,B),1) 347 | return A 348 | 349 | 350 | class ChannelAttention(nn.Module): 351 | def __init__(self, in_planes, ratio=16): 352 | super(ChannelAttention, self).__init__() 353 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 354 | self.max_pool = nn.AdaptiveMaxPool2d(1) 355 | 356 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False), 357 | nn.ReLU(), 358 | nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)) 359 | self.sigmoid = nn.Sigmoid() 360 | 361 | def forward(self, x): 362 | avg_out = self.fc(self.avg_pool(x)) 363 | max_out = self.fc(self.max_pool(x)) 364 | out = avg_out + max_out 365 | return self.sigmoid(out) 366 | 367 | 368 | 369 | class GCN(nn.Module): 370 | def __init__(self, height: int, width: int, changel: int, class_count: int): 371 | super(GCN, self).__init__() 372 | # 类别数,即网络最终输出通道数 373 | self.class_count = class_count # 类别数 374 | # 网络输入数据大小 375 | self.channel = changel # 200 376 | self.height = height # 145 377 | self.width = width # 145 378 | layers_count = 4 379 | # Superpixel-level Graph Sub-Network 380 | self.GCN_Branch = nn.Sequential() 381 | for i in range(layers_count): 382 | # self.GCN_Branch.add_module('GCN_Branch' + str(i), GCNLayer(self.channel, self.channel)) 383 | if i < layers_count - 1: 384 | if i==0: 385 | self.GCN_Branch.add_module('GCN_Branch' + str(i), GCNLayer(self.channel, 128)) 386 | else: 387 | self.GCN_Branch.add_module('GCN_Branch' + str(i), GCNLayer(128, 128)) 388 | else: 389 | self.GCN_Branch.add_module('GCN_Branch' + str(i), GCNLayer(128, 64)) 390 | # Softmax layer 391 | self.Softmax_linear = nn.Sequential(nn.Linear(64, self.class_count)) 392 | 393 | self.ca = ChannelAttention(64) 394 | self.neigh_C = neigh_Conv2(64,3) 395 | self.BN = nn.BatchNorm1d(64) 396 | 397 | def forward(self, x: torch.Tensor,A: torch.Tensor,indexs_train,band_patch): 398 | (batch,h, w, c) = x.shape 399 | _, in_num=indexs_train.shape 400 | 401 | H = torch.reshape(x,(batch,h*w, c)) # 145*145*200-21025*200 402 | for i in range(len(self.GCN_Branch)): 403 | H = self.GCN_Branch[i](H, A) 404 | # if i>0 and i=args.epoches*0.8: 137 | 138 | 139 | tr_net.eval() 140 | tar_v, pre_v = valid_epoch( tr_net, label_test_loader, criterion) 141 | OA2, AA_mean2, Kappa2, AA2 = output_metric(tar_v, pre_v) 142 | if OA2 >= best_OA2 : 143 | best_OA2 = OA2 144 | best_AA_mean2 = AA_mean2 145 | best_Kappa2 = Kappa2 146 | best_AA2 = AA2 147 | # run_results = metrics(best_OA2, best_AA_mean2, best_Kappa2,AA2) 148 | # show_results( 149 | # run_results,agregated=False) 150 | # results.append(run_results) 151 | toc = time.time() 152 | 153 | f = open('./result/' + args.dataset + '_results.txt', 'a+') 154 | 155 | str_results = '\n\n************************************************' \ 156 | + '\nseed_value={}'.format(seed_value) \ 157 | + '\nepoch={}'.format(epoch) \ 158 | + '\nPCA_band={}'.format(args.PCA_band) \ 159 | + '\nOA={:.2f}'.format(best_OA2*100) \ 160 | + '\nAA={:.2f}'.format(best_AA_mean2*100) \ 161 | + '\nKappa={:.2f}'.format(best_Kappa2*100) \ 162 | + '\nbest_AA2=' + str(np.around(best_AA2*100, 2)) 163 | 164 | f.write(str_results) 165 | f.close() 166 | 167 | print('\nbest_OA2={}'.format(best_OA2)) 168 | print('\nbest_AA_mean2={}'.format(best_AA_mean2)) 169 | print('\nbest_Kappa2={}'.format(best_Kappa2)) 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /MAIN_Mamba_sa.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import argparse 5 | import torch.nn as nn 6 | import torch.utils.data as Data 7 | import torch.backends.cudnn as cudnn 8 | from Graphormer import ViT,GCN 9 | from functions import metrics,show_results,train_and_test_data,train_epoch,valid_epoch,output_metric,applyPCA,GET_A2,get_data,normalize,GET_dis,get_edge_A 10 | import numpy as np 11 | import time 12 | import os 13 | from Mamba import VisionMamba 14 | 15 | parser = argparse.ArgumentParser("HSI") 16 | parser.add_argument('--dataset', choices=['Indian', 'PaviaU', 'Pavia', 'Salinas', 'KSC', 'Botswana', 'HoustonU', 'Houston'], 17 | default='Salinas', help='dataset to use') 18 | parser.add_argument('--mode', choices=['ViT', 'CAF'], default='CAF', help='mode choice') 19 | 20 | parser.add_argument("--num_run", type=int, default=10) 21 | parser.add_argument('--epoches', type=int, default=200, help='epoch number') 22 | parser.add_argument('--patches', type=int, default=15, help='number of patches')#奇数#ip11*11 sa 11*11 hu 7*7 23 | parser.add_argument('--PCA_band', type=int, default=30, help='pca_components')#40 94.11 50 94.77 60 93.84 70 93.17 24 | parser.add_argument('--learning_rate', type=float, default=5e-4, help='learning rate') 25 | parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') 26 | parser.add_argument('--gamma', type=float, default=0.9, help='gamma') 27 | 28 | parser.add_argument('--gpu_id', default='0', help='gpu id') 29 | parser.add_argument('--seed', type=int, default=0, help='number of seed') 30 | parser.add_argument('--batch_size', type=int, default=128, help='number of batch size') 31 | parser.add_argument('--test_freq', type=int, default=10, help='number of evaluation') 32 | args = parser.parse_args() 33 | 34 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 35 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#选择cpu或者GPU 36 | 37 | seed_value=1 38 | np.random.seed(seed_value) 39 | random.seed(seed_value) 40 | os.environ['PYTHONHASHSEED'] = str(seed_value) 41 | 42 | torch.manual_seed(seed_value) 43 | torch.cuda.manual_seed(seed_value) 44 | torch.cuda.manual_seed_all(seed_value) 45 | 46 | torch.backends.cudnn.deterministic = True 47 | 48 | 49 | # ------------------------------------------------------------------------------- 50 | # 定位训练和测试样本 51 | # Parameter Setting 52 | np.random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | torch.cuda.manual_seed(args.seed) 55 | cudnn.deterministic = True 56 | cudnn.benchmark = False 57 | # prepare data 58 | 59 | input, num_classes, total_pos_train, total_pos_test, total_pos_true, y_train, y_test, y_true = get_data(args.dataset) 60 | ##########得到原始图像 训练测试以及所有点坐标 每一类训练测试的个数############ 61 | ################################################################################################ 62 | # normalize data by band norm 63 | input = applyPCA(input, numComponents=args.PCA_band) 64 | ################################################################################################ 65 | input_normalize = normalize(input) 66 | height, width, band = input_normalize.shape # 145*145*200 67 | print("height={0},width={1},band={2}".format(height, width, band)) 68 | input_normalize = torch.from_numpy(input_normalize.astype(np.float32)).to(device) 69 | # ------------------------------------------------------------------------------- 70 | # obtain train and test data 71 | x_train_band, x_test_band, x_true_band, corner_train, corner_test, corner_true, center_pos_train,center_pos_test,center_pos_ture = train_and_test_data( 72 | input_normalize, band, total_pos_train, total_pos_test, total_pos_true, patch=args.patches, w=height, h=width) 73 | ##########得到训练测试以及所有点的光谱############ 74 | 75 | 76 | A_train, dgree_train,_,D2 = GET_A2(x_train_band, input_normalize, corner=corner_train,patches=args.patches , l=3,sigma=10) 77 | dis =GET_dis(args.patches,l=5) 78 | edge = get_edge_A(args.patches) 79 | 80 | y_train = torch.from_numpy(y_train).type(torch.LongTensor) # [695] 81 | Label_train = Data.TensorDataset(x_train_band, y_train, center_pos_train,A_train,dgree_train) 82 | 83 | A_test, dgree_test,D_max,D2= GET_A2(x_test_band, input_normalize, corner=corner_test, patches=args.patches ,l=3, sigma=10) 84 | 85 | y_test = torch.from_numpy(y_test).type(torch.LongTensor) # [9671] 86 | Label_test = Data.TensorDataset( x_test_band, y_test, center_pos_test,A_test,dgree_test) 87 | 88 | 89 | 90 | label_train_loader = Data.DataLoader(Label_train, batch_size=args.batch_size, shuffle=True) 91 | ##########训练集的光谱值及标签########## 92 | label_test_loader = Data.DataLoader(Label_test, batch_size=args.batch_size, shuffle=True) 93 | # ------------------------------------------------------------------------------- 94 | 95 | results = [] 96 | best_AA2 = [] 97 | for run in range(args.num_run): 98 | best_OA2 = 0.0 99 | best_AA_mean2 = 0.0 100 | best_Kappa2 = 0.0 101 | gcn_net = GCN(height, width, band, num_classes) 102 | gcn_net = gcn_net.cuda() 103 | # criterion 104 | criterion = nn.CrossEntropyLoss().cuda() 105 | # optimizer 106 | optimizer = torch.optim.Adam(gcn_net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 107 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epoches // 10, gamma=args.gamma) 108 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.epoches*0.5,args.epoches*0.75,args.epoches*0.9],gamma=0.2) # learning rate decay 109 | # ------------------------------------------------------------------------------- 110 | 111 | tr_net = VisionMamba( 112 | img_size=args.patches, 113 | depth=5, 114 | embed_dim=64, 115 | channels=band, 116 | num_classes=num_classes, 117 | rms_norm=True, residual_in_fp32=True, fused_add_norm=True, 118 | final_pool_type='all', if_abs_pos_embed=True, if_rope=False, if_rope_residual=True, bimamba_type="v2") 119 | tr_net = tr_net.cuda() 120 | # optimizer 121 | optimizer2 = torch.optim.Adam(tr_net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 122 | # scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=args.epoches//2, gamma=args.gamma) 123 | scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=args.epoches // 10, gamma=args.gamma) 124 | 125 | print("start training") 126 | tic = time.time() 127 | for epoch in range(args.epoches): 128 | scheduler.step() 129 | scheduler2.step() 130 | # train model 131 | gcn_net.train() 132 | tr_net.train() 133 | train_acc, train_obj, tar_t, pre_t = train_epoch(tr_net, label_train_loader, criterion, 134 | optimizer2) 135 | OA1, AA_mean1, Kappa1, AA1 = output_metric(tar_t, pre_t) 136 | if (epoch % args.test_freq == 0): 137 | print("Epoch: {:03d} train_loss: {:.4f} train_acc: {:.4f}".format(epoch + 1, train_obj, train_acc)) 138 | 139 | if (epoch % args.test_freq == 0) | (epoch == args.epoches - 1)and epoch>=args.epoches*0.9: 140 | 141 | gcn_net.eval() 142 | tr_net.eval() 143 | tar_v, pre_v = valid_epoch( tr_net, label_test_loader, criterion) 144 | OA2, AA_mean2, Kappa2, AA2 = output_metric(tar_v, pre_v) 145 | if OA2 >= best_OA2 : 146 | best_OA2 = OA2 147 | best_AA_mean2 = AA_mean2 148 | best_Kappa2 = Kappa2 149 | best_AA2 = AA2 150 | # run_results = metrics(best_OA2, best_AA_mean2, best_Kappa2,AA2) 151 | # show_results( 152 | # run_results,agregated=False) 153 | # results.append(run_results) 154 | toc = time.time() 155 | 156 | f = open('./result/' + args.dataset + '_results.txt', 'a+') 157 | 158 | str_results = '\n\n************************************************' \ 159 | + '\nseed_value={}'.format(seed_value) \ 160 | + '\nrun={}'.format(run) \ 161 | + '\nepoch={}'.format(epoch) \ 162 | + '\nPCA_band={}'.format(args.PCA_band) \ 163 | + '\nOA={:.2f}'.format(best_OA2*100) \ 164 | + '\nAA={:.2f}'.format(best_AA_mean2*100) \ 165 | + '\nKappa={:.2f}'.format(best_Kappa2*100) \ 166 | + '\nbest_AA2=' + str(np.around(best_AA2*100, 2)) 167 | 168 | 169 | f.write(str_results) 170 | f.close() 171 | 172 | print('\nbest_OA2={}'.format(best_OA2)) 173 | print('\nbest_AA_mean2={}'.format(best_AA_mean2)) 174 | print('\nbest_Kappa2={}'.format(best_Kappa2)) 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /MAIN_Mamba_uh2018.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import argparse 5 | import torch.nn as nn 6 | import torch.utils.data as Data 7 | import torch.backends.cudnn as cudnn 8 | from Graphormer import ViT,GCN 9 | from functions import metrics,show_results,train_and_test_data,train_epoch,valid_epoch,output_metric,applyPCA,GET_A2,get_data,normalize,GET_dis,get_edge_A 10 | import numpy as np 11 | import time 12 | import os 13 | from Mamba import VisionMamba 14 | 15 | parser = argparse.ArgumentParser("HSI") 16 | parser.add_argument('--dataset', choices=['Indian', 'PaviaU', 'Pavia', 'Salinas', 'KSC', 'Botswana', 'HoustonU', 'Houston'], 17 | default='HoustonU', help='dataset to use') 18 | parser.add_argument('--mode', choices=['ViT', 'CAF'], default='CAF', help='mode choice') 19 | 20 | parser.add_argument("--num_run", type=int, default=10) 21 | parser.add_argument('--epoches', type=int, default=120, help='epoch number') 22 | parser.add_argument('--patches', type=int, default=9, help='number of patches')#奇数#ip11*11 sa 11*11 hu 7*7 23 | parser.add_argument('--PCA_band', type=int, default=20, help='pca_components')#ip70 sa 70 hu 50 24 | parser.add_argument('--learning_rate', type=float, default=5e-4, help='learning rate') 25 | parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') 26 | parser.add_argument('--gamma', type=float, default=0.9, help='gamma') 27 | 28 | 29 | parser.add_argument('--seed', type=int, default=16, help='number of seed') 30 | parser.add_argument('--gpu_id', default='0', help='gpu id') 31 | parser.add_argument('--batch_size', type=int, default=128, help='number of batch size') 32 | parser.add_argument('--test_freq', type=int, default=10, help='number of evaluation') 33 | args = parser.parse_args() 34 | 35 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 36 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#选择cpu或者GPU 37 | 38 | seed_value=args.seed 39 | np.random.seed(seed_value) 40 | random.seed(seed_value) 41 | os.environ['PYTHONHASHSEED'] = str(seed_value) 42 | 43 | torch.manual_seed(seed_value) 44 | torch.cuda.manual_seed(seed_value) 45 | torch.cuda.manual_seed_all(seed_value) 46 | 47 | torch.backends.cudnn.deterministic = True 48 | 49 | 50 | # ------------------------------------------------------------------------------- 51 | # 定位训练和测试样本 52 | # Parameter Setting 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | torch.cuda.manual_seed(args.seed) 56 | cudnn.deterministic = True 57 | cudnn.benchmark = False 58 | # prepare data 59 | 60 | input, num_classes, total_pos_train, total_pos_test, total_pos_true, y_train, y_test, y_true = get_data(args.dataset) 61 | ##########得到原始图像 训练测试以及所有点坐标 每一类训练测试的个数############ 62 | ################################################################################################ 63 | # normalize data by band norm 64 | input = applyPCA(input, numComponents=args.PCA_band) 65 | ################################################################################################ 66 | input_normalize = normalize(input) 67 | height, width, band = input_normalize.shape # 145*145*200 68 | print("height={0},width={1},band={2}".format(height, width, band)) 69 | input_normalize = torch.from_numpy(input_normalize.astype(np.float32)).to(device) 70 | # ------------------------------------------------------------------------------- 71 | # obtain train and test data 72 | x_train_band, x_test_band, x_true_band, corner_train, corner_test, corner_true, center_pos_train,center_pos_test,center_pos_ture = train_and_test_data( 73 | input_normalize, band, total_pos_train, total_pos_test, total_pos_true, patch=args.patches, w=height, h=width) 74 | ##########得到训练测试以及所有点的光谱############ 75 | 76 | 77 | A_train, dgree_train,_,D2 = GET_A2(x_train_band, input_normalize, corner=corner_train,patches=args.patches , l=3,sigma=10) 78 | dis =GET_dis(args.patches,l=5) 79 | edge = get_edge_A(args.patches) 80 | 81 | y_train = torch.from_numpy(y_train).type(torch.LongTensor) # [695] 82 | Label_train = Data.TensorDataset(x_train_band, y_train, center_pos_train,A_train,dgree_train) 83 | 84 | A_test, dgree_test,D_max,D2= GET_A2(x_test_band, input_normalize, corner=corner_test, patches=args.patches ,l=3, sigma=10) 85 | 86 | y_test = torch.from_numpy(y_test).type(torch.LongTensor) # [9671] 87 | Label_test = Data.TensorDataset( x_test_band, y_test, center_pos_test,A_test,dgree_test) 88 | 89 | 90 | 91 | label_train_loader = Data.DataLoader(Label_train, batch_size=args.batch_size, shuffle=True) 92 | ##########训练集的光谱值及标签########## 93 | label_test_loader = Data.DataLoader(Label_test, batch_size=args.batch_size, shuffle=True) 94 | # ------------------------------------------------------------------------------- 95 | 96 | results = [] 97 | best_AA2 = [] 98 | for run in range(args.num_run): 99 | best_OA2 = 0.0 100 | best_AA_mean2 = 0.0 101 | best_Kappa2 = 0.0 102 | gcn_net = GCN(height, width, band, num_classes) 103 | gcn_net = gcn_net.cuda() 104 | # criterion 105 | criterion = nn.CrossEntropyLoss().cuda() 106 | # optimizer 107 | optimizer = torch.optim.Adam(gcn_net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 108 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epoches // 10, gamma=args.gamma) 109 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.epoches*0.5,args.epoches*0.75,args.epoches*0.9],gamma=0.2) # learning rate decay 110 | # ------------------------------------------------------------------------------- 111 | 112 | tr_net = VisionMamba( 113 | img_size=args.patches, 114 | depth=3, 115 | embed_dim=64, 116 | channels=band, 117 | num_classes=num_classes, 118 | rms_norm=True, residual_in_fp32=True, fused_add_norm=True, 119 | final_pool_type='all', if_abs_pos_embed=True, if_rope=False, if_rope_residual=True, bimamba_type="v2") 120 | tr_net = tr_net.cuda() 121 | # optimizer 122 | optimizer2 = torch.optim.Adam(tr_net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 123 | # scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=args.epoches//2, gamma=args.gamma) 124 | scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=args.epoches // 10, gamma=args.gamma) 125 | 126 | print("start training") 127 | tic = time.time() 128 | for epoch in range(args.epoches): 129 | scheduler.step() 130 | scheduler2.step() 131 | # train model 132 | gcn_net.train() 133 | tr_net.train() 134 | train_acc, train_obj, tar_t, pre_t = train_epoch(tr_net, label_train_loader, criterion, 135 | optimizer2) 136 | OA1, AA_mean1, Kappa1, AA1 = output_metric(tar_t, pre_t) 137 | if (epoch % args.test_freq == 0) | (epoch == args.epoches - 1): 138 | print("Epoch: {:03d} train_loss: {:.4f} train_acc: {:.4f}".format(epoch + 1, train_obj, train_acc)) 139 | 140 | if (epoch % args.test_freq == 0) | (epoch == args.epoches - 1)and epoch>=args.epoches*0.8: 141 | 142 | gcn_net.eval() 143 | tr_net.eval() 144 | tar_v, pre_v = valid_epoch( tr_net, label_test_loader, criterion) 145 | OA2, AA_mean2, Kappa2, AA2 = output_metric(tar_v, pre_v) 146 | if OA2 >= best_OA2 : 147 | best_OA2 = OA2 148 | best_AA_mean2 = AA_mean2 149 | best_Kappa2 = Kappa2 150 | best_AA2 = AA2 151 | # run_results = metrics(best_OA2, best_AA_mean2, best_Kappa2,AA2) 152 | # show_results( 153 | # run_results,agregated=False) 154 | # results.append(run_results) 155 | toc = time.time() 156 | 157 | f = open('./result/' + args.dataset + '_results.txt', 'a+') 158 | 159 | str_results = '\n\n************************************************' \ 160 | + '\nseed_value={}'.format(seed_value) \ 161 | + '\nepoch={}'.format(epoch) \ 162 | + '\nPCA_band={}'.format(args.PCA_band) \ 163 | + '\nOA={:.2f}'.format(best_OA2*100) \ 164 | + '\nAA={:.2f}'.format(best_AA_mean2*100) \ 165 | + '\nKappa={:.2f}'.format(best_Kappa2*100) \ 166 | + '\nbest_AA2=' + str(np.around(best_AA2*100, 2)) 167 | 168 | 169 | f.write(str_results) 170 | f.close() 171 | 172 | print('\nbest_OA2={}'.format(best_OA2)) 173 | print('\nbest_AA_mean2={}'.format(best_AA_mean2)) 174 | print('\nbest_Kappa2={}'.format(best_Kappa2)) 175 | 176 | if args.num_run > 1: 177 | show_results(results, agregated=True) 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /Mamba.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | from torch import Tensor 7 | from typing import Optional 8 | 9 | from timm.models.vision_transformer import VisionTransformer, _cfg 10 | from timm.models.registry import register_model 11 | from timm.models.layers import trunc_normal_ 12 | 13 | from timm.models.layers import DropPath, PatchEmbed 14 | from timm.models.vision_transformer import _load_weights 15 | 16 | import math 17 | from GCN import GCN 18 | 19 | from collections import namedtuple 20 | 21 | from mamba_ssm.modules.mamba_simple import Mamba 22 | from mamba_ssm.utils.generation import GenerationMixin 23 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 24 | 25 | from rope import * 26 | import random 27 | 28 | try: 29 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 30 | except ImportError: 31 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 32 | 33 | 34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 35 | 36 | 37 | class Block(nn.Module): 38 | def __init__( 39 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0., 40 | ): 41 | """ 42 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 43 | 44 | This Block has a slightly different structure compared to a regular 45 | prenorm Transformer block. 46 | The standard block is: LN -> MHA/MLP -> Add. 47 | [Ref: https://arxiv.org/abs/2002.04745] 48 | Here we have: Add -> LN -> Mixer, returning both 49 | the hidden_states (output of the mixer) and the residual. 50 | This is purely for performance reasons, as we can fuse add and LayerNorm. 51 | The residual needs to be provided (except for the very first block). 52 | """ 53 | super().__init__() 54 | self.residual_in_fp32 = residual_in_fp32 55 | self.fused_add_norm = fused_add_norm 56 | self.mixer = mixer_cls(dim) 57 | self.norm = norm_cls(dim) 58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 59 | if self.fused_add_norm: 60 | assert RMSNorm is not None, "RMSNorm import fails" 61 | assert isinstance( 62 | self.norm, (nn.LayerNorm, RMSNorm) 63 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 64 | 65 | def forward( 66 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 67 | ): 68 | r"""Pass the input through the encoder layer. 69 | 70 | Args: 71 | hidden_states: the sequence to the encoder layer (required). 72 | residual: hidden_states = Mixer(LN(residual)) 73 | """ 74 | if not self.fused_add_norm: 75 | if residual is None: 76 | residual = hidden_states 77 | else: 78 | residual = residual + self.drop_path(hidden_states) 79 | 80 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 81 | if self.residual_in_fp32: 82 | residual = residual.to(torch.float32) 83 | else: 84 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 85 | if residual is None: 86 | hidden_states, residual = fused_add_norm_fn( 87 | hidden_states, 88 | self.norm.weight, 89 | self.norm.bias, 90 | residual=residual, 91 | prenorm=True, 92 | residual_in_fp32=self.residual_in_fp32, 93 | eps=self.norm.eps, 94 | ) 95 | else: 96 | hidden_states, residual = fused_add_norm_fn( 97 | self.drop_path(hidden_states), 98 | self.norm.weight, 99 | self.norm.bias, 100 | residual=residual, 101 | prenorm=True, 102 | residual_in_fp32=self.residual_in_fp32, 103 | eps=self.norm.eps, 104 | ) 105 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 106 | return hidden_states, residual 107 | 108 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 109 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 110 | 111 | 112 | def create_block( 113 | d_model, 114 | ssm_cfg=None, 115 | norm_epsilon=1e-5, 116 | drop_path=0., 117 | rms_norm=False, 118 | residual_in_fp32=False, 119 | fused_add_norm=False, 120 | layer_idx=None, 121 | device=None, 122 | dtype=None, 123 | bimamba_type="none", 124 | ): 125 | if ssm_cfg is None: 126 | ssm_cfg = {} 127 | factory_kwargs = {"device": device, "dtype": dtype} 128 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 129 | norm_cls = partial( 130 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 131 | ) 132 | block = Block( 133 | d_model, 134 | mixer_cls, 135 | norm_cls=norm_cls, 136 | drop_path=drop_path, 137 | fused_add_norm=fused_add_norm, 138 | residual_in_fp32=residual_in_fp32, 139 | ) 140 | block.layer_idx = layer_idx 141 | return block 142 | 143 | 144 | 145 | 146 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 147 | def _init_weights( 148 | module, 149 | n_layer, 150 | initializer_range=0.02, # Now only used for embedding layer. 151 | rescale_prenorm_residual=True, 152 | n_residuals_per_layer=1, # Change to 2 if we have MLP 153 | ): 154 | if isinstance(module, nn.Linear): 155 | if module.bias is not None: 156 | if not getattr(module.bias, "_no_reinit", False): 157 | nn.init.zeros_(module.bias) 158 | elif isinstance(module, nn.Embedding): 159 | nn.init.normal_(module.weight, std=initializer_range) 160 | 161 | if rescale_prenorm_residual: 162 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 163 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 164 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 165 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 166 | # 167 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 168 | for name, p in module.named_parameters(): 169 | if name in ["out_proj.weight", "fc2.weight"]: 170 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 171 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 172 | # We need to reinit p since this code could be called multiple times 173 | # Having just p *= scale would repeatedly scale it down 174 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 175 | with torch.no_grad(): 176 | p /= math.sqrt(n_residuals_per_layer * n_layer) 177 | 178 | 179 | def segm_init_weights(m): 180 | if isinstance(m, nn.Linear): 181 | trunc_normal_(m.weight, std=0.02) 182 | if isinstance(m, nn.Linear) and m.bias is not None: 183 | nn.init.constant_(m.bias, 0) 184 | elif isinstance(m, nn.LayerNorm): 185 | nn.init.constant_(m.bias, 0) 186 | nn.init.constant_(m.weight, 1.0) 187 | 188 | class VisionMamba(nn.Module): 189 | def __init__(self, 190 | img_size=224, 191 | depth=24, 192 | embed_dim=192, 193 | channels=3, 194 | num_classes=1000, 195 | ssm_cfg=None, 196 | drop_rate=0., 197 | drop_path_rate=0.1, 198 | norm_epsilon: float = 1e-5, 199 | rms_norm: bool = False, 200 | initializer_cfg=None, 201 | fused_add_norm=False, 202 | residual_in_fp32=False, 203 | device=None, 204 | dtype=None, 205 | ft_seq_len=None, 206 | pt_hw_seq_len=14, 207 | final_pool_type='none', 208 | if_abs_pos_embed=False, 209 | if_rope=False, 210 | if_rope_residual=False, 211 | bimamba_type="none", 212 | if_cls_token=False, 213 | **kwargs): 214 | factory_kwargs = {"device": device, "dtype": dtype} 215 | # add factory_kwargs into kwargs 216 | kwargs.update(factory_kwargs) 217 | super().__init__() 218 | self.residual_in_fp32 = residual_in_fp32 219 | self.fused_add_norm = fused_add_norm 220 | self.final_pool_type = final_pool_type 221 | self.if_abs_pos_embed = if_abs_pos_embed 222 | self.if_rope = if_rope 223 | self.if_rope_residual = if_rope_residual 224 | self.if_cls_token = if_cls_token 225 | self.num_tokens = 1 if if_cls_token else 0 226 | 227 | # pretrain parameters 228 | self.num_classes = num_classes 229 | self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 230 | 231 | # self.patch_embed = PatchEmbed( 232 | # img_size=img_size, patch_size=patch_size, in_chans=channels, embed_dim=embed_dim) 233 | # num_patches = self.patch_embed.num_patches 234 | 235 | self.patch_to_embedding = nn.Linear(channels, embed_dim) 236 | num_patches = img_size*img_size 237 | 238 | if if_cls_token: 239 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 240 | 241 | if if_abs_pos_embed: 242 | # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim)) 243 | self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, self.embed_dim)) # 1,201,64 244 | self.pos_drop = nn.Dropout(p=drop_rate) 245 | 246 | if if_rope: 247 | half_head_dim = embed_dim // 2 248 | hw_seq_len = img_size 249 | self.rope = VisionRotaryEmbeddingFast( 250 | dim=half_head_dim, 251 | pt_seq_len=pt_hw_seq_len, 252 | ft_seq_len=hw_seq_len 253 | ) 254 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 255 | 256 | 257 | # TODO: release this comment 258 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 259 | # import ipdb;ipdb.set_trace() 260 | inter_dpr = [0.0] + dpr 261 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 262 | # transformer blocks 263 | self.layers = nn.ModuleList( 264 | [ 265 | create_block( 266 | embed_dim, 267 | ssm_cfg=ssm_cfg, 268 | norm_epsilon=norm_epsilon, 269 | rms_norm=rms_norm, 270 | residual_in_fp32=residual_in_fp32, 271 | fused_add_norm=fused_add_norm, 272 | layer_idx=i, 273 | bimamba_type=bimamba_type, 274 | drop_path=inter_dpr[i], 275 | **factory_kwargs, 276 | ) 277 | for i in range(depth) 278 | ] 279 | ) 280 | self.layer_GCN = nn.Sequential() 281 | 282 | for i in range(depth): 283 | self.layer_GCN.add_module('GCN_Branch' + str(i), GCN(height=img_size, width=img_size, changel= embed_dim, layers_count=3)) 284 | 285 | 286 | # output head 287 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 288 | embed_dim, eps=norm_epsilon, **factory_kwargs 289 | ) 290 | 291 | self.pre_logits = nn.Identity() 292 | 293 | # original init 294 | self.apply(segm_init_weights) 295 | self.head.apply(segm_init_weights) 296 | if if_abs_pos_embed: 297 | trunc_normal_(self.pos_embed, std=.02) 298 | 299 | # mamba init 300 | self.apply( 301 | partial( 302 | _init_weights, 303 | n_layer=depth, 304 | **(initializer_cfg if initializer_cfg is not None else {}), 305 | ) 306 | ) 307 | 308 | self.skipcat = nn.ModuleList([]) 309 | for _ in range(depth-2): 310 | self.skipcat.append(nn.Conv2d(img_size*img_size, img_size*img_size, [1, 2], 1, 0)) 311 | 312 | 313 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 314 | return { 315 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 316 | for i, layer in enumerate(self.layers) 317 | } 318 | 319 | 320 | 321 | def forward_features(self, x, batch_A, inference_params=None): 322 | 323 | # x = self.patch_embed(x)#(64,196,192) (64,3,224,224)-(64,192,14,14)-(64,196,192) (B,H*W,C) 324 | 325 | x = self.patch_to_embedding(x) 326 | 327 | # if self.if_abs_pos_embed: 328 | # x = x + self.pos_embed 329 | # x = self.pos_drop(x) 330 | 331 | 332 | # mamba impl 333 | residual = None 334 | hidden_states = x 335 | 336 | last_output = [] 337 | nl = 0 338 | 339 | # for layer in self.layers: 340 | # last_output.append(hidden_states) 341 | # if nl > 1: 342 | # hidden_states = self.skipcat[nl-2](torch.cat([hidden_states.unsqueeze(3), last_output[nl-2].unsqueeze(3)], dim=3)).squeeze(3) 343 | # hidden_states, residual = layer( 344 | # hidden_states, residual, inference_params=inference_params 345 | # ) 346 | # nl += 1 347 | 348 | for (Mamba,GCN) in zip(self.layers,self.layer_GCN): 349 | last_output.append(hidden_states) 350 | if nl > 1: 351 | hidden_states = self.skipcat[nl-2](torch.cat([hidden_states.unsqueeze(3), last_output[nl-2].unsqueeze(3)], dim=3)).squeeze(3) 352 | hidden_states, residual = Mamba( 353 | hidden_states, residual, inference_params=inference_params 354 | ) 355 | hidden_states = GCN(hidden_states, batch_A) 356 | 357 | nl += 1 358 | 359 | 360 | if not self.fused_add_norm: 361 | if residual is None: 362 | residual = hidden_states 363 | else: 364 | residual = residual + self.drop_path(hidden_states) 365 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 366 | else: 367 | # Set prenorm=False here since we don't need the residual 368 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 369 | hidden_states = fused_add_norm_fn( 370 | self.drop_path(hidden_states), 371 | self.norm_f.weight, 372 | self.norm_f.bias, 373 | eps=self.norm_f.eps, 374 | residual=residual, 375 | prenorm=False, 376 | residual_in_fp32=self.residual_in_fp32, 377 | ) 378 | 379 | return hidden_states 380 | 381 | 382 | def forward(self, x, center_pos, batch_A, return_features=False, inference_params=None): 383 | # x=x.permute(0, 3, 1, 2) 384 | x = self.forward_features(x, batch_A, inference_params)#(64,3,224,224)-(64,192) 385 | if return_features: 386 | return x 387 | x = self.head(x)#(64,10) 388 | 389 | batch, _, _ = x.shape 390 | x_out=torch.zeros((batch, self.num_classes),dtype=float).to(device) 391 | for i in range(batch): 392 | x_out[i]=x[i,center_pos[i],:] 393 | 394 | return x_out 395 | 396 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphMamba: An Efficient Graph Structure Learning Vision Mamba for Hyperspectral Image Classification 2 | 3 | Aitao Yang; Min Li; Yao Ding; Leyuan Fang; Yaoming Cai; Yujie He 4 | 5 | ___________ 6 | 7 | The code in this toolbox implements the ["GraphMamba: An Efficient Graph Structure Learning Vision Mamba for Hyperspectral Image Classification"]( https://ieeexplore.ieee.org/document/10746459). 8 | 9 | 10 | 11 | Citation 12 | --------------------- 13 | 14 | **Please kindly cite the papers if this code is useful and helpful for your research.** 15 | 16 | A. Yang, M. Li, Y. Ding, L. Fang, Y. Cai and Y. He, "GraphMamba: An Efficient Graph Structure Learning Vision Mamba for Hyperspectral Image Classification," in IEEE Transactions on Geoscience and Remote Sensing, vol. 62, pp. 1-14, 2024, Art no. 5537414, doi: 10.1109/TGRS.2024.3493101. 17 | 18 | @ARTICLE{10746459, 19 | author={Yang, Aitao and Li, Min and Ding, Yao and Fang, Leyuan and Cai, Yaoming and He, Yujie}, 20 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 21 | title={GraphMamba: An Efficient Graph Structure Learning Vision Mamba for Hyperspectral Image Classification}, 22 | year={2024}, 23 | volume={62}, 24 | number={}, 25 | pages={1-14}, 26 | keywords={Feature extraction;Transformers;Semantics;Data mining;Vectors;Hyperspectral imaging;Computational efficiency;Training;Kernel;Encoding;Graph convolutional network (GCN);hyperspectral image (HSI) classification;mamba;remote sensing;state space model (SSM)}, 27 | doi={10.1109/TGRS.2024.3493101}} 28 | 29 | 30 | System-specific notes 31 | --------------------- 32 | The codes of networks were tested using PyTorch 2.1.1 version (CUDA 11.8) in Python 3.8 on Ubuntu system. 33 | 34 | How to use it? 35 | --------------------- 36 | Directly run **GraphMamba.py** functions with different network parameter settings to produce the results. Please note that due to the randomness of the parameter initialization, the experimental results might have slightly different from those reported in the paper. 37 | 38 | For the datasets: 39 | Add your dataset path to function “load_dataset” in function.py 40 | 41 | On the Indian Pines dataset, you can either re-train by following: 42 | `python MAIN_Mamba_ip.py` 43 | 44 | On the Salinas dataset, you can either re-train by following: 45 | `python MAIN_Mamba_sa.py` 46 | 47 | On the UH2018 dataset, you can either re-train by following: 48 | `python MAIN_Mamba_uh2018.py` 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /code_acc.py: -------------------------------------------------------------------------------- 1 | parser = argparse.ArgumentParser("HSI") 2 | parser.add_argument('--dataset', choices=['Indian', 'PaviaU', 'Pavia', 'Salinas', 'KSC', 'Botswana', 'Houston'], 3 | default='Indian', help='dataset to use') 4 | parser.add_argument('--flag_test', choices=['test', 'train'], default='train', help='testing mark') 5 | parser.add_argument('--mode', choices=['ViT', 'CAF'], default='CAF', help='mode choice') 6 | 7 | parser.add_argument('--epoches', type=int, default=200, help='epoch number') 8 | parser.add_argument('--patches', type=int, default=9, help='number of patches')#奇数 9 | parser.add_argument('--band_patches', type=int, default=3, help='number of related band')#奇数 10 | parser.add_argument('--n_gcn', type=int, default=21, help='number of related pix') 11 | parser.add_argument('--pca_band', type=int, default=70, help='pca_components') 12 | 13 | parser.add_argument('--learning_rate', type=float, default=5e-4, help='learning rate') 14 | parser.add_argument('--gamma', type=float, default=0.9, help='gamma') 15 | parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') 16 | parser.add_argument('--gpu_id', default='0', help='gpu id') 17 | parser.add_argument('--seed', type=int, default=0, help='number of seed') 18 | parser.add_argument('--batch_size', type=int, default=64, help='number of batch size') 19 | parser.add_argument('--test_freq', type=int, default=10, help='number of evaluation') 20 | # Final result: 21 | # OA: 0.9287 | AA: 0.9552 | Kappa: 0.9186 22 | # [0.96774194 0.86051502 0.90375 0.98550725 0.92715232 0.97714286 23 | # 1. 1. 1. 0.93524416 0.87092784 0.93428064 24 | # 1. 0.99271255 0.96067416 0.96825397] 25 | # ************************************************** 26 | 27 | 28 | parser.add_argument('--epoches', type=int, default=300, help='epoch number') 29 | parser.add_argument('--patches', type=int, default=11, help='number of patches')#奇数 30 | parser.add_argument('--band_patches', type=int, default=3, help='number of related band')#奇数 31 | parser.add_argument('--n_gcn', type=int, default=21, help='number of related pix') 32 | parser.add_argument('--pca_band', type=int, default=70, help='pca_components') 33 | parser.add_argument('--weight_decay', type=float, default=5e-3, help='weight_decay') 34 | 35 | # ************************************************** 36 | # Final result: 37 | # OA: 0.9314 | AA: 0.9551 | Kappa: 0.9217 38 | # [0.96774194 0.91917024 0.91375 0.97101449 0.92935982 0.96428571 39 | # 0.92307692 1. 1. 0.89278132 0.85814433 0.88987567 40 | # 0.97714286 0.99109312 0.99719101 0.98412698] 41 | # ************************************************** 42 | 43 | 44 | parser.add_argument('--epoches', type=int, default=300, help='epoch number') 45 | parser.add_argument('--patches', type=int, default=9, help='number of patches')#奇数 46 | parser.add_argument('--band_patches', type=int, default=3, help='number of related band')#奇数 47 | parser.add_argument('--n_gcn', type=int, default=21, help='number of related pix') 48 | parser.add_argument('--pca_band', type=int, default=70, help='pca_components') 49 | parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') 50 | # ************************************************** 51 | # Final result: 52 | # OA: 0.9379 | AA: 0.9652 | Kappa: 0.9291 53 | # [1. 0.88841202 0.90375 1. 0.93818985 0.97571429 54 | # 1. 0.99776786 1. 0.93949045 0.9142268 0.89165187 55 | # 1. 0.99757085 0.99719101 1. ] 56 | # ************************************************** 57 | 58 | 59 | 60 | parser.add_argument('--epoches', type=int, default=300, help='epoch number') 61 | parser.add_argument('--patches', type=int, default=9, help='number of patches')#奇数 62 | parser.add_argument('--band_patches', type=int, default=3, help='number of related band')#奇数 63 | parser.add_argument('--n_gcn', type=int, default=21, help='number of related pix') 64 | parser.add_argument('--pca_band', type=int, default=70, help='pca_components') 65 | parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') 66 | 67 | #######GCN为四层############## 68 | 69 | # ************************************************** 70 | # Final result: 71 | # OA: 0.9416 | AA: 0.9651 | Kappa: 0.9333 72 | # [1. 0.83690987 0.9075 0.99516908 0.92273731 0.97285714 73 | # 0.84615385 1. 1. 0.9596603 0.88412371 0.90586146 74 | # 0.98857143 0.99757085 0.99438202 0.95238095] 75 | # ************************************************** 76 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import confusion_matrix 2 | from sklearn.decomposition import PCA 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch 6 | import math 7 | from sklearn import preprocessing 8 | import h5py 9 | 10 | 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#选择cpu或者GPU 12 | 13 | def chooose_train_and_test_point(train_data, test_data, true_data, num_classes): 14 | number_train = [] 15 | pos_train = {} 16 | number_test = [] 17 | pos_test = {} 18 | number_true = [] 19 | pos_true = {} 20 | #-------------------------for train data------------------------------------ 21 | for i in range(num_classes): 22 | each_class = [] 23 | each_class = np.argwhere(train_data==(i+1)) 24 | number_train.append(each_class.shape[0]) 25 | pos_train[i] = each_class 26 | 27 | total_pos_train = pos_train[0] 28 | for i in range(1, num_classes): 29 | total_pos_train = np.r_[total_pos_train, pos_train[i]] #(695,2) 30 | total_pos_train = total_pos_train.astype(int) 31 | #--------------------------for test data------------------------------------ 32 | for i in range(num_classes): 33 | each_class = [] 34 | each_class = np.argwhere(test_data==(i+1)) 35 | number_test.append(each_class.shape[0]) 36 | pos_test[i] = each_class 37 | 38 | total_pos_test = pos_test[0] 39 | for i in range(1, num_classes): 40 | total_pos_test = np.r_[total_pos_test, pos_test[i]] #(9671,2) 41 | total_pos_test = total_pos_test.astype(int) 42 | #--------------------------for true data------------------------------------ 43 | for i in range(num_classes+1): 44 | each_class = [] 45 | each_class = np.argwhere(true_data==i) 46 | number_true.append(each_class.shape[0]) 47 | pos_true[i] = each_class 48 | 49 | total_pos_true = pos_true[0] 50 | for i in range(1, num_classes+1): 51 | total_pos_true = np.r_[total_pos_true, pos_true[i]] 52 | total_pos_true = total_pos_true.astype(int) 53 | 54 | return total_pos_train, total_pos_test, total_pos_true, number_train, number_test, number_true 55 | #------------------------------------------------------------------------------- 56 | # 边界拓展:镜像 57 | def mirror_hsi(height,width,band,input_normalize,patch=5): 58 | padding=patch//2 59 | mirror_hsi=np.zeros((height+2*padding,width+2*padding,band),dtype=float) 60 | #中心区域 61 | mirror_hsi[padding:(padding+height),padding:(padding+width),:]=input_normalize 62 | #左边镜像 63 | for i in range(padding): 64 | mirror_hsi[padding:(height+padding),i,:]=input_normalize[:,padding-i-1,:] 65 | #右边镜像 66 | for i in range(padding): 67 | mirror_hsi[padding:(height+padding),width+padding+i,:]=input_normalize[:,width-1-i,:] 68 | #上边镜像 69 | for i in range(padding): 70 | mirror_hsi[i,:,:]=mirror_hsi[padding*2-i-1,:,:] 71 | #下边镜像 72 | for i in range(padding): 73 | mirror_hsi[height+padding+i,:,:]=mirror_hsi[height+padding-1-i,:,:] 74 | 75 | print("**************************************************") 76 | print("patch is : {}".format(patch)) 77 | print("mirror_image shape : [{0},{1},{2}]".format(mirror_hsi.shape[0],mirror_hsi.shape[1],mirror_hsi.shape[2])) 78 | print("**************************************************") 79 | return mirror_hsi 80 | #------------------------------------------------------------------------------- 81 | # 排序取索引 82 | def choose_top(image,cornor_index,x,y,patch,b,n_top): 83 | sort = image.reshape(patch * patch, b) 84 | sort = torch.from_numpy(sort).type(torch.FloatTensor) 85 | pos = (x - cornor_index[0]) * patch + (y - cornor_index[1]) 86 | Q = torch.sum(torch.pow(sort[pos] - sort, 2), dim=1) 87 | _, indices = Q.topk(k=n_top, dim=0, largest=False, sorted=True) 88 | return indices 89 | #------------------------------------------------------------------------------- 90 | # 获取patch的图像数据 91 | def gain_neighborhood_pixel(pca_image, point, i, patch, W, H): 92 | x = point[i,0] 93 | y = point[i,1] 94 | m=int((patch-1)/2)##patch奇数 95 | _,_,b=pca_image.shape 96 | if x<=m: 97 | if y<=m: 98 | temp_image = pca_image[0:patch, 0:patch, :] 99 | cornor_index = [0,0] 100 | if y>=(H-m): 101 | temp_image = pca_image[0:patch, H-patch:H, :] 102 | cornor_index = [0, H-patch] 103 | if y>m and y=(W-m): 107 | if y<=m: 108 | temp_image = pca_image[W-patch:W, 0:patch, :] 109 | cornor_index = [W-patch, 0] 110 | if y>=(H-m): 111 | temp_image = pca_image[W-patch:W, H-patch:H, :] 112 | cornor_index = [W - patch, H-patch] 113 | if y>m and ym and x=(H-m): 121 | temp_image = pca_image[x-m:x+m+1, H-patch:H, :] 122 | cornor_index = [x - m, H-patch] 123 | if y>m and y= 0 and x2 < len(grid) and y2 >= 0 and y2 < len(grid[0]): # 判断可否通过那个点 669 | if close_matrix[x2][y2] == 0 and grid[x2][y2] == 0: 670 | g2 = g + cost 671 | f2 = g2 + heuristic[x2][y2] 672 | cell.append([f2, g2, x2, y2]) 673 | close_matrix[x2][y2] = 1 674 | action_matrix[x2][y2] = i 675 | invpath = [] 676 | x = target_point[0] 677 | y = target_point[1] 678 | invpath.append([x, y]) # we get the reverse path from here 679 | while x != begin_point[0] or y != begin_point[1]: 680 | x2 = x - delta[action_matrix[x][y]][0] 681 | y2 = y - delta[action_matrix[x][y]][1] 682 | x = x2 683 | y = y2 684 | invpath.append([x, y]) 685 | 686 | path = [] 687 | for i in range(len(invpath)): 688 | path.append(invpath[len(invpath) - 1 - i]) 689 | return path, action_matrix 690 | 691 | def path_search(grid,begin,target): 692 | a_star_path, action_matrix = a_star_search(grid, begin, target) 693 | a_star_path=np.array(a_star_path) 694 | edge_number=np.zeros(4) 695 | for i in range(a_star_path.shape[0]-1): 696 | x= a_star_path[i+1,0] - a_star_path[i,0] 697 | y = a_star_path[i+1, 1] - a_star_path[i, 1] 698 | if x==-1 and y==0: 699 | edge_number[0]=edge_number[0]+1 700 | if x==0 and y==-1: 701 | edge_number[1]=edge_number[1]+1 702 | if x==1 and y==0: 703 | edge_number[2]=edge_number[2]+1 704 | if x==0 and y==1: 705 | edge_number[3]=edge_number[3]+1 706 | return edge_number 707 | 708 | def GET_dis(patches,l): # l为邻域范围,sigma为计算距离的参数 709 | dis=torch.zeros((patches*patches,patches*patches),dtype=torch.int64) 710 | h=patches 711 | w=patches 712 | center=(int)(l-1)/2 713 | for i in range(h): # 图像的行 h代表有几行,w代表有几列 714 | for j in range(w): # 图像的列 715 | m = int(i * w + j) # 在邻接矩阵中的行数 716 | for k in range(l): # 邻域的行数 717 | for q in range(l): # 邻域的列数 718 | n = int((i + (k - (l - 1) / 2)) * w + (j + (q - (l - 1) / 2))) # 计算邻域,并转换为邻域在邻接矩阵中的列数 719 | if 0 <= i + (k - (l - 1) / 2) < h and 0 <= (j + (q - (l - 1) / 2)) < w : 720 | if abs(k-center)==0 and abs(q-center)==0: 721 | dis[m, n] = 1 722 | if (abs(k-center)==0 and abs(q-center)==1) or (abs(k-center)==1 and abs(q-center)==0): 723 | dis[m, n] = 2 724 | if abs(k-center)==1 and abs(q-center)==1: 725 | dis[m, n] = 3 726 | if (abs(k-center)==0 and abs(q-center)==2) or (abs(k-center)==2 and abs(q-center)==0): 727 | dis[m, n] = 4 728 | if (abs(k-center)==2 and abs(q-center)==1) or (abs(k-center)==1 and abs(q-center)==2): 729 | dis[m, n] = 5 730 | if abs(k-center)==2 and abs(q-center)==2: 731 | dis[m, n] = 6 732 | return dis.cuda() 733 | 734 | def get_edge_A(patch): 735 | edge_A=np.zeros((patch*patch,patch*patch,4)) 736 | grid = [[0 for j in range(patch)] for i in range(patch)] 737 | begin=[0,0] 738 | target=[0,0] 739 | for i in range(patch*patch): 740 | for j in range(patch*patch): 741 | begin[0] = i // patch 742 | begin[1] = i % patch 743 | target[0] = j // patch 744 | target[1] = j % patch 745 | edge_A[i,j,:]=path_search(grid,begin,target) 746 | edge_A=torch.from_numpy(edge_A).cuda() 747 | edge_A=torch.reshape(edge_A,(patch*patch*patch*patch,4)) 748 | M = torch.sum(edge_A, dim=1) 749 | for i in range(patch*patch*patch*patch): 750 | if M[i]!=0: 751 | edge_A[i,:]=edge_A[i,:]/ M[i] 752 | edge_A=torch.reshape(edge_A,(patch*patch,patch*patch,4)) 753 | return edge_A.type(torch.float32) 754 | 755 | 756 | 757 | 758 | 759 | -------------------------------------------------------------------------------- /rope.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # EVA-02: A Visual Representation for Neon Genesis 3 | # Github source: https://github.com/baaivision/EVA/EVA02 4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Yuxin Fang 7 | # 8 | # Based on https://github.com/lucidrains/rotary-embedding-torch 9 | # --------------------------------------------------------' 10 | 11 | from math import pi 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | 20 | def broadcat(tensors, dim = -1): 21 | num_tensors = len(tensors) 22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 24 | shape_len = list(shape_lens)[0] 25 | dim = (dim + shape_len) if dim < 0 else dim 26 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 31 | expanded_dims.insert(dim, (dim, dims[dim])) 32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 34 | return torch.cat(tensors, dim = dim) 35 | 36 | 37 | 38 | def rotate_half(x): 39 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 40 | x1, x2 = x.unbind(dim = -1) 41 | x = torch.stack((-x2, x1), dim = -1) 42 | return rearrange(x, '... d r -> ... (d r)') 43 | 44 | 45 | 46 | class VisionRotaryEmbedding(nn.Module): 47 | def __init__( 48 | self, 49 | dim, 50 | pt_seq_len, 51 | ft_seq_len=None, 52 | custom_freqs = None, 53 | freqs_for = 'lang', 54 | theta = 10000, 55 | max_freq = 10, 56 | num_freqs = 1, 57 | ): 58 | super().__init__() 59 | if custom_freqs: 60 | freqs = custom_freqs 61 | elif freqs_for == 'lang': 62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 63 | elif freqs_for == 'pixel': 64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 65 | elif freqs_for == 'constant': 66 | freqs = torch.ones(num_freqs).float() 67 | else: 68 | raise ValueError(f'unknown modality {freqs_for}') 69 | 70 | if ft_seq_len is None: ft_seq_len = pt_seq_len 71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 72 | 73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 75 | 76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 78 | 79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 80 | 81 | self.register_buffer("freqs_cos", freqs.cos()) 82 | self.register_buffer("freqs_sin", freqs.sin()) 83 | 84 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 85 | 86 | def forward(self, t, start_index = 0): 87 | rot_dim = self.freqs_cos.shape[-1] 88 | end_index = start_index + rot_dim 89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 92 | return torch.cat((t_left, t, t_right), dim = -1) 93 | 94 | 95 | 96 | class VisionRotaryEmbeddingFast(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | pt_seq_len=16, 101 | ft_seq_len=None, 102 | custom_freqs = None, 103 | freqs_for = 'lang', 104 | theta = 10000, 105 | max_freq = 10, 106 | num_freqs = 1, 107 | ): 108 | super().__init__() 109 | if custom_freqs: 110 | freqs = custom_freqs 111 | elif freqs_for == 'lang': 112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 113 | elif freqs_for == 'pixel': 114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 115 | elif freqs_for == 'constant': 116 | freqs = torch.ones(num_freqs).float() 117 | else: 118 | raise ValueError(f'unknown modality {freqs_for}') 119 | 120 | if ft_seq_len is None: ft_seq_len = pt_seq_len 121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 122 | 123 | freqs = torch.einsum('..., f -> ... f', t, freqs) 124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 126 | 127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 129 | 130 | self.register_buffer("freqs_cos", freqs_cos) 131 | self.register_buffer("freqs_sin", freqs_sin) 132 | 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 134 | 135 | def forward(self, t): 136 | # if t.shape[1] % 2 != 0: 137 | # t_spatial = t[:, 1:, :] 138 | # t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin 139 | # return torch.cat((t[:, :1, :], t_spatial), dim=1) 140 | # else: 141 | # return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 142 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 143 | --------------------------------------------------------------------------------