├── DFINet.ipynb └── README.md /DFINet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled3.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "metadata": { 22 | "id": "suRi9jsX-mbK", 23 | "colab": { 24 | "base_uri": "https://localhost:8080/" 25 | }, 26 | "outputId": "16151316-2241-4ae9-8b8e-6e2ea85b70d1" 27 | }, 28 | "source": [ 29 | "pip install spectral" 30 | ], 31 | "execution_count": null, 32 | "outputs": [ 33 | { 34 | "output_type": "stream", 35 | "text": [ 36 | "Requirement already satisfied: spectral in /usr/local/lib/python3.7/dist-packages (0.22.2)\n", 37 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from spectral) (1.19.5)\n" 38 | ], 39 | "name": "stdout" 40 | } 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "metadata": { 46 | "id": "3lAKqodT-sKk" 47 | }, 48 | "source": [ 49 | "import os\n", 50 | "import numpy as np\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "import scipy.io as sio\n", 53 | "from sklearn.decomposition import PCA\n", 54 | "from sklearn.model_selection import train_test_split\n", 55 | "from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score\n", 56 | "import spectral\n", 57 | "import torch, math\n", 58 | "from torchvision import datasets, transforms\n", 59 | "import torch.nn as nn\n", 60 | "import torch.nn.functional as F\n", 61 | "import torch.optim as optim\n", 62 | "from scipy.io import loadmat, savemat\n", 63 | "import random\n", 64 | "from time import time\n", 65 | "import h5py" 66 | ], 67 | "execution_count": null, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "9zHmDD44-5RR" 74 | }, 75 | "source": [ 76 | "Setting the seed of GPU" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "bAKfj4w7-3Zv" 83 | }, 84 | "source": [ 85 | "# def seed_torch(seed = 612):\n", 86 | "# \trandom.seed(seed)\n", 87 | "# \tos.environ['PYTHONHASHSEED'] = str(seed) \n", 88 | "# \tnp.random.seed(seed)\n", 89 | "# \ttorch.manual_seed(seed)\n", 90 | "# \ttorch.cuda.manual_seed(seed)\n", 91 | "# \ttorch.cuda.manual_seed_all(seed) # if you are using multi-GPU.\n", 92 | "# \ttorch.backends.cudnn.benchmark = False\n", 93 | "# \ttorch.backends.cudnn.deterministic = True\n", 94 | "# seed_torch() \n", 95 | "\n", 96 | "# Setting parameters of model \n", 97 | "data_path = '/content/drive/MyDrive/data/'\n", 98 | "\n", 99 | "# the number of bands\n", 100 | "channel_hsi = 63\n", 101 | "channel_msi = 2\n", 102 | "\n", 103 | "# parameters of loss finctions\n", 104 | "alpha = 0.01\n", 105 | "beta = 0.01\n", 106 | "\n", 107 | "windowSize = 11\n", 108 | "valRatio = 0.2 # Ratio of validation sets\n", 109 | "class_num = 20\n", 110 | "batch_size = 64\n", 111 | "\n", 112 | "# parameters of optimizer\n", 113 | "lr = 0.001 # learning rate \n", 114 | "momentum = 0.9 # momentum of SGD\n", 115 | "betas = (0.9, 0.999) # betas of Adam\n", 116 | "num_epochs = 100" 117 | ], 118 | "execution_count": null, 119 | "outputs": [] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "id": "rDjYmMhU_XTC" 125 | }, 126 | "source": [ 127 | "# 1. Feature extraction network" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "metadata": { 133 | "id": "jaTZVb3s-864" 134 | }, 135 | "source": [ 136 | "class HSINet(nn.Module):\n", 137 | " def __init__(self, channel_hsi):\n", 138 | " super(HSINet, self).__init__()\n", 139 | "\n", 140 | " self.conv1 = nn.Conv2d(channel_hsi, 256, 3, padding=1)\n", 141 | " self.bn1 = nn.BatchNorm2d(256)\n", 142 | "\n", 143 | " self.conv2 = nn.Conv2d(256, 128, 3)\n", 144 | " self.bn2 = nn.BatchNorm2d(128)\n", 145 | " self.conv3 = nn.Conv2d(128, 128, 3)\n", 146 | " self.bn3 = nn.BatchNorm2d(128)\n", 147 | "\n", 148 | " def forward(self, x):\n", 149 | "\n", 150 | " x = F.relu(self.bn1(self.conv1(x)))\n", 151 | " x = F.relu(self.bn2(self.conv2(x)))\n", 152 | " x = F.relu(self.bn3(self.conv3(x)))\n", 153 | " return x\n", 154 | "\n", 155 | "class MSINet(nn.Module):\n", 156 | " def __init__(self, channel_msi):\n", 157 | " super(MSINet, self).__init__()\n", 158 | "\n", 159 | " self.conv1 = nn.Conv2d(channel_msi, 128, 3, padding =1)\n", 160 | " self.bn1 = nn.BatchNorm2d(128)\n", 161 | "\n", 162 | " self.conv2 = nn.Conv2d(128, 128, 3)\n", 163 | " self.bn2 = nn.BatchNorm2d(128)\n", 164 | "\n", 165 | " self.conv3 = nn.Conv2d(128, 128, 3)\n", 166 | " self.bn3 = nn.BatchNorm2d(128)\n", 167 | "\n", 168 | " def forward(self, x):\n", 169 | "\n", 170 | " x = F.relu(self.bn1(self.conv1(x)))\n", 171 | " x = F.relu(self.bn2(self.conv2(x)))\n", 172 | " x = F.relu(self.bn3(self.conv3(x)))\n", 173 | " return x" 174 | ], 175 | "execution_count": null, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": { 181 | "id": "KNGYOU-n_Wln" 182 | }, 183 | "source": [ 184 | "Define normalization and dropout layer" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "metadata": { 190 | "id": "HpP2m1lx_dhl" 191 | }, 192 | "source": [ 193 | "class LayerNorm(nn.Module):\n", 194 | " def __init__(self, size, eps=1e-6):\n", 195 | " super(LayerNorm, self).__init__()\n", 196 | " self.eps = eps\n", 197 | " self.a_2 = nn.Parameter(torch.ones(size))\n", 198 | " self.b_2 = nn.Parameter(torch.zeros(size))\n", 199 | "\n", 200 | " def forward(self, x):\n", 201 | " mean = x.mean(-1, keepdim=True)\n", 202 | " std = x.std(-1, keepdim=True)\n", 203 | " return self.a_2 * (x - mean) / (std + self.eps) + self.b_2\n", 204 | "\n", 205 | "class Dropout(nn.Module):\n", 206 | " def __init__(self):\n", 207 | " super(Dropout, self).__init__()\n", 208 | "\n", 209 | " def forward(self, x):\n", 210 | " out = F.dropout(x, p = 0.2, training=self.training)\n", 211 | " return out" 212 | ], 213 | "execution_count": null, 214 | "outputs": [] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": { 219 | "id": "H3sdRkzu_VIV" 220 | }, 221 | "source": [ 222 | "Define cross attention layer" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "id": "laGokoIy_y--" 229 | }, 230 | "source": [ 231 | "class CAM(nn.Module):\n", 232 | " def __init__(self):\n", 233 | " super(CAM, self).__init__() \n", 234 | " k_size = 3 \n", 235 | " self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)\n", 236 | " # self.conv1 = nn.Conv2d(9, 7, 1) # 81 is the spatial size of features\n", 237 | " # self.conv2 = nn.Conv2d(7, 49, 1, stride=1, padding=0)\n", 238 | "\n", 239 | " for m in self.modules():\n", 240 | " if isinstance(m, nn.Conv2d):\n", 241 | " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", 242 | " m.weight.data.normal_(0, math.sqrt(2. / n))\n", 243 | "\n", 244 | " def get_attention(self, a):\n", 245 | "\n", 246 | " input_a = a\n", 247 | " a = a.mean(3)\n", 248 | " a = a.transpose(1, 3)\n", 249 | " # a= F.relu(self.conv1(a))\n", 250 | " # a= self.conv2(a)\n", 251 | " a = self.conv(a.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)\n", 252 | " a = a.transpose(1, 3)\n", 253 | "\n", 254 | " a = a.unsqueeze(3)\n", 255 | " a = torch.mean(input_a * a, -1)\n", 256 | " a = F.softmax(a / 0.025, dim=-1) +1\n", 257 | " return a \n", 258 | "\n", 259 | " def forward(self, f1, f2):\n", 260 | "\n", 261 | " b, n1, c, h, w = f1.size()\n", 262 | " n2 = f2.size(1)\n", 263 | "\n", 264 | " f1 = f1.view(b, n1, c, -1) \n", 265 | " f2 = f2.view(b, n2, c, -1)\n", 266 | "\n", 267 | " f1_norm = F.normalize(f1, p=2, dim=2, eps=1e-12)\n", 268 | " f2_norm = F.normalize(f2, p=2, dim=2, eps=1e-12)\n", 269 | " \n", 270 | " f1_norm = f1_norm.transpose(2, 3).unsqueeze(2)\n", 271 | " f2_norm = f2_norm.unsqueeze(1)\n", 272 | "\n", 273 | " a1 = torch.matmul(f1_norm, f2_norm) \n", 274 | " a2 = a1.transpose(3, 4) \n", 275 | "\n", 276 | " a1 = self.get_attention(a1)\n", 277 | " a2 = self.get_attention(a2)\n", 278 | " f1 = f1 * a1\n", 279 | " f1 = f1.view(b, c, h, w)\n", 280 | " f2 = f2 * a2\n", 281 | " f2 = f2.view(b, c, h, w)\n", 282 | " return f1, f2" 283 | ], 284 | "execution_count": null, 285 | "outputs": [] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": { 290 | "id": "lmAGqpMmAa-S" 291 | }, 292 | "source": [ 293 | "# 2. The proposed deep feature interaction network" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "metadata": { 299 | "id": "14WtUYaPAKIJ" 300 | }, 301 | "source": [ 302 | "class Net(nn.Module):\n", 303 | " def __init__(self, channel_hsi, channel_msi, class_num):\n", 304 | " super(Net, self).__init__()\n", 305 | "\n", 306 | " self.featnet1 = HSINet(channel_hsi)\n", 307 | " self.featnet2 = MSINet(channel_msi)\n", 308 | " self.cam = CAM()\n", 309 | " self.proj_norm = LayerNorm(64)\n", 310 | " self.fc1 = nn.Linear(1 * 1 * 128, 64)\n", 311 | " self.fc2 = nn.Linear(64, class_num)\n", 312 | " self.dropout = nn.Dropout()\n", 313 | "\n", 314 | " def forward(self, x, y):\n", 315 | "\n", 316 | " # Pre-process Image Feature\n", 317 | " feature_1 = self.featnet1(x)\n", 318 | " feature_2 = self.featnet2(y)\n", 319 | "\n", 320 | " hsi_feat = feature_1.unsqueeze(1)\n", 321 | " lidar_feat = feature_2.unsqueeze(1)\n", 322 | " hsi, lidar = self.cam(hsi_feat, lidar_feat)\n", 323 | " x = self.xcorr_depthwise(hsi, lidar)\n", 324 | " y = self.xcorr_depthwise(lidar, hsi)\n", 325 | " x1 = x.contiguous().view(x.size(0), -1)\n", 326 | " y1 = y.contiguous().view(y.size(0), -1)\n", 327 | " x = x1 + y1\n", 328 | " x = F.relu(self.proj_norm(self.fc1(x)))\n", 329 | " \n", 330 | " x = self.dropout(x)\n", 331 | " x = self.fc2(x)\n", 332 | " # hsi = hsi.contiguous().view(x.size(0), -1)\n", 333 | " # lidar = lidar.contiguous().view(x.size(0), -1)\n", 334 | " return feature_1, feature_2, x1, y1, x\n", 335 | "\n", 336 | " def xcorr_depthwise11(self, x, kernel):\n", 337 | " batch = kernel.size(0)\n", 338 | " channel = kernel.size(1)\n", 339 | " x = x.view(1, batch * channel, x.size(2), x.size(3))\n", 340 | " kernel = kernel.view(batch, channel, kernel.size(2), kernel.size(3))\n", 341 | " out = F.conv2d(x, kernel, groups= 1)\n", 342 | " # out = F.relu(out)\n", 343 | " out = out.view(batch, 1, out.size(2), out.size(3))\n", 344 | " return out\n", 345 | "\n", 346 | " def xcorr_depthwise(self, x, kernel):\n", 347 | " batch = kernel.size(0)\n", 348 | " channel = kernel.size(1)\n", 349 | " x = x.view(1, batch * channel, x.size(2), x.size(3))\n", 350 | " kernel = kernel.view(batch * channel, 1, kernel.size(2), kernel.size(3))\n", 351 | " out = F.conv2d(x, kernel, groups=batch*channel)\n", 352 | " # out = F.relu(out)\n", 353 | " out = out.view(batch, channel, out.size(2), out.size(3))\n", 354 | " return out" 355 | ], 356 | "execution_count": null, 357 | "outputs": [] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": { 362 | "id": "SM4kogzPAniC" 363 | }, 364 | "source": [ 365 | "# 3. Data processing function" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "metadata": { 371 | "id": "1A_Ndtr3AuKV" 372 | }, 373 | "source": [ 374 | "def max_min_mean(img):\n", 375 | " \"\"\"\n", 376 | " calculate the max value ,min value ,mean value from the image.\n", 377 | " \"\"\"\n", 378 | " print('max: ',np.max(img),'min: ',np.min(img),'mean: ',np.mean(img))\n", 379 | "\n", 380 | "def c(img):\n", 381 | " \"\"\"\n", 382 | " map the image to [0,255]\n", 383 | " \"\"\"\n", 384 | " return ( img - np.min(img) ) / ( np.max(img)-np.min(img) ) * 255\n", 385 | "\n", 386 | "\n", 387 | "def applyPCA(X, numComponents):\n", 388 | " \"\"\"\n", 389 | " apply PCA to the image to reduce dimensionality \n", 390 | " \"\"\"\n", 391 | " newX = np.reshape(X, (-1, X.shape[2]))\n", 392 | " pca = PCA(n_components=numComponents, whiten=True)\n", 393 | " newX = pca.fit_transform(newX)\n", 394 | " newX = np.reshape(newX, (X.shape[0], X.shape[1], numComponents))\n", 395 | " return newX\n", 396 | "\n", 397 | "def addZeroPadding(X, margin=2):\n", 398 | " \"\"\"\n", 399 | " add zero padding to the image\n", 400 | " \"\"\"\n", 401 | " newX = np.zeros((\n", 402 | " X.shape[0] + 2 * margin,\n", 403 | " X.shape[1] + 2 * margin,\n", 404 | " X.shape[2]\n", 405 | " ))\n", 406 | " newX[margin:X.shape[0]+margin, margin:X.shape[1]+margin, :] = X\n", 407 | " return newX\n", 408 | "\n", 409 | "\n", 410 | "def createImgCube(X ,gt ,pos:list ,windowSize=25):\n", 411 | " \"\"\"\n", 412 | " create Cube from pos list\n", 413 | " return imagecube gt nextPos\n", 414 | " \"\"\"\n", 415 | " margin = (windowSize-1)//2\n", 416 | " zeroPaddingX = addZeroPadding(X, margin=margin)\n", 417 | " dataPatches = np.zeros((pos.__len__(), windowSize, windowSize, X.shape[2]))\n", 418 | " if( pos[-1][1]+1 != X.shape[1] ):\n", 419 | " nextPos = (pos[-1][0] ,pos[-1][1]+1)\n", 420 | " elif( pos[-1][0]+1 != X.shape[0] ):\n", 421 | " nextPos = (pos[-1][0]+1 ,0)\n", 422 | " else:\n", 423 | " nextPos = (0,0)\n", 424 | " return np.array([zeroPaddingX[i:i+windowSize, j:j+windowSize, :] for i,j in pos ]),\\\n", 425 | " np.array([gt[i,j] for i,j in pos]) ,\\\n", 426 | " nextPos\n", 427 | "\n", 428 | "\n", 429 | "def createPos(shape:tuple, pos:tuple, num:int):\n", 430 | " \"\"\"\n", 431 | " creatre pos list after the given pos\n", 432 | " \"\"\"\n", 433 | " if (pos[0]+1)*(pos[1]+1)+num >shape[0]*shape[1]:\n", 434 | " num = shape[0]*shape[1]-( (pos[0])*shape[1] + pos[1] )\n", 435 | " return [(pos[0]+(pos[1]+i)//shape[1] , (pos[1]+i)%shape[1] ) for i in range(num) ]\n", 436 | "\n", 437 | "def createPosWithoutZero(hsi, gt):\n", 438 | " \"\"\"\n", 439 | " creatre pos list without zero labels\n", 440 | " \"\"\"\n", 441 | " mask = gt > 0\n", 442 | " return [(i,j) for i , row in enumerate(mask) for j , row_element in enumerate(row) if row_element]\n", 443 | "\n", 444 | "def createImgPatch(lidar, pos:list, windowSize=25):\n", 445 | " \"\"\"\n", 446 | " return lidar Img patches\n", 447 | " \"\"\"\n", 448 | " margin = (windowSize-1)//2\n", 449 | " zeroPaddingLidar = np.zeros((\n", 450 | " lidar.shape[0] + 2 * margin,\n", 451 | " lidar.shape[1] + 2 * margin\n", 452 | " ))\n", 453 | " zeroPaddingLidar[margin:lidar.shape[0]+margin, margin:lidar.shape[1]+margin] = lidar\n", 454 | " return np.array([zeroPaddingLidar[i:i+windowSize, j:j+windowSize] for i,j in pos ])\n", 455 | " \n", 456 | "def splitTrainTestSet(X, gt, testRatio, randomState=111):\n", 457 | " \"\"\"\n", 458 | " random split data set\n", 459 | " \"\"\"\n", 460 | " X_train, X_test, gt_train, gt_test = train_test_split(\n", 461 | " X, gt, test_size=testRatio, random_state=randomState, stratify=gt)\n", 462 | " return X_train, X_test, gt_train, gt_test\n", 463 | "\n", 464 | "def minmax_normalize(array): \n", 465 | " amin = np.min(array)\n", 466 | " amax = np.max(array)\n", 467 | " return (array - amin) / (amax - amin)" 468 | ], 469 | "execution_count": null, 470 | "outputs": [] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "metadata": { 475 | "id": "OZ4XnH3qBBCe" 476 | }, 477 | "source": [ 478 | "# 4. Create dataloader" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "metadata": { 484 | "id": "TTJxj8OyBBXX", 485 | "colab": { 486 | "base_uri": "https://localhost:8080/" 487 | }, 488 | "outputId": "9556d637-5bf7-464c-f552-763533db44f0" 489 | }, 490 | "source": [ 491 | "\n", 492 | "data_traingt = sio.loadmat(os.path.join(data_path, 'trento_mask_train.mat'))['mask_train']\n", 493 | "data_testgt = sio.loadmat(os.path.join(data_path, 'trento_mask_test.mat'))['mask_test']\n", 494 | "data_hsi = sio.loadmat(os.path.join(data_path, 'trento_hsi.mat'))['trento_hsi']\n", 495 | "data_msi = sio.loadmat(os.path.join(data_path, 'trento_lidar.mat'))['trento_lidar']\n", 496 | "# data_msi = h5py.File(os.path.join(data_path, 'HHK_msi.mat'))\n", 497 | "# data_msi = data_msi['HHK_msi'][:]\n", 498 | "# data_msi= np.transpose(data_msi,(2,1,0))\n", 499 | "\n", 500 | "data_hsi = minmax_normalize(data_hsi)\n", 501 | "data_msi = minmax_normalize(data_msi)\n", 502 | "height, width, c = data_msi.shape\n", 503 | "\n", 504 | "# training / testing set for 2D-CNN\n", 505 | "\n", 506 | "train_hsiCube, train_labels, _ = createImgCube(\tdata_hsi, data_traingt, createPosWithoutZero(data_hsi, data_traingt), windowSize=windowSize)\n", 507 | "train_patches, _, _ = createImgCube(data_msi, data_traingt, createPosWithoutZero(data_msi, data_traingt), windowSize=windowSize)\n", 508 | "\n", 509 | "# data augmentation if need\n", 510 | "\n", 511 | "Xh = []\n", 512 | "Xl = []\n", 513 | "y = []\n", 514 | "for i in range(train_hsiCube.shape[0]):\n", 515 | " Xh.append(train_hsiCube[i])\n", 516 | " Xl.append(train_patches[i])\n", 517 | "\n", 518 | " noise = np.random.normal(0.0, 0.01, size=train_hsiCube[0].shape)\n", 519 | " noise2 = np.random.normal(0.0, 0.01, size=train_patches[0].shape)\n", 520 | " Xh.append(np.flip(train_hsiCube[i] + noise, axis=1))\n", 521 | " Xl.append(np.flip(train_patches[i] + noise2, axis=1))\n", 522 | "\n", 523 | " k = np.random.randint(4)\n", 524 | " Xh.append(np.rot90(train_hsiCube[i], k=k))\n", 525 | " Xl.append(np.rot90(train_patches[i], k=k))\n", 526 | "\n", 527 | " y.append(train_labels[i])\n", 528 | " y.append(train_labels[i])\n", 529 | " y.append(train_labels[i])\n", 530 | "\n", 531 | "train_labels = np.asarray(y, dtype=np.int8)\n", 532 | "train_hsiCube = np.asarray(Xh, dtype=np.float32)\n", 533 | "train_patches = np.asarray(Xl, dtype=np.float32)\n", 534 | "train_hsiCube = torch.from_numpy(train_hsiCube.transpose(0, 3, 1, 2)).float()\n", 535 | "train_patches = torch.from_numpy(train_patches.transpose(0, 3, 1, 2)).float()\n", 536 | "\n", 537 | "X_train, X_test, gt_train, gt_test = splitTrainTestSet(train_hsiCube, train_labels, valRatio, randomState=111)\n", 538 | "X_train_2, X_test_2, _, _ = splitTrainTestSet(train_patches, train_labels, valRatio, randomState=111)\n", 539 | "\n", 540 | "print(X_train.shape)\n", 541 | "print(X_test.shape)\n", 542 | "print(\"Creating dataloader\")\n", 543 | "\n", 544 | "class TrainDS(torch.utils.data.Dataset):\n", 545 | " def __init__(self):\n", 546 | " self.len = gt_train.shape[0]\n", 547 | " self.hsi = torch.FloatTensor(X_train)\n", 548 | " self.lidar = torch.FloatTensor(X_train_2)\n", 549 | " self.labels = torch.LongTensor(gt_train - 1)\n", 550 | " def __getitem__(self, index):\n", 551 | " return self.hsi[index], self.lidar[index], self.labels[index]\n", 552 | " def __len__(self):\n", 553 | " return self.len\n", 554 | "\n", 555 | "\"\"\" Testing dataset\"\"\"\n", 556 | "\n", 557 | "class TestDS(torch.utils.data.Dataset):\n", 558 | " def __init__(self):\n", 559 | " self.len = gt_test.shape[0]\n", 560 | " self.hsi = torch.FloatTensor(X_test)\n", 561 | " self.lidar = torch.FloatTensor(X_test_2)\n", 562 | " self.labels = torch.LongTensor(gt_test - 1)\n", 563 | " def __getitem__(self, index):\n", 564 | " return self.hsi[index], self.lidar[index], self.labels[index]\n", 565 | " def __len__(self):\n", 566 | " return self.len\n", 567 | "\n", 568 | "# create trainloader and valloader -- (testload)\n", 569 | "trainset = TrainDS()\n", 570 | "valset = TestDS()\n", 571 | "\n", 572 | "train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=64, shuffle=True, num_workers=0)\n", 573 | "val_loader = torch.utils.data.DataLoader(dataset=valset, batch_size=64, shuffle=False, num_workers=0)\n", 574 | "\n", 575 | "print(\"Finish!\")" 576 | ], 577 | "execution_count": null, 578 | "outputs": [ 579 | { 580 | "output_type": "stream", 581 | "text": [ 582 | "torch.Size([1965, 63, 11, 11])\n", 583 | "torch.Size([492, 63, 11, 11])\n", 584 | "Creating dataloader\n", 585 | "Finish!\n" 586 | ], 587 | "name": "stdout" 588 | } 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": { 594 | "id": "gxQ57lsoBsIk" 595 | }, 596 | "source": [ 597 | "# 5. The loss function\n", 598 | "\n", 599 | "1. Loss1: The consistency loss\n", 600 | "2. Loss2: The distinctive loss\n", 601 | "3. Loss3: The classification loss\n", 602 | "\n", 603 | "\n", 604 | "\n" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "metadata": { 610 | "id": "qgM_E3XHBg7C" 611 | }, 612 | "source": [ 613 | "def calc_label_sim(label_1,label_2):\n", 614 | "\n", 615 | " batch_size = label_1.shape[0]\n", 616 | " label = torch.zeros(batch_size, class_num).scatter_(1, label_1.unsqueeze(1).cpu(), 1)\n", 617 | " sim = label.float().mm(label.float().t()).cuda()\n", 618 | " return sim\n", 619 | "def calc_loss(feature_1, feature_2, hsi_1, lidar_1, outputs, labels, alpha, beta):\n", 620 | "\n", 621 | " cos = lambda x, y: x.mm(y.t()) / ((x ** 2).sum(1, keepdim=True).sqrt().mm((y ** 2).sum(1, keepdim=True).sqrt().t())).clamp(min=1e-6) / 2.\n", 622 | " theta = cos(hsi_1, lidar_1)\n", 623 | " sim = calc_label_sim(labels, labels)\n", 624 | " theta1 = cos(hsi_1, hsi_1)\n", 625 | " theta2 = cos(lidar_1, lidar_1)\n", 626 | "\n", 627 | " term1= ((1+torch.exp(theta)).log() + sim * theta).mean()\n", 628 | " term2 = ((1 + torch.exp(theta1)).log() + sim * theta1).mean()\n", 629 | " term3 = ((1 + torch.exp(theta2)).log() + sim * theta2).mean()\n", 630 | " loss2 = term1 + term2 + term3\n", 631 | "\n", 632 | " criterion = nn.CrossEntropyLoss()\n", 633 | " loss3 = criterion(outputs, labels)\n", 634 | " loss1 = torch.mean(torch.pow(feature_1 - feature_2, 2))\n", 635 | "\n", 636 | " loss_sum = loss3 + alpha * loss2 + beta * loss1\n", 637 | " return loss_sum.mean()" 638 | ], 639 | "execution_count": null, 640 | "outputs": [] 641 | }, 642 | { 643 | "cell_type": "markdown", 644 | "metadata": { 645 | "id": "kdMnyTbeCO6T" 646 | }, 647 | "source": [ 648 | "Define training and testing layer" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "metadata": { 654 | "id": "mjBG9og2CPBV" 655 | }, 656 | "source": [ 657 | "def train(model, device, train_loader, optimizer, epoch):\n", 658 | " model.train()\n", 659 | " total_loss = 0\n", 660 | " for i, (inputs_1, inputs_2, labels) in enumerate(train_loader):\n", 661 | " inputs_1, inputs_2 = inputs_1.to(device), inputs_2.to(device)\n", 662 | "\n", 663 | " labels = labels.to(device)\n", 664 | "\n", 665 | " optimizer.zero_grad()\n", 666 | " feature_1, feature_2, hsi_1, lidar_1, outputs = model(inputs_1, inputs_2)\n", 667 | " loss = calc_loss(feature_1, feature_2, hsi_1, lidar_1, outputs, labels, alpha =0.01, beta = 0.01)\n", 668 | " loss.backward()\n", 669 | " optimizer.step()\n", 670 | " total_loss += loss.item()\n", 671 | "\n", 672 | " print('[Epoch: %d] [loss avg: %.4f] [current loss: %.4f]' %(epoch + 1, total_loss/(epoch+1), loss.item()))\n", 673 | "\n", 674 | "def test(model, device, val_loader):\n", 675 | " model.eval()\n", 676 | " count = 0\n", 677 | " feature =[]\n", 678 | " flabel = []\n", 679 | " for inputs_1, inputs_2, labels in val_loader:\n", 680 | " \n", 681 | " inputs_1, inputs_2 = inputs_1.to(device), inputs_2.to(device)\n", 682 | " _, _, _, _, outputs = model(inputs_1, inputs_2)\n", 683 | " feature.append(outputs.detach().cpu().numpy())\n", 684 | " outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1)\n", 685 | " if count == 0:\n", 686 | " y_pred_test = outputs\n", 687 | " test_labels = labels\n", 688 | " count = 1\n", 689 | " else:\n", 690 | " y_pred_test = np.concatenate((y_pred_test, outputs))\n", 691 | " test_labels = np.concatenate((test_labels, labels))\n", 692 | " classification = classification_report(test_labels, y_pred_test, digits=4)\n", 693 | " \n", 694 | " sio.savemat('feature.mat', {'feature': feature})\n", 695 | " a = 0\n", 696 | " for c in range(len(y_pred_test)):\n", 697 | " if test_labels[c]==y_pred_test[c]:\n", 698 | " a = a+1\n", 699 | " sio.savemat('test_labels.mat', {'test_labels': test_labels})\n", 700 | " print(classification)\n", 701 | " acc = a/len(y_pred_test)*100\n", 702 | " print('%.2f' % (a/len(y_pred_test)*100))\n", 703 | " return acc" 704 | ], 705 | "execution_count": null, 706 | "outputs": [] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "metadata": { 711 | "id": "7oSAc_kbLAxP" 712 | }, 713 | "source": [ 714 | "# 6. Running" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "metadata": { 720 | "id": "HSy2Vxx_LA5P" 721 | }, 722 | "source": [ 723 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 724 | "model = Net(channel_hsi, channel_msi, class_num).to(device)\n", 725 | "\n", 726 | "params_to_update = list(model.parameters())\n", 727 | "\n", 728 | "# optimizer = torch.optim.Adam(params_to_update, lr=lr, betas=betas, eps=1e-8, weight_decay=0.0005)\n", 729 | "optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005)\n", 730 | "\n", 731 | "# if num_epochs % 50 == 0:\n", 732 | "# for p in optimizer.param_groups:p['lr'] *= 0.9\n", 733 | "# lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])\n", 734 | "# print (lr_list)\n", 735 | "\n", 736 | "best_acc = 0\n", 737 | "for epoch in range(num_epochs):\n", 738 | " train(model, device, train_loader, optimizer, epoch)\n", 739 | " acc = test(model, device, val_loader)\n", 740 | " if acc >= best_acc:\n", 741 | " best_acc = acc\n", 742 | " print(\"save model\")\n", 743 | " torch.save(model.state_dict(), './model/model.pth')" 744 | ], 745 | "execution_count": null, 746 | "outputs": [] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": { 751 | "id": "WR3aPwlILr4T" 752 | }, 753 | "source": [ 754 | "# 7. Record the final test results" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "metadata": { 760 | "id": "_jkC7NQTLsA4" 761 | }, 762 | "source": [ 763 | "model = Net(channel_hsi, channel_msi, class_num).eval().cuda()\n", 764 | "model.load_state_dict(torch.load('./model/model.pth'))\n", 765 | "\n", 766 | "margin = (windowSize-1)//2\n", 767 | "data_hsi = addZeroPadding(data_hsi, margin=margin)\n", 768 | "data_msi = addZeroPadding(data_msi, margin=margin)\n", 769 | "\n", 770 | "# Prediction\n", 771 | "outputs = np.zeros((height, width))\n", 772 | "# feature = np.zeros(test_hsiCube.shape[0], class_num)\n", 773 | "for i in range(height):\n", 774 | " for j in range(width):\n", 775 | " if int(data_testgt[i, j]) != 0:\n", 776 | " # continue\n", 777 | " # else :\n", 778 | " image_patch = data_hsi[i:i+windowSize, j:j+windowSize, :]\n", 779 | " image_patch = image_patch.reshape(1, image_patch.shape[0], image_patch.shape[1], image_patch.shape[2])\n", 780 | " X_test_image = torch.FloatTensor(image_patch.transpose(0, 3, 1, 2)).to(device)\n", 781 | "\n", 782 | " image_patch1 = data_msi[i:i+windowSize, j:j+windowSize, :]\n", 783 | " image_patch1 = image_patch1.reshape(1, image_patch1.shape[0], image_patch1.shape[1], image_patch1.shape[2])\n", 784 | " X_test_image1 = torch.FloatTensor(image_patch1.transpose(0, 3, 1, 2)).to(device)\n", 785 | "\n", 786 | " _, _, _, _, prediction = model(X_test_image, X_test_image1)\n", 787 | " prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)\n", 788 | " outputs[i][j] = prediction + 1\n", 789 | " if i % 20 == 0:\n", 790 | " print('... ... row ', i, ' handling ... ...')\n", 791 | "sio.savemat('result.mat', {'output': outputs})\n", 792 | "print('ALL Finish!!')" 793 | ], 794 | "execution_count": null, 795 | "outputs": [] 796 | } 797 | ] 798 | } 799 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HSI-and-MSI-Classification 2 | 3 | This code is for our paper "Hyperspectral and Multispectral Classification for Coastal Wetland Using Depthwise Feature Interaction Network". 4 | This paper has been accepted by IEEE Transactions on Geoscience and Remote Sensing, vol. 60, pp. 1-15, 2022, Art no. 5512615, doi: 10.1109/TGRS.2021.3097093. 5 | 6 | If you have any questions, please contact us. Email: gaoyunhao128@163.com 7 | 8 | @ARTICLE{9494718, 9 | author={Gao, Yunhao and Li, Wei and Zhang, Mengmeng and Wang, Jianbu and Sun, Weiwei and Tao, Ran and Du, Qian}, 10 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 11 | title={Hyperspectral and Multispectral Classification for Coastal Wetland Using Depthwise Feature Interaction Network}, 12 | year={2021}, 13 | volume={}, 14 | number={}, 15 | pages={1-15}, 16 | doi={10.1109/TGRS.2021.3097093}} 17 | 18 | #### Colab 19 | 20 | 1. Go to Google Colab and sign in! https://colab.research.google.com/ 21 | 2. Pick the notebook and run it! 22 | 23 | --------------------------------------------------------------------------------