├── SSRN ├── net │ └── a.v ├── records │ └── a.v ├── classification_maps │ └── a.v ├── bsconv.txt └── indian.py ├── datasets └── h.py ├── global_module ├── Utils │ ├── __init__.py │ ├── extract_samll_cubic.py │ ├── record.py │ └── extract_samll_cubic_save_RAM.py ├── d2lzh_pytorch │ ├── __init__.py │ └── utils.py ├── activation.py ├── train.py ├── generate_pic.py └── network.py ├── assets ├── carbon.jpg ├── CNN-page-001.jpg ├── bsnetsIN3D-page-001.jpg ├── architecture-MLP-page-001.jpg ├── architecture-Conv-page-001.jpg ├── indian-svm-oa-3-30-page-001.jpg ├── MSD-Indian-3-30-band-page-001.jpg ├── architecture-overall-page-001.jpg ├── top15bands-entropy-all-BS-Indian-page-001.jpg ├── MLP-loss-acc-Indian-5band-100epoch-L2-01-page-001.jpg └── loss-acc-Indian-5band-100epoch-L10-01-best-page-001.jpg ├── requirements.txt ├── .deepsource.toml ├── .gitpod.yml ├── .gitpod.Dockerfile ├── .github ├── workflows │ └── greetings.yml └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── LICENSE └── README.md /SSRN/net/a.v: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/h.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSRN/records/a.v: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSRN/classification_maps/a.v: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /global_module/Utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /global_module/d2lzh_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | -------------------------------------------------------------------------------- /assets/carbon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/carbon.jpg -------------------------------------------------------------------------------- /assets/CNN-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/CNN-page-001.jpg -------------------------------------------------------------------------------- /SSRN/bsconv.txt: -------------------------------------------------------------------------------- 1 | [80, 97, 43, 71, 72, 7, 92, 151, 134, 87, 100, 15, 10, 84, 95, 38, 106, 122, 3, 34, 37, 48, 36, 86, 190] 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | scikit-learn==0.22.2 3 | numpy==1.18.1 4 | spectral 5 | torchsummary 6 | torchvision==0.5.1 -------------------------------------------------------------------------------- /assets/bsnetsIN3D-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/bsnetsIN3D-page-001.jpg -------------------------------------------------------------------------------- /assets/architecture-MLP-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/architecture-MLP-page-001.jpg -------------------------------------------------------------------------------- /assets/architecture-Conv-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/architecture-Conv-page-001.jpg -------------------------------------------------------------------------------- /assets/indian-svm-oa-3-30-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/indian-svm-oa-3-30-page-001.jpg -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "python" 5 | enabled = true 6 | 7 | [analyzers.meta] 8 | runtime_version = "3.x.x" 9 | -------------------------------------------------------------------------------- /assets/MSD-Indian-3-30-band-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/MSD-Indian-3-30-band-page-001.jpg -------------------------------------------------------------------------------- /assets/architecture-overall-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/architecture-overall-page-001.jpg -------------------------------------------------------------------------------- /assets/top15bands-entropy-all-BS-Indian-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/top15bands-entropy-all-BS-Indian-page-001.jpg -------------------------------------------------------------------------------- /assets/MLP-loss-acc-Indian-5band-100epoch-L2-01-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/MLP-loss-acc-Indian-5band-100epoch-L2-01-page-001.jpg -------------------------------------------------------------------------------- /assets/loss-acc-Indian-5band-100epoch-L10-01-best-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucalyptus/BS-Nets-Implementation-Pytorch/HEAD/assets/loss-acc-Indian-5band-100epoch-L10-01-best-page-001.jpg -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | tasks: 2 | - init: echo "Replace me with a build script for the project." 3 | command: echo "Replace me with something that should run on every start, or just 4 | remove me entirely." 5 | image: 6 | file: .gitpod.Dockerfile 7 | -------------------------------------------------------------------------------- /.gitpod.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-full 2 | 3 | USER gitpod 4 | 5 | # Install custom tools, runtime, etc. using apt-get 6 | # For example, the command below would install "bastet" - a command line tetris clone: 7 | # 8 | # RUN sudo apt-get -q update && # sudo apt-get install -yq bastet && # sudo rm -rf /var/lib/apt/lists/* 9 | # 10 | # More information: https://www.gitpod.io/docs/config-docker/ 11 | -------------------------------------------------------------------------------- /.github/workflows/greetings.yml: -------------------------------------------------------------------------------- 1 | name: Greetings 2 | # https://github.com/marketplace/actions/first-interaction 3 | 4 | on: [issues] # pull_request 5 | 6 | jobs: 7 | greeting: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/first-interaction@v1 11 | with: 12 | repo-token: ${{ secrets.GITHUB_TOKEN }} 13 | issue-message: 'Hi! thanks for your contribution!, great first issue!' 14 | pr-message: 'Hey thanks for the input! Please give us a bit of time to review it!' 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sayantan Das 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /global_module/Utils/extract_samll_cubic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def index_assignment(index, row, col, pad_length): 5 | new_assign = {} 6 | for counter, value in enumerate(index): 7 | assign_0 = value // col + pad_length 8 | assign_1 = value % col + pad_length 9 | new_assign[counter] = [assign_0, assign_1] 10 | return new_assign 11 | 12 | 13 | def assignment_index(assign_0, assign_1, col): 14 | new_index = assign_0 * col + assign_1 15 | return new_index 16 | 17 | 18 | def select_patch(matrix, pos_row, pos_col, ex_len): 19 | selected_rows = matrix[range(pos_row-ex_len, pos_row+ex_len+1)] 20 | selected_patch = selected_rows[:, range(pos_col-ex_len, pos_col+ex_len+1)] 21 | return selected_patch 22 | 23 | 24 | def select_small_cubic(data_size, data_indices, whole_data, patch_length, padded_data, dimension): 25 | small_cubic_data = np.zeros((data_size, 2 * patch_length + 1, 2 * patch_length + 1, dimension)) 26 | data_assign = index_assignment(data_indices, whole_data.shape[0], whole_data.shape[1], patch_length) 27 | for i in range(len(data_assign)): 28 | small_cubic_data[i] = select_patch(padded_data, data_assign[i][0], data_assign[i][1], patch_length) 29 | return small_cubic_data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Gitpod Ready-to-Code](https://img.shields.io/badge/Gitpod-Ready--to--Code-blue?logo=gitpod)](https://gitpod.io/#https://github.com/ucalyptus/BS-Nets-Implementation-Pytorch) 2 | 3 | # BS-Nets-Implementation-Pytorch [![HitCount](http://hits.dwyl.io/ucalyptus/BS-Nets-Implementation-Pytorch.svg)](http://hits.dwyl.io/ucalyptus/BS-Nets-Implementation-Pytorch) 4 | 5 | 6 | # Setup 7 | - `conda activate bsnets` 8 | - `pip install -r requirements.txt` 9 | 10 | # SSRN Classification 11 | `cd SSRN/` 12 | `python indian.py` 13 | 14 | 15 | # Plots 16 | ![](assets/CNN-page-001.png) 17 | ![](assets/architecture-Conv-page-001.jpg) 18 | ![](https://github.com/ucalyptus/BS-Nets-Implementation-Pytorch/blob/e50a34df2cc45d08979383a29d6c41535a965453/assets/top15bands-entropy-all-BS-Indian-page-001.jpg) 19 | ![](https://github.com/ucalyptus/BS-Nets-Implementation-Pytorch/blob/e50a34df2cc45d08979383a29d6c41535a965453/assets/architecture-MLP-page-001.jpg) 20 | ![](https://github.com/ucalyptus/BS-Nets-Implementation-Pytorch/blob/e50a34df2cc45d08979383a29d6c41535a965453/assets/loss-acc-Indian-5band-100epoch-L10-01-best-page-001.jpg) 21 | # Confusion Matrix 22 | ![](assets/bsnetsIN3D-page-001.jpg) 23 | 24 | # Architecture Code 25 | ![](assets/carbon.svg) 26 | ![](assets/carbon.jpg) 27 | -------------------------------------------------------------------------------- /global_module/activation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import math 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | class mish(nn.Module): 9 | def __init__(self): 10 | super(mish, self).__init__() 11 | # Also see https://arxiv.org/abs/1606.08415 12 | def forward(self, x): 13 | return x * torch.tanh(F.softplus(x)) 14 | 15 | 16 | class gelu(nn.Module): 17 | def __init__(self): 18 | super(gelu, self).__init__() 19 | # Also see https://arxiv.org/abs/1606.08415 20 | def forward(self, x): 21 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 22 | 23 | 24 | class gelu_new(nn.Module): 25 | def __init__(self): 26 | super(gelu_new, self).__init__() 27 | #Also see https://arxiv.org/abs/1606.08415 28 | def forward(self, x): 29 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 30 | 31 | 32 | class swish(nn.Module): 33 | def __init__(self): 34 | super(swish, self).__init__() 35 | #Also see https://arxiv.org/abs/1606.08415 36 | def forward(self, x): 37 | return x * torch.sigmoid(x) 38 | 39 | 40 | 41 | 42 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new} -------------------------------------------------------------------------------- /global_module/Utils/record.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, training_time_ae, testing_time_ae, path): 6 | f = open(path, 'a') 7 | 8 | sentence0 = 'OAs for each iteration are:' + str(oa_ae) + '\n' 9 | f.write(sentence0) 10 | sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n' 11 | f.write(sentence1) 12 | sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n' 13 | f.write(sentence2) 14 | sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)*100) + ' ± ' + str(np.std(oa_ae)*100) + '\n' 15 | f.write(sentence3) 16 | sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)*100) + ' ± ' + str(np.std(aa_ae)*100) + '\n' 17 | f.write(sentence4) 18 | sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n' 19 | f.write(sentence5) 20 | sentence6 = 'Total average Training time is: ' + str(np.sum(training_time_ae)) + '\n' 21 | f.write(sentence6) 22 | sentence7 = 'Total average Testing time is: ' + str(np.sum(testing_time_ae)) + '\n' + '\n' 23 | f.write(sentence7) 24 | 25 | element_mean = np.mean(element_acc_ae, axis=0) 26 | element_std = np.std(element_acc_ae, axis=0) 27 | sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean*100) + '\n' 28 | f.write(sentence8) 29 | sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std*100) + '\n' 30 | f.write(sentence9) 31 | 32 | f.close() 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /global_module/Utils/extract_samll_cubic_save_RAM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gc 3 | 4 | 5 | def index_assignment(index, row, col, pad_length): 6 | new_assign = {} 7 | for counter, value in enumerate(index): 8 | assign_0 = value // col + pad_length 9 | assign_1 = value % col + pad_length 10 | new_assign[counter] = [assign_0, assign_1] 11 | return new_assign 12 | 13 | 14 | def assignment_index(assign_0, assign_1, col): 15 | new_index = assign_0 * col + assign_1 16 | return new_index 17 | 18 | 19 | def select_patch(matrix, pos_row, pos_col, ex_len): 20 | selected_rows = matrix[range(pos_row-ex_len, pos_row+ex_len+1)] 21 | selected_patch = selected_rows[:, range(pos_col-ex_len, pos_col+ex_len+1)] 22 | del(matrix) 23 | del(selected_rows) 24 | # gc.collect() 25 | return selected_patch 26 | 27 | 28 | def select_small_cubic(data_size, data_indices, whole_data, patch_length, padded_data, dimension): 29 | small_cubic_data = np.zeros((data_size, 2 * patch_length + 1, 2 * patch_length + 1, dimension)) 30 | data_assign = index_assignment(data_indices, whole_data.shape[0], whole_data.shape[1], patch_length) 31 | 32 | # selected_rows = padded_data[range(data_assign[0][0] - patch_length, data_assign[0][0] + patch_length + 1)] 33 | # selected_patch = selected_rows[:, range(data_assign[0][1] - patch_length, data_assign[0][1] + patch_length + 1)] 34 | 35 | for i in range(len(data_assign)): 36 | selected_rows = padded_data[range(data_assign[i][0] - patch_length, data_assign[i][0] + patch_length + 1)] 37 | small_cubic_data[i] = selected_rows[:, range(data_assign[i][1] - patch_length, data_assign[i][1] + patch_length + 1)] 38 | #small_cubic_data[i] = select_patch(padded_data, data_assign[i][0], data_assign[i][1], patch_length) 39 | return small_cubic_data 40 | 41 | 42 | -------------------------------------------------------------------------------- /SSRN/indian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import collections 4 | from torch import optim 5 | import torch 6 | from sklearn import metrics, preprocessing 7 | import datetime 8 | from torchsummary import summary 9 | 10 | 11 | import sys 12 | sys.path.append('../global_module/') 13 | import network 14 | import train 15 | from generate_pic import aa_and_each_accuracy, sampling,load_dataset, generate_png, generate_iter 16 | from Utils import record, extract_samll_cubic 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | # for Monte Carlo runs 21 | seeds = [1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341] 22 | ensemble = 1 23 | 24 | day = datetime.datetime.now() 25 | day_str = day.strftime('%m_%d_%H_%M') 26 | 27 | print('-----Importing Dataset-----') 28 | 29 | 30 | 31 | global Dataset # UP,IN,KSC 32 | dataset = 'IN' 33 | Dataset = dataset.upper() 34 | data_hsi, gt_hsi, TOTAL_SIZE, TRAIN_SIZE,VALIDATION_SPLIT,method = load_dataset(Dataset) 35 | 36 | print(data_hsi.shape) 37 | image_x, image_y, BAND = data_hsi.shape 38 | data = data_hsi.reshape(np.prod(data_hsi.shape[:2]), np.prod(data_hsi.shape[2:])) 39 | gt = gt_hsi.reshape(np.prod(gt_hsi.shape[:2]),) 40 | CLASSES_NUM = max(gt) 41 | print('The class numbers of the HSI data is:', CLASSES_NUM) 42 | 43 | print('-----Importing Setting Parameters-----') 44 | ITER = int(input("Enter num of iterations ")) 45 | PATCH_LENGTH = 3 46 | # number of training samples per class 47 | #lr, num_epochs, batch_size = 0.0001, 200, 32 48 | lr, num_epochs, batch_size = 0.0005, 200, 16 49 | loss = torch.nn.CrossEntropyLoss() 50 | 51 | img_rows = 2*PATCH_LENGTH+1 52 | img_cols = 2*PATCH_LENGTH+1 53 | img_channels = data_hsi.shape[2] 54 | INPUT_DIMENSION = data_hsi.shape[2] 55 | ALL_SIZE = data_hsi.shape[0] * data_hsi.shape[1] 56 | VAL_SIZE = int(TRAIN_SIZE) 57 | TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE 58 | 59 | 60 | KAPPA = [] 61 | OA = [] 62 | AA = [] 63 | TRAINING_TIME = [] 64 | TESTING_TIME = [] 65 | ELEMENT_ACC = np.zeros((ITER, CLASSES_NUM)) 66 | 67 | data = preprocessing.scale(data) 68 | data_ = data.reshape(data_hsi.shape[0], data_hsi.shape[1], data_hsi.shape[2]) 69 | whole_data = data_ 70 | padded_data = np.lib.pad(whole_data, ((PATCH_LENGTH, PATCH_LENGTH), (PATCH_LENGTH, PATCH_LENGTH), (0, 0)), 71 | 72 | 'constant', constant_values=0) 73 | net = network.SSRN_network(BAND, CLASSES_NUM).to(device) 74 | summary(net,input_size=(1,img_rows,img_cols,BAND)) 75 | for index_iter in range(ITER): 76 | print(f"ITER : {index_iter+1}") 77 | 78 | optimizer = optim.Adam(net.parameters(), lr=lr) # , weight_decay=0.0001) 79 | time_1 = int(time.time()) 80 | np.random.seed(seeds[index_iter]) 81 | train_indices, test_indices = sampling(VALIDATION_SPLIT, gt) 82 | _, total_indices = sampling(1, gt) 83 | 84 | TRAIN_SIZE = len(train_indices) 85 | print('Train size: ', TRAIN_SIZE) 86 | TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE 87 | print('Test size: ', TEST_SIZE) 88 | VAL_SIZE = int(TRAIN_SIZE) 89 | print('Validation size: ', VAL_SIZE) 90 | 91 | print('-----Selecting Small Pieces from the Original Cube Data-----') 92 | 93 | train_iter, valida_iter, test_iter, all_iter = generate_iter(TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE, total_indices, VAL_SIZE, 94 | whole_data, PATCH_LENGTH, padded_data, INPUT_DIMENSION, batch_size, gt) 95 | 96 | tic1 = time.clock() 97 | train.train(net, train_iter, valida_iter, loss, optimizer, device, epochs=num_epochs) 98 | toc1 = time.clock() 99 | 100 | pred_test_fdssc = [] 101 | tic2 = time.clock() 102 | with torch.no_grad(): 103 | for X, y in test_iter: 104 | X = X.to(device) 105 | net.eval() 106 | y_hat = net(X) 107 | # print(net(X)) 108 | pred_test_fdssc.extend(np.array(net(X).cpu().argmax(axis=1))) 109 | toc2 = time.clock() 110 | collections.Counter(pred_test_fdssc) 111 | gt_test = gt[test_indices] - 1 112 | 113 | 114 | overall_acc_fdssc = metrics.accuracy_score(pred_test_fdssc, gt_test[:-VAL_SIZE]) 115 | confusion_matrix_fdssc = metrics.confusion_matrix(pred_test_fdssc, gt_test[:-VAL_SIZE]) 116 | print(confusion_matrix_fdssc) 117 | each_acc_fdssc, average_acc_fdssc = aa_and_each_accuracy(confusion_matrix_fdssc) 118 | kappa = metrics.cohen_kappa_score(pred_test_fdssc, gt_test[:-VAL_SIZE]) 119 | 120 | torch.save(net.state_dict(), "./net/" + str(round(overall_acc_fdssc, 3)) + '.pt') 121 | KAPPA.append(kappa) 122 | OA.append(overall_acc_fdssc) 123 | AA.append(average_acc_fdssc) 124 | TRAINING_TIME.append(toc1 - tic1) 125 | TESTING_TIME.append(toc2 - tic2) 126 | ELEMENT_ACC[index_iter, :] = each_acc_fdssc 127 | 128 | print("--------" + net.name + " Training Finished-----------") 129 | record.record_output(OA, AA, KAPPA, ELEMENT_ACC, TRAINING_TIME, TESTING_TIME, 130 | 'records/' + method + '_' + Dataset + '_' +str(BAND)+ '_' + str(VALIDATION_SPLIT) + '.txt') 131 | location = 'records/' + method + '_' + Dataset + '_' +str(BAND)+ '_' + str(VALIDATION_SPLIT) + '.txt' 132 | 133 | 134 | 135 | generate_png(all_iter, net, gt_hsi, Dataset, device, total_indices) 136 | print("location=\"",end="") 137 | print("./records/"+ method + '_' + Dataset + '_' +str(BAND)+ '_' + str(VALIDATION_SPLIT) + '.txt',end="") 138 | print("\"") 139 | -------------------------------------------------------------------------------- /global_module/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import sys 5 | sys.path.append('../global_module/') 6 | import d2lzh_pytorch as d2l 7 | 8 | 9 | def evaluate_accuracy(data_iter, net, loss, device): 10 | acc_sum, n = 0.0, 0 11 | with torch.no_grad(): 12 | for X, y in data_iter: 13 | test_l_sum, test_num = 0, 0 14 | X = X.to(device) 15 | y = y.to(device) 16 | net.eval() # 评估模式, 这会关闭dropout 17 | y_hat = net(X) 18 | l = loss(y_hat, y.long()) 19 | acc_sum += (y_hat.argmax(dim=1) == y.to(device)).float().sum().cpu().item() 20 | test_l_sum += l 21 | test_num += 1 22 | net.train() # 改回训练模式 23 | n += y.shape[0] 24 | return [acc_sum / n, test_l_sum] # / test_num] 25 | 26 | def train(net, train_iter, valida_iter, loss, optimizer, device, epochs=30, early_stopping=True, 27 | early_num=20): 28 | loss_list = [100] 29 | early_epoch = 0 30 | 31 | net = net.to(device) 32 | print("training on ", device) 33 | start = time.time() 34 | train_loss_list = [] 35 | valida_loss_list = [] 36 | train_acc_list = [] 37 | valida_acc_list = [] 38 | for epoch in range(epochs): 39 | train_acc_sum, n = 0.0, 0 40 | time_epoch = time.time() 41 | lr_adjust = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, eta_min=0.0, last_epoch=-1) 42 | #lr_adjust = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=2000, step_size_down=None, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) 43 | 44 | 45 | for X, y in train_iter: 46 | batch_count, train_l_sum = 0, 0 47 | X = X.to(device) 48 | y = y.to(device) 49 | y_hat = net(X) 50 | # print('y_hat', y_hat) 51 | # print('y', y) 52 | l = loss(y_hat, y.long()) 53 | 54 | optimizer.zero_grad() 55 | l.backward() 56 | optimizer.step() 57 | 58 | train_l_sum += l.cpu().item() 59 | train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() 60 | n += y.shape[0] 61 | batch_count += 1 62 | lr_adjust.step() 63 | valida_acc, valida_loss = evaluate_accuracy(valida_iter, net, loss, device) 64 | loss_list.append(valida_loss) 65 | 66 | 67 | train_loss_list.append(train_l_sum) # / batch_count) 68 | train_acc_list.append(train_acc_sum / n) 69 | valida_loss_list.append(valida_loss) 70 | valida_acc_list.append(valida_acc) 71 | 72 | print('epoch %d, train loss %.6f, train acc %.3f, valida loss %.6f, valida acc %.3f, time %.1f sec' 73 | % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, valida_loss, valida_acc, time.time() - time_epoch)) 74 | 75 | PATH = "./net_DBA.pt" 76 | # if loss_list[-1] <= 0.01 and valida_acc >= 0.95: 77 | # torch.save(net.state_dict(), PATH) 78 | # break 79 | 80 | if early_stopping and loss_list[-2] < loss_list[-1]: # < 0.05) and (loss_list[-1] <= 0.05): 81 | if early_epoch == 0: # and valida_acc > 0.9: 82 | torch.save(net.state_dict(), PATH) 83 | early_epoch += 1 84 | loss_list[-1] = loss_list[-2] 85 | if early_epoch == early_num: 86 | net.load_state_dict(torch.load(PATH)) 87 | break 88 | else: 89 | early_epoch = 0 90 | 91 | d2l.set_figsize() 92 | d2l.plt.figure(figsize=(8, 8.5)) 93 | train_accuracy = d2l.plt.subplot(221) 94 | train_accuracy.set_title('train_accuracy') 95 | d2l.plt.plot(np.linspace(1, epoch, len(train_acc_list)), train_acc_list, color='green') 96 | d2l.plt.xlabel('epoch') 97 | d2l.plt.ylabel('train_accuracy') 98 | # train_acc_plot = np.array(train_acc_plot) 99 | # for x, y in zip(num_epochs, train_acc_plot): 100 | # d2l.plt.text(x, y + 0.05, '%.0f' % y, ha='center', va='bottom', fontsize=11) 101 | 102 | test_accuracy = d2l.plt.subplot(222) 103 | test_accuracy.set_title('valida_accuracy') 104 | d2l.plt.plot(np.linspace(1, epoch, len(valida_acc_list)), valida_acc_list, color='deepskyblue') 105 | d2l.plt.xlabel('epoch') 106 | d2l.plt.ylabel('test_accuracy') 107 | # test_acc_plot = np.array(test_acc_plot) 108 | # for x, y in zip(num_epochs, test_acc_plot): 109 | # d2l.plt.text(x, y + 0.05, '%.0f' % y, ha='center', va='bottom', fontsize=11) 110 | 111 | loss_sum = d2l.plt.subplot(223) 112 | loss_sum.set_title('train_loss') 113 | d2l.plt.plot(np.linspace(1, epoch, len(valida_acc_list)), valida_acc_list, color='red') 114 | d2l.plt.xlabel('epoch') 115 | d2l.plt.ylabel('train loss') 116 | # ls_plot = np.array(ls_plot) 117 | 118 | test_loss = d2l.plt.subplot(224) 119 | test_loss.set_title('valida_loss') 120 | d2l.plt.plot(np.linspace(1, epoch, len(valida_loss_list)), valida_loss_list, color='gold') 121 | d2l.plt.xlabel('epoch') 122 | d2l.plt.ylabel('valida loss') 123 | # ls_plot = np.array(ls_plot) 124 | 125 | d2l.plt.show() 126 | print('epoch %d, loss %.4f, train acc %.3f, time %.1f sec' 127 | % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, time.time() - start)) 128 | -------------------------------------------------------------------------------- /global_module/generate_pic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from operator import truediv 4 | import scipy.io as sio 5 | import torch 6 | import math 7 | from Utils import extract_samll_cubic 8 | import torch.utils.data as Data 9 | 10 | 11 | import ast 12 | import os 13 | mydir = "/content/BS-Nets-Implementation-Pytorch/SSRN/" 14 | for file in os.listdir(mydir): 15 | if file.endswith(".txt"): 16 | pass 17 | #print(file) 18 | print("All Bands") 19 | print("Def") 20 | 21 | allband=False 22 | filename = input("Enter filename ") 23 | if filename=="Def": 24 | allband=True 25 | split=0.95 26 | else: 27 | split = float(input("Enter VALIDATION_SPLIT ")) 28 | if split > 1.00 or split <= 0.05: 29 | print("Split was wrong, defaulting to 0.95") 30 | split=0.95 31 | if filename=="All Bands": 32 | allband=True 33 | else: 34 | nbands = int(input("Select Number of bands ")) 35 | 36 | with open(filename, 'r') as f: 37 | BANDLIST = ast.literal_eval(f.read()) 38 | 39 | if nbands>len(BANDLIST) or nbands<5: 40 | print("u entered more bands than provided in the bandlist.") 41 | exit() 42 | 43 | 44 | def pavia_transform(ARRAY,BANDLIST): 45 | if BANDLIST is not None: 46 | BANDLIST=BANDLIST[:15] 47 | BANDLIST = BANDLIST[:nbands] 48 | assert ARRAY.shape[2] ==103 49 | tensor_list = [] 50 | for i in range(0,len(BANDLIST)): 51 | tensor_list.append(ARRAY[:,:,BANDLIST[i]]) 52 | return np.stack(tensor_list,axis=2) 53 | 54 | def salinas_transform(ARRAY,BANDLIST): 55 | if BANDLIST is not None: 56 | BANDLIST=BANDLIST[:20] 57 | BANDLIST = BANDLIST[:nbands] 58 | assert ARRAY.shape[2] ==204 59 | tensor_list = [] 60 | for i in range(0,len(BANDLIST)): 61 | tensor_list.append(ARRAY[:,:,BANDLIST[i]]) 62 | return np.stack(tensor_list,axis=2) 63 | 64 | 65 | def indian_transform(ARRAY,BANDLIST): 66 | if BANDLIST is not None: 67 | BANDLIST=BANDLIST[:25] 68 | BANDLIST = BANDLIST[:nbands] 69 | assert ARRAY.shape[2] ==200 70 | tensor_list = [] 71 | for i in range(0,len(BANDLIST)): 72 | tensor_list.append(ARRAY[:,:,BANDLIST[i]]) 73 | return np.stack(tensor_list,axis=2) 74 | 75 | def load_dataset(Dataset): 76 | if Dataset == 'IN': 77 | mat_data = sio.loadmat('../datasets/Indian_pines_corrected.mat') 78 | mat_gt = sio.loadmat('../datasets/Indian_pines_gt.mat') 79 | data_hsi = mat_data['indian_pines_corrected'] 80 | gt_hsi = mat_gt['indian_pines_gt'] 81 | if not allband: 82 | data_hsi = indian_transform(data_hsi,BANDLIST) 83 | TOTAL_SIZE = 10249 84 | VALIDATION_SPLIT = split 85 | TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT) 86 | 87 | if Dataset == 'UP': 88 | uPavia = sio.loadmat('../datasets/PaviaU.mat') 89 | gt_uPavia = sio.loadmat('../datasets/PaviaU_gt.mat') 90 | data_hsi = uPavia['paviaU'] 91 | gt_hsi = gt_uPavia['paviaU_gt'] 92 | data_hsi = pavia_transform(data_hsi,BANDLIST) 93 | TOTAL_SIZE = 42776 94 | VALIDATION_SPLIT = split 95 | TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT) 96 | 97 | if Dataset == 'SV': 98 | SV = sio.loadmat('../datasets/Salinas_corrected.mat') 99 | gt_SV = sio.loadmat('../datasets/Salinas_gt.mat') 100 | data_hsi = SV['salinas_corrected'] 101 | gt_hsi = gt_SV['salinas_gt'] 102 | data_hsi = salinas_transform(data_hsi,BANDLIST) 103 | TOTAL_SIZE = 54129 104 | VALIDATION_SPLIT = split 105 | TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT) 106 | 107 | return data_hsi, gt_hsi, TOTAL_SIZE, TRAIN_SIZE, VALIDATION_SPLIT,filename 108 | 109 | def save_cmap(img, cmap, fname): 110 | sizes = np.shape(img) 111 | height = float(sizes[0]) 112 | width = float(sizes[1]) 113 | 114 | fig = plt.figure() 115 | fig.set_size_inches(width / height, 1, forward=False) 116 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 117 | ax.set_axis_off() 118 | fig.add_axes(ax) 119 | 120 | ax.imshow(img, cmap=cmap) 121 | plt.savefig(fname, dpi=height) 122 | plt.close() 123 | 124 | def sampling(proportion, ground_truth): 125 | train = {} 126 | test = {} 127 | labels_loc = {} 128 | m = max(ground_truth) 129 | for i in range(m): 130 | indexes = [j for j, x in enumerate(ground_truth.ravel().tolist()) if x == i + 1] 131 | np.random.shuffle(indexes) 132 | labels_loc[i] = indexes 133 | if proportion != 1: 134 | nb_val = max(int((1 - proportion) * len(indexes)), 3) 135 | else: 136 | nb_val = 0 137 | # print(i, nb_val, indexes[:nb_val]) 138 | # train[i] = indexes[:-nb_val] 139 | # test[i] = indexes[-nb_val:] 140 | train[i] = indexes[:nb_val] 141 | test[i] = indexes[nb_val:] 142 | train_indexes = [] 143 | test_indexes = [] 144 | for i in range(m): 145 | train_indexes += train[i] 146 | test_indexes += test[i] 147 | np.random.shuffle(train_indexes) 148 | np.random.shuffle(test_indexes) 149 | return train_indexes, test_indexes 150 | 151 | def aa_and_each_accuracy(confusion_matrix): 152 | list_diag = np.diag(confusion_matrix) 153 | list_raw_sum = np.sum(confusion_matrix, axis=1) 154 | each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum)) 155 | average_acc = np.mean(each_acc) 156 | return each_acc, average_acc 157 | 158 | 159 | def classification_map(map, ground_truth, dpi, save_path): 160 | fig = plt.figure(frameon=False) 161 | fig.set_size_inches(ground_truth.shape[1] * 2.0 / dpi, ground_truth.shape[0] * 2.0 / dpi) 162 | 163 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 164 | ax.set_axis_off() 165 | ax.xaxis.set_visible(False) 166 | ax.yaxis.set_visible(False) 167 | fig.add_axes(ax) 168 | 169 | ax.imshow(map) 170 | fig.savefig(save_path, dpi=dpi) 171 | 172 | return 0 173 | 174 | 175 | def list_to_colormap(x_list): 176 | y = np.zeros((x_list.shape[0], 3)) 177 | for index, item in enumerate(x_list): 178 | if item == 0: 179 | y[index] = np.array([255, 0, 0]) / 255. 180 | if item == 1: 181 | y[index] = np.array([0, 255, 0]) / 255. 182 | if item == 2: 183 | y[index] = np.array([0, 0, 255]) / 255. 184 | if item == 3: 185 | y[index] = np.array([255, 255, 0]) / 255. 186 | if item == 4: 187 | y[index] = np.array([0, 255, 255]) / 255. 188 | if item == 5: 189 | y[index] = np.array([255, 0, 255]) / 255. 190 | if item == 6: 191 | y[index] = np.array([192, 192, 192]) / 255. 192 | if item == 7: 193 | y[index] = np.array([128, 128, 128]) / 255. 194 | if item == 8: 195 | y[index] = np.array([128, 0, 0]) / 255. 196 | if item == 9: 197 | y[index] = np.array([128, 128, 0]) / 255. 198 | if item == 10: 199 | y[index] = np.array([0, 128, 0]) / 255. 200 | if item == 11: 201 | y[index] = np.array([128, 0, 128]) / 255. 202 | if item == 12: 203 | y[index] = np.array([0, 128, 128]) / 255. 204 | if item == 13: 205 | y[index] = np.array([0, 0, 128]) / 255. 206 | if item == 14: 207 | y[index] = np.array([255, 165, 0]) / 255. 208 | if item == 15: 209 | y[index] = np.array([255, 215, 0]) / 255. 210 | if item == 16: 211 | y[index] = np.array([0, 0, 0]) / 255. 212 | if item == 17: 213 | y[index] = np.array([215, 255, 0]) / 255. 214 | if item == 18: 215 | y[index] = np.array([0, 255, 215]) / 255. 216 | if item == -1: 217 | y[index] = np.array([0, 0, 0]) / 255. 218 | return y 219 | 220 | 221 | def generate_iter(TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE, total_indices, VAL_SIZE, 222 | whole_data, PATCH_LENGTH, padded_data, INPUT_DIMENSION, batch_size, gt): 223 | 224 | gt_all = gt[total_indices] - 1 225 | y_train = gt[train_indices] - 1 226 | y_test = gt[test_indices] - 1 227 | 228 | all_data = extract_samll_cubic.select_small_cubic(TOTAL_SIZE, total_indices, whole_data, 229 | PATCH_LENGTH, padded_data, INPUT_DIMENSION) 230 | 231 | train_data = extract_samll_cubic.select_small_cubic(TRAIN_SIZE, train_indices, whole_data, 232 | PATCH_LENGTH, padded_data, INPUT_DIMENSION) 233 | test_data = extract_samll_cubic.select_small_cubic(TEST_SIZE, test_indices, whole_data, 234 | PATCH_LENGTH, padded_data, INPUT_DIMENSION) 235 | x_train = train_data.reshape(train_data.shape[0], train_data.shape[1], train_data.shape[2], INPUT_DIMENSION) 236 | x_test_all = test_data.reshape(test_data.shape[0], test_data.shape[1], test_data.shape[2], INPUT_DIMENSION) 237 | 238 | x_val = x_test_all[-VAL_SIZE:] 239 | y_val = y_test[-VAL_SIZE:] 240 | 241 | x_test = x_test_all[:-VAL_SIZE] 242 | y_test = y_test[:-VAL_SIZE] 243 | # print('y_train', np.unique(y_train)) 244 | # print('y_val', np.unique(y_val)) 245 | # print('y_test', np.unique(y_test)) 246 | # print(y_val) 247 | # print(y_test) 248 | 249 | # K.clear_session() # clear session before next loop 250 | 251 | # print(y1_train) 252 | #y1_train = to_categorical(y1_train) # to one-hot labels 253 | x1_tensor_train = torch.from_numpy(x_train).type(torch.FloatTensor).unsqueeze(1) 254 | y1_tensor_train = torch.from_numpy(y_train).type(torch.FloatTensor) 255 | torch_dataset_train = Data.TensorDataset(x1_tensor_train, y1_tensor_train) 256 | 257 | x1_tensor_valida = torch.from_numpy(x_val).type(torch.FloatTensor).unsqueeze(1) 258 | y1_tensor_valida = torch.from_numpy(y_val).type(torch.FloatTensor) 259 | torch_dataset_valida = Data.TensorDataset(x1_tensor_valida, y1_tensor_valida) 260 | 261 | x1_tensor_test = torch.from_numpy(x_test).type(torch.FloatTensor).unsqueeze(1) 262 | y1_tensor_test = torch.from_numpy(y_test).type(torch.FloatTensor) 263 | torch_dataset_test = Data.TensorDataset(x1_tensor_test,y1_tensor_test) 264 | 265 | all_data.reshape(all_data.shape[0], all_data.shape[1], all_data.shape[2], INPUT_DIMENSION) 266 | all_tensor_data = torch.from_numpy(all_data).type(torch.FloatTensor).unsqueeze(1) 267 | all_tensor_data_label = torch.from_numpy(gt_all).type(torch.FloatTensor) 268 | torch_dataset_all = Data.TensorDataset(all_tensor_data, all_tensor_data_label) 269 | 270 | 271 | train_iter = Data.DataLoader( 272 | dataset=torch_dataset_train, # torch TensorDataset format 273 | batch_size=batch_size, # mini batch size 274 | shuffle=True, # 要不要打乱数据 (打乱比较好) 275 | num_workers=0, # 多线程来读数据 276 | ) 277 | valiada_iter = Data.DataLoader( 278 | dataset=torch_dataset_valida, # torch TensorDataset format 279 | batch_size=batch_size, # mini batch size 280 | shuffle=True, # 要不要打乱数据 (打乱比较好) 281 | num_workers=0, # 多线程来读数据 282 | ) 283 | test_iter = Data.DataLoader( 284 | dataset=torch_dataset_test, # torch TensorDataset format 285 | batch_size=batch_size, # mini batch size 286 | shuffle=False, # 要不要打乱数据 (打乱比较好) 287 | num_workers=0, # 多线程来读数据 288 | ) 289 | all_iter = Data.DataLoader( 290 | dataset=torch_dataset_all, # torch TensorDataset format 291 | batch_size=batch_size, # mini batch size 292 | shuffle=False, # 要不要打乱数据 (打乱比较好) 293 | num_workers=0, # 多线程来读数据 294 | ) 295 | return train_iter, valiada_iter, test_iter, all_iter #, y_test 296 | 297 | def generate_png(all_iter, net, gt_hsi, Dataset, device, total_indices): 298 | pred_test = [] 299 | for X, y in all_iter: 300 | X = X.to(device) 301 | net.eval() # 评估模式, 这会关闭dropout 302 | # print(net(X)) 303 | pred_test.extend(np.array(net(X).detach().cpu().numpy().argmax(axis=1))) 304 | 305 | gt = gt_hsi.flatten() 306 | x_label = np.zeros(gt.shape) 307 | for i in range(len(gt)): 308 | if gt[i] == 0: 309 | gt[i] = 17 310 | # x[i] = 16 311 | x_label[i] = 16 312 | # else: 313 | # x_label[i] = pred_test[label_list] 314 | # label_list += 1 315 | gt = gt[:] - 1 316 | x_label[total_indices] = pred_test 317 | x = np.ravel(x_label) 318 | 319 | # print('-------Save the result in mat format--------') 320 | # x_re = np.reshape(x, (gt_hsi.shape[0], gt_hsi.shape[1])) 321 | # sio.savemat('mat/' + Dataset + '_' + '.mat', {Dataset: x_re}) 322 | 323 | y_list = list_to_colormap(x) 324 | y_gt = list_to_colormap(gt) 325 | 326 | y_re = np.reshape(y_list, (gt_hsi.shape[0], gt_hsi.shape[1], 3)) 327 | gt_re = np.reshape(y_gt, (gt_hsi.shape[0], gt_hsi.shape[1], 3)) 328 | 329 | path = '../' + net.name 330 | 331 | classification_map(y_re, gt_hsi, 300, 332 | path + '/classification_maps/' + '_' + Dataset + '.pdf') 333 | classification_map(gt_re, gt_hsi, 300, 334 | path + '/classification_maps/' + Dataset + '_gt.pdf') 335 | print('------Get classification maps successful-------') 336 | -------------------------------------------------------------------------------- /global_module/d2lzh_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import os 4 | import random 5 | import sys 6 | import tarfile 7 | import time 8 | import zipfile 9 | from tqdm import tqdm 10 | 11 | from IPython import display 12 | 13 | from matplotlib import pyplot as plt 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | import torchtext 20 | import torchtext.vocab as Vocab 21 | import numpy as np 22 | 23 | 24 | VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 25 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 26 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 27 | 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] 28 | 29 | 30 | VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 31 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 32 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 33 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 34 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 35 | [0, 64, 128]] 36 | 37 | 38 | 39 | # ###################### 3.2 ############################ 40 | def set_figsize(figsize=(3.5, 2.5)): 41 | use_svg_display() 42 | # 设置图的尺寸 43 | plt.rcParams['figure.figsize'] = figsize 44 | 45 | def use_svg_display(): 46 | """Use svg format to display plot in jupyter""" 47 | display.set_matplotlib_formats('svg') 48 | 49 | def data_iter(batch_size, features, labels): 50 | num_examples = len(features) 51 | indices = list(range(num_examples)) 52 | random.shuffle(indices) # 样本的读取顺序是随机的 53 | for i in range(0, num_examples, batch_size): 54 | j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # 最后一次可能不足一个batch 55 | yield features.index_select(0, j), labels.index_select(0, j) 56 | 57 | def linreg(X, w, b): 58 | return torch.mm(X, w) + b 59 | 60 | def squared_loss(y_hat, y): 61 | # 注意这里返回的是向量, 另外, pytorch里的MSELoss并没有除以 2 62 | return ((y_hat - y.view(y_hat.size())) ** 2) / 2 63 | 64 | def sgd(params, lr, batch_size): 65 | # 为了和原书保持一致,这里除以了batch_size,但是应该是不用除的,因为一般用PyTorch计算loss时就默认已经 66 | # 沿batch维求了平均了。 67 | for param in params: 68 | param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data 69 | 70 | 71 | 72 | # ######################3##### 3.5 ############################# 73 | def get_fashion_mnist_labels(labels): 74 | text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 75 | 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] 76 | return [text_labels[int(i)] for i in labels] 77 | 78 | def show_fashion_mnist(images, labels): 79 | use_svg_display() 80 | # 这里的_表示我们忽略(不使用)的变量 81 | _, figs = plt.subplots(1, len(images), figsize=(12, 12)) 82 | for f, img, lbl in zip(figs, images, labels): 83 | f.imshow(img.view((28, 28)).numpy()) 84 | f.set_title(lbl) 85 | f.axes.get_xaxis().set_visible(False) 86 | f.axes.get_yaxis().set_visible(False) 87 | # plt.show() 88 | 89 | # 5.6 修改 90 | # def load_data_fashion_mnist(batch_size, root='~/Datasets/FashionMNIST'): 91 | # """Download the fashion mnist dataset and then load into memory.""" 92 | # transform = transforms.ToTensor() 93 | # mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) 94 | # mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform) 95 | # if sys.platform.startswith('win'): 96 | # num_workers = 0 # 0表示不用额外的进程来加速读取数据 97 | # else: 98 | # num_workers = 4 99 | # train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) 100 | # test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) 101 | 102 | # return train_iter, test_iter 103 | 104 | 105 | 106 | 107 | # ########################### 3.6 ############################### 108 | # (3.13节修改) 109 | # def evaluate_accuracy(data_iter, net): 110 | # acc_sum, n = 0.0, 0 111 | # for X, y in data_iter: 112 | # acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 113 | # n += y.shape[0] 114 | # return acc_sum / n 115 | 116 | 117 | def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, 118 | params=None, lr=None, optimizer=None): 119 | for epoch in range(num_epochs): 120 | train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 121 | for X, y in train_iter: 122 | y_hat = net(X) 123 | l = loss(y_hat, y).sum() 124 | 125 | # 梯度清零 126 | if optimizer is not None: 127 | optimizer.zero_grad() 128 | elif params is not None and params[0].grad is not None: 129 | for param in params: 130 | param.grad.data.zero_() 131 | 132 | l.backward() 133 | if optimizer is None: 134 | sgd(params, lr, batch_size) 135 | else: 136 | optimizer.step() # “softmax回归的简洁实现”一节将用到 137 | 138 | 139 | train_l_sum += l.item() 140 | train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item() 141 | n += y.shape[0] 142 | test_acc = evaluate_accuracy(test_iter, net) 143 | print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' 144 | % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc)) 145 | 146 | 147 | 148 | 149 | # ########################### 3.7 #####################################3 150 | class FlattenLayer(torch.nn.Module): 151 | def __init__(self): 152 | super(FlattenLayer, self).__init__() 153 | def forward(self, x): # x shape: (batch, *, *, ...) 154 | return x.view(x.shape[0], -1) 155 | 156 | 157 | # ########################### 3.11 ############################### 158 | def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, 159 | legend=None, figsize=(3.5, 2.5)): 160 | set_figsize(figsize) 161 | plt.xlabel(x_label) 162 | plt.ylabel(y_label) 163 | plt.semilogy(x_vals, y_vals) 164 | if x2_vals and y2_vals: 165 | plt.semilogy(x2_vals, y2_vals, linestyle=':') 166 | plt.legend(legend) 167 | # plt.show() 168 | 169 | 170 | 171 | 172 | # ############################# 3.13 ############################## 173 | # 5.5 修改 174 | # def evaluate_accuracy(data_iter, net): 175 | # acc_sum, n = 0.0, 0 176 | # for X, y in data_iter: 177 | # if isinstance(net, torch.nn.Module): 178 | # net.eval() # 评估模式, 这会关闭dropout 179 | # acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 180 | # net.train() # 改回训练模式 181 | # else: # 自定义的模型 182 | # if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数 183 | # # 将is_training设置成False 184 | # acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 185 | # else: 186 | # acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 187 | # n += y.shape[0] 188 | # return acc_sum / n 189 | 190 | 191 | 192 | 193 | 194 | 195 | # ########################### 5.1 ######################### 196 | def corr2d(X, K): 197 | h, w = K.shape 198 | Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)) 199 | for i in range(Y.shape[0]): 200 | for j in range(Y.shape[1]): 201 | Y[i, j] = (X[i: i + h, j: j + w] * K).sum() 202 | return Y 203 | 204 | 205 | 206 | # ############################ 5.5 ######################### 207 | def evaluate_accuracy(data_iter, net, 208 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')): 209 | acc_sum, n = 0.0, 0 210 | with torch.no_grad(): 211 | for X, y in data_iter: 212 | if isinstance(net, torch.nn.Module): 213 | net.eval() # 评估模式, 这会关闭dropout 214 | acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item() 215 | net.train() # 改回训练模式 216 | else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU 217 | if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数 218 | # 将is_training设置成False 219 | acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 220 | else: 221 | acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 222 | n += y.shape[0] 223 | return acc_sum / n 224 | 225 | def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs): 226 | net = net.to(device) 227 | print("training on ", device) 228 | loss = torch.nn.CrossEntropyLoss() 229 | batch_count = 0 230 | for epoch in range(num_epochs): 231 | train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() 232 | for X, y in train_iter: 233 | X = X.to(device) 234 | y = y.to(device) 235 | y_hat = net(X) 236 | l = loss(y_hat, y) 237 | optimizer.zero_grad() 238 | l.backward() 239 | optimizer.step() 240 | train_l_sum += l.cpu().item() 241 | train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() 242 | n += y.shape[0] 243 | batch_count += 1 244 | test_acc = evaluate_accuracy(test_iter, net) 245 | print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' 246 | % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) 247 | 248 | 249 | 250 | # ########################## 5.6 #########################3 251 | def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'): 252 | """Download the fashion mnist dataset and then load into memory.""" 253 | trans = [] 254 | if resize: 255 | trans.append(torchvision.transforms.Resize(size=resize)) 256 | trans.append(torchvision.transforms.ToTensor()) 257 | 258 | transform = torchvision.transforms.Compose(trans) 259 | mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) 260 | mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform) 261 | if sys.platform.startswith('win'): 262 | num_workers = 0 # 0表示不用额外的进程来加速读取数据 263 | else: 264 | num_workers = 4 265 | train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) 266 | test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) 267 | 268 | return train_iter, test_iter 269 | 270 | 271 | 272 | ############################# 5.8 ############################## 273 | class GlobalAvgPool2d(nn.Module): 274 | # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现 275 | def __init__(self): 276 | super(GlobalAvgPool2d, self).__init__() 277 | def forward(self, x): 278 | return F.avg_pool2d(x, kernel_size=x.size()[2:]) 279 | 280 | 281 | 282 | # ########################### 5.11 ################################ 283 | class Residual(nn.Module): 284 | def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1): 285 | super(Residual, self).__init__() 286 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) 287 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 288 | if use_1x1conv: 289 | self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) 290 | else: 291 | self.conv3 = None 292 | self.bn1 = nn.BatchNorm2d(out_channels) 293 | self.bn2 = nn.BatchNorm2d(out_channels) 294 | 295 | def forward(self, X): 296 | Y = F.relu(self.bn1(self.conv1(X))) 297 | Y = self.bn2(self.conv2(Y)) 298 | if self.conv3: 299 | X = self.conv3(X) 300 | return F.relu(Y + X) 301 | 302 | def resnet_block(in_channels, out_channels, num_residuals, first_block=False): 303 | if first_block: 304 | assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致 305 | blk = [] 306 | for i in range(num_residuals): 307 | if i == 0 and not first_block: 308 | blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2)) 309 | else: 310 | blk.append(Residual(out_channels, out_channels)) 311 | return nn.Sequential(*blk) 312 | 313 | def resnet18(output=10, in_channels=3): 314 | net = nn.Sequential( 315 | nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3), 316 | nn.BatchNorm2d(64), 317 | nn.ReLU(), 318 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 319 | net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True)) 320 | net.add_module("resnet_block2", resnet_block(64, 128, 2)) 321 | net.add_module("resnet_block3", resnet_block(128, 256, 2)) 322 | net.add_module("resnet_block4", resnet_block(256, 512, 2)) 323 | net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1) 324 | net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(512, output))) 325 | return net 326 | 327 | 328 | 329 | # ############################## 6.3 ##################################3 330 | def load_data_jay_lyrics(): 331 | """加载周杰伦歌词数据集""" 332 | with zipfile.ZipFile('../../data/jaychou_lyrics.txt.zip') as zin: 333 | with zin.open('jaychou_lyrics.txt') as f: 334 | corpus_chars = f.read().decode('utf-8') 335 | corpus_chars = corpus_chars.replace('\n', ' ').replace('\r', ' ') 336 | corpus_chars = corpus_chars[0:10000] 337 | idx_to_char = list(set(corpus_chars)) 338 | char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)]) 339 | vocab_size = len(char_to_idx) 340 | corpus_indices = [char_to_idx[char] for char in corpus_chars] 341 | return corpus_indices, char_to_idx, idx_to_char, vocab_size 342 | 343 | def data_iter_random(corpus_indices, batch_size, num_steps, device=None): 344 | # 减1是因为输出的索引x是相应输入的索引y加1 345 | num_examples = (len(corpus_indices) - 1) // num_steps 346 | epoch_size = num_examples // batch_size 347 | example_indices = list(range(num_examples)) 348 | random.shuffle(example_indices) 349 | 350 | # 返回从pos开始的长为num_steps的序列 351 | def _data(pos): 352 | return corpus_indices[pos: pos + num_steps] 353 | if device is None: 354 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 355 | 356 | for i in range(epoch_size): 357 | # 每次读取batch_size个随机样本 358 | i = i * batch_size 359 | batch_indices = example_indices[i: i + batch_size] 360 | X = [_data(j * num_steps) for j in batch_indices] 361 | Y = [_data(j * num_steps + 1) for j in batch_indices] 362 | yield torch.tensor(X, dtype=torch.float32, device=device), torch.tensor(Y, dtype=torch.float32, device=device) 363 | 364 | def data_iter_consecutive(corpus_indices, batch_size, num_steps, device=None): 365 | if device is None: 366 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 367 | corpus_indices = torch.tensor(corpus_indices, dtype=torch.float32, device=device) 368 | data_len = len(corpus_indices) 369 | batch_len = data_len // batch_size 370 | indices = corpus_indices[0: batch_size*batch_len].view(batch_size, batch_len) 371 | epoch_size = (batch_len - 1) // num_steps 372 | for i in range(epoch_size): 373 | i = i * num_steps 374 | X = indices[:, i: i + num_steps] 375 | Y = indices[:, i + 1: i + num_steps + 1] 376 | yield X, Y 377 | 378 | 379 | 380 | 381 | 382 | # ###################################### 6.4 ###################################### 383 | def one_hot(x, n_class, dtype=torch.float32): 384 | # X shape: (batch), output shape: (batch, n_class) 385 | x = x.long() 386 | res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) 387 | res.scatter_(1, x.view(-1, 1), 1) 388 | return res 389 | 390 | def to_onehot(X, n_class): 391 | # X shape: (batch, seq_len), output: seq_len elements of (batch, n_class) 392 | return [one_hot(X[:, i], n_class) for i in range(X.shape[1])] 393 | 394 | def predict_rnn(prefix, num_chars, rnn, params, init_rnn_state, 395 | num_hiddens, vocab_size, device, idx_to_char, char_to_idx): 396 | state = init_rnn_state(1, num_hiddens, device) 397 | output = [char_to_idx[prefix[0]]] 398 | for t in range(num_chars + len(prefix) - 1): 399 | # 将上一时间步的输出作为当前时间步的输入 400 | X = to_onehot(torch.tensor([[output[-1]]], device=device), vocab_size) 401 | # 计算输出和更新隐藏状态 402 | (Y, state) = rnn(X, state, params) 403 | # 下一个时间步的输入是prefix里的字符或者当前的最佳预测字符 404 | if t < len(prefix) - 1: 405 | output.append(char_to_idx[prefix[t + 1]]) 406 | else: 407 | output.append(int(Y[0].argmax(dim=1).item())) 408 | return ''.join([idx_to_char[i] for i in output]) 409 | 410 | def grad_clipping(params, theta, device): 411 | norm = torch.tensor([0.0], device=device) 412 | for param in params: 413 | norm += (param.grad.data ** 2).sum() 414 | norm = norm.sqrt().item() 415 | if norm > theta: 416 | for param in params: 417 | param.grad.data *= (theta / norm) 418 | 419 | def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens, 420 | vocab_size, device, corpus_indices, idx_to_char, 421 | char_to_idx, is_random_iter, num_epochs, num_steps, 422 | lr, clipping_theta, batch_size, pred_period, 423 | pred_len, prefixes): 424 | if is_random_iter: 425 | data_iter_fn = data_iter_random 426 | else: 427 | data_iter_fn = data_iter_consecutive 428 | params = get_params() 429 | loss = nn.CrossEntropyLoss() 430 | 431 | for epoch in range(num_epochs): 432 | if not is_random_iter: # 如使用相邻采样,在epoch开始时初始化隐藏状态 433 | state = init_rnn_state(batch_size, num_hiddens, device) 434 | l_sum, n, start = 0.0, 0, time.time() 435 | data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device) 436 | for X, Y in data_iter: 437 | if is_random_iter: # 如使用随机采样,在每个小批量更新前初始化隐藏状态 438 | state = init_rnn_state(batch_size, num_hiddens, device) 439 | else: 440 | # 否则需要使用detach函数从计算图分离隐藏状态, 这是为了 441 | # 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大) 442 | for s in state: 443 | s.detach_() 444 | 445 | inputs = to_onehot(X, vocab_size) 446 | # outputs有num_steps个形状为(batch_size, vocab_size)的矩阵 447 | (outputs, state) = rnn(inputs, state, params) 448 | # 拼接之后形状为(num_steps * batch_size, vocab_size) 449 | outputs = torch.cat(outputs, dim=0) 450 | # Y的形状是(batch_size, num_steps),转置后再变成长度为 451 | # batch * num_steps 的向量,这样跟输出的行一一对应 452 | y = torch.transpose(Y, 0, 1).contiguous().view(-1) 453 | # 使用交叉熵损失计算平均分类误差 454 | l = loss(outputs, y.long()) 455 | 456 | # 梯度清0 457 | if params[0].grad is not None: 458 | for param in params: 459 | param.grad.data.zero_() 460 | l.backward() 461 | grad_clipping(params, clipping_theta, device) # 裁剪梯度 462 | sgd(params, lr, 1) # 因为误差已经取过均值,梯度不用再做平均 463 | l_sum += l.item() * y.shape[0] 464 | n += y.shape[0] 465 | 466 | if (epoch + 1) % pred_period == 0: 467 | print('epoch %d, perplexity %f, time %.2f sec' % ( 468 | epoch + 1, math.exp(l_sum / n), time.time() - start)) 469 | for prefix in prefixes: 470 | print(' -', predict_rnn(prefix, pred_len, rnn, params, init_rnn_state, 471 | num_hiddens, vocab_size, device, idx_to_char, char_to_idx)) 472 | 473 | 474 | 475 | 476 | # ################################### 6.5 ################################################ 477 | class RNNModel(nn.Module): 478 | def __init__(self, rnn_layer, vocab_size): 479 | super(RNNModel, self).__init__() 480 | self.rnn = rnn_layer 481 | self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 482 | self.vocab_size = vocab_size 483 | self.dense = nn.Linear(self.hidden_size, vocab_size) 484 | self.state = None 485 | 486 | def forward(self, inputs, state): # inputs: (batch, seq_len) 487 | # 获取one-hot向量表示 488 | X = to_onehot(inputs, self.vocab_size) # X是个list 489 | Y, self.state = self.rnn(torch.stack(X), state) 490 | # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出 491 | # 形状为(num_steps * batch_size, vocab_size) 492 | output = self.dense(Y.view(-1, Y.shape[-1])) 493 | return output, self.state 494 | 495 | def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char, 496 | char_to_idx): 497 | state = None 498 | output = [char_to_idx[prefix[0]]] # output会记录prefix加上输出 499 | for t in range(num_chars + len(prefix) - 1): 500 | X = torch.tensor([output[-1]], device=device).view(1, 1) 501 | if state is not None: 502 | if isinstance(state, tuple): # LSTM, state:(h, c) 503 | state = (state[0].to(device), state[1].to(device)) 504 | else: 505 | state = state.to(device) 506 | 507 | (Y, state) = model(X, state) # 前向计算不需要传入模型参数 508 | if t < len(prefix) - 1: 509 | output.append(char_to_idx[prefix[t + 1]]) 510 | else: 511 | output.append(int(Y.argmax(dim=1).item())) 512 | return ''.join([idx_to_char[i] for i in output]) 513 | 514 | def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, 515 | corpus_indices, idx_to_char, char_to_idx, 516 | num_epochs, num_steps, lr, clipping_theta, 517 | batch_size, pred_period, pred_len, prefixes): 518 | loss = nn.CrossEntropyLoss() 519 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 520 | model.to(device) 521 | state = None 522 | for epoch in range(num_epochs): 523 | l_sum, n, start = 0.0, 0, time.time() 524 | data_iter = data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样 525 | for X, Y in data_iter: 526 | if state is not None: 527 | # 使用detach函数从计算图分离隐藏状态, 这是为了 528 | # 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大) 529 | if isinstance (state, tuple): # LSTM, state:(h, c) 530 | state = (state[0].detach(), state[1].detach()) 531 | else: 532 | state = state.detach() 533 | 534 | (output, state) = model(X, state) # output: 形状为(num_steps * batch_size, vocab_size) 535 | 536 | # Y的形状是(batch_size, num_steps),转置后再变成长度为 537 | # batch * num_steps 的向量,这样跟输出的行一一对应 538 | y = torch.transpose(Y, 0, 1).contiguous().view(-1) 539 | l = loss(output, y.long()) 540 | 541 | optimizer.zero_grad() 542 | l.backward() 543 | # 梯度裁剪 544 | grad_clipping(model.parameters(), clipping_theta, device) 545 | optimizer.step() 546 | l_sum += l.item() * y.shape[0] 547 | n += y.shape[0] 548 | 549 | try: 550 | perplexity = math.exp(l_sum / n) 551 | except OverflowError: 552 | perplexity = float('inf') 553 | if (epoch + 1) % pred_period == 0: 554 | print('epoch %d, perplexity %f, time %.2f sec' % ( 555 | epoch + 1, perplexity, time.time() - start)) 556 | for prefix in prefixes: 557 | print(' -', predict_rnn_pytorch( 558 | prefix, pred_len, model, vocab_size, device, idx_to_char, 559 | char_to_idx)) 560 | 561 | 562 | 563 | 564 | # ######################################## 7.2 ############################################### 565 | def train_2d(trainer): 566 | x1, x2, s1, s2 = -5, -2, 0, 0 # s1和s2是自变量状态,本章后续几节会使用 567 | results = [(x1, x2)] 568 | for i in range(20): 569 | x1, x2, s1, s2 = trainer(x1, x2, s1, s2) 570 | results.append((x1, x2)) 571 | print('epoch %d, x1 %f, x2 %f' % (i + 1, x1, x2)) 572 | return results 573 | 574 | def show_trace_2d(f, results): 575 | plt.plot(*zip(*results), '-o', color='#ff7f0e') 576 | x1, x2 = np.meshgrid(np.arange(-5.5, 1.0, 0.1), np.arange(-3.0, 1.0, 0.1)) 577 | plt.contour(x1, x2, f(x1, x2), colors='#1f77b4') 578 | plt.xlabel('x1') 579 | plt.ylabel('x2') 580 | 581 | 582 | 583 | 584 | # ######################################## 7.3 ############################################### 585 | def get_data_ch7(): 586 | data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\t') 587 | data = (data - data.mean(axis=0)) / data.std(axis=0) 588 | return torch.tensor(data[:1500, :-1], dtype=torch.float32), \ 589 | torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征) 590 | 591 | def train_ch7(optimizer_fn, states, hyperparams, features, labels, 592 | batch_size=10, num_epochs=2): 593 | # 初始化模型 594 | net, loss = linreg, squared_loss 595 | 596 | w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32), 597 | requires_grad=True) 598 | b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True) 599 | 600 | def eval_loss(): 601 | return loss(net(features, w, b), labels).mean().item() 602 | 603 | ls = [eval_loss()] 604 | data_iter = torch.utils.data.DataLoader( 605 | torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True) 606 | 607 | for _ in range(num_epochs): 608 | start = time.time() 609 | for batch_i, (X, y) in enumerate(data_iter): 610 | l = loss(net(X, w, b), y).mean() # 使用平均损失 611 | 612 | # 梯度清零 613 | if w.grad is not None: 614 | w.grad.data.zero_() 615 | b.grad.data.zero_() 616 | 617 | l.backward() 618 | optimizer_fn([w, b], states, hyperparams) # 迭代模型参数 619 | if (batch_i + 1) * batch_size % 100 == 0: 620 | ls.append(eval_loss()) # 每100个样本记录下当前训练误差 621 | # 打印结果和作图 622 | print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start)) 623 | set_figsize() 624 | plt.plot(np.linspace(0, num_epochs, len(ls)), ls) 625 | plt.xlabel('epoch') 626 | plt.ylabel('loss') 627 | 628 | # 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字 629 | # 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05} 630 | def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels, 631 | batch_size=10, num_epochs=2): 632 | # 初始化模型 633 | net = nn.Sequential( 634 | nn.Linear(features.shape[-1], 1) 635 | ) 636 | loss = nn.MSELoss() 637 | optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams) 638 | 639 | def eval_loss(): 640 | return loss(net(features).view(-1), labels).item() / 2 641 | 642 | ls = [eval_loss()] 643 | data_iter = torch.utils.data.DataLoader( 644 | torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True) 645 | 646 | for _ in range(num_epochs): 647 | start = time.time() 648 | for batch_i, (X, y) in enumerate(data_iter): 649 | # 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2 650 | l = loss(net(X).view(-1), y) / 2 651 | 652 | optimizer.zero_grad() 653 | l.backward() 654 | optimizer.step() 655 | if (batch_i + 1) * batch_size % 100 == 0: 656 | ls.append(eval_loss()) 657 | # 打印结果和作图 658 | print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start)) 659 | set_figsize() 660 | plt.plot(np.linspace(0, num_epochs, len(ls)), ls) 661 | plt.xlabel('epoch') 662 | plt.ylabel('loss') 663 | 664 | 665 | 666 | 667 | ############################## 8.3 ################################## 668 | class Benchmark(): 669 | def __init__(self, prefix=None): 670 | self.prefix = prefix + ' ' if prefix else '' 671 | 672 | def __enter__(self): 673 | self.start = time.time() 674 | 675 | def __exit__(self, *args): 676 | print('%stime: %.4f sec' % (self.prefix, time.time() - self.start)) 677 | 678 | 679 | 680 | 681 | 682 | # ########################### 9.1 ######################################## 683 | def show_images(imgs, num_rows, num_cols, scale=2): 684 | figsize = (num_cols * scale, num_rows * scale) 685 | _, axes = plt.subplots(num_rows, num_cols, figsize=figsize) 686 | for i in range(num_rows): 687 | for j in range(num_cols): 688 | axes[i][j].imshow(imgs[i * num_cols + j]) 689 | axes[i][j].axes.get_xaxis().set_visible(False) 690 | axes[i][j].axes.get_yaxis().set_visible(False) 691 | return axes 692 | 693 | def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs): 694 | net = net.to(device) 695 | print("training on ", device) 696 | batch_count = 0 697 | for epoch in range(num_epochs): 698 | train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() 699 | for X, y in train_iter: 700 | X = X.to(device) 701 | y = y.to(device) 702 | y_hat = net(X) 703 | l = loss(y_hat, y) 704 | optimizer.zero_grad() 705 | l.backward() 706 | optimizer.step() 707 | train_l_sum += l.cpu().item() 708 | train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() 709 | n += y.shape[0] 710 | batch_count += 1 711 | test_acc = evaluate_accuracy(test_iter, net) 712 | print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' 713 | % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) 714 | 715 | 716 | 717 | 718 | 719 | ############################## 9.3 ##################### 720 | def bbox_to_rect(bbox, color): 721 | # 将边界框(左上x, 左上y, 右下x, 右下y)格式转换成matplotlib格式: 722 | # ((左上x, 左上y), 宽, 高) 723 | return plt.Rectangle( 724 | xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1], 725 | fill=False, edgecolor=color, linewidth=2) 726 | 727 | 728 | 729 | 730 | 731 | # ############################# 10.7 ########################## 732 | def read_imdb(folder='train', data_root="/S1/CSCL/tangss/Datasets/aclImdb"): 733 | data = [] 734 | for label in ['pos', 'neg']: 735 | folder_name = os.path.join(data_root, folder, label) 736 | for file in tqdm(os.listdir(folder_name)): 737 | with open(os.path.join(folder_name, file), 'rb') as f: 738 | review = f.read().decode('utf-8').replace('\n', '').lower() 739 | data.append([review, 1 if label == 'pos' else 0]) 740 | random.shuffle(data) 741 | return data 742 | 743 | def get_tokenized_imdb(data): 744 | """ 745 | data: list of [string, label] 746 | """ 747 | def tokenizer(text): 748 | return [tok.lower() for tok in text.split(' ')] 749 | return [tokenizer(review) for review, _ in data] 750 | 751 | def get_vocab_imdb(data): 752 | tokenized_data = get_tokenized_imdb(data) 753 | counter = collections.Counter([tk for st in tokenized_data for tk in st]) 754 | return torchtext.vocab.Vocab(counter, min_freq=5) 755 | 756 | def preprocess_imdb(data, vocab): 757 | max_l = 500 # 将每条评论通过截断或者补0,使得长度变成500 758 | 759 | def pad(x): 760 | return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x)) 761 | 762 | tokenized_data = get_tokenized_imdb(data) 763 | features = torch.tensor([pad([vocab.stoi[word] for word in words]) for words in tokenized_data]) 764 | labels = torch.tensor([score for _, score in data]) 765 | return features, labels 766 | 767 | def load_pretrained_embedding(words, pretrained_vocab): 768 | """从预训练好的vocab中提取出words对应的词向量""" 769 | embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0]) # 初始化为0 770 | oov_count = 0 # out of vocabulary 771 | for i, word in enumerate(words): 772 | try: 773 | idx = pretrained_vocab.stoi[word] 774 | embed[i, :] = pretrained_vocab.vectors[idx] 775 | except KeyError: 776 | oov_count += 0 777 | if oov_count > 0: 778 | print("There are %d oov words.") 779 | return embed 780 | 781 | def predict_sentiment(net, vocab, sentence): 782 | """sentence是词语的列表""" 783 | device = list(net.parameters())[0].device 784 | sentence = torch.tensor([vocab.stoi[word] for word in sentence], device=device) 785 | label = torch.argmax(net(sentence.view((1, -1))), dim=1) 786 | return 'positive' if label.item() == 1 else 'negative' -------------------------------------------------------------------------------- /global_module/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | from sklearn.svm import SVC 6 | from sklearn.model_selection import GridSearchCV 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | import sys 11 | sys.path.append('../global_module/') 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | class Residual_2D(nn.Module): # 本类已保存在d2lzh_pytorch包中方便以后使用 16 | def __init__(self, in_channels, out_channels, kernel_size, padding, batch_normal = False, stride=1): 17 | super(Residual_2D, self).__init__() 18 | self.conv1 = nn.Conv2d(in_channels, out_channels, 19 | kernel_size=kernel_size, padding=padding, stride=stride) 20 | self.conv2 = nn.Conv2d(out_channels, out_channels, 21 | kernel_size=kernel_size, padding=padding,stride=stride) 22 | if batch_normal: 23 | self.bn = nn.Sequential( 24 | nn.ReLU(), 25 | nn.BatchNorm2d(out_channels) 26 | ) 27 | else: 28 | self.bn = nn.ReLU() 29 | def forward(self, X): 30 | Y = F.relu(self.conv1(self.bn(X))) 31 | Y = self.conv2(Y) 32 | return F.relu(Y + X) 33 | 34 | 35 | from torch.nn.modules.conv import _ConvNd 36 | from torch.nn.modules.utils import _pair 37 | 38 | class GaborConv2d(_ConvNd): 39 | 40 | def __init__(self, in_channels, out_channels, kernel_size, device="cpu", stride=1, 41 | padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros'): 42 | kernel_size = _pair(kernel_size) 43 | stride = _pair(stride) 44 | padding = _pair(padding) 45 | dilation = _pair(dilation) 46 | 47 | super(GaborConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, 48 | _pair(0), groups, bias, padding_mode) 49 | self.freq = nn.Parameter( 50 | (3.14 / 2) * 1.41 ** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor)) 51 | self.theta = nn.Parameter((3.14 / 8) * torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor)) 52 | self.psi = nn.Parameter(3.14 * torch.rand(out_channels, in_channels)) 53 | self.sigma = nn.Parameter(3.14 / self.freq) 54 | self.x0 = torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0] 55 | self.y0 = torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0] 56 | self.device = device 57 | 58 | def forward(self, input_image): 59 | y, x = torch.meshgrid([torch.linspace(-self.x0 + 1, self.x0, self.kernel_size[0]), 60 | torch.linspace(-self.y0 + 1, self.y0, self.kernel_size[1])]) 61 | x = x.to(self.device) 62 | y = y.to(self.device) 63 | weight = torch.empty(self.weight.shape, requires_grad=False).to(self.device) 64 | for i in range(self.out_channels): 65 | for j in range(self.in_channels): 66 | sigma = self.sigma[i, j].expand_as(y) 67 | freq = self.freq[i, j].expand_as(y) 68 | theta = self.theta[i, j].expand_as(y) 69 | psi = self.psi[i, j].expand_as(y) 70 | 71 | rotx = x * torch.cos(theta) + y * torch.sin(theta) 72 | roty = -x * torch.sin(theta) + y * torch.cos(theta) 73 | 74 | g = torch.zeros(y.shape) 75 | 76 | g = torch.exp(-0.5 * ((rotx ** 2 + roty ** 2) / (sigma + 1e-3) ** 2)) 77 | g = g * torch.cos(freq * rotx + psi) 78 | g = g / (2 * 3.14 * sigma ** 2) 79 | weight[i, j] = g 80 | self.weight.data[i, j] = g 81 | return F.conv2d(input_image, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 82 | 83 | 84 | class Residual(nn.Module): # 本类已保存在d2lzh_pytorch包中方便以后使用 85 | def __init__(self, in_channels, out_channels, kernel_size, padding, use_1x1conv=False, stride=1): 86 | super(Residual, self).__init__() 87 | self.conv1 = nn.Sequential( 88 | nn.Conv3d(in_channels, out_channels, 89 | kernel_size=kernel_size, padding=padding, stride=stride), 90 | nn.ReLU() 91 | ) 92 | self.conv2 = nn.Conv3d(out_channels, out_channels, 93 | kernel_size=kernel_size, padding=padding,stride=stride) 94 | if use_1x1conv: 95 | self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride) 96 | else: 97 | self.conv3 = None 98 | self.bn1 = nn.BatchNorm3d(out_channels) 99 | self.bn2 = nn.BatchNorm3d(out_channels) 100 | 101 | def forward(self, X): 102 | Y = F.relu(self.bn1(self.conv1(X))) 103 | Y = self.bn2(self.conv2(Y)) 104 | if self.conv3: 105 | X = self.conv3(X) 106 | return F.relu(Y + X) 107 | 108 | 109 | class Separable_Convolution(nn.Module): 110 | def __init__(self, in_channels, out_channels, padding=0, kernel_size=1, stride=1): 111 | super(Separable_Convolution, self).__init__() 112 | self.depth_conv = nn.Conv3d( 113 | in_channels=in_channels, 114 | out_channels=in_channels, 115 | kernel_size=kernel_size, 116 | stride=stride, 117 | padding=padding, 118 | groups=in_channels 119 | ) 120 | self.point_conv = nn.Conv3d( 121 | in_channels=in_channels, 122 | out_channels=out_channels, 123 | kernel_size=1, 124 | stride=1, 125 | padding=0, 126 | groups=1 127 | ) 128 | 129 | def forward(self, input): 130 | out = self.depth_conv(input) 131 | out = self.point_conv(out) 132 | return out 133 | 134 | 135 | class RES_AVE(nn.Module): 136 | def __init__(self, band, classes): 137 | super(RES_AVE, self).__init__() 138 | self.conv11 = nn.Conv2d(in_channels=band, out_channels=32, 139 | kernel_size=(3, 3), stride=(1, 1)) 140 | self.conv12 = nn.Conv2d(in_channels=band, out_channels=32, 141 | kernel_size=(3, 3), stride=(1, 1)) 142 | self.batch_normal = nn.Sequential( 143 | nn.BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True), 144 | nn.ReLU() 145 | ) 146 | self.conv2 = nn.Sequential( 147 | nn.Conv2d(in_channels=64, out_channels=64, padding=(1, 1), 148 | kernel_size=(3, 3), stride=(1, 1)), 149 | nn.ReLU() 150 | ) 151 | self.conv3 = nn.Sequential( 152 | nn.Conv2d(in_channels=64, out_channels=64, padding=(1, 1), 153 | kernel_size=(3, 3), stride=(1, 1)), 154 | ) 155 | self.Avg_pooling = nn.AdaptiveAvgPool2d(1) 156 | self.full_connection = nn.Sequential( 157 | nn.Linear(1 * 64, classes), 158 | #nn.Sigmoid() 159 | ) 160 | def forward(self, X): 161 | X = X.permute(0, 4, 2, 3, 1) 162 | X = X.squeeze(-1) 163 | x1 = self.conv11(X) 164 | x2 = self.conv12(X) 165 | x3 = torch.cat((x1, x2), dim=1) 166 | x4 = self.batch_normal(x3) 167 | x4 = self.conv2(x4) 168 | x4 = self.conv3(x4) 169 | x5 = torch.add(x3, x4) 170 | # print('x5', x5.shape) 171 | x6 = self.Avg_pooling(x5) 172 | x6 = x6.view(x6.shape[0], -1) 173 | return self.full_connection(x6) 174 | 175 | 176 | class CDCNN_network(nn.Module): 177 | def __init__(self, band, classes): 178 | super(CDCNN_network, self).__init__() 179 | self.name = 'CDCNN' 180 | 181 | self.conv11 = nn.Sequential( 182 | nn.Conv2d(in_channels=band, out_channels=128, kernel_size=(1,1)), 183 | nn.MaxPool2d(kernel_size=(5, 5)) 184 | ) 185 | self.conv12 = nn.Sequential( 186 | nn.Conv2d(in_channels=band, out_channels=128, kernel_size=(3, 3)), 187 | nn.MaxPool2d(kernel_size=(3, 3)) 188 | ) 189 | self.conv13 = nn.Conv2d(in_channels=band, out_channels=128, kernel_size=(5, 5)) 190 | 191 | self.batch_normal1 = nn.Sequential( 192 | nn.ReLU(inplace=True), 193 | nn.BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True) 194 | ) 195 | self.conv2 = nn.Conv2d(in_channels=384, out_channels=128, kernel_size=(1, 1)) 196 | self.res_net1 = Residual_2D(128, 128, (1, 1), (0, 0), batch_normal=True) 197 | self.res_net2 = Residual_2D(128, 128, (1, 1), (0, 0)) 198 | 199 | self.conv3 = nn.Sequential( 200 | nn.ReLU(), 201 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(1, 1)) 202 | ) 203 | self.conv4 = nn.Sequential( 204 | nn.ReLU(), 205 | nn.Dropout(0.5), 206 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(1, 1)) 207 | ) 208 | self.conv5 = nn.Sequential( 209 | nn.ReLU(), 210 | nn.Dropout(0.5), 211 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(1, 1)) 212 | ) 213 | 214 | 215 | self.full_connection = nn.Sequential( 216 | nn.Linear(128, classes) 217 | #nn.Sigmoid() 218 | ) 219 | 220 | def forward(self, X): 221 | X = X.squeeze(1).permute(0, 3, 1, 2) 222 | x11 = self.conv11(X) 223 | x12 = self.conv12(X) 224 | x13 = self.conv13(X) 225 | # print(x11.shape) 226 | # print(x12.shape) 227 | # print(x13.shape) 228 | x1 = torch.cat((x11, x12, x13), dim=1) 229 | x1 = self.conv2(x1) 230 | x1 = self.res_net1(x1) 231 | x1 = self.res_net2(x1) 232 | x1 = self.conv3(x1) 233 | x1 = self.conv4(x1) 234 | x1 = self.conv5(x1) 235 | x1 = x1.view(x1.shape[0], -1) 236 | # print(x1.shape) 237 | return self.full_connection(x1) 238 | 239 | 240 | class DBDA_Separable_network(nn.Module): 241 | def __init__(self, band, classes): 242 | super(DBDA_Separable_network, self).__init__() 243 | 244 | self.conv11 = Separable_Convolution(in_channels=1, out_channels=24, padding=0, 245 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 246 | # Dense block 247 | self.batch_norm11 = nn.Sequential( 248 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 249 | nn.ReLU(inplace=True) 250 | ) 251 | self.conv12 = Separable_Convolution(in_channels=24, out_channels=24, padding=(0, 0, 3), 252 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 253 | self.batch_norm12 = nn.Sequential( 254 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 255 | nn.ReLU(inplace=True) 256 | ) 257 | self.conv13 = Separable_Convolution(in_channels=48, out_channels=24, padding=(0, 0, 3), 258 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 259 | self.batch_norm13 = nn.Sequential( 260 | nn.BatchNorm3d(72, eps=0.001, momentum=0.1, affine=True), 261 | nn.ReLU(inplace=True) 262 | ) 263 | self.conv14 = Separable_Convolution(in_channels=72, out_channels=24, padding=(0, 0, 3), 264 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 265 | self.batch_norm14 = nn.Sequential( 266 | nn.BatchNorm3d(96, eps=0.001, momentum=0.1, affine=True), 267 | nn.ReLU(inplace=True) 268 | ) 269 | kernel_3d = math.floor((band - 6) / 2) 270 | self.conv15 = Separable_Convolution(in_channels=96, out_channels=60, 271 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 272 | 273 | # 注意力机制模块 274 | 275 | # self.max_pooling1 = nn.MaxPool3d(kernel_size=(7, 7, 1)) 276 | # self.avg_pooling1 = nn.AvgPool3d(kernel_size=(7, 7, 1)) 277 | self.max_pooling1 = nn.AdaptiveAvgPool3d(1) 278 | self.avg_pooling1 = nn.AdaptiveAvgPool3d(1) 279 | 280 | self.shared_mlp = nn.Sequential( 281 | nn.Conv3d(in_channels=60, out_channels=30, 282 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 283 | nn.Conv3d(in_channels=30, out_channels=60, 284 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 285 | ) 286 | 287 | self.activation1 = nn.Sigmoid() 288 | 289 | # Spatial Branch 290 | self.conv21 = Separable_Convolution(in_channels=1, out_channels=24, 291 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 292 | # Dense block 293 | self.batch_norm21 = nn.Sequential( 294 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 295 | nn.ReLU(inplace=True) 296 | ) 297 | self.conv22 = Separable_Convolution(in_channels=24, out_channels=12, padding=(1, 1, 0), 298 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 299 | self.batch_norm22 = nn.Sequential( 300 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 301 | nn.ReLU(inplace=True) 302 | ) 303 | self.conv23 = Separable_Convolution(in_channels=36, out_channels=12, padding=(1, 1, 0), 304 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 305 | self.batch_norm23 = nn.Sequential( 306 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 307 | nn.ReLU(inplace=True) 308 | ) 309 | self.conv24 = Separable_Convolution(in_channels=48, out_channels=12, padding=(1, 1, 0), 310 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 311 | 312 | # 注意力机制模块 313 | 314 | # self.max_pooling2 = nn.MaxPool3d(kernel_size=(1, 1, 60)) 315 | # self.avg_pooling2 = nn.AvgPool3d(kernel_size=(1, 1, 60)) 316 | # self.max_pooling2 = nn.AdaptiveAvgPool3d(1) 317 | # self.avg_pooling2 = nn.AdaptiveAvgPool3d(1) 318 | 319 | self.conv25 = nn.Sequential( 320 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 321 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 322 | nn.Sigmoid() 323 | ) 324 | 325 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 326 | self.full_connection = nn.Sequential( 327 | nn.Linear(120, classes) # , 328 | # nn.Softmax() 329 | ) 330 | 331 | self.attention_spectral = CAM_Module(60) 332 | self.attention_spatial = PAM_Module(60) 333 | 334 | # fc = Dense(classes, activation='softmax', name='output1', 335 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 336 | 337 | def forward(self, X): 338 | # spectral 339 | x11 = self.conv11(X) 340 | # print('x11', x11.shape) 341 | x12 = self.batch_norm11(x11) 342 | x12 = self.conv12(x12) 343 | # print('x12', x12.shape) 344 | 345 | x13 = torch.cat((x11, x12), dim=1) 346 | # print('x13', x13.shape) 347 | x13 = self.batch_norm12(x13) 348 | x13 = self.conv13(x13) 349 | # print('x13', x13.shape) 350 | 351 | x14 = torch.cat((x11, x12, x13), dim=1) 352 | x14 = self.batch_norm13(x14) 353 | x14 = self.conv14(x14) 354 | 355 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 356 | # print('x15', x15.shape) 357 | 358 | x16 = self.batch_norm14(x15) 359 | x16 = self.conv15(x16) 360 | # print('x16', x16.shape) # 7*7*97, 60 361 | 362 | # print('x16', x16.shape) 363 | # 光谱注意力通道 364 | x1 = self.attention_spectral(x16) 365 | x1 = torch.mul(x1, x16) 366 | 367 | # spatial 368 | # print('x', X.shape) 369 | x21 = self.conv21(X) 370 | x22 = self.batch_norm21(x21) 371 | x22 = self.conv22(x22) 372 | 373 | x23 = torch.cat((x21, x22), dim=1) 374 | x23 = self.batch_norm22(x23) 375 | x23 = self.conv23(x23) 376 | 377 | x24 = torch.cat((x21, x22, x23), dim=1) 378 | x24 = self.batch_norm23(x24) 379 | x24 = self.conv24(x24) 380 | 381 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 382 | # print('x25', x25.shape) 383 | # x25 = x25.permute(0, 4, 2, 3, 1) 384 | # print('x25', x25.shape) 385 | 386 | # 空间注意力机制 387 | x2 = self.attention_spatial(x25) 388 | x2 = torch.mul(x2, x25) 389 | 390 | # model1 391 | x1 = self.global_pooling(x1) 392 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 393 | x2 = self.global_pooling(x2) 394 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 395 | 396 | x_pre = torch.cat((x1, x2), dim=1) 397 | # print('x_pre', x_pre.shape) 398 | 399 | # model2 400 | # x1 = torch.mul(x2, x16) 401 | # x2 = torch.mul(x2, x25) 402 | # x_pre = x1 + x2 403 | # 404 | # 405 | # x_pre = x_pre.view(x_pre.shape[0], -1) 406 | output = self.full_connection(x_pre) 407 | # output = self.fc(x_pre) 408 | return output 409 | 410 | 411 | 412 | 413 | class GaborNN(nn.Module): 414 | def __init__(self,band,classes): 415 | super(GaborNN, self).__init__() 416 | self.name = 'GaborNN' 417 | self.g0 = GaborConv2d(in_channels=band, out_channels=96, kernel_size=(3, 3), device=device) 418 | self.m1 = nn.MaxPool2d(kernel_size=(1,1)) 419 | self.c1 = nn.Conv2d(96, 128, (1,1)) 420 | self.m2 = nn.MaxPool2d(kernel_size=(2,2)) 421 | self.fc1 = nn.Linear(128*2*2, 128) 422 | self.fc2 = nn.Linear(128, classes) 423 | 424 | def forward(self, x): 425 | x = x.squeeze(1).permute(0,3,1,2) 426 | x = F.leaky_relu(self.g0(x)) 427 | 428 | x = self.m1(x) 429 | 430 | x = F.leaky_relu(self.c1(x)) 431 | 432 | x = self.m2(x) 433 | 434 | x = x.view(-1, 128*2*2) 435 | 436 | x = F.leaky_relu(self.fc1(x)) 437 | 438 | x = self.fc2(x) 439 | 440 | return x 441 | class DBMA_network(nn.Module): 442 | def __init__(self, band, classes): 443 | super(DBMA_network, self).__init__() 444 | 445 | # spectral branch 446 | self.name = 'DBMA' 447 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 448 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 449 | # Dense block 450 | self.batch_norm11 = nn.Sequential( 451 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 452 | nn.ReLU(inplace=True) 453 | ) 454 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=24, padding=(0, 0, 3), 455 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 456 | self.batch_norm12 = nn.Sequential( 457 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 458 | nn.ReLU(inplace=True) 459 | ) 460 | self.conv13 = nn.Conv3d(in_channels=48, out_channels=24, padding=(0, 0, 3), 461 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 462 | self.batch_norm13 = nn.Sequential( 463 | nn.BatchNorm3d(72, eps=0.001, momentum=0.1, affine=True), 464 | nn.ReLU(inplace=True) 465 | ) 466 | self.conv14 = nn.Conv3d(in_channels=72, out_channels=24, padding=(0, 0, 3), 467 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 468 | self.batch_norm14 = nn.Sequential( 469 | nn.BatchNorm3d(96, eps=0.001, momentum=0.1, affine=True), 470 | nn.ReLU(inplace=True) 471 | ) 472 | kernel_3d = math.floor((band - 6) / 2) 473 | self.conv15 = nn.Conv3d(in_channels=96, out_channels=60, 474 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 475 | 476 | #注意力机制模块 477 | 478 | #self.max_pooling1 = nn.MaxPool3d(kernel_size=(7, 7, 1)) 479 | #self.avg_pooling1 = nn.AvgPool3d(kernel_size=(7, 7, 1)) 480 | self.max_pooling1 = nn.AdaptiveAvgPool3d(1) 481 | self.avg_pooling1 = nn.AdaptiveAvgPool3d(1) 482 | 483 | self.shared_mlp = nn.Sequential( 484 | nn.Conv3d(in_channels=60, out_channels=30, 485 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 486 | nn.Conv3d(in_channels=30, out_channels=60, 487 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 488 | ) 489 | #self.fc11 = Dense(30, activation=None, kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 490 | #self.fc12 = Dense(60, activation=None, kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 491 | 492 | self.activation1 = nn.Sigmoid() 493 | 494 | 495 | # Spatial Branch 496 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 497 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 498 | # Dense block 499 | self.batch_norm21 = nn.Sequential( 500 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 501 | nn.ReLU(inplace=True) 502 | ) 503 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 504 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 505 | self.batch_norm22 = nn.Sequential( 506 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 507 | nn.ReLU(inplace=True) 508 | ) 509 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 510 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 511 | self.batch_norm23 = nn.Sequential( 512 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 513 | nn.ReLU(inplace=True) 514 | ) 515 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 516 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 517 | 518 | # 注意力机制模块 519 | 520 | # self.max_pooling2 = nn.MaxPool3d(kernel_size=(1, 1, 60)) 521 | # self.avg_pooling2 = nn.AvgPool3d(kernel_size=(1, 1, 60)) 522 | # self.max_pooling2 = nn.AdaptiveAvgPool3d(1) 523 | # self.avg_pooling2 = nn.AdaptiveAvgPool3d(1) 524 | 525 | self.conv25 = nn.Sequential( 526 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 527 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 528 | nn.Sigmoid() 529 | ) 530 | 531 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 532 | self.full_connection = nn.Sequential( 533 | nn.Linear(120, classes) # , 534 | # nn.Softmax() 535 | ) 536 | 537 | def forward(self, X): 538 | # spectral 539 | x11 = self.conv11(X) 540 | #print('x11', x11.shape) 541 | x12 = self.batch_norm11(x11) 542 | x12 = self.conv12(x12) 543 | #print('x12', x12.shape) 544 | 545 | x13 = torch.cat((x11, x12), dim=1) 546 | #print('x13', x13.shape) 547 | x13 = self.batch_norm12(x13) 548 | x13 = self.conv13(x13) 549 | #print('x13', x13.shape) 550 | 551 | x14 = torch.cat((x11, x12, x13), dim=1) 552 | x14 = self.batch_norm13(x14) 553 | x14 = self.conv14(x14) 554 | 555 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 556 | #print('x15', x15.shape) 557 | 558 | x16 = self.batch_norm14(x15) 559 | x16 = self.conv15(x16) 560 | #print('x16', x16.shape) # 7*7*97, 60 561 | 562 | #print('x16', x16.shape) 563 | # 光谱注意力通道 564 | x_max1 = self.max_pooling1(x16) 565 | x_avg1 = self.avg_pooling1(x16) 566 | #print('x_max1', x_max1.shape) 567 | 568 | 569 | # x_max1 = self.fc11(x_max1) 570 | # x_max1 = self.fc12(x_max1) 571 | # 572 | # x1_avg1 = self.fc11(x_avg1) 573 | # x1_avg1 = self.fc12(x_avg1) 574 | #print('x_max1', x_max1.shape) 575 | # x_max1 = x_max1.view(x_max1.size(0), -1) 576 | # x_avg1 = x_avg1.view(x_avg1.size(0), -1) 577 | # print('x_max1', x_max1.shape) 578 | x_max1 = self.shared_mlp(x_max1) 579 | x_avg1 = self.shared_mlp(x_avg1) 580 | #print('x_max1', x_max1.shape) 581 | x1 = torch.add(x_max1, x_avg1) 582 | x1 = self.activation1(x1) 583 | 584 | #x1 = x1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 585 | #print('x1', x1.shape) 586 | #print('x16', x16.shape) 587 | 588 | # x1 = multiply([x1, x16]) 589 | # x1 = self.activation1(x1) 590 | x1 = torch.mul(x1, x16) 591 | #print('x1', x1.shape) 592 | x1 = self.global_pooling(x1) 593 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 594 | #print('x1', x1.shape) 595 | # x1 = Reshape(target_shape=(7, 7, 1, 60))(x1) 596 | #x1 = GlobalAveragePooling3D()(x1) 597 | 598 | # spatial 599 | #print('x', X.shape) 600 | x21 = self.conv21(X) 601 | x22 = self.batch_norm21(x21) 602 | x22 = self.conv22(x22) 603 | 604 | x23 = torch.cat((x21, x22), dim=1) 605 | x23 = self.batch_norm22(x23) 606 | x23 = self.conv23(x23) 607 | 608 | x24 = torch.cat((x21, x22, x23), dim=1) 609 | x24 = self.batch_norm23(x24) 610 | x24 = self.conv24(x24) 611 | 612 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 613 | #print('x25', x25.shape) 614 | # x25 = x25.permute(0, 4, 2, 3, 1) 615 | #print('x25', x25.shape) 616 | 617 | # 空间注意力机制 618 | #x_max2 = self.max_pooling2(x25) 619 | #x_avg2 = self.avg_pooling2(x25) 620 | # x_avg2 = x_avg2.permute(0, 4, 2, 3, 1) 621 | x_avg2 = torch.mean(x25, dim=1, keepdim=True) 622 | x_max2, _ = torch.max(x25, dim=1, keepdim=True) 623 | #print('x_avg2', x_avg2.shape) 624 | 625 | x2 = torch.cat((x_max2, x_avg2), dim=-1) 626 | x2 = self.conv25(x2) 627 | #print('x2', x2.shape) 628 | #print('x25', x25.shape) 629 | 630 | 631 | x2 = torch.mul(x2, x25) 632 | #print('x2', x2.shape) 633 | x2 = self.global_pooling(x2) 634 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 635 | # x2 = Reshape(target_shape=(7, 7, 1, 60))(x2) 636 | # x2 = GlobalAveragePooling3D()(x2) 637 | 638 | #print('x1', x1.shape) 639 | #print('x2', x2.shape) 640 | 641 | x_pre = torch.cat((x1, x2), dim=1) 642 | #print('x_pre', x_pre.shape) 643 | x_pre = x_pre.view(x_pre.shape[0], -1) 644 | output = self.full_connection(x_pre) 645 | # output = self.fc(x_pre) 646 | return output 647 | 648 | class DBDA_network(nn.Module): 649 | def __init__(self, band, classes): 650 | super(DBDA_network, self).__init__() 651 | 652 | # spectral branch 653 | 654 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 655 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 656 | # Dense block 657 | self.batch_norm11 = nn.Sequential( 658 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 659 | nn.ReLU(inplace=True) 660 | ) 661 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=24, padding=(0, 0, 3), 662 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 663 | self.batch_norm12 = nn.Sequential( 664 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 665 | nn.ReLU(inplace=True) 666 | ) 667 | self.conv13 = nn.Conv3d(in_channels=48, out_channels=24, padding=(0, 0, 3), 668 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 669 | self.batch_norm13 = nn.Sequential( 670 | nn.BatchNorm3d(72, eps=0.001, momentum=0.1, affine=True), 671 | nn.ReLU(inplace=True) 672 | ) 673 | self.conv14 = nn.Conv3d(in_channels=72, out_channels=24, padding=(0, 0, 3), 674 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 675 | self.batch_norm14 = nn.Sequential( 676 | nn.BatchNorm3d(96, eps=0.001, momentum=0.1, affine=True), 677 | nn.ReLU(inplace=True) 678 | ) 679 | kernel_3d = math.floor((band - 6) / 2) 680 | self.conv15 = nn.Conv3d(in_channels=96, out_channels=60, 681 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 682 | 683 | #注意力机制模块 684 | 685 | #self.max_pooling1 = nn.MaxPool3d(kernel_size=(7, 7, 1)) 686 | #self.avg_pooling1 = nn.AvgPool3d(kernel_size=(7, 7, 1)) 687 | self.max_pooling1 = nn.AdaptiveAvgPool3d(1) 688 | self.avg_pooling1 = nn.AdaptiveAvgPool3d(1) 689 | 690 | self.shared_mlp = nn.Sequential( 691 | nn.Conv3d(in_channels=60, out_channels=30, 692 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 693 | nn.Conv3d(in_channels=30, out_channels=60, 694 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 695 | ) 696 | 697 | self.activation1 = nn.Sigmoid() 698 | 699 | 700 | # Spatial Branch 701 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 702 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 703 | # Dense block 704 | self.batch_norm21 = nn.Sequential( 705 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 706 | nn.ReLU(inplace=True) 707 | ) 708 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 709 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 710 | self.batch_norm22 = nn.Sequential( 711 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 712 | nn.ReLU(inplace=True) 713 | ) 714 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 715 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 716 | self.batch_norm23 = nn.Sequential( 717 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 718 | nn.ReLU(inplace=True) 719 | ) 720 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 721 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 722 | 723 | # 注意力机制模块 724 | 725 | # self.max_pooling2 = nn.MaxPool3d(kernel_size=(1, 1, 60)) 726 | # self.avg_pooling2 = nn.AvgPool3d(kernel_size=(1, 1, 60)) 727 | # self.max_pooling2 = nn.AdaptiveAvgPool3d(1) 728 | # self.avg_pooling2 = nn.AdaptiveAvgPool3d(1) 729 | 730 | self.conv25 = nn.Sequential( 731 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 732 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 733 | nn.Sigmoid() 734 | ) 735 | 736 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 737 | self.full_connection = nn.Sequential( 738 | # nn.Dropout(p=0.5), 739 | nn.Linear(120, classes) # , 740 | # nn.Softmax() 741 | ) 742 | 743 | self.attention_spectral = CAM_Module(60) 744 | self.attention_spatial = PAM_Module(60) 745 | 746 | #fc = Dense(classes, activation='softmax', name='output1', 747 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 748 | 749 | def forward(self, X): 750 | # spectral 751 | x11 = self.conv11(X) 752 | #print('x11', x11.shape) 753 | x12 = self.batch_norm11(x11) 754 | x12 = self.conv12(x12) 755 | #print('x12', x12.shape) 756 | 757 | x13 = torch.cat((x11, x12), dim=1) 758 | #print('x13', x13.shape) 759 | x13 = self.batch_norm12(x13) 760 | x13 = self.conv13(x13) 761 | #print('x13', x13.shape) 762 | 763 | x14 = torch.cat((x11, x12, x13), dim=1) 764 | x14 = self.batch_norm13(x14) 765 | x14 = self.conv14(x14) 766 | 767 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 768 | # print('x15', x15.shape) 769 | 770 | x16 = self.batch_norm14(x15) 771 | x16 = self.conv15(x16) 772 | #print('x16', x16.shape) # 7*7*97, 60 773 | 774 | #print('x16', x16.shape) 775 | # 光谱注意力通道 776 | x1 = self.attention_spectral(x16) 777 | x1 = torch.mul(x1, x16) 778 | 779 | 780 | # spatial 781 | #print('x', X.shape) 782 | x21 = self.conv21(X) 783 | x22 = self.batch_norm21(x21) 784 | x22 = self.conv22(x22) 785 | 786 | x23 = torch.cat((x21, x22), dim=1) 787 | x23 = self.batch_norm22(x23) 788 | x23 = self.conv23(x23) 789 | 790 | x24 = torch.cat((x21, x22, x23), dim=1) 791 | x24 = self.batch_norm23(x24) 792 | x24 = self.conv24(x24) 793 | 794 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 795 | #print('x25', x25.shape) 796 | # x25 = x25.permute(0, 4, 2, 3, 1) 797 | #print('x25', x25.shape) 798 | 799 | # 空间注意力机制 800 | x2 = self.attention_spatial(x25) 801 | x2 = torch.mul(x2, x25) 802 | 803 | # model1 804 | x1 = self.global_pooling(x1) 805 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 806 | x2= self.global_pooling(x2) 807 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 808 | 809 | x_pre = torch.cat((x1, x2), dim=1) 810 | #print('x_pre', x_pre.shape) 811 | 812 | # model2 813 | # x1 = torch.mul(x2, x16) 814 | # x2 = torch.mul(x2, x25) 815 | # x_pre = x1 + x2 816 | # 817 | # 818 | # x_pre = x_pre.view(x_pre.shape[0], -1) 819 | output = self.full_connection(x_pre) 820 | # output = self.fc(x_pre) 821 | return output 822 | 823 | 824 | class DBDA_network_simplified(nn.Module): 825 | def __init__(self, band, classes): 826 | super(DBDA_network_simplified, self).__init__() 827 | 828 | # spectral branch 829 | 830 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 831 | kernel_size=(7, 7, 7), stride=(1, 1, 2)) 832 | # Dense block 833 | self.batch_norm11 = nn.Sequential( 834 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 835 | nn.ReLU(inplace=True) 836 | ) 837 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=24, padding=(0, 0, 3), 838 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 839 | self.batch_norm12 = nn.Sequential( 840 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 841 | nn.ReLU(inplace=True) 842 | ) 843 | self.conv13 = nn.Conv3d(in_channels=48, out_channels=24, padding=(0, 0, 3), 844 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 845 | self.batch_norm13 = nn.Sequential( 846 | nn.BatchNorm3d(72, eps=0.001, momentum=0.1, affine=True), 847 | nn.ReLU(inplace=True) 848 | ) 849 | self.conv14 = nn.Conv3d(in_channels=72, out_channels=24, padding=(0, 0, 3), 850 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 851 | self.batch_norm14 = nn.Sequential( 852 | nn.BatchNorm3d(96, eps=0.001, momentum=0.1, affine=True), 853 | nn.ReLU(inplace=True) 854 | ) 855 | kernel_3d = math.floor((band - 6) / 2) 856 | self.conv15 = nn.Conv3d(in_channels=96, out_channels=60, 857 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 858 | 859 | #注意力机制模块 860 | 861 | #self.max_pooling1 = nn.MaxPool3d(kernel_size=(7, 7, 1)) 862 | #self.avg_pooling1 = nn.AvgPool3d(kernel_size=(7, 7, 1)) 863 | self.max_pooling1 = nn.AdaptiveAvgPool3d(1) 864 | self.avg_pooling1 = nn.AdaptiveAvgPool3d(1) 865 | 866 | self.shared_mlp = nn.Sequential( 867 | nn.Conv3d(in_channels=60, out_channels=30, 868 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 869 | nn.Conv3d(in_channels=30, out_channels=60, 870 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 871 | ) 872 | 873 | self.activation1 = nn.Sigmoid() 874 | 875 | 876 | # Spatial Branch 877 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 878 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 879 | # Dense block 880 | self.batch_norm21 = nn.Sequential( 881 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 882 | nn.ReLU(inplace=True) 883 | ) 884 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 885 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 886 | self.batch_norm22 = nn.Sequential( 887 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 888 | nn.ReLU(inplace=True) 889 | ) 890 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 891 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 892 | self.batch_norm23 = nn.Sequential( 893 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 894 | nn.ReLU(inplace=True) 895 | ) 896 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 897 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 898 | 899 | # 注意力机制模块 900 | 901 | # self.max_pooling2 = nn.MaxPool3d(kernel_size=(1, 1, 60)) 902 | # self.avg_pooling2 = nn.AvgPool3d(kernel_size=(1, 1, 60)) 903 | # self.max_pooling2 = nn.AdaptiveAvgPool3d(1) 904 | # self.avg_pooling2 = nn.AdaptiveAvgPool3d(1) 905 | 906 | self.conv25 = nn.Sequential( 907 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 908 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 909 | nn.Sigmoid() 910 | ) 911 | 912 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 913 | self.full_connection = nn.Sequential( 914 | nn.Linear(120, classes) # , 915 | # nn.Softmax() 916 | ) 917 | 918 | self.attention_spectral = CAM_Module(60) 919 | self.attention_spatial = PAM_Module(60) 920 | 921 | #fc = Dense(classes, activation='softmax', name='output1', 922 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 923 | 924 | def forward(self, X): 925 | # spectral 926 | x11 = self.conv11(X) 927 | #print('x11', x11.shape) 928 | x12 = self.batch_norm11(x11) 929 | x12 = self.conv12(x12) 930 | #print('x12', x12.shape) 931 | 932 | x13 = torch.cat((x11, x12), dim=1) 933 | #print('x13', x13.shape) 934 | x13 = self.batch_norm12(x13) 935 | x13 = self.conv13(x13) 936 | #print('x13', x13.shape) 937 | 938 | x14 = torch.cat((x11, x12, x13), dim=1) 939 | x14 = self.batch_norm13(x14) 940 | x14 = self.conv14(x14) 941 | 942 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 943 | # print('x15', x15.shape) 944 | 945 | x16 = self.batch_norm14(x15) 946 | x16 = self.conv15(x16) 947 | #print('x16', x16.shape) # 7*7*97, 60 948 | 949 | #print('x16', x16.shape) 950 | # 光谱注意力通道 951 | x1 = self.attention_spectral(x16) 952 | x1 = torch.mul(x1, x16) 953 | 954 | 955 | # spatial 956 | #print('x', X.shape) 957 | x21 = self.conv21(X) 958 | x22 = self.batch_norm21(x21) 959 | x22 = self.conv22(x22) 960 | 961 | x23 = torch.cat((x21, x22), dim=1) 962 | x23 = self.batch_norm22(x23) 963 | x23 = self.conv23(x23) 964 | 965 | x24 = torch.cat((x21, x22, x23), dim=1) 966 | x24 = self.batch_norm23(x24) 967 | x24 = self.conv24(x24) 968 | 969 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 970 | #print('x25', x25.shape) 971 | # x25 = x25.permute(0, 4, 2, 3, 1) 972 | #print('x25', x25.shape) 973 | 974 | # 空间注意力机制 975 | x2 = self.attention_spatial(x25) 976 | x2 = torch.mul(x2, x25) 977 | 978 | # model1 979 | x1 = self.global_pooling(x1) 980 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 981 | x2= self.global_pooling(x2) 982 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 983 | 984 | x_pre = torch.cat((x1, x2), dim=1) 985 | #print('x_pre', x_pre.shape) 986 | 987 | # model2 988 | # x1 = torch.mul(x2, x16) 989 | # x2 = torch.mul(x2, x25) 990 | # x_pre = x1 + x2 991 | # 992 | # 993 | # x_pre = x_pre.view(x_pre.shape[0], -1) 994 | output = self.full_connection(x_pre) 995 | # output = self.fc(x_pre) 996 | return output 997 | 998 | class DBDA_network_conv(nn.Module): 999 | def __init__(self, band, classes): 1000 | super(DBDA_network_conv, self).__init__() 1001 | 1002 | # spectral branch 1003 | 1004 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 1005 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1006 | # Dense block 1007 | self.batch_norm11 = nn.Sequential( 1008 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1009 | nn.ReLU(inplace=True) 1010 | ) 1011 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 1012 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1013 | self.batch_norm12 = nn.Sequential( 1014 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1015 | nn.ReLU(inplace=True) 1016 | ) 1017 | self.conv13 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 1018 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1019 | self.batch_norm13 = nn.Sequential( 1020 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1021 | nn.ReLU(inplace=True) 1022 | ) 1023 | self.conv14 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 1024 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1025 | self.batch_norm14 = nn.Sequential( 1026 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 1027 | nn.ReLU(inplace=True) 1028 | ) 1029 | kernel_3d = math.floor((band - 6) / 2) 1030 | self.conv15 = nn.Conv3d(in_channels=60, out_channels=60, 1031 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 1032 | 1033 | #注意力机制模块 1034 | 1035 | #self.max_pooling1 = nn.MaxPool3d(kernel_size=(7, 7, 1)) 1036 | #self.avg_pooling1 = nn.AvgPool3d(kernel_size=(7, 7, 1)) 1037 | self.max_pooling1 = nn.AdaptiveAvgPool3d(1) 1038 | self.avg_pooling1 = nn.AdaptiveAvgPool3d(1) 1039 | 1040 | self.shared_mlp = nn.Sequential( 1041 | nn.Conv3d(in_channels=60, out_channels=30, 1042 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 1043 | nn.Conv3d(in_channels=30, out_channels=60, 1044 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 1045 | ) 1046 | 1047 | self.activation1 = nn.Sigmoid() 1048 | 1049 | 1050 | # Spatial Branch 1051 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 1052 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 1053 | # Dense block 1054 | self.batch_norm21 = nn.Sequential( 1055 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 1056 | nn.ReLU(inplace=True) 1057 | ) 1058 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 1059 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1060 | self.batch_norm22 = nn.Sequential( 1061 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1062 | nn.ReLU(inplace=True) 1063 | ) 1064 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 1065 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1066 | self.batch_norm23 = nn.Sequential( 1067 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1068 | nn.ReLU(inplace=True) 1069 | ) 1070 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 1071 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1072 | 1073 | # 注意力机制模块 1074 | 1075 | # self.max_pooling2 = nn.MaxPool3d(kernel_size=(1, 1, 60)) 1076 | # self.avg_pooling2 = nn.AvgPool3d(kernel_size=(1, 1, 60)) 1077 | # self.max_pooling2 = nn.AdaptiveAvgPool3d(1) 1078 | # self.avg_pooling2 = nn.AdaptiveAvgPool3d(1) 1079 | 1080 | self.conv25 = nn.Sequential( 1081 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 1082 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 1083 | nn.Sigmoid() 1084 | ) 1085 | 1086 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 1087 | self.full_connection = nn.Sequential( 1088 | # nn.Dropout(p=0.5), 1089 | nn.Linear(120, classes) # , 1090 | # nn.Softmax() 1091 | ) 1092 | 1093 | self.attention_spectral = CAM_Module(60) 1094 | self.attention_spatial = PAM_Module(60) 1095 | 1096 | #fc = Dense(classes, activation='softmax', name='output1', 1097 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 1098 | 1099 | def forward(self, X): 1100 | # spectral 1101 | x11 = self.conv11(X) 1102 | #print('x11', x11.shape) 1103 | x12 = self.batch_norm11(x11) 1104 | x12 = self.conv12(x12) 1105 | #print('x12', x12.shape) 1106 | 1107 | x13 = torch.cat((x11, x12), dim=1) 1108 | #print('x13', x13.shape) 1109 | x13 = self.batch_norm12(x13) 1110 | x13 = self.conv13(x13) 1111 | #print('x13', x13.shape) 1112 | 1113 | x14 = torch.cat((x11, x12, x13), dim=1) 1114 | x14 = self.batch_norm13(x14) 1115 | x14 = self.conv14(x14) 1116 | 1117 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 1118 | # print('x15', x15.shape) 1119 | 1120 | x16 = self.batch_norm14(x15) 1121 | x16 = self.conv15(x16) 1122 | #print('x16', x16.shape) # 7*7*97, 60 1123 | 1124 | #print('x16', x16.shape) 1125 | # 光谱注意力通道 1126 | x1 = self.attention_spectral(x16) 1127 | x1 = torch.mul(x1, x16) 1128 | 1129 | 1130 | # spatial 1131 | #print('x', X.shape) 1132 | x21 = self.conv21(X) 1133 | x22 = self.batch_norm21(x21) 1134 | x22 = self.conv22(x22) 1135 | 1136 | x23 = torch.cat((x21, x22), dim=1) 1137 | x23 = self.batch_norm22(x23) 1138 | x23 = self.conv23(x23) 1139 | 1140 | x24 = torch.cat((x21, x22, x23), dim=1) 1141 | x24 = self.batch_norm23(x24) 1142 | x24 = self.conv24(x24) 1143 | 1144 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 1145 | #print('x25', x25.shape) 1146 | # x25 = x25.permute(0, 4, 2, 3, 1) 1147 | #print('x25', x25.shape) 1148 | 1149 | # 空间注意力机制 1150 | x2 = self.attention_spatial(x25) 1151 | x2 = torch.mul(x2, x25) 1152 | 1153 | # model1 1154 | x1 = self.global_pooling(x1) 1155 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 1156 | x2= self.global_pooling(x2) 1157 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 1158 | 1159 | x_pre = torch.cat((x1, x2), dim=1) 1160 | #print('x_pre', x_pre.shape) 1161 | 1162 | # model2 1163 | # x1 = torch.mul(x2, x16) 1164 | # x2 = torch.mul(x2, x25) 1165 | # x_pre = x1 + x2 1166 | # 1167 | # 1168 | # x_pre = x_pre.view(x_pre.shape[0], -1) 1169 | output = self.full_connection(x_pre) 1170 | # output = self.fc(x_pre) 1171 | return output 1172 | 1173 | class DBZA_network(nn.Module): # 删除注意力模块 1174 | def __init__(self, band, classes): 1175 | super(DBZA_network, self).__init__() 1176 | 1177 | # spectral branch 1178 | self.name = 'DBZA' 1179 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 1180 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1181 | # Dense block 1182 | self.batch_norm11 = nn.Sequential( 1183 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1184 | nn.ReLU(inplace=True) 1185 | ) 1186 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=24, padding=(0, 0, 3), 1187 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1188 | self.batch_norm12 = nn.Sequential( 1189 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1190 | nn.ReLU(inplace=True) 1191 | ) 1192 | self.conv13 = nn.Conv3d(in_channels=48, out_channels=24, padding=(0, 0, 3), 1193 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1194 | self.batch_norm13 = nn.Sequential( 1195 | nn.BatchNorm3d(72, eps=0.001, momentum=0.1, affine=True), 1196 | nn.ReLU(inplace=True) 1197 | ) 1198 | self.conv14 = nn.Conv3d(in_channels=72, out_channels=24, padding=(0, 0, 3), 1199 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1200 | self.batch_norm14 = nn.Sequential( 1201 | nn.BatchNorm3d(96, eps=0.001, momentum=0.1, affine=True), 1202 | nn.ReLU(inplace=True) 1203 | ) 1204 | kernel_3d = math.floor((band - 6) / 2) 1205 | self.conv15 = nn.Conv3d(in_channels=96, out_channels=60, 1206 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 1207 | 1208 | #注意力机制模块 1209 | 1210 | #self.max_pooling1 = nn.MaxPool3d(kernel_size=(7, 7, 1)) 1211 | #self.avg_pooling1 = nn.AvgPool3d(kernel_size=(7, 7, 1)) 1212 | self.max_pooling1 = nn.AdaptiveAvgPool3d(1) 1213 | self.avg_pooling1 = nn.AdaptiveAvgPool3d(1) 1214 | 1215 | self.shared_mlp = nn.Sequential( 1216 | nn.Conv3d(in_channels=60, out_channels=30, 1217 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 1218 | nn.Conv3d(in_channels=30, out_channels=60, 1219 | kernel_size=(1, 1, 1), stride=(1, 1, 1)), 1220 | ) 1221 | 1222 | self.activation1 = nn.Sigmoid() 1223 | 1224 | 1225 | # Spatial Branch 1226 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 1227 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 1228 | # Dense block 1229 | self.batch_norm21 = nn.Sequential( 1230 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 1231 | nn.ReLU(inplace=True) 1232 | ) 1233 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 1234 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1235 | self.batch_norm22 = nn.Sequential( 1236 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1237 | nn.ReLU(inplace=True) 1238 | ) 1239 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 1240 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1241 | self.batch_norm23 = nn.Sequential( 1242 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1243 | nn.ReLU(inplace=True) 1244 | ) 1245 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 1246 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1247 | 1248 | # 注意力机制模块 1249 | 1250 | # self.max_pooling2 = nn.MaxPool3d(kernel_size=(1, 1, 60)) 1251 | # self.avg_pooling2 = nn.AvgPool3d(kernel_size=(1, 1, 60)) 1252 | # self.max_pooling2 = nn.AdaptiveAvgPool3d(1) 1253 | # self.avg_pooling2 = nn.AdaptiveAvgPool3d(1) 1254 | 1255 | self.conv25 = nn.Sequential( 1256 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 1257 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 1258 | nn.Sigmoid() 1259 | ) 1260 | 1261 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 1262 | self.full_connection = nn.Sequential( 1263 | nn.Linear(120, classes) # , 1264 | # nn.Softmax() 1265 | ) 1266 | 1267 | self.attention_spectral = CAM_Module(60) 1268 | self.attention_spatial = PAM_Module(60) 1269 | 1270 | #fc = Dense(classes, activation='softmax', name='output1', 1271 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 1272 | 1273 | def forward(self, X): 1274 | # spectral 1275 | x11 = self.conv11(X) 1276 | #print('x11', x11.shape) 1277 | x12 = self.batch_norm11(x11) 1278 | x12 = self.conv12(x12) 1279 | #print('x12', x12.shape) 1280 | 1281 | x13 = torch.cat((x11, x12), dim=1) 1282 | #print('x13', x13.shape) 1283 | x13 = self.batch_norm12(x13) 1284 | x13 = self.conv13(x13) 1285 | #print('x13', x13.shape) 1286 | 1287 | x14 = torch.cat((x11, x12, x13), dim=1) 1288 | x14 = self.batch_norm13(x14) 1289 | x14 = self.conv14(x14) 1290 | 1291 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 1292 | # print('x15', x15.shape) 1293 | 1294 | x16 = self.batch_norm14(x15) 1295 | x16 = self.conv15(x16) 1296 | #print('x16', x16.shape) # 7*7*97, 60 1297 | 1298 | #print('x16', x16.shape) 1299 | # 光谱注意力通道 1300 | # x1 = self.attention_spectral(x16) 1301 | # x1 = torch.mul(x1, x16) 1302 | 1303 | 1304 | # spatial 1305 | #print('x', X.shape) 1306 | x21 = self.conv21(X) 1307 | x22 = self.batch_norm21(x21) 1308 | x22 = self.conv22(x22) 1309 | 1310 | x23 = torch.cat((x21, x22), dim=1) 1311 | x23 = self.batch_norm22(x23) 1312 | x23 = self.conv23(x23) 1313 | 1314 | x24 = torch.cat((x21, x22, x23), dim=1) 1315 | x24 = self.batch_norm23(x24) 1316 | x24 = self.conv24(x24) 1317 | 1318 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 1319 | #print('x25', x25.shape) 1320 | # x25 = x25.permute(0, 4, 2, 3, 1) 1321 | #print('x25', x25.shape) 1322 | 1323 | # 空间注意力机制 1324 | # x2 = self.attention_spatial(x25) 1325 | # x2 = torch.mul(x2, x25) 1326 | 1327 | # model1 1328 | x1 = self.global_pooling(x16) 1329 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 1330 | x2= self.global_pooling(x25) 1331 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 1332 | 1333 | x_pre = torch.cat((x1, x2), dim=1) 1334 | #print('x_pre', x_pre.shape) 1335 | 1336 | # model2 1337 | # x1 = torch.mul(x2, x16) 1338 | # x2 = torch.mul(x2, x25) 1339 | # x_pre = x1 + x2 1340 | # 1341 | # 1342 | # x_pre = x_pre.view(x_pre.shape[0], -1) 1343 | output = self.full_connection(x_pre) 1344 | # output = self.fc(x_pre) 1345 | return output # # #qu 1346 | 1347 | 1348 | class FDSSC_network(nn.Module): 1349 | def __init__(self, band, classes): 1350 | super(FDSSC_network, self).__init__() 1351 | 1352 | # spectral branch 1353 | self.name = 'FDSSC' 1354 | self.conv1 = nn.Conv3d(in_channels=1, out_channels=24, 1355 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1356 | # Dense block 1357 | self.batch_norm1 = nn.Sequential( 1358 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1359 | nn.PReLU() 1360 | ) 1361 | self.conv2 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 1362 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1363 | self.batch_norm2 = nn.Sequential( 1364 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1365 | nn.PReLU() 1366 | ) 1367 | self.conv3 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 1368 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1369 | self.batch_norm3 = nn.Sequential( 1370 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1371 | nn.PReLU() 1372 | ) 1373 | self.conv4 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 1374 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1375 | self.batch_norm4 = nn.Sequential( 1376 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 1377 | nn.ReLU(inplace=True) 1378 | ) 1379 | kernel_3d = math.ceil((band - 6) / 2) 1380 | # print(kernel_3d) 1381 | self.conv5 = nn.Conv3d(in_channels=60, out_channels=200, padding=(0, 0, 0), 1382 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) 1383 | 1384 | self.batch_norm5 = nn.Sequential( 1385 | nn.BatchNorm3d(1, eps=0.001, momentum=0.1, affine=True), 1386 | nn.PReLU() 1387 | ) 1388 | self.conv6 = nn.Conv3d(in_channels=1, out_channels=24, padding=(1, 1, 0), 1389 | kernel_size=(3, 3, 200), stride=(1, 1, 1)) 1390 | self.batch_norm6 = nn.Sequential( 1391 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1392 | nn.PReLU() 1393 | ) 1394 | self.conv7 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 1395 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1396 | self.batch_norm7 = nn.Sequential( 1397 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1398 | nn.PReLU() 1399 | ) 1400 | self.conv8 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 1401 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1402 | self.batch_norm8 = nn.Sequential( 1403 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1404 | nn.PReLU() 1405 | ) 1406 | self.conv9 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 1407 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1408 | self.batch_norm9 = nn.Sequential( 1409 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 1410 | nn.PReLU() 1411 | ) 1412 | 1413 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 1414 | self.full_connection = nn.Sequential( 1415 | nn.Dropout(p=0.5), 1416 | nn.Linear(60, classes) 1417 | # nn.Softmax() 1418 | ) 1419 | 1420 | 1421 | def forward(self, X): 1422 | # spectral 1423 | x1 = self.conv1(X) 1424 | #print('x11', x11.shape) 1425 | x2 = self.batch_norm1(x1) 1426 | x2 = self.conv2(x2) 1427 | #print('x12', x12.shape) 1428 | 1429 | x3 = torch.cat((x1, x2), dim=1) 1430 | #print('x13', x13.shape) 1431 | x3 = self.batch_norm2(x3) 1432 | x3 = self.conv3(x3) 1433 | #print('x13', x13.shape) 1434 | 1435 | x4 = torch.cat((x1, x2, x3), dim=1) 1436 | x4 = self.batch_norm3(x4) 1437 | x4 = self.conv4(x4) 1438 | 1439 | x5 = torch.cat((x1, x2, x3, x4), dim=1) 1440 | # print('x15', x15.shape) 1441 | 1442 | # print(x5.shape) 1443 | x6 = self.batch_norm4(x5) 1444 | x6 = self.conv5(x6) 1445 | #print('x16', x16.shape) # 7*7*97, 60 1446 | 1447 | #print('x16', x16.shape) 1448 | # 光谱注意力通道 1449 | x6 = x6.permute(0, 4, 2, 3, 1) 1450 | # print(x6.shape) 1451 | 1452 | x7 = self.batch_norm5(x6) 1453 | x7 = self.conv6(x7) 1454 | 1455 | x8 = self.batch_norm6(x7) 1456 | x8 = self.conv7(x8) 1457 | 1458 | x9 = torch.cat((x7, x8), dim=1) 1459 | x9 = self.batch_norm7(x9) 1460 | x9 = self.conv8(x9) 1461 | 1462 | x10 = torch.cat((x7, x8, x9), dim=1) 1463 | x10 = self.batch_norm8(x10) 1464 | x10 = self.conv9(x10) 1465 | 1466 | x10 = torch.cat((x7, x8, x9, x10), dim=1) 1467 | x10 = self.batch_norm9(x10) 1468 | x10 = self.global_pooling(x10) 1469 | x10 = x10.view(x10.size(0), -1) 1470 | 1471 | output = self.full_connection(x10) 1472 | # output = self.fc(x_pre) 1473 | return output 1474 | 1475 | 1476 | class DBDA_network_drop(nn.Module): 1477 | def __init__(self, band, classes): 1478 | super(DBDA_network_drop, self).__init__() 1479 | 1480 | # spectral branch 1481 | self.name = 'DBDA' 1482 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 1483 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1484 | # Dense block 1485 | self.batch_norm11 = nn.Sequential( 1486 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1487 | nn.ReLU(inplace=True) 1488 | ) 1489 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 1490 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1491 | self.batch_norm12 = nn.Sequential( 1492 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1493 | nn.ReLU(inplace=True) 1494 | ) 1495 | self.conv13 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 1496 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1497 | self.batch_norm13 = nn.Sequential( 1498 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1499 | nn.ReLU(inplace=True) 1500 | ) 1501 | self.conv14 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 1502 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1503 | self.batch_norm14 = nn.Sequential( 1504 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 1505 | nn.ReLU(inplace=True) 1506 | ) 1507 | kernel_3d = math.floor((band - 6) / 2) 1508 | self.conv15 = nn.Conv3d(in_channels=60, out_channels=60, 1509 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 1510 | 1511 | 1512 | # Spatial Branch 1513 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 1514 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 1515 | # Dense block 1516 | self.batch_norm21 = nn.Sequential( 1517 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 1518 | nn.ReLU(inplace=True) 1519 | ) 1520 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 1521 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1522 | self.batch_norm22 = nn.Sequential( 1523 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1524 | nn.ReLU(inplace=True) 1525 | ) 1526 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 1527 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1528 | self.batch_norm23 = nn.Sequential( 1529 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1530 | nn.ReLU(inplace=True) 1531 | ) 1532 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 1533 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1534 | 1535 | 1536 | self.conv25 = nn.Sequential( 1537 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 1538 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 1539 | nn.Sigmoid() 1540 | ) 1541 | 1542 | self.batch_norm_spectral = nn.Sequential( 1543 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1544 | nn.ReLU(inplace=True), 1545 | nn.Dropout(p=0.5) 1546 | ) 1547 | self.batch_norm_spatial = nn.Sequential( 1548 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1549 | nn.ReLU(inplace=True), 1550 | nn.Dropout(p=0.5) 1551 | ) 1552 | 1553 | 1554 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 1555 | self.full_connection = nn.Sequential( 1556 | #nn.Dropout(p=0.5), 1557 | nn.Linear(120, classes) # , 1558 | # nn.Softmax() 1559 | ) 1560 | 1561 | self.attention_spectral = CAM_Module(60) 1562 | self.attention_spatial = PAM_Module(60) 1563 | 1564 | #fc = Dense(classes, activation='softmax', name='output1', 1565 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 1566 | 1567 | def forward(self, X): 1568 | # spectral 1569 | x11 = self.conv11(X) 1570 | #print('x11', x11.shape) 1571 | x12 = self.batch_norm11(x11) 1572 | x12 = self.conv12(x12) 1573 | #print('x12', x12.shape) 1574 | 1575 | x13 = torch.cat((x11, x12), dim=1) 1576 | #print('x13', x13.shape) 1577 | x13 = self.batch_norm12(x13) 1578 | x13 = self.conv13(x13) 1579 | #print('x13', x13.shape) 1580 | 1581 | x14 = torch.cat((x11, x12, x13), dim=1) 1582 | x14 = self.batch_norm13(x14) 1583 | x14 = self.conv14(x14) 1584 | 1585 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 1586 | # print('x15', x15.shape) 1587 | 1588 | x16 = self.batch_norm14(x15) 1589 | x16 = self.conv15(x16) 1590 | #print('x16', x16.shape) # 7*7*97, 60 1591 | 1592 | #print('x16', x16.shape) 1593 | # 光谱注意力通道 1594 | x1 = self.attention_spectral(x16) 1595 | x1 = torch.mul(x1, x16) 1596 | 1597 | 1598 | # spatial 1599 | #print('x', X.shape) 1600 | x21 = self.conv21(X) 1601 | x22 = self.batch_norm21(x21) 1602 | x22 = self.conv22(x22) 1603 | 1604 | x23 = torch.cat((x21, x22), dim=1) 1605 | x23 = self.batch_norm22(x23) 1606 | x23 = self.conv23(x23) 1607 | 1608 | x24 = torch.cat((x21, x22, x23), dim=1) 1609 | x24 = self.batch_norm23(x24) 1610 | x24 = self.conv24(x24) 1611 | 1612 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 1613 | #print('x25', x25.shape) 1614 | # x25 = x25.permute(0, 4, 2, 3, 1) 1615 | #print('x25', x25.shape) 1616 | 1617 | # 空间注意力机制 1618 | x2 = self.attention_spatial(x25) 1619 | x2 = torch.mul(x2, x25) 1620 | 1621 | # model1 1622 | x1 = self.batch_norm_spectral(x1) 1623 | x1 = self.global_pooling(x1) 1624 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 1625 | x2 = self.batch_norm_spatial(x2) 1626 | x2= self.global_pooling(x2) 1627 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 1628 | 1629 | x_pre = torch.cat((x1, x2), dim=1) 1630 | #print('x_pre', x_pre.shape) 1631 | 1632 | # model2 1633 | # x1 = torch.mul(x2, x16) 1634 | # x2 = torch.mul(x2, x25) 1635 | # x_pre = x1 + x2 1636 | # 1637 | # 1638 | # x_pre = x_pre.view(x_pre.shape[0], -1) 1639 | output = self.full_connection(x_pre) 1640 | # output = self.fc(x_pre) 1641 | return output 1642 | 1643 | 1644 | class DBDA_network_MISH(nn.Module): 1645 | def __init__(self, band, classes): 1646 | super(DBDA_network_MISH, self).__init__() 1647 | 1648 | # spectral branch 1649 | self.name = 'DBDA_MISH' 1650 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 1651 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1652 | # Dense block 1653 | self.batch_norm11 = nn.Sequential( 1654 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1655 | #gelu_new() 1656 | #swish() 1657 | mish() 1658 | ) 1659 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 1660 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1661 | self.batch_norm12 = nn.Sequential( 1662 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1663 | #gelu_new() 1664 | #swish() 1665 | mish() 1666 | ) 1667 | self.conv13 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 1668 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1669 | self.batch_norm13 = nn.Sequential( 1670 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1671 | #gelu_new() 1672 | #swish() 1673 | mish() 1674 | ) 1675 | self.conv14 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 1676 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1677 | self.batch_norm14 = nn.Sequential( 1678 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 1679 | #gelu_new() 1680 | #swish() 1681 | mish() 1682 | ) 1683 | kernel_3d = math.floor((band - 6) / 2) 1684 | self.conv15 = nn.Conv3d(in_channels=60, out_channels=60, 1685 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 1686 | 1687 | 1688 | # Spatial Branch 1689 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 1690 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 1691 | # Dense block 1692 | self.batch_norm21 = nn.Sequential( 1693 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 1694 | #gelu_new() 1695 | #swish() 1696 | mish() 1697 | ) 1698 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 1699 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1700 | self.batch_norm22 = nn.Sequential( 1701 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1702 | #gelu_new() 1703 | #swish() 1704 | mish() 1705 | ) 1706 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 1707 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1708 | self.batch_norm23 = nn.Sequential( 1709 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1710 | #gelu_new() 1711 | #swish() 1712 | mish() 1713 | ) 1714 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 1715 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1716 | 1717 | 1718 | self.conv25 = nn.Sequential( 1719 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 1720 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 1721 | nn.Sigmoid() 1722 | ) 1723 | 1724 | self.batch_norm_spectral = nn.Sequential( 1725 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1726 | #gelu_new(), 1727 | #swish(), 1728 | mish(), 1729 | nn.Dropout(p=0.5) 1730 | ) 1731 | self.batch_norm_spatial = nn.Sequential( 1732 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1733 | #gelu_new(), 1734 | #swish(), 1735 | mish(), 1736 | nn.Dropout(p=0.5) 1737 | ) 1738 | 1739 | 1740 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 1741 | self.full_connection = nn.Sequential( 1742 | #nn.Dropout(p=0.5), 1743 | nn.Linear(120, classes) # , 1744 | # nn.Softmax() 1745 | ) 1746 | 1747 | self.attention_spectral = CAM_Module(60) 1748 | self.attention_spatial = PAM_Module(60) 1749 | 1750 | #fc = Dense(classes, activation='softmax', name='output1', 1751 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 1752 | 1753 | def forward(self, X): 1754 | # spectral 1755 | x11 = self.conv11(X) 1756 | #print('x11', x11.shape) 1757 | x12 = self.batch_norm11(x11) 1758 | x12 = self.conv12(x12) 1759 | #print('x12', x12.shape) 1760 | 1761 | x13 = torch.cat((x11, x12), dim=1) 1762 | #print('x13', x13.shape) 1763 | x13 = self.batch_norm12(x13) 1764 | x13 = self.conv13(x13) 1765 | #print('x13', x13.shape) 1766 | 1767 | x14 = torch.cat((x11, x12, x13), dim=1) 1768 | x14 = self.batch_norm13(x14) 1769 | x14 = self.conv14(x14) 1770 | 1771 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 1772 | # print('x15', x15.shape) 1773 | 1774 | x16 = self.batch_norm14(x15) 1775 | x16 = self.conv15(x16) 1776 | #print('x16', x16.shape) # 7*7*97, 60 1777 | 1778 | #print('x16', x16.shape) 1779 | # 光谱注意力通道 1780 | x1 = self.attention_spectral(x16) 1781 | x1 = torch.mul(x1, x16) 1782 | 1783 | 1784 | # spatial 1785 | #print('x', X.shape) 1786 | x21 = self.conv21(X) 1787 | x22 = self.batch_norm21(x21) 1788 | x22 = self.conv22(x22) 1789 | 1790 | x23 = torch.cat((x21, x22), dim=1) 1791 | x23 = self.batch_norm22(x23) 1792 | x23 = self.conv23(x23) 1793 | 1794 | x24 = torch.cat((x21, x22, x23), dim=1) 1795 | x24 = self.batch_norm23(x24) 1796 | x24 = self.conv24(x24) 1797 | 1798 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 1799 | #print('x25', x25.shape) 1800 | # x25 = x25.permute(0, 4, 2, 3, 1) 1801 | #print('x25', x25.shape) 1802 | 1803 | # 空间注意力机制 1804 | x2 = self.attention_spatial(x25) 1805 | x2 = torch.mul(x2, x25) 1806 | 1807 | # model1 1808 | x1 = self.batch_norm_spectral(x1) 1809 | x1 = self.global_pooling(x1) 1810 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 1811 | x2 = self.batch_norm_spatial(x2) 1812 | x2= self.global_pooling(x2) 1813 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 1814 | 1815 | x_pre = torch.cat((x1, x2), dim=1) 1816 | #print('x_pre', x_pre.shape) 1817 | 1818 | # model2 1819 | # x1 = torch.mul(x2, x16) 1820 | # x2 = torch.mul(x2, x25) 1821 | # x_pre = x1 + x2 1822 | # 1823 | # 1824 | # x_pre = x_pre.view(x_pre.shape[0], -1) 1825 | output = self.full_connection(x_pre) 1826 | # output = self.fc(x_pre) 1827 | return output 1828 | 1829 | 1830 | class SSRN_network(nn.Module): 1831 | def __init__(self, band, classes): 1832 | super(SSRN_network, self).__init__() 1833 | self.name = 'SSRN' 1834 | self.conv1 = nn.Conv3d(in_channels=1, out_channels=24, 1835 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1836 | self.batch_norm1 = nn.Sequential( 1837 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1838 | nn.ReLU(inplace=True) 1839 | ) 1840 | 1841 | self.res_net1 = Residual(24, 24, (1, 1, 7), (0, 0, 3)) 1842 | self.res_net2 = Residual(24, 24, (1, 1, 7), (0, 0, 3)) 1843 | self.res_net3 = Residual(24, 24, (3, 3, 1), (1, 1, 0)) 1844 | self.res_net4 = Residual(24, 24, (3, 3, 1), (1, 1, 0)) 1845 | 1846 | kernel_3d = math.ceil((band - 6) / 2) 1847 | 1848 | self.conv2 = nn.Conv3d(in_channels=24, out_channels=128, padding=(0, 0, 0), 1849 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) 1850 | self.batch_norm2 = nn.Sequential( 1851 | nn.BatchNorm3d(128, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1852 | nn.ReLU(inplace=True) 1853 | ) 1854 | self.conv3 = nn.Conv3d(in_channels=1, out_channels=24, padding=(0, 0, 0), 1855 | kernel_size=(3, 3, 128), stride=(1, 1, 1)) 1856 | self.batch_norm3 = nn.Sequential( 1857 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1858 | nn.ReLU(inplace=True) 1859 | ) 1860 | 1861 | self.avg_pooling = nn.AvgPool3d(kernel_size=(5, 5, 1)) 1862 | self.full_connection = nn.Sequential( 1863 | # nn.Dropout(p=0.5), 1864 | nn.Linear(24, classes) # , 1865 | # nn.Softmax() 1866 | ) 1867 | 1868 | def forward(self, X): 1869 | x1 = self.batch_norm1(self.conv1(X)) 1870 | # print('x1', x1.shape) 1871 | 1872 | x2 = self.res_net1(x1) 1873 | x2 = self.res_net2(x2) 1874 | x2 = self.batch_norm2(self.conv2(x2)) 1875 | x2 = x2.permute(0, 4, 2, 3, 1) 1876 | x2 = self.batch_norm3(self.conv3(x2)) 1877 | 1878 | x3 = self.res_net3(x2) 1879 | x3 = self.res_net4(x3) 1880 | x4 = self.avg_pooling(x3) 1881 | x4 = x4.view(x4.size(0), -1) 1882 | # print(x10.shape) 1883 | return self.full_connection(x4) 1884 | 1885 | 1886 | class DBDA_network_SPECTRAL(nn.Module): 1887 | def __init__(self, band, classes): 1888 | super(DBDA_network_SPECTRAL, self).__init__() 1889 | 1890 | # spectral branch 1891 | self.name = 'DBDA_SPECTRAL' 1892 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 1893 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 1894 | # Dense block 1895 | self.batch_norm11 = nn.Sequential( 1896 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1897 | #gelu_new() 1898 | #swish() 1899 | mish() 1900 | ) 1901 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 1902 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1903 | self.batch_norm12 = nn.Sequential( 1904 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1905 | #gelu_new() 1906 | #swish() 1907 | mish() 1908 | ) 1909 | self.conv13 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 1910 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1911 | self.batch_norm13 = nn.Sequential( 1912 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1913 | #gelu_new() 1914 | #swish() 1915 | mish() 1916 | ) 1917 | self.conv14 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 1918 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 1919 | self.batch_norm14 = nn.Sequential( 1920 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 1921 | #gelu_new() 1922 | #swish() 1923 | mish() 1924 | ) 1925 | kernel_3d = math.floor((band - 6) / 2) 1926 | self.conv15 = nn.Conv3d(in_channels=60, out_channels=60, 1927 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 1928 | 1929 | 1930 | # Spatial Branch 1931 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 1932 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 1933 | # Dense block 1934 | self.batch_norm21 = nn.Sequential( 1935 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 1936 | #gelu_new() 1937 | #swish() 1938 | mish() 1939 | ) 1940 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 1941 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1942 | self.batch_norm22 = nn.Sequential( 1943 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 1944 | #gelu_new() 1945 | #swish() 1946 | mish() 1947 | ) 1948 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 1949 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1950 | self.batch_norm23 = nn.Sequential( 1951 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 1952 | #gelu_new() 1953 | #swish() 1954 | mish() 1955 | ) 1956 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 1957 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 1958 | 1959 | 1960 | self.conv25 = nn.Sequential( 1961 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 1962 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 1963 | nn.Sigmoid() 1964 | ) 1965 | 1966 | self.batch_norm_spectral = nn.Sequential( 1967 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1968 | #gelu_new(), 1969 | #swish(), 1970 | mish(), 1971 | nn.Dropout(p=0.5) 1972 | ) 1973 | self.batch_norm_spatial = nn.Sequential( 1974 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 1975 | #gelu_new(), 1976 | #swish(), 1977 | mish(), 1978 | nn.Dropout(p=0.5) 1979 | ) 1980 | 1981 | 1982 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 1983 | self.full_connection = nn.Sequential( 1984 | #nn.Dropout(p=0.5), 1985 | nn.Linear(120, classes) # , 1986 | # nn.Softmax() 1987 | ) 1988 | 1989 | self.attention_spectral = CAM_Module(60) 1990 | self.attention_spatial = PAM_Module(60) 1991 | 1992 | #fc = Dense(classes, activation='softmax', name='output1', 1993 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 1994 | 1995 | def forward(self, X): 1996 | # spectral 1997 | x11 = self.conv11(X) 1998 | #print('x11', x11.shape) 1999 | x12 = self.batch_norm11(x11) 2000 | x12 = self.conv12(x12) 2001 | #print('x12', x12.shape) 2002 | 2003 | x13 = torch.cat((x11, x12), dim=1) 2004 | #print('x13', x13.shape) 2005 | x13 = self.batch_norm12(x13) 2006 | x13 = self.conv13(x13) 2007 | #print('x13', x13.shape) 2008 | 2009 | x14 = torch.cat((x11, x12, x13), dim=1) 2010 | x14 = self.batch_norm13(x14) 2011 | x14 = self.conv14(x14) 2012 | 2013 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 2014 | # print('x15', x15.shape) 2015 | 2016 | x16 = self.batch_norm14(x15) 2017 | x16 = self.conv15(x16) 2018 | #print('x16', x16.shape) # 7*7*97, 60 2019 | 2020 | #print('x16', x16.shape) 2021 | # 光谱注意力通道 2022 | x1 = self.attention_spectral(x16) 2023 | x1 = torch.mul(x1, x16) 2024 | 2025 | 2026 | # spatial 2027 | #print('x', X.shape) 2028 | x21 = self.conv21(X) 2029 | x22 = self.batch_norm21(x21) 2030 | x22 = self.conv22(x22) 2031 | 2032 | x23 = torch.cat((x21, x22), dim=1) 2033 | x23 = self.batch_norm22(x23) 2034 | x23 = self.conv23(x23) 2035 | 2036 | x24 = torch.cat((x21, x22, x23), dim=1) 2037 | x24 = self.batch_norm23(x24) 2038 | x24 = self.conv24(x24) 2039 | 2040 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 2041 | #print('x25', x25.shape) 2042 | # x25 = x25.permute(0, 4, 2, 3, 1) 2043 | #print('x25', x25.shape) 2044 | 2045 | # 空间注意力机制 2046 | x2 = x25 2047 | 2048 | # model1 2049 | x1 = self.batch_norm_spectral(x1) 2050 | x1 = self.global_pooling(x1) 2051 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 2052 | x2 = self.batch_norm_spatial(x2) 2053 | x2= self.global_pooling(x2) 2054 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 2055 | 2056 | x_pre = torch.cat((x1, x2), dim=1) 2057 | #print('x_pre', x_pre.shape) 2058 | 2059 | # model2 2060 | # x1 = torch.mul(x2, x16) 2061 | # x2 = torch.mul(x2, x25) 2062 | # x_pre = x1 + x2 2063 | # 2064 | # 2065 | # x_pre = x_pre.view(x_pre.shape[0], -1) 2066 | output = self.full_connection(x_pre) 2067 | # output = self.fc(x_pre) 2068 | return output 2069 | 2070 | 2071 | class DBDA_network_SPATIAL(nn.Module): 2072 | def __init__(self, band, classes): 2073 | super(DBDA_network_SPATIAL, self).__init__() 2074 | 2075 | # spectral branch 2076 | self.name = 'DBDA_SPATIAL' 2077 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 2078 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 2079 | # Dense block 2080 | self.batch_norm11 = nn.Sequential( 2081 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 2082 | #gelu_new() 2083 | #swish() 2084 | mish() 2085 | ) 2086 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 2087 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 2088 | self.batch_norm12 = nn.Sequential( 2089 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 2090 | #gelu_new() 2091 | #swish() 2092 | mish() 2093 | ) 2094 | self.conv13 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 2095 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 2096 | self.batch_norm13 = nn.Sequential( 2097 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 2098 | #gelu_new() 2099 | #swish() 2100 | mish() 2101 | ) 2102 | self.conv14 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 2103 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 2104 | self.batch_norm14 = nn.Sequential( 2105 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 2106 | #gelu_new() 2107 | #swish() 2108 | mish() 2109 | ) 2110 | kernel_3d = math.floor((band - 6) / 2) 2111 | self.conv15 = nn.Conv3d(in_channels=60, out_channels=60, 2112 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 2113 | 2114 | 2115 | # Spatial Branch 2116 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 2117 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 2118 | # Dense block 2119 | self.batch_norm21 = nn.Sequential( 2120 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 2121 | #gelu_new() 2122 | #swish() 2123 | mish() 2124 | ) 2125 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 2126 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 2127 | self.batch_norm22 = nn.Sequential( 2128 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 2129 | #gelu_new() 2130 | #swish() 2131 | mish() 2132 | ) 2133 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 2134 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 2135 | self.batch_norm23 = nn.Sequential( 2136 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 2137 | #gelu_new() 2138 | #swish() 2139 | mish() 2140 | ) 2141 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 2142 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 2143 | 2144 | 2145 | self.conv25 = nn.Sequential( 2146 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 2147 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 2148 | nn.Sigmoid() 2149 | ) 2150 | 2151 | self.batch_norm_spectral = nn.Sequential( 2152 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 2153 | #gelu_new(), 2154 | #swish(), 2155 | mish(), 2156 | nn.Dropout(p=0.5) 2157 | ) 2158 | self.batch_norm_spatial = nn.Sequential( 2159 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 2160 | #gelu_new(), 2161 | #swish(), 2162 | mish(), 2163 | nn.Dropout(p=0.5) 2164 | ) 2165 | 2166 | 2167 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 2168 | self.full_connection = nn.Sequential( 2169 | #nn.Dropout(p=0.5), 2170 | nn.Linear(120, classes) # , 2171 | # nn.Softmax() 2172 | ) 2173 | 2174 | self.attention_spectral = CAM_Module(60) 2175 | self.attention_spatial = PAM_Module(60) 2176 | 2177 | #fc = Dense(classes, activation='softmax', name='output1', 2178 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 2179 | 2180 | def forward(self, X): 2181 | # spectral 2182 | x11 = self.conv11(X) 2183 | #print('x11', x11.shape) 2184 | x12 = self.batch_norm11(x11) 2185 | x12 = self.conv12(x12) 2186 | #print('x12', x12.shape) 2187 | 2188 | x13 = torch.cat((x11, x12), dim=1) 2189 | #print('x13', x13.shape) 2190 | x13 = self.batch_norm12(x13) 2191 | x13 = self.conv13(x13) 2192 | #print('x13', x13.shape) 2193 | 2194 | x14 = torch.cat((x11, x12, x13), dim=1) 2195 | x14 = self.batch_norm13(x14) 2196 | x14 = self.conv14(x14) 2197 | 2198 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 2199 | # print('x15', x15.shape) 2200 | 2201 | x16 = self.batch_norm14(x15) 2202 | x16 = self.conv15(x16) 2203 | #print('x16', x16.shape) # 7*7*97, 60 2204 | 2205 | #print('x16', x16.shape) 2206 | # 光谱注意力通道 2207 | x1 = x16 2208 | 2209 | 2210 | # spatial 2211 | #print('x', X.shape) 2212 | x21 = self.conv21(X) 2213 | x22 = self.batch_norm21(x21) 2214 | x22 = self.conv22(x22) 2215 | 2216 | x23 = torch.cat((x21, x22), dim=1) 2217 | x23 = self.batch_norm22(x23) 2218 | x23 = self.conv23(x23) 2219 | 2220 | x24 = torch.cat((x21, x22, x23), dim=1) 2221 | x24 = self.batch_norm23(x24) 2222 | x24 = self.conv24(x24) 2223 | 2224 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 2225 | #print('x25', x25.shape) 2226 | # x25 = x25.permute(0, 4, 2, 3, 1) 2227 | #print('x25', x25.shape) 2228 | 2229 | # 空间注意力机制 2230 | x2 = self.attention_spatial(x25) 2231 | x2 = torch.mul(x2, x25) 2232 | 2233 | # model1 2234 | x1 = self.batch_norm_spectral(x1) 2235 | x1 = self.global_pooling(x1) 2236 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 2237 | x2 = self.batch_norm_spatial(x2) 2238 | x2= self.global_pooling(x2) 2239 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 2240 | 2241 | x_pre = torch.cat((x1, x2), dim=1) 2242 | #print('x_pre', x_pre.shape) 2243 | 2244 | # model2 2245 | # x1 = torch.mul(x2, x16) 2246 | # x2 = torch.mul(x2, x25) 2247 | # x_pre = x1 + x2 2248 | # 2249 | # 2250 | # x_pre = x_pre.view(x_pre.shape[0], -1) 2251 | output = self.full_connection(x_pre) 2252 | # output = self.fc(x_pre) 2253 | return output 2254 | 2255 | class DBZA_network(nn.Module): 2256 | def __init__(self, band, classes): 2257 | super(DBDA_network_MISH, self).__init__() 2258 | 2259 | # spectral branch 2260 | self.name = 'DBZA' 2261 | self.conv11 = nn.Conv3d(in_channels=1, out_channels=24, 2262 | kernel_size=(1, 1, 7), stride=(1, 1, 2)) 2263 | # Dense block 2264 | self.batch_norm11 = nn.Sequential( 2265 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 2266 | #gelu_new() 2267 | #swish() 2268 | mish() 2269 | ) 2270 | self.conv12 = nn.Conv3d(in_channels=24, out_channels=12, padding=(0, 0, 3), 2271 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 2272 | self.batch_norm12 = nn.Sequential( 2273 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 2274 | #gelu_new() 2275 | #swish() 2276 | mish() 2277 | ) 2278 | self.conv13 = nn.Conv3d(in_channels=36, out_channels=12, padding=(0, 0, 3), 2279 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 2280 | self.batch_norm13 = nn.Sequential( 2281 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 2282 | #gelu_new() 2283 | #swish() 2284 | mish() 2285 | ) 2286 | self.conv14 = nn.Conv3d(in_channels=48, out_channels=12, padding=(0, 0, 3), 2287 | kernel_size=(1, 1, 7), stride=(1, 1, 1)) 2288 | self.batch_norm14 = nn.Sequential( 2289 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), 2290 | #gelu_new() 2291 | #swish() 2292 | mish() 2293 | ) 2294 | kernel_3d = math.floor((band - 6) / 2) 2295 | self.conv15 = nn.Conv3d(in_channels=60, out_channels=60, 2296 | kernel_size=(1, 1, kernel_3d), stride=(1, 1, 1)) # kernel size随数据变化 2297 | 2298 | 2299 | # Spatial Branch 2300 | self.conv21 = nn.Conv3d(in_channels=1, out_channels=24, 2301 | kernel_size=(1, 1, band), stride=(1, 1, 1)) 2302 | # Dense block 2303 | self.batch_norm21 = nn.Sequential( 2304 | nn.BatchNorm3d(24, eps=0.001, momentum=0.1, affine=True), 2305 | #gelu_new() 2306 | #swish() 2307 | mish() 2308 | ) 2309 | self.conv22 = nn.Conv3d(in_channels=24, out_channels=12, padding=(1, 1, 0), 2310 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 2311 | self.batch_norm22 = nn.Sequential( 2312 | nn.BatchNorm3d(36, eps=0.001, momentum=0.1, affine=True), 2313 | #gelu_new() 2314 | #swish() 2315 | mish() 2316 | ) 2317 | self.conv23 = nn.Conv3d(in_channels=36, out_channels=12, padding=(1, 1, 0), 2318 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 2319 | self.batch_norm23 = nn.Sequential( 2320 | nn.BatchNorm3d(48, eps=0.001, momentum=0.1, affine=True), 2321 | #gelu_new() 2322 | #swish() 2323 | mish() 2324 | ) 2325 | self.conv24 = nn.Conv3d(in_channels=48, out_channels=12, padding=(1, 1, 0), 2326 | kernel_size=(3, 3, 1), stride=(1, 1, 1)) 2327 | 2328 | 2329 | self.conv25 = nn.Sequential( 2330 | nn.Conv3d(in_channels=1, out_channels=1, padding=(1, 1, 0), 2331 | kernel_size=(3, 3, 2), stride=(1, 1, 1)), 2332 | nn.Sigmoid() 2333 | ) 2334 | 2335 | self.batch_norm_spectral = nn.Sequential( 2336 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 2337 | #gelu_new(), 2338 | #swish(), 2339 | mish(), 2340 | nn.Dropout(p=0.5) 2341 | ) 2342 | self.batch_norm_spatial = nn.Sequential( 2343 | nn.BatchNorm3d(60, eps=0.001, momentum=0.1, affine=True), # 动量默认值为0.1 2344 | #gelu_new(), 2345 | #swish(), 2346 | mish(), 2347 | nn.Dropout(p=0.5) 2348 | ) 2349 | 2350 | 2351 | self.global_pooling = nn.AdaptiveAvgPool3d(1) 2352 | self.full_connection = nn.Sequential( 2353 | #nn.Dropout(p=0.5), 2354 | nn.Linear(120, classes) # , 2355 | # nn.Softmax() 2356 | ) 2357 | 2358 | self.attention_spectral = CAM_Module(60) 2359 | self.attention_spatial = PAM_Module(60) 2360 | 2361 | #fc = Dense(classes, activation='softmax', name='output1', 2362 | # kernel_initializer=RandomNormal(mean=0.0, stddev=0.01)) 2363 | 2364 | def forward(self, X): 2365 | # spectral 2366 | x11 = self.conv11(X) 2367 | #print('x11', x11.shape) 2368 | x12 = self.batch_norm11(x11) 2369 | x12 = self.conv12(x12) 2370 | #print('x12', x12.shape) 2371 | 2372 | x13 = torch.cat((x11, x12), dim=1) 2373 | #print('x13', x13.shape) 2374 | x13 = self.batch_norm12(x13) 2375 | x13 = self.conv13(x13) 2376 | #print('x13', x13.shape) 2377 | 2378 | x14 = torch.cat((x11, x12, x13), dim=1) 2379 | x14 = self.batch_norm13(x14) 2380 | x14 = self.conv14(x14) 2381 | 2382 | x15 = torch.cat((x11, x12, x13, x14), dim=1) 2383 | # print('x15', x15.shape) 2384 | 2385 | x16 = self.batch_norm14(x15) 2386 | x16 = self.conv15(x16) 2387 | #print('x16', x16.shape) # 7*7*97, 60 2388 | 2389 | #print('x16', x16.shape) 2390 | # 光谱注意力通道 2391 | x1 = x16 2392 | 2393 | 2394 | # spatial 2395 | #print('x', X.shape) 2396 | x21 = self.conv21(X) 2397 | x22 = self.batch_norm21(x21) 2398 | x22 = self.conv22(x22) 2399 | 2400 | x23 = torch.cat((x21, x22), dim=1) 2401 | x23 = self.batch_norm22(x23) 2402 | x23 = self.conv23(x23) 2403 | 2404 | x24 = torch.cat((x21, x22, x23), dim=1) 2405 | x24 = self.batch_norm23(x24) 2406 | x24 = self.conv24(x24) 2407 | 2408 | x25 = torch.cat((x21, x22, x23, x24), dim=1) 2409 | #print('x25', x25.shape) 2410 | # x25 = x25.permute(0, 4, 2, 3, 1) 2411 | #print('x25', x25.shape) 2412 | 2413 | # 空间注意力机制 2414 | x2 = x25 2415 | 2416 | # model1 2417 | x1 = self.batch_norm_spectral(x1) 2418 | x1 = self.global_pooling(x1) 2419 | x1 = x1.squeeze(-1).squeeze(-1).squeeze(-1) 2420 | x2 = self.batch_norm_spatial(x2) 2421 | x2= self.global_pooling(x2) 2422 | x2 = x2.squeeze(-1).squeeze(-1).squeeze(-1) 2423 | 2424 | x_pre = torch.cat((x1, x2), dim=1) 2425 | #print('x_pre', x_pre.shape) 2426 | 2427 | # model2 2428 | # x1 = torch.mul(x2, x16) 2429 | # x2 = torch.mul(x2, x25) 2430 | # x_pre = x1 + x2 2431 | # 2432 | # 2433 | # x_pre = x_pre.view(x_pre.shape[0], -1) 2434 | output = self.full_connection(x_pre) 2435 | # output = self.fc(x_pre) 2436 | return output 2437 | 2438 | 2439 | 2440 | class svm_rbf(): 2441 | def __init__(self, data, label): 2442 | self.name = 'SVM_RBF' 2443 | self.trainx = data 2444 | self.trainy = label 2445 | 2446 | def train(self): 2447 | cost = [] 2448 | gamma = [] 2449 | for i in range(-3, 10, 2): 2450 | cost.append(np.power(2.0, i)) 2451 | for i in range(-5, 4, 2): 2452 | gamma.append(np.power(2.0, i)) 2453 | 2454 | parameters = {'C': cost, 'gamma': gamma} 2455 | svm = SVC(verbose=0, kernel='rbf') 2456 | clf = GridSearchCV(svm, parameters, cv=3) 2457 | clf.fit(self.trainx, self.trainy) 2458 | 2459 | # print(clf.best_params_) 2460 | bestc = clf.best_params_['C'] 2461 | bestg = clf.best_params_['gamma'] 2462 | tmpc = [-1.75, -1.5, -1.25, -1, -0.75, -0.5, -0.25, 0.0, 2463 | 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 2464 | cost = [] 2465 | gamma = [] 2466 | for i in tmpc: 2467 | cost.append(bestc * np.power(2.0, i)) 2468 | gamma.append(bestg * np.power(2.0, i)) 2469 | parameters = {'C': cost, 'gamma': gamma} 2470 | svm = SVC(verbose=0, kernel='rbf') 2471 | clf = GridSearchCV(svm, parameters, cv=3) 2472 | clf.fit(self.trainx, self.trainy) 2473 | # print(clf.best_params_) 2474 | p = clf.best_estimator_ 2475 | return p 2476 | 2477 | --------------------------------------------------------------------------------